├── mymod.lua ├── mymod_update.lua ├── reload.lua └── test.lua /mymod.lua: -------------------------------------------------------------------------------- 1 | local mod = {} 2 | 3 | local a = 1 4 | 5 | local function foobar() 6 | return a 7 | end 8 | 9 | print("foobar:", foobar) 10 | 11 | function mod.foo() 12 | return foobar 13 | end 14 | 15 | function mod.foo2() 16 | return foobar 17 | end 18 | 19 | function mod.foobar(x) 20 | a = x 21 | end 22 | 23 | local meta = {} 24 | 25 | meta.__index = meta 26 | 27 | function meta:show() 28 | print("OLD") 29 | end 30 | 31 | function mod.new() 32 | return setmetatable({}, meta) 33 | end 34 | 35 | return mod -------------------------------------------------------------------------------- /mymod_update.lua: -------------------------------------------------------------------------------- 1 | local debug = require "debug" 2 | 3 | local mod = {} 4 | 5 | local a 6 | 7 | local function foobar() 8 | print "UPDATE" 9 | return a 10 | end 11 | 12 | print("update foobar:", foobar) 13 | 14 | function mod.foo() 15 | return foobar() 16 | end 17 | 18 | function mod.foo2() 19 | return foobar 20 | end 21 | 22 | function mod.foobar(x) 23 | a = x 24 | end 25 | 26 | mod.getinfo = debug.getinfo 27 | 28 | local meta = {} 29 | 30 | meta.__index = meta 31 | 32 | function meta:show() 33 | print("NEW") 34 | end 35 | 36 | function mod.new() 37 | return setmetatable({}, meta) 38 | end 39 | 40 | return mod -------------------------------------------------------------------------------- /reload.lua: -------------------------------------------------------------------------------- 1 | local reload = {} 2 | local sandbox = {} 3 | 4 | local table = table 5 | local debug = debug 6 | 7 | do -- sandbox begin 8 | 9 | local function findloader(name) 10 | if reload.postfix then 11 | name = name .. reload.postfix 12 | end 13 | local msg = {} 14 | for _, loader in ipairs(package.searchers) do 15 | local f , extra = loader(name) 16 | local t = type(f) 17 | if t == "function" then 18 | return f, extra 19 | elseif t == "string" then 20 | table.insert(msg, f) 21 | end 22 | end 23 | error(string.format("module '%s' not found:%s", name, table.concat(msg))) 24 | end 25 | 26 | local global_mt = { 27 | __newindex = error, 28 | __pairs = error, 29 | __metatable = "SANDBOX", 30 | } 31 | local _LOADED_DUMMY = {} 32 | local _LOADED = {} 33 | local weak = { __mode = "kv" } 34 | local dummy_cache 35 | local dummy_module_cache 36 | 37 | local module_dummy_mt = { 38 | __metatable = "MODULE", 39 | __newindex = error, 40 | __pairs = error, 41 | __tostring = function(self) return dummy_module_cache[self] end, 42 | } 43 | 44 | local function make_dummy_module(name) 45 | local name = "[" .. name .. "]" 46 | if dummy_module_cache[name] then 47 | return dummy_module_cache[name] 48 | else 49 | local obj = {} 50 | dummy_module_cache[name] = obj 51 | dummy_module_cache[obj] = name 52 | return setmetatable(obj, module_dummy_mt) 53 | end 54 | end 55 | 56 | function module_dummy_mt:__index(k) 57 | assert(type(k) == "string", "module field is not string") 58 | local parent_key = dummy_module_cache[self] 59 | local key = parent_key .. "." .. k 60 | if dummy_module_cache[key] then 61 | return dummy_module_cache[key] 62 | else 63 | local obj = {} 64 | dummy_module_cache[key] = obj 65 | dummy_module_cache[obj] = key 66 | return setmetatable(obj, module_dummy_mt) 67 | end 68 | end 69 | 70 | local function make_sandbox() 71 | return setmetatable({}, global_mt) 72 | end 73 | 74 | function sandbox.require(name) 75 | assert(type(name) == "string") 76 | if _LOADED_DUMMY[name] then 77 | return _LOADED_DUMMY[name] 78 | end 79 | local loader, arg = findloader(name) 80 | local env, uv = debug.getupvalue(loader, 1) 81 | if env == "_ENV" then 82 | debug.setupvalue(loader, 1, make_sandbox()) 83 | end 84 | local ret = loader(name, arg) or true 85 | _LOADED[name] = { module = ret } 86 | if env == "_ENV" then 87 | debug.setupvalue(loader, 1, nil) 88 | _LOADED[name].loader = loader 89 | end 90 | _LOADED_DUMMY[name] = make_dummy_module(name) 91 | return _LOADED_DUMMY[name] 92 | end 93 | 94 | local global_dummy_mt = { 95 | __metatable = "GLOBAL", 96 | __tostring = function(self) return dummy_cache[self] end, 97 | __newindex = error, 98 | __pairs = error, 99 | } 100 | 101 | local function make_dummy(k) 102 | if dummy_cache[k] then 103 | return dummy_cache[k] 104 | else 105 | local obj = {} 106 | dummy_cache[obj] = k 107 | dummy_cache[k] = obj 108 | return setmetatable(obj, global_dummy_mt) 109 | end 110 | end 111 | 112 | function global_dummy_mt:__index(k) 113 | local parent_key = dummy_cache[self] 114 | assert(type(k) == "string", "Global name must be a string") 115 | local key = parent_key .. "." .. k 116 | return make_dummy(key) 117 | end 118 | 119 | local _inext = ipairs {} 120 | 121 | -- the base lib function never return objects out of sandbox 122 | local safe_function = { 123 | require = sandbox.require, -- sandbox require 124 | pairs = pairs, -- allow pairs during require 125 | next = next, 126 | ipairs = ipairs, 127 | _inext = _inext, 128 | print = print, -- for debug 129 | } 130 | 131 | function global_mt:__index(k) 132 | assert(type(k) == "string", "Global name must be a string") 133 | if safe_function[k] then 134 | return safe_function[k] 135 | else 136 | return make_dummy(k) 137 | end 138 | end 139 | 140 | local function get_G(obj) 141 | local k = dummy_cache[obj] 142 | local G = _G 143 | for w in string.gmatch(k, "[_%a]%w*") do 144 | if G == nil then 145 | error("Invalid global", k) 146 | end 147 | G=G[w] 148 | end 149 | return G 150 | end 151 | 152 | local function get_M(obj) 153 | local k = dummy_module_cache[obj] 154 | local M = debug.getregistry()._LOADED 155 | local from, to, name = string.find(k, "^%[([_%w]+)%]") 156 | if from == nil then 157 | error ("Invalid module " .. k) 158 | end 159 | local mod = assert(M[name]) 160 | for w in string.gmatch(k:sub(to+1), "[_%a]%w*") do 161 | if mod == nil then 162 | error("Invalid module key", k) 163 | end 164 | mod=mod[w] 165 | end 166 | return mod 167 | end 168 | 169 | function sandbox.value(obj) 170 | local meta = getmetatable(obj) 171 | if meta == "GLOBAL" then 172 | return get_G(obj) 173 | elseif meta == "MODULE" then 174 | return get_M(obj) 175 | else 176 | error("Invalid object", obj) 177 | end 178 | end 179 | 180 | function sandbox.init(list) 181 | dummy_cache = setmetatable({}, weak) 182 | dummy_module_cache = setmetatable({}, weak) 183 | 184 | for k,v in pairs(_LOADED_DUMMY) do 185 | _LOADED_DUMMY[k] = nil 186 | end 187 | for k,v in pairs(_LOADED) do 188 | _LOADED[k] = nil 189 | end 190 | if list then 191 | for _,name in ipairs(list) do 192 | _LOADED_DUMMY[name] = make_dummy_module(name) 193 | end 194 | end 195 | end 196 | 197 | function sandbox.isdummy(v) 198 | if safe_function[v] then 199 | return true 200 | end 201 | return getmetatable(v) ~= nil 202 | end 203 | 204 | function sandbox.module(name) 205 | return _LOADED[name] 206 | end 207 | 208 | function sandbox.clear() 209 | dummy_cache = nil 210 | dummy_module_cache = nil 211 | for k, v in pairs(_LOADED) do 212 | _LOADED[k] = nil 213 | end 214 | end 215 | 216 | end -- sandbox end 217 | 218 | function reload.list() 219 | local list = {} 220 | for k in pairs(debug.getregistry()._LOADED) do 221 | table.insert(list, k) 222 | end 223 | return list 224 | end 225 | 226 | local accept_key_type = { 227 | number = true, 228 | string = true, 229 | boolean = true, 230 | } 231 | 232 | local function enum_object(value) 233 | local print = reload.print 234 | local all = {} 235 | local path = {} 236 | local objs = {} 237 | local function iterate(value) 238 | if sandbox.isdummy(value) then 239 | if print then print("ENUM", value, table.concat(path, ".")) end 240 | table.insert(all, { value, table.unpack(path) }) 241 | return 242 | end 243 | local t = type(value) 244 | if t == "function" or t == "table" then 245 | if print then print("ENUM", value, table.concat(path, ".")) end 246 | table.insert(all, { value, table.unpack(path) }) 247 | if objs[value] then 248 | -- already unfold 249 | return 250 | end 251 | objs[value] = true 252 | else 253 | return 254 | end 255 | local depth = #path + 1 256 | if t == "function" then 257 | local i = 1 258 | while true do 259 | local name, v = debug.getupvalue(value, i) 260 | if name == nil or name == "" then 261 | break 262 | else 263 | if not name:find("^[_%w]") then 264 | error("Invalid upvalue : " .. table.concat(path, ".")) 265 | end 266 | local vt = type(v) 267 | if vt == "function" or vt == "table" then 268 | path[depth] = name 269 | path[depth + 1] = i 270 | iterate(v) 271 | path[depth] = nil 272 | path[depth + 1] = nil 273 | end 274 | end 275 | i = i + 1 276 | end 277 | else -- table 278 | for k,v in pairs(value) do 279 | if not accept_key_type[type(k)] then 280 | error("Invalid key : " .. k .. " " .. table.concat(path, ".")) 281 | end 282 | path[depth] = k 283 | iterate(v) 284 | path[depth] = nil 285 | end 286 | end 287 | end 288 | iterate(value) 289 | return all 290 | end 291 | 292 | local function find_object(mod, name, id , ...) 293 | if mod == nil or name == nil then 294 | return mod 295 | end 296 | local t = type(mod) 297 | if t == "table" then 298 | return find_object(mod[name] , id , ...) 299 | else 300 | assert(t == "function") 301 | local i = 1 302 | while true do 303 | local n, value = debug.getupvalue(mod, i) 304 | if n == nil or name == "" then 305 | return 306 | end 307 | if n == name then 308 | return find_object(value, ...) 309 | end 310 | i = i + 1 311 | end 312 | end 313 | end 314 | 315 | local function match_objects(objects, old_module, map, globals) 316 | local print = reload.print 317 | for _, item in ipairs(objects) do 318 | local obj = item[1] 319 | if sandbox.isdummy(obj) then 320 | table.insert(globals, item) 321 | else 322 | local ok, old_one = pcall(find_object,old_module, table.unpack(item, 2)) 323 | if not ok then 324 | local current = { table.unpack(item, 2) } 325 | error ( "type mismatch : " .. table.concat(current, ",") ) 326 | end 327 | if old_one == nil then 328 | map[obj] = map[obj] or false 329 | elseif type(old_one) ~= type(obj) then 330 | local current = { table.unpack(item, 2) } 331 | error ( "Not a table : " .. table.concat(current, ",") ) 332 | end 333 | if map[obj] and map[obj] ~= old_one then 334 | local current = { table.unpack(item, 2) } 335 | error ( "Ambiguity table : " .. table.concat(current, ",") ) 336 | end 337 | map[obj] = old_one 338 | if print then print("MATCH", old_one, table.unpack(item,2)) end 339 | end 340 | end 341 | end 342 | 343 | local function find_upvalue(func, name) 344 | if not func then 345 | return 346 | end 347 | local i = 1 348 | while true do 349 | local n,v = debug.getupvalue(func, i) 350 | if n == nil or name == "" then 351 | return 352 | end 353 | if n == name then 354 | return i 355 | end 356 | i = i + 1 357 | end 358 | end 359 | 360 | local function match_upvalues(map, upvalues) 361 | for new_one , old_one in pairs(map) do 362 | if type(new_one) == "function" then 363 | local i = 1 364 | while true do 365 | local name, value = debug.getupvalue(new_one, i) 366 | if name == nil or name == "" then 367 | break 368 | end 369 | local old_index = find_upvalue(old_one, name) 370 | local id = debug.upvalueid(new_one, i) 371 | if not upvalues[id] and old_index then 372 | upvalues[id] = { 373 | func = old_one, 374 | index = old_index, 375 | oldid = debug.upvalueid(old_one, old_index), 376 | } 377 | elseif old_index then 378 | local oldid = debug.upvalueid(old_one, old_index) 379 | if oldid ~= upvalues[id].oldid then 380 | error (string.format("Ambiguity upvalue : %s .%s",tostring(new_one),name)) 381 | end 382 | end 383 | i = i + 1 384 | end 385 | end 386 | end 387 | end 388 | 389 | local function reload_list(list) 390 | local _LOADED = debug.getregistry()._LOADED 391 | local all = {} 392 | for _, mod in ipairs(list) do 393 | sandbox.require(mod) 394 | local m = sandbox.module(mod) 395 | local objs = enum_object(m.module) 396 | local old_module = _LOADED[mod] 397 | local result = { 398 | globals = {}, 399 | map = {}, 400 | upvalues = {}, 401 | old_module = old_module, 402 | module = m , 403 | objects = objs 404 | } 405 | all[mod] = result 406 | match_objects(objs, old_module, result.map, result.globals) -- find match table/func between old module and new one 407 | match_upvalues(result.map, result.upvalues) -- find match func's upvalues 408 | end 409 | return all 410 | end 411 | 412 | local function set_object(v, mod, name, tmore, fmore, ...) 413 | if mod == nil then 414 | return false 415 | end 416 | if type(mod) == "table" then 417 | if not tmore then -- no more 418 | mod[name] = v 419 | return true 420 | end 421 | return set_object(v, mod[name], tmore, fmore, ...) 422 | else 423 | local i = 1 424 | while true do 425 | local n, value = debug.getupvalue(mod, i) 426 | if n == nil or name == "" then 427 | return false 428 | end 429 | if n == name then 430 | if not fmore then 431 | debug.setupvalue(mod, i, v) 432 | return true 433 | end 434 | return set_object(v, value, fmore, ...) -- skip tmore (id) 435 | end 436 | i = i + 1 437 | end 438 | end 439 | end 440 | 441 | local function patch_funcs(upvalues, map) 442 | local print = reload.print 443 | for value in pairs(map) do 444 | if type(value) == "function" then 445 | local i = 1 446 | while true do 447 | local name,v = debug.getupvalue(value, i) 448 | if name == nil or name == "" then 449 | break 450 | end 451 | local id = debug.upvalueid(value, i) 452 | local uv = upvalues[id] 453 | if uv then 454 | if print then print("JOIN", value, name) end 455 | debug.upvaluejoin(value, i, uv.func, uv.index) 456 | end 457 | i = i + 1 458 | end 459 | end 460 | end 461 | end 462 | 463 | local function merge_objects(all) 464 | local REG = debug.getregistry() 465 | local _LOADED = REG._LOADED 466 | local print = reload.print 467 | for mod_name, data in pairs(all) do 468 | if data.old_module then 469 | local map = data.map 470 | patch_funcs(data.upvalues, map) 471 | for new_one, old_one in pairs(map) do 472 | if type(new_one) == "table" and old_one then 473 | -- merge new_one into old_one 474 | if print then print("COPY", old_one) end 475 | for k,v in pairs(new_one) do 476 | if type(v) ~= "table" or -- copy values not a table 477 | getmetatable(v) ~= nil or -- copy dummy 478 | old_one[k] == nil then -- copy new object 479 | old_one[k] = v 480 | end 481 | end 482 | end 483 | end 484 | for _, item in ipairs(data.objects) do 485 | local v = item[1] 486 | if not sandbox.isdummy(v) then 487 | if not map[v] then 488 | -- insert new object 489 | local ok = set_object(v, data.old_module, table.unpack(item,2)) 490 | if print then print("MOVE", mod_name, table.concat(item,".",2),ok) end 491 | end 492 | end 493 | end 494 | else 495 | _LOADED[mod_name] = data.module.module 496 | end 497 | end 498 | end 499 | 500 | local function solve_globals(all) 501 | local _LOADED = debug.getregistry()._LOADED 502 | local print = reload.print 503 | local i = 0 504 | for mod_name, data in pairs(all) do 505 | for gk, item in pairs(data.globals) do 506 | -- solve one global 507 | local v = item[1] 508 | local path = tostring(v) 509 | local value 510 | local unsolved 511 | local invalid 512 | if getmetatable(v) == "GLOBAL" then 513 | local G = _G 514 | for w in string.gmatch(path, "[_%a]%w*") do 515 | if G == nil then 516 | invalid = true 517 | break 518 | end 519 | G=G[w] 520 | end 521 | value = G 522 | else 523 | -- "MODULE" 524 | local from, to, name = string.find(path, "^%[([_%w]+)%]") 525 | if from == nil then 526 | invalid = true 527 | break 528 | end 529 | local mod = _LOADED[name] 530 | if mod == nil then 531 | invalid = true 532 | break 533 | end 534 | for w in string.gmatch(path:sub(to+1), "[_%a]%w*") do 535 | if mod == nil then 536 | invalid = true 537 | break 538 | end 539 | mod=mod[w] 540 | end 541 | local mt = getmetatable(mod) 542 | if mt == "MODULE" then 543 | else 544 | unsolved = true 545 | value = mod 546 | end 547 | end 548 | if invalid then 549 | if print then print("GLOBAL INVALID", path) end 550 | data.globals[gk] = nil 551 | elseif not unsolved then 552 | i = i + 1 553 | if print then print("GLOBAL", path, value) end 554 | set_object(value, _LOADED[mod_name], table.unpack(item,2)) 555 | data.globals[gk] = nil 556 | end 557 | end 558 | end 559 | return i 560 | end 561 | 562 | local function update_funcs(map) 563 | local root = debug.getregistry() 564 | local co = coroutine.running() 565 | local exclude = { [map] = true , [co] = true } 566 | local getmetatable = debug.getmetatable 567 | local getinfo = debug.getinfo 568 | local getlocal = debug.getlocal 569 | local setlocal = debug.setlocal 570 | local getupvalue = debug.getupvalue 571 | local setupvalue = debug.setupvalue 572 | local getuservalue = debug.getuservalue 573 | local setuservalue = debug.setuservalue 574 | local type = type 575 | local next = next 576 | local rawset = rawset 577 | 578 | exclude[exclude] = true 579 | 580 | 581 | local update_funcs_ 582 | 583 | local function update_funcs_frame(co,level) 584 | local info = getinfo(co, level+1, "f") 585 | if info == nil then 586 | return 587 | end 588 | local f = info.func 589 | info = nil 590 | update_funcs_(f) 591 | local i = 1 592 | while true do 593 | local name, v = getlocal(co, level+1, i) 594 | if name == nil then 595 | if i > 0 then 596 | i = -1 597 | else 598 | break 599 | end 600 | end 601 | local nv = map[v] 602 | if nv then 603 | setlocal(co, level+1, i, nv) 604 | update_funcs_(nv) 605 | else 606 | update_funcs_(v) 607 | end 608 | if i > 0 then 609 | i = i + 1 610 | else 611 | i = i - 1 612 | end 613 | end 614 | return update_funcs_frame(co, level+1) 615 | end 616 | 617 | function update_funcs_(root) -- local function 618 | if exclude[root] then 619 | return 620 | end 621 | local t = type(root) 622 | if t == "table" then 623 | exclude[root] = true 624 | local mt = getmetatable(root) 625 | if mt then update_funcs_(mt) end 626 | local tmp 627 | for k,v in next, root do 628 | local nv = map[v] 629 | if nv then 630 | rawset(root,k,nv) 631 | update_funcs_(nv) 632 | else 633 | update_funcs_(v) 634 | end 635 | local nk = map[k] 636 | if nk then 637 | if tmp == nil then 638 | tmp = {} 639 | end 640 | tmp[k] = nk 641 | else 642 | update_funcs_(k) 643 | end 644 | end 645 | if tmp then 646 | for k,v in next, tmp do 647 | root[k], root[v] = nil, root[k] 648 | update_funcs_(v) 649 | end 650 | tmp = nil 651 | end 652 | elseif t == "userdata" then 653 | exclude[root] = true 654 | local mt = getmetatable(root) 655 | if mt then update_funcs_(mt) end 656 | local uv = getuservalue(root) 657 | if uv then 658 | local tmp = map[uv] 659 | if tmp then 660 | setuservalue(root, tmp) 661 | update_funcs_(tmp) 662 | else 663 | update_funcs_(uv) 664 | end 665 | end 666 | elseif t == "thread" then 667 | exclude[root] = true 668 | update_funcs_frame(root,2) 669 | elseif t == "function" then 670 | exclude[root] = true 671 | local i = 1 672 | while true do 673 | local name, v = getupvalue(root, i) 674 | if name == nil then 675 | break 676 | else 677 | local nv = map[v] 678 | if nv then 679 | setupvalue(root, i, nv) 680 | update_funcs_(nv) 681 | else 682 | update_funcs_(v) 683 | end 684 | end 685 | i=i+1 686 | end 687 | end 688 | end 689 | 690 | -- nil, number, boolean, string, thread, function, lightuserdata may have metatable 691 | for _,v in pairs { nil, 0, true, "", co, update_funcs, debug.upvalueid(update_funcs,1) } do 692 | local mt = getmetatable(v) 693 | if mt then update_funcs_(mt) end 694 | end 695 | 696 | update_funcs_frame(co, 2) 697 | update_funcs_(root) 698 | end 699 | 700 | function reload.reload(list) 701 | local print = reload.print 702 | local REG = debug.getregistry() 703 | local _LOADED = REG._LOADED 704 | local need_reload = {} 705 | for _,mod in ipairs(list) do 706 | need_reload[mod] = true 707 | end 708 | local tmp = {} 709 | for k in pairs(_LOADED) do 710 | if not need_reload[k] then 711 | table.insert(tmp, k) 712 | end 713 | end 714 | sandbox.init(tmp) -- init dummy modoule existed 715 | 716 | local ok, result = xpcall(reload_list, debug.traceback, list) 717 | if not ok then 718 | sandbox.clear() 719 | if print then print("ERROR", result) end 720 | return ok, result 721 | end 722 | 723 | merge_objects(result) 724 | 725 | for _, data in pairs(result) do 726 | if data.module.loader then 727 | debug.setupvalue(data.module.loader, 1, _ENV) 728 | end 729 | end 730 | 731 | repeat 732 | local n = solve_globals(result) 733 | until n == 0 734 | 735 | local func_map = {} 736 | for _, data in pairs(result) do 737 | for k,v in pairs(data.map) do 738 | if type(k) == "function" then 739 | func_map[v] = k 740 | end 741 | end 742 | end 743 | result = nil 744 | sandbox.clear() 745 | 746 | update_funcs(func_map) 747 | 748 | return true 749 | end 750 | 751 | return reload 752 | -------------------------------------------------------------------------------- /test.lua: -------------------------------------------------------------------------------- 1 | local reload = require "reload" 2 | reload.postfix = "_update" -- for test 3 | 4 | local mymod = require "mymod" 5 | 6 | function reload.print(...) 7 | print(" DEBUG:", ...) 8 | end 9 | 10 | mymod.foobar(42) 11 | 12 | local tmp = {} 13 | local foo = mymod.foo2() 14 | tmp[foo] = foo 15 | print("FOO before", foo) 16 | 17 | local obj = mymod.new() 18 | 19 | obj:show() 20 | 21 | function test() 22 | print("BEFORE update foo", foo) 23 | reload.reload { "mymod" } 24 | print("AFTER update foo", foo) 25 | end 26 | 27 | test() 28 | foo() 29 | 30 | print("FOO after", foo) 31 | assert(tmp[foo] == foo) 32 | 33 | obj:show() 34 | --------------------------------------------------------------------------------