├── README.md ├── parallel.py ├── split.py ├── normalize.py ├── cache_write.py ├── cache_read.py ├── prefetch.py ├── rfactor.py ├── unroll.py ├── fuse.py ├── set_scope.py ├── tile.py ├── reorder.py ├── pragma.py ├── bind.py ├── storage_align.py ├── vectorize.py ├── compute_at.py ├── set_store_predicate.py ├── compute_root.py ├── create_group.py ├── compute_inline.py └── tensorize.py /README.md: -------------------------------------------------------------------------------- 1 | # tvm.schedule 2 | examples for tvm schedule API 3 | 4 | run them by: python xxx.py 5 | -------------------------------------------------------------------------------- /parallel.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | n = 1024 3 | m = 1024 4 | 5 | A = tvm.placeholder((n, m), name='A') 6 | l = tvm.reduce_axis((0, m), name = 'l') 7 | 8 | B = tvm.compute((n,), lambda i: tvm.sum(A[i, l], axis=l), name='B') 9 | 10 | s = tvm.create_schedule(B.op) 11 | 12 | print(tvm.lower(s, [A, B], simple_mode=True)) 13 | print("---------cutting line---------") 14 | 15 | s[B].parallel(B.op.reduce_axis[0]) 16 | print(tvm.lower(s, [A, B], simple_mode=True)) -------------------------------------------------------------------------------- /split.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | A = tvm.placeholder((n,), name='A') 5 | k = tvm.reduce_axis((0, n), name='k') 6 | 7 | B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B') 8 | 9 | s = tvm.create_schedule(B.op) 10 | 11 | print(tvm.lower(s, [A, B], simple_mode=True)) 12 | print("---------cutting line---------") 13 | 14 | ko, ki = s[B].split(B.op.reduce_axis[0], factor=32) 15 | 16 | print(tvm.lower(s, [A, B], simple_mode=True)) -------------------------------------------------------------------------------- /normalize.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = tvm.var('n') 4 | A = tvm.placeholder((n,), name='A') 5 | B = tvm.placeholder((n,), name='B') 6 | k = tvm.reduce_axis((10, n), 'k') 7 | C = tvm.compute((1,), lambda _: tvm.sum(A[k] * B[k], axis=k), name='C') 8 | 9 | s = tvm.create_schedule(C.op) 10 | print(tvm.lower(s, [A, B, C], simple_mode=True)) 11 | print("---------cutting line---------") 12 | s = s.normalize() 13 | print(tvm.lower(s, [A, B, C], simple_mode=True)) -------------------------------------------------------------------------------- /cache_write.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | dtype = "float32" 5 | A = tvm.placeholder((n, n), dtype=dtype, name='A') 6 | k = tvm.reduce_axis((0, n), name='k') 7 | B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name='B') 8 | 9 | s = tvm.create_schedule(B.op) 10 | 11 | print(tvm.lower(s, [A, B], simple_mode=True)) 12 | print("---------cutting line---------") 13 | 14 | BW = s.cache_write(B, "local") 15 | 16 | print(tvm.lower(s, [A, B], simple_mode=True)) -------------------------------------------------------------------------------- /cache_read.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | dtype = "float32" 5 | A = tvm.placeholder((n, n), dtype=dtype, name='A') 6 | k = tvm.reduce_axis((0, n), name='k') 7 | B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name='B') 8 | 9 | s = tvm.create_schedule(B.op) 10 | 11 | print(tvm.lower(s, [A, B], simple_mode=True)) 12 | print("---------cutting line---------") 13 | 14 | AA = s.cache_read(A, "shared", [B]) 15 | 16 | print(tvm.lower(s, [A, B], simple_mode=True)) -------------------------------------------------------------------------------- /prefetch.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | dtype = "float32" 5 | k = tvm.reduce_axis((0, n), name='k') 6 | A = tvm.placeholder((n, n), dtype=dtype, name='A') 7 | B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name='B') 8 | 9 | s = tvm.create_schedule(B.op) 10 | 11 | print(tvm.lower(s, [A, B], simple_mode=True)) 12 | print("---------cutting line---------") 13 | 14 | s[B].prefetch(A, s[B].op.reduce_axis[0], 1) 15 | print(tvm.lower(s, [A, B], simple_mode=True)) -------------------------------------------------------------------------------- /rfactor.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | k = tvm.reduce_axis((0, n), name='k') 5 | 6 | A = tvm.placeholder((n,), name='A') 7 | B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B') 8 | 9 | s = tvm.create_schedule(B.op) 10 | ko, ki = s[B].split(s[B].op.reduce_axis[0], 32) 11 | 12 | print(tvm.lower(s, [A, B], simple_mode=True)) 13 | print("---------cutting line---------") 14 | 15 | BR = s.rfactor(B, ki) 16 | 17 | print(tvm.lower(s, [A, B], simple_mode=True)) -------------------------------------------------------------------------------- /unroll.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | A = tvm.placeholder((n, n), name='A') 5 | B = tvm.placeholder((n, n), name='B') 6 | C = tvm.compute((n, n), lambda i, j: A[i, j] + B[i, j], name='C') 7 | 8 | s = tvm.create_schedule(C.op) 9 | 10 | xo, xi = s[C].split(s[C].op.axis[0], factor=4) 11 | 12 | print(tvm.lower(s, [A, B, C], simple_mode=True)) 13 | print("---------cutting line---------") 14 | 15 | s[C].unroll(xi) 16 | 17 | print(tvm.lower(s, [A, B, C], simple_mode=True)) -------------------------------------------------------------------------------- /fuse.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | A = tvm.placeholder((n,), name='A') 5 | k = tvm.reduce_axis((0, n), name='k') 6 | 7 | B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B') 8 | 9 | s = tvm.create_schedule(B.op) 10 | 11 | ko, ki = s[B].split(B.op.reduce_axis[0], factor=32) 12 | 13 | print(tvm.lower(s, [A, B], simple_mode=True)) 14 | print("---------cutting line---------") 15 | 16 | s[B].fuse(ko, ki) 17 | 18 | print(tvm.lower(s, [A, B], simple_mode=True)) -------------------------------------------------------------------------------- /set_scope.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | dtype = "float32" 5 | A = tvm.placeholder((n, n), dtype=dtype, name='A') 6 | k = tvm.reduce_axis((0, n), name='k') 7 | B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name='B') 8 | C = tvm.compute((n,), lambda i: B[i] + 10, name='C') 9 | 10 | s = tvm.create_schedule(C.op) 11 | 12 | print(tvm.lower(s, [A, C], simple_mode=True)) 13 | print("---------cutting line---------") 14 | 15 | s[B].set_scope('shared') 16 | 17 | print(tvm.lower(s, [A, C], simple_mode=True)) 18 | -------------------------------------------------------------------------------- /tile.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | A = tvm.placeholder((n, n), name='A') 5 | B = tvm.placeholder((n, n), name='B') 6 | K = tvm.reduce_axis((0, n), name='K') 7 | C = tvm.compute((n, n), lambda i, j: tvm.sum(A[i, K] * B[K, j], axis=K), name='C') 8 | 9 | s = tvm.create_schedule(C.op) 10 | 11 | print(tvm.lower(s, [A, B, C], simple_mode=True)) 12 | print("---------cutting line---------") 13 | 14 | xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], 32, 32) 15 | 16 | print(tvm.lower(s, [A, B, C], simple_mode=True)) -------------------------------------------------------------------------------- /reorder.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | A = tvm.placeholder((n, n), name='A') 5 | B = tvm.placeholder((n,n), name='B') 6 | C = tvm.compute((n, n), lambda i, j: A[i, j] + B[i, j], name='C') 7 | 8 | s = tvm.create_schedule(C.op) 9 | 10 | xo, xi = s[C].split(s[C].op.axis[0], factor=32) 11 | yo, yi = s[C].split(s[C].op.axis[1], factor=32) 12 | 13 | print(tvm.lower(s, [A, B, C], simple_mode=True)) 14 | print("---------cutting line---------") 15 | 16 | s[C].reorder(xo, yo, yi, xi) 17 | 18 | print(tvm.lower(s, [A, B, C], simple_mode=True)) -------------------------------------------------------------------------------- /pragma.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | m = 1024 5 | A = tvm.placeholder((n, m), name='A') 6 | k = tvm.reduce_axis((0, n), name='k') 7 | l = tvm.reduce_axis((0, m), name = 'l') 8 | 9 | B = tvm.compute((n,), lambda i: tvm.sum(A[i, l], axis=l), name='B') 10 | 11 | s = tvm.create_schedule(B.op) 12 | 13 | ko, ki = s[B].split(B.op.reduce_axis[0], factor=4) 14 | 15 | print(tvm.lower(s, [A, B], simple_mode=True)) 16 | print("---------cutting line---------") 17 | 18 | s[B].pragma(ki, "unroll") 19 | 20 | print(tvm.lower(s, [A, B], simple_mode=True)) -------------------------------------------------------------------------------- /bind.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | A = tvm.placeholder((n,), name='A') 5 | k = tvm.reduce_axis((0, n), name='k') 6 | 7 | B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B') 8 | 9 | s = tvm.create_schedule(B.op) 10 | 11 | ko, ki = s[B].split(B.op.reduce_axis[0], factor=32) 12 | 13 | print(tvm.lower(s, [A, B], simple_mode=True)) 14 | print("---------cutting line---------") 15 | 16 | s[B].bind(ko, tvm.thread_axis("blockIdx.x")) 17 | s[B].bind(ki, tvm.thread_axis("threadIdx.x")) 18 | 19 | print(tvm.lower(s, [A, B], simple_mode=True)) -------------------------------------------------------------------------------- /storage_align.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | factor =100 5 | offset =8 6 | dtype = "float32" 7 | A = tvm.placeholder((n, n), dtype=dtype, name='A') 8 | k = tvm.reduce_axis((0, n), name='k') 9 | B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name='B') 10 | 11 | s = tvm.create_schedule(B.op) 12 | AA = s.cache_read(A, "shared", [B]) 13 | 14 | print(tvm.lower(s, [A, B], simple_mode=True)) 15 | print("---------cutting line---------") 16 | 17 | s[AA].storage_align(AA.op.axis[0], factor, offset) 18 | 19 | print(tvm.lower(s, [A, B], simple_mode=True)) -------------------------------------------------------------------------------- /vectorize.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | import numpy 3 | import timeit 4 | 5 | M = 1024 6 | N = 1024 7 | A = tvm.placeholder((M, N), name='A') 8 | B = tvm.placeholder((M, N), name='B') 9 | C = tvm.compute( 10 | (M, N), 11 | lambda x, y: A[x, y] + B[x, y], 12 | name='C') 13 | 14 | s = tvm.create_schedule(C.op) 15 | xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], 32, 32) 16 | 17 | print(tvm.lower(s, [A, B, C], simple_mode=True)) 18 | print("---------cutting line---------") 19 | 20 | s[C].vectorize(yi) 21 | 22 | print(tvm.lower(s, [A, B, C], simple_mode=True)) -------------------------------------------------------------------------------- /compute_at.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | A = tvm.placeholder((n,), name='A') 5 | k = tvm.reduce_axis((0, n), 'k') 6 | B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B') 7 | 8 | s = tvm.create_schedule(B.op) 9 | ko, ki = s[B].split(B.op.reduce_axis[0], factor=32) 10 | BF = s.rfactor(B, ki) 11 | 12 | tx = tvm.thread_axis("threadIdx.x") 13 | s[B].bind(s[B].op.reduce_axis[0], tx) 14 | 15 | print(tvm.lower(s, [A, B], simple_mode=True)) 16 | print("---------cutting line---------") 17 | 18 | s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) 19 | 20 | print(tvm.lower(s, [A, B], simple_mode=True)) -------------------------------------------------------------------------------- /set_store_predicate.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | A = tvm.placeholder((n,), name='A') 5 | k = tvm.reduce_axis((0, n), 'k') 6 | B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B') 7 | 8 | s = tvm.create_schedule(B.op) 9 | 10 | ko, ki = s[B].split(B.op.reduce_axis[0], factor=16) 11 | BF = s.rfactor(B, ki) 12 | tx = tvm.thread_axis("threadIdx.x") 13 | s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) 14 | 15 | print(tvm.lower(s, [A, B], simple_mode=True)) 16 | print("---------cutting line---------") 17 | 18 | s[B].set_store_predicate(tx.var.equal(0)) 19 | 20 | print(tvm.lower(s, [A, B], simple_mode=True)) -------------------------------------------------------------------------------- /compute_root.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | n = 1024 3 | A = tvm.placeholder((n,), name='A') 4 | k = tvm.reduce_axis((0, n), 'k') 5 | B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B') 6 | 7 | s = tvm.create_schedule(B.op) 8 | 9 | ko, ki = s[B].split(B.op.reduce_axis[0], factor=32) 10 | BF = s.rfactor(B, ki) 11 | 12 | tx = tvm.thread_axis("threadIdx.x") 13 | s[B].bind(s[B].op.reduce_axis[0], tx) 14 | s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) 15 | 16 | print(tvm.lower(s, [A, B], simple_mode=True)) 17 | print("---------cutting line---------") 18 | 19 | s[BF].compute_root() 20 | 21 | print(tvm.lower(s, [A, B], simple_mode=True)) 22 | exit(0) -------------------------------------------------------------------------------- /create_group.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | k = tvm.reduce_axis((0, n), name='k') 5 | 6 | A = tvm.placeholder((n, n), name='A') 7 | B = tvm.placeholder((n, n), name='B') 8 | 9 | D = tvm.compute((n, n), lambda i, j: A[i, j] + B[i, j], name='D') 10 | E = tvm.compute((n, n), lambda i, j: D[i, j] + B[i, j], name='E') 11 | F = tvm.compute((n,), lambda i: tvm.sum(E[i, k], axis=k), name='F') 12 | 13 | s = tvm.create_schedule(F.op) 14 | 15 | print(tvm.lower(s, [A, B, E], simple_mode=True)) 16 | print("---------cutting line---------") 17 | 18 | g = s.create_group(outputs = E, inputs = [A, B], include_inputs=True) 19 | g.compute_at(s[F], F.op.reduce_axis[0]) 20 | 21 | print(tvm.lower(s, [A, B, E], simple_mode=True)) -------------------------------------------------------------------------------- /compute_inline.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | n = 1024 4 | k = 3 5 | pad = 2 6 | A = tvm.placeholder((n, n), name='A') 7 | W = tvm.placeholder((k, k), name='W') 8 | m = (n - k + 2 * pad) + 1 9 | Apad = tvm.compute((n + 2 * pad, n + 2 * pad), 10 | lambda yy, xx: tvm.if_then_else( 11 | tvm.all(yy >= pad, yy < pad + n, xx >= pad, xx < pad + n), 12 | A[yy - pad, xx - pad], tvm.const(0., "float32")), 13 | name='Apad') 14 | 15 | ry = tvm.reduce_axis((0, k), name='ry') 16 | rx = tvm.reduce_axis((0, k), name='rx') 17 | 18 | B = tvm.compute((m, m), 19 | lambda yy, xx: 20 | tvm.sum(Apad[yy + ry, xx + rx] * W[ry, rx], 21 | axis=[ry, rx]), 22 | name='B') 23 | 24 | s = tvm.create_schedule(B.op) 25 | 26 | print(tvm.lower(s, [A, W, B], simple_mode=True)) 27 | print("---------cutting line---------") 28 | 29 | s[Apad].compute_inline() 30 | 31 | print(tvm.lower(s, [A, W, B], simple_mode=True)) 32 | exit(0) -------------------------------------------------------------------------------- /tensorize.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | 3 | N, M, L = 1024, 512, 64 4 | A = tvm.placeholder((N, L), name='A') 5 | B = tvm.placeholder((M, L), name='B') 6 | k = tvm.reduce_axis((0, L), name='k') 7 | C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[j, k], axis=k), name='C') 8 | s = tvm.create_schedule(C.op) 9 | 10 | def intrin_gemv(m, l): 11 | a = tvm.placeholder((l,), name='a') 12 | b = tvm.placeholder((m, l), name='b') 13 | k = tvm.reduce_axis((0, l), name='k') 14 | c = tvm.compute((m,), lambda i: tvm.sum(a[k] * b[i, k], axis=k), name='c') 15 | Abuf = tvm.decl_buffer(a.shape, a.dtype, name='A', offset_factor=1, strides=[1]) 16 | Bbuf = tvm.decl_buffer(b.shape, b.dtype, name='B', offset_factor=1, strides=[tvm.var("s1"), 1]) 17 | Cbuf = tvm.decl_buffer(c.shape, c.dtype, name='C', offset_factor=1, strides=[1]) 18 | 19 | def intrin_func(ins, outs): 20 | ib = tvm.ir_builder.create() 21 | aa, bb = ins 22 | cc = outs[0] 23 | ib.emit(tvm.call_extern("int32", "gemv_update", cc.access_ptr("w"), aa.access_ptr("r"), bb.access_ptr("r"), m, l, bb.strides[0])) 24 | return ib.get() 25 | with tvm.build_config(offset_factor=1): 26 | return tvm.decl_tensor_intrin(c.op, intrin_func, binds={a: Abuf, b: Bbuf, c: Cbuf}) 27 | 28 | factor = 16 29 | x, y = C.op.axis 30 | z, = C.op.reduce_axis 31 | yo, yi = s[C].split(y, factor=factor) 32 | s[C].reorder(x, yo, yi, z) 33 | 34 | gemv = intrin_gemv(factor, L) 35 | 36 | print(tvm.lower(s, [A, B, C], simple_mode=True)) 37 | print("---------cutting line---------") 38 | 39 | s[C].tensorize(yi, gemv) 40 | 41 | print(tvm.lower(s, [A, B, C], simple_mode=True)) --------------------------------------------------------------------------------