├── .gitignore ├── .travis.yml ├── test.lua ├── README.md ├── amoeba.lua ├── lua_amoeba.c ├── test.c └── amoeba.h /.gitignore: -------------------------------------------------------------------------------- 1 | *.dll 2 | *.exp 3 | *.ilk 4 | *.lib 5 | *.pdb 6 | *.exe 7 | 8 | *.so 9 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: c 2 | 3 | compiler: 4 | - gcc 5 | 6 | before_install: 7 | - pip install --user urllib3[secure] cpp-coveralls 8 | 9 | # Work around https://github.com/eddyxu/cpp-coveralls/issues/108 by manually 10 | # installing the pyOpenSSL module and injecting it into urllib3 as per 11 | # https://urllib3.readthedocs.io/en/latest/user-guide.html#ssl-py2 12 | - sed -i -e '/^import sys$/a import urllib3.contrib.pyopenssl\nurllib3.contrib.pyopenssl.inject_into_urllib3()' `which coveralls` 13 | 14 | 15 | install: 16 | - gcc -shared -Wall -O3 -Wextra -pedantic -std=c89 -xc amoeba.h -o amoeba.so 17 | - gcc -Wall -fprofile-arcs -ftest-coverage -O0 -Wextra -pedantic -std=c89 test.c -o test 18 | 19 | script: 20 | - ./test 21 | 22 | after_success: 23 | - coveralls 24 | 25 | notifications: 26 | email: 27 | on_success: change 28 | on_failure: always 29 | 30 | -------------------------------------------------------------------------------- /test.lua: -------------------------------------------------------------------------------- 1 | package.path = "" 2 | local amoeba = require "amoeba" 3 | 4 | local S = amoeba.new() 5 | print(S) 6 | local xl, xm, xr = 7 | S:var "xl", S:var "xm", S:var "xr" 8 | print(xl) 9 | print(xm) 10 | print(xr) 11 | print(S:constraint() 12 | :add(xl):add(10) 13 | :relation "le" -- or "<=" 14 | :add(xr)) 15 | S:addconstraint((xm*2) :eq (xl + xr)) 16 | S:addconstraint( 17 | S:constraint() 18 | :add(xl):add(10) 19 | :relation "le" -- or "<=" 20 | :add(xr)) -- (xl + 10) :le (xr) 21 | S:addconstraint( 22 | S:constraint()(xr) "<=" (100)) -- (xr) :le (100) 23 | S:addconstraint((xl) :ge (0)) 24 | print(S) 25 | print(xl) 26 | print(xm) 27 | print(xr) 28 | 29 | print('suggest xm to 0') 30 | S:suggest(xm, 0) 31 | print(S) 32 | print(xl) 33 | print(xm) 34 | print(xr) 35 | 36 | print('suggest xm to 70') 37 | S:suggest(xm, 70) 38 | print(S) 39 | print(xl) 40 | print(xm) 41 | print(xr) 42 | 43 | print('delete edit xm') 44 | S:deledit(xm) 45 | print(S) 46 | print(xl) 47 | print(xm) 48 | print(xr) 49 | 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Amoeba -- the constraint solving algorithm in pure C 2 | 3 | [![Build Status](https://travis-ci.org/starwing/amoeba.svg?branch=master)](https://travis-ci.org/starwing/amoeba) 4 | [![Coverage Status](https://coveralls.io/repos/github/starwing/amoeba/badge.svg?branch=master)](https://coveralls.io/github/starwing/amoeba?branch=master) 5 | 6 | Amoeba is a pure C implement of Cassowary algorithm. 7 | Amoeba use Clean C, which is the cross set of ANSI C89 and C++, like 8 | the Lua language. 9 | 10 | Amoeba is a single-file library, for more single-file library, see the 11 | stb project [here][1]. 12 | 13 | Amoeba largely impressed by [kiwi][2], the C++ implement of Cassowary 14 | algorithm, and the algorithm [paper][3]. 15 | 16 | Amoeba ships a hand written Lua binding. 17 | 18 | Amoeba has the same license with the [Lua language][4]. 19 | 20 | [1]: https://github.com/nothings/stb 21 | [2]: https://github.com/nucleic/kiwi 22 | [3]: http://constraints.cs.washington.edu/solvers/uist97.html 23 | [4]: https://www.lua.org/license.html 24 | 25 | -------------------------------------------------------------------------------- /amoeba.lua: -------------------------------------------------------------------------------- 1 | 2 | local function meta(name, parent) 3 | t = {} 4 | t.__name = name 5 | t.__index = t 6 | return setmetatable(t, parent) 7 | end 8 | 9 | local function approx(a, b) 10 | if a > b then return a - b < 1e-6 end 11 | return b - a < 1e-6 12 | end 13 | 14 | local function near_zero(n) 15 | return approx(n, 0.0) 16 | end 17 | 18 | local function default(t, k, nv) 19 | local v = t[k] 20 | if not v then v = nv or {}; t[k] = v end 21 | return v 22 | end 23 | 24 | local Variable, Expression, Constraint do 25 | 26 | Variable = meta "Variable" 27 | Expression = meta "Expression" 28 | Constraint = meta "Constraint" 29 | 30 | Constraint.REQUIRED = 1000000000.0 31 | Constraint.STRONG = 1000000.0 32 | Constraint.MEDIUM = 1000.0 33 | Constraint.WEAK = 1.0 34 | 35 | function Variable:__unm() return Expression.new(self, -1.0) end 36 | function Expression:__unm() return Expression.new(self):multiply(-1) end 37 | 38 | function Variable:__add(other) return Expression.new(self) + other end 39 | function Variable:__sub(other) return Expression.new(self) - other end 40 | function Variable:__mul(other) return Expression.new(self) * other end 41 | function Variable:__div(other) return Expression.new(self) / other end 42 | 43 | function Variable:le(other) return Expression.new(self):le(other) end 44 | function Variable:eq(other) return Expression.new(self):eq(other) end 45 | function Variable:ge(other) return Expression.new(self):ge(other) end 46 | 47 | function Expression:__add(other) return Expression.new(self):add(other) end 48 | function Expression:__sub(other) return Expression.new(self):add(-other) end 49 | function Expression:__mul(other) return Expression.new(self):multiply(other) end 50 | function Expression:__div(other) return Expression.new(self):multiply(1.0/other) end 51 | 52 | function Expression:le(other) return Constraint.new("<=", self, other) end 53 | function Expression:eq(other) return Constraint.new("==", self, other) end 54 | function Expression:ge(other) return Constraint.new(">=", self, other) end 55 | 56 | function Constraint:__call(...) return self:add(...) end 57 | 58 | function Variable.new(name, type, id) 59 | type = type or "external" 60 | assert(type == "external" or 61 | type == "slack" or 62 | type == "error" or 63 | type == "dummy", type) 64 | local self = { 65 | id = id, 66 | name = name, 67 | value = 0.0, 68 | type = type or "external", 69 | is_dummy = type == "dummy", 70 | is_slack = type == "slack", 71 | is_error = type == "error", 72 | is_external = type == "external", 73 | is_pivotable = type == "slack" or type == "error", 74 | is_restricted = type ~= "external", 75 | } 76 | return setmetatable(self, Variable) 77 | end 78 | 79 | function Variable:__tostring() 80 | return ("amoeba.Variable: %s = %g"):format(self.name, self.value) 81 | end 82 | 83 | function Expression.new(other, multiplier, constant) 84 | local self = setmetatable({}, Expression) 85 | return self:add(other, multiplier, constant) 86 | end 87 | 88 | function Expression:tostring() 89 | local t = { ("%g"):format(self.constant or 0.0) } 90 | for k, v in self:iter_vars() do 91 | t[#t+1] = v < 0.0 and ' - ' or ' + ' 92 | v = math.abs(v) 93 | if not approx(v, 1.0) then 94 | t[#t+1] = ("%g*"):format(v) 95 | end 96 | t[#t+1] = k.name 97 | end 98 | return table.concat(t) 99 | end 100 | 101 | function Expression:__tostring() 102 | return "Exp: "..self:tostring() 103 | end 104 | 105 | function Expression:add(other, multiplier, constant) 106 | if other == nil then return self end 107 | self.constant = (self.constant or 0.0) + (constant or 0.0) 108 | multiplier = multiplier or 1.0 109 | if tonumber(other) then 110 | self.constant = self.constant + other*multiplier 111 | return self 112 | end 113 | local mt = getmetatable(other) 114 | if mt == Variable then 115 | local multiplier = (self[other] or 0.0) + multiplier 116 | self[other] = not near_zero(multiplier) and multiplier or nil 117 | elseif mt == Expression then 118 | for k, v in pairs(other) do 119 | local multiplier = (self[k] or 0.0) + multiplier * v 120 | self[k] = not near_zero(multiplier) and multiplier or nil 121 | end 122 | self.constant = self.constant or 0.0 123 | else 124 | error("constant/variable/expression expected") 125 | end 126 | return self 127 | end 128 | 129 | function Expression:multiply(other) 130 | if tonumber(other) then 131 | for k, v in pairs(self) do 132 | self[k] = v * other 133 | end 134 | return self 135 | end 136 | local mt = getmetatable(other) 137 | if mt == Variable then 138 | return self:multiply(Expression.new(other)) 139 | elseif mt == Expression then 140 | if other:is_constant() then 141 | return self:multiply(other.constant) 142 | elseif self.constant then 143 | local constant = self.constant 144 | self.constant = 0.0 145 | return self:add(other):multiply(constant) 146 | end 147 | error("attempt to multiply two non-constant expression") 148 | else 149 | error("number/variable/constant expression expected") 150 | end 151 | end 152 | 153 | function Expression:choose_pivotable() 154 | for k, v in pairs(self) do 155 | if k.is_pivotable then 156 | return k 157 | end 158 | end 159 | end 160 | 161 | function Expression:is_constant() 162 | for k, v in self:iter_vars() do 163 | return false 164 | end 165 | return true 166 | end 167 | 168 | function Expression:solve_for(new, old) 169 | -- expr: old == a[n] *new + constant + a[i]* v[i]... 170 | -- => new == (1/a[n])*old - 1/a[n]*constant - (1/a[n])*a[i]*v[i]... 171 | local multiplier = assert(self[new]) 172 | assert(new ~= old and not near_zero(multiplier)) 173 | self[new] = nil 174 | local reciprocal = 1.0 / multiplier 175 | self:multiply(-reciprocal) 176 | if old then self[old] = reciprocal end 177 | return new 178 | end 179 | 180 | function Expression:substitute_out(var, expr) 181 | assert(var ~= "constant") 182 | local multiplier = self[var] 183 | if not multiplier then return end 184 | self[var] = nil 185 | self:add(expr, multiplier) 186 | end 187 | 188 | function Expression:iter_vars() 189 | return function(self, k) 190 | local k, v = next(self, k) 191 | if k == 'constant' then 192 | return next(self, k) 193 | end 194 | return k, v 195 | end, self 196 | end 197 | 198 | function Constraint.new(op, expr1, expr2, strength) 199 | local self = setmetatable({}, Constraint) 200 | if not op then 201 | self.expression = Expression.new() 202 | else 203 | self:relation(op) 204 | if self.op == '<=' then 205 | self.expression = Expression.new(expr2 or 0.0):add(expr1, -1.0) 206 | else 207 | self.expression = Expression.new(expr1 or 0.0):add(expr2, -1.0) 208 | end 209 | end 210 | return self:strength(strength or Constraint.REQUIRED) 211 | end 212 | 213 | function Constraint:__tostring() 214 | local repr = "amoeba.Constraint: ["..self.expression:tostring() 215 | if self.is_inequality then 216 | repr = repr .. " >= 0.0]" 217 | else 218 | repr = repr .. " == 0.0]" 219 | end 220 | return repr 221 | end 222 | 223 | function Constraint:add(other, multiplier, constant) 224 | if other == ">=" or other == "<=" or other == "==" then 225 | self:relation(other) 226 | else 227 | multiplier = multiplier or 1.0 228 | if self.op == '>=' then multiplier = -multiplier end 229 | self.expression:add(other, multiplier, constant) 230 | end 231 | return self 232 | end 233 | 234 | function Constraint:relation(op) 235 | assert(op == '==' or op == '<=' or op == '>=' or 236 | op == 'eq' or op == 'le' or op == 'ge', 237 | "op must be '==', '>=' or '<='") 238 | if op == 'eq' then op = '==' 239 | elseif op == 'le' then op = '<=' 240 | elseif op == 'ge' then op = '>=' end 241 | self.op = op 242 | if self.op ~= '==' then 243 | self.is_inequality = true 244 | end 245 | if self.op ~= '>=' and self.expression then 246 | self.expression:multiply(-1.0) 247 | end 248 | return self 249 | end 250 | 251 | function Constraint:strength(strength) 252 | if self.solver then 253 | self.solver:setstrength(self, strength) 254 | else 255 | self.weight = Constraint[strength] or tonumber(strength) or self.weight 256 | self.is_required = self.weight >= Constraint.REQUIRED 257 | end 258 | return self 259 | end 260 | 261 | function Constraint:clone(strength) 262 | local new = Constraint.new():strength(strength) 263 | new:add(self) 264 | new.op = self.op 265 | new.is_inequality = self.is_inequality 266 | return new 267 | end 268 | 269 | end 270 | 271 | local SimplexSolver = meta "SimplexSolver" do 272 | 273 | -- implements 274 | 275 | local function update_external_variables(self) 276 | for var in pairs(self.vars) do 277 | local row = self.rows[var] 278 | var.value = row and row.constant or 0.0 279 | end 280 | end 281 | 282 | local function substitute_out(self, var, expr) 283 | for k, row in pairs(self.rows) do 284 | row:substitute_out(var, expr) 285 | if k.is_restricted and row.constant < 0.0 then 286 | self.infeasible_rows[#self.infeasible_rows+1] = k 287 | end 288 | end 289 | self.objective:substitute_out(var, expr) 290 | end 291 | 292 | local function optimize(self, objective) 293 | objective = objective or self.objective 294 | while true do 295 | local entry, exit 296 | for var, multiplier in objective:iter_vars() do 297 | if not var.is_dummy and multiplier < 0.0 then 298 | entry = var 299 | break 300 | end 301 | end 302 | if not entry then return end 303 | 304 | local r = 0.0 305 | local min_ratio = math.huge 306 | for var, row in pairs(self.rows) do 307 | local multiplier = row[entry] 308 | if multiplier and var.is_pivotable and multiplier < 0.0 then 309 | r = -row.constant / multiplier 310 | if r < min_ratio or (approx(r, min_ratio) and 311 | var.id < exit.id) then 312 | min_ratio, exit = r, var 313 | end 314 | end 315 | end 316 | assert(exit, "objective function is unbounded") 317 | 318 | -- do pivot 319 | local row = self.rows[exit] 320 | self.rows[exit] = nil 321 | row:solve_for(entry, exit) 322 | substitute_out(self, entry, row) 323 | if objective ~= self.objective then 324 | objective:substitute_out(entry, row) 325 | end 326 | self.rows[entry] = row 327 | end 328 | end 329 | 330 | local function make_variable(self, type) 331 | local id = self.last_varid 332 | self.last_varid = id + 1 333 | local prefix = type == "eplus" and "ep" or 334 | type == "eminus" and "em" or 335 | type == "dummy" and "d" or 336 | type == "artificial" and "a" or "s" 337 | if not type or type == "artificial" then 338 | type = "slack" 339 | elseif type == "eplus" or type == "eminus" then 340 | type = "error" 341 | end 342 | return Variable.new(prefix..id, type, id) 343 | end 344 | 345 | local function make_expression(self, cons) 346 | local expr = Expression.new(cons.expression.constant) 347 | local var1, var2 348 | for k, v in cons.expression:iter_vars() do 349 | if not k.id then 350 | k.id = self.last_varid 351 | self.last_varid = k.id + 1 352 | end 353 | if not self.vars[k] then 354 | self.vars[k] = true 355 | end 356 | expr:add(self.rows[k] or k, v) 357 | end 358 | if cons.is_inequality then 359 | var1 = make_variable(self) -- slack 360 | expr[var1] = -1.0 361 | if not cons.is_required then 362 | var2 = make_variable(self, "eminus") 363 | expr[var2] = 1.0 364 | self.objective[var2] = cons.weight 365 | end 366 | elseif cons.is_required then 367 | var1 = make_variable(self, 'dummy') 368 | expr[var1] = 1.0 369 | else 370 | var1 = make_variable(self, 'eplus') 371 | var2 = make_variable(self, 'eminus') 372 | expr[var1] = -1.0 373 | expr[var2] = 1.0 374 | self.objective[var1] = cons.weight 375 | self.objective[var2] = cons.weight 376 | end 377 | if expr.constant < 0.0 then expr:multiply(-1.0) end 378 | return expr, var1, var2 379 | end 380 | 381 | local function choose_subject(self, expr, var1, var2) 382 | for k, v in expr:iter_vars() do 383 | if k.is_external then return k end 384 | end 385 | if var1 and var1.is_pivotable then return var1 end 386 | if var2 and var2.is_pivotable then return var2 end 387 | for k, v in expr:iter_vars() do 388 | if not k.is_dummy then return nil end -- no luck 389 | end 390 | if not near_zero(expr.constant) then 391 | return nil, "unsatisfiable required constraint added" 392 | end 393 | return var1 394 | end 395 | 396 | local function add_with_artificial_variable(self, expr) 397 | local a = make_variable(self, 'artificial') 398 | self.last_varid = self.last_varid - 1 399 | 400 | self.rows[a] = expr 401 | optimize(self, expr) 402 | local row = self.rows[a] 403 | self.rows[a] = nil 404 | 405 | local success = near_zero(expr.constant) 406 | if row then 407 | if row:is_constant() then 408 | return success 409 | end 410 | local entering = row:choose_pivotable() 411 | if not entering then return false end 412 | 413 | row:solve_for(entering, a) 414 | self.rows[entering] = row 415 | end 416 | 417 | for var, row in pairs(self.rows) do row[a] = nil end 418 | self.objective[a] = nil 419 | return success 420 | end 421 | 422 | local function get_marker_leaving_row(self, marker) 423 | local r1, r2 = math.huge, math.huge 424 | local first, second, third 425 | for var, row in pairs(self.rows) do 426 | local multiplier = row[marker] 427 | if multiplier then 428 | if var.is_external then 429 | third = var 430 | elseif multiplier < 0.0 then 431 | local r = -row.constant / multiplier 432 | if r < r1 then r1 = r; first = var end 433 | else 434 | local r = row.constant / multiplier 435 | if r < r2 then r2 = r; second = var end 436 | end 437 | end 438 | end 439 | return first or second or third 440 | end 441 | 442 | local function delta_edit_constant(self, delta, var1, var2) 443 | local row = self.rows[var1] 444 | if row then 445 | row.constant = row.constant - delta 446 | if row.constant < 0.0 then 447 | self.infeasible_rows[#self.infeasible_rows+1] = var1 448 | end 449 | return 450 | end 451 | local row = self.rows[var2] 452 | if row then 453 | row.constant = row.constant + delta 454 | if row.constant < 0.0 then 455 | self.infeasible_rows[#self.infeasible_rows+1] = var2 456 | end 457 | return 458 | end 459 | for var, row in pairs(self.rows) do 460 | row.constant = row.constant + (row[var1] or 0.0)*delta 461 | if var.is_restricted and row.constant < 0.0 then 462 | self.infeasible_rows[#self.infeasible_rows+1] = var 463 | end 464 | end 465 | end 466 | 467 | local function dual_optimize(self) 468 | while true do 469 | local count = #self.infeasible_rows 470 | if count == 0 then return end 471 | local exit = self.infeasible_rows[count] 472 | self.infeasible_rows[count] = nil 473 | 474 | local row = self.rows[exit] 475 | if row and row.constant < 0.0 then 476 | local entry 477 | local min_ratio = math.huge 478 | for var, multiplier in row:iter_vars() do 479 | if multiplier > 0.0 and not var.is_dummy then 480 | local r = (self.objective[var] or 0.0) / multiplier 481 | if r < min_ratio then 482 | min_ratio, entry = r, var 483 | end 484 | end 485 | end 486 | assert(entry, "dual optimize failed") 487 | 488 | -- pivot 489 | self.rows[exit] = nil 490 | row:solve_for(entry, exit) 491 | substitute_out(self, entry, row) 492 | self.rows[entry] = row 493 | end 494 | end 495 | end 496 | 497 | -- interface 498 | 499 | function SimplexSolver:hasvariable(var) return self.vars[var] end 500 | function SimplexSolver:hasconstraint(cons) return self.constraints[cons] end 501 | function SimplexSolver:hasedit(var) return self.edits[var] end 502 | function SimplexSolver:var(...) return Variable.new(...) end 503 | function SimplexSolver:constraint(...) return Constraint.new(...) end 504 | 505 | function SimplexSolver.new() 506 | local self = {} 507 | self.last_varid = 1 508 | 509 | self.vars = {} 510 | self.edits = {} 511 | self.constraints = {} 512 | 513 | self.objective = Expression.new() 514 | self.rows = {} 515 | self.infeasible_rows = {} 516 | 517 | return setmetatable(self, SimplexSolver) 518 | end 519 | 520 | function SimplexSolver:__tostring() 521 | local t = { "amoeba.Solver: {\n" } 522 | t[#t+1] = (" objective = %s\n"):format(self.objective:tostring()) 523 | if next(self.rows) then 524 | t[#t+1] = " rows:\n" 525 | local keys = {} 526 | for k, v in pairs(self.rows) do 527 | keys[#keys+1] = k 528 | end 529 | table.sort(keys, function(a, b) return a.id < b.id end) 530 | for idx, k in ipairs(keys) do local v = self.rows[k] 531 | t[#t+1] = (" %d. %s(%g) = %s\n"):format(idx, k.name, k.value, v:tostring()) 532 | end 533 | end 534 | if next(self.edits) then 535 | t[#t+1] = " edits:\n" 536 | local idx = 1 537 | for k, v in pairs(self.edits) do 538 | t[#t+1] = (" %d. %s = %s; info = { %s, %s, %g }\n"):format( 539 | idx, k.name, k.value, v.plus.name, v.minus.name, 540 | v.prev_constant) 541 | idx = idx + 1 542 | end 543 | end 544 | if #self.infeasible_rows ~= 0 then 545 | t[#t+1] = " infeasible_rows: {" 546 | for _, var in ipairs(self.infeasible_rows) do 547 | t[#t+1] = (" %s"):format(var.name) 548 | end 549 | t[#t+1] = " }\n" 550 | end 551 | if #self.vars ~= 0 then 552 | t[#t+1] = " vars: {" 553 | for var in pairs(self.vars) do 554 | t[#t+1] = (" %s"):format(var.name) 555 | end 556 | t[#t+1] = " }\n" 557 | end 558 | t[#t+1] = "}" 559 | return table.concat(t) 560 | end 561 | 562 | function SimplexSolver:addconstraint(cons, ...) 563 | if getmetatable(cons) ~= Constraint then 564 | cons = Constraint.new(cons, ...) 565 | end 566 | if self.constraints[cons] then return cons end 567 | local expr, var1, var2 = make_expression(self, cons) 568 | local subject, err = choose_subject(self, expr, var1, var2) 569 | if subject then 570 | expr:solve_for(subject) 571 | substitute_out(self, subject, expr) 572 | self.rows[subject] = expr 573 | elseif err then 574 | return nil, err 575 | elseif not add_with_artificial_variable(self, expr) then 576 | return nil, "constraint added may unbounded" 577 | end 578 | self.constraints[cons] = { 579 | marker = var1, 580 | other = var2, 581 | } 582 | cons.solver = self 583 | optimize(self) 584 | update_external_variables(self) 585 | return cons 586 | end 587 | 588 | function SimplexSolver:delconstraint(cons) 589 | local info = self.constraints[cons] 590 | if not info then return end 591 | self.constraints[cons] = nil 592 | 593 | if info.marker and info.marker.is_error then 594 | self.objective:add(self.rows[info.marker] or info.marker, -cons.weight) 595 | end 596 | if info.other and info.other.is_error then 597 | self.objective:add(self.rows[info.other] or info.other, -cons.weight) 598 | end 599 | if self.objective:is_constant() then 600 | self.objective.constant = 0.0 601 | end 602 | 603 | local row = self.rows[info.marker] 604 | if row then 605 | self.rows[info.marker] = nil 606 | else 607 | local var = assert(get_marker_leaving_row(self, info.marker), 608 | "failed to find leaving row") 609 | local row = self.rows[var] 610 | self.rows[var] = nil 611 | row:solve_for(info.marker, var) 612 | substitute_out(self, info.marker, row) 613 | end 614 | cons.solver = nil 615 | optimize(self) 616 | update_external_variables(self) 617 | return cons 618 | end 619 | 620 | function SimplexSolver:addedit(var, strength) 621 | if self.edits[var] then return end 622 | strength = strength or Constraint.MEDIUM 623 | assert(strength < Constraint.REQUIRED, "attempt to edit a required var") 624 | local cons = Constraint.new("==", var, var.value, strength) 625 | assert(self:addconstraint(cons)) 626 | local info = self.constraints[cons] 627 | self.edits[var] = { 628 | constraint = cons, 629 | plus = info.marker, 630 | minus = info.other, 631 | prev_constant = var.value or 0.0, 632 | } 633 | return self 634 | end 635 | 636 | function SimplexSolver:deledit(var) 637 | local info = self.edits[var] 638 | if info then 639 | self:delconstraint(info.constraint) 640 | self.edits[var] = nil 641 | end 642 | end 643 | 644 | function SimplexSolver:suggest(var, value) 645 | local info = self.edits[var] 646 | if not info then self:addedit(var); info = self.edits[var] end 647 | local delta = value - info.prev_constant 648 | info.prev_constant = value 649 | delta_edit_constant(self, delta, info.plus, info.minus) 650 | dual_optimize(self) 651 | update_external_variables(self) 652 | end 653 | 654 | function SimplexSolver:setstrength(cons, strength) 655 | local info = self.constraints[cons] 656 | if not info then cons.weight = strength end 657 | assert(info.marker and info.marker.is_error, "attempt to change required strength") 658 | local multiplier = strength / cons.strength 659 | cons.weight = strength 660 | self.is_required = self.weight >= Constraint.REQUIRED 661 | if near_zero(diff) then return self end 662 | 663 | self.objective:add(self.rows[info.marker] or info.marker, multiplier) 664 | self.objective:add(self.rows[info.other] or info.other, multiplier) 665 | optimize(self) 666 | update_external_variables(self) 667 | return self 668 | end 669 | 670 | function SimplexSolver:resolve() 671 | dual_optimize(self) 672 | set_external_variables() 673 | reset_stay_constant(self) 674 | self.infeasible_rows = {} 675 | end 676 | 677 | function SimplexSolver:set_constant(cons, constant) 678 | local info = self.constraints[cons] 679 | if not info then return end 680 | local delta = info.prev_constant - constant 681 | info.prev_constant = constant 682 | 683 | if info.marker.is_slack or cons.is_required then 684 | for var, row in pairs(self.rows) do 685 | row:add((row[info.marker] or 0.0) * -delta) 686 | if var.is_restricted and row.constant < 0.0 then 687 | self.infeasible_rows[#self.infeasible_rows+1] = var 688 | end 689 | end 690 | else 691 | delta_edit_constant(self, delta, info.marker, info.other) 692 | end 693 | dual_optimize(self) 694 | update_external_variables(self) 695 | end 696 | 697 | end 698 | 699 | return SimplexSolver 700 | -------------------------------------------------------------------------------- /lua_amoeba.c: -------------------------------------------------------------------------------- 1 | #define LUA_LIB 2 | #include 3 | #include 4 | #include 5 | 6 | #define AM_STATIC_API 7 | #include "amoeba.h" 8 | 9 | 10 | #define AML_SOLVER_TYPE "amoeba.Solver" 11 | #define AML_VAR_TYPE "amoeba.Variable" 12 | #define AML_CONS_TYPE "amoeba.Constraint" 13 | 14 | enum aml_ItemType { AML_VAR, AML_CONS, AML_CONSTANT }; 15 | 16 | typedef struct aml_Solver { 17 | am_Solver *solver; 18 | int ref_vars; 19 | int ref_cons; 20 | } aml_Solver; 21 | 22 | typedef struct aml_Var { 23 | am_Variable *var; 24 | aml_Solver *S; 25 | const char *name; 26 | } aml_Var; 27 | 28 | typedef struct aml_Cons { 29 | am_Constraint *cons; 30 | aml_Solver *S; 31 | } aml_Cons; 32 | 33 | typedef struct aml_Item { 34 | int type; 35 | am_Variable *var; 36 | am_Constraint *cons; 37 | am_Float value; 38 | } aml_Item; 39 | 40 | 41 | /* utils */ 42 | 43 | static int aml_argferror(lua_State *L, int idx, const char *fmt, ...) { 44 | va_list l; 45 | va_start(l, fmt); 46 | lua_pushvfstring(L, fmt, l); 47 | va_end(l); 48 | return luaL_argerror(L, idx, lua_tostring(L, -1)); 49 | } 50 | 51 | static int aml_typeerror(lua_State *L, int idx, const char *tname) { 52 | return aml_argferror(L, idx, "%s expected, got %s", 53 | tname, luaL_typename(L, idx)); 54 | } 55 | 56 | static void aml_setweak(lua_State *L, const char *mode) { 57 | lua_createtable(L, 0, 1); 58 | lua_pushstring(L, mode); 59 | lua_setfield(L, -2, "__mode"); 60 | lua_setmetatable(L, -2); 61 | } 62 | 63 | static am_Variable *aml_checkvar(lua_State *L, aml_Solver *S, int idx) { 64 | aml_Var *lvar = (aml_Var*)luaL_testudata(L, idx, AML_VAR_TYPE); 65 | const char *name; 66 | if (lvar != NULL) { 67 | if (lvar->var == NULL) luaL_argerror(L, idx, "invalid variable"); 68 | return lvar->var; 69 | } 70 | name = luaL_checkstring(L, idx); 71 | lua_rawgeti(L, LUA_REGISTRYINDEX, S->ref_vars); 72 | if (lua_getfield(L, -2, name) == LUA_TUSERDATA) { 73 | lua_remove(L, -2); 74 | return aml_checkvar(L, S, -1); 75 | } 76 | lua_pop(L, 2); 77 | aml_argferror(L, idx, "variable named '%s' not exists", 78 | lua_tostring(L, idx)); 79 | return NULL; 80 | } 81 | 82 | static aml_Cons *aml_newcons(lua_State *L, aml_Solver *S, am_Float strength) { 83 | aml_Cons *lcons = (aml_Cons*)lua_newuserdata(L, sizeof(aml_Cons)); 84 | lcons->cons = am_newconstraint(S->solver, strength); 85 | lcons->S = S; 86 | luaL_setmetatable(L, AML_CONS_TYPE); 87 | lua_rawgeti(L, LUA_REGISTRYINDEX, S->ref_cons); 88 | lua_pushvalue(L, -2); 89 | lua_rawsetp(L, -2, lcons); 90 | lua_pop(L, 1); 91 | return lcons; 92 | } 93 | 94 | static aml_Item aml_checkitem(lua_State *L, aml_Solver *S, int idx) { 95 | aml_Item item = { 0 }; 96 | aml_Cons *lcons; 97 | aml_Var *lvar; 98 | switch (lua_type(L, idx)) { 99 | case LUA_TSTRING: 100 | item.var = aml_checkvar(L, S, idx); 101 | item.type = AML_VAR; 102 | return item; 103 | case LUA_TNUMBER: 104 | item.value = lua_tonumber(L, idx); 105 | item.type = AML_CONSTANT; 106 | return item; 107 | case LUA_TUSERDATA: 108 | lcons = (aml_Cons*)luaL_testudata(L, idx, AML_CONS_TYPE); 109 | if (lcons) { 110 | if (lcons->cons == NULL) luaL_argerror(L, idx, "invalid constraint"); 111 | item.cons = lcons->cons; 112 | item.type = AML_CONS; 113 | return item; 114 | } 115 | lvar = luaL_testudata(L, idx, AML_VAR_TYPE); 116 | if (lvar) { 117 | if (lvar->var == NULL) luaL_argerror(L, idx, "invalid variable"); 118 | item.var = lvar->var; 119 | item.type = AML_VAR; 120 | return item; 121 | } 122 | /* FALLTHROUGHT */ 123 | default: 124 | aml_typeerror(L, idx, "number/string/variable/constraint"); 125 | } 126 | return item; 127 | } 128 | 129 | static aml_Solver *aml_checkitems(lua_State *L, int start, aml_Item *items) { 130 | aml_Var *lvar; 131 | aml_Cons *lcons; 132 | if ((lcons = (aml_Cons*)luaL_testudata(L, start, AML_CONS_TYPE)) != NULL) { 133 | items[0].type = AML_CONS, items[0].cons = lcons->cons; 134 | items[1] = aml_checkitem(L, lcons->S, start+1); 135 | return lcons->S; 136 | } 137 | if ((lcons = (aml_Cons*)luaL_testudata(L, start+1, AML_CONS_TYPE)) != NULL) { 138 | items[1].type = AML_CONS, items[1].cons = lcons->cons; 139 | items[0] = aml_checkitem(L, lcons->S, start); 140 | return lcons->S; 141 | } 142 | if ((lvar = (aml_Var*)luaL_testudata(L, start, AML_VAR_TYPE)) != NULL) { 143 | if (lvar->var == NULL) luaL_argerror(L, start, "invalid variable"); 144 | items[0].type = AML_VAR, items[0].var = lvar->var; 145 | items[1] = aml_checkitem(L, lvar->S, start+1); 146 | return lvar->S; 147 | } 148 | if ((lvar = (aml_Var*)luaL_testudata(L, start+1, AML_VAR_TYPE)) != NULL) { 149 | if (lvar->var == NULL) luaL_argerror(L, start+1, "invalid variable"); 150 | items[1].type = AML_VAR, items[1].var = lvar->var; 151 | items[0] = aml_checkitem(L, lvar->S, start); 152 | return lvar->S; 153 | } 154 | aml_typeerror(L, start, "variable/constraint"); 155 | return NULL; 156 | } 157 | 158 | static int aml_performitem(am_Constraint *cons, aml_Item *item, am_Float multiplier) { 159 | switch (item->type) { 160 | case AML_CONSTANT: return am_addconstant(cons, item->value*multiplier); break; 161 | case AML_VAR: return am_addterm(cons, item->var, multiplier); break; 162 | case AML_CONS: return am_mergeconstraint(cons, item->cons, multiplier); break; 163 | } 164 | return AM_FAILED; 165 | } 166 | 167 | static am_Float aml_checkstrength(lua_State *L, int idx, am_Float def) { 168 | int type = lua_type(L, idx); 169 | const char *s; 170 | switch (type) { 171 | case LUA_TSTRING: 172 | s = lua_tostring(L, idx); 173 | if (strcmp(s, "required") == 0) return AM_REQUIRED; 174 | if (strcmp(s, "strong") == 0) return AM_STRONG; 175 | if (strcmp(s, "medium") == 0) return AM_MEDIUM; 176 | if (strcmp(s, "weak") == 0) return AM_WEAK; 177 | aml_argferror(L, idx, "invalid strength value '%s'", s); 178 | break; 179 | case LUA_TNONE: 180 | case LUA_TNIL: return def; 181 | case LUA_TNUMBER: return lua_tonumber(L, idx); 182 | } 183 | aml_typeerror(L, idx, "number/string"); 184 | return 0.0f; 185 | } 186 | 187 | static int aml_checkrelation(lua_State *L, int idx) { 188 | const char *op = luaL_checkstring(L, 2); 189 | if (strcmp(op, "==") == 0) return AM_EQUAL; 190 | else if (strcmp(op, "<=") == 0) return AM_LESSEQUAL; 191 | else if (strcmp(op, ">=") == 0) return AM_GREATEQUAL; 192 | else if (strcmp(op, "eq") == 0) return AM_EQUAL; 193 | else if (strcmp(op, "le") == 0) return AM_LESSEQUAL; 194 | else if (strcmp(op, "ge") == 0) return AM_GREATEQUAL; 195 | return aml_argferror(L, 2, "invalid relation operator: '%s'", op); 196 | } 197 | 198 | static aml_Cons *aml_makecons(lua_State *L, aml_Solver *S, int start) { 199 | aml_Cons *lcons; 200 | int op = aml_checkrelation(L, start); 201 | am_Float strength = aml_checkstrength(L, start+3, AM_REQUIRED); 202 | aml_Item items[2]; 203 | aml_checkitems(L, start+1, items); 204 | lcons = aml_newcons(L, S, strength); 205 | aml_performitem(lcons->cons, &items[0], 1.0f); 206 | am_setrelation(lcons->cons, op); 207 | aml_performitem(lcons->cons, &items[1], 1.0f); 208 | return lcons; 209 | } 210 | 211 | static void aml_dumpkey(luaL_Buffer *B, int idx, am_Symbol sym) { 212 | lua_State *L = B->L; 213 | aml_Var *lvar; 214 | lua_rawgeti(L, idx, sym.id); 215 | lvar = (aml_Var*)luaL_testudata(L, -1, AML_VAR_TYPE); 216 | lua_pop(L, 1); 217 | if (lvar) luaL_addstring(B, lvar->name); 218 | else { 219 | int ch = 'v'; 220 | switch (sym.type) { 221 | case AM_EXTERNAL: ch = 'v'; break; 222 | case AM_SLACK: ch = 's'; break; 223 | case AM_ERROR: ch = 'e'; break; 224 | case AM_DUMMY: ch = 'd'; break; 225 | } 226 | lua_pushfstring(L, "%c%d", ch, sym.id); 227 | luaL_addvalue(B); 228 | } 229 | } 230 | 231 | static void aml_dumprow(luaL_Buffer *B, int idx, am_Row *row) { 232 | lua_State *L = B->L; 233 | am_Term *term = NULL; 234 | lua_pushfstring(L, "%f", row->constant); 235 | luaL_addvalue(B); 236 | while ((am_nextentry(&row->terms, (am_Entry**)&term))) { 237 | am_Float multiplier = term->multiplier; 238 | lua_pushfstring(L, " %c ", multiplier > 0.0f ? '+' : '-'); 239 | luaL_addvalue(B); 240 | if (multiplier < 0.0f) multiplier = -multiplier; 241 | if (!am_approx(multiplier, 1.0f)) { 242 | lua_pushfstring(L, "%f*", multiplier); 243 | luaL_addvalue(B); 244 | } 245 | aml_dumpkey(B, idx, am_key(term)); 246 | } 247 | } 248 | 249 | 250 | /* expression */ 251 | 252 | static int Lexpr_neg(lua_State *L) { 253 | aml_Cons *lcons = (aml_Cons*)luaL_checkudata(L, 1, AML_CONS_TYPE); 254 | aml_Cons *newcons = aml_newcons(L, lcons->S, AM_REQUIRED); 255 | am_mergeconstraint(newcons->cons, lcons->cons, -1.0f); 256 | return 1; 257 | } 258 | 259 | static int Lexpr_add(lua_State *L) { 260 | aml_Cons *lcons; 261 | aml_Item items[2]; 262 | aml_Solver *S = aml_checkitems(L, 1, items); 263 | lcons = aml_newcons(L, S, AM_REQUIRED); 264 | aml_performitem(lcons->cons, &items[0], 1.0f); 265 | aml_performitem(lcons->cons, &items[1], 1.0f); 266 | return 1; 267 | } 268 | 269 | static int Lexpr_sub(lua_State *L) { 270 | aml_Cons *lcons; 271 | aml_Item items[2]; 272 | aml_Solver *S = aml_checkitems(L, 1, items); 273 | lcons = aml_newcons(L, S, AM_REQUIRED); 274 | aml_performitem(lcons->cons, &items[0], 1.0f); 275 | aml_performitem(lcons->cons, &items[1], -1.0f); 276 | return 1; 277 | } 278 | 279 | static int Lexpr_mul(lua_State *L) { 280 | aml_Item items[2]; 281 | aml_Solver *S = aml_checkitems(L, 1, items); 282 | if (items[0].type == AML_CONSTANT) { 283 | aml_Cons *lcons = aml_newcons(L, S, AM_REQUIRED); 284 | aml_performitem(lcons->cons, &items[1], items[0].value); 285 | } 286 | else if (items[1].type == AML_CONSTANT) { 287 | aml_Cons *lcons = aml_newcons(L, S, AM_REQUIRED); 288 | aml_performitem(lcons->cons, &items[0], items[1].value); 289 | } 290 | else luaL_error(L, "attempt to multiply two expression"); 291 | return 1; 292 | } 293 | 294 | static int Lexpr_div(lua_State *L) { 295 | aml_Item items[2]; 296 | aml_Solver *S = aml_checkitems(L, 1, items); 297 | if (items[0].type == AML_CONSTANT) 298 | luaL_error(L, "attempt to divide a expression"); 299 | if (items[1].type == AML_CONSTANT) { 300 | aml_Cons *lcons = aml_newcons(L, S, AM_REQUIRED); 301 | aml_performitem(lcons->cons, &items[0], 1.0f/items[1].value); 302 | } 303 | else luaL_error(L, "attempt to divide two expression"); 304 | return 1; 305 | } 306 | 307 | static int Lexpr_cmp(lua_State *L, int op) { 308 | aml_Item items[2]; 309 | aml_Solver *S = aml_checkitems(L, 1, items); 310 | aml_Cons *lcons = aml_newcons(L, S, AM_REQUIRED); 311 | aml_performitem(lcons->cons, &items[0], 1.0f); 312 | am_setrelation(lcons->cons, op); 313 | aml_performitem(lcons->cons, &items[1], 1.0f); 314 | return 1; 315 | } 316 | 317 | static int Lexpr_le(lua_State *L) { return Lexpr_cmp(L, AM_LESSEQUAL); } 318 | static int Lexpr_eq(lua_State *L) { return Lexpr_cmp(L, AM_EQUAL); } 319 | static int Lexpr_ge(lua_State *L) { return Lexpr_cmp(L, AM_GREATEQUAL); } 320 | 321 | 322 | /* variable */ 323 | 324 | static int Lvar_new(lua_State *L) { 325 | aml_Solver *S = (aml_Solver*)luaL_checkudata(L, 1, AML_SOLVER_TYPE); 326 | aml_Var *lvar; 327 | int type = lua_type(L, 2); 328 | if (type != LUA_TNONE || type != LUA_TNIL) { 329 | if (type != LUA_TSTRING && type != LUA_TNUMBER) 330 | return aml_typeerror(L, 2, "number/string"); 331 | lua_rawgeti(L, LUA_REGISTRYINDEX, S->ref_vars); 332 | lua_pushvalue(L, 2); 333 | if (lua_rawget(L, -2) != LUA_TNIL) return 1; 334 | if (type == LUA_TNUMBER) 335 | aml_argferror(L, 2, "variable#%d not exists", lua_tointeger(L, 2)); 336 | lua_pop(L, 1); 337 | } 338 | lvar = (aml_Var*)lua_newuserdata(L, sizeof(aml_Var)); 339 | lvar->var = am_newvariable(S->solver); 340 | lvar->S = S; 341 | lvar->name = lua_tostring(L, 2); 342 | if (lvar->name == NULL) { 343 | lua_settop(L, 1); 344 | lua_pushfstring(L, "v%d", am_variableid(lvar->var)); 345 | lvar->name = lua_tostring(L, 2); 346 | } 347 | luaL_setmetatable(L, AML_VAR_TYPE); 348 | lua_rawgeti(L, LUA_REGISTRYINDEX, S->ref_vars); 349 | lua_pushvalue(L, -2); 350 | lua_setfield(L, -2, lvar->name); 351 | lua_pushvalue(L, -2); 352 | lua_rawseti(L, -2, am_variableid(lvar->var)); 353 | lua_pop(L, 1); 354 | return 1; 355 | } 356 | 357 | static int Lvar_delete(lua_State *L) { 358 | aml_Var *lvar = (aml_Var*)luaL_checkudata(L, 1, AML_VAR_TYPE); 359 | if (lvar->var == NULL) return 0; 360 | am_delvariable(lvar->var); 361 | lua_rawgeti(L, LUA_REGISTRYINDEX, lvar->S->ref_vars); 362 | lua_pushnil(L); 363 | lua_setfield(L, -2, lvar->name); 364 | lvar->var = NULL; 365 | lvar->name = NULL; 366 | return 0; 367 | } 368 | 369 | static int Lvar_value(lua_State *L) { 370 | aml_Var *lvar = (aml_Var*)luaL_checkudata(L, 1, AML_VAR_TYPE); 371 | if (lvar->var == NULL) luaL_argerror(L, 1, "invalid variable"); 372 | lua_pushnumber(L, am_value(lvar->var)); 373 | return 1; 374 | } 375 | 376 | static int Lvar_tostring(lua_State *L) { 377 | aml_Var *lvar = (aml_Var*)luaL_checkudata(L, 1, AML_VAR_TYPE); 378 | if (lvar->var) lua_pushfstring(L, AML_VAR_TYPE "(%p): %s = %f", 379 | lvar->var, lvar->name, am_value(lvar->var)); 380 | else lua_pushstring(L, AML_VAR_TYPE ": deleted"); 381 | return 1; 382 | } 383 | 384 | static void open_variable(lua_State *L) { 385 | luaL_Reg libs[] = { 386 | { "__neg", Lexpr_neg }, 387 | { "__add", Lexpr_add }, 388 | { "__sub", Lexpr_sub }, 389 | { "__mul", Lexpr_mul }, 390 | { "__div", Lexpr_div }, 391 | { "le", Lexpr_le }, 392 | { "eq", Lexpr_eq }, 393 | { "ge", Lexpr_ge }, 394 | { "__tostring", Lvar_tostring }, 395 | { "__gc", Lvar_delete }, 396 | #define ENTRY(name) { #name, Lvar_##name } 397 | ENTRY(new), 398 | ENTRY(delete), 399 | ENTRY(value), 400 | #undef ENTRY 401 | { NULL, NULL } 402 | }; 403 | if (luaL_newmetatable(L, AML_VAR_TYPE)) { 404 | luaL_setfuncs(L, libs, 0); 405 | lua_pushvalue(L, -1); 406 | lua_setfield(L, -2, "__index"); 407 | } 408 | } 409 | 410 | 411 | /* constraint */ 412 | 413 | static int Lcons_new(lua_State *L) { 414 | aml_Solver *S = (aml_Solver*)luaL_checkudata(L, 1, AML_SOLVER_TYPE); 415 | if (lua_gettop(L) >= 3) aml_makecons(L, S, 2); 416 | else aml_newcons(L, S, aml_checkstrength(L, 2, AM_REQUIRED)); 417 | return 1; 418 | } 419 | 420 | static int Lcons_delete(lua_State *L) { 421 | aml_Cons *lcons = (aml_Cons*)luaL_checkudata(L, 1, AML_CONS_TYPE); 422 | if (lcons->cons == NULL) return 0; 423 | am_delconstraint(lcons->cons); 424 | lcons->cons = NULL; 425 | lua_rawgeti(L, LUA_REGISTRYINDEX, lcons->S->ref_vars); 426 | lua_pushnil(L); 427 | lua_rawsetp(L, -2, lcons); 428 | return 0; 429 | } 430 | 431 | static int Lcons_reset(lua_State *L) { 432 | aml_Cons *lcons = (aml_Cons*)luaL_checkudata(L, 1, AML_CONS_TYPE); 433 | if (lcons->cons == NULL) luaL_argerror(L, 1, "invalid constraint"); 434 | am_resetconstraint(lcons->cons); 435 | lua_settop(L, 1); return 1; 436 | } 437 | 438 | static int Lcons_add(lua_State *L) { 439 | aml_Cons *lcons = (aml_Cons*)luaL_checkudata(L, 1, AML_CONS_TYPE); 440 | aml_Item item; 441 | int ret; 442 | if (lcons->cons == NULL) luaL_argerror(L, 1, "invalid constraint"); 443 | if (lua_type(L, 2) == LUA_TSTRING) { 444 | const char *s = lua_tostring(L, 2); 445 | if (s[0] == '<' || s[0] == '>' || s[0] == '=') { 446 | ret = am_setrelation(lcons->cons, aml_checkrelation(L, 2)); 447 | goto out; 448 | } 449 | } 450 | item = aml_checkitem(L, lcons->S, 2); 451 | ret = aml_performitem(lcons->cons, &item, 1.0f); 452 | out: 453 | if (ret != AM_OK) luaL_error(L, "constraint has been added to solver!"); 454 | lua_settop(L, 1); return 1; 455 | } 456 | 457 | static int Lcons_relation(lua_State *L) { 458 | aml_Cons *lcons = (aml_Cons*)luaL_checkudata(L, 1, AML_CONS_TYPE); 459 | int op = aml_checkrelation(L, 2); 460 | if (lcons->cons == NULL) luaL_argerror(L, 1, "invalid constraint"); 461 | if (am_setrelation(lcons->cons, op) != AM_OK) 462 | luaL_error(L, "constraint has been added to solver!"); 463 | lua_settop(L, 1); return 1; 464 | } 465 | 466 | static int Lcons_strength(lua_State *L) { 467 | aml_Cons *lcons = (aml_Cons*)luaL_checkudata(L, 1, AML_CONS_TYPE); 468 | am_Float strength = aml_checkstrength(L, 2, AM_REQUIRED); 469 | if (lcons->cons == NULL) luaL_argerror(L, 1, "invalid constraint"); 470 | if (am_setstrength(lcons->cons, strength) != AM_OK) 471 | luaL_error(L, "constraint has been added to solver!"); 472 | lua_settop(L, 1); return 1; 473 | } 474 | 475 | static int Lcons_tostring(lua_State *L) { 476 | aml_Cons *lcons = (aml_Cons*)luaL_checkudata(L, 1, AML_CONS_TYPE); 477 | luaL_Buffer B; 478 | if (lcons->cons == NULL) { 479 | lua_pushstring(L, AML_CONS_TYPE ": deleted"); 480 | return 1; 481 | } 482 | lua_settop(L, 1); 483 | lua_rawgeti(L, LUA_REGISTRYINDEX, lcons->S->ref_vars); 484 | luaL_buffinit(L, &B); 485 | lua_pushfstring(L, AML_CONS_TYPE "(%p): [", lcons->cons); 486 | luaL_addvalue(&B); 487 | aml_dumprow(&B, 2, &lcons->cons->expression); 488 | if (lcons->cons->relation == AM_EQUAL) 489 | luaL_addstring(&B, " == 0.0]"); 490 | else 491 | luaL_addstring(&B, " >= 0.0]"); 492 | if (lcons->cons->marker.id != 0) { 493 | luaL_addstring(&B, "(added:"); 494 | aml_dumpkey(&B, 2, lcons->cons->marker); 495 | luaL_addchar(&B, '-'); 496 | aml_dumpkey(&B, 2, lcons->cons->other); 497 | luaL_addchar(&B, ')'); 498 | } 499 | luaL_pushresult(&B); 500 | return 1; 501 | } 502 | 503 | static void open_constraint(lua_State *L) { 504 | luaL_Reg libs[] = { 505 | { "__call", Lcons_add }, 506 | { "__neg", Lexpr_neg }, 507 | { "__add", Lexpr_add }, 508 | { "__sub", Lexpr_sub }, 509 | { "__mul", Lexpr_mul }, 510 | { "__div", Lexpr_div }, 511 | { "le", Lexpr_le }, 512 | { "eq", Lexpr_eq }, 513 | { "ge", Lexpr_ge }, 514 | { "__bor", Lcons_strength }, 515 | { "__tostring", Lcons_tostring }, 516 | { "__gc", Lcons_delete }, 517 | #define ENTRY(name) { #name, Lcons_##name } 518 | ENTRY(new), 519 | ENTRY(delete), 520 | ENTRY(reset), 521 | ENTRY(add), 522 | ENTRY(relation), 523 | ENTRY(strength), 524 | #undef ENTRY 525 | { NULL, NULL } 526 | }; 527 | if (luaL_newmetatable(L, AML_CONS_TYPE)) { 528 | luaL_setfuncs(L, libs, 0); 529 | lua_pushvalue(L, -1); 530 | lua_setfield(L, -2, "__index"); 531 | } 532 | } 533 | 534 | 535 | /* solver */ 536 | 537 | static int Lnew(lua_State *L) { 538 | aml_Solver *S = lua_newuserdata(L, sizeof(aml_Solver)); 539 | if ((S->solver = am_newsolver(NULL, NULL, NULL)) == NULL) 540 | return 0; 541 | lua_createtable(L, 0, 4); aml_setweak(L, "v"); 542 | S->ref_vars = luaL_ref(L, LUA_REGISTRYINDEX); 543 | lua_createtable(L, 0, 4); aml_setweak(L, "v"); 544 | S->ref_cons = luaL_ref(L, LUA_REGISTRYINDEX); 545 | luaL_setmetatable(L, AML_SOLVER_TYPE); 546 | return 1; 547 | } 548 | 549 | static int Ldelete(lua_State *L) { 550 | aml_Solver *S = (aml_Solver*)luaL_checkudata(L, 1, AML_SOLVER_TYPE); 551 | if (S->solver == NULL) return 0; 552 | lua_rawgeti(L, LUA_REGISTRYINDEX, S->ref_vars); 553 | lua_pushnil(L); 554 | while (lua_next(L, -2)) { 555 | aml_Var *lvar = (aml_Var*)luaL_testudata(L, -1, AML_VAR_TYPE); 556 | if (lvar && lvar->var) { 557 | am_delvariable(lvar->var); 558 | lvar->var = NULL; 559 | lvar->name = NULL; 560 | } 561 | lua_pop(L, 1); 562 | } 563 | lua_rawgeti(L, LUA_REGISTRYINDEX, S->ref_cons); 564 | lua_pushnil(L); 565 | while (lua_next(L, -2)) { 566 | aml_Cons *lcons = (aml_Cons*)luaL_testudata(L, -1, AML_CONS_TYPE); 567 | if (lcons && lcons->cons) { 568 | am_delconstraint(lcons->cons); 569 | lcons->cons = NULL; 570 | } 571 | lua_pop(L, 1); 572 | } 573 | luaL_unref(L, LUA_REGISTRYINDEX, S->ref_vars); 574 | luaL_unref(L, LUA_REGISTRYINDEX, S->ref_cons); 575 | am_delsolver(S->solver); 576 | S->solver = NULL; 577 | return 0; 578 | } 579 | 580 | static int Ltostring(lua_State *L) { 581 | aml_Solver *S = (aml_Solver*)luaL_checkudata(L, 1, AML_SOLVER_TYPE); 582 | luaL_Buffer B; 583 | lua_settop(L, 1); 584 | lua_rawgeti(L, LUA_REGISTRYINDEX, S->ref_vars); 585 | luaL_buffinit(L, &B); 586 | lua_pushfstring(L, AML_SOLVER_TYPE "(%p): {", S->solver); 587 | luaL_addvalue(&B); 588 | luaL_addstring(&B, "\n objective = "); 589 | aml_dumprow(&B, 2, &S->solver->objective); 590 | if (S->solver->rows.count != 0) { 591 | am_Row *row = NULL; 592 | int idx = 0; 593 | lua_pushfstring(L, "\n rows(%d):", S->solver->rows.count); 594 | luaL_addvalue(&B); 595 | while (am_nextentry(&S->solver->rows, (am_Entry**)&row)) { 596 | lua_pushfstring(L, "\n %d. ", ++idx); 597 | luaL_addvalue(&B); 598 | aml_dumpkey(&B, 2, am_key(row)); 599 | luaL_addstring(&B, " = "); 600 | aml_dumprow(&B, 2, row); 601 | } 602 | } 603 | if (S->solver->infeasible_rows.id != 0) { 604 | am_Row *row = (am_Row*)am_gettable(&S->solver->rows, 605 | S->solver->infeasible_rows); 606 | luaL_addstring(&B, "\n infeasible rows: "); 607 | aml_dumpkey(&B, 2, am_key(row)); 608 | while (row != NULL) { 609 | luaL_addstring(&B, ", "); 610 | aml_dumpkey(&B, 2, am_key(row)); 611 | row = (am_Row*)am_gettable(&S->solver->rows, row->infeasible_next); 612 | } 613 | } 614 | luaL_addstring(&B, "\n}"); 615 | luaL_pushresult(&B); 616 | return 1; 617 | } 618 | 619 | static int Lreset(lua_State *L) { 620 | aml_Solver *S = (aml_Solver*)luaL_checkudata(L, 1, AML_SOLVER_TYPE); 621 | int clear = lua_toboolean(L, 2); 622 | am_resetsolver(S->solver, clear); 623 | lua_settop(L, 1); return 1; 624 | } 625 | 626 | static int Laddconstraint(lua_State *L) { 627 | aml_Solver *S = (aml_Solver*)luaL_checkudata(L, 1, AML_SOLVER_TYPE); 628 | aml_Cons *lcons = (aml_Cons*)luaL_testudata(L, 2, AML_CONS_TYPE); 629 | int ret; 630 | if (lcons == NULL) lcons = aml_makecons(L, S, 2); 631 | if ((ret = am_add(lcons->cons)) == AM_OK) 632 | { lua_settop(L, 1); return 1; } 633 | switch (ret) { 634 | case AM_UNSATISFIED: luaL_argerror(L, 2, "constraint unsatisfied"); 635 | case AM_UNBOUND: luaL_argerror(L, 2, "constraint unbound"); 636 | } 637 | return 0; 638 | } 639 | 640 | static int Ldelconstraint(lua_State *L) { 641 | luaL_checkudata(L, 1, AML_SOLVER_TYPE); 642 | aml_Cons *lcons = (aml_Cons*)luaL_checkudata(L, 2, AML_CONS_TYPE); 643 | am_remove(lcons->cons); 644 | lua_settop(L, 1); return 1; 645 | } 646 | 647 | static int Laddedit(lua_State *L) { 648 | aml_Solver *S = (aml_Solver*)luaL_checkudata(L, 1, AML_SOLVER_TYPE); 649 | am_Variable *var = aml_checkvar(L, S, 2); 650 | am_Float strength = aml_checkstrength(L, 3, AM_MEDIUM); 651 | am_addedit(var, strength); 652 | lua_settop(L, 1); return 1; 653 | } 654 | 655 | static int Ldeledit(lua_State *L) { 656 | aml_Solver *S = (aml_Solver*)luaL_checkudata(L, 1, AML_SOLVER_TYPE); 657 | am_Variable *var = aml_checkvar(L, S, 2); 658 | am_deledit(var); 659 | lua_settop(L, 1); return 1; 660 | } 661 | 662 | static int Lsuggest(lua_State *L) { 663 | aml_Solver *S = (aml_Solver*)luaL_checkudata(L, 1, AML_SOLVER_TYPE); 664 | am_Variable *var = aml_checkvar(L, S, 2); 665 | am_Float value = (am_Float)luaL_checknumber(L, 3); 666 | am_suggest(var, value); 667 | lua_settop(L, 1); return 1; 668 | } 669 | 670 | LUALIB_API int luaopen_amoeba(lua_State *L) { 671 | luaL_Reg libs[] = { 672 | { "var", Lvar_new }, 673 | { "constraint", Lcons_new }, 674 | { "__tostring", Ltostring }, 675 | #define ENTRY(name) { #name, L##name } 676 | ENTRY(new), 677 | ENTRY(delete), 678 | ENTRY(reset), 679 | ENTRY(addconstraint), 680 | ENTRY(delconstraint), 681 | ENTRY(addedit), 682 | ENTRY(deledit), 683 | ENTRY(suggest), 684 | #undef ENTRY 685 | { NULL, NULL } 686 | }; 687 | open_variable(L); 688 | open_constraint(L); 689 | if (luaL_newmetatable(L, AML_SOLVER_TYPE)) { 690 | luaL_setfuncs(L, libs, 0); 691 | lua_pushvalue(L, -1); 692 | lua_setfield(L, -2, "__index"); 693 | } 694 | return 1; 695 | } 696 | 697 | /* maccc: flags+='-undefined dynamic_lookup -bundle -O2' output='amoeba.so' 698 | * win32cc: flags+='-DLUA_BUILD_AS_DLL -shared -O3' libs+='-llua53' output='amoeba.dll' */ 699 | 700 | -------------------------------------------------------------------------------- /test.c: -------------------------------------------------------------------------------- 1 | #define AM_IMPLEMENTATION 2 | #include "amoeba.h" 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | static jmp_buf jbuf; 11 | static size_t allmem = 0; 12 | static size_t maxmem = 0; 13 | static void *END = NULL; 14 | 15 | static void *debug_allocf(void *ud, void *ptr, size_t ns, size_t os) { 16 | void *newptr = NULL; 17 | (void)ud; 18 | allmem += ns; 19 | allmem -= os; 20 | if (maxmem < allmem) maxmem = allmem; 21 | if (ns == 0) free(ptr); 22 | else { 23 | newptr = realloc(ptr, ns); 24 | if (newptr == NULL) longjmp(jbuf, 1); 25 | } 26 | #ifdef DEBUG_MEMORY 27 | printf("new(%p):\t+%d, old(%p):\t-%d\n", newptr, (int)ns, ptr, (int)os); 28 | #endif 29 | return newptr; 30 | } 31 | 32 | static void *null_allocf(void *ud, void *ptr, size_t ns, size_t os) 33 | { (void)ud, (void)ptr, (void)ns, (void)os; return NULL; } 34 | 35 | static void am_dumpkey(am_Symbol sym) { 36 | int ch = 'v'; 37 | switch (sym.type) { 38 | case AM_EXTERNAL: ch = 'v'; break; 39 | case AM_SLACK: ch = 's'; break; 40 | case AM_ERROR: ch = 'e'; break; 41 | case AM_DUMMY: ch = 'd'; break; 42 | } 43 | printf("%c%d", ch, (int)sym.id); 44 | } 45 | 46 | static void am_dumprow(am_Row *row) { 47 | am_Term *term = NULL; 48 | printf("%g", row->constant); 49 | while (am_nextentry(&row->terms, (am_Entry**)&term)) { 50 | am_Float multiplier = term->multiplier; 51 | printf(" %c ", multiplier > 0.0 ? '+' : '-'); 52 | if (multiplier < 0.0) multiplier = -multiplier; 53 | if (!am_approx(multiplier, 1.0f)) 54 | printf("%g*", multiplier); 55 | am_dumpkey(am_key(term)); 56 | } 57 | printf("\n"); 58 | } 59 | 60 | static void am_dumpsolver(am_Solver *solver) { 61 | am_Row *row = NULL; 62 | int idx = 0; 63 | printf("-------------------------------\n"); 64 | printf("solver: "); 65 | am_dumprow(&solver->objective); 66 | printf("rows(%d):\n", (int)solver->rows.count); 67 | while (am_nextentry(&solver->rows, (am_Entry**)&row)) { 68 | printf("%d. ", ++idx); 69 | am_dumpkey(am_key(row)); 70 | printf(" = "); 71 | am_dumprow(row); 72 | } 73 | printf("-------------------------------\n"); 74 | } 75 | 76 | static am_Constraint* new_constraint(am_Solver* in_solver, double in_strength, 77 | am_Variable* in_term1, double in_factor1, int in_relation, 78 | double in_constant, ...) 79 | { 80 | int result; 81 | va_list argp; 82 | am_Constraint* c; 83 | assert(in_solver && in_term1); 84 | c = am_newconstraint(in_solver, (am_Float)in_strength); 85 | if(!c) return 0; 86 | am_addterm(c, in_term1, (am_Float)in_factor1); 87 | am_setrelation(c, in_relation); 88 | if(in_constant) am_addconstant(c, (am_Float)in_constant); 89 | va_start(argp, in_constant); 90 | while(1) { 91 | am_Variable* va_term = va_arg(argp, am_Variable*); 92 | double va_factor = va_arg(argp, double); 93 | if(va_term == 0) break; 94 | am_addterm(c, va_term, (am_Float)va_factor); 95 | } 96 | va_end(argp); 97 | result = am_add(c); 98 | assert(result == AM_OK); 99 | return c; 100 | } 101 | 102 | static void test_all(void) { 103 | am_Solver *solver; 104 | am_Variable *xl; 105 | am_Variable *xm; 106 | am_Variable *xr; 107 | am_Variable *xd; 108 | am_Constraint *c1, *c2, *c3, *c4, *c5, *c6; 109 | int ret = setjmp(jbuf); 110 | printf("\n\n==========\ntest all\n"); 111 | printf("ret = %d\n", ret); 112 | if (ret < 0) { perror("setjmp"); return; } 113 | else if (ret != 0) { printf("out of memory!\n"); return; } 114 | 115 | solver = am_newsolver(null_allocf, NULL, NULL); 116 | assert(solver == NULL); 117 | 118 | solver = am_newsolver(NULL, NULL, NULL); 119 | assert(solver != NULL); 120 | am_delsolver(solver); 121 | 122 | solver = am_newsolver(debug_allocf, NULL, NULL); 123 | xl = am_newvariable(solver); 124 | xm = am_newvariable(solver); 125 | xr = am_newvariable(solver); 126 | 127 | assert(am_variableid(NULL) == -1); 128 | assert(am_variableid(xl) == 1); 129 | assert(am_variableid(xm) == 2); 130 | assert(am_variableid(xr) == 3); 131 | assert(!am_hasedit(NULL)); 132 | assert(!am_hasedit(xl)); 133 | assert(!am_hasedit(xm)); 134 | assert(!am_hasedit(xr)); 135 | assert(!am_hasconstraint(NULL)); 136 | 137 | xd = am_newvariable(solver); 138 | am_delvariable(xd); 139 | 140 | assert(am_setrelation(NULL, AM_GREATEQUAL) == AM_FAILED); 141 | 142 | c1 = am_newconstraint(solver, AM_REQUIRED); 143 | am_addterm(c1, xl, 1.0); 144 | am_setrelation(c1, AM_GREATEQUAL); 145 | ret = am_add(c1); 146 | assert(ret == AM_OK); 147 | am_dumpsolver(solver); 148 | 149 | assert(am_setrelation(c1, AM_GREATEQUAL) == AM_FAILED); 150 | assert(am_setstrength(c1, AM_REQUIRED-10) == AM_OK); 151 | assert(am_setstrength(c1, AM_REQUIRED) == AM_OK); 152 | 153 | assert(am_hasconstraint(c1)); 154 | assert(!am_hasedit(xl)); 155 | 156 | c2 = am_newconstraint(solver, AM_REQUIRED); 157 | am_addterm(c2, xl, 1.0); 158 | am_setrelation(c2, AM_EQUAL); 159 | ret = am_add(c2); 160 | assert(ret == AM_OK); 161 | am_dumpsolver(solver); 162 | 163 | am_resetsolver(solver, 1); 164 | am_delconstraint(c1); 165 | am_delconstraint(c2); 166 | am_dumpsolver(solver); 167 | 168 | /* c1: 2*xm == xl + xr */ 169 | c1 = am_newconstraint(solver, AM_REQUIRED); 170 | am_addterm(c1, xm, 2.0); 171 | am_setrelation(c1, AM_EQUAL); 172 | am_addterm(c1, xl, 1.0); 173 | am_addterm(c1, xr, 1.0); 174 | ret = am_add(c1); 175 | assert(ret == AM_OK); 176 | am_dumpsolver(solver); 177 | 178 | /* c2: xl + 10 <= xr */ 179 | c2 = am_newconstraint(solver, AM_REQUIRED); 180 | am_addterm(c2, xl, 1.0); 181 | am_addconstant(c2, 10.0); 182 | am_setrelation(c2, AM_LESSEQUAL); 183 | am_addterm(c2, xr, 1.0); 184 | ret = am_add(c2); 185 | assert(ret == AM_OK); 186 | am_dumpsolver(solver); 187 | 188 | /* c3: xr <= 100 */ 189 | c3 = am_newconstraint(solver, AM_REQUIRED); 190 | am_addterm(c3, xr, 1.0); 191 | am_setrelation(c3, AM_LESSEQUAL); 192 | am_addconstant(c3, 100.0); 193 | ret = am_add(c3); 194 | assert(ret == AM_OK); 195 | am_dumpsolver(solver); 196 | 197 | /* c4: xl >= 0 */ 198 | c4 = am_newconstraint(solver, AM_REQUIRED); 199 | am_addterm(c4, xl, 1.0); 200 | am_setrelation(c4, AM_GREATEQUAL); 201 | am_addconstant(c4, 0.0); 202 | ret = am_add(c4); 203 | assert(ret == AM_OK); 204 | am_dumpsolver(solver); 205 | 206 | c5 = am_cloneconstraint(c4, AM_REQUIRED); 207 | ret = am_add(c5); 208 | assert(ret == AM_OK); 209 | am_dumpsolver(solver); 210 | am_remove(c5); 211 | 212 | c5 = am_newconstraint(solver, AM_REQUIRED); 213 | am_addterm(c5, xl, 1.0); 214 | am_setrelation(c5, AM_EQUAL); 215 | am_addconstant(c5, 0.0); 216 | ret = am_add(c5); 217 | assert(ret == AM_OK); 218 | 219 | c6 = am_cloneconstraint(c4, AM_REQUIRED); 220 | ret = am_add(c6); 221 | assert(ret == AM_OK); 222 | am_dumpsolver(solver); 223 | 224 | am_resetconstraint(c6); 225 | am_delconstraint(c6); 226 | 227 | am_remove(c1); 228 | am_remove(c2); 229 | am_remove(c3); 230 | am_remove(c4); 231 | am_dumpsolver(solver); 232 | ret |= am_add(c4); 233 | ret |= am_add(c3); 234 | ret |= am_add(c2); 235 | ret |= am_add(c1); 236 | assert(ret == AM_OK); 237 | 238 | am_resetsolver(solver, 0); 239 | am_resetsolver(solver, 1); 240 | printf("after reset\n"); 241 | am_dumpsolver(solver); 242 | ret |= am_add(c1); 243 | ret |= am_add(c2); 244 | ret |= am_add(c3); 245 | ret |= am_add(c4); 246 | assert(ret == AM_OK); 247 | 248 | printf("after initialize\n"); 249 | am_dumpsolver(solver); 250 | am_updatevars(solver); 251 | printf("xl: %f, xm: %f, xr: %f\n", 252 | am_value(xl), 253 | am_value(xm), 254 | am_value(xr)); 255 | 256 | am_addedit(xm, AM_MEDIUM); 257 | am_dumpsolver(solver); 258 | am_updatevars(solver); 259 | printf("xl: %f, xm: %f, xr: %f\n", 260 | am_value(xl), 261 | am_value(xm), 262 | am_value(xr)); 263 | 264 | assert(am_hasedit(xm)); 265 | 266 | printf("suggest to 0.0\n"); 267 | am_suggest(xm, 0.0); 268 | am_dumpsolver(solver); 269 | am_updatevars(solver); 270 | printf("xl: %f, xm: %f, xr: %f\n", 271 | am_value(xl), 272 | am_value(xm), 273 | am_value(xr)); 274 | 275 | printf("suggest to 70.0\n"); 276 | am_suggest(xm, 70.0); 277 | am_updatevars(solver); 278 | am_dumpsolver(solver); 279 | 280 | printf("xl: %f, xm: %f, xr: %f\n", 281 | am_value(xl), 282 | am_value(xm), 283 | am_value(xr)); 284 | 285 | am_deledit(xm); 286 | am_updatevars(solver); 287 | am_dumpsolver(solver); 288 | 289 | printf("xl: %f, xm: %f, xr: %f\n", 290 | am_value(xl), 291 | am_value(xm), 292 | am_value(xr)); 293 | 294 | am_delsolver(solver); 295 | printf("allmem = %d\n", (int)allmem); 296 | printf("maxmem = %d\n", (int)maxmem); 297 | assert(allmem == 0); 298 | maxmem = 0; 299 | } 300 | 301 | static void test_binarytree(void) { 302 | const int NUM_ROWS = 9; 303 | const int X_OFFSET = 0; 304 | int nPointsCount, nResult, nRow; 305 | int nCurrentRowPointsCount = 1; 306 | int nCurrentRowFirstPointIndex = 0; 307 | am_Constraint *pC; 308 | am_Solver *pSolver; 309 | am_Variable **arrX, **arrY; 310 | 311 | printf("\n\n==========\ntest binarytree\n"); 312 | arrX = (am_Variable**)malloc(2048 * sizeof(am_Variable*)); 313 | if (arrX == NULL) return; 314 | arrY = arrX + 1024; 315 | 316 | /* Create set of rules to distribute vertexes of a binary tree like this one: 317 | * 0 318 | * / \ 319 | * / \ 320 | * 1 2 321 | * / \ / \ 322 | * 3 4 5 6 323 | */ 324 | 325 | pSolver = am_newsolver(debug_allocf, NULL, NULL); 326 | 327 | /* Xroot=500, Yroot=10 */ 328 | arrX[0] = am_newvariable(pSolver); 329 | arrY[0] = am_newvariable(pSolver); 330 | am_addedit(arrX[0], AM_STRONG); 331 | am_addedit(arrY[0], AM_STRONG); 332 | am_suggest(arrX[0], 500.0f + X_OFFSET); 333 | am_suggest(arrY[0], 10.0f); 334 | 335 | for (nRow = 1; nRow < NUM_ROWS; nRow++) { 336 | int nPreviousRowFirstPointIndex = nCurrentRowFirstPointIndex; 337 | int nPoint, nParentPoint = 0; 338 | nCurrentRowFirstPointIndex += nCurrentRowPointsCount; 339 | nCurrentRowPointsCount *= 2; 340 | 341 | for (nPoint = 0; nPoint < nCurrentRowPointsCount; nPoint++) { 342 | arrX[nCurrentRowFirstPointIndex + nPoint] = am_newvariable(pSolver); 343 | arrY[nCurrentRowFirstPointIndex + nPoint] = am_newvariable(pSolver); 344 | 345 | /* Ycur = Yprev_row + 15 */ 346 | pC = am_newconstraint(pSolver, AM_REQUIRED); 347 | am_addterm(pC, arrY[nCurrentRowFirstPointIndex + nPoint], 1.0); 348 | am_setrelation(pC, AM_EQUAL); 349 | am_addterm(pC, arrY[nCurrentRowFirstPointIndex - 1], 1.0); 350 | am_addconstant(pC, 15.0); 351 | nResult = am_add(pC); 352 | assert(nResult == AM_OK); 353 | 354 | if (nPoint > 0) { 355 | /* Xcur >= XPrev + 5 */ 356 | pC = am_newconstraint(pSolver, AM_REQUIRED); 357 | am_addterm(pC, arrX[nCurrentRowFirstPointIndex + nPoint], 1.0); 358 | am_setrelation(pC, AM_GREATEQUAL); 359 | am_addterm(pC, arrX[nCurrentRowFirstPointIndex + nPoint - 1], 1.0); 360 | am_addconstant(pC, 5.0); 361 | nResult = am_add(pC); 362 | assert(nResult == AM_OK); 363 | } else { 364 | /* When these lines added it crashes at the line 109 */ 365 | pC = am_newconstraint(pSolver, AM_REQUIRED); 366 | am_addterm(pC, arrX[nCurrentRowFirstPointIndex + nPoint], 1.0); 367 | am_setrelation(pC, AM_GREATEQUAL); 368 | am_addconstant(pC, 0.0); 369 | nResult = am_add(pC); 370 | assert(nResult == AM_OK); 371 | } 372 | 373 | if ((nPoint % 2) == 1) { 374 | /* Xparent = 0.5 * Xcur + 0.5 * Xprev */ 375 | pC = am_newconstraint(pSolver, AM_REQUIRED); 376 | am_addterm(pC, arrX[nPreviousRowFirstPointIndex + nParentPoint], 1.0); 377 | am_setrelation(pC, AM_EQUAL); 378 | am_addterm(pC, arrX[nCurrentRowFirstPointIndex + nPoint], 0.5); 379 | am_addterm(pC, arrX[nCurrentRowFirstPointIndex + nPoint - 1], 0.5); 380 | /* It crashes here (at the 3rd call of am_add(...))! */ 381 | nResult = am_add(pC); 382 | assert(nResult == AM_OK); 383 | 384 | nParentPoint++; 385 | } 386 | } 387 | } 388 | nPointsCount = nCurrentRowFirstPointIndex + nCurrentRowPointsCount; 389 | 390 | /*{ 391 | int i; 392 | for (i = 0; i < nPointsCount; i++) 393 | printf("Point %d: (%f, %f)\n", i, 394 | am_value(arrX[i]), am_value(arrY[i])); 395 | }*/ 396 | 397 | am_delsolver(pSolver); 398 | printf("allmem = %d\n", (int)allmem); 399 | printf("maxmem = %d\n", (int)maxmem); 400 | assert(allmem == 0); 401 | free(arrX); 402 | maxmem = 0; 403 | } 404 | 405 | static void test_unbounded(void) { 406 | am_Solver *solver; 407 | am_Variable *x, *y; 408 | am_Constraint *c; 409 | int ret = setjmp(jbuf); 410 | printf("\n\n==========\ntest unbound\n"); 411 | printf("ret = %d\n", ret); 412 | if (ret < 0) { perror("setjmp"); return; } 413 | else if (ret != 0) { printf("out of memory!\n"); return; } 414 | 415 | solver = am_newsolver(debug_allocf, NULL, NULL); 416 | x = am_newvariable(solver); 417 | y = am_newvariable(solver); 418 | 419 | /* 10.0 == 0.0 */ 420 | c = am_newconstraint(solver, AM_REQUIRED); 421 | am_addconstant(c, 10.0); 422 | am_setrelation(c, AM_EQUAL); 423 | ret = am_add(c); 424 | printf("ret = %d\n", ret); 425 | assert(ret == AM_UNSATISFIED); 426 | am_dumpsolver(solver); 427 | 428 | /* 0.0 == 0.0 */ 429 | c = am_newconstraint(solver, AM_REQUIRED); 430 | am_addconstant(c, 0.0); 431 | am_setrelation(c, AM_EQUAL); 432 | ret = am_add(c); 433 | printf("ret = %d\n", ret); 434 | assert(ret == AM_OK); 435 | am_dumpsolver(solver); 436 | 437 | am_resetsolver(solver, 1); 438 | 439 | /* x >= 10.0 */ 440 | c = am_newconstraint(solver, AM_REQUIRED); 441 | am_addterm(c, x, 1.0); 442 | am_setrelation(c, AM_GREATEQUAL); 443 | am_addconstant(c, 10.0); 444 | ret = am_add(c); 445 | printf("ret = %d\n", ret); 446 | assert(ret == AM_OK); 447 | am_dumpsolver(solver); 448 | 449 | /* x == 2*y */ 450 | c = am_newconstraint(solver, AM_REQUIRED); 451 | am_addterm(c, x, 1.0); 452 | am_setrelation(c, AM_EQUAL); 453 | am_addterm(c, y, 2.0); 454 | ret = am_add(c); 455 | printf("ret = %d\n", ret); 456 | assert(ret == AM_OK); 457 | am_dumpsolver(solver); 458 | 459 | /* y == 3*x */ 460 | c = am_newconstraint(solver, AM_REQUIRED); 461 | am_addterm(c, y, 1.0); 462 | am_setrelation(c, AM_EQUAL); 463 | am_addterm(c, x, 3.0); 464 | ret = am_add(c); 465 | printf("ret = %d\n", ret); 466 | assert(ret == AM_UNBOUND); 467 | am_dumpsolver(solver); 468 | 469 | am_resetsolver(solver, 1); 470 | 471 | /* x >= 10.0 */ 472 | c = am_newconstraint(solver, AM_REQUIRED); 473 | am_addterm(c, x, 1.0); 474 | am_setrelation(c, AM_GREATEQUAL); 475 | am_addconstant(c, 10.0); 476 | ret = am_add(c); 477 | printf("ret = %d\n", ret); 478 | assert(ret == AM_OK); 479 | am_dumpsolver(solver); 480 | 481 | /* x <= 0.0 */ 482 | c = am_newconstraint(solver, AM_REQUIRED); 483 | am_addterm(c, x, 1.0); 484 | am_setrelation(c, AM_LESSEQUAL); 485 | ret = am_add(c); 486 | printf("ret = %d\n", ret); 487 | assert(ret == AM_UNBOUND); 488 | am_dumpsolver(solver); 489 | 490 | printf("x: %f\n", am_value(x)); 491 | 492 | am_resetsolver(solver, 1); 493 | 494 | /* x == 10.0 */ 495 | c = am_newconstraint(solver, AM_REQUIRED); 496 | am_addterm(c, x, 1.0); 497 | am_setrelation(c, AM_EQUAL); 498 | am_addconstant(c, 10.0); 499 | ret = am_add(c); 500 | printf("ret = %d\n", ret); 501 | assert(ret == AM_OK); 502 | am_dumpsolver(solver); 503 | 504 | /* x == 20.0 */ 505 | c = am_newconstraint(solver, AM_REQUIRED); 506 | am_addterm(c, x, 1.0); 507 | am_setrelation(c, AM_EQUAL); 508 | am_addconstant(c, 20.0); 509 | ret = am_add(c); 510 | printf("ret = %d\n", ret); 511 | assert(ret == AM_UNSATISFIED); 512 | am_dumpsolver(solver); 513 | 514 | /* x == 10.0 */ 515 | c = am_newconstraint(solver, AM_REQUIRED); 516 | am_addterm(c, x, 1.0); 517 | am_setrelation(c, AM_EQUAL); 518 | am_addconstant(c, 10.0); 519 | ret = am_add(c); 520 | printf("ret = %d\n", ret); 521 | assert(ret == AM_OK); 522 | am_dumpsolver(solver); 523 | 524 | am_delsolver(solver); 525 | printf("allmem = %d\n", (int)allmem); 526 | printf("maxmem = %d\n", (int)maxmem); 527 | assert(allmem == 0); 528 | maxmem = 0; 529 | } 530 | 531 | static void test_strength(void) { 532 | am_Solver *solver; 533 | am_Variable *x, *y; 534 | am_Constraint *c; 535 | int ret = setjmp(jbuf); 536 | printf("\n\n==========\ntest strength\n"); 537 | printf("ret = %d\n", ret); 538 | if (ret < 0) { perror("setjmp"); return; } 539 | else if (ret != 0) { printf("out of memory!\n"); return; } 540 | 541 | solver = am_newsolver(debug_allocf, NULL, NULL); 542 | am_autoupdate(solver, 1); 543 | x = am_newvariable(solver); 544 | y = am_newvariable(solver); 545 | 546 | /* x <= y */ 547 | new_constraint(solver, AM_STRONG, x, 1.0, AM_LESSEQUAL, 0.0, 548 | y, 1.0, END); 549 | new_constraint(solver, AM_MEDIUM, x, 1.0, AM_EQUAL, 50, END); 550 | c = new_constraint(solver, AM_MEDIUM-10, y, 1.0, AM_EQUAL, 40, END); 551 | printf("%f, %f\n", am_value(x), am_value(y)); 552 | assert(am_value(x) == 50); 553 | assert(am_value(y) == 50); 554 | 555 | am_setstrength(c, AM_MEDIUM+10); 556 | printf("%f, %f\n", am_value(x), am_value(y)); 557 | assert(am_value(x) == 40); 558 | assert(am_value(y) == 40); 559 | 560 | am_setstrength(c, AM_MEDIUM-10); 561 | printf("%f, %f\n", am_value(x), am_value(y)); 562 | assert(am_value(x) == 50); 563 | assert(am_value(y) == 50); 564 | 565 | am_delsolver(solver); 566 | printf("allmem = %d\n", (int)allmem); 567 | printf("maxmem = %d\n", (int)maxmem); 568 | assert(allmem == 0); 569 | maxmem = 0; 570 | } 571 | 572 | static void test_suggest(void) { 573 | #if 1 574 | /* This should be valid but fails the (enter.id != 0) assertion in am_dual_optimize() */ 575 | am_Float strength1 = AM_REQUIRED; 576 | am_Float strength2 = AM_REQUIRED; 577 | am_Float width = 76; 578 | #else 579 | /* This mostly works, but still insists on forcing left_child_l = 0 which it should not */ 580 | am_Float strength1 = AM_STRONG; 581 | am_Float strength2 = AM_WEAK; 582 | am_Float width = 76; 583 | #endif 584 | am_Float delta = 0; 585 | am_Float pos; 586 | am_Solver *solver; 587 | am_Variable *splitter_l, *splitter_w, *splitter_r; 588 | am_Variable *left_child_l, *left_child_w, *left_child_r; 589 | am_Variable *splitter_bar_l, *splitter_bar_w, *splitter_bar_r; 590 | am_Variable *right_child_l, *right_child_w, *right_child_r; 591 | int ret = setjmp(jbuf); 592 | printf("\n\n==========\ntest suggest\n"); 593 | printf("ret = %d\n", ret); 594 | if (ret < 0) { perror("setjmp"); return; } 595 | else if (ret != 0) { printf("out of memory!\n"); return; } 596 | 597 | solver = am_newsolver(debug_allocf, NULL, NULL); 598 | splitter_l = am_newvariable(solver); 599 | splitter_w = am_newvariable(solver); 600 | splitter_r = am_newvariable(solver); 601 | left_child_l = am_newvariable(solver); 602 | left_child_w = am_newvariable(solver); 603 | left_child_r = am_newvariable(solver); 604 | splitter_bar_l = am_newvariable(solver); 605 | splitter_bar_w = am_newvariable(solver); 606 | splitter_bar_r = am_newvariable(solver); 607 | right_child_l = am_newvariable(solver); 608 | right_child_w = am_newvariable(solver); 609 | right_child_r = am_newvariable(solver); 610 | 611 | /* splitter_r = splitter_l + splitter_w */ 612 | /* left_child_r = left_child_l + left_child_w */ 613 | /* splitter_bar_r = splitter_bar_l + splitter_bar_w */ 614 | /* right_child_r = right_child_l + right_child_w */ 615 | new_constraint(solver, AM_REQUIRED, splitter_r, 1.0, AM_EQUAL, 0.0, 616 | splitter_l, 1.0, splitter_w, 1.0, END); 617 | new_constraint(solver, AM_REQUIRED, left_child_r, 1.0, AM_EQUAL, 0.0, 618 | left_child_l, 1.0, left_child_w, 1.0, END); 619 | new_constraint(solver, AM_REQUIRED, splitter_bar_r, 1.0, AM_EQUAL, 0.0, 620 | splitter_bar_l, 1.0, splitter_bar_w, 1.0, END); 621 | new_constraint(solver, AM_REQUIRED, right_child_r, 1.0, AM_EQUAL, 0.0, 622 | right_child_l, 1.0, right_child_w, 1.0, END); 623 | 624 | /* splitter_bar_w = 6 */ 625 | /* splitter_bar_l >= splitter_l + delta */ 626 | /* splitter_bar_r <= splitter_r - delta */ 627 | /* left_child_r = splitter_bar_l */ 628 | /* right_child_l = splitter_bar_r */ 629 | new_constraint(solver, AM_REQUIRED, splitter_bar_w, 1.0, AM_EQUAL, 6.0, END); 630 | new_constraint(solver, AM_REQUIRED, splitter_bar_l, 1.0, AM_GREATEQUAL, 631 | delta, splitter_l, 1.0, END); 632 | new_constraint(solver, AM_REQUIRED, splitter_bar_r, 1.0, AM_LESSEQUAL, 633 | -delta, splitter_r, 1.0, END); 634 | new_constraint(solver, AM_REQUIRED, left_child_r, 1.0, AM_EQUAL, 0.0, 635 | splitter_bar_l, 1.0, END); 636 | new_constraint(solver, AM_REQUIRED, right_child_l, 1.0, AM_EQUAL, 0.0, 637 | splitter_bar_r, 1.0, END); 638 | 639 | /* right_child_r >= splitter_r + 1 */ 640 | /* left_child_w = 256 */ 641 | new_constraint(solver, strength1, right_child_r, 1.0, AM_GREATEQUAL, 1.0, 642 | splitter_r, 1.0, END); 643 | new_constraint(solver, strength2, left_child_w, 1.0, AM_EQUAL, 256.0, END); 644 | 645 | /* splitter_l = 0 */ 646 | /* splitter_r = 76 */ 647 | new_constraint(solver, AM_REQUIRED, splitter_l, 1.0, AM_EQUAL, 0.0, END); 648 | new_constraint(solver, AM_REQUIRED, splitter_r, 1.0, AM_EQUAL, width, END); 649 | 650 | printf("\n\n==========\ntest suggest\n"); 651 | for(pos = -10; pos < 86; pos++) { 652 | am_suggest(splitter_bar_l, pos); 653 | printf("pos: %4g | ", pos); 654 | printf("splitter_l l=%2g, w=%2g, r=%2g | ", am_value(splitter_l), 655 | am_value(splitter_w), am_value(splitter_r)); 656 | printf("left_child_l l=%2g, w=%2g, r=%2g | ", am_value(left_child_l), 657 | am_value(left_child_w), am_value(left_child_r)); 658 | printf("splitter_bar_l l=%2g, w=%2g, r=%2g | ", am_value(splitter_bar_l), 659 | am_value(splitter_bar_w), am_value(splitter_bar_r)); 660 | printf("right_child_l l=%2g, w=%2g, r=%2g | ", am_value(right_child_l), 661 | am_value(right_child_w), am_value(right_child_r)); 662 | printf("\n"); 663 | } 664 | 665 | am_delsolver(solver); 666 | printf("allmem = %d\n", (int)allmem); 667 | printf("maxmem = %d\n", (int)maxmem); 668 | assert(allmem == 0); 669 | maxmem = 0; 670 | } 671 | 672 | void test_cycling() { 673 | am_Solver * solver = am_newsolver(NULL, NULL, NULL); 674 | 675 | am_Variable * va = am_newvariable(solver); 676 | am_Variable * vb = am_newvariable(solver); 677 | am_Variable * vc = am_newvariable(solver); 678 | am_Variable * vd = am_newvariable(solver); 679 | 680 | am_addedit(va, AM_STRONG); 681 | printf("after edit\n"); 682 | am_dumpsolver(solver); 683 | 684 | /* vb == va */ 685 | { 686 | am_Constraint * c = am_newconstraint(solver, AM_REQUIRED); 687 | int ret = 0; 688 | ret |= am_addterm(c, vb, 1.0); 689 | ret |= am_setrelation(c, AM_EQUAL); 690 | ret |= am_addterm(c, va, 1.0); 691 | ret |= am_add(c); 692 | assert(ret == AM_OK); 693 | am_dumpsolver(solver); 694 | } 695 | 696 | /* vb == vc */ 697 | { 698 | am_Constraint * c = am_newconstraint(solver, AM_REQUIRED); 699 | int ret = 0; 700 | ret |= am_addterm(c, vb, 1.0); 701 | ret |= am_setrelation(c, AM_EQUAL); 702 | ret |= am_addterm(c, vc, 1.0); 703 | ret |= am_add(c); 704 | assert(ret == AM_OK); 705 | am_dumpsolver(solver); 706 | } 707 | 708 | /* vc == vd */ 709 | { 710 | am_Constraint * c = am_newconstraint(solver, AM_REQUIRED); 711 | int ret = 0; 712 | ret |= am_addterm(c, vc, 1.0); 713 | ret |= am_setrelation(c, AM_EQUAL); 714 | ret |= am_addterm(c, vd, 1.0); 715 | ret |= am_add(c); 716 | assert(ret == AM_OK); 717 | am_dumpsolver(solver); 718 | } 719 | 720 | /* vd == va */ 721 | { 722 | am_Constraint * c = am_newconstraint(solver, AM_REQUIRED); 723 | int ret = 0; 724 | ret |= am_addterm(c, vd, 1.0); 725 | ret |= am_setrelation(c, AM_EQUAL); 726 | ret |= am_addterm(c, va, 1.0); 727 | ret |= am_add(c); 728 | assert(ret == AM_OK); /* asserts here */ 729 | am_dumpsolver(solver); 730 | } 731 | } 732 | 733 | static am_Float stored_val = 0.0; 734 | 735 | void example_callback(am_Solver *solver, am_Variable *variable, am_Float new_value, am_Float old_value) { 736 | printf("value: %f -> %f\n", old_value, new_value); 737 | stored_val = new_value; 738 | } 739 | 740 | void test_callback() { 741 | am_Solver *solver = am_newsolver(NULL, NULL, example_callback); 742 | 743 | am_Variable *va = am_newvariable(solver); 744 | { 745 | am_Constraint * c = am_newconstraint(solver, AM_REQUIRED); 746 | int ret = 0; 747 | ret |= am_addterm(c, va, 1.0); 748 | ret |= am_setrelation(c, AM_EQUAL); 749 | ret |= am_addconstant(c, 4.0); 750 | ret |= am_add(c); 751 | assert(ret == AM_OK); 752 | assert(stored_val == 0.0); 753 | am_updatevars(solver); 754 | assert(stored_val == 4.0); 755 | } 756 | } 757 | 758 | int main(void) { 759 | test_binarytree(); 760 | test_unbounded(); 761 | test_strength(); 762 | test_suggest(); 763 | test_cycling(); 764 | test_callback(); 765 | test_all(); 766 | return 0; 767 | } 768 | 769 | /* cc: flags='-ggdb -Wall -fprofile-arcs -ftest-coverage -O0 -Wextra -pedantic -std=c89' */ 770 | -------------------------------------------------------------------------------- /amoeba.h: -------------------------------------------------------------------------------- 1 | #ifndef amoeba_h 2 | #define amoeba_h 3 | 4 | #ifndef AM_NS_BEGIN 5 | # ifdef __cplusplus 6 | # define AM_NS_BEGIN extern "C" { 7 | # define AM_NS_END } 8 | # else 9 | # define AM_NS_BEGIN 10 | # define AM_NS_END 11 | # endif 12 | #endif /* AM_NS_BEGIN */ 13 | 14 | #ifndef AM_STATIC 15 | # ifdef __GNUC__ 16 | # define AM_STATIC static __attribute((unused)) 17 | # else 18 | # define AM_STATIC static 19 | # endif 20 | #endif 21 | 22 | #ifdef AM_STATIC_API 23 | # ifndef AM_IMPLEMENTATION 24 | # define AM_IMPLEMENTATION 25 | # endif 26 | # define AM_API AM_STATIC 27 | #endif 28 | 29 | #if !defined(AM_API) && defined(_WIN32) 30 | # ifdef AM_IMPLEMENTATION 31 | # define AM_API __declspec(dllexport) 32 | # else 33 | # define AM_API __declspec(dllimport) 34 | # endif 35 | #endif 36 | 37 | #ifndef AM_API 38 | # define AM_API extern 39 | #endif 40 | 41 | #define AM_OK (0) 42 | #define AM_FAILED (-1) 43 | #define AM_UNSATISFIED (-2) 44 | #define AM_UNBOUND (-3) 45 | 46 | #define AM_LESSEQUAL (1) 47 | #define AM_EQUAL (2) 48 | #define AM_GREATEQUAL (3) 49 | 50 | #define AM_REQUIRED ((am_Float)1000000000) 51 | #define AM_STRONG ((am_Float)1000000) 52 | #define AM_MEDIUM ((am_Float)1000) 53 | #define AM_WEAK ((am_Float)1) 54 | 55 | #include 56 | 57 | 58 | AM_NS_BEGIN 59 | 60 | #ifdef AM_USE_FLOAT 61 | typedef float am_Float; 62 | #else 63 | typedef double am_Float; 64 | #endif 65 | 66 | typedef struct am_Solver am_Solver; 67 | typedef struct am_Variable am_Variable; 68 | typedef struct am_Constraint am_Constraint; 69 | 70 | typedef void *am_Allocf (void *ud, void *ptr, size_t nsize, size_t osize); 71 | typedef void am_VarCallback (am_Solver *solver, am_Variable *var, am_Float old_value, am_Float new_value); 72 | 73 | AM_API am_Solver *am_newsolver (am_Allocf *allocf, void *ud, am_VarCallback *callback); 74 | AM_API void am_resetsolver (am_Solver *solver, int clear_constraints); 75 | AM_API void am_delsolver (am_Solver *solver); 76 | 77 | AM_API void am_updatevars(am_Solver *solver); 78 | AM_API void am_autoupdate(am_Solver *solver, int auto_update); 79 | 80 | AM_API int am_hasedit (am_Variable *var); 81 | AM_API int am_hasconstraint (am_Constraint *cons); 82 | 83 | AM_API int am_add (am_Constraint *cons); 84 | AM_API void am_remove (am_Constraint *cons); 85 | 86 | AM_API int am_addedit (am_Variable *var, am_Float strength); 87 | AM_API void am_suggest (am_Variable *var, am_Float value); 88 | AM_API void am_deledit (am_Variable *var); 89 | 90 | AM_API am_Variable *am_newvariable (am_Solver *solver); 91 | AM_API void am_usevariable (am_Variable *var); 92 | AM_API void am_delvariable (am_Variable *var); 93 | AM_API int am_variableid (am_Variable *var); 94 | AM_API am_Float am_value (am_Variable *var); 95 | 96 | AM_API am_Constraint *am_newconstraint (am_Solver *solver, am_Float strength); 97 | AM_API am_Constraint *am_cloneconstraint (am_Constraint *other, am_Float strength); 98 | 99 | AM_API void am_resetconstraint (am_Constraint *cons); 100 | AM_API void am_delconstraint (am_Constraint *cons); 101 | 102 | AM_API int am_addterm (am_Constraint *cons, am_Variable *var, am_Float multiplier); 103 | AM_API int am_setrelation (am_Constraint *cons, int relation); 104 | AM_API int am_addconstant (am_Constraint *cons, am_Float constant); 105 | AM_API int am_setstrength (am_Constraint *cons, am_Float strength); 106 | 107 | AM_API int am_mergeconstraint (am_Constraint *cons, am_Constraint *other, am_Float multiplier); 108 | 109 | AM_NS_END 110 | 111 | 112 | #endif /* amoeba_h */ 113 | 114 | 115 | #if defined(AM_IMPLEMENTATION) && !defined(am_implemented) 116 | #define am_implemented 117 | 118 | 119 | #include 120 | #include 121 | #include 122 | #include 123 | 124 | #define AM_EXTERNAL (0) 125 | #define AM_SLACK (1) 126 | #define AM_ERROR (2) 127 | #define AM_DUMMY (3) 128 | 129 | #define am_isexternal(key) ((key).type == AM_EXTERNAL) 130 | #define am_isslack(key) ((key).type == AM_SLACK) 131 | #define am_iserror(key) ((key).type == AM_ERROR) 132 | #define am_isdummy(key) ((key).type == AM_DUMMY) 133 | #define am_ispivotable(key) (am_isslack(key) || am_iserror(key)) 134 | 135 | #define AM_POOLSIZE 4096 136 | #define AM_MIN_HASHSIZE 4 137 | #define AM_MAX_SIZET ((~(size_t)0)-100) 138 | 139 | #ifdef AM_USE_FLOAT 140 | # define AM_FLOAT_MAX FLT_MAX 141 | # define AM_FLOAT_EPS 1e-4f 142 | #else 143 | # define AM_FLOAT_MAX DBL_MAX 144 | # define AM_FLOAT_EPS 1e-6 145 | #endif 146 | 147 | AM_NS_BEGIN 148 | 149 | typedef struct am_Symbol { 150 | unsigned id : 30; 151 | unsigned type : 2; 152 | } am_Symbol; 153 | 154 | typedef struct am_MemPool { 155 | size_t size; 156 | void *freed; 157 | void *pages; 158 | } am_MemPool; 159 | 160 | typedef struct am_Entry { 161 | int next; 162 | am_Symbol key; 163 | } am_Entry; 164 | 165 | typedef struct am_Table { 166 | size_t size; 167 | size_t count; 168 | size_t entry_size; 169 | size_t lastfree; 170 | am_Entry *hash; 171 | } am_Table; 172 | 173 | typedef struct am_VarEntry { 174 | am_Entry entry; 175 | am_Variable *variable; 176 | } am_VarEntry; 177 | 178 | typedef struct am_ConsEntry { 179 | am_Entry entry; 180 | am_Constraint *constraint; 181 | } am_ConsEntry; 182 | 183 | typedef struct am_Term { 184 | am_Entry entry; 185 | am_Float multiplier; 186 | } am_Term; 187 | 188 | typedef struct am_Row { 189 | am_Entry entry; 190 | am_Symbol infeasible_next; 191 | am_Table terms; 192 | am_Float constant; 193 | } am_Row; 194 | 195 | struct am_Variable { 196 | am_Symbol sym; 197 | am_Symbol dirty_next; 198 | unsigned refcount; 199 | am_Solver *solver; 200 | am_Constraint *constraint; 201 | am_Float edit_value; 202 | am_Float value; 203 | }; 204 | 205 | struct am_Constraint { 206 | am_Row expression; 207 | am_Symbol marker; 208 | am_Symbol other; 209 | int relation; 210 | am_Solver *solver; 211 | am_Float strength; 212 | }; 213 | 214 | struct am_Solver { 215 | am_Allocf *allocf; 216 | void *ud; 217 | am_Row objective; 218 | am_Table vars; /* symbol -> VarEntry */ 219 | am_Table constraints; /* symbol -> ConsEntry */ 220 | am_Table rows; /* symbol -> Row */ 221 | am_MemPool varpool; 222 | am_MemPool conspool; 223 | unsigned symbol_count; 224 | unsigned constraint_count; 225 | unsigned auto_update; 226 | am_Symbol infeasible_rows; 227 | am_Symbol dirty_vars; 228 | am_VarCallback *var_callback; 229 | }; 230 | 231 | 232 | /* utils */ 233 | 234 | static am_Symbol am_newsymbol(am_Solver *solver, int type); 235 | 236 | static int am_approx(am_Float a, am_Float b) 237 | { return a > b ? a - b < AM_FLOAT_EPS : b - a < AM_FLOAT_EPS; } 238 | 239 | static int am_nearzero(am_Float a) 240 | { return am_approx(a, 0.0f); } 241 | 242 | static am_Symbol am_null() 243 | { am_Symbol null = { 0, 0 }; return null; } 244 | 245 | static void am_initsymbol(am_Solver *solver, am_Symbol *sym, int type) 246 | { if (sym->id == 0) *sym = am_newsymbol(solver, type); } 247 | 248 | static void am_initpool(am_MemPool *pool, size_t size) { 249 | pool->size = size; 250 | pool->freed = pool->pages = NULL; 251 | assert(size > sizeof(void*) && size < AM_POOLSIZE/4); 252 | } 253 | 254 | static void am_freepool(am_Solver *solver, am_MemPool *pool) { 255 | const size_t offset = AM_POOLSIZE - sizeof(void*); 256 | while (pool->pages != NULL) { 257 | void *next = *(void**)((char*)pool->pages + offset); 258 | solver->allocf(solver->ud, pool->pages, 0, AM_POOLSIZE); 259 | pool->pages = next; 260 | } 261 | am_initpool(pool, pool->size); 262 | } 263 | 264 | static void *am_alloc(am_Solver *solver, am_MemPool *pool) { 265 | void *obj = pool->freed; 266 | if (obj == NULL) { 267 | const size_t offset = AM_POOLSIZE - sizeof(void*); 268 | void *end, *newpage = solver->allocf(solver->ud, NULL, AM_POOLSIZE, 0); 269 | *(void**)((char*)newpage + offset) = pool->pages; 270 | pool->pages = newpage; 271 | end = (char*)newpage + (offset/pool->size-1)*pool->size; 272 | while (end != newpage) { 273 | *(void**)end = pool->freed; 274 | pool->freed = (void**)end; 275 | end = (char*)end - pool->size; 276 | } 277 | return end; 278 | } 279 | pool->freed = *(void**)obj; 280 | return obj; 281 | } 282 | 283 | static void am_free(am_MemPool *pool, void *obj) { 284 | *(void**)obj = pool->freed; 285 | pool->freed = obj; 286 | } 287 | 288 | static am_Symbol am_newsymbol(am_Solver *solver, int type) { 289 | am_Symbol sym; 290 | unsigned id = ++solver->symbol_count; 291 | if (id > 0x3FFFFFFF) id = solver->symbol_count = 1; 292 | assert(type >= AM_EXTERNAL && type <= AM_DUMMY); 293 | sym.id = id; 294 | sym.type = type; 295 | return sym; 296 | } 297 | 298 | 299 | /* hash table */ 300 | 301 | #define am_key(entry) (((am_Entry*)(entry))->key) 302 | 303 | #define am_offset(lhs, rhs) ((int)((char*)(lhs) - (char*)(rhs))) 304 | #define am_index(h, i) ((am_Entry*)((char*)(h) + (i))) 305 | 306 | static am_Entry *am_newkey(am_Solver *solver, am_Table *t, am_Symbol key); 307 | 308 | static void am_delkey(am_Table *t, am_Entry *entry) 309 | { entry->key = am_null(), --t->count; } 310 | 311 | static void am_inittable(am_Table *t, size_t entry_size) 312 | { memset(t, 0, sizeof(*t)), t->entry_size = entry_size; } 313 | 314 | static am_Entry *am_mainposition(const am_Table *t, am_Symbol key) 315 | { return am_index(t->hash, (key.id & (t->size - 1))*t->entry_size); } 316 | 317 | static void am_resettable(am_Table *t) 318 | { t->count = 0; memset(t->hash, 0, t->lastfree = t->size * t->entry_size); } 319 | 320 | static size_t am_hashsize(am_Table *t, size_t len) { 321 | size_t newsize = AM_MIN_HASHSIZE; 322 | const size_t max_size = (AM_MAX_SIZET / 2) / t->entry_size; 323 | while (newsize < max_size && newsize < len) 324 | newsize <<= 1; 325 | assert((newsize & (newsize - 1)) == 0); 326 | return newsize < len ? 0 : newsize; 327 | } 328 | 329 | static void am_freetable(am_Solver *solver, am_Table *t) { 330 | size_t size = t->size*t->entry_size; 331 | if (size) solver->allocf(solver->ud, t->hash, 0, size); 332 | am_inittable(t, t->entry_size); 333 | } 334 | 335 | static size_t am_resizetable(am_Solver *solver, am_Table *t, size_t len) { 336 | size_t i, oldsize = t->size * t->entry_size; 337 | am_Table nt = *t; 338 | nt.size = am_hashsize(t, len); 339 | nt.lastfree = nt.size*nt.entry_size; 340 | nt.hash = (am_Entry*)solver->allocf(solver->ud, NULL, nt.lastfree, 0); 341 | memset(nt.hash, 0, nt.size*nt.entry_size); 342 | for (i = 0; i < oldsize; i += nt.entry_size) { 343 | am_Entry *e = am_index(t->hash, i); 344 | if (e->key.id != 0) { 345 | am_Entry *ne = am_newkey(solver, &nt, e->key); 346 | if (t->entry_size > sizeof(am_Entry)) 347 | memcpy(ne + 1, e + 1, t->entry_size-sizeof(am_Entry)); 348 | } 349 | } 350 | if (oldsize) solver->allocf(solver->ud, t->hash, 0, oldsize); 351 | *t = nt; 352 | return t->size; 353 | } 354 | 355 | static am_Entry *am_newkey(am_Solver *solver, am_Table *t, am_Symbol key) { 356 | if (t->size == 0) am_resizetable(solver, t, AM_MIN_HASHSIZE); 357 | for (;;) { 358 | am_Entry *mp = am_mainposition(t, key); 359 | if (mp->key.id != 0) { 360 | am_Entry *f = NULL, *othern; 361 | while (t->lastfree > 0) { 362 | am_Entry *e = am_index(t->hash, t->lastfree -= t->entry_size); 363 | if (e->key.id == 0 && e->next == 0) { f = e; break; } 364 | } 365 | if (!f) { am_resizetable(solver, t, t->count*2); continue; } 366 | assert(f->key.id == 0); 367 | othern = am_mainposition(t, mp->key); 368 | if (othern != mp) { 369 | am_Entry *next; 370 | while ((next = am_index(othern, othern->next)) != mp) 371 | othern = next; 372 | othern->next = am_offset(f, othern); 373 | memcpy(f, mp, t->entry_size); 374 | if (mp->next) f->next += am_offset(mp, f), mp->next = 0; 375 | } 376 | else { 377 | if (mp->next != 0) f->next = am_offset(mp, f) + mp->next; 378 | else assert(f->next == 0); 379 | mp->next = am_offset(f, mp), mp = f; 380 | } 381 | } 382 | mp->key = key; 383 | return mp; 384 | } 385 | } 386 | 387 | static const am_Entry *am_gettable(const am_Table *t, am_Symbol key) { 388 | const am_Entry *e; 389 | if (t->size == 0 || key.id == 0) return NULL; 390 | e = am_mainposition(t, key); 391 | for (; e->key.id != key.id; e = am_index(e, e->next)) 392 | if (e->next == 0) return NULL; 393 | return e; 394 | } 395 | 396 | static am_Entry *am_settable(am_Solver *solver, am_Table *t, am_Symbol key) { 397 | am_Entry *e; 398 | assert(key.id != 0); 399 | if ((e = (am_Entry*)am_gettable(t, key)) != NULL) return e; 400 | e = am_newkey(solver, t, key); 401 | if (t->entry_size > sizeof(am_Entry)) 402 | memset(e + 1, 0, t->entry_size-sizeof(am_Entry)); 403 | ++t->count; 404 | return e; 405 | } 406 | 407 | static int am_nextentry(const am_Table *t, am_Entry **pentry) { 408 | size_t i = *pentry ? am_offset(*pentry, t->hash) + t->entry_size : 0; 409 | size_t size = t->size*t->entry_size; 410 | for (; i < size; i += t->entry_size) { 411 | am_Entry *e = am_index(t->hash, i); 412 | if (e->key.id != 0) { *pentry = e; return 1; } 413 | } 414 | *pentry = NULL; 415 | return 0; 416 | } 417 | 418 | 419 | /* expression (row) */ 420 | 421 | static int am_isconstant(am_Row *row) 422 | { return row->terms.count == 0; } 423 | 424 | static void am_freerow(am_Solver *solver, am_Row *row) 425 | { am_freetable(solver, &row->terms); } 426 | 427 | static void am_resetrow(am_Row *row) 428 | { row->constant = 0.0f; am_resettable(&row->terms); } 429 | 430 | static void am_initrow(am_Row *row) { 431 | am_key(row) = am_null(); 432 | row->infeasible_next = am_null(); 433 | row->constant = 0.0f; 434 | am_inittable(&row->terms, sizeof(am_Term)); 435 | } 436 | 437 | static void am_multiply(am_Row *row, am_Float multiplier) { 438 | am_Term *term = NULL; 439 | row->constant *= multiplier; 440 | while (am_nextentry(&row->terms, (am_Entry**)&term)) 441 | term->multiplier *= multiplier; 442 | } 443 | 444 | static void am_addvar(am_Solver *solver, am_Row *row, am_Symbol sym, am_Float value) { 445 | am_Term *term; 446 | if (sym.id == 0) return; 447 | if ((term = (am_Term*)am_gettable(&row->terms, sym)) == NULL) 448 | term = (am_Term*)am_settable(solver, &row->terms, sym); 449 | if (am_nearzero(term->multiplier += value)) 450 | am_delkey(&row->terms, &term->entry); 451 | } 452 | 453 | static void am_addrow(am_Solver *solver, am_Row *row, const am_Row *other, am_Float multiplier) { 454 | am_Term *term = NULL; 455 | row->constant += other->constant*multiplier; 456 | while (am_nextentry(&other->terms, (am_Entry**)&term)) 457 | am_addvar(solver, row, am_key(term), term->multiplier*multiplier); 458 | } 459 | 460 | static void am_solvefor(am_Solver *solver, am_Row *row, am_Symbol entry, am_Symbol exit) { 461 | am_Term *term = (am_Term*)am_gettable(&row->terms, entry); 462 | am_Float reciprocal = 1.0f / term->multiplier; 463 | assert(entry.id != exit.id && !am_nearzero(term->multiplier)); 464 | am_delkey(&row->terms, &term->entry); 465 | am_multiply(row, -reciprocal); 466 | if (exit.id != 0) am_addvar(solver, row, exit, reciprocal); 467 | } 468 | 469 | static void am_substitute(am_Solver *solver, am_Row *row, am_Symbol entry, const am_Row *other) { 470 | am_Term *term = (am_Term*)am_gettable(&row->terms, entry); 471 | if (!term) return; 472 | am_delkey(&row->terms, &term->entry); 473 | am_addrow(solver, row, other, term->multiplier); 474 | } 475 | 476 | 477 | /* variables & constraints */ 478 | 479 | AM_API int am_variableid(am_Variable *var) { return var ? var->sym.id : -1; } 480 | AM_API am_Float am_value(am_Variable *var) { return var ? var->value : 0.0f; } 481 | AM_API void am_usevariable(am_Variable *var) { if (var) ++var->refcount; } 482 | 483 | static am_Variable *am_sym2var(am_Solver *solver, am_Symbol sym) { 484 | am_VarEntry *ve = (am_VarEntry*)am_gettable(&solver->vars, sym); 485 | assert(ve != NULL); 486 | return ve->variable; 487 | } 488 | 489 | AM_API am_Variable *am_newvariable(am_Solver *solver) { 490 | am_Variable *var = (am_Variable*)am_alloc(solver, &solver->varpool); 491 | am_Symbol sym = am_newsymbol(solver, AM_EXTERNAL); 492 | am_VarEntry *ve = (am_VarEntry*)am_settable(solver, &solver->vars, sym); 493 | assert(ve->variable == NULL); 494 | memset(var, 0, sizeof(*var)); 495 | var->sym = sym; 496 | var->refcount = 1; 497 | var->solver = solver; 498 | ve->variable = var; 499 | return var; 500 | } 501 | 502 | AM_API void am_delvariable(am_Variable *var) { 503 | if (var && --var->refcount <= 0) { 504 | am_Solver *solver = var->solver; 505 | am_VarEntry *e = (am_VarEntry*)am_gettable(&solver->vars, var->sym); 506 | assert(e != NULL); 507 | am_delkey(&solver->vars, &e->entry); 508 | am_remove(var->constraint); 509 | am_free(&solver->varpool, var); 510 | } 511 | } 512 | 513 | AM_API am_Constraint *am_newconstraint(am_Solver *solver, am_Float strength) { 514 | am_Constraint *cons = (am_Constraint*)am_alloc(solver, &solver->conspool); 515 | memset(cons, 0, sizeof(*cons)); 516 | cons->solver = solver; 517 | cons->strength = am_nearzero(strength) ? AM_REQUIRED : strength; 518 | am_initrow(&cons->expression); 519 | am_key(cons).id = ++solver->constraint_count; 520 | am_key(cons).type = AM_EXTERNAL; 521 | ((am_ConsEntry*)am_settable(solver, &solver->constraints, 522 | am_key(cons)))->constraint = cons; 523 | return cons; 524 | } 525 | 526 | AM_API void am_delconstraint(am_Constraint *cons) { 527 | am_Solver *solver = cons ? cons->solver : NULL; 528 | am_Term *term = NULL; 529 | am_ConsEntry *ce; 530 | if (cons == NULL) return; 531 | am_remove(cons); 532 | ce = (am_ConsEntry*)am_gettable(&solver->constraints, am_key(cons)); 533 | assert(ce != NULL); 534 | am_delkey(&solver->constraints, &ce->entry); 535 | while (am_nextentry(&cons->expression.terms, (am_Entry**)&term)) 536 | am_delvariable(am_sym2var(solver, am_key(term))); 537 | am_freerow(solver, &cons->expression); 538 | am_free(&solver->conspool, cons); 539 | } 540 | 541 | AM_API am_Constraint *am_cloneconstraint(am_Constraint *other, am_Float strength) { 542 | am_Constraint *cons; 543 | if (other == NULL) return NULL; 544 | cons = am_newconstraint(other->solver, 545 | am_nearzero(strength) ? other->strength : strength); 546 | am_mergeconstraint(cons, other, 1.0f); 547 | cons->relation = other->relation; 548 | return cons; 549 | } 550 | 551 | AM_API int am_mergeconstraint(am_Constraint *cons, am_Constraint *other, am_Float multiplier) { 552 | am_Term *term = NULL; 553 | if (cons == NULL || other == NULL || cons->marker.id != 0 554 | || cons->solver != other->solver) return AM_FAILED; 555 | if (cons->relation == AM_GREATEQUAL) multiplier = -multiplier; 556 | cons->expression.constant += other->expression.constant*multiplier; 557 | while (am_nextentry(&other->expression.terms, (am_Entry**)&term)) { 558 | am_usevariable(am_sym2var(cons->solver, am_key(term))); 559 | am_addvar(cons->solver, &cons->expression, am_key(term), 560 | term->multiplier*multiplier); 561 | } 562 | return AM_OK; 563 | } 564 | 565 | AM_API void am_resetconstraint(am_Constraint *cons) { 566 | am_Term *term = NULL; 567 | if (cons == NULL) return; 568 | am_remove(cons); 569 | cons->relation = 0; 570 | while (am_nextentry(&cons->expression.terms, (am_Entry**)&term)) 571 | am_delvariable(am_sym2var(cons->solver, am_key(term))); 572 | am_resetrow(&cons->expression); 573 | } 574 | 575 | AM_API int am_addterm(am_Constraint *cons, am_Variable *var, am_Float multiplier) { 576 | if (cons == NULL || var == NULL || cons->marker.id != 0 || 577 | cons->solver != var->solver) return AM_FAILED; 578 | assert(var->sym.id != 0); 579 | assert(var->solver == cons->solver); 580 | if (cons->relation == AM_GREATEQUAL) multiplier = -multiplier; 581 | am_addvar(cons->solver, &cons->expression, var->sym, multiplier); 582 | am_usevariable(var); 583 | return AM_OK; 584 | } 585 | 586 | AM_API int am_addconstant(am_Constraint *cons, am_Float constant) { 587 | if (cons == NULL || cons->marker.id != 0) return AM_FAILED; 588 | if (cons->relation == AM_GREATEQUAL) 589 | cons->expression.constant -= constant; 590 | else 591 | cons->expression.constant += constant; 592 | return AM_OK; 593 | } 594 | 595 | AM_API int am_setrelation(am_Constraint *cons, int relation) { 596 | assert(relation >= AM_LESSEQUAL && relation <= AM_GREATEQUAL); 597 | if (cons == NULL || cons->marker.id != 0 || cons->relation != 0) 598 | return AM_FAILED; 599 | if (relation != AM_GREATEQUAL) am_multiply(&cons->expression, -1.0f); 600 | cons->relation = relation; 601 | return AM_OK; 602 | } 603 | 604 | 605 | /* Cassowary algorithm */ 606 | 607 | AM_API int am_hasedit(am_Variable *var) 608 | { return var != NULL && var->constraint != NULL; } 609 | 610 | AM_API int am_hasconstraint(am_Constraint *cons) 611 | { return cons != NULL && cons->marker.id != 0; } 612 | 613 | AM_API void am_autoupdate(am_Solver *solver, int auto_update) 614 | { solver->auto_update = auto_update; } 615 | 616 | static void am_infeasible(am_Solver *solver, am_Row *row) { 617 | if (am_isdummy(row->infeasible_next)) return; 618 | row->infeasible_next.id = solver->infeasible_rows.id; 619 | row->infeasible_next.type = AM_DUMMY; 620 | solver->infeasible_rows = am_key(row); 621 | } 622 | 623 | static void am_markdirty(am_Solver *solver, am_Variable *var) { 624 | if (var->dirty_next.type == AM_DUMMY) return; 625 | var->dirty_next.id = solver->dirty_vars.id; 626 | var->dirty_next.type = AM_DUMMY; 627 | solver->dirty_vars = var->sym; 628 | } 629 | 630 | static void am_substitute_rows(am_Solver *solver, am_Symbol var, am_Row *expr) { 631 | am_Row *row = NULL; 632 | while (am_nextentry(&solver->rows, (am_Entry**)&row)) { 633 | am_substitute(solver, row, var, expr); 634 | if (am_isexternal(am_key(row))) 635 | am_markdirty(solver, am_sym2var(solver, am_key(row))); 636 | else if (row->constant < 0.0f) 637 | am_infeasible(solver, row); 638 | } 639 | am_substitute(solver, &solver->objective, var, expr); 640 | } 641 | 642 | static int am_getrow(am_Solver *solver, am_Symbol sym, am_Row *dst) { 643 | am_Row *row = (am_Row*)am_gettable(&solver->rows, sym); 644 | am_key(dst) = am_null(); 645 | if (row == NULL) return AM_FAILED; 646 | am_delkey(&solver->rows, &row->entry); 647 | dst->constant = row->constant; 648 | dst->terms = row->terms; 649 | return AM_OK; 650 | } 651 | 652 | static int am_putrow(am_Solver *solver, am_Symbol sym, const am_Row *src) { 653 | am_Row *row = (am_Row*)am_settable(solver, &solver->rows, sym); 654 | row->constant = src->constant; 655 | row->terms = src->terms; 656 | return AM_OK; 657 | } 658 | 659 | static void am_mergerow(am_Solver *solver, am_Row *row, am_Symbol var, am_Float multiplier) { 660 | am_Row *oldrow = (am_Row*)am_gettable(&solver->rows, var); 661 | if (oldrow) am_addrow(solver, row, oldrow, multiplier); 662 | else am_addvar(solver, row, var, multiplier); 663 | } 664 | 665 | static int am_optimize(am_Solver *solver, am_Row *objective) { 666 | for (;;) { 667 | am_Symbol enter = am_null(), exit = am_null(); 668 | am_Float r, min_ratio = AM_FLOAT_MAX; 669 | am_Row tmp, *row = NULL; 670 | am_Term *term = NULL; 671 | 672 | assert(solver->infeasible_rows.id == 0); 673 | while (am_nextentry(&objective->terms, (am_Entry**)&term)) { 674 | if (!am_isdummy(am_key(term)) && term->multiplier < 0.0f) 675 | { enter = am_key(term); break; } 676 | } 677 | if (enter.id == 0) return AM_OK; 678 | 679 | while (am_nextentry(&solver->rows, (am_Entry**)&row)) { 680 | term = (am_Term*)am_gettable(&row->terms, enter); 681 | if (term == NULL || !am_ispivotable(am_key(row)) 682 | || term->multiplier > 0.0f) continue; 683 | r = -row->constant / term->multiplier; 684 | if (r < min_ratio || (am_approx(r, min_ratio) 685 | && am_key(row).id < exit.id)) 686 | min_ratio = r, exit = am_key(row); 687 | } 688 | assert(exit.id != 0); 689 | if (exit.id == 0) return AM_FAILED; 690 | 691 | am_getrow(solver, exit, &tmp); 692 | am_solvefor(solver, &tmp, enter, exit); 693 | am_substitute_rows(solver, enter, &tmp); 694 | if (objective != &solver->objective) 695 | am_substitute(solver, objective, enter, &tmp); 696 | am_putrow(solver, enter, &tmp); 697 | } 698 | } 699 | 700 | static am_Row am_makerow(am_Solver *solver, am_Constraint *cons) { 701 | am_Term *term = NULL; 702 | am_Row row; 703 | am_initrow(&row); 704 | row.constant = cons->expression.constant; 705 | while (am_nextentry(&cons->expression.terms, (am_Entry**)&term)) { 706 | am_markdirty(solver, am_sym2var(solver, am_key(term))); 707 | am_mergerow(solver, &row, am_key(term), term->multiplier); 708 | } 709 | if (cons->relation != AM_EQUAL) { 710 | am_initsymbol(solver, &cons->marker, AM_SLACK); 711 | am_addvar(solver, &row, cons->marker, -1.0f); 712 | if (cons->strength < AM_REQUIRED) { 713 | am_initsymbol(solver, &cons->other, AM_ERROR); 714 | am_addvar(solver, &row, cons->other, 1.0f); 715 | am_addvar(solver, &solver->objective, cons->other, cons->strength); 716 | } 717 | } 718 | else if (cons->strength >= AM_REQUIRED) { 719 | am_initsymbol(solver, &cons->marker, AM_DUMMY); 720 | am_addvar(solver, &row, cons->marker, 1.0f); 721 | } 722 | else { 723 | am_initsymbol(solver, &cons->marker, AM_ERROR); 724 | am_initsymbol(solver, &cons->other, AM_ERROR); 725 | am_addvar(solver, &row, cons->marker, -1.0f); 726 | am_addvar(solver, &row, cons->other, 1.0f); 727 | am_addvar(solver, &solver->objective, cons->marker, cons->strength); 728 | am_addvar(solver, &solver->objective, cons->other, cons->strength); 729 | } 730 | if (row.constant < 0.0f) am_multiply(&row, -1.0f); 731 | return row; 732 | } 733 | 734 | static void am_remove_errors(am_Solver *solver, am_Constraint *cons) { 735 | if (am_iserror(cons->marker)) 736 | am_mergerow(solver, &solver->objective, cons->marker, -cons->strength); 737 | if (am_iserror(cons->other)) 738 | am_mergerow(solver, &solver->objective, cons->other, -cons->strength); 739 | if (am_isconstant(&solver->objective)) 740 | solver->objective.constant = 0.0f; 741 | cons->marker = cons->other = am_null(); 742 | } 743 | 744 | static int am_add_with_artificial(am_Solver *solver, am_Row *row, am_Constraint *cons) { 745 | am_Symbol a = am_newsymbol(solver, AM_SLACK); 746 | am_Term *term = NULL; 747 | am_Row tmp; 748 | int ret; 749 | --solver->symbol_count; /* artificial variable will be removed */ 750 | am_initrow(&tmp); 751 | am_addrow(solver, &tmp, row, 1.0f); 752 | am_putrow(solver, a, row); 753 | am_initrow(row), row = NULL; /* row is useless */ 754 | am_optimize(solver, &tmp); 755 | ret = am_nearzero(tmp.constant) ? AM_OK : AM_UNBOUND; 756 | am_freerow(solver, &tmp); 757 | if (am_getrow(solver, a, &tmp) == AM_OK) { 758 | am_Symbol entry = am_null(); 759 | if (am_isconstant(&tmp)) { am_freerow(solver, &tmp); return ret; } 760 | while (am_nextentry(&tmp.terms, (am_Entry**)&term)) 761 | if (am_ispivotable(am_key(term))) { entry = am_key(term); break; } 762 | if (entry.id == 0) { am_freerow(solver, &tmp); return AM_UNBOUND; } 763 | am_solvefor(solver, &tmp, entry, a); 764 | am_substitute_rows(solver, entry, &tmp); 765 | am_putrow(solver, entry, &tmp); 766 | } 767 | while (am_nextentry(&solver->rows, (am_Entry**)&row)) { 768 | term = (am_Term*)am_gettable(&row->terms, a); 769 | if (term) am_delkey(&row->terms, &term->entry); 770 | } 771 | term = (am_Term*)am_gettable(&solver->objective.terms, a); 772 | if (term) am_delkey(&solver->objective.terms, &term->entry); 773 | if (ret != AM_OK) am_remove(cons); 774 | return ret; 775 | } 776 | 777 | static int am_try_addrow(am_Solver *solver, am_Row *row, am_Constraint *cons) { 778 | am_Symbol subject = am_null(); 779 | am_Term *term = NULL; 780 | while (am_nextentry(&row->terms, (am_Entry**)&term)) 781 | if (am_isexternal(am_key(term))) { subject = am_key(term); break; } 782 | if (subject.id == 0 && am_ispivotable(cons->marker)) { 783 | am_Term *mterm = (am_Term*)am_gettable(&row->terms, cons->marker); 784 | if (mterm->multiplier < 0.0f) subject = cons->marker; 785 | } 786 | if (subject.id == 0 && am_ispivotable(cons->other)) { 787 | am_Term *mterm = (am_Term*)am_gettable(&row->terms, cons->other); 788 | if (mterm->multiplier < 0.0f) subject = cons->other; 789 | } 790 | if (subject.id == 0) { 791 | while (am_nextentry(&row->terms, (am_Entry**)&term)) 792 | if (!am_isdummy(am_key(term))) break; 793 | if (term == NULL) { 794 | if (am_nearzero(row->constant)) 795 | subject = cons->marker; 796 | else { 797 | am_freerow(solver, row); 798 | return AM_UNSATISFIED; 799 | } 800 | } 801 | } 802 | if (subject.id == 0) 803 | return am_add_with_artificial(solver, row, cons); 804 | am_solvefor(solver, row, subject, am_null()); 805 | am_substitute_rows(solver, subject, row); 806 | am_putrow(solver, subject, row); 807 | return AM_OK; 808 | } 809 | 810 | static am_Symbol am_get_leaving_row(am_Solver *solver, am_Symbol marker) { 811 | am_Symbol first = am_null(), second = am_null(), third = am_null(); 812 | am_Float r1 = AM_FLOAT_MAX, r2 = AM_FLOAT_MAX; 813 | am_Row *row = NULL; 814 | while (am_nextentry(&solver->rows, (am_Entry**)&row)) { 815 | am_Term *term = (am_Term*)am_gettable(&row->terms, marker); 816 | if (term == NULL) continue; 817 | if (am_isexternal(am_key(row))) third = am_key(row); 818 | else if (term->multiplier < 0.0f) { 819 | am_Float r = -row->constant / term->multiplier; 820 | if (r < r1) r1 = r, first = am_key(row); 821 | } 822 | else { 823 | am_Float r = row->constant / term->multiplier; 824 | if (r < r2) r2 = r, second = am_key(row); 825 | } 826 | } 827 | return first.id ? first : second.id ? second : third; 828 | } 829 | 830 | static void am_delta_edit_constant(am_Solver *solver, am_Float delta, am_Constraint *cons) { 831 | am_Row *row; 832 | if ((row = (am_Row*)am_gettable(&solver->rows, cons->marker)) != NULL) 833 | { if ((row->constant -= delta) < 0.0f) am_infeasible(solver, row); return; } 834 | if ((row = (am_Row*)am_gettable(&solver->rows, cons->other)) != NULL) 835 | { if ((row->constant += delta) < 0.0f) am_infeasible(solver, row); return; } 836 | while (am_nextentry(&solver->rows, (am_Entry**)&row)) { 837 | am_Term *term = (am_Term*)am_gettable(&row->terms, cons->marker); 838 | if (term == NULL) continue; 839 | row->constant += term->multiplier*delta; 840 | if (am_isexternal(am_key(row))) 841 | am_markdirty(solver, am_sym2var(solver, am_key(row))); 842 | else if (row->constant < 0.0f) 843 | am_infeasible(solver, row); 844 | } 845 | } 846 | 847 | static void am_dual_optimize(am_Solver *solver) { 848 | while (solver->infeasible_rows.id != 0) { 849 | am_Row tmp, *row = 850 | (am_Row*)am_gettable(&solver->rows, solver->infeasible_rows); 851 | am_Symbol enter = am_null(), exit = am_key(row), curr; 852 | am_Term *objterm, *term = NULL; 853 | am_Float r, min_ratio = AM_FLOAT_MAX; 854 | solver->infeasible_rows = row->infeasible_next; 855 | row->infeasible_next = am_null(); 856 | if (row->constant >= 0.0f) continue; 857 | while (am_nextentry(&row->terms, (am_Entry**)&term)) { 858 | if (am_isdummy(curr = am_key(term)) || term->multiplier <= 0.0f) 859 | continue; 860 | objterm = (am_Term*)am_gettable(&solver->objective.terms, curr); 861 | r = objterm ? objterm->multiplier / term->multiplier : 0.0f; 862 | if (min_ratio > r) min_ratio = r, enter = curr; 863 | } 864 | assert(enter.id != 0); 865 | am_getrow(solver, exit, &tmp); 866 | am_solvefor(solver, &tmp, enter, exit); 867 | am_substitute_rows(solver, enter, &tmp); 868 | am_putrow(solver, enter, &tmp); 869 | } 870 | } 871 | 872 | static void *am_default_allocf(void *ud, void *ptr, size_t nsize, size_t osize) { 873 | void *newptr; 874 | (void)ud, (void)osize; 875 | if (nsize == 0) { free(ptr); return NULL; } 876 | newptr = realloc(ptr, nsize); 877 | if (newptr == NULL) abort(); 878 | return newptr; 879 | } 880 | 881 | AM_API am_Solver *am_newsolver(am_Allocf *allocf, void *ud, am_VarCallback *var_callback) { 882 | am_Solver *solver; 883 | if (allocf == NULL) allocf = am_default_allocf; 884 | if ((solver = (am_Solver*)allocf(ud, NULL, sizeof(am_Solver), 0)) == NULL) 885 | return NULL; 886 | memset(solver, 0, sizeof(*solver)); 887 | solver->allocf = allocf; 888 | solver->ud = ud; 889 | solver->var_callback = var_callback; 890 | am_initrow(&solver->objective); 891 | am_inittable(&solver->vars, sizeof(am_VarEntry)); 892 | am_inittable(&solver->constraints, sizeof(am_ConsEntry)); 893 | am_inittable(&solver->rows, sizeof(am_Row)); 894 | am_initpool(&solver->varpool, sizeof(am_Variable)); 895 | am_initpool(&solver->conspool, sizeof(am_Constraint)); 896 | return solver; 897 | } 898 | 899 | AM_API void am_delsolver(am_Solver *solver) { 900 | am_ConsEntry *ce = NULL; 901 | am_Row *row = NULL; 902 | while (am_nextentry(&solver->constraints, (am_Entry**)&ce)) 903 | am_freerow(solver, &ce->constraint->expression); 904 | while (am_nextentry(&solver->rows, (am_Entry**)&row)) 905 | am_freerow(solver, row); 906 | am_freerow(solver, &solver->objective); 907 | am_freetable(solver, &solver->vars); 908 | am_freetable(solver, &solver->constraints); 909 | am_freetable(solver, &solver->rows); 910 | am_freepool(solver, &solver->varpool); 911 | am_freepool(solver, &solver->conspool); 912 | solver->allocf(solver->ud, solver, 0, sizeof(*solver)); 913 | } 914 | 915 | AM_API void am_resetsolver(am_Solver *solver, int clear_constraints) { 916 | am_Entry *entry = NULL; 917 | if (!solver->auto_update) am_updatevars(solver); 918 | while (am_nextentry(&solver->vars, &entry)) { 919 | am_Constraint **cons = &((am_VarEntry*)entry)->variable->constraint; 920 | am_remove(*cons); 921 | *cons = NULL; 922 | } 923 | assert(am_nearzero(solver->objective.constant)); 924 | assert(solver->infeasible_rows.id == 0); 925 | assert(solver->dirty_vars.id == 0); 926 | if (!clear_constraints) return; 927 | am_resetrow(&solver->objective); 928 | while (am_nextentry(&solver->constraints, &entry)) { 929 | am_Constraint *cons = ((am_ConsEntry*)entry)->constraint; 930 | if (cons->marker.id == 0) continue; 931 | cons->marker = cons->other = am_null(); 932 | } 933 | while (am_nextentry(&solver->rows, &entry)) { 934 | am_delkey(&solver->rows, entry); 935 | am_freerow(solver, (am_Row*)entry); 936 | } 937 | } 938 | 939 | AM_API void am_updatevars(am_Solver *solver) { 940 | while (solver->dirty_vars.id != 0) { 941 | am_Variable *var = am_sym2var(solver, solver->dirty_vars); 942 | am_Row *row = (am_Row*)am_gettable(&solver->rows, var->sym); 943 | solver->dirty_vars = var->dirty_next; 944 | var->dirty_next = am_null(); 945 | am_Float old_value = var->value; 946 | am_Float new_value = row ? row->constant : 0.0f; 947 | var->value = new_value; 948 | if (solver->var_callback != NULL) { 949 | solver->var_callback(solver, var, new_value, old_value); 950 | } 951 | } 952 | } 953 | 954 | AM_API int am_add(am_Constraint *cons) { 955 | am_Solver *solver = cons ? cons->solver : NULL; 956 | int ret, oldsym = solver ? solver->symbol_count : 0; 957 | am_Row row; 958 | if (solver == NULL || cons->marker.id != 0) return AM_FAILED; 959 | row = am_makerow(solver, cons); 960 | if ((ret = am_try_addrow(solver, &row, cons)) != AM_OK) { 961 | am_remove_errors(solver, cons); 962 | solver->symbol_count = oldsym; 963 | } 964 | else { 965 | am_optimize(solver, &solver->objective); 966 | if (solver->auto_update) am_updatevars(solver); 967 | } 968 | return ret; 969 | } 970 | 971 | AM_API void am_remove(am_Constraint *cons) { 972 | am_Solver *solver; 973 | am_Symbol marker; 974 | am_Row tmp; 975 | if (cons == NULL || cons->marker.id == 0) return; 976 | solver = cons->solver, marker = cons->marker; 977 | am_remove_errors(solver, cons); 978 | if (am_getrow(solver, marker, &tmp) != AM_OK) { 979 | am_Symbol exit = am_get_leaving_row(solver, marker); 980 | assert(exit.id != 0); 981 | am_getrow(solver, exit, &tmp); 982 | am_solvefor(solver, &tmp, marker, exit); 983 | am_substitute_rows(solver, marker, &tmp); 984 | } 985 | am_freerow(solver, &tmp); 986 | am_optimize(solver, &solver->objective); 987 | if (solver->auto_update) am_updatevars(solver); 988 | } 989 | 990 | AM_API int am_setstrength(am_Constraint *cons, am_Float strength) { 991 | if (cons == NULL) return AM_FAILED; 992 | strength = am_nearzero(strength) ? AM_REQUIRED : strength; 993 | if (cons->strength == strength) return AM_OK; 994 | if (cons->strength >= AM_REQUIRED || strength >= AM_REQUIRED) 995 | { am_remove(cons), cons->strength = strength; return am_add(cons); } 996 | if (cons->marker.id != 0) { 997 | am_Solver *solver = cons->solver; 998 | am_Float diff = strength - cons->strength; 999 | am_mergerow(solver, &solver->objective, cons->marker, diff); 1000 | am_mergerow(solver, &solver->objective, cons->other, diff); 1001 | am_optimize(solver, &solver->objective); 1002 | if (solver->auto_update) am_updatevars(solver); 1003 | } 1004 | cons->strength = strength; 1005 | return AM_OK; 1006 | } 1007 | 1008 | AM_API int am_addedit(am_Variable *var, am_Float strength) { 1009 | am_Solver *solver = var ? var->solver : NULL; 1010 | am_Constraint *cons; 1011 | if (var == NULL || var->constraint != NULL) return AM_FAILED; 1012 | assert(var->sym.id != 0); 1013 | if (strength >= AM_STRONG) strength = AM_STRONG; 1014 | cons = am_newconstraint(solver, strength); 1015 | am_setrelation(cons, AM_EQUAL); 1016 | am_addterm(cons, var, 1.0f); /* var must have positive signture */ 1017 | am_addconstant(cons, -var->value); 1018 | if (am_add(cons) != AM_OK) assert(0); 1019 | var->constraint = cons; 1020 | var->edit_value = var->value; 1021 | return AM_OK; 1022 | } 1023 | 1024 | AM_API void am_deledit(am_Variable *var) { 1025 | if (var == NULL || var->constraint == NULL) return; 1026 | am_delconstraint(var->constraint); 1027 | var->constraint = NULL; 1028 | var->edit_value = 0.0f; 1029 | } 1030 | 1031 | AM_API void am_suggest(am_Variable *var, am_Float value) { 1032 | am_Solver *solver = var ? var->solver : NULL; 1033 | am_Float delta; 1034 | if (var == NULL) return; 1035 | if (var->constraint == NULL) { 1036 | am_addedit(var, AM_MEDIUM); 1037 | assert(var->constraint != NULL); 1038 | } 1039 | delta = value - var->edit_value; 1040 | var->edit_value = value; 1041 | am_delta_edit_constant(solver, delta, var->constraint); 1042 | am_dual_optimize(solver); 1043 | if (solver->auto_update) am_updatevars(solver); 1044 | } 1045 | 1046 | AM_NS_END 1047 | 1048 | 1049 | #endif /* AM_IMPLEMENTATION */ 1050 | 1051 | /* cc: flags+='-shared -O2 -DAM_IMPLEMENTATION -xc' 1052 | unixcc: output='amoeba.so' 1053 | win32cc: output='amoeba.dll' */ 1054 | 1055 | --------------------------------------------------------------------------------