├── .gitignore ├── COPYING ├── MANIFEST.in ├── README.rst ├── doc └── benchmark_logistic.ipynb ├── example_logistic.py ├── pytron ├── __init__.py ├── src │ ├── blas │ │ ├── Makefile │ │ ├── blas.h │ │ ├── blasp.h │ │ ├── daxpy.c │ │ ├── ddot.c │ │ ├── dnrm2.c │ │ └── dscal.c │ ├── tron.cpp │ ├── tron.h │ ├── tron_helper.cpp │ └── tron_helper.h └── tron.pyx ├── setup.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.so 3 | .idea 4 | pytron/tron.cpp -------------------------------------------------------------------------------- /COPYING: -------------------------------------------------------------------------------- 1 | New BSD License 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are met: 5 | 6 | a. Redistributions of source code must retain the above copyright notice, 7 | this list of conditions and the following disclaimer. 8 | b. Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | c. Neither the name of the Scikit-learn Developers nor the names of 12 | its contributors may be used to endorse or promote products 13 | derived from this software without specific prior written 14 | permission. 15 | 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR 21 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 23 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 24 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 25 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 26 | OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH 27 | DAMAGE. 28 | 29 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.py 2 | recursive-include pytron *.cpp *.h -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | A Trust-Region Newton Method in Python 2 | ====================================== 3 | 4 | .. DANGER:: 5 | This is alpha quality software and still quite rough on the edges. 6 | Specifically the error management is still lacking (which means that 7 | if something goes wrong in the optimization you won't see an error 8 | message but just get garbage). These things are being worked out but 9 | we're not quite there yet. 10 | 11 | .. image:: http://fa.bianp.net/blog/static/images/2013/comparison_logistic_corr_10.png 12 | 13 | The main function is pytron.minimize:: 14 | 15 | def minimize(func, grad_hess, x0, args=(), max_iter=1000, tol=1e-6): 16 | 17 | Parameters 18 | ---------- 19 | func : callable 20 | func(w, *args) is the evaluation of the function at w, It 21 | should return a float. 22 | grad_hess: callable 23 | returns the gradient and a callable with the hessian times 24 | an arbitrary vector. 25 | tol: float 26 | stopping criterion. XXX TODO. what is the stopping criterion ? 27 | 28 | Returns 29 | ------- 30 | w : array 31 | 32 | 33 | 34 | Stopping criterion 35 | ------------------ 36 | 37 | It stops whenever ||grad(x)|| < eps or the maximum number of iterations is 38 | attained. 39 | 40 | TODO: add tol 41 | 42 | Examples 43 | -------- 44 | 45 | Code 46 | ---- 47 | This software uses the C++ implementation of `TRON optimization software 48 | `_ (files src/tron.{h,cpp}) 49 | distributed from the LIBLINEAR sources (v1.93), which is BSD licensed. 50 | Note that the original Fortran TRON implementation (available 51 | `here `_) is not open 52 | source and is not used in this project. 53 | 54 | The modifications with respect to the orginal code are: 55 | 56 | * Do not initialize values to zero, allow arbitrary initializations 57 | 58 | * Modify stopping criterion to comply with scipy.optimize API. Stop 59 | whenever gradient is smaller than a given quantity, specified in the 60 | gtol argument 61 | 62 | * Return the gradient from TRON::tron (pass by reference) 63 | 64 | * Add `tol` option to TRON 65 | 66 | * Rename `eps` to `gtol`. 67 | 68 | * Use infinity norm as stopping criterion for gradient instead of L2. 69 | 70 | TODO 71 | ---- 72 | * return status from TRON::TRON 73 | * callback argument 74 | 75 | 76 | References 77 | ---------- 78 | If you use the software please consider citing some of the references below. 79 | 80 | The method is described in the paper "Newton's Method for Large 81 | Bound-Constrained Optimization Problems", Chih-Jen Lin and Jorge J. Moré 82 | (http://epubs.siam.org/doi/abs/10.1137/S1052623498345075) 83 | 84 | It is also discussed in the contex of Logistic Regression in the paper "Trust 85 | Region Newton Method for Logistic Regression", Chih-Jen Lin, Ruby C. Weng, 86 | S. Sathiya Keerthi (http://dl.acm.org/citation.cfm?id=1390703) 87 | 88 | The website http://www.mcs.anl.gov/~more/tron/ contains reference to this 89 | implementation, although the links to the software seem to be currently 90 | broken (May 2013). 91 | 92 | 93 | License 94 | ------- 95 | This code is licensed under the terms of the BSD license. See file COPYING 96 | for more details. 97 | 98 | 99 | Acknowledgement 100 | --------------- 101 | The source code for the 102 | -------------------------------------------------------------------------------- /example_logistic.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | from pytron import minimize 4 | 5 | def phi(t): 6 | # helper function 7 | return 1. / (1 + np.exp(-t)) 8 | 9 | 10 | def loss(w, X, y, alpha): 11 | # loss function to be optimized, it's the logistic loss 12 | z = X.dot(w) 13 | yz = y * z 14 | idx = yz > 0 15 | out = np.zeros_like(yz) 16 | out[idx] = np.log(1 + np.exp(-yz[idx])) 17 | out[~idx] = (-yz[~idx] + np.log(1 + np.exp(yz[~idx]))) 18 | out = out.sum() + .5 * alpha * w.dot(w) 19 | return out 20 | 21 | 22 | def grad_hess(w, X, y, alpha): 23 | # gradient of the logistic loss 24 | z = X.dot(w) 25 | z = phi(y * z) 26 | z0 = (z - 1) * y 27 | grad = X.T.dot(z0) + alpha * w 28 | def Hs(s): 29 | d = z * (1 - z) 30 | wa = d * X.dot(s) 31 | return X.T.dot(wa) + alpha * s 32 | return grad, Hs 33 | 34 | 35 | 36 | if __name__ == '__main__': 37 | # set the data 38 | n_samples, n_features = 100, 10 39 | X = np.random.randn(n_samples, n_features) 40 | y = np.sign(X.dot(5 * np.random.randn(n_features))) 41 | alpha = 1. 42 | x0 = np.zeros(n_features) 43 | 44 | 45 | def callback(x0): 46 | print(loss(x0, X, y, alpha)) 47 | # call the solver 48 | res = minimize(loss, grad_hess, x0, args=(X, y, alpha), 49 | max_iter=15, gtol=1e-3, tol=1e-12, callback=callback) 50 | print(res) 51 | 52 | from sklearn import linear_model 53 | clf = linear_model.LogisticRegression(C=1./alpha, fit_intercept=False) 54 | clf.fit(X, y) 55 | 56 | print() 57 | print('Solution using TRON: %s' % res.x) 58 | print('Solution using scikit-learn: %s' % clf.coef_) -------------------------------------------------------------------------------- /pytron/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | __version__ = '0.3' 3 | from .tron import minimize 4 | -------------------------------------------------------------------------------- /pytron/src/blas/Makefile: -------------------------------------------------------------------------------- 1 | AR = ar rcv 2 | RANLIB = ranlib 3 | 4 | HEADERS = blas.h blasp.h 5 | FILES = dnrm2.o daxpy.o ddot.o dscal.o 6 | 7 | CFLAGS = $(OPTFLAGS) 8 | FFLAGS = $(OPTFLAGS) 9 | 10 | blas: $(FILES) $(HEADERS) 11 | $(AR) blas.a $(FILES) 12 | $(RANLIB) blas.a 13 | 14 | clean: 15 | - rm -f *.o 16 | - rm -f *.a 17 | - rm -f *~ 18 | 19 | .c.o: 20 | $(CC) $(CFLAGS) -c $*.c 21 | 22 | 23 | -------------------------------------------------------------------------------- /pytron/src/blas/blas.h: -------------------------------------------------------------------------------- 1 | /* blas.h -- C header file for BLAS Ver 1.0 */ 2 | /* Jesse Bennett March 23, 2000 */ 3 | 4 | /** barf [ba:rf] 2. "He suggested using FORTRAN, and everybody barfed." 5 | 6 | - From The Shogakukan DICTIONARY OF NEW ENGLISH (Second edition) */ 7 | 8 | #ifndef BLAS_INCLUDE 9 | #define BLAS_INCLUDE 10 | 11 | /* Data types specific to BLAS implementation */ 12 | typedef struct { float r, i; } fcomplex; 13 | typedef struct { double r, i; } dcomplex; 14 | typedef int blasbool; 15 | 16 | #include "blasp.h" /* Prototypes for all BLAS functions */ 17 | 18 | #define FALSE 0 19 | #define TRUE 1 20 | 21 | /* Macro functions */ 22 | #define MIN(a,b) ((a) <= (b) ? (a) : (b)) 23 | #define MAX(a,b) ((a) >= (b) ? (a) : (b)) 24 | 25 | #endif 26 | -------------------------------------------------------------------------------- /pytron/src/blas/blasp.h: -------------------------------------------------------------------------------- 1 | /* blasp.h -- C prototypes for BLAS Ver 1.0 */ 2 | /* Jesse Bennett March 23, 2000 */ 3 | 4 | /* Functions listed in alphabetical order */ 5 | 6 | #ifdef F2C_COMPAT 7 | 8 | void cdotc_(fcomplex *dotval, int *n, fcomplex *cx, int *incx, 9 | fcomplex *cy, int *incy); 10 | 11 | void cdotu_(fcomplex *dotval, int *n, fcomplex *cx, int *incx, 12 | fcomplex *cy, int *incy); 13 | 14 | double sasum_(int *n, float *sx, int *incx); 15 | 16 | double scasum_(int *n, fcomplex *cx, int *incx); 17 | 18 | double scnrm2_(int *n, fcomplex *x, int *incx); 19 | 20 | double sdot_(int *n, float *sx, int *incx, float *sy, int *incy); 21 | 22 | double snrm2_(int *n, float *x, int *incx); 23 | 24 | void zdotc_(dcomplex *dotval, int *n, dcomplex *cx, int *incx, 25 | dcomplex *cy, int *incy); 26 | 27 | void zdotu_(dcomplex *dotval, int *n, dcomplex *cx, int *incx, 28 | dcomplex *cy, int *incy); 29 | 30 | #else 31 | 32 | fcomplex cdotc_(int *n, fcomplex *cx, int *incx, fcomplex *cy, int *incy); 33 | 34 | fcomplex cdotu_(int *n, fcomplex *cx, int *incx, fcomplex *cy, int *incy); 35 | 36 | float sasum_(int *n, float *sx, int *incx); 37 | 38 | float scasum_(int *n, fcomplex *cx, int *incx); 39 | 40 | float scnrm2_(int *n, fcomplex *x, int *incx); 41 | 42 | float sdot_(int *n, float *sx, int *incx, float *sy, int *incy); 43 | 44 | float snrm2_(int *n, float *x, int *incx); 45 | 46 | dcomplex zdotc_(int *n, dcomplex *cx, int *incx, dcomplex *cy, int *incy); 47 | 48 | dcomplex zdotu_(int *n, dcomplex *cx, int *incx, dcomplex *cy, int *incy); 49 | 50 | #endif 51 | 52 | /* Remaining functions listed in alphabetical order */ 53 | 54 | int caxpy_(int *n, fcomplex *ca, fcomplex *cx, int *incx, fcomplex *cy, 55 | int *incy); 56 | 57 | int ccopy_(int *n, fcomplex *cx, int *incx, fcomplex *cy, int *incy); 58 | 59 | int cgbmv_(char *trans, int *m, int *n, int *kl, int *ku, 60 | fcomplex *alpha, fcomplex *a, int *lda, fcomplex *x, int *incx, 61 | fcomplex *beta, fcomplex *y, int *incy); 62 | 63 | int cgemm_(char *transa, char *transb, int *m, int *n, int *k, 64 | fcomplex *alpha, fcomplex *a, int *lda, fcomplex *b, int *ldb, 65 | fcomplex *beta, fcomplex *c, int *ldc); 66 | 67 | int cgemv_(char *trans, int *m, int *n, fcomplex *alpha, fcomplex *a, 68 | int *lda, fcomplex *x, int *incx, fcomplex *beta, fcomplex *y, 69 | int *incy); 70 | 71 | int cgerc_(int *m, int *n, fcomplex *alpha, fcomplex *x, int *incx, 72 | fcomplex *y, int *incy, fcomplex *a, int *lda); 73 | 74 | int cgeru_(int *m, int *n, fcomplex *alpha, fcomplex *x, int *incx, 75 | fcomplex *y, int *incy, fcomplex *a, int *lda); 76 | 77 | int chbmv_(char *uplo, int *n, int *k, fcomplex *alpha, fcomplex *a, 78 | int *lda, fcomplex *x, int *incx, fcomplex *beta, fcomplex *y, 79 | int *incy); 80 | 81 | int chemm_(char *side, char *uplo, int *m, int *n, fcomplex *alpha, 82 | fcomplex *a, int *lda, fcomplex *b, int *ldb, fcomplex *beta, 83 | fcomplex *c, int *ldc); 84 | 85 | int chemv_(char *uplo, int *n, fcomplex *alpha, fcomplex *a, int *lda, 86 | fcomplex *x, int *incx, fcomplex *beta, fcomplex *y, int *incy); 87 | 88 | int cher_(char *uplo, int *n, float *alpha, fcomplex *x, int *incx, 89 | fcomplex *a, int *lda); 90 | 91 | int cher2_(char *uplo, int *n, fcomplex *alpha, fcomplex *x, int *incx, 92 | fcomplex *y, int *incy, fcomplex *a, int *lda); 93 | 94 | int cher2k_(char *uplo, char *trans, int *n, int *k, fcomplex *alpha, 95 | fcomplex *a, int *lda, fcomplex *b, int *ldb, float *beta, 96 | fcomplex *c, int *ldc); 97 | 98 | int cherk_(char *uplo, char *trans, int *n, int *k, float *alpha, 99 | fcomplex *a, int *lda, float *beta, fcomplex *c, int *ldc); 100 | 101 | int chpmv_(char *uplo, int *n, fcomplex *alpha, fcomplex *ap, fcomplex *x, 102 | int *incx, fcomplex *beta, fcomplex *y, int *incy); 103 | 104 | int chpr_(char *uplo, int *n, float *alpha, fcomplex *x, int *incx, 105 | fcomplex *ap); 106 | 107 | int chpr2_(char *uplo, int *n, fcomplex *alpha, fcomplex *x, int *incx, 108 | fcomplex *y, int *incy, fcomplex *ap); 109 | 110 | int crotg_(fcomplex *ca, fcomplex *cb, float *c, fcomplex *s); 111 | 112 | int cscal_(int *n, fcomplex *ca, fcomplex *cx, int *incx); 113 | 114 | int csscal_(int *n, float *sa, fcomplex *cx, int *incx); 115 | 116 | int cswap_(int *n, fcomplex *cx, int *incx, fcomplex *cy, int *incy); 117 | 118 | int csymm_(char *side, char *uplo, int *m, int *n, fcomplex *alpha, 119 | fcomplex *a, int *lda, fcomplex *b, int *ldb, fcomplex *beta, 120 | fcomplex *c, int *ldc); 121 | 122 | int csyr2k_(char *uplo, char *trans, int *n, int *k, fcomplex *alpha, 123 | fcomplex *a, int *lda, fcomplex *b, int *ldb, fcomplex *beta, 124 | fcomplex *c, int *ldc); 125 | 126 | int csyrk_(char *uplo, char *trans, int *n, int *k, fcomplex *alpha, 127 | fcomplex *a, int *lda, fcomplex *beta, fcomplex *c, int *ldc); 128 | 129 | int ctbmv_(char *uplo, char *trans, char *diag, int *n, int *k, 130 | fcomplex *a, int *lda, fcomplex *x, int *incx); 131 | 132 | int ctbsv_(char *uplo, char *trans, char *diag, int *n, int *k, 133 | fcomplex *a, int *lda, fcomplex *x, int *incx); 134 | 135 | int ctpmv_(char *uplo, char *trans, char *diag, int *n, fcomplex *ap, 136 | fcomplex *x, int *incx); 137 | 138 | int ctpsv_(char *uplo, char *trans, char *diag, int *n, fcomplex *ap, 139 | fcomplex *x, int *incx); 140 | 141 | int ctrmm_(char *side, char *uplo, char *transa, char *diag, int *m, 142 | int *n, fcomplex *alpha, fcomplex *a, int *lda, fcomplex *b, 143 | int *ldb); 144 | 145 | int ctrmv_(char *uplo, char *trans, char *diag, int *n, fcomplex *a, 146 | int *lda, fcomplex *x, int *incx); 147 | 148 | int ctrsm_(char *side, char *uplo, char *transa, char *diag, int *m, 149 | int *n, fcomplex *alpha, fcomplex *a, int *lda, fcomplex *b, 150 | int *ldb); 151 | 152 | int ctrsv_(char *uplo, char *trans, char *diag, int *n, fcomplex *a, 153 | int *lda, fcomplex *x, int *incx); 154 | 155 | int daxpy_(int *n, double *sa, double *sx, int *incx, double *sy, 156 | int *incy); 157 | 158 | int dcopy_(int *n, double *sx, int *incx, double *sy, int *incy); 159 | 160 | int dgbmv_(char *trans, int *m, int *n, int *kl, int *ku, 161 | double *alpha, double *a, int *lda, double *x, int *incx, 162 | double *beta, double *y, int *incy); 163 | 164 | int dgemm_(char *transa, char *transb, int *m, int *n, int *k, 165 | double *alpha, double *a, int *lda, double *b, int *ldb, 166 | double *beta, double *c, int *ldc); 167 | 168 | int dgemv_(char *trans, int *m, int *n, double *alpha, double *a, 169 | int *lda, double *x, int *incx, double *beta, double *y, 170 | int *incy); 171 | 172 | int dger_(int *m, int *n, double *alpha, double *x, int *incx, 173 | double *y, int *incy, double *a, int *lda); 174 | 175 | int drot_(int *n, double *sx, int *incx, double *sy, int *incy, 176 | double *c, double *s); 177 | 178 | int drotg_(double *sa, double *sb, double *c, double *s); 179 | 180 | int dsbmv_(char *uplo, int *n, int *k, double *alpha, double *a, 181 | int *lda, double *x, int *incx, double *beta, double *y, 182 | int *incy); 183 | 184 | int dscal_(int *n, double *sa, double *sx, int *incx); 185 | 186 | int dspmv_(char *uplo, int *n, double *alpha, double *ap, double *x, 187 | int *incx, double *beta, double *y, int *incy); 188 | 189 | int dspr_(char *uplo, int *n, double *alpha, double *x, int *incx, 190 | double *ap); 191 | 192 | int dspr2_(char *uplo, int *n, double *alpha, double *x, int *incx, 193 | double *y, int *incy, double *ap); 194 | 195 | int dswap_(int *n, double *sx, int *incx, double *sy, int *incy); 196 | 197 | int dsymm_(char *side, char *uplo, int *m, int *n, double *alpha, 198 | double *a, int *lda, double *b, int *ldb, double *beta, 199 | double *c, int *ldc); 200 | 201 | int dsymv_(char *uplo, int *n, double *alpha, double *a, int *lda, 202 | double *x, int *incx, double *beta, double *y, int *incy); 203 | 204 | int dsyr_(char *uplo, int *n, double *alpha, double *x, int *incx, 205 | double *a, int *lda); 206 | 207 | int dsyr2_(char *uplo, int *n, double *alpha, double *x, int *incx, 208 | double *y, int *incy, double *a, int *lda); 209 | 210 | int dsyr2k_(char *uplo, char *trans, int *n, int *k, double *alpha, 211 | double *a, int *lda, double *b, int *ldb, double *beta, 212 | double *c, int *ldc); 213 | 214 | int dsyrk_(char *uplo, char *trans, int *n, int *k, double *alpha, 215 | double *a, int *lda, double *beta, double *c, int *ldc); 216 | 217 | int dtbmv_(char *uplo, char *trans, char *diag, int *n, int *k, 218 | double *a, int *lda, double *x, int *incx); 219 | 220 | int dtbsv_(char *uplo, char *trans, char *diag, int *n, int *k, 221 | double *a, int *lda, double *x, int *incx); 222 | 223 | int dtpmv_(char *uplo, char *trans, char *diag, int *n, double *ap, 224 | double *x, int *incx); 225 | 226 | int dtpsv_(char *uplo, char *trans, char *diag, int *n, double *ap, 227 | double *x, int *incx); 228 | 229 | int dtrmm_(char *side, char *uplo, char *transa, char *diag, int *m, 230 | int *n, double *alpha, double *a, int *lda, double *b, 231 | int *ldb); 232 | 233 | int dtrmv_(char *uplo, char *trans, char *diag, int *n, double *a, 234 | int *lda, double *x, int *incx); 235 | 236 | int dtrsm_(char *side, char *uplo, char *transa, char *diag, int *m, 237 | int *n, double *alpha, double *a, int *lda, double *b, 238 | int *ldb); 239 | 240 | int dtrsv_(char *uplo, char *trans, char *diag, int *n, double *a, 241 | int *lda, double *x, int *incx); 242 | 243 | 244 | int saxpy_(int *n, float *sa, float *sx, int *incx, float *sy, int *incy); 245 | 246 | int scopy_(int *n, float *sx, int *incx, float *sy, int *incy); 247 | 248 | int sgbmv_(char *trans, int *m, int *n, int *kl, int *ku, 249 | float *alpha, float *a, int *lda, float *x, int *incx, 250 | float *beta, float *y, int *incy); 251 | 252 | int sgemm_(char *transa, char *transb, int *m, int *n, int *k, 253 | float *alpha, float *a, int *lda, float *b, int *ldb, 254 | float *beta, float *c, int *ldc); 255 | 256 | int sgemv_(char *trans, int *m, int *n, float *alpha, float *a, 257 | int *lda, float *x, int *incx, float *beta, float *y, 258 | int *incy); 259 | 260 | int sger_(int *m, int *n, float *alpha, float *x, int *incx, 261 | float *y, int *incy, float *a, int *lda); 262 | 263 | int srot_(int *n, float *sx, int *incx, float *sy, int *incy, 264 | float *c, float *s); 265 | 266 | int srotg_(float *sa, float *sb, float *c, float *s); 267 | 268 | int ssbmv_(char *uplo, int *n, int *k, float *alpha, float *a, 269 | int *lda, float *x, int *incx, float *beta, float *y, 270 | int *incy); 271 | 272 | int sscal_(int *n, float *sa, float *sx, int *incx); 273 | 274 | int sspmv_(char *uplo, int *n, float *alpha, float *ap, float *x, 275 | int *incx, float *beta, float *y, int *incy); 276 | 277 | int sspr_(char *uplo, int *n, float *alpha, float *x, int *incx, 278 | float *ap); 279 | 280 | int sspr2_(char *uplo, int *n, float *alpha, float *x, int *incx, 281 | float *y, int *incy, float *ap); 282 | 283 | int sswap_(int *n, float *sx, int *incx, float *sy, int *incy); 284 | 285 | int ssymm_(char *side, char *uplo, int *m, int *n, float *alpha, 286 | float *a, int *lda, float *b, int *ldb, float *beta, 287 | float *c, int *ldc); 288 | 289 | int ssymv_(char *uplo, int *n, float *alpha, float *a, int *lda, 290 | float *x, int *incx, float *beta, float *y, int *incy); 291 | 292 | int ssyr_(char *uplo, int *n, float *alpha, float *x, int *incx, 293 | float *a, int *lda); 294 | 295 | int ssyr2_(char *uplo, int *n, float *alpha, float *x, int *incx, 296 | float *y, int *incy, float *a, int *lda); 297 | 298 | int ssyr2k_(char *uplo, char *trans, int *n, int *k, float *alpha, 299 | float *a, int *lda, float *b, int *ldb, float *beta, 300 | float *c, int *ldc); 301 | 302 | int ssyrk_(char *uplo, char *trans, int *n, int *k, float *alpha, 303 | float *a, int *lda, float *beta, float *c, int *ldc); 304 | 305 | int stbmv_(char *uplo, char *trans, char *diag, int *n, int *k, 306 | float *a, int *lda, float *x, int *incx); 307 | 308 | int stbsv_(char *uplo, char *trans, char *diag, int *n, int *k, 309 | float *a, int *lda, float *x, int *incx); 310 | 311 | int stpmv_(char *uplo, char *trans, char *diag, int *n, float *ap, 312 | float *x, int *incx); 313 | 314 | int stpsv_(char *uplo, char *trans, char *diag, int *n, float *ap, 315 | float *x, int *incx); 316 | 317 | int strmm_(char *side, char *uplo, char *transa, char *diag, int *m, 318 | int *n, float *alpha, float *a, int *lda, float *b, 319 | int *ldb); 320 | 321 | int strmv_(char *uplo, char *trans, char *diag, int *n, float *a, 322 | int *lda, float *x, int *incx); 323 | 324 | int strsm_(char *side, char *uplo, char *transa, char *diag, int *m, 325 | int *n, float *alpha, float *a, int *lda, float *b, 326 | int *ldb); 327 | 328 | int strsv_(char *uplo, char *trans, char *diag, int *n, float *a, 329 | int *lda, float *x, int *incx); 330 | 331 | int zaxpy_(int *n, dcomplex *ca, dcomplex *cx, int *incx, dcomplex *cy, 332 | int *incy); 333 | 334 | int zcopy_(int *n, dcomplex *cx, int *incx, dcomplex *cy, int *incy); 335 | 336 | int zdscal_(int *n, double *sa, dcomplex *cx, int *incx); 337 | 338 | int zgbmv_(char *trans, int *m, int *n, int *kl, int *ku, 339 | dcomplex *alpha, dcomplex *a, int *lda, dcomplex *x, int *incx, 340 | dcomplex *beta, dcomplex *y, int *incy); 341 | 342 | int zgemm_(char *transa, char *transb, int *m, int *n, int *k, 343 | dcomplex *alpha, dcomplex *a, int *lda, dcomplex *b, int *ldb, 344 | dcomplex *beta, dcomplex *c, int *ldc); 345 | 346 | int zgemv_(char *trans, int *m, int *n, dcomplex *alpha, dcomplex *a, 347 | int *lda, dcomplex *x, int *incx, dcomplex *beta, dcomplex *y, 348 | int *incy); 349 | 350 | int zgerc_(int *m, int *n, dcomplex *alpha, dcomplex *x, int *incx, 351 | dcomplex *y, int *incy, dcomplex *a, int *lda); 352 | 353 | int zgeru_(int *m, int *n, dcomplex *alpha, dcomplex *x, int *incx, 354 | dcomplex *y, int *incy, dcomplex *a, int *lda); 355 | 356 | int zhbmv_(char *uplo, int *n, int *k, dcomplex *alpha, dcomplex *a, 357 | int *lda, dcomplex *x, int *incx, dcomplex *beta, dcomplex *y, 358 | int *incy); 359 | 360 | int zhemm_(char *side, char *uplo, int *m, int *n, dcomplex *alpha, 361 | dcomplex *a, int *lda, dcomplex *b, int *ldb, dcomplex *beta, 362 | dcomplex *c, int *ldc); 363 | 364 | int zhemv_(char *uplo, int *n, dcomplex *alpha, dcomplex *a, int *lda, 365 | dcomplex *x, int *incx, dcomplex *beta, dcomplex *y, int *incy); 366 | 367 | int zher_(char *uplo, int *n, double *alpha, dcomplex *x, int *incx, 368 | dcomplex *a, int *lda); 369 | 370 | int zher2_(char *uplo, int *n, dcomplex *alpha, dcomplex *x, int *incx, 371 | dcomplex *y, int *incy, dcomplex *a, int *lda); 372 | 373 | int zher2k_(char *uplo, char *trans, int *n, int *k, dcomplex *alpha, 374 | dcomplex *a, int *lda, dcomplex *b, int *ldb, double *beta, 375 | dcomplex *c, int *ldc); 376 | 377 | int zherk_(char *uplo, char *trans, int *n, int *k, double *alpha, 378 | dcomplex *a, int *lda, double *beta, dcomplex *c, int *ldc); 379 | 380 | int zhpmv_(char *uplo, int *n, dcomplex *alpha, dcomplex *ap, dcomplex *x, 381 | int *incx, dcomplex *beta, dcomplex *y, int *incy); 382 | 383 | int zhpr_(char *uplo, int *n, double *alpha, dcomplex *x, int *incx, 384 | dcomplex *ap); 385 | 386 | int zhpr2_(char *uplo, int *n, dcomplex *alpha, dcomplex *x, int *incx, 387 | dcomplex *y, int *incy, dcomplex *ap); 388 | 389 | int zrotg_(dcomplex *ca, dcomplex *cb, double *c, dcomplex *s); 390 | 391 | int zscal_(int *n, dcomplex *ca, dcomplex *cx, int *incx); 392 | 393 | int zswap_(int *n, dcomplex *cx, int *incx, dcomplex *cy, int *incy); 394 | 395 | int zsymm_(char *side, char *uplo, int *m, int *n, dcomplex *alpha, 396 | dcomplex *a, int *lda, dcomplex *b, int *ldb, dcomplex *beta, 397 | dcomplex *c, int *ldc); 398 | 399 | int zsyr2k_(char *uplo, char *trans, int *n, int *k, dcomplex *alpha, 400 | dcomplex *a, int *lda, dcomplex *b, int *ldb, dcomplex *beta, 401 | dcomplex *c, int *ldc); 402 | 403 | int zsyrk_(char *uplo, char *trans, int *n, int *k, dcomplex *alpha, 404 | dcomplex *a, int *lda, dcomplex *beta, dcomplex *c, int *ldc); 405 | 406 | int ztbmv_(char *uplo, char *trans, char *diag, int *n, int *k, 407 | dcomplex *a, int *lda, dcomplex *x, int *incx); 408 | 409 | int ztbsv_(char *uplo, char *trans, char *diag, int *n, int *k, 410 | dcomplex *a, int *lda, dcomplex *x, int *incx); 411 | 412 | int ztpmv_(char *uplo, char *trans, char *diag, int *n, dcomplex *ap, 413 | dcomplex *x, int *incx); 414 | 415 | int ztpsv_(char *uplo, char *trans, char *diag, int *n, dcomplex *ap, 416 | dcomplex *x, int *incx); 417 | 418 | int ztrmm_(char *side, char *uplo, char *transa, char *diag, int *m, 419 | int *n, dcomplex *alpha, dcomplex *a, int *lda, dcomplex *b, 420 | int *ldb); 421 | 422 | int ztrmv_(char *uplo, char *trans, char *diag, int *n, dcomplex *a, 423 | int *lda, dcomplex *x, int *incx); 424 | 425 | int ztrsm_(char *side, char *uplo, char *transa, char *diag, int *m, 426 | int *n, dcomplex *alpha, dcomplex *a, int *lda, dcomplex *b, 427 | int *ldb); 428 | 429 | int ztrsv_(char *uplo, char *trans, char *diag, int *n, dcomplex *a, 430 | int *lda, dcomplex *x, int *incx); 431 | -------------------------------------------------------------------------------- /pytron/src/blas/daxpy.c: -------------------------------------------------------------------------------- 1 | #include "blas.h" 2 | 3 | int daxpy_(int *n, double *sa, double *sx, int *incx, double *sy, 4 | int *incy) 5 | { 6 | long int i, m, ix, iy, nn, iincx, iincy; 7 | register double ssa; 8 | 9 | /* constant times a vector plus a vector. 10 | uses unrolled loop for increments equal to one. 11 | jack dongarra, linpack, 3/11/78. 12 | modified 12/3/93, array(1) declarations changed to array(*) */ 13 | 14 | /* Dereference inputs */ 15 | nn = *n; 16 | ssa = *sa; 17 | iincx = *incx; 18 | iincy = *incy; 19 | 20 | if( nn > 0 && ssa != 0.0 ) 21 | { 22 | if (iincx == 1 && iincy == 1) /* code for both increments equal to 1 */ 23 | { 24 | m = nn-3; 25 | for (i = 0; i < m; i += 4) 26 | { 27 | sy[i] += ssa * sx[i]; 28 | sy[i+1] += ssa * sx[i+1]; 29 | sy[i+2] += ssa * sx[i+2]; 30 | sy[i+3] += ssa * sx[i+3]; 31 | } 32 | for ( ; i < nn; ++i) /* clean-up loop */ 33 | sy[i] += ssa * sx[i]; 34 | } 35 | else /* code for unequal increments or equal increments not equal to 1 */ 36 | { 37 | ix = iincx >= 0 ? 0 : (1 - nn) * iincx; 38 | iy = iincy >= 0 ? 0 : (1 - nn) * iincy; 39 | for (i = 0; i < nn; i++) 40 | { 41 | sy[iy] += ssa * sx[ix]; 42 | ix += iincx; 43 | iy += iincy; 44 | } 45 | } 46 | } 47 | 48 | return 0; 49 | } /* daxpy_ */ 50 | -------------------------------------------------------------------------------- /pytron/src/blas/ddot.c: -------------------------------------------------------------------------------- 1 | #include "blas.h" 2 | 3 | double ddot_(int *n, double *sx, int *incx, double *sy, int *incy) 4 | { 5 | long int i, m, nn, iincx, iincy; 6 | double stemp; 7 | long int ix, iy; 8 | 9 | /* forms the dot product of two vectors. 10 | uses unrolled loops for increments equal to one. 11 | jack dongarra, linpack, 3/11/78. 12 | modified 12/3/93, array(1) declarations changed to array(*) */ 13 | 14 | /* Dereference inputs */ 15 | nn = *n; 16 | iincx = *incx; 17 | iincy = *incy; 18 | 19 | stemp = 0.0; 20 | if (nn > 0) 21 | { 22 | if (iincx == 1 && iincy == 1) /* code for both increments equal to 1 */ 23 | { 24 | m = nn-4; 25 | for (i = 0; i < m; i += 5) 26 | stemp += sx[i] * sy[i] + sx[i+1] * sy[i+1] + sx[i+2] * sy[i+2] + 27 | sx[i+3] * sy[i+3] + sx[i+4] * sy[i+4]; 28 | 29 | for ( ; i < nn; i++) /* clean-up loop */ 30 | stemp += sx[i] * sy[i]; 31 | } 32 | else /* code for unequal increments or equal increments not equal to 1 */ 33 | { 34 | ix = 0; 35 | iy = 0; 36 | if (iincx < 0) 37 | ix = (1 - nn) * iincx; 38 | if (iincy < 0) 39 | iy = (1 - nn) * iincy; 40 | for (i = 0; i < nn; i++) 41 | { 42 | stemp += sx[ix] * sy[iy]; 43 | ix += iincx; 44 | iy += iincy; 45 | } 46 | } 47 | } 48 | 49 | return stemp; 50 | } /* ddot_ */ 51 | -------------------------------------------------------------------------------- /pytron/src/blas/dnrm2.c: -------------------------------------------------------------------------------- 1 | #include /* Needed for fabs() and sqrt() */ 2 | #include "blas.h" 3 | 4 | double dnrm2_(int *n, double *x, int *incx) 5 | { 6 | long int ix, nn, iincx; 7 | double norm, scale, absxi, ssq, temp; 8 | 9 | /* DNRM2 returns the euclidean norm of a vector via the function 10 | name, so that 11 | 12 | DNRM2 := sqrt( x'*x ) 13 | 14 | -- This version written on 25-October-1982. 15 | Modified on 14-October-1993 to inline the call to SLASSQ. 16 | Sven Hammarling, Nag Ltd. */ 17 | 18 | /* Dereference inputs */ 19 | nn = *n; 20 | iincx = *incx; 21 | 22 | if( nn > 0 && iincx > 0 ) 23 | { 24 | if (nn == 1) 25 | { 26 | norm = fabs(x[0]); 27 | } 28 | else 29 | { 30 | scale = 0.0; 31 | ssq = 1.0; 32 | 33 | /* The following loop is equivalent to this call to the LAPACK 34 | auxiliary routine: CALL SLASSQ( N, X, INCX, SCALE, SSQ ) */ 35 | 36 | for (ix=(nn-1)*iincx; ix>=0; ix-=iincx) 37 | { 38 | if (x[ix] != 0.0) 39 | { 40 | absxi = fabs(x[ix]); 41 | if (scale < absxi) 42 | { 43 | temp = scale / absxi; 44 | ssq = ssq * (temp * temp) + 1.0; 45 | scale = absxi; 46 | } 47 | else 48 | { 49 | temp = absxi / scale; 50 | ssq += temp * temp; 51 | } 52 | } 53 | } 54 | norm = scale * sqrt(ssq); 55 | } 56 | } 57 | else 58 | norm = 0.0; 59 | 60 | return norm; 61 | 62 | } /* dnrm2_ */ 63 | -------------------------------------------------------------------------------- /pytron/src/blas/dscal.c: -------------------------------------------------------------------------------- 1 | #include "blas.h" 2 | 3 | int dscal_(int *n, double *sa, double *sx, int *incx) 4 | { 5 | long int i, m, nincx, nn, iincx; 6 | double ssa; 7 | 8 | /* scales a vector by a constant. 9 | uses unrolled loops for increment equal to 1. 10 | jack dongarra, linpack, 3/11/78. 11 | modified 3/93 to return if incx .le. 0. 12 | modified 12/3/93, array(1) declarations changed to array(*) */ 13 | 14 | /* Dereference inputs */ 15 | nn = *n; 16 | iincx = *incx; 17 | ssa = *sa; 18 | 19 | if (nn > 0 && iincx > 0) 20 | { 21 | if (iincx == 1) /* code for increment equal to 1 */ 22 | { 23 | m = nn-4; 24 | for (i = 0; i < m; i += 5) 25 | { 26 | sx[i] = ssa * sx[i]; 27 | sx[i+1] = ssa * sx[i+1]; 28 | sx[i+2] = ssa * sx[i+2]; 29 | sx[i+3] = ssa * sx[i+3]; 30 | sx[i+4] = ssa * sx[i+4]; 31 | } 32 | for ( ; i < nn; ++i) /* clean-up loop */ 33 | sx[i] = ssa * sx[i]; 34 | } 35 | else /* code for increment not equal to 1 */ 36 | { 37 | nincx = nn * iincx; 38 | for (i = 0; i < nincx; i += iincx) 39 | sx[i] = ssa * sx[i]; 40 | } 41 | } 42 | 43 | return 0; 44 | } /* dscal_ */ 45 | -------------------------------------------------------------------------------- /pytron/src/tron.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "tron.h" 6 | 7 | #ifndef min 8 | template static inline T min(T x,T y) { return (x static inline T max(T x,T y) { return (x>y)?x:y; } 13 | #endif 14 | 15 | #ifdef __cplusplus 16 | extern "C" { 17 | #endif 18 | 19 | extern double dnrm2_(int *, double *, int *); 20 | extern double ddot_(int *, double *, int *, double *, int *); 21 | extern int daxpy_(int *, double *, double *, int *, double *, int *); 22 | extern int dscal_(int *, double *, double *, int *); 23 | 24 | #ifdef __cplusplus 25 | } 26 | #endif 27 | 28 | static void default_print(const char *buf) 29 | { 30 | fputs(buf,stdout); 31 | fflush(stdout); 32 | } 33 | 34 | void TRON::info(const char *fmt,...) 35 | { 36 | char buf[BUFSIZ]; 37 | va_list ap; 38 | va_start(ap,fmt); 39 | vsprintf(buf,fmt,ap); 40 | va_end(ap); 41 | (*tron_print_string)(buf); 42 | } 43 | 44 | TRON::TRON(const function *fun_obj, double tol, double gtol, int max_iter) 45 | { 46 | this->fun_obj=const_cast(fun_obj); 47 | this->gtol=gtol; 48 | this->tol=tol; 49 | this->max_iter=max_iter; 50 | tron_print_string = default_print; 51 | this->n_iter = 0; 52 | this->gnorm = 0.; 53 | } 54 | 55 | TRON::~TRON() 56 | { 57 | } 58 | 59 | void TRON::tron(double *w, double *g, int verbose) 60 | { 61 | /* 62 | actred : Actual reduction 63 | prered : predicted reduction 64 | */ 65 | // Parameters for updating the iterates. 66 | double eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75; 67 | 68 | // Parameters for updating the trust region size delta. 69 | double sigma1 = 0.25, sigma2 = 0.5, sigma3 = 4; 70 | 71 | int n = fun_obj->get_nr_variable(); 72 | int i, cg_iter; 73 | double delta, snorm, one=1.0; 74 | double alpha, f, fnew, prered, actred, gs; 75 | int search = 1, iter = 1, inc = 1; 76 | double *s = new double[n]; 77 | double *r = new double[n]; 78 | double *w_new = new double[n]; 79 | 80 | /* Edit (Fabian): allow for warm restarts 81 | for (i=0; ifun(w); 86 | fun_obj->grad(w, g); 87 | delta = dnrm2_(&n, g, &inc); // TODO: use infinity norm 88 | double gnorm1 = delta; 89 | double gnorm = gnorm1; 90 | 91 | if (gnorm <= gtol) 92 | search = 0; 93 | 94 | iter = 1; 95 | 96 | while (iter <= max_iter && search) 97 | { 98 | 99 | cg_iter = trcg(delta, g, s, r); 100 | 101 | memcpy(w_new, w, sizeof(double)*n); 102 | daxpy_(&n, &one, s, &inc, w_new, &inc); 103 | 104 | gs = ddot_(&n, g, &inc, s, &inc); 105 | prered = -0.5*(gs-ddot_(&n, s, &inc, r, &inc)); 106 | fnew = fun_obj->fun(w_new); 107 | 108 | // Compute the actual reduction. 109 | actred = f - fnew; 110 | 111 | // On the first iteration, adjust the initial step bound. 112 | snorm = dnrm2_(&n, s, &inc); 113 | if (iter == 1) 114 | delta = min(delta, snorm); 115 | 116 | // Compute prediction alpha*snorm of the step. 117 | if (fnew - f - gs <= 0) 118 | alpha = sigma3; 119 | else 120 | alpha = max(sigma1, -0.5*(gs/(fnew - f - gs))); 121 | 122 | // Update the trust region bound according to the ratio of actual to predicted reduction. 123 | if (actred < eta0*prered) 124 | delta = min(max(alpha, sigma1)*snorm, sigma2*delta); 125 | else if (actred < eta1*prered) 126 | delta = max(sigma1*delta, min(alpha*snorm, sigma2*delta)); 127 | else if (actred < eta2*prered) 128 | delta = max(sigma1*delta, min(alpha*snorm, sigma3*delta)); 129 | else 130 | delta = max(delta, min(alpha*snorm, sigma3*delta)); 131 | 132 | if (verbose) 133 | info("iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d\n", iter, actred, prered, delta, f, gnorm, cg_iter); 134 | 135 | if (actred > eta0*prered) 136 | { 137 | iter++; 138 | memcpy(w, w_new, sizeof(double)*n); 139 | f = fnew; 140 | fun_obj->grad(w, g); 141 | 142 | gnorm = norm_inf(n, g); 143 | // gnorm = dnrm2_(&n, g, &inc); 144 | if (gnorm <= gtol) 145 | break; 146 | } 147 | fun_obj->callback(w); 148 | 149 | if (f < -1.0e+32) 150 | { 151 | info("WARNING: f < -1.0e+32\n"); 152 | break; 153 | } 154 | if (fabs(actred) <= 0 && prered <= 0) 155 | { 156 | info("WARNING: actred and prered <= 0\n"); 157 | break; 158 | } 159 | if (fabs(actred) <= tol*fabs(f) && 160 | fabs(prered) <= tol*fabs(f)) 161 | { 162 | info("WARNING: actred and prered too small\n"); 163 | break; 164 | } 165 | } 166 | 167 | this->n_iter = iter; 168 | this->gnorm = gnorm; 169 | this->fun = fun; 170 | 171 | delete[] r; 172 | delete[] w_new; 173 | delete[] s; 174 | } 175 | 176 | int TRON::trcg(double delta, double *g, double *s, double *r) 177 | { 178 | int i, inc = 1; 179 | int n = fun_obj->get_nr_variable(); 180 | double one = 1; 181 | double *d = new double[n]; 182 | double *Hd = new double[n]; 183 | double rTr, rnewTrnew, alpha, beta, cgtol; 184 | 185 | for (i=0; iHv(d, Hd); 201 | 202 | alpha = rTr/ddot_(&n, d, &inc, Hd, &inc); 203 | daxpy_(&n, &alpha, d, &inc, s, &inc); 204 | if (dnrm2_(&n, s, &inc) > delta) 205 | { 206 | info("cg reaches trust region boundary\n"); 207 | alpha = -alpha; 208 | daxpy_(&n, &alpha, d, &inc, s, &inc); 209 | 210 | double std = ddot_(&n, s, &inc, d, &inc); 211 | double sts = ddot_(&n, s, &inc, s, &inc); 212 | double dtd = ddot_(&n, d, &inc, d, &inc); 213 | double dsq = delta*delta; 214 | double rad = sqrt(std*std + dtd*(dsq-sts)); 215 | if (std >= 0) 216 | alpha = (dsq - sts)/(std + rad); 217 | else 218 | alpha = (rad - std)/dtd; 219 | daxpy_(&n, &alpha, d, &inc, s, &inc); 220 | alpha = -alpha; 221 | daxpy_(&n, &alpha, Hd, &inc, r, &inc); 222 | break; 223 | } 224 | alpha = -alpha; 225 | daxpy_(&n, &alpha, Hd, &inc, r, &inc); 226 | rnewTrnew = ddot_(&n, r, &inc, r, &inc); 227 | beta = rnewTrnew/rTr; 228 | dscal_(&n, &beta, d, &inc); 229 | daxpy_(&n, &one, r, &inc, d, &inc); 230 | rTr = rnewTrnew; 231 | } 232 | 233 | delete[] d; 234 | delete[] Hd; 235 | 236 | return(cg_iter); 237 | } 238 | 239 | double TRON::norm_inf(int n, double *x) 240 | { 241 | double dmax = fabs(x[0]); 242 | for (int i=1; i= dmax) 244 | dmax = fabs(x[i]); 245 | return(dmax); 246 | } 247 | 248 | void TRON::set_print_string(void (*print_string) (const char *buf)) 249 | { 250 | tron_print_string = print_string; 251 | } 252 | -------------------------------------------------------------------------------- /pytron/src/tron.h: -------------------------------------------------------------------------------- 1 | #ifndef _TRON_H 2 | #define _TRON_H 3 | 4 | class function 5 | { 6 | public: 7 | virtual double fun(double *w) = 0 ; 8 | virtual void grad(double *w, double *g) = 0 ; 9 | virtual void Hv(double *s, double *Hs) = 0 ; 10 | virtual void callback(double *w) = 0 ; 11 | 12 | virtual int get_nr_variable(void) = 0 ; 13 | virtual ~function(void){} 14 | }; 15 | 16 | class TRON 17 | { 18 | public: 19 | TRON(const function *fun_obj, double tol=0.1, double gtol = 0.1, 20 | int max_iter = 1000); 21 | ~TRON(); 22 | 23 | void tron(double *w, double *g, int verbose); 24 | void set_print_string(void (*i_print) (const char *buf)); 25 | int n_iter; 26 | double gnorm; 27 | double fun; 28 | 29 | private: 30 | int trcg(double delta, double *g, double *s, double *r); 31 | double norm_inf(int n, double *x); 32 | 33 | double gtol; 34 | double tol; 35 | int max_iter; 36 | function *fun_obj; 37 | void info(const char *fmt,...); 38 | void (*tron_print_string)(const char *buf); 39 | }; 40 | #endif 41 | -------------------------------------------------------------------------------- /pytron/src/tron_helper.cpp: -------------------------------------------------------------------------------- 1 | #include "tron_helper.h" 2 | #include 3 | #include 4 | 5 | double func_callback::fun(double *w) 6 | { 7 | double t; 8 | t = c_func(w, py_func, this->nr_variable, this->py_args); 9 | return t; 10 | } 11 | 12 | void func_callback::grad(double *w, double *g) 13 | { 14 | c_grad(w, py_grad_hess, &this->py_hess, g, this->nr_variable, 15 | this->py_args); 16 | } 17 | 18 | void func_callback::Hv(double *s, double *Hs) 19 | { 20 | c_hess(s, this->py_hess, Hs, this->nr_variable, this->py_args); 21 | } 22 | 23 | int func_callback::get_nr_variable(void) 24 | { 25 | return nr_variable; 26 | } 27 | 28 | void func_callback::callback(double *w) 29 | { 30 | if (this->py_callback != NULL) 31 | c_callback(w, this->py_callback, this->nr_variable, this->py_args); 32 | } -------------------------------------------------------------------------------- /pytron/src/tron_helper.h: -------------------------------------------------------------------------------- 1 | #include "tron.h" 2 | 3 | typedef double (*func_cb)(double *, void *, int, void *); 4 | typedef int (*grad_cb)(double *, void *, void **, double *, int, void *); 5 | typedef int (*hess_cb)(double *, void *, double *, int, void *); 6 | 7 | class func_callback: public function { 8 | 9 | public: 10 | func_callback(double *x0, void *py_func, func_cb c_func, 11 | void *py_grad_hess, grad_cb c_grad, hess_cb c_hess, 12 | void *py_callback, func_cb c_callback, int nr_variable, void *py_args) { 13 | this->w = new double[nr_variable]; 14 | this->py_func = py_func; 15 | this->py_grad_hess = py_grad_hess; 16 | this->py_callback = py_callback; 17 | this->c_func = c_func; 18 | this->c_grad = c_grad; 19 | this->c_hess = c_hess; 20 | this->c_callback = c_callback; 21 | this->nr_variable = nr_variable; 22 | this->py_args = py_args; 23 | }; 24 | 25 | ~func_callback() { 26 | delete this->w; 27 | }; 28 | double fun(double *w); 29 | void grad(double *w, double *g); 30 | void Hv(double *s, double *Hs); 31 | void callback(double *w); 32 | int get_nr_variable(void); 33 | 34 | protected: 35 | double tmp; 36 | double *w; 37 | func_cb c_func; 38 | grad_cb c_grad; 39 | hess_cb c_hess; 40 | func_cb c_callback; 41 | void *py_func; 42 | void *py_grad_hess; 43 | void *py_hess; 44 | void *py_args; 45 | void *py_callback; 46 | int nr_variable; 47 | }; 48 | 49 | -------------------------------------------------------------------------------- /pytron/tron.pyx: -------------------------------------------------------------------------------- 1 | from cython cimport view 2 | import numpy as np 3 | from scipy import optimize 4 | cimport numpy as np 5 | from libc cimport string 6 | from cpython cimport Py_INCREF, Py_XDECREF, PyObject 7 | 8 | cdef extern from "tron_helper.h": 9 | ctypedef double (*func_cb)(double *, void *, int, void *) 10 | ctypedef int (*grad_cb)(double *, void *, void **, double *, int, void *) 11 | ctypedef int (*hess_cb)(double *, void *, double *, int, void *) 12 | cdef cppclass func_callback: 13 | func_callback(double *, void *, func_cb, 14 | void *, grad_cb, hess_cb, void *, func_cb, int nr_variable, void *) 15 | 16 | 17 | cdef extern from "tron.h": 18 | cdef cppclass TRON: 19 | TRON(func_callback *, double, double, int) 20 | void tron(double *, double *, int) 21 | int n_iter 22 | double gnorm 23 | double fun 24 | 25 | 26 | cdef double c_func(double *w, void *f_py, int nr_variable, void *py_args): 27 | cdef view.array w0 = view.array(shape=(nr_variable,), itemsize=sizeof(double), 28 | mode='c', format='d', allocate_buffer=False) 29 | cdef double b 30 | w0.data = w 31 | # TODO: exception return value 32 | out = ( f_py)(np.asarray(w0), *( py_args)) 33 | b = out 34 | return b 35 | 36 | 37 | cdef int c_grad(double *w, void *grad_hess_py, void **hess_py, 38 | double *b, int nr_variable, void *py_args): 39 | cdef view.array b0 = view.array(shape=(nr_variable,), itemsize=sizeof(double), 40 | mode='c', format='d', allocate_buffer=False) 41 | b0.data = b 42 | cdef view.array w0 = view.array(shape=(nr_variable,), itemsize=sizeof(double), 43 | mode='c', format='d', allocate_buffer=False) 44 | w0.data = w 45 | try: 46 | out = ( grad_hess_py)(np.asarray(w0), *( py_args)) 47 | except: 48 | return -1 49 | #Py_XDECREF( hess_py[0]) # liberate previous one 50 | grad, hess = out[0], out[1] 51 | Py_INCREF(hess) # segfault otherwise 52 | b0[:] = grad[:] 53 | hess_py[0] = hess 54 | return 0 55 | 56 | 57 | cdef int c_hess(double *s, void *hess_py, double *b, int nr_variable, 58 | void *py_args): 59 | cdef view.array b0 = view.array(shape=(nr_variable,), 60 | itemsize=sizeof(double), format='d', 61 | mode='c', allocate_buffer=False) 62 | cdef view.array s0 = view.array(shape=(nr_variable,), 63 | itemsize=sizeof(double), format='d', 64 | mode='c', allocate_buffer=False) 65 | s0.data = s 66 | b0.data = b 67 | try: 68 | out = ( hess_py)(np.asarray(s0)) 69 | except: 70 | return -1 71 | out = np.asarray(out, dtype=np.float) 72 | b0[:] = out[:] 73 | return 0 74 | 75 | cdef double c_callback(double *w, void *py_callback, int nr_variable, 76 | void *py_args): 77 | 78 | cdef view.array w0 = view.array(shape=(nr_variable,), itemsize=sizeof(double), 79 | mode='c', format='d', allocate_buffer=False) 80 | w0.data = w 81 | out = ( py_callback)(np.asarray(w0)) 82 | return 0. 83 | 84 | 85 | def minimize(func, grad_hess, x0, args=(), max_iter=500, tol=1e-6, gtol=1., 86 | callback=None, verbose=False): 87 | """minimize func using Trust Region Newton algorithm 88 | 89 | Parameters 90 | ---------- 91 | func : callable 92 | func(w, *args) is the evaluation of the function at w, It 93 | should return a float. 94 | grad_hess: callable 95 | TODO 96 | x0 : array 97 | starting point for iteration. 98 | gtol: float 99 | Gradient infinity norm must be less than gtol 100 | before succesful termination. 101 | 102 | Returns 103 | ------- 104 | res : scipy.optimize.Result 105 | The optimization result represented as a scipy.optimize.Result object. 106 | Important attributes are: ``x`` the solution array, ``success`` a 107 | boolean flag indicating if the optimizer exited successfully, 108 | ``nit`` an integer for the number of iterations performed 109 | """ 110 | 111 | cdef np.ndarray[np.float64_t, ndim=1] x0_np 112 | cdef np.ndarray[np.float64_t, ndim=1] grad 113 | cdef int nr_variable = x0.size 114 | cdef double c_gtol = gtol 115 | cdef double c_tol = tol 116 | cdef int c_max_iter = max_iter 117 | cdef void *py_callback 118 | cur_w = None 119 | x0_np = np.asarray(x0, dtype=np.float64) 120 | grad = np.empty(x0_np.size, dtype=np.float64) 121 | if callback is None: 122 | py_callback = NULL 123 | else: 124 | py_callback = callback 125 | 126 | cdef func_callback * fc = new func_callback( 127 | x0_np.data, 128 | func, c_func, 129 | grad_hess, c_grad, 130 | c_hess, py_callback, c_callback, nr_variable, args) 131 | 132 | cdef TRON *solver = new TRON(fc, c_tol, c_gtol, c_max_iter) 133 | solver.tron( x0_np.data, grad.data, verbose) 134 | success = solver.gnorm < gtol 135 | result = optimize.OptimizeResult( 136 | x=x0_np, success=success, nit=solver.n_iter, gnorm=solver.gnorm, 137 | fun=solver.fun, jac=grad, message='TODO') 138 | 139 | del fc 140 | del solver 141 | 142 | return result 143 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from Cython.Distutils import build_ext 2 | import numpy as np 3 | from glob import glob 4 | from setuptools import setup, Extension 5 | 6 | CLASSIFIERS = """\ 7 | Development Status :: 5 - Production/Stable 8 | Intended Audience :: Science/Research 9 | Intended Audience :: Developers 10 | License :: OSI Approved 11 | Programming Language :: Python 12 | Programming Language :: Python :: 2 13 | Programming Language :: Python :: 2.6 14 | Programming Language :: Python :: 2.7 15 | Programming Language :: Python :: 3 16 | Programming Language :: Python :: 3.2 17 | Programming Language :: Python :: 3.3 18 | Topic :: Software Development 19 | Operating System :: POSIX 20 | Operating System :: Unix 21 | 22 | """ 23 | 24 | sources =['pytron/tron.pyx', 'pytron/src/tron.cpp', 'pytron/src/tron_helper.cpp'] + \ 25 | glob('pytron/src/blas/*.c') 26 | 27 | 28 | setup( 29 | name='pytron', 30 | description='Python bindings for TRON optimizer', 31 | long_description=open('README.rst').read(), 32 | version='0.3', 33 | author='Fabian Pedregosa', 34 | author_email='f@bianp.net', 35 | url='http://pypi.python.org/pypi/pytron', 36 | packages=['pytron'], 37 | classifiers=[_f for _f in CLASSIFIERS.split('\n') if _f], 38 | license='Simplified BSD', 39 | requires=['numpy', 'scipy'], 40 | cmdclass={'build_ext': build_ext}, 41 | ext_modules=[Extension('pytron.tron', 42 | sources=sources, 43 | language='c++', include_dirs=[np.get_include(), 'pytron/src/'])], 44 | 45 | ) 46 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import optimize 3 | from example_logistic import loss, grad_hess 4 | from sklearn import datasets, cross_validation 5 | from nose import tools 6 | 7 | 8 | def test_grad_logistic(): 9 | X, y = datasets.make_classification() 10 | y[y==0] = -1 11 | y = y.astype(np.float) 12 | 13 | f = lambda x: loss(x, X, y, 1.) 14 | f_grad = lambda x: grad_hess(x, X, y, 1.)[0] 15 | 16 | small = optimize.check_grad(f, f_grad, np.random.randn(X.shape[1])) 17 | tools.assert_less(small, 1.) 18 | 19 | --------------------------------------------------------------------------------