├── train_set.npy ├── boston_input.npy ├── train_labels.npy ├── boston_output.npy ├── boston_weights.npy ├── one_hot_label.npy ├── one_hot_train.npy ├── boston_output_raw.npy ├── boston_weights_raw.npy ├── old ├── boston_input_raw.npy ├── tf_test.py ├── gen_mem.py ├── relu_static.py ├── relu_static_wrapped.py ├── maxpool_naive.py ├── norm_dynam.py └── act_top.py ├── simplemult_hostmem.npy ├── simplemult_weights.npy ├── simple_nn.a ├── config.py ├── LICENSE.md ├── gen_one_hot.py ├── activate.py ├── simple_nn.py ├── checker.py ├── isa.py ├── decoder.py ├── simplemult.a ├── architecture.md ├── boston.a ├── boston.out ├── runtpu.py ├── sim.py ├── assembler.py ├── tpu.py ├── tf_nn.py ├── README.md └── matrix.py /train_set.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSBarchlab/OpenTPU/HEAD/train_set.npy -------------------------------------------------------------------------------- /boston_input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSBarchlab/OpenTPU/HEAD/boston_input.npy -------------------------------------------------------------------------------- /train_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSBarchlab/OpenTPU/HEAD/train_labels.npy -------------------------------------------------------------------------------- /boston_output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSBarchlab/OpenTPU/HEAD/boston_output.npy -------------------------------------------------------------------------------- /boston_weights.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSBarchlab/OpenTPU/HEAD/boston_weights.npy -------------------------------------------------------------------------------- /one_hot_label.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSBarchlab/OpenTPU/HEAD/one_hot_label.npy -------------------------------------------------------------------------------- /one_hot_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSBarchlab/OpenTPU/HEAD/one_hot_train.npy -------------------------------------------------------------------------------- /boston_output_raw.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSBarchlab/OpenTPU/HEAD/boston_output_raw.npy -------------------------------------------------------------------------------- /boston_weights_raw.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSBarchlab/OpenTPU/HEAD/boston_weights_raw.npy -------------------------------------------------------------------------------- /old/boston_input_raw.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSBarchlab/OpenTPU/HEAD/old/boston_input_raw.npy -------------------------------------------------------------------------------- /simplemult_hostmem.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSBarchlab/OpenTPU/HEAD/simplemult_hostmem.npy -------------------------------------------------------------------------------- /simplemult_weights.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCSBarchlab/OpenTPU/HEAD/simplemult_weights.npy -------------------------------------------------------------------------------- /old/tf_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | inputs = tf.placeholder(np.float32, [2, 2]) 5 | 6 | m = tf.Variable(np.ones((2, 2))*32, dtype=tf.float32) 7 | 8 | op = tf.matmul(inputs, m) 9 | 10 | sess = tf.Session() 11 | init = tf.global_variables_initializer() 12 | sess.run(init) 13 | 14 | input_val = [[1, 1], [2, 1]] 15 | 16 | print 'inputs: {}\nmatrix: {}'.format(input_val, sess.run(m)) 17 | print sess.run(op, feed_dict={inputs: input_val}) 18 | -------------------------------------------------------------------------------- /simple_nn.a: -------------------------------------------------------------------------------- 1 | # Host mem: N x 8 input matrix 2 | # Weight mem: 8 x 8 weight matrix 3 | RHM 0, 0, 8 # read from host mem addr 0, to UB addr 0, for length N = 4 4 | RW 0 # read weights from dram addr 0 to FIFO 5 | MMC 0, 0, 8 # Do MM on UB addr 0, to accumulator addr 0, for length 4 6 | ACT.Q 0, 0, 8 # Do ACT ReLU on accumulator addr 0, to UB addr 4, for length 4 7 | SYNC 8 | WHM 0, 0, 8 # write result from UB addr 4, to host mem addr 4, for length 4 9 | HLT 10 | -------------------------------------------------------------------------------- /old/gen_mem.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | 5 | args = None 6 | 7 | def gen_mem(path, fill): 8 | np.save(path, fill) 9 | 10 | def parse_args(): 11 | global args 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--path', action='store', 16 | help='path to source file.') 17 | parser.add_argument('--shape', action='store', type=int, nargs='+', 18 | help = 'shape of matrix to generate.') 19 | parser.add_argument('--debug', action='store_true', 20 | help='switch debug prints.') 21 | args = parser.parse_args() 22 | 23 | 24 | if __name__ == '__main__': 25 | parse_args() 26 | mem = np.random.rand(*args.shape) 27 | gen_mem(args.path, mem) 28 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Hardware configuration. 3 | ''' 4 | 5 | MATSIZE = 16 6 | 7 | HOST_ADDR_SIZE = 64 8 | UB_ADDR_SIZE = 12 9 | WEIGHT_DRAM_ADDR_SIZE = 40 10 | ACC_ADDR_SIZE = 16 11 | DWIDTH = 8 12 | INSTRUCTION_WIDTH = 14 * 8 13 | IMEM_ADDR_SIZE = 12 14 | 15 | # values = [host_addr_size, ub_addr_size, weight_dram_addr_size, acc_addr_size, mat_mul_size, data_width] 16 | # 17 | # def set_config(values): 18 | # keys = ['HOST_ADDR_SIZE', 'UB_ADDR_SIZE', 'WEIGHT_DRAM_ADDR_SIZE', 'ACC_ADDR_SIZE', 'MAT_MUL_SIZE', 'DATA_WIDTH'] 19 | # return dict(zip(keys, values)) 20 | # 21 | # config = { 22 | # 'HOST_ADDR_SIZE': host_addr_size, 23 | # 'UB_ADDR_SIZE': ub_addr_size, 24 | # 'WEIGHT_DRAM_ADDR_SIZE': weight_dram_addr_size, 25 | # 'ACC_ADDR_SIZE': acc_addr_size, 26 | # 'MAT_MUL_SIZE': mat_mul_size, 27 | # 'DATA_WIDTH': data_width, 28 | # } 29 | -------------------------------------------------------------------------------- /old/relu_static.py: -------------------------------------------------------------------------------- 1 | # Function: Relu and normalization 2 | # Comments: offset defined during design phase (not runtime) 3 | 4 | import pyrtl 5 | 6 | # relu and normalization 7 | def relu_nrml(din, offset=0): 8 | assert len(din) == 32 9 | assert offset <= 24 10 | dout = pyrtl.WireVector(32) 11 | with pyrtl.conditional_assignment: 12 | with din[-1] == 0: 13 | dout |= din 14 | with pyrtl.otherwise: 15 | dout |= 0 16 | return dout[24-offset:32-offset] 17 | 18 | # Test: collects only the 8 LSBs (after relu) 19 | relu_in = pyrtl.Register(bitwidth=32, name='din') 20 | relu_in.next <<= 300 21 | offset = 24 22 | dout = relu_nrml(relu_in, offset) 23 | relu_out = pyrtl.Register(bitwidth=8, name='dout') 24 | relu_out.next <<= dout 25 | 26 | # simulate the instantiated design for 15 cycles 27 | sim_trace = pyrtl.SimulationTrace() 28 | sim = pyrtl.Simulation(tracer=sim_trace) 29 | for cyle in range(35): 30 | sim.step({}) 31 | sim_trace.render_trace() 32 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Regents of the University of California All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | Neither the name of OpenTPU nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /gen_one_hot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from gen_mem import gen_mem 4 | 5 | args = None 6 | 7 | def gen_one_hot(lower=-5, upper=5, shape=(8, 8)): 8 | #one_hot = np.random.randint(-5, 5, (8, 8), dtype=np.int8) 9 | one_hot = np.random.randint(lower, upper, shape, dtype=np.int8) 10 | # We eigher generate a squre matrix for training or generate a vector for testing. 11 | if shape[0] == shape[1]: 12 | for i in xrange(shape[0]): 13 | one_hot[i, i] = 64 14 | else: 15 | assert shape[1] == 1 16 | for i in xrange(shape[0]): 17 | one_hot[i, 0] = np.random.randint(lower, upper, dtype=np.int8) 18 | return one_hot 19 | 20 | def gen_nn(path, shape, lower=None, upper=None): 21 | #nn = np.random.randint(lower, upper, shape, dtype=np.int8) 22 | nn = gen_one_hot(lower, upper, shape) 23 | print(nn) 24 | gen_mem(path, nn) 25 | 26 | def parse_args(): 27 | global args 28 | 29 | parser = argparse.ArgumentParser() 30 | 31 | parser.add_argument('--path', action='store', 32 | help='dest file.') 33 | parser.add_argument('--shape', action='store', type=int, nargs='+', 34 | help = 'shape of matrix to generate.') 35 | parser.add_argument('--debug', action='store_true', 36 | help='switch debug prints.') 37 | parser.add_argument('--range', type=int, nargs=2, 38 | help='gen rand in [lower, upper)') 39 | args = parser.parse_args() 40 | 41 | 42 | if __name__ == '__main__': 43 | parse_args() 44 | gen_nn(args.path, args.shape, *args.range) 45 | -------------------------------------------------------------------------------- /old/relu_static_wrapped.py: -------------------------------------------------------------------------------- 1 | # Function: Relu and normalization. Start and done signals included 2 | # Latency: 1cc 3 | # Comments: offset defined during design phase (not runtime) 4 | 5 | import pyrtl 6 | 7 | # relu and normalization 8 | def relu_nrml(din, offset): 9 | assert len(din) == 32 10 | assert offset <= 24 11 | dout = pyrtl.WireVector(32) 12 | dout_reg = pyrtl.Register(8) 13 | with pyrtl.conditional_assignment: 14 | with din[-1] == 0: 15 | dout |= din 16 | with pyrtl.otherwise: 17 | dout |= 0 18 | dout_reg.next <<= dout[24-offset:32-offset] 19 | return dout_reg 20 | 21 | def relu_top(din, start, offset=0): 22 | dout = [relu_nrml(din[i], offset) for i in range(len(din))] 23 | done = pyrtl.Register(1) 24 | done.next <<= start 25 | return done, dout 26 | 27 | # Test: collects only the 8 LSBs (after relu) 28 | relu_in = [] 29 | relu_in0 = pyrtl.Register(bitwidth=32, name='din0') 30 | relu_in0.next <<= 300 31 | relu_in.append(relu_in0) 32 | relu_in1 = pyrtl.Register(bitwidth=32, name='din1') 33 | relu_in1.next <<= 200 34 | relu_in.append(relu_in1) 35 | relu_in2 = pyrtl.Register(bitwidth=32, name='din2') 36 | relu_in2.next <<= 100 37 | relu_in.append(relu_in2) 38 | start = pyrtl.Register(bitwidth=1, name='start') 39 | start.next <<= 1 40 | offset = 24 41 | done, dout = relu_top(relu_in, start, offset) 42 | relu_out0 = pyrtl.Register(bitwidth=8, name='dout0') 43 | relu_out0.next <<= dout[0] 44 | relu_out1 = pyrtl.Register(bitwidth=8, name='dout1') 45 | relu_out1.next <<= dout[1] 46 | relu_out2 = pyrtl.Register(bitwidth=8, name='dout2') 47 | relu_out2.next <<= dout[2] 48 | relu_done = pyrtl.Register(bitwidth=1, name='done') 49 | relu_done.next <<= done 50 | 51 | # simulate the instantiated design for 15 cycles 52 | sim_trace = pyrtl.SimulationTrace() 53 | sim = pyrtl.Simulation(tracer=sim_trace) 54 | for cyle in range(35): 55 | sim.step({}) 56 | sim_trace.render_trace() 57 | -------------------------------------------------------------------------------- /activate.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | #import pyrtl 4 | from pyrtl import * 5 | 6 | def relu_vector(vec, offset): 7 | assert offset <= 24 8 | return concat_list([ select(d[-1], falsecase=d, truecase=Const(0, len(d)))[24-offset:32-offset] for d in vec ]) 9 | 10 | def sigmoid(x): 11 | rb = RomBlock(bitwidth=8, addrwidth=3, asynchronous=True, romdata={0: 128, 1: 187, 2: 225, 3: 243, 4: 251, 5: 254, 6: 255, 7: 255, 8: 255}) 12 | x_gt_7 = reduce(lambda x, y: x|y, x[3:]) # OR of bits 3 and up 13 | return select(x_gt_7, falsecase=rb[x[:3]], truecase=Const(255, bitwidth=8)) 14 | 15 | def sigmoid_vector(vec): 16 | return concat_list([ sigmoid(x) for x in vec ]) 17 | 18 | 19 | def act_top(start, start_addr, dest_addr, nvecs, func, accum_out): 20 | 21 | # func: 0 - nothing 22 | # 1 - ReLU 23 | # 2 - sigmoid 24 | 25 | busy = Register(1) 26 | accum_addr = Register(len(start_addr)) 27 | ub_waddr = Register(len(dest_addr)) 28 | N = Register(len(nvecs)) 29 | act_func = Register(len(func)) 30 | 31 | rtl_assert(~(start & busy), Exception("Dispatching new activate instruction while previous instruction is still running.")) 32 | 33 | with conditional_assignment: 34 | with start: # new instruction being dispatched 35 | accum_addr.next |= start_addr 36 | ub_waddr.next |= dest_addr 37 | N.next |= nvecs 38 | act_func.next |= func 39 | busy.next |= 1 40 | with busy: # Do activate on another vector this cycle 41 | accum_addr.next |= accum_addr + 1 42 | ub_waddr.next |= ub_waddr + 1 43 | N.next |= N - 1 44 | with N == 1: # this was the last vector 45 | busy.next |= 0 46 | 47 | invals = concat_list([ x[:8] for x in accum_out ]) 48 | act_out = mux(act_func, invals, relu_vector(accum_out, 24), sigmoid_vector(accum_out), invals) 49 | #act_out = relu_vector(accum_out, 24) 50 | ub_we = busy 51 | 52 | return accum_addr, ub_waddr, act_out, ub_we, busy 53 | -------------------------------------------------------------------------------- /old/maxpool_naive.py: -------------------------------------------------------------------------------- 1 | # We're not sorry for the TERRIBLE code! 2 | # We now know how to create lists containing Registers/WireVectors :) 3 | # The pipeline was also implemented traditionally. 4 | # Attention: parametric code! 5 | # Function: maxpooling 6 | # Design: fully-pipelined. Latency: ceil(log(n)) 7 | # v2 it is, v3 is coming and it looks much better! We promise :) 8 | 9 | import pyrtl 10 | 11 | def bitCmp(din0, din1): 12 | dout = pyrtl.WireVector(32) 13 | dout_reg = pyrtl.Register(32) 14 | with pyrtl.conditional_assignment: 15 | with din0 >= din1: 16 | dout |= din0 17 | with pyrtl.otherwise: 18 | dout |= din1 19 | dout_reg.next <<= dout 20 | return dout_reg 21 | 22 | def maxpool(din): 23 | if (len(din)==1): 24 | return din[0] 25 | elif (len(din)==2): 26 | return bitCmp(din[0], din[1]) 27 | else: 28 | left = maxpool(din[:len(din)/2]) 29 | right = maxpool(din[len(din)/2:]) 30 | return bitCmp(left, right) 31 | 32 | 33 | 34 | # instantiate relu and set test inputs 35 | '''din = [] 36 | din0 = pyrtl.Register(bitwidth=32, name='din0') 37 | din0.next <<= 10 38 | din1 = pyrtl.Register(bitwidth=32, name='din1') 39 | din1.next <<= 12 40 | din2 = pyrtl.Register(bitwidth=32, name='din2') 41 | din2.next <<= 10 42 | din3 = pyrtl.Register(bitwidth=32, name='din3') 43 | din3.next <<= 12 44 | din4 = pyrtl.Register(bitwidth=32, name='din4') 45 | din4.next <<= 127 46 | din5 = pyrtl.Register(bitwidth=32, name='din5') 47 | din5.next <<= 12 48 | din6 = pyrtl.Register(bitwidth=32, name='din6') 49 | din6.next <<= 10 50 | din7 = pyrtl.Register(bitwidth=32, name='din7') 51 | din7.next <<= 12 52 | din8 = pyrtl.Register(bitwidth=32, name='din8') 53 | din8.next <<= 10 54 | 55 | 56 | din.append(din0) 57 | din.append(din1) 58 | din.append(din2) 59 | din.append(din3) 60 | din.append(din4) 61 | din.append(din5) 62 | din.append(din6) 63 | din.append(din7) 64 | din.append(din8) 65 | 66 | 67 | #for i in range(9): 68 | #din.append(pyrtl.Register(bitwidth=32, name='dins')) 69 | 70 | 71 | dout = maxpool(din) 72 | cmpr_out = pyrtl.Register(bitwidth=32, name='cmpr_out') 73 | cmpr_out.next <<= dout 74 | 75 | # simulate the instantiated design for 15 cycles 76 | sim_trace = pyrtl.SimulationTrace() 77 | sim = pyrtl.Simulation(tracer=sim_trace) 78 | for cyle in range(35): 79 | sim.step({}) 80 | sim_trace.render_trace() ''' 81 | -------------------------------------------------------------------------------- /simple_nn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pickle 4 | from gen_mem import gen_mem 5 | 6 | # train code from http://iamtrask.github.io/2015/07/12/basic-python-network/ 7 | 8 | # sigmoid function 9 | def sigmoid(x,deriv=False): 10 | if(deriv==True): 11 | return x*(1-x) 12 | return 1/(1+np.exp(-x)) 13 | 14 | def relu(X): 15 | map_func = np.vectorize(lambda x: max(x, 0)) 16 | return map_func(X) 17 | 18 | def train(path, lpath): 19 | X = np.load(path) 20 | y = np.load(lpath) 21 | 22 | print 'X: {}'.format(X) 23 | print 'y: {}'.format(y) 24 | # seed random numbers to make calculation 25 | # deterministic (just a good practice) 26 | np.random.seed(1) 27 | 28 | # initialize weights randomly with mean 0 29 | syn0 = np.random.random((X.shape[-1], y.shape[-1])) 30 | 31 | nonlin = sigmoid 32 | 33 | for iter in xrange(10000): 34 | 35 | # forward propagation 36 | l0 = X 37 | l1 = nonlin(np.dot(l0,syn0)) 38 | 39 | # how much did we miss? 40 | l1_error = y - l1 41 | 42 | # multiply how much we missed by the 43 | # slope of the sigmoid at the values in l1 44 | l1_delta = l1_error * nonlin(l1) 45 | 46 | # update weights 47 | syn0 += np.dot(l0.T,l1_delta) 48 | print 'syn0: {}'.format(syn0) 49 | print 'l1: {}'.format(l1) 50 | with open('simple_nn_gt', 'w') as f: 51 | pickle.dump((l1, syn0), f) 52 | f.close() 53 | 54 | syn0 = float2byte(syn0) 55 | gen_mem('simple_nn_weight_dram', syn0) 56 | 57 | args = None 58 | 59 | def float2byte(mat): 60 | pos_mat = np.vectorize(lambda x: np.abs(x))(mat) 61 | max_w = np.amax(pos_mat) 62 | mat = np.vectorize(lambda x: (127 * x/max_w).astype(np.int8))(mat) 63 | return mat.reshape(1, 8, 8) 64 | 65 | def parse_args(): 66 | global args 67 | 68 | parser = argparse.ArgumentParser() 69 | 70 | parser.add_argument('--path', action='store', 71 | help='path to dataset file.') 72 | parser.add_argument('--label', action='store', 73 | help='path to the label file.') 74 | parser.add_argument('--debug', action='store_true', 75 | help='switch debug prints.') 76 | args = parser.parse_args() 77 | 78 | 79 | if __name__ == '__main__': 80 | parse_args() 81 | train(args.path, args.label) 82 | -------------------------------------------------------------------------------- /checker.py: -------------------------------------------------------------------------------- 1 | """ The checker assumes results are always written to hostmem 2 | consecutively starting from location 0. 3 | 4 | If the result is shorter than HW width: (X is don't care) 5 | 6 | --------HW WIDTH--------- 7 | D D D D D D D X X X X X X 8 | D D D D D D D X X X X X X 9 | 10 | else: 11 | 12 | --------HW WIDTH--------- 13 | D D D D D D D D D D D D D 14 | D D X X X X X X X X X X X 15 | D D D D D D D D D D D D D 16 | D D X X X X X X X X X X X 17 | 18 | """ 19 | 20 | import argparse 21 | import numpy as np 22 | 23 | args = None 24 | 25 | def equal(a1, a2): 26 | assert a1.shape == a2.shape, 'result file shape mismatch.' 27 | if a1.dtype == np.int8: 28 | a1 = a1.astype(np.uint8) 29 | if a2.dtype == np.int8: 30 | a2 = a2.astype(np.uint8) 31 | for x, y in np.nditer([a1, a2]): 32 | assert x == y, 'result value mismatch.' 33 | 34 | def check(p1, p2, width=None): 35 | r1 = np.load(p1) 36 | r2 = np.load(p2) 37 | if not width: 38 | # Checking sim8 against hw8. 39 | equal(r1, r2) 40 | else: 41 | # Checking gt32 against sim32. 42 | #assert width == r2.shape[1] 43 | r_width = r1.shape[1] 44 | if r_width <= width: 45 | r2 = r2[:, :r_width] 46 | equal(r1, r2) 47 | else: 48 | r2 = np.concatenate((r2[::2], r2[1::2]), axis=1) 49 | r2 = r2[:, :r_width] 50 | equal(r1, r2) 51 | 52 | 53 | def parse_args(): 54 | global args 55 | 56 | parser = argparse.ArgumentParser() 57 | 58 | parser.add_argument('--width', action='store', type=int, default=16, 59 | help='HW WIDTH.') 60 | parser.add_argument('--gt32', action='store', default='gt32.npy', 61 | help='path to f32 ground truth result.') 62 | parser.add_argument('--sim32', action='store', default='sim32.npy', 63 | help='path to f32 simulator result.') 64 | parser.add_argument('--sim8', action='store', default='sim8.npy', 65 | help='path to i8 simulator result.') 66 | parser.add_argument('--hw8', action='store', default='hw8.npy', 67 | help='path to i8 hardware result.') 68 | args = parser.parse_args() 69 | 70 | 71 | if __name__ == '__main__': 72 | parse_args() 73 | print 'HW width set to %d.' % args.width 74 | check(args.gt32, args.sim32, args.width) 75 | print '32-bit passed.' 76 | check(args.sim8, args.hw8) 77 | print '8-bit passed.' 78 | -------------------------------------------------------------------------------- /isa.py: -------------------------------------------------------------------------------- 1 | """ 2 | The assembly format for most instructions (RHM, WHM, MMC, ACT) is 3 | INSTRUCTION SRC, DEST, LENGTH 4 | For RW, it is 5 | RW SRC 6 | for HLT, it is 7 | HLT 8 | 9 | === Binary Encoding ==== 10 | 11 | | opcode | flags | length | addr | ub addr | 12 | | 1 | 1 | 1 | 8 | 3 | 13 | |13 13|12 12|11 11|10 3|2 0| 14 | 15 | All numbers above are expressed in BYTES. 16 | The 'addr' field is used for host memory address (for RHM and WHM), 17 | weight DRAM address (for RW), and accumulator address (for MMC and ACT). 18 | For the later two, the field is larger than necessary, and only the lower bits are used. 19 | 'ub addr' is always a Unified Buffer address. 20 | 'length' is the number of vectors to read/write/process. 21 | 22 | FLAG field is r|r|f|f|f|o|s|c, r stands for reserved bit, s for switch bit, 23 | c for convolve bit, f for function select bits, and o for override bit. 24 | 25 | """ 26 | 27 | ENDIANNESS = 'big' 28 | #ENDIANNESS = 'little' 29 | 30 | INSTRUCTION_WIDTH_BYTES = 14 31 | 32 | HOST_ADDR_SIZE = 8 # 64-bit addressing 33 | DRAM_ADDR_SIZE = 5 # 33-bit addressing (TPU has 8 GB on-chip DRAM) 34 | UB_ADDR_SIZE = 3 # 17-bit addressing for Unified Buffer 35 | ACC_ADDR_SIZE = 2 # 12-bit addressing for accumulator 36 | OP_SIZE = 1 37 | FLAGS_SIZE = 1 38 | ADDR_SIZE = 8 39 | UB_ADDR_SIZE = 3 40 | LEN_SIZE = 1 41 | 42 | UBADDR_START = 0 43 | UBADDR_END = 3 44 | ADDR_START = 3 45 | ADDR_END = 11 46 | LEN_START = 11 47 | LEN_END = 12 48 | FLAGS_START = 12 49 | FLAGS_END = 13 50 | OP_START = 13 51 | OP_END = 14 52 | 53 | # Map text opcode to instruction decomposition info. 54 | # Str -> (opcode_value, src_len, dst_len, 3rd_len) 55 | OPCODE2BIN = { 56 | 'NOP': (0x0, 0, 0, 0), 57 | 'WHM': (0x1, UB_ADDR_SIZE, HOST_ADDR_SIZE, 1), 58 | 'RW': (0x2, DRAM_ADDR_SIZE, 0, 1), 59 | 'MMC': (0x3, UB_ADDR_SIZE, ACC_ADDR_SIZE, 1), 60 | 'ACT': (0x4, ACC_ADDR_SIZE, UB_ADDR_SIZE, 1), 61 | 'SYNC': (0x5, 0, 0, 0), 62 | 'RHM': (0x6, HOST_ADDR_SIZE, UB_ADDR_SIZE, 1), 63 | 'HLT': (0x7, 0, 0, 0), 64 | } 65 | 66 | BIN2OPCODE = {v[0]: k for k, v in OPCODE2BIN.items()} 67 | 68 | SWITCH_MASK = 0b00000001 69 | CONV_MASK = 0b00000010 70 | OVERWRITE_MASK = 0b00000100 # whether MMC should overwrite accumulator value or add to it 71 | ACT_FUNC_MASK = 0b00011000 # 0 for nothing; 1 for ReLU; 2 for sigmoid 72 | FUNC_RELU_MASK = 0b00001000 73 | FUNC_SIGMOID_MASK = 0b00010000 74 | 75 | SWITCH_BIT = 0 76 | OVERWRITE_BIT = 2 77 | ACT_FUNC_BITS = slice(3,5) 78 | FUNC_RELU_BIT = 3 79 | FUNC_SIGMOID_BIT = 4 80 | 81 | -------------------------------------------------------------------------------- /old/norm_dynam.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Function: Normalization 3 | Design: fully-pipelined (it may be better to replace some Regs with WireVectors 4 | Comments: the offset can change at runtime. To avoid the use of big MUXes we have used a barrel shifter. PyRTL rtllib contains a barrel_shifter, but the interpreter was returing us an error when trying to call it from the lib, so we copied the whole function here. 5 | ''' 6 | 7 | import pyrtl 8 | import math 9 | 10 | def barrel_shifter(shift_in, bit_in, direction, shift_dist, wrap_around=0): 11 | """ 12 | Create a barrel shifter that operates on data based on the wire width 13 | :param shift_in: the input wire 14 | :param bit_in: the 1-bit wire giving the value to shift in 15 | :param direction: a one bit WireVector representing shift direction 16 | (0 = shift down, 1 = shift up) 17 | :param shift_dist: WireVector representing offset to shift 18 | :param wrap_around: ****currently not implemented**** 19 | :return: shifted WireVector 20 | """ 21 | # Implement with logN stages pyrtl.muxing between shifted and un-shifted values 22 | 23 | val = shift_in 24 | append_val = bit_in 25 | log_length = int(math.log(len(shift_in)-1, 2)) # note the one offset 26 | 27 | if len(shift_dist) > log_length: 28 | print('Warning: for barrel shifter, the shift distance wirevector ' 29 | 'has bits that are not used in the barrel shifter') 30 | 31 | for i in range(min(len(shift_dist), log_length)): 32 | shift_amt = pow(2, i) # stages shift 1,2,4,8,... 33 | newval = pyrtl.select(direction, truecase=val[:-shift_amt], falsecase=val[shift_amt:]) 34 | newval = pyrtl.select(direction, truecase=pyrtl.concat(newval, append_val), 35 | falsecase=pyrtl.concat(append_val, newval)) # Build shifted value 36 | # pyrtl.mux shifted vs. unshifted by using i-th bit of shift amount signal 37 | val = pyrtl.select(shift_dist[i], truecase=newval, falsecase=val) 38 | append_val = pyrtl.concat(append_val, bit_in) 39 | 40 | return val 41 | 42 | # main normalization module 43 | def nrml(din, offset=24): 44 | zero = pyrtl.Const(0,1) 45 | one = pyrtl.Const(1,1) 46 | temp = pyrtl.Register(32, name='temp') 47 | temp.next <<= barrel_shifter(din, zero, zero, offset) 48 | dout = pyrtl.Register(8, name='dout') 49 | dout.next <<= temp[:8] 50 | return dout 51 | 52 | # test 53 | din = pyrtl.Register(bitwidth=32, name='din') 54 | din.next <<= 300 55 | offset = pyrtl. Register(bitwidth = 32, name= 'offset') 56 | offset.next <<= 5 57 | test_out = nrml(din,offset) 58 | 59 | # simulate the instantiated design for 15 cycles 60 | sim_trace = pyrtl.SimulationTrace() 61 | sim = pyrtl.Simulation(tracer=sim_trace) 62 | for cyle in range(35): 63 | sim.step({}) 64 | sim_trace.render_trace() 65 | 66 | -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | from pyrtl import * 2 | import config 3 | import isa 4 | 5 | DATASIZE = config.DWIDTH 6 | MATSIZE = config.MATSIZE 7 | ACCSIZE = config.ACC_ADDR_SIZE 8 | 9 | def decode(instruction): 10 | """ 11 | :param instruction: instruction + optional operands + flags 12 | """ 13 | 14 | accum_raddr = WireVector(ACCSIZE) 15 | accum_waddr = WireVector(ACCSIZE) 16 | accum_overwrite = WireVector(1) 17 | switch_weights = WireVector(1) 18 | weights_raddr = WireVector(config.WEIGHT_DRAM_ADDR_SIZE) # read address for weights DRAM 19 | weights_read = WireVector(1) # raised high to perform DRAM read 20 | 21 | ub_addr = WireVector(24) # goes to FSM 22 | ub_raddr = WireVector(config.UB_ADDR_SIZE) # goes to UB read addr port 23 | ub_waddr = WireVector(config.UB_ADDR_SIZE) 24 | 25 | whm_length = WireVector(8) 26 | rhm_length = WireVector(8) 27 | mmc_length = WireVector(16) 28 | act_length = WireVector(8) 29 | act_type = WireVector(2) 30 | 31 | rhm_addr = WireVector(config.HOST_ADDR_SIZE) 32 | whm_addr = WireVector(config.HOST_ADDR_SIZE) 33 | 34 | dispatch_mm = WireVector(1) 35 | dispatch_act = WireVector(1) 36 | dispatch_rhm = WireVector(1) 37 | dispatch_whm = WireVector(1) 38 | dispatch_halt = WireVector(1) 39 | 40 | # parse instruction 41 | op = instruction[ isa.OP_START*8 : isa.OP_END*8 ] 42 | #probe(op, "op") 43 | iflags = instruction[ isa.FLAGS_START*8 : isa.FLAGS_END*8 ] 44 | #probe(iflags, "flags") 45 | #probe(accum_overwrite, "decode_overwrite") 46 | ilength = instruction[ isa.LEN_START*8 : isa.LEN_END*8 ] 47 | memaddr = instruction[ isa.ADDR_START*8 : isa.ADDR_END*8 ] 48 | #probe(memaddr, "addr") 49 | ubaddr = instruction[ isa.UBADDR_START*8 : isa.UBADDR_END*8 ] 50 | #probe(ubaddr, "ubaddr") 51 | 52 | with conditional_assignment: 53 | with op == isa.OPCODE2BIN['NOP'][0]: 54 | pass 55 | with op == isa.OPCODE2BIN['WHM'][0]: 56 | dispatch_whm |= 1 57 | ub_raddr |= ubaddr 58 | whm_addr |= memaddr 59 | whm_length |= ilength 60 | with op == isa.OPCODE2BIN['RW'][0]: 61 | weights_raddr |= memaddr 62 | weights_read |= 1 63 | with op == isa.OPCODE2BIN['MMC'][0]: 64 | dispatch_mm |= 1 65 | ub_addr |= ubaddr 66 | accum_waddr |= memaddr 67 | mmc_length |= ilength 68 | accum_overwrite |= iflags[isa.OVERWRITE_BIT] 69 | switch_weights |= iflags[isa.SWITCH_BIT] 70 | # TODO: MMC may deal with convolution, set/clear that flag 71 | with op == isa.OPCODE2BIN['ACT'][0]: 72 | dispatch_act |= 1 73 | accum_raddr |= memaddr 74 | ub_waddr |= ubaddr 75 | act_length |= ilength 76 | act_type |= iflags[isa.ACT_FUNC_BITS] 77 | #probe(act_length, "act_length") 78 | #probe(act_type, "act_type") 79 | # TODO: ACT takes function select bits 80 | with op == isa.OPCODE2BIN['SYNC'][0]: 81 | pass 82 | with op == isa.OPCODE2BIN['RHM'][0]: 83 | dispatch_rhm |= 1 84 | rhm_addr |= memaddr 85 | ub_raddr |= ubaddr 86 | rhm_length |= ilength 87 | with op == isa.OPCODE2BIN['HLT'][0]: 88 | dispatch_halt |= 1 89 | 90 | #with otherwise: 91 | # print("otherwise") 92 | 93 | return dispatch_mm, dispatch_act, dispatch_rhm, dispatch_whm, dispatch_halt, ub_addr, ub_raddr, ub_waddr, rhm_addr, whm_addr, rhm_length, whm_length, mmc_length, act_length, act_type, accum_raddr, accum_waddr, accum_overwrite, switch_weights, weights_raddr, weights_read 94 | -------------------------------------------------------------------------------- /old/act_top.py: -------------------------------------------------------------------------------- 1 | # Function: Relu and normalization. 2 | # It reads from Accum Buffer and writes to the Unified Buffer. 3 | # Comments: offset defined during design phase (not runtime). 4 | # Code for "dynamic" normalization available as well. 5 | # Simulation seems to work well. Reordering may be REQUIRED though. 6 | 7 | #import pyrtl 8 | from pyrtl import * 9 | 10 | # relu and normalization 11 | def relu_elem(din, offset): 12 | assert offset <= 24 13 | dout = pyrtl.WireVector(32) 14 | dout_reg = pyrtl.Register(8) 15 | with pyrtl.conditional_assignment: 16 | with din[-1] == 0: 17 | dout |= din 18 | with pyrtl.otherwise: 19 | dout |= 0 20 | dout_reg.next <<= dout[24-offset:32-offset] 21 | return dout_reg 22 | 23 | # Latency: 1cc 24 | def relu_vector(din, offset=0): 25 | dout = [relu_elem(din[i], offset) for i in range(len(din))] 26 | return dout 27 | 28 | # Latency: N+1 cc 29 | def act_top(rd_addr, N, wr_addr): 30 | cntr = pyrtl.Register(2, name='counter') 31 | cntr_dl1 = pyrtl.Register(2) 32 | cntr_dl2 = pyrtl.Register(2) 33 | din_addr = pyrtl.Register(2, name='din_addr') 34 | dout_addr = pyrtl.Register(2, name='dout_addr') 35 | dout = pyrtl.Register(4*8, name='dout') 36 | din = [] 37 | din_addr.next |= rd_addr + cntr 38 | with pyrtl.conditional_assignment: 39 | with cntr < N-1: 40 | cntr.next |= cntr + 1 41 | with pyrtl.otherwise: 42 | cntr.next |= 0 43 | for i in range(4): 44 | din.append(mem_acum[din_addr][i*32:(i+1)*32-1]) 45 | # latency from read to write: 2cc 46 | cntr_dl1.next <<= cntr 47 | cntr_dl2.next <<=cntr_dl1 48 | dout_addr.next <<= wr_addr + cntr_dl2 49 | relu_out = relu_vector(din, 24) 50 | dout.next <<= pyrtl.concat_list(relu_out) 51 | #mem_ub[dout_addr] <<= pyrtl.concat_list(relu_out) 52 | return 1 53 | 54 | def relu_vector2(vec, offset): 55 | assert offset <= 24 56 | return [ select(d[-1], falsecase=d, truecase=Const(0))[24-offset:32-offset] for d in vec ] 57 | 58 | def act_top2(start, start_addr, dest_addr, nvecs, accum_out): 59 | 60 | busy = Register(1) 61 | accum_addr = Register(len(start_addr)) 62 | ub_waddr = Register(len(dest_addr)) 63 | N = Register(len(nvecs)) 64 | 65 | rtl_assert(~(start & busy), Exception("Dispatching new activate instruction while previous instruction is still running.")) 66 | 67 | with conditional_assignment: 68 | with start: # new instruction being dispatched 69 | accum_addr.next |= start_addr 70 | ub_waddr.next |= dest_addr 71 | N.next |= nvecs 72 | busy.next |= 1 73 | with busy: # Do activate on another vector this cycle 74 | accum_addr.next |= accum_addr + 1 75 | ub_waddr.next |= ub_waddr + 1 76 | N.next |= N - 1 77 | with N == 1: # this was the last vector 78 | busy.next |= 0 79 | 80 | act_out_list = relu_vector2(accum_out, 24) 81 | act_out = concat_list(act_out_list) 82 | ub_we = busy 83 | 84 | return accum_addr, ub_waddr, act_out, ub_we, busy 85 | 86 | def testact(): 87 | # Test 88 | mem_acum = pyrtl.MemBlock(bitwidth=4*32,addrwidth=2) 89 | mem_acum[0] <<= pyrtl.Const(0x00000001000000020000000300000004) 90 | mem_acum[1] <<= pyrtl.Const(0x00000005000000060000000700000008) 91 | mem_acum[2] <<= pyrtl.Const(0x000000090000000A0000000B0000000C) 92 | mem_acum[3] <<= pyrtl.Const(0x0000000D0000000E0000000F00000010) 93 | mem_ub = pyrtl.MemBlock(bitwidth=4*8,addrwidth=2) 94 | 95 | act_top(0,4,0) 96 | 97 | # simulate the instantiated design for 15 cycles 98 | sim_trace = pyrtl.SimulationTrace() 99 | sim = pyrtl.Simulation(tracer=sim_trace) 100 | for cyle in range(15): 101 | sim.step({}) 102 | sim_trace.render_trace() 103 | -------------------------------------------------------------------------------- /simplemult.a: -------------------------------------------------------------------------------- 1 | # Simple matrix multiplication. 2 | RHM 0, 0, 8 # Read from Host at addr 0, to UB addr 0 for 1 256B vector 3 | RW 0 4 | NOP 5 | NOP 6 | NOP 7 | NOP 8 | NOP 9 | NOP 10 | NOP 11 | NOP 12 | NOP 13 | NOP 14 | NOP 15 | NOP 16 | NOP 17 | NOP 18 | NOP 19 | NOP 20 | NOP 21 | NOP 22 | NOP 23 | NOP 24 | NOP 25 | NOP 26 | NOP 27 | NOP 28 | NOP 29 | NOP 30 | NOP 31 | NOP 32 | NOP 33 | NOP 34 | NOP 35 | NOP 36 | NOP 37 | NOP 38 | NOP 39 | NOP 40 | NOP 41 | NOP 42 | NOP 43 | NOP 44 | NOP 45 | NOP 46 | NOP 47 | NOP 48 | NOP 49 | NOP 50 | NOP 51 | NOP 52 | NOP 53 | NOP 54 | NOP 55 | NOP 56 | NOP 57 | NOP 58 | NOP 59 | NOP 60 | NOP 61 | NOP 62 | NOP 63 | NOP 64 | NOP 65 | NOP 66 | NOP 67 | NOP 68 | NOP 69 | NOP 70 | NOP 71 | NOP 72 | NOP 73 | NOP 74 | NOP 75 | NOP 76 | NOP 77 | NOP 78 | NOP 79 | NOP 80 | NOP 81 | NOP 82 | NOP 83 | NOP 84 | NOP 85 | NOP 86 | NOP 87 | NOP 88 | NOP 89 | NOP 90 | NOP 91 | NOP 92 | NOP 93 | NOP 94 | NOP 95 | NOP 96 | NOP 97 | NOP 98 | NOP 99 | NOP 100 | NOP 101 | NOP 102 | NOP 103 | NOP 104 | MMC.S 0, 0, 8 105 | NOP 106 | NOP 107 | NOP 108 | NOP 109 | NOP 110 | NOP 111 | NOP 112 | NOP 113 | NOP 114 | NOP 115 | NOP 116 | NOP 117 | NOP 118 | NOP 119 | NOP 120 | NOP 121 | NOP 122 | NOP 123 | NOP 124 | NOP 125 | NOP 126 | NOP 127 | NOP 128 | NOP 129 | NOP 130 | NOP 131 | NOP 132 | NOP 133 | NOP 134 | NOP 135 | NOP 136 | NOP 137 | NOP 138 | NOP 139 | NOP 140 | NOP 141 | NOP 142 | NOP 143 | NOP 144 | NOP 145 | NOP 146 | NOP 147 | NOP 148 | NOP 149 | NOP 150 | NOP 151 | NOP 152 | NOP 153 | NOP 154 | NOP 155 | NOP 156 | NOP 157 | NOP 158 | NOP 159 | NOP 160 | NOP 161 | NOP 162 | NOP 163 | NOP 164 | NOP 165 | NOP 166 | NOP 167 | NOP 168 | NOP 169 | NOP 170 | NOP 171 | NOP 172 | NOP 173 | NOP 174 | NOP 175 | NOP 176 | NOP 177 | NOP 178 | NOP 179 | NOP 180 | NOP 181 | NOP 182 | NOP 183 | NOP 184 | NOP 185 | NOP 186 | NOP 187 | NOP 188 | NOP 189 | NOP 190 | NOP 191 | NOP 192 | NOP 193 | NOP 194 | NOP 195 | NOP 196 | NOP 197 | NOP 198 | NOP 199 | NOP 200 | NOP 201 | NOP 202 | NOP 203 | NOP 204 | NOP 205 | NOP 206 | NOP 207 | NOP 208 | NOP 209 | NOP 210 | NOP 211 | NOP 212 | NOP 213 | NOP 214 | NOP 215 | NOP 216 | NOP 217 | NOP 218 | NOP 219 | NOP 220 | NOP 221 | NOP 222 | NOP 223 | NOP 224 | NOP 225 | NOP 226 | NOP 227 | NOP 228 | NOP 229 | NOP 230 | NOP 231 | NOP 232 | NOP 233 | NOP 234 | NOP 235 | NOP 236 | NOP 237 | NOP 238 | NOP 239 | NOP 240 | NOP 241 | NOP 242 | NOP 243 | NOP 244 | NOP 245 | NOP 246 | NOP 247 | NOP 248 | NOP 249 | NOP 250 | NOP 251 | NOP 252 | NOP 253 | NOP 254 | NOP 255 | NOP 256 | NOP 257 | NOP 258 | NOP 259 | NOP 260 | NOP 261 | NOP 262 | NOP 263 | NOP 264 | NOP 265 | NOP 266 | NOP 267 | NOP 268 | NOP 269 | NOP 270 | NOP 271 | NOP 272 | NOP 273 | NOP 274 | NOP 275 | NOP 276 | NOP 277 | NOP 278 | NOP 279 | NOP 280 | NOP 281 | NOP 282 | NOP 283 | NOP 284 | NOP 285 | NOP 286 | NOP 287 | NOP 288 | NOP 289 | NOP 290 | NOP 291 | NOP 292 | NOP 293 | NOP 294 | NOP 295 | NOP 296 | NOP 297 | NOP 298 | NOP 299 | NOP 300 | NOP 301 | NOP 302 | NOP 303 | NOP 304 | NOP 305 | ACT.R 0, 0, 8 306 | NOP 307 | NOP 308 | NOP 309 | NOP 310 | NOP 311 | NOP 312 | NOP 313 | NOP 314 | NOP 315 | NOP 316 | NOP 317 | NOP 318 | NOP 319 | NOP 320 | NOP 321 | NOP 322 | NOP 323 | NOP 324 | NOP 325 | NOP 326 | NOP 327 | NOP 328 | NOP 329 | NOP 330 | NOP 331 | NOP 332 | NOP 333 | NOP 334 | NOP 335 | NOP 336 | NOP 337 | NOP 338 | NOP 339 | NOP 340 | NOP 341 | NOP 342 | NOP 343 | NOP 344 | NOP 345 | NOP 346 | NOP 347 | NOP 348 | NOP 349 | NOP 350 | NOP 351 | NOP 352 | NOP 353 | NOP 354 | NOP 355 | NOP 356 | NOP 357 | NOP 358 | NOP 359 | NOP 360 | NOP 361 | NOP 362 | NOP 363 | NOP 364 | NOP 365 | NOP 366 | NOP 367 | NOP 368 | NOP 369 | NOP 370 | NOP 371 | NOP 372 | NOP 373 | NOP 374 | NOP 375 | NOP 376 | NOP 377 | NOP 378 | NOP 379 | NOP 380 | NOP 381 | NOP 382 | NOP 383 | NOP 384 | NOP 385 | NOP 386 | NOP 387 | NOP 388 | NOP 389 | NOP 390 | NOP 391 | NOP 392 | NOP 393 | NOP 394 | NOP 395 | NOP 396 | NOP 397 | NOP 398 | NOP 399 | NOP 400 | NOP 401 | NOP 402 | NOP 403 | NOP 404 | NOP 405 | NOP 406 | WHM 0, 0, 8 407 | NOP 408 | NOP 409 | NOP 410 | NOP 411 | NOP 412 | NOP 413 | NOP 414 | NOP 415 | NOP 416 | NOP 417 | NOP 418 | NOP 419 | NOP 420 | NOP 421 | NOP 422 | NOP 423 | NOP 424 | NOP 425 | NOP 426 | NOP 427 | NOP 428 | NOP 429 | NOP 430 | NOP 431 | HLT 432 | -------------------------------------------------------------------------------- /architecture.md: -------------------------------------------------------------------------------- 1 | 2 | ### Writing a Program 3 | OpenTPU uses no dynamic scheduling; all execution is fully determinstic* and the hardware relies on the compiler to correctly schedule operations and pad NOPs to handle delays. This OpenTPU release does not support "repeat" flags on instructions, so many NOPs are required to ensure correct execution. 4 | 5 | *DRAM is a source of non-deterministic latency, discussed in the Memory Controller section of Microarchitecture. 6 | 7 | 8 | ### Latencies 9 | The following gives the hardware execution latency for each instruction on OpenTPU: 10 | 11 | RHM - _M_ cycles for reading _M_ vectors 12 | 13 | WHM - _M_ cycles for writing _M_ vectors 14 | 15 | RW - _N*N_/64 cycles for _N_x_N_ MM Array for DRAM transfer, and up to 3 additional cycles to propagate through the FIFO 16 | 17 | MMC - _L+2N_ cycles, for _N_x_N_ MM Array and _L_ vectors multiplied in the instruction 18 | 19 | ACT - _L+1_ cycles, for _L_ vectors activated in the instruction 20 | 21 | 22 | ## Microarchitecture 23 | 24 | ### Matrix Multiply (MM) Unit 25 | The core of the compute of the OpenTPU is the parametrizable array of 8-bit Multiply-Accumulate Units (MACs), each consisting of an 8-bit integer multiplier and an integer adder of between 16 and 32 bits*. Each MAC has two buffers storing 8-bit weights (the second buffer allows weight programming to happen in parallel). Input vectors enter the array from the left, with values advancing one unit to the right each cycle. Each unit multiplies the input value by the active weight, adds it to the value from the unit above, and passes the result to the unit below. Input vectors are fed diagonally so that values align correctly as partial sums flow down the array. 26 | 27 | *The multipliers produce 16-bit outputs; as values move down the columns of the array, each add produces 1 extra bit. Width is capped at 32, creating the potential for uncaught overflow. 28 | 29 | 30 | ### Accumulator Buffers 31 | Result vectors from the MM Array are written to a software-specified address in a set of accumulator buffers. Instructions indicate whether values should be added into the value already at the address or overwrite it. MM instructions read from the Unified Buffer (UB) and write to the accumulator buffers; activate instructions read from the accumulator buffers and write to the UB. 32 | 33 | 34 | ### Weight FIFO 35 | At scale (256x256 MACs), a full matrix of weights (a "tile") is 64KB; to avoid stalls while weights are moved from off-chip weight DRAM, a 4-entry FIFO is used to buffer tiles. It is assumed the connection to the weight DRAM is a standard DDR interface moving data in 64-byte chunks (memory controllers are currently emulated with no simulated delay, so one chunk arrives each cycle). When an MM instruction carries the "switch" flag, each MAC switches the active weight buffer as first vector of the instruction propagates through the array. Once it reaches the end of the first row, the FIFO begins feeding new weight values into the free buffers of the array. New weight values are passed down through the array each cycle until each row reaches its destination. 36 | 37 | 38 | ### Systolic Setup 39 | Vectors are read all at once from the Unified Buffer, but must be fed diagonally into the MM Array. This is accomplished with a set of sequential buffers in a lower triangular configuration. The top value reaches the matrix immediately, the second after one cycle, the third after two, etc., so that each value reaches a MAC at the same time as the corresponding partial sum from the same source vector. 40 | 41 | 42 | ### Memory Controllers 43 | Currently, memory controllers are emulated and have no delay. The connection to Host Memory is currently the size of one vector. The connection to the Weight DRAM uses a standard width of 64 bytes. 44 | 45 | Because the emulated controllers can return a new value each cycle, the OpenTPU hardware simulation currently has no non-detministic delay. With a more accurate DRAM interface that may encounter dynamic delays, programs would need to either take care to schedule for the worst-case memory delay, or make use of another instruction to ensure memory operations complete before the values are required*. 46 | 47 | *We note that the TPU "SYNC" instruction may fulfill this purpose, but is currently unimplemented on OpenTPU. 48 | 49 | 50 | ### Configuration 51 | Unified Buffer size, Accumulator Buffer size, and the size of the MM Array can all be specified in config.py. However, the MM Array must always be square, and vectors/weights are always composed of 8-bit integers. 52 | 53 | 54 | -------------------------------------------------------------------------------- /boston.a: -------------------------------------------------------------------------------- 1 | # Host mem: N x 13 input matrix 2 | # Weight mem: 3 | # L1: 13x8 4 | # L2: 8x8 5 | # L3: 8x1 6 | RHM 0, 0, 10 # read from host mem addr 0, to UB addr 0, for length N = 10 7 | RW 0 # read weights from dram addr 0 to FIFO 8 | NOP 9 | NOP 10 | NOP 11 | NOP 12 | NOP 13 | NOP 14 | NOP 15 | NOP 16 | NOP 17 | NOP 18 | NOP 19 | NOP 20 | NOP 21 | NOP 22 | NOP 23 | NOP 24 | NOP 25 | NOP 26 | NOP 27 | NOP 28 | NOP 29 | NOP 30 | NOP 31 | NOP 32 | NOP 33 | NOP 34 | NOP 35 | NOP 36 | NOP 37 | NOP 38 | NOP 39 | NOP 40 | NOP 41 | NOP 42 | NOP 43 | NOP 44 | NOP 45 | NOP 46 | NOP 47 | NOP 48 | NOP 49 | RW 1 # read weights from dram addr 0 to FIFO 50 | NOP 51 | NOP 52 | NOP 53 | NOP 54 | NOP 55 | NOP 56 | NOP 57 | NOP 58 | NOP 59 | NOP 60 | NOP 61 | NOP 62 | NOP 63 | NOP 64 | NOP 65 | NOP 66 | NOP 67 | NOP 68 | NOP 69 | NOP 70 | NOP 71 | NOP 72 | NOP 73 | NOP 74 | NOP 75 | NOP 76 | NOP 77 | NOP 78 | NOP 79 | NOP 80 | NOP 81 | NOP 82 | NOP 83 | NOP 84 | NOP 85 | NOP 86 | NOP 87 | NOP 88 | NOP 89 | NOP 90 | NOP 91 | RW 2 # read weights from dram addr 0 to FIFO 92 | NOP 93 | NOP 94 | NOP 95 | NOP 96 | NOP 97 | NOP 98 | NOP 99 | NOP 100 | NOP 101 | NOP 102 | NOP 103 | NOP 104 | NOP 105 | NOP 106 | NOP 107 | NOP 108 | NOP 109 | NOP 110 | NOP 111 | NOP 112 | NOP 113 | NOP 114 | NOP 115 | NOP 116 | NOP 117 | NOP 118 | NOP 119 | NOP 120 | NOP 121 | NOP 122 | NOP 123 | NOP 124 | NOP 125 | NOP 126 | NOP 127 | NOP 128 | NOP 129 | NOP 130 | NOP 131 | NOP 132 | NOP 133 | MMC.SO 0, 0, 10 # Do MM on UB addr 0, to accumulator addr 0, for length 10 134 | NOP 135 | NOP 136 | NOP 137 | NOP 138 | NOP 139 | NOP 140 | NOP 141 | NOP 142 | NOP 143 | NOP 144 | NOP 145 | NOP 146 | NOP 147 | NOP 148 | NOP 149 | NOP 150 | NOP 151 | NOP 152 | NOP 153 | NOP 154 | NOP 155 | NOP 156 | NOP 157 | NOP 158 | NOP 159 | NOP 160 | NOP 161 | NOP 162 | NOP 163 | NOP 164 | NOP 165 | NOP 166 | NOP 167 | NOP 168 | NOP 169 | NOP 170 | NOP 171 | NOP 172 | NOP 173 | NOP 174 | NOP 175 | ACT.R 0, 0, 10 # Do ACT ReLU on accumulator addr 0, to UB addr 0, for length 10 176 | NOP 177 | NOP 178 | NOP 179 | NOP 180 | NOP 181 | NOP 182 | NOP 183 | NOP 184 | NOP 185 | NOP 186 | NOP 187 | NOP 188 | NOP 189 | NOP 190 | NOP 191 | NOP 192 | NOP 193 | NOP 194 | NOP 195 | NOP 196 | NOP 197 | NOP 198 | NOP 199 | NOP 200 | NOP 201 | NOP 202 | NOP 203 | NOP 204 | NOP 205 | NOP 206 | NOP 207 | NOP 208 | NOP 209 | NOP 210 | NOP 211 | NOP 212 | NOP 213 | NOP 214 | NOP 215 | NOP 216 | NOP 217 | MMC.SO 0, 0, 10 218 | NOP 219 | NOP 220 | NOP 221 | NOP 222 | NOP 223 | NOP 224 | NOP 225 | NOP 226 | NOP 227 | NOP 228 | NOP 229 | NOP 230 | NOP 231 | NOP 232 | NOP 233 | NOP 234 | NOP 235 | NOP 236 | NOP 237 | NOP 238 | NOP 239 | NOP 240 | NOP 241 | NOP 242 | NOP 243 | NOP 244 | NOP 245 | NOP 246 | NOP 247 | NOP 248 | NOP 249 | NOP 250 | NOP 251 | NOP 252 | NOP 253 | NOP 254 | NOP 255 | NOP 256 | NOP 257 | NOP 258 | NOP 259 | ACT.R 0, 0, 10 260 | NOP 261 | NOP 262 | NOP 263 | NOP 264 | NOP 265 | NOP 266 | NOP 267 | NOP 268 | NOP 269 | NOP 270 | NOP 271 | NOP 272 | NOP 273 | NOP 274 | NOP 275 | NOP 276 | NOP 277 | NOP 278 | NOP 279 | NOP 280 | NOP 281 | NOP 282 | NOP 283 | NOP 284 | NOP 285 | NOP 286 | NOP 287 | NOP 288 | NOP 289 | NOP 290 | NOP 291 | NOP 292 | NOP 293 | NOP 294 | NOP 295 | NOP 296 | NOP 297 | NOP 298 | NOP 299 | NOP 300 | NOP 301 | MMC.SO 0, 0, 10 302 | NOP 303 | NOP 304 | NOP 305 | NOP 306 | NOP 307 | NOP 308 | NOP 309 | NOP 310 | NOP 311 | NOP 312 | NOP 313 | NOP 314 | NOP 315 | NOP 316 | NOP 317 | NOP 318 | NOP 319 | NOP 320 | NOP 321 | NOP 322 | NOP 323 | NOP 324 | NOP 325 | NOP 326 | NOP 327 | NOP 328 | NOP 329 | NOP 330 | NOP 331 | NOP 332 | NOP 333 | NOP 334 | NOP 335 | NOP 336 | NOP 337 | NOP 338 | NOP 339 | NOP 340 | NOP 341 | NOP 342 | NOP 343 | ACT.R 0, 0, 10 344 | NOP 345 | NOP 346 | NOP 347 | NOP 348 | NOP 349 | NOP 350 | NOP 351 | NOP 352 | NOP 353 | NOP 354 | NOP 355 | NOP 356 | NOP 357 | NOP 358 | NOP 359 | NOP 360 | NOP 361 | NOP 362 | NOP 363 | NOP 364 | NOP 365 | NOP 366 | NOP 367 | NOP 368 | NOP 369 | NOP 370 | NOP 371 | NOP 372 | NOP 373 | NOP 374 | NOP 375 | NOP 376 | NOP 377 | NOP 378 | NOP 379 | NOP 380 | NOP 381 | NOP 382 | NOP 383 | NOP 384 | NOP 385 | WHM 0, 0, 10 # write result from UB addr 0, to host mem addr 0, for length 10 386 | NOP 387 | NOP 388 | NOP 389 | NOP 390 | NOP 391 | NOP 392 | NOP 393 | NOP 394 | NOP 395 | NOP 396 | NOP 397 | NOP 398 | NOP 399 | NOP 400 | NOP 401 | NOP 402 | NOP 403 | NOP 404 | NOP 405 | NOP 406 | NOP 407 | NOP 408 | NOP 409 | NOP 410 | NOP 411 | NOP 412 | NOP 413 | NOP 414 | NOP 415 | NOP 416 | NOP 417 | NOP 418 | NOP 419 | NOP 420 | NOP 421 | NOP 422 | NOP 423 | NOP 424 | NOP 425 | NOP 426 | NOP 427 | HLT 428 | -------------------------------------------------------------------------------- /boston.out: -------------------------------------------------------------------------------- 1 |  2 |  3 |  4 |  5 |  6 |  7 |  8 |  9 |  -------------------------------------------------------------------------------- /runtpu.py: -------------------------------------------------------------------------------- 1 | from pyrtl import * 2 | import argparse 3 | import numpy as np 4 | 5 | #set_debug_mode() 6 | 7 | from tpu import * 8 | import config 9 | 10 | import sys 11 | 12 | parser = argparse.ArgumentParser(description="Run the PyRTL spec for the TPU on the indicated program.") 13 | parser.add_argument("prog", metavar="program.bin", help="A valid binary program for OpenTPU.") 14 | parser.add_argument("hostmem", metavar="HostMemoryArray", help="A file containing a numpy array containing the initial contents of host memory. Each row represents one vector.") 15 | parser.add_argument("weightsmem", metavar="WeightsMemoryArray", help="A file containing a numpy array containing the contents of the weights memroy. Each row represents one tile (the first row corresponds to the top row of the weights matrix).") 16 | 17 | args = parser.parse_args() 18 | 19 | 20 | # Read the program and build an instruction list 21 | with open(args.prog, 'rb') as f: 22 | ins = [x for x in f.read()] # create byte list from input 23 | 24 | instrs = [] 25 | width = config.INSTRUCTION_WIDTH / 8 26 | # This assumes instructions are strictly byte-aligned 27 | 28 | for i in range(int(len(ins)/width)): # once per instruction 29 | val = 0 30 | for j in range(int(width)): # for each byte 31 | val = (val << 8) | ins.pop(0) 32 | instrs.append(val) 33 | 34 | #print(list(map(hex, instrs))) 35 | 36 | def concat_vec(vec, bits=8): 37 | t = 0 38 | mask = int('1'*bits, 2) 39 | for x in reversed(vec): 40 | t = (t< 0: 58 | vec.append(value & mask) 59 | value = value >> 8 60 | return list(reversed(vec)) 61 | 62 | def print_mem(mem): 63 | ks = sorted(mem.keys()) 64 | for a in ks: 65 | print(a, make_vec(mem[a])) 66 | 67 | def print_weight_mem(mem, bits=8, size=8): 68 | ks = sorted(mem.keys()) 69 | mask = int('1'*(size*bits), 2) 70 | vecs = [] 71 | for a in ks: 72 | vec = [] 73 | tile = mem[a] 74 | while tile > 0: 75 | vec.append(make_vec(tile & mask)) 76 | tile = tile >> (8*8) 77 | if vec != []: 78 | vecs.append(vec) 79 | for a, vec in enumerate(vecs): 80 | print(a, list(reversed(vec))) 81 | 82 | # Read the dram files and build memory images 83 | hostarray = np.load(args.hostmem) 84 | #print(hostarray) 85 | #print(hostarray.shape) 86 | hostmem = { a : concat_vec(vec) for a,vec in enumerate(hostarray) } 87 | print("Host memory:") 88 | print_mem(hostmem) 89 | 90 | 91 | weightsarray = np.load(args.weightsmem) 92 | size = weightsarray.shape[-1] 93 | #print(weightsarray) 94 | #print(weightsarray.shape) 95 | weightsmem = { a : concat_tile(tile) for a,tile in enumerate(weightsarray) } 96 | #weightsmem = { a : concat_vec(vec) for a,vec in enumerate(weightsarray) } 97 | print("Weight memory:") 98 | print_weight_mem(weightsmem, size=size) 99 | #print(weightsmem) 100 | 101 | ''' 102 | Left-most element of each vector should be left-most in memory: use concat_list for each vector 103 | 104 | For weights mem, first vector goes last; hardware already handles this by iterating from back to front over the tile. 105 | The first vector should be at the "front" of the tile. 106 | 107 | For host mem, each vector goes at one address. First vector at address 0. 108 | ''' 109 | 110 | tilesize = config.MATSIZE * config.MATSIZE # number of weights in a tile 111 | nchunks = max(tilesize / 64, 1) # Number of DRAM transfers needed from Weight DRAM for one tile 112 | chunkmask = pow(2,64*8)-1 113 | def getchunkfromtile(tile, chunkn): 114 | #print("Get chunk: ", chunkn, nchunks, chunkmask, tile) 115 | #print((tile >> ((nchunks - chunkn - 1)*64*8)) & chunkmask) 116 | if chunkn >= nchunks: 117 | raise Exception("Reading more weights than are present in one tile?") 118 | return (tile >> int(((nchunks - chunkn - 1))*64*8)) & chunkmask 119 | 120 | # Run Simulation 121 | sim_trace = SimulationTrace() 122 | sim = FastSimulation(tracer=sim_trace, memory_value_map={ IMem : { a : v for a,v in enumerate(instrs)} }) 123 | 124 | din = { 125 | weights_dram_in : 0, 126 | weights_dram_valid : 0, 127 | hostmem_rdata : 0, 128 | } 129 | 130 | cycle = 0 131 | chunkaddr = nchunks 132 | sim.step(din) 133 | i = 0 134 | while True: 135 | i += 1 136 | # Halt signal 137 | if sim.inspect(halt): 138 | break 139 | 140 | d = din.copy() 141 | 142 | # Check if we're doing a Read Weights 143 | if chunkaddr < nchunks: 144 | #print("Sending weights from chunk {}: {}".format(chunkaddr, getchunkfromtile(weighttile, chunkaddr))) 145 | #print(getchunkfromtile(weighttile, chunkaddr)) 146 | #print(weighttile) 147 | d[weights_dram_in] = getchunkfromtile(weighttile, chunkaddr) 148 | d[weights_dram_valid] = 1 149 | chunkaddr += 1 150 | 151 | # Read host memory signal 152 | if sim.inspect(hostmem_re): 153 | raddr = sim.inspect(hostmem_raddr) 154 | if raddr in hostmem: 155 | d[hostmem_rdata] = hostmem[raddr] 156 | 157 | # Write host memory signal 158 | if sim.inspect(hostmem_we): 159 | waddr = sim.inspect(hostmem_waddr) 160 | wdata = sim.inspect(hostmem_wdata) 161 | hostmem[waddr] = wdata 162 | 163 | # Read weights memory signal 164 | if sim.inspect(weights_dram_read): 165 | weightaddr = sim.inspect(weights_dram_raddr) 166 | weighttile = weightsmem[weightaddr] 167 | chunkaddr = 0 168 | #print("Read Weights: addr {}".format(weightaddr)) 169 | #print(weighttile) 170 | 171 | sim.step(d) 172 | cycle += 1 173 | 174 | print("\n\n") 175 | print("Simulation terminated at cycle {}".format(cycle)) 176 | print("Final Host memory:") 177 | print_mem(hostmem) 178 | 179 | #sim_trace.render_trace() 180 | with open("trace.vcd", 'w') as f: 181 | sim_trace.print_vcd(f) 182 | -------------------------------------------------------------------------------- /sim.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import argparse 3 | import sys 4 | import numpy as np 5 | from collections import deque 6 | from math import exp 7 | 8 | import isa 9 | from config import MATSIZE as WIDTH 10 | 11 | args = None 12 | # width of the tile 13 | #WIDTH = 16 14 | 15 | 16 | class TPUSim(object): 17 | def __init__(self, program_filename, dram_filename, hostmem_filename): 18 | # TODO: switch b/w 32-bit float vs int 19 | self.program = open(program_filename, 'rb') 20 | self.weight_memory = np.load(dram_filename) 21 | self.host_memory = np.load(hostmem_filename) 22 | if not args.raw: 23 | assert self.weight_memory.dtype == np.int8, 'DRAM weight mem is not 8-bit ints' 24 | assert self.host_memory.dtype == np.int8, 'Hostmem not 8-bit ints' 25 | self.unified_buffer = (np.zeros((96000, WIDTH), dtype=np.float32) if args.raw else 26 | np.zeros((96000, WIDTH), dtype=np.int8)) 27 | self.accumulator = (np.zeros((4000, WIDTH), dtype=np.float32) if args.raw else 28 | np.zeros((4000, WIDTH), dtype=np.int32)) 29 | self.weight_fifo = deque() 30 | 31 | def run(self): 32 | # load program and execute instructions 33 | while True: 34 | instruction = self.decode() 35 | opcode, operands = instruction[0], instruction[1:] 36 | if opcode in ['RHM', 'WHM', 'RW']: 37 | self.memops(opcode, *operands) 38 | elif opcode == 'MMC': 39 | self.matrix_multiply_convolve(*operands) 40 | elif opcode == 'ACT': 41 | self.act(*operands) 42 | elif opcode == 'SYNC': 43 | pass 44 | elif opcode == 'NOP': 45 | pass 46 | elif opcode == 'HLT': 47 | print('H A L T') 48 | break 49 | else: 50 | raise Exception('WAT (╯°□°)╯︵ ┻━┻') 51 | 52 | # all done, exit 53 | savepath = 'sim32.npy' if args.raw else 'sim8.npy' 54 | np.save(savepath, self.host_memory) 55 | print(self.host_memory.astype('uint8')) 56 | self.program.close() 57 | 58 | print("""ALL DONE! 59 | (•_•) 60 | ( •_•)>⌐■-■ 61 | (⌐■_■)""") 62 | 63 | def decode(self): 64 | opcode = int.from_bytes(self.program.read(isa.OP_SIZE), byteorder='big') 65 | opcode = isa.BIN2OPCODE[opcode] 66 | 67 | flag = int.from_bytes(self.program.read(isa.FLAGS_SIZE), byteorder='big') 68 | length = int.from_bytes(self.program.read(isa.LEN_SIZE), byteorder='big') 69 | src_addr = int.from_bytes(self.program.read(isa.ADDR_SIZE), byteorder='big') 70 | dest_addr = int.from_bytes(self.program.read(isa.UB_ADDR_SIZE), byteorder='big') 71 | #print('{} decoding: len {}, flags {}, src {}, dst {}'.format(opcode, length, flag, src_addr, dest_addr)) 72 | return opcode, src_addr, dest_addr, length, flag 73 | 74 | # opcodes 75 | def act(self, src, dest, length, flag): 76 | print('ACTIVATE!') 77 | 78 | result = self.accumulator[src:src+length] 79 | if flag & isa.FUNC_RELU_MASK: 80 | print(' RELU!!!!') 81 | vfunc = np.vectorize(lambda x: 0 * x if x < 0. else x) 82 | elif flag & isa.FUNC_SIGMOID_MASK: 83 | print(' SIGMOID') 84 | vfunc = np.vectorize(lambda x: int(255./(1.+exp(-x)))) 85 | else: 86 | vfunc = np.vectorize(lambda x: x) 87 | #raise Exception('(╯°□°)╯︵ ┻━┻ bad activation function!') 88 | 89 | result = vfunc(result) 90 | 91 | # downsample/normalize if needed 92 | if not args.raw: 93 | result = [v & 0x000000FF for v in result] 94 | self.unified_buffer[dest:dest+length] = result 95 | 96 | def memops(self, opcode, src_addr, dest_addr, length, flag): 97 | print('Memory xfer! host: {} unified buffer: {}: length: {} (FLAGS? {})'.format( 98 | src_addr, dest_addr, length, flag 99 | )) 100 | 101 | if opcode == 'RHM': 102 | print(' read host memory to unified buffer') 103 | self.unified_buffer[dest_addr:dest_addr + length] = self.host_memory[src_addr:src_addr + length] 104 | elif opcode == 'WHM': 105 | print(' write unified buffer to host memory') 106 | self.host_memory[dest_addr:dest_addr + length] = self.unified_buffer[src_addr:src_addr + length] 107 | elif opcode == 'RW': 108 | print(' read weights from DRAM into MMU') 109 | self.weight_fifo.append(self.weight_memory[src_addr]) 110 | else: 111 | raise Exception('WAT (╯°□°)╯︵ ┻━┻') 112 | 113 | def matrix_multiply_convolve(self, ub_addr, accum_addr, size, flags): 114 | print('Matrix things....') 115 | print(' UB@{} + {} -> MMU -> accumulator@{} + {}'.format( 116 | ub_addr, size, accum_addr, size 117 | )) 118 | 119 | inp = self.unified_buffer[ub_addr: ub_addr + size] 120 | print('MMC input shape: {}'.format(inp.shape)) 121 | weight_mat = self.weight_fifo.popleft() 122 | print('MMC weight: {}'.format(weight_mat)) 123 | if not args.raw: 124 | inp = inp.astype(np.int32) 125 | weight_mat = weight_mat.astype(np.int32) 126 | out = np.matmul(inp, weight_mat) 127 | print('MMC output shape: {}'.format(out.shape)) 128 | overwrite = isa.OVERWRITE_MASK & flags 129 | if overwrite: 130 | self.accumulator[accum_addr:accum_addr + size] = out 131 | else: 132 | self.accumulator[accum_addr:accum_addr + size] += out 133 | 134 | def parse_args(): 135 | global args 136 | 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument('program', action='store', 139 | help='Path to assembly program file.') 140 | parser.add_argument('host_file', action='store', 141 | help='Path to host file.') 142 | parser.add_argument('dram_file', action='store', 143 | help='Path to dram file.') 144 | parser.add_argument('--raw', action='store_true', default=False, 145 | help='Gen sim32.npy instead of sim8.npy.') 146 | args = parser.parse_args() 147 | 148 | if __name__ == '__main__': 149 | if len(sys.argv) < 4: 150 | print('Usage:', sys.argv[0], 'PROGRAM_BINARY DRAM_FILE HOST_FILE') 151 | sys.exit(0) 152 | 153 | parse_args() 154 | tpusim = TPUSim(args.program, args.dram_file, args.host_file) 155 | tpusim.run() 156 | -------------------------------------------------------------------------------- /assembler.py: -------------------------------------------------------------------------------- 1 | """ 2 | ===assembly==== 3 | 4 | Instruction has the following format: 5 | INST: OP, SRC, TAR, LEN (or FUNC), FLAG 6 | LENGTH: 1B, VAR, VAR, 1B, 1B 7 | 8 | OPCODE may define flags by using dot (.) separator following 9 | opcode string. 10 | 11 | For ACT instruction, function byte is defined using the following 12 | mapping: 13 | 0x0 -> ReLU 14 | 0x1 -> Sigmoid 15 | 0x2 -> MaxPooling 16 | 17 | Comments start with #. 18 | 19 | EXAMPLES: 20 | # example program 21 | RHM 1, 2, 3 # first instruction 22 | WHM 1, 2, 3 23 | RW 0xab 24 | MMC 100, 2, 3 25 | MMC.C 100, 2, 3 26 | ACT 0xab, 12, 1 27 | NOP 28 | HLT 29 | 30 | ===binary encoding==== 31 | 32 | INST is encoded in a little-endian format. 33 | OPCODE values are defined in OPCODE2BIN. 34 | FLAG field is r|r|r|r|r|o|s|c, r stands for reserved bit, s for switch bit, 35 | c for convolve bit, and o for override bit. 36 | 37 | SRC and TAR are addresses. They can be of variable length defined in 38 | global dict OPCODE2BIN. 39 | 40 | SRC/TAR takes 5B for memory operations to support at least 8GB addressing, 41 | 3B for Unified Buffer addressing (96KB), 2B for accumulator buffer addressing 42 | (4K). 43 | 44 | """ 45 | 46 | import argparse 47 | import re 48 | from isa import * 49 | 50 | args = None 51 | 52 | TOP_LEVEL_SEP = re.compile(r'[a-zA-Z]+\s+') 53 | 54 | SUFFIX = '.out' 55 | 56 | #ENDIANNESS = 'little' 57 | ENDIANNESS = 'big' 58 | 59 | def DEBUG(string): 60 | if args.debug: 61 | print(string) 62 | else: 63 | return 64 | 65 | def putbytes(val, lo, hi): 66 | # Pack value 'val' into a byte range of lo..hi inclusive. 67 | val = int(val) 68 | lo = lo * 8 # convert bytes to bits 69 | hi = hi * 8 + 7 70 | if val > pow(2, hi-lo+1): 71 | raise Exception("Value {} too large for bit range {}-{}".format(val, lo, hi)) 72 | return val << lo 73 | 74 | def format_instr(op, flags, length, addr, ubaddr): 75 | return putbytes(op, OP_START, OP_END-1) |\ 76 | putbytes(flags, FLAGS_START, FLAGS_END-1) |\ 77 | putbytes(length, LEN_START, LEN_END-1) |\ 78 | putbytes(addr, ADDR_START, ADDR_END-1) |\ 79 | putbytes(ubaddr, UBADDR_START, UBADDR_END-1) 80 | 81 | def assemble(path, n): 82 | """ Translates an assembly code file into a binary. 83 | """ 84 | 85 | assert path 86 | with open(path, 'r') as code: 87 | lines = code.readlines() 88 | code.close() 89 | n = len(lines) if not n else n 90 | write_path = path[:path.rfind('.')] if path.rfind('.') > -1 else path 91 | bin_code = open(write_path + SUFFIX, 'wb') 92 | counter = 0 93 | for line in lines: 94 | line = line.partition('#')[0] 95 | if not line.strip(): 96 | continue 97 | counter += 1 98 | operands = TOP_LEVEL_SEP.split(line)[1] 99 | operands = [int(op.strip(), 0) for op in operands.split(',')] if operands else [] 100 | opcode = line.split()[0].strip() 101 | assert opcode 102 | comps = opcode.split('.') 103 | assert comps and len(comps) < 3 104 | if len(comps) == 1: 105 | opcode = comps[0] 106 | flags = '' 107 | else: 108 | opcode = comps[0] 109 | flags = comps[1] 110 | 111 | flag = 0 112 | if 'S' in flags: 113 | flag |= SWITCH_MASK 114 | if 'C' in flags: 115 | flag |= CONV_MASK 116 | if 'O' in flags: 117 | flag |= OVERWRITE_MASK 118 | if 'Q' in flags: 119 | flag |= FUNC_SIGMOID_MASK 120 | if 'R' in flags: 121 | flag |= FUNC_RELU_MASK 122 | 123 | # binary for flags 124 | bin_flags = flag.to_bytes(1, byteorder=ENDIANNESS) 125 | 126 | opcode, n_src, n_dst, n_len = OPCODE2BIN[opcode] 127 | 128 | if opcode == OPCODE2BIN['NOP'][0]: 129 | instr = format_instr(op=opcode, flags=0, length=0, addr=0, ubaddr=0) 130 | elif opcode == OPCODE2BIN['HLT'][0]: 131 | instr = format_instr(op=opcode, flags=0, length=0, addr=0, ubaddr=0) 132 | elif opcode == OPCODE2BIN['RW'][0]: 133 | # RW instruction only has only operand (weight DRAM address) 134 | instr = format_instr(op=opcode, flags=flag, length=0, addr=operands[0], ubaddr=0) 135 | elif (opcode == OPCODE2BIN['RHM'][0]) or (opcode == OPCODE2BIN['ACT'][0]): 136 | # RHM and ACT have UB-addr as their destination field 137 | instr = format_instr(op=opcode, flags=flag, length=operands[2], addr=operands[0], ubaddr=operands[1]) 138 | else: 139 | # WHM and MMC have UB-addr as their source field 140 | instr = format_instr(op=opcode, flags=flag, length=operands[2], addr=operands[1], ubaddr=operands[0]) 141 | 142 | bin_code.write(instr.to_bytes(14, byteorder=ENDIANNESS)) 143 | 144 | ''' 145 | # binary representation for opcode 146 | bin_opcode = opcode.to_bytes(1, byteorder=ENDIANNESS) 147 | 148 | # binary for oprands 149 | bin_operands = b'' 150 | if len(operands) == 0: 151 | bin_operands = b'' 152 | elif len(operands) == 1: 153 | bin_operands = operands[0].to_bytes(n_src, byteorder=ENDIANNESS) 154 | elif len(operands) == 3: 155 | bin_operands += operands[0].to_bytes(n_src, byteorder=ENDIANNESS) 156 | bin_operands += operands[1].to_bytes(n_tar, byteorder=ENDIANNESS) 157 | bin_operands += operands[2].to_bytes(n_3rd, byteorder=ENDIANNESS) 158 | 159 | # binary for instruction 160 | #bin_rep = bin_flags + bin_operands + bin_opcode 161 | bin_rep = bin_opcode + bin_operands + bin_flags 162 | 163 | if len(bin_rep) < INSTRUCTION_WIDTH_BYTES: 164 | x = 0 165 | zeros = x.to_bytes(INSTRUCTION_WIDTH_BYTES - len(bin_rep), byteorder=ENDIANNESS) 166 | #bin_rep = bin_flags + bin_operands + zeros + bin_opcode 167 | bin_rep = bin_opcode + bin_operands + zeros + bin_flags 168 | 169 | DEBUG(line[:-1]) 170 | DEBUG(bin_rep) 171 | 172 | # write to file 173 | bin_code.write(bin_rep) 174 | ''' 175 | 176 | if counter == n: 177 | break 178 | bin_code.close() 179 | 180 | 181 | def parse_args(): 182 | global args 183 | 184 | parser = argparse.ArgumentParser() 185 | 186 | parser.add_argument('path', action='store', 187 | help='path to source file.') 188 | parser.add_argument('--n', action='store', type=int, default=0, 189 | help='only parse first n lines of code, for dbg only.') 190 | parser.add_argument('--debug', action='store_true', 191 | help='switch debug prints.') 192 | args = parser.parse_args() 193 | 194 | 195 | if __name__ == '__main__': 196 | parse_args() 197 | assemble(args.path, args.n) 198 | -------------------------------------------------------------------------------- /tpu.py: -------------------------------------------------------------------------------- 1 | from pyrtl import * 2 | from pyrtl.analysis import area_estimation, TimingAnalysis 3 | 4 | from config import * 5 | from decoder import decode 6 | from matrix import MMU_top 7 | from activate import act_top 8 | 9 | ############################################################ 10 | # Control Signals 11 | ############################################################ 12 | 13 | accum_act_raddr = WireVector(ACC_ADDR_SIZE) # Activate unit read address for accumulator buffers 14 | weights_dram_in = Input(64*8, "weights_dram_in") # Input signal from weights DRAM controller 15 | weights_dram_valid = Input(1, "weights_dram_valid") # Valid bit for weights DRAM signal 16 | halt = Output(1) # When raised, stop simulation 17 | 18 | 19 | ############################################################ 20 | # Instruction Memory and PC 21 | ############################################################ 22 | 23 | IMem = MemBlock(bitwidth=INSTRUCTION_WIDTH, addrwidth=IMEM_ADDR_SIZE) 24 | pc = Register(IMEM_ADDR_SIZE) 25 | #probe(pc, 'pc') 26 | pc.incr = WireVector(1) 27 | with conditional_assignment: 28 | with pc.incr: 29 | pc.next |= pc + 1 30 | pc.incr <<= 1 # right now, increment the PC every cycle 31 | instr = IMem[pc] 32 | #probe(instr, "instr") 33 | 34 | ############################################################ 35 | # Unified Buffer 36 | ############################################################ 37 | 38 | UBuffer = MemBlock(bitwidth=MATSIZE*DWIDTH, addrwidth=UB_ADDR_SIZE, max_write_ports=2) 39 | 40 | # Address and data wires for MM read port 41 | ub_mm_raddr = WireVector(UBuffer.addrwidth) # MM UB read address 42 | UB2MM = UBuffer[ub_mm_raddr] 43 | 44 | ############################################################ 45 | # Decoder 46 | ############################################################ 47 | 48 | dispatch_mm, dispatch_act, dispatch_rhm, dispatch_whm, dispatch_halt, ub_start_addr, ub_dec_addr, ub_dest_addr, rhm_dec_addr, whm_dec_addr, rhm_length, whm_length, mmc_length, act_length, act_type, accum_raddr, accum_waddr, accum_overwrite, switch_weights, weights_raddr, weights_read = decode(instr) 49 | 50 | halt <<= dispatch_halt 51 | 52 | ############################################################ 53 | # Matrix Multiply Unit 54 | ############################################################ 55 | 56 | ub_mm_raddr_sig, acc_out, mm_busy, mm_done = MMU_top(data_width=DWIDTH, matrix_size=MATSIZE, accum_size=ACC_ADDR_SIZE, ub_size=UB_ADDR_SIZE, start=dispatch_mm, start_addr=ub_start_addr, nvecs=mmc_length, dest_acc_addr=accum_waddr, overwrite=accum_overwrite, swap_weights=switch_weights, ub_rdata=UB2MM, accum_raddr=accum_act_raddr, weights_dram_in=weights_dram_in, weights_dram_valid=weights_dram_valid) 57 | 58 | ub_mm_raddr <<= ub_mm_raddr_sig 59 | 60 | ############################################################ 61 | # Activate Unit 62 | ############################################################ 63 | 64 | accum_raddr_sig, ub_act_waddr, act_out, ub_act_we, act_busy = act_top(start=dispatch_act, start_addr=accum_raddr, dest_addr=ub_dest_addr, nvecs=act_length, func=act_type, accum_out=acc_out) 65 | accum_act_raddr <<= accum_raddr_sig 66 | 67 | # Write the result of activate to the unified buffer 68 | with conditional_assignment: 69 | with ub_act_we: 70 | UBuffer[ub_act_waddr] |= act_out 71 | 72 | #probe(ub_act_we, "ub_act_we") 73 | #probe(ub_act_waddr, "ub_act_waddr") 74 | #probe(act_out, "act_out") 75 | #probe(accum_raddr_sig, "accum_raddr") 76 | 77 | ############################################################ 78 | # Read/Write Host Memory 79 | ############################################################ 80 | 81 | hostmem_raddr = Output(HOST_ADDR_SIZE, "raddr") 82 | hostmem_rdata = Input(DWIDTH*MATSIZE) 83 | hostmem_re = Output(1, "hostmem_re") 84 | hostmem_waddr = Output(HOST_ADDR_SIZE) 85 | hostmem_wdata = Output(DWIDTH*MATSIZE) 86 | hostmem_we = Output(1) 87 | 88 | # Write Host Memory control logic 89 | whm_N = Register(len(whm_length)) 90 | whm_ub_raddr = Register(len(ub_dec_addr)) 91 | whm_addr = Register(len(whm_dec_addr)) 92 | whm_busy = Register(1) 93 | 94 | ubuffer_out = UBuffer[whm_ub_raddr] 95 | 96 | hostmem_waddr <<= whm_addr 97 | hostmem_wdata <<= ubuffer_out 98 | 99 | with conditional_assignment: 100 | with dispatch_whm: 101 | whm_N.next |= whm_length 102 | whm_ub_raddr.next |= ub_dec_addr 103 | whm_addr.next |= whm_dec_addr 104 | whm_busy.next |= 1 105 | with whm_busy: 106 | whm_N.next |= whm_N - 1 107 | whm_ub_raddr.next |= whm_ub_raddr + 1 108 | whm_addr.next |= whm_addr + 1 109 | hostmem_we |= 1 110 | with whm_N == 1: 111 | whm_busy.next |= 0 112 | 113 | 114 | # Read Host Memory control logic 115 | #probe(rhm_length, "rhm_length") 116 | rhm_N = Register(len(rhm_length)) 117 | rhm_addr = Register(len(rhm_dec_addr)) 118 | rhm_busy = Register(1) 119 | rhm_ub_waddr = Register(len(ub_dec_addr)) 120 | with conditional_assignment: 121 | with dispatch_rhm: 122 | rhm_N.next |= rhm_length 123 | rhm_busy.next |= 1 124 | hostmem_raddr |= rhm_dec_addr 125 | hostmem_re |= 1 126 | rhm_addr.next |= + 1 127 | rhm_ub_waddr.next |= ub_dec_addr 128 | with rhm_busy: 129 | rhm_N.next |= rhm_N - 1 130 | hostmem_raddr |= rhm_addr 131 | hostmem_re |= 1 132 | rhm_addr.next |= rhm_addr + 1 133 | rhm_ub_waddr.next |= rhm_ub_waddr + 1 134 | UBuffer[rhm_ub_waddr] |= hostmem_rdata 135 | with rhm_N == 1: 136 | rhm_busy.next |= 0 137 | 138 | ############################################################ 139 | # Weights Memory 140 | ############################################################ 141 | 142 | weights_dram_raddr = Output(WEIGHT_DRAM_ADDR_SIZE) 143 | weights_dram_read = Output(1) 144 | 145 | weights_dram_raddr <<= weights_raddr 146 | weights_dram_read <<= weights_read 147 | 148 | 149 | #probe(dispatch_mm, "dispatch_mm") 150 | #probe(dispatch_act, "dispatch_act") 151 | #probe(dispatch_rhm, "dispatch_rhm") 152 | #probe(dispatch_whm, "dispatch_whm") 153 | 154 | def run_synth(): 155 | print("logic = {:2f} mm^2, mem={:2f} mm^2".format(*area_estimation())) 156 | t = TimingAnalysis() 157 | print("Max freq = {} MHz".format(t.max_freq())) 158 | print("") 159 | print("Running synthesis...") 160 | synthesize() 161 | print("logic = {:2f} mm^2, mem={:2f} mm^2".format(*area_estimation())) 162 | t = TimingAnalysis() 163 | print("Max freq = {} MHz".format(t.max_freq())) 164 | print("") 165 | print("Running optimizations...") 166 | optimize() 167 | total = 0 168 | for gate in working_block(): 169 | if gate.op in ('s', 'c'): 170 | pass 171 | total += 1 172 | print("Gate total: " + str(total)) 173 | print("logic = {:2f} mm^2, mem={:2f} mm^2".format(*area_estimation())) 174 | t = TimingAnalysis() 175 | print("Max freq = {} MHz".format(t.max_freq())) 176 | 177 | #run_synth() 178 | -------------------------------------------------------------------------------- /tf_nn.py: -------------------------------------------------------------------------------- 1 | """ Using TF to train a DNN regressor for Boston Housing Data. 2 | 3 | ref: http://cs.smith.edu/dftwiki/images/b/bd/TFLinearRegression_BostonData.pdf 4 | """ 5 | 6 | import argparse 7 | import numpy as np 8 | import tensorflow as tf 9 | from tensorflow.contrib import learn 10 | from sklearn.model_selection import train_test_split 11 | from sklearn import preprocessing 12 | from sklearn import metrics 13 | 14 | args = None 15 | 16 | def model(inputs, layers, act): 17 | [m1, m2, m3] = layers 18 | #y1 = tf.add(tf.matmul(inputs, m1), b1) 19 | y1 = tf.matmul(inputs, m1) 20 | y1_act = act(y1) 21 | 22 | y2 = tf.matmul(y1_act, m2) 23 | y2_act = act(y2) 24 | 25 | y3 = tf.matmul(y2_act, m3) 26 | #y3 = act(y3) 27 | 28 | #y4 = tf.add(tf.matmul(y3, m4), b4) 29 | #y4 = act(y4) 30 | 31 | #y_ret = tf.matmul(y4, m_out) + b_out 32 | return y3, y1, y1_act, y2, y2_act 33 | 34 | def main(): 35 | boston = learn.datasets.load_dataset('boston') 36 | x, y = boston.data, boston.target 37 | y.resize(y.size, 1) 38 | train_x, test_x, train_y, test_y = train_test_split( 39 | x, y, test_size = .2, random_state=int(np.random.rand(1))) 40 | print 'train: {}/{}, test: {}/{}'.format(len(train_x), len(x), len(test_x), len(x)) 41 | print 'dimension of data: {}'.format(x.shape) 42 | 43 | # scale the data to (0, 1). 44 | scaler = preprocessing.StandardScaler() 45 | train_x = scaler.fit_transform(train_x) 46 | test_x = scaler.fit_transform(test_x) 47 | num_features = train_x.shape[1] 48 | print 'num of features: {}'.format(num_features) 49 | 50 | with tf.name_scope('IO'): 51 | inputs = tf.placeholder(np.float32, [None, num_features], name='X') 52 | outputs = tf.placeholder(np.float32, [None, 1], name='Yhat') 53 | 54 | with tf.name_scope('LAYER'): 55 | # DNN architecture 56 | #layers = [num_features, 8, 8, 8, 8, 1] 57 | layers = [num_features, 8, 8, 1] 58 | # Weight matrices 59 | m1 = tf.Variable(tf.random_normal([layers[0], layers[1]], 60 | 0, .1, dtype=np.float32), name='m1') 61 | m2 = tf.Variable(tf.random_normal([layers[1], layers[2]], 62 | 0, .1, dtype=np.float32), name='m2') 63 | m3 = tf.Variable(tf.random_normal([layers[2], layers[3]], 64 | 0, .1, dtype=np.float32), name='m3') 65 | #m4 = tf.Variable(tf.random_normal([layers[3], layers[4]], 0, .1, dtype=tf.float32), name='m4') 66 | #m_out = tf.Variable(tf.random_normal([layers[4], layers[5]], 67 | # 0, .1, dtype=tf.float32), name='m_out') 68 | # Bias 69 | #b1 = tf.Variable(tf.random_normal([layers[1]], 0, .1, dtype=tf.float32), name='b1') 70 | #b2 = tf.Variable(tf.random_normal([layers[2]], 0, .1, dtype=tf.float32), name='b2') 71 | #b3 = tf.Variable(tf.random_normal([layers[3]], 0, .1, dtype=tf.float32), name='b3') 72 | #b4 = tf.Variable(tf.random_normal([layers[4]], 0, .1, dtype=tf.float32), name='b4') 73 | #b_out = tf.Variable(tf.random_normal([layers[5]], 0, .1, dtype=tf.float32), name='b_out') 74 | # Actication function 75 | #act = tf.nn.sigmoid 76 | act = tf.nn.relu 77 | 78 | with tf.name_scope('TRAIN'): 79 | learning_rate = .5 80 | y_out, y1, y1_act, y2, y2_act = model(inputs, [m1, m2, m3], act) 81 | 82 | cost_op = tf.reduce_mean(tf.pow(y_out - outputs, 2)) 83 | train_op = tf.train.AdagradOptimizer(learning_rate=learning_rate).minimize(cost_op) 84 | 85 | # Actual training 86 | epoch, last_cost, max_epochs, tolerance = 0, 0, 5000, 1e-6 87 | 88 | print 'Begin training...' 89 | sess = tf.Session() 90 | with sess.as_default(): 91 | init = tf.global_variables_initializer() 92 | sess.run(init) 93 | 94 | costs = [] 95 | epochs = [] 96 | 97 | #train_x = norm2byte(train_x).astype(np.float32) 98 | #train_y = norm2byte(train_y).astype(np.float32) 99 | 100 | while True: 101 | sess.run(train_op, feed_dict={inputs: train_x, outputs: train_y}) 102 | if epoch % 1000 == 0: 103 | cost = sess.run(cost_op, feed_dict={inputs: train_x, outputs: train_y}) 104 | costs.append(cost) 105 | epochs.append(epoch) 106 | 107 | print 'Epoch: {} -- Error: {}'.format(epoch, cost) 108 | 109 | if epoch > max_epochs: 110 | print 'Max # of iteration reached, stop.' 111 | break 112 | last_cost = cost 113 | epoch += 1 114 | 115 | # Gen test sets 116 | test_x = test_x[:args.N] 117 | test_y = test_y[:args.N] 118 | 119 | # Quantize inputs/outputs/weights 120 | qtz_input = norm2byte(test_x) 121 | qtz_output = norm2byte(test_y) 122 | m1_val = sess.run(m1) 123 | qtz_m1 = norm2byte(m1_val) 124 | m2_val = sess.run(m2) 125 | qtz_m2 = norm2byte(m2_val) 126 | m3_val = sess.run(m3) 127 | qtz_m3 = norm2byte(m3_val) 128 | 129 | # Pad/Save inputs/weights 130 | HW_WIDTH = 16 131 | 132 | # input 133 | shape = qtz_input.shape 134 | pad_input = np.zeros((shape[0], HW_WIDTH), dtype=np.int8 if not args.raw else np.float32) 135 | pad_input[:shape[0], :shape[1]] = qtz_input if not args.raw else test_x 136 | print 'padded input: {}'.format(pad_input) 137 | np.save(args.save_input_path, pad_input) 138 | 139 | # weights 140 | shape = qtz_m1.shape 141 | pad_m1 = np.zeros((HW_WIDTH, HW_WIDTH), dtype=np.int8 if not args.raw else np.float32) 142 | pad_m1[:shape[0], :shape[1]] = qtz_m1 if not args.raw else m1_val 143 | pad_m1.reshape((1, HW_WIDTH, HW_WIDTH)) 144 | print 'padded m1: {}'.format(pad_m1) 145 | 146 | shape = qtz_m2.shape 147 | pad_m2 = np.zeros((HW_WIDTH, HW_WIDTH), dtype=np.int8 if not args.raw else np.float32) 148 | pad_m2[:shape[0], :shape[1]] = qtz_m2 if not args.raw else m2_val 149 | pad_m2.reshape((1, HW_WIDTH, HW_WIDTH)) 150 | print 'padded m2: {}'.format(pad_m2) 151 | 152 | shape = qtz_m3.shape 153 | pad_m3 = np.zeros((HW_WIDTH, HW_WIDTH), dtype=np.int8 if not args.raw else np.float32) 154 | pad_m3[:shape[0], :shape[1]] = qtz_m3 if not args.raw else m3_val 155 | pad_m3.reshape((1, HW_WIDTH, HW_WIDTH)) 156 | print 'padded m3: {}'.format(pad_m3) 157 | padded_weights = np.array((pad_m1, pad_m2, pad_m3)) 158 | print 'padded weights: {}'.format(padded_weights) 159 | np.save(args.save_weight_path, padded_weights) 160 | 161 | # Update weights 162 | if not args.raw: 163 | sess.run(m1.assign(qtz_m1)) 164 | sess.run(m2.assign(qtz_m2)) 165 | sess.run(m3.assign(qtz_m3)) 166 | 167 | # Test with 8b inputs/weights 168 | test_input = qtz_input if not args.raw else test_x 169 | test_output = qtz_output if not args.raw else test_y 170 | v = sess.run([y_out, y1, y1_act, y2, y2_act], 171 | feed_dict={inputs: test_input, outputs: test_output}) 172 | pred_y, y1, y1_act, y2, y2_act = v[0], v[1], v[2], v[3], v[4] 173 | pred_y = pred_y.astype(np.int8) if not args.raw else pred_y 174 | print y1.shape, y2.shape, pred_y.shape 175 | pred_out = np.array((pred_y.tolist(), y1.tolist(), 176 | y1_act.tolist(), y2.tolist(), y2_act.tolist())) 177 | np.save(args.save_output_path, pred_out) 178 | if args.raw: 179 | np.save('gt32', pred_y) 180 | print 'Prediction\nReal\tPredicted' 181 | for (y, y_hat) in zip(test_y, pred_y): 182 | print '{}\t{}'.format(y, y_hat) 183 | 184 | r2 = metrics.r2_score(test_y, pred_y) 185 | print 'R2: {}'.format(r2) 186 | 187 | def norm2byte(mat, shape=None): 188 | pos_mat = np.vectorize(lambda x: np.abs(x))(mat) 189 | max_w = np.amax(pos_mat) 190 | mat = np.vectorize(lambda x: (127 * x/max_w).astype(np.int8))(mat) 191 | return mat.reshape(shape) if shape else mat 192 | 193 | def parse_args(): 194 | global args 195 | 196 | parser = argparse.ArgumentParser() 197 | 198 | parser.add_argument('--save-weight-path', action='store', default='app_weight', 199 | help='path to save weights.') 200 | parser.add_argument('--save-input-path', action='store', default='app_in', 201 | help='path to save inputs.') 202 | parser.add_argument('--save-output-path', action='store', default='app_out', 203 | help='path to save predicts.') 204 | parser.add_argument('--N', action='store', type=int, 205 | help='number of test cases.') 206 | parser.add_argument('--raw', action='store_true', default=False, 207 | help='use float32 raw numbers.') 208 | args = parser.parse_args() 209 | 210 | if __name__ == '__main__': 211 | parse_args() 212 | main() 213 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UCSB ArchLab OpenTPU Project 2 | 3 | OpenTPU is an open-source re-implementation of Google's Tensor Processing Unit (TPU) by the UC Santa Barbara ArchLab. 4 | 5 | The TPU is Google's custom ASIC for accelerating the inference phase of neural network computations. 6 | 7 | Our design is based on details from Google's paper titled "In-Datacentre Performance Analysis of a Tensor Processing Unit" (https://arxiv.org/abs/1704.04760), which is to appear at ISCA2017. However, no formal spec, interface, or ISA has yet been published for the TPU. 8 | 9 | #### The OpenTPU is powered by PyRTL (http://ucsbarchlab.github.io/PyRTL/). 10 | 11 | ## Requirements 12 | 13 | - Python 3 14 | - PyRTL version >= 0.8.5 15 | - numpy 16 | 17 | Both PyRTL and numpy can be installed with pip; e.g., `pip install pyrtl`. 18 | 19 | ## How to Run 20 | 21 | To run the simple matrix multiply test in both the hardware and functional simulators: 22 | 23 | Make sure MATSIZE is set to 8 in config.py, then 24 | 25 | ``` 26 | python3 assembler.py simplemult.a 27 | python3 runtpu.py simplemult.out simplemult_hostmem.npy simplemult_weights.npy 28 | python3 sim.py simplemult.out simplemult_hostmem.npy simplemult_weights.npy 29 | ``` 30 | 31 | To run the Boston housing data regression test in both the hardware and functional simulators: 32 | 33 | Make sure MATSIZE is set to 16 in config.py, then 34 | ``` 35 | python3 assembler.py boston.a 36 | python3 runtpu.py boston.out boston_inputs.npy boston_weights.npy 37 | python3 sim.py boston.out boston_inputs.npy boston_weights.npy 38 | ``` 39 | 40 | 41 | ### Hardware Simulation 42 | The executable hardware spec can be run using PyRTL's simulation features by running `runtpu.py`. The simulation expects as inputs a binary program and numpy array files containing the initial host memory and the weights. 43 | 44 | Be aware that the size of the hardware Matrix Multiply unit is parametrizable --- double check `config.py` to make sure MATSIZE is what you expect. 45 | 46 | ### Functional Simulation 47 | sim.py implements the functional simulator of OpenTPU. It reads in three cmd args: the assembly program, the host memory file, and the weights file. Due to the different quantization mechnisms between high-level applications (written in tensorflow) and OpenTPU, the simulator runs in two modes: 32b float mode and 8b int mode. The downsampling/quantization mechanism is consistent with the HW implementation of OpenTPU. It generates two sets of outputs, one set being 32b-float typed, the other 8b-int typed. 48 | 49 | Example usage: 50 | 51 | python sim.py boston.out boston_input.npy boston_weights.npy 52 | 53 | Numpy matrices (.npy files) can be generated by calling `numpy.save` on a numpy array. 54 | 55 | checker.py implementes a simple checking function to verify the results from HW, simulator and applications. It checkes the 32b-float application results against 32b-float simulator results and then checks the 8b-int simulator results against 8b-int HW results. 56 | 57 | Example usage: 58 | 59 | python checker.py 60 | 61 | 62 | ## FAQs: 63 | 64 | ### How big/efficient/fast is OpenTPU? 65 | As of the alpha release, we do not have hard synthesis figures for the full 256x256 OpenTPU. 66 | 67 | ### What can OpenTPU do? 68 | The hardware prototype can currently handle matrix multiplies and activations for ReLU and sigmoid --- i.e., the inference phase of many neural network computations. 69 | 70 | ### What features are missing? 71 | Convolution, pooling, programmable normalization. 72 | 73 | ### Does your design follow that of the TPU? 74 | We used high-level design details from the TPU paper to guide our design when possible. Thus, the major components of the chip are the same --- matrix multiply unit, unified buffer, activation unit, accumulator, weight FIFO, etc. Beyond that, the implementations may have many differences. 75 | 76 | ### Does OpenTPU support all the same instructions as TPU? 77 | No. Currently, OpenTPU supports the RHM, WHM, RW, MMC, ACT, NOP, and HLT instructions (see ISA section for details). The purpose, definition, and specification of other TPU instructions is absent from the published paper. Some instructions will likely be added to OpenTPU as we continue development (such as SYNC), but the final ISA will likely feature many differences without a published spec from Google to work off of. 78 | 79 | ### Is OpenTPU binary compatible with the TPU? 80 | No. There is no publicly available interface or spec for TPU. 81 | 82 | ### I'd like to do some analysis/extensions of OpenTPU, but I need Verilog. Do you have a Verilog version? 83 | PyRTL can can output structural Verilog for the design, using the `OutputToVerilog` function. 84 | 85 | ### I have suggestions, criticisms, and/or would like to contribute. 86 | That's not a question, but please get in touch! Email Deeksha (deeksha@cs.ucsb.edu) or Joseph (jmcmahan@cs.ucsb.edu). 87 | 88 | ### I'm a Distinguished Hardware Engineer at Google and the Lead Architect of the TPU. I see many inefficiencies in your implementation. 89 | Hi Norm! Tim welcomes you to Santa Barbara to talk about all things TPU :) 90 | 91 | 92 | ## Software details 93 | 94 | ### ISA 95 | 96 | - RHM src, dst, N 97 | Read Host Memory. 98 | Read _N_ vectors from host memory beginning at address _src_ and save them in the UB (unified buffer) beginning at address _dst_. 99 | - WHM src, dst, N 100 | Write Host Memory. 101 | Write _N_ vectors from the UB beginning at address _src_ to host memory beginning at address _dst_. 102 | - RW addr 103 | Read Weights. 104 | Load the weights tile from the weights DRAM at address _addr_ into the on-chip FIFO. 105 | - MMC.{OS} src, dst, N 106 | Matrix Multiply/Convolution. 107 | Perform a matrix multiply operation on the _N_ vectors beginning at UB address _src_, storing the result in the accumulator buffers beginning at address _dst_. If the _O_ (overwrite) flag is specified, overwrite the contents of the accumulator buffers at the destination addresses; default behavior is to add to the value there and store the new sum. If the _S_ (switch) flag is specified, switch to using the next tile of weights, which must have already been pre-loaded. The first `MMC` instruction in a program should always use the _S_ flag. 108 | - ACT.{RQ} src, dst, N 109 | Activate. 110 | Perform activation on _N_ vectors in the accumulator buffers starting at address _src_, storing the results in the UB beginning at address _dst_. Activation function is specified with a flag: _R_ for ReLU and _Q_ for sigmoid. With no flag, values are passed through without activation. Normalization is programmable at synthesis-time, but not at run-time; by default, after activation the upper 24 bits are dropped from each value, producing an 8-bit integer. 111 | - NOP 112 | No op. Do nothing for one cycle. 113 | - HLT 114 | Halt. Stop simulation. 115 | 116 | 117 | ### Writing a Program 118 | OpenTPU uses no dynamic scheduling; all execution is fully determinstic* and the hardware relies on the compiler to correctly schedule operations and pad NOPs to handle delays. This OpenTPU release does \ 119 | not support "repeat" flags on instructions, so many NOPs are required to ensure correct execution. 120 | 121 | *DRAM is a source of non-deterministic latency, discussed in the Memory Controller section of Microarchitecture. 122 | 123 | ### Generating Data 124 | __Application__ 125 | 126 | 1. Simple one hot 2-layer NN 127 | 128 | gen_one_hot.py generates 8b-int typed random squre matrix as training data and vector as label, example usage: 129 | 130 | python gen_one_hot.py --path simple_train --shape 8 8 --range -5 5 131 | python gen_one_hot.py --path simple_train_label --shape 8 1 --range 0 2 132 | 133 | simple_nn.py trains a simple 2-layer nn on the given train/label dataset and writes the weights into a file, example usage (run gen_one_hot example first to generate the files): 134 | 135 | python simple_nn.py --path simple_train.npy --label simple_train_label.npy 136 | 137 | After running the above command, two files are generated: simple_nn_weight_dram.npy is the 8b-int typed weight dram that the OpenTPU operates on, simple_nn_gt is the pickled ground truth 32b-float resulits and weights. To run with OpenTPU, a test file must also be generated, example usage: 138 | 139 | python gen_one_hot.py --path simple_test --shape 100 8 --range 1, 9 140 | 141 | After which simple_test.npy will be generated and it should be used as the host memory by OpenTPU. 142 | 143 | We also provide simple_nn.a -- the assembly program for this simple nn. 144 | 145 | 2. Tensorflow DNN regression 146 | 147 | Although applications written in any high-level nn framework can be used, here we use tensorflow as an example. 148 | 149 | tf_nn.py trains a MLP regressor on the Boston Housing Dataset (https://archive.ics.uci.edu/ml/datasets/housing). Example usage: 150 | 151 | python tf_nn.py --N 10 --save-input-path boston_input --save-weight-path boston_weights --save-output-path boston_output 152 | python tf_nn.py --N 10 --save-input-path boston_input --save-weight-path boston_weights --save-output-path boston_output --raw 153 | 154 | After running the above command, four files are generated: gt32.npy holds the ground truth prediction values, boston_input.npy holds the input test cases which is used as the host memeory for OpenTPU, boston_output.npy holds all the intermediate output values, and boston_weights.npy holds the weight matrices which are used as the weight dram for OpenTPU. 155 | 156 | Adding --raw to the command generates 32b-float typed files instead of 8b ints. 157 | 158 | 159 | ### Latencies 160 | The following gives the hardware execution latency for each instruction on OpenTPU: 161 | 162 | - RHM - _M_ cycles for reading _M_ vectors 163 | - WHM - _M_ cycles for writing _M_ vectors 164 | - RW - _N*N_/64 cycles for _N_x_N_ MM Array for DRAM transfer, and up to 3 additional cycles to propagate through the FIFO 165 | - MMC - _L+2N_ cycles, for _N_x_N_ MM Array and _L_ vectors multiplied in the instruction 166 | - ACT - _L+1_ cycles, for _L_ vectors activated in the instruction 167 | 168 | 169 | ## Microarchitecture 170 | 171 | ### Matrix Multiply (MM) Unit 172 | The core of the compute of the OpenTPU is the parametrizable array of 8-bit Multiply-Accumulate Units (MACs), each consisting of an 8-bit integer multiplier and an integer adder of between 16 and 32 bits\ 173 | *. Each MAC has two buffers storing 8-bit weights (the second buffer allows weight programming to happen in parallel). Input vectors enter the array from the left, with values advancing one unit to the r\ 174 | ight each cycle. Each unit multiplies the input value by the active weight, adds it to the value from the unit above, and passes the result to the unit below. Input vectors are fed diagonally so that val\ 175 | ues align correctly as partial sums flow down the array. 176 | 177 | *The multipliers produce 16-bit outputs; as values move down the columns of the array, each add produces 1 extra bit. Width is capped at 32, creating the potential for uncaught overflow. 178 | 179 | 180 | ### Accumulator Buffers 181 | Result vectors from the MM Array are written to a software-specified address in a set of accumulator buffers. Instructions indicate whether values should be added into the value already at the address or\ 182 | overwrite it. MM instructions read from the Unified Buffer (UB) and write to the accumulator buffers; activate instructions read from the accumulator buffers and write to the UB. 183 | 184 | 185 | ### Weight FIFO 186 | At scale (256x256 MACs), a full matrix of weights (a "tile") is 64KB; to avoid stalls while weights are moved from off-chip weight DRAM, a 4-entry FIFO is used to buffer tiles. It is assumed the connecti\ 187 | on to the weight DRAM is a standard DDR interface moving data in 64-byte chunks (memory controllers are currently emulated with no simulated delay, so one chunk arrives each cycle). When an MM instructio\ 188 | n carries the "switch" flag, each MAC switches the active weight buffer as first vector of the instruction propagates through the array. Once it reaches the end of the first row, the FIFO begins feeding \ 189 | new weight values into the free buffers of the array. New weight values are passed down through the array each cycle until each row reaches its destination. 190 | 191 | 192 | ### Systolic Setup 193 | Vectors are read all at once from the Unified Buffer, but must be fed diagonally into the MM Array. This is accomplished with a set of sequential buffers in a lower triangular configuration. The top valu\ 194 | e reaches the matrix immediately, the second after one cycle, the third after two, etc., so that each value reaches a MAC at the same time as the corresponding partial sum from the same source vector. 195 | 196 | 197 | ### Memory Controllers 198 | Currently, memory controllers are emulated and have no delay. The connection to Host Memory is currently the size of one vector. The connection to the Weight DRAM uses a standard width of 64 bytes. 199 | 200 | Because the emulated controllers can return a new value each cycle, the OpenTPU hardware simulation currently has no non-detministic delay. With a more accurate DRAM interface that may encounter dynamic \ 201 | delays, programs would need to either take care to schedule for the worst-case memory delay, or make use of another instruction to ensure memory operations complete before the values are required*. 202 | 203 | *We note that the TPU "SYNC" instruction may fulfill this purpose, but is currently unimplemented on OpenTPU. 204 | 205 | 206 | ### Configuration 207 | Unified Buffer size, Accumulator Buffer size, and the size of the MM Array can all be specified in config.py. However, the MM Array must always be square, and vectors/weights are always composed of 8-bit integers. 208 | 209 | -------------------------------------------------------------------------------- /matrix.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from pyrtl import * 3 | from pyrtl import rtllib 4 | from pyrtl.rtllib import multipliers 5 | 6 | #set_debug_mode() 7 | globali = 0 # To give unique numbers to each MAC 8 | def MAC(data_width, matrix_size, data_in, acc_in, switchw, weight_in, weight_we, weight_tag): 9 | '''Multiply-Accumulate unit with programmable weight. 10 | Inputs 11 | data_in: The 8-bit activation value to multiply by weight. 12 | acc_in: 32-bit value to accumulate with product. 13 | switchw: Control signal; when 1, switch to using the other weight buffer. 14 | weight_in: 8-bit value to write to the secondary weight buffer. 15 | weight_we: When high, weights are being written; if tag matches, store weights. 16 | Otherwise, pass them through with incremented tag. 17 | weight_tag: If equal to 255, weight is for this row; store it. 18 | 19 | Outputs 20 | out: Result of the multiply accumulate; moves one cell down to become acc_in. 21 | data_reg: data_in, stored in a pipeline register for cell to the right. 22 | switch_reg: switchw, stored in a pipeline register for cell to the right. 23 | weight_reg: weight_in, stored in a pipeline register for cell below. 24 | weight_we_reg: weight_we, stored in a pipeline register for cell below. 25 | weight_tag_reg: weight_tag, incremented and stored in a pipeline register for cell below 26 | ''' 27 | global globali 28 | # Check lengths of inupts 29 | if len(weight_in) != len(data_in) != data_width: 30 | raise Exception("Expected 8-bit value in MAC.") 31 | if len(switchw) != len(weight_we) != 1: 32 | raise Exception("Expected 1-bit control signal in MAC.") 33 | 34 | # Should never switch weight buffers while they're changing 35 | #rtl_assert(~(weight_we & switchw), Exception("Cannot switch weight values when they're being loaded!")) 36 | 37 | # Use two buffers to store weight and next weight to use. 38 | wbuf1, wbuf2 = Register(len(weight_in)), Register(len(weight_in)) 39 | 40 | # Track which buffer is current and which is secondary. 41 | current_buffer_reg = Register(1) 42 | with conditional_assignment: 43 | with switchw: 44 | current_buffer_reg.next |= ~current_buffer_reg 45 | current_buffer = current_buffer_reg ^ switchw # reflects change in same cycle switchw goes high 46 | 47 | # When told, store a new weight value in the secondary buffer 48 | with conditional_assignment: 49 | with weight_we & (weight_tag == Const(matrix_size-1)): 50 | with current_buffer == 0: # If 0, wbuf1 is current; if 1, wbuf2 is current 51 | wbuf2.next |= weight_in 52 | with otherwise: 53 | wbuf1.next |= weight_in 54 | 55 | # Do the actual MAC operation 56 | weight = select(current_buffer, wbuf2, wbuf1) 57 | #probe(weight, "weight" + str(globali)) 58 | globali += 1 59 | #inlen = max(len(weight), len(data_in)) 60 | #product = weight.sign_extended(inlen*2) * data_in.sign_extended(inlen*2) 61 | #product = product[:inlen*2] 62 | product = helperfuncs.mult_signed(weight, data_in)[:32] 63 | #plen = len(weight) + len(data_in) 64 | #product = weight.sign_extended(plen) * data_in.sign_extended(plen) 65 | #product = product[:plen] 66 | l = max(len(product), len(acc_in)) + 1 67 | out = (product.sign_extended(l) + acc_in.sign_extended(l))[:-1] 68 | 69 | #product = rtllib.multipliers.signed_tree_multiplier(weight, data_in) 70 | #l = max(len(product), len(acc_in)) 71 | #out = product.sign_extended(l) + acc_in.sign_extended(l) 72 | 73 | if len(out) > 32: 74 | out = out[:32] 75 | 76 | # For values that need to be forward to the right/bottom, store in pipeline registers 77 | data_reg = Register(len(data_in)) # pipeline register, holds data value for cell to the right 78 | data_reg.next <<= data_in 79 | switch_reg = Register(1) # pipeline register, holds switch control signal for cell to the right 80 | switch_reg.next <<= switchw 81 | acc_reg = Register(len(out)) # output value for MAC below 82 | acc_reg.next <<= out 83 | weight_reg = Register(len(weight_in)) # pipeline register, holds weight input for cell below 84 | weight_reg.next <<= weight_in 85 | weight_we_reg = Register(1) # pipeline register, holds weight write enable signal for cell below 86 | weight_we_reg.next <<= weight_we 87 | weight_tag_reg = Register(len(weight_tag)) # pipeline register, holds weight tag for cell below 88 | weight_tag_reg.next <<= (weight_tag + 1)[:len(weight_tag)] # increment tag as it passes down rows 89 | 90 | return acc_reg, data_reg, switch_reg, weight_reg, weight_we_reg, weight_tag_reg 91 | 92 | 93 | def MMArray(data_width, matrix_size, data_in, new_weights, weights_in, weights_we): 94 | ''' 95 | data_in: 256-array of 8-bit activation values from systolic_setup buffer 96 | new_weights: 256-array of 1-bit control values indicating that new weight should be used 97 | weights_in: output of weight FIFO (8 x matsize x matsize bit wire) 98 | weights_we: 1-bit signal to begin writing new weights into the matrix 99 | ''' 100 | 101 | # For signals going to the right, store in a var; for signals going down, keep a list 102 | # For signals going down, keep a copy of inputs to top row to connect to later 103 | weights_in_top = [ WireVector(data_width) for i in range(matrix_size) ] # input weights to top row 104 | weights_in_last = [x for x in weights_in_top] 105 | weights_enable_top = [ WireVector(1) for i in range(matrix_size) ] # weight we to top row 106 | weights_enable = [x for x in weights_enable_top] 107 | weights_tag_top = [ WireVector(data_width) for i in range(matrix_size) ] # weight row tag to top row 108 | weights_tag = [x for x in weights_tag_top] 109 | data_out = [Const(0) for i in range(matrix_size)] # will hold output from final row 110 | # Build array of MACs 111 | for i in range(matrix_size): # for each row 112 | din = data_in[i] 113 | switchin = new_weights[i] 114 | #probe(switchin, "switch" + str(i)) 115 | for j in range(matrix_size): # for each column 116 | acc_out, din, switchin, newweight, newwe, newtag = MAC(data_width, matrix_size, din, data_out[j], switchin, weights_in_last[j], weights_enable[j], weights_tag[j]) 117 | #probe(data_out[j], "MACacc{}_{}".format(i, j)) 118 | #probe(acc_out, "MACout{}_{}".format(i, j)) 119 | #probe(din, "MACdata{}_{}".format(i, j)) 120 | weights_in_last[j] = newweight 121 | weights_enable[j] = newwe 122 | weights_tag[j] = newtag 123 | data_out[j] = acc_out 124 | 125 | # Handle weight reprogramming 126 | programming = Register(1) # when 1, we're in the process of loading new weights 127 | size = 1 128 | while pow(2, size) < matrix_size: 129 | size = size + 1 130 | progstep = Register(size) # 256 steps to program new weights (also serves as tag input) 131 | with conditional_assignment: 132 | with weights_we & (~programming): 133 | programming.next |= 1 134 | with programming & (progstep == matrix_size-1): 135 | programming.next |= 0 136 | with otherwise: 137 | pass 138 | with programming: # while programming, increment state each cycle 139 | progstep.next |= progstep + 1 140 | with otherwise: 141 | progstep.next |= Const(0) 142 | 143 | # Divide FIFO output into rows (each row datawidth x matrixsize bits) 144 | rowsize = data_width * matrix_size 145 | weight_arr = [ weights_in[i*rowsize : i*rowsize + rowsize] for i in range(matrix_size) ] 146 | # Mux the wire for this row 147 | current_weights_wire = mux(progstep, *weight_arr) 148 | # Split the wire into an array of 8-bit values 149 | current_weights = [ current_weights_wire[i*data_width:i*data_width+data_width] for i in reversed(range(matrix_size)) ] 150 | 151 | # Connect top row to input and control signals 152 | for i, win in enumerate(weights_in_top): 153 | # From the current 256-array, select the byte for this column 154 | win <<= current_weights[i] 155 | for we in weights_enable_top: 156 | # Whole row gets same signal: high when programming new weights 157 | we <<= programming 158 | for wt in weights_tag_top: 159 | # Tag is same for whole row; use state index (runs from 0 to 255) 160 | wt <<= progstep 161 | 162 | return [ x.sign_extended(32) for x in data_out ] 163 | 164 | 165 | def accum(size, data_in, waddr, wen, wclear, raddr, lastvec): 166 | '''A single 32-bit accumulator with 2^size 32-bit buffers. 167 | On wen, writes data_in to the specified address (waddr) if wclear is high; 168 | otherwise, it performs an accumulate at the specified address (buffer[waddr] += data_in). 169 | lastvec is a control signal indicating that the operation being stored now is the 170 | last vector of a matrix multiply instruction (at the final accumulator, this becomes 171 | a "done" signal). 172 | ''' 173 | 174 | mem = MemBlock(bitwidth=32, addrwidth=size) 175 | 176 | # Writes 177 | with conditional_assignment: 178 | with wen: 179 | with wclear: 180 | mem[waddr] |= data_in 181 | with otherwise: 182 | mem[waddr] |= (data_in + mem[waddr])[:mem.bitwidth] 183 | 184 | # Read 185 | data_out = mem[raddr] 186 | 187 | # Pipeline registers 188 | waddrsave = Register(len(waddr)) 189 | waddrsave.next <<= waddr 190 | wensave = Register(1) 191 | wensave.next <<= wen 192 | wclearsave = Register(1) 193 | wclearsave.next <<= wclear 194 | lastsave = Register(1) 195 | lastsave.next <<= lastvec 196 | 197 | return data_out, waddrsave, wensave, wclearsave, lastsave 198 | 199 | def accumulators(accsize, datas_in, waddr, we, wclear, raddr, lastvec): 200 | ''' 201 | Produces array of accumulators of same dimension as datas_in. 202 | ''' 203 | 204 | #probe(we, "accum_wen") 205 | #probe(wclear, "accum_wclear") 206 | #probe(waddr, "accum_waddr") 207 | 208 | accout = [ None for i in range(len(datas_in)) ] 209 | waddrin = waddr 210 | wein = we 211 | wclearin = wclear 212 | lastvecin = lastvec 213 | for i,x in enumerate(datas_in): 214 | #probe(x, "acc_{}_in".format(i)) 215 | #probe(wein, "acc_{}_we".format(i)) 216 | #probe(waddrin, "acc_{}_waddr".format(i)) 217 | dout, waddrin, wein, wclearin, lastvecin = accum(accsize, x, waddrin, wein, wclearin, raddr, lastvecin) 218 | accout[i] = dout 219 | done = lastvecin 220 | 221 | return accout, done 222 | 223 | 224 | def FIFO(matsize, mem_data, mem_valid, advance_fifo): 225 | ''' 226 | matsize is the length of one row of the Matrix. 227 | mem_data is the connection from the DRAM controller, which is assumed to be 64 bytes wide. 228 | mem_valid is a one bit control signal from the controller indicating that the read completed and the current value is valid. 229 | advance_fifo signals to drop the tile at the end of the FIFO and advance everything forward. 230 | 231 | Output 232 | tile, ready, full 233 | tile: entire tile at the front of the queue (8 x matsize x matsize bits) 234 | ready: the tile output is valid 235 | full: there is no room in the FIFO 236 | ''' 237 | 238 | #probe(mem_data, "fifo_dram_in") 239 | #probe(mem_valid, "fifo_dram_valid") 240 | #probe(advance_fifo, "weights_advance_fifo") 241 | 242 | # Make some size parameters, declare state register 243 | totalsize = matsize * matsize # total size of a tile in bytes 244 | tilesize = totalsize * 8 # total size of a tile in bits 245 | ddrwidth = int(len(mem_data)/8) # width from DDR in bytes (typically 64) 246 | size = 1 247 | while pow(2, size) < (totalsize/ddrwidth): # compute log of number of transfers required 248 | size = size + 1 249 | state = Register(size) # Number of reads to receive (each read is ddrwidth bytes) 250 | startup = Register(1) 251 | startup.next <<= 1 252 | 253 | #probe(state, "fifo_state") 254 | 255 | # Declare top row of buffer: need to write to it in ddrwidth-byte chunks 256 | topbuf = [ Register(ddrwidth*8) for i in range(max(1, int(totalsize/ddrwidth))) ] 257 | 258 | # Latch command to advance FIFO, since it may not complete immediately 259 | droptile = Register(1) 260 | clear_droptile = WireVector(1) 261 | with conditional_assignment: 262 | with advance_fifo: 263 | droptile.next |= 1 264 | with clear_droptile: 265 | droptile.next |= 0 266 | 267 | #probe(droptile, "fifo_droptile") 268 | #probe(clear_droptile, "fifo_clear_droptile") 269 | 270 | # When we get data from DRAM controller, write to next buffer space 271 | with conditional_assignment: 272 | with mem_valid: 273 | state.next |= state + 1 # state tracks which ddrwidth-byte chunk we're writing to 274 | for i, reg in enumerate(reversed(topbuf)): # enumerate a decoder for write-enable signals 275 | #probe(reg, "fifo_reg{}".format(i)) 276 | with state == Const(i, bitwidth=size): 277 | reg.next |= mem_data 278 | 279 | # Track when first buffer is filled and when data moves out of it 280 | full = Register(1) # goes high when last chunk of top buffer is filled 281 | cleartop = WireVector(1) 282 | with conditional_assignment: 283 | with mem_valid & (state == Const(len(topbuf)-1)): # writing the last buffer spot now 284 | full.next |= 1 285 | with cleartop: # advancing FIFO, so buffer becomes empty 286 | full.next |= 0 287 | 288 | # Build buffers for remainder of FIFO 289 | buf2, buf3, buf4 = Register(tilesize), Register(tilesize), Register(tilesize) 290 | #probe(concat_list(topbuf), "buf1") 291 | #probe(buf2, "buf2") 292 | #probe(buf3, "buf3") 293 | #probe(buf4, "buf4") 294 | #probe(full, "buf1_full") 295 | # If a given row is empty, track that so we can fill immediately 296 | empty2, empty3, empty4 = Register(1), Register(1), Register(1) 297 | #probe(empty2, "buf2_empty") 298 | #probe(empty3, "buf3_empty") 299 | #probe(empty4, "buf4_empty") 300 | 301 | # Handle moving data between the buffers 302 | with conditional_assignment: 303 | with ~startup: 304 | empty2.next |= 1 305 | empty3.next |= 1 306 | empty4.next |= 1 307 | with full & empty2: # First buffer is full, second is empty 308 | buf2.next |= concat_list(topbuf) # move data to second buffer 309 | cleartop |= 1 # empty the first buffer 310 | empty2.next |= 0 # mark the second buffer as non-empty 311 | with empty3 & ~empty2: # Third buffer is empty and second is full 312 | buf3.next |= buf2 313 | empty3.next |= 0 314 | empty2.next |= 1 315 | with empty4 & ~empty3: # Fourth buffer is empty and third is full 316 | buf4.next |= buf3 317 | empty4.next |= 0 318 | empty3.next |= 1 319 | with droptile: 320 | empty4.next |= 1 # mark fourth buffer as free; tiles will advance automatically 321 | clear_droptile |= 1 322 | 323 | ready = startup & (~empty4) & (~droptile) # there is data in final buffer and we're not about to change it 324 | 325 | return buf4, ready, full 326 | 327 | def systolic_setup(data_width, matsize, vec_in, waddr, valid, clearbit, lastvec, switch): 328 | '''Buffers vectors from the unified SRAM buffer so that they can be fed along diagonals to the 329 | Matrix Multiply array. 330 | 331 | matsize: row size of Matrix 332 | vec_in: row read from unified buffer 333 | waddr: the accumulator address this vector is bound for 334 | valid: this is a valid vector; write it when done 335 | clearbit: if 1, store result (default accumulate) 336 | lastvec: this is the last vector of a matrix 337 | switch: use the next weights tile beginning with this vector 338 | 339 | Output 340 | next_row: diagonal cross-cut of vectors to feed to MM array 341 | switchout: switch signals for MM array 342 | addrout: write address for first accumulator 343 | weout: write enable for first accumulator 344 | clearout: clear signal for first accumulator 345 | doneout: done signal for first accumulator 346 | ''' 347 | 348 | # Use a diagonal set of buffer so that when a vector is read from SRAM, it "falls" into 349 | # the correct diagonal pattern. 350 | # The last column of buffers need extra bits for control signals, which propagate down 351 | # and into the accumulators. 352 | 353 | addrreg = Register(len(waddr)) 354 | addrreg.next <<= waddr 355 | wereg = Register(1) 356 | wereg.next <<= valid 357 | clearreg = Register(1) 358 | clearreg.next <<= clearbit 359 | donereg = Register(1) 360 | donereg.next <<= lastvec 361 | topreg = Register(data_width) 362 | 363 | firstcolumn = [topreg,] + [ Register(data_width) for i in range(matsize-1) ] 364 | lastcolumn = [ None for i in range(matsize) ] 365 | lastcolumn[0] = topreg 366 | 367 | # Generate switch signals to matrix; propagate down 368 | switchout = [ None for i in range(matsize) ] 369 | switchout[0] = Register(1) 370 | switchout[0].next <<= switch 371 | for i in range(1, len(switchout)): 372 | switchout[i] = Register(1) 373 | switchout[i].next <<= switchout[i-1] 374 | 375 | # Generate control pipeline for address, clear, and done signals 376 | addrout = addrreg 377 | weout = wereg 378 | clearout = clearreg 379 | doneout = lastvec 380 | #probe(clearout, "sys_clear_in") 381 | # Need one extra cycle of delay for control signals before giving them to first accumulator 382 | # But we already did registers for first row, so cancels out 383 | for i in range(0, matsize): 384 | a = Register(len(addrout)) 385 | a.next <<= addrout 386 | addrout = a 387 | w = Register(1) 388 | w.next <<= weout 389 | weout = w 390 | c = Register(1) 391 | c.next <<= clearout 392 | clearout = c 393 | d = Register(1) 394 | d.next <<= doneout 395 | doneout = d 396 | #probe(clearout, "sys_clear_out") 397 | 398 | # Generate buffers in a diagonal pattern 399 | for row in range(1, matsize): # first row is done 400 | left = firstcolumn[row] 401 | lastcolumn[row] = left 402 | for column in range(0, row): # first column is done 403 | buf = Register(data_width) 404 | buf.next <<= left 405 | left = buf 406 | lastcolumn[row] = left # holds final column for output 407 | 408 | # Connect first column to input data 409 | datain = [ vec_in[i*data_width : i*data_width+data_width] for i in range(matsize) ] 410 | for din, reg in zip(datain, firstcolumn): 411 | reg.next <<= din 412 | 413 | 414 | return lastcolumn, switchout, addrout, weout, clearout, doneout 415 | 416 | 417 | def MMU(data_width, matrix_size, accum_size, vector_in, accum_raddr, accum_waddr, vec_valid, accum_overwrite, lastvec, switch_weights, ddr_data, ddr_valid): #, weights_in, weights_we): 418 | 419 | logn1 = 1 420 | while pow(2, logn1) < (matrix_size + 1): 421 | logn1 = logn1 + 1 422 | logn = 1 423 | while pow(2, logn) < (matrix_size): 424 | logn = logn + 1 425 | 426 | programming = Register(1) # if high, we're programming new weights now 427 | waiting = WireVector(1) # if high, a switch is underway and we're waiting 428 | 429 | weights_wait = Register(logn1, "weights_wait") # counts cycles since last weight push 430 | weights_count = Register(logn, "weights_count") # counts cycles of current weight push 431 | startup = Register(1) 432 | startup.next <<= 1 # 0 only in first cycle 433 | weights_we = WireVector(1) 434 | done_programming = WireVector(1) 435 | first_tile = Register(1) # Tracks if we've programmed the first tile yet 436 | 437 | #rtl_assert(~(switch_weights & (weights_wait != 0)), Exception("Weights are not ready to switch. Need a minimum of {} + 1 cycles since last switch.".format(matrix_size))) 438 | 439 | # FIFO 440 | weights_tile, tile_ready, full = FIFO(matsize=matrix_size, mem_data=ddr_data, mem_valid=ddr_valid, advance_fifo=done_programming) 441 | #probe(tile_ready, "tile_ready") 442 | #probe(weights_tile, "FIFO_weights_out") 443 | 444 | matin, switchout, addrout, weout, clearout, doneout = systolic_setup(data_width=data_width, matsize=matrix_size, vec_in=vector_in, waddr=accum_waddr, valid=vec_valid, clearbit=accum_overwrite, lastvec=lastvec, switch=switch_weights) 445 | 446 | mouts = MMArray(data_width=data_width, matrix_size=matrix_size, data_in=matin, new_weights=switchout, weights_in=weights_tile, weights_we=weights_we) 447 | 448 | accout, done = accumulators(accsize=accum_size, datas_in=mouts, waddr=addrout, we=weout, wclear=clearout, raddr=accum_raddr, lastvec=doneout) 449 | 450 | switchstart = switchout[0] 451 | totalwait = Const(matrix_size + 1) 452 | waiting <<= weights_wait != totalwait # if high, we have to wait 453 | 454 | #probe(waiting, "waiting") 455 | 456 | with conditional_assignment: 457 | with ~startup: # when we start, configure values to be ready to accept a new tile 458 | weights_wait.next |= totalwait 459 | with waiting: # need to wait for switch to finish propagating 460 | weights_wait.next |= weights_wait + 1 461 | with ~first_tile & tile_ready: # start programming the first tile 462 | weights_wait.next |= totalwait # we don't have to swait for a switch to clear 463 | programming.next |= 1 # begin programming weights 464 | weights_count.next |= 0 465 | first_tile.next |= 1 466 | with switchstart: # Weight switch initiated; begin waiting 467 | weights_wait.next |= 0 468 | programming.next |= 1 469 | weights_count.next |= 0 470 | with programming: # We're pushing new weights now 471 | with weights_count == Const(matrix_size-1): # We've reached the end 472 | programming.next |= 0 473 | done_programming |= 1 474 | with otherwise: # Still programming; increment count and keep write signal high 475 | weights_count.next |= weights_count + 1 476 | weights_we |= 1 477 | 478 | ''' 479 | with conditional_assignment: 480 | with startup == 0: # When we start, we're ready to push weights as soon as FIFO is ready 481 | weights_wait.next |= totalwait 482 | with switchout: # Got a switch signal; start wait count 483 | weights_wait.next |= 1 484 | with weights_wait != totalwait: # Stall on the final number 485 | weights_wait.next |= weights_wait + 1 486 | with weights_count != 0: # If we've started programming new weights, reset 487 | weights_wait.next |= 0 488 | with otherwise: 489 | pass 490 | 491 | with ~startup: 492 | pass 493 | with (weights_wait == totalwait) & tile_ready: # Ready to push new weights in 494 | weights_count.next |= 1 495 | with weights_count == Const(matrix_size): # Finished pushing new weights 496 | done_programming |= 1 497 | weights_count.next |= 0 498 | with otherwise: # We're pushing weights now; increment count 499 | weights_count.next |= weights_count + 1 500 | weights_we |= 1 501 | ''' 502 | 503 | return accout, done 504 | 505 | def MMU_top(data_width, matrix_size, accum_size, ub_size, start, start_addr, nvecs, dest_acc_addr, overwrite, swap_weights, ub_rdata, accum_raddr, weights_dram_in, weights_dram_valid): 506 | ''' 507 | 508 | Outputs 509 | ub_raddr: read address for unified buffer 510 | ''' 511 | 512 | #probe(ub_rdata, "ub_mm_rdata") 513 | 514 | accum_waddr = Register(accum_size) 515 | vec_valid = WireVector(1) 516 | overwrite_reg = Register(1) 517 | last = WireVector(1) 518 | swap_reg = Register(1) 519 | 520 | busy = Register(1) 521 | N = Register(len(nvecs)) 522 | ub_raddr = Register(ub_size) 523 | 524 | rtl_assert(~(start & busy), Exception("Cannot dispatch new MM instruction while previous instruction is still being issued.")) 525 | 526 | #probe(vec_valid, "MM_vec_valid_issue") 527 | #probe(busy, "MM_busy") 528 | 529 | # Vector issue control logic 530 | with conditional_assignment: 531 | with start: # new instruction being issued 532 | accum_waddr.next |= dest_acc_addr 533 | overwrite_reg.next |= overwrite 534 | swap_reg.next |= swap_weights 535 | busy.next |= 1 536 | N.next |= nvecs 537 | ub_raddr.next |= start_addr # begin issuing next cycle 538 | with busy: # We're issuing a vector this cycle 539 | vec_valid |= 1 540 | swap_reg.next |= 0 541 | N.next |= N - 1 542 | with N == 1: # this was the last vector 543 | last |= 1 544 | overwrite_reg.next |= 0 545 | busy.next |= 0 546 | with otherwise: # we're going to issue a vector next cycle as well 547 | ub_raddr.next |= ub_raddr + 1 548 | accum_waddr.next |= accum_waddr + 1 549 | last |= 0 550 | 551 | acc_out, done = MMU(data_width=data_width, matrix_size=matrix_size, accum_size=accum_size, vector_in=ub_rdata, accum_raddr=accum_raddr, accum_waddr=accum_waddr, vec_valid=vec_valid, accum_overwrite=overwrite_reg, lastvec=last, switch_weights=swap_reg, ddr_data=weights_dram_in, ddr_valid=weights_dram_valid) 552 | 553 | #probe(ub_raddr, "ub_mm_raddr") 554 | 555 | return ub_raddr, acc_out, busy, done 556 | 557 | 558 | 559 | ''' 560 | Do we need full/stall signal from Matrix? Would need to stop SRAM out from writing to systolic setup 561 | Yes: MMU needs to track when both buffers used and emit such a signal 562 | 563 | The timing systems for weights programming are wonky right now. Both rtl_asserts are failing, but the 564 | right answer comes out if you ignore that. It looks like the state machine that counts time since the 565 | last weights programming keeps restarting, so the MMU thinks it's always programming weights? 566 | 567 | Control signals propagating down systolic_setup to accumulators: 568 | -Overwrite signal (default: accumulate) 569 | -New accumulator address value (default: add 1 to previous address) 570 | -Done signal? 571 | ''' 572 | 573 | def testall(input_vectors, weights_vectors): 574 | DATWIDTH = 8 575 | MATSIZE = 4 576 | ACCSIZE = 8 577 | 578 | L = len(input_vectors) 579 | 580 | ins = [probe(Input(DATWIDTH)) for i in range(MATSIZE)] 581 | invec = concat_list(ins) 582 | swap = Input(1, 'swap') 583 | waddr = Input(8) 584 | lastvec = Input(1) 585 | valid = Input(1) 586 | raddr = Input(8, "raddr") # accumulator read address to read out answers 587 | donesig = Output(1, "done") 588 | 589 | outs = [Output(32, name="out{}".format(str(i))) for i in range(MATSIZE)] 590 | 591 | #ws = [ Const(item, bitwidth=DATWIDTH) for sublist in weights_vectors for item in sublist ] # flatten weight matrix 592 | #ws = concat_list(ws) # combine weights into single wire 593 | ws = [ item for sublist in weights_vectors for item in sublist ] # flatten weight matrix 594 | print(ws) 595 | #ws = reduce(lambda x, y : (x<<8)+y, ws) # "concat" weights into one integer 596 | 597 | weightsdata = Input(64*8) 598 | weightsvalid = Input(1) 599 | 600 | accout, done = MMU(data_width=DATWIDTH, matrix_size=MATSIZE, accum_size=ACCSIZE, vector_in=invec, accum_raddr=raddr, accum_waddr=waddr, vec_valid=valid, accum_overwrite=Const(0), lastvec=lastvec, switch_weights=swap, ddr_data=weightsdata, ddr_valid=weightsvalid) 601 | 602 | donesig <<= done 603 | for out, accout in zip(outs, accout): 604 | out <<= accout 605 | 606 | sim_trace = SimulationTrace() 607 | sim = FastSimulation(tracer=sim_trace) 608 | 609 | # make a default input dictionary 610 | din = { swap:0, waddr:0, lastvec:0, valid:0, raddr:0, weightsdata:0, weightsvalid:0 } 611 | din.update({ins[j] : 0 for j in range(MATSIZE)}) 612 | 613 | # Give a few cycles for startup 614 | sim.step(din) 615 | 616 | # First, simulate memory read to feed weights to FIFO 617 | chunk = 64*8 # size of one dram read 618 | #ws = [ ws[i*chunk:i*chunk+chunk] for i in range(max(1,len(ws)/chunk)) ] # divide weights into dram chunks 619 | # divide weights into dram-transfer sized chunks 620 | ws = reduce(lambda x, y : (x<<8)+y, ws) # "concat" weights into one integer 621 | ws = [ (ws >> (64*8*i)) & pow(2, 64*8)-1 for i in range(max(1,len(weights_vectors)/64)) ] 622 | print(ws) 623 | for block in ws: 624 | d = din.copy() 625 | d.update({ins[j] : 0 for j in range(MATSIZE)}) 626 | d.update({ weightsdata:block, weightsvalid:1}) 627 | sim.step(d) 628 | 629 | # Wait until the FIFO is ready 630 | for i in range(10): 631 | sim.step(din) 632 | 633 | #din.update({ins[j]:0 for j in range(MATSIZE)}) 634 | 635 | # Send signal to write weights 636 | #d = din.copy() 637 | #d[weights_we] = 1 638 | #sim.step(d) 639 | 640 | # Wait MATSIZE cycles for weights to propagate 641 | for i in range(MATSIZE*2): 642 | sim.step(din) 643 | 644 | # Send the swap signal with first row of input 645 | d = din.copy() 646 | d.update({ins[j] : input_vectors[0][j] for j in range(MATSIZE) }) 647 | d.update({ swap : 1, valid : 1 }) 648 | sim.step(d) 649 | 650 | # Send rest of input 651 | for i in range(L-1): 652 | d = din.copy() 653 | d.update({ins[j] : input_vectors[i+1][j] for j in range(MATSIZE) }) 654 | d.update({ waddr : i+1, lastvec : 1 if i == L-2 else 0, valid : 1 }) 655 | sim.step(d) 656 | 657 | # Wait some cycles while it propagates 658 | for i in range(L*2): 659 | d = din.copy() 660 | sim.step(d) 661 | 662 | # Read out values 663 | for i in range(L): 664 | d = din.copy() 665 | d[raddr] = i 666 | sim.step(d) 667 | 668 | with open('trace.vcd', 'w') as f: 669 | sim_trace.print_vcd(f) 670 | 671 | 672 | if __name__ == "__main__": 673 | #weights = [[1, 10, 10, 2], [3, 9, 6, 2], [6, 8, 2, 8], [4, 1, 8, 6]] # transposed 674 | #weights = [[4, 1, 8, 6], [6, 8, 2, 8], [3, 9, 6, 2], [1, 10, 10, 2]] # tranposed, reversed 675 | #weights = [[1, 3, 6, 4], [10, 9, 8, 1], [10, 6, 2, 8], [2, 2, 8, 6]] 676 | weights = [[2, 2, 8, 6], [10, 6, 2, 8], [10, 9, 8, 1], [1, 3, 6, 4]] # reversed 677 | 678 | vectors = [[12, 7, 2, 6], [21, 21, 18, 8], [1, 4, 18, 11], [6, 3, 25, 15], [21, 12, 1, 15], [1, 6, 13, 8], [24, 25, 18, 1], [2, 5, 13, 6], [19, 3, 1, 17], [25, 10, 20, 10]] 679 | 680 | testall(vectors, weights) 681 | --------------------------------------------------------------------------------