├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── a9a ├── a9a ├── a9a.t └── readme.txt └── src ├── IRLS_jax.py ├── IRLS_jittor.py ├── IRLS_megengine.py ├── IRLS_paddle.py ├── IRLS_paddle3.py ├── IRLS_pytorch.py ├── IRLS_tf.py ├── IRLS_tf_v2.py └── tf_utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | a9a/* linguist-vendored 2 | *.py linguist-language=Python 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | src/log/ 2 | logs 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *,cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # IPython Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # dotenv 82 | .env 83 | 84 | # virtualenv 85 | venv/ 86 | ENV/ 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | 91 | # Rope project settings 92 | .ropeproject 93 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {2017-2019} {Gu Wang} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IRLS using DL frameworks 2 | * IRLS(Iterative re-weighted least square) for Logistic Regression, implemented using 3 | * tensorflow < 2.0 4 | * tensorflow2.0 5 | * pytorch 6 | * paddlepaddle 7 | * megengine 8 | 9 | * Note that IRLS is a second order optimization problem, which is equivalent to Newton's method. 10 | 11 | * We show that these DL frameworks can do general matrix based algorithms, and can be accelerated by the power of gpu. 12 | 13 | * In this implementation, we use svd to solve pseudo inverse of singular matrices. 14 | -------------------------------------------------------------------------------- /a9a/readme.txt: -------------------------------------------------------------------------------- 1 | # of classes: 2 2 | # of data: 32,561 / 16,281 (testing) 3 | # of features: 123 / 123 (testing) 4 | 5 | 6 | a9a : training dataset 7 | a9a.t: testing dataset 8 | 9 | format: Label [feature id]:[feature] [feature id]:[feature] [feature id]:[feature] [feature id]:[feature] ...... 10 | 11 | -------------------------------------------------------------------------------- /src/IRLS_jax.py: -------------------------------------------------------------------------------- 1 | # python 3 2 | from __future__ import print_function, division, absolute_import 3 | from functools import partial 4 | import os.path as osp 5 | import argparse 6 | import random 7 | from loguru import logger 8 | import numpy as np 9 | import time 10 | import jax 11 | import jax.numpy as jnp 12 | from jax import lax 13 | from jax import grad, jit, vmap 14 | 15 | from sklearn.datasets import load_svmlight_file 16 | from scipy.sparse import csr_matrix 17 | 18 | cur_dir = osp.dirname(osp.abspath(__file__)) 19 | path_train = osp.join(cur_dir, "../a9a/a9a") 20 | path_test = osp.join(cur_dir, "../a9a/a9a.t") 21 | MAX_ITER = 100 22 | np_dtype = np.float32 23 | 24 | 25 | # manual seed 26 | manualSeed = random.randint(1, 10000) # fix seed 27 | print("Random Seed: ", manualSeed) 28 | random.seed(manualSeed) 29 | np.random.seed(manualSeed) 30 | 31 | # load all data 32 | X_train, y_train = load_svmlight_file(path_train, n_features=123, dtype=np_dtype) 33 | X_test, y_test = load_svmlight_file(path_test, n_features=123, dtype=np_dtype) 34 | # X: scipy.sparse.csr.csr_matrix 35 | 36 | # X_train: (32561, 123), y_train: (32561,) 37 | # X_test: (16281, 123), y_test:(16281,) 38 | 39 | # stack a dimension of ones to X to simplify computation 40 | N_train = X_train.shape[0] 41 | N_test = X_test.shape[0] 42 | X_train = np.hstack((np.ones((N_train, 1), dtype=np_dtype), X_train.toarray())) 43 | X_test = np.hstack((np.ones((N_test, 1), dtype=np_dtype), X_test.toarray())) 44 | 45 | # X_train = csr_matrix(X_train, dtype=dtype) 46 | # X_test = csr_matrix(X_test, dtype=dtype) 47 | 48 | y_train = y_train.reshape((N_train, 1)) 49 | y_test = y_test.reshape((N_test, 1)) 50 | 51 | # label: -1, +1 ==> 0, 1 52 | y_train = np.float32(np.where(y_train == -1, 0, 1)) 53 | y_test = np.float32(np.where(y_test == -1, 0, 1)) 54 | 55 | 56 | # NB: here X's shape is (N,d), which differs to the derivation 57 | 58 | 59 | @jit 60 | def neg_log_likelihood(w, X, y, L2_param): 61 | """ 62 | w: dx1 63 | X: Nxd 64 | y: Nx1 65 | L2_param: \lambda>0, will introduce -\lambda/2 ||w||_2^2 66 | """ 67 | Xw = jnp.dot(X, w) 68 | res = jnp.dot(Xw.transpose(), y) - jnp.log(1 + jnp.exp(Xw)).sum() 69 | # if L2_param is not None and L2_param > 0: 70 | # res += -0.5 * L2_param * (jnp.dot(w.transpose(), w)) 71 | res += lax.cond( 72 | L2_param is not None and L2_param > 0, 73 | lambda a, b: -0.5 * a * jnp.dot(b.transpose(), b), 74 | lambda a, b: 0.0 * a * jnp.dot(b.transpose(), b), 75 | L2_param, 76 | w, 77 | ) 78 | return -res 79 | 80 | 81 | @jit 82 | def prob(X, w): 83 | """ 84 | X: Nxd 85 | w: dx1 86 | --- 87 | prob: N x num_classes(2)""" 88 | Xw = jnp.dot(X, w) 89 | y = jnp.array([[0.0, 1.0]]) # 1x2 90 | return jnp.exp(Xw * y) / (1 + jnp.exp(Xw)) # Nx2 91 | 92 | 93 | @jit 94 | def compute_acc(X, y, w): 95 | p = prob(X, w) 96 | y_pred = jnp.argmax(p, 1) 97 | return (y.flatten() == y_pred).astype("float32").mean() 98 | 99 | 100 | @jit 101 | def pinv_naive(A): 102 | # dtype = A.dtype 103 | U, S, Vh = jnp.linalg.svd(A, full_matrices=True) 104 | threshold = jnp.max(S) * 1e-5 105 | S_pinv = jnp.where(S > threshold, 1 / S, jnp.zeros_like(S)) 106 | # S_mask = S[S > threshold] 107 | # S_pinv = jnp.concatenate([1.0 / S_mask, jnp.full([S.size - S_mask.size], 0.0, dtype=dtype)], 0) 108 | # A_pinv = V @ S_pinv.diag() @ U.transpose() 109 | A_pinv = jnp.dot(jnp.dot(Vh.transpose(), jnp.diag(S_pinv)), U.transpose()) 110 | return A_pinv 111 | 112 | 113 | # @partial(jit, static_argnums=(3,)) 114 | @jit 115 | def update_weight(w_old, X, y, L2_param, identity): 116 | """ 117 | w_new = w_old - w_update 118 | w_update = (X'RX+lambda*I)^(-1) (X'(mu-y) + lambda*w_old) 119 | lambda is L2_param 120 | 121 | w_old: dx1 122 | X: Nxd 123 | y: Nx1 124 | --- 125 | w_new: dx1 126 | """ 127 | mu = jax.nn.sigmoid(jnp.dot(X, w_old)) # Nx1 128 | 129 | R_flat = mu * (1 - mu) # element-wise, Nx1 130 | 131 | XRX = jnp.dot(X.transpose(), (R_flat * X)) # dxd 132 | # if L2_param > 0: 133 | XRX += L2_param * identity 134 | 135 | # XRX += lax.cond(L2_param>0, lambda a, b: a * jnp.identity((b, b)), lambda a,b: jnp.zeros((b, b)), L2_param, XRX) 136 | # np.save('XRX_pytorch.npy', XRX.cpu().numpy()) 137 | 138 | # Method 1: Calculate pseudo inverse via SVD 139 | # For singular matrices, we invert the singular 140 | # values above certain threshold (computed with the max singular value) 141 | # this is slightly better than torch.pinverse when L2_param=0 142 | XRX_pinv = pinv_naive(XRX) 143 | 144 | # method 2 145 | # XRX_pinv = jnp.linalg.pinv(XRX) 146 | 147 | # w = w - (X^T R X)^(-1) X^T (mu-y) 148 | val = jnp.dot(X.transpose(), (mu - y)) 149 | # if L2_param > 0: 150 | # val += L2_param * w_old 151 | val += lax.cond(L2_param > 0, lambda a, b: a * b, lambda a, b: jnp.zeros_like(b), L2_param, w_old) 152 | 153 | w_update = jnp.dot(XRX_pinv, val) 154 | w_new = w_old - w_update 155 | return w_new 156 | 157 | 158 | @logger.catch 159 | # @partial(jit, static_argnums=(5,)) 160 | def train_IRLS(X_train, y_train, X_test=None, y_test=None, L2_param=0, max_iter=MAX_ITER, identity=None): 161 | """train Logistic Regression via IRLS algorithm 162 | X: Nxd 163 | y: Nx1 164 | --- 165 | 166 | """ 167 | N, d = X_train.shape 168 | X_train = jax.device_put(jnp.array(X_train)) 169 | X_test = jax.device_put(jnp.array(X_test)) 170 | y_train = jax.device_put(jnp.array(y_train)) 171 | y_test = jax.device_put(jnp.array(y_test)) 172 | 173 | w = jax.device_put(jnp.full((d, 1), 0.01, dtype="float32")) 174 | 175 | print("start training...") 176 | tic = time.time() 177 | # print("Device: {}".format(device)) 178 | print("L2 param(lambda): {}".format(L2_param)) 179 | i = 0 180 | # iteration 181 | while i <= max_iter: 182 | print("iter: {}".format(i)) 183 | 184 | neg_L = neg_log_likelihood(w, X_train, y_train, L2_param) 185 | print("\t neg log likelihood: {}".format(neg_L.sum())) 186 | 187 | train_acc = compute_acc(X_train, y_train, w) 188 | test_acc = compute_acc(X_test, y_test, w) 189 | print("\t train acc: {}, test acc: {}".format(train_acc, test_acc)) 190 | 191 | L2_norm_w = jnp.linalg.norm(w) 192 | print("\t L2 norm of w: {}".format(L2_norm_w.item())) 193 | 194 | if i > 0: 195 | diff_w = jnp.linalg.norm(w - w_old_data) 196 | print("\t diff of w_old and w: {}".format(diff_w.item())) 197 | if diff_w < 1e-2: 198 | break 199 | 200 | w_old_data = jnp.array(w, copy=True) 201 | w = update_weight(w, X_train, y_train, L2_param, identity) 202 | i += 1 203 | print(f"training done, using {time.time() - tic}s.") 204 | # still much slower than pytorch 205 | 206 | 207 | if __name__ == "__main__": 208 | lambda_ = 20 # 0 209 | d = X_train.shape[1] 210 | identity = jax.device_put(jnp.identity(d, dtype="float32")) 211 | train_IRLS(X_train, y_train, X_test, y_test, L2_param=lambda_, max_iter=MAX_ITER, identity=identity) 212 | 213 | # from sklearn.linear_model import LogisticRegression 214 | # classifier = LogisticRegression() 215 | # classifier.fit(X_train, y_train.reshape(N_train,)) 216 | # y_pred_train = classifier.predict(X_train) 217 | # train_acc = np.sum(y_train.reshape(N_train,) == y_pred_train)/N_train 218 | # print('train_acc: {}'.format(train_acc)) 219 | # y_pred_test = classifier.predict(X_test) 220 | # test_acc = np.sum(y_test.reshape(N_test,) == y_pred_test)/N_test 221 | # print('test acc: {}'.format(test_acc)) 222 | -------------------------------------------------------------------------------- /src/IRLS_jittor.py: -------------------------------------------------------------------------------- 1 | # python 3 2 | from __future__ import print_function, division, absolute_import 3 | 4 | import os.path as osp 5 | import argparse 6 | import random 7 | from loguru import logger 8 | import numpy as np 9 | import time 10 | 11 | # from numpy import linalg 12 | import jittor as jt 13 | 14 | from sklearn.datasets import load_svmlight_file 15 | from scipy.sparse import csr_matrix 16 | 17 | # from scipy.sparse import linalg 18 | import matplotlib 19 | 20 | matplotlib.use("Agg") 21 | import matplotlib.pyplot as plt 22 | 23 | jt.flags.use_cuda = jt.has_cuda 24 | 25 | cur_dir = osp.dirname(osp.abspath(__file__)) 26 | path_train = osp.join(cur_dir, "../a9a/a9a") 27 | path_test = osp.join(cur_dir, "../a9a/a9a.t") 28 | MAX_ITER = 100 29 | np_dtype = np.float32 30 | 31 | # manual seed 32 | manualSeed = random.randint(1, 10000) # fix seed 33 | print("Random Seed: ", manualSeed) 34 | random.seed(manualSeed) 35 | np.random.seed(manualSeed) 36 | 37 | # load all data 38 | X_train, y_train = load_svmlight_file(path_train, n_features=123, dtype=np_dtype) 39 | X_test, y_test = load_svmlight_file(path_test, n_features=123, dtype=np_dtype) 40 | # X: scipy.sparse.csr.csr_matrix 41 | 42 | # X_train: (32561, 123), y_train: (32561,) 43 | # X_test: (16281, 123), y_test:(16281,) 44 | 45 | # stack a dimension of ones to X to simplify computation 46 | N_train = X_train.shape[0] 47 | N_test = X_test.shape[0] 48 | X_train = np.hstack((np.ones((N_train, 1), dtype=np_dtype), X_train.toarray())) 49 | X_test = np.hstack((np.ones((N_test, 1), dtype=np_dtype), X_test.toarray())) 50 | 51 | # X_train = csr_matrix(X_train, dtype=dtype) 52 | # X_test = csr_matrix(X_test, dtype=dtype) 53 | 54 | y_train = y_train.reshape((N_train, 1)) 55 | y_test = y_test.reshape((N_test, 1)) 56 | 57 | # label: -1, +1 ==> 0, 1 58 | y_train = np.float32(np.where(y_train == -1, 0, 1)) 59 | y_test = np.float32(np.where(y_test == -1, 0, 1)) 60 | 61 | 62 | # NB: here X's shape is (N,d), which differs to the derivation 63 | 64 | 65 | def neg_log_likelihood(w, X, y, L2_param=None): 66 | """ 67 | w: dx1 68 | X: Nxd 69 | y: Nx1 70 | L2_param: \lambda>0, will introduce -\lambda/2 ||w||_2^2 71 | """ 72 | Xw = X @ w 73 | res = Xw.t() @ y - jt.log(1 + Xw.exp()).sum() 74 | if L2_param is not None and L2_param > 0: 75 | res += -0.5 * L2_param * (w.t() @ w) 76 | return -res 77 | 78 | 79 | def prob(X, w): 80 | """ 81 | X: Nxd 82 | w: dx1 83 | --- 84 | prob: N x num_classes(2)""" 85 | Xw = X @ w 86 | y = jt.array([[0.0, 1.0]]) # 1x2 87 | return (Xw * y).exp() / (1 + Xw.exp()) # Nx2 88 | 89 | 90 | def compute_acc(X, y, w): 91 | p = prob(X, w) 92 | y_pred = jt.argmax(p, dim=1)[0] 93 | # print(y_pred.shape) 94 | # print(y.shape) 95 | return (y.flatten() == y_pred).float().mean() 96 | 97 | 98 | # def pinv_naive(A): 99 | # dtype = A.dtype 100 | # # U, S, V = torch.svd(A, some=False) 101 | # # does not support full_matrices=True yet 102 | # U, S, Vh = jt.linalg.svd(A, full_matrices=True) 103 | # threshold = jt.max(S) * 1e-5 104 | # # S_pinv = torch.where(S > threshold, 1/S, torch.zeros_like(S)) 105 | # S_mask = S[S > threshold] 106 | # S_pinv = jt.cat([1.0 / S_mask, jt.full([S.numel() - S_mask.numel()], 0.0, dtype=dtype)], 0) 107 | # # A_pinv = V @ S_pinv.diag() @ U.t() 108 | # A_pinv = Vh.t() @ S_pinv.diag() @ U.t() 109 | # return A_pinv 110 | 111 | 112 | def update_weight(w_old, X, y, L2_param=0): 113 | """ 114 | w_new = w_old - w_update 115 | w_update = (X'RX+lambda*I)^(-1) (X'(mu-y) + lambda*w_old) 116 | lambda is L2_param 117 | 118 | w_old: dx1 119 | X: Nxd 120 | y: Nx1 121 | --- 122 | w_new: dx1 123 | """ 124 | mu = (X @ w_old).sigmoid() # Nx1 125 | 126 | R_flat = mu * (1 - mu) # element-wise, Nx1 127 | 128 | XRX = X.t() @ (R_flat.expand_as(X) * X) # dxd 129 | if L2_param > 0: 130 | # for i in range(XRX.shape[0]): 131 | # XRX[i, i] += L2_param 132 | XRX += L2_param * jt.init.eye(XRX.shape[0]) 133 | # jt.misc.diag(XRX).add_(L2_param) 134 | 135 | # np.save('XRX_pytorch.npy', XRX.cpu().numpy()) 136 | 137 | # Method 1: Calculate pseudo inverse via SVD 138 | # For singular matrices, we invert the singular 139 | # values above certain threshold (computed with the max singular value) 140 | # this is slightly better than torch.pinverse when L2_param=0 141 | # XRX_pinv = pinv_naive(XRX) 142 | XRX_pinv = jt.linalg.pinv(XRX) 143 | # method 2 144 | # XRX_pinv = torch.pinverse(XRX) 145 | 146 | # w = w - (X^T R X)^(-1) X^T (mu-y) 147 | val = X.t() @ (mu - y) 148 | if L2_param > 0: 149 | val += L2_param * w_old 150 | 151 | w_update = XRX_pinv @ val 152 | w_new = w_old - w_update 153 | return w_new 154 | 155 | 156 | @logger.catch 157 | def train_IRLS(X_train, y_train, X_test=None, y_test=None, L2_param=0, max_iter=MAX_ITER): 158 | """train Logistic Regression via IRLS algorithm 159 | X: Nxd 160 | y: Nx1 161 | --- 162 | 163 | """ 164 | N, d = X_train.shape 165 | X_train = jt.array(X_train) 166 | X_test = jt.array(X_test) 167 | y_train = jt.array(y_train) 168 | y_test = jt.array(y_test) 169 | 170 | w = jt.full((d, 1), 0.01) 171 | jt.sync_all(True) 172 | print("start training...") 173 | tic = time.time() 174 | print("L2 param(lambda): {}".format(L2_param)) 175 | i = 0 176 | # iteration 177 | while i <= max_iter: 178 | print("iter: {}".format(i)) 179 | 180 | neg_L = neg_log_likelihood(w, X_train, y_train, L2_param) 181 | print("\t neg log likelihood: {}".format(neg_L.sum())) 182 | 183 | train_acc = compute_acc(X_train, y_train, w) 184 | test_acc = compute_acc(X_test, y_test, w) 185 | print("\t train acc: {}, test acc: {}".format(train_acc, test_acc)) 186 | 187 | L2_norm_w = jt.norm(w, dim=(0, 1)) 188 | print("\t L2 norm of w: {}".format(L2_norm_w.item())) 189 | 190 | if i > 0: 191 | diff_w = jt.norm(w - w_old_data, dim=(0, 1)) 192 | print("\t diff of w_old and w: {}".format(diff_w.item())) 193 | if diff_w < 1e-2: 194 | break 195 | 196 | w_old_data = w.clone() 197 | w = update_weight(w, X_train, y_train, L2_param) 198 | i += 1 199 | jt.sync_all(True) 200 | print(f"training done, using {time.time() - tic}s.") 201 | 202 | 203 | if __name__ == "__main__": 204 | lambda_ = 20 # 0 205 | train_IRLS(X_train, y_train, X_test, y_test, L2_param=lambda_, max_iter=100) 206 | 207 | # from sklearn.linear_model import LogisticRegression 208 | # classifier = LogisticRegression() 209 | # classifier.fit(X_train, y_train.reshape(N_train,)) 210 | # y_pred_train = classifier.predict(X_train) 211 | # train_acc = np.sum(y_train.reshape(N_train,) == y_pred_train)/N_train 212 | # print('train_acc: {}'.format(train_acc)) 213 | # y_pred_test = classifier.predict(X_test) 214 | # test_acc = np.sum(y_test.reshape(N_test,) == y_pred_test)/N_test 215 | # print('test acc: {}'.format(test_acc)) 216 | -------------------------------------------------------------------------------- /src/IRLS_megengine.py: -------------------------------------------------------------------------------- 1 | # python 3 2 | from __future__ import print_function, division, absolute_import 3 | import os.path as osp 4 | import random 5 | import numpy as np 6 | from loguru import logger 7 | import time 8 | 9 | # from numpy import linalg 10 | import megengine as mge 11 | import megengine.functional as F 12 | 13 | from sklearn.datasets import load_svmlight_file 14 | 15 | # mge.core.set_option("async_level", 0) 16 | 17 | cur_dir = osp.abspath(osp.dirname(__file__)) 18 | path_train = osp.join(cur_dir, "../a9a/a9a") 19 | path_test = osp.join(cur_dir, "../a9a/a9a.t") 20 | MAX_ITER = 100 21 | np_dtype = np.float32 22 | 23 | if mge.is_cuda_available(): 24 | device = "gpu0:0" 25 | else: 26 | device = "cpu0:0" 27 | 28 | # manual seed 29 | manualSeed = random.randint(1, 10000) # fix seed 30 | print("Random Seed: ", manualSeed) 31 | random.seed(manualSeed) 32 | np.random.seed(manualSeed) 33 | 34 | # load all data 35 | X_train, y_train = load_svmlight_file(path_train, n_features=123, dtype=np_dtype) 36 | X_test, y_test = load_svmlight_file(path_test, n_features=123, dtype=np_dtype) 37 | # X: scipy.sparse.csr.csr_matrix 38 | 39 | # X_train: (32561, 123), y_train: (32561,) 40 | # X_test: (16281, 123), y_test:(16281,) 41 | 42 | # stack a dimension of ones to X to simplify computation 43 | N_train = X_train.shape[0] 44 | N_test = X_test.shape[0] 45 | X_train = np.hstack((np.ones((N_train, 1), dtype=np_dtype), X_train.toarray())) 46 | X_test = np.hstack((np.ones((N_test, 1), dtype=np_dtype), X_test.toarray())) 47 | 48 | # X_train = csr_matrix(X_train, dtype=dtype) 49 | # X_test = csr_matrix(X_test, dtype=dtype) 50 | 51 | y_train = y_train.reshape((N_train, 1)) 52 | y_test = y_test.reshape((N_test, 1)) 53 | 54 | # label: -1, +1 ==> 0, 1 55 | y_train = np.float32(np.where(y_train == -1, 0, 1)) 56 | y_test = np.float32(np.where(y_test == -1, 0, 1)) 57 | 58 | 59 | # NB: here X's shape is (N,d), which differs to the derivation 60 | 61 | 62 | def neg_log_likelihood(w, X, y, L2_param=None): 63 | """ 64 | w: dx1 65 | X: Nxd 66 | y: Nx1 67 | L2_param: \lambda>0, will introduce -\lambda/2 ||w||_2^2 68 | """ 69 | Xw = X @ w 70 | res = Xw.transpose() @ y - F.log(1 + F.exp(Xw)).sum() 71 | if L2_param is not None and L2_param > 0: 72 | res += -0.5 * L2_param * (w.transpose() @ w) 73 | return -res 74 | 75 | 76 | def prob(X, w): 77 | """ 78 | X: Nxd 79 | w: dx1 80 | --- 81 | prob: N x num_classes(2)""" 82 | Xw = X @ w 83 | y = mge.tensor([[0.0, 1.0]]) # 1x2 84 | return F.exp(Xw * y) / (1 + F.exp(Xw)) # Nx2 85 | 86 | 87 | def compute_acc(X, y, w): 88 | p = prob(X, w) 89 | y_pred = F.argmax(p, 1).to(y.device).astype(y.dtype) 90 | return (y.flatten() == y_pred).astype("float32").mean() 91 | 92 | 93 | def get_mat_diag(X): 94 | n = X.shape[0] 95 | mask = F.eye(n).astype("bool") 96 | return F.cond_take(mask, X)[0] 97 | 98 | 99 | def vec_to_diag_mat(v): 100 | n = np.prod(v.shape) 101 | D = v.flatten() * F.eye(n, dtype=v.dtype) 102 | return D 103 | 104 | 105 | def pinv_naive(A): 106 | device = A.device 107 | dtype = A.dtype 108 | U, S, Vh = F.svd(A, full_matrices=True) 109 | threshold = F.max(S) * 1e-5 110 | S_mask = S[S > threshold] 111 | S_pinv = F.concat( 112 | [1.0 / S_mask, F.full([np.prod(S.shape) - np.prod(S_mask.shape)], 0.0, device=device, dtype=dtype)], 0 113 | ) 114 | A_pinv = Vh.transpose() @ vec_to_diag_mat(S_pinv) @ U.transpose() 115 | return A_pinv 116 | 117 | 118 | def update_weight(w_old, X, y, L2_param=0): 119 | """ 120 | w_new = w_old - w_update 121 | w_update = (X'RX+lambda*I)^(-1) (X'(mu-y) + lambda*w_old) 122 | lambda is L2_param 123 | 124 | w_old: dx1 125 | X: Nxd 126 | y: Nx1 127 | --- 128 | w_new: dx1 129 | """ 130 | mu = F.sigmoid(X @ w_old) # Nx1 131 | 132 | R_flat = mu * (1 - mu) # element-wise, Nx1 133 | 134 | XRX = X.transpose() @ (F.broadcast_to(R_flat, X.shape) * X) # dxd 135 | if L2_param > 0: 136 | XRX += L2_param * F.eye(XRX.shape[0]) 137 | 138 | # np.save('XRX_pytorch.npy', XRX.cpu().numpy()) 139 | 140 | # Method 1: Calculate pseudo inverse via SVD 141 | # For singular matrices, we invert the singular 142 | # values above certain threshold (computed with the max singular value) 143 | # this is slightly better than torch.pinverse when L2_param=0 144 | XRX_pinv = pinv_naive(XRX) 145 | 146 | # method 2 147 | # XRX_pinv = torch.pinverse(XRX) 148 | 149 | # w = w - (X^T R X)^(-1) X^T (mu-y) 150 | val = X.transpose() @ (mu - y) 151 | if L2_param > 0: 152 | val += L2_param * w_old 153 | 154 | w_update = XRX_pinv @ val 155 | w_new = w_old - w_update 156 | return w_new 157 | 158 | 159 | @logger.catch 160 | def train_IRLS(X_train, y_train, X_test=None, y_test=None, L2_param=0, max_iter=MAX_ITER): 161 | """train Logistic Regression via IRLS algorithm 162 | X: Nxd 163 | y: Nx1 164 | --- 165 | 166 | """ 167 | N, d = X_train.shape 168 | X_train = mge.tensor(X_train) 169 | X_test = mge.tensor(X_test) 170 | y_train = mge.tensor(y_train) 171 | y_test = mge.tensor(y_test) 172 | 173 | w = F.full((d, 1), 0.01) 174 | 175 | print("start training...") 176 | tic = time.time() 177 | print("Device: {}".format(device)) 178 | print("L2 param(lambda): {}".format(L2_param)) 179 | i = 0 180 | # iteration 181 | while i <= max_iter: 182 | print("iter: {}".format(i)) 183 | 184 | neg_L = neg_log_likelihood(w, X_train, y_train, L2_param) 185 | print("\t neg log likelihood: {}".format(neg_L.sum())) 186 | 187 | train_acc = compute_acc(X_train, y_train, w) 188 | test_acc = compute_acc(X_test, y_test, w) 189 | print("\t train acc: {}, test acc: {}".format(train_acc, test_acc)) 190 | 191 | L2_norm_w = F.norm(w.flatten()) 192 | print("\t L2 norm of w: {}".format(L2_norm_w.item())) 193 | 194 | if i > 0: 195 | diff_w = F.norm((w - w_old_data).flatten()) 196 | print("\t diff of w_old and w: {}".format(diff_w.item())) 197 | if diff_w < 1e-2: 198 | break 199 | 200 | w_old_data = F.copy(w) 201 | w = update_weight(w, X_train, y_train, L2_param) 202 | i += 1 203 | print(f"training done, using {time.time() - tic}s.") 204 | # slower than pytorch 205 | 206 | 207 | if __name__ == "__main__": 208 | lambda_ = 20 # 0 209 | train_IRLS(X_train, y_train, X_test, y_test, L2_param=lambda_, max_iter=100) 210 | 211 | # from sklearn.linear_model import LogisticRegression 212 | # classifier = LogisticRegression() 213 | # classifier.fit(X_train, y_train.reshape(N_train,)) 214 | # y_pred_train = classifier.predict(X_train) 215 | # train_acc = np.sum(y_train.reshape(N_train,) == y_pred_train)/N_train 216 | # print('train_acc: {}'.format(train_acc)) 217 | # y_pred_test = classifier.predict(X_test) 218 | # test_acc = np.sum(y_test.reshape(N_test,) == y_pred_test)/N_test 219 | # print('test acc: {}'.format(test_acc)) 220 | -------------------------------------------------------------------------------- /src/IRLS_paddle.py: -------------------------------------------------------------------------------- 1 | # python 3 2 | from __future__ import print_function, division, absolute_import 3 | 4 | import os.path as osp 5 | import argparse 6 | import random 7 | from loguru import logger 8 | import numpy as np 9 | import time 10 | 11 | # from numpy import linalg 12 | import paddle 13 | 14 | 15 | from sklearn.datasets import load_svmlight_file 16 | from scipy.sparse import csr_matrix 17 | 18 | # from scipy.sparse import linalg 19 | import matplotlib 20 | 21 | matplotlib.use("Agg") 22 | import matplotlib.pyplot as plt 23 | 24 | cur_dir = osp.dirname(osp.abspath(__file__)) 25 | path_train = osp.join(cur_dir, "../a9a/a9a") 26 | path_test = osp.join(cur_dir, "../a9a/a9a.t") 27 | MAX_ITER = 100 28 | np_dtype = np.float32 29 | 30 | paddle.device.set_device("gpu") 31 | 32 | # manual seed 33 | manualSeed = random.randint(1, 10000) # fix seed 34 | print("Random Seed: ", manualSeed) 35 | random.seed(manualSeed) 36 | np.random.seed(manualSeed) 37 | 38 | # load all data 39 | X_train, y_train = load_svmlight_file(path_train, n_features=123, dtype=np_dtype) 40 | X_test, y_test = load_svmlight_file(path_test, n_features=123, dtype=np_dtype) 41 | # X: scipy.sparse.csr.csr_matrix 42 | 43 | # X_train: (32561, 123), y_train: (32561,) 44 | # X_test: (16281, 123), y_test:(16281,) 45 | 46 | # stack a dimension of ones to X to simplify computation 47 | N_train = X_train.shape[0] 48 | N_test = X_test.shape[0] 49 | X_train = np.hstack((np.ones((N_train, 1), dtype=np_dtype), X_train.toarray())) 50 | X_test = np.hstack((np.ones((N_test, 1), dtype=np_dtype), X_test.toarray())) 51 | 52 | # X_train = csr_matrix(X_train, dtype=dtype) 53 | # X_test = csr_matrix(X_test, dtype=dtype) 54 | 55 | y_train = y_train.reshape((N_train, 1)) 56 | y_test = y_test.reshape((N_test, 1)) 57 | 58 | # label: -1, +1 ==> 0, 1 59 | y_train = np.float32(np.where(y_train == -1, 0, 1)) 60 | y_test = np.float32(np.where(y_test == -1, 0, 1)) 61 | 62 | 63 | # NB: here X's shape is (N,d), which differs to the derivation 64 | 65 | 66 | def neg_log_likelihood(w, X, y, L2_param=None): 67 | """ 68 | w: dx1 69 | X: Nxd 70 | y: Nx1 71 | L2_param: \lambda>0, will introduce -\lambda/2 ||w||_2^2 72 | """ 73 | Xw = X.mm(w) 74 | res = paddle.mm(Xw.t(), y) - paddle.log(1 + Xw.exp()).sum() 75 | if L2_param != None and L2_param > 0: 76 | res += -0.5 * L2_param * paddle.mm(w.t(), w) 77 | return -res 78 | 79 | 80 | def prob(X, w): 81 | """ 82 | X: Nxd 83 | w: dx1 84 | --- 85 | prob: N x num_classes(2)""" 86 | Xw = X.mm(w) 87 | y = paddle.to_tensor([[0.0, 1.0]]) # 1x2 88 | return (Xw * y).exp() / (1 + Xw.exp()) # Nx2 89 | 90 | 91 | def compute_acc(X, y, w): 92 | p = prob(X, w) 93 | y_pred = paddle.argmax(p, 1) 94 | return (y.flatten() == y_pred).to("float32").mean() 95 | 96 | 97 | def pinv_naive(A): 98 | dtype = A.dtype 99 | U, S, Vh = paddle.linalg.svd(A, full_matrices=True) 100 | threshold = paddle.max(S) * 1e-5 101 | S_mask = S[S > threshold] 102 | S_pinv = paddle.concat([1.0 / S_mask, paddle.full([S.numel() - S_mask.numel()], 0.0, dtype=dtype)], 0) 103 | # A_pinv = V @ S_pinv.diag() @ U.t() 104 | A_pinv = Vh.t() @ S_pinv.diag() @ U.t() 105 | return A_pinv 106 | 107 | 108 | def update_weight(w_old, X, y, L2_param=0): 109 | """ 110 | w_new = w_old - w_update 111 | w_update = (X'RX+lambda*I)^(-1) (X'(mu-y) + lambda*w_old) 112 | lambda is L2_param 113 | 114 | w_old: dx1 115 | X: Nxd 116 | y: Nx1 117 | --- 118 | w_new: dx1 119 | """ 120 | mu = X.mm(w_old).sigmoid() # Nx1 121 | 122 | R_flat = mu * (1 - mu) # element-wise, Nx1 123 | 124 | XRX = paddle.mm(X.t(), R_flat.expand_as(X) * X) # dxd 125 | if L2_param > 0: 126 | # XRX.diagonal().add_(L2_param) 127 | # for i in range(XRX.shape[0]): 128 | # XRX[i, i] += L2_param 129 | XRX += L2_param * paddle.eye(XRX.shape[0]) 130 | 131 | # np.save('XRX_paddle.npy', XRX.cpu().numpy()) 132 | 133 | # Method 1: Calculate pseudo inverse via SVD 134 | # For singular matrices, we invert the singular 135 | # values above certain threshold (computed with the max singular value) 136 | # this is slightly better than paddle.linalg.pinv when L2_param=0 137 | XRX_pinv = pinv_naive(XRX) 138 | 139 | # method 2 140 | # XRX_pinv = paddle.linalg.pinv(XRX) 141 | 142 | # w = w - (X^T R X)^(-1) X^T (mu-y) 143 | val = paddle.mm(X.t(), mu - y) 144 | if L2_param > 0: 145 | val += L2_param * w_old 146 | 147 | w_update = paddle.mm(XRX_pinv, val) 148 | w_new = w_old - w_update 149 | return w_new 150 | 151 | 152 | @logger.catch 153 | def train_IRLS(X_train, y_train, X_test=None, y_test=None, L2_param=0, max_iter=MAX_ITER): 154 | """train Logistic Regression via IRLS algorithm 155 | X: Nxd 156 | y: Nx1 157 | --- 158 | 159 | """ 160 | N, d = X_train.shape 161 | X_train = paddle.to_tensor(X_train) 162 | X_test = paddle.to_tensor(X_test) 163 | y_train = paddle.to_tensor(y_train) 164 | y_test = paddle.to_tensor(y_test) 165 | 166 | w = paddle.full((d, 1), 0.01) 167 | 168 | print("start training...") 169 | tic = time.time() 170 | print("Device: {}".format(paddle.device.get_device())) 171 | print("L2 param(lambda): {}".format(L2_param)) 172 | i = 0 173 | # iteration 174 | while i <= max_iter: 175 | print("iter: {}".format(i)) 176 | 177 | neg_L = neg_log_likelihood(w, X_train, y_train, L2_param) 178 | print("\t neg log likelihood: {}".format(neg_L.sum())) 179 | 180 | train_acc = compute_acc(X_train, y_train, w) 181 | test_acc = compute_acc(X_test, y_test, w) 182 | print("\t train acc: {}, test acc: {}".format(train_acc, test_acc)) 183 | 184 | L2_norm_w = paddle.norm(w) 185 | print("\t L2 norm of w: {}".format(L2_norm_w.item())) 186 | 187 | if i > 0: 188 | diff_w = paddle.norm(w - w_old_data) 189 | print("\t diff of w_old and w: {}".format(diff_w.item())) 190 | if diff_w < 1e-2: 191 | break 192 | 193 | w_old_data = w.clone() 194 | w = update_weight(w, X_train, y_train, L2_param) 195 | i += 1 196 | print(f"training done, using {time.time() - tic}s.") 197 | 198 | 199 | if __name__ == "__main__": 200 | lambda_ = 20 # 0 201 | train_IRLS(X_train, y_train, X_test, y_test, L2_param=lambda_, max_iter=100) 202 | 203 | # from sklearn.linear_model import LogisticRegression 204 | # classifier = LogisticRegression() 205 | # classifier.fit(X_train, y_train.reshape(N_train,)) 206 | # y_pred_train = classifier.predict(X_train) 207 | # train_acc = np.sum(y_train.reshape(N_train,) == y_pred_train)/N_train 208 | # print('train_acc: {}'.format(train_acc)) 209 | # y_pred_test = classifier.predict(X_test) 210 | # test_acc = np.sum(y_test.reshape(N_test,) == y_pred_test)/N_test 211 | # print('test acc: {}'.format(test_acc)) 212 | -------------------------------------------------------------------------------- /src/IRLS_paddle3.py: -------------------------------------------------------------------------------- 1 | # python 3 2 | from __future__ import print_function, division, absolute_import 3 | 4 | import os.path as osp 5 | import argparse 6 | import random 7 | from loguru import logger 8 | import numpy as np 9 | import time 10 | 11 | # from numpy import linalg 12 | import paddle 13 | # python -m pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/ 14 | 15 | 16 | from sklearn.datasets import load_svmlight_file 17 | from scipy.sparse import csr_matrix 18 | 19 | # from scipy.sparse import linalg 20 | import matplotlib 21 | 22 | matplotlib.use("Agg") 23 | import matplotlib.pyplot as plt 24 | 25 | cur_dir = osp.dirname(osp.abspath(__file__)) 26 | path_train = osp.join(cur_dir, "../a9a/a9a") 27 | path_test = osp.join(cur_dir, "../a9a/a9a.t") 28 | MAX_ITER = 100 29 | np_dtype = np.float32 30 | 31 | paddle.device.set_device("gpu") 32 | 33 | # manual seed 34 | manualSeed = random.randint(1, 10000) # fix seed 35 | print("Random Seed: ", manualSeed) 36 | random.seed(manualSeed) 37 | np.random.seed(manualSeed) 38 | 39 | # load all data 40 | X_train, y_train = load_svmlight_file(path_train, n_features=123, dtype=np_dtype) 41 | X_test, y_test = load_svmlight_file(path_test, n_features=123, dtype=np_dtype) 42 | # X: scipy.sparse.csr.csr_matrix 43 | 44 | # X_train: (32561, 123), y_train: (32561,) 45 | # X_test: (16281, 123), y_test:(16281,) 46 | 47 | # stack a dimension of ones to X to simplify computation 48 | N_train = X_train.shape[0] 49 | N_test = X_test.shape[0] 50 | X_train = np.hstack((np.ones((N_train, 1), dtype=np_dtype), X_train.toarray())) 51 | X_test = np.hstack((np.ones((N_test, 1), dtype=np_dtype), X_test.toarray())) 52 | 53 | # X_train = csr_matrix(X_train, dtype=dtype) 54 | # X_test = csr_matrix(X_test, dtype=dtype) 55 | 56 | y_train = y_train.reshape((N_train, 1)) 57 | y_test = y_test.reshape((N_test, 1)) 58 | 59 | # label: -1, +1 ==> 0, 1 60 | y_train = np.float32(np.where(y_train == -1, 0, 1)) 61 | y_test = np.float32(np.where(y_test == -1, 0, 1)) 62 | 63 | 64 | # NB: here X's shape is (N,d), which differs to the derivation 65 | 66 | 67 | def neg_log_likelihood(w, X, y, L2_param=None): 68 | """ 69 | w: dx1 70 | X: Nxd 71 | y: Nx1 72 | L2_param: \lambda>0, will introduce -\lambda/2 ||w||_2^2 73 | """ 74 | Xw = X.mm(w) 75 | res = paddle.mm(Xw.t(), y) - paddle.log(1 + Xw.exp()).sum() 76 | if L2_param != None and L2_param > 0: 77 | res += -0.5 * L2_param * paddle.mm(w.t(), w) 78 | return -res 79 | 80 | 81 | def prob(X, w): 82 | """ 83 | X: Nxd 84 | w: dx1 85 | --- 86 | prob: N x num_classes(2)""" 87 | Xw = X.mm(w) 88 | y = paddle.to_tensor([[0.0, 1.0]]) # 1x2 89 | return (Xw * y).exp() / (1 + Xw.exp()) # Nx2 90 | 91 | 92 | def compute_acc(X, y, w): 93 | p = prob(X, w) 94 | y_pred = paddle.argmax(p, 1) 95 | return (y.flatten() == y_pred.to("float32")).to("float32").mean() 96 | 97 | 98 | def pinv_naive(A): 99 | dtype = A.dtype 100 | U, S, Vh = paddle.linalg.svd(A, full_matrices=True) 101 | threshold = paddle.max(S) * 1e-5 102 | S_mask = S[S > threshold] 103 | S_pinv = paddle.concat([1.0 / S_mask, paddle.full([S.numel() - S_mask.numel()], 0.0, dtype=dtype)], 0) 104 | # A_pinv = V @ S_pinv.diag() @ U.t() 105 | A_pinv = Vh.t() @ S_pinv.diag() @ U.t() 106 | return A_pinv 107 | 108 | 109 | def update_weight(w_old, X, y, L2_param=0): 110 | """ 111 | w_new = w_old - w_update 112 | w_update = (X'RX+lambda*I)^(-1) (X'(mu-y) + lambda*w_old) 113 | lambda is L2_param 114 | 115 | w_old: dx1 116 | X: Nxd 117 | y: Nx1 118 | --- 119 | w_new: dx1 120 | """ 121 | mu = X.mm(w_old).sigmoid() # Nx1 122 | 123 | R_flat = mu * (1 - mu) # element-wise, Nx1 124 | 125 | XRX = paddle.mm(X.t(), R_flat.expand_as(X) * X) # dxd 126 | if L2_param > 0: 127 | # XRX.diagonal().add_(L2_param) 128 | # for i in range(XRX.shape[0]): 129 | # XRX[i, i] += L2_param 130 | XRX += L2_param * paddle.eye(XRX.shape[0]) 131 | 132 | # np.save('XRX_paddle.npy', XRX.cpu().numpy()) 133 | 134 | # Method 1: Calculate pseudo inverse via SVD 135 | # For singular matrices, we invert the singular 136 | # values above certain threshold (computed with the max singular value) 137 | # this is slightly better than paddle.linalg.pinv when L2_param=0 138 | XRX_pinv = pinv_naive(XRX) 139 | 140 | # method 2 141 | # XRX_pinv = paddle.linalg.pinv(XRX) 142 | 143 | # w = w - (X^T R X)^(-1) X^T (mu-y) 144 | val = paddle.mm(X.t(), mu - y) 145 | if L2_param > 0: 146 | val += L2_param * w_old 147 | 148 | w_update = paddle.mm(XRX_pinv, val) 149 | w_new = w_old - w_update 150 | return w_new 151 | 152 | 153 | @logger.catch 154 | def train_IRLS(X_train, y_train, X_test=None, y_test=None, L2_param=0, max_iter=MAX_ITER): 155 | """train Logistic Regression via IRLS algorithm 156 | X: Nxd 157 | y: Nx1 158 | --- 159 | 160 | """ 161 | N, d = X_train.shape 162 | X_train = paddle.to_tensor(X_train) 163 | X_test = paddle.to_tensor(X_test) 164 | y_train = paddle.to_tensor(y_train) 165 | y_test = paddle.to_tensor(y_test) 166 | 167 | w = paddle.full((d, 1), 0.01) 168 | 169 | print("start training...") 170 | tic = time.time() 171 | print("Device: {}".format(paddle.device.get_device())) 172 | print("L2 param(lambda): {}".format(L2_param)) 173 | i = 0 174 | # iteration 175 | while i <= max_iter: 176 | print("iter: {}".format(i)) 177 | 178 | neg_L = neg_log_likelihood(w, X_train, y_train, L2_param) 179 | print("\t neg log likelihood: {}".format(neg_L.sum())) 180 | 181 | train_acc = compute_acc(X_train, y_train, w) 182 | test_acc = compute_acc(X_test, y_test, w) 183 | print("\t train acc: {}, test acc: {}".format(train_acc, test_acc)) 184 | 185 | L2_norm_w = paddle.norm(w) 186 | print("\t L2 norm of w: {}".format(L2_norm_w.item())) 187 | 188 | if i > 0: 189 | diff_w = paddle.norm(w - w_old_data) 190 | print("\t diff of w_old and w: {}".format(diff_w.item())) 191 | if diff_w < 1e-2: 192 | break 193 | 194 | w_old_data = w.clone() 195 | w = update_weight(w, X_train, y_train, L2_param) 196 | i += 1 197 | print(f"training done, using {time.time() - tic}s.") 198 | 199 | 200 | if __name__ == "__main__": 201 | lambda_ = 20 # 0 202 | train_IRLS(X_train, y_train, X_test, y_test, L2_param=lambda_, max_iter=100) 203 | 204 | # from sklearn.linear_model import LogisticRegression 205 | # classifier = LogisticRegression() 206 | # classifier.fit(X_train, y_train.reshape(N_train,)) 207 | # y_pred_train = classifier.predict(X_train) 208 | # train_acc = np.sum(y_train.reshape(N_train,) == y_pred_train)/N_train 209 | # print('train_acc: {}'.format(train_acc)) 210 | # y_pred_test = classifier.predict(X_test) 211 | # test_acc = np.sum(y_test.reshape(N_test,) == y_pred_test)/N_test 212 | # print('test acc: {}'.format(test_acc)) 213 | 214 | # L2 param(lambda): 20 215 | # iter: 0 216 | # neg log likelihood: 23909.580078125 217 | # train acc: 0.2408095747232437, test acc: 0.23622629046440125 218 | # L2 norm of w: 0.11135528981685638 219 | # iter: 1 220 | # neg log likelihood: 12469.4580078125 221 | # train acc: 0.8444765210151672, test acc: 0.8458327054977417 222 | # L2 norm of w: 2.1622190475463867 223 | # diff of w_old and w: 2.169222354888916 224 | # iter: 2 225 | # neg log likelihood: 11109.1318359375 226 | # train acc: 0.8469641208648682, test acc: 0.8494566082954407 227 | # L2 norm of w: 3.1962289810180664 228 | # diff of w_old and w: 1.1507704257965088 229 | # iter: 3 230 | # neg log likelihood: 10792.71875 231 | # train acc: 0.8475170135498047, test acc: 0.8501322269439697 232 | # L2 norm of w: 3.8568596839904785 233 | # diff of w_old and w: 0.854336678981781 234 | # iter: 4 235 | # neg log likelihood: 10746.015625 236 | # train acc: 0.8479468822479248, test acc: 0.8507464528083801 237 | # L2 norm of w: 4.167862892150879 238 | # diff of w_old and w: 0.49177077412605286 239 | # iter: 5 240 | # neg log likelihood: 10743.85546875 241 | # train acc: 0.8479161858558655, test acc: 0.8506850004196167 242 | # L2 norm of w: 4.240683078765869 243 | # diff of w_old and w: 0.14320610463619232 244 | # iter: 6 245 | # neg log likelihood: 10743.845703125 246 | # train acc: 0.8479161858558655, test acc: 0.8507464528083801 247 | # L2 norm of w: 4.244931221008301 248 | # diff of w_old and w: 0.010079155676066875 249 | # iter: 7 250 | # neg log likelihood: 10743.8466796875 251 | # train acc: 0.8479161858558655, test acc: 0.8507464528083801 252 | # L2 norm of w: 4.244947910308838 253 | # diff of w_old and w: 4.5704247895628214e-05 254 | # training done, using 0.15619826316833496s. 255 | -------------------------------------------------------------------------------- /src/IRLS_pytorch.py: -------------------------------------------------------------------------------- 1 | # python 3 2 | from __future__ import print_function, division, absolute_import 3 | 4 | import os.path as osp 5 | import argparse 6 | import random 7 | from loguru import logger 8 | import numpy as np 9 | import time 10 | 11 | # from numpy import linalg 12 | import torch 13 | import torch.backends.cudnn as cudnn 14 | 15 | cudnn.benchmark = True 16 | 17 | from sklearn.datasets import load_svmlight_file 18 | from scipy.sparse import csr_matrix 19 | 20 | # from scipy.sparse import linalg 21 | import matplotlib 22 | 23 | matplotlib.use("Agg") 24 | import matplotlib.pyplot as plt 25 | 26 | cur_dir = osp.dirname(osp.abspath(__file__)) 27 | path_train = osp.join(cur_dir, "../a9a/a9a") 28 | path_test = osp.join(cur_dir, "../a9a/a9a.t") 29 | MAX_ITER = 100 30 | np_dtype = np.float32 31 | 32 | device = torch.device("cpu") 33 | if torch.cuda.is_available(): 34 | device = torch.device("cuda") 35 | 36 | 37 | # manual seed 38 | manualSeed = random.randint(1, 10000) # fix seed 39 | print("Random Seed: ", manualSeed) 40 | random.seed(manualSeed) 41 | np.random.seed(manualSeed) 42 | 43 | # load all data 44 | X_train, y_train = load_svmlight_file(path_train, n_features=123, dtype=np_dtype) 45 | X_test, y_test = load_svmlight_file(path_test, n_features=123, dtype=np_dtype) 46 | # X: scipy.sparse.csr.csr_matrix 47 | 48 | # X_train: (32561, 123), y_train: (32561,) 49 | # X_test: (16281, 123), y_test:(16281,) 50 | 51 | # stack a dimension of ones to X to simplify computation 52 | N_train = X_train.shape[0] 53 | N_test = X_test.shape[0] 54 | X_train = np.hstack((np.ones((N_train, 1), dtype=np_dtype), X_train.toarray())) 55 | X_test = np.hstack((np.ones((N_test, 1), dtype=np_dtype), X_test.toarray())) 56 | 57 | # X_train = csr_matrix(X_train, dtype=dtype) 58 | # X_test = csr_matrix(X_test, dtype=dtype) 59 | 60 | y_train = y_train.reshape((N_train, 1)) 61 | y_test = y_test.reshape((N_test, 1)) 62 | 63 | # label: -1, +1 ==> 0, 1 64 | y_train = np.float32(np.where(y_train == -1, 0, 1)) 65 | y_test = np.float32(np.where(y_test == -1, 0, 1)) 66 | 67 | 68 | # NB: here X's shape is (N,d), which differs to the derivation 69 | 70 | 71 | def neg_log_likelihood(w, X, y, L2_param=None): 72 | """ 73 | w: dx1 74 | X: Nxd 75 | y: Nx1 76 | L2_param: \lambda>0, will introduce -\lambda/2 ||w||_2^2 77 | """ 78 | Xw = X.mm(w) 79 | res = torch.mm(Xw.t(), y) - torch.log(1 + Xw.exp()).sum() 80 | if L2_param != None and L2_param > 0: 81 | res += -0.5 * L2_param * torch.mm(w.t(), w) 82 | return -res 83 | 84 | 85 | def prob(X, w): 86 | """ 87 | X: Nxd 88 | w: dx1 89 | --- 90 | prob: N x num_classes(2)""" 91 | Xw = X.mm(w) 92 | y = torch.as_tensor([[0.0, 1.0]], device=device) # 1x2 93 | return (Xw * y).exp() / (1 + Xw.exp()) # Nx2 94 | 95 | 96 | def compute_acc(X, y, w): 97 | p = prob(X, w) 98 | y_pred = torch.argmax(p, 1).to(y) 99 | return (y.flatten() == y_pred).float().mean() 100 | 101 | 102 | def pinv_naive(A): 103 | device = A.device 104 | dtype = A.dtype 105 | # U, S, V = torch.svd(A, some=False) 106 | U, S, Vh = torch.linalg.svd(A, full_matrices=True) 107 | threshold = torch.max(S) * 1e-5 108 | # S_pinv = torch.where(S > threshold, 1/S, torch.zeros_like(S)) 109 | S_mask = S[S > threshold] 110 | S_pinv = torch.cat([1.0 / S_mask, torch.full([S.numel() - S_mask.numel()], 0.0, device=device, dtype=dtype)], 0) 111 | # A_pinv = V @ S_pinv.diag() @ U.t() 112 | A_pinv = Vh.t() @ S_pinv.diag() @ U.t() 113 | return A_pinv 114 | 115 | 116 | def update_weight(w_old, X, y, L2_param=0): 117 | """ 118 | w_new = w_old - w_update 119 | w_update = (X'RX+lambda*I)^(-1) (X'(mu-y) + lambda*w_old) 120 | lambda is L2_param 121 | 122 | w_old: dx1 123 | X: Nxd 124 | y: Nx1 125 | --- 126 | w_new: dx1 127 | """ 128 | mu = X.mm(w_old).sigmoid() # Nx1 129 | 130 | R_flat = mu * (1 - mu) # element-wise, Nx1 131 | 132 | XRX = torch.mm(X.t(), R_flat.expand_as(X) * X) # dxd 133 | if L2_param > 0: 134 | # XRX.diagonal().add_(L2_param) 135 | XRX += L2_param * torch.eye(XRX.shape[0], device=device) 136 | 137 | # np.save('XRX_pytorch.npy', XRX.cpu().numpy()) 138 | 139 | # Method 1: Calculate pseudo inverse via SVD 140 | # For singular matrices, we invert the singular 141 | # values above certain threshold (computed with the max singular value) 142 | # this is slightly better than torch.pinverse when L2_param=0 143 | XRX_pinv = pinv_naive(XRX) 144 | 145 | # method 2 146 | # XRX_pinv = torch.pinverse(XRX) 147 | 148 | # w = w - (X^T R X)^(-1) X^T (mu-y) 149 | val = torch.mm(X.t(), mu - y) 150 | if L2_param > 0: 151 | val += L2_param * w_old 152 | 153 | w_update = torch.mm(XRX_pinv, val) 154 | w_new = w_old - w_update 155 | return w_new 156 | 157 | 158 | @logger.catch 159 | def train_IRLS(X_train, y_train, X_test=None, y_test=None, L2_param=0, max_iter=MAX_ITER): 160 | """train Logistic Regression via IRLS algorithm 161 | X: Nxd 162 | y: Nx1 163 | --- 164 | 165 | """ 166 | N, d = X_train.shape 167 | X_train = torch.as_tensor(X_train, device=device) 168 | X_test = torch.as_tensor(X_test, device=device) 169 | y_train = torch.as_tensor(y_train, device=device) 170 | y_test = torch.as_tensor(y_test, device=device) 171 | 172 | w = torch.full((d, 1), 0.01, device=device) 173 | 174 | print("start training...") 175 | tic = time.time() 176 | print("Device: {}".format(device)) 177 | print("L2 param(lambda): {}".format(L2_param)) 178 | i = 0 179 | # iteration 180 | while i <= max_iter: 181 | print("iter: {}".format(i)) 182 | 183 | neg_L = neg_log_likelihood(w, X_train, y_train, L2_param) 184 | print("\t neg log likelihood: {}".format(neg_L.sum())) 185 | 186 | train_acc = compute_acc(X_train, y_train, w) 187 | test_acc = compute_acc(X_test, y_test, w) 188 | print("\t train acc: {}, test acc: {}".format(train_acc, test_acc)) 189 | 190 | L2_norm_w = torch.norm(w) 191 | print("\t L2 norm of w: {}".format(L2_norm_w.item())) 192 | 193 | if i > 0: 194 | diff_w = torch.norm(w - w_old_data) 195 | print("\t diff of w_old and w: {}".format(diff_w.item())) 196 | if diff_w < 1e-2: 197 | break 198 | 199 | w_old_data = w.clone() 200 | w = update_weight(w, X_train, y_train, L2_param) 201 | i += 1 202 | print(f"training done, using {time.time() - tic}s.") 203 | 204 | 205 | if __name__ == "__main__": 206 | lambda_ = 20 # 0 207 | train_IRLS(X_train, y_train, X_test, y_test, L2_param=lambda_, max_iter=100) 208 | 209 | # from sklearn.linear_model import LogisticRegression 210 | # classifier = LogisticRegression() 211 | # classifier.fit(X_train, y_train.reshape(N_train,)) 212 | # y_pred_train = classifier.predict(X_train) 213 | # train_acc = np.sum(y_train.reshape(N_train,) == y_pred_train)/N_train 214 | # print('train_acc: {}'.format(train_acc)) 215 | # y_pred_test = classifier.predict(X_test) 216 | # test_acc = np.sum(y_test.reshape(N_test,) == y_pred_test)/N_test 217 | # print('test acc: {}'.format(test_acc)) 218 | 219 | # Device: cuda 220 | # L2 param(lambda): 20 221 | # iter: 0 222 | # neg log likelihood: 23909.580078125 223 | # train acc: 0.24080955982208252, test acc: 0.23622629046440125 224 | # L2 norm of w: 0.11135528236627579 225 | # iter: 1 226 | # neg log likelihood: 12469.443359375 227 | # train acc: 0.8444765210151672, test acc: 0.8458325862884521 228 | # L2 norm of w: 2.162217617034912 229 | # diff of w_old and w: 2.1692135334014893 230 | # iter: 2 231 | # neg log likelihood: 11109.1337890625 232 | # train acc: 0.8469641804695129, test acc: 0.8494564294815063 233 | # L2 norm of w: 3.1962287425994873 234 | # diff of w_old and w: 1.1507691144943237 235 | # iter: 3 236 | # neg log likelihood: 10792.7197265625 237 | # train acc: 0.8475169539451599, test acc: 0.8501321077346802 238 | # L2 norm of w: 3.8568592071533203 239 | # diff of w_old and w: 0.8543345928192139 240 | # iter: 4 241 | # neg log likelihood: 10746.015625 242 | # train acc: 0.8479469418525696, test acc: 0.8507463335990906 243 | # L2 norm of w: 4.1678619384765625 244 | # diff of w_old and w: 0.4917723834514618 245 | # iter: 5 246 | # neg log likelihood: 10743.853515625 247 | # train acc: 0.8479162454605103, test acc: 0.8506848812103271 248 | # L2 norm of w: 4.240682601928711 249 | # diff of w_old and w: 0.1432085931301117 250 | # iter: 6 251 | # neg log likelihood: 10743.8447265625 252 | # train acc: 0.8479162454605103, test acc: 0.8507463335990906 253 | # L2 norm of w: 4.244931221008301 254 | # diff of w_old and w: 0.010079198516905308 255 | # iter: 7 256 | # neg log likelihood: 10743.845703125 257 | # train acc: 0.8479162454605103, test acc: 0.8507463335990906 258 | # L2 norm of w: 4.244947910308838 259 | # diff of w_old and w: 4.562574395094998e-05 260 | # training done, using 0.14194560050964355s. -------------------------------------------------------------------------------- /src/IRLS_tf.py: -------------------------------------------------------------------------------- 1 | # python 3 2 | from __future__ import print_function, division, absolute_import 3 | 4 | import os 5 | import argparse 6 | import random 7 | import numpy as np 8 | import time 9 | 10 | # from numpy import linalg 11 | import os.path as osp 12 | 13 | cur_dir = osp.dirname(osp.abspath(__file__)) 14 | 15 | from sklearn.datasets import load_svmlight_file 16 | from scipy.sparse import csr_matrix 17 | 18 | # from scipy.sparse import linalg 19 | import matplotlib 20 | 21 | matplotlib.use("Agg") 22 | import matplotlib.pyplot as plt 23 | import tensorflow as tf 24 | 25 | path_train = osp.join(cur_dir, "../a9a/a9a") 26 | path_test = osp.join(cur_dir, "../a9a/a9a.t") 27 | MAX_ITER = 100 28 | np_dtype = np.float32 29 | tf_dtype = tf.float32 30 | 31 | # manual seed 32 | manualSeed = random.randint(1, 10000) # fix seed 33 | print("Random Seed: ", manualSeed) 34 | random.seed(manualSeed) 35 | np.random.seed(manualSeed) 36 | 37 | # load all data 38 | X_train, y_train = load_svmlight_file(path_train, n_features=123, dtype=np_dtype) 39 | X_test, y_test = load_svmlight_file(path_test, n_features=123, dtype=np_dtype) 40 | # X: scipy.sparse.csr.csr_matrix 41 | 42 | # X_train: (32561, 123), y_train: (32561,) 43 | # X_test: (16281, 123), y_test:(16281,) 44 | 45 | # stack a dimension of ones to X to simplify computation 46 | N_train = X_train.shape[0] 47 | N_test = X_test.shape[0] 48 | X_train = np.hstack((np.ones((N_train, 1)), X_train.toarray())) 49 | X_test = np.hstack((np.ones((N_test, 1)), X_test.toarray())) 50 | 51 | 52 | y_train = y_train.reshape((N_train, 1)) 53 | y_test = y_test.reshape((N_test, 1)) 54 | 55 | # label: -1, +1 ==> 0, 1 56 | y_train = np.where(y_train == -1, 0, 1) 57 | y_test = np.where(y_test == -1, 0, 1) 58 | 59 | # NB: here X's shape is (N,d), which differs to the derivation 60 | 61 | 62 | # def sigmoid(v, a=1): 63 | # ''' 64 | # 1./(1+exp(-a*v)) 65 | # v: input, can be a ndarray, in this case, sigmoid is applied element-wise 66 | # ''' 67 | # res = np.zeros(v.shape, dtype=dtype) 68 | # res = np.where(a*v>=0, 1./(1+np.exp(-a*v)), np.exp(a*v)/(1 + np.exp(a*v))) 69 | # return res #1./(1+np.exp(-a*v)) 70 | 71 | 72 | def neg_log_likelihood(w, X, y, L2_param=None): 73 | """ 74 | w: dx1 75 | X: Nxd 76 | y: Nx1 77 | L2_param: \lambda>0, will introduce -\lambda/2 ||w||_2^2 78 | """ 79 | res = tf.matmul(tf.matmul(tf.transpose(w), tf.transpose(X)), y) - tf.reduce_sum(tf.log(1 + tf.exp(tf.matmul(X, w)))) 80 | if L2_param != None and L2_param > 0: 81 | res += -0.5 * L2_param * tf.matmul(tf.transpose(w), w) 82 | return -res[0][0] 83 | 84 | 85 | def prob(X, w): 86 | """ 87 | X: Nxd 88 | w: dx1 89 | --- 90 | prob: N x num_classes(2)""" 91 | y = tf.constant(np.array([0.0, 1.0]), dtype=tf.float32) 92 | prob = tf.exp(tf.matmul(X, w) * y) / (1 + tf.exp(tf.matmul(X, w))) 93 | return prob 94 | 95 | 96 | def compute_acc(X, y, w): 97 | p = prob(X, w) 98 | y_pred = tf.cast(tf.argmax(p, axis=1), tf.float32) 99 | y = tf.squeeze(y) 100 | acc = tf.reduce_mean(tf.cast(tf.equal(y, y_pred), tf.float32)) 101 | return acc 102 | 103 | 104 | def update(w_old, X, y, L2_param=0): 105 | """ 106 | w_new = w_old - w_update 107 | w_update = (X'RX+lambda*I)^(-1) (X'(mu-y) + lambda*w_old) 108 | lambda is L2_param 109 | 110 | w_old: dx1 111 | X: Nxd 112 | y: Nx1 113 | --- 114 | w_update: dx1 115 | """ 116 | d = X.shape.as_list()[1] 117 | mu = tf.sigmoid(tf.matmul(X, w_old)) # Nx1 118 | 119 | R_flat = mu * (1 - mu) # element-wise, Nx1 120 | 121 | L2_reg_term = L2_param * tf.eye(d) 122 | XRX = tf.matmul(tf.transpose(X), R_flat * X) + L2_reg_term # dxd 123 | 124 | S, U, V = tf.svd(XRX, full_matrices=True, compute_uv=True) 125 | S = tf.expand_dims(S, 1) 126 | 127 | # calculate pseudo inverse via SVD 128 | S_pinv = tf.where(tf.not_equal(S, 0), 1 / S, tf.zeros_like(S)) # not good, will produce inf when divide by 0 129 | XRX_pinv = tf.matmul(V, S_pinv * tf.transpose(U)) 130 | 131 | # w = w - (X^T R X)^(-1) X^T (mu-y) 132 | # w_new = tf.assign(w_old, w_old - tf.matmul(tf.matmul(XRX_pinv, tf.transpose(X)), mu - y)) 133 | w_update = tf.matmul(XRX_pinv, tf.matmul(tf.transpose(X), mu - y) + L2_param * w_old) 134 | return w_update 135 | 136 | 137 | def optimize(w_old, w_update): 138 | """custom update op, instead of using SGD variants""" 139 | return w_old.assign(w_old - w_update) 140 | 141 | 142 | def train_IRLS(X_train, y_train, X_test=None, y_test=None, L2_param=0, max_iter=MAX_ITER): 143 | """train Logistic Regression via IRLS algorithm 144 | X: Nxd 145 | y: Nx1 146 | --- 147 | 148 | """ 149 | N, d = X_train.shape 150 | X = tf.placeholder(dtype=tf.float32, shape=(None, 124), name="X") 151 | y = tf.placeholder(dtype=tf.float32, shape=(None, 1), name="y") 152 | 153 | w = tf.Variable(0.01 * tf.ones((d, 1), dtype=tf.float32), name="w") 154 | w_update = update(w, X, y, L2_param) 155 | with tf.variable_scope("neg_L"): 156 | neg_L = neg_log_likelihood(w, X, y, L2_param) 157 | neg_L_summ = tf.summary.scalar("neg_L", neg_L) 158 | 159 | with tf.variable_scope("accuracy"): 160 | acc = compute_acc(X, y, w) 161 | acc_summ = tf.summary.scalar("acc", acc) 162 | 163 | optimize_op = optimize(w, w_update) 164 | 165 | merged_all = tf.summary.merge_all() 166 | 167 | config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 168 | config.gpu_options.allow_growth = True 169 | config.gpu_options.per_process_gpu_memory_fraction = 0.8 170 | 171 | sess = tf.Session(config=config) 172 | sess.run(tf.global_variables_initializer()) 173 | summary_writer = tf.summary.FileWriter("./log", sess.graph) 174 | 175 | train_feed_dict = {X: X_train, y: y_train} 176 | test_feed_dict = {X: X_test, y: y_test} 177 | 178 | print("start training...") 179 | print("L2 param(lambda): {}".format(L2_param)) 180 | i = 0 181 | # iteration 182 | w_old = w 183 | while i <= max_iter: 184 | print("iter: {}".format(i)) 185 | 186 | print("\t neg log likelihood: {}".format(sess.run(neg_L, feed_dict=train_feed_dict))) 187 | 188 | train_acc, merged = sess.run([acc, merged_all], feed_dict=train_feed_dict) 189 | summary_writer.add_summary(merged, i) 190 | 191 | test_acc = sess.run(acc, feed_dict=test_feed_dict) 192 | print("\t train acc: {}, test acc: {}".format(train_acc, test_acc)) 193 | 194 | L2_norm_w = np.linalg.norm(sess.run(w)) 195 | print("\t L2 norm of w: {}".format(L2_norm_w)) 196 | 197 | if i > 0: 198 | diff_w = np.linalg.norm(sess.run(w_update, feed_dict=train_feed_dict)) 199 | print("\t diff of w_old and w: {}".format(diff_w)) 200 | if diff_w < 1e-2: 201 | break 202 | 203 | w_new = sess.run(optimize_op, feed_dict=train_feed_dict) 204 | i += 1 205 | print("training done.") 206 | 207 | 208 | if __name__ == "__main__": 209 | lambda_ = 20 # 0 210 | train_IRLS(X_train, y_train, X_test, y_test, L2_param=lambda_, max_iter=100) 211 | 212 | # from sklearn.linear_model import LogisticRegression 213 | # classifier = LogisticRegression() 214 | # classifier.fit(X_train, y_train.reshape(N_train,)) 215 | # y_pred_train = classifier.predict(X_train) 216 | # train_acc = np.sum(y_train.reshape(N_train,) == y_pred_train)/N_train 217 | # print('train_acc: {}'.format(train_acc)) 218 | # y_pred_test = classifier.predict(X_test) 219 | # test_acc = np.sum(y_test.reshape(N_test,) == y_pred_test)/N_test 220 | # print('test acc: {}'.format(test_acc)) 221 | -------------------------------------------------------------------------------- /src/IRLS_tf_v2.py: -------------------------------------------------------------------------------- 1 | # python 3 2 | # tensorflow 2 3 | # pip install tensorflow # for both cpu and gpu 4 | from __future__ import print_function, division, absolute_import 5 | 6 | import os 7 | import argparse 8 | import random 9 | import numpy as np 10 | import time 11 | import datetime 12 | 13 | # from numpy import linalg 14 | import os.path as osp 15 | import sys 16 | 17 | cur_dir = osp.dirname(osp.abspath(__file__)) 18 | sys.path.insert(1, osp.join(cur_dir, ".")) 19 | from sklearn.datasets import load_svmlight_file 20 | from scipy.sparse import csr_matrix 21 | 22 | # from scipy.sparse import linalg 23 | import matplotlib 24 | 25 | matplotlib.use("Agg") 26 | import matplotlib.pyplot as plt 27 | import tensorflow as tf 28 | from tf_utils import pinv_naive, pinv 29 | 30 | 31 | path_train = osp.join(cur_dir, "../a9a/a9a") 32 | path_test = osp.join(cur_dir, "../a9a/a9a.t") 33 | MAX_ITER = 100 34 | np_dtype = np.float32 35 | tf_dtype = tf.float32 36 | 37 | # manual seed 38 | manualSeed = random.randint(1, 10000) # fix seed 39 | print("Random Seed: ", manualSeed) 40 | random.seed(manualSeed) 41 | np.random.seed(manualSeed) 42 | 43 | # load all data 44 | X_train, y_train = load_svmlight_file(path_train, n_features=123, dtype=np_dtype) 45 | X_test, y_test = load_svmlight_file(path_test, n_features=123, dtype=np_dtype) 46 | # X: scipy.sparse.csr.csr_matrix 47 | 48 | # X_train: (32561, 123), y_train: (32561,) 49 | # X_test: (16281, 123), y_test:(16281,) 50 | 51 | # stack a dimension of ones to X to simplify computation 52 | N_train = X_train.shape[0] 53 | N_test = X_test.shape[0] 54 | X_train = np.hstack((np.ones((N_train, 1)), X_train.toarray())).astype(np_dtype) 55 | X_test = np.hstack((np.ones((N_test, 1)), X_test.toarray())).astype(np_dtype) 56 | # print(X_train.shape, X_test.shape) 57 | 58 | y_train = y_train.reshape((N_train, 1)) 59 | y_test = y_test.reshape((N_test, 1)) 60 | 61 | # label: -1, +1 ==> 0, 1 62 | y_train = np.where(y_train == -1, 0, 1) 63 | y_test = np.where(y_test == -1, 0, 1) 64 | 65 | # NB: here X's shape is (N,d), which differs to the derivation 66 | 67 | 68 | def neg_log_likelihood(w, X, y, L2_param=None): 69 | """ 70 | w: dx1 71 | X: Nxd 72 | y: Nx1 73 | L2_param: \lambda>0, will introduce -\lambda/2 ||w||_2^2 74 | """ 75 | # print(type(X), X.dtype) 76 | res = tf.matmul(tf.matmul(tf.transpose(w), tf.transpose(X)), y.astype(np_dtype)) - tf.reduce_sum( 77 | tf.math.log(1 + tf.exp(tf.matmul(X, w))) 78 | ) 79 | if L2_param != None and L2_param > 0: 80 | res += -0.5 * L2_param * tf.matmul(tf.transpose(w), w) 81 | return -res[0][0] 82 | 83 | 84 | def prob(X, w): 85 | """ 86 | X: Nxd 87 | w: dx1 88 | --- 89 | prob: N x num_classes(2)""" 90 | y = tf.constant(np.array([0.0, 1.0]), dtype=tf.float32) 91 | prob = tf.exp(tf.matmul(X, w) * y) / (1 + tf.exp(tf.matmul(X, w))) 92 | return prob 93 | 94 | 95 | def compute_acc(X, y, w): 96 | p = prob(X, w) 97 | y_pred = tf.cast(tf.argmax(p, axis=1), tf.float32) 98 | y = tf.cast(tf.squeeze(y), tf.float32) 99 | acc = tf.reduce_mean(tf.cast(tf.equal(y, y_pred), tf.float32)) 100 | return acc 101 | 102 | 103 | def update(w_old, X, y, L2_param=0): 104 | """ 105 | w_new = w_old - w_update 106 | w_update = (X'RX+lambda*I)^(-1) (X'(mu-y) + lambda*w_old) 107 | lambda is L2_param 108 | 109 | w_old: dx1 110 | X: Nxd 111 | y: Nx1 112 | --- 113 | w_update: dx1 114 | """ 115 | d = X.shape[1] 116 | mu = tf.sigmoid(tf.matmul(X, w_old)) # Nx1 117 | 118 | R_flat = mu * (1 - mu) # element-wise, Nx1 119 | 120 | L2_reg_term = L2_param * tf.eye(d) 121 | XRX = tf.matmul(tf.transpose(X), R_flat * X) + L2_reg_term # dxd 122 | # np.save('XRX_tf.npy', XRX.numpy()) 123 | 124 | # calculate pseudo inverse via SVD 125 | # method 1 126 | # slightly better than tfp.math.pinv when L2_param=0 127 | XRX_pinv = pinv_naive(XRX) 128 | 129 | # method 2 130 | # XRX_pinv = pinv(XRX) 131 | 132 | # w = w - (X^T R X)^(-1) X^T (mu-y) 133 | # w_new = tf.assign(w_old, w_old - tf.matmul(tf.matmul(XRX_pinv, tf.transpose(X)), mu - y)) 134 | y = tf.cast(y, tf_dtype) 135 | w_update = tf.matmul(XRX_pinv, tf.matmul(tf.transpose(X), mu - y) + L2_param * w_old) 136 | return w_update 137 | 138 | 139 | def optimize(w_old, w_update): 140 | """custom update op, instead of using SGD variants""" 141 | return w_old.assign(w_old - w_update) 142 | 143 | 144 | def train_IRLS(X_train, y_train, X_test=None, y_test=None, L2_param=0, max_iter=MAX_ITER): 145 | """train Logistic Regression via IRLS algorithm 146 | X: Nxd 147 | y: Nx1 148 | --- 149 | """ 150 | N, d = X_train.shape 151 | w = tf.Variable(0.01 * tf.ones((d, 1), dtype=tf.float32), name="w") 152 | current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 153 | summary_writer = tf.summary.create_file_writer(f"./logs/{current_time}") 154 | print("start training...") 155 | tic = time.time() 156 | print("L2 param(lambda): {}".format(L2_param)) 157 | i = 0 158 | # iteration 159 | while i <= max_iter: 160 | print("iter: {}".format(i)) 161 | 162 | # print('\t neg log likelihood: {}'.format(sess.run(neg_L, feed_dict=train_feed_dict))) 163 | neg_L = neg_log_likelihood(w, X_train, y_train, L2_param) 164 | print("\t neg log likelihood: {}".format(neg_L)) 165 | train_acc = compute_acc(X_train, y_train, w) 166 | with summary_writer.as_default(): 167 | tf.summary.scalar("train_acc", train_acc, step=i) 168 | tf.summary.scalar("train_neg_L", neg_L, step=i) 169 | 170 | test_acc = compute_acc(X_test, y_test, w) 171 | with summary_writer.as_default(): 172 | tf.summary.scalar("test_acc", test_acc, step=i) 173 | print("\t train acc: {}, test acc: {}".format(train_acc, test_acc)) 174 | 175 | L2_norm_w = np.linalg.norm(w.numpy()) 176 | print("\t L2 norm of w: {}".format(L2_norm_w)) 177 | 178 | if i > 0: 179 | diff_w = np.linalg.norm(w_update.numpy()) 180 | print("\t diff of w_old and w: {}".format(diff_w)) 181 | if diff_w < 1e-2: 182 | break 183 | w_update = update(w, X_train, y_train, L2_param) 184 | w = optimize(w, w_update) 185 | i += 1 186 | print(f"training done, using {time.time() - tic}s.") 187 | 188 | 189 | if __name__ == "__main__": 190 | # test_acc should be about 0.85 191 | lambda_ = 20 # 0 192 | train_IRLS(X_train, y_train, X_test, y_test, L2_param=lambda_, max_iter=100) 193 | 194 | # from sklearn.linear_model import LogisticRegression 195 | 196 | # classifier = LogisticRegression() 197 | # classifier.fit( 198 | # X_train, 199 | # y_train.reshape( 200 | # N_train, 201 | # ), 202 | # ) 203 | # y_pred_train = classifier.predict(X_train) 204 | # train_acc = ( 205 | # np.sum( 206 | # y_train.reshape( 207 | # N_train, 208 | # ) 209 | # == y_pred_train 210 | # ) 211 | # / N_train 212 | # ) 213 | # print("train_acc: {}".format(train_acc)) 214 | # y_pred_test = classifier.predict(X_test) 215 | # test_acc = ( 216 | # np.sum( 217 | # y_test.reshape( 218 | # N_test, 219 | # ) 220 | # == y_pred_test 221 | # ) 222 | # / N_test 223 | # ) 224 | # print("test acc: {}".format(test_acc)) 225 | -------------------------------------------------------------------------------- /src/tf_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def _maybe_validate_matrix(a, validate_args): 6 | """Checks that input is a `float` matrix.""" 7 | assertions = [] 8 | if not a.dtype.is_floating: 9 | raise TypeError("Input `a` must have `float`-like `dtype` " "(saw {}).".format(a.dtype.name)) 10 | if a.shape.ndims is not None: 11 | if a.shape.ndims < 2: 12 | raise ValueError("Input `a` must have at least 2 dimensions " "(saw: {}).".format(a.shape.ndims)) 13 | elif validate_args: 14 | assertions.append( 15 | tf.compat.v1.assert_rank_at_least(a, rank=2, message="Input `a` must have at least 2 dimensions.") 16 | ) 17 | return assertions 18 | 19 | 20 | def pinv_naive(a): 21 | """Returns the Moore-Penrose pseudo-inverse""" 22 | # dtype = a.dtype.as_numpy_dtype 23 | # S, U, V = tf.linalg.svd(a, full_matrices=True, compute_uv=True) 24 | # S = tf.expand_dims(S, 1) 25 | # 26 | # # calculate pseudo inverse via SVD 27 | # # not good, will produce inf when divide by 0 28 | # threshold = tf.reduce_max(S) * 1e-5 29 | # S = tf.where(S > threshold, S, tf.fill(tf.shape(input=S), np.array(np.inf, dtype))) 30 | # a_pinv = tf.matmul(V/S, tf.transpose(U)) 31 | # return a_pinv 32 | s, u, v = tf.linalg.svd(a) 33 | 34 | threshold = tf.reduce_max(s) * 1e-5 35 | s_mask = tf.boolean_mask(s, s > threshold) # s[s>threshold] 36 | s_inv = tf.linalg.diag(tf.concat([1.0 / s_mask, tf.zeros([tf.size(s) - tf.size(s_mask)])], 0)) 37 | 38 | return tf.matmul(v, tf.matmul(s_inv, tf.transpose(u))) 39 | 40 | 41 | def pinv(a, rcond=None, validate_args=False, name=None): 42 | """ 43 | https://github.com/tensorflow/probability/blob/d674d79bc8175bff2f415bf3b38a42f51ffc999c/tensorflow_probability/python/math/linalg.py 44 | Compute the Moore-Penrose pseudo-inverse of a matrix. 45 | Calculate the [generalized inverse of a matrix]( 46 | https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) using its 47 | singular-value decomposition (SVD) and including all large singular values. 48 | The pseudo-inverse of a matrix `A`, is defined as: "the matrix that 'solves' 49 | [the least-squares problem] `A @ x = b`," i.e., if `x_hat` is a solution, then 50 | `A_pinv` is the matrix such that `x_hat = A_pinv @ b`. It can be shown that if 51 | `U @ Sigma @ V.T = A` is the singular value decomposition of `A`, then 52 | `A_pinv = V @ inv(Sigma) U^T`. [(Strang, 1980)][1] 53 | This function is analogous to [`numpy.linalg.pinv`]( 54 | https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.pinv.html). 55 | It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the 56 | default `rcond` is `1e-15`. Here the default is 57 | `10. * max(num_rows, num_cols) * np.finfo(dtype).eps`. 58 | Args: 59 | a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be 60 | pseudo-inverted. 61 | rcond: `Tensor` of small singular value cutoffs. Singular values smaller 62 | (in modulus) than `rcond` * largest_singular_value (again, in modulus) are 63 | set to zero. Must broadcast against `tf.shape(a)[:-2]`. 64 | Default value: `10. * max(num_rows, num_cols) * np.finfo(a.dtype).eps`. 65 | validate_args: When `True`, additional assertions might be embedded in the 66 | graph. 67 | Default value: `False` (i.e., no graph assertions are added). 68 | name: Python `str` prefixed to ops created by this function. 69 | Default value: "pinv". 70 | Returns: 71 | a_pinv: The pseudo-inverse of input `a`. Has same shape as `a` except 72 | rightmost two dimensions are transposed. 73 | Raises: 74 | TypeError: if input `a` does not have `float`-like `dtype`. 75 | ValueError: if input `a` has fewer than 2 dimensions. 76 | #### Examples 77 | ```python 78 | import tensorflow as tf 79 | import tensorflow_probability as tfp 80 | a = tf.constant([[1., 0.4, 0.5], 81 | [0.4, 0.2, 0.25], 82 | [0.5, 0.25, 0.35]]) 83 | tf.matmul(tfp.math.pinv(a), a) 84 | # ==> array([[1., 0., 0.], 85 | [0., 1., 0.], 86 | [0., 0., 1.]], dtype=float32) 87 | a = tf.constant([[1., 0.4, 0.5, 1.], 88 | [0.4, 0.2, 0.25, 2.], 89 | [0.5, 0.25, 0.35, 3.]]) 90 | tf.matmul(tfp.math.pinv(a), a) 91 | # ==> array([[ 0.76, 0.37, 0.21, -0.02], 92 | [ 0.37, 0.43, -0.33, 0.02], 93 | [ 0.21, -0.33, 0.81, 0.01], 94 | [-0.02, 0.02, 0.01, 1. ]], dtype=float32) 95 | ``` 96 | #### References 97 | [1]: G. Strang. "Linear Algebra and Its Applications, 2nd Ed." Academic Press, 98 | Inc., 1980, pp. 139-142. 99 | """ 100 | with tf.compat.v1.name_scope(name, "pinv", [a, rcond]): 101 | a = tf.convert_to_tensor(value=a, name="a") 102 | 103 | assertions = _maybe_validate_matrix(a, validate_args) 104 | if assertions: 105 | with tf.control_dependencies(assertions): 106 | a = tf.identity(a) 107 | 108 | dtype = a.dtype.as_numpy_dtype 109 | 110 | if rcond is None: 111 | 112 | def get_dim_size(dim): 113 | if tf.compat.dimension_value(a.shape[dim]) is not None: 114 | return tf.compat.dimension_value(a.shape[dim]) 115 | return tf.shape(input=a)[dim] 116 | 117 | num_rows = get_dim_size(-2) 118 | num_cols = get_dim_size(-1) 119 | if isinstance(num_rows, int) and isinstance(num_cols, int): 120 | max_rows_cols = float(max(num_rows, num_cols)) 121 | else: 122 | max_rows_cols = tf.cast(tf.maximum(num_rows, num_cols), dtype) 123 | rcond = 10.0 * max_rows_cols * np.finfo(dtype).eps 124 | 125 | rcond = tf.convert_to_tensor(value=rcond, dtype=dtype, name="rcond") 126 | 127 | # Calculate pseudo inverse via SVD. 128 | # Note: if a is symmetric then u == v. (We might observe additional 129 | # performance by explicitly setting `v = u` in such cases.) 130 | [ 131 | singular_values, # Sigma 132 | left_singular_vectors, # U 133 | right_singular_vectors, # V 134 | ] = tf.linalg.svd(a, full_matrices=False, compute_uv=True) 135 | 136 | # Saturate small singular values to inf. This has the effect of make 137 | # `1. / s = 0.` while not resulting in `NaN` gradients. 138 | cutoff = rcond * tf.reduce_max(input_tensor=singular_values, axis=-1) 139 | singular_values = tf.where( 140 | singular_values > cutoff[..., tf.newaxis], 141 | singular_values, 142 | tf.fill(tf.shape(input=singular_values), np.array(np.inf, dtype)), 143 | ) 144 | 145 | # Although `a == tf.matmul(u, s * v, transpose_b=True)` we swap 146 | # `u` and `v` here so that `tf.matmul(pinv(A), A) = tf.eye()`, i.e., 147 | # a matrix inverse has "transposed" semantics. 148 | a_pinv = tf.matmul( 149 | right_singular_vectors / singular_values[..., tf.newaxis, :], 150 | left_singular_vectors, 151 | adjoint_b=True, 152 | ) 153 | 154 | if a.shape.ndims is not None: 155 | a_pinv.set_shape(a.shape[:-2].concatenate([a.shape[-1], a.shape[-2]])) 156 | 157 | return a_pinv 158 | --------------------------------------------------------------------------------