├── .gitignore ├── LICENSE ├── README.md ├── bin ├── PackedChol0.mexa64 ├── PackedChol0.mexmaci64 ├── PackedChol0.mexw64 ├── PackedChol0SSE.mexa64 ├── PackedChol0SSE.mexw64 ├── PackedChol0arm.mexmaci64 ├── PackedChol1.mexa64 ├── PackedChol1.mexw64 ├── PackedChol1SSE.mexa64 ├── PackedChol1SSE.mexw64 ├── PackedChol1arm.mexmaci64 ├── PackedChol4.mexa64 ├── PackedChol4.mexmaci64 ├── PackedChol4.mexw64 ├── PackedChol4SSE.mexa64 ├── PackedChol4SSE.mexw64 ├── PackedChol4arm.mexmaci64 ├── cpuInfo.mexa64 ├── cpuInfo.mexmaci64 └── cpuInfo.mexw64 ├── code ├── Hamiltonian.m ├── barrier │ ├── TwoSidedBarrier.m │ └── WeightedTwoSidedBarrier.m ├── default_options.m ├── diagnostics │ ├── distribution_test.m │ ├── effective_sample_size.m │ ├── rhat.m │ └── summary.m ├── gauss_legendre.m ├── implicit_midpoint.m ├── integrate.m ├── module │ ├── DebugLogger.m │ ├── DynamicRegularizer.m │ ├── DynamicStepSize.m │ ├── DynamicWeight.m │ ├── MemoryStorage.m │ ├── MixingTimeEstimator.m │ ├── ProgressBar.m │ └── Sampler.m ├── prepare │ ├── Polytope.m │ ├── analytic_center.m │ ├── gmscale.m │ ├── lewis_center.m │ └── standardize_problem.m ├── sample.m ├── solver │ ├── FeatureDetector │ │ ├── cpuInfo.cpp │ │ ├── cpu_x86.cpp │ │ ├── cpu_x86.h │ │ ├── cpu_x86_Linux.ipp │ │ ├── cpu_x86_Windows.ipp │ │ └── readme.txt │ ├── MatlabSolver.m │ ├── MexSolver.m │ ├── MultiMatlabSolver.m │ ├── PackedCSparse │ │ ├── FloatArray.h │ │ ├── FloatArrayAVX2.h │ │ ├── PackedChol.h │ │ ├── SparseMatrix.h │ │ ├── add.h │ │ ├── chol.h │ │ ├── leverage.h │ │ ├── leverageJL.h │ │ ├── multiply.h │ │ ├── outerprod.h │ │ ├── projinv.h │ │ └── transpose.h │ ├── PackedChol.cpp │ ├── RHMC_compile_all.m │ ├── Solver.m │ ├── batch_pcg.m │ ├── compile_solver.m │ ├── mex_utils.h │ └── qd │ │ ├── COPYING │ │ ├── NEWS │ │ ├── README │ │ ├── bits.cc │ │ ├── bits.h │ │ ├── c_dd.cc │ │ ├── c_dd.h │ │ ├── c_qd.cc │ │ ├── c_qd.h │ │ ├── dd_const.cc │ │ ├── dd_inline.h │ │ ├── dd_real.cc │ │ ├── dd_real.h │ │ ├── fpu.cc │ │ ├── fpu.h │ │ ├── inline.h │ │ ├── qd.pdf │ │ ├── qd_config.h │ │ ├── qd_const.cc │ │ ├── qd_inline.h │ │ ├── qd_real.cc │ │ ├── qd_real.h │ │ ├── util.cc │ │ └── util.h └── utils │ ├── Setfield.m │ ├── TableDisplay.m │ ├── blendv.m │ ├── dblcmp.m │ ├── nonempty.m │ ├── spdiag.m │ └── timeit.m ├── coverage ├── Recon1.mat ├── TestSuite.m ├── coverage_test.m ├── p_value_test.m ├── presolve_test.m ├── problems │ ├── .DS_Store │ ├── SOURCE.txt │ ├── basic │ │ ├── birkhoff.m │ │ ├── long_box.m │ │ ├── polytope_box.m │ │ ├── random_sparse.m │ │ ├── simplex.m │ │ ├── tv_ball.m │ │ └── tv_ball2.m │ ├── loadProblem.m │ ├── metabolic │ │ ├── Acidaminococcus_sp_D21.mat │ │ ├── Recon1.mat │ │ ├── Recon2.mat │ │ ├── Recon3.mat │ │ └── cardiac_mit_glcuptake_atpmax.mat │ ├── netlib │ │ ├── 25fv47.mat │ │ ├── 80bau3b.mat │ │ ├── afiro.mat │ │ ├── agg.mat │ │ ├── beaconfd.mat │ │ ├── blend.mat │ │ ├── degen2.mat │ │ ├── degen3.mat │ │ ├── etamacro.mat │ │ ├── scorpion.mat │ │ ├── sierra.mat │ │ └── truss.mat │ └── problemList.m ├── sample_test.m └── solver │ ├── solver_scalar_test.m │ ├── solver_simd_test.m │ ├── solver_test.m │ ├── solver_test_2.m │ └── solver_zero_test.m ├── demo.m └── initSampler.m /.gitignore: -------------------------------------------------------------------------------- 1 | *.asv 2 | *.log 3 | *.icloud 4 | *.m~ 5 | *.pdb 6 | *_ignore* 7 | .DS_Store 8 | coverage/problems/*extra/ 9 | .MATLABDriveTag 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PolytopeSampler 2 | 3 | PolytopeSampler is a `Matlab` implementation of constrained Riemannian Hamiltonian Monte Carlo for sampling from high dimensional disributions on polytopes. It is able to sample efficiently from sets and distributions with more than 100K dimensions. 4 | 5 | ## Quick Tutorial 6 | 7 | PolytopeSampler samples from distributions of the form `exp(-f(x))`, for a convex function `f`, subject to constraints `Aineq * x <= bineq`, `Aeq * x = beq` and `lb <= x <= ub`. 8 | 9 | The function `f` can be specified by arrays containing its first and second derivative or function handles. Only the first derivative is required. By default, `f` is empty, which represents a uniform distribution. If the first derivative is a function handle, then the function and its second derivatives must also be provided. 10 | 11 | To sample `N` points from a polytope `P`, you can call `sample(P, N)`. The function `sample` will 12 | 1. Find an initial feasible point 13 | 2. Run constrained Hamiltonian Monte Carlo 14 | 3. Test convergence of the sampling algorithm by computing Effective Sample Size (ESS) and terminate when `ESS >= N`. If the target distribution is uniform, a uniformity test will also be performed. 15 | 16 | Extra parameters can be set up using `opts`. Some useful parameters include `maxTime` and `maxStep`. By default, they are set to 17 | ``` 18 | maxTime: 86400 (max sampling time in seconds) 19 | maxStep: 300000 (maximum number of steps) 20 | ``` 21 | The output is a struct `o`, which stores samples generated in `o.samples` and a summary of the sample in `o.summary`. `o.samples` is an array of size `dim x #steps`. 22 | 23 | 24 | ### Example 25 | 26 | We demonstrate PolytopeSampler using a simple example, sampling uniformly from a simplex. 27 | The polytope is defined by 28 | 29 | ``` 30 | >> P = struct; 31 | >> d = 10; 32 | >> P.Aeq = ones(1, d); 33 | >> P.beq = 1; 34 | >> P.lb = zeros(d, 1); 35 | ``` 36 | The polytope has dimension `d = 10` with constraint `sum_i x_i = 1` and `x >= 0`. This is a simplex. 37 | To generate 200 samples uniformly from the polytope `P`, we call the function `sample()`. 38 | ``` 39 | >> o = sample(P, 200); 40 | Time spent | Time reamin | Progress | Samples | AccProb | StepSize | MixTime 41 | 00d:00:00:01 | 00d:00:00:00 | ######################### | 211/200 | 0.989903 | 0.200000 | 11.2 42 | Done! 43 | ``` 44 | We can access the samples generated using 45 | ``` 46 | >> o.samples 47 | ``` 48 | We can print a summary of the samples: 49 | ``` 50 | >> o.summary 51 | 52 | ans = 53 | 54 | 10×7 table 55 | 56 | mean std 25% 50% 75% n_ess r_hat 57 | ________ ________ ________ ________ _______ ______ _______ 58 | 59 | samples[1] 0.093187 0.091207 0.026222 0.064326 0.13375 221.51 0.99954 60 | samples[2] 0.092815 0.086905 0.027018 0.066017 0.13221 234.59 1.0301 61 | samples[3] 0.10034 0.090834 0.030968 0.075631 0.13788 216.56 1.0159 62 | samples[4] 0.10531 0.092285 0.035363 0.077519 0.1481 235.25 1.0062 63 | samples[5] 0.10437 0.087634 0.034946 0.080095 0.1533 212.54 0.99841 64 | samples[6] 0.1029 0.093724 0.028774 0.074354 0.15135 227.6 1.0052 65 | samples[7] 0.1042 0.083084 0.038431 0.081964 0.15352 231.54 1.0008 66 | samples[8] 0.088778 0.086902 0.025565 0.062473 0.11837 229.69 1.0469 67 | samples[9] 0.10627 0.09074 0.036962 0.084294 0.15125 211.64 0.99856 68 | samples[10] 0.10184 0.084699 0.035981 0.074923 0.14578 230.63 1.0277 69 | ``` 70 | `n_ess` shows the effective sample size of the samples generated. 71 | `r_hat` tests the convergence of the sampling algorithm. 72 | A value of `r_hat` close to 1 indicates that the algorithm has converged properly. 73 | 74 | See `demo.m` for more examples, including examples of sampling from non-uniform distributions. 75 | -------------------------------------------------------------------------------- /bin/PackedChol0.mexa64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol0.mexa64 -------------------------------------------------------------------------------- /bin/PackedChol0.mexmaci64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol0.mexmaci64 -------------------------------------------------------------------------------- /bin/PackedChol0.mexw64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol0.mexw64 -------------------------------------------------------------------------------- /bin/PackedChol0SSE.mexa64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol0SSE.mexa64 -------------------------------------------------------------------------------- /bin/PackedChol0SSE.mexw64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol0SSE.mexw64 -------------------------------------------------------------------------------- /bin/PackedChol0arm.mexmaci64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol0arm.mexmaci64 -------------------------------------------------------------------------------- /bin/PackedChol1.mexa64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol1.mexa64 -------------------------------------------------------------------------------- /bin/PackedChol1.mexw64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol1.mexw64 -------------------------------------------------------------------------------- /bin/PackedChol1SSE.mexa64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol1SSE.mexa64 -------------------------------------------------------------------------------- /bin/PackedChol1SSE.mexw64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol1SSE.mexw64 -------------------------------------------------------------------------------- /bin/PackedChol1arm.mexmaci64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol1arm.mexmaci64 -------------------------------------------------------------------------------- /bin/PackedChol4.mexa64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol4.mexa64 -------------------------------------------------------------------------------- /bin/PackedChol4.mexmaci64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol4.mexmaci64 -------------------------------------------------------------------------------- /bin/PackedChol4.mexw64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol4.mexw64 -------------------------------------------------------------------------------- /bin/PackedChol4SSE.mexa64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol4SSE.mexa64 -------------------------------------------------------------------------------- /bin/PackedChol4SSE.mexw64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol4SSE.mexw64 -------------------------------------------------------------------------------- /bin/PackedChol4arm.mexmaci64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/PackedChol4arm.mexmaci64 -------------------------------------------------------------------------------- /bin/cpuInfo.mexa64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/cpuInfo.mexa64 -------------------------------------------------------------------------------- /bin/cpuInfo.mexmaci64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/cpuInfo.mexmaci64 -------------------------------------------------------------------------------- /bin/cpuInfo.mexw64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/bin/cpuInfo.mexw64 -------------------------------------------------------------------------------- /code/Hamiltonian.m: -------------------------------------------------------------------------------- 1 | classdef Hamiltonian < handle 2 | % H(x, v) = U(x) + K(x,v) 3 | % where 4 | % U(x) = f(x) + 1/2 (log det g + log det A g^-1 A') 5 | % K(x,v) = 1/2 v' (g^-1 - g^-1 A'(A g^-1 A')^-1 A g^-1) v 6 | 7 | properties 8 | A % constraint matrix A 9 | b % constraint vector b 10 | f % the objective function and its derivatives in the original space 11 | barrier % TwoSidedBarrier 12 | P % domain 13 | lsc 14 | df 15 | ddf 16 | solver 17 | opts 18 | end 19 | 20 | % Dependent Varialbes 21 | properties 22 | m 23 | n 24 | k 25 | x 26 | fx % f(Tx + y) 27 | dfx % T' * f'(Tx + y) 28 | hess 29 | prepared 30 | last_dUdx = [] 31 | end 32 | 33 | methods 34 | function o = Hamiltonian(P, opts) 35 | P.set_vdim(2); 36 | 37 | o.P = P; 38 | o.A = P.A; 39 | o.b = P.b; 40 | o.f = P.f; 41 | o.df = P.df; 42 | o.ddf = P.ddf; 43 | o.opts = opts; 44 | o.m = size(P.A,1); 45 | o.n = size(P.A,2); 46 | o.k = opts.simdLen; 47 | o.x = randn(o.k, o.n); 48 | o.barrier = P.barrier; 49 | o.solver = Solver(P.A, opts.solverThreshold, o.k); 50 | o.prepared = false; 51 | end 52 | 53 | % when we prepare 54 | function prepare(o, x) 55 | o.move(x); 56 | if ~o.prepared 57 | o.solver.setScale(1./o.hess); 58 | o.last_dUdx = []; 59 | end 60 | o.prepared = true; 61 | end 62 | 63 | % Resample v = g^{1/2} * N(0, I_d) 64 | function v = resample(o, x, v, momentum) 65 | if nargin == 3, momentum = 0; end 66 | o.move(x); 67 | sqrtHess = sqrt(o.hess); 68 | v = sqrt(momentum) * v + sqrt(1-momentum) * (sqrtHess .* randn(o.k, o.n)); 69 | end 70 | 71 | % Compute H(x,v) 72 | function E = H(o, x, v) 73 | o.prepare(x) 74 | K = 0.5 * sum(v .* o.DK(x, v),2); 75 | U = 0.5 * (o.solver.logdet() + sum(log(o.hess),2)); 76 | U = U + o.fx; 77 | E = U + K; 78 | end 79 | 80 | function dUdx = DU(o, x) 81 | o.move(x); 82 | if ~o.prepared || isempty(o.last_dUdx) 83 | o.prepare(x); 84 | o.lsc = o.solver.leverageScoreComplement(o.opts.nSketch); 85 | o.last_dUdx = o.barrier.tensor(x) .* o.lsc ./ (2*o.hess) + o.dfx; 86 | end 87 | dUdx = o.last_dUdx; 88 | end 89 | 90 | % Project to Ax = b 91 | function x = project(o, x) 92 | o.move(x); 93 | % col vector: x = x + step * (o.A' * o.solver.approxSolve(o.b - o.A*x))./o.hess; 94 | x = x + (o.solver.approxSolve(o.b' - x*o.A') * o.A)./o.hess; 95 | end 96 | 97 | % Compute dK/dv = (g^-1 - g^-1 A'(A g^-1 A')^-1 A g^-1) v and 98 | % dK/dx = -Dg[dK/dv,dK/dv]/2 99 | function [dKdv, dKdx] = DK(o, x, v) 100 | o.move(x); 101 | invHessV = v./o.hess; 102 | % col vector: dKdv = invHessV - (o.A' * o.solver.solve(o.A * invHessV))./o.hess; 103 | dKdv = invHessV - (o.solver.solve(invHessV * o.A') * o.A)./o.hess; 104 | if nargout > 1 105 | dKdx = -o.barrier.quadratic_form_gradient(x, dKdv)/2; 106 | end 107 | end 108 | 109 | % Approximate dK/dv = (g^-1 - g^-1 A'(A g^-1 A')^-1 A g^-1) v and 110 | % dK/dx = -Dg[dK/dv,dK/dv]/2 111 | function [dKdv, dKdx, nu] = approxDK(o, x, v, nu) 112 | o.move(x); 113 | % col vector: dUdv_b = o.A * ((v - o.A' * nu)./o.hess); 114 | dUdv_b = ((v - nu * o.A)./o.hess) * o.A'; 115 | nu = nu + o.solver.approxSolve(dUdv_b); 116 | % col vector: dKdv = (v - o.A' * nu)./o.hess; 117 | dKdv = (v - nu * o.A)./o.hess; 118 | dKdx = -o.barrier.quadratic_form_gradient(x, dKdv)/2; 119 | end 120 | 121 | function t = step_size(o, x, dx) 122 | t1 = o.barrier.step_size(x, dx); 123 | t2 = 1 / max(sqrt(o.barrier.hessian_norm(x, dx)),[],2); 124 | t = min(t1,t2); 125 | end 126 | 127 | % Test if the values of x and v are valid and if x is feasible 128 | function r = feasible(o, x, v) 129 | r = ~any(isnan(x),2) & o.barrier.feasible(x); 130 | 131 | if nargin == 3 132 | r = r & ~any(isnan(v), 2); 133 | end 134 | end 135 | 136 | function r = v_norm(o, x, dv) 137 | o.move(x); 138 | r = sum((dv .* dv) ./ o.hess,2); 139 | end 140 | 141 | function r = x_norm(o, x, dx) 142 | o.move(x); 143 | r = sum((dx .* dx) .* o.hess,2); 144 | end 145 | 146 | function move(o, x, forceUpdate) 147 | if nargin == 2, forceUpdate = false; end 148 | if ~all(size(o.x) == size(x)), return; end 149 | if ~forceUpdate && all(o.x == x, 'all'), return; end 150 | 151 | 152 | o.x = x; 153 | [o.fx, o.dfx, ddfx] = o.P.f_oracle(x); 154 | o.hess = o.barrier.hessian(x) + ddfx; 155 | o.prepared = false; 156 | end 157 | end 158 | end -------------------------------------------------------------------------------- /code/barrier/TwoSidedBarrier.m: -------------------------------------------------------------------------------- 1 | classdef TwoSidedBarrier < handle 2 | % The log barrier for the domain {lu <= x <= ub}: 3 | % phi(x) = - sum log(x - lb) - sum log(ub - x). 4 | properties (SetAccess = private) 5 | ub % ub 6 | lb % lb 7 | vdim % Each point is stored along the dimension vdim 8 | n % Number of variables 9 | upperIdx % Indices that lb == -Inf 10 | lowerIdx % Indices that ub == Inf 11 | freeIdx % Indices that ub == Inf and lb == -Inf 12 | center % Some feasible point x 13 | end 14 | 15 | properties 16 | extraHessian = 0 % Extra factor added when computing Hessian. Used to handle free constraints. 17 | end 18 | 19 | methods 20 | function o = TwoSidedBarrier(lb, ub, vdim) 21 | % o.update(lb, ub) 22 | % Update the bounds lb and ub. 23 | 24 | if nargin < 3, vdim = 1; end 25 | o.set_bound(lb, ub); 26 | o.vdim = vdim; 27 | end 28 | 29 | function set_bound(o, lb, ub) 30 | % o.set_bound(lb, ub) 31 | % Update the bounds lb and ub. 32 | 33 | o.n = length(lb); 34 | assert(numel(lb) == o.n); 35 | assert(numel(ub) == o.n); 36 | assert(all(lb < ub)); 37 | 38 | o.lb = lb; 39 | o.ub = ub; 40 | o.upperIdx = find(o.lb == -Inf); 41 | o.lowerIdx = find(o.ub == Inf); 42 | o.freeIdx = find((o.lb == -Inf) & (o.ub == Inf)); 43 | 44 | c = (o.ub+o.lb)/2; 45 | c(o.lowerIdx) = o.lb(o.lowerIdx) + 1e6; 46 | c(o.upperIdx) = o.ub(o.upperIdx) - 1e6; 47 | c(o.freeIdx) = 0; 48 | o.center = c; 49 | end 50 | 51 | function set_vdim(o, vdim) 52 | % o.set_bound(lb, ub) 53 | % Update the dimension dim. 54 | 55 | assert(vdim == 1 || vdim == 2); 56 | 57 | o.vdim = vdim; 58 | if vdim == 1 59 | o.lb = reshape(o.lb, [o.n, 1]); 60 | o.ub = reshape(o.ub, [o.n, 1]); 61 | o.center = reshape(o.center, [o.n, 1]); 62 | else 63 | o.lb = reshape(o.lb, [1, o.n]); 64 | o.ub = reshape(o.ub, [1, o.n]); 65 | o.center = reshape(o.center, [1, o.n]); 66 | end 67 | end 68 | 69 | function r = feasible(o, x) 70 | % r = o.feasible(x) 71 | % Output if x is feasible. 72 | 73 | r = all((x > o.lb) & (x < o.ub), o.vdim); 74 | end 75 | 76 | function t = step_size(o, x, v) 77 | % t = o.stepsize(x, v) 78 | % Output the maximum step size from x with direction v. 79 | 80 | max_step = 1e16; % largest step size 81 | if (o.vdim == 2) 82 | max_step = max_step * ones(size(x,1),1); 83 | end 84 | 85 | % check positive direction 86 | posIdx = v > 0; 87 | t1 = min((o.ub(posIdx) - x(posIdx))./v(posIdx), [], o.vdim); 88 | if isempty(t1), t1 = max_step; end 89 | 90 | % check negative direction 91 | negIdx = v < 0; 92 | t2 = min((o.lb(negIdx) - x(negIdx))./v(negIdx), [], o.vdim); 93 | if isempty(t2), t2 = max_step; end 94 | 95 | t = min(min(t1, t2), max_step); 96 | end 97 | 98 | function [A, b] = boundary(o, x) 99 | % [A, b] = o.boundary(x) 100 | % Output the normal at the boundary around x for each barrier. 101 | % Assume: only 1 vector is given 102 | 103 | assert(size(x, 3-o.vdim) == 1); 104 | 105 | c = o.center; 106 | 107 | b = o.ub; 108 | b(x o.lb) & (x < o.ub), o.vdim); 76 | end 77 | 78 | function t = step_size(o, x, v) 79 | % t = o.stepsize(x, v) 80 | % Output the maximum step size from x with direction v. 81 | 82 | max_step = 1e16; % largest step size 83 | if (o.vdim == 2) 84 | max_step = max_step * ones(size(x,1),1); 85 | end 86 | 87 | % check positive direction 88 | posIdx = v > 0; 89 | t1 = min((o.ub(posIdx) - x(posIdx))./v(posIdx), [], o.vdim); 90 | if isempty(t1), t1 = max_step; end 91 | 92 | % check negative direction 93 | negIdx = v < 0; 94 | t2 = min((o.lb(negIdx) - x(negIdx))./v(negIdx), [], o.vdim); 95 | if isempty(t2), t2 = max_step; end 96 | 97 | t = min(min(t1, t2), max_step); 98 | end 99 | 100 | function [A, b] = boundary(o, x) 101 | % [A, b] = o.boundary(x) 102 | % Output the normal at the boundary around x for each barrier. 103 | % Assume: only 1 vector is given 104 | 105 | assert(size(x, 3-o.vdim) == 1); 106 | 107 | c = o.center; 108 | 109 | b = o.ub; 110 | b(x opts.tol; 74 | r1 = min(pUbS(posIdx)./u(posIdx)); 75 | 76 | % check the constraint x >= P.lb 77 | negIdx = u < -opts.tol; 78 | r2 = min(pLbS(negIdx)./(u(negIdx))); 79 | 80 | % check the constraint P.Aineq * x <= P.bineq 81 | Au = P.Aineq * u; 82 | posAIdx = (Au > opts.tol); 83 | r3 = min(pIneqS(posAIdx)./(Au(posAIdx))); 84 | 85 | r = min([r1, r2, r3]); 86 | if (vectorMode) 87 | a = u' * P.df; 88 | else 89 | xt = P.f(x_i); 90 | xt = xt + exprnd(1,size(xt,1),size(xt,2)); 91 | ut = xt - pt; 92 | g = @(t) P.f(p + t * u) - (pt + t * ut); 93 | r4 = binary_search(g, 1, r, 1e-12); 94 | r = min(r, r4); 95 | 96 | a = sum(ut); 97 | end 98 | 99 | if (a * r > dim) 100 | v = gammainc(a, dim) / gammainc(a * r, dim); 101 | else 102 | v = scaledgammainc(a, dim) / scaledgammainc(a * r, dim); 103 | v = v * exp(a * (r-1) - dim * log(r)); 104 | end 105 | unif_vals(i) = real(v); 106 | end 107 | 108 | % To ensure pval1 and pval2 are independent, 109 | % we run adtest and kstest on disjoint subset 110 | try 111 | [~,pVal1] = adtest(norminv(unif_vals(1:floor(N/2)))); 112 | catch 113 | pVal1 = 0; 114 | end 115 | 116 | try 117 | [~,pVal2] = kstest(norminv(unif_vals((floor(N/2)+1):end))); 118 | catch 119 | pVal2 = 0; 120 | end 121 | 122 | % We merge two indepdent p value into one p value 123 | z = pVal1 * pVal2; 124 | pVal = z - z * log(z); 125 | 126 | if opts.toPlot 127 | figure; 128 | cdfplot(unif_vals); 129 | hold on; 130 | plot(0:0.01:1, 0:0.01:1, '.') 131 | title(sprintf('Empirical CDF of the radial distribution (pval = %.4f)', pVal)) 132 | end 133 | end 134 | 135 | % find maximum t such that g(x) <= 0 136 | % assume g(a) <= 0 137 | function t = binary_search(g, a, b, tol) 138 | assert(all(g(a) <= 0) && b >= a); 139 | if all(g(b)<=0) 140 | t = b; 141 | else 142 | while (b > a + tol) 143 | m = (b+a)/2; 144 | if all(g(m)<=0) 145 | a = m; 146 | else 147 | b = m; 148 | end 149 | end 150 | t = (a+b)/2; 151 | end 152 | end 153 | 154 | function t = scaledgammainc(a, d) 155 | assert(a <= d); 156 | if (a >= -d-20) 157 | t = gammainc(a, d, 'scaledlower'); 158 | else 159 | t = (d)/(d-1-a)-(d*(d-1))/(d-1-a)^3 - 2*d*(d-1)/(d-1-a)^4 + 3 * d * (d-1)^2 / (d-1-a)^5; 160 | % this case should not happen. But I am adding this to avoid NaN 161 | end 162 | end -------------------------------------------------------------------------------- /code/diagnostics/effective_sample_size.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/code/diagnostics/effective_sample_size.m -------------------------------------------------------------------------------- /code/diagnostics/rhat.m: -------------------------------------------------------------------------------- 1 | function [rhat_val] = rhat(x) 2 | %rhat_val = rhat(x) 3 | %compute rhat of each parameter in the matrix x 4 | % 5 | %Input: 6 | % x - a dim x N vector, where N is the length of the chain. 7 | % 8 | %Output: 9 | % rhat - a dim x 1 vector where ess(i) is the effective sample size of x(i,:). 10 | 11 | % TODO: this is not a correct algorithm. Fix it. 12 | % Now the input is always 13 | % cell of matrices dim x N. Each matrix is a chain. 14 | % What is the correct formula in this case? 15 | % Update the comment also. It is wrong now. 16 | 17 | [~, N] = size(x); 18 | if mod(N, 2) == 1 19 | x = x(:, 1:N-1); 20 | N = N-1; 21 | end 22 | 23 | N = N/2; 24 | y = x(:, 1: N); 25 | x = x(:, N+1: 2*N); 26 | 27 | b_div_n = var([mean(x, 2), mean(y, 2)], 0, 2); 28 | w = mean([var(x, 0, 2), var(y, 0, 2)], 2) + eps; 29 | 30 | sig_2p = (N-1)/N .* w + b_div_n; 31 | 32 | rhat_val = 3/2 * sig_2p ./w - (N-1)/ (2*N); 33 | end -------------------------------------------------------------------------------- /code/diagnostics/summary.m: -------------------------------------------------------------------------------- 1 | function smry = summary(o) 2 | %smry = summary(o) 3 | %compute summary of samples 4 | % 5 | %Input: 6 | % o - the samples object outputted by sample 7 | % 8 | %Output: 9 | % smry - a table summarizing the mean and rhat of samples. 10 | 11 | % compute ess by summing over all chains 12 | 13 | % compute the rest of the summary 14 | rh = rhat(o.samples); 15 | st = std(o.samples, 0, 2); 16 | m = mean(o.samples,2); 17 | d = size(o.samples, 1); 18 | 19 | if isempty(ver('stats')) 20 | smry = table(m, st, rh, 'VariableNames',... 21 | {'mean', 'std', 'r_hat'}, ... 22 | 'RowNames', 'samples['+string(1:d)+']'); 23 | else 24 | Y = prctile(o.samples, [25 50 75], 2); 25 | per25 = Y(:, 1); 26 | per50 = Y(:, 2); 27 | per75 = Y(:, 3); 28 | 29 | smry = table(m, st, per25, per50, per75, rh, 'VariableNames',... 30 | {'mean', 'std', 'Q1', 'Q2', 'Q3', 'r_hat'}, ... 31 | 'RowNames', 'samples['+string(1:d)+']'); 32 | end 33 | end -------------------------------------------------------------------------------- /code/gauss_legendre.m: -------------------------------------------------------------------------------- 1 | function [x4, v4, step] = gauss_legendre(x0, v0, h, ham, opts) 2 | % Step 1 3 | x1 = x0; 4 | v1 = v0 - (h/2) * ham.DU(x0); 5 | done = 0; 6 | 7 | % Step 2 8 | c1 = h/4; c2 = h * (1/4-sqrt(3)/6); 9 | c3 = h * (1/4+sqrt(3)/6); c4 = h/4; 10 | nu1 = zeros(size(x1,1),size(ham.A,1),1); nu2 = zeros(size(x1,1),size(ham.A,1),1); 11 | k1x = zeros(size(x1)); k1v = zeros(size(v1)); 12 | k2x = zeros(size(x1)); k2v = zeros(size(v1)); 13 | 14 | for step = 1:opts.maxODEStep 15 | k2x_old = k2x; 16 | 17 | [k1x, k1v, nu1] = ham.approxDK... 18 | (x1 + c1 * k1x + c2 * k2x, ... 19 | v1 - c1 * k1v - c2 * k2v, nu1); 20 | 21 | [k2x, k2v, nu2] = ham.approxDK... 22 | (x1 + c3 * k1x + c4 * k2x, ... 23 | v1 - c3 * k1v - c4 * k2v, nu2); 24 | 25 | dist = ham.x_norm(x1, k2x_old-k2x); 26 | if (max(dist,[],'all') < opts.implicitTol) 27 | done = 1; 28 | break; 29 | elseif any(dist > 1e16, 'all') 30 | break; 31 | end 32 | end 33 | 34 | if done == 0 35 | x4 = NaN; v4 = NaN; 36 | return 37 | end 38 | 39 | x2 = x1 + h/2 * (k1x + k2x); 40 | v2 = v1 - h/2 * (k1v + k2v); 41 | 42 | % Step 3 43 | x3 = x2; 44 | v3 = v2 - (h/2) * ham.DU(x3); 45 | 46 | % Step 4 (Project to Ax = b) 47 | ham.prepare(x3); 48 | v4 = v3; 49 | x4 = ham.project(x3); 50 | end 51 | -------------------------------------------------------------------------------- /code/implicit_midpoint.m: -------------------------------------------------------------------------------- 1 | function [x4, v4, step] = implicit_midpoint(x0, v0, h, ham, opts) 2 | % Step 1 3 | x1 = x0; 4 | v1 = v0 - (h/2) * ham.DU(x0); 5 | done = 0; 6 | 7 | % Step 2 8 | x2 = x1; v2 = v1; 9 | nu = zeros(size(x1,1),size(ham.A,1),1); 10 | for step = 1:opts.maxODEStep 11 | x2_old = x2; 12 | xmid = (x1+x2)/2; 13 | vmid = (v1+v2)/2; 14 | [dKdv, dKdx, nu] = ham.approxDK(xmid, vmid, nu); 15 | x2 = x1 + h * dKdv; 16 | v2 = v1 - h * dKdx; 17 | dist = ham.x_norm(xmid, x2-x2_old)/ h; 18 | if (max(dist,[],'all') < opts.implicitTol) 19 | done = 1; 20 | break; 21 | elseif any(dist > 1e16, 'all') 22 | break; 23 | end 24 | end 25 | 26 | if done == 0 27 | x4 = NaN; v4 = NaN; 28 | return 29 | end 30 | 31 | % Step 3 32 | x3 = x2; 33 | v3 = v2 - (h/2) * ham.DU(x3); 34 | 35 | % Step 4 (Project to Ax = b) 36 | ham.prepare(x3); 37 | v4 = v3; 38 | x4 = ham.project(x3); 39 | end 40 | -------------------------------------------------------------------------------- /code/integrate.m: -------------------------------------------------------------------------------- 1 | function W = integrate(problem, eps, opts, B, k) 2 | %Input: a structure problem with the following fields 3 | % .Aineq 4 | % .bineq 5 | % .Aeq 6 | % .beq 7 | % .lb 8 | % .ub 9 | % .f 10 | % .df 11 | % .ddf 12 | % describing a logconcave distribution given by 13 | % exp(-sum f_i(x_i)) 14 | % over 15 | % {Aineq x <= bineq, Aeq x = beq, lb <= x <= ub} 16 | % where f is given by a vector function of its first (df) and second (ddf) 17 | % derivatives. 18 | % 19 | % Case 1: df is not defined 20 | % f(x) = 0. 21 | % In this case, f, ddf must be empty. This is the uniform distribution. In 22 | % this case the feasible region must be bounded. 23 | % 24 | % Case 2: df is a vector 25 | % f_i(x_i) = df_i x_i. 26 | % In this case, f, ddf must be empty. 27 | % 28 | % Case 3: df is a function handle 29 | % f need to be defined as a function handle. 30 | % df need to be the derivative of f 31 | % ddf is optional. Providing ddf could improve the mixing time. 32 | % 33 | % eps - multiplicative error of the integral 34 | % 35 | % opts - sampling options 36 | % 37 | % B - optional. Upper bound on the quantity max(f) - min(f) 38 | % 39 | % k - optional. Number of samples in each iteration. If k is provided, eps 40 | % error is not guaranteed. 41 | % 42 | %Output: 43 | % W - the integral of exp(-f) restricted on the polytope 44 | % {Aineq x <= bineq, Aeq x = beq, lb <= x <= ub} 45 | 46 | %% Initialize parameters and compiling if needed 47 | if (nargin <= 2) 48 | opts = default_options(); 49 | else 50 | opts = Setfield(default_options(), opts); 51 | end 52 | 53 | compile_solver(0); compile_solver(opts.simdLen); 54 | 55 | %% Presolve 56 | if isempty(opts.seed) 57 | opts.seed = randi(2^31); 58 | end 59 | 60 | if ischar(opts.logging) || isstring(opts.logging) % logging for Polytope 61 | fid = fopen(opts.logging, 'a'); 62 | opts.presolve.logFunc = @(tag, msg) fprintf(fid, '%s', msg); 63 | elseif ~isempty(opts.logging) 64 | opts.presolve.logFunc = opts.logging; 65 | else 66 | opts.presolve.logFunc = @(tag, msg) 0; 67 | end 68 | 69 | polytope = Polytope(problem, opts); 70 | 71 | if ischar(opts.logging) || isstring(opts.logging) 72 | fclose(fid); 73 | end 74 | 75 | d = polytope.n; 76 | % fprintf('d is %d\n', d); 77 | if (nargin <= 3) 78 | B = 2 * d; 79 | end 80 | num_of_iter = ceil(sqrt(d)*log(B)); 81 | 82 | if (nargin <= 4) 83 | k = 512/eps^2 * sqrt(d)*log(B); 84 | end 85 | 86 | W = 1; 87 | a = 0; 88 | f = polytope.f; 89 | df = polytope.df; 90 | ddf = polytope.ddf; 91 | 92 | density = @(x) exp(-f(x)); 93 | 94 | problem.f = []; 95 | problem.df = []; 96 | problem.dff = []; 97 | 98 | for i = 1:num_of_iter-1 99 | 100 | o = sample(problem, k, opts); 101 | W = W * mean(density(o.samples).^(1/B *(1+1/sqrt(d))^i - a)); 102 | % fprintf('W after iteration %d is %d\n', i, W); 103 | clear o 104 | a = 1/B * (1+1/sqrt(d))^i; 105 | % fprintf('a after iteration %d is %d\n', i, a); 106 | problem.f = @(x) a * f(x); 107 | problem.df = @(x) a * df(x); 108 | if isa(ddf, 'function_handle') 109 | problem.ddf = @(x) a * ddf(x); 110 | end 111 | end 112 | 113 | 114 | o = sample(problem, k, opts); 115 | W = W * mean(density(o.samples).^(1 - a)); 116 | clear o 117 | 118 | -------------------------------------------------------------------------------- /code/module/DebugLogger.m: -------------------------------------------------------------------------------- 1 | classdef DebugLogger < handle 2 | % Module for output the following information in the sample output 3 | % - acceptedStep, the number of accepted step for each chain 4 | % - totalStep, the total number of step taken 5 | % - numCholesky(1:end-1), the number of high precision cholesky decompositions we performed 6 | % - numCholesky(end), the number of cholesky decompositions (in any precision) we performed 7 | % - sampler, the sampler object we used 8 | % 9 | % If the sampler fails, dump all variables to 'dump_#_ignore.mat' 10 | methods 11 | function o = DebugLogger(s) 12 | s.output.averageAccuracy = 0; 13 | end 14 | 15 | function o = step(o, s) 16 | s.output.averageAccuracy = s.output.averageAccuracy + s.ham.solver.accuracy; 17 | if (any(s.ham.solver.accuracy > s.opts.solverThreshold)) 18 | s.log('Solver:inaccurate', 'low double accuracy %s.\n', num2str(s.ham.solver.accuracy')); 19 | end 20 | end 21 | 22 | function o = finalize(o, s) 23 | s.output.acceptedStep = s.acceptedStep; 24 | s.output.totalStep = s.i; % ignore the warm up phase 25 | s.output.numCholesky = s.ham.solver.getDecomposeCount(); 26 | s.output.averageAccuracy = s.output.averageAccuracy / s.i; 27 | s.output.sampler = s; 28 | if s.terminate == 3 29 | save(sprintf('dump_%i_ignore.mat', s.nWorkers)); 30 | end 31 | end 32 | end 33 | end -------------------------------------------------------------------------------- /code/module/DynamicRegularizer.m: -------------------------------------------------------------------------------- 1 | classdef DynamicRegularizer < handle 2 | % Module for updating the extra term we add to the barrier 3 | % This is nesscary for any polytope with free variables 4 | properties 5 | bound = 1 6 | end 7 | 8 | methods 9 | function o = DynamicRegularizer(s) 10 | s.ham.barrier.extraHessian = 1; 11 | end 12 | 13 | function o = step(o, s) 14 | o.bound = max(max(abs(s.x), 1), o.bound); 15 | % the s.ham.n factor is due to the tv_ball example 16 | if (~s.freezed) 17 | idx = find(2./(o.bound.*o.bound) < s.ham.n * s.ham.barrier.extraHessian, 1); 18 | if ~isempty(idx) 19 | idx = find(1./(o.bound.*o.bound) < s.ham.n * s.ham.barrier.extraHessian); 20 | s.ham.barrier.extraHessian = 0.5./(s.ham.n * o.bound.*o.bound); 21 | s.ham.move(s.x, true); % Update the cache used in barrier 22 | s.v = s.ham.resample(s.x, zeros(size(s.x))); 23 | s.log('DynamicRegularizer:set_bound', 'The bound of %i coordinates are changed.\n', length(idx)); 24 | end 25 | end 26 | end 27 | end 28 | end -------------------------------------------------------------------------------- /code/module/DynamicStepSize.m: -------------------------------------------------------------------------------- 1 | classdef DynamicStepSize < handle 2 | % Module for dynamically choosing the step size 3 | properties 4 | opts 5 | 6 | consecutiveBadStep 7 | iterSinceShrink 8 | rejectSinceShrink 9 | ODEStepSinceShrink 10 | effectiveStep 11 | warmupFinished = false 12 | end 13 | 14 | methods 15 | function o = DynamicStepSize(s) 16 | o.opts = s.opts.DynamicStepSize; 17 | o.consecutiveBadStep = 0; 18 | o.iterSinceShrink = 0; 19 | o.rejectSinceShrink = 0; 20 | o.ODEStepSinceShrink = 0; 21 | o.effectiveStep = 0; 22 | 23 | if o.opts.warmUpStep > 0 24 | s.stepSize = 1e-3; 25 | else 26 | o.warmupFinished = true; 27 | end 28 | end 29 | 30 | function o = step(o, s) 31 | % Warmup phase 32 | bad_step = s.prob < 0.5 | s.ODEStep == s.opts.maxODEStep; 33 | o.consecutiveBadStep = bad_step .* o.consecutiveBadStep + bad_step; 34 | warmupRatio = mean(s.nEffectiveStep) / o.opts.warmUpStep; 35 | 36 | if warmupRatio < 1 && ~o.warmupFinished && max(o.consecutiveBadStep) < o.opts.maxConsecutiveBadStep 37 | s.stepSize = s.opts.initalStepSize * min(warmupRatio+1e-2, 1); 38 | s.momentum = 1 - min(1, s.stepSize / s.opts.effectiveStepSize); 39 | return; 40 | end 41 | 42 | if (~o.warmupFinished) 43 | s.i = 1; 44 | s.acceptedStep = 0; 45 | s.nEffectiveStep = 0; 46 | s.chains = zeros(s.opts.simdLen, s.ham.n, 0); 47 | o.warmupFinished = true; 48 | end 49 | 50 | o.iterSinceShrink = o.iterSinceShrink + 1; 51 | o.rejectSinceShrink = o.rejectSinceShrink + 1-s.prob; 52 | o.ODEStepSinceShrink = o.ODEStepSinceShrink + s.ODEStep; 53 | 54 | % Shrink the step during the warmup phase 55 | if (~s.freezed) 56 | shrink = 0; 57 | shiftedIter = o.iterSinceShrink + 20 / (1-s.momentum); 58 | 59 | targetProbability = (1-s.momentum)^(2/3)/4; 60 | if (max(o.rejectSinceShrink) > targetProbability * shiftedIter) 61 | shrink = sprintf('Failure Probability is %.4f, which is larger than the target %.4f', max(o.rejectSinceShrink) / o.iterSinceShrink, targetProbability); 62 | end 63 | 64 | if (max(o.consecutiveBadStep) > o.opts.maxConsecutiveBadStep) 65 | shrink = sprintf('Consecutive %i Bad Steps', max(o.consecutiveBadStep)); 66 | end 67 | 68 | if (max(o.ODEStepSinceShrink) > o.opts.targetODEStep * shiftedIter) 69 | shrink = sprintf('ODE solver requires %.4f steps in average, which is larger than the target %.4f', max(o.ODEStepSinceShrink) / o.iterSinceShrink, o.opts.targetODEStep); 70 | end 71 | 72 | if ischar(shrink) 73 | o.iterSinceShrink = 0; 74 | o.rejectSinceShrink = 0; 75 | o.ODEStepSinceShrink = 0; 76 | o.consecutiveBadStep = 0; 77 | 78 | s.stepSize = s.stepSize / o.opts.shrinkFactor; 79 | s.momentum = 1 - min(0.999, s.stepSize / s.opts.effectiveStepSize); 80 | 81 | s.log('DynamicStepSize:step', 'Step shrinks to h = %f due to %s\n', s.stepSize, shrink); 82 | 83 | if s.stepSize < o.opts.minStepSize 84 | s.log('warning', 'Algorithm fails to converge even with step size h = %f.\n', s.stepSize); 85 | s.terminate = 3; 86 | end 87 | end 88 | 89 | o.iterSinceShrink = o.iterSinceShrink + 1; 90 | elseif max(o.consecutiveBadStep) > o.opts.maxConsecutiveBadStep 91 | s.x = ones(size(s.x,1),1) * mean(s.chains, [1 3]); 92 | s.v = s.ham.resample(s.x, zeros(size(s.x))); 93 | s.log('DynamicStepSize:step', 'Sampler reset to the center gravity due to consecutive bad steps.\n'); 94 | 95 | o.iterSinceShrink = 0; 96 | o.rejectSinceShrink = 0; 97 | o.ODEStepSinceShrink = 0; 98 | o.consecutiveBadStep = 0; 99 | end 100 | end 101 | end 102 | end -------------------------------------------------------------------------------- /code/module/DynamicWeight.m: -------------------------------------------------------------------------------- 1 | classdef DynamicWeight < handle 2 | % Module for updating the weight in the log barrier. 3 | properties 4 | consecutiveBadStep = 0; 5 | end 6 | 7 | methods 8 | function o = DynamicWeight(s) 9 | extrHess = s.ham.barrier.extraHessian; 10 | s.ham.barrier = WeightedTwoSidedBarrier(s.ham.barrier.lb, s.ham.barrier.ub, s.problem.w, 2); 11 | s.ham.barrier.extraHessian = extrHess; 12 | end 13 | 14 | function o = step(o, s) 15 | bad_step = s.prob < 0.5 | s.ODEStep == s.opts.maxODEStep; 16 | o.consecutiveBadStep = bad_step .* o.consecutiveBadStep + bad_step; 17 | if (~s.freezed && ~all(s.accept)) 18 | lsc = max(s.ham.lsc, [], 1); 19 | w = reshape(s.ham.barrier.w, size(lsc)); 20 | if max(o.consecutiveBadStep) > 2 21 | threshold = 4; 22 | else 23 | threshold = 16; 24 | end 25 | 26 | idx = find(lsc > threshold * w); 27 | if ~isempty(idx) 28 | s.ham.barrier.w(idx) = min(s.ham.barrier.w(idx) * threshold, 1); 29 | s.ham.move(s.x, true); % Update the cache used in barrier 30 | s.v = s.ham.resample(s.x, zeros(size(s.x))); 31 | s.log('DynamicWeight:update_weight', 'The weight of %i coordinates are changed.\n', length(idx)); 32 | end 33 | end 34 | end 35 | end 36 | end -------------------------------------------------------------------------------- /code/module/MemoryStorage.m: -------------------------------------------------------------------------------- 1 | classdef MemoryStorage < handle 2 | % Module for maintaining the chain in memory 3 | 4 | properties 5 | opts 6 | end 7 | 8 | methods 9 | function o = MemoryStorage(s) 10 | o.opts = s.opts.MemoryStorage; 11 | s.chains = zeros(s.opts.simdLen, s.ham.n, 0); 12 | end 13 | 14 | function o = step(o, s) 15 | if mod(s.i, s.iterPerRecord) == 0 16 | s.chains(:,:,end+1) = s.x; 17 | 18 | % Thin the chain if 19 | % it exceeds the memory limit or more samples than opts.maxRecordsPerIndependentSample required 20 | if (isnan(s.mixingTime)) 21 | mixingTime = s.i; 22 | recordMax = 1; 23 | else 24 | mixingTime = s.mixingTime; 25 | recordMax = mixingTime / o.opts.maxRecordsPerIndependentSample; 26 | 27 | if (size(s.chains, 3) * s.iterPerRecord < 20 * mixingTime) 28 | recordMax = recordMax / 10; % if ess < 20, records more 29 | end 30 | end 31 | 32 | mem = numel(s.chains) * 8; 33 | if (mem > o.opts.memoryLimit || s.iterPerRecord < recordMax) 34 | if (2 * s.iterPerRecord < mixingTime) 35 | s.chains = s.chains(:, :, 2:2:end); 36 | s.iterPerRecord = s.iterPerRecord * 2; 37 | else 38 | s.log('warning', 'Algorithm stops due to out of memory (opts.MemoryStorage.memoryLimit = %f byte).\n', o.opts.memoryLimit); 39 | s.terminate = 2; 40 | end 41 | end 42 | end 43 | end 44 | 45 | function o = finalize(o, s) 46 | if s.opts.rawOutput 47 | s.output.chains = s.chains; 48 | else 49 | N = size(s.chains,3); 50 | if o.opts.thinOutput 51 | ess = min(effective_sample_size(s.chains), [], 1); 52 | else 53 | ess = N * ones(size(s.chains,1),1); 54 | end 55 | 56 | out = []; 57 | for i = 1:numel(ess) 58 | gap = ceil(N/ess(i)); 59 | out_i = s.chains(i, :, 1:gap:N); 60 | out_i = reshape(out_i, [size(out_i,2) size(out_i,3)]); 61 | out_i = s.problem.T * out_i + s.problem.y; 62 | out = [out out_i]; 63 | end 64 | s.output.chains = out; 65 | end 66 | end 67 | end 68 | end -------------------------------------------------------------------------------- /code/module/MixingTimeEstimator.m: -------------------------------------------------------------------------------- 1 | classdef MixingTimeEstimator < handle 2 | % Module for estimate mixing time 3 | 4 | properties 5 | opts 6 | 7 | removedInitial = false 8 | sampleRate = 0 9 | sampleRateOutside = 0 10 | estNumSamples = 0 11 | estNumSamplesOutside = 0 12 | nextEstimateStep 13 | end 14 | 15 | methods 16 | function o = MixingTimeEstimator(s) 17 | o.opts = s.opts.MixingTimeEstimator; 18 | o.nextEstimateStep = o.opts.initialStep; 19 | end 20 | 21 | function o = step(o, s) 22 | if mean(s.nEffectiveStep) > o.nextEstimateStep 23 | ess = effective_sample_size(s.chains); 24 | ess = min(ess, [], 'all'); 25 | 26 | if (o.removedInitial == false && ess > 2 * s.opts.nRemoveInitialSamples) 27 | k = ceil(s.opts.nRemoveInitialSamples * (size(s.chains, 3) / ess)); 28 | s.i = ceil(s.i * (1-k / size(s.chains, 3))); 29 | s.acceptedStep = s.acceptedStep * (1-k / size(s.chains, 3)); 30 | s.chains = s.chains(:,:,k:end); 31 | o.removedInitial = true; 32 | ess = effective_sample_size(s.chains); 33 | ess = min(ess, [], 'all'); 34 | end 35 | 36 | s.mixingTime = s.iterPerRecord * size(s.chains, 3) / ess; 37 | o.sampleRate = size(s.chains,1) / s.mixingTime; 38 | o.estNumSamples = s.i * o.sampleRate; 39 | s.share('sampleRate', o.sampleRate); 40 | s.share('estNumSamples', o.estNumSamples); 41 | o.update(s); 42 | end 43 | end 44 | 45 | function o = sync(o, s) 46 | o.estNumSamplesOutside = 0; 47 | o.sampleRateOutside = 0; 48 | for i = 1:s.nWorkers 49 | if i ~= s.labindex 50 | if isfield(s.shared{i}, 'estNumSamples') 51 | o.estNumSamplesOutside = o.estNumSamplesOutside + s.shared{i}.estNumSamples; 52 | end 53 | 54 | if isfield(s.shared{i}, 'sampleRate') 55 | o.sampleRateOutside = o.sampleRateOutside + s.shared{i}.sampleRate; 56 | end 57 | end 58 | end 59 | o.update(s); 60 | end 61 | 62 | function o = update(o, s) 63 | s.sampleRate = o.sampleRate + o.sampleRateOutside; 64 | s.totalNumSamples = o.estNumSamples + o.estNumSamplesOutside; 65 | if o.estNumSamples > s.opts.freezeMCMCAfterSamples 66 | s.freezed = true; 67 | end 68 | 69 | if s.totalNumSamples > s.N && o.removedInitial 70 | s.share('estNumSamples', o.estNumSamples); 71 | s.terminate = 1; 72 | s.log('sample:end', '%i samples found.\n', s.totalNumSamples); 73 | elseif o.removedInitial 74 | estimateEndingStep = s.N / s.sampleRate * (mean(s.nEffectiveStep) / s.i); 75 | o.nextEstimateStep = min(o.nextEstimateStep * o.opts.stepMultiplier, estimateEndingStep); 76 | else 77 | estimateEndingStep = (2 * s.opts.nRemoveInitialSamples * size(s.chains,1)) / o.sampleRate * (mean(s.nEffectiveStep) / s.i); 78 | o.nextEstimateStep = min(o.nextEstimateStep * o.opts.stepMultiplier, estimateEndingStep); 79 | end 80 | end 81 | end 82 | end -------------------------------------------------------------------------------- /code/module/ProgressBar.m: -------------------------------------------------------------------------------- 1 | classdef ProgressBar < handle 2 | % Module for printing out progress bar on Matlab console 3 | 4 | properties 5 | barLength = 25 % the length of the field "Progress" 6 | nSamplesTextLength = 12 % the length of "samples number" 7 | nSamplesFieldLength = 12 % the length of the field "Est Samples" 8 | nextBackspaceLength 9 | nExtraBackspacePerPrint 10 | 11 | refreshInterval = 0.2 12 | lastRefresh = tic(); 13 | end 14 | 15 | methods 16 | function o = ProgressBar(s) 17 | o.refreshInterval = s.opts.ProgressBar.refreshInterval; 18 | 19 | if (s.N ~= Inf) 20 | o.nSamplesTextLength = max(length(num2str(s.N)), 5); 21 | o.nSamplesFieldLength = 2 * o.nSamplesTextLength + 3; 22 | end 23 | 24 | if s.labindex == 1 25 | fmt = sprintf('%s%i%s%i%s', '%12s | %12s | %', o.barLength, 's | %', o.nSamplesFieldLength, 's | %8s | %8s | %8s\n'); 26 | str = sprintf(fmt, 'Time spent', 'Time left', 'Progress', 'Est Samples', 'AccProb', 'StepSize', 'MixTime'); 27 | disp(str); 28 | 29 | if (s.nWorkers > 1) 30 | o.nExtraBackspacePerPrint = 3; 31 | else 32 | o.nExtraBackspacePerPrint = 1; 33 | end 34 | o.nextBackspaceLength = o.nExtraBackspacePerPrint; 35 | end 36 | end 37 | 38 | function o = step(o, s) 39 | if toc(o.lastRefresh) < o.refreshInterval, return, end 40 | if s.labindex == 1 41 | o.refresh_bar(s); 42 | end 43 | end 44 | 45 | function o = finalize(o, s) 46 | if s.labindex == 1 47 | o.refresh_bar(s); 48 | fprintf('Done!\n'); 49 | end 50 | end 51 | 52 | function o = refresh_bar(o, s) 53 | prob = mean(s.acceptedStep) / s.i; 54 | timeSpent = toc(s.startTime); % s.startTime = time after presolve 55 | 56 | if isnan(s.sampleRate) 57 | avgMixingTime = NaN; 58 | nSamples = 0; 59 | timeRemain = Inf; 60 | else 61 | avgMixingTime = (size(s.x,1) * s.nWorkers) / s.sampleRate; 62 | nSamples = floor(max(s.i * s.sampleRate, s.totalNumSamples)); 63 | r = min(1, nSamples / s.N); 64 | timeRemain = (timeSpent / r) * (1-r); 65 | end 66 | timeRemain = min(timeRemain, s.opts.maxTime - toc(s.opts.startTime)); % s.opts.startTime = time before presolve 67 | timeRemain = min(timeRemain, (s.opts.maxStep - s.i) / s.i * timeSpent); 68 | progress = timeSpent/(timeRemain+timeSpent); 69 | 70 | str1 = sprintf(repmat('\b', 1, o.nextBackspaceLength)); 71 | 72 | if (s.N == Inf) 73 | fmt = sprintf('%s%i%s%i%s', '%12s | %12s | %s | %', o.nSamplesFieldLength, 'i | %8.6f | %8.6f | %8.1f'); 74 | str2 = sprintf(fmt, durationString(timeSpent), durationString(timeRemain), progressString(progress, o.barLength), nSamples, prob, s.stepSize, avgMixingTime); 75 | else 76 | fmt = sprintf('%s%i%s%i%s', '%12s | %12s | %s | %', o.nSamplesTextLength, 'i / %', o.nSamplesTextLength, 'i | %8.6f | %8.6f | %8.1f'); 77 | str2 = sprintf(fmt, durationString(timeSpent), durationString(timeRemain), progressString(progress, o.barLength), nSamples, s.N, prob, s.stepSize, avgMixingTime); 78 | end 79 | 80 | disp([str1 str2]); 81 | o.nextBackspaceLength = length(str2) + o.nExtraBackspacePerPrint; 82 | o.lastRefresh = tic; 83 | 84 | function o = durationString(str) 85 | if isnan(str) || isinf(str) 86 | o = ' NaN'; 87 | else 88 | str = max(str,0); 89 | second = mod(floor(str),60); 90 | minute = mod(floor(str/60),60); 91 | hour = mod(floor(str/3600),24); 92 | day = floor(str/86400); 93 | o = sprintf('%02id:%02i:%02i:%02i', day, hour, minute, second); 94 | end 95 | end 96 | 97 | function o = progressString(str, len) 98 | str = max(min(str, 1),0); 99 | l = round(str*len); 100 | o = repmat('#', 1, l); 101 | o = [o, repmat(' ', 1, len - l)]; 102 | end 103 | end 104 | end 105 | end -------------------------------------------------------------------------------- /code/prepare/analytic_center.m: -------------------------------------------------------------------------------- 1 | function [x, C, d] = analytic_center(A, b, f, opts, x) 2 | %[x, C, d, output] = analytic_center(A, b, f, opts, x) 3 | %compute the analytic center for the domain {Ax=b} intersect the domain of f 4 | % 5 | %Input: 6 | % A - a m x n constraint matrix. The domain of the problem is given by Ax=b. 7 | % b - a m x 1 constraint vector. 8 | % f - a barrier class (Currently, we only have TwoSidedBarrier) 9 | % opts - a structure for options with the following properties 10 | % ipmMaxIter - maximum number of iterations 11 | % ipmDualTol - stop when ||A' * lambda - gradf(x)||_2 < dualTol 12 | % ipmDistanceTol - if x_i is ipmDistanceTol close to some boundary, we assume the coordinate i is tight. 13 | % 14 | %Output: 15 | % x - It outputs the minimizer of min f(x) subjects to {Ax=b} 16 | % C - detected constraint matrix 17 | % If the domain ({Ax=b} intersect dom(f)) is not full dimensional in {Ax=b} 18 | % because of the dom(f), the algorithm will detect the collapsed dimension 19 | % and output the detected constraint C x = d 20 | % d - detected constraint vector 21 | 22 | %% prepare the printout 23 | formats = struct; 24 | formats.iter = struct('label', 'Iter', 'format', '5i'); 25 | formats.t = struct('label', 'Step Size', 'format', '13.2e'); 26 | formats.primalErr = struct('label', 'Primal Error', 'format', '13.2e'); 27 | formats.dualErr = struct('label', 'Dual Error', 'format', '13.2e'); 28 | output = TableDisplay(formats); 29 | opts.logFunc('analytic_center', output.header()); 30 | 31 | %% initial conditions 32 | if exist('x', 'var') == 0 || isempty(x) || ~f.barrier.feasible(x) 33 | x = f.barrier.center; 34 | end 35 | lambda = zeros(size(A,2),1); 36 | fullStep = 0; tConst = 0; 37 | primalErr = Inf; dualErr = Inf; primalErrMin = Inf; 38 | primalFactor = 1; dualFactor = 1 + norm(b); 39 | idx = []; 40 | solver = Solver(A, 'doubledouble'); 41 | 42 | %% find the central path 43 | for iter = 1:opts.ipmMaxIter 44 | [grad, hess] = f.analytic_center_oracle(x); 45 | 46 | % compute the residual 47 | rx = lambda - grad; 48 | rs = b - A * x; 49 | 50 | % check stagnation 51 | primalErrMin = min(primalErr,primalErrMin); primalErr = norm(rx)/primalFactor; 52 | dualErrLast = dualErr; dualErr = norm(rs)/dualFactor; 53 | feasible = f.barrier.feasible(x); 54 | if ((dualErr > (1-0.9*tConst)*dualErrLast) || (primalErr > 10 * primalErrMin) || ~feasible) 55 | dist = f.barrier.boundary_distance(x); 56 | idx = find(dist < opts.ipmDistanceTol); 57 | if ~isempty(idx), break; end 58 | end 59 | 60 | % compute the step direction 61 | Hinv = 1./hess; 62 | solver.setScale(Hinv); 63 | dr1 = A' * solver.solve(rs); dr2 = A' * solver.solve(A * (Hinv .* rx)); 64 | dx1 = Hinv .* (dr1); 65 | dx2 = Hinv .* (rx - dr2); 66 | 67 | % compute the step size 68 | dx = dx1 + dx2; 69 | tGrad = min(f.barrier.step_size(x, dx),1); 70 | dx = dx1 + tGrad * dx2; 71 | tConst = min(0.99*f.barrier.step_size(x, dx),1); 72 | tGrad = tGrad * tConst; 73 | 74 | % make the step 75 | x = x + tConst * dx; 76 | lambda = lambda - dr2; 77 | 78 | if ~f.barrier.feasible(x), break; end 79 | 80 | % printout 81 | o = struct('iter', iter, 't', tGrad, 'primalErr', primalErr, 'dualErr', dualErr); 82 | opts.logFunc('analytic_center', output.print(o)); 83 | 84 | % stop if converged 85 | if (tGrad == 1) 86 | fullStep = fullStep + 1; 87 | if (fullStep > log(dualErr/opts.ipmDualTol) && fullStep > 8) 88 | break; 89 | end 90 | else 91 | fullStep = 0; 92 | end 93 | end 94 | 95 | if isempty(idx) 96 | dist = f.barrier.boundary_distance(x); 97 | idx = find(dist < opts.ipmDistanceTol); 98 | end 99 | 100 | if ~isempty(idx) 101 | [A_, b_] = f.barrier.boundary(x); 102 | C = A_(idx,:); 103 | d = b_(idx); 104 | else 105 | C = zeros(0, size(A,2)); 106 | d = zeros(0, 1); 107 | end -------------------------------------------------------------------------------- /code/prepare/gmscale.m: -------------------------------------------------------------------------------- 1 | function [cscale,rscale] = gmscale(A,iprint,scltol) 2 | 3 | % [cscale,rscale] = gmscale(A,iprint,scltol); 4 | %------------------------------------------------------------------ 5 | % gmscale (Geometric-Mean Scaling) finds the scale values for the 6 | % m x n sparse matrix A. 7 | % 8 | % On entry: 9 | % A(i,j) contains entries of A. 10 | % iprint > 0 requests messages to the screen (0 means no output). 11 | % scltol should be in the range (0.0, 1.0). 12 | % Typically scltol = 0.9. A bigger value like 0.99 asks 13 | % gmscale to work a little harder (more passes). 14 | % 15 | % On exit: 16 | % cscale, rscale are column vectors of column and row scales such that 17 | % R(inverse) A C(inverse) should have entries near 1.0, 18 | % where R = diag(rscale), C = diag(cscale). 19 | % 20 | % Method: 21 | % An iterative procedure based on geometric means is used, 22 | % following a routine written by Robert Fourer, 1979. 23 | % Several passes are made through the columns and rows of A. 24 | % The main steps are: 25 | % 26 | % 1. Compute aratio = max_j (max_i Aij / min_i Aij). 27 | % 2. Divide each row i by sqrt( max_j Aij * min_j Aij ). 28 | % 3. Divide each column j by sqrt( max_i Aij * min_i Aij ). 29 | % 4. Compute sratio as in Step 1. 30 | % 5. If sratio < aratio * scltol, 31 | % set aratio = sratio and repeat from Step 2. 32 | % 33 | % To dampen the effect of very small elements, on each pass, 34 | % a new row or column scale will not be smaller than sqrt(damp) 35 | % times the largest (scaled) element in that row or column. 36 | % 37 | % Use of the scales: 38 | % To apply the scales to a linear program, 39 | % min c'x st Ax = b, l <= x <= u, 40 | % we need to define "barred" quantities by the following relations: 41 | % A = R Abar C, b = R bbar, C cbar = c, 42 | % C l = lbar, C u = ubar, C x = xbar. 43 | % This gives the scaled problem 44 | % min cbar'xbar st Abar xbar = bbar, lbar <= xbar <= ubar. 45 | 46 | % Maintainer: Michael Saunders, Systems Optimization Laboratory, 47 | % Stanford University. 48 | % 07 Jun 1996: First f77 version, based on MINOS 5.5 routine m2scal. 49 | % 24 Apr 1998: Added final pass to make column norms = 1. 50 | % 18 Nov 1999: Fixed up documentation. 51 | % 26 Mar 2006: (Leo Tenenblat) First Matlab version based on Fortran version. 52 | % 21 Mar 2008: (MAS) Inner loops j = 1:n optimized. 53 | % 09 Apr 2008: (MAS) All loops replaced by sparse-matrix operations. 54 | % We can't find the biggest and smallest Aij 55 | % on each scaling pass, so no longer print them. 56 | % 24 Apr 2008: (MAS, Kaustuv) Allow for empty rows and columns. 57 | % 13 Nov 2009: gmscal.m renamed gmscale.m. 58 | %------------------------------------------------------------------ 59 | 60 | if iprint > 0 61 | fprintf('\ngmscale: Geometric-Mean scaling of matrix') 62 | fprintf('\n-------\n Max col ratio') 63 | end 64 | 65 | [m,n] = size(A); 66 | A = abs(A); % Work with |Aij| 67 | maxpass = 10; 68 | aratio = 1e+50; 69 | damp = 1e-4; 70 | small = 1e-8; 71 | rscale = ones(m,1); 72 | cscale = ones(n,1); 73 | 74 | %--------------------------------------------------------------- 75 | % Main loop. 76 | %--------------------------------------------------------------- 77 | for npass = 0:maxpass 78 | 79 | % Find the largest column ratio. 80 | % Also set new column scales (except on pass 0). 81 | 82 | rscale(rscale==0) = 1; 83 | Rinv = diag(sparse(1./rscale)); 84 | SA = Rinv*A; 85 | [I,J,V] = find(SA); 86 | invSA = sparse(I,J,1./V,m,n); 87 | cmax = full(max(SA))'; % column vector 88 | cmin = full(max(invSA))'; % column vector 89 | cmin = 1./(cmin + eps); 90 | sratio = max( cmax./cmin ); % Max col ratio 91 | if npass > 0 92 | cscale = sqrt( max(cmin, damp*cmax) .* cmax ); 93 | end 94 | 95 | if iprint > 0 96 | fprintf('\n After %2g %19.2f', npass, sratio) 97 | end 98 | 99 | if npass >= 2 && sratio >= aratio*scltol, break; end 100 | if npass == maxpass, break; end 101 | aratio = sratio; 102 | 103 | % Set new row scales for the next pass. 104 | 105 | cscale(cscale==0) = 1; 106 | cscale = 2.^nextpow2(full(cscale)); 107 | Cinv = diag(sparse(1./cscale)); 108 | SA = A*Cinv; % Scaled A 109 | [I,J,V] = find(SA); 110 | invSA = sparse(I,J,1./V,m,n); 111 | rmax = full(max(SA,[],2)); % column vector 112 | rmin = full(max(invSA,[],2)); % column vector 113 | rmin = 1./(rmin + eps); 114 | rscale = sqrt( max(rmin, damp*rmax) .* rmax ); 115 | rscale = 2.^nextpow2(full(rscale)); 116 | end 117 | %--------------------------------------------------------------- 118 | % End of main loop. 119 | %--------------------------------------------------------------- 120 | 121 | % Reset column scales so the biggest element 122 | % in each scaled column will be 1. 123 | % Again, allow for empty rows and columns. 124 | 125 | rscale(rscale==0) = 1; 126 | Rinv = diag(sparse(1./rscale)); 127 | SA = Rinv*A; 128 | [I,J,V] = find(SA); 129 | cscale = full(max(SA))'; % column vector 130 | cscale(cscale==0) = 1; 131 | cscale = 2.^nextpow2(full(cscale)); 132 | % Find the min and max scales. 133 | 134 | if iprint>0 135 | [rmin,imin] = min(rscale); 136 | [rmax,imax] = max(rscale); 137 | [cmin,jmin] = min(cscale); 138 | [cmax,jmax] = max(cscale); 139 | 140 | fprintf('\n\n Min scale Max scale') 141 | fprintf('\n Row %6g %9.1e Row %6g %9.1e', imin, rmin, imax, rmax) 142 | fprintf('\n Col %6g %9.1e Col %6g %9.1e', jmin, cmin, jmax, cmax) 143 | end 144 | 145 | % end of gmscale 146 | -------------------------------------------------------------------------------- /code/prepare/lewis_center.m: -------------------------------------------------------------------------------- 1 | function [x, C, d, wp] = lewis_center(A, b, f, opts, x) 2 | %[x, C, d, wp] = lewis_center(A, b, f, opts, x) 3 | %compute the lewis center for the domain {Ax=b} intersect the domain of f 4 | % 5 | %Input: 6 | % A - a m x n constraint matrix. The domain of the problem is given by Ax=b. 7 | % b - a m x 1 constraint vector. 8 | % f - a barrier class (Currently, we only have TwoSidedBarrier) 9 | % opts - a structure for options with the following properties 10 | % ipmMaxIter - maximum number of iterations 11 | % ipmDualTol - stop when ||A' * lambda - gradf(x)||_2 < dualTol 12 | % ipmDistanceTol - if x_i is ipmDistanceTol close to some boundary, we assume the coordinate i is tight. 13 | % 14 | %Output: 15 | % x - It outputs the minimizer of min f(x) subjects to {Ax=b} 16 | % C - detected constraint matrix 17 | % If the domain ({Ax=b} intersect dom(f)) is not full dimensional in {Ax=b} 18 | % because of the dom(f), the algorithm will detect the collapsed dimension 19 | % and output the detected constraint C x = d 20 | % d - detected constraint vector 21 | 22 | %% prepare the printout 23 | formats = struct; 24 | formats.iter = struct('label', 'Iter', 'format', '5i'); 25 | formats.t = struct('label', 'Step Size', 'format', '13.2e'); 26 | formats.primalErr = struct('label', 'Primal Error', 'format', '13.2e'); 27 | formats.dualErr = struct('label', 'Dual Error', 'format', '13.2e'); 28 | output = TableDisplay(formats); 29 | opts.logFunc('analytic_center', output.header()); 30 | 31 | %% initial conditions 32 | if exist('x', 'var') == 0 || isempty(x) || ~f.barrier.feasible(x) 33 | x = f.barrier.center; 34 | end 35 | lambda = zeros(size(A,2),1); 36 | fullStep = 0; tConst = 0; 37 | primalErr = Inf; dualErr = Inf; primalErrMin = Inf; 38 | primalFactor = 1; dualFactor = 1 + norm(b); 39 | idx = []; 40 | solver = Solver(A, 'doubledouble'); 41 | w = ones(size(x)); 42 | wp = w; 43 | 44 | %% find the central path 45 | for iter = 1:opts.ipmMaxIter 46 | [grad, hess] = f.lewis_center_oracle(x, wp); 47 | 48 | % compute the residual 49 | rx = lambda - grad; 50 | rs = b - A * x; 51 | 52 | % check stagnation 53 | primalErrMin = min(primalErr,primalErrMin); primalErr = norm(rx)/primalFactor; 54 | dualErrLast = dualErr; dualErr = norm(rs)/dualFactor; 55 | feasible = f.barrier.feasible(x); 56 | if ((dualErr > (1-0.9*tConst)*dualErrLast) || (primalErr > 10 * primalErrMin) || ~feasible) 57 | dist = f.barrier.boundary_distance(x); 58 | idx = find(dist < opts.ipmDistanceTol); 59 | if ~isempty(idx), break; end 60 | end 61 | 62 | % compute the step direction 63 | Hinv = 1./hess; 64 | solver.setScale(Hinv); 65 | dr1 = A' * solver.solve(rs); dr2 = A' * solver.solve(A * (Hinv .* rx)); 66 | dx1 = Hinv .* (dr1); 67 | dx2 = Hinv .* (rx - dr2); 68 | 69 | % compute the step size 70 | dx = dx1 + dx2; 71 | tGrad = min(f.barrier.step_size(x, dx),1); 72 | dx = dx1 + tGrad * dx2; 73 | tConst = min(0.99*f.barrier.step_size(x, dx),1); 74 | tGrad = tGrad * tConst; 75 | 76 | % make the step 77 | x = x + tConst * dx; 78 | lambda = lambda - dr2; 79 | 80 | % update weight 81 | wNew = max(double(solver.leverageScoreComplement(0)), 0) + 1e-8;%1 / length(w); 82 | w = (w + wNew)/2; 83 | wp = w.^(1-1/8); 84 | 85 | if ~f.barrier.feasible(x), break; end 86 | 87 | % printout 88 | o = struct('iter', iter, 't', tGrad, 'primalErr', primalErr, 'dualErr', dualErr); 89 | opts.logFunc('analytic_center', output.print(o)); 90 | 91 | % stop if converged 92 | if (tGrad == 1) 93 | fullStep = fullStep + 1; 94 | if (fullStep > log(dualErr/opts.ipmDualTol) && fullStep > 8) 95 | break; 96 | end 97 | else 98 | fullStep = 0; 99 | end 100 | end 101 | 102 | if isempty(idx) 103 | dist = f.barrier.boundary_distance(x); 104 | idx = find(dist < opts.ipmDistanceTol); 105 | end 106 | 107 | if ~isempty(idx) 108 | [A_, b_] = f.barrier.boundary(x); 109 | C = A_(idx,:); 110 | d = b_(idx); 111 | else 112 | C = zeros(0, size(A,2)); 113 | d = zeros(0, 1); 114 | end -------------------------------------------------------------------------------- /code/prepare/standardize_problem.m: -------------------------------------------------------------------------------- 1 | function P = standardize_problem(P) 2 | if nonempty(P, 'A') 3 | error('Polytope:standardize', 'Use Aeq or Aineq instead of A in the model structure.'); 4 | end 5 | 6 | if nonempty(P, 'Aeq') 7 | n = size(P.Aeq, 2); 8 | elseif nonempty(P, 'Aineq') 9 | n = size(P.Aineq, 2); 10 | elseif nonempty(P, 'lb') 11 | n = length(P.lb); 12 | elseif nonempty(P, 'ub') 13 | n = length(P.ub); 14 | elseif nonempty(P, 'center') 15 | n = length(P.center); 16 | else 17 | error('Polytope:standardize', 'For unconstrained problems, an initial point "center" is required.'); 18 | end 19 | 20 | %% Set all non-existence fields 21 | if ~nonempty(P, 'Aeq') 22 | P.Aeq = sparse(zeros(0, n)); 23 | end 24 | 25 | if ~nonempty(P, 'beq') 26 | P.beq = zeros(size(P.Aeq, 1), 1); 27 | end 28 | 29 | if ~nonempty(P, 'Aineq') 30 | P.Aineq = sparse(zeros(0, n)); 31 | end 32 | 33 | if ~nonempty(P, 'bineq') 34 | P.bineq = zeros(size(P.Aineq, 1), 1); 35 | end 36 | 37 | if ~nonempty(P, 'lb') 38 | P.lb = -Inf * ones(n, 1); 39 | end 40 | 41 | if ~nonempty(P,'ub') 42 | P.ub = Inf * ones(n,1); 43 | end 44 | 45 | if ~nonempty(P,'center') 46 | P.center = []; 47 | end 48 | 49 | %% Store f, df, ddf, dddf 50 | randVec = randn(n, 1); 51 | hasf = isfield(P, 'f') && ~isempty(P.f); 52 | hasdf = isfield(P, 'df') && ~isempty(P.df); 53 | hasddf = isfield(P, 'ddf') && ~isempty(P.ddf); 54 | 55 | if (hasdf && isfloat(P.df) && norm(P.df) == 0) 56 | hasdf = false; 57 | end 58 | 59 | % Case 1: df is empty 60 | if ~hasdf 61 | assert(~hasf && ~hasddf);% && ~hasdddf); 62 | P.f = []; 63 | P.df = []; 64 | P.ddf = []; 65 | elseif isfloat(P.df) % Case 2: df is a vector 66 | assert(all(size(P.df) == [n 1])); 67 | P.f = []; 68 | P.df = P.df; 69 | P.ddf = []; 70 | elseif isa(P.df, 'function_handle') % Case 3: df is handle 71 | assert(hasf); 72 | assert(isa(P.f,'function_handle')); 73 | assert(all(size(P.f(randVec)) == [1 1])); 74 | assert(all(size(P.df(randVec)) == [n 1])); 75 | P.f = P.f; 76 | P.df = P.df; 77 | if hasddf 78 | if isa(P.ddf, 'function_handle') 79 | assert(all(size(P.ddf(randVec)) == [n 1])); 80 | else 81 | assert(all(size(P.ddf) == [n 1])); 82 | end 83 | P.ddf = P.ddf; 84 | else 85 | P.ddf = []; 86 | end 87 | 88 | %% Verify f, df, ddf, dddf 89 | % TODO 90 | end 91 | 92 | %% Check the input dimensions 93 | assert(all(size(P.Aineq) == [length(P.bineq) n])); 94 | assert(all(size(P.Aeq) == [length(P.beq) n])); 95 | assert(all(size(P.lb) == [n 1])); 96 | assert(all(size(P.ub) == [n 1])); 97 | assert(all(P.lb <= P.ub)); 98 | end -------------------------------------------------------------------------------- /code/sample.m: -------------------------------------------------------------------------------- 1 | function o = sample(problem, N, opts) 2 | %Input: a structure problem with the following fields 3 | % .Aineq 4 | % .bineq 5 | % .Aeq 6 | % .beq 7 | % .lb 8 | % .ub 9 | % .f 10 | % .df 11 | % .ddf 12 | % describing a logconcave distribution given by 13 | % exp(-sum f_i(x_i)) 14 | % over 15 | % {Aineq x <= bineq, Aeq x = beq, lb <= x <= ub} 16 | % where f is given by a vector function of its first (df) and second (ddf) 17 | % derivatives. 18 | % 19 | % Case 1: df is not defined 20 | % f(x) = 0. 21 | % In this case, f, ddf must be empty. This is the uniform distribution. In 22 | % this case the feasible region must be bounded. 23 | % 24 | % Case 2: df is a vector 25 | % f_i(x_i) = df_i x_i. 26 | % In this case, f, ddf must be empty. 27 | % 28 | % Case 3: df is a function handle 29 | % f need to be defined as a function handle. 30 | % df need to be the derivative of f 31 | % ddf is optional. Providing ddf could improve the mixing time. 32 | % 33 | % N - number of independent samples 34 | % 35 | % opts - sampling options 36 | % 37 | %Output: 38 | % o - a structure containing the following properties: 39 | % samples - a cell of dim x N vectors containing each chain of samples 40 | % prepareTime - time to pre-process the input (including find interior 41 | % point, remove redundant constraints, reduce dimension) 42 | % sampleTime - total sampling time in seconds (sum over all workers) 43 | t = tic; 44 | 45 | %% Initialize parameters and compiling if needed 46 | if (nargin <= 2) 47 | opts = default_options(); 48 | else 49 | opts = Setfield(default_options(), opts); 50 | end 51 | opts.startTime = t; 52 | 53 | compile_solver(0); compile_solver(opts.simdLen); 54 | 55 | %% Presolve 56 | if isempty(opts.seed) 57 | opts.seed = randi(2^31); 58 | end 59 | 60 | if ischar(opts.logging) || isstring(opts.logging) % logging for Polytope 61 | fid = fopen(opts.logging, 'a'); 62 | opts.presolve.logFunc = @(tag, msg) fprintf(fid, '%s', msg); 63 | elseif ~isempty(opts.logging) 64 | opts.presolve.logFunc = opts.logging; 65 | else 66 | opts.presolve.logFunc = @(tag, msg) 0; 67 | end 68 | 69 | polytope = Polytope(problem, opts); 70 | 71 | if ischar(opts.logging) || isstring(opts.logging) 72 | fclose(fid); 73 | end 74 | 75 | prepareTime = toc(t); 76 | 77 | 78 | %% Check the trivial case 79 | if polytope.n == 0 80 | warning('The domain consists only a single point.'); 81 | o = struct; 82 | o.prepareTime = prepareTime; 83 | o.sampleTime = 0; 84 | o.problem = polytope; 85 | o.samples = polytope.center; 86 | return 87 | end 88 | 89 | %% Set up workers if nWorkers ~= 1 90 | if opts.nWorkers ~= 1 && ~isempty(ver('parallel')) 91 | % create pool with size nWorkers 92 | p = gcp('nocreate'); 93 | if isempty(p) 94 | if opts.nWorkers ~= Inf 95 | p = parpool(opts.nWorkers); 96 | else 97 | p = parpool(); 98 | end 99 | elseif opts.nWorkers ~= Inf && p.NumWorkers ~= opts.nWorkers 100 | delete(p); 101 | p = parpool(opts.nWorkers); 102 | end 103 | opts.nWorkers = p.NumWorkers; 104 | opts.N = N; 105 | 106 | spmd(opts.nWorkers) 107 | if opts.profiling 108 | mpiprofile on 109 | end 110 | 111 | rng(opts.seed + labindex, 'simdTwister'); 112 | s = Sampler(polytope, opts); 113 | while s.terminate == 0 114 | s.step(); 115 | end 116 | s.finalize(); 117 | workerOutput = s.output; 118 | 119 | if opts.profiling 120 | mpiprofile viewer 121 | end 122 | end 123 | 124 | o = struct; 125 | o.workerOutput = cell(opts.nWorkers, 1); 126 | o.sampleTime = 0; 127 | for i = 1:opts.nWorkers 128 | o.workerOutput{i} = workerOutput{i}; 129 | o.sampleTime = o.sampleTime + o.workerOutput{i}.sampleTime; 130 | end 131 | 132 | if ~opts.rawOutput 133 | o.chains = o.workerOutput{1}.chains; 134 | for i = 2:opts.nWorkers 135 | o.chains = [o.chains o.workerOutput{i}.chains]; 136 | o.workerOutput{i}.chains = []; 137 | end 138 | end 139 | else 140 | opts.N = N; 141 | 142 | if opts.profiling 143 | profile on 144 | end 145 | 146 | rng(opts.seed, 'simdTwister'); 147 | s = Sampler(polytope, opts); 148 | while s.terminate == 0 149 | s.step(); 150 | end 151 | s.finalize(); 152 | o = s.output; 153 | 154 | if opts.profiling 155 | profile report 156 | end 157 | end 158 | o.problem = polytope; 159 | o.opts = opts; 160 | o.prepareTime = prepareTime; 161 | if ~opts.rawOutput 162 | o.samples = o.chains; 163 | o = rmfield(o, 'chains'); 164 | o.summary = summary(o); 165 | end 166 | -------------------------------------------------------------------------------- /code/solver/FeatureDetector/cpuInfo.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "mex.h" 3 | #include "cpu_x86.h" 4 | #include "cpu_x86.cpp" 5 | 6 | using namespace std; 7 | using namespace FeatureDetector; 8 | 9 | template 10 | mxArray* mxCreateScalar(T x); 11 | 12 | template <> 13 | mxArray* mxCreateScalar(bool x) 14 | { 15 | return mxCreateLogicalScalar(x); 16 | } 17 | 18 | template 19 | mxArray* createStructure(const unordered_map &map) 20 | { 21 | const char** fieldnames = new const char* [map.size()]; 22 | mxArray** values = new mxArray *[map.size()]; 23 | 24 | size_t k = 0; 25 | 26 | for (const auto& nv : map) 27 | { 28 | fieldnames[k] = (char*)mxMalloc(64); 29 | values[k] = mxCreateScalar(nv.second); 30 | memcpy((void*)fieldnames[k], nv.first, strlen(nv.first) + 1); 31 | ++k; 32 | } 33 | 34 | auto pt = mxCreateStructMatrix(1, 1, (int)map.size(), fieldnames); 35 | for (size_t i = 0; i < map.size(); ++i) 36 | { 37 | mxFree((void*)fieldnames[i]); 38 | mxSetFieldByNumber(pt, 0, i, values[i]); 39 | } 40 | 41 | delete[] fieldnames; 42 | delete[] values; 43 | 44 | return pt; 45 | } 46 | 47 | void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) 48 | { 49 | if (nlhs == 1) 50 | { 51 | cpu_x86 features; 52 | features.detect_host(); 53 | 54 | unordered_map info; 55 | info["SSE"] = features.HW_SSE; 56 | info["SSE2"] = features.HW_SSE2; 57 | info["SSE3"] = features.HW_SSE3; 58 | info["SSE41"] = features.HW_SSE41; 59 | info["SSE42"] = features.HW_SSE42; 60 | info["AVX"] = features.HW_AVX; 61 | info["AVX2"] = features.HW_AVX2; 62 | info["FMA"] = features.HW_FMA3; 63 | info["AVX512F"] = features.HW_AVX512_F; 64 | info["OS_AVX"] = features.OS_AVX; 65 | info["OS_AVX512"] = features.OS_AVX512; 66 | 67 | plhs[0] = createStructure(info); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /code/solver/FeatureDetector/cpu_x86.h: -------------------------------------------------------------------------------- 1 | /* cpu_x86.h 2 | * 3 | * Author : Alexander J. Yee 4 | * Date Created : 04/12/2014 5 | * Last Modified : 04/12/2014 6 | * 7 | */ 8 | 9 | #pragma once 10 | #ifndef _cpu_x86_H 11 | #define _cpu_x86_H 12 | //////////////////////////////////////////////////////////////////////////////// 13 | //////////////////////////////////////////////////////////////////////////////// 14 | //////////////////////////////////////////////////////////////////////////////// 15 | //////////////////////////////////////////////////////////////////////////////// 16 | // Dependencies 17 | #include 18 | #include 19 | namespace FeatureDetector{ 20 | //////////////////////////////////////////////////////////////////////////////// 21 | //////////////////////////////////////////////////////////////////////////////// 22 | //////////////////////////////////////////////////////////////////////////////// 23 | //////////////////////////////////////////////////////////////////////////////// 24 | struct cpu_x86{ 25 | // Vendor 26 | bool Vendor_AMD; 27 | bool Vendor_Intel; 28 | 29 | // OS Features 30 | bool OS_x64; 31 | bool OS_AVX; 32 | bool OS_AVX512; 33 | 34 | // Misc. 35 | bool HW_MMX; 36 | bool HW_x64; 37 | bool HW_ABM; 38 | bool HW_RDRAND; 39 | bool HW_RDSEED; 40 | bool HW_BMI1; 41 | bool HW_BMI2; 42 | bool HW_ADX; 43 | bool HW_MPX; 44 | bool HW_PREFETCHW; 45 | bool HW_PREFETCHWT1; 46 | bool HW_RDPID; 47 | 48 | // SIMD: 128-bit 49 | bool HW_SSE; 50 | bool HW_SSE2; 51 | bool HW_SSE3; 52 | bool HW_SSSE3; 53 | bool HW_SSE41; 54 | bool HW_SSE42; 55 | bool HW_SSE4a; 56 | bool HW_AES; 57 | bool HW_SHA; 58 | 59 | // SIMD: 256-bit 60 | bool HW_AVX; 61 | bool HW_XOP; 62 | bool HW_FMA3; 63 | bool HW_FMA4; 64 | bool HW_AVX2; 65 | 66 | // SIMD: 512-bit 67 | bool HW_AVX512_F; 68 | bool HW_AVX512_CD; 69 | 70 | // Knights Landing 71 | bool HW_AVX512_PF; 72 | bool HW_AVX512_ER; 73 | 74 | // Skylake Purley 75 | bool HW_AVX512_VL; 76 | bool HW_AVX512_BW; 77 | bool HW_AVX512_DQ; 78 | 79 | // Cannon Lake 80 | bool HW_AVX512_IFMA; 81 | bool HW_AVX512_VBMI; 82 | 83 | // Knights Mill 84 | bool HW_AVX512_VPOPCNTDQ; 85 | bool HW_AVX512_4FMAPS; 86 | bool HW_AVX512_4VNNIW; 87 | 88 | // Cascade Lake 89 | bool HW_AVX512_VNNI; 90 | 91 | // Cooper Lake 92 | bool HW_AVX512_BF16; 93 | 94 | // Ice Lake 95 | bool HW_AVX512_VBMI2; 96 | bool HW_GFNI; 97 | bool HW_VAES; 98 | bool HW_AVX512_VPCLMUL; 99 | bool HW_AVX512_BITALG; 100 | 101 | public: 102 | cpu_x86(); 103 | void detect_host(); 104 | 105 | void print() const; 106 | static void print_host(); 107 | 108 | static void cpuid(int32_t out[4], int32_t eax, int32_t ecx); 109 | static std::string get_vendor_string(); 110 | 111 | private: 112 | static void print(const char* label, bool yes); 113 | 114 | static bool detect_OS_x64(); 115 | static bool detect_OS_AVX(); 116 | static bool detect_OS_AVX512(); 117 | }; 118 | //////////////////////////////////////////////////////////////////////////////// 119 | //////////////////////////////////////////////////////////////////////////////// 120 | //////////////////////////////////////////////////////////////////////////////// 121 | //////////////////////////////////////////////////////////////////////////////// 122 | } 123 | #endif 124 | -------------------------------------------------------------------------------- /code/solver/FeatureDetector/cpu_x86_Linux.ipp: -------------------------------------------------------------------------------- 1 | /* cpu_x86_Linux.ipp 2 | * 3 | * Author : Alexander J. Yee 4 | * Date Created : 04/12/2014 5 | * Last Modified : 04/12/2014 6 | * 7 | */ 8 | 9 | //////////////////////////////////////////////////////////////////////////////// 10 | //////////////////////////////////////////////////////////////////////////////// 11 | //////////////////////////////////////////////////////////////////////////////// 12 | //////////////////////////////////////////////////////////////////////////////// 13 | // Dependencies 14 | #include 15 | #include "cpu_x86.h" 16 | namespace FeatureDetector{ 17 | //////////////////////////////////////////////////////////////////////////////// 18 | //////////////////////////////////////////////////////////////////////////////// 19 | //////////////////////////////////////////////////////////////////////////////// 20 | //////////////////////////////////////////////////////////////////////////////// 21 | void cpu_x86::cpuid(int32_t out[4], int32_t eax, int32_t ecx){ 22 | __cpuid_count(eax, ecx, out[0], out[1], out[2], out[3]); 23 | } 24 | uint64_t xgetbv(unsigned int index){ 25 | uint32_t eax, edx; 26 | __asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index)); 27 | return ((uint64_t)edx << 32) | eax; 28 | } 29 | #define _XCR_XFEATURE_ENABLED_MASK 0 30 | //////////////////////////////////////////////////////////////////////////////// 31 | //////////////////////////////////////////////////////////////////////////////// 32 | // Detect 64-bit 33 | bool cpu_x86::detect_OS_x64(){ 34 | // We only support x64 on Linux. 35 | return true; 36 | } 37 | //////////////////////////////////////////////////////////////////////////////// 38 | //////////////////////////////////////////////////////////////////////////////// 39 | //////////////////////////////////////////////////////////////////////////////// 40 | //////////////////////////////////////////////////////////////////////////////// 41 | } 42 | -------------------------------------------------------------------------------- /code/solver/FeatureDetector/cpu_x86_Windows.ipp: -------------------------------------------------------------------------------- 1 | /* cpu_x86_Windows.ipp 2 | * 3 | * Author : Alexander J. Yee 4 | * Date Created : 04/12/2014 5 | * Last Modified : 04/12/2014 6 | * 7 | */ 8 | 9 | //////////////////////////////////////////////////////////////////////////////// 10 | //////////////////////////////////////////////////////////////////////////////// 11 | //////////////////////////////////////////////////////////////////////////////// 12 | //////////////////////////////////////////////////////////////////////////////// 13 | // Dependencies 14 | #include 15 | #include 16 | #include "cpu_x86.h" 17 | namespace FeatureDetector{ 18 | //////////////////////////////////////////////////////////////////////////////// 19 | //////////////////////////////////////////////////////////////////////////////// 20 | //////////////////////////////////////////////////////////////////////////////// 21 | //////////////////////////////////////////////////////////////////////////////// 22 | void cpu_x86::cpuid(int32_t out[4], int32_t eax, int32_t ecx){ 23 | __cpuidex(out, eax, ecx); 24 | } 25 | __int64 xgetbv(unsigned int x){ 26 | return _xgetbv(x); 27 | } 28 | //////////////////////////////////////////////////////////////////////////////// 29 | //////////////////////////////////////////////////////////////////////////////// 30 | // Detect 64-bit - Note that this snippet of code for detecting 64-bit has been copied from MSDN. 31 | typedef BOOL (WINAPI *LPFN_ISWOW64PROCESS) (HANDLE, PBOOL); 32 | BOOL IsWow64() 33 | { 34 | BOOL bIsWow64 = FALSE; 35 | 36 | LPFN_ISWOW64PROCESS fnIsWow64Process = (LPFN_ISWOW64PROCESS) GetProcAddress( 37 | GetModuleHandle(TEXT("kernel32")), "IsWow64Process"); 38 | 39 | if (NULL != fnIsWow64Process) 40 | { 41 | if (!fnIsWow64Process(GetCurrentProcess(), &bIsWow64)) 42 | { 43 | printf("Error Detecting Operating System.\n"); 44 | printf("Defaulting to 32-bit OS.\n\n"); 45 | bIsWow64 = FALSE; 46 | } 47 | } 48 | return bIsWow64; 49 | } 50 | bool cpu_x86::detect_OS_x64(){ 51 | #ifdef _M_X64 52 | return true; 53 | #else 54 | return IsWow64() != 0; 55 | #endif 56 | } 57 | //////////////////////////////////////////////////////////////////////////////// 58 | //////////////////////////////////////////////////////////////////////////////// 59 | //////////////////////////////////////////////////////////////////////////////// 60 | //////////////////////////////////////////////////////////////////////////////// 61 | } 62 | -------------------------------------------------------------------------------- /code/solver/FeatureDetector/readme.txt: -------------------------------------------------------------------------------- 1 | All files except cpuInfo.cpp come from the open source project 2 | https://github.com/Mysticial/FeatureDetector -------------------------------------------------------------------------------- /code/solver/MatlabSolver.m: -------------------------------------------------------------------------------- 1 | classdef MatlabSolver < handle 2 | properties (GetAccess = public) 3 | % the constraint matrix 4 | A 5 | w = NaN 6 | w_solve = NaN 7 | 8 | % privates 9 | initialized = false 10 | L 11 | precision 12 | usingExact = false 13 | numExact 14 | accuracy 15 | exactSolver 16 | end 17 | 18 | methods 19 | % precision is either double or doubledouble 20 | function o = MatlabSolver(A, precision) 21 | count = symbfact(A,'row'); 22 | n = size(A,1); 23 | if sum(count.^2) > n^3 / 100 24 | 25 | A = sparse(A); 26 | H = A * A' + speye(n); 27 | t1 = timeit(@() chol(H, 'lower')); 28 | H = full(H); 29 | t2 = timeit(@() chol(H, 'lower')); 30 | 31 | if t1 > t2 32 | A = full(A); 33 | end 34 | end 35 | 36 | o.A = A; 37 | o.precision = precision; 38 | o.numExact = zeros(2,1); 39 | o.exactSolver = MexSolver(A, 0.0, 0); 40 | end 41 | 42 | function err = estimateAccuracy(o) 43 | k = 4; 44 | JLdir = (rand(size(o.A,1), 2 * k)-0.5)*sqrt(12); 45 | V = (o.A' * (o.L'\JLdir)) .* sqrt(o.w); 46 | err = 0; 47 | for i = 1:k 48 | err = err + (sum(V(:,2*i-1).*V(:,2*i)) - sum(JLdir(:,2*i-1).*JLdir(:,2*i)))^2; 49 | end 50 | err = sqrt(err / k); 51 | end 52 | 53 | function setScale(o, w) 54 | o.initialized = true; 55 | 56 | if ~all(w == o.w) 57 | o.w = w; 58 | o.numExact(2) = o.numExact(2) + 1; 59 | o.usingExact = false; 60 | if (o.precision > 0.0) 61 | H = o.A * (spdiag(w) * o.A'); 62 | [o.L, p] = chol(H, 'lower'); 63 | if p~= 0 64 | o.usingExact = true; 65 | else 66 | o.accuracy = o.estimateAccuracy(); 67 | if (isnan(o.accuracy) || o.accuracy > o.precision) 68 | o.usingExact = true; 69 | end 70 | end 71 | else 72 | o.usingExact = true; 73 | end 74 | 75 | if (o.usingExact) 76 | o.numExact(1) = o.numExact(1) + 1; 77 | o.exactSolver.setScale(w); 78 | end 79 | end 80 | 81 | o.w_solve = w; 82 | end 83 | 84 | function L = diagL(o) 85 | assert(o.initialized); 86 | 87 | if (o.usingExact) 88 | L = o.exactSolver.diagL(); 89 | else 90 | L = diag(o.L); 91 | end 92 | end 93 | 94 | function r = logdet(o) 95 | assert(o.initialized); 96 | 97 | if (o.usingExact) 98 | r = o.exactSolver.logdet(); 99 | else 100 | r = 2 * sum(log(diag(o.L))); 101 | end 102 | end 103 | 104 | % Note that sigma can be negative. 105 | % We cannot truncate it to >=0 because we need unbias 106 | function sigma = leverageScoreComplement(o, nSketch) 107 | assert(o.initialized); 108 | 109 | if (o.usingExact) 110 | if nargin == 1 111 | sigma = o.exactSolver.leverageScoreComplement(); 112 | else 113 | sigma = o.exactSolver.leverageScoreComplement(nSketch); 114 | end 115 | else 116 | % V = full((o.L\(o.A .* sqrt(o.w)'))'); 117 | % sigma = 1-sum(V.^2,2); 118 | if nargin == 1 || nSketch == 0 119 | nSketch = 32; 120 | end 121 | 122 | JLdir = sign(rand(size(o.A,1), nSketch)-0.5); 123 | V = (o.A' * (o.L'\JLdir)) .* sqrt(o.w); 124 | sigma = 1 - sum(V.^2,2) / nSketch; 125 | end 126 | end 127 | 128 | function y = approxSolve(o, b) 129 | assert(o.initialized); 130 | 131 | if (o.usingExact) 132 | y = o.exactSolver.approxSolve(b); 133 | else 134 | y = o.L'\(o.L\b); 135 | end 136 | end 137 | 138 | function y = AwAt(o, b) 139 | assert(o.initialized); 140 | 141 | y = o.A * (o.w_solve .* (o.A' * b)); 142 | end 143 | 144 | function x = solve(o, b, w, x0) 145 | assert(o.initialized); 146 | 147 | if nargin < 4, x0 = zeros(size(b)); end 148 | if nargin >= 3, o.w_solve = w; end 149 | x = batch_pcg(x0, b, o, 1e-12, 20); 150 | end 151 | 152 | function counts = getDecomposeCount(o) 153 | counts = o.numExact; 154 | end 155 | end 156 | end -------------------------------------------------------------------------------- /code/solver/MexSolver.m: -------------------------------------------------------------------------------- 1 | classdef MexSolver < handle 2 | properties 3 | % the constraint matrix 4 | A 5 | w = NaN 6 | w_solve = NaN 7 | precision 8 | 9 | % private 10 | uid 11 | initialized = false 12 | accuracy 13 | solver 14 | k 15 | end 16 | 17 | methods (Static) 18 | function o = loadobj(s) 19 | s.uid = s.solver('init', uint64(randi(2^32-1,'uint32')), s.A); 20 | s.solver('setAccuracyTarget', s.uid, s.precision); 21 | if ~any(isnan(s.w)) 22 | w = s.w; s.w = NaN; 23 | s.setScale(w); 24 | end 25 | o = s; 26 | end 27 | 28 | function func_name = solverName(simd_len) 29 | if ismac() 30 | [~,result] = system('sysctl -n machdep.cpu.brand_string'); 31 | if contains(result,'Apple') 32 | chip = 'arm'; 33 | else 34 | chip = 'x86'; 35 | end 36 | else 37 | chip = 'x86'; 38 | end 39 | if strcmp(chip, 'arm') 40 | func_name = ['PackedChol' num2str(simd_len) 'arm']; 41 | else 42 | func_name = ['PackedChol' num2str(simd_len)]; 43 | end 44 | 45 | if exist([func_name 'native']) == 3 46 | func_name = [func_name 'native']; 47 | elseif strcmp(chip, 'x86') 48 | s = cpuInfo(); 49 | if (~(s.AVX2 && s.OS_AVX && s.FMA)) 50 | func_name = [func_name 'SSE']; 51 | end 52 | end 53 | end 54 | end 55 | 56 | methods 57 | % precision is either double or doubledouble 58 | function o = MexSolver(A, precision, k) 59 | o.A = A; 60 | o.k = k; 61 | o.solver = str2func(MexSolver.solverName(k)); 62 | o.uid = o.solver('init', uint64(randi(2^32-1,'uint32')), A); 63 | o.solver('setAccuracyTarget', o.uid, precision); 64 | o.precision = precision; 65 | 66 | if size(A, 2) == 0 67 | if o.k == 0 68 | o.w = zeros(0, 1); 69 | else 70 | o.w = zeros(o.k, 0); 71 | end 72 | o.accuracy = o.solver('decompose', o.uid, o.w); 73 | end 74 | end 75 | 76 | function b = saveobj(a) 77 | b = a; 78 | b.uid = []; 79 | end 80 | 81 | function delete(o) 82 | if ~isempty(o.uid) 83 | o.solver('delete', o.uid); 84 | end 85 | end 86 | 87 | function setScale(o, w) 88 | o.initialized = true; 89 | if ~all(w == o.w, 'all') 90 | o.accuracy = o.solver('decompose', o.uid, w); 91 | o.w = w; 92 | end 93 | 94 | o.w_solve = w; 95 | end 96 | 97 | function r = diagL(o) 98 | assert(o.initialized); 99 | 100 | r = o.solver('diagL', o.uid); 101 | end 102 | 103 | function r = logdet(o) 104 | assert(o.initialized); 105 | 106 | r = o.solver('logdet', o.uid); 107 | end 108 | 109 | % Note that sigma can be negative. 110 | % We cannot truncate it to >=0 because we need unbias 111 | function sigma = leverageScoreComplement(o, nSketch) 112 | assert(o.initialized); 113 | 114 | if nargin == 1, nSketch = 0; end 115 | sigma = o.solver('leverageScoreComplement', o.uid, nSketch); 116 | end 117 | 118 | function counts = getDecomposeCount(o) 119 | counts = o.solver('getDecomposeCount', o.uid); 120 | end 121 | 122 | function y2 = approxSolve(o, b) 123 | assert(o.initialized); 124 | 125 | y2 = o.solver('solve', o.uid, b); 126 | end 127 | 128 | function x = solve(o, b, w, x0) 129 | assert(o.initialized); 130 | 131 | if nargin < 4, x0 = zeros(size(b)); end 132 | if nargin >= 3, o.w_solve = w; end 133 | x = batch_pcg(x0, b, o, 1e-12, 20); 134 | end 135 | 136 | % used only in batch_pcg 137 | function y = AwAt(o, b) 138 | assert(o.initialized); 139 | if (o.k == 0) 140 | y = o.A * (o.w_solve .* (o.A' * b)); 141 | else 142 | y = ((b * o.A) .* o.w_solve) * o.A'; 143 | end 144 | end 145 | end 146 | end -------------------------------------------------------------------------------- /code/solver/MultiMatlabSolver.m: -------------------------------------------------------------------------------- 1 | classdef MultiMatlabSolver < handle 2 | properties (GetAccess = public) 3 | solvers 4 | k 5 | m 6 | n 7 | w 8 | accuracy 9 | end 10 | 11 | methods 12 | function o = MultiMatlabSolver(A, precision, k) 13 | o.k = k; 14 | o.m = size(A,2); 15 | o.n = size(A,1); 16 | for i = 1:o.k 17 | o.solvers{i} = MexSolver(A, precision, 0); 18 | end 19 | end 20 | 21 | function setScale(o, w) 22 | o.w = w; 23 | o.accuracy = zeros(o.k, 1); 24 | for i = 1:o.k 25 | o.solvers{i}.setScale(reshape(w(i,:), [o.m, 1])); 26 | o.accuracy(i) = o.solvers{i}.accuracy; 27 | end 28 | end 29 | 30 | function L = diagL(o) 31 | L = zeros(o.k, o.n); 32 | for i = 1:o.k 33 | L(i, :) = o.solvers{i}.diagL(); 34 | end 35 | end 36 | 37 | function r = logdet(o) 38 | r = zeros(o.k, 1); 39 | for i = 1:o.k 40 | r(i) = o.solvers{i}.logdet(); 41 | end 42 | end 43 | 44 | % Note that sigma can be negative. 45 | % We cannot truncate it to >=0 because we need unbias 46 | function sigma = leverageScoreComplement(o, nSketch) 47 | if nargin == 1, nSketch = 0; end 48 | 49 | sigma = zeros(o.k, o.m); 50 | for i = 1:o.k 51 | sigma(i,:) = o.solvers{i}.leverageScoreComplement(nSketch); 52 | end 53 | end 54 | 55 | function y = approxSolve(o, b) 56 | y = zeros(o.k, o.n); 57 | for i = 1:o.k 58 | y(i,:) = o.solvers{i}.approxSolve(b(i,:)'); 59 | end 60 | end 61 | 62 | function x = solve(o, b, w, x0) 63 | if nargin < 4 64 | x0 = zeros([o.k, o.n]); 65 | end 66 | 67 | if nargin < 3, w = o.w; end 68 | 69 | x = zeros(o.k, o.n); 70 | for i = 1:o.k 71 | x(i,:,:) = o.solvers{i}.solve(b(i,:)', w(i,:)', x0(i,:)'); 72 | end 73 | end 74 | 75 | function counts = getDecomposeCount(o) 76 | counts = zeros(o.k+1, 1); 77 | for i = 1:o.k 78 | c = o.solvers{i}.getDecomposeCount(); 79 | counts(i) = c(1); 80 | counts(end) = c(2); 81 | end 82 | end 83 | end 84 | end -------------------------------------------------------------------------------- /code/solver/PackedCSparse/SparseMatrix.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include "FloatArray.h" 5 | 6 | namespace PackedCSparse { 7 | static void pcs_assert(bool value, const char* message) 8 | { 9 | if (value == false) 10 | throw std::logic_error(message); 11 | } 12 | 13 | template 14 | T* pcs_aligned_new(size_t size) 15 | { 16 | int alignment = 64; // size of memory cache line 17 | int offset = alignment - 1 + sizeof(void*); 18 | void* p1 = (void*)new char[size * sizeof(T) + offset]; 19 | void** p2 = (void**)(((size_t)(p1)+offset) & ~(alignment - 1)); 20 | p2[-1] = p1; 21 | return (T*)p2; 22 | } 23 | 24 | template 25 | struct AlignedDeleter 26 | { 27 | void operator()(T* p) const 28 | { 29 | delete[](char*)(((void**)p)[-1]); 30 | } 31 | }; 32 | 33 | template 34 | using UniqueAlignedPtr = std::unique_ptr>; 35 | 36 | template 37 | using UniquePtr = std::unique_ptr; 38 | 39 | // Tx = Type for entries, Ti = Type for indices. 40 | // if Tx == bool, the matrix stores only sparsity information 41 | template 42 | struct SparseMatrix 43 | { 44 | Ti m = 0; /* number of rows */ 45 | Ti n = 0; /* number of columns */ 46 | UniquePtr p; /* column pointers (size n+1) */ 47 | UniquePtr i; /* row indices, size nnz */ 48 | UniqueAlignedPtr x; /* numerical values, size nnz */ 49 | 50 | SparseMatrix() = default; 51 | 52 | SparseMatrix(Ti m_, Ti n_, Ti nzmax_) 53 | { 54 | initialize(m_, n_, nzmax_); 55 | } 56 | 57 | bool initialized() const 58 | { 59 | return p && i; 60 | } 61 | 62 | void initialize(Ti m_, Ti n_, Ti nzmax) 63 | { 64 | if (nzmax < 1) nzmax = 1; 65 | m = m_; n = n_; 66 | p.reset(new Ti[n + 1]); 67 | i.reset(new Ti[nzmax]); 68 | if (!std::is_same::value) 69 | x.reset(pcs_aligned_new(nzmax)); 70 | } 71 | 72 | Ti nnz() const 73 | { 74 | return p[n]; 75 | } 76 | 77 | template 78 | SparseMatrix clone() const 79 | { 80 | SparseMatrix C(m, n, nnz()); 81 | Ti* Ap = p.get(), * Ai = i.get(); Tx* Ax = x.get(); 82 | Ti2* Cp = C.p.get(), * Ci = C.i.get(); Tx2* Cx = C.x.get(); 83 | 84 | for (Ti s = 0; s <= n; s++) 85 | Cp[s] = Ti2(Ap[s]); 86 | 87 | Ti nz = nnz(); 88 | for (Ti s = 0; s < nz; s++) 89 | Ci[s] = Ti2(Ai[s]); 90 | 91 | if (Cx) 92 | { 93 | for (Ti s = 0; s < nz; s++) 94 | Cx[s] = Ax? Tx2(Ax[s]): Tx2(1.0); 95 | } 96 | 97 | return C; 98 | } 99 | }; 100 | 101 | template 102 | struct DenseVector 103 | { 104 | Ti n = 0; /* number of columns */ 105 | UniqueAlignedPtr x; /* numerical values, size nnz */ 106 | 107 | DenseVector() = default; 108 | 109 | DenseVector(Ti n_) 110 | { 111 | initialize(n_); 112 | } 113 | 114 | bool initialized() const 115 | { 116 | return bool(x); 117 | } 118 | 119 | void initialize(Ti n_) 120 | { 121 | n = n_; 122 | x.reset(pcs_aligned_new(n_)); 123 | } 124 | }; 125 | 126 | 127 | // basic functions 128 | template 129 | SparseMatrix speye(Ti n, Tx* d = nullptr) 130 | { 131 | SparseMatrix D(n, n, n); 132 | 133 | for (Ti k = 0; k < n; k++) 134 | { 135 | D.i[k] = k; 136 | D.p[k] = k; 137 | } 138 | D.p[n] = n; 139 | 140 | Tx Tx1 = Tx(1.0); 141 | for (Ti k = 0; k < n; k++) 142 | D.x[k] = (d ? d[k] : Tx1); 143 | return D; 144 | } 145 | 146 | // Solve L out = x 147 | // Input: L in Tx^{n by n}, x in Tx2^{n} 148 | // Output: out in Tx2^{n}. 149 | // If out is provided, we will output to out. Else, output to x. 150 | template 151 | void lsolve(const SparseMatrix& L, Tx2* x, Tx2* out = nullptr) 152 | { 153 | pcs_assert(L.initialized(), "lsolve: bad inputs."); 154 | pcs_assert(L.n == L.m, "lsolve: dimensions mismatch."); 155 | 156 | Ti n = L.n, * Lp = L.p.get(), * Li = L.i.get(); Tx* Lx = L.x.get(); 157 | 158 | if (!out) out = x; 159 | if (x != out) std::copy(x, x + n, out); 160 | 161 | for (Ti j = 0; j < n; j++) 162 | { 163 | Tx2 out_j = out[j] / Lx[Lp[j]]; 164 | out[j] = out_j; 165 | 166 | Ti p_start = Lp[j] + 1, p_end = Lp[j + 1]; 167 | for (Ti p = p_start; p < p_end; p++) 168 | { //out[Li[p]] -= Lx[p] * out[j]; 169 | fnmadd(out[Li[p]], out_j, Lx[p]); 170 | } 171 | } 172 | } 173 | 174 | // Solve L' out = x 175 | // Input: L in Tx^{n by n}, x in Tx2^{n} 176 | // Output: out in Tx2^{n}. 177 | // If out is provided, we will output to out. Else, output to x. 178 | template 179 | void ltsolve(const SparseMatrix& L, Tx2* x, Tx2* out = nullptr) 180 | { 181 | pcs_assert(L.initialized(), "ltsolve: bad inputs."); 182 | pcs_assert(L.n == L.m, "ltsolve: dimensions mismatch."); 183 | 184 | Ti n = L.n, * Lp = L.p.get(), * Li = L.i.get(); Tx* Lx = L.x.get(); 185 | 186 | if (!out) out = x; 187 | if (x != out) std::copy(x, x + n, out); 188 | 189 | for (Ti j = n - 1; j != -1; j--) 190 | { 191 | Tx2 out_j = out[j]; 192 | 193 | Ti p_start = Lp[j] + 1, p_end = Lp[j + 1]; 194 | for (Ti p = p_start; p < p_end; p++) 195 | { //out[j] -= Lx[p] * out[Li[p]]; 196 | fnmadd(out_j, out[Li[p]], Lx[p]); 197 | } 198 | 199 | out[j] = out_j / Tx2(Lx[Lp[j]]); 200 | } 201 | } 202 | 203 | // Update y <-- y + A x 204 | // Input: A in Tx^{n by n}, x, y in Tx2^{n} 205 | template 206 | void gaxpy(const SparseMatrix& A, const Tx2* x, Tx2* y) 207 | { 208 | pcs_assert(A.initialized(), "gaxpy: bad inputs."); 209 | Ti m = A.m, n = A.n, * Ap = A.p.get(), * Ai = A.i.get(); Tx* Ax = A.x.get(); 210 | 211 | for (Ti j = 0; j < n; j++) 212 | { 213 | Tx2 x_j = x[j]; 214 | 215 | Ti p_start = Ap[j], p_end = Ap[j + 1]; 216 | for (Ti p = p_start; p < p_end; p++) 217 | { //y[Ai[p]] += Ax[p] * x[j]; 218 | fmadd(y[Ai[p]], x_j, Ax[p]); 219 | } 220 | } 221 | } 222 | }; 223 | -------------------------------------------------------------------------------- /code/solver/PackedCSparse/add.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "SparseMatrix.h" 4 | 5 | // Problem: 6 | // Compute M = A + B 7 | 8 | // Algorithm: 9 | // M = 0 10 | // M(A != 0) += A(A != 0) 11 | // M(B != 0) += B(A != 0) 12 | 13 | namespace PackedCSparse { 14 | template 15 | struct AddOutput : SparseMatrix 16 | { 17 | UniquePtr forwardA; 18 | UniquePtr forwardB; 19 | 20 | template 21 | void initialize(const SparseMatrix& A, const SparseMatrix& B) 22 | { 23 | pcs_assert(A.initialized() && B.initialized(), "add: bad inputs."); 24 | pcs_assert(A.n == B.n && A.m == B.m, "add: dimensions mismatch."); 25 | 26 | Ti m = A.m, n = A.n; 27 | Ti Anz = A.nnz(); Ti* Ap = A.p.get(), * Ai = A.i.get(); 28 | Ti Bnz = B.nnz(); Ti* Bp = B.p.get(), * Bi = B.i.get(); 29 | this->m = A.m; this->n = A.n; 30 | 31 | std::vector Ci; 32 | Ti* Cp = new Ti[n + 1]; 33 | forwardA.reset(new Ti[Anz]); 34 | forwardB.reset(new Ti[Bnz]); 35 | 36 | Cp[0] = 0; 37 | for (Ti i = 0; i < n; i++) 38 | { 39 | Ti s1 = Ap[i], s2 = Bp[i], end1 = Ap[i + 1], end2 = Bp[i + 1]; 40 | while ((s1 < end1) || (s2 < end2)) 41 | { 42 | Ti q = Ti(Ci.size()); 43 | Ti i1 = (s1 < end1) ? Ai[s1] : m; 44 | Ti i2 = (s2 < end2) ? Bi[s2] : m; 45 | Ti min_i = std::min(i1, i2); 46 | Ci.push_back(min_i); 47 | 48 | if (i1 == min_i) 49 | forwardA[s1++] = q; 50 | 51 | if (i2 == min_i) 52 | forwardB[s2++] = q; 53 | } 54 | Cp[i + 1] = Ti(Ci.size()); 55 | } 56 | 57 | this->p.reset(Cp); 58 | this->i.reset(new Ti[Ci.size()]); 59 | this->x.reset(pcs_aligned_new(Ci.size())); 60 | std::copy(Ci.begin(), Ci.end(), this->i.get()); 61 | } 62 | }; 63 | 64 | template 65 | void add(AddOutput& o, const SparseMatrix& A, const SparseMatrix& B) 66 | { 67 | if (!o.initialized()) 68 | o.initialize(A, B); 69 | 70 | Ti m = o.m, n = o.n; 71 | Ti Anz = A.nnz(); Ti* Ap = A.p.get(), * Ai = A.i.get(); Tx* Ax = A.x.get(); 72 | Ti Bnz = B.nnz(); Ti* Bp = B.p.get(), * Bi = B.i.get(); Tx* Bx = B.x.get(); 73 | Ti Cnz = o.nnz(); Ti* Cp = o.p.get(), * Ci = o.i.get(); Tx2* Cx = o.x.get(); 74 | Ti* forwardA = o.forwardA.get(), *forwardB = o.forwardB.get(); 75 | 76 | for (Ti s = 0; s < Cnz; s++) 77 | Cx[s] = 0; 78 | 79 | for (Ti s = 0; s < Anz; s++) 80 | Cx[forwardA[s]] = Ax[s]; 81 | 82 | for (Ti s = 0; s < Bnz; s++) 83 | Cx[forwardB[s]] += Bx[s]; 84 | } 85 | 86 | template 87 | AddOutput add(const SparseMatrix& A, const SparseMatrix& B) 88 | { 89 | AddOutput o; 90 | add(o, A, B); 91 | return o; 92 | } 93 | } -------------------------------------------------------------------------------- /code/solver/PackedCSparse/chol.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include "SparseMatrix.h" 5 | #include "transpose.h" 6 | 7 | // Problem: 8 | // Compute chol(A) 9 | 10 | // Algorithm: 11 | // We need to study this later as this is the bottleneck. 12 | // Document it as a lyx. 13 | // chol_up_looking: 14 | // Compute L row by row 15 | // This is faster when it is compute bound. 16 | // 17 | // chol_left_looking: 18 | // Compute L col by col 19 | // This is faster when it is memory bound. 20 | 21 | 22 | namespace PackedCSparse { 23 | template 24 | struct CholOutput : SparseMatrix 25 | { 26 | TransposeOutput Lt; // sparsity pattern of the Lt 27 | UniquePtr diag; // the index for diagonal element. Ax[diag[k]] is A_kk 28 | UniquePtr c; // c[i] = index the last nonzero on column i in the current L 29 | UniqueAlignedPtr w; // the row of L we are computing 30 | 31 | // The cost of this is roughly 3 times larger than chol 32 | // One can optimize it by using other data structure 33 | void initialize(const SparseMatrix& A) 34 | { 35 | pcs_assert(A.initialized(), "chol: bad inputs."); 36 | pcs_assert(A.n == A.m, "chol: dimensions mismatch."); 37 | 38 | Ti n = A.n, * Ap = A.p.get(), * Ai = A.i.get(); 39 | 40 | // initialize 41 | this->diag.reset(new Ti[n]); 42 | this->c.reset(new Ti[n]); 43 | this->w.reset(pcs_aligned_new(n)); 44 | 45 | // compute the sparsity pattern of L and diag 46 | using queue = std::priority_queue, std::greater>; 47 | queue q; // sol'n of the current row of L 48 | Ti* mark = new Ti[n]; // used to prevent same indices push to q twice 49 | std::vector* cols = new std::vector[n]; // stores the non-zeros of each col of L 50 | Ti nz = 0, Anz = Ap[n]; 51 | 52 | for (Ti i = 0; i < n; i++) 53 | mark[i] = -1; 54 | 55 | // for each row of A 56 | for (Ti i = 0; i < n; i++) 57 | { // push i-th row of A, called a_12, into mark 58 | Ti s; 59 | for (s = Ap[i]; s < Ap[i + 1]; s++) 60 | { 61 | Ti j = Ai[s]; 62 | if (j >= i) break; 63 | 64 | q.push(j); 65 | mark[j] = i; 66 | } 67 | if (s >= Anz) // this case happens only if the diagonal is 0. No cholesky in this case. 68 | this->diag[i] = 0; 69 | else 70 | this->diag[i] = s; 71 | 72 | // Solve L_11 l_12 = a_12 73 | while (!q.empty()) 74 | { 75 | Ti j = q.top(); 76 | 77 | for (Ti k : cols[j]) 78 | { 79 | if (mark[k] != i) 80 | { 81 | q.push(k); 82 | mark[k] = i; 83 | } 84 | } 85 | q.pop(); 86 | 87 | // update j col 88 | cols[j].push_back(i); 89 | ++nz; 90 | } 91 | 92 | // diag 93 | cols[i].push_back(i); 94 | ++nz; 95 | } 96 | delete[] mark; 97 | 98 | // write it as the compress form 99 | SparseMatrix::initialize(n, n, nz); 100 | 101 | Ti s_start = 0; Ti s = 0; 102 | for (Ti i = 0; i < n; i++) 103 | { 104 | this->p[i] = s_start; 105 | for (Ti k : cols[i]) 106 | this->i[s++] = k; 107 | s_start += Ti(cols[i].size()); 108 | } 109 | this->p[n] = s_start; 110 | delete[] cols; 111 | 112 | this->Lt = transpose(*this); 113 | 114 | // initialize w to 0 115 | Tx Tv0 = Tx(0); 116 | for (Ti k = 0; k < n; k++) 117 | w[k] = Tv0; 118 | } 119 | }; 120 | 121 | template 122 | void chol(CholOutput& o, const SparseMatrix& A) 123 | { 124 | if (!o.initialized()) 125 | o.initialize(A); 126 | 127 | //chol_up_looking(o, A); 128 | chol_left_looking(o, A); 129 | } 130 | 131 | template 132 | void chol_up_looking(CholOutput& o, const SparseMatrix& A) 133 | { 134 | Ti *Ap = A.p.get(), * Ai = A.i.get(); Tx* Ax = A.x.get(); 135 | Ti nzmax = o.nzmax; Ti n = A.n; 136 | Ti *Lp = o.p.get(); Ti* Li = o.i.get(); 137 | Ti *Ltp = o.Lt.p.get(); Ti* Lti = o.Lt.i.get(); 138 | 139 | Tx T0 = Tx(0); 140 | Tx* Lx = o.x.get(); Tx* w = o.w.get(); Ti* c = o.c.get(); 141 | Ti* diag = o.diag.get(); 142 | 143 | Ti* Lti_ptr = Lti; 144 | for (Ti k = 0; k < n; ++k) 145 | { 146 | c[k] = Lp[k]; 147 | 148 | Ti s_end = diag[k]; 149 | for (Ti s = Ap[k]; s < s_end; ++s) 150 | w[Ai[s]] = Ax[s]; 151 | 152 | // Solve L_11 l_12 = a_12 153 | Tx d = Ax[s_end]; Ti i; 154 | for (; (i = *(Lti_ptr++)) < k;) 155 | { 156 | Ti dLi = Lp[i], ci = c[i]++; 157 | Tx Lki = w[i] / Lx[dLi]; 158 | w[i] = T0; // maintain x = 0 for the (k+1) iteration 159 | 160 | for (Ti q = dLi + 1; q < ci; ++q) 161 | fnmadd(w[Li[q]], Lx[q], Lki); 162 | 163 | d -= Lki * Lki; 164 | Lx[ci] = Lki; 165 | } 166 | 167 | // l_22 = sqrt(a22 - ) 168 | Lx[c[k]++] = clipped_sqrt(d); 169 | } 170 | } 171 | 172 | template 173 | void chol_left_looking(CholOutput& o, const SparseMatrix& A) 174 | { 175 | Ti* Ap = A.p.get(), * Ai = A.i.get(); Tx* Ax = A.x.get(); 176 | Ti nzmax = o.nnz(); Ti n = A.n; 177 | Ti* Lp = o.p.get(); Ti* Li = o.i.get(); 178 | Ti* Ltp = o.Lt.p.get(); Ti* Lti = o.Lt.i.get(); 179 | 180 | Tx T0 = Tx(0), T1 = Tx(1); 181 | Tx* Lx = o.x.get(); 182 | Tx* w = o.w.get(); Ti* c = o.c.get(); 183 | Ti* diag = o.diag.get(); 184 | 185 | for (Ti j = 0; j < n; ++j) 186 | { 187 | c[j] = Lp[j]; 188 | 189 | // x = A_{j:n, j} 190 | { 191 | Ti is_start = diag[j], is_end = Ap[j + 1]; 192 | for (Ti is = is_start; is < is_end; ++is) 193 | w[Ai[is]] = Ax[is]; 194 | } 195 | 196 | // for each p in L_{j, 1:j-1} 197 | Ti ps_start = Ltp[j], ps_end = Ltp[j + 1] - 1; 198 | for (Ti ps = ps_start; ps < ps_end; ++ps) 199 | { 200 | Ti p = Lti[ps]; 201 | Ti cp = c[p]++; 202 | Tx Ljp = Lx[cp]; 203 | 204 | // for each i in L_{j:n,p} 205 | Ti is_start = cp, is_end = Lp[p + 1]; 206 | for (Ti is = is_start; is < is_end; ++is) 207 | { 208 | Ti i = Li[is]; 209 | fnmadd(w[i], Lx[is], Ljp); 210 | } 211 | } 212 | 213 | Tx Ljj = clipped_sqrt(w[j], 1e128); 214 | Lx[c[j]++] = Ljj; 215 | Tx inv_Ljj = T1 / Ljj; 216 | w[j] = T0; 217 | 218 | // for each i in L_{:,j} 219 | { 220 | Ti is_start = Lp[j] + 1, is_end = Lp[j + 1]; 221 | for (Ti is = is_start; is < is_end; ++is) 222 | { 223 | Ti i = Li[is]; 224 | Lx[is] = w[i] * inv_Ljj; 225 | w[i] = T0; 226 | } 227 | } 228 | } 229 | } 230 | 231 | template 232 | CholOutput chol(const SparseMatrix& A) 233 | { 234 | CholOutput o; 235 | chol(o, A); 236 | return o; 237 | } 238 | } 239 | -------------------------------------------------------------------------------- /code/solver/PackedCSparse/leverage.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "SparseMatrix.h" 3 | #include "projinv.h" 4 | #include "outerprod.h" 5 | 6 | // Problem: 7 | // Compute M = diag(A' inv(LL') A) 8 | 9 | namespace PackedCSparse { 10 | template 11 | struct LeverageOutput : DenseVector 12 | { 13 | ProjinvOutput Hinv; // Hinv = inv(H)|_L 14 | OuterprodOutput tau; // tau = diag(A' Hinv A) 15 | 16 | template 17 | void initialize(const SparseMatrix& L, const SparseMatrix& A, const SparseMatrix& At) 18 | { 19 | pcs_assert(L.initialized() && A.initialized() && At.initialized(), "leverage: bad inputs."); 20 | pcs_assert(L.m == L.n && L.n == A.m && L.n == At.n && A.n == At.m, "leverage: dimensions mismatch."); 21 | DenseVector::initialize(A.n); 22 | Hinv.initialize(L); 23 | tau.initialize(A, Hinv, At); 24 | } 25 | }; 26 | 27 | template 28 | void leverage(LeverageOutput& o, const SparseMatrix& L, const SparseMatrix& A, const SparseMatrix& At) 29 | { 30 | if (!o.initialized()) 31 | o.initialize(L, A, At); 32 | 33 | Tx T1 = Tx(1.0), T2 = Tx(2.0); 34 | projinv(o.Hinv, L); 35 | 36 | Ti m = A.m, n = A.n; 37 | Ti* Sp = o.Hinv.p.get(); Tx* Sv = o.Hinv.x.get(); 38 | for (Ti k = 0; k < m; ++k) 39 | Sv[Sp[k]] /= T2; 40 | 41 | outerprod(o.tau, A, o.Hinv, At); 42 | 43 | Tx* x = o.x.get(), * tau = o.tau.x.get(); 44 | for (Ti j = 0; j < n; j++) 45 | x[j] = T2 * tau[j]; 46 | } 47 | 48 | 49 | template 50 | LeverageOutput leverage(const SparseMatrix& L, const SparseMatrix& A, const SparseMatrix& At) 51 | { 52 | LeverageOutput o; 53 | leverage(o, L, A, At); 54 | return o; 55 | } 56 | } -------------------------------------------------------------------------------- /code/solver/PackedCSparse/leverageJL.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "SparseMatrix.h" 4 | 5 | // Problem: 6 | // Approximate M = diag(A' inv(LL') A) 7 | namespace PackedCSparse { 8 | const size_t JLPackedSize = 4; 9 | 10 | template 11 | struct LeverageJLOutput : DenseVector 12 | { 13 | UniqueAlignedPtr d; // random direction d 14 | UniqueAlignedPtr L_d; // random direction d 15 | UniqueAlignedPtr AtL_d; // A' L^{-1} d 16 | Ti m = 0; 17 | std::mt19937_64 gen; 18 | 19 | template 20 | void initialize(const SparseMatrix& L, const SparseMatrix& A, const SparseMatrix& At) 21 | { 22 | pcs_assert(L.initialized() && A.initialized() && At.initialized(), "leverageJL: bad inputs."); 23 | pcs_assert(L.m == L.n && L.n == A.m && L.n == At.n && A.n == At.m, "leverageJL: dimensions mismatch."); 24 | this->n = A.n; this->m = A.m; 25 | this->x.reset(pcs_aligned_new(this->n)); 26 | this->d.reset(pcs_aligned_new(this->m * 2 * JLPackedSize)); 27 | this->L_d.reset(pcs_aligned_new(this->m * 2 * JLPackedSize)); 28 | this->AtL_d.reset(pcs_aligned_new(this->n * 2 * JLPackedSize)); 29 | } 30 | }; 31 | 32 | // compute sum_{j=1}^{k} (A' L^{-T} u_j) .* (A' L^{-T} u_j) 33 | template 34 | void projectionJL(LeverageJLOutput& o, const SparseMatrix& L, const SparseMatrix& A, const SparseMatrix& At) 35 | { 36 | Ti m = A.m, n = A.n; 37 | Tx T0 = Tx(0.0), T1 = Tx(1.0); 38 | Tx* d = o.d.get(), * L_d = o.L_d.get(), * AtL_d = o.AtL_d.get(), * x = o.x.get(); 39 | 40 | for (Ti i = 0; i < m * k; i++) 41 | d[i] = sign(o.gen); 42 | 43 | for (Ti i = 0; i < n * k; i++) 44 | AtL_d[i] = T0; 45 | 46 | ltsolve(L, (BaseImpl*)d, (BaseImpl*)L_d); 47 | gaxpy(At, (BaseImpl*)L_d, (BaseImpl*)AtL_d); 48 | 49 | for (Ti i = 0; i < n; i++) 50 | { 51 | Tx ret_i = T0; 52 | for (Ti j = 0; j < k; j++) 53 | ret_i += AtL_d[i * k + j] * AtL_d[i * k + j]; 54 | 55 | x[i] += ret_i; 56 | } 57 | } 58 | 59 | template 60 | void leverageJL(LeverageJLOutput& o, const SparseMatrix& L, const SparseMatrix& A, const SparseMatrix& At, size_t k) 61 | { 62 | if (!o.initialized()) 63 | o.initialize(L, A, At); 64 | 65 | Ti n = A.n; Tx* x = o.x.get(); 66 | for (Ti i = 0; i < n; i++) 67 | x[i] = Tx(0.0); 68 | 69 | constexpr size_t k_step = JLPackedSize; 70 | for(size_t i = 1; i <= k / k_step; i++) 71 | projectionJL(o, L, A, At); 72 | 73 | for (size_t i = 1; i <= k % k_step; i++) 74 | projectionJL<1>(o, L, A, At); 75 | 76 | Tx ratio = Tx(1 / double(k)); 77 | for (Ti i = 0; i < n; i++) 78 | x[i] *= ratio; 79 | } 80 | 81 | template 82 | LeverageJLOutput leverageJL(const SparseMatrix& L, const SparseMatrix& A, const SparseMatrix& At, size_t k) 83 | { 84 | LeverageJLOutput o; 85 | leverageJL(o, L, A, At, k); 86 | return o; 87 | } 88 | 89 | 90 | // compute (A' L^{-T} u_j) .* (A' L^{-T} v_j) for j = 1, 2, ... k 91 | template 92 | Tx cholAccuracy(LeverageJLOutput& o, const SparseMatrix& L, const SparseMatrix& A, const SparseMatrix& At, const Tx* w) 93 | { 94 | if (!o.initialized()) 95 | o.initialize(L, A, At); 96 | 97 | constexpr Ti k = JLPackedSize; 98 | constexpr Ti k_ = 2 * k; 99 | 100 | 101 | Ti m = A.m, n = A.n; 102 | Tx T0 = Tx(0.0), T1 = Tx(1.0); 103 | Tx* d = o.d.get(), * L_d = o.L_d.get(), * AtL_d = o.AtL_d.get(), * x = o.x.get(); 104 | 105 | std::uniform_real_distribution distribution(-sqrt(3.0),sqrt(3.0)); 106 | for (Ti i = 0; i < m * k_; i++) 107 | d[i] = Tx(distribution(o.gen)); // roughly uniform distribution with variance 1 108 | 109 | for (Ti i = 0; i < n * k_; i++) 110 | AtL_d[i] = T0; 111 | 112 | ltsolve(L, (BaseImpl*)d, (BaseImpl*)L_d); 113 | gaxpy(At, (BaseImpl*)L_d, (BaseImpl*)AtL_d); 114 | 115 | Tx result[k]; 116 | for (Ti j = 0; j < k; j++) 117 | result[j] = Tx(0.0); 118 | 119 | for (Ti i = 0; i < m; i++) 120 | { 121 | Tx* d = o.d.get() + i * (2 * k); 122 | for (Ti j = 0; j < k; j++) 123 | result[j] -= d[j] * d[j + k]; 124 | } 125 | 126 | for (Ti i = 0; i < n; i++) 127 | { 128 | Tx w_i = w[i]; 129 | for (Ti j = 0; j < k; j++) 130 | result[j] += AtL_d[i * k_ + j] * AtL_d[i * k_ + j + k] * w_i; 131 | } 132 | 133 | Tx est = Tx(0.0); 134 | for (Ti j = 0; j < k; j++) 135 | est += result[j] * result[j]; 136 | 137 | return clipped_sqrt(est/Tx(double(k)), 0.0); 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /code/solver/PackedCSparse/multiply.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include "SparseMatrix.h" 5 | 6 | // Problem: 7 | // Compute M = A diag(w) B 8 | 9 | // Algorithm: 10 | // Compute M col by col 11 | 12 | namespace PackedCSparse { 13 | template 14 | struct MultiplyOutput : SparseMatrix 15 | { 16 | UniqueAlignedPtr c; 17 | 18 | template 19 | void initialize(const SparseMatrix& A, const SparseMatrix& B) 20 | { 21 | pcs_assert(A.initialized() && B.initialized(), "multiply: bad inputs."); 22 | pcs_assert(A.n == B.m, "multiply: dimensions mismatch."); 23 | 24 | Ti m = A.m, n = B.n; 25 | Ti* Ap = A.p.get(), * Ai = A.i.get(); 26 | Ti* Bp = B.p.get(), * Bi = B.i.get(); 27 | 28 | this->c.reset(pcs_aligned_new(m)); 29 | 30 | Ti* last_j = new Ti[m]; 31 | for (Ti i = 0; i < m; i++) 32 | { 33 | last_j[i] = -1; 34 | this->c[i] = Tx(0.0); 35 | } 36 | 37 | Ti* Cp = new Ti[size_t(n)+1]; 38 | std::vector Ci; 39 | 40 | Cp[0] = 0; 41 | for (Ti j1 = 0; j1 < n; j1++) 42 | { 43 | for (Ti p1 = Bp[j1]; p1 < Bp[j1 + 1]; p1++) 44 | { 45 | Ti j2 = Bi[p1]; 46 | for (Ti p2 = Ap[j2]; p2 < Ap[j2 + 1]; p2++) 47 | { 48 | Ti i = Ai[p2]; 49 | if (last_j[i] != j1) 50 | { 51 | last_j[i] = j1; 52 | Ci.push_back(i); 53 | } 54 | } 55 | } 56 | Cp[j1 + 1] = Ti(Ci.size()); 57 | } 58 | delete[] last_j; 59 | 60 | for (Ti j = 0; j < n; j++) 61 | std::sort(Ci.begin() + Cp[j], Ci.begin() + Cp[j + 1]); 62 | 63 | this->m = m; this->n = n; 64 | this->x.reset(pcs_aligned_new(Ci.size())); 65 | this->p.reset(Cp); 66 | this->i.reset(new Ti[Ci.size()]); 67 | std::copy(Ci.begin(), Ci.end(), this->i.get()); 68 | } 69 | }; 70 | 71 | template 72 | void multiply_general(MultiplyOutput& o, const SparseMatrix& A, const Tx2* w, const SparseMatrix& B) 73 | { 74 | if (!o.initialized()) 75 | o.initialize(A, B); 76 | 77 | Ti m = o.m, n = o.n; 78 | Ti* Ap = A.p.get(), * Ai = A.i.get(); Tx* Ax = A.x.get(); 79 | Ti* Bp = B.p.get(), * Bi = B.i.get(); Tx* Bx = B.x.get(); 80 | Ti* Cp = o.p.get(), * Ci = o.i.get(); Tx2* Cx = o.x.get(); 81 | Tx2* c = o.c.get(); // initialized to 0 82 | 83 | const Tx2 T0 = Tx2(0); 84 | for (Ti j1 = 0; j1 < n; j1++) 85 | { 86 | for (Ti p1 = Bp[j1]; p1 < Bp[j1 + 1]; p1++) 87 | { 88 | Ti j2 = Bi[p1]; 89 | Tx2 beta = has_weight? (Tx2(Bx[p1]) * w[j2]) : Tx2(Bx[p1]); 90 | 91 | for (Ti p2 = Ap[j2]; p2 < Ap[j2 + 1]; p2++) 92 | { 93 | //x[Ai[p2]] += beta * Ax[p2]; 94 | fmadd(c[Ai[p2]], beta, Ax[p2]); 95 | } 96 | } 97 | 98 | for (Ti p1 = Cp[j1]; p1 < Cp[j1 + 1]; p1++) 99 | { 100 | Cx[p1] = c[Ci[p1]]; 101 | c[Ci[p1]] = T0; // ensure c is 0 after the call 102 | } 103 | } 104 | } 105 | 106 | template 107 | void multiply(MultiplyOutput& o, const SparseMatrix& A, const SparseMatrix& B) 108 | { 109 | multiply_general(o, A, nullptr, B); 110 | } 111 | 112 | template 113 | void multiply(MultiplyOutput& o, const SparseMatrix& A, const Tx2* w, const SparseMatrix& B) 114 | { 115 | multiply_general(o, A, w, B); 116 | } 117 | 118 | template 119 | MultiplyOutput multiply(const SparseMatrix& A, const Tx2* w, const SparseMatrix& B) 120 | { 121 | MultiplyOutput o; 122 | multiply(o, A, w, B); 123 | return o; 124 | } 125 | 126 | template 127 | MultiplyOutput multiply(const SparseMatrix& A, const SparseMatrix& B) 128 | { 129 | MultiplyOutput o; 130 | multiply(o, A, B); 131 | return o; 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /code/solver/PackedCSparse/outerprod.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "SparseMatrix.h" 3 | 4 | // Problem: 5 | // Compute x = diag(At S Bt) 6 | 7 | // Algorithm: 8 | // Note that x = diag(B St A) = grad_H Tr(St A H B) 9 | // We run autodiff on the function Tr(St A H B). 10 | // Hence, the algorithm is essentially same as multiply(A, B) with the same runtime. 11 | 12 | namespace PackedCSparse { 13 | template 14 | struct OuterprodOutput : DenseVector 15 | { 16 | UniqueAlignedPtr s_col; 17 | UniquePtr s_mark; 18 | 19 | template 20 | void initialize(const SparseMatrix& A, const SparseMatrix& S, const SparseMatrix& B) 21 | { 22 | pcs_assert(A.initialized() && B.initialized() && S.initialized(), "outerprod: bad inputs."); 23 | pcs_assert(A.m == S.m && S.n == B.n, "outerprod: dimensions mismatch."); 24 | 25 | DenseVector::initialize(A.n); 26 | s_col.reset(pcs_aligned_new(S.m)); 27 | s_mark.reset(new Ti[S.m]); 28 | } 29 | }; 30 | 31 | template 32 | void outerprod(OuterprodOutput& o, const SparseMatrix& A, const SparseMatrix& S, const SparseMatrix& B) 33 | { 34 | if (!o.initialized()) 35 | o.initialize(A, S, B); 36 | 37 | Ti Sn = S.n, Sm = S.m, An = A.n; 38 | Ti* Ap = A.p.get(), * Ai = A.i.get(); Tx2* Ax = A.x.get(); 39 | Ti* Bp = B.p.get(), * Bi = B.i.get(); Tx2* Bx = B.x.get(); 40 | Ti* Sp = S.p.get(), * Si = S.i.get(); Tx* Sx = S.x.get(); 41 | Tx* s_col = o.s_col.get(); 42 | Ti* s_mark = o.s_mark.get(); 43 | Tx* x = o.x.get(); 44 | 45 | std::fill(s_mark, s_mark + Sm, Ti(-1)); 46 | std::fill(x, x + An, Tx(0.0)); 47 | 48 | for (Ti j = 0; j < Sn; j++) 49 | { 50 | for (Ti p = Sp[j]; p < Sp[j + 1]; p++) 51 | { 52 | s_col[Si[p]] = Sx[p]; 53 | s_mark[Si[p]] = j; 54 | } 55 | 56 | for (Ti p = Bp[j]; p < Bp[j + 1]; p++) 57 | { 58 | Ti i = Bi[p]; Tx b = Bx[p]; 59 | for (Ti q = Ap[i]; q < Ap[i + 1]; q++) 60 | { 61 | Tx a = Ax[q]; Ti a_i = Ai[q]; 62 | if (s_mark[a_i] == j) 63 | { //x[i] += s_col[a_i] * a * b; 64 | fmadd(x[i], s_col[a_i], a * b); 65 | } 66 | } 67 | } 68 | } 69 | } 70 | 71 | template 72 | OuterprodOutput outerprod(const SparseMatrix& A, const SparseMatrix& S, const SparseMatrix& B) 73 | { 74 | OuterprodOutput o; 75 | outerprod(o, A, S, B); 76 | return o; 77 | } 78 | } -------------------------------------------------------------------------------- /code/solver/PackedCSparse/projinv.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "SparseMatrix.h" 3 | 4 | // Problem: 5 | // Compute inv(L L') restricted on L 6 | 7 | // Algorithm: 8 | // We need to study this later as this is the bottleneck. 9 | // Document it as a lyx. 10 | 11 | namespace PackedCSparse { 12 | template 13 | struct ProjinvOutput : SparseMatrix 14 | { 15 | TransposeOutput Lt; // sparsity pattern of the Lt 16 | UniqueAlignedPtr w; // the row of L we are computing 17 | UniquePtr c; // c[i] = index the last nonzero on column i in the current L 18 | 19 | void initialize(const SparseMatrix& L) 20 | { 21 | pcs_assert(L.initialized(), "chol: bad inputs."); 22 | pcs_assert(L.n == L.m, "chol: dimensions mismatch."); 23 | 24 | // Copy the sparsity of L 25 | SparseMatrix::operator=(std::move(L.clone())); 26 | 27 | // allocate workspaces 28 | Ti n = L.n; 29 | w.reset(pcs_aligned_new(n)); 30 | c.reset(new Ti[n]); 31 | Lt = transpose(L); 32 | } 33 | }; 34 | 35 | template 36 | void projinv(ProjinvOutput& o, const SparseMatrix& L) 37 | { 38 | if (!o.initialized()) 39 | o.initialize(L); 40 | 41 | Tx* Sx = o.x.get(); Ti n = o.n; 42 | Ti* Li = L.i.get(), * Lp = L.p.get(); Tx* Lv = L.x.get(); 43 | Ti* Lti = o.Lt.i.get(), * Ltp = o.Lt.p.get(); 44 | Tx* w = o.w.get(); 45 | Ti* c = o.c.get(); 46 | Tx T0 = Tx(0), T1 = Tx(1); 47 | 48 | for (Ti k = 0; k < n; k++) 49 | c[k] = Lp[k + 1] - 1; 50 | 51 | for (Ti k = n - 1; k != -1; k--) 52 | { 53 | for (Ti p = Lp[k] + 1; p < Lp[k + 1]; p++) 54 | w[Li[p]] = Sx[p]; 55 | 56 | Tx sum = T1 / Lv[Lp[k]]; 57 | for (Ti p = Ltp[k + 1] - 1; p != Ltp[k] - 1; p--) 58 | { 59 | Ti i = Lti[p], Lpi = Lp[i]; 60 | 61 | for (Ti q = Lp[i + 1] - 1; q != Lpi; q--) 62 | fnmadd(sum, Lv[q], w[Li[q]]); 63 | //sum -= Lv[q] * w[Li[q]]; 64 | 65 | sum = sum / Lv[Lpi]; 66 | w[i] = sum; 67 | Sx[c[i]] = sum; 68 | c[i]--; 69 | sum = T0; 70 | } 71 | } 72 | } 73 | 74 | template 75 | ProjinvOutput projinv(const SparseMatrix& L) 76 | { 77 | ProjinvOutput o; 78 | projinv(o, L); 79 | return o; 80 | } 81 | } -------------------------------------------------------------------------------- /code/solver/PackedCSparse/transpose.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "SparseMatrix.h" 3 | 4 | // Problem: 5 | // Compute M = A' 6 | 7 | // Algorithm: 8 | // We precompute the mapping from entries of A to entries of At 9 | 10 | namespace PackedCSparse { 11 | template 12 | struct TransposeOutput : SparseMatrix 13 | { 14 | UniquePtr forward; 15 | 16 | template 17 | void initialize(const SparseMatrix& A) 18 | { 19 | pcs_assert(A.initialized(), "transpose: bad inputs."); 20 | SparseMatrix::initialize(A.n, A.m, A.nnz()); 21 | 22 | Ti Am = A.m, An = A.n, * Ap = A.p.get(), * Ai = A.i.get(); 23 | Ti Bm = this->m, Bn = this->n, * Bp = this->p.get(), * Bi = this->i.get(); 24 | Ti nz = A.nnz(); 25 | 26 | // compute row counts of A 27 | Ti* count = new Ti[Bn + 1](); 28 | 29 | for (Ti p = 0; p < nz; p++) 30 | count[Ai[p]]++; 31 | 32 | // compute this->p 33 | Bp[0] = 0; 34 | for (Ti i = 0; i < Bn; i++) 35 | { 36 | Bp[i + 1] = Bp[i] + count[i]; 37 | count[i] = Bp[i]; // Now, cnt[i] stores the index of the first element in the i-th row 38 | } 39 | 40 | // compute i and forward 41 | if (!std::is_same::value) 42 | forward.reset(new Ti[nz]); 43 | for (Ti j = 0; j < An; j++) 44 | { 45 | for (Ti p = Ap[j]; p < Ap[j + 1]; p++) 46 | { 47 | Ti q = count[Ai[p]]; 48 | Bi[q] = j; 49 | if (!std::is_same::value) 50 | forward[p] = q; 51 | count[Ai[p]]++; 52 | } 53 | } 54 | 55 | delete[] count; 56 | } 57 | }; 58 | 59 | template 60 | void transpose(TransposeOutput& o, const SparseMatrix& A) 61 | { 62 | if (!o.initialized()) 63 | o.initialize(A); 64 | 65 | Tx* Ax = A.x.get(); Tx2 *Bx = o.x.get(); 66 | Ti nz = o.nnz(), *forward = o.forward.get(); 67 | 68 | if (!std::is_same::value) 69 | { 70 | for (Ti s = 0; s < nz; s++) 71 | Bx[forward[s]] = Tx2(Ax[s]); 72 | } 73 | } 74 | 75 | template 76 | TransposeOutput transpose(const SparseMatrix& A) 77 | { 78 | TransposeOutput o; 79 | transpose(o, A); 80 | return o; 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /code/solver/PackedChol.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "mex_utils.h" 3 | #include "PackedCSparse/PackedChol.h" 4 | 5 | // SIMD_LEN = 0 means 1 and the vectors are inputted not as a matrix 6 | #ifndef SIMD_LEN 7 | #define SIMD_LEN 0 8 | #endif 9 | 10 | namespace env = MexEnvironment; 11 | typedef env::mexIdx Index; 12 | const size_t chol_k = (SIMD_LEN == 0) ? 1 : SIMD_LEN; 13 | using CholObj = PackedChol; 14 | using Matrix = SparseMatrix; 15 | using Tx2 = FloatArray; 16 | 17 | int main() 18 | { 19 | size_t simd_len = SIMD_LEN; 20 | const char* cmd = env::inputString(); 21 | uint64_t uid = env::inputScalar(); 22 | if (!strcmp(cmd, "init")) 23 | { 24 | Matrix A = std::move(env::inputSparseArray()); 25 | auto solver = new CholObj(A); 26 | solver->setSeed(uid); 27 | env::outputScalar((uint64_t)solver); 28 | } 29 | else 30 | { 31 | CholObj* solver = (CholObj*)uid; 32 | size_t n = solver->A.n; size_t m = solver->A.m; 33 | if (!strcmp(cmd, "solve")) 34 | { 35 | const double* b; double * out; 36 | if (SIMD_LEN == 0) 37 | { 38 | b = env::inputArray(m); 39 | out = env::outputArray(m); 40 | } 41 | else 42 | { 43 | b = env::inputArray(simd_len, m); 44 | out = env::outputArray(simd_len, m); 45 | } 46 | solver->solve((Tx2*)b, (Tx2*)out); 47 | } 48 | else if (!strcmp(cmd, "decompose")) 49 | { 50 | const double* w; 51 | if (SIMD_LEN == 0) 52 | w = env::inputArray(n); 53 | else 54 | w = env::inputArray(simd_len, n); 55 | Tx2 ret = solver->decompose((Tx2*)w); 56 | 57 | if (SIMD_LEN == 0) 58 | env::outputScalar(get(ret, 0)); 59 | else 60 | { 61 | double* out = env::outputArray(simd_len, 1); 62 | for (size_t i = 0; i < SIMD_LEN; i++) 63 | out[i] = get(ret, i); 64 | } 65 | } 66 | else if (!strcmp(cmd, "leverageScoreComplement")) 67 | { 68 | size_t k = (size_t)env::inputScalar(); 69 | double* out; 70 | if (SIMD_LEN == 0) 71 | out = env::outputArray(n); 72 | else 73 | out = env::outputArray(simd_len, n); 74 | 75 | if (k == 0) 76 | solver->leverageScoreComplement((Tx2*)out); 77 | else 78 | solver->leverageScoreComplementJL((Tx2*)out, k); 79 | } 80 | else if (!strcmp(cmd, "logdet")) 81 | { 82 | Tx2 ret = solver->logdet(); 83 | if (SIMD_LEN == 0) 84 | env::outputScalar(get(ret, 0)); 85 | else 86 | { 87 | double* out = env::outputArray(simd_len, 1); 88 | for (size_t i = 0; i < SIMD_LEN; i++) 89 | out[i] = get(ret, i); 90 | } 91 | } 92 | else if (!strcmp(cmd, "diagL")) 93 | { 94 | double* out; 95 | if (SIMD_LEN == 0) 96 | out = env::outputArray(m); 97 | else 98 | out = env::outputArray(simd_len, m); 99 | solver->diagL((Tx2*)out); 100 | } 101 | else if (!strcmp(cmd, "L")) 102 | { 103 | size_t k = (size_t)env::inputScalar(0.0); 104 | auto L = solver->getL(k); 105 | env::outputSparseArray(L); 106 | } 107 | else if (!strcmp(cmd, "setAccuracyTarget")) 108 | { 109 | solver->accuracyThreshold = env::inputScalar(); 110 | } 111 | else if (!strcmp(cmd, "getDecomposeCount")) 112 | { 113 | env::outputDoubleArray(solver->numExact.data(), chol_k + 1); 114 | } 115 | else if (!strcmp(cmd, "delete")) 116 | { 117 | delete solver; 118 | } 119 | else 120 | throw "Invalid operation."; 121 | } 122 | mxFree((void*)cmd); 123 | } -------------------------------------------------------------------------------- /code/solver/RHMC_compile_all.m: -------------------------------------------------------------------------------- 1 | function RHMC_compile_all(internal_use) 2 | [path, ~, ~] = fileparts(mfilename('fullpath')); 3 | outputfile = fullfile(path, '..', '..', 'bin', 'cpuInfo'); 4 | mex_string = ['mex ' fullfile(path, 'FeatureDetector', 'cpuInfo.cpp') ' -output ' outputfile]; 5 | disp(mex_string); 6 | eval(mex_string); 7 | 8 | %% 9 | if (nargin == 1 && internal_use) 10 | arch = 'AVX'; 11 | compile_solver(0, true, arch); 12 | compile_solver(1, true, arch); 13 | compile_solver(4, true, arch); 14 | arch = 'SSE'; 15 | compile_solver(0, true, arch); 16 | compile_solver(1, true, arch); 17 | compile_solver(4, true, arch); 18 | else 19 | arch = 'native'; 20 | compile_solver(0, true, arch); 21 | compile_solver(1, true, arch); 22 | compile_solver(4, true, arch); 23 | end 24 | -------------------------------------------------------------------------------- /code/solver/Solver.m: -------------------------------------------------------------------------------- 1 | function o = Solver(A, precision, k) 2 | if nargin < 2, precision = 'double'; end 3 | if nargin < 3, k = 0; end 4 | if (strcmp(precision,'double')) 5 | precision = 1e-6; 6 | elseif (strcmp(precision,'doubledouble')) 7 | precision = 0.0; 8 | elseif ~isfloat(precision) 9 | error('Unsupported precision mode'); 10 | end 11 | 12 | o = MexSolver(A, precision, k); 13 | %if nargin < 3 14 | % o = MatlabSolver(A, precision); 15 | %else 16 | % o = MultiMatlabSolver(A, precision, k); 17 | %end 18 | end -------------------------------------------------------------------------------- /code/solver/batch_pcg.m: -------------------------------------------------------------------------------- 1 | % Solve the equation o.AwAt y = b with initial point x 2 | % with an approximate solver o.approxSolve 3 | % 4 | % flag = 0 means success 5 | % flag = 1 means fails 6 | % flag = 2 means haven't converged 7 | function [x, iter, flag] = batch_pcg(x, b, o, tol, max_iter) 8 | if (o.k == 0) 9 | vdim = 1; 10 | else 11 | vdim = 2; 12 | end 13 | 14 | flag = 0; % success 15 | r = b - o.AwAt(x); 16 | rz = 1; 17 | b_norm = sqrt(sum(b.^2,vdim)); 18 | r_norm = sqrt(sum(r.^2,vdim)); 19 | 20 | if all(r_norm <= tol * b_norm) 21 | iter = 0; 22 | return 23 | end 24 | 25 | for iter = 1:max_iter 26 | z = o.approxSolve(r); 27 | 28 | rz_prev = rz; 29 | rz = sum(r.*z,vdim); 30 | 31 | if any(rz < -eps) || any(~isfinite(rz)) 32 | flag = 1; % stagnation 33 | break; 34 | end 35 | 36 | if iter == 1 37 | p = z; 38 | else 39 | beta = rz ./ rz_prev; 40 | p = z + beta .* p; 41 | end 42 | 43 | Ap = o.AwAt(p); 44 | pAp = sum(p.*Ap,vdim); 45 | if any(pAp <= 0) || any(~isfinite(pAp)) 46 | flag = 1; % stagnation 47 | break; 48 | end 49 | alpha = rz./sum(p.*Ap,vdim); 50 | 51 | x = x + alpha .* p; 52 | r = r - alpha .* Ap; 53 | 54 | r_norm = sqrt(sum(r.^2,vdim)); 55 | 56 | if all(r_norm <= tol * b_norm) 57 | return; 58 | end 59 | end 60 | 61 | if (flag ~= 1), flag = 2; end 62 | x = o.approxSolve(b); 63 | end -------------------------------------------------------------------------------- /code/solver/compile_solver.m: -------------------------------------------------------------------------------- 1 | function compile_solver(simd_len, recompile, arch) 2 | if nargin < 1, simd_len = 0; end 3 | if nargin < 2, recompile = false; end 4 | if nargin < 3, arch = 'native'; end 5 | 6 | % check if the solver exists 7 | func_name = MexSolver.solverName(simd_len); 8 | if (~recompile && exist(func_name) == 3) 9 | try 10 | pchol = str2func(func_name); 11 | uid = pchol('init', uint64(123), speye(3)); 12 | pchol('delete', uid); 13 | recompile = false; 14 | catch 15 | recompile = true; 16 | end 17 | else 18 | recompile = true; 19 | end 20 | 21 | if (recompile) 22 | clear mex 23 | [path,~,~] = fileparts(mfilename('fullpath')); 24 | libpath = fullfile(path); 25 | qdpath = fullfile(path, 'qd'); 26 | outputfile = fullfile(path, '..', '..', 'bin', func_name); 27 | 28 | if (isempty(mex.getCompilerConfigurations('C++','Selected'))) 29 | error('No C++ mex compiler is available.'); 30 | end 31 | 32 | compiler = mex.getCompilerConfigurations('C++','Selected').ShortName; 33 | mex_string = ['mex -R2018a -silent -DSIMD_LEN=' num2str(simd_len) ' -O -I"%libpath" ']; 34 | 35 | if (contains(func_name, 'arm')) 36 | archflag = ''; 37 | nameflag = ''; 38 | mex_string2 = ['CFLAGS="$CFLAGS ' archflag '"']; 39 | elseif (contains(compiler, 'MSVCPP')) 40 | switch arch 41 | case 'native' 42 | s = cpuInfo(); 43 | if (s.AVX512F && s.OS_AVX512) 44 | archflag = '/arch:AVX512'; 45 | elseif (s.AVX2 && s.OS_AVX) 46 | archflag = '/arch:AVX2'; 47 | else 48 | archflag = ''; 49 | end 50 | nameflag = 'native'; 51 | case 'AVX' 52 | archflag = '/arch:AVX2'; 53 | nameflag = ''; 54 | case 'SSE' 55 | archflag = ''; 56 | nameflag = 'SSE'; 57 | end 58 | mex_string2 = ['COMPFLAGS="$COMPFLAGS /O2 ' archflag '"']; 59 | elseif (contains(compiler, 'Clang++')) 60 | switch arch 61 | case 'native' 62 | archflag = '-march=native'; 63 | nameflag = 'native'; 64 | case 'AVX' 65 | archflag = '-mavx2 -mfma'; 66 | nameflag = ''; 67 | case 'SSE' 68 | archflag = ''; 69 | nameflag = 'SSE'; 70 | end 71 | mex_string2 = ['CFLAGS="$CFLAGS ' archflag '"']; 72 | elseif (contains(compiler, 'g++')) 73 | switch arch 74 | case 'native' 75 | archflag = '-march=native'; 76 | nameflag = 'native'; 77 | case 'AVX' 78 | archflag = '-mavx2 -mfma'; 79 | nameflag = ''; 80 | case 'SSE' 81 | archflag = ''; 82 | nameflag = 'SSE'; 83 | end 84 | mex_string2 = ['CFLAGS="$CFLAGS ' archflag '"']; 85 | else 86 | error('Currently, we only support MSVCPP, Clang++ or g++ as the compiler.'); 87 | end 88 | outputfile = [outputfile nameflag]; 89 | mex_string = [mex_string mex_string2]; 90 | 91 | c = {}; 92 | c{end+1} = '%mex_string -output "%outputfile" "%libpath/PackedChol.cpp" "%qdpath/util.cc" "%qdpath/bits.cc" "%qdpath/dd_real.cc" "%qdpath/dd_const.cc" "%qdpath/qd_real.cc" "%qdpath/qd_const.cc"'; 93 | 94 | %% replace keywords 95 | keywords = {'%mex_string', '%qdpath', '%libpath', '%outputfile'}; 96 | replaces = {mex_string, qdpath, libpath, outputfile}; 97 | replaces = replace(replaces, keywords, replaces); 98 | 99 | for i = 1:length(c) 100 | c{i} = replace(c{i}, keywords, replaces); 101 | end 102 | 103 | %% Run the commands 104 | for i = 1:length(c) 105 | disp(c{i}); 106 | eval(c{i}); 107 | end 108 | end -------------------------------------------------------------------------------- /code/solver/qd/COPYING: -------------------------------------------------------------------------------- 1 | This work was supported by the Director, Office of Science, Division 2 | of Mathematical, Information, and Computational Sciences of the 3 | U.S. Department of Energy under contract numbers DE-AC03-76SF00098 and 4 | DE-AC02-05CH11231. 5 | 6 | Copyright (c) 2003-2009, The Regents of the University of California, 7 | through Lawrence Berkeley National Laboratory (subject to receipt of 8 | any required approvals from U.S. Dept. of Energy) All rights reserved. 9 | 10 | By downloading or using this software you are agreeing to the modified 11 | BSD license that is in file "BSD-LBNL-License.doc" in the main ARPREC 12 | directory. If you wish to use the software for commercial purposes 13 | please contact the Technology Transfer Department at TTD@lbl.gov or 14 | call 510-286-6457." 15 | 16 | 17 | -------------------------------------------------------------------------------- /code/solver/qd/NEWS: -------------------------------------------------------------------------------- 1 | Changes for 2.3.22 2 | Made changes suggested by Vasiliy Sotnikov 3 | 4 | Changes for 2.3.21 5 | Changed renorm in include/qd/qd_inline.h 6 | 7 | Changes for 2.3.20 8 | added #include to quadt_test.cpp 9 | changed references to 2.3.20 from 2.3.18 10 | 11 | Changes for 2.3.19 12 | - Updated qd_real.cpp and dd_real.cpp to fix a buffer overflow problem. 13 | 14 | Changes for 2.3.18 15 | - Updated qd_real.cpp and dd_real.cpp to fix a problem in output. 16 | 17 | Changes for 2.3.17 18 | - updated qd_real.cpp, to fix a problem with improper treatment of 19 | negative arguments in nroot. 20 | 21 | Changes for 2.3.16 22 | - Updated dd_real.cpp, to fix a problem with inaccurate values of 23 | tanh for small arguments. 24 | 25 | Changes for 2.3.15 26 | - Updated qd_real.cpp, to fix a problem with static definitions. 27 | 28 | Changes for 2.3.14 29 | - Updated autoconfig (replaced config.sub and config.guess) 30 | 31 | Changes for 2.3.7 32 | - Fixed bug in to_digits where digits larger than 10 33 | where output occasionally. 34 | 35 | Changes for 2.3.6 36 | - Added fmod (C++) and mod (Fortran) functions. 37 | 38 | Changes for 2.3.5 39 | - Fixed bug in division of qd_real by dd_real. 40 | - Fixed bug in ddoutc (Fortran ddmod.f). 41 | - Now compiles with g++ 4.3. 42 | - Distribute tests/coeff.dat. 43 | 44 | Changes for 2.3.4 45 | - Fixed bug in Makefile for cygwin / mingw systems. 46 | 47 | Changes for 2.3.3 48 | - Fixed bug in atan2. 49 | 50 | Changes for 2.3.2 51 | - Fixed bug in sin / cos / sincos where too much accuracy was 52 | lost for (moderately) large angles. 53 | - Use fused-multiply add intrinsics on IA-64 platforms if 54 | compiled by Intel compiler. 55 | - Fixed bug in c_dd_write and c_qd_write. 56 | - Fixed bug were qdext.mod was not being installed. 57 | 58 | Changes for 2.3.1 59 | - Fixed bug in sincos and cos_taylor. This affected the result 60 | of trigonometric functions in some cases. 61 | 62 | Changes for 2.3.0 63 | This is a fairly significant change, breaking API compatibility. 64 | - Moved C++ main entry in libqdmod.a to libqd_f_main.a. 65 | This allows to link Fortran code using QD with custom 66 | C++ main function. Pure Fortran code will need to be linked 67 | with qd_f_main library in addition to qdmod and qd library. 68 | - Constructors accepting pointers made explicit. 69 | - Fortran routines labeled as elemental or pure, where appropriate. 70 | - Write() is now to_string(), and now takes a single fmtflag. 71 | - dd_real addition and multiplication made commutative. 72 | - dd_real now represented as array of two doubles, instead of 73 | two discrete scalars. 74 | - New Fortran generic routines to read / write, operations with 75 | complex and integers. 76 | - Improved exp, sin, and cos functions. 77 | - Removed unused constants and obscure constants only used internally 78 | from public interface. 79 | 80 | Changes for 2.2.6 81 | - Fixed bug in mixed precision multiplication: qd_real * dd_real. 82 | 83 | Changes for 2.2.5 84 | - Bug fix in qd_real addition when --enable-ieee-add is specified. 85 | - Debugging routines dump and dump_bits updated; 86 | dump_components removed (just use dump). 87 | - Fortran support for Fortran strings. Use character arrays instead. 88 | - Return NaN under error conditions. 89 | - Added _inf constant; exp now returns Inf when argument is too large. 90 | - Output formatting fixes for Inf and NaNs. 91 | - Added more real-complex mixed arithmetic routines in Fortran 92 | interface. 93 | 94 | Changes for 2.2.4 95 | - Added random_number interface for Fortran modules. 96 | - Use slightly more conservative values for eps. 97 | - Avoid unnecessary overflow near overflow threshold. 98 | - Added radix, digits, min/maxexponent, range, and precision 99 | intrinsics to Fortran interface. 100 | - Added safe_max (C++) and safe_huge (Fortran). 101 | 102 | Changes for 2.2.3 103 | - Fix sign function bug in Fortran modules. 104 | 105 | Changes for 2.2.2 106 | - Do not bother setting uninitialized dd_real and qd_reals to zero. 107 | - Use clock_gettime if available for timing. 108 | - Fortran I/O should be more consistent with C++ version. 109 | - fpu.h is now included with dd_real.h. 110 | 111 | Changes for 2.2.1 112 | - Minor fixes when printing in scientific format. 113 | - Change search order of C++ compilers in Apple systems to avoid 114 | case insensitive filesystems. 115 | 116 | Changes for 2.2.0 117 | - Added F95 interface for complex types. 118 | - Renamed dd.h and qd.h to dd_real.h and qd_real.h, respectively. 119 | This will break older C++ code using 2.1.x library, but it was 120 | conflicting with QuickDraw libraries on Macs. (Hence the version 121 | bump to 2.2). 122 | - Removed overloaded typecast operators for int and double. These 123 | permitted *automatic* conversion of dd_real/qd_real to double or 124 | int, which is somewhat dangerous. Instead to_int and to_double 125 | routines are added. 126 | 127 | Changes for 2.1.214 128 | - Updated pslq_test. 129 | - Implmented numeric_limits<>. 130 | - Better polyroot. 131 | - Added isnan, isfinite, isinf functions. 132 | - Fix / improve input output functions. 133 | - Drop Microsoft Visual C++ 6.0 support. 134 | - More efficient dd_real::sin. 135 | 136 | Changes for 2.1.213 137 | - Support for x86_64 platforms. 138 | - Drop libtool support for now. 139 | 140 | Changes for 2.1.212 141 | - Support for pathCC compiler. 142 | - Added accurate and sloppy versions of add / sub / mul / div avaialble. 143 | - Added autodetection of fma functions. 144 | 145 | Changes for 2.1 (2003-12-30) 146 | - added automake scripts. 147 | - use libtool to compile / link and build libraries. 148 | - supports standard installation targets (make install). 149 | - support for Intel C++ compilers (icc / ecc). 150 | - Fortran programs are now linked by C++ compiler. 151 | - support for building shared library. 152 | - minor bug fixes. 153 | 154 | Changes for 2.0 (2003-12-08) 155 | - all header files are in "include/qd" directory. 156 | - added autoconf scripts. 157 | - added config.h and qd_config.h to store configuration information. 158 | - renamed x86_* routines to fpu_* routines. 159 | - added separate Fortran interface (f_* routines). 160 | - options for sloppy multiply and sloppy divison separated. 161 | - fixed C interface to be actually in C syntax. 162 | - updated / added README, AUTHORS, NEWS, and LICENSE files. 163 | - minor bug fixes. 164 | 165 | Changes for 1.2 (2003-12-04) 166 | - added "dist-clean" target in Makefile 167 | - initialize dd and qd variables to zero 168 | - increases tolerance for qd / dd tests 169 | - changed .cc extension to .cpp 170 | - updated README, COPYING, and NEWS files 171 | - added ChangeLog file 172 | - fixed bug in '-all' flag in qd_test 173 | - minor bug fixes 174 | 175 | Changes for 1.1 (2002-10-22) 176 | - added "Changes" file (this file) 177 | - fixed to 178 | - fixed constant (3/4) * pi 179 | - fixed exp(x) to return zero if x is a large negative number 180 | - removed "docs" target in Makefile 181 | 182 | -------------------------------------------------------------------------------- /code/solver/qd/bits.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * src/bits.cc 3 | * 4 | * This work was supported by the Director, Office of Science, Division 5 | * of Mathematical, Information, and Computational Sciences of the 6 | * U.S. Department of Energy under contract number DE-AC03-76SF00098. 7 | * 8 | * Copyright (c) 2000-2001 9 | * 10 | * Defines various routines to get / set bits of a IEEE floating point 11 | * number. This used by the library for debugging purposes. 12 | */ 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include "qd_config.h" 20 | #include "inline.h" 21 | #include "bits.h" 22 | 23 | #ifdef HAVE_IEEEFP_H 24 | #include 25 | #endif 26 | 27 | using std::setw; 28 | 29 | int get_double_expn(double x) { 30 | if (x == 0.0) 31 | return INT_MIN; 32 | if (QD_ISINF(x) || QD_ISNAN(x)) 33 | return INT_MAX; 34 | 35 | double y = std::abs(x); 36 | int i = 0; 37 | if (y < 1.0) { 38 | while (y < 1.0) { 39 | y *= 2.0; 40 | i++; 41 | } 42 | return -i; 43 | } else if (y >= 2.0) { 44 | while (y >= 2.0) { 45 | y *= 0.5; 46 | i++; 47 | } 48 | return i; 49 | } 50 | return 0; 51 | } 52 | 53 | void print_double_info(std::ostream &os, double x) { 54 | std::streamsize old_prec = os.precision(19); 55 | std::ios_base::fmtflags old_flags = os.flags(); 56 | os << std::scientific; 57 | 58 | os << setw(27) << x << ' '; 59 | if (QD_ISNAN(x) || QD_ISINF(x) || (x == 0.0)) { 60 | os << " "; 61 | } else { 62 | 63 | x = std::abs(x); 64 | int expn = get_double_expn(x); 65 | double d = std::ldexp(1.0, expn); 66 | os << setw(5) << expn << " "; 67 | for (int i = 0; i < 53; i++) { 68 | if (x >= d) { 69 | x -= d; 70 | os << '1'; 71 | } else 72 | os << '0'; 73 | d *= 0.5; 74 | } 75 | 76 | if (x != 0.0) { 77 | // should not happen 78 | os << " +trailing stuff"; 79 | } 80 | } 81 | 82 | os.precision(old_prec); 83 | os.flags(old_flags); 84 | } 85 | 86 | -------------------------------------------------------------------------------- /code/solver/qd/bits.h: -------------------------------------------------------------------------------- 1 | /* 2 | * include/bits.h 3 | * 4 | * This work was supported by the Director, Office of Science, Division 5 | * of Mathematical, Information, and Computational Sciences of the 6 | * U.S. Department of Energy under contract number DE-AC03-76SF00098. 7 | * 8 | * Copyright (c) 2000-2001 9 | * 10 | * This file defines various routines to get / set bits of a IEEE floating 11 | * point number. This is used by the library for debugging purposes. 12 | */ 13 | 14 | #ifndef _QD_BITS_H 15 | #define _QD_BITS_H 16 | 17 | #include 18 | #include "qd_config.h" 19 | 20 | /* Returns the exponent of the double precision number. 21 | Returns INT_MIN is x is zero, and INT_MAX if x is INF or NaN. */ 22 | int get_double_expn(double x); 23 | 24 | /* Prints 25 | SIGN EXPN MANTISSA 26 | of the given double. If x is NaN, INF, or Zero, this 27 | prints out the strings NaN, +/- INF, and 0. */ 28 | void print_double_info(std::ostream &os, double x); 29 | 30 | 31 | #endif /* _QD_BITS_H */ 32 | 33 | -------------------------------------------------------------------------------- /code/solver/qd/c_dd.h: -------------------------------------------------------------------------------- 1 | /* 2 | * include/c_dd.h 3 | * 4 | * This work was supported by the Director, Office of Science, Division 5 | * of Mathematical, Information, and Computational Sciences of the 6 | * U.S. Department of Energy under contract number DE-AC03-76SF00098. 7 | * 8 | * Copyright (c) 2000-2001 9 | * 10 | * Contains C wrapper function prototypes for double-double precision 11 | * arithmetic. This can also be used from fortran code. 12 | */ 13 | #ifndef _QD_C_DD_H 14 | #define _QD_C_DD_H 15 | 16 | #include "qd_config.h" 17 | #include "fpu.h" 18 | 19 | #ifdef __cplusplus 20 | extern "C" { 21 | #endif 22 | 23 | /* add */ 24 | void c_dd_add(const double *a, const double *b, double *c); 25 | void c_dd_add_d_dd(double a, const double *b, double *c); 26 | void c_dd_add_dd_d(const double *a, double b, double *c); 27 | 28 | /* sub */ 29 | void c_dd_sub(const double *a, const double *b, double *c); 30 | void c_dd_sub_d_dd(double a, const double *b, double *c); 31 | void c_dd_sub_dd_d(const double *a, double b, double *c); 32 | 33 | /* mul */ 34 | void c_dd_mul(const double *a, const double *b, double *c); 35 | void c_dd_mul_d_dd(double a, const double *b, double *c); 36 | void c_dd_mul_dd_d(const double *a, double b, double *c); 37 | 38 | /* div */ 39 | void c_dd_div(const double *a, const double *b, double *c); 40 | void c_dd_div_d_dd(double a, const double *b, double *c); 41 | void c_dd_div_dd_d(const double *a, double b, double *c); 42 | 43 | /* copy */ 44 | void c_dd_copy(const double *a, double *b); 45 | void c_dd_copy_d(double a, double *b); 46 | 47 | void c_dd_sqrt(const double *a, double *b); 48 | void c_dd_sqr(const double *a, double *b); 49 | 50 | void c_dd_abs(const double *a, double *b); 51 | 52 | void c_dd_npwr(const double *a, int b, double *c); 53 | void c_dd_nroot(const double *a, int b, double *c); 54 | 55 | void c_dd_nint(const double *a, double *b); 56 | void c_dd_aint(const double *a, double *b); 57 | void c_dd_floor(const double *a, double *b); 58 | void c_dd_ceil(const double *a, double *b); 59 | 60 | void c_dd_exp(const double *a, double *b); 61 | void c_dd_log(const double *a, double *b); 62 | void c_dd_log10(const double *a, double *b); 63 | 64 | void c_dd_sin(const double *a, double *b); 65 | void c_dd_cos(const double *a, double *b); 66 | void c_dd_tan(const double *a, double *b); 67 | 68 | void c_dd_asin(const double *a, double *b); 69 | void c_dd_acos(const double *a, double *b); 70 | void c_dd_atan(const double *a, double *b); 71 | void c_dd_atan2(const double *a, const double *b, double *c); 72 | 73 | void c_dd_sinh(const double *a, double *b); 74 | void c_dd_cosh(const double *a, double *b); 75 | void c_dd_tanh(const double *a, double *b); 76 | 77 | void c_dd_asinh(const double *a, double *b); 78 | void c_dd_acosh(const double *a, double *b); 79 | void c_dd_atanh(const double *a, double *b); 80 | 81 | void c_dd_sincos(const double *a, double *s, double *c); 82 | void c_dd_sincosh(const double *a, double *s, double *c); 83 | 84 | void c_dd_read(const char *s, double *a); 85 | void c_dd_swrite(const double *a, int precision, char *s, int len); 86 | void c_dd_write(const double *a); 87 | void c_dd_neg(const double *a, double *b); 88 | void c_dd_rand(double *a); 89 | void c_dd_comp(const double *a, const double *b, int *result); 90 | void c_dd_comp_dd_d(const double *a, double b, int *result); 91 | void c_dd_comp_d_dd(double a, const double *b, int *result); 92 | void c_dd_pi(double *a); 93 | 94 | #ifdef __cplusplus 95 | } 96 | #endif 97 | 98 | #endif /* _QD_C_DD_H */ 99 | -------------------------------------------------------------------------------- /code/solver/qd/c_qd.h: -------------------------------------------------------------------------------- 1 | /* 2 | * include/c_qd.h 3 | * 4 | * This work was supported by the Director, Office of Science, Division 5 | * of Mathematical, Information, and Computational Sciences of the 6 | * U.S. Department of Energy under contract number DE-AC03-76SF00098. 7 | * 8 | * Copyright (c) 2000-2001 9 | * 10 | * Contains C wrapper function prototypes for quad-double precision 11 | * arithmetic. This can also be used from fortran code. 12 | */ 13 | #ifndef _QD_C_QD_H 14 | #define _QD_C_QD_H 15 | 16 | #include "c_dd.h" 17 | #include "qd_config.h" 18 | 19 | #ifdef __cplusplus 20 | extern "C" { 21 | #endif 22 | 23 | /* add */ 24 | void c_qd_add(const double *a, const double *b, double *c); 25 | void c_qd_add_dd_qd(const double *a, const double *b, double *c); 26 | void c_qd_add_qd_dd(const double *a, const double *b, double *c); 27 | void c_qd_add_d_qd(double a, const double *b, double *c); 28 | void c_qd_add_qd_d(const double *a, double b, double *c); 29 | void c_qd_selfadd(const double *a, double *b); 30 | void c_qd_selfadd_dd(const double *a, double *b); 31 | void c_qd_selfadd_d(double a, double *b); 32 | 33 | /* sub */ 34 | void c_qd_sub(const double *a, const double *b, double *c); 35 | void c_qd_sub_dd_qd(const double *a, const double *b, double *c); 36 | void c_qd_sub_qd_dd(const double *a, const double *b, double *c); 37 | void c_qd_sub_d_qd(double a, const double *b, double *c); 38 | void c_qd_sub_qd_d(const double *a, double b, double *c); 39 | void c_qd_selfsub(const double *a, double *b); 40 | void c_qd_selfsub_dd(const double *a, double *b); 41 | void c_qd_selfsub_d(double a, double *b); 42 | 43 | /* mul */ 44 | void c_qd_mul(const double *a, const double *b, double *c); 45 | void c_qd_mul_dd_qd(const double *a, const double *b, double *c); 46 | void c_qd_mul_qd_dd(const double *a, const double *b, double *c); 47 | void c_qd_mul_d_qd(double a, const double *b, double *c); 48 | void c_qd_mul_qd_d(const double *a, double b, double *c); 49 | void c_qd_selfmul(const double *a, double *b); 50 | void c_qd_selfmul_dd(const double *a, double *b); 51 | void c_qd_selfmul_d(double a, double *b); 52 | 53 | /* div */ 54 | void c_qd_div(const double *a, const double *b, double *c); 55 | void c_qd_div_dd_qd(const double *a, const double *b, double *c); 56 | void c_qd_div_qd_dd(const double *a, const double *b, double *c); 57 | void c_qd_div_d_qd(double a, const double *b, double *c); 58 | void c_qd_div_qd_d(const double *a, double b, double *c); 59 | void c_qd_selfdiv(const double *a, double *b); 60 | void c_qd_selfdiv_dd(const double *a, double *b); 61 | void c_qd_selfdiv_d(double a, double *b); 62 | 63 | /* copy */ 64 | void c_qd_copy(const double *a, double *b); 65 | void c_qd_copy_dd(const double *a, double *b); 66 | void c_qd_copy_d(double a, double *b); 67 | 68 | void c_qd_sqrt(const double *a, double *b); 69 | void c_qd_sqr(const double *a, double *b); 70 | 71 | void c_qd_abs(const double *a, double *b); 72 | 73 | void c_qd_npwr(const double *a, int b, double *c); 74 | void c_qd_nroot(const double *a, int b, double *c); 75 | 76 | void c_qd_nint(const double *a, double *b); 77 | void c_qd_aint(const double *a, double *b); 78 | void c_qd_floor(const double *a, double *b); 79 | void c_qd_ceil(const double *a, double *b); 80 | 81 | void c_qd_exp(const double *a, double *b); 82 | void c_qd_log(const double *a, double *b); 83 | void c_qd_log10(const double *a, double *b); 84 | 85 | void c_qd_sin(const double *a, double *b); 86 | void c_qd_cos(const double *a, double *b); 87 | void c_qd_tan(const double *a, double *b); 88 | 89 | void c_qd_asin(const double *a, double *b); 90 | void c_qd_acos(const double *a, double *b); 91 | void c_qd_atan(const double *a, double *b); 92 | void c_qd_atan2(const double *a, const double *b, double *c); 93 | 94 | void c_qd_sinh(const double *a, double *b); 95 | void c_qd_cosh(const double *a, double *b); 96 | void c_qd_tanh(const double *a, double *b); 97 | 98 | void c_qd_asinh(const double *a, double *b); 99 | void c_qd_acosh(const double *a, double *b); 100 | void c_qd_atanh(const double *a, double *b); 101 | 102 | void c_qd_sincos(const double *a, double *s, double *c); 103 | void c_qd_sincosh(const double *a, double *s, double *c); 104 | 105 | void c_qd_read(const char *s, double *a); 106 | void c_qd_swrite(const double *a, int precision, char *s, int len); 107 | void c_qd_write(const double *a); 108 | void c_qd_neg(const double *a, double *b); 109 | void c_qd_rand(double *a); 110 | void c_qd_comp(const double *a, const double *b, int *result); 111 | void c_qd_comp_qd_d(const double *a, double b, int *result); 112 | void c_qd_comp_d_qd(double a, const double *b, int *result); 113 | void c_qd_pi(double *a); 114 | 115 | #ifdef __cplusplus 116 | } 117 | #endif 118 | 119 | #endif /* _QD_C_QD_H */ 120 | -------------------------------------------------------------------------------- /code/solver/qd/dd_const.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * src/dd_const.cc 3 | * 4 | * This work was supported by the Director, Office of Science, Division 5 | * of Mathematical, Information, and Computational Sciences of the 6 | * U.S. Department of Energy under contract number DE-AC03-76SF00098. 7 | * 8 | * Copyright (c) 2000-2007 9 | */ 10 | #include "qd_config.h" 11 | #include "dd_real.h" 12 | 13 | const dd_real dd_real::_2pi = dd_real(6.283185307179586232e+00, 14 | 2.449293598294706414e-16); 15 | const dd_real dd_real::_pi = dd_real(3.141592653589793116e+00, 16 | 1.224646799147353207e-16); 17 | const dd_real dd_real::_pi2 = dd_real(1.570796326794896558e+00, 18 | 6.123233995736766036e-17); 19 | const dd_real dd_real::_pi4 = dd_real(7.853981633974482790e-01, 20 | 3.061616997868383018e-17); 21 | const dd_real dd_real::_3pi4 = dd_real(2.356194490192344837e+00, 22 | 9.1848509936051484375e-17); 23 | const dd_real dd_real::_e = dd_real(2.718281828459045091e+00, 24 | 1.445646891729250158e-16); 25 | const dd_real dd_real::_log2 = dd_real(6.931471805599452862e-01, 26 | 2.319046813846299558e-17); 27 | const dd_real dd_real::_log10 = dd_real(2.302585092994045901e+00, 28 | -2.170756223382249351e-16); 29 | const dd_real dd_real::_nan = dd_real(qd::_d_nan, qd::_d_nan); 30 | const dd_real dd_real::_inf = dd_real(qd::_d_inf, qd::_d_inf); 31 | 32 | const double dd_real::_eps = 4.93038065763132e-32; // 2^-104 33 | const double dd_real::_min_normalized = 2.0041683600089728e-292; // = 2^(-1022 + 53) 34 | const dd_real dd_real::_max = 35 | dd_real(1.79769313486231570815e+308, 9.97920154767359795037e+291); 36 | const dd_real dd_real::_safe_max = 37 | dd_real(1.7976931080746007281e+308, 9.97920154767359795037e+291); 38 | const int dd_real::_ndigits = 31; 39 | 40 | 41 | -------------------------------------------------------------------------------- /code/solver/qd/fpu.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * src/fpu.cc 3 | * 4 | * This work was supported by the Director, Office of Science, Division 5 | * of Mathematical, Information, and Computational Sciences of the 6 | * U.S. Department of Energy under contract number DE-AC03-76SF00098. 7 | * 8 | * Copyright (c) 2000-2001 9 | * 10 | * Contains functions to set and restore the round-to-double flag in the 11 | * control word of a x86 FPU. 12 | */ 13 | 14 | #include "qd_config.h" 15 | #include "fpu.h" 16 | 17 | #ifdef X86 18 | #ifdef _WIN32 19 | #include 20 | #else 21 | 22 | #ifdef HAVE_FPU_CONTROL_H 23 | #include 24 | #endif 25 | 26 | #ifndef _FPU_GETCW 27 | #define _FPU_GETCW(x) asm volatile ("fnstcw %0":"=m" (x)); 28 | #endif 29 | 30 | #ifndef _FPU_SETCW 31 | #define _FPU_SETCW(x) asm volatile ("fldcw %0": :"m" (x)); 32 | #endif 33 | 34 | #ifndef _FPU_EXTENDED 35 | #define _FPU_EXTENDED 0x0300 36 | #endif 37 | 38 | #ifndef _FPU_DOUBLE 39 | #define _FPU_DOUBLE 0x0200 40 | #endif 41 | 42 | #endif 43 | #endif /* X86 */ 44 | 45 | extern "C" { 46 | 47 | void fpu_fix_start(unsigned int *old_cw) { 48 | #ifdef X86 49 | #ifdef _WIN32 50 | #ifdef __BORLANDC__ 51 | /* Win 32 Borland C */ 52 | unsigned short cw = _control87(0, 0); 53 | _control87(0x0200, 0x0300); 54 | if (old_cw) { 55 | *old_cw = cw; 56 | } 57 | #else 58 | /* Win 32 MSVC */ 59 | unsigned int cw = _control87(0, 0); 60 | _control87(0x00010000, 0x00030000); 61 | if (old_cw) { 62 | *old_cw = cw; 63 | } 64 | #endif 65 | #else 66 | /* Linux */ 67 | volatile unsigned short cw, new_cw; 68 | _FPU_GETCW(cw); 69 | 70 | new_cw = (cw & ~_FPU_EXTENDED) | _FPU_DOUBLE; 71 | _FPU_SETCW(new_cw); 72 | 73 | if (old_cw) { 74 | *old_cw = cw; 75 | } 76 | #endif 77 | #endif 78 | } 79 | 80 | void fpu_fix_end(unsigned int *old_cw) { 81 | #ifdef X86 82 | #ifdef _WIN32 83 | 84 | #ifdef __BORLANDC__ 85 | /* Win 32 Borland C */ 86 | if (old_cw) { 87 | unsigned short cw = (unsigned short) *old_cw; 88 | _control87(cw, 0xFFFF); 89 | } 90 | #else 91 | /* Win 32 MSVC */ 92 | if (old_cw) { 93 | _control87(*old_cw, 0xFFFFFFFF); 94 | } 95 | #endif 96 | 97 | #else 98 | /* Linux */ 99 | if (old_cw) { 100 | int cw; 101 | cw = *old_cw; 102 | _FPU_SETCW(cw); 103 | } 104 | #endif 105 | #endif 106 | } 107 | 108 | #ifdef HAVE_FORTRAN 109 | 110 | #define f_fpu_fix_start FC_FUNC_(f_fpu_fix_start, F_FPU_FIX_START) 111 | #define f_fpu_fix_end FC_FUNC_(f_fpu_fix_end, F_FPU_FIX_END) 112 | 113 | void f_fpu_fix_start(unsigned int *old_cw) { 114 | fpu_fix_start(old_cw); 115 | } 116 | 117 | void f_fpu_fix_end(unsigned int *old_cw) { 118 | fpu_fix_end(old_cw); 119 | } 120 | 121 | #endif 122 | 123 | } 124 | 125 | -------------------------------------------------------------------------------- /code/solver/qd/fpu.h: -------------------------------------------------------------------------------- 1 | /* 2 | * include/fpu.h 3 | * 4 | * This work was supported by the Director, Office of Science, Division 5 | * of Mathematical, Information, and Computational Sciences of the 6 | * U.S. Department of Energy under contract number DE-AC03-76SF00098. 7 | * 8 | * Copyright (c) 2001 9 | * 10 | * Contains functions to set and restore the round-to-double flag in the 11 | * control word of a x86 FPU. The algorithms in the double-double and 12 | * quad-double package does not function with the extended mode found in 13 | * these FPU. 14 | */ 15 | #ifndef _QD_FPU_H 16 | #define _QD_FPU_H 17 | 18 | #include "qd_config.h" 19 | 20 | #ifdef __cplusplus 21 | extern "C" { 22 | #endif 23 | 24 | /* 25 | * Set the round-to-double flag, and save the old control word in old_cw. 26 | * If old_cw is NULL, the old control word is not saved. 27 | */ 28 | QD_API void fpu_fix_start(unsigned int *old_cw); 29 | 30 | /* 31 | * Restore the control word. 32 | */ 33 | QD_API void fpu_fix_end(unsigned int *old_cw); 34 | 35 | #ifdef __cplusplus 36 | } 37 | #endif 38 | 39 | #endif /* _QD_FPU_H */ 40 | -------------------------------------------------------------------------------- /code/solver/qd/inline.h: -------------------------------------------------------------------------------- 1 | /* 2 | * include/inline.h 3 | * 4 | * This work was supported by the Director, Office of Science, Division 5 | * of Mathematical, Information, and Computational Sciences of the 6 | * U.S. Department of Energy under contract number DE-AC03-76SF00098. 7 | * 8 | * Copyright (c) 2000-2001 9 | * 10 | * This file contains the basic functions used both by double-double 11 | * and quad-double package. These are declared as inline functions as 12 | * they are the smallest building blocks of the double-double and 13 | * quad-double arithmetic. 14 | */ 15 | #ifndef _QD_INLINE_H 16 | #define _QD_INLINE_H 17 | 18 | #define _QD_SPLITTER 134217729.0 // = 2^27 + 1 19 | #define _QD_SPLIT_THRESH 6.69692879491417e+299 // = 2^996 20 | 21 | #ifdef QD_VACPP_BUILTINS_H 22 | /* For VisualAge C++ __fmadd */ 23 | #include 24 | #endif 25 | 26 | #include 27 | #include 28 | 29 | namespace qd { 30 | 31 | static const double _d_nan = std::numeric_limits::quiet_NaN(); 32 | static const double _d_inf = std::numeric_limits::infinity(); 33 | 34 | /*********** Basic Functions ************/ 35 | /* Computes fl(a+b) and err(a+b). Assumes |a| >= |b|. */ 36 | inline double quick_two_sum(double a, double b, double &err) { 37 | double s = a + b; 38 | err = b - (s - a); 39 | return s; 40 | } 41 | 42 | /* Computes fl(a-b) and err(a-b). Assumes |a| >= |b| */ 43 | inline double quick_two_diff(double a, double b, double &err) { 44 | double s = a - b; 45 | err = (a - s) - b; 46 | return s; 47 | } 48 | 49 | /* Computes fl(a+b) and err(a+b). */ 50 | inline double two_sum(double a, double b, double &err) { 51 | double s = a + b; 52 | double bb = s - a; 53 | err = (a - (s - bb)) + (b - bb); 54 | return s; 55 | } 56 | 57 | /* Computes fl(a-b) and err(a-b). */ 58 | inline double two_diff(double a, double b, double &err) { 59 | double s = a - b; 60 | double bb = s - a; 61 | err = (a - (s - bb)) - (b + bb); 62 | return s; 63 | } 64 | 65 | #ifndef QD_FMS 66 | /* Computes high word and lo word of a */ 67 | inline void split(double a, double &hi, double &lo) { 68 | double temp; 69 | if (a > _QD_SPLIT_THRESH || a < -_QD_SPLIT_THRESH) { 70 | a *= 3.7252902984619140625e-09; // 2^-28 71 | temp = _QD_SPLITTER * a; 72 | hi = temp - (temp - a); 73 | lo = a - hi; 74 | hi *= 268435456.0; // 2^28 75 | lo *= 268435456.0; // 2^28 76 | } else { 77 | temp = _QD_SPLITTER * a; 78 | hi = temp - (temp - a); 79 | lo = a - hi; 80 | } 81 | } 82 | #endif 83 | 84 | /* Computes fl(a*b) and err(a*b). */ 85 | inline double two_prod(double a, double b, double &err) { 86 | #ifdef QD_FMS 87 | double p = a * b; 88 | err = QD_FMS(a, b, p); 89 | return p; 90 | #else 91 | double a_hi, a_lo, b_hi, b_lo; 92 | double p = a * b; 93 | split(a, a_hi, a_lo); 94 | split(b, b_hi, b_lo); 95 | err = ((a_hi * b_hi - p) + a_hi * b_lo + a_lo * b_hi) + a_lo * b_lo; 96 | return p; 97 | #endif 98 | } 99 | 100 | /* Computes fl(a*a) and err(a*a). Faster than the above method. */ 101 | inline double two_sqr(double a, double &err) { 102 | #ifdef QD_FMS 103 | double p = a * a; 104 | err = QD_FMS(a, a, p); 105 | return p; 106 | #else 107 | double hi, lo; 108 | double q = a * a; 109 | split(a, hi, lo); 110 | err = ((hi * hi - q) + 2.0 * hi * lo) + lo * lo; 111 | return q; 112 | #endif 113 | } 114 | 115 | /* Computes the nearest integer to d. */ 116 | inline double nint(double d) { 117 | if (d == std::floor(d)) 118 | return d; 119 | return std::floor(d + 0.5); 120 | } 121 | 122 | /* Computes the truncated integer. */ 123 | inline double aint(double d) { 124 | return (d >= 0.0) ? std::floor(d) : std::ceil(d); 125 | } 126 | 127 | /* These are provided to give consistent 128 | interface for double with double-double and quad-double. */ 129 | inline void sincosh(double t, double &sinh_t, double &cosh_t) { 130 | sinh_t = std::sinh(t); 131 | cosh_t = std::cosh(t); 132 | } 133 | 134 | inline double sqr(double t) { 135 | return t * t; 136 | } 137 | 138 | inline double to_double(double a) { return a; } 139 | inline int to_int(double a) { return static_cast(a); } 140 | 141 | } 142 | 143 | #endif /* _QD_INLINE_H */ 144 | -------------------------------------------------------------------------------- /code/solver/qd/qd.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/code/solver/qd/qd.pdf -------------------------------------------------------------------------------- /code/solver/qd/qd_config.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | /* include/qd/qd_config.h. Generated from qd_config.h.in by configure. */ 5 | #ifndef _QD_QD_CONFIG_H 6 | #define _QD_QD_CONFIG_H 1 7 | 8 | #ifndef QD_API 9 | #define QD_API /**/ 10 | #endif 11 | 12 | /* Set to 1 if using VisualAge C++ compiler for __fmadd builtin. */ 13 | #ifndef QD_VACPP_BUILTINS_H 14 | /* #undef QD_VACPP_BUILTINS_H */ 15 | #endif 16 | 17 | /* If fused multiply-add is available, define to correct macro for 18 | using it. It is invoked as QD_FMA(a, b, c) to compute fl(a * b + c). 19 | If correctly rounded multiply-add is not available (or if unsure), 20 | keep it undefined.*/ 21 | #ifndef QD_FMA 22 | /* #undef QD_FMA */ 23 | #endif 24 | 25 | /* If fused multiply-subtract is available, define to correct macro for 26 | using it. It is invoked as QD_FMS(a, b, c) to compute fl(a * b - c). 27 | If correctly rounded multiply-add is not available (or if unsure), 28 | keep it undefined.*/ 29 | #ifndef QD_FMS 30 | #define QD_FMS(a, b, c) std::fma(a,b,-c) 31 | /* #undef QD_FMS */ 32 | #endif 33 | 34 | /* Set the following to 1 to define commonly used function 35 | to be inlined. This should be set to 1 unless the compiler 36 | does not support the "inline" keyword, or if building for 37 | debugging purposes. */ 38 | #ifndef QD_INLINE 39 | #define QD_INLINE 1 40 | #endif 41 | 42 | /* Set the following to 1 to use ANSI C++ standard header files 43 | such as cmath, iostream, etc. If set to zero, it will try to 44 | include math.h, iostream.h, etc, instead. */ 45 | #ifndef QD_HAVE_STD 46 | #define QD_HAVE_STD 1 47 | #endif 48 | 49 | /* Set the following to 1 to make the addition and subtraction 50 | to satisfy the IEEE-style error bound 51 | 52 | fl(a + b) = (1 + d) * (a + b) 53 | 54 | where |d| <= eps. If set to 0, the addition and subtraction 55 | will satisfy the weaker Cray-style error bound 56 | 57 | fl(a + b) = (1 + d1) * a + (1 + d2) * b 58 | 59 | where |d1| <= eps and |d2| eps. */ 60 | #ifndef QD_IEEE_ADD 61 | /* #undef QD_IEEE_ADD */ 62 | #endif 63 | 64 | /* Set the following to 1 to use slightly inaccurate but faster 65 | version of multiplication. */ 66 | #ifndef QD_SLOPPY_MUL 67 | #define QD_SLOPPY_MUL 1 68 | #endif 69 | 70 | /* Set the following to 1 to use slightly inaccurate but faster 71 | version of division. */ 72 | #ifndef QD_SLOPPY_DIV 73 | #define QD_SLOPPY_DIV 1 74 | #endif 75 | 76 | /* Define this macro to be the isfinite(x) function. */ 77 | #ifndef QD_ISFINITE 78 | #define QD_ISFINITE(x) std::isfinite(x) 79 | #endif 80 | 81 | /* Define this macro to be the isinf(x) function. */ 82 | #ifndef QD_ISINF 83 | #define QD_ISINF(x) std::isinf(x) 84 | #endif 85 | 86 | /* Define this macro to be the isnan(x) function. */ 87 | #ifndef QD_ISNAN 88 | #define QD_ISNAN(x) std::isnan(x) 89 | #endif 90 | 91 | 92 | #endif /* _QD_QD_CONFIG_H */ 93 | -------------------------------------------------------------------------------- /code/solver/qd/qd_const.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * src/qd_const.cc 3 | * 4 | * This work was supported by the Director, Office of Science, Division 5 | * of Mathematical, Information, and Computational Sciences of the 6 | * U.S. Department of Energy under contract number DE-AC03-76SF00098. 7 | * 8 | * Copyright (c) 2000-2001 9 | * 10 | * Defines constants used in quad-double package. 11 | */ 12 | #include "qd_config.h" 13 | #include "qd_real.h" 14 | 15 | /* Some useful constants. */ 16 | const qd_real qd_real::_2pi = qd_real(6.283185307179586232e+00, 17 | 2.449293598294706414e-16, 18 | -5.989539619436679332e-33, 19 | 2.224908441726730563e-49); 20 | const qd_real qd_real::_pi = qd_real(3.141592653589793116e+00, 21 | 1.224646799147353207e-16, 22 | -2.994769809718339666e-33, 23 | 1.112454220863365282e-49); 24 | const qd_real qd_real::_pi2 = qd_real(1.570796326794896558e+00, 25 | 6.123233995736766036e-17, 26 | -1.497384904859169833e-33, 27 | 5.562271104316826408e-50); 28 | const qd_real qd_real::_pi4 = qd_real(7.853981633974482790e-01, 29 | 3.061616997868383018e-17, 30 | -7.486924524295849165e-34, 31 | 2.781135552158413204e-50); 32 | const qd_real qd_real::_3pi4 = qd_real(2.356194490192344837e+00, 33 | 9.1848509936051484375e-17, 34 | 3.9168984647504003225e-33, 35 | -2.5867981632704860386e-49); 36 | const qd_real qd_real::_e = qd_real(2.718281828459045091e+00, 37 | 1.445646891729250158e-16, 38 | -2.127717108038176765e-33, 39 | 1.515630159841218954e-49); 40 | const qd_real qd_real::_log2 = qd_real(6.931471805599452862e-01, 41 | 2.319046813846299558e-17, 42 | 5.707708438416212066e-34, 43 | -3.582432210601811423e-50); 44 | const qd_real qd_real::_log10 = qd_real(2.302585092994045901e+00, 45 | -2.170756223382249351e-16, 46 | -9.984262454465776570e-33, 47 | -4.023357454450206379e-49); 48 | const qd_real qd_real::_nan = qd_real(qd::_d_nan, qd::_d_nan, 49 | qd::_d_nan, qd::_d_nan); 50 | const qd_real qd_real::_inf = qd_real(qd::_d_inf, qd::_d_inf, 51 | qd::_d_inf, qd::_d_inf); 52 | 53 | const double qd_real::_eps = 1.21543267145725e-63; // = 2^-209 54 | const double qd_real::_min_normalized = 1.6259745436952323e-260; // = 2^(-1022 + 3*53) 55 | const qd_real qd_real::_max = qd_real( 56 | 1.79769313486231570815e+308, 9.97920154767359795037e+291, 57 | 5.53956966280111259858e+275, 3.07507889307840487279e+259); 58 | const qd_real qd_real::_safe_max = qd_real( 59 | 1.7976931080746007281e+308, 9.97920154767359795037e+291, 60 | 5.53956966280111259858e+275, 3.07507889307840487279e+259); 61 | const int qd_real::_ndigits = 62; 62 | 63 | -------------------------------------------------------------------------------- /code/solver/qd/util.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "util.h" 3 | 4 | void append_expn(std::string &str, int expn) { 5 | int k; 6 | 7 | str += (expn < 0 ? '-' : '+'); 8 | expn = std::abs(expn); 9 | 10 | if (expn >= 100) { 11 | k = (expn / 100); 12 | str += '0' + k; 13 | expn -= 100*k; 14 | } 15 | 16 | k = (expn / 10); 17 | str += '0' + k; 18 | expn -= 10*k; 19 | 20 | str += '0' + expn; 21 | } 22 | 23 | -------------------------------------------------------------------------------- /code/solver/qd/util.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void append_expn(std::string &str, int expn); 4 | 5 | -------------------------------------------------------------------------------- /code/utils/Setfield.m: -------------------------------------------------------------------------------- 1 | % obj holds the default value 2 | function obj = Setfield(obj, source) 3 | key = fieldnames(obj); 4 | for i = 1:length(key) 5 | if isfield(source, key{i}) || isprop(source, key{i}) 6 | obj.(key{i}) = source.(key{i}); 7 | end 8 | end 9 | 10 | if isstruct(obj) 11 | key = fieldnames(source); 12 | for i = 1:length(key) 13 | obj.(key{i}) = source.(key{i}); 14 | end 15 | end 16 | end 17 | -------------------------------------------------------------------------------- /code/utils/TableDisplay.m: -------------------------------------------------------------------------------- 1 | % This class is used internal for testing only 2 | % It prints out a table (to a string) row by row. 3 | classdef TableDisplay < handle 4 | properties 5 | format 6 | end 7 | 8 | methods 9 | function o = TableDisplay(format) 10 | % o = TableDisplay(format) 11 | % format is struct where each field is a col in the table. 12 | % If that field is a string, 13 | % it represents the format (according to printf) 14 | % otherwise 15 | % it is a structure with fields 16 | % format 17 | % default (default value for the field) 18 | % length 19 | % label 20 | % type (double or string) 21 | 22 | fields = fieldnames(format); 23 | for i = 1:length(fields) 24 | name = fields{i}; 25 | field = format.(fields{i}); 26 | 27 | % if the field contains only a string, 28 | % convert it into the structure format. 29 | if ischar(field) 30 | formattmp = field; 31 | field = struct; 32 | field.format = formattmp; 33 | end 34 | 35 | % read off the type if not specified 36 | if ~(isfield(field, 'type')) 37 | if endsWith(field.format, 's') 38 | field.type = 'string'; 39 | else 40 | field.type = 'double'; 41 | end 42 | end 43 | 44 | % set the default if not specified 45 | if ~(isfield(field, 'default')) 46 | if strcmp(field.type, 'string') 47 | field.default = ''; 48 | else 49 | field.default = NaN; 50 | end 51 | end 52 | 53 | % read off the length from the format if not specified 54 | if ~(isfield(field, 'length')) 55 | matchStr = regexp(field.format, '[0-9]*', 'match', 'once'); 56 | if ismissing(matchStr) 57 | field.length = +Inf; 58 | else 59 | field.length = str2double(matchStr); 60 | end 61 | end 62 | 63 | % use the field id as label if not specified 64 | if ~(isfield(field, 'label')) 65 | field.label = name; 66 | end 67 | 68 | format.(name) = field; 69 | end 70 | o.format = format; 71 | end 72 | 73 | function s = header(o) 74 | % s = o.header(); 75 | % Print out the header of the table 76 | 77 | s = ''; 78 | fields = fieldnames(o.format); 79 | total_length = 0; 80 | for i = 1:length(fields) 81 | field = o.format.(fields{i}); 82 | if field.length == +Inf 83 | f = '%s'; 84 | total_length = total_length + strlength(field.label); 85 | else 86 | f = strcat('%', num2str(field.length), 's'); 87 | total_length = total_length + field.length + 1; 88 | end 89 | s = strcat(s, sprintf(f, field.label), ' '); 90 | end 91 | total_length = total_length - 1; 92 | 93 | s = [s, newline, repmat('-', 1, total_length), newline]; 94 | end 95 | 96 | function s = print(o, data) 97 | % s = o.print(item); 98 | % Print out a row of the table with the data 99 | 100 | s = ''; 101 | fields = fieldnames(o.format); 102 | for i = 1:length(fields) 103 | name = fields{i}; 104 | if (isfield(data, name)) 105 | data_i = data.(name); 106 | else 107 | data_i = o.format.(fields{i}).default; 108 | end 109 | field = o.format.(name); 110 | if strcmp(field.type, 'string') && ... 111 | strlength(data_i) > field.length-1 112 | data_i = extractBetween(data_i, 1, field.length-1); 113 | data_i = data_i{1}; 114 | end 115 | s = strcat(s, sprintf(strcat('%', field.format), data_i), ' '); 116 | end 117 | 118 | s = [s, newline]; 119 | end 120 | end 121 | end 122 | -------------------------------------------------------------------------------- /code/utils/blendv.m: -------------------------------------------------------------------------------- 1 | function v = blendv(m1, m2, mask) 2 | v = m1; 3 | v(mask,:,:) = m2(mask,:,:); 4 | end -------------------------------------------------------------------------------- /code/utils/dblcmp.m: -------------------------------------------------------------------------------- 1 | % compares a and b and returns true if they are identical up to tol 2 | function c = dblcmp(a, b, tol) 3 | if nargin < 3, tol = eps; end 4 | 5 | % <= does not work for the case with Inf 6 | c = (abs(a - b) < tol * (1+abs(a)+abs(b))); 7 | end -------------------------------------------------------------------------------- /code/utils/nonempty.m: -------------------------------------------------------------------------------- 1 | function b = nonempty(obj, field) 2 | b = isfield(obj, field) && ~isempty(obj.(field)); -------------------------------------------------------------------------------- /code/utils/spdiag.m: -------------------------------------------------------------------------------- 1 | function D = spdiag(x) 2 | % D = spdiag(x) 3 | % Output the diagonal matrix given by the vector x 4 | 5 | D = diag(sparse(x)); 6 | end -------------------------------------------------------------------------------- /code/utils/timeit.m: -------------------------------------------------------------------------------- 1 | function t = timeit(f) 2 | t = tic; 3 | f(); % warm up 4 | t = toc(t); 5 | 6 | repeat = ceil(2 + 1e-5 / t); % the program takes at least 1e-5 sec. 7 | t = tic; 8 | for i = 1:repeat 9 | f(); 10 | end 11 | t = toc(t) / repeat; 12 | end 13 | -------------------------------------------------------------------------------- /coverage/Recon1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/Recon1.mat -------------------------------------------------------------------------------- /coverage/TestSuite.m: -------------------------------------------------------------------------------- 1 | classdef TestSuite < handle 2 | properties 3 | printFormat 4 | problemFilter 5 | problems = []; 6 | debug = 0 7 | nCores = +Inf % +Inf means default number of cores 8 | randomSeed = 0 % 0 means do not fix randomSeed 9 | testFunc % function for testing 10 | end 11 | 12 | methods 13 | function o = TestSuite() 14 | %% Prepare the default printout info 15 | o.printFormat = struct; 16 | o.printFormat.id = '5i'; 17 | o.printFormat.name = '50s'; 18 | o.printFormat.success = struct('label', 'ok', 'format', '4i'); 19 | o.printFormat.time = '10.2f'; 20 | 21 | %% Setup the problem lists 22 | o.problemFilter = struct; 23 | o.problemFilter.ignoreProblems = ... 24 | {'extra/fit1p$', 'extra/fit2p$', ... %matlab can't solve it 25 | 'basic/random_dense@\d\d\d\d', 'basic/random_sparse@\d\d\d', 'basic/birkhoff@\d\d\d\d', ... % problem too large 26 | }; 27 | o.problemFilter.fileSizeLimit = [0 120000]; 28 | end 29 | 30 | function test(o) 31 | output = TableDisplay(o.printFormat); 32 | fprintf(output.header()); 33 | if isempty(o.problems) 34 | o.problems = problemList(o.problemFilter); 35 | end 36 | success = 0; total_time = 0; 37 | 38 | if (o.debug || o.nCores == 1) 39 | for k = 1:length(o.problems) 40 | ret = o.runStep(k); 41 | fprintf(output.print(ret)); 42 | 43 | total_time = total_time + ret.time; 44 | success = success + ret.success; 45 | end 46 | else 47 | if (o.nCores ~= +Inf) 48 | delete(gcp('nocreate')); 49 | parpool('local', o.nCores); 50 | end 51 | 52 | parfor k = 1:length(o.problems) 53 | ret = o.runStep(k); 54 | fprintf(output.print(ret)); 55 | 56 | total_time = total_time + ret.time; 57 | success = success + ret.success; 58 | end 59 | end 60 | 61 | fprintf('%i/%i success\n', success, length(o.problems)) 62 | fprintf('Total time: %f\n', total_time) 63 | end 64 | end 65 | 66 | methods(Access = private) 67 | function ret = runStep(o, id) 68 | name = o.problems{id}; 69 | 70 | if o.randomSeed ~= 0, rng(o.randomSeed); end 71 | warning('off', 'MATLAB:nearlySingularMatrix'); 72 | warning('off', 'MATLAB:singularMatrix'); 73 | warning('off', 'MATLAB:rankDeficientMatrix'); 74 | warning('off', 'uniformtest:size'); 75 | warning('off', 'stats:adtest:OutOfRangePLow'); 76 | warning('off', 'stats:adtest:OutOfRangePHigh'); 77 | warning('off', 'stats:adtest:SmallSampleSize'); 78 | warning('off', 'stats:adtest:NotEnoughData'); 79 | 80 | t = tic; 81 | if (o.debug) 82 | ret = o.testFunc(name); 83 | else 84 | try 85 | ret = o.testFunc(name); 86 | catch s 87 | ret = struct; 88 | ret.success = 0; 89 | warning('problem = %s\n%s\n\n\n', name, getReport(s,'extended')); 90 | end 91 | end 92 | ret.time = toc(t); 93 | ret.id = id; 94 | ret.name = name; 95 | end 96 | end 97 | end -------------------------------------------------------------------------------- /coverage/coverage_test.m: -------------------------------------------------------------------------------- 1 | debug = 0; 2 | problems = []; 3 | folders = {'basic', 'metabolic', 'netlib'}; 4 | presolve_test(debug, folders, problems) 5 | %% 6 | 7 | folders = {'basic', 'metabolic', 'netlib'}; 8 | sample_test(debug, folders, problems, 100, 'uniform') 9 | sample_test(debug, folders, problems, 100, 'exponential') 10 | sample_test(debug, folders, problems, 100, 'normal') 11 | 12 | 13 | p_value_test(); 14 | -------------------------------------------------------------------------------- /coverage/p_value_test.m: -------------------------------------------------------------------------------- 1 | initSampler 2 | 3 | N = 1000; % number of experiment 4 | 5 | warning('off', 'uniformtest:size'); 6 | warning('off', 'stats:adtest:OutOfRangePLow'); 7 | warning('off', 'stats:adtest:OutOfRangePHigh'); 8 | warning('off', 'stats:adtest:SmallSampleSize'); 9 | warning('off', 'stats:adtest:NotEnoughData'); 10 | 11 | P = struct; d = 100; 12 | P.Aineq = ones(1, d); 13 | P.bineq = 1; 14 | P.ub = ones(d, 1); 15 | P.lb = zeros(d, 1); 16 | p = zeros(N,1); 17 | parfor i = 1:N 18 | opts = default_options(); 19 | opts.module = {'MixingTimeEstimator', 'MemoryStorage', 'DynamicRegularizer', 'DynamicWeight', 'DynamicStepSize'}; 20 | opts.seed = i; 21 | o = sample(P, 1000, opts); 22 | p(i) = uniformtest(o); 23 | end 24 | 25 | % Look by eye to see if this is uniform 26 | histogram(p, 20) 27 | [~, z] = kstest(norminv(p)); 28 | assert(z > 0.05); 29 | 30 | %% understand the worst case 31 | if (0) 32 | opts = default_options(); 33 | opts.module = {'MixingTimeEstimator', 'MemoryStorage', 'DynamicRegularizer', 'DynamicWeight', 'DynamicStepSize'}; 34 | opts.seed = find(p == min(p)); 35 | o = sample(P, 1000, opts); 36 | uniformtest(o, struct('toPlot', true)); 37 | end -------------------------------------------------------------------------------- /coverage/presolve_test.m: -------------------------------------------------------------------------------- 1 | function presolve_test(debug, folders, problems) 2 | s = TestSuite; 3 | if nargin >= 2 && ~isempty(folders) 4 | s.problemFilter.folders = folders; 5 | end 6 | if nargin >= 3 && ~isempty(problems) 7 | s.problems = problems; 8 | end 9 | s.randomSeed = 123456; 10 | s.nCores = +Inf; 11 | s.debug = debug; 12 | s.printFormat.m = '8i'; 13 | s.printFormat.n = '8i'; 14 | s.printFormat.nnz = '10i'; 15 | s.printFormat.mNew = '8i'; 16 | s.printFormat.nNew = '8i'; 17 | s.printFormat.nnzNew = '10i'; 18 | s.printFormat.opt = '14e'; 19 | s.printFormat.optNew = '14e'; 20 | s.printFormat.error = '10.3e'; 21 | s.printFormat.feasible = '10i'; 22 | s.printFormat.minDist = '10.3e'; 23 | s.testFunc = @test_func; 24 | s.test(); 25 | end 26 | 27 | function o = test_func(name) 28 | o = {}; 29 | P = loadProblem(name); 30 | rng(123456); % make sure it is reproducible 31 | 32 | %% Test 1: Check if the solution remains the same. 33 | 34 | if isempty(P.df) 35 | P.df = randn(size(P.lb)); 36 | end 37 | 38 | % first solve it using matlab LP solver 39 | P_opts = default_options(); 40 | P_opts.presolve.runSimplify = false; 41 | P.center = P.df; % avoid Polytope raise error of not finding a center 42 | P_opts.presolve.logFunc = @(tag, msg) 0; 43 | P0 = Polytope(P, P_opts); 44 | P.center = []; 45 | df = P0.df; 46 | 47 | opts = optimoptions('linprog','Display','none'); 48 | x0 = linprog(P0.T' * df, [], [], P0.A, P0.b, P0.barrier.lb, P0.barrier.ub, opts); 49 | if numel(x0) ~= 0 50 | o.opt = df' * (P0.T * x0 + P0.y); 51 | else 52 | o.opt = NaN; 53 | end 54 | 55 | % simplify the polytope and solve it again 56 | P_opts.presolve.runSimplify = true; 57 | P1 = Polytope(P, P_opts); 58 | df = P1.df; 59 | 60 | x1 = linprog(P1.T' * df, [], [], P1.A, P1.b, P1.barrier.lb, P1.barrier.ub, opts); 61 | if numel(x1) ~= 0 62 | o.optNew = df' * (P1.T * x1 + P1.y); 63 | else 64 | o.optNew = NaN; 65 | end 66 | 67 | o.feasible = P1.barrier.feasible(P1.center); 68 | o.error = abs(o.opt-o.optNew)/(abs(o.opt)+1e-4); 69 | 70 | %% Test 2: check if the analytic center is deep inside 71 | P.df = zeros(size(P.lb)); 72 | P3 = Polytope(P, P_opts); 73 | c = P3.center; 74 | o.minDist = min(c - P3.barrier.lb, P3.barrier.ub - c); 75 | o.minDist = o.minDist ./ min(P3.barrier.ub-P3.barrier.lb, 1); 76 | o.minDist = min(o.minDist); 77 | 78 | o.m = size(P0.A,1); o.n = size(P0.A,2); o.nnz = nnz(P0.A); 79 | o.mNew = size(P1.A,1); o.nNew = size(P1.A,2); o.nnzNew = nnz(P1.A); 80 | 81 | if (o.error <= 1e-4 && o.feasible && o.minDist > 1e-6) % = is important for NaN case 82 | o.success = 1; 83 | else 84 | o.success = 0; 85 | end 86 | end -------------------------------------------------------------------------------- /coverage/problems/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/.DS_Store -------------------------------------------------------------------------------- /coverage/problems/SOURCE.txt: -------------------------------------------------------------------------------- 1 | Source of the polytopes: 2 | 3 | Recon1 http://bigg.ucsd.edu/models/RECON1 4 | Recon2.04 https://www.vmh.life/#downloadview 5 | Recon3D_3.01 https://www.vmh.life/#downloadview 6 | netlib https://www.netlib.org/lp/data/index.html 7 | -------------------------------------------------------------------------------- /coverage/problems/basic/birkhoff.m: -------------------------------------------------------------------------------- 1 | function P = birkhoff(dim) 2 | d = ceil(sqrt(dim)); 3 | P = struct; 4 | P.lb = zeros(d^2,1); 5 | P.Aeq = sparse(2*d,d^2); 6 | P.beq = ones(2*d,1); 7 | for i=1:d 8 | P.Aeq(i,(i-1)*d+1:i*d) = 1; 9 | P.Aeq(d+i,i:d:d^2)=1; 10 | end 11 | end 12 | -------------------------------------------------------------------------------- /coverage/problems/basic/long_box.m: -------------------------------------------------------------------------------- 1 | function P = long_box(dim) 2 | P = struct; 3 | P.lb = -0.5*ones(dim,1); 4 | P.ub = 0.5*ones(dim,1); 5 | P.ub(1) = 1e6; 6 | end -------------------------------------------------------------------------------- /coverage/problems/basic/polytope_box.m: -------------------------------------------------------------------------------- 1 | function P = polytope_box(dim) 2 | P = struct; 3 | P.lb = -0.5*ones(dim,1); 4 | P.ub = 0.5*ones(dim,1); 5 | end 6 | -------------------------------------------------------------------------------- /coverage/problems/basic/random_sparse.m: -------------------------------------------------------------------------------- 1 | function P = random_sparse(dim) 2 | facets = 5*dim; 3 | P = struct; 4 | P.Aineq = zeros(facets,dim); 5 | k = 5; 6 | for i=1:facets 7 | coords = randperm(dim,k); 8 | coord_signs = sign(randn(1,k)); 9 | P.Aineq(i,coords) = coord_signs/sqrt(k); 10 | end 11 | P.lb = -sqrt(dim)*ones(dim,1); 12 | P.ub = +sqrt(dim)*ones(dim,1); 13 | P.bineq = ones(facets,1); 14 | end -------------------------------------------------------------------------------- /coverage/problems/basic/simplex.m: -------------------------------------------------------------------------------- 1 | function P = simplex(dim) 2 | P = struct; 3 | P.Aeq = ones(1,dim); 4 | P.beq = 1; 5 | P.lb = zeros(dim,1); 6 | end -------------------------------------------------------------------------------- /coverage/problems/basic/tv_ball.m: -------------------------------------------------------------------------------- 1 | function P = tv_ball(dim) 2 | e = ones(dim,1); 3 | P = struct; 4 | P.Aeq = [spdiags([e -e], 0:1, dim-1, dim) spdiags(e, 0, dim-1, dim-1)]; 5 | P.beq = zeros(dim-1,1); 6 | P.lb = -ones(2*dim-1,1); 7 | P.ub = ones(2*dim-1,1); 8 | P.lb(2:(dim-1)) = -Inf; 9 | P.ub(2:(dim-1)) = Inf; 10 | end -------------------------------------------------------------------------------- /coverage/problems/basic/tv_ball2.m: -------------------------------------------------------------------------------- 1 | function P = tv_ball2(dim) 2 | e = ones(dim,1); 3 | P = struct; 4 | P.Aeq = [spdiags([e -e], 0:1, dim-1, dim) spdiags(e, 0, dim-1, dim-1)]; 5 | P.beq = zeros(dim-1,1); 6 | P.lb = -ones(2*dim-1,1); 7 | P.ub = ones(2*dim-1,1); 8 | P.lb(1:dim) = -10*sqrt(dim); 9 | P.ub(1:dim) = 10*sqrt(dim); 10 | end -------------------------------------------------------------------------------- /coverage/problems/loadProblem.m: -------------------------------------------------------------------------------- 1 | function P = loadProblem(name) 2 | path_size = split(name,'@'); 3 | path_folders = split(path_size{1},'/'); 4 | curFolder = fileparts(mfilename('fullpath')); 5 | path = fullfile(curFolder, path_folders{:}); 6 | 7 | % check if the file exists as a mat 8 | if exist([path '.mat'], 'file') && length(path_size) == 1 9 | load([path '.mat'], 'problem'); 10 | P = problem; 11 | elseif exist([path '.m'], 'file') && length(path_size) == 2 12 | [folder,name] = fileparts(path); 13 | prevDir = pwd; 14 | cd(folder); 15 | h = str2func(['@' name]); 16 | cd(prevDir); 17 | 18 | scurr = rng; 19 | rng(123456); % fix the seed for random generator 20 | P = h(str2double(path_size{2})); 21 | rng(scurr); 22 | else 23 | error(['Problem ' name ' does not exists']); 24 | end 25 | 26 | P = standardize_problem(P); 27 | 28 | % netlib LP are not bounded 29 | if contains(name, 'netlib') 30 | df = P.df; % the df for netlib problem is a fixed vector 31 | x = linprog(df, P.Aineq, P.bineq, P.Aeq, P.beq, P.lb, P.ub, struct('Display','none')); 32 | threshold = df' * x + abs(df)' * abs(x); 33 | P.f = []; 34 | P.df = []; 35 | P.ddf = []; 36 | P.Aineq = [P.Aineq; df']; 37 | P.bineq = [P.bineq; threshold]; 38 | P.ub = min(P.ub, 2 * max(abs(x))); 39 | P.lb = max(P.lb, -2 * max(abs(x))); 40 | end 41 | -------------------------------------------------------------------------------- /coverage/problems/metabolic/Acidaminococcus_sp_D21.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/metabolic/Acidaminococcus_sp_D21.mat -------------------------------------------------------------------------------- /coverage/problems/metabolic/Recon1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/metabolic/Recon1.mat -------------------------------------------------------------------------------- /coverage/problems/metabolic/Recon2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/metabolic/Recon2.mat -------------------------------------------------------------------------------- /coverage/problems/metabolic/Recon3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/metabolic/Recon3.mat -------------------------------------------------------------------------------- /coverage/problems/metabolic/cardiac_mit_glcuptake_atpmax.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/metabolic/cardiac_mit_glcuptake_atpmax.mat -------------------------------------------------------------------------------- /coverage/problems/netlib/25fv47.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/netlib/25fv47.mat -------------------------------------------------------------------------------- /coverage/problems/netlib/80bau3b.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/netlib/80bau3b.mat -------------------------------------------------------------------------------- /coverage/problems/netlib/afiro.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/netlib/afiro.mat -------------------------------------------------------------------------------- /coverage/problems/netlib/agg.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/netlib/agg.mat -------------------------------------------------------------------------------- /coverage/problems/netlib/beaconfd.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/netlib/beaconfd.mat -------------------------------------------------------------------------------- /coverage/problems/netlib/blend.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/netlib/blend.mat -------------------------------------------------------------------------------- /coverage/problems/netlib/degen2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/netlib/degen2.mat -------------------------------------------------------------------------------- /coverage/problems/netlib/degen3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/netlib/degen3.mat -------------------------------------------------------------------------------- /coverage/problems/netlib/etamacro.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/netlib/etamacro.mat -------------------------------------------------------------------------------- /coverage/problems/netlib/scorpion.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/netlib/scorpion.mat -------------------------------------------------------------------------------- /coverage/problems/netlib/sierra.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/netlib/sierra.mat -------------------------------------------------------------------------------- /coverage/problems/netlib/truss.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConstrainedSampler/PolytopeSamplerMatlab/f85505cc33ddef08982c0b1bf78eaccfe6d4ce47/coverage/problems/netlib/truss.mat -------------------------------------------------------------------------------- /coverage/problems/problemList.m: -------------------------------------------------------------------------------- 1 | function l = problemList(options) 2 | if ~exist('options', 'var'), options = struct(); end 3 | problems = {}; 4 | 5 | default.fileSizeLimit = [0 +Inf]; % in bytes 6 | default.folders = {'basic', 'metabolic', 'netlib'}; 7 | default.ignoreProblems = {}; 8 | default.generateDimensions = [10 100 1000 10000]; 9 | 10 | o = Setfield(default, options); 11 | 12 | curFolder = fileparts(mfilename('fullpath')); 13 | 14 | for j = 1:length(o.folders) 15 | files = dir(fullfile(curFolder, o.folders{j}, '*.m*')); 16 | for k = 1:length(files) 17 | file = fullfile(files(k).folder, files(k).name); 18 | [~,name,ext] = fileparts(file); 19 | name = strcat(o.folders{j}, '/', name); 20 | 21 | if strcmp(ext, '.mat') == 1 22 | s = dir(file); 23 | if (s.bytes < o.fileSizeLimit(1) || s.bytes > o.fileSizeLimit(2)) 24 | continue; 25 | end 26 | problems{end+1} = name; 27 | elseif strcmp(ext, '.m') == 1 28 | for l = 1:length(o.generateDimensions) 29 | problems{end+1} = [name '@' num2str(o.generateDimensions(l))]; 30 | end 31 | end 32 | end 33 | end 34 | 35 | % removed ignored problems 36 | l = {}; 37 | for j = 1:length(problems) 38 | name = problems{j}; 39 | ignored = 0; 40 | for k = 1:length(o.ignoreProblems) 41 | if ~isempty(regexp(name, o.ignoreProblems{k})) 42 | ignored = 1; 43 | break; 44 | end 45 | end 46 | if ignored == 0 47 | l{end+1} = name; 48 | end 49 | end 50 | end 51 | -------------------------------------------------------------------------------- /coverage/sample_test.m: -------------------------------------------------------------------------------- 1 | function sample_test(debug, folders, problems, num_samples, func) 2 | s = TestSuite; 3 | if nargin >= 2 && ~isempty(folders) 4 | s.problemFilter.folders = folders; 5 | end 6 | if nargin >= 3 && ~isempty(problems) 7 | s.problems = problems; 8 | end 9 | s.randomSeed = 123456; 10 | s.nCores = +Inf; 11 | s.debug = debug; 12 | s.printFormat.m = '8i'; 13 | s.printFormat.n = '8i'; 14 | s.printFormat.nnz = '10i'; 15 | s.printFormat.mixing = '15f'; 16 | s.printFormat.pVal = '10f'; 17 | s.printFormat.preTime = '8.2f'; 18 | s.printFormat.stepSize = '10f'; 19 | s.printFormat.nStep = '10i'; 20 | s.printFormat.avgAcc = '15.3e'; 21 | s.testFunc = @(name) test_func(name, num_samples, debug, func); 22 | s.test(); 23 | end 24 | 25 | function o = test_func(name, num_samples, debug, func) 26 | 27 | % load the problem and truncate it to make it bounded 28 | P = loadProblem(name); 29 | P_opts = default_options(); 30 | P_opts.maxTime = 3600*8; 31 | P_opts.module = {'MixingTimeEstimator', 'MemoryStorage', 'DynamicRegularizer', 'DynamicStepSize', 'DynamicWeight', 'DebugLogger'}; 32 | if debug 33 | P_opts.module{end+1} = 'ProgressBar'; 34 | end 35 | d = size(P.lb,1); 36 | 37 | switch func 38 | case 'uniform' 39 | case 'expoential' 40 | P.df = randn(d, 1); 41 | case 'normal' 42 | P.f = @(x) x'*x/2; 43 | P.df = @(x) x; 44 | P.ddf = @(x) ones(d,1); 45 | end 46 | 47 | sample_out = sample(P, num_samples, P_opts); 48 | 49 | o = {}; 50 | o.m = size(sample_out.sampler.ham.A,1); 51 | o.n = sample_out.sampler.ham.n; 52 | o.nnz = nnz(sample_out.sampler.ham.A); 53 | o.preTime = sample_out.prepareTime; 54 | o.stepSize = sample_out.sampler.stepSize; 55 | o.nStep = sample_out.totalStep; 56 | o.avgAcc = mean(sample_out.averageAccuracy); 57 | [o.pVal] = distribution_test(sample_out); 58 | o.mixing = sample_out.sampler.mixingTime; 59 | 60 | if (o.mixing < 500 && o.pVal > 0.005 && o.pVal < 0.995) 61 | o.success = 1; 62 | else 63 | o.success = 0; 64 | end 65 | end 66 | 67 | -------------------------------------------------------------------------------- /coverage/solver/solver_scalar_test.m: -------------------------------------------------------------------------------- 1 | function solver_scalar_test(file, high_acc) 2 | solver = @PackedChol0; 3 | 4 | load(file); 5 | A = problem.Aeq; 6 | A = [A speye(size(A,1))]; 7 | w = ones(size(A,2),1); 8 | uid = solver('init', uint64(1234), A); 9 | if (high_acc) 10 | solver('setAccuracyTarget', uid, 0.0); 11 | end 12 | acc = solver('decompose', uid, w); 13 | 14 | %% test lsc 15 | lsc = solver('leverageScoreComplement', uid, 0); 16 | 17 | corank1 = sum(lsc); 18 | corank2 = size(A,2)-size(A,1); 19 | 20 | assert(abs(corank1 - corank2)<0.01); 21 | assert(all(lsc > -1e-5)); 22 | assert(all(lsc < 1+1e-5)); 23 | 24 | %% test lsc with JL 25 | lsc_apx = solver('leverageScoreComplement', uid, 32); 26 | assert(max(abs(lsc_apx-lsc)) < 1.0); 27 | 28 | %% test diagL 29 | diagL = solver('diagL', uid); 30 | 31 | L = chol(A * diag(sparse(w)) * A', 'lower'); 32 | diagL2 = diag(L); 33 | assert(sum(abs(diagL - diagL2)) < 0.01); 34 | 35 | %% test logdet 36 | logdet = solver('logdet', uid); 37 | logdet2 = sum(log(diagL2)) * 2; 38 | 39 | assert(abs(logdet - logdet2) < 0.01); 40 | 41 | %% test solve 42 | b = randn(size(A,1),1); 43 | x = solver('solve', uid, b); 44 | x2 = L'\ (L \ b); 45 | 46 | assert(sum(abs(x - x2)) < 0.01); 47 | 48 | 49 | solver('delete', uid); 50 | end -------------------------------------------------------------------------------- /coverage/solver/solver_simd_test.m: -------------------------------------------------------------------------------- 1 | function solver_simd_test(file, high_acc) 2 | solver = @PackedChol4; 3 | 4 | load(file); 5 | A = problem.Aeq; 6 | A = [A speye(size(A,1))]; 7 | w = rand(4, size(A,2)) + 0.2; 8 | %w = ones(4,1) * rand(1, size(A,2)) + 0.2; 9 | %w = 5 * ones(4, size(A,2)); 10 | uid = solver('init', uint64(1234), A); 11 | if (high_acc) 12 | solver('setAccuracyTarget', uid, 0.0); 13 | end 14 | acc = solver('decompose', uid, w); 15 | 16 | %% test L 17 | L = solver('L', uid, 3); 18 | H = A * diag(sparse(w(4,:))) * A'; 19 | assert(sum(abs(H - L * L'),'all') < 0.01) 20 | 21 | %% test lsc 22 | lsc = solver('leverageScoreComplement', uid, 0); 23 | 24 | corank1 = sum(lsc, 2); 25 | corank2 = size(A,2)-size(A,1); 26 | 27 | assert(all(abs(corank1 - corank2)<0.01)); 28 | assert(all(lsc > -1e-5, 'all')); 29 | assert(all(lsc < 1+1e-5, 'all')); 30 | 31 | %% test lsc with JL 32 | lsc_apx = solver('leverageScoreComplement', uid, 32); 33 | assert(max(abs(lsc_apx-lsc), [], 'all') < 1.0); 34 | 35 | %% test diagL 36 | diagL = solver('diagL', uid); 37 | 38 | for k = 1:4 39 | L = chol(A * diag(sparse(w(k,:))) * A', 'lower'); 40 | diagL2 = diag(L); 41 | assert(sum(abs(diagL(k,:)' - diagL2)) < 0.01); 42 | end 43 | 44 | %% test logdet 45 | logdet = solver('logdet', uid); 46 | logdet2 = sum(log(diagL), 2) * 2; 47 | 48 | assert(all(abs(logdet - logdet2) < 0.01)); 49 | 50 | %% test solve 51 | b = randn(4, size(A,1)); 52 | x = solver('solve', uid, b); 53 | x2 = L'\ (L \ b(4,:)'); 54 | 55 | assert(sum(abs(x(4,:)' - x2)) < 0.01); 56 | 57 | solver('delete', uid); 58 | end -------------------------------------------------------------------------------- /coverage/solver/solver_test.m: -------------------------------------------------------------------------------- 1 | function solver_test() 2 | [path,~,~] = fileparts(mfilename('fullpath')); 3 | matrix_file = fullfile(path, '..', 'problems', 'netlib', 'degen2'); 4 | solver_scalar_test(matrix_file, false); 5 | solver_scalar_test(matrix_file, true); 6 | solver_zero_test(true, 0); 7 | solver_zero_test(true, 4); 8 | solver_zero_test(false, 0); 9 | solver_zero_test(false, 4); 10 | solver_simd_test(matrix_file, false); 11 | solver_simd_test(matrix_file, true); 12 | -------------------------------------------------------------------------------- /coverage/solver/solver_test_2.m: -------------------------------------------------------------------------------- 1 | [path,~,~] = fileparts(mfilename('fullpath')); 2 | file = fullfile(path, '..', 'problems', 'netlib', 'degen2'); 3 | load(file); 4 | A = problem.Aeq; 5 | A = [A 1e-8*speye(size(A,1))]; 6 | %rand('seed', 1); 7 | 8 | for k = 1:1000 9 | rand('seed', k); 10 | so1 = MultiMatlabSolver(A, 1e+100, 4); 11 | so2 = MexSolver(A, 1e+100, 4); 12 | 13 | w = rand(4, size(A,2)) + 0.2; 14 | so1.setScale(w); 15 | so2.setScale(w); 16 | 17 | [so1.accuracy so2.accuracy] 18 | %assert(norm(so1.diagL-so2.diagL) < 1) 19 | assert(all([so1.accuracy so2.accuracy] < 1000000,'all')) 20 | assert(~any(isnan([so1.accuracy so2.accuracy]),'all')) 21 | end -------------------------------------------------------------------------------- /coverage/solver/solver_zero_test.m: -------------------------------------------------------------------------------- 1 | function solver_zero_test(high_acc, simd_len) 2 | solver = str2func(['PackedChol' num2str(simd_len)]); 3 | 4 | A = sparse(5,5); 5 | 6 | if simd_len == 0 7 | w = ones(size(A,2),1); 8 | else 9 | w = ones(simd_len, size(A,2)); 10 | end 11 | 12 | uid = solver('init', uint64(1234), A); 13 | if (high_acc) 14 | solver('setAccuracyTarget', uid, 0.0); 15 | end 16 | acc = solver('decompose', uid, w); 17 | x = solver('leverageScoreComplement', uid, 0); 18 | x = solver('leverageScoreComplement', uid, 32); 19 | x = solver('diagL', uid); 20 | x = solver('L', uid); 21 | x = solver('logdet', uid); 22 | if simd_len == 0 23 | b = randn(size(A,1),1); 24 | else 25 | b = randn(simd_len, size(A,1)); 26 | end 27 | x = solver('solve', uid, b); -------------------------------------------------------------------------------- /demo.m: -------------------------------------------------------------------------------- 1 | %% Example 1: Sample uniform from a simplex 2 | initSampler 3 | 4 | P = struct; d = 100; 5 | P.Aineq = ones(1, d); 6 | P.bineq = 1; 7 | P.lb = zeros(d, 1); 8 | o = sample(P, 2000); % Number of samples = 2000 9 | s = o.samples; 10 | histogram(sum(s), 0.9:0.005:1) 11 | title('distribution of l1 norm of simplex'); 12 | distribution_test(o, struct('toPlot', true)); 13 | drawnow() 14 | 15 | %% Example 2: Sample uniform from Birkhoff polytope 16 | initSampler 17 | 18 | P = struct; d = 10; 19 | P.lb = zeros(d^2,1); 20 | P.Aeq = sparse(2*d,d^2); 21 | P.beq = ones(2*d,1); 22 | for i=1:d 23 | P.Aeq(i,(i-1)*d+1:i*d) = 1; 24 | P.Aeq(d+i,i:d:d^2)=1; 25 | end 26 | 27 | opts = default_options(); 28 | opts.maxTime = 20; % Stop in 20 sec 29 | opts.logging = 'demo_ignore.log'; % Output the debug log to demo_ignore.log 30 | o = sample(P, +Inf, opts); 31 | s = o.samples; 32 | figure; 33 | histogram(s(1,:), 'BinLimits', [0, 0.5]) 34 | title('Marginal of first coordinate of Birkhoff polytope'); 35 | drawnow() 36 | 37 | %% Example 3: Sample Gaussian distribution restricted to a hypercube 38 | initSampler 39 | 40 | P = struct; d = 1000; 41 | P.lb = -ones(d,1); 42 | P.ub = ones(d,1); 43 | P.f = @(x) x'*x/2; 44 | P.df = @(x) x; 45 | P.ddf = @(x) ones(d,1); 46 | 47 | opts = default_options(); 48 | opts.maxStep = 10000; % Stop after 10000 iter 49 | o = sample(P, +Inf, opts); 50 | figure; 51 | histogram(o.samples(:), 'BinLimits', [-1, 1]) 52 | title('Marginal of Gaussian distribution restricted to hypercube'); 53 | drawnow() 54 | 55 | %% Example 4: Brownian bridge 56 | initSampler 57 | 58 | P = struct; d = 1000; 59 | e = ones(d,1); 60 | P.Aeq = [spdiags([e -e], 0:1, d-1, d) spdiags(e, 0, d-1, d-1)]; 61 | P.beq = zeros(d-1,1); 62 | P.lb = -Inf*ones(2*d-1,1); 63 | P.ub = Inf*ones(2*d-1,1); 64 | P.lb([1 d]) = 0; 65 | P.ub([1 d]) = 0; 66 | 67 | P.f = @(x) x((d+1):end)'*x((d+1):end)/2; 68 | P.df = @(x) [zeros(d,1);x((d+1):end)]; 69 | P.ddf = @(x) [zeros(d,1);ones(d-1,1)]; 70 | 71 | o = sample(P, 100); 72 | figure; 73 | plot(o.samples(1:d, end)) 74 | title('Brownian bridge'); 75 | drawnow() 76 | 77 | %% Example 5: Read a polytope according to Cobra format 78 | initSampler 79 | 80 | load(fullfile('coverage','Recon1.mat')) 81 | P = struct; % Warning: Other Cobra models may have optional constraints (C,d) 82 | P.lb = model.lb; 83 | P.ub = model.ub; 84 | P.beq = model.b; 85 | P.Aeq = model.S; 86 | o = sample(P, 100); 87 | 88 | %% Example 6: Run the sampler in parallel 89 | initSampler 90 | if isempty(ver('parallel')) 91 | fprintf('Parallel Computing Toolbox is required for this example') 92 | else 93 | load(fullfile('coverage','Recon1.mat')) 94 | P = struct; 95 | P.lb = model.lb; 96 | P.ub = model.ub; 97 | P.beq = model.b; 98 | P.Aeq = model.S; 99 | opts = default_options(); 100 | opts.nWorkers = Inf; % Inf means the default number of workers in the Parallel Computing Toolbox 101 | o = sample(P, 200, opts); 102 | end 103 | -------------------------------------------------------------------------------- /initSampler.m: -------------------------------------------------------------------------------- 1 | function initSampler() 2 | addpath(genpath(fullfile('code'))); 3 | addpath(fullfile('bin')); 4 | end 5 | --------------------------------------------------------------------------------