├── README.md └── drdr.rb /README.md: -------------------------------------------------------------------------------- 1 | # drdr 2 | -------------------------------------------------------------------------------- /drdr.rb: -------------------------------------------------------------------------------- 1 | require 'thread' 2 | 3 | class DRError < RuntimeError 4 | end 5 | 6 | class DRTask 7 | attr_accessor :inputs, :outputs 8 | attr_reader :name, :tid, :ckpt, :run, :done, :result 9 | 10 | def initialize(name=nil, tid=0, ckpt: nil, &proc) 11 | @inputs = [] 12 | @outputs = [] 13 | @proc = proc 14 | @tid = tid 15 | @name = name 16 | @ckpt = ckpt 17 | @run = false 18 | @done = false 19 | @debug = false 20 | end 21 | 22 | def |(r) 23 | tasks.each do |it| 24 | r.tasks.each do |ot| 25 | ot.inputs << it 26 | it.outputs << ot 27 | end 28 | end 29 | r 30 | end 31 | 32 | def +(r) 33 | DRTaskGroup.new([*tasks, *r.tasks]) 34 | end 35 | 36 | def tasks 37 | [self] 38 | end 39 | 40 | def to_s 41 | n = @tid != @name ? "(#{@name})" : "" 42 | "Task#{@tid}#{n}" 43 | end 44 | 45 | def inspect 46 | ins = inputs.map{|i|i.name} * " " 47 | "#{to_s}: #{ins}" 48 | end 49 | 50 | def can_run 51 | !@run && @inputs.all?{|i|i.done} 52 | end 53 | 54 | def maybe_skip 55 | if @ckpt 56 | if File.exist?(@ckpt) 57 | @done = true 58 | File.open(@ckpt) do |f| 59 | @result = Marshal.load(f) 60 | end 61 | return true 62 | end 63 | end 64 | false 65 | end 66 | 67 | def start 68 | @run = true 69 | end 70 | 71 | def finish(r) 72 | if @ckpt 73 | File.open(@ckpt, 'w'){|of| 74 | Marshal.dump(r, of) 75 | } 76 | end 77 | 78 | @done = true 79 | @result = r 80 | end 81 | 82 | def run 83 | ins = @inputs.map{|i|i.result} 84 | task_proc = @proc 85 | drdr(**Thread::current[:kwargs]) { 86 | instance_exec(*ins, &task_proc) 87 | } 88 | end 89 | 90 | end 91 | 92 | class DRTaskGroup < DRTask 93 | attr_accessor :tasks 94 | 95 | def initialize(tasks) 96 | @tasks = tasks 97 | end 98 | 99 | def inspect 100 | ts = tasks.map{|i|i.name} * " " 101 | "TaskGroup(#{ts})" 102 | end 103 | end 104 | 105 | class DRGraph 106 | 107 | def initialize(log: STDERR, seq: false, &proc) 108 | # Parameters which need to be propagated to the next subgraph. 109 | @kwargs = { 110 | :log => log, 111 | :seq => seq, 112 | } 113 | 114 | @proc = proc 115 | @tasks = {} 116 | @tid = 0 117 | @thid = 0 118 | @log = log 119 | @seq = seq 120 | 121 | @threads = {} 122 | @mu = Mutex.new 123 | @cond = ConditionVariable.new 124 | 125 | @results = *instance_eval(&@proc) 126 | end 127 | 128 | def run 129 | if @tasks.empty? 130 | if !Thread.current[:is_sub] 131 | @log << "DR: No task in the graph\n" 132 | end 133 | return @results.size == 1 ? @results[0] : @results 134 | end 135 | 136 | sub = Thread.current[:is_sub] 137 | @log << "DR: execute #{sub}graph with #{@tasks.size} tasks\n" 138 | STDERR.puts "DR: About to execute a graph:\n#{inspect}\n\n" if @debug 139 | 140 | analyze 141 | 142 | run_loop 143 | @threads.each do |_, th| 144 | th.join 145 | end 146 | 147 | results = @results.map do |r| 148 | if r.is_a? DRTask 149 | r.result 150 | else 151 | r 152 | end 153 | end 154 | results.size == 1 ? results[0] : results 155 | end 156 | 157 | def traverse(task, seen, ntasks) 158 | raise DRError.new("Cyclic dependency detected") if seen[task] 159 | seen[task] = true 160 | 161 | if task.maybe_skip 162 | @log << "DR: there is a ckpt #{task.ckpt} for #{task}\n" 163 | return 164 | end 165 | ntasks[task.tid] = task 166 | 167 | task.inputs.each do |it| 168 | traverse(it, seen, ntasks) 169 | end 170 | seen[task] = false 171 | end 172 | 173 | def analyze 174 | goals = [] 175 | @tasks.each do |_, task| 176 | if task.outputs.empty? 177 | goals << task 178 | end 179 | end 180 | 181 | ntasks = {} 182 | raise DRError.new("Cyclic dependency detected") if goals.empty? 183 | goals.each do |goal| 184 | traverse(goal, {}, ntasks) 185 | end 186 | 187 | if @tasks.size != ntasks.size 188 | diff = @tasks.size - ntasks.size 189 | @log << "DR: #{diff} tasks were skipped thanks to ckpts\n" 190 | @tasks = ntasks 191 | end 192 | end 193 | 194 | def run_loop 195 | loop do 196 | @mu.synchronize do 197 | launch_tasks 198 | return if @threads.empty? 199 | @cond.wait(@mu) 200 | raise @exception if @exception 201 | end 202 | end 203 | end 204 | 205 | def launch_tasks 206 | @tasks.each do |_, task| 207 | if task.can_run 208 | @log << "DR: start #{task}\n" 209 | STDERR.puts "DR: start #{task.inspect}" if @debug 210 | task.start 211 | thid = @thid += 1 212 | th = Thread.start do 213 | Thread::current[:kwargs] = @kwargs 214 | run_task(task, thid) 215 | end 216 | @threads[thid] = th 217 | if @seq 218 | break 219 | end 220 | end 221 | end 222 | end 223 | 224 | def run_task(task, thid) 225 | Thread.current[:is_sub] = 'sub' 226 | begin 227 | result = task.run 228 | rescue => e 229 | @mu.synchronize do 230 | @exception = e 231 | @cond.signal 232 | return 233 | end 234 | end 235 | 236 | @mu.synchronize do 237 | STDERR.puts "DR: finish (#{result}) #{task.inspect}" if @debug 238 | task.finish(result) 239 | @threads.delete(thid) 240 | @cond.signal 241 | end 242 | end 243 | 244 | def task(name=nil, **kwargs, &proc) 245 | @mu.synchronize do 246 | @tid += 1 247 | name ||= @tid 248 | task = @tasks[@tid] = DRTask.new(name, @tid, **kwargs, &proc) 249 | @cond.signal 250 | task 251 | end 252 | end 253 | 254 | def shell_escape(s) 255 | s.gsub('\\', '\\\\').gsub("'", '\\\'') 256 | end 257 | 258 | def cmd(args, name=nil, stdout: nil, **kwargs) 259 | name ||= "'#{[*args].map{|a|shell_escape(a)} * ' '}'" 260 | task(name, **kwargs) do |*ins| 261 | if ins.size > 1 262 | raise DRError.new("`cmd` takes only a single input but comes #{ins}") 263 | elsif ins.size == 1 264 | instr = ins[0].to_s 265 | else 266 | instr = '' 267 | end 268 | 269 | if stdout 270 | pipe = IO.popen(args, 'w:binary') 271 | pipe.print instr 272 | pipe.close 273 | result = $?.exitstatus 274 | else 275 | pipe = IO.popen(args, 'r+:binary') 276 | pipe.print instr 277 | pipe.close_write 278 | result = pipe.read 279 | pipe.close 280 | end 281 | if !$?.success? 282 | msg = "cmd #{name} failed (status=#{$?.exitstatus})" 283 | raise DRError.new(msg) 284 | end 285 | result 286 | end 287 | end 288 | 289 | def show 290 | puts inspect 291 | end 292 | 293 | def inspect 294 | @tasks.map do |_, task| 295 | task.inspect 296 | end * "\n" 297 | end 298 | 299 | def debug 300 | @debug = true 301 | end 302 | 303 | end 304 | 305 | 306 | def drdr(log: STDERR, seq: false, &proc) 307 | DRGraph.new(log: log, seq: seq, &proc).run 308 | end 309 | 310 | 311 | if $0 == __FILE__ 312 | require 'test/unit' 313 | require 'fileutils' 314 | require 'tmpdir' 315 | 316 | class DrdrTest < Test::Unit::TestCase 317 | 318 | def setup 319 | @testdir = "#{Dir.tmpdir}/drdr_test" 320 | FileUtils.rm_r(@testdir) 321 | FileUtils.mkdir_p(@testdir) 322 | Dir.chdir(@testdir) 323 | end 324 | 325 | def test_drdr 326 | assert_equal (42/2)+(42*2), drdr { 327 | task{ 42 } | task{|x|x / 2} + task{|x|x * 2} | task{|x, y|x + y} 328 | } 329 | end 330 | 331 | def test_access_local 332 | x = nil 333 | y = nil 334 | drdr { 335 | task{ x = 42 } 336 | task{ y = 99 } 337 | } 338 | assert_equal 42, x 339 | assert_equal 99, y 340 | end 341 | 342 | def test_access_local2 343 | x = nil 344 | y = nil 345 | drdr { 346 | task{ x = 42 } + task{ y = 99 } 347 | } 348 | assert_equal 42, x 349 | assert_equal 99, y 350 | end 351 | 352 | def test_add_task 353 | x = 0 354 | drdr { 355 | task{ 356 | 1.upto(10){|i| 357 | task{ x += i } 358 | } 359 | } 360 | } 361 | assert_equal 55, x 362 | end 363 | 364 | class TestError < RuntimeError 365 | end 366 | 367 | class ShouldntHappen < RuntimeError 368 | end 369 | 370 | def test_raise 371 | assert_raise TestError do 372 | drdr { 373 | task{ raise TestError.new } | task{ raise ShouldntHappen.new } 374 | } 375 | end 376 | end 377 | 378 | def test_log 379 | log = '' 380 | drdr(log: log) { 381 | task('hoge'){} | task('fuga'){} 382 | } 383 | assert_match /hoge.*fuga/m, log 384 | end 385 | 386 | def test_cmd 387 | assert_equal "fxo\n", drdr { 388 | cmd(%W(echo foo)) | cmd(%W(sed s/o/x/)) 389 | } 390 | end 391 | 392 | def test_cmd_fail 393 | assert_raise DRError do 394 | drdr { 395 | cmd("false") | task{ raise ShouldntHappen.new } 396 | } 397 | end 398 | end 399 | 400 | def test_cyclic 401 | assert_raise DRError do 402 | drdr { 403 | a = task{} 404 | b = task{} 405 | a | b 406 | b | a 407 | } 408 | end 409 | end 410 | 411 | def test_cyclic2 412 | assert_raise DRError do 413 | drdr { 414 | a = task{} 415 | b = task{} 416 | c = task{} 417 | a | b | c 418 | b | a 419 | } 420 | end 421 | end 422 | 423 | def test_ckpt 424 | assert_equal "foo\n", drdr { 425 | cmd("echo foo", ckpt: "foo") | task{|i|i} 426 | } 427 | assert_true File.exist?("foo") 428 | 429 | assert_equal "foo\nbar", drdr { 430 | task(ckpt: "foo"){ raise ShouldntHappen.new } | task{|x|x + "bar"} 431 | } 432 | end 433 | 434 | def test_variable 435 | assert_equal ["foo", "barbaz", 42], drdr { 436 | foo = task{"foo"} 437 | barbaz = task{"bar"} | task{|x|x+"baz"} 438 | [foo, barbaz, 42] 439 | } 440 | end 441 | 442 | def test_nest_drdr 443 | assert_equal "foo", drdr { 444 | task { drdr { task { "foo" } } } 445 | } 446 | assert_equal "foobar", drdr { 447 | task { drdr { task { "foo" } } } | task{|x|x + "bar"} 448 | } 449 | end 450 | 451 | def test_nest_task 452 | assert_equal "foo", drdr { 453 | task { task { "foo" } } 454 | } 455 | assert_equal "foobar", drdr { 456 | task { task { "foo" } } | task{|x|x + "bar"} 457 | } 458 | end 459 | 460 | def test_nest_task_graph_id 461 | x = nil 462 | y = nil 463 | z = nil 464 | drdr { 465 | x = object_id 466 | task { 467 | y = object_id 468 | task { 469 | z = object_id 470 | } 471 | } 472 | } 473 | assert_not_equal x, y 474 | assert_not_equal x, z 475 | assert_not_equal y, z 476 | end 477 | 478 | def test_empty_drdr 479 | drdr {} 480 | end 481 | 482 | def test_seq 483 | assert_raise TestError do 484 | drdr(seq: true) { 485 | task{ 486 | 10.times{ Thread.pass } 487 | raise TestError.new 488 | } 489 | task{ raise ShouldntHappen.new } 490 | } 491 | end 492 | end 493 | 494 | def test_seq_nested 495 | assert_raise TestError do 496 | drdr(seq: true) { 497 | # One more nest. `seq` must be propagated. 498 | task { 499 | task{ 500 | 10.times{ Thread.pass } 501 | raise TestError.new 502 | } 503 | task{ raise ShouldntHappen.new } 504 | } 505 | } 506 | end 507 | end 508 | 509 | def test_drdr_no_task 510 | assert_equal "foo", drdr { "foo" } 511 | end 512 | 513 | end 514 | end 515 | --------------------------------------------------------------------------------