├── .gitignore
├── COPYING
├── MANIFEST.in
├── README.rst
├── doc
└── benchmark_logistic.ipynb
├── example_logistic.py
├── pytron
├── __init__.py
├── src
│ ├── blas
│ │ ├── Makefile
│ │ ├── blas.h
│ │ ├── blasp.h
│ │ ├── daxpy.c
│ │ ├── ddot.c
│ │ ├── dnrm2.c
│ │ └── dscal.c
│ ├── tron.cpp
│ ├── tron.h
│ ├── tron_helper.cpp
│ └── tron_helper.h
└── tron.pyx
├── setup.py
└── test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.so
3 | .idea
4 | pytron/tron.cpp
--------------------------------------------------------------------------------
/COPYING:
--------------------------------------------------------------------------------
1 | New BSD License
2 |
3 | Redistribution and use in source and binary forms, with or without
4 | modification, are permitted provided that the following conditions are met:
5 |
6 | a. Redistributions of source code must retain the above copyright notice,
7 | this list of conditions and the following disclaimer.
8 | b. Redistributions in binary form must reproduce the above copyright
9 | notice, this list of conditions and the following disclaimer in the
10 | documentation and/or other materials provided with the distribution.
11 | c. Neither the name of the Scikit-learn Developers nor the names of
12 | its contributors may be used to endorse or promote products
13 | derived from this software without specific prior written
14 | permission.
15 |
16 |
17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 | ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
21 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
23 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
24 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26 | OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
27 | DAMAGE.
28 |
29 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include *.py
2 | recursive-include pytron *.cpp *.h
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | A Trust-Region Newton Method in Python
2 | ======================================
3 |
4 | .. DANGER::
5 | This is alpha quality software and still quite rough on the edges.
6 | Specifically the error management is still lacking (which means that
7 | if something goes wrong in the optimization you won't see an error
8 | message but just get garbage). These things are being worked out but
9 | we're not quite there yet.
10 |
11 | .. image:: http://fa.bianp.net/blog/static/images/2013/comparison_logistic_corr_10.png
12 |
13 | The main function is pytron.minimize::
14 |
15 | def minimize(func, grad_hess, x0, args=(), max_iter=1000, tol=1e-6):
16 |
17 | Parameters
18 | ----------
19 | func : callable
20 | func(w, *args) is the evaluation of the function at w, It
21 | should return a float.
22 | grad_hess: callable
23 | returns the gradient and a callable with the hessian times
24 | an arbitrary vector.
25 | tol: float
26 | stopping criterion. XXX TODO. what is the stopping criterion ?
27 |
28 | Returns
29 | -------
30 | w : array
31 |
32 |
33 |
34 | Stopping criterion
35 | ------------------
36 |
37 | It stops whenever ||grad(x)|| < eps or the maximum number of iterations is
38 | attained.
39 |
40 | TODO: add tol
41 |
42 | Examples
43 | --------
44 |
45 | Code
46 | ----
47 | This software uses the C++ implementation of `TRON optimization software
48 | `_ (files src/tron.{h,cpp})
49 | distributed from the LIBLINEAR sources (v1.93), which is BSD licensed.
50 | Note that the original Fortran TRON implementation (available
51 | `here `_) is not open
52 | source and is not used in this project.
53 |
54 | The modifications with respect to the orginal code are:
55 |
56 | * Do not initialize values to zero, allow arbitrary initializations
57 |
58 | * Modify stopping criterion to comply with scipy.optimize API. Stop
59 | whenever gradient is smaller than a given quantity, specified in the
60 | gtol argument
61 |
62 | * Return the gradient from TRON::tron (pass by reference)
63 |
64 | * Add `tol` option to TRON
65 |
66 | * Rename `eps` to `gtol`.
67 |
68 | * Use infinity norm as stopping criterion for gradient instead of L2.
69 |
70 | TODO
71 | ----
72 | * return status from TRON::TRON
73 | * callback argument
74 |
75 |
76 | References
77 | ----------
78 | If you use the software please consider citing some of the references below.
79 |
80 | The method is described in the paper "Newton's Method for Large
81 | Bound-Constrained Optimization Problems", Chih-Jen Lin and Jorge J. Moré
82 | (http://epubs.siam.org/doi/abs/10.1137/S1052623498345075)
83 |
84 | It is also discussed in the contex of Logistic Regression in the paper "Trust
85 | Region Newton Method for Logistic Regression", Chih-Jen Lin, Ruby C. Weng,
86 | S. Sathiya Keerthi (http://dl.acm.org/citation.cfm?id=1390703)
87 |
88 | The website http://www.mcs.anl.gov/~more/tron/ contains reference to this
89 | implementation, although the links to the software seem to be currently
90 | broken (May 2013).
91 |
92 |
93 | License
94 | -------
95 | This code is licensed under the terms of the BSD license. See file COPYING
96 | for more details.
97 |
98 |
99 | Acknowledgement
100 | ---------------
101 | The source code for the
102 |
--------------------------------------------------------------------------------
/example_logistic.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | from pytron import minimize
4 |
5 | def phi(t):
6 | # helper function
7 | return 1. / (1 + np.exp(-t))
8 |
9 |
10 | def loss(w, X, y, alpha):
11 | # loss function to be optimized, it's the logistic loss
12 | z = X.dot(w)
13 | yz = y * z
14 | idx = yz > 0
15 | out = np.zeros_like(yz)
16 | out[idx] = np.log(1 + np.exp(-yz[idx]))
17 | out[~idx] = (-yz[~idx] + np.log(1 + np.exp(yz[~idx])))
18 | out = out.sum() + .5 * alpha * w.dot(w)
19 | return out
20 |
21 |
22 | def grad_hess(w, X, y, alpha):
23 | # gradient of the logistic loss
24 | z = X.dot(w)
25 | z = phi(y * z)
26 | z0 = (z - 1) * y
27 | grad = X.T.dot(z0) + alpha * w
28 | def Hs(s):
29 | d = z * (1 - z)
30 | wa = d * X.dot(s)
31 | return X.T.dot(wa) + alpha * s
32 | return grad, Hs
33 |
34 |
35 |
36 | if __name__ == '__main__':
37 | # set the data
38 | n_samples, n_features = 100, 10
39 | X = np.random.randn(n_samples, n_features)
40 | y = np.sign(X.dot(5 * np.random.randn(n_features)))
41 | alpha = 1.
42 | x0 = np.zeros(n_features)
43 |
44 |
45 | def callback(x0):
46 | print(loss(x0, X, y, alpha))
47 | # call the solver
48 | res = minimize(loss, grad_hess, x0, args=(X, y, alpha),
49 | max_iter=15, gtol=1e-3, tol=1e-12, callback=callback)
50 | print(res)
51 |
52 | from sklearn import linear_model
53 | clf = linear_model.LogisticRegression(C=1./alpha, fit_intercept=False)
54 | clf.fit(X, y)
55 |
56 | print()
57 | print('Solution using TRON: %s' % res.x)
58 | print('Solution using scikit-learn: %s' % clf.coef_)
--------------------------------------------------------------------------------
/pytron/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | __version__ = '0.3'
3 | from .tron import minimize
4 |
--------------------------------------------------------------------------------
/pytron/src/blas/Makefile:
--------------------------------------------------------------------------------
1 | AR = ar rcv
2 | RANLIB = ranlib
3 |
4 | HEADERS = blas.h blasp.h
5 | FILES = dnrm2.o daxpy.o ddot.o dscal.o
6 |
7 | CFLAGS = $(OPTFLAGS)
8 | FFLAGS = $(OPTFLAGS)
9 |
10 | blas: $(FILES) $(HEADERS)
11 | $(AR) blas.a $(FILES)
12 | $(RANLIB) blas.a
13 |
14 | clean:
15 | - rm -f *.o
16 | - rm -f *.a
17 | - rm -f *~
18 |
19 | .c.o:
20 | $(CC) $(CFLAGS) -c $*.c
21 |
22 |
23 |
--------------------------------------------------------------------------------
/pytron/src/blas/blas.h:
--------------------------------------------------------------------------------
1 | /* blas.h -- C header file for BLAS Ver 1.0 */
2 | /* Jesse Bennett March 23, 2000 */
3 |
4 | /** barf [ba:rf] 2. "He suggested using FORTRAN, and everybody barfed."
5 |
6 | - From The Shogakukan DICTIONARY OF NEW ENGLISH (Second edition) */
7 |
8 | #ifndef BLAS_INCLUDE
9 | #define BLAS_INCLUDE
10 |
11 | /* Data types specific to BLAS implementation */
12 | typedef struct { float r, i; } fcomplex;
13 | typedef struct { double r, i; } dcomplex;
14 | typedef int blasbool;
15 |
16 | #include "blasp.h" /* Prototypes for all BLAS functions */
17 |
18 | #define FALSE 0
19 | #define TRUE 1
20 |
21 | /* Macro functions */
22 | #define MIN(a,b) ((a) <= (b) ? (a) : (b))
23 | #define MAX(a,b) ((a) >= (b) ? (a) : (b))
24 |
25 | #endif
26 |
--------------------------------------------------------------------------------
/pytron/src/blas/blasp.h:
--------------------------------------------------------------------------------
1 | /* blasp.h -- C prototypes for BLAS Ver 1.0 */
2 | /* Jesse Bennett March 23, 2000 */
3 |
4 | /* Functions listed in alphabetical order */
5 |
6 | #ifdef F2C_COMPAT
7 |
8 | void cdotc_(fcomplex *dotval, int *n, fcomplex *cx, int *incx,
9 | fcomplex *cy, int *incy);
10 |
11 | void cdotu_(fcomplex *dotval, int *n, fcomplex *cx, int *incx,
12 | fcomplex *cy, int *incy);
13 |
14 | double sasum_(int *n, float *sx, int *incx);
15 |
16 | double scasum_(int *n, fcomplex *cx, int *incx);
17 |
18 | double scnrm2_(int *n, fcomplex *x, int *incx);
19 |
20 | double sdot_(int *n, float *sx, int *incx, float *sy, int *incy);
21 |
22 | double snrm2_(int *n, float *x, int *incx);
23 |
24 | void zdotc_(dcomplex *dotval, int *n, dcomplex *cx, int *incx,
25 | dcomplex *cy, int *incy);
26 |
27 | void zdotu_(dcomplex *dotval, int *n, dcomplex *cx, int *incx,
28 | dcomplex *cy, int *incy);
29 |
30 | #else
31 |
32 | fcomplex cdotc_(int *n, fcomplex *cx, int *incx, fcomplex *cy, int *incy);
33 |
34 | fcomplex cdotu_(int *n, fcomplex *cx, int *incx, fcomplex *cy, int *incy);
35 |
36 | float sasum_(int *n, float *sx, int *incx);
37 |
38 | float scasum_(int *n, fcomplex *cx, int *incx);
39 |
40 | float scnrm2_(int *n, fcomplex *x, int *incx);
41 |
42 | float sdot_(int *n, float *sx, int *incx, float *sy, int *incy);
43 |
44 | float snrm2_(int *n, float *x, int *incx);
45 |
46 | dcomplex zdotc_(int *n, dcomplex *cx, int *incx, dcomplex *cy, int *incy);
47 |
48 | dcomplex zdotu_(int *n, dcomplex *cx, int *incx, dcomplex *cy, int *incy);
49 |
50 | #endif
51 |
52 | /* Remaining functions listed in alphabetical order */
53 |
54 | int caxpy_(int *n, fcomplex *ca, fcomplex *cx, int *incx, fcomplex *cy,
55 | int *incy);
56 |
57 | int ccopy_(int *n, fcomplex *cx, int *incx, fcomplex *cy, int *incy);
58 |
59 | int cgbmv_(char *trans, int *m, int *n, int *kl, int *ku,
60 | fcomplex *alpha, fcomplex *a, int *lda, fcomplex *x, int *incx,
61 | fcomplex *beta, fcomplex *y, int *incy);
62 |
63 | int cgemm_(char *transa, char *transb, int *m, int *n, int *k,
64 | fcomplex *alpha, fcomplex *a, int *lda, fcomplex *b, int *ldb,
65 | fcomplex *beta, fcomplex *c, int *ldc);
66 |
67 | int cgemv_(char *trans, int *m, int *n, fcomplex *alpha, fcomplex *a,
68 | int *lda, fcomplex *x, int *incx, fcomplex *beta, fcomplex *y,
69 | int *incy);
70 |
71 | int cgerc_(int *m, int *n, fcomplex *alpha, fcomplex *x, int *incx,
72 | fcomplex *y, int *incy, fcomplex *a, int *lda);
73 |
74 | int cgeru_(int *m, int *n, fcomplex *alpha, fcomplex *x, int *incx,
75 | fcomplex *y, int *incy, fcomplex *a, int *lda);
76 |
77 | int chbmv_(char *uplo, int *n, int *k, fcomplex *alpha, fcomplex *a,
78 | int *lda, fcomplex *x, int *incx, fcomplex *beta, fcomplex *y,
79 | int *incy);
80 |
81 | int chemm_(char *side, char *uplo, int *m, int *n, fcomplex *alpha,
82 | fcomplex *a, int *lda, fcomplex *b, int *ldb, fcomplex *beta,
83 | fcomplex *c, int *ldc);
84 |
85 | int chemv_(char *uplo, int *n, fcomplex *alpha, fcomplex *a, int *lda,
86 | fcomplex *x, int *incx, fcomplex *beta, fcomplex *y, int *incy);
87 |
88 | int cher_(char *uplo, int *n, float *alpha, fcomplex *x, int *incx,
89 | fcomplex *a, int *lda);
90 |
91 | int cher2_(char *uplo, int *n, fcomplex *alpha, fcomplex *x, int *incx,
92 | fcomplex *y, int *incy, fcomplex *a, int *lda);
93 |
94 | int cher2k_(char *uplo, char *trans, int *n, int *k, fcomplex *alpha,
95 | fcomplex *a, int *lda, fcomplex *b, int *ldb, float *beta,
96 | fcomplex *c, int *ldc);
97 |
98 | int cherk_(char *uplo, char *trans, int *n, int *k, float *alpha,
99 | fcomplex *a, int *lda, float *beta, fcomplex *c, int *ldc);
100 |
101 | int chpmv_(char *uplo, int *n, fcomplex *alpha, fcomplex *ap, fcomplex *x,
102 | int *incx, fcomplex *beta, fcomplex *y, int *incy);
103 |
104 | int chpr_(char *uplo, int *n, float *alpha, fcomplex *x, int *incx,
105 | fcomplex *ap);
106 |
107 | int chpr2_(char *uplo, int *n, fcomplex *alpha, fcomplex *x, int *incx,
108 | fcomplex *y, int *incy, fcomplex *ap);
109 |
110 | int crotg_(fcomplex *ca, fcomplex *cb, float *c, fcomplex *s);
111 |
112 | int cscal_(int *n, fcomplex *ca, fcomplex *cx, int *incx);
113 |
114 | int csscal_(int *n, float *sa, fcomplex *cx, int *incx);
115 |
116 | int cswap_(int *n, fcomplex *cx, int *incx, fcomplex *cy, int *incy);
117 |
118 | int csymm_(char *side, char *uplo, int *m, int *n, fcomplex *alpha,
119 | fcomplex *a, int *lda, fcomplex *b, int *ldb, fcomplex *beta,
120 | fcomplex *c, int *ldc);
121 |
122 | int csyr2k_(char *uplo, char *trans, int *n, int *k, fcomplex *alpha,
123 | fcomplex *a, int *lda, fcomplex *b, int *ldb, fcomplex *beta,
124 | fcomplex *c, int *ldc);
125 |
126 | int csyrk_(char *uplo, char *trans, int *n, int *k, fcomplex *alpha,
127 | fcomplex *a, int *lda, fcomplex *beta, fcomplex *c, int *ldc);
128 |
129 | int ctbmv_(char *uplo, char *trans, char *diag, int *n, int *k,
130 | fcomplex *a, int *lda, fcomplex *x, int *incx);
131 |
132 | int ctbsv_(char *uplo, char *trans, char *diag, int *n, int *k,
133 | fcomplex *a, int *lda, fcomplex *x, int *incx);
134 |
135 | int ctpmv_(char *uplo, char *trans, char *diag, int *n, fcomplex *ap,
136 | fcomplex *x, int *incx);
137 |
138 | int ctpsv_(char *uplo, char *trans, char *diag, int *n, fcomplex *ap,
139 | fcomplex *x, int *incx);
140 |
141 | int ctrmm_(char *side, char *uplo, char *transa, char *diag, int *m,
142 | int *n, fcomplex *alpha, fcomplex *a, int *lda, fcomplex *b,
143 | int *ldb);
144 |
145 | int ctrmv_(char *uplo, char *trans, char *diag, int *n, fcomplex *a,
146 | int *lda, fcomplex *x, int *incx);
147 |
148 | int ctrsm_(char *side, char *uplo, char *transa, char *diag, int *m,
149 | int *n, fcomplex *alpha, fcomplex *a, int *lda, fcomplex *b,
150 | int *ldb);
151 |
152 | int ctrsv_(char *uplo, char *trans, char *diag, int *n, fcomplex *a,
153 | int *lda, fcomplex *x, int *incx);
154 |
155 | int daxpy_(int *n, double *sa, double *sx, int *incx, double *sy,
156 | int *incy);
157 |
158 | int dcopy_(int *n, double *sx, int *incx, double *sy, int *incy);
159 |
160 | int dgbmv_(char *trans, int *m, int *n, int *kl, int *ku,
161 | double *alpha, double *a, int *lda, double *x, int *incx,
162 | double *beta, double *y, int *incy);
163 |
164 | int dgemm_(char *transa, char *transb, int *m, int *n, int *k,
165 | double *alpha, double *a, int *lda, double *b, int *ldb,
166 | double *beta, double *c, int *ldc);
167 |
168 | int dgemv_(char *trans, int *m, int *n, double *alpha, double *a,
169 | int *lda, double *x, int *incx, double *beta, double *y,
170 | int *incy);
171 |
172 | int dger_(int *m, int *n, double *alpha, double *x, int *incx,
173 | double *y, int *incy, double *a, int *lda);
174 |
175 | int drot_(int *n, double *sx, int *incx, double *sy, int *incy,
176 | double *c, double *s);
177 |
178 | int drotg_(double *sa, double *sb, double *c, double *s);
179 |
180 | int dsbmv_(char *uplo, int *n, int *k, double *alpha, double *a,
181 | int *lda, double *x, int *incx, double *beta, double *y,
182 | int *incy);
183 |
184 | int dscal_(int *n, double *sa, double *sx, int *incx);
185 |
186 | int dspmv_(char *uplo, int *n, double *alpha, double *ap, double *x,
187 | int *incx, double *beta, double *y, int *incy);
188 |
189 | int dspr_(char *uplo, int *n, double *alpha, double *x, int *incx,
190 | double *ap);
191 |
192 | int dspr2_(char *uplo, int *n, double *alpha, double *x, int *incx,
193 | double *y, int *incy, double *ap);
194 |
195 | int dswap_(int *n, double *sx, int *incx, double *sy, int *incy);
196 |
197 | int dsymm_(char *side, char *uplo, int *m, int *n, double *alpha,
198 | double *a, int *lda, double *b, int *ldb, double *beta,
199 | double *c, int *ldc);
200 |
201 | int dsymv_(char *uplo, int *n, double *alpha, double *a, int *lda,
202 | double *x, int *incx, double *beta, double *y, int *incy);
203 |
204 | int dsyr_(char *uplo, int *n, double *alpha, double *x, int *incx,
205 | double *a, int *lda);
206 |
207 | int dsyr2_(char *uplo, int *n, double *alpha, double *x, int *incx,
208 | double *y, int *incy, double *a, int *lda);
209 |
210 | int dsyr2k_(char *uplo, char *trans, int *n, int *k, double *alpha,
211 | double *a, int *lda, double *b, int *ldb, double *beta,
212 | double *c, int *ldc);
213 |
214 | int dsyrk_(char *uplo, char *trans, int *n, int *k, double *alpha,
215 | double *a, int *lda, double *beta, double *c, int *ldc);
216 |
217 | int dtbmv_(char *uplo, char *trans, char *diag, int *n, int *k,
218 | double *a, int *lda, double *x, int *incx);
219 |
220 | int dtbsv_(char *uplo, char *trans, char *diag, int *n, int *k,
221 | double *a, int *lda, double *x, int *incx);
222 |
223 | int dtpmv_(char *uplo, char *trans, char *diag, int *n, double *ap,
224 | double *x, int *incx);
225 |
226 | int dtpsv_(char *uplo, char *trans, char *diag, int *n, double *ap,
227 | double *x, int *incx);
228 |
229 | int dtrmm_(char *side, char *uplo, char *transa, char *diag, int *m,
230 | int *n, double *alpha, double *a, int *lda, double *b,
231 | int *ldb);
232 |
233 | int dtrmv_(char *uplo, char *trans, char *diag, int *n, double *a,
234 | int *lda, double *x, int *incx);
235 |
236 | int dtrsm_(char *side, char *uplo, char *transa, char *diag, int *m,
237 | int *n, double *alpha, double *a, int *lda, double *b,
238 | int *ldb);
239 |
240 | int dtrsv_(char *uplo, char *trans, char *diag, int *n, double *a,
241 | int *lda, double *x, int *incx);
242 |
243 |
244 | int saxpy_(int *n, float *sa, float *sx, int *incx, float *sy, int *incy);
245 |
246 | int scopy_(int *n, float *sx, int *incx, float *sy, int *incy);
247 |
248 | int sgbmv_(char *trans, int *m, int *n, int *kl, int *ku,
249 | float *alpha, float *a, int *lda, float *x, int *incx,
250 | float *beta, float *y, int *incy);
251 |
252 | int sgemm_(char *transa, char *transb, int *m, int *n, int *k,
253 | float *alpha, float *a, int *lda, float *b, int *ldb,
254 | float *beta, float *c, int *ldc);
255 |
256 | int sgemv_(char *trans, int *m, int *n, float *alpha, float *a,
257 | int *lda, float *x, int *incx, float *beta, float *y,
258 | int *incy);
259 |
260 | int sger_(int *m, int *n, float *alpha, float *x, int *incx,
261 | float *y, int *incy, float *a, int *lda);
262 |
263 | int srot_(int *n, float *sx, int *incx, float *sy, int *incy,
264 | float *c, float *s);
265 |
266 | int srotg_(float *sa, float *sb, float *c, float *s);
267 |
268 | int ssbmv_(char *uplo, int *n, int *k, float *alpha, float *a,
269 | int *lda, float *x, int *incx, float *beta, float *y,
270 | int *incy);
271 |
272 | int sscal_(int *n, float *sa, float *sx, int *incx);
273 |
274 | int sspmv_(char *uplo, int *n, float *alpha, float *ap, float *x,
275 | int *incx, float *beta, float *y, int *incy);
276 |
277 | int sspr_(char *uplo, int *n, float *alpha, float *x, int *incx,
278 | float *ap);
279 |
280 | int sspr2_(char *uplo, int *n, float *alpha, float *x, int *incx,
281 | float *y, int *incy, float *ap);
282 |
283 | int sswap_(int *n, float *sx, int *incx, float *sy, int *incy);
284 |
285 | int ssymm_(char *side, char *uplo, int *m, int *n, float *alpha,
286 | float *a, int *lda, float *b, int *ldb, float *beta,
287 | float *c, int *ldc);
288 |
289 | int ssymv_(char *uplo, int *n, float *alpha, float *a, int *lda,
290 | float *x, int *incx, float *beta, float *y, int *incy);
291 |
292 | int ssyr_(char *uplo, int *n, float *alpha, float *x, int *incx,
293 | float *a, int *lda);
294 |
295 | int ssyr2_(char *uplo, int *n, float *alpha, float *x, int *incx,
296 | float *y, int *incy, float *a, int *lda);
297 |
298 | int ssyr2k_(char *uplo, char *trans, int *n, int *k, float *alpha,
299 | float *a, int *lda, float *b, int *ldb, float *beta,
300 | float *c, int *ldc);
301 |
302 | int ssyrk_(char *uplo, char *trans, int *n, int *k, float *alpha,
303 | float *a, int *lda, float *beta, float *c, int *ldc);
304 |
305 | int stbmv_(char *uplo, char *trans, char *diag, int *n, int *k,
306 | float *a, int *lda, float *x, int *incx);
307 |
308 | int stbsv_(char *uplo, char *trans, char *diag, int *n, int *k,
309 | float *a, int *lda, float *x, int *incx);
310 |
311 | int stpmv_(char *uplo, char *trans, char *diag, int *n, float *ap,
312 | float *x, int *incx);
313 |
314 | int stpsv_(char *uplo, char *trans, char *diag, int *n, float *ap,
315 | float *x, int *incx);
316 |
317 | int strmm_(char *side, char *uplo, char *transa, char *diag, int *m,
318 | int *n, float *alpha, float *a, int *lda, float *b,
319 | int *ldb);
320 |
321 | int strmv_(char *uplo, char *trans, char *diag, int *n, float *a,
322 | int *lda, float *x, int *incx);
323 |
324 | int strsm_(char *side, char *uplo, char *transa, char *diag, int *m,
325 | int *n, float *alpha, float *a, int *lda, float *b,
326 | int *ldb);
327 |
328 | int strsv_(char *uplo, char *trans, char *diag, int *n, float *a,
329 | int *lda, float *x, int *incx);
330 |
331 | int zaxpy_(int *n, dcomplex *ca, dcomplex *cx, int *incx, dcomplex *cy,
332 | int *incy);
333 |
334 | int zcopy_(int *n, dcomplex *cx, int *incx, dcomplex *cy, int *incy);
335 |
336 | int zdscal_(int *n, double *sa, dcomplex *cx, int *incx);
337 |
338 | int zgbmv_(char *trans, int *m, int *n, int *kl, int *ku,
339 | dcomplex *alpha, dcomplex *a, int *lda, dcomplex *x, int *incx,
340 | dcomplex *beta, dcomplex *y, int *incy);
341 |
342 | int zgemm_(char *transa, char *transb, int *m, int *n, int *k,
343 | dcomplex *alpha, dcomplex *a, int *lda, dcomplex *b, int *ldb,
344 | dcomplex *beta, dcomplex *c, int *ldc);
345 |
346 | int zgemv_(char *trans, int *m, int *n, dcomplex *alpha, dcomplex *a,
347 | int *lda, dcomplex *x, int *incx, dcomplex *beta, dcomplex *y,
348 | int *incy);
349 |
350 | int zgerc_(int *m, int *n, dcomplex *alpha, dcomplex *x, int *incx,
351 | dcomplex *y, int *incy, dcomplex *a, int *lda);
352 |
353 | int zgeru_(int *m, int *n, dcomplex *alpha, dcomplex *x, int *incx,
354 | dcomplex *y, int *incy, dcomplex *a, int *lda);
355 |
356 | int zhbmv_(char *uplo, int *n, int *k, dcomplex *alpha, dcomplex *a,
357 | int *lda, dcomplex *x, int *incx, dcomplex *beta, dcomplex *y,
358 | int *incy);
359 |
360 | int zhemm_(char *side, char *uplo, int *m, int *n, dcomplex *alpha,
361 | dcomplex *a, int *lda, dcomplex *b, int *ldb, dcomplex *beta,
362 | dcomplex *c, int *ldc);
363 |
364 | int zhemv_(char *uplo, int *n, dcomplex *alpha, dcomplex *a, int *lda,
365 | dcomplex *x, int *incx, dcomplex *beta, dcomplex *y, int *incy);
366 |
367 | int zher_(char *uplo, int *n, double *alpha, dcomplex *x, int *incx,
368 | dcomplex *a, int *lda);
369 |
370 | int zher2_(char *uplo, int *n, dcomplex *alpha, dcomplex *x, int *incx,
371 | dcomplex *y, int *incy, dcomplex *a, int *lda);
372 |
373 | int zher2k_(char *uplo, char *trans, int *n, int *k, dcomplex *alpha,
374 | dcomplex *a, int *lda, dcomplex *b, int *ldb, double *beta,
375 | dcomplex *c, int *ldc);
376 |
377 | int zherk_(char *uplo, char *trans, int *n, int *k, double *alpha,
378 | dcomplex *a, int *lda, double *beta, dcomplex *c, int *ldc);
379 |
380 | int zhpmv_(char *uplo, int *n, dcomplex *alpha, dcomplex *ap, dcomplex *x,
381 | int *incx, dcomplex *beta, dcomplex *y, int *incy);
382 |
383 | int zhpr_(char *uplo, int *n, double *alpha, dcomplex *x, int *incx,
384 | dcomplex *ap);
385 |
386 | int zhpr2_(char *uplo, int *n, dcomplex *alpha, dcomplex *x, int *incx,
387 | dcomplex *y, int *incy, dcomplex *ap);
388 |
389 | int zrotg_(dcomplex *ca, dcomplex *cb, double *c, dcomplex *s);
390 |
391 | int zscal_(int *n, dcomplex *ca, dcomplex *cx, int *incx);
392 |
393 | int zswap_(int *n, dcomplex *cx, int *incx, dcomplex *cy, int *incy);
394 |
395 | int zsymm_(char *side, char *uplo, int *m, int *n, dcomplex *alpha,
396 | dcomplex *a, int *lda, dcomplex *b, int *ldb, dcomplex *beta,
397 | dcomplex *c, int *ldc);
398 |
399 | int zsyr2k_(char *uplo, char *trans, int *n, int *k, dcomplex *alpha,
400 | dcomplex *a, int *lda, dcomplex *b, int *ldb, dcomplex *beta,
401 | dcomplex *c, int *ldc);
402 |
403 | int zsyrk_(char *uplo, char *trans, int *n, int *k, dcomplex *alpha,
404 | dcomplex *a, int *lda, dcomplex *beta, dcomplex *c, int *ldc);
405 |
406 | int ztbmv_(char *uplo, char *trans, char *diag, int *n, int *k,
407 | dcomplex *a, int *lda, dcomplex *x, int *incx);
408 |
409 | int ztbsv_(char *uplo, char *trans, char *diag, int *n, int *k,
410 | dcomplex *a, int *lda, dcomplex *x, int *incx);
411 |
412 | int ztpmv_(char *uplo, char *trans, char *diag, int *n, dcomplex *ap,
413 | dcomplex *x, int *incx);
414 |
415 | int ztpsv_(char *uplo, char *trans, char *diag, int *n, dcomplex *ap,
416 | dcomplex *x, int *incx);
417 |
418 | int ztrmm_(char *side, char *uplo, char *transa, char *diag, int *m,
419 | int *n, dcomplex *alpha, dcomplex *a, int *lda, dcomplex *b,
420 | int *ldb);
421 |
422 | int ztrmv_(char *uplo, char *trans, char *diag, int *n, dcomplex *a,
423 | int *lda, dcomplex *x, int *incx);
424 |
425 | int ztrsm_(char *side, char *uplo, char *transa, char *diag, int *m,
426 | int *n, dcomplex *alpha, dcomplex *a, int *lda, dcomplex *b,
427 | int *ldb);
428 |
429 | int ztrsv_(char *uplo, char *trans, char *diag, int *n, dcomplex *a,
430 | int *lda, dcomplex *x, int *incx);
431 |
--------------------------------------------------------------------------------
/pytron/src/blas/daxpy.c:
--------------------------------------------------------------------------------
1 | #include "blas.h"
2 |
3 | int daxpy_(int *n, double *sa, double *sx, int *incx, double *sy,
4 | int *incy)
5 | {
6 | long int i, m, ix, iy, nn, iincx, iincy;
7 | register double ssa;
8 |
9 | /* constant times a vector plus a vector.
10 | uses unrolled loop for increments equal to one.
11 | jack dongarra, linpack, 3/11/78.
12 | modified 12/3/93, array(1) declarations changed to array(*) */
13 |
14 | /* Dereference inputs */
15 | nn = *n;
16 | ssa = *sa;
17 | iincx = *incx;
18 | iincy = *incy;
19 |
20 | if( nn > 0 && ssa != 0.0 )
21 | {
22 | if (iincx == 1 && iincy == 1) /* code for both increments equal to 1 */
23 | {
24 | m = nn-3;
25 | for (i = 0; i < m; i += 4)
26 | {
27 | sy[i] += ssa * sx[i];
28 | sy[i+1] += ssa * sx[i+1];
29 | sy[i+2] += ssa * sx[i+2];
30 | sy[i+3] += ssa * sx[i+3];
31 | }
32 | for ( ; i < nn; ++i) /* clean-up loop */
33 | sy[i] += ssa * sx[i];
34 | }
35 | else /* code for unequal increments or equal increments not equal to 1 */
36 | {
37 | ix = iincx >= 0 ? 0 : (1 - nn) * iincx;
38 | iy = iincy >= 0 ? 0 : (1 - nn) * iincy;
39 | for (i = 0; i < nn; i++)
40 | {
41 | sy[iy] += ssa * sx[ix];
42 | ix += iincx;
43 | iy += iincy;
44 | }
45 | }
46 | }
47 |
48 | return 0;
49 | } /* daxpy_ */
50 |
--------------------------------------------------------------------------------
/pytron/src/blas/ddot.c:
--------------------------------------------------------------------------------
1 | #include "blas.h"
2 |
3 | double ddot_(int *n, double *sx, int *incx, double *sy, int *incy)
4 | {
5 | long int i, m, nn, iincx, iincy;
6 | double stemp;
7 | long int ix, iy;
8 |
9 | /* forms the dot product of two vectors.
10 | uses unrolled loops for increments equal to one.
11 | jack dongarra, linpack, 3/11/78.
12 | modified 12/3/93, array(1) declarations changed to array(*) */
13 |
14 | /* Dereference inputs */
15 | nn = *n;
16 | iincx = *incx;
17 | iincy = *incy;
18 |
19 | stemp = 0.0;
20 | if (nn > 0)
21 | {
22 | if (iincx == 1 && iincy == 1) /* code for both increments equal to 1 */
23 | {
24 | m = nn-4;
25 | for (i = 0; i < m; i += 5)
26 | stemp += sx[i] * sy[i] + sx[i+1] * sy[i+1] + sx[i+2] * sy[i+2] +
27 | sx[i+3] * sy[i+3] + sx[i+4] * sy[i+4];
28 |
29 | for ( ; i < nn; i++) /* clean-up loop */
30 | stemp += sx[i] * sy[i];
31 | }
32 | else /* code for unequal increments or equal increments not equal to 1 */
33 | {
34 | ix = 0;
35 | iy = 0;
36 | if (iincx < 0)
37 | ix = (1 - nn) * iincx;
38 | if (iincy < 0)
39 | iy = (1 - nn) * iincy;
40 | for (i = 0; i < nn; i++)
41 | {
42 | stemp += sx[ix] * sy[iy];
43 | ix += iincx;
44 | iy += iincy;
45 | }
46 | }
47 | }
48 |
49 | return stemp;
50 | } /* ddot_ */
51 |
--------------------------------------------------------------------------------
/pytron/src/blas/dnrm2.c:
--------------------------------------------------------------------------------
1 | #include /* Needed for fabs() and sqrt() */
2 | #include "blas.h"
3 |
4 | double dnrm2_(int *n, double *x, int *incx)
5 | {
6 | long int ix, nn, iincx;
7 | double norm, scale, absxi, ssq, temp;
8 |
9 | /* DNRM2 returns the euclidean norm of a vector via the function
10 | name, so that
11 |
12 | DNRM2 := sqrt( x'*x )
13 |
14 | -- This version written on 25-October-1982.
15 | Modified on 14-October-1993 to inline the call to SLASSQ.
16 | Sven Hammarling, Nag Ltd. */
17 |
18 | /* Dereference inputs */
19 | nn = *n;
20 | iincx = *incx;
21 |
22 | if( nn > 0 && iincx > 0 )
23 | {
24 | if (nn == 1)
25 | {
26 | norm = fabs(x[0]);
27 | }
28 | else
29 | {
30 | scale = 0.0;
31 | ssq = 1.0;
32 |
33 | /* The following loop is equivalent to this call to the LAPACK
34 | auxiliary routine: CALL SLASSQ( N, X, INCX, SCALE, SSQ ) */
35 |
36 | for (ix=(nn-1)*iincx; ix>=0; ix-=iincx)
37 | {
38 | if (x[ix] != 0.0)
39 | {
40 | absxi = fabs(x[ix]);
41 | if (scale < absxi)
42 | {
43 | temp = scale / absxi;
44 | ssq = ssq * (temp * temp) + 1.0;
45 | scale = absxi;
46 | }
47 | else
48 | {
49 | temp = absxi / scale;
50 | ssq += temp * temp;
51 | }
52 | }
53 | }
54 | norm = scale * sqrt(ssq);
55 | }
56 | }
57 | else
58 | norm = 0.0;
59 |
60 | return norm;
61 |
62 | } /* dnrm2_ */
63 |
--------------------------------------------------------------------------------
/pytron/src/blas/dscal.c:
--------------------------------------------------------------------------------
1 | #include "blas.h"
2 |
3 | int dscal_(int *n, double *sa, double *sx, int *incx)
4 | {
5 | long int i, m, nincx, nn, iincx;
6 | double ssa;
7 |
8 | /* scales a vector by a constant.
9 | uses unrolled loops for increment equal to 1.
10 | jack dongarra, linpack, 3/11/78.
11 | modified 3/93 to return if incx .le. 0.
12 | modified 12/3/93, array(1) declarations changed to array(*) */
13 |
14 | /* Dereference inputs */
15 | nn = *n;
16 | iincx = *incx;
17 | ssa = *sa;
18 |
19 | if (nn > 0 && iincx > 0)
20 | {
21 | if (iincx == 1) /* code for increment equal to 1 */
22 | {
23 | m = nn-4;
24 | for (i = 0; i < m; i += 5)
25 | {
26 | sx[i] = ssa * sx[i];
27 | sx[i+1] = ssa * sx[i+1];
28 | sx[i+2] = ssa * sx[i+2];
29 | sx[i+3] = ssa * sx[i+3];
30 | sx[i+4] = ssa * sx[i+4];
31 | }
32 | for ( ; i < nn; ++i) /* clean-up loop */
33 | sx[i] = ssa * sx[i];
34 | }
35 | else /* code for increment not equal to 1 */
36 | {
37 | nincx = nn * iincx;
38 | for (i = 0; i < nincx; i += iincx)
39 | sx[i] = ssa * sx[i];
40 | }
41 | }
42 |
43 | return 0;
44 | } /* dscal_ */
45 |
--------------------------------------------------------------------------------
/pytron/src/tron.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "tron.h"
6 |
7 | #ifndef min
8 | template static inline T min(T x,T y) { return (x static inline T max(T x,T y) { return (x>y)?x:y; }
13 | #endif
14 |
15 | #ifdef __cplusplus
16 | extern "C" {
17 | #endif
18 |
19 | extern double dnrm2_(int *, double *, int *);
20 | extern double ddot_(int *, double *, int *, double *, int *);
21 | extern int daxpy_(int *, double *, double *, int *, double *, int *);
22 | extern int dscal_(int *, double *, double *, int *);
23 |
24 | #ifdef __cplusplus
25 | }
26 | #endif
27 |
28 | static void default_print(const char *buf)
29 | {
30 | fputs(buf,stdout);
31 | fflush(stdout);
32 | }
33 |
34 | void TRON::info(const char *fmt,...)
35 | {
36 | char buf[BUFSIZ];
37 | va_list ap;
38 | va_start(ap,fmt);
39 | vsprintf(buf,fmt,ap);
40 | va_end(ap);
41 | (*tron_print_string)(buf);
42 | }
43 |
44 | TRON::TRON(const function *fun_obj, double tol, double gtol, int max_iter)
45 | {
46 | this->fun_obj=const_cast(fun_obj);
47 | this->gtol=gtol;
48 | this->tol=tol;
49 | this->max_iter=max_iter;
50 | tron_print_string = default_print;
51 | this->n_iter = 0;
52 | this->gnorm = 0.;
53 | }
54 |
55 | TRON::~TRON()
56 | {
57 | }
58 |
59 | void TRON::tron(double *w, double *g, int verbose)
60 | {
61 | /*
62 | actred : Actual reduction
63 | prered : predicted reduction
64 | */
65 | // Parameters for updating the iterates.
66 | double eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75;
67 |
68 | // Parameters for updating the trust region size delta.
69 | double sigma1 = 0.25, sigma2 = 0.5, sigma3 = 4;
70 |
71 | int n = fun_obj->get_nr_variable();
72 | int i, cg_iter;
73 | double delta, snorm, one=1.0;
74 | double alpha, f, fnew, prered, actred, gs;
75 | int search = 1, iter = 1, inc = 1;
76 | double *s = new double[n];
77 | double *r = new double[n];
78 | double *w_new = new double[n];
79 |
80 | /* Edit (Fabian): allow for warm restarts
81 | for (i=0; ifun(w);
86 | fun_obj->grad(w, g);
87 | delta = dnrm2_(&n, g, &inc); // TODO: use infinity norm
88 | double gnorm1 = delta;
89 | double gnorm = gnorm1;
90 |
91 | if (gnorm <= gtol)
92 | search = 0;
93 |
94 | iter = 1;
95 |
96 | while (iter <= max_iter && search)
97 | {
98 |
99 | cg_iter = trcg(delta, g, s, r);
100 |
101 | memcpy(w_new, w, sizeof(double)*n);
102 | daxpy_(&n, &one, s, &inc, w_new, &inc);
103 |
104 | gs = ddot_(&n, g, &inc, s, &inc);
105 | prered = -0.5*(gs-ddot_(&n, s, &inc, r, &inc));
106 | fnew = fun_obj->fun(w_new);
107 |
108 | // Compute the actual reduction.
109 | actred = f - fnew;
110 |
111 | // On the first iteration, adjust the initial step bound.
112 | snorm = dnrm2_(&n, s, &inc);
113 | if (iter == 1)
114 | delta = min(delta, snorm);
115 |
116 | // Compute prediction alpha*snorm of the step.
117 | if (fnew - f - gs <= 0)
118 | alpha = sigma3;
119 | else
120 | alpha = max(sigma1, -0.5*(gs/(fnew - f - gs)));
121 |
122 | // Update the trust region bound according to the ratio of actual to predicted reduction.
123 | if (actred < eta0*prered)
124 | delta = min(max(alpha, sigma1)*snorm, sigma2*delta);
125 | else if (actred < eta1*prered)
126 | delta = max(sigma1*delta, min(alpha*snorm, sigma2*delta));
127 | else if (actred < eta2*prered)
128 | delta = max(sigma1*delta, min(alpha*snorm, sigma3*delta));
129 | else
130 | delta = max(delta, min(alpha*snorm, sigma3*delta));
131 |
132 | if (verbose)
133 | info("iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d\n", iter, actred, prered, delta, f, gnorm, cg_iter);
134 |
135 | if (actred > eta0*prered)
136 | {
137 | iter++;
138 | memcpy(w, w_new, sizeof(double)*n);
139 | f = fnew;
140 | fun_obj->grad(w, g);
141 |
142 | gnorm = norm_inf(n, g);
143 | // gnorm = dnrm2_(&n, g, &inc);
144 | if (gnorm <= gtol)
145 | break;
146 | }
147 | fun_obj->callback(w);
148 |
149 | if (f < -1.0e+32)
150 | {
151 | info("WARNING: f < -1.0e+32\n");
152 | break;
153 | }
154 | if (fabs(actred) <= 0 && prered <= 0)
155 | {
156 | info("WARNING: actred and prered <= 0\n");
157 | break;
158 | }
159 | if (fabs(actred) <= tol*fabs(f) &&
160 | fabs(prered) <= tol*fabs(f))
161 | {
162 | info("WARNING: actred and prered too small\n");
163 | break;
164 | }
165 | }
166 |
167 | this->n_iter = iter;
168 | this->gnorm = gnorm;
169 | this->fun = fun;
170 |
171 | delete[] r;
172 | delete[] w_new;
173 | delete[] s;
174 | }
175 |
176 | int TRON::trcg(double delta, double *g, double *s, double *r)
177 | {
178 | int i, inc = 1;
179 | int n = fun_obj->get_nr_variable();
180 | double one = 1;
181 | double *d = new double[n];
182 | double *Hd = new double[n];
183 | double rTr, rnewTrnew, alpha, beta, cgtol;
184 |
185 | for (i=0; iHv(d, Hd);
201 |
202 | alpha = rTr/ddot_(&n, d, &inc, Hd, &inc);
203 | daxpy_(&n, &alpha, d, &inc, s, &inc);
204 | if (dnrm2_(&n, s, &inc) > delta)
205 | {
206 | info("cg reaches trust region boundary\n");
207 | alpha = -alpha;
208 | daxpy_(&n, &alpha, d, &inc, s, &inc);
209 |
210 | double std = ddot_(&n, s, &inc, d, &inc);
211 | double sts = ddot_(&n, s, &inc, s, &inc);
212 | double dtd = ddot_(&n, d, &inc, d, &inc);
213 | double dsq = delta*delta;
214 | double rad = sqrt(std*std + dtd*(dsq-sts));
215 | if (std >= 0)
216 | alpha = (dsq - sts)/(std + rad);
217 | else
218 | alpha = (rad - std)/dtd;
219 | daxpy_(&n, &alpha, d, &inc, s, &inc);
220 | alpha = -alpha;
221 | daxpy_(&n, &alpha, Hd, &inc, r, &inc);
222 | break;
223 | }
224 | alpha = -alpha;
225 | daxpy_(&n, &alpha, Hd, &inc, r, &inc);
226 | rnewTrnew = ddot_(&n, r, &inc, r, &inc);
227 | beta = rnewTrnew/rTr;
228 | dscal_(&n, &beta, d, &inc);
229 | daxpy_(&n, &one, r, &inc, d, &inc);
230 | rTr = rnewTrnew;
231 | }
232 |
233 | delete[] d;
234 | delete[] Hd;
235 |
236 | return(cg_iter);
237 | }
238 |
239 | double TRON::norm_inf(int n, double *x)
240 | {
241 | double dmax = fabs(x[0]);
242 | for (int i=1; i= dmax)
244 | dmax = fabs(x[i]);
245 | return(dmax);
246 | }
247 |
248 | void TRON::set_print_string(void (*print_string) (const char *buf))
249 | {
250 | tron_print_string = print_string;
251 | }
252 |
--------------------------------------------------------------------------------
/pytron/src/tron.h:
--------------------------------------------------------------------------------
1 | #ifndef _TRON_H
2 | #define _TRON_H
3 |
4 | class function
5 | {
6 | public:
7 | virtual double fun(double *w) = 0 ;
8 | virtual void grad(double *w, double *g) = 0 ;
9 | virtual void Hv(double *s, double *Hs) = 0 ;
10 | virtual void callback(double *w) = 0 ;
11 |
12 | virtual int get_nr_variable(void) = 0 ;
13 | virtual ~function(void){}
14 | };
15 |
16 | class TRON
17 | {
18 | public:
19 | TRON(const function *fun_obj, double tol=0.1, double gtol = 0.1,
20 | int max_iter = 1000);
21 | ~TRON();
22 |
23 | void tron(double *w, double *g, int verbose);
24 | void set_print_string(void (*i_print) (const char *buf));
25 | int n_iter;
26 | double gnorm;
27 | double fun;
28 |
29 | private:
30 | int trcg(double delta, double *g, double *s, double *r);
31 | double norm_inf(int n, double *x);
32 |
33 | double gtol;
34 | double tol;
35 | int max_iter;
36 | function *fun_obj;
37 | void info(const char *fmt,...);
38 | void (*tron_print_string)(const char *buf);
39 | };
40 | #endif
41 |
--------------------------------------------------------------------------------
/pytron/src/tron_helper.cpp:
--------------------------------------------------------------------------------
1 | #include "tron_helper.h"
2 | #include
3 | #include
4 |
5 | double func_callback::fun(double *w)
6 | {
7 | double t;
8 | t = c_func(w, py_func, this->nr_variable, this->py_args);
9 | return t;
10 | }
11 |
12 | void func_callback::grad(double *w, double *g)
13 | {
14 | c_grad(w, py_grad_hess, &this->py_hess, g, this->nr_variable,
15 | this->py_args);
16 | }
17 |
18 | void func_callback::Hv(double *s, double *Hs)
19 | {
20 | c_hess(s, this->py_hess, Hs, this->nr_variable, this->py_args);
21 | }
22 |
23 | int func_callback::get_nr_variable(void)
24 | {
25 | return nr_variable;
26 | }
27 |
28 | void func_callback::callback(double *w)
29 | {
30 | if (this->py_callback != NULL)
31 | c_callback(w, this->py_callback, this->nr_variable, this->py_args);
32 | }
--------------------------------------------------------------------------------
/pytron/src/tron_helper.h:
--------------------------------------------------------------------------------
1 | #include "tron.h"
2 |
3 | typedef double (*func_cb)(double *, void *, int, void *);
4 | typedef int (*grad_cb)(double *, void *, void **, double *, int, void *);
5 | typedef int (*hess_cb)(double *, void *, double *, int, void *);
6 |
7 | class func_callback: public function {
8 |
9 | public:
10 | func_callback(double *x0, void *py_func, func_cb c_func,
11 | void *py_grad_hess, grad_cb c_grad, hess_cb c_hess,
12 | void *py_callback, func_cb c_callback, int nr_variable, void *py_args) {
13 | this->w = new double[nr_variable];
14 | this->py_func = py_func;
15 | this->py_grad_hess = py_grad_hess;
16 | this->py_callback = py_callback;
17 | this->c_func = c_func;
18 | this->c_grad = c_grad;
19 | this->c_hess = c_hess;
20 | this->c_callback = c_callback;
21 | this->nr_variable = nr_variable;
22 | this->py_args = py_args;
23 | };
24 |
25 | ~func_callback() {
26 | delete this->w;
27 | };
28 | double fun(double *w);
29 | void grad(double *w, double *g);
30 | void Hv(double *s, double *Hs);
31 | void callback(double *w);
32 | int get_nr_variable(void);
33 |
34 | protected:
35 | double tmp;
36 | double *w;
37 | func_cb c_func;
38 | grad_cb c_grad;
39 | hess_cb c_hess;
40 | func_cb c_callback;
41 | void *py_func;
42 | void *py_grad_hess;
43 | void *py_hess;
44 | void *py_args;
45 | void *py_callback;
46 | int nr_variable;
47 | };
48 |
49 |
--------------------------------------------------------------------------------
/pytron/tron.pyx:
--------------------------------------------------------------------------------
1 | from cython cimport view
2 | import numpy as np
3 | from scipy import optimize
4 | cimport numpy as np
5 | from libc cimport string
6 | from cpython cimport Py_INCREF, Py_XDECREF, PyObject
7 |
8 | cdef extern from "tron_helper.h":
9 | ctypedef double (*func_cb)(double *, void *, int, void *)
10 | ctypedef int (*grad_cb)(double *, void *, void **, double *, int, void *)
11 | ctypedef int (*hess_cb)(double *, void *, double *, int, void *)
12 | cdef cppclass func_callback:
13 | func_callback(double *, void *, func_cb,
14 | void *, grad_cb, hess_cb, void *, func_cb, int nr_variable, void *)
15 |
16 |
17 | cdef extern from "tron.h":
18 | cdef cppclass TRON:
19 | TRON(func_callback *, double, double, int)
20 | void tron(double *, double *, int)
21 | int n_iter
22 | double gnorm
23 | double fun
24 |
25 |
26 | cdef double c_func(double *w, void *f_py, int nr_variable, void *py_args):
27 | cdef view.array w0 = view.array(shape=(nr_variable,), itemsize=sizeof(double),
28 | mode='c', format='d', allocate_buffer=False)
29 | cdef double b
30 | w0.data = w
31 | # TODO: exception return value
32 | out = (