├── .gitignore
├── LICENSE
├── README.md
├── requirements.txt
├── src
├── haoda
│ ├── backend
│ │ └── xilinx.py
│ ├── ir
│ │ ├── __init__.py
│ │ ├── arithmetic
│ │ │ ├── __init__.py
│ │ │ └── base.py
│ │ └── visitor.py
│ └── util.py
├── soda
│ ├── codegen
│ │ └── xilinx
│ │ │ ├── header.py
│ │ │ ├── hls_kernel.py
│ │ │ ├── host.py
│ │ │ ├── opencl.py
│ │ │ └── rtl_kernel.py
│ ├── core.py
│ ├── dataflow.py
│ ├── grammar.py
│ ├── mutator.py
│ ├── util.py
│ └── visitor.py
└── sodac
└── tests
├── src
├── blur.soda
├── denoise2d.soda
├── denoise3d.soda
├── heat3d.soda
├── jacobi2d.soda
├── jacobi3d.soda
├── seidel2d.soda
└── sobel2d.soda
└── test-compilation.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | sdaccel_profile_summary.csv
2 | sdaccel_profile_summary.html
3 | .Xil/
4 | __pycache__
5 | *.pyc
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 UCLA-VAST
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This repo is archived. For latest version, please see https://github.com/UCLA-VAST/soda.
2 |
3 | # SODA Compiler
4 | Stencil with Optimized Dataflow Architecture Compiler
5 |
6 | ## Publication
7 |
8 | + Yuze Chi, Jason Cong, Peng Wei, Peipei Zhou. [SODA: Stencil with Optimized Dataflow Architecture](https://doi.org/10.1145/3240765.3240850). In *ICCAD*, 2018. (Best Paper Candidate) [[PDF]](https://about.blaok.me/pub/iccad18.pdf) [[Slides]](https://about.blaok.me/pub/iccad18.slides.pdf)
9 |
10 | ## SODA DSL Example
11 |
12 | # comments start with hashtag(#)
13 |
14 | kernel: blur # the kernel name, will be used as the kernel name in HLS
15 | burst width: 512 # DRAM burst I/O width in bits, for Xilinx platform by default it's 512
16 | unroll factor: 16 # how many pixels are generated per cycle
17 |
18 | # specify the dram bank, type, name, and dimension of the input tile
19 | # the last dimension is not needed and a placeholder '*' must be given
20 | # dram bank is optional
21 | # multiple inputs can be specified but 1 and only 1 must specify the dimensions
22 | input dram 1 uint16: input(2000, *)
23 |
24 | # specify an intermediate stage of computation, may appear 0 or more times
25 | local uint16: blur_x(0, 0) = (input(0, 0) + input(0, 1) + input(0, 2)) / 3
26 |
27 | # specify the output
28 | # dram bank is optional
29 | output dram 1 uint16: blur_y(0, 0) = (blur_x(0, 0) + blur_x(1, 0) + blur_x(2, 0)) / 3
30 |
31 | # how many times the whole computation is repeated (only works if input matches output)
32 | iterate: 1
33 |
34 | ## Getting Started
35 |
36 | ### Prerequisites
37 |
38 | + Python 3.5+ and corresponding `pip`
39 | + SDAccel 2018.3 (earlier versions might work but won't be supported)
40 |
41 | How to install Python 3.5+ on Ubuntu 16.04+ and CentOS 7?
42 |
43 | #### Ubuntu 16.04+
44 | ```bash
45 | sudo apt install python3 python3-pip
46 | ```
47 |
48 | #### CentOS 7
49 | ```bash
50 | sudo yum install python36 python36-pip
51 | sudo alternatives --install /usr/bin/python3 python3 /usr/bin/python3.6 100
52 | ```
53 |
54 |
55 |
56 | ### Clone the Repo
57 | git clone https://github.com/UCLA-VAST/soda-compiler.git
58 | cd soda-compiler
59 | python3 -m pip install --user -r requirements.txt
60 |
61 | ### Parameter Setup
62 | app=blur
63 | platform=xilinx_u200_xdma_201830_1
64 | # The following can be set via sourcing /path/to/xilinx/sdx/settings64.sh
65 | XILINX_SDX=/path/to/xilinx/sdx
66 | XILINX_VIVADO=/path/to/xilinx/vivado
67 |
68 | ### Generate HLS Kernel Code
69 | src/sodac tests/src/${app}.soda --xocl-kernel ${app}_kernel.cpp
70 |
71 | ### Generate OpenCL Host Code
72 | src/sodac tests/src/${app}.soda --xocl-header ${app}.h
73 | src/sodac tests/src/${app}.soda --xocl-host ${app}.cpp
74 |
75 | ### Create Testbench
76 | cat >${app}_run.cpp <
78 | #include
79 |
80 | #include "${app}.h"
81 |
82 | int ${app}_test(const char* xclbin, const int dims[4]);
83 | int main(int argc, char **argv) {
84 | if (argc != 4) {
85 | fprintf(stderr, "Usage: \n %s \n", argv[0]);
86 | return 1;
87 | }
88 | int dims[4] = {atoi(argv[2]), atoi(argv[3]), 0, 0};
89 | return ${app}_test(argv[1], dims);
90 | }
91 | EOF
92 |
93 | ### Compile OpenCL Host Executable
94 | # Please set TILE_SIZE_DIM_0 and UNROLL_FACTOR macros to match the kernel.
95 | g++ -std=c++11 -I${XILINX_SDX}/runtime/include -I${XILINX_VIVADO}/include ${app}.cpp ${app}_run.cpp -o ${app} \
96 | -lxilinxopencl -DTILE_SIZE_DIM_0=2000 -DUNROLL_FACTOR=2 -fopenmp -Wno-deprecated-declarations -Wall
97 |
98 | ### Create Emulation Config
99 | emconfigutil -f ${platform}
100 |
101 | ### Software Emulation
102 |
103 | #### Compile for Software Emulation
104 | xocc -t sw_emu -f ${platform} --kernel ${app}_kernel --xp prop:kernel.${app}_kernel.kernel_flags="-std=c++0x" \
105 | -c ${app}_kernel.cpp -o ${app}.sw_emu.xo
106 |
107 | #### Link for Software Emulation
108 | xocc -t sw_emu -f ${platform} -l ${app}.sw_emu.xo -o ${app}.sw_emu.xclbin
109 |
110 | #### Run Software Emulation
111 | XCL_EMULATION_MODE=sw_emu ./${app} ${app}.sw_emu.xclbin 2000 100
112 |
113 | ### High-Level Synthesis
114 | xocc -t hw -f ${platform} --kernel ${app}_kernel --xp prop:kernel.${app}_kernel.kernel_flags="-std=c++0x" \
115 | -c ${app}_kernel.cpp -o ${app}.hw.xo
116 |
117 | ### Hardware Emulation
118 |
119 | #### Link for Hardware Emulation
120 | xocc -t hw_emu -f ${platform} -l ${app}.hw.xo -o ${app}.hw_emu.xclbin
121 |
122 | #### Run Hardware Emulation
123 | # By default, kernel ports are connected via DRAM bank 1 on the xilinx_u200_xdma_201830_1 platform.
124 | DRAM_IN=1 DRAM_OUT=1 XCL_EMULATION_MODE=hw_emu ./${app} ${app}.hw_emu.xclbin 2000 10
125 |
126 | ### Hardware Deployment
127 |
128 | #### Logic Synthesis, Place, and Route
129 | xocc -t hw -f ${platform} -l ${app}.hw.xo -o ${app}.hw.xclbin
130 |
131 | #### Run Bitstream on FPGA
132 | # By default, kernel ports are connected via DRAM bank 1 on the xilinx_u200_xdma_201830_1 platform.
133 | DRAM_IN=1 DRAM_OUT=1 ./${app} ${app}.hw.xclbin 2000 1000
134 |
135 | ## Code Snippet Example
136 |
137 | ### Source Code
138 |
139 | kernel: jacobi2d
140 | burst width: 512
141 | unroll factor: 2
142 | input float: t1(2000, *)
143 | output float: t0(0, 0) = (t1(0, 1) + t1(1, 0) + t1(0, 0) + t1(0, -1) + t1(-1, 0)) * 0.2f
144 | iterate: 1
145 |
146 | ### HLS Kernel Code
147 | Each function in the below code snippets is synthesized into an RTL module.
148 | Their arguments are all `hls::stream` FIFOs; Without unrolling, a simple line-buffer pipeline is generated, producing 1 pixel per cycle.
149 | With unrolling, a SODA microarchitecture pipeline is generated, procuding 2 pixeles per cycle.
150 |
151 | #### Without Unrolling (`--unroll-factor=1`)
152 |
153 | #pragma HLS dataflow
154 | Module1Func(
155 | /*output*/ &from_t1_offset_0_to_t1_offset_1999,
156 | /*output*/ &from_t1_offset_0_to_t0_pe_0,
157 | /* input*/ &from_super_source_to_t1_offset_0);
158 | Module2Func(
159 | /*output*/ &from_t1_offset_1999_to_t1_offset_2000,
160 | /*output*/ &from_t1_offset_1999_to_t0_pe_0,
161 | /* input*/ &from_t1_offset_0_to_t1_offset_1999);
162 | Module3Func(
163 | /*output*/ &from_t1_offset_2000_to_t1_offset_2001,
164 | /*output*/ &from_t1_offset_2000_to_t0_pe_0,
165 | /* input*/ &from_t1_offset_1999_to_t1_offset_2000);
166 | Module3Func(
167 | /*output*/ &from_t1_offset_2001_to_t1_offset_4000,
168 | /*output*/ &from_t1_offset_2001_to_t0_pe_0,
169 | /* input*/ &from_t1_offset_2000_to_t1_offset_2001);
170 | Module4Func(
171 | /*output*/ &from_t1_offset_4000_to_t0_pe_0,
172 | /* input*/ &from_t1_offset_2001_to_t1_offset_4000);
173 | Module5Func(
174 | /*output*/ &from_t0_pe_0_to_super_sink,
175 | /* input*/ &from_t1_offset_0_to_t0_pe_0,
176 | /* input*/ &from_t1_offset_1999_to_t0_pe_0,
177 | /* input*/ &from_t1_offset_2000_to_t0_pe_0,
178 | /* input*/ &from_t1_offset_4000_to_t0_pe_0,
179 | /* input*/ &from_t1_offset_2001_to_t0_pe_0);
180 |
181 | In the above code snippet, `Module1Func` to `Module4Func` are forwarding modules; they constitute the data-reuse line buffer.
182 | The line buffer size is approximately two lines of pixels, i.e. 4000 pixels.
183 | `Module5Func` is a computing module; it implements the computation kernel.
184 | The whole design is fully pipelined; however, with only 1 computing module, it can only produce 1 pixel per cycle.
185 |
186 | #### Unroll 2 Times (`--unroll-factor=2`)
187 |
188 | #pragma HLS dataflow
189 | Module1Func(
190 | /*output*/ &from_t1_offset_1_to_t1_offset_1999,
191 | /*output*/ &from_t1_offset_1_to_t0_pe_0,
192 | /* input*/ &from_super_source_to_t1_offset_1);
193 | Module1Func(
194 | /*output*/ &from_t1_offset_0_to_t1_offset_2000,
195 | /*output*/ &from_t1_offset_0_to_t0_pe_1,
196 | /* input*/ &from_super_source_to_t1_offset_0);
197 | Module2Func(
198 | /*output*/ &from_t1_offset_1999_to_t1_offset_2001,
199 | /*output*/ &from_t1_offset_1999_to_t0_pe_1,
200 | /* input*/ &from_t1_offset_1_to_t1_offset_1999);
201 | Module3Func(
202 | /*output*/ &from_t1_offset_2000_to_t1_offset_2002,
203 | /*output*/ &from_t1_offset_2000_to_t0_pe_1,
204 | /*output*/ &from_t1_offset_2000_to_t0_pe_0,
205 | /* input*/ &from_t1_offset_0_to_t1_offset_2000);
206 | Module4Func(
207 | /*output*/ &from_t1_offset_2001_to_t1_offset_4001,
208 | /*output*/ &from_t1_offset_2001_to_t0_pe_1,
209 | /*output*/ &from_t1_offset_2001_to_t0_pe_0,
210 | /* input*/ &from_t1_offset_1999_to_t1_offset_2001);
211 | Module5Func(
212 | /*output*/ &from_t1_offset_2002_to_t1_offset_4000,
213 | /*output*/ &from_t1_offset_2002_to_t0_pe_0,
214 | /* input*/ &from_t1_offset_2000_to_t1_offset_2002);
215 | Module6Func(
216 | /*output*/ &from_t1_offset_4001_to_t0_pe_0,
217 | /* input*/ &from_t1_offset_2001_to_t1_offset_4001);
218 | Module7Func(
219 | /*output*/ &from_t0_pe_0_to_super_sink,
220 | /* input*/ &from_t1_offset_1_to_t0_pe_0,
221 | /* input*/ &from_t1_offset_2000_to_t0_pe_0,
222 | /* input*/ &from_t1_offset_2001_to_t0_pe_0,
223 | /* input*/ &from_t1_offset_4001_to_t0_pe_0,
224 | /* input*/ &from_t1_offset_2002_to_t0_pe_0);
225 | Module8Func(
226 | /*output*/ &from_t1_offset_4000_to_t0_pe_1,
227 | /* input*/ &from_t1_offset_2002_to_t1_offset_4000);
228 | Module7Func(
229 | /*output*/ &from_t0_pe_1_to_super_sink,
230 | /* input*/ &from_t1_offset_0_to_t0_pe_1,
231 | /* input*/ &from_t1_offset_1999_to_t0_pe_1,
232 | /* input*/ &from_t1_offset_2000_to_t0_pe_1,
233 | /* input*/ &from_t1_offset_4000_to_t0_pe_1,
234 | /* input*/ &from_t1_offset_2001_to_t0_pe_1);
235 |
236 | In the above code snippet, `Module1Func` to `Module6Func` and `Module8Func` are forwarding modules; they constitute the reuse buffers of the SODA microarchitecture.
237 | Although unrolled, the reuse buffer size is still approximately two lines of pixels, i.e. 4000 pixels.
238 | `Module7Func` is a computing module; it is instanciated twice.
239 | The whole design is fully pipelined and can produce 2 pixel per cycle.
240 | In general, the unroll factor can be set to any number that satisfies the throughput requirement.
241 |
242 | ## Design Considerations
243 |
244 | + `kernel`, `burst width`, `unroll factor`, `input`, `output`, and `iterate` keywords are mandatory
245 | + For non-iterative stencil, `unroll factor` shall be determined by the DRAM bandwidth, i.e. saturate the external bandwidth, since the resource is usually not the bottleneck
246 | + For iterative stencil, prefer to use more PEs in a single iteration rather than implement more iterations
247 | + Note that `2.0` will be a `double` number. To generate `float`, use `2.0f`. This may help reduce DSP usage
248 | + SODA is tiling-based and the size of the tile is specified in the `input` keyword. The last dimension is a placeholder because it is not needed in the reuse buffer generation
249 |
250 | ## Projects Using SODA
251 |
252 | + Yi-Hsiang Lai, Yuze Chi, Yuwei Hu, Jie Wang, Cody Hao Yu, Yuan Zhou, Jason Cong, Zhiru Zhang. [HeteroCL: A Multi-Paradigm Programming Infrastructure for Software-Defined Reconfigurable Computing](https://doi.org/10.1145/3289602.3293910). In *FPGA*, 2019. (Best Paper Candidate) [[PDF]](https://about.blaok.me/pub/fpga19-heterocl.pdf) [[Slides]](https://about.blaok.me/pub/fpga19-heterocl.slides.pdf)
253 | + Yuze Chi, Young-kyu Choi, Jason Cong, Jie Wang. [Rapid Cycle-Accurate Simulator for High-Level Synthesis](https://doi.org/10.1145/3289602.3293918). In *FPGA*, 2019. [[PDF]](https://about.blaok.me/pub/fpga19-flash.pdf) [[Slides]](https://about.blaok.me/pub/fpga19-flash.slides.pdf)
254 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | textx
2 | cached_property
3 |
--------------------------------------------------------------------------------
/src/haoda/backend/xilinx.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import collections
3 | import logging
4 | import os
5 | import subprocess
6 | import tarfile
7 | import tempfile
8 | import xml.etree.ElementTree as ET
9 | import zipfile
10 |
11 | from haoda import util
12 |
13 | _logger = logging.getLogger().getChild(__name__)
14 |
15 | class Vivado(subprocess.Popen):
16 | """Call vivado with the given tcl commands and arguments.
17 |
18 | Args:
19 | commands: string of tcl commands
20 | args: sequence of arguments
21 | """
22 | def __init__(self, commands, *args):
23 | self.cwd = tempfile.TemporaryDirectory(prefix='vivado-')
24 | self.tcl_file = open(os.path.join(self.cwd.name, 'tcl'), mode='w+')
25 | self.tcl_file.write(commands)
26 | self.tcl_file.flush()
27 | cmd_args = ['vivado', '-mode', 'batch', '-source', self.tcl_file.name,
28 | '-nojournal', '-nolog', '-tclargs', *args]
29 | pipe_args = {'stdout' : subprocess.PIPE, 'stderr' : subprocess.PIPE}
30 | super().__init__(cmd_args, cwd=self.cwd.name, **pipe_args)
31 |
32 | def __exit__(self, *args):
33 | super().__exit__(*args)
34 | self.tcl_file.close()
35 | self.cwd.cleanup()
36 |
37 | class VivadoHls(subprocess.Popen):
38 | """Call vivado_hls with the given tcl commands.
39 |
40 | Args:
41 | commands: string of tcl commands
42 | """
43 | def __init__(self, commands):
44 | self.cwd = tempfile.TemporaryDirectory(prefix='vivado-hls-')
45 | self.tcl_file = open(os.path.join(self.cwd.name, 'tcl'), mode='w+')
46 | self.tcl_file.write(commands)
47 | self.tcl_file.flush()
48 | cmd_args = ['vivado_hls', '-f', self.tcl_file.name, '-l', '/dev/null']
49 | pipe_args = {'stdout' : subprocess.PIPE, 'stderr' : subprocess.PIPE}
50 | super().__init__(cmd_args, cwd=self.cwd.name, **pipe_args)
51 |
52 | def __exit__(self, *args):
53 | super().__exit__(*args)
54 | self.tcl_file.close()
55 | self.cwd.cleanup()
56 |
57 | PACKAGEXO_COMMANDS = r'''
58 | set tmp_ip_dir "{tmpdir}/tmp_ip_dir"
59 | set tmp_project "{tmpdir}/tmp_project"
60 |
61 | create_project -force kernel_pack ${{tmp_project}}
62 | add_files -norecurse [glob {hdl_dir}/*.v]
63 | foreach tcl_file [glob -nocomplain {hdl_dir}/*.tcl] {{
64 | source ${{tcl_file}}
65 | }}
66 | update_compile_order -fileset sources_1
67 | update_compile_order -fileset sim_1
68 | ipx::package_project -root_dir ${{tmp_ip_dir}} -vendor xilinx.com -library RTLKernel -taxonomy /KernelIP -import_files -set_current false
69 | ipx::unload_core ${{tmp_ip_dir}}/component.xml
70 | ipx::edit_ip_in_project -upgrade true -name tmp_edit_project -directory ${{tmp_ip_dir}} ${{tmp_ip_dir}}/component.xml
71 | set_property core_revision 2 [ipx::current_core]
72 | foreach up [ipx::get_user_parameters] {{
73 | ipx::remove_user_parameter [get_property NAME ${{up}}] [ipx::current_core]
74 | }}
75 | set_property sdx_kernel true [ipx::current_core]
76 | set_property sdx_kernel_type rtl [ipx::current_core]
77 | ipx::create_xgui_files [ipx::current_core]
78 | {bus_ifaces}
79 | ipx::associate_bus_interfaces -busif s_axi_control -clock ap_clk [ipx::current_core]
80 | set_property xpm_libraries {{XPM_CDC XPM_MEMORY XPM_FIFO}} [ipx::current_core]
81 | set_property supported_families {{ }} [ipx::current_core]
82 | set_property auto_family_support_level level_2 [ipx::current_core]
83 | ipx::update_checksums [ipx::current_core]
84 | ipx::save_core [ipx::current_core]
85 | close_project -delete
86 |
87 | package_xo -force -xo_path "{xo_file}" -kernel_name {top_name} -ip_directory ${{tmp_ip_dir}} -kernel_xml {kernel_xml}{cpp_kernels}
88 | '''
89 |
90 | class PackageXo(Vivado):
91 | """Packages the given files into a Xilinx hardware object.
92 |
93 | Args:
94 | xo_file: name of the generated xo file.
95 | top_name: top-level module name.
96 | kernel_xml: xml description of the kernel.
97 | hdl_dir: directory of all HDL files.
98 | m_axi_names: variable names connected to the m_axi bus.
99 | cpp_kernels: sequence of file names of C++ kernels.
100 | """
101 | def __init__(self, xo_file, top_name, kernel_xml, hdl_dir, m_axi_names,
102 | cpp_kernels=()):
103 | self.tmpdir = tempfile.TemporaryDirectory(prefix='package-xo-')
104 | if _logger.isEnabledFor(logging.INFO):
105 | for _, _, files in os.walk(hdl_dir):
106 | for filename in files:
107 | _logger.info('packing: %s', filename)
108 | kwargs = {
109 | 'top_name' : top_name,
110 | 'kernel_xml' : kernel_xml,
111 | 'hdl_dir' : hdl_dir,
112 | 'xo_file' : xo_file,
113 | 'bus_ifaces' : '\n'.join(map(
114 | 'ipx::associate_bus_interfaces -busif m_axi_{} -clock ap_clk '
115 | '[ipx::current_core]'.format, m_axi_names)),
116 | 'tmpdir' : self.tmpdir.name,
117 | 'cpp_kernels' : ''.join(map(' -kernel_files {}'.format, cpp_kernels))
118 | }
119 | super().__init__(PACKAGEXO_COMMANDS.format(**kwargs))
120 |
121 | def __exit__(self, *args):
122 | super().__exit__(*args)
123 | self.tmpdir.cleanup()
124 |
125 | HLS_COMMANDS = r'''
126 | cd "{project_dir}"
127 | open_project "{project_name}"
128 | set_top {top_name}
129 | {add_kernels}
130 | open_solution "{solution_name}"
131 | set_part {{{part_num}}}
132 | create_clock -period {clock_period} -name default
133 | config_compile -name_max_length 253
134 | config_interface -m_axi_addr64
135 | config_rtl -disable_start_propagation
136 | csynth_design
137 | exit
138 | '''
139 |
140 | class RunHls(VivadoHls):
141 | """Runs Vivado HLS for the given kernels and generate HDL files
142 |
143 | Args:
144 | tarfileobj: file object that will contain the reports and HDL files.
145 | kernel_files: file names of the kernels.
146 | top_name: top-level module name.
147 | clock_period: target clock period.
148 | part_num: target part number.
149 | """
150 | def __init__(self, tarfileobj, kernel_files, top_name, clock_period,
151 | part_num):
152 | self.project_dir = tempfile.TemporaryDirectory(prefix='hls-')
153 | self.project_name = 'project'
154 | self.solution_name = 'solution'
155 | self.tarfileobj = tarfileobj
156 | kwargs = {
157 | 'project_dir' : self.project_dir.name,
158 | 'project_name' : self.project_name,
159 | 'solution_name' : self.solution_name,
160 | 'top_name' : top_name,
161 | 'add_kernels' : '\n'.join(map(
162 | 'add_files "{}" -cflags "-std=c++11"'.format, kernel_files)),
163 | 'part_num' : part_num,
164 | 'clock_period' : clock_period
165 | }
166 | super().__init__(HLS_COMMANDS.format(**kwargs))
167 |
168 | def __exit__(self, *args):
169 | super().__exit__(*args)
170 | if self.returncode == 0:
171 | with tarfile.open(mode='w', fileobj=self.tarfileobj) as tar:
172 | solution_dir = os.path.join(self.project_dir.name, self.project_name,
173 | self.solution_name)
174 | tar.add(os.path.join(solution_dir, 'syn/report'), arcname='report')
175 | tar.add(os.path.join(solution_dir, 'syn/verilog'), arcname='hdl')
176 | tar.add(os.path.join(solution_dir, self.solution_name + '.log'),
177 | arcname=self.solution_name + '.log')
178 | self.project_dir.cleanup()
179 |
180 | XILINX_XML_NS = {'xd' : 'http://www.xilinx.com/xd'}
181 |
182 | def get_device_info(platform_path):
183 | """Extract device part number and target frequency from SDAccel platform.
184 |
185 | Currently only support 5.x platforms.
186 | """
187 | device_name = os.path.basename(platform_path)
188 | with zipfile.ZipFile(os.path.join(
189 | platform_path, 'hw', device_name + '.dsa')) as platform:
190 | with platform.open(device_name + '.hpfm') as metadata:
191 | platform_info = ET.parse(metadata).find('./xd:component/xd:platformInfo',
192 | XILINX_XML_NS)
193 | return {
194 | 'clock_period' : platform_info.find(
195 | "./xd:systemClocks/xd:clock/[@xd:id='0']", XILINX_XML_NS).attrib[
196 | '{{{xd}}}period'.format(**XILINX_XML_NS)],
197 | 'part_num' : platform_info.find(
198 | 'xd:deviceInfo', XILINX_XML_NS).attrib[
199 | '{{{xd}}}name'.format(**XILINX_XML_NS)]
200 | }
201 |
202 | KERNEL_XML_TEMPLATE = r'''
203 |
204 |
205 |
206 | {m_axi_ports}
207 |
208 |
209 | {args}
210 |
211 |
212 |
213 | '''
214 |
215 | PORT_TEMPLATE = r'''
216 |
217 | '''
218 |
219 | ARG_TEMPLATE = r'''
220 |
221 | '''
222 |
223 | def print_kernel_xml(top_name, ports, kernel_xml):
224 | """Generate kernel.xml file.
225 |
226 | Args:
227 | top_name: name of the top-level kernel function.
228 | ports: sequence of (port_name, bundle_name, haoda_type, _) of m_axi ports
229 | kernel_xml: file object to write to.
230 | """
231 | m_axi_ports = ''
232 | args = ''
233 | offset = 0x10
234 | arg_id = 0
235 | bundle_set = set()
236 | for port_name, bundle_name, haoda_type, _ in ports:
237 | size = host_size = 8
238 | if bundle_name not in bundle_set:
239 | m_axi_ports += PORT_TEMPLATE.format(
240 | name=bundle_name,
241 | width=util.get_width_in_bits(haoda_type)).rstrip('\n')
242 | bundle_set.add(bundle_name)
243 | args += ARG_TEMPLATE.format(
244 | name=port_name, addr_qualifier=1, arg_id=arg_id,
245 | port_name='m_axi_' + bundle_name, c_type=util.get_c_type(haoda_type),
246 | size=size, offset=offset, host_size=host_size).rstrip('\n')
247 | offset += size + 4
248 | arg_id += 1
249 | args += ARG_TEMPLATE.format(
250 | name='coalesced_data_num', addr_qualifier=0, arg_id=arg_id,
251 | port_name='s_axi_control', c_type='uint64_t', size=size, offset=offset,
252 | host_size=host_size).rstrip('\n')
253 | kernel_xml.write(KERNEL_XML_TEMPLATE.format(
254 | top_name=top_name, m_axi_ports=m_axi_ports, args=args))
255 |
256 | BRAM_FIFO_TEMPLATE = r'''
257 | `timescale 1ns/1ps
258 |
259 | module {name}_w{width}_d{depth}_A
260 | #(parameter
261 | MEM_STYLE = "block",
262 | DATA_WIDTH = {width},
263 | ADDR_WIDTH = {addr_width},
264 | DEPTH = {depth}
265 | )
266 | (
267 | // system signal
268 | input wire clk,
269 | input wire reset,
270 |
271 | // write
272 | output wire if_full_n,
273 | input wire if_write_ce,
274 | input wire if_write,
275 | input wire [DATA_WIDTH-1:0] if_din,
276 |
277 | // read
278 | output wire if_empty_n,
279 | input wire if_read_ce,
280 | input wire if_read,
281 | output wire [DATA_WIDTH-1:0] if_dout
282 | );
283 | //------------------------Parameter----------------------
284 |
285 | //------------------------Local signal-------------------
286 | (* ram_style = MEM_STYLE *)
287 | reg [DATA_WIDTH-1:0] mem[0:DEPTH-1];
288 | reg [DATA_WIDTH-1:0] q_buf = 1'b0;
289 | reg [ADDR_WIDTH-1:0] waddr = 1'b0;
290 | reg [ADDR_WIDTH-1:0] raddr = 1'b0;
291 | wire [ADDR_WIDTH-1:0] wnext;
292 | wire [ADDR_WIDTH-1:0] rnext;
293 | wire push;
294 | wire pop;
295 | reg [ADDR_WIDTH-1:0] usedw = 1'b0;
296 | reg full_n = 1'b1;
297 | reg empty_n = 1'b0;
298 | reg [DATA_WIDTH-1:0] q_tmp = 1'b0;
299 | reg show_ahead = 1'b0;
300 | reg [DATA_WIDTH-1:0] dout_buf = 1'b0;
301 | reg dout_valid = 1'b0;
302 |
303 |
304 | //------------------------Instantiation------------------
305 |
306 | //------------------------Task and function--------------
307 |
308 | //------------------------Body---------------------------
309 | assign if_full_n = full_n;
310 | assign if_empty_n = dout_valid;
311 | assign if_dout = dout_buf;
312 | assign push = full_n & if_write_ce & if_write;
313 | assign pop = empty_n & if_read_ce & (~dout_valid | if_read);
314 | assign wnext = !push ? waddr :
315 | (waddr == DEPTH - 1) ? 1'b0 :
316 | waddr + 1'b1;
317 | assign rnext = !pop ? raddr :
318 | (raddr == DEPTH - 1) ? 1'b0 :
319 | raddr + 1'b1;
320 |
321 | // waddr
322 | always @(posedge clk) begin
323 | if (reset == 1'b1)
324 | waddr <= 1'b0;
325 | else
326 | waddr <= wnext;
327 | end
328 |
329 | // raddr
330 | always @(posedge clk) begin
331 | if (reset == 1'b1)
332 | raddr <= 1'b0;
333 | else
334 | raddr <= rnext;
335 | end
336 |
337 | // usedw
338 | always @(posedge clk) begin
339 | if (reset == 1'b1)
340 | usedw <= 1'b0;
341 | else if (push & ~pop)
342 | usedw <= usedw + 1'b1;
343 | else if (~push & pop)
344 | usedw <= usedw - 1'b1;
345 | end
346 |
347 | // full_n
348 | always @(posedge clk) begin
349 | if (reset == 1'b1)
350 | full_n <= 1'b1;
351 | else if (push & ~pop)
352 | full_n <= (usedw != DEPTH - 1);
353 | else if (~push & pop)
354 | full_n <= 1'b1;
355 | end
356 |
357 | // empty_n
358 | always @(posedge clk) begin
359 | if (reset == 1'b1)
360 | empty_n <= 1'b0;
361 | else if (push & ~pop)
362 | empty_n <= 1'b1;
363 | else if (~push & pop)
364 | empty_n <= (usedw != 1'b1);
365 | end
366 |
367 | // mem
368 | always @(posedge clk) begin
369 | if (push)
370 | mem[waddr] <= if_din;
371 | end
372 |
373 | // q_buf
374 | always @(posedge clk) begin
375 | q_buf <= mem[rnext];
376 | end
377 |
378 | // q_tmp
379 | always @(posedge clk) begin
380 | if (reset == 1'b1)
381 | q_tmp <= 1'b0;
382 | else if (push)
383 | q_tmp <= if_din;
384 | end
385 |
386 | // show_ahead
387 | always @(posedge clk) begin
388 | if (reset == 1'b1)
389 | show_ahead <= 1'b0;
390 | else if (push && usedw == pop)
391 | show_ahead <= 1'b1;
392 | else
393 | show_ahead <= 1'b0;
394 | end
395 |
396 | // dout_buf
397 | always @(posedge clk) begin
398 | if (reset == 1'b1)
399 | dout_buf <= 1'b0;
400 | else if (pop)
401 | dout_buf <= show_ahead? q_tmp : q_buf;
402 | end
403 |
404 | // dout_valid
405 | always @(posedge clk) begin
406 | if (reset == 1'b1)
407 | dout_valid <= 1'b0;
408 | else if (pop)
409 | dout_valid <= 1'b1;
410 | else if (if_read_ce & if_read)
411 | dout_valid <= 1'b0;
412 | end
413 |
414 | endmodule
415 | '''
416 |
417 | SRL_FIFO_TEMPLATE = r'''
418 | // ==============================================================
419 | // File generated by Vivado(TM) HLS - High-Level Synthesis from C, C++ and SystemC
420 | // Version: 2018.2
421 | // Copyright (C) 1986-2018 Xilinx, Inc. All Rights Reserved.
422 | //
423 | // ==============================================================
424 |
425 |
426 | `timescale 1 ns / 1 ps
427 |
428 | module {name}_w{width}_d{depth}_A_shiftReg (
429 | clk,
430 | data,
431 | ce,
432 | a,
433 | q);
434 |
435 | parameter DATA_WIDTH = 32'd{width};
436 | parameter ADDR_WIDTH = 32'd{addr_width};
437 | parameter DEPTH = {depth_width}'d{depth};
438 |
439 | input clk;
440 | input [DATA_WIDTH-1:0] data;
441 | input ce;
442 | input [ADDR_WIDTH-1:0] a;
443 | output [DATA_WIDTH-1:0] q;
444 |
445 | reg[DATA_WIDTH-1:0] SRL_SIG [0:DEPTH-1];
446 | integer i;
447 |
448 | always @ (posedge clk)
449 | begin
450 | if (ce)
451 | begin
452 | for (i=0;i threshold:
629 | self.bram_fifo_module(width, depth)
630 | else:
631 | self.srl_fifo_module(width, depth)
632 |
633 | def bram_fifo_module(self, width, depth, name='fifo'):
634 | """Generate BRAM FIFO with the given parameters.
635 |
636 | Generate a BRAM FIFO module named {name}_w{width}_d{depth}_A.
637 |
638 | Args:
639 | printer: VerilogPrinter to print to.
640 | width: FIFO width
641 | depth: FIFO depth
642 | name: Optionally give the fifo a name prefix, default to 'fifo'.
643 | """
644 | self._out.write(BRAM_FIFO_TEMPLATE.format(
645 | width=width, depth=depth, name=name,
646 | addr_width=(depth - 1).bit_length()))
647 |
648 | def srl_fifo_module(self, width, depth, name='fifo'):
649 | """Generate SRL FIFO with the given parameters.
650 |
651 | Generate a SRL FIFO module named {name}_w{width}_d{depth}_A.
652 |
653 | Args:
654 | printer: VerilogPrinter to print to.
655 | width: FIFO width
656 | depth: FIFO depth
657 | name: Optionally give the fifo a name prefix, default to 'fifo'.
658 | """
659 | addr_width = (depth - 1).bit_length()
660 | self._out.write(SRL_FIFO_TEMPLATE.format(
661 | width=width, depth=depth, name=name, addr_width=addr_width,
662 | depth_width=addr_width + 1))
663 |
--------------------------------------------------------------------------------
/src/haoda/ir/__init__.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import copy
3 | import logging
4 | import math
5 |
6 | import cached_property
7 |
8 | from haoda import util
9 | from haoda.ir import visitor
10 |
11 | _logger = logging.getLogger().getChild(__name__)
12 |
13 | GRAMMAR = r'''
14 | Bin: /0[Bb][01]+([Uu][Ll][Ll]?|[Ll]?[Ll]?[Uu]?)/;
15 | Dec: /\d+([Uu][Ll][Ll]?|[Ll]?[Ll]?[Uu]?)/;
16 | Oct: /0[0-7]+([Uu][Ll][Ll]?|[Ll]?[Ll]?[Uu]?)/;
17 | Hex: /0[Xx][0-9a-fA-F]+([Uu][Ll][Ll]?|[Ll]?[Ll]?[Uu]?)/;
18 | Int: ('+'|'-')?(Hex|Bin|Oct|Dec);
19 | Float: /(((\d*\.\d+|\d+\.)([+-]?[Ee]\d+)?)|(\d+[+-]?[Ee]\d+))[FfLl]?/;
20 | Num: Float|Int;
21 |
22 | Type: FixedType | FloatType;
23 | FixedType: /u?int[1-9]\d*(_[1-9]\d*)?/;
24 | FloatType: /float[1-9]\d*(_[1-9]\d*)?/ | 'float' | 'double' | 'half';
25 |
26 | Let: (haoda_type=Type)? name=ID '=' expr=Expr;
27 | Ref: name=ID '(' idx=INT (',' idx=INT)* ')' ('~' lat=Int)?;
28 |
29 | Expr: operand=LogicAnd (operator=LogicOrOp operand=LogicAnd)*;
30 | LogicOrOp: '||';
31 |
32 | LogicAnd: operand=BinaryOr (operator=LogicAndOp operand=BinaryOr)*;
33 | LogicAndOp: '&&';
34 |
35 | BinaryOr: operand=Xor (operator=BinaryOrOp operand=Xor)*;
36 | BinaryOrOp: '|';
37 |
38 | Xor: operand=BinaryAnd (operator=XorOp operand=BinaryAnd)*;
39 | XorOp: '^';
40 |
41 | BinaryAnd: operand=EqCmp (operator=BinaryAndOp operand=EqCmp)*;
42 | BinaryAndOp: '&';
43 |
44 | EqCmp: operand=LtCmp (operator=EqCmpOp operand=LtCmp)*;
45 | EqCmpOp: '=='|'!=';
46 |
47 | LtCmp: operand=AddSub (operator=LtCmpOp operand=AddSub)*;
48 | LtCmpOp: '<='|'>='|'<'|'>';
49 |
50 | AddSub: operand=MulDiv (operator=AddSubOp operand=MulDiv)*;
51 | AddSubOp: '+'|'-';
52 |
53 | MulDiv: operand=Unary (operator=MulDivOp operand=Unary)*;
54 | MulDivOp: '*'|'/'|'%';
55 |
56 | Unary: (operator=UnaryOp)* operand=Operand;
57 | UnaryOp: '+'|'-'|'~'|'!';
58 |
59 | Operand: cast=Cast | call=Call | ref=Ref | num=Num | var=Var | '(' expr=Expr ')';
60 | Cast: haoda_type=Type '(' expr=Expr ')';
61 | Call: name=FuncName '(' arg=Expr (',' arg=Expr)* ')';
62 | Var: name=ID ('[' idx=Int ']')*;
63 |
64 | '''
65 |
66 | class Node():
67 | """A immutable, hashable IR node.
68 | """
69 | SCALAR_ATTRS = ()
70 | LINEAR_ATTRS = ()
71 |
72 | @property
73 | def ATTRS(self):
74 | return self.SCALAR_ATTRS + self.LINEAR_ATTRS
75 |
76 | def __init__(self, **kwargs):
77 | for attr in self.SCALAR_ATTRS:
78 | setattr(self, attr, kwargs.pop(attr))
79 | for attr in self.LINEAR_ATTRS:
80 | setattr(self, attr, tuple(kwargs.pop(attr)))
81 |
82 | def __hash__(self):
83 | return hash((tuple(getattr(self, _) for _ in self.SCALAR_ATTRS),
84 | tuple(tuple(getattr(self, _)) for _ in self.LINEAR_ATTRS)))
85 |
86 | def __eq__(self, other):
87 | return all(hasattr(other, attr) and
88 | getattr(self, attr) == getattr(other, attr)
89 | for attr in self.ATTRS)
90 |
91 | @property
92 | def c_type(self):
93 | return util.get_c_type(self.haoda_type)
94 |
95 | @property
96 | def width_in_bits(self):
97 | return util.get_width_in_bits(self.haoda_type)
98 |
99 | def visit(self, callback, args=None, pre_recursion=None, post_recursion=None):
100 | """A general-purpose, flexible, and powerful visitor.
101 |
102 | The args parameter will be passed to the callback callable so that it may
103 | read or write any information from or to the caller.
104 |
105 | A copy of self will be made and passed to the callback to avoid destructive
106 | access.
107 |
108 | If a new object is returned by the callback, it will be returned directly
109 | without recursion.
110 |
111 | If the same object is returned by the callback, if any attribute is
112 | changed, it will not be recursively visited. If an attribute is unchanged,
113 | it will be recursively visited.
114 | """
115 |
116 | def callback_wrapper(callback, obj, args):
117 | if callback is None:
118 | return obj
119 | result = callback(obj, args)
120 | if result is not None:
121 | return result
122 | return obj
123 |
124 | self_copy = copy.copy(self)
125 | obj = callback_wrapper(callback, self_copy, args)
126 | if obj is not self_copy:
127 | return obj
128 | self_copy = callback_wrapper(pre_recursion, copy.copy(self), args)
129 | scalar_attrs = {attr: getattr(self_copy, attr).visit(
130 | callback, args, pre_recursion, post_recursion)
131 | if isinstance(getattr(self_copy, attr), Node)
132 | else getattr(self_copy, attr)
133 | for attr in self_copy.SCALAR_ATTRS}
134 | linear_attrs = {attr: tuple(_.visit(
135 | callback, args, pre_recursion, post_recursion)
136 | if isinstance(_, Node) else _
137 | for _ in getattr(self_copy, attr))
138 | for attr in self_copy.LINEAR_ATTRS}
139 |
140 | for attr in self.SCALAR_ATTRS:
141 | # old attribute may not exist in mutated object
142 | if not hasattr(obj, attr):
143 | continue
144 | if getattr(obj, attr) is getattr(self, attr):
145 | if isinstance(getattr(obj, attr), Node):
146 | setattr(obj, attr, scalar_attrs[attr])
147 | for attr in self.LINEAR_ATTRS:
148 | # old attribute may not exist in mutated object
149 | if not hasattr(obj, attr):
150 | continue
151 | setattr(obj, attr, tuple(
152 | c if a is b and isinstance(a, Node) else a
153 | for a, b, c in zip(getattr(obj, attr), getattr(self, attr),
154 | linear_attrs[attr])))
155 | return callback_wrapper(post_recursion, obj, args)
156 |
157 | class Let(Node):
158 | SCALAR_ATTRS = 'haoda_type', 'name', 'expr'
159 |
160 | def __str__(self):
161 | result = '{} = {}'.format(self.name, unparenthesize(self.expr))
162 | if self.haoda_type is not None:
163 | result = '{} {}'.format(self.haoda_type, result)
164 | return result
165 |
166 | @property
167 | def haoda_type(self):
168 | if self._haoda_type is None:
169 | return self.expr.haoda_type
170 | return self._haoda_type
171 |
172 | @haoda_type.setter
173 | def haoda_type(self, val):
174 | self._haoda_type = val
175 |
176 | @property
177 | def c_expr(self):
178 | return 'const {} {} = {};'.format(self.c_type, self.name,
179 | unparenthesize(self.expr.c_expr))
180 |
181 | class Ref(Node):
182 | SCALAR_ATTRS = 'name', 'lat'
183 | LINEAR_ATTRS = ('idx',)
184 | def __init__(self, **kwargs):
185 | super().__init__(**kwargs)
186 | self.idx = tuple(self.idx)
187 | if not hasattr(self, 'haoda_type'):
188 | self.haoda_type = None
189 | # self.lat will be defined in super().__init__(**kwargs)
190 | # pylint: disable=access-member-before-definition
191 | if isinstance(self.lat, str):
192 | self.lat = str2int(self.lat)
193 |
194 | def __str__(self):
195 | result = '{}({})'.format(self.name, ', '.join(map(str, self.idx)))
196 | if self.lat is not None:
197 | result += ' ~{}'.format(self.lat)
198 | return result
199 |
200 | class BinaryOp(Node):
201 | LINEAR_ATTRS = 'operand', 'operator'
202 | def __str__(self):
203 | result = str(self.operand[0])
204 | for operator, operand in zip(self.operator, self.operand[1:]):
205 | result += ' {} {}'.format(operator, operand)
206 | if self.singleton:
207 | return result
208 | return parenthesize(result)
209 |
210 | @property
211 | def haoda_type(self):
212 | # TODO: derive from all operands
213 | return self.operand[0].haoda_type
214 |
215 | @property
216 | def c_expr(self):
217 | result = self.operand[0].c_expr
218 | for operator, operand in zip(self.operator, self.operand[1:]):
219 | result += ' {} {}'.format(operator, operand.c_expr)
220 | if self.singleton:
221 | return result
222 | return parenthesize(result)
223 |
224 | @property
225 | def singleton(self) -> bool:
226 | return len(self.operand) == 1
227 |
228 | class Expr(BinaryOp):
229 | pass
230 |
231 | class LogicAnd(BinaryOp):
232 | pass
233 |
234 | class BinaryOr(BinaryOp):
235 | pass
236 |
237 | class Xor(BinaryOp):
238 | pass
239 |
240 | class BinaryAnd(BinaryOp):
241 | pass
242 |
243 | class EqCmp(BinaryOp):
244 | pass
245 |
246 | class LtCmp(BinaryOp):
247 | pass
248 |
249 | class AddSub(BinaryOp):
250 | pass
251 |
252 | class MulDiv(BinaryOp):
253 | pass
254 |
255 | class Unary(Node):
256 | SCALAR_ATTRS = ('operand',)
257 | LINEAR_ATTRS = ('operator',)
258 | def __str__(self):
259 | return ''.join(self.operator)+str(self.operand)
260 |
261 | @property
262 | def haoda_type(self):
263 | return self.operand.haoda_type
264 |
265 | @property
266 | def c_expr(self):
267 | return ''.join(self.operator)+self.operand.c_expr
268 |
269 | class Operand(Node):
270 | SCALAR_ATTRS = 'cast', 'call', 'ref', 'num', 'var', 'expr'
271 | def __str__(self):
272 | for attr in ('cast', 'call', 'ref', 'num', 'var'):
273 | if getattr(self, attr) is not None:
274 | return str(getattr(self, attr))
275 | # pylint: disable=useless-else-on-loop
276 | else:
277 | return parenthesize(self.expr)
278 |
279 | @property
280 | def c_expr(self):
281 | for attr in ('cast', 'call', 'ref', 'num', 'var'):
282 | attr = getattr(self, attr)
283 | if attr is not None:
284 | if hasattr(attr, 'c_expr'):
285 | return attr.c_expr
286 | return str(attr)
287 | # pylint: disable=useless-else-on-loop
288 | else:
289 | return parenthesize(self.expr.c_expr)
290 |
291 | @property
292 | def haoda_type(self):
293 | for attr in self.ATTRS:
294 | val = getattr(self, attr)
295 | if val is not None:
296 | if hasattr(val, 'haoda_type'):
297 | return val.haoda_type
298 | if attr == 'num':
299 | if 'u' in val.lower():
300 | if 'll' in val.lower():
301 | return 'uint64'
302 | return 'uint32'
303 | if 'll' in val.lower():
304 | return 'int64'
305 | if 'fl' in val.lower():
306 | return 'double'
307 | if 'f' in val.lower() or 'e' in val.lower():
308 | return 'float'
309 | if '.' in val:
310 | return 'double'
311 | return 'int32'
312 | return None
313 | raise util.InternalError('undefined Operand')
314 |
315 | class Cast(Node):
316 | SCALAR_ATTRS = 'haoda_type', 'expr'
317 | def __str__(self):
318 | return '{}{}'.format(self.haoda_type, parenthesize(self.expr))
319 |
320 | @property
321 | def c_expr(self):
322 | return 'static_cast<{} >{}'.format(self.c_type,
323 | parenthesize(self.expr.c_expr))
324 |
325 | class Call(Node):
326 | SCALAR_ATTRS = ('name',)
327 | LINEAR_ATTRS = ('arg',)
328 | def __str__(self):
329 | return '{}({})'.format(self.name, ', '.join(map(str, self.arg)))
330 |
331 | @property
332 | def haoda_type(self):
333 | if self.name in ('select',):
334 | return self.arg[1].haoda_type
335 | return self.arg[0].haoda_type
336 |
337 | @property
338 | def c_expr(self):
339 | return '{}({})'.format(self.name, ', '.join(_.c_expr for _ in self.arg))
340 |
341 | class Var(Node):
342 | SCALAR_ATTRS = ('name',)
343 | LINEAR_ATTRS = ('idx',)
344 | def __str__(self):
345 | return self.name+''.join(map('[{}]'.format, self.idx))
346 |
347 | @property
348 | def c_expr(self):
349 | return self.name+''.join(map('[{}]'.format, self.idx))
350 |
351 | class FIFO(Node):
352 | """A reference to another node in a haoda.ir.Expr.
353 |
354 | This is used to represent a read/write from/to a Module in an output's Expr.
355 | It replaces Ref in haoda.ir, which is used to represent an element
356 | reference to a tensor.
357 |
358 | Attributes:
359 | read_module: Module reading from this FIFO.
360 | read_lat: int, at what cycle of a pipelined loop it is being read.
361 | write_module: Module writing to this FIFO.
362 | write_lat: int, at what cycle of a pipelined loop it is being written.
363 | depth: int, FIFO depth.
364 | """
365 | IMMUTABLE_ATTRS = 'read_module', 'write_module'
366 | SCALAR_ATTRS = 'read_module', 'read_lat', 'write_module', 'write_lat', 'depth'
367 |
368 | def __init__(self, write_module, read_module,
369 | depth=None, write_lat=None, read_lat=None):
370 | super().__init__(write_module=write_module, read_module=read_module,
371 | depth=depth, write_lat=write_lat, read_lat=read_lat)
372 |
373 | def __repr__(self):
374 | return 'fifo[%d]: %s%s => %s%s' % (self.depth, repr(self.write_module),
375 | '' if self.write_lat is None else ' ~%s'%self.write_lat,
376 | repr(self.read_module),
377 | '' if self.read_lat is None else ' ~%s'%self.read_lat)
378 |
379 | def __hash__(self):
380 | return hash(tuple(getattr(self, _) for _ in self.IMMUTABLE_ATTRS))
381 |
382 | def __eq__(self, other):
383 | return all(getattr(self, _) == getattr(other, _)
384 | for _ in type(self).IMMUTABLE_ATTRS)
385 | @property
386 | def edge(self):
387 | return self.write_module, self.read_module
388 |
389 | @property
390 | def haoda_type(self):
391 | return self.write_module.exprs[self].haoda_type
392 |
393 | @property
394 | def c_expr(self):
395 | return 'from_{}_to_{}'.format(self.write_module.name, self.read_module.name)
396 |
397 | class Module():
398 | """A node in the dataflow graph.
399 |
400 | This is the base class for a dataflow module. It defines the parent (input)
401 | nodes, children (output) nodes, output expressions, input schedules, and
402 | output schedules. It also has a name to help identify itself.
403 |
404 | Attributes:
405 | parents: Set of parent (input) Module.
406 | children: Set of child (output) Module.
407 | lets: List of haoda.ir.Let expressions.
408 | exprs: Dict of {FIFO: haoda.ir.Expr}, stores an output's expression.
409 | """
410 | def __init__(self):
411 | """Initializes attributes into empty list or dict.
412 | """
413 | self.parents = []
414 | self.children = []
415 | self.lets = []
416 | self.exprs = collections.OrderedDict()
417 |
418 | @property
419 | def name(self):
420 | return 'module_%u' % hash(self)
421 |
422 | @property
423 | def fifos(self):
424 | return tuple(self.exprs.keys())
425 |
426 | @property
427 | def fifo_dict(self):
428 | return {(self, fifo.read_module): fifo for fifo in self.exprs}
429 |
430 | def fifo(self, dst_node):
431 | return self.fifo_dict[(self, dst_node)]
432 |
433 | def get_latency(self, dst_node):
434 | return self.fifo(dst_node).write_lat or 0
435 |
436 | def visit_loads(self, callback, args=None):
437 | obj = copy.copy(self)
438 | obj.lets = tuple(_.visit(callback, args) for _ in self.lets)
439 | obj.exprs = collections.OrderedDict()
440 | for fifo in self.exprs:
441 | obj.exprs[fifo] = self.exprs[fifo].visit(callback, args)
442 | return obj
443 |
444 | @property
445 | def dram_reads(self):
446 | return self._interfaces['dram_reads']
447 |
448 | @property
449 | def dram_writes(self):
450 | return self._interfaces['dram_writes']
451 |
452 | @property
453 | def input_fifos(self):
454 | return self._interfaces['input_fifos']
455 |
456 | @property
457 | def output_fifos(self):
458 | return self._interfaces['output_fifos']
459 |
460 | @cached_property.cached_property
461 | def _interfaces(self):
462 | # find dram reads
463 | reads_in_lets = tuple(_.expr for _ in self.lets)
464 | reads_in_exprs = tuple(self.exprs.values())
465 | dram_reads = collections.OrderedDict()
466 | for dram_ref in visitor.get_dram_refs(reads_in_lets + reads_in_exprs):
467 | for bank in dram_ref.dram:
468 | dram_reads[(dram_ref.var, bank)] = (dram_ref, bank)
469 | dram_reads = tuple(dram_reads.values())
470 |
471 | # find dram writes
472 | writes_in_lets = tuple(_.name for _ in self.lets
473 | if not isinstance(_.name, str))
474 | dram_writes = collections.OrderedDict()
475 | for dram_ref in visitor.get_dram_refs(writes_in_lets):
476 | for bank in dram_ref.dram:
477 | dram_writes[(dram_ref.var, bank)] = (dram_ref, bank)
478 | dram_writes = tuple(dram_writes.values())
479 |
480 | output_fifos = tuple(_.c_expr for _ in self.exprs)
481 | input_fifos = tuple(_.c_expr for _ in visitor.get_read_fifo_set(self))
482 |
483 |
484 | return {
485 | 'dram_writes' : dram_writes,
486 | 'output_fifos' : output_fifos,
487 | 'input_fifos' : input_fifos,
488 | 'dram_reads' : dram_reads
489 | }
490 |
491 | def __str__(self):
492 | return '%s @ 0x%x: %s' % (type(self).__name__, id(self),
493 | self.__dict__)
494 |
495 | def __repr__(self):
496 | return '%s @ 0x%x' % (type(self).__name__, id(self))
497 |
498 | def add_child(self, child):
499 | """Add a child (low level).
500 |
501 | This method only handles children and parents field; lets and exprs are
502 | not updated.
503 |
504 | Arguments:
505 | child: Module, child being added
506 | """
507 | if child not in self.children:
508 | self.children.append(child)
509 | if self not in child.parents:
510 | child.parents.append(self)
511 |
512 | def bfs_node_gen(self):
513 | """BFS over descendant nodes.
514 |
515 | This method is a BFS traversal generator over all descendant nodes.
516 | """
517 | node_queue = collections.deque([self])
518 | seen_nodes = {self}
519 | while node_queue:
520 | node = node_queue.popleft()
521 | yield node
522 | for child in node.children:
523 | if child not in seen_nodes:
524 | node_queue.append(child)
525 | seen_nodes.add(child)
526 |
527 | def dfs_node_gen(self):
528 | """DFS over descendant nodes.
529 |
530 | This method is a DFS traversal generator over all descendant nodes.
531 | """
532 | node_stack = [self]
533 | seen_nodes = {self}
534 | while node_stack:
535 | node = node_stack.pop()
536 | yield node
537 | for child in node.children:
538 | if child not in seen_nodes:
539 | node_stack.append(child)
540 | seen_nodes.add(child)
541 |
542 | def tpo_node_gen(self):
543 | """Traverse descendant nodes in topological order.
544 |
545 | This method is a generator that traverses all descendant nodes in
546 | topological order.
547 | """
548 | nodes = collections.OrderedDict()
549 | for node in self.bfs_node_gen():
550 | nodes[node] = len(node.parents)
551 | while nodes:
552 | for node in nodes:
553 | if nodes[node] == 0:
554 | yield node
555 | for child in node.children:
556 | nodes[child] -= 1
557 | del nodes[node]
558 | break
559 | else:
560 | return
561 |
562 | def bfs_edge_gen(self):
563 | """BFS over descendant edges.
564 |
565 | This method is a BFS traversal generator over all descendant edges.
566 | """
567 | node_queue = collections.deque([self])
568 | seen_nodes = {self}
569 | while node_queue:
570 | node = node_queue.popleft()
571 | for child in node.children:
572 | yield node, child
573 | if child not in seen_nodes:
574 | node_queue.append(child)
575 | seen_nodes.add(child)
576 |
577 | def dfs_edge_gen(self):
578 | """DFS over descendant edges.
579 |
580 | This method is a DFS traversal generator over all descendant edges.
581 | """
582 | node_stack = [self]
583 | seen_nodes = {self}
584 | while node_stack:
585 | node = node_stack.pop()
586 | for child in node.children:
587 | yield node, child
588 | if child not in seen_nodes:
589 | node_stack.append(child)
590 | seen_nodes.add(child)
591 |
592 | def get_descendants(self):
593 | """Get all descendant nodes.
594 |
595 | This method returns all descendant nodes as a set.
596 |
597 | Returns:
598 | Set of descendant Module.
599 | """
600 | return {self}.union(*map(Module.get_descendants, self.children))
601 |
602 | def get_connections(self):
603 | """Get all descendant edges.
604 |
605 | This method returns all descendant edges as a set.
606 |
607 | Returns:
608 | Set of descendant (src Module, dst Module) tuple.
609 | """
610 | return ({(self, child) for child in self.children}
611 | .union(*map(Module.get_connections, self.children)))
612 |
613 |
614 | class DelayedRef(Node):
615 | """A delayed FIFO reference.
616 |
617 | Attributes:
618 | delay: int
619 | ref: FIFO
620 | """
621 | SCALAR_ATTRS = ('delay', 'ref')
622 | @property
623 | def haoda_type(self):
624 | return self.ref.haoda_type
625 |
626 | def __str__(self):
627 | return '%s delayed %d' % (self.ref, self.delay)
628 |
629 | def __repr__(self):
630 | return str(self)
631 |
632 | def __hash__(self):
633 | return hash((self.delay, self.ref))
634 |
635 | def __eq__(self, other):
636 | return all(getattr(self, attr) == getattr(other, attr)
637 | for attr in ('delay', 'ref'))
638 |
639 | @property
640 | def buf_name(self):
641 | return '{ref.c_expr}_delayed_{delay}_buf'.format(**self.__dict__)
642 |
643 | @property
644 | def ptr(self):
645 | return '{ref.c_expr}_delayed_{delay}_ptr'.format(**self.__dict__)
646 |
647 | @property
648 | def ptr_type(self):
649 | return 'uint%d' % int(math.log2(self.delay)+1)
650 |
651 | @property
652 | def c_expr(self):
653 | return '{ref.c_expr}_delayed_{delay}'.format(**self.__dict__)
654 |
655 | @property
656 | def c_ptr_type(self):
657 | return util.get_c_type(self.ptr_type)
658 |
659 | @property
660 | def c_ptr_decl(self):
661 | return '{} {} = 0;'.format(self.c_ptr_type, self.ptr)
662 |
663 | @property
664 | def c_buf_ref(self):
665 | return '{}[{}]'.format(self.buf_name, self.ptr)
666 |
667 | @property
668 | def c_buf_decl(self):
669 | return '{} {}[{}];'.format(self.c_type, self.buf_name, self.delay)
670 |
671 | @property
672 | def c_buf_load(self):
673 | return '{} = {};'.format(self.c_expr, self.c_buf_ref)
674 |
675 | @property
676 | def c_buf_store(self):
677 | return '{} = {};'.format(self.c_buf_ref, self.ref.ref_name)
678 |
679 | @property
680 | def c_next_ptr_expr(self):
681 | return '{ptr} < {depth} ? {c_ptr_type}({ptr}+1) : {c_ptr_type}(0)'.format(
682 | ptr=self.ptr, c_ptr_type=self.c_ptr_type, depth=self.delay-1)
683 |
684 | class FIFORef(Node):
685 | """A FIFO reference.
686 |
687 | Attributes:
688 | fifo: FIFO it is linked to
689 | lat: int, at what cycle of a pipelined loop it is being referenced.
690 | ref_id: int, reference id in the current scope
691 | Properties:
692 | c_type: str
693 | c_expr: str
694 | haoda_type: str
695 | ld_name: str
696 | st_name: str
697 | ref_name: str
698 | """
699 | SCALAR_ATTRS = ('fifo', 'lat', 'ref_id')
700 | LD_PREFIX = 'fifo_ld_'
701 | ST_PREFIX = 'fifo_st_'
702 | REF_PREFIX = 'fifo_ref_'
703 | def __str__(self):
704 | return '<%s fifo_ref_%d%s>' % (self.haoda_type, self.ref_id,
705 | '@%s'%self.lat if self.lat else '')
706 |
707 | def __repr__(self):
708 | return str(self)
709 |
710 | def __hash__(self):
711 | return hash((self.lat, self.ref_id))
712 |
713 | def __eq__(self, other):
714 | return all(getattr(self, attr) == getattr(other, attr)
715 | for attr in ('lat', 'ref_id'))
716 |
717 | @property
718 | def haoda_type(self):
719 | return self.fifo.haoda_type
720 |
721 | @property
722 | def ld_name(self):
723 | return '{LD_PREFIX}{ref_id}'.format(**self.__dict__, **type(self).__dict__)
724 |
725 | @property
726 | def ref_name(self):
727 | return '{REF_PREFIX}{ref_id}'.format(**self.__dict__, **type(self).__dict__)
728 |
729 | @property
730 | def c_expr(self):
731 | return self.ref_name
732 |
733 | class DRAMRef(Node):
734 | """A DRAM reference.
735 |
736 | Attributes:
737 | haoda_type: str
738 | dram: [int], DRAM id it is accessing
739 | var: str, variable name it is accessing
740 | offset: int
741 | """
742 | SCALAR_ATTRS = 'haoda_type', 'dram', 'var', 'offset'
743 | def __str__(self):
744 | return 'dram'.format(util.lst2str(self.dram),
745 | self.var, self.offset)
746 |
747 | def __repr__(self):
748 | return str(self)
749 |
750 | def __hash__(self):
751 | return hash((self.dram, self.offset))
752 |
753 | def __eq__(self, other):
754 | return all(getattr(self, attr) == getattr(other, attr)
755 | for attr in ('dram', 'offset'))
756 | @property
757 | def c_expr(self):
758 | return str(self)
759 |
760 | def dram_buf_name(self, bank):
761 | assert bank in self.dram, 'unexpected bank {}'.format(bank)
762 | return 'dram_{}_bank_{}_buf'.format(self.var, bank)
763 |
764 | def dram_fifo_name(self, bank):
765 | assert bank in self.dram, 'unexpected bank {}'.format(bank)
766 | return 'dram_{}_bank_{}_fifo'.format(self.var, bank)
767 |
768 | class ModuleTrait(Node):
769 | """A immutable, hashable trait of a dataflow module.
770 |
771 | Attributes:
772 | lets: tuple of lets
773 | exprs: tuple of exprs
774 | template_types: tuple of template types (TODO)
775 | template_ints: tuple of template ints (TODO)
776 |
777 | Properties:
778 | loads: tuple of FIFORefs
779 | """
780 | LINEAR_ATTRS = ('lets', 'exprs', 'template_types', 'template_ints')
781 |
782 | def __init__(self, node):
783 | def mutate(obj, loads):
784 | if isinstance(obj, FIFO):
785 | if loads:
786 | if obj not in loads:
787 | load_id = next(reversed(loads.values())).ref_id+1
788 | else:
789 | return loads[obj]
790 | else:
791 | load_id = 0
792 | fifo_ref = FIFORef(fifo=obj, lat=obj.read_lat, ref_id=load_id)
793 | loads[obj] = fifo_ref
794 | return fifo_ref
795 | return obj
796 | loads = collections.OrderedDict()
797 | node = node.visit_loads(mutate, loads)
798 | self.loads = tuple(loads.values())
799 | super().__init__(lets=tuple(node.lets), exprs=tuple(node.exprs.values()),
800 | template_types=tuple(), template_ints=tuple())
801 | _logger.debug('Signature: %s', self)
802 |
803 | def __repr__(self):
804 | return '%s(loads: %s, lets: %s, exprs: %s)' % (
805 | type(self).__name__,
806 | util.idx2str(self.loads),
807 | util.idx2str(self.lets),
808 | util.idx2str(self.exprs))
809 |
810 | @property
811 | def dram_reads(self):
812 | return self._interfaces['dram_reads']
813 |
814 | @property
815 | def dram_writes(self):
816 | return self._interfaces['dram_writes']
817 |
818 | @property
819 | def input_fifos(self):
820 | return self._interfaces['input_fifos']
821 |
822 | @property
823 | def output_fifos(self):
824 | return self._interfaces['output_fifos']
825 |
826 | @cached_property.cached_property
827 | def _interfaces(self):
828 | # find dram reads
829 | reads_in_lets = tuple(_.expr for _ in self.lets)
830 | reads_in_exprs = tuple(self.exprs)
831 | dram_reads = collections.OrderedDict()
832 | for dram_ref in visitor.get_dram_refs(reads_in_lets + reads_in_exprs):
833 | for bank in dram_ref.dram:
834 | dram_reads[(dram_ref.var, bank)] = (dram_ref, bank)
835 | dram_reads = tuple(dram_reads.values())
836 |
837 | # find dram writes
838 | writes_in_lets = tuple(_.name for _ in self.lets
839 | if not isinstance(_.name, str))
840 | dram_writes = collections.OrderedDict()
841 | for dram_ref in visitor.get_dram_refs(writes_in_lets):
842 | for bank in dram_ref.dram:
843 | dram_writes[(dram_ref.var, bank)] = (dram_ref, bank)
844 | dram_writes = tuple(dram_writes.values())
845 |
846 | output_fifos = tuple('{}{}'.format(FIFORef.ST_PREFIX, idx)
847 | for idx, expr in enumerate(self.exprs))
848 | input_fifos = tuple(_.ld_name for _ in self.loads)
849 |
850 | return {
851 | 'dram_writes' : dram_writes,
852 | 'output_fifos' : output_fifos,
853 | 'input_fifos' : input_fifos,
854 | 'dram_reads' : dram_reads
855 | }
856 |
857 | def make_var(val):
858 | """Make literal Var from val."""
859 | return Var(name=val, idx=())
860 |
861 | def str2int(s, none_val=None):
862 | if s is None:
863 | return none_val
864 | while s[-1] in 'UuLl':
865 | s = s[:-1]
866 | if s[0:2] == '0x' or s[0:2] == '0X':
867 | return int(s, 16)
868 | if s[0:2] == '0b' or s[0:2] == '0B':
869 | return int(s, 2)
870 | if s[0] == '0':
871 | return int(s, 8)
872 | return int(s)
873 |
874 | def parenthesize(expr) -> str:
875 | return '({})'.format(unparenthesize(expr))
876 |
877 | def unparenthesize(expr) -> str:
878 | expr_str = str(expr)
879 | while expr_str.startswith('(') and expr_str.endswith(')'):
880 | expr_str = expr_str[1:-1]
881 | return expr_str
882 |
883 | def get_result_type(operand1, operand2, operator):
884 | for t in ('double', 'float') + sum((('int%d_t'%w, 'uint%d_t'%w)
885 | for w in (64, 32, 16, 8)), tuple()):
886 | if t in (operand1, operand2):
887 | return t
888 | raise util.SemanticError('cannot parse type: %s %s %s' %
889 | (operand1, operator, operand2))
890 |
--------------------------------------------------------------------------------
/src/haoda/ir/arithmetic/__init__.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import logging
3 |
4 | from haoda.ir.arithmetic import base
5 |
6 | _logger = logging.getLogger().getChild(__name__)
7 |
8 | def simplify(expr):
9 | """Simplifies expressions.
10 |
11 | Args:
12 | expr: A haoda.ir.Node or a sequence of haoda.ir.Node.
13 |
14 | Returns:
15 | Simplified haoda.ir.Node or sequence.
16 | """
17 |
18 | if expr is None:
19 | _logger.debug('None expr, no simplification.')
20 | return expr
21 |
22 | passes = base.compose(
23 | base.flatten,
24 | base.print_tree)
25 |
26 | if isinstance(expr, collections.Iterable):
27 | return type(expr)(map(passes, expr))
28 |
29 | return passes(expr)
30 |
--------------------------------------------------------------------------------
/src/haoda/ir/arithmetic/base.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import logging
3 |
4 | from haoda import ir
5 | from haoda import util
6 |
7 | _logger = logging.getLogger().getChild(__name__)
8 |
9 | def compose(*funcs):
10 | """Composes functions. The first function in funcs are invoked the first.
11 | """
12 | # Somehow pylint gives false positive for f and g.
13 | # pylint: disable=undefined-variable
14 | return functools.reduce(lambda g, f: lambda x: f(g(x)), funcs, lambda x: x)
15 |
16 | def flatten(node: ir.Node) -> ir.Node:
17 | """Flattens an node if possible.
18 |
19 | Flattens an node if it is:
20 | + a singleton BinaryOp; or
21 | + a compound BinaryOp with reduction operators; or
22 | + a compound Operand; or
23 | + a Unary with an identity operator.
24 |
25 | An Operand is a compound Operand if and only if its attr is a ir.Node.
26 |
27 | A Unary has identity operator if and only if all its operators are '+' or '-',
28 | and the number of '-' is even; or all of its operators are '!' and the number
29 | of '!' is even.
30 |
31 | Args:
32 | node: ir.Node to flatten.
33 |
34 | Returns:
35 | node: flattened ir.Node.
36 |
37 | Raises:
38 | util.InternalError: if Operand is undefined.
39 | """
40 |
41 | def visitor(node, args=None):
42 | if isinstance(node, ir.BinaryOp):
43 |
44 | # Flatten singleton BinaryOp
45 | if len(node.operand) == 1:
46 | return flatten(node.operand[0])
47 |
48 | # Flatten BinaryOp with reduction operators
49 | new_operator, new_operand = [], []
50 | for child_operator, child_operand in zip((None, *node.operator),
51 | node.operand):
52 | if child_operator is not None:
53 | new_operator.append(child_operator)
54 | # The first operator can always be flattened if two operations has the
55 | # same type.
56 | if child_operator in (None, '||', '&&', *'|&+*') and \
57 | type(child_operand) is type(node):
58 | new_operator.extend(child_operand.operator)
59 | new_operand.extend(child_operand.operand)
60 | else:
61 | new_operand.append(child_operand)
62 | # At least 1 operand is flattened.
63 | if len(new_operand) > len(node.operand):
64 | return flatten(type(node)(operator=new_operator, operand=new_operand))
65 |
66 | # Flatten compound Operand
67 | if isinstance(node, ir.Operand):
68 | for attr in node.ATTRS:
69 | val = getattr(node, attr)
70 | if val is not None:
71 | if isinstance(val, ir.Node):
72 | return flatten(val)
73 | break
74 | else:
75 | raise util.InternalError('undefined Operand')
76 |
77 | # Flatten identity unary operators
78 | if isinstance(node, ir.Unary):
79 | minus_count = node.operator.count('-')
80 | if minus_count % 2 == 0:
81 | plus_count = node.operator.count('+')
82 | if plus_count + minus_count == len(node.operator):
83 | return flatten(node.operand)
84 | not_count = node.operator.count('!')
85 | if not_count % 2 == 0 and not_count == len(node.operator):
86 | return flatten(node.operand)
87 |
88 | return node
89 |
90 | if not isinstance(node, ir.Node):
91 | return node
92 |
93 | return node.visit(visitor)
94 |
95 | def print_tree(node, printer=_logger.debug):
96 | """Prints the node type as a tree.
97 |
98 | Args:
99 | node: ir.Node to print.
100 | args: Singleton list of the current tree height.
101 |
102 | Returns:
103 | node: Input ir.Node as-is.
104 | """
105 |
106 | def pre_recursion(node, args):
107 | args[0] += 1
108 |
109 | def post_recursion(node, args):
110 | args[0] -= 1
111 |
112 | def visitor(node, args):
113 | printer('%s+-%s: %s' % (' ' * args[0], type(node).__name__, node))
114 |
115 | if not isinstance(node, ir.Node):
116 | return node
117 |
118 | printer('root')
119 | return node.visit(visitor, args=[1], pre_recursion=pre_recursion,
120 | post_recursion=post_recursion)
121 |
122 | def propagate_type(node, symbol_table):
123 | def visitor(node, symbol_table):
124 | if node.haoda_type is None:
125 | if isinstance(node, (ir.Ref, ir.Var)):
126 | node.haoda_type = symbol_table[node.name]
127 | return node
128 | return node.visit(visitor, symbol_table)
129 |
--------------------------------------------------------------------------------
/src/haoda/ir/visitor.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | from haoda import ir
4 |
5 | def get_dram_refs(obj):
6 | """Get all DRAM references as a tuple.
7 |
8 | Args:
9 | obj: A haoda.ir.Node object or an Iterable of haoda.ir.Node objects.
10 |
11 | Returns:
12 | A tuple of all DRAM references.
13 |
14 | Raises:
15 | TypeError: If obj is not an IR node or a sequence.
16 | """
17 | def visitor(obj, args):
18 | if isinstance(obj, ir.DRAMRef):
19 | args.append(obj)
20 | return obj
21 | if isinstance(obj, collections.Iterable):
22 | return sum(map(get_dram_refs, obj), ())
23 | dram_refs = []
24 | if isinstance(obj, ir.Node):
25 | obj.visit(visitor, dram_refs)
26 | else:
27 | raise TypeError('argument is not an IR node or a sequence')
28 | return tuple(dram_refs)
29 |
30 | def get_read_fifo_set(module):
31 | """Get all read FIFOs as a tuple. Each FIFO only appears once.
32 |
33 | Args:
34 | module: A haoda.ir.Module object.
35 |
36 | Returns:
37 | A tuple of all FIFOs that are read in the module.
38 |
39 | Raises:
40 | TypeError: If argument is not a module.
41 | """
42 | def visitor(obj, args):
43 | if isinstance(obj, ir.FIFO):
44 | args[obj] = None
45 | return obj
46 | fifo_loads = collections.OrderedDict()
47 | if isinstance(module, ir.Module):
48 | module.visit_loads(visitor, fifo_loads)
49 | else:
50 | raise TypeError('argument is not a module')
51 | return tuple(fifo_loads)
52 |
--------------------------------------------------------------------------------
/src/haoda/util.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import logging
3 | import signal
4 |
5 | # constants
6 | COORDS_TILED = 'xyzw'
7 | COORDS_IN_TILE = 'ijkl'
8 | COORDS_IN_ORIG = 'pqrs'
9 | TYPE_WIDTH = {
10 | 'float': 32,
11 | 'double': 64,
12 | 'half': 16
13 | }
14 | MAX_DRAM_BANK = 4
15 |
16 | _logger = logging.getLogger().getChild(__name__)
17 |
18 | class InternalError(Exception):
19 | pass
20 |
21 | class SemanticError(Exception):
22 | pass
23 |
24 | class SemanticWarn(Exception):
25 | pass
26 |
27 | class Printer():
28 | def __init__(self, out):
29 | self._out = out
30 | self._indent = 0
31 | self._assign = 0
32 | self._comments = []
33 | self._tab = 2
34 |
35 | def println(self, line='', indent=-1):
36 | if indent < 0:
37 | indent = self._indent
38 | if line:
39 | self._out.write('%s%s\n' % (' '*indent*self._tab, line))
40 | else:
41 | self._out.write('\n')
42 |
43 | def do_indent(self):
44 | self._indent += 1
45 |
46 | def un_indent(self):
47 | self._indent -= 1
48 |
49 | def do_scope(self, comment=''):
50 | self.println('{')
51 | self.do_indent()
52 | self._comments.append(comment)
53 |
54 | def un_scope(self, comment='', suffix=''):
55 | self.un_indent()
56 | popped_comment = self._comments.pop()
57 | if comment:
58 | self.println('}%s // %s' % (suffix, comment))
59 | else:
60 | if popped_comment:
61 | self.println('}%s // %s' % (suffix, popped_comment))
62 | else:
63 | self.println('}%s' % suffix)
64 |
65 | def new_var(self):
66 | self._assign += 1
67 | return self.last_var()
68 |
69 | def last_var(self, offset=-1):
70 | return 'assign_%d' % (self._assign+offset)
71 |
72 | def print_func(self, name, params, suffix='', align=80):
73 | lines = [name+'(']
74 | for param in params:
75 | if ((self._indent + min(1, len(lines)-1))*self._tab+
76 | len(lines[-1])+len(param+', ')) > align:
77 | lines.append(param+', ')
78 | else:
79 | lines[-1] += param+', '
80 | if lines[-1][-2:] == ', ':
81 | lines[-1] = lines[-1][:-2]+')'+suffix
82 | line = lines.pop(0)
83 | self.println(line)
84 | if lines:
85 | self.do_indent()
86 | for line in lines:
87 | self.println(line)
88 | self.un_indent()
89 |
90 | @contextlib.contextmanager
91 | def for_(self, *args):
92 | if len(args) == 3:
93 | self.println('for ({}; {}; {}) {{'.format(*args))
94 | elif len(args) == 2:
95 | self.println('for ({} : {}) {{'.format(*args))
96 | else:
97 | raise InternalError('for_ takes 2 or 3 arguments')
98 | self.do_indent()
99 | yield
100 | self.un_indent()
101 | self.println('}')
102 |
103 | @contextlib.contextmanager
104 | def do_while(self, cond):
105 | self.println('do {')
106 | self.do_indent()
107 | yield
108 | self.un_indent()
109 | self.println('}} while ({});'.format(cond))
110 |
111 | @contextlib.contextmanager
112 | def if_(self, cond):
113 | self.println('if ({}) {{'.format(cond))
114 | self.do_indent()
115 | yield
116 | self.un_indent()
117 | self.println('}')
118 |
119 | @contextlib.contextmanager
120 | def elif_(self, cond):
121 | self.un_indent()
122 | self.println('}} else if ({}) {{'.format(cond))
123 | self.do_indent()
124 | yield
125 |
126 | @contextlib.contextmanager
127 | def else_(self):
128 | self.un_indent()
129 | self.println('} else {')
130 | self.do_indent()
131 | yield
132 |
133 | def print_define(printer, var, val):
134 | printer.println('#ifndef %s' % var)
135 | printer.println('#define %s %d' % (var, val))
136 | printer.println('#endif//%s' % var)
137 |
138 | def print_guard(printer, var, val):
139 | printer.println('#ifdef %s' % var)
140 | printer.println('#if %s != %d' % (var, val))
141 | printer.println('#error %s != %d' % (var, val))
142 | printer.println('#endif//%s != %d' % (var, val))
143 | printer.println('#endif//%s' % var)
144 |
145 | def get_c_type(haoda_type):
146 | if haoda_type in {
147 | 'uint8', 'uint16', 'uint32', 'uint64',
148 | 'int8', 'int16', 'int32', 'int64'}:
149 | return haoda_type+'_t'
150 | if haoda_type is None:
151 | return None
152 | if haoda_type == 'float32':
153 | return 'float'
154 | if haoda_type == 'float64':
155 | return 'double'
156 | for token in ('int', 'uint'):
157 | if haoda_type.startswith(token):
158 | return 'ap_{}<{}>'.format(token, haoda_type.replace(token, ''))
159 | return haoda_type
160 |
161 | def get_haoda_type(c_type):
162 | return c_type[:-2] if c_type[-2:] == '_t' else c_type
163 |
164 | def get_width_in_bits(haoda_type):
165 | if isinstance(haoda_type, str):
166 | if haoda_type in TYPE_WIDTH:
167 | return TYPE_WIDTH[haoda_type]
168 | for prefix in 'uint', 'int', 'float':
169 | if haoda_type.startswith(prefix):
170 | return int(haoda_type.lstrip(prefix).split('_')[0])
171 | else:
172 | if hasattr(haoda_type, 'haoda_type'):
173 | return get_width_in_bits(haoda_type.haoda_type)
174 | raise InternalError('unknown haoda type: %s' % haoda_type)
175 |
176 | def get_width_in_bytes(haoda_type):
177 | return (get_width_in_bits(haoda_type)-1)//8+1
178 |
179 | def is_float(haoda_type):
180 | return haoda_type in {'half', 'double'} or haoda_type.startswith('float')
181 |
182 | def idx2str(idx):
183 | return '(%s)' % ', '.join(map(str, idx))
184 |
185 | def lst2str(idx):
186 | return '[%s]' % ', '.join(map(str, idx))
187 |
188 | def get_module_name(module_id):
189 | return 'module_%d' % module_id
190 |
191 | def get_func_name(module_id):
192 | return 'Module%dFunc' % module_id
193 |
194 | get_port_name = lambda name, bank: 'bank_{}_{}'.format(bank, name)
195 | get_port_buf_name = lambda name, bank: 'bank_{}_{}_buf'.format(bank, name)
196 | def get_bundle_name(name, bank):
197 | return '{}_bank_{}'.format(name.replace('<', '_').replace('>', ''), bank)
198 |
199 | def pause_for_debugging():
200 | if _logger.isEnabledFor(logging.DEBUG):
201 | try:
202 | _logger.debug('pausing for debugging... send Ctrl-C to resume')
203 | signal.pause()
204 | except KeyboardInterrupt:
205 | pass
206 |
--------------------------------------------------------------------------------
/src/soda/codegen/xilinx/header.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from haoda import util
4 |
5 | logger = logging.getLogger().getChild(__name__)
6 |
7 | def print_code(stencil, header_file):
8 | logger.info('generate host header code as %s' % header_file.name)
9 | printer = util.Printer(header_file)
10 | println = printer.println
11 | do_indent = printer.do_indent
12 | un_indent = printer.un_indent
13 | println('#ifndef HALIDE_%s_H_' % stencil.app_name.upper())
14 | println('#define HALIDE_%s_H_' % stencil.app_name.upper())
15 | println()
16 |
17 | println('#ifndef HALIDE_ATTRIBUTE_ALIGN')
18 | do_indent()
19 | println('#ifdef _MSC_VER')
20 | do_indent()
21 | println('#define HALIDE_ATTRIBUTE_ALIGN(x) __declspec(align(x))')
22 | un_indent()
23 | println('#else')
24 | do_indent()
25 | println('#define HALIDE_ATTRIBUTE_ALIGN(x) __attribute__((aligned(x)))')
26 | un_indent()
27 | println('#endif')
28 | un_indent()
29 | println('#endif//HALIDE_ATTRIBUTE_ALIGN')
30 | println()
31 |
32 | println('#ifndef BUFFER_T_DEFINED')
33 | println('#define BUFFER_T_DEFINED')
34 | println('#include')
35 | println('#include')
36 | println('typedef struct buffer_t {')
37 | do_indent()
38 | println('uint64_t dev;')
39 | println('uint8_t* host;')
40 | println('int32_t extent[4];')
41 | println('int32_t stride[4];')
42 | println('int32_t min[4];')
43 | println('int32_t elem_size;')
44 | println('HALIDE_ATTRIBUTE_ALIGN(1) bool host_dirty;')
45 | println('HALIDE_ATTRIBUTE_ALIGN(1) bool dev_dirty;')
46 | println('HALIDE_ATTRIBUTE_ALIGN(1) uint8_t _padding[10 - sizeof(void *)];')
47 | un_indent()
48 | println('} buffer_t;')
49 | println('#endif//BUFFER_T_DEFINED')
50 | println()
51 |
52 | println('#ifndef HALIDE_FUNCTION_ATTRS')
53 | println('#define HALIDE_FUNCTION_ATTRS')
54 | println('#endif//HALIDE_FUNCTION_ATTRS')
55 | println()
56 |
57 | tensors = stencil.input_names + stencil.output_names + stencil.param_names
58 | println('int {}({}const char* xclbin) HALIDE_FUNCTION_ATTRS;'.format(
59 | stencil.app_name,
60 | ''.join(map('buffer_t *var_{}_buffer, '.format, tensors))))
61 | println()
62 |
63 | println('#endif//HALIDE_%s_H_' % stencil.app_name.upper())
64 | println()
65 |
--------------------------------------------------------------------------------
/src/soda/codegen/xilinx/hls_kernel.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import functools
3 | import logging
4 | import operator
5 |
6 | from haoda import ir
7 | from haoda import util
8 | from haoda.ir import visitor
9 |
10 | _logger = logging.getLogger().getChild(__name__)
11 |
12 | def _print_interface(printer, kernel_name, inputs, outputs, super_source):
13 | """Prints the top-level module for the given arguments.
14 |
15 | Prints the top-level interfaces and sub-module instances with proper interface
16 | pragmas, hls::stream declarations and references, and module function calls.
17 | Currently only streaming applications are supported.
18 |
19 | Args:
20 | printer: Printer to which the code is emitted.
21 | kernel_name: str, name of the kernel.
22 | inputs: Sequence of (name, c_type, bank, depth) tuples, specifies the m_axi
23 | input interfaces.
24 | outputs: Sequence of (name, c_type, bank, depth) tuples, specifies the m_axi
25 | output interfaces.
26 | super_source: SuperSourceNode of a DAG of HAODA nodes.
27 | """
28 | println = printer.println
29 | do_indent = printer.do_indent
30 | un_indent = printer.un_indent
31 | do_scope = printer.do_scope
32 | un_scope = printer.un_scope
33 |
34 | get_bundle_name = util.get_bundle_name
35 | get_port_name = util.get_port_name
36 | get_port_buf_name = util.get_port_buf_name
37 |
38 | println('extern "C"')
39 | println('{')
40 | println()
41 | println('void %s(' % kernel_name)
42 | do_indent()
43 | for name, c_type, bank, _ in outputs + inputs:
44 | println('{}* {},'.format(c_type, get_port_name(name, bank)))
45 | println('uint64_t coalesced_data_num)')
46 | un_indent()
47 | do_scope()
48 |
49 | for name, c_type, bank, depth in outputs + inputs:
50 | println('#pragma HLS interface m_axi port={} offset=slave depth={} bundle={'
51 | '}'.format(get_port_name(name, bank), depth,
52 | get_bundle_name(name, bank)), 0)
53 |
54 | println()
55 | for name, _, bank, _ in outputs + inputs:
56 | println('#pragma HLS interface s_axilite port={} bundle=control'.format(
57 | get_port_name(name, bank)), 0)
58 |
59 | println('#pragma HLS interface s_axilite port=coalesced_data_num '
60 | 'bundle=control', 0)
61 | println('#pragma HLS interface s_axilite port=return bundle=control', 0)
62 | println()
63 |
64 | # port buf declarations
65 | for name, c_type, bank, _ in inputs + outputs:
66 | println('hls::stream> {name}("{name}");'.format(
67 | name=get_port_buf_name(name, bank), c_type=c_type))
68 | # port buf depths
69 | println('#pragma HLS stream variable={} depth=32'.format(
70 | get_port_buf_name(name, bank)), 0)
71 | println('#pragma HLS data_pack variable={}'.format(
72 | get_port_buf_name(name, bank)), indent=0)
73 | println()
74 |
75 | # internal fifos
76 | for node in super_source.tpo_node_gen():
77 | for fifo in node.fifos:
78 | println('hls::stream> {1}("{1}");'.format(fifo.c_type,
79 | fifo.c_expr))
80 | println('#pragma HLS stream variable={} depth={}'.format(
81 | fifo.c_expr,
82 | max(fifo.depth, 512 // util.get_width_in_bits(fifo.haoda_type))), 0)
83 | println('#pragma HLS data_pack variable={}'.format(fifo.c_expr),
84 | indent=0)
85 |
86 | println()
87 |
88 | println('#pragma HLS dataflow', 0)
89 | for name, _, bank, _ in inputs:
90 | println('BurstRead(&{}, {}, coalesced_data_num);'.format(
91 | get_port_buf_name(name, bank), get_port_name(name, bank)))
92 |
93 | for node in super_source.tpo_node_gen():
94 | module_trait_id = super_source.module_table[node][1]
95 | _print_module_func_call(printer, node, module_trait_id)
96 |
97 | for name, _, bank, _ in outputs:
98 | println('BurstWrite({}, &{}, coalesced_data_num);'.format(
99 | get_port_name(name, bank), get_port_buf_name(name, bank)))
100 |
101 | un_scope()
102 | println()
103 | println('}//extern "C"')
104 |
105 | def print_header(printer):
106 | println = printer.println
107 | for header in ['float', 'math', 'stdbool', 'stddef', 'stdint', 'stdio',
108 | 'string', 'ap_int', 'hls_stream']:
109 | println('#include<%s.h>' % header)
110 | println()
111 |
112 | def _print_burst_read(printer):
113 | println = printer.println
114 | do_scope = printer.do_scope
115 | un_scope = printer.un_scope
116 | println('void BurstRead(hls::stream>>* to, ap_uint<'
117 | 'BURST_WIDTH>* from, uint64_t data_num)')
118 | do_scope()
119 | println('load_epoch:', 0)
120 | println('for (uint64_t epoch = 0; epoch < data_num;)')
121 | do_scope()
122 | println('#pragma HLS pipeline II=1', 0)
123 | println('const uint64_t next_epoch = epoch + 1;')
124 | println('WriteData(to, from[epoch], next_epoch < data_num);')
125 | println('epoch = next_epoch;')
126 | un_scope()
127 | un_scope()
128 |
129 | def _print_burst_write(printer):
130 | println = printer.println
131 | do_scope = printer.do_scope
132 | un_scope = printer.un_scope
133 | println('void BurstWrite(ap_uint* to, hls::stream>>* from, uint64_t data_num)')
135 | do_scope()
136 | println('store_epoch:', 0)
137 | println('for (uint64_t epoch = 0; epoch < data_num; ++epoch)')
138 | do_scope()
139 | println('#pragma HLS pipeline II=1', 0)
140 | println('ap_uint buf;')
141 | println('ReadData(&buf, from);')
142 | println('to[epoch] = buf;')
143 | un_scope()
144 | un_scope()
145 |
146 | def print_code(stencil, output_file):
147 | _logger.info('generate kernel code as %s' % output_file.name)
148 | printer = util.Printer(output_file)
149 |
150 | print_header(printer)
151 |
152 | printer.println()
153 |
154 | util.print_define(printer, 'BURST_WIDTH', stencil.burst_width)
155 | printer.println()
156 |
157 | util.print_guard(printer, 'UNROLL_FACTOR', stencil.unroll_factor)
158 | for i in range(len(stencil.tile_size)-1):
159 | util.print_guard(printer, 'TILE_SIZE_DIM_%d' % i, stencil.tile_size[i])
160 | util.print_guard(printer, 'BURST_WIDTH', stencil.burst_width)
161 | printer.println()
162 |
163 | _print_data_struct(printer)
164 | _print_reinterpret(printer)
165 | _print_read_data(printer)
166 | _print_write_data(printer)
167 |
168 | _print_burst_read(printer)
169 | _print_burst_write(printer)
170 |
171 | for module_trait_id, module_trait in enumerate(stencil.module_traits):
172 | _print_module_definition(printer, module_trait, module_trait_id,
173 | burst_width=stencil.burst_width)
174 |
175 | outputs = []
176 | inputs = []
177 | for stmt in stencil.output_stmts:
178 | for bank in sorted(stmt.dram):
179 | outputs.append((stmt.name, 'ap_uint<%d>' % stencil.burst_width, bank,
180 | 65536))
181 | for stmt in stencil.input_stmts:
182 | for bank in sorted(stmt.dram):
183 | inputs.append((stmt.name, 'ap_uint<%d>' % stencil.burst_width, bank,
184 | 65536))
185 | for stmt in stencil.param_stmts:
186 | inputs.append(('var_%s' % stmt.name, stmt.type, 0,
187 | functools.reduce(operator.mul, stmt.size)))
188 | _print_interface(printer, stencil.app_name + '_kernel', inputs, outputs,
189 | stencil.dataflow_super_source)
190 |
191 | def _print_module_func_call(printer, node, module_trait_id, **kwargs):
192 | println = printer.println
193 | print_func = printer.print_func
194 | func_name = util.get_func_name(module_trait_id)
195 |
196 | dram_reads = tuple(
197 | '/* input*/ &' + util.get_port_buf_name(dram_ref.var, bank)
198 | for dram_ref, bank in node.dram_reads)
199 | dram_writes = tuple(
200 | '/*output*/ &' + util.get_port_buf_name(dram_ref.var, bank)
201 | for dram_ref, bank in node.dram_writes)
202 | output_fifos = tuple('/*output*/ &' + _ for _ in node.output_fifos)
203 | input_fifos = tuple('/* input*/ &' + _ for _ in node.input_fifos)
204 | params = dram_writes + output_fifos + input_fifos + dram_reads
205 |
206 | print_func(func_name, params, suffix=';', align=0)
207 |
208 | # pylint: disable=too-many-branches,too-many-statements
209 | def _print_module_definition(printer, module_trait, module_trait_id, **kwargs):
210 | println = printer.println
211 | do_scope = printer.do_scope
212 | un_scope = printer.un_scope
213 | func_name = util.get_func_name(module_trait_id)
214 | func_lower_name = util.get_module_name(module_trait_id)
215 | ii = 1
216 |
217 | def get_delays(obj, delays):
218 | if isinstance(obj, ir.DelayedRef):
219 | delays.append(obj)
220 | return obj
221 | delays = []
222 | for let in module_trait.lets:
223 | let.visit(get_delays, delays)
224 | for expr in module_trait.exprs:
225 | expr.visit(get_delays, delays)
226 | _logger.debug('delays: %s', delays)
227 |
228 | fifo_loads = tuple('/* input*/ hls::stream>* {}'.format(
229 | _.c_type, _.ld_name) for _ in module_trait.loads)
230 | fifo_stores = tuple('/*output*/ hls::stream>* {}{}'.format(
231 | expr.c_type, ir.FIFORef.ST_PREFIX, idx)
232 | for idx, expr in enumerate(module_trait.exprs))
233 |
234 | # look for DRAM access
235 | reads_in_lets = tuple(_.expr for _ in module_trait.lets)
236 | writes_in_lets = tuple(_.name for _ in module_trait.lets
237 | if not isinstance(_.name, str))
238 | reads_in_exprs = module_trait.exprs
239 | dram_reads = visitor.get_dram_refs(reads_in_lets + reads_in_exprs)
240 | dram_writes = visitor.get_dram_refs(writes_in_lets)
241 | dram_read_map = collections.OrderedDict()
242 | dram_write_map = collections.OrderedDict()
243 | all_dram_reads = ()
244 | num_bank_map = {}
245 | if dram_reads: # this is an unpacking module
246 | assert not dram_writes, 'cannot read and write DRAM in the same module'
247 | for dram_read in dram_reads:
248 | dram_read_map.setdefault(dram_read.var,
249 | collections.OrderedDict()).setdefault(
250 | dram_read.dram, []).append(dram_read)
251 | _logger.debug('dram read map: %s', dram_read_map)
252 | burst_width = kwargs.pop('burst_width')
253 | for var in dram_read_map:
254 | for dram in dram_read_map[var]:
255 | # number of elements per cycle
256 | batch_size = len(dram_read_map[var][dram])
257 | dram_read_map[var][dram] = collections.OrderedDict(
258 | (_.offset, _) for _ in dram_read_map[var][dram])
259 | dram_reads = dram_read_map[var][dram]
260 | num_banks = len(next(iter(dram_reads.values())).dram)
261 | if var in num_bank_map:
262 | assert num_bank_map[var] == num_banks, 'inconsistent num banks'
263 | else:
264 | num_bank_map[var] = num_banks
265 | _logger.debug('dram reads: %s', dram_reads)
266 | assert tuple(sorted(dram_reads.keys())) == tuple(range(batch_size)), \
267 | 'unexpected DRAM accesses pattern %s' % dram_reads
268 | batch_width = sum(util.get_width_in_bits(_.haoda_type)
269 | for _ in dram_reads.values())
270 | del dram_reads
271 | if burst_width * num_banks >= batch_width:
272 | assert burst_width * num_banks % batch_width == 0, \
273 | 'cannot process such a burst'
274 | # a single burst consumed in multiple cycles
275 | coalescing_factor = burst_width * num_banks // batch_width
276 | ii = coalescing_factor
277 | else:
278 | assert batch_width * num_banks % burst_width == 0, \
279 | 'cannot process such a burst'
280 | # multiple bursts consumed in a single cycle
281 | # reassemble_factor = batch_width // (burst_width * num_banks)
282 | raise util.InternalError('cannot process such a burst yet')
283 | dram_reads = tuple(next(iter(_.values()))
284 | for _ in dram_read_map[var].values())
285 | all_dram_reads += dram_reads
286 | fifo_loads += tuple(
287 | '/* input*/ hls::stream>>* '
288 | '{bank_name}'.format(
289 | burst_width=burst_width, bank_name=_.dram_fifo_name(bank))
290 | for _ in dram_reads for bank in _.dram)
291 | elif dram_writes: # this is a packing module
292 | for dram_write in dram_writes:
293 | dram_write_map.setdefault(dram_write.var,
294 | collections.OrderedDict()).setdefault(
295 | dram_write.dram, []).append(dram_write)
296 | _logger.debug('dram write map: %s', dram_write_map)
297 | burst_width = kwargs.pop('burst_width')
298 | for var in dram_write_map:
299 | for dram in dram_write_map[var]:
300 | # number of elements per cycle
301 | batch_size = len(dram_write_map[var][dram])
302 | dram_write_map[var][dram] = collections.OrderedDict(
303 | (_.offset, _) for _ in dram_write_map[var][dram])
304 | dram_writes = dram_write_map[var][dram]
305 | num_banks = len(next(iter(dram_writes.values())).dram)
306 | if var in num_bank_map:
307 | assert num_bank_map[var] == num_banks, 'inconsistent num banks'
308 | else:
309 | num_bank_map[var] = num_banks
310 | _logger.debug('dram writes: %s', dram_writes)
311 | assert tuple(sorted(dram_writes.keys())) == tuple(range(batch_size)), \
312 | 'unexpected DRAM accesses pattern %s' % dram_writes
313 | batch_width = sum(util.get_width_in_bits(_.haoda_type)
314 | for _ in dram_writes.values())
315 | del dram_writes
316 | if burst_width * num_banks >= batch_width:
317 | assert burst_width * num_banks % batch_width == 0, \
318 | 'cannot process such a burst'
319 | # a single burst consumed in multiple cycles
320 | coalescing_factor = burst_width * num_banks // batch_width
321 | ii = coalescing_factor
322 | else:
323 | assert batch_width * num_banks % burst_width == 0, \
324 | 'cannot process such a burst'
325 | # multiple bursts consumed in a single cycle
326 | # reassemble_factor = batch_width // (burst_width * num_banks)
327 | raise util.InternalError('cannot process such a burst yet')
328 | dram_writes = tuple(next(iter(_.values()))
329 | for _ in dram_write_map[var].values())
330 | fifo_stores += tuple(
331 | '/*output*/ hls::stream>>* '
332 | '{bank_name}'.format(
333 | burst_width=burst_width, bank_name=_.dram_fifo_name(bank))
334 | for _ in dram_writes for bank in _.dram)
335 |
336 | # print function
337 | printer.print_func('void {func_name}'.format(**locals()),
338 | fifo_stores+fifo_loads, align=0)
339 | do_scope(func_name)
340 |
341 | for dram_ref, bank in module_trait.dram_writes:
342 | println('#pragma HLS data_pack variable = {}'.format(
343 | dram_ref.dram_fifo_name(bank)), 0)
344 | for arg in module_trait.output_fifos:
345 | println('#pragma HLS data_pack variable = %s' % arg, 0)
346 | for arg in module_trait.input_fifos:
347 | println('#pragma HLS data_pack variable = %s' % arg, 0)
348 | for dram_ref, bank in module_trait.dram_reads:
349 | println('#pragma HLS data_pack variable = {}'.format(
350 | dram_ref.dram_fifo_name(bank)), 0)
351 |
352 | # print inter-iteration declarations
353 | for delay in delays:
354 | println(delay.c_buf_decl)
355 | println(delay.c_ptr_decl)
356 |
357 | # print loop
358 | println('{}_epoch:'.format(func_lower_name), indent=0)
359 | println('for (bool enable = true; enable;)')
360 | do_scope('for {}_epoch'.format(func_lower_name))
361 | println('#pragma HLS pipeline II=%d' % ii, 0)
362 | for delay in delays:
363 | println('#pragma HLS dependence variable=%s inter false' %
364 | delay.buf_name, 0)
365 |
366 | # print emptyness tests
367 | println('if (%s)' % (' && '.join(
368 | '!{fifo}->empty()'.format(fifo=fifo)
369 | for fifo in tuple(_.ld_name for _ in module_trait.loads) +
370 | tuple(_.dram_fifo_name(bank)
371 | for _ in all_dram_reads for bank in _.dram))))
372 | do_scope('if not empty')
373 |
374 | # print intra-iteration declarations
375 | for fifo_in in module_trait.loads:
376 | println('{fifo_in.c_type} {fifo_in.ref_name};'.format(**locals()))
377 | for var in dram_read_map:
378 | for dram in (next(iter(_.values())) for _ in dram_read_map[var].values()):
379 | for bank in dram.dram:
380 | println('ap_uint<{}> {};'.format(burst_width, dram.dram_buf_name(bank)))
381 | for var in dram_write_map:
382 | for dram in (next(iter(_.values())) for _ in dram_write_map[var].values()):
383 | for bank in dram.dram:
384 | println('ap_uint<{}> {};'.format(burst_width, dram.dram_buf_name(bank)))
385 |
386 | # print enable conditions
387 | if not dram_write_map:
388 | for fifo_in in module_trait.loads:
389 | println('const bool {fifo_in.ref_name}_enable = '
390 | 'ReadData(&{fifo_in.ref_name}, {fifo_in.ld_name});'.format(**locals()))
391 | for dram in all_dram_reads:
392 | for bank in dram.dram:
393 | println('const bool {dram_buf_name}_enable = '
394 | 'ReadData(&{dram_buf_name}, {dram_fifo_name});'.format(
395 | dram_buf_name=dram.dram_buf_name(bank),
396 | dram_fifo_name=dram.dram_fifo_name(bank)))
397 | if not dram_write_map:
398 | println('const bool enabled = %s;' % (
399 | ' && '.join(tuple('{_.ref_name}_enable'.format(_=_)
400 | for _ in module_trait.loads) +
401 | tuple('{}_enable'.format(_.dram_buf_name(bank))
402 | for _ in all_dram_reads for bank in _.dram))))
403 | println('enable = enabled;')
404 |
405 | # print delays (if any)
406 | for delay in delays:
407 | println('const {} {};'.format(delay.c_type, delay.c_buf_load))
408 |
409 | # print lets
410 | def mutate_dram_ref_for_writes(obj, kwargs):
411 | if isinstance(obj, ir.DRAMRef):
412 | coalescing_idx = kwargs.pop('coalescing_idx')
413 | unroll_factor = kwargs.pop('unroll_factor')
414 | type_width = util.get_width_in_bits(obj.haoda_type)
415 | elem_idx = coalescing_idx * unroll_factor + obj.offset
416 | num_banks = num_bank_map[obj.var]
417 | bank = obj.dram[elem_idx % num_banks]
418 | lsb = (elem_idx // num_banks) * type_width
419 | msb = lsb + type_width - 1
420 | return ir.Var(name='{}({msb}, {lsb})'.format(
421 | obj.dram_buf_name(bank), msb=msb, lsb=lsb), idx=())
422 | return obj
423 |
424 | # mutate dram ref for writes
425 | if dram_write_map:
426 | for coalescing_idx in range(coalescing_factor):
427 | for fifo_in in module_trait.loads:
428 | if coalescing_idx == coalescing_factor - 1:
429 | prefix = 'const bool {fifo_in.ref_name}_enable = '.format(
430 | fifo_in=fifo_in)
431 | else:
432 | prefix = ''
433 | println('{prefix}ReadData(&{fifo_in.ref_name},'
434 | ' {fifo_in.ld_name});'.format(fifo_in=fifo_in, prefix=prefix))
435 | if coalescing_idx == coalescing_factor - 1:
436 | println('const bool enabled = %s;' % (
437 | ' && '.join(tuple('{_.ref_name}_enable'.format(_=_)
438 | for _ in module_trait.loads) +
439 | tuple('{}_enable'.format(_.dram_buf_name(bank))
440 | for _ in dram_reads for bank in _.dram))))
441 | println('enable = enabled;')
442 | for idx, let in enumerate(module_trait.lets):
443 | let = let.visit(mutate_dram_ref_for_writes, {
444 | 'coalescing_idx': coalescing_idx, 'unroll_factor': len(
445 | dram_write_map[let.name.var][let.name.dram])})
446 | println('{} = Reinterpret>({});'.format(
447 | let.name, let.expr.c_expr,
448 | width=util.get_width_in_bits(let.expr.haoda_type)))
449 | for var in dram_write_map:
450 | for dram in (next(iter(_.values()))
451 | for _ in dram_write_map[var].values()):
452 | for bank in dram.dram:
453 | println('WriteData({}, {}, enabled);'.format(
454 | dram.dram_fifo_name(bank), dram.dram_buf_name(bank)))
455 | else:
456 | for let in module_trait.lets:
457 | println(let.c_expr)
458 |
459 | def mutate_dram_ref_for_reads(obj, kwargs):
460 | if isinstance(obj, ir.DRAMRef):
461 | coalescing_idx = kwargs.pop('coalescing_idx')
462 | unroll_factor = kwargs.pop('unroll_factor')
463 | type_width = util.get_width_in_bits(obj.haoda_type)
464 | elem_idx = coalescing_idx * unroll_factor + obj.offset
465 | num_banks = num_bank_map[obj.var]
466 | bank = expr.dram[elem_idx % num_banks]
467 | lsb = (elem_idx // num_banks) * type_width
468 | msb = lsb + type_width - 1
469 | return ir.Var(
470 | name='Reinterpret<{c_type}>(static_cast>('
471 | '{dram_buf_name}({msb}, {lsb})))'.format(
472 | c_type=obj.c_type, dram_buf_name=obj.dram_buf_name(bank),
473 | msb=msb, lsb=lsb, width=msb-lsb+1), idx=())
474 | return obj
475 |
476 | # mutate dram ref for reads
477 | if dram_read_map:
478 | for coalescing_idx in range(coalescing_factor):
479 | for idx, expr in enumerate(module_trait.exprs):
480 | println('WriteData({}{}, {}, {});'.format(
481 | ir.FIFORef.ST_PREFIX, idx,
482 | expr.visit(mutate_dram_ref_for_reads, {
483 | 'coalescing_idx': coalescing_idx, 'unroll_factor': len(
484 | dram_read_map[expr.var][expr.dram])}).c_expr,
485 | 'true' if coalescing_idx < coalescing_factor - 1 else 'enabled'))
486 | else:
487 | for idx, expr in enumerate(module_trait.exprs):
488 | println('WriteData({}{}, {}({}), enabled);'.format(
489 | ir.FIFORef.ST_PREFIX, idx, expr.c_type, expr.c_expr))
490 |
491 | for delay in delays:
492 | println(delay.c_buf_store)
493 | println('{} = {};'.format(delay.ptr, delay.c_next_ptr_expr))
494 |
495 | un_scope()
496 | un_scope()
497 | un_scope()
498 | _logger.debug('printing: %s', module_trait)
499 |
500 | def _print_data_struct(printer):
501 | println = printer.println
502 | println('template struct Data')
503 | printer.do_scope()
504 | println('T data;')
505 | println('bool ctrl;')
506 | printer.un_scope(suffix=';')
507 |
508 | def _print_reinterpret(printer):
509 | println = printer.println
510 | println('template')
511 | println('inline To Reinterpret(const From& val)')
512 | printer.do_scope()
513 | println('return reinterpret_cast(val);')
514 | printer.un_scope()
515 |
516 | def _print_read_data(printer):
517 | println = printer.println
518 | println('template inline bool ReadData'
519 | '(T* data, hls::stream>* from)')
520 | printer.do_scope()
521 | println('#pragma HLS inline', indent=0)
522 | println('const Data& tmp = from->read();')
523 | println('*data = tmp.data;')
524 | println('return tmp.ctrl;')
525 | printer.un_scope()
526 |
527 | def _print_write_data(printer):
528 | println = printer.println
529 | println('template inline void WriteData'
530 | '(hls::stream>* to, const T& data, bool ctrl)')
531 | printer.do_scope()
532 | println('#pragma HLS inline', indent=0)
533 | println('Data tmp;')
534 | println('tmp.data = data;')
535 | println('tmp.ctrl = ctrl;')
536 | println('to->write(tmp);')
537 | printer.un_scope()
538 |
--------------------------------------------------------------------------------
/src/soda/codegen/xilinx/opencl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import sys
4 | import tempfile
5 |
6 | from soda.codegen.xilinx import header
7 | from soda.codegen.xilinx import host
8 | from soda.codegen.xilinx import hls_kernel as kernel
9 | from soda.codegen.xilinx import rtl_kernel
10 |
11 | def add_arguments(parser):
12 | parser.add_argument(
13 | '--xocl', type=str, dest='output_dir', metavar='dir', nargs='?', const='',
14 | help='directory to generate kernel and host code; default names are'
15 | 'used; default to the current working directory; may be overridden by '
16 | '--xocl-header, --xocl-host, or --xocl-kernel')
17 | parser.add_argument(
18 | '--xocl-header', type=str, dest='header_file', metavar='file',
19 | help='host C++ header code; overrides --xocl')
20 | parser.add_argument(
21 | '--xocl-host', type=str, dest='host_file', metavar='file',
22 | help='host C++ source code for the Xilinx OpenCL flow; overrides --xocl')
23 | parser.add_argument(
24 | '--xocl-kernel', type=str, dest='kernel_file', metavar='file',
25 | help='Vivado HLS C++ kernel code for the Xilinx OpenCL flow; overrides '
26 | '--xocl')
27 | parser.add_argument(
28 | '--xocl-platform', type=str, dest='xocl_platform', metavar='dir',
29 | help='SDAccel platform directory of the Xilinx OpenCL flow')
30 | parser.add_argument('--xocl-hw-xo', type=str, dest='xo_file', metavar='file',
31 | help='hardware object file for the Xilinx OpenCL flow')
32 |
33 | def print_code(stencil, args):
34 | if args.kernel_file is not None:
35 | with tempfile.TemporaryFile(mode='w+') as tmp:
36 | kernel.print_code(stencil, tmp)
37 | tmp.seek(0)
38 | if args.kernel_file == '-':
39 | shutil.copyfileobj(tmp, sys.stdout)
40 | else:
41 | with open(args.kernel_file, 'w') as kernel_file:
42 | shutil.copyfileobj(tmp, kernel_file)
43 |
44 | if args.host_file is not None:
45 | with tempfile.TemporaryFile(mode='w+') as tmp:
46 | host.print_code(stencil, tmp)
47 | tmp.seek(0)
48 | if args.host_file == '-':
49 | shutil.copyfileobj(tmp, sys.stdout)
50 | else:
51 | with open(args.host_file, 'w') as host_file:
52 | shutil.copyfileobj(tmp, host_file)
53 |
54 | if args.header_file is not None:
55 | with tempfile.TemporaryFile(mode='w+') as tmp:
56 | header.print_code(stencil, tmp)
57 | tmp.seek(0)
58 | if args.header_file == '-':
59 | shutil.copyfileobj(tmp, sys.stdout)
60 | else:
61 | with open(args.header_file, 'w') as header_file:
62 | shutil.copyfileobj(tmp, header_file)
63 |
64 | if args.xo_file is not None:
65 | with tempfile.TemporaryFile(mode='w+b') as tmp:
66 | rtl_kernel.print_code(stencil, tmp, platform=args.xocl_platform)
67 | tmp.seek(0)
68 | if args.xo_file == '-':
69 | shutil.copyfileobj(tmp, sys.stdout)
70 | else:
71 | with open(args.xo_file, 'wb') as xo_file:
72 | shutil.copyfileobj(tmp, xo_file)
73 |
74 | if args.output_dir is not None and (args.kernel_file is None or
75 | args.host_file is None or
76 | args.header_file is None):
77 | if args.kernel_file is None:
78 | dram_in = args.dram_in if args.dram_in else '_'
79 | dram_out = args.dram_out if args.dram_out else '_'
80 | kernel_file_name = os.path.join(
81 | args.output_dir, '%s_kernel-tile%s-unroll%d-ddr%s.cpp' % (
82 | stencil.app_name,
83 | 'x'.join('%d'%x for x in stencil.tile_size[:-1]),
84 | stencil.unroll_factor, dram_in + '-' + dram_out))
85 | else:
86 | kernel_file_name = args.kernel_file
87 | with tempfile.TemporaryFile(mode='w+') as tmp:
88 | kernel.print_code(stencil, tmp)
89 | tmp.seek(0)
90 | with open(kernel_file_name, 'w') as kernel_file:
91 | shutil.copyfileobj(tmp, kernel_file)
92 | if args.host_file is None:
93 | host_file_name = os.path.join(args.output_dir, stencil.app_name + '.cpp')
94 | else:
95 | host_file_name = args.host_file
96 | with tempfile.TemporaryFile(mode='w+') as tmp:
97 | host.print_code(stencil, tmp)
98 | tmp.seek(0)
99 | with open(host_file_name, 'w') as host_file:
100 | shutil.copyfileobj(tmp, kernel_file)
101 | if args.header_file is None:
102 | header_file_name = os.path.join(args.output_dir, stencil.app_name + '.h')
103 | else:
104 | header_file_name = args.header_file
105 | with tempfile.TemporaryFile(mode='w+') as tmp:
106 | header.print_code(stencil, tmp)
107 | tmp.seek(0)
108 | with open(header_file_name, 'w') as header_file:
109 | shutil.copyfileobj(tmp, header_file)
110 |
--------------------------------------------------------------------------------
/src/soda/codegen/xilinx/rtl_kernel.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import concurrent
3 | import logging
4 | import os
5 | import shutil
6 | import sys
7 | import tarfile
8 | import tempfile
9 |
10 | from haoda import util
11 | from haoda.backend import xilinx as backend
12 | from soda.codegen.xilinx import hls_kernel
13 |
14 | _logger = logging.getLogger().getChild(__name__)
15 |
16 | def print_code(stencil, xo_file, platform=None, jobs=os.cpu_count()):
17 | """Generate hardware object file for the given Stencil.
18 |
19 | Working `vivado` and `vivado_hls` is required in the PATH.
20 |
21 | Args:
22 | stencil: Stencil object to generate from.
23 | xo_file: file object to write to.
24 | platform: path to the SDAccel platform directory.
25 | jobs: maximum number of jobs running in parallel.
26 | """
27 |
28 | m_axi_names = []
29 | m_axi_bundles = []
30 | inputs = []
31 | outputs = []
32 | for stmt in stencil.output_stmts + stencil.input_stmts:
33 | for bank in stmt.dram:
34 | haoda_type = 'uint%d' % stencil.burst_width
35 | bundle_name = util.get_bundle_name(stmt.name, bank)
36 | m_axi_names.append(bundle_name)
37 | m_axi_bundles.append((bundle_name, haoda_type))
38 |
39 | for stmt in stencil.output_stmts:
40 | for bank in stmt.dram:
41 | haoda_type = 'uint%d' % stencil.burst_width
42 | bundle_name = util.get_bundle_name(stmt.name, bank)
43 | outputs.append((util.get_port_name(stmt.name, bank), bundle_name,
44 | haoda_type, util.get_port_buf_name(stmt.name, bank)))
45 | for stmt in stencil.input_stmts:
46 | for bank in stmt.dram:
47 | haoda_type = 'uint%d' % stencil.burst_width
48 | bundle_name = util.get_bundle_name(stmt.name, bank)
49 | inputs.append((util.get_port_name(stmt.name, bank), bundle_name,
50 | haoda_type, util.get_port_buf_name(stmt.name, bank)))
51 |
52 | top_name = stencil.app_name + '_kernel'
53 |
54 | if 'XDEVICE' in os.environ:
55 | xdevice = os.environ['XDEVICE'].replace(':', '_').replace('.', '_')
56 | if platform is None or not os.path.exists(platform):
57 | platform = os.path.join('/opt/xilinx/platforms', xdevice)
58 | if platform is None or not os.path.exists(platform):
59 | if 'XILINX_SDX' in os.environ:
60 | platform = os.path.join(os.environ['XILINX_SDX'], 'platforms', xdevice)
61 | if platform is None or not os.path.exists(platform):
62 | raise ValueError('Cannot determine platform from environment.')
63 | device_info = backend.get_device_info(platform)
64 |
65 | with tempfile.TemporaryDirectory(prefix='sodac-xrtl-') as tmpdir:
66 | dataflow_kernel = os.path.join(tmpdir, 'dataflow_kernel.cpp')
67 | with open(dataflow_kernel, 'w') as dataflow_kernel_obj:
68 | print_dataflow_hls_interface(
69 | util.Printer(dataflow_kernel_obj), top_name, inputs, outputs)
70 |
71 | kernel_xml = os.path.join(tmpdir, 'kernel.xml')
72 | with open(kernel_xml, 'w') as kernel_xml_obj:
73 | backend.print_kernel_xml(top_name, outputs + inputs, kernel_xml_obj)
74 |
75 | kernel_file = os.path.join(tmpdir, 'kernel.cpp')
76 | with open(kernel_file, 'w') as kernel_fileobj:
77 | hls_kernel.print_code(stencil, kernel_fileobj)
78 |
79 | super_source = stencil.dataflow_super_source
80 | with concurrent.futures.ThreadPoolExecutor(max_workers=jobs) as executor:
81 | threads = []
82 | for module_id in range(len(super_source.module_traits)):
83 | threads.append(executor.submit(
84 | synthesis_module, tmpdir, [kernel_file],
85 | util.get_func_name(module_id), device_info))
86 | threads.append(executor.submit(
87 | synthesis_module, tmpdir, [dataflow_kernel], top_name, device_info))
88 | for future in concurrent.futures.as_completed(threads):
89 | returncode, stdout, stderr = future.result()
90 | log_func = _logger.error if returncode != 0 else _logger.debug
91 | if stdout:
92 | log_func(stdout.decode())
93 | if stderr:
94 | log_func(stderr.decode())
95 | if returncode != 0:
96 | util.pause_for_debugging()
97 | sys.exit(returncode)
98 |
99 | hdl_dir = os.path.join(tmpdir, 'hdl')
100 | with open(os.path.join(hdl_dir, 'Dataflow.v'), mode='w') as dataflow_v:
101 | print_top_module(backend.VerilogPrinter(dataflow_v),
102 | stencil.dataflow_super_source, inputs, outputs)
103 |
104 | util.pause_for_debugging()
105 |
106 | xo_filename = os.path.join(tmpdir, stencil.app_name + '.xo')
107 | with backend.PackageXo(xo_filename, top_name, kernel_xml, hdl_dir,
108 | m_axi_names, [dataflow_kernel]) as proc:
109 | stdout, stderr = proc.communicate()
110 | log_func = _logger.error if proc.returncode != 0 else _logger.debug
111 | log_func(stdout.decode())
112 | log_func(stderr.decode())
113 | with open(xo_filename, mode='rb') as xo_fileobj:
114 | shutil.copyfileobj(xo_fileobj, xo_file)
115 |
116 | def synthesis_module(tmpdir, kernel_files, module_name, device_info):
117 | """Synthesis a module in kernel files.
118 |
119 | Returns:
120 | (returncode, stdout, stderr) results of the subprocess.
121 | """
122 | with tempfile.TemporaryFile(mode='w+b') as tarfileobj:
123 | with backend.RunHls(
124 | tarfileobj, kernel_files, module_name, device_info['clock_period'],
125 | device_info['part_num']) as proc:
126 | stdout, stderr = proc.communicate()
127 | if proc.returncode == 0:
128 | tarfileobj.seek(0)
129 | with tarfile.open(mode='r', fileobj=tarfileobj) as tar:
130 | tar.extractall(tmpdir, filter(lambda _: _.name.startswith('hdl'),
131 | tar.getmembers()))
132 | return proc.returncode, stdout, stderr
133 |
134 | FIFO_PORT_SUFFIXES = dict(
135 | data_in='_din',
136 | not_full='_full_n',
137 | write_enable='_write',
138 | data_out='_dout',
139 | not_empty='_empty_n',
140 | read_enable='_read',
141 | not_block='_blk_n')
142 |
143 |
144 | def print_top_module(printer, super_source, inputs, outputs):
145 | println = printer.println
146 | println('`timescale 1 ns / 1 ps')
147 | args = ['ap_clk', 'ap_rst', 'ap_start', 'ap_done', 'ap_continue', 'ap_idle',
148 | 'ap_ready']
149 | for port_name, _, _, _ in outputs:
150 | args.append('{}_V_V{data_in}'.format(port_name, **FIFO_PORT_SUFFIXES))
151 | args.append('{}_V_V{not_full}'.format(port_name, **FIFO_PORT_SUFFIXES))
152 | args.append('{}_V_V{write_enable}'.format(port_name, **FIFO_PORT_SUFFIXES))
153 | for port_name, _, _, _ in inputs:
154 | args.append('{}_V_V{data_out}'.format(port_name, **FIFO_PORT_SUFFIXES))
155 | args.append('{}_V_V{not_empty}'.format(port_name, **FIFO_PORT_SUFFIXES))
156 | args.append('{}_V_V{read_enable}'.format(port_name, **FIFO_PORT_SUFFIXES))
157 | printer.module('Dataflow', args)
158 | println()
159 |
160 | input_args = 'ap_clk', 'ap_rst', 'ap_start', 'ap_continue'
161 | output_args = 'ap_done', 'ap_idle', 'ap_ready'
162 |
163 | for arg in input_args:
164 | println('input %s;' % arg)
165 | for arg in output_args:
166 | println('output %s;' % arg)
167 | for port_name, _, haoda_type, _ in outputs:
168 | kwargs = dict(port_name=port_name, **FIFO_PORT_SUFFIXES)
169 | println('output [{}:0] {port_name}_V_V{data_in};'.format(
170 | util.get_width_in_bits(haoda_type) - 1, **kwargs))
171 | println('input {port_name}_V_V{not_full};'.format(**kwargs))
172 | println('output {port_name}_V_V{write_enable};'.format(**kwargs))
173 | for port_name, _, haoda_type, _ in inputs:
174 | kwargs = dict(port_name=port_name, **FIFO_PORT_SUFFIXES)
175 | println('input [{}:0] {port_name}_V_V{data_out};'.format(
176 | util.get_width_in_bits(haoda_type) - 1, **kwargs))
177 | println('input {port_name}_V_V{not_empty};'.format(**kwargs))
178 | println('output {port_name}_V_V{read_enable};'.format(**kwargs))
179 | println()
180 |
181 | println("reg ap_done = 1'b0;")
182 | println("reg ap_idle = 1'b1;")
183 | println("reg ap_ready = 1'b0;")
184 |
185 | for port_name, _, haoda_type, _ in outputs:
186 | kwargs = dict(port_name=port_name, **FIFO_PORT_SUFFIXES)
187 | println('reg [{}:0] {port_name}{data_in};'.format(
188 | util.get_width_in_bits(haoda_type) - 1, **kwargs))
189 | println('wire {port_name}_V_V{write_enable};'.format(**kwargs))
190 | for port_name, _, haoda_type, _ in inputs:
191 | println('wire {}_V_V{read_enable};'.format(port_name, **FIFO_PORT_SUFFIXES))
192 | println('reg ap_rst_n_inv;')
193 | with printer.always('*'):
194 | println('ap_rst_n_inv = ap_rst;')
195 | println()
196 |
197 | with printer.always('posedge ap_clk'):
198 | with printer.if_('ap_rst'):
199 | println("ap_done <= 1'b0;")
200 | println("ap_idle <= 1'b1;")
201 | println("ap_ready <= 1'b0;")
202 | printer.else_()
203 | println('ap_idle <= ~ap_start;')
204 |
205 | for port_name, _, _, _ in outputs:
206 | println('reg {}_V_V{not_block};'.format(port_name, **FIFO_PORT_SUFFIXES))
207 | for port_name, _, _, _ in inputs:
208 | println('reg {}_V_V{not_block};'.format(port_name, **FIFO_PORT_SUFFIXES))
209 |
210 | with printer.always('*'):
211 | for port_name, _, _, _ in outputs:
212 | println('{port_name}_V_V{not_block} = {port_name}_V_V{not_full};'.format(
213 | port_name=port_name, **FIFO_PORT_SUFFIXES))
214 | for port_name, _, _, _ in inputs:
215 | println('{port_name}_V_V{not_block} = {port_name}_V_V{not_empty};'.format(
216 | port_name=port_name, **FIFO_PORT_SUFFIXES))
217 | println()
218 |
219 | for module in super_source.tpo_node_gen():
220 | for fifo in module.fifos:
221 | kwargs = {
222 | 'name' : fifo.c_expr,
223 | 'msb' : fifo.width_in_bits - 1,
224 | **FIFO_PORT_SUFFIXES
225 | }
226 | println('wire [{msb}:0] {name}{data_in};'.format(**kwargs))
227 | println('wire {name}{not_full};'.format(**kwargs))
228 | println('wire {name}{write_enable};'.format(**kwargs))
229 | println('wire [{msb}:0] {name}{data_out};'.format(**kwargs))
230 | println('wire {name}{not_empty};'.format(**kwargs))
231 | println('wire {name}{read_enable};'.format(**kwargs))
232 | println()
233 |
234 | args = collections.OrderedDict((
235 | ('clk', 'ap_clk'),
236 | ('reset', 'ap_rst_n_inv'),
237 | ('if_read_ce', "1'b1"),
238 | ('if_write_ce', "1'b1"),
239 | ('if{data_in}'.format(**kwargs),
240 | '{name}{data_in}'.format(**kwargs)),
241 | ('if{not_full}'.format(**kwargs),
242 | '{name}{not_full}'.format(**kwargs)),
243 | ('if{write_enable}'.format(**kwargs),
244 | '{name}{write_enable}'.format(**kwargs)),
245 | ('if{data_out}'.format(**kwargs),
246 | '{name}{data_out}'.format(**kwargs)),
247 | ('if{not_empty}'.format(**kwargs),
248 | '{name}{not_empty}'.format(**kwargs)),
249 | ('if{read_enable}'.format(**kwargs),
250 | '{name}{read_enable}'.format(**kwargs))
251 | ))
252 | printer.module_instance('fifo_w{width}_d{depth}_A'.format(
253 | width=fifo.width_in_bits, depth=fifo.depth+2), fifo.c_expr, args)
254 | println()
255 |
256 | for module in super_source.tpo_node_gen():
257 | module_trait, module_trait_id = super_source.module_table[module]
258 | args = collections.OrderedDict((('ap_clk', 'ap_clk'),
259 | ('ap_rst', 'ap_rst_n_inv'),
260 | ('ap_start', "1'b1")))
261 | for dram_ref, bank in module.dram_writes:
262 | kwargs = dict(port=dram_ref.dram_fifo_name(bank),
263 | fifo=util.get_port_name(dram_ref.var, bank),
264 | **FIFO_PORT_SUFFIXES)
265 | args['{port}_V{data_in}'.format(**kwargs)] = \
266 | '{fifo}_V_V{data_in}'.format(**kwargs)
267 | args['{port}_V{not_full}'.format(**kwargs)] = \
268 | '{fifo}_V_V{not_full}'.format(**kwargs)
269 | args['{port}_V{write_enable}'.format(**kwargs)] = \
270 | '{fifo}_V_V{write_enable}'.format(**kwargs)
271 | for port, fifo in zip(module_trait.output_fifos, module.output_fifos):
272 | kwargs = dict(port=port, fifo=fifo, **FIFO_PORT_SUFFIXES)
273 | args['{port}_V{data_in}'.format(**kwargs)] = \
274 | '{fifo}{data_in}'.format(**kwargs)
275 | args['{port}_V{not_full}'.format(**kwargs)] = \
276 | '{fifo}{not_full}'.format(**kwargs)
277 | args['{port}_V{write_enable}'.format(**kwargs)] = \
278 | '{fifo}{write_enable}'.format(**kwargs)
279 | for port, fifo in zip(module_trait.input_fifos, module.input_fifos):
280 | kwargs = dict(port=port, fifo=fifo, **FIFO_PORT_SUFFIXES)
281 | args['{port}_V{data_out}'.format(**kwargs)] = \
282 | "{{1'b1, {fifo}{data_out}}}".format(**kwargs)
283 | args['{port}_V{not_empty}'.format(**kwargs)] = \
284 | '{fifo}{not_empty}'.format(**kwargs)
285 | args['{port}_V{read_enable}'.format(**kwargs)] = \
286 | '{fifo}{read_enable}'.format(**kwargs)
287 | for dram_ref, bank in module.dram_reads:
288 | kwargs = dict(port=dram_ref.dram_fifo_name(bank),
289 | fifo=util.get_port_name(dram_ref.var, bank),
290 | **FIFO_PORT_SUFFIXES)
291 | args['{port}_V{data_out}'.format(**kwargs)] = \
292 | "{{1'b1, {fifo}_V_V{data_out}}}".format(**kwargs)
293 | args['{port}_V{not_empty}'.format(**kwargs)] = \
294 | '{fifo}_V_V{not_empty}'.format(**kwargs)
295 | args['{port}_V{read_enable}'.format(**kwargs)] = \
296 | '{fifo}_V_V{read_enable}'.format(**kwargs)
297 | printer.module_instance(util.get_func_name(module_trait_id), module.name,
298 | args)
299 | println()
300 | printer.endmodule()
301 |
302 | fifos = set()
303 | for module in super_source.tpo_node_gen():
304 | for fifo in module.fifos:
305 | fifos.add((fifo.width_in_bits, fifo.depth + 2))
306 | for fifo in fifos:
307 | printer.fifo_module(*fifo)
308 |
309 | def print_dataflow_hls_interface(printer, top_name, inputs, outputs):
310 | println = printer.println
311 | do_scope = printer.do_scope
312 | un_scope = printer.un_scope
313 | do_indent = printer.do_indent
314 | un_indent = printer.un_indent
315 | m_axi_ports = outputs + inputs
316 | print_func = printer.print_func
317 |
318 | println('#include ')
319 | println('#include ')
320 | println('#include ')
321 | println('#include ')
322 |
323 | println('template')
324 | print_func('void BurstRead', [
325 | 'hls::stream>* to', 'ap_uint* from',
326 | 'uint64_t data_num'], align=0)
327 | do_scope()
328 | println('load_epoch:', 0)
329 | with printer.for_('uint64_t epoch = 0', 'epoch < data_num', '++epoch'):
330 | println('#pragma HLS pipeline II=1', 0)
331 | println('to->write(from[epoch]);')
332 | un_scope()
333 |
334 | println('template')
335 | print_func('void BurstWrite', [
336 | 'ap_uint* to', 'hls::stream>* from',
337 | 'uint64_t data_num'], align=0)
338 | do_scope()
339 | println('store_epoch:', 0)
340 | with printer.for_('uint64_t epoch = 0', 'epoch < data_num', '++epoch'):
341 | println('#pragma HLS pipeline II=1', 0)
342 | println('to[epoch] = from->read();')
343 | un_scope()
344 |
345 | params = ['hls::stream<{}>* {}'.format(util.get_c_type(haoda_type), name)
346 | for name, _, haoda_type, _ in m_axi_ports]
347 | print_func('void Dataflow', params, align=0)
348 | do_scope()
349 | for name, _, haoda_type, _ in inputs:
350 | println('volatile {c_type} {name}_read;'.format(
351 | c_type=util.get_c_type(haoda_type), name=name))
352 | for name, _, haoda_type, _ in inputs:
353 | println('{name}_read = {name}->read();'.format(name=name))
354 | for name, _, haoda_type, _ in outputs:
355 | println(
356 | '{name}->write({c_type}());'.format(
357 | c_type=util.get_c_type(haoda_type), name=name))
358 | un_scope()
359 |
360 | params = ['{}* {}'.format(util.get_c_type(haoda_type), name)
361 | for name, _, haoda_type, _ in m_axi_ports]
362 | params.append('uint64_t coalesced_data_num')
363 | print_func('void %s' % top_name, params, align=0)
364 | do_scope()
365 |
366 | println('#pragma HLS dataflow', 0)
367 |
368 | for port_name, bundle_name, _, _ in m_axi_ports:
369 | println('#pragma HLS interface m_axi port={} offset=slave bundle={}'.format(
370 | port_name, bundle_name), 0)
371 | for port_name, _, _, _ in m_axi_ports:
372 | println('#pragma HLS interface s_axilite port={} bundle=control'.format(
373 | port_name), 0)
374 | println('#pragma HLS interface s_axilite port=coalesced_data_num '
375 | 'bundle=control', 0)
376 | println('#pragma HLS interface s_axilite port=return bundle=control', 0)
377 | println()
378 |
379 | for _, _, haoda_type, name in m_axi_ports:
380 | println('hls::stream<{c_type}> {name}("{name}");'.format(
381 | c_type=util.get_c_type(haoda_type), name=name))
382 | println('#pragma HLS stream variable={name} depth=32'.format(name=name), 0)
383 |
384 | for port_name, _, haoda_type, buf_name in inputs:
385 | print_func('BurstRead', [
386 | '&{name}'.format(name=buf_name), '{name}'.format(name=port_name),
387 | 'coalesced_data_num'], suffix=';', align=0)
388 |
389 | params = ['&{}'.format(name) for _, _, _, name in m_axi_ports]
390 | printer.print_func('Dataflow', params, suffix=';', align=0)
391 |
392 | for port_name, _, haoda_type, buf_name in outputs:
393 | print_func('BurstWrite', [
394 | '{name}'.format(name=port_name), '&{name}'.format(name=buf_name),
395 | 'coalesced_data_num'],
396 | suffix=';', align=0)
397 |
398 | un_scope()
399 |
--------------------------------------------------------------------------------
/src/soda/core.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import copy
3 | import itertools
4 | import logging
5 | import operator
6 |
7 | import cached_property
8 |
9 | from haoda import ir
10 | from haoda import util
11 | from haoda.ir import arithmetic
12 | from soda import dataflow
13 | from soda import grammar
14 | from soda import util as soda_util
15 | from soda import visitor
16 | from soda import mutator
17 |
18 | _logger = logging.getLogger().getChild(__name__)
19 |
20 | class Tensor():
21 | """A tensor that corresponse to an input, local, or output.
22 |
23 | This class is used in the high-level DAG for stencil dependency analysis.
24 | Each tensor either is an input tensor, or has at least 1 parent tensor, which
25 | will be used to generate this tensor. Meanwhile, each tensor either is an
26 | output tensor, or has at least 1 child tensor, which will be computed using
27 | this tensor.
28 |
29 | Attributes:
30 | haoda_type: str, type of the tensor element.
31 | parents: Dict from str of name of Tensor to Tensor.
32 | children: Dict from str of name of Tensor to Tensor.
33 | st_ref: Ref, name, index, and latency stored.
34 | offset: int, shift offset in terms of data elements
35 | lets: Lets of computation.
36 | expr: Expr of computation.
37 | ld_refs: Dict from str of name to dict of Ref loaded.
38 | ld_delays: Dict from str of name to extra delay of the input.
39 |
40 | Property:
41 | name: str, unique in each SODA program.
42 | st_offset: int, stencil offset in terms of data elements.
43 | st_idx, Tuple of int, the index referenced by its parent stage.
44 | ld_indices: Dict from str of name to dict of accessed indices of the input.
45 | ld_offsets: Dict from str of name to dict of offsets of the input.
46 | """
47 | def __init__(self, stmt, tile_size):
48 | self.haoda_type = stmt.haoda_type
49 | self._tile_size = tile_size
50 | if isinstance(stmt, grammar.LocalStmtOrOutputStmt):
51 | self.st_ref = copy.copy(stmt.ref)
52 | self.st_ref.parent = self
53 | self.lets = stmt.let
54 | self.expr = stmt.expr
55 | elif isinstance(stmt, grammar.InputStmt):
56 | self._name = stmt.name
57 | self.st_ref = None
58 | self.lets = []
59 | self.expr = None
60 | else:
61 | raise util.InternalError('cannot initialize a Tensor from %s' %
62 | type(stmt))
63 | _logger.debug('tensor initialized from stmt `%s`', stmt)
64 | # pylint: disable=protected-access
65 | _logger.debug(' at tx position %d', stmt._tx_position)
66 |
67 | # these fields are to be set externally
68 | self.st_delay = 0
69 | self.parents = collections.OrderedDict()
70 | self.children = collections.OrderedDict()
71 | self.ld_refs = collections.OrderedDict()
72 | self.ld_delays = collections.OrderedDict()
73 |
74 | @property
75 | def name(self):
76 | if self.st_ref is not None:
77 | return self.st_ref.name
78 | return self._name
79 |
80 | @property
81 | def st_idx(self):
82 | if self.st_ref is not None:
83 | return self.st_ref.idx
84 | return (0,)*len(self._tile_size)
85 |
86 | @property
87 | def st_offset(self):
88 | return soda_util.serialize(self.st_idx, self._tile_size) + self.st_delay
89 |
90 | @cached_property.cached_property
91 | def ld_indices(self):
92 | return collections.OrderedDict(
93 | (name, collections.OrderedDict((ref.idx, ref) for ref in refs))
94 | for name, refs in self.ld_refs.items())
95 |
96 | @cached_property.cached_property
97 | def ld_offsets(self):
98 | return collections.OrderedDict(
99 | (name, collections.OrderedDict(
100 | (soda_util.serialize(ref.idx, self._tile_size), ref) for ref in refs))
101 | for name, refs in self.ld_refs.items())
102 |
103 | @property
104 | def c_type(self):
105 | return util.get_c_type(self.haoda_type)
106 |
107 | def propagate_type(self):
108 | if self.expr is None:
109 | return
110 |
111 | var_types = {}
112 | # pylint: disable=access-member-before-definition
113 | for let in self.lets:
114 | var_types[let.name] = let.haoda_type
115 |
116 | def visit_haoda_type(obj, args):
117 | if obj.haoda_type is None:
118 | if isinstance(obj, ir.Var):
119 | obj.haoda_type = var_types[obj.name]
120 | return obj
121 |
122 | self.lets = tuple(_.visit(visit_haoda_type) for _ in self.lets)
123 | self.expr = self.expr.visit(visit_haoda_type)
124 | self.st_ref = self.st_ref.visit(visit_haoda_type)
125 |
126 | def mutate(self, callback, args=None):
127 | self.lets = tuple(_.visit(callback, args) for _ in self.lets)
128 | self.expr = self.expr.visit(callback, args)
129 | self.st_ref = self.st_ref.visit(callback, args)
130 |
131 | def visit_loads(self, callback, args=None):
132 | for let in self.lets:
133 | let.visit(callback, args)
134 | self.expr.visit(callback, args)
135 |
136 | def __str__(self):
137 | return '''Tensor
138 | {haoda_type}: {name} = {expr}
139 | store: {st_ref} with delay {st_delay}
140 | parents: {parents}
141 | children: {children}'''.format(
142 | name=self.name, haoda_type=self.haoda_type, expr=self.expr,
143 | parents=util.idx2str(self.parents), children=util.idx2str(self.children),
144 | st_ref=str(self.st_ref), st_delay=self.st_delay)
145 |
146 | def is_output(self):
147 | return len(self.children) == 0
148 |
149 | def is_input(self):
150 | return len(self.parents) == 0
151 |
152 | def is_producer(self):
153 | return not self.is_output()
154 |
155 | def is_consumer(self):
156 | return not self.is_input()
157 |
158 | class Stencil():
159 | """
160 | Attributes:
161 | iterate: int, number of iteration to implement.
162 | burst_width: int, width of bits for DRAM burst access.
163 | app_name: str, application's name.
164 | tile_size: List of int.
165 | unroll_factor: int.
166 | dim: int.
167 | param_stmts: List of ParamStmt.
168 | input_stmts: List of InputStmt.
169 | local_stmts: List of LocalStmt.
170 | output_stmts: List of OutputStmt.
171 |
172 | Cached properties:
173 | tensors: Dict from str of name to Tensor.
174 | input_names: Tuple of str, names of input tensors.
175 | param_names: Tuple of str, names of param tensors.
176 | local_names: Tuple of str, names of local tensors.
177 | output_names: Tuple of str, names of output tensors.
178 | """
179 | def __init__(self, **kwargs):
180 | self.iterate = kwargs.pop('iterate')
181 | if self.iterate < 1:
182 | raise util.SemanticError('cannot iterate %d times' % self.iterate)
183 | # platform determined
184 | self.burst_width = kwargs.pop('burst_width')
185 | # application determined
186 | self.app_name = kwargs.pop('app_name')
187 | # parameters that can be explored
188 | self.tile_size = tuple(kwargs.pop('tile_size'))
189 | self.unroll_factor = kwargs.pop('unroll_factor')
190 | # stage-independent
191 | self.dim = kwargs.pop('dim')
192 | self.param_stmts = kwargs.pop('param_stmts')
193 | # stage-specific
194 | self.input_stmts = kwargs.pop('input_stmts')
195 | self.local_stmts = kwargs.pop('local_stmts')
196 | self.output_stmts = kwargs.pop('output_stmts')
197 |
198 | if 'dram_in' in kwargs:
199 | dram_in = kwargs.pop('dram_in')
200 | if dram_in is not None:
201 | if ':' in dram_in:
202 | input_stmt_map = {_.name : _ for _ in self.input_stmts}
203 | for dram_map in dram_in.split('^'):
204 | var_name, bank_list = dram_map.split(':')
205 | if var_name not in input_stmt_map:
206 | raise util.SemanticError('no input named `{}`'.format(var_name))
207 | input_stmt_map[var_name].dram = tuple(map(int,
208 | bank_list.split('.')))
209 | else:
210 | for input_stmt in self.input_stmts:
211 | input_stmt.dram = tuple(map(int, dram_in.split('.')))
212 |
213 | if 'dram_out' in kwargs:
214 | dram_out = kwargs.pop('dram_out')
215 | if dram_out is not None:
216 | if ':' in dram_out:
217 | output_stmt_map = {_.name : _ for _ in self.output_stmts}
218 | for dram_map in dram_out.split(','):
219 | var_name, bank_list = dram_map.split(':')
220 | if var_name not in output_stmt_map:
221 | raise util.SemanticError('no output named `{}`'.format(var_name))
222 | output_stmt_map[var_name].dram = tuple(map(int,
223 | bank_list.split('.')))
224 | else:
225 | for output_stmt in self.output_stmts:
226 | output_stmt.dram = tuple(map(int, dram_out.split('.')))
227 |
228 | if self.iterate > 1:
229 | if len(self.input_stmts) != len(self.output_stmts):
230 | raise util.SemanticError(
231 | 'number of input tensors must be the same as output if iterate > 1 '
232 | 'times, currently there are %d input(s) but %d output(s)' %
233 | (len(self.input_stmts), len(self.output_stmts)))
234 | if self.input_types != self.output_types:
235 | raise util.SemanticError(
236 | 'input must have the same type(s) as output if iterate > 1 '
237 | 'times, current input has type %s but output has type %s' %
238 | (util.lst2str(self.input_types), util.lst2str(self.output_types)))
239 | _logger.debug('pipeline %d iterations of [%s] -> [%s]' % (self.iterate,
240 | ', '.join('%s: %s' % (stmt.haoda_type, stmt.name)
241 | for stmt in self.input_stmts),
242 | ', '.join('%s: %s' % (stmt.haoda_type, stmt.name)
243 | for stmt in self.output_stmts)))
244 |
245 | for stmt in itertools.chain(self.local_stmts, self.output_stmts):
246 | _logger.debug('simplify %s', stmt.name)
247 | stmt.expr = arithmetic.simplify(stmt.expr)
248 | stmt.let = arithmetic.simplify(stmt.let)
249 |
250 | # soda frontend successfully parsed
251 | # triggers cached property
252 | # replicate tensors for iterative stencil
253 | # pylint: disable=pointless-statement
254 | self.tensors
255 | _logger.debug('producer tensors: [%s]',
256 | ', '.join(tensor.name for tensor in self.producer_tensors))
257 | _logger.debug('consumer tensors: [%s]',
258 | ', '.join(tensor.name for tensor in self.consumer_tensors))
259 |
260 | # TODO: build Ref table and Var table
261 | # generate reuse buffers and get haoda nodes
262 | # pylint: disable=pointless-statement
263 | self.dataflow_super_source
264 | _logger.debug('dataflow: %s', self.dataflow_super_source)
265 |
266 | _logger.debug('module table: %s', dict(self.module_table))
267 | _logger.debug('module traits: %s', self.module_traits)
268 |
269 | @cached_property.cached_property
270 | def dataflow_super_source(self):
271 | return dataflow.create_dataflow_graph(self)
272 |
273 | @property
274 | def module_table(self):
275 | return self.dataflow_super_source.module_table
276 |
277 | @property
278 | def module_traits(self):
279 | return self.dataflow_super_source.module_traits
280 |
281 | @cached_property.cached_property
282 | def input_types(self):
283 | return tuple(tensor.haoda_type for tensor in self.input_stmts)
284 |
285 | @cached_property.cached_property
286 | def param_types(self):
287 | return tuple(tensor.haoda_type for tensor in self.param_stmts)
288 |
289 | @cached_property.cached_property
290 | def local_types(self):
291 | return tuple(tensor.haoda_type for tensor in self.local_stmts)
292 |
293 | @cached_property.cached_property
294 | def output_types(self):
295 | return tuple(tensor.haoda_type for tensor in self.output_stmts)
296 |
297 | @cached_property.cached_property
298 | def input_names(self):
299 | return tuple(stmt.name for stmt in self.input_stmts)
300 |
301 | @cached_property.cached_property
302 | def param_names(self):
303 | return tuple(stmt.name for stmt in self.param_stmts)
304 |
305 | @cached_property.cached_property
306 | def local_names(self):
307 | return tuple(stmt.name for stmt in self.local_stmts)
308 |
309 | @cached_property.cached_property
310 | def output_names(self):
311 | return tuple(stmt.name for stmt in self.output_stmts)
312 |
313 | @cached_property.cached_property
314 | def symbol_table(self):
315 | """Constructs a mapping from a tensor's name to its type.
316 |
317 | Returns:
318 | tensor_types: dict from name (str) to haoda_type (str).
319 | """
320 | tensor_types = {}
321 | for name, haoda_type in zip(self.input_names, self.input_types):
322 | tensor_types[name] = haoda_type
323 | for name, haoda_type in zip(self.local_names, self.local_types):
324 | tensor_types[name] = haoda_type
325 | for name, haoda_type in zip(self.output_names, self.output_types):
326 | tensor_types[name] = haoda_type
327 | return tensor_types
328 |
329 | @cached_property.cached_property
330 | def tensors(self):
331 | """Constructs high-level DAG and creates the tensors.
332 |
333 | Returns:
334 | An collections.OrderedDict mapping a tensor's name to the tensor.
335 | """
336 | # TODO: check for name conflicts
337 | tensor_map = collections.OrderedDict()
338 | for stmt in self.input_stmts:
339 | tensor = Tensor(stmt, self.tile_size)
340 | tensor_map[stmt.name] = tensor
341 |
342 | def name_in_iter(name, iteration):
343 | if name in self.input_names:
344 | if iteration > 0:
345 | return name+'_iter%d' % iteration
346 | return name
347 | if name in self.output_names:
348 | if iteration < self.iterate-1:
349 | return (self.input_names[self.output_names.index(name)]+
350 | '_iter%d' % (iteration+1))
351 | return name
352 | if name in self.local_names:
353 | if iteration > 0:
354 | return name+'_iter%d' % iteration
355 | return name
356 | if name in self.param_names:
357 | return name
358 | raise util.InternalError('unknown name: %s' % name)
359 |
360 | for iteration in range(self.iterate):
361 | _logger.debug('iterate %s', iteration)
362 | _logger.debug('map: %s', self.symbol_table)
363 | def mutate_name_callback(obj, mutated):
364 | if isinstance(obj, ir.Ref):
365 | obj.haoda_type = self.symbol_table[obj.name]
366 | # pylint: disable=cell-var-from-loop
367 | obj.name = name_in_iter(obj.name, iteration)
368 | return obj
369 | tensors = []
370 | for stmt in itertools.chain(self.local_stmts, self.output_stmts):
371 | tensor = Tensor(stmt.visit(mutate_name_callback), self.tile_size)
372 | loads = visitor.get_load_tuple(tensor)
373 | norm_idx = tuple(min(load.idx[d] for load in loads
374 | if load.name not in self.param_names)
375 | for d in range(self.dim))
376 | if any(norm_idx):
377 | _logger.debug('normalize index of %s: (%s)',
378 | tensor.name, ', '.join(map(str, norm_idx)))
379 | mutator.shift(tensor, norm_idx, excluded=self.param_names)
380 | tensor_map[tensor.name] = tensor
381 | tensors.append(tensor)
382 |
383 | for tensor in tensors:
384 | _logger.debug('%s', tensor)
385 |
386 | for tensor in tensors:
387 | tensor.propagate_type()
388 | loads = visitor.get_load_dict(tensor)
389 | for parent_name, ld_refs in loads.items():
390 | ld_refs = sorted(ld_refs, key=lambda ref: soda_util.serialize(
391 | ref.idx, self.tile_size))
392 | parent_tensor = tensor_map[parent_name]
393 | parent_tensor.children[tensor.name] = tensor
394 | tensor.parents[parent_name] = parent_tensor
395 | tensor.ld_refs[parent_name] = ld_refs
396 |
397 | # high-level DAG construction finished
398 | for tensor in tensor_map.values():
399 | if tensor.name in self.input_names:
400 | _logger.debug(': %s', tensor)
401 | elif tensor.name in self.output_names:
402 | _logger.debug('