├── README.md
├── data
├── golden_result.txt
├── mlp_img.bin
├── mlp_instr.bin
└── mlp_param.bin
├── lab1
├── README.md
├── data
│ ├── feature.dat
│ ├── golden.dat
│ └── weight.dat
├── refcode
│ ├── conv3d.m
│ ├── convmxu.m
│ └── saveparam.m
├── run_hls.tcl
└── src
│ ├── mxu.cpp
│ ├── tb_mxu.cpp
│ └── tpu.h
├── lab2
├── README.md
├── run_hls.tcl
└── src
│ ├── relu_norm_pool.cpp
│ ├── tb_pool.cpp
│ └── tpu.h
├── pictures
├── cla_result.png
├── sim.png
└── syn.png
└── src
├── ctrl.cpp
├── mxu.cpp
├── norm_relu_pool.cpp
├── tb_tpu.cpp
├── tpu.cpp
└── tpu.h
/README.md:
--------------------------------------------------------------------------------
1 | # SimpleTPU
2 |
3 | A Tensor Processing Unit is designed to accelerate the matrix multiplication, especially for Multilayer perceptron and Convolution Nerual Network.
4 | This implmentaion is mainly following the Google TPU Version 1, which architecture is introduced in [https://arxiv.org/ftp/arxiv/papers/1704/1704.04760.pdf](https://arxiv.org/ftp/arxiv/papers/1704/1704.04760.pdf "In-Datacenter Performance Analysis of a Tensor Processing Unit").
5 |
6 | It may cost a lot of time to implementation TPU using Hardware Description Language (such as VHDL or Verilog HDL), even if I had tried to simplify it. So I try to use the Xilinx HLS ToolKit to complete it.
7 |
8 | The plan is divided into three phases.
9 |
10 | - Phase 1: Completing the main computing module,including
11 | - Lab1:Systolic Array
12 | - Lab2:Relu, Normalization & Pooling
13 | - Phase 2: Finish the full design of simpleTPU.
14 | - Phase 3: Testing the simpleTPU through some real network, such as MLP and CNN.
15 |
16 | # Key Features
17 |
18 | The key features of Simple TPU including
19 | - Int8 mulitply & Int32 accumulators
20 | - VLIW based instruction parallel
21 | - Vector Architecture based data parallel
22 |
23 | Here are some operate which Simple TPU can support.
24 |
25 | Operate | Support
26 | -|-
27 | Conv3d | in_channels: Resource Constrained
out_channels: Resource Constrained
kerner_size: Support
stride: support
padding: Support
dilation:Support
groups: Architecture Constrained
bias :Support
28 | ConvTranspose3d | The same as above
29 | Maxpool2d | kernel_size: Support
stride: Support
padding: Support
30 | Avgpool2d | The same as above
31 | Relu | Only support Relu as nonlinear function
32 | BatchNorm2d | BatchNorm2d is merge with Conv or Pool when inference
33 | Linear | Resource Constrained
34 | UpscalingNearest2D | Support (calling Avgpool2d multiple times.)
35 | UpscalingBilinear2D | Support (calling Avgpool2d multiple times.)
36 |
37 |
38 | # Performance
39 | The size of mac array in SimpleTPU is 32*32, the clock frequency is 500MHz (timing closure when using Xilinx ultrascale+ FPGA, Speed -2).
40 | $$32\times 32 \times 500 \times 2 = 1 Tops(int8)$$
41 |
42 | # Installation
43 | **env** :
44 | - Vivado HLS 2018.2
45 |
46 | **run** :
47 | - step1: `vivado_hls -f run_hls.tcl`
48 | - step2: lanch vivado HLS and open the project
49 | - step3: Run C synthesis, C/RTL cosimulation e.t.c
50 |
51 | **Synthesis Result**:
52 | 
53 | **Simulation Result**:
54 | 
55 | # Examlpes
56 | ## 1. MLP
57 | The network structure of mlp is defined as follow.
58 | ```
59 | class MLP(nn.Module):
60 | def __init__(self):
61 | super(MLP, self).__init__()
62 | self.hidden = nn.Linear(784,64)
63 | self.fc = nn.Linear(64,10)
64 |
65 | def forward(self, x):
66 | x = x.view(-1,784)
67 | x = self.hidden(x)
68 | x = self.fc(x)
69 | return F.log_softmax(x, dim=1)
70 | ```
71 |
72 | Work efficiency of SimpleTPU is about 84%.
73 |
74 |
75 | |LOC| Layers | Nonlinear function | Weights | Batch Size | % of Deployed|
76 | |---|---|---|----|----|----|
77 | |10 | 2 FC | Relu | 5M | 512 | 16%|
78 |
79 | Classfication Result in MNIST.
80 |
81 | 
82 | ## 2. CNN
83 | Because there is no compiler to generate instruction, this plan was suspended.
84 | If you want to kown how to calculate convolution using SimpleTPU, lab1 provides a simple example.
85 |
86 |
87 | # Relative Link
88 | https://www.cnblogs.com/sea-wind/p/10993958.html
89 |
--------------------------------------------------------------------------------
/data/golden_result.txt:
--------------------------------------------------------------------------------
1 | 7
2 | 2
3 | 1
4 | 0
5 | 4
6 | 1
7 | 4
8 | 9
9 | 6
10 | 9
11 | 0
12 | 6
13 | 9
14 | 0
15 | 1
16 | 5
17 | 9
18 | 7
19 | 6
20 | 4
21 | 9
22 | 6
23 | 6
24 | 5
25 | 4
26 | 0
27 | 7
28 | 4
29 | 0
30 | 1
31 | 3
32 | 1
33 | 3
34 | 6
35 | 7
36 | 2
37 | 7
38 | 1
39 | 2
40 | 1
41 | 1
42 | 7
43 | 4
44 | 2
45 | 6
46 | 5
47 | 1
48 | 2
49 | 4
50 | 4
51 | 6
52 | 3
53 | 5
54 | 5
55 | 6
56 | 0
57 | 4
58 | 1
59 | 9
60 | 5
61 | 7
62 | 8
63 | 4
64 | 2
65 | 7
66 | 4
67 | 6
68 | 4
69 | 3
70 | 0
71 | 7
72 | 0
73 | 2
74 | 9
75 | 1
76 | 7
77 | 3
78 | 7
79 | 9
80 | 7
81 | 9
82 | 6
83 | 2
84 | 7
85 | 8
86 | 4
87 | 7
88 | 5
89 | 6
90 | 1
91 | 3
92 | 6
93 | 4
94 | 3
95 | 1
96 | 4
97 | 1
98 | 1
99 | 6
100 | 9
101 | 6
102 | 0
103 | 5
104 | 4
105 | 9
106 | 9
107 | 2
108 | 1
109 | 4
110 | 4
111 | 8
112 | 1
113 | 3
114 | 9
115 | 7
116 | 4
117 | 4
118 | 4
119 | 9
120 | 2
121 | 5
122 | 4
123 | 7
124 | 6
125 | 4
126 | 9
127 | 0
128 | 5
129 | 8
130 | 5
131 | 6
132 | 6
133 | 5
134 | 2
135 | 8
136 | 1
137 | 0
138 | 1
139 | 6
140 | 4
141 | 6
142 | 7
143 | 3
144 | 1
145 | 9
146 | 1
147 | 8
148 | 2
149 | 0
150 | 9
151 | 9
152 | 9
153 | 5
154 | 5
155 | 1
156 | 5
157 | 6
158 | 0
159 | 3
160 | 4
161 | 4
162 | 6
163 | 5
164 | 4
165 | 6
166 | 5
167 | 4
168 | 5
169 | 1
170 | 4
171 | 4
172 | 7
173 | 2
174 | 3
175 | 2
176 | 1
177 | 1
178 | 8
179 | 1
180 | 8
181 | 1
182 | 8
183 | 5
184 | 0
185 | 8
186 | 9
187 | 2
188 | 5
189 | 0
190 | 1
191 | 1
192 | 1
193 | 0
194 | 4
195 | 0
196 | 5
197 | 1
198 | 6
199 | 4
200 | 2
201 | 3
202 | 6
203 | 1
204 | 1
205 | 1
206 | 3
207 | 9
208 | 5
209 | 2
210 | 9
211 | 4
212 | 5
213 | 9
214 | 3
215 | 9
216 | 0
217 | 3
218 | 6
219 | 5
220 | 5
221 | 7
222 | 2
223 | 2
224 | 7
225 | 1
226 | 2
227 | 8
228 | 4
229 | 1
230 | 7
231 | 3
232 | 3
233 | 8
234 | 9
235 | 7
236 | 9
237 | 2
238 | 2
239 | 4
240 | 1
241 | 5
242 | 8
243 | 8
244 | 4
245 | 2
246 | 6
247 | 0
248 | 6
249 | 4
250 | 2
251 | 4
252 | 1
253 | 9
254 | 5
255 | 7
256 | 7
257 | 2
258 | 8
259 | 2
260 | 0
261 | 8
262 | 1
263 | 7
264 | 7
265 | 9
266 | 1
267 | 8
268 | 1
269 | 8
270 | 0
271 | 3
272 | 0
273 | 1
274 | 9
275 | 9
276 | 4
277 | 1
278 | 8
279 | 2
280 | 1
281 | 2
282 | 9
283 | 7
284 | 5
285 | 9
286 | 2
287 | 6
288 | 4
289 | 1
290 | 5
291 | 4
292 | 2
293 | 9
294 | 2
295 | 0
296 | 4
297 | 0
298 | 0
299 | 2
300 | 8
301 | 6
302 | 2
303 | 1
304 | 2
305 | 4
306 | 0
307 | 2
308 | 9
309 | 4
310 | 3
311 | 3
312 | 0
313 | 0
314 | 5
315 | 1
316 | 9
317 | 6
318 | 4
319 | 0
320 | 5
321 | 1
322 | 7
323 | 9
324 | 3
325 | 0
326 | 4
327 | 2
328 | 0
329 | 7
330 | 1
331 | 1
332 | 2
333 | 1
334 | 5
335 | 3
336 | 3
337 | 4
338 | 7
339 | 8
340 | 6
341 | 6
342 | 4
343 | 1
344 | 3
345 | 5
346 | 1
347 | 0
348 | 5
349 | 1
350 | 9
351 | 1
352 | 5
353 | 0
354 | 6
355 | 1
356 | 8
357 | 5
358 | 1
359 | 9
360 | 4
361 | 4
362 | 6
363 | 7
364 | 1
365 | 5
366 | 0
367 | 6
368 | 5
369 | 6
370 | 3
371 | 7
372 | 2
373 | 0
374 | 8
375 | 8
376 | 5
377 | 4
378 | 1
379 | 1
380 | 4
381 | 0
382 | 7
383 | 3
384 | 7
385 | 6
386 | 1
387 | 6
388 | 2
389 | 1
390 | 4
391 | 2
392 | 8
393 | 6
394 | 1
395 | 9
396 | 5
397 | 2
398 | 5
399 | 4
400 | 4
401 | 2
402 | 8
403 | 3
404 | 9
405 | 2
406 | 4
407 | 6
408 | 0
409 | 3
410 | 1
411 | 7
412 | 7
413 | 3
414 | 7
415 | 9
416 | 7
417 | 1
418 | 9
419 | 2
420 | 1
421 | 4
422 | 2
423 | 9
424 | 2
425 | 0
426 | 4
427 | 9
428 | 1
429 | 4
430 | 8
431 | 1
432 | 8
433 | 4
434 | 4
435 | 9
436 | 8
437 | 8
438 | 3
439 | 7
440 | 6
441 | 0
442 | 0
443 | 3
444 | 0
445 | 8
446 | 0
447 | 6
448 | 4
449 | 8
450 | 5
451 | 3
452 | 3
453 | 2
454 | 3
455 | 9
456 | 1
457 | 2
458 | 6
459 | 8
460 | 0
461 | 5
462 | 6
463 | 6
464 | 6
465 | 9
466 | 8
467 | 8
468 | 2
469 | 2
470 | 5
471 | 8
472 | 9
473 | 6
474 | 1
475 | 8
476 | 4
477 | 1
478 | 2
479 | 8
480 | 3
481 | 1
482 | 9
483 | 7
484 | 5
485 | 4
486 | 0
487 | 8
488 | 9
489 | 9
490 | 1
491 | 0
492 | 5
493 | 2
494 | 3
495 | 7
496 | 8
497 | 9
498 | 4
499 | 0
500 | 6
501 | 3
502 | 9
503 | 1
504 | 2
505 | 1
506 | 8
507 | 1
508 | 5
509 | 6
510 | 5
511 | 2
512 | 1
513 |
--------------------------------------------------------------------------------
/data/mlp_img.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cea-wind/SimpleTPU/b81f32e563ca3b16a9fd7044490235e1c5c8feda/data/mlp_img.bin
--------------------------------------------------------------------------------
/data/mlp_instr.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cea-wind/SimpleTPU/b81f32e563ca3b16a9fd7044490235e1c5c8feda/data/mlp_instr.bin
--------------------------------------------------------------------------------
/data/mlp_param.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cea-wind/SimpleTPU/b81f32e563ca3b16a9fd7044490235e1c5c8feda/data/mlp_param.bin
--------------------------------------------------------------------------------
/lab1/README.md:
--------------------------------------------------------------------------------
1 | # Systolic Array
2 |
3 | Systolic Array implement in FPGA using Xilinx HLS.
4 |
5 | ## 1.Env & Build
6 | **env** :
7 | - Vivado HLS 2018.2 or 2016.3 , MATLAB 2014a(for matlabcode)
8 |
9 | **run** :
10 | - step1: `vivado_hls -f run_hls.tcl`
11 | - step2: lanch vivado HLS and open the project
12 | - step3: Run C synthesis, C/RTL cosimulation e.t.c
13 |
14 | ## 2.Relative Link
15 | https://www.cnblogs.com/sea-wind/p/10995360.html
--------------------------------------------------------------------------------
/lab1/data/feature.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cea-wind/SimpleTPU/b81f32e563ca3b16a9fd7044490235e1c5c8feda/lab1/data/feature.dat
--------------------------------------------------------------------------------
/lab1/data/golden.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cea-wind/SimpleTPU/b81f32e563ca3b16a9fd7044490235e1c5c8feda/lab1/data/golden.dat
--------------------------------------------------------------------------------
/lab1/data/weight.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cea-wind/SimpleTPU/b81f32e563ca3b16a9fd7044490235e1c5c8feda/lab1/data/weight.dat
--------------------------------------------------------------------------------
/lab1/refcode/conv3d.m:
--------------------------------------------------------------------------------
1 |
2 | rng(0);
3 | feature = randi([-128,127],14,14,32);
4 | weight = randi([-128,127],32,3,3,32);
5 | bias = randi([-1024,1023],1,32);
6 | output = zeros(14,14,32);
7 |
8 | saveparam(feature,weight,bias)
9 |
10 | out1 = convmxu(weight,feature,bias,2,2);
11 | out2 = convmxu(weight,feature,zeros(1,32),1,1);
12 | out3 = convmxu(weight,feature,zeros(1,32),1,2);
13 | out4 = convmxu(weight,feature,zeros(1,32),1,3);
14 | out5 = convmxu(weight,feature,zeros(1,32),2,1);
15 | out6 = convmxu(weight,feature,zeros(1,32),2,3);
16 | out7 = convmxu(weight,feature,zeros(1,32),3,1);
17 | out8 = convmxu(weight,feature,zeros(1,32),3,2);
18 | out9 = convmxu(weight,feature,zeros(1,32),3,3);
19 |
20 | output = out1;
21 | output(2:end,2:end,:) = output(2:end,2:end,:) + out2(1:end-1,1:end-1,:);
22 | output(2:end,:,:) = output(2:end,:,:) + out3(1:end-1,:,:);
23 | output(2:end,1:end-1,:) = output(2:end,1:end-1,:) + out4(1:end-1,2:end,:);
24 | output(:,2:end,:) = output(:,2:end,:) + out5(:,1:end-1,:);
25 | output(:,1:end-1,:) = output(:,1:end-1,:) + out6(:,2:end,:);
26 | output(1:end-1,2:end,:) = output(1:end-1,2:end,:) + out7(2:end,1:end-1,:);
27 | output(1:end-1,:,:) = output(1:end-1,:,:) + out8(2:end,:,:);
28 | output(1:end-1,1:end-1,:) = output(1:end-1,1:end-1,:) + out9(2:end,2:end,:);
29 |
30 | golden = zeros(14,14,32);
31 | for k = 1:32
32 | wk = reshape(weight(k,:,:,:),3,3,32);
33 | wk = wk(end:-1:1,end:-1:1,end:-1:1);
34 | tmp = convn(feature,wk,'same');
35 | golden(:,:,k) = tmp(:,:,16)+bias(k);
36 | end
37 | golden = int32(golden);
38 | fid = fopen('golden.dat','wb');
39 | for i=1:14
40 | for j=1:14
41 | fwrite(fid,golden(i,j,:),'int32');
42 | end
43 | end
44 | fclose(fid);
--------------------------------------------------------------------------------
/lab1/refcode/convmxu.m:
--------------------------------------------------------------------------------
1 | function [out1] = convmxu(weight,feature,bias,index1,index2)
2 | %UNTITLED3 Summary of this function goes here
3 | % Detailed explanation goes here
4 |
5 | out1 = zeros(14,14,32);
6 | for i = 1:14
7 | for j = 1:14
8 | for k = 1:32
9 | for c = 1:32
10 | if(c==1)
11 | out1(i,j,k) = bias(k) + weight(k,index1,index2,c)*feature(i,j,c);
12 | else
13 | out1(i,j,k) = out1(i,j,k) + weight(k,index1,index2,c)*feature(i,j,c);
14 | end
15 | end
16 | end
17 | end
18 | end
19 |
20 | end
21 |
22 |
--------------------------------------------------------------------------------
/lab1/refcode/saveparam.m:
--------------------------------------------------------------------------------
1 | function [] = saveparam(feature,weight,bias)
2 | %UNTITLED2 Summary of this function goes here
3 | % Detailed explanation goes here
4 |
5 | feature = int8(feature);
6 | weight = int8(weight);
7 | bias = int32(bias);
8 | bias4 = bitand(bitshift(bias,-24),int32(255));
9 | bias3 = bitand(bitshift(bias,-16),int32(255));
10 | bias2 = bitand(bitshift(bias,-8),int32(255));
11 | bias1 = bitand(bias,int32(255));
12 | fid = fopen('feature.dat','wb');
13 | for i=1:14
14 | for j=1:14
15 | fwrite(fid,feature(i,j,:),'int8');
16 | end
17 | end
18 | fclose(fid);
19 |
20 | fid = fopen('weight.dat','wb');
21 | for k=1:32
22 | fwrite(fid,weight(:,2,2,k),'int8');
23 | end
24 | fwrite(fid,uint8(bias4),'uint8');
25 | fwrite(fid,uint8(bias3),'uint8');
26 | fwrite(fid,uint8(bias2),'uint8');
27 | fwrite(fid,uint8(bias1),'uint8');
28 | for i=1:3
29 | for j=1:3
30 | for k=1:32
31 | if(~(i==2&&j==2))
32 | fwrite(fid,weight(:,i,j,k),'int8');
33 | end
34 | end
35 | if(~(i==2&&j==2))
36 | for k=1:32
37 | fwrite(fid,0,'int32');
38 | end
39 | end
40 | end
41 | end
42 | fclose(fid);
43 |
44 | end
45 |
46 |
--------------------------------------------------------------------------------
/lab1/run_hls.tcl:
--------------------------------------------------------------------------------
1 | open_project -reset mxu_conv_prj
2 | set_top MXU
3 | add_files src/tpu.h
4 | add_files src/mxu.cpp
5 | add_files -tb data/feature.dat
6 | add_files -tb data/golden.dat
7 | add_files -tb data/weight.dat
8 | add_files -tb src/tb_mxu.cpp
9 |
10 | open_solution -reset "solution1"
11 | set_part {xczu7cg-fbvb900-2-i} -tool vivado
12 | create_clock -period 2.5 -name default
13 |
14 | csim_design
15 | # Do not perform any other steps
16 | # - The basic project will be opened in the GUI
17 | exit
--------------------------------------------------------------------------------
/lab1/src/mxu.cpp:
--------------------------------------------------------------------------------
1 |
2 | #include "tpu.h"
3 |
4 | void SetWeight(WEIGHTDTYPE weight[512][MXU_COLNUM],WEIGHTDTYPE weightreg[MXU_ROWNUM+4][MXU_COLNUM],
5 | short weight_raddr, bool enable){
6 | if(!enable)
7 | return;
8 | for(short i=weight_raddr;i=0;k--){
56 | if(k>0)
57 | featreg[j][k] = featreg[j][k-1];
58 | else
59 | if(i=0;j--){
67 | for(int k=0;k biasreg;
69 | biasreg(31,24)=weightreg[MXU_ROWNUM+0][k];
70 | biasreg(23,16)=weightreg[MXU_ROWNUM+1][k];
71 | biasreg(15, 8)=weightreg[MXU_ROWNUM+2][k];
72 | biasreg( 7, 0)=weightreg[MXU_ROWNUM+3][k];
73 | if(j==0)
74 | psumreg[j][k] = featreg[j][k+j]*weightreg[j][k] + biasreg;
75 | else
76 | psumreg[j][k] = featreg[j][k+j]*weightreg[j][k] + psumreg[j-1][k];
77 | }
78 | }
79 | #pragma HLS DEPENDENCE variable=psum inter false
80 | #pragma HLS DEPENDENCE variable=psum intra false
81 | for(int j=0;j=j+MXU_ROWNUM-1&&ipsumpool[j])
48 | psumpool[j] = psumrelu[j];
49 | }
50 | else{
51 | psumpool[j] = psumpool[j] + psumrelu[j];
52 | }
53 | }
54 |
55 | if(pool_kw_cnt==param.pool_kw&&pool_kh_cnt==param.pool_kh){
56 | short ubuf_waddr = param.ubuf_waddr_start + ubuf_waddr_p1 + ubuf_waddr_p2 + ubuf_waddr_p3;
57 | if(ubuf_waddr_p1==param.ubuf_waddr_end1){
58 | if(ubuf_waddr_p2==param.ubuf_waddr_end2){
59 | ubuf_waddr_p2 = 0;
60 | ubuf_waddr_p3 = ubuf_waddr_p3 + param.ubuf_waddr_step3;
61 | }
62 | else{
63 | ubuf_waddr_p2 = ubuf_waddr_p2 + param.ubuf_waddr_step2;
64 | }
65 | }
66 | else{
67 | ubuf_waddr_p1 = ubuf_waddr_p1 + param.ubuf_waddr_step1;
68 | }
69 | for(int j=0;j>32;
73 | ap_int<8> res;
74 | if(tmpcut>127)
75 | res = 127;
76 | else if(tmpcut<-128)
77 | res = -128;
78 | else
79 | res = tmpcut;
80 | unified_buffer[ubuf_waddr][j] = res;
81 | }
82 | }
83 |
84 | if(pool_kw_cnt==param.pool_kw){
85 | pool_kw_cnt = 0;
86 | if(pool_kh_cnt==param.pool_kh){
87 | pool_kh_cnt = 0;
88 | if(pool_w_cnt==param.pool_w){
89 | pool_w_cnt = 0;
90 | pool_h_cnt = pool_h_cnt + param.pool_sh;
91 | }
92 | else{
93 | pool_w_cnt = pool_w_cnt + param.pool_sw;
94 | }
95 | }
96 | else{
97 | pool_kh_cnt = pool_kh_cnt + 1;
98 | }
99 | }
100 | else{
101 | pool_kw_cnt = pool_kw_cnt + 1;
102 | }
103 | }
104 | }
105 |
--------------------------------------------------------------------------------
/lab2/src/tb_pool.cpp:
--------------------------------------------------------------------------------
1 | #include "tpu.h"
2 | #include "stdio.h"
3 | #include "stdlib.h"
4 |
5 | int main(){
6 | PSUMDTYPE psum_buffer[512][MXU_COLNUM];
7 | FEATDTYPE unified_buffer[16384][MXU_ROWNUM];
8 | int norm_coef[MXU_COLNUM];
9 | RELPOOL_PARAM param;
10 | for(int i=0;i<14;i++){
11 | for(int j=0;j<14;j++){
12 | for(int c=0;c<32;c++){
13 | psum_buffer[i*14+j][c] = (i*14+j+c)*512;
14 | }
15 | }
16 | }
17 | for(int c=0;c<32;c++)
18 | norm_coef[c] = 1<<23;
19 |
20 | // no pooling
21 | param.isrelu = true;
22 | param.psum_raddr_start = 0;
23 | param.maxpool = true;
24 | param.pool_kw = 0;
25 | param.pool_kh = 0;
26 | param.pool_w = 14-1;
27 | param.pool_sw = 1;
28 | param.pool_sh = 1;
29 | param.pool_cnt = 14*14;
30 | param.pool_h_step = 14;
31 | param.ubuf_waddr_start = 0;
32 | param.ubuf_waddr_step1 = 1;
33 | param.ubuf_waddr_end1 = 14*14-1;
34 | relu_norm_pool(psum_buffer,unified_buffer,norm_coef,param);
35 |
36 | FEATDTYPE golden[14*14][MXU_ROWNUM];
37 | for(int i=0;i<14;i++){
38 | for(int j=0;j<14;j++){
39 | for (int k=0;k<32;k++){
40 | int tmp = psum_buffer[i*14+j][k]/512;
41 | tmp = tmp>127?127:tmp;
42 | tmp = tmp<-128?-128:tmp;
43 | golden[i*14+j][k] = tmp;
44 | }
45 | }
46 | }
47 | int err=0;
48 | for(int i=0;i<14*14;i++){
49 | for(int k=0;k<32;k++){
50 | if(golden[i][k]!=unified_buffer[i][k])
51 | err ++;
52 | }
53 | }
54 |
55 | // max pooling 2,2
56 | for(int c=0;c<32;c++)
57 | norm_coef[c] = 1<<23;
58 | param.isrelu = true;
59 | param.psum_raddr_start = 0;
60 | param.maxpool = true;
61 | param.pool_kw = 1;
62 | param.pool_kh = 1;
63 | param.pool_w = 12;
64 | param.pool_sw = 2;
65 | param.pool_sh = 2;
66 | param.pool_cnt = 14*14;
67 | param.pool_h_step = 14;
68 | param.ubuf_waddr_start = 0;
69 | param.ubuf_waddr_step1 = 1;
70 | param.ubuf_waddr_end1 = 7*7-1;
71 | relu_norm_pool(psum_buffer,unified_buffer,norm_coef,param);
72 |
73 | for(int i=0;i<7;i++){
74 | for(int j=0;j<7;j++){
75 | for (int k=0;k<32;k++){
76 | int tmp = -128;
77 | for(int i1=0;i1<2;i1++){
78 | for(int j1=0;j1<2;j1++){
79 | if(tmp127?127:tmp;
84 | tmp = tmp<-128?-128:tmp;
85 | golden[i*7+j][k] = tmp;
86 | }
87 | }
88 | }
89 | for(int i=0;i<7*7;i++){
90 | for(int k=0;k<32;k++){
91 | if(golden[i][k]!=unified_buffer[i][k])
92 | err ++;
93 | }
94 | }
95 |
96 | for(int c=0;c<32;c++)
97 | norm_coef[c] = 171196;
98 | // avg pooling 7,7
99 | param.isrelu = true;
100 | param.psum_raddr_start = 0;
101 | param.maxpool = false;
102 | param.pool_kw = 6;
103 | param.pool_kh = 6;
104 | param.pool_w = 7;
105 | param.pool_sw = 7;
106 | param.pool_sh = 7;
107 | param.pool_cnt = 14*14;
108 | param.pool_h_step = 14;
109 | param.ubuf_waddr_start = 0;
110 | param.ubuf_waddr_step1 = 1;
111 | param.ubuf_waddr_end1 = 14*14-1;
112 | relu_norm_pool(psum_buffer,unified_buffer,norm_coef,param);
113 |
114 | for(int i=0;i<2;i++){
115 | for(int j=0;j<2;j++){
116 | for (int k=0;k<32;k++){
117 | int tmp = 0;
118 | for(int i1=0;i1<7;i1++){
119 | for(int j1=0;j1<7;j1++){
120 | tmp += psum_buffer[(i*7+i1)*14+7*j+j1][k];
121 | }
122 | }
123 | tmp = (long(tmp)*long(171196))>>32;
124 | tmp = tmp>127?127:tmp;
125 | tmp = tmp<-128?-128:tmp;
126 | golden[i*2+j][k] = tmp;
127 | }
128 | }
129 | }
130 | for(int i=0;i<2*2;i++){
131 | for(int k=0;k<32;k++){
132 | if(golden[i][k]!=unified_buffer[i][k])
133 | err ++;
134 | }
135 | }
136 | return err;
137 | }
138 |
--------------------------------------------------------------------------------
/lab2/src/tpu.h:
--------------------------------------------------------------------------------
1 | #include "ap_int.h"
2 |
3 | #define MXU_COLNUM 32
4 | #define MXU_ROWNUM 32
5 | #define WEIGHTDTYPE char
6 | #define FEATDTYPE char
7 | #define PSUMDTYPE int
8 |
9 |
10 | struct MXU_PARAM{
11 | bool isload;
12 | bool iscalc;
13 | bool isping;
14 | bool isfirstpsum;
15 |
16 | short weight_raddr;
17 | short ubuf_raddr_start;
18 | short ubuf_raddr_step1;
19 | short ubuf_raddr_step2;
20 | short ubuf_raddr_step3;
21 | short ubuf_raddr_end1;
22 | short ubuf_raddr_end2;
23 | short ubuf_raddr_end3;
24 | short ubuf_raddr_num;
25 | short psum_start;
26 | short psum_step1;
27 | short psum_end1;
28 | short psum_step2;
29 | };
30 | struct RELPOOL_PARAM{
31 | bool isrelu;
32 | short psum_raddr_start;
33 |
34 | bool maxpool; // max pool or average pool
35 | char pool_kw;
36 | char pool_kh;
37 | char pool_w;
38 | char pool_sw;
39 | char pool_sh;
40 | short pool_cnt; // output_num*pool_kw*pool_kh
41 | short pool_h_step;
42 |
43 | short ubuf_waddr_start;
44 | short ubuf_waddr_step1;
45 | short ubuf_waddr_step2;
46 | short ubuf_waddr_step3;
47 | short ubuf_waddr_end1;
48 | short ubuf_waddr_end2;
49 | short ubuf_waddr_end3;
50 | };
51 |
52 | void MXU(FEATDTYPE ubuf[16384][MXU_ROWNUM],WEIGHTDTYPE weight[512][MXU_COLNUM],
53 | PSUMDTYPE psum[512][MXU_COLNUM],MXU_PARAM mxuparam);
54 | void relu_norm_pool(PSUMDTYPE psum_buffer[512][MXU_COLNUM],FEATDTYPE unified_buffer[16384][MXU_ROWNUM],
55 | int norm_coef[MXU_COLNUM],RELPOOL_PARAM param);
56 |
--------------------------------------------------------------------------------
/pictures/cla_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cea-wind/SimpleTPU/b81f32e563ca3b16a9fd7044490235e1c5c8feda/pictures/cla_result.png
--------------------------------------------------------------------------------
/pictures/sim.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cea-wind/SimpleTPU/b81f32e563ca3b16a9fd7044490235e1c5c8feda/pictures/sim.png
--------------------------------------------------------------------------------
/pictures/syn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cea-wind/SimpleTPU/b81f32e563ca3b16a9fd7044490235e1c5c8feda/pictures/syn.png
--------------------------------------------------------------------------------
/src/ctrl.cpp:
--------------------------------------------------------------------------------
1 | #include "tpu.h"
2 |
3 | void loadWeight(ap_uint<256> *ddr,WEIGHTDTYPE weight_buffer[512][MXU_COLNUM],
4 | unsigned offset,short addr, short len, bool enable){
5 | if(!enable)
6 | return;
7 | for(int i=0;i tmp = ddr[offset+i];
10 | for(int j=0;j<32;j++){
11 | weight_buffer[addr+i][j] = tmp(j*8+7,j*8);
12 | }
13 | }
14 | }
15 |
16 | void loadFeature(ap_uint<256> *ddr,FEATDTYPE unified_buffer[512][MXU_ROWNUM],
17 | unsigned offset,short addr, short len, bool enable){
18 | if(!enable)
19 | return;
20 | for(int i=0;i tmp = ddr[offset+i];
23 | for(int j=0;j<32;j++){
24 | unified_buffer[addr+i][j] = tmp(j*8+7,j*8);
25 | }
26 | }
27 | }
28 | void storeFeature(ap_uint<256> *ddr,FEATDTYPE unified_buffer[512][MXU_COLNUM],
29 | unsigned offset,short addr, short len, bool enable){
30 | if(!enable)
31 | return;
32 | for(int i=0;i tmp;
35 | for(int j=0;j<32;j++){
36 | tmp(j*8+7,j*8) = unified_buffer[addr+i][j];;
37 | }
38 | ddr[offset+i] = tmp;
39 | }
40 | }
41 | //set instr. set register
42 | //run instr. run process
43 | //eop instr. end of process
44 | //
45 |
46 | void instr(ap_uint<64> *ddr,unsigned &offset,ap_int<16> reggroup[96],ap_int<8> &runmode,bool enable){
47 | #pragma HLS INTERFACE m_axi depth=8192 port=ddr
48 | #pragma HLS ARRAY_PARTITION variable=reggroup complete dim=1
49 | if(!enable)
50 | return;
51 | bool isRunInstr = false;
52 | while(!isRunInstr){
53 | ap_uint<64> tmp = ddr[offset];
54 | offset++;
55 | if(tmp[63]==0){
56 | switch(tmp(52,48)){
57 | case( 0):reggroup[ 0] = tmp(15, 0);reggroup[ 1] = tmp(31,16);reggroup[ 2] = tmp(47,32);break;
58 | case( 1):reggroup[ 3] = tmp(15, 0);reggroup[ 4] = tmp(31,16);reggroup[ 5] = tmp(47,32);break;
59 | case( 2):reggroup[ 6] = tmp(15, 0);reggroup[ 7] = tmp(31,16);reggroup[ 8] = tmp(47,32);break;
60 | case( 3):reggroup[ 9] = tmp(15, 0);reggroup[10] = tmp(31,16);reggroup[11] = tmp(47,32);break;
61 | case( 4):reggroup[12] = tmp(15, 0);reggroup[13] = tmp(31,16);reggroup[14] = tmp(47,32);break;
62 | case( 5):reggroup[15] = tmp(15, 0);reggroup[16] = tmp(31,16);reggroup[17] = tmp(47,32);break;
63 | case( 6):reggroup[18] = tmp(15, 0);reggroup[19] = tmp(31,16);reggroup[20] = tmp(47,32);break;
64 | case( 7):reggroup[21] = tmp(15, 0);reggroup[22] = tmp(31,16);reggroup[23] = tmp(47,32);break;
65 | case( 8):reggroup[24] = tmp(15, 0);reggroup[25] = tmp(31,16);reggroup[26] = tmp(47,32);break;
66 | case( 9):reggroup[27] = tmp(15, 0);reggroup[28] = tmp(31,16);reggroup[29] = tmp(47,32);break;
67 | case(10):reggroup[30] = tmp(15, 0);reggroup[31] = tmp(31,16);reggroup[32] = tmp(47,32);break;
68 | case(11):reggroup[33] = tmp(15, 0);reggroup[34] = tmp(31,16);reggroup[35] = tmp(47,32);break;
69 | case(12):reggroup[36] = tmp(15, 0);reggroup[37] = tmp(31,16);reggroup[38] = tmp(47,32);break;
70 | case(13):reggroup[39] = tmp(15, 0);reggroup[40] = tmp(31,16);reggroup[41] = tmp(47,32);break;
71 | case(14):reggroup[42] = tmp(15, 0);reggroup[43] = tmp(31,16);reggroup[44] = tmp(47,32);break;
72 | case(15):reggroup[45] = tmp(15, 0);reggroup[46] = tmp(31,16);reggroup[47] = tmp(47,32);break;
73 | case(16):reggroup[48] = tmp(15, 0);reggroup[49] = tmp(31,16);reggroup[50] = tmp(47,32);break;
74 | case(17):reggroup[51] = tmp(15, 0);reggroup[52] = tmp(31,16);reggroup[53] = tmp(47,32);break;
75 | case(18):reggroup[54] = tmp(15, 0);reggroup[55] = tmp(31,16);reggroup[56] = tmp(47,32);break;
76 | case(19):reggroup[57] = tmp(15, 0);reggroup[58] = tmp(31,16);reggroup[59] = tmp(47,32);break;
77 | case(20):reggroup[60] = tmp(15, 0);reggroup[61] = tmp(31,16);reggroup[62] = tmp(47,32);break;
78 | case(21):reggroup[63] = tmp(15, 0);reggroup[64] = tmp(31,16);reggroup[65] = tmp(47,32);break;
79 | case(22):reggroup[66] = tmp(15, 0);reggroup[67] = tmp(31,16);reggroup[68] = tmp(47,32);break;
80 | case(23):reggroup[69] = tmp(15, 0);reggroup[70] = tmp(31,16);reggroup[71] = tmp(47,32);break;
81 | case(24):reggroup[72] = tmp(15, 0);reggroup[73] = tmp(31,16);reggroup[74] = tmp(47,32);break;
82 | case(25):reggroup[75] = tmp(15, 0);reggroup[76] = tmp(31,16);reggroup[77] = tmp(47,32);break;
83 | case(26):reggroup[78] = tmp(15, 0);reggroup[79] = tmp(31,16);reggroup[80] = tmp(47,32);break;
84 | case(27):reggroup[81] = tmp(15, 0);reggroup[82] = tmp(31,16);reggroup[83] = tmp(47,32);break;
85 | case(28):reggroup[84] = tmp(15, 0);reggroup[85] = tmp(31,16);reggroup[86] = tmp(47,32);break;
86 | case(29):reggroup[87] = tmp(15, 0);reggroup[88] = tmp(31,16);reggroup[89] = tmp(47,32);break;
87 | case(30):reggroup[90] = tmp(15, 0);reggroup[91] = tmp(31,16);reggroup[92] = tmp(47,32);break;
88 | case(31):reggroup[93] = tmp(15, 0);reggroup[94] = tmp(31,16);reggroup[95] = tmp(47,32);break;
89 | }
90 | }
91 | else{
92 | runmode = tmp(55,48);
93 | isRunInstr = true;
94 | }
95 | }
96 | }
97 |
98 |
99 | void config(ap_int<16> reggroup[96],MXU_PARAM &mxuparam,RELPOOL_PARAM &poolparam,LDST_PARAM &lsdtparam, ap_int<32> norm_coef[32]){
100 | #pragma HLS INLINE
101 | mxuparam.isload = reggroup[ 0].range(0,0);
102 | mxuparam.iscalc = reggroup[ 0].range(1,1);
103 | mxuparam.isping = reggroup[ 0].range(2,2);
104 | mxuparam.isfirstpsum = reggroup[ 0].range(3,3);
105 | mxuparam.weight_raddr = reggroup[ 1];
106 | mxuparam.ubuf_raddr_start= reggroup[ 2];
107 | mxuparam.ubuf_raddr_step1= reggroup[ 3];
108 | mxuparam.ubuf_raddr_step2= reggroup[ 4];
109 | mxuparam.ubuf_raddr_step3= reggroup[ 5];
110 | mxuparam.ubuf_raddr_end1 = reggroup[ 6];
111 | mxuparam.ubuf_raddr_end2 = reggroup[ 7];
112 | mxuparam.ubuf_raddr_end3 = reggroup[ 8];
113 | mxuparam.ubuf_raddr_num = reggroup[ 9];
114 | mxuparam.psum_start = reggroup[10];
115 | mxuparam.psum_step1 = reggroup[11];
116 | mxuparam.psum_end1 = reggroup[12];
117 | mxuparam.psum_step2 = reggroup[13];
118 |
119 | poolparam.isrelu = reggroup[14].range( 0,0);
120 | poolparam.maxpool = reggroup[14].range( 1,1);
121 | poolparam.avg_shift = reggroup[14].range( 7,4);
122 | poolparam.pool_kw = reggroup[14].range(15,8);
123 | poolparam.pool_kh = reggroup[15].range( 7,0);
124 | poolparam.pool_w = reggroup[15].range(15,8);
125 | poolparam.pool_sw = reggroup[16].range( 7,0);
126 | poolparam.pool_sh = reggroup[16].range(15,8);
127 | poolparam.psum_raddr_start = reggroup[17];
128 | poolparam.pool_cnt = reggroup[18];
129 | poolparam.pool_h_step = reggroup[19];
130 | poolparam.avg_val = reggroup[20];
131 | poolparam.ubuf_waddr_start = reggroup[21];
132 | poolparam.ubuf_waddr_step1 = reggroup[22];
133 | poolparam.ubuf_waddr_step2 = reggroup[23];
134 | poolparam.ubuf_waddr_step3 = reggroup[24];
135 | poolparam.ubuf_waddr_end1 = reggroup[25];
136 | poolparam.ubuf_waddr_end2 = reggroup[26];
137 | poolparam.ubuf_waddr_end3 = reggroup[27];
138 |
139 | lsdtparam.weight_addr = reggroup[28];
140 | lsdtparam.weight_ldlen = reggroup[29];
141 | ap_uint<32> tmp = (reggroup[31],reggroup[30]);
142 | lsdtparam.weight_offset = tmp;
143 | for(int i=0;i<32;i++){
144 | #pragma HLS UNROLL
145 | norm_coef[i] = (reggroup[33+2*i],reggroup[32+2*i]);
146 | }
147 | return;
148 | }
149 |
--------------------------------------------------------------------------------
/src/mxu.cpp:
--------------------------------------------------------------------------------
1 |
2 | #include "tpu.h"
3 |
4 | void SetWeight(WEIGHTDTYPE weight[512][MXU_COLNUM],WEIGHTDTYPE weightreg[MXU_ROWNUM+4][MXU_COLNUM],
5 | short weight_raddr, bool enable){
6 | if(!enable)
7 | return;
8 | for(short i=weight_raddr;i=0;k--){
56 | if(k>0)
57 | featreg[j][k] = featreg[j][k-1];
58 | else
59 | if(i=0;j--){
67 | for(int k=0;k biasreg;
69 | biasreg(31,24)=weightreg[MXU_ROWNUM+0][k];
70 | biasreg(23,16)=weightreg[MXU_ROWNUM+1][k];
71 | biasreg(15, 8)=weightreg[MXU_ROWNUM+2][k];
72 | biasreg( 7, 0)=weightreg[MXU_ROWNUM+3][k];
73 | if(j==0)
74 | psumreg[j][k] = featreg[j][k+j]*weightreg[j][k] + biasreg;
75 | else
76 | psumreg[j][k] = featreg[j][k+j]*weightreg[j][k] + psumreg[j-1][k];
77 | }
78 | }
79 | #pragma HLS DEPENDENCE variable=psum inter false
80 | #pragma HLS DEPENDENCE variable=psum intra false
81 | for(int j=0;j=j+MXU_ROWNUM-1&&i norm_coef[MXU_COLNUM],RELPOOL_PARAM param, bool enable){
6 | //#pragma HLS INTERFACE bram port=unified_buffer
7 | //#pragma HLS INTERFACE bram port=psum_buffer
8 | //#pragma HLS ARRAY_PARTITION variable=norm_coef complete dim=1
9 | //#pragma HLS ARRAY_PARTITION variable=unified_buffer complete dim=2
10 | //#pragma HLS ARRAY_PARTITION variable=psum_buffer complete dim=2
11 |
12 | PSUMDTYPE psumreg[MXU_COLNUM];
13 | PSUMDTYPE psumrelu[MXU_COLNUM];
14 | PSUMDTYPE psumpool[MXU_COLNUM];
15 | FEATDTYPE relu[MXU_COLNUM];
16 | short pool[MXU_COLNUM];
17 | #pragma HLS ARRAY_PARTITION variable=psumreg complete dim=1
18 | #pragma HLS ARRAY_PARTITION variable=psumsht complete dim=1
19 | #pragma HLS ARRAY_PARTITION variable=relu complete dim=1
20 | #pragma HLS ARRAY_PARTITION variable=pool complete dim=1
21 |
22 | char pool_kw_cnt = 0;
23 | char pool_kh_cnt = 0;
24 | char pool_w_cnt = 0;
25 | char pool_h_cnt = 0;
26 | short ubuf_waddr_p1=0;
27 | short ubuf_waddr_p2=0;
28 | short ubuf_waddr_p3=0;
29 |
30 | if(!enable)
31 | return;
32 | for(short i=0;ipsumpool[j])
48 | psumpool[j] = psumrelu[j];
49 | }
50 | else{
51 | psumpool[j] = psumpool[j] + psumrelu[j];
52 | }
53 | }
54 |
55 | if(pool_kw_cnt==param.pool_kw&&pool_kh_cnt==param.pool_kh){
56 | short ubuf_waddr = param.ubuf_waddr_start + ubuf_waddr_p1 + ubuf_waddr_p2 + ubuf_waddr_p3;
57 | if(ubuf_waddr_p1==param.ubuf_waddr_end1){
58 | if(ubuf_waddr_p2==param.ubuf_waddr_end2){
59 | ubuf_waddr_p2 = 0;
60 | ubuf_waddr_p3 = ubuf_waddr_p3 + param.ubuf_waddr_step3;
61 | }
62 | else{
63 | ubuf_waddr_p2 = ubuf_waddr_p2 + param.ubuf_waddr_step2;
64 | }
65 | }
66 | else{
67 | ubuf_waddr_p1 = ubuf_waddr_p1 + param.ubuf_waddr_step1;
68 | }
69 | for(int j=0;j>32;
73 | ap_int<8> res;
74 | if(tmpcut>127)
75 | res = 127;
76 | else if(tmpcut<-128)
77 | res = -128;
78 | else
79 | res = tmpcut;
80 | unified_buffer[ubuf_waddr][j] = res;
81 | }
82 | }
83 |
84 | if(pool_kw_cnt==param.pool_kw){
85 | pool_kw_cnt = 0;
86 | if(pool_kh_cnt==param.pool_kh){
87 | pool_kh_cnt = 0;
88 | if(pool_w_cnt==param.pool_w){
89 | pool_w_cnt = 0;
90 | pool_h_cnt = pool_h_cnt + param.pool_sh;
91 | }
92 | else{
93 | pool_w_cnt = pool_w_cnt + param.pool_sw;
94 | }
95 | }
96 | else{
97 | pool_kh_cnt = pool_kh_cnt + 1;
98 | }
99 | }
100 | else{
101 | pool_kw_cnt = pool_kw_cnt + 1;
102 | }
103 | }
104 | }
105 |
--------------------------------------------------------------------------------
/src/tb_tpu.cpp:
--------------------------------------------------------------------------------
1 | #include "tpu.h"
2 | #include "stdio.h"
3 | int main(){
4 | ap_uint<256> *ddr;
5 | ap_uint<64> *ddr_instr;
6 | ddr = (ap_uint<256> *)malloc(sizeof(ap_uint<256>)*(16384));
7 | //512*25+72*25+72+512
8 | ddr_instr = (ap_uint<64> *)malloc(sizeof(ap_uint<64>)*3300);
9 | FILE *fid;
10 | fid = fopen("mlp_img.bin","rb");
11 | fread(ddr,32,25*512,fid);
12 | fclose(fid);
13 | fid = fopen("mlp_param.bin","rb");
14 | fread(ddr+512*25,32,25*72+72,fid);
15 | fclose(fid);
16 | fid = fopen("mlp_instr.bin","rb");
17 | ap_uint<64> *ddr_instr_r = ddr_instr;
18 | int cnt = 0;
19 | while(1==1){
20 | fread(ddr_instr_r,8,1,fid);
21 | ap_uint<64> tmp = *ddr_instr_r;
22 | if(tmp.range(55,55)==1)
23 | break;
24 | ddr_instr_r++;
25 | cnt++;
26 | }
27 | fclose(fid);
28 | tpu(ddr,ddr_instr);
29 | fid = fopen("golden_result.txt","r");
30 | int err = 0;
31 | for(int i=0;i<512;i++){
32 | ap_uint<256> val = ddr[512*25+72*25+72+i];
33 | int maxcof = -255;
34 | int idx = -1;
35 | int ref = -1;
36 | for(int j=0;j<16;j++){
37 | int cof = val(j*8+7,j*8);
38 | if(cof>127)
39 | cof = cof-256;
40 | if(cof>maxcof){
41 | maxcof = cof;
42 | idx = j;
43 | }
44 | }
45 | fscanf(fid,"%d",&ref);
46 | if(idx!=ref)
47 | err++;
48 | }
49 | return err;
50 | }
51 |
--------------------------------------------------------------------------------
/src/tpu.cpp:
--------------------------------------------------------------------------------
1 |
2 | #include "tpu.h"
3 |
4 | void ex_module(FEATDTYPE unified_buffer[16384][MXU_ROWNUM],WEIGHTDTYPE weight_buffer[512][MXU_COLNUM],
5 | ap_int<32> norm_coef[MXU_COLNUM],MXU_PARAM mxuparam,RELPOOL_PARAM poolparam,
6 | bool is_MXU,bool is_relu_norm_pool){
7 | #pragma HLS INLINE off
8 | #pragma HLS DEPENDENCE variable=unified_buffer inter false
9 | #pragma HLS DEPENDENCE variable=unified_buffer intra false
10 | static PSUMDTYPE psum_buffer1[512][MXU_COLNUM];
11 | static PSUMDTYPE psum_buffer2[512][MXU_COLNUM];
12 | #pragma HLS ARRAY_PARTITION variable=psum_buffer1 complete dim=2
13 | #pragma HLS ARRAY_PARTITION variable=psum_buffer2 complete dim=2
14 | if((is_MXU&&mxuparam.psum_start<512) || (is_relu_norm_pool&&poolparam.psum_raddr_start>=512) )
15 | {
16 | MXU(unified_buffer,weight_buffer,psum_buffer1,mxuparam,is_MXU);
17 | relu_norm_pool(psum_buffer2,unified_buffer,norm_coef,poolparam,is_relu_norm_pool);
18 | }
19 | else{
20 | MXU(unified_buffer,weight_buffer,psum_buffer2,mxuparam,is_MXU);
21 | relu_norm_pool(psum_buffer1,unified_buffer,norm_coef,poolparam,is_relu_norm_pool);
22 | }
23 |
24 | }
25 |
26 | void tpu(ap_uint<256> *ddr,ap_uint<64> *ddr_instr){
27 | #pragma HLS INTERFACE m_axi depth=16384 port=ddr
28 | #pragma HLS INTERFACE m_axi depth=3300 port=ddr_instr
29 |
30 | static FEATDTYPE unified_buffer[16384][MXU_ROWNUM];
31 | #pragma HLS RESOURCE variable=unified_buffer core=RAM_S2P_BRAM
32 | static WEIGHTDTYPE weight_buffer[512][MXU_COLNUM];
33 | #pragma HLS RESOURCE variable=weight_buffer core=RAM_S2P_BRAM
34 | static ap_int<32> norm_coef[MXU_COLNUM];
35 | #pragma HLS ARRAY_PARTITION variable=unified_buffer complete dim=2
36 | #pragma HLS ARRAY_PARTITION variable=weight_buffer complete dim=2
37 | #pragma HLS ARRAY_PARTITION variable=norm_coef complete dim=0
38 |
39 | ap_int<16> reggroup[96];
40 | #pragma HLS ARRAY_PARTITION variable=reggroup complete dim=0
41 | MXU_PARAM mxuparam;
42 | RELPOOL_PARAM poolparam;
43 | LDST_PARAM ldstparam;
44 | unsigned instr_offset = 0;
45 | bool is_load_weight;
46 | bool is_MXU;
47 | bool is_relu_norm_pool;
48 | // load img
49 | loadFeature(ddr,unified_buffer, 0,0, 512*25, true);
50 | bool eop = false;
51 | ap_int<8> runmode = 0; //0 nop, bit[0] loadweight;bit[1] mxu; bit[2] pool; bit[7] eop;
52 | instr(ddr_instr,instr_offset,reggroup,runmode,true);
53 | while(runmode[7]==0)
54 | {
55 | #pragma HLS DEPENDENCE variable=unified_buffer inter false
56 | #pragma HLS DEPENDENCE variable=unified_buffer intra false
57 | #pragma HLS DEPENDENCE variable=weight_buffer inter false
58 | #pragma HLS DEPENDENCE variable=weight_buffer intra false
59 |
60 | config(reggroup,mxuparam,poolparam,ldstparam,norm_coef);
61 | is_load_weight = runmode[0]==1;
62 | is_MXU = runmode[1]==1;
63 | is_relu_norm_pool = runmode[2]==1;
64 | instr(ddr_instr,instr_offset,reggroup,runmode,true);
65 | loadWeight(ddr,weight_buffer,ldstparam.weight_offset,ldstparam.weight_addr,
66 | ldstparam.weight_ldlen,is_load_weight);
67 | ex_module(unified_buffer,weight_buffer,norm_coef,mxuparam,poolparam,is_MXU,is_relu_norm_pool);
68 | }
69 |
70 | storeFeature(ddr,unified_buffer, 512*25+72*25+72,14000, 512, true);
71 | }
72 |
--------------------------------------------------------------------------------
/src/tpu.h:
--------------------------------------------------------------------------------
1 | #include "ap_int.h"
2 |
3 | #define MXU_COLNUM 32
4 | #define MXU_ROWNUM 32
5 | #define WEIGHTDTYPE char
6 | #define FEATDTYPE char
7 | #define PSUMDTYPE ap_int<32>
8 |
9 |
10 | struct MXU_PARAM{
11 | bool isload;
12 | bool iscalc;
13 | bool isping;
14 | bool isfirstpsum;
15 |
16 | short weight_raddr;
17 | short ubuf_raddr_start;
18 | short ubuf_raddr_step1;
19 | short ubuf_raddr_step2;
20 | short ubuf_raddr_step3;
21 | short ubuf_raddr_end1;
22 | short ubuf_raddr_end2;
23 | short ubuf_raddr_end3;
24 | short ubuf_raddr_num;
25 | short psum_start;
26 | short psum_step1;
27 | short psum_end1;
28 | short psum_step2;
29 | };
30 | struct RELPOOL_PARAM{
31 | bool isrelu;
32 | short psum_raddr_start;
33 |
34 | bool maxpool; // max pool or average pool
35 | char pool_kw;
36 | char pool_kh;
37 | char pool_w;
38 | char pool_sw;
39 | char pool_sh;
40 | short pool_cnt; // output_num*pool_kw*pool_kh
41 | short pool_h_step;
42 |
43 | short avg_val;
44 | ap_uint<4> avg_shift;
45 |
46 | short ubuf_waddr_start;
47 | short ubuf_waddr_step1;
48 | short ubuf_waddr_step2;
49 | short ubuf_waddr_step3;
50 | short ubuf_waddr_end1;
51 | short ubuf_waddr_end2;
52 | short ubuf_waddr_end3;
53 | };
54 |
55 | struct LDST_PARAM{
56 | unsigned weight_offset;
57 | short weight_addr;
58 | short weight_ldlen;
59 | };
60 |
61 | void MXU(FEATDTYPE ubuf[16384][MXU_ROWNUM],WEIGHTDTYPE weight[512][MXU_COLNUM],
62 | PSUMDTYPE psum[512][MXU_COLNUM],MXU_PARAM mxuparam, bool enable);
63 | void relu_norm_pool(PSUMDTYPE psum_buffer[512][MXU_COLNUM],FEATDTYPE unified_buffer[16384][MXU_ROWNUM],
64 | ap_int<32> norm_coef[MXU_COLNUM],RELPOOL_PARAM param, bool enable);
65 | void loadWeight(ap_uint<256> *ddr,WEIGHTDTYPE weight_buffer[512][MXU_COLNUM],
66 | unsigned offset,short addr, short len, bool enable);
67 | void loadFeature(ap_uint<256> *ddr,FEATDTYPE unified_buffer[512][MXU_ROWNUM],
68 | unsigned offset,short addr, short len, bool enable);
69 | void storeFeature(ap_uint<256> *ddr,FEATDTYPE unified_buffer[512][MXU_COLNUM],
70 | unsigned offset,short addr, short len, bool enable);
71 | void instr(ap_uint<64> *ddr,unsigned &offset,ap_int<16> reggroup[96],ap_int<8> &runmode,bool enable);
72 | void config(ap_int<16> reggroup[96],MXU_PARAM &mxuparam,RELPOOL_PARAM &poolparam,
73 | LDST_PARAM &lsdtparam, ap_int<32> norm_coef[32]);
74 | void tpu(ap_uint<256> *ddr,ap_uint<64> *ddr_instr);
75 |
--------------------------------------------------------------------------------