├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── doc ├── main.tex └── smt_gen.pdf ├── inputs ├── xdp1_kern.desc ├── xdp1_kern.ins └── xdp1_kern.maps ├── main.cc ├── main.h ├── measure ├── README.md ├── benchmark_ebpf.cc ├── benchmark_ebpf.h ├── benchmark_header.h ├── benchmark_toy_isa.cc ├── benchmark_toy_isa.h ├── meas_mh_bhv.cc ├── meas_mh_bhv.h ├── meas_mh_bhv_figure.py ├── meas_mh_bhv_figure.sh ├── meas_mh_bhv_script.sh ├── meas_mh_bhv_test.cc ├── meas_solve_time_ebpf.cc ├── meas_time.cc └── meas_time_ebpf.cc ├── src ├── inout.cc ├── inout.h ├── inout_test.cc ├── isa │ ├── ebpf │ │ ├── bpf.h │ │ ├── canonicalize.cc │ │ ├── canonicalize.h │ │ ├── canonicalize_test.cc │ │ ├── inst.cc │ │ ├── inst.h │ │ ├── inst.runtime │ │ ├── inst_codegen.cc │ │ ├── inst_codegen.h │ │ ├── inst_codegen_test.cc │ │ ├── inst_cyclops.runtime │ │ ├── inst_test.cc │ │ ├── inst_var.cc │ │ ├── inst_var.h │ │ ├── win_select.cc │ │ ├── win_select.h │ │ └── win_select_test.cc │ ├── inst.cc │ ├── inst.h │ ├── inst_header.h │ ├── inst_header_basic.h │ ├── inst_var.cc │ ├── inst_var.h │ ├── inst_var_test.cc │ ├── prog.cc │ ├── prog.h │ ├── prog_test.cc │ ├── prog_test_ebpf.cc │ └── toy-isa │ │ ├── inst.cc │ │ ├── inst.h │ │ ├── inst_codegen.h │ │ ├── inst_codegen_test.cc │ │ ├── inst_test.cc │ │ ├── inst_var.cc │ │ └── inst_var.h ├── search │ ├── cost.cc │ ├── cost.h │ ├── cost_test.cc │ ├── cost_test_ebpf.cc │ ├── mh.cc │ ├── mh_prog.cc │ ├── mh_prog.h │ ├── mh_prog_test.cc │ ├── proposals.cc │ ├── proposals.h │ ├── proposals_test.cc │ ├── win_select.cc │ ├── win_select.h │ └── win_select_test_ebpf.cc ├── utils.cc ├── utils.h └── verify │ ├── cfg.cc │ ├── cfg.h │ ├── cfg_test.cc │ ├── cfg_test_ebpf.cc │ ├── smt_prog.cc │ ├── smt_prog.h │ ├── smt_prog_test.cc │ ├── smt_prog_test_ebpf.cc │ ├── smt_var_test.cc │ ├── validator.cc │ ├── validator.h │ ├── validator_test.cc │ ├── validator_test_ebpf.cc │ ├── z3client.cc │ └── z3client.h └── z3server.cc /.gitignore: -------------------------------------------------------------------------------- 1 | *.out 2 | *.dSYM/ 3 | *.o 4 | *.txt 5 | .idea 6 | *.xlsx 7 | *.insns 8 | *.desc 9 | *.bpf_insns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Srinivas Narayana 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Superopt 2 | 3 | #### Installation: Linux (Ubuntu 18.04) and macOS (10.15.2) 4 | 5 | * Install `z3`, more in https://github.com/Z3Prover/z3 6 | ``` 7 | git clone https://github.com/Z3Prover/z3.git 8 | cd z3 9 | git checkout 1c7d27bdf31ca038f7beee28c41aa7dbba1407dd 10 | python scripts/mk_make.py 11 | cd build 12 | make 13 | sudo make install 14 | ``` 15 | * Install `superopt`. Keep superopt folder and z3 folder in the same directory level 16 | ``` 17 | cd ../../ 18 | git clone https://github.com/smartnic/superopt.git 19 | cd superopt 20 | make main_ebpf.out 21 | ``` 22 | Todo: add more instructions soon 23 | -------------------------------------------------------------------------------- /doc/smt_gen.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartnic/superopt/f50ee1f375329dbba105ecf72570ce75ae023812/doc/smt_gen.pdf -------------------------------------------------------------------------------- /inputs/xdp1_kern.desc: -------------------------------------------------------------------------------- 1 | { pgm_input_type = 2, } 2 | { max_pkt_sz = 256, } 3 | -------------------------------------------------------------------------------- /inputs/xdp1_kern.ins: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smartnic/superopt/f50ee1f375329dbba105ecf72570ce75ae023812/inputs/xdp1_kern.ins -------------------------------------------------------------------------------- /inputs/xdp1_kern.maps: -------------------------------------------------------------------------------- 1 | rxcnt { type = 6, key_size = 4, value_size = 4, max_entries = 256, fd = 0 } 2 | -------------------------------------------------------------------------------- /main.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | using namespace std; 6 | 7 | class input_paras { 8 | public: 9 | int niter; 10 | unsigned int k; 11 | int bm; 12 | bool bm_from_file; 13 | string bytecode; 14 | string map; 15 | string desc; 16 | double w_e; 17 | double w_p; 18 | bool meas_mode; 19 | string path_out; 20 | int st_ex; 21 | int st_eq; 22 | int st_avg; 23 | int st_perf; 24 | int st_when_to_restart; 25 | int st_when_to_restart_niter; 26 | int st_start_prog; 27 | vector restart_w_e_list; 28 | vector restart_w_p_list; 29 | int reset_win_niter; 30 | vector win_s_list; 31 | vector win_e_list; 32 | double p_inst_operand; 33 | double p_inst; 34 | double p_inst_as_nop; 35 | bool disable_prog_eq_cache; 36 | bool enable_prog_uneq_cache; 37 | bool is_win; 38 | int logger_level; 39 | int server_port; 40 | }; 41 | struct bpf_insn { 42 | uint8_t opcode; 43 | uint8_t dst_reg: 4; 44 | uint8_t src_reg: 4; 45 | short off; 46 | int imm; 47 | }; 48 | -------------------------------------------------------------------------------- /measure/README.md: -------------------------------------------------------------------------------- 1 | Measuring the solving time of equivalence check formula in the validator 2 | 3 | ##### How to use it 4 | * make 5 | ``` 6 | make z3server.out; make meas_solve_time_ebpf.out; 7 | ``` 8 | 9 | * use the tool to get the solving time 10 | ``` 11 | ./measure/meas_solve_time_ebpf.out [loop_times] 12 | ``` 13 | `loop_times` is the repeat times of running the tool for one benchmark pair. The default value of `loop_times` is 1 14 | 15 | * how to read the output 16 | ``` 17 | ./measure/meas_solve_time_ebpf.out 2 18 | ``` 19 | output: 20 | ``` 21 | Original program is rcv-sock4 22 | starting p1 23 | validator is_smt_valid: 9.51441e+07 us 1 24 | validator is_smt_valid: 6.94754e+07 us 1 25 | starting p2 26 | validator is_smt_valid: 8.34699e+07 us 1 27 | validator is_smt_valid: 9.16541e+07 us 1 28 | ``` 29 | Here two benchmark pairs are (rcv-sock4, p1) and (rcv-sock4, p2). The measurement was repeated twice. The first and second solving times of checking whether rcv-sock4 and p1 are equal is 9.51441e+07 and 6.94754e+07 us respectively. 30 | 31 | * clean the environment 32 | ``` 33 | pkill z3server.out 34 | ``` 35 | -------------------------------------------------------------------------------- /measure/benchmark_ebpf.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "../src/isa/ebpf/inst.h" 5 | 6 | using namespace std; 7 | void init_benchmarks(inst** bm, vector &bm_optis_orig, int bm_id); 8 | void init_benchmark_from_file(inst** bm, const char* insn_file, const char* map_file, const char* desc_file); 9 | // N can not greater than 56 because of the limit of combination function 10 | #undef N0 11 | #undef N1 12 | #undef N2 13 | #undef N3 14 | #undef N4 15 | #undef N5 16 | #undef N6 17 | #undef N7 18 | #undef N8 19 | #undef N9 20 | #undef N11 21 | #undef N12 22 | #define N0 7 23 | #define N1 7 24 | #define N2 16 25 | #define N3 91 26 | #define N4 7 27 | #define N5 7 28 | #define N6 7 29 | #define N7 7 30 | #define N8 24 31 | #define N9 7 32 | #define N10 13 33 | #define N11 24 34 | #define N12 61 35 | #define N13 36 36 | #define N14 24 37 | #define N15 18 38 | #define N16 18 39 | #define N17 26 40 | #define N18 24 41 | #define N19 57 42 | #define N20 38 43 | #define N21 38 44 | #define N22 41 45 | #define N23 43 46 | #define N24 22 47 | #define N25 35 48 | 49 | extern inst bm0[N0]; 50 | extern inst bm1[N1]; 51 | extern inst bm2[N2]; 52 | extern inst bm3[N3]; 53 | extern inst bm4[N4]; 54 | extern inst bm5[N5]; 55 | extern inst bm6[N6]; 56 | extern inst bm7[N7]; 57 | extern inst bm8[N8]; 58 | extern inst bm9[N9]; 59 | extern inst bm10[N10]; 60 | extern inst bm11[N11]; 61 | extern inst bm12[N12]; 62 | extern inst bm13[N13]; 63 | extern inst bm14[N14]; 64 | extern inst bm15[N15]; 65 | extern inst bm16[N16]; 66 | extern inst bm17[N17]; 67 | extern inst bm18[N18]; 68 | extern inst bm19[N19]; 69 | extern inst bm20[N20]; 70 | extern inst bm21[N21]; 71 | extern inst bm22[N22]; 72 | extern inst bm23[N23]; 73 | extern inst bm24[N24]; 74 | extern inst bm25[N25]; 75 | -------------------------------------------------------------------------------- /measure/benchmark_header.h: -------------------------------------------------------------------------------- 1 | #if ISA_TOY_ISA 2 | #include "benchmark_toy_isa.h" 3 | #elif ISA_EBPF 4 | #include "benchmark_ebpf.h" 5 | #endif 6 | -------------------------------------------------------------------------------- /measure/benchmark_toy_isa.cc: -------------------------------------------------------------------------------- 1 | #include "benchmark_toy_isa.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | using namespace std; 8 | 9 | ostream& operator<<(ostream& out, vector& v) { 10 | for (size_t i = 0; i < v.size(); i++) { 11 | out << v[i] << " "; 12 | } 13 | return out; 14 | } 15 | 16 | ostream& operator<<(ostream& out, vector >& v) { 17 | for (size_t i = 0; i < v.size(); i++) { 18 | out << i << ": " << v[i] << endl; 19 | } 20 | return out; 21 | } 22 | // output = max(input+4, 15) 23 | // perf_cost = 3 + 1 = 4 24 | inst bm0[N] = {inst(MOVXC, 2, 4), /* mov r2, 4 */ 25 | inst(ADDXY, 0, 2), /* add r0, r2 */ 26 | inst(MOVXC, 3, 15), /* mov r3, 15 */ 27 | inst(JMPGT, 0, 3, 1), /* if r0 <= r3: */ 28 | inst(RETX, 3), /* ret r3 */ 29 | inst(RETX, 0), /* else ret r0 */ 30 | inst(), /* control never reaches here */ 31 | }; 32 | // f(x) = max(2*x, x+4) 33 | // perf_cost = 3 + 1 = 4 34 | inst bm1[N] = {inst(ADDXY, 1, 0), 35 | inst(MOVXC, 2, 4), 36 | inst(ADDXY, 1, 2), // r1 = r0+4 37 | inst(ADDXY, 0, 0), // r0 += r0 38 | inst(MAXX, 1, 0), // r1 = max(r1, r0) 39 | inst(RETX, 1), 40 | inst(), 41 | }; 42 | // f(x) = 6*x 43 | // perf_cost = 4 + 0 = 4 44 | inst bm2[N] = {inst(MOVXC, 1, 0), 45 | inst(ADDXY, 1, 0), // r1 = 2*r0 46 | inst(ADDXY, 0, 1), 47 | inst(ADDXY, 0, 1), 48 | inst(ADDXY, 0, 1), 49 | inst(ADDXY, 0, 1), 50 | inst(ADDXY, 0, 1), 51 | }; 52 | inst bm_opti00[N] = {inst(MOVXC, 1, 4), 53 | inst(ADDXY, 0, 1), 54 | inst(MAXC, 0, 15), 55 | inst(), 56 | inst(), 57 | inst(), 58 | inst(), 59 | }; 60 | inst bm_opti01[N] = {inst(MAXC, 1, 4), 61 | inst(ADDXY, 0, 1), 62 | inst(MAXC, 0, 15), 63 | inst(), 64 | inst(), 65 | inst(), 66 | inst(), 67 | }; 68 | inst bm_opti02[N] = {inst(MAXC, 1, 4), 69 | inst(MAXC, 0, 11), 70 | inst(ADDXY, 0, 1), 71 | inst(), 72 | inst(), 73 | inst(), 74 | inst(), 75 | }; 76 | inst bm_opti03[N] = {inst(MOVXC, 1, 4), 77 | inst(MAXC, 0, 11), 78 | inst(ADDXY, 0, 1), 79 | inst(), 80 | inst(), 81 | inst(), 82 | inst(), 83 | }; 84 | inst bm_opti04[N] = {inst(MAXC, 0, 11), 85 | inst(MOVXC, 1, 4), 86 | inst(ADDXY, 0, 1), 87 | inst(), 88 | inst(), 89 | inst(), 90 | inst(), 91 | }; 92 | inst bm_opti05[N] = {inst(MAXC, 0, 11), 93 | inst(MAXC, 1, 4), 94 | inst(ADDXY, 0, 1), 95 | inst(), 96 | inst(), 97 | inst(), 98 | inst(), 99 | }; 100 | inst bm_opti10[N] = {inst(MOVXC, 1, 4), 101 | inst(MAXX, 1, 0), 102 | inst(ADDXY, 0, 1), 103 | inst(), 104 | inst(), 105 | inst(), 106 | inst(), 107 | }; 108 | inst bm_opti11[N] = {inst(MAXC, 1, 4), 109 | inst(MAXX, 1, 0), 110 | inst(ADDXY, 0, 1), 111 | inst(), 112 | inst(), 113 | inst(), 114 | inst(), 115 | }; 116 | inst bm_opti12[N] = {inst(ADDXY, 1, 0), 117 | inst(MAXC, 0, 4), 118 | inst(ADDXY, 0, 1), 119 | inst(), 120 | inst(), 121 | inst(), 122 | inst(), 123 | }; 124 | inst bm_opti13[N] = {inst(ADDXY, 1, 0), 125 | inst(MAXC, 1, 4), 126 | inst(ADDXY, 0, 1), 127 | inst(), 128 | inst(), 129 | inst(), 130 | inst(), 131 | }; 132 | inst bm_opti20[N] = {inst(ADDXY, 0, 0), 133 | inst(ADDXY, 1, 0), 134 | inst(ADDXY, 0, 1), 135 | inst(ADDXY, 0, 1), 136 | inst(), 137 | inst(), 138 | inst(), 139 | }; 140 | inst bm_opti21[N] = {inst(ADDXY, 1, 0), 141 | inst(ADDXY, 1, 1), 142 | inst(ADDXY, 0, 1), 143 | inst(ADDXY, 0, 0), 144 | inst(), 145 | inst(), 146 | inst(), 147 | }; 148 | inst bm_opti22[N] = {inst(ADDXY, 1, 0), 149 | inst(ADDXY, 0, 1), 150 | inst(ADDXY, 0, 1), 151 | inst(ADDXY, 0, 0), 152 | inst(), 153 | inst(), 154 | inst(), 155 | }; 156 | inst bm_opti23[N] = {inst(ADDXY, 1, 0), 157 | inst(ADDXY, 1, 0), 158 | inst(ADDXY, 0, 1), 159 | inst(ADDXY, 0, 0), 160 | inst(), 161 | inst(), 162 | inst(), 163 | }; 164 | inst bm_opti24[N] = {inst(ADDXY, 1, 0), 165 | inst(ADDXY, 0, 0), 166 | inst(ADDXY, 0, 1), 167 | inst(ADDXY, 0, 0), 168 | inst(), 169 | inst(), 170 | inst(), 171 | }; 172 | inst bm_opti25[N] = {inst(ADDXY, 0, 0), 173 | inst(ADDXY, 1, 0), 174 | inst(ADDXY, 1, 0), 175 | inst(ADDXY, 0, 1), 176 | inst(), 177 | inst(), 178 | inst(), 179 | }; 180 | inst bm_opti26[N] = {inst(ADDXY, 0, 0), 181 | inst(ADDXY, 1, 0), 182 | inst(ADDXY, 0, 0), 183 | inst(ADDXY, 0, 1), 184 | inst(), 185 | inst(), 186 | inst(), 187 | }; 188 | inst bm_opti27[N] = {inst(ADDXY, 0, 0), 189 | inst(ADDXY, 1, 0), 190 | inst(ADDXY, 1, 1), 191 | inst(ADDXY, 0, 1), 192 | inst(), 193 | inst(), 194 | inst(), 195 | }; 196 | 197 | void init_benchmarks(inst** bm, vector &bm_optis_orig, int bm_id) { 198 | inst::max_prog_len = N; 199 | switch (bm_id) { 200 | case 0: 201 | *bm = bm0; 202 | bm_optis_orig.push_back(bm_opti00); 203 | bm_optis_orig.push_back(bm_opti01); 204 | bm_optis_orig.push_back(bm_opti02); 205 | bm_optis_orig.push_back(bm_opti03); 206 | bm_optis_orig.push_back(bm_opti04); 207 | bm_optis_orig.push_back(bm_opti05); 208 | return; 209 | case 1: 210 | *bm = bm1; 211 | bm_optis_orig.push_back(bm_opti10); 212 | bm_optis_orig.push_back(bm_opti11); 213 | bm_optis_orig.push_back(bm_opti12); 214 | bm_optis_orig.push_back(bm_opti13); 215 | return; 216 | case 2: 217 | *bm = bm2; 218 | bm_optis_orig.push_back(bm_opti20); 219 | bm_optis_orig.push_back(bm_opti21); 220 | bm_optis_orig.push_back(bm_opti22); 221 | bm_optis_orig.push_back(bm_opti23); 222 | bm_optis_orig.push_back(bm_opti24); 223 | bm_optis_orig.push_back(bm_opti25); 224 | bm_optis_orig.push_back(bm_opti26); 225 | bm_optis_orig.push_back(bm_opti27); 226 | return; 227 | default: 228 | cout << "ERROR: toy-isa bm_id " + to_string(bm_id) + " is out of range {0, 1, 2}" << endl; 229 | return; 230 | } 231 | } 232 | -------------------------------------------------------------------------------- /measure/benchmark_toy_isa.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "../src/isa/toy-isa/inst.h" 6 | 7 | using namespace std; 8 | 9 | ostream& operator<<(ostream& out, vector& v); 10 | ostream& operator<<(ostream& out, vector >& v); 11 | 12 | // instruction_list set 13 | // N can not greater than 56 because of the limit of combination function 14 | #undef N 15 | #define N 7 16 | 17 | #undef NUM_ORIG 18 | #define NUM_ORIG 3 19 | extern inst bm0[N]; 20 | extern inst bm1[N]; 21 | extern inst bm2[N]; 22 | 23 | extern inst bm_opti00[N]; 24 | extern inst bm_opti01[N]; 25 | extern inst bm_opti02[N]; 26 | extern inst bm_opti03[N]; 27 | extern inst bm_opti04[N]; 28 | extern inst bm_opti05[N]; 29 | 30 | extern inst bm_opti10[N]; 31 | extern inst bm_opti11[N]; 32 | extern inst bm_opti12[N]; 33 | extern inst bm_opti13[N]; 34 | 35 | extern inst bm_opti20[N]; 36 | extern inst bm_opti21[N]; 37 | extern inst bm_opti22[N]; 38 | extern inst bm_opti23[N]; 39 | extern inst bm_opti24[N]; 40 | extern inst bm_opti25[N]; 41 | extern inst bm_opti26[N]; 42 | extern inst bm_opti27[N]; 43 | 44 | void init_benchmarks(inst** bm, vector &bm_optis_orig, int bm_id); 45 | inline void init_benchmark_from_file(inst** bm, const char* insn_file, const char* map_file, const char* desc_file) {} 46 | -------------------------------------------------------------------------------- /measure/meas_mh_bhv.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "meas_mh_bhv.h" 5 | 6 | using namespace std; 7 | 8 | string FILE_RAW_DATA_PROGRAMS = "raw_data_programs"; 9 | string FILE_RAW_DATA_PROPOSALS = "raw_data_proposals"; 10 | string FILE_RAW_DATA_EXAMPLES = "raw_data_examples"; 11 | string FILE_RAW_DATA_OPTIMALS = "raw_data_optimals"; 12 | 13 | /* class meas_mh_data start */ 14 | meas_mh_data::meas_mh_data() {} 15 | 16 | meas_mh_data::~meas_mh_data() {} 17 | 18 | void meas_mh_data::insert_proposal(const prog &proposal, bool accepted) { 19 | if (_mode) { 20 | _proposals.push_back(make_pair(proposal, accepted)); 21 | } 22 | } 23 | 24 | void meas_mh_data::insert_program(unsigned int iter_num, const prog &program) { 25 | if (_mode) { 26 | _programs.push_back(make_pair(iter_num, program)); 27 | } 28 | } 29 | 30 | void meas_mh_data::insert_examples(unsigned int iter_num, const examples &exs) { 31 | if (_mode) { 32 | _examples.push_back(make_pair(iter_num, exs)); 33 | } 34 | } 35 | 36 | void meas_mh_data::insert_examples(unsigned int iter_num, const inout &exs) { 37 | if (_mode) { 38 | examples exs_new; 39 | exs_new.insert(exs); 40 | this->_examples.push_back(make_pair(iter_num, exs_new)); 41 | } 42 | } 43 | /* class meas_mh_data end */ 44 | 45 | string prog_rel_bv_to_str(int v) { 46 | string res = ""; 47 | for (int i = 0; i < inst::max_prog_len; i++) { 48 | int a = v % 2; 49 | res = to_string(a) + res; 50 | v = v >> 1; 51 | } 52 | return res; 53 | } 54 | 55 | string prog_abs_bv_to_str(vector& v) { 56 | string str = ""; 57 | for (size_t i = 0; i < v.size(); i++) 58 | str += bitset(v[i]).to_string(); 59 | return str; 60 | } 61 | 62 | // fmt: 63 | void store_proposals_to_file(string file_name, 64 | const meas_mh_data &d, 65 | const vector &optimals) { 66 | if (! d._mode) return; 67 | fstream fout; 68 | fout.open(file_name, ios::out | ios::trunc); 69 | fout << " " << endl; 70 | for (size_t i = 0; i < d._proposals.size(); i++) { 71 | prog p(d._proposals[i].first); 72 | vector bv; 73 | p.to_abs_bv(bv); 74 | fout << d._proposals[i].second << " " 75 | << p._error_cost << " " 76 | << p._perf_cost << " " << endl; 77 | // << prog_rel_bv_to_str(p.to_rel_bv(optimals)) << " " 78 | // << prog_abs_bv_to_str(bv) << endl; 79 | } 80 | fout.close(); 81 | } 82 | 83 | // fmt: 84 | void store_programs_to_file(string file_name, 85 | const meas_mh_data &d, 86 | const vector &optimals) { 87 | if (! d._mode) return; 88 | fstream fout; 89 | fout.open(file_name, ios::out | ios::trunc); 90 | fout << " " << endl; 91 | for (size_t i = 0; i < d._programs.size(); i++) { 92 | prog p(d._programs[i].second); 93 | vector bv; 94 | p.to_abs_bv(bv); 95 | fout << d._programs[i].first << " " 96 | << p._error_cost << " " 97 | << p._perf_cost << " " << endl; 98 | // << prog_rel_bv_to_str(p.to_rel_bv(optimals)) << " " 99 | // << prog_abs_bv_to_str(bv) << endl; 100 | } 101 | fout.close(); 102 | } 103 | 104 | // fmt: 105 | void store_examples_to_file(string file_name, 106 | const meas_mh_data &d) { 107 | if (! d._mode) return; 108 | fstream fout; 109 | fout.open(file_name, ios::out | ios::trunc); 110 | fout << " " << endl; 111 | for (size_t i = 0; i < d._examples.size(); i++) { 112 | fout << d._examples[i].first << " " 113 | << d._examples[i].second._exs << endl; 114 | } 115 | fout.close(); 116 | } 117 | 118 | void store_optimals_to_file(string file_name, 119 | const vector &optimals, 120 | bool measure_mode) { 121 | if (! measure_mode) return; 122 | fstream fout; 123 | fout.open(file_name, ios::out | ios::trunc); 124 | fout << "" << endl; 125 | for (size_t i = 0; i < optimals.size(); i++) { 126 | vector bv; 127 | optimals[i].to_abs_bv(bv); 128 | fout << prog_abs_bv_to_str(bv) << endl; 129 | } 130 | fout.close(); 131 | } 132 | 133 | void meas_store_raw_data(meas_mh_data &d, string meas_path_out, string suffix, 134 | int meas_bm, vector &bm_optimals) { 135 | string file_raw_data_programs = meas_path_out + FILE_RAW_DATA_PROGRAMS + suffix; 136 | string file_raw_data_proposals = meas_path_out + FILE_RAW_DATA_PROPOSALS + suffix; 137 | string file_raw_data_examples = meas_path_out + FILE_RAW_DATA_EXAMPLES + suffix; 138 | string file_raw_data_optimals = meas_path_out + FILE_RAW_DATA_OPTIMALS; 139 | file_raw_data_optimals += "_" + to_string(meas_bm) + ".txt"; 140 | store_proposals_to_file(file_raw_data_proposals, d, bm_optimals); 141 | store_programs_to_file(file_raw_data_programs, d, bm_optimals); 142 | store_examples_to_file(file_raw_data_examples, d); 143 | store_optimals_to_file(file_raw_data_optimals, bm_optimals, d._mode); 144 | } 145 | 146 | // return C_n^m 147 | // max n: 56 148 | double combination(unsigned int n, unsigned m) { 149 | assert(n <= 56); 150 | // utilize C_n^m = C_n^(n-m) to simplify the computation and try to avoid overflow 151 | if (m > (n - m)) m = n - m; 152 | double a = 1; 153 | for (unsigned int i = n; i > (n - m); i--) { 154 | a *= i; 155 | } 156 | double b = 1; 157 | for (unsigned int i = 1; i <= m; i++) { 158 | b *= i; 159 | } 160 | return (a / b); 161 | } 162 | 163 | // Generate all combinations that picks n unrepeated numbers from s to e 164 | // row_s is the starting row in `res` that stores the combinations 165 | // e.g. s=1, e=3, n=2, row_s=0, res=[[1,2], [1,3], [2,3]] 166 | // steps: compute combinations recursively ranging from large to small, 167 | // while the real computation is from small to large, 168 | // that is, compute combinations in range [s+1:e] first, then [s:e] 169 | void gen_n_combinations(int n, int s, int e, 170 | int row_s, vector >& res) { 171 | if (n == 0) return; 172 | for (int i = s; i <= e - n + 1; i++) { 173 | double num_comb = combination(e - i, n - 1); 174 | for (double j = row_s; j < row_s + num_comb; j++) 175 | res[j].push_back(i); 176 | gen_n_combinations(n - 1, i + 1, e, row_s, res); 177 | row_s += num_comb; 178 | } 179 | } 180 | 181 | // Premise: should ensure the first real_length instructions in program p 182 | // are not NOP, while the remainings are NOP. 183 | // steps: 1. Set all instructions of this optimal program as NOP; 184 | // 2. Compute combinations for real instruction positions; 185 | // 3. replace NOP instructions with real instructions according to combinations. 186 | // e.g. if optimal program has 2 real instuctions, one combination is [2,3], 187 | // then the second and third instructions are replaced with real instructions 188 | void gen_optis_for_prog(const prog& p, int len, 189 | vector& opti_set) { 190 | int n = p.num_real_instructions(); 191 | // C_len^n 192 | int num_opti = combination(len, n); 193 | vector > comb_set(num_opti); 194 | gen_n_combinations(n, 0, len - 1, 0, comb_set); 195 | opti_set.resize(num_opti, p); 196 | for (size_t i = 0; i < comb_set.size(); i++) { 197 | // set all instructions of this optimal program as NOP 198 | for (size_t j = 0; j < len; j++) 199 | opti_set[i].inst_list[j].set_as_nop_inst(); 200 | // replace some NOP instructions with real instructions 201 | // according to the combination value 202 | for (size_t j = 0; j < comb_set[i].size(); j++) { 203 | size_t pos = comb_set[i][j]; 204 | opti_set[i].inst_list[pos] = p.inst_list[j]; 205 | } 206 | } 207 | } 208 | 209 | void gen_optis_for_progs(const vector &bm_optis_orig, vector &bm_optimals) { 210 | for (size_t i = 0; i < bm_optis_orig.size(); i++) { 211 | prog bm_opti(bm_optis_orig[i]); 212 | // op_set: temporarily store optimals for one bm optimal program 213 | vector op_set; 214 | gen_optis_for_prog(bm_opti, inst::max_prog_len, op_set); 215 | for (size_t j = 0; j < op_set.size(); j++) 216 | bm_optimals.push_back(op_set[j]); 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /measure/meas_mh_bhv.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "../src/utils.h" 6 | #include "../src/inout.h" 7 | #include "../src/isa/inst_header.h" 8 | #include "../src/isa/prog.h" 9 | 10 | using namespace std; 11 | 12 | /* Class meas_mh_data is used to store measurement data when mh_sampler is sampling. 13 | * It ONLY stores data when `_mode` is set as `true`. Now, it is supported to store 14 | * three kinds of data, that is, proposals, programs and examples, the details are in 15 | * class commments. 16 | */ 17 | class meas_mh_data { 18 | public: 19 | // true: measure; false: do not measure 20 | bool _mode; 21 | // (proposal program, accepted or rejected) 22 | vector > _proposals; 23 | // (iteration number, sampled program) 24 | vector > _programs; 25 | // (iteration number, new examples) 26 | vector > _examples; 27 | meas_mh_data(); 28 | ~meas_mh_data(); 29 | void insert_proposal(const prog &proposal, bool accepted); 30 | void insert_program(unsigned int iter_num, const prog &program); 31 | void insert_examples(unsigned int iter_num, const examples &exs); 32 | void insert_examples(unsigned int iter_num, const inout &exs); 33 | }; 34 | 35 | /* The following `store_[]_to_file` functions store raw data of various objects 36 | * into files. It ONLY works when `_mode` in class `meas_mh_data` is set as `true` 37 | */ 38 | void store_proposals_to_file(string file_name, 39 | const meas_mh_data &d, 40 | const vector &optimals); 41 | void store_programs_to_file(string file_name, 42 | const meas_mh_data &d, 43 | const vector &optimals); 44 | void store_examples_to_file(string file_name, 45 | const meas_mh_data &d); 46 | void store_optimals_to_file(string file_name, 47 | const vector &optimals, 48 | bool measure_mode); 49 | void meas_store_raw_data(meas_mh_data &d, string meas_path_out, string suffix, 50 | int meas_bm, vector &bm_optimals); 51 | 52 | void gen_optis_for_progs(const vector &bm_optis_orig, vector &bm_optimals); 53 | // for the unit tests 54 | double combination(unsigned int n, unsigned m); 55 | -------------------------------------------------------------------------------- /measure/meas_mh_bhv_figure.sh: -------------------------------------------------------------------------------- 1 | pythonfile="measure/meas_mh_bhv_figure.py" 2 | file_in="measure/" 3 | file_out=${file_in} 4 | python3 ${pythonfile} -n 100 --fin_path=${file_in} --bm_ids="0" \ 5 | --w_list="1,0" --st_list="000" --best_perf_costs="4" --steady_start=10 \ 6 | --st_when_to_restart_list="0" --st_when_to_restart_niter_list="0" \ 7 | --st_start_prog="0" --p_list="0.333333,0.333333" 8 | -------------------------------------------------------------------------------- /measure/meas_mh_bhv_script.sh: -------------------------------------------------------------------------------- 1 | # For toy-isa 2 | ./main.out \ 3 | -m -n 5000 --path_out "measure/" --bm 0 \ 4 | --w_e 0.5 --w_p 1.5 \ 5 | --st_ex 0 --st_eq 0 --st_avg 0 \ 6 | --st_when_to_restart 1 --st_when_to_restart_niter 2000 \ 7 | --st_start_prog 0 \ 8 | --restart_w_e_list 0.5,1.5 --restart_w_p_list 1.5,0.5 \ 9 | --p_inst_operand 0.33 --p_inst 0.34 10 | 11 | # For ebpf 12 | ./main_ebpf.out \ 13 | -m -n 5000 --path_out "measure/" --bm 0 \ 14 | --w_e 0.5 --w_p 1.5 \ 15 | --st_ex 0 --st_eq 0 --st_avg 0 \ 16 | --st_when_to_restart 1 --st_when_to_restart_niter 2000 \ 17 | --st_start_prog 0 \ 18 | --restart_w_e_list 0.5,1.5 --restart_w_p_list 1.5,0.5 \ 19 | --p_inst_operand 0.33 --p_inst 0.34 20 | -------------------------------------------------------------------------------- /measure/meas_mh_bhv_test.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "meas_mh_bhv.h" 4 | #include "benchmark_toy_isa.h" 5 | #include "../src/utils.h" 6 | 7 | using namespace std; 8 | 9 | void read_data_from_file(string file_name, string str_to_print) { 10 | ifstream fin(file_name, ios::in); 11 | char line[256]; 12 | cout << str_to_print << endl; 13 | while (! fin.eof()) { 14 | fin.getline (line, 256); 15 | cout << line << endl; 16 | } 17 | fin.clear(); 18 | fin.close(); 19 | } 20 | 21 | void test1() { 22 | vector optimals; 23 | optimals.push_back(prog(bm_opti00)); 24 | optimals.push_back(prog(bm_opti01)); 25 | string file_name = "measure/test.txt"; 26 | meas_mh_data d; 27 | d._mode = true; 28 | 29 | store_optimals_to_file(file_name, optimals, d._mode); 30 | read_data_from_file(file_name, "Optimals:"); 31 | 32 | d.insert_proposal(prog(bm0), 1); 33 | d.insert_proposal(prog(bm1), 0); 34 | store_proposals_to_file(file_name, d, optimals); 35 | read_data_from_file(file_name, "Proposals:"); 36 | 37 | d.insert_program(0, prog(bm0)); 38 | d.insert_program(5, prog(bm1)); 39 | store_programs_to_file(file_name, d, optimals); 40 | read_data_from_file(file_name, "Programs:"); 41 | 42 | examples exs; 43 | inout_t input, output; 44 | input.init(); 45 | output.init(); 46 | input.reg = 5; 47 | output.reg = 10; 48 | exs.insert(inout(input, output)); 49 | d.insert_examples(0, exs); 50 | input.reg = 3; 51 | output.reg = 6; 52 | d.insert_examples(1, inout(input, output)); 53 | input.reg = 4; 54 | output.reg = 8; 55 | d.insert_examples(2, inout(input, output)); 56 | store_examples_to_file(file_name, d); 57 | read_data_from_file(file_name, "Examples:"); 58 | 59 | remove("measure/test.txt"); 60 | } 61 | 62 | void test2() { 63 | cout << "test combination" << endl; 64 | print_test_res(combination(10, 2) == 45, "1"); 65 | print_test_res(combination(10, 5) == 252, "2"); 66 | print_test_res(combination(20, 19) == 20, "3"); 67 | print_test_res(combination(50, 20) == 47129212243960, "4"); 68 | print_test_res(combination(50, 25) == 126410606437752, "5"); 69 | print_test_res(combination(56, 28) == 7648690600760440, "6"); 70 | } 71 | 72 | int main() { 73 | test1(); 74 | test2(); 75 | return 0; 76 | } 77 | -------------------------------------------------------------------------------- /measure/meas_time.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "../src/utils.h" 6 | #include "../src/inout.h" 7 | #include "../src/isa/toy-isa/inst.h" 8 | #include "../src/isa/prog.h" 9 | #include "../src/verify/smt_prog.h" 10 | #include "../src/verify/validator.h" 11 | #include "../src/search/mh_prog.h" 12 | #include "benchmark_toy_isa.h" 13 | #include "z3++.h" 14 | 15 | using namespace std; 16 | 17 | #define measure_print(print, loop_times, t1, t2) \ 18 | cout << print << DUR(t1, t2) / loop_times << " us" << endl; 19 | 20 | #define time_measure(func_called, times, print) \ 21 | int loop_times = times; \ 22 | auto start = NOW; \ 23 | for (int i = 0; i < loop_times; i++) { \ 24 | func_called; \ 25 | } \ 26 | auto end = NOW; \ 27 | measure_print(print, times, start, end); 28 | 29 | void time_smt_prog() { 30 | smt_prog ps; 31 | time_measure(ps.gen_smt(i, bm0, inst::max_prog_len), 1000, 32 | "smt prog::gen_smt: "); 33 | } 34 | 35 | void time_validator_set_orig() { 36 | validator vld; 37 | time_measure(vld.set_orig(bm0, inst::max_prog_len), 1000, 38 | "validator::set_orig: "); 39 | } 40 | 41 | void time_validator_is_equal_to() { 42 | validator vld; 43 | vld.set_orig(bm0, inst::max_prog_len); 44 | time_measure(vld.is_equal_to(bm0, inst::max_prog_len, bm0, inst::max_prog_len), 100, 45 | "validator::is_equal_to: "); 46 | } 47 | 48 | void time_validator_is_smt_valid() { 49 | validator vld; 50 | vld.is_equal_to(bm0, inst::max_prog_len, bm0, inst::max_prog_len); 51 | z3::expr smt = vld._store_f; 52 | z3::model mdl(smt_c); 53 | time_measure(vld.is_smt_valid(smt, mdl), 100, 54 | "validator::is_smt_valid: "); 55 | } 56 | 57 | void time_validator_get_orig_output() { 58 | validator vld; 59 | vld.set_orig(bm0, inst::max_prog_len); 60 | time_measure(vld.get_orig_output(i, NUM_REGS, bm0->get_input_reg()), 100, 61 | "validator::get_orig_output: "); 62 | } 63 | 64 | void time_interpret() { 65 | prog_state ps; 66 | ps.init(); 67 | inout_t in, out; 68 | in.init(); 69 | out.init(); 70 | prog p(bm0); 71 | time_measure(p.interpret(out, ps, in), 10000, 72 | "interpret: "); 73 | } 74 | 75 | void time_cost_init() { 76 | double w_e = 1.0; 77 | double w_p = 0.0; 78 | vector input_regs = {10, 16, 11, 48, 1}; 79 | vector inputs(5); 80 | for (int i = 0; i < inputs.size(); i++) { 81 | inputs[i].init(); 82 | inputs[i].reg = input_regs[i]; 83 | } 84 | cost c; 85 | prog orig(bm0); 86 | time_measure(c.init(&orig, N, inputs, w_e, w_p), 100, 87 | "cost::init: "); 88 | } 89 | 90 | void time_cost_error_cost() { 91 | double w_e = 1.0; 92 | double w_p = 0.0; 93 | vector input_regs = {10, 16, 11, 48, 1}; 94 | vector inputs(5); 95 | for (int i = 0; i < inputs.size(); i++) { 96 | inputs[i].init(); 97 | inputs[i].reg = input_regs[i]; 98 | } 99 | cost c; 100 | prog orig(bm0); 101 | c.init(&orig, N, inputs, w_e, w_p); 102 | time_measure(c.error_cost(&orig, inst::max_prog_len, &orig, inst::max_prog_len); 103 | orig._error_cost = -1; 104 | orig._perf_cost = -1, 105 | 200, 106 | "cost::error_cost: " 107 | ); 108 | } 109 | 110 | void time_cost_perf_cost() { 111 | double w_e = 1.0; 112 | double w_p = 0.0; 113 | vector input_regs = {10, 16, 11, 48, 1}; 114 | vector inputs(5); 115 | for (int i = 0; i < inputs.size(); i++) { 116 | inputs[i].init(); 117 | inputs[i].reg = input_regs[i]; 118 | } 119 | cost c; 120 | prog orig(bm0); 121 | c.init(&orig, N, inputs, w_e, w_p); 122 | time_measure(c.perf_cost(&orig, inst::max_prog_len), 1000, 123 | "cost::perf_cost: "); 124 | } 125 | 126 | void time_mh_sampler() { 127 | int loop_times = 50; 128 | auto start = NOW; 129 | for (int i = 0; i < loop_times; i++) { 130 | int nrolls = 1000; 131 | double w_e = 0.45; 132 | double w_p = 1.55; 133 | vector inputs(30); 134 | for (int i = 0; i < inputs.size(); i++) { 135 | inputs[i].init(); 136 | } 137 | gen_random_input(inputs, 0, 50); 138 | mh_sampler mh; 139 | unordered_map > prog_freq; 140 | prog orig(bm0); 141 | mh._cost.init(&orig, N, inputs, w_e, w_p); 142 | mh.mcmc_iter(nrolls, orig, prog_freq); 143 | } 144 | auto end = NOW; 145 | measure_print("mh_sampler: ", loop_times, start, end); 146 | } 147 | 148 | int main() { 149 | time_smt_prog(); 150 | time_validator_set_orig(); 151 | time_validator_is_equal_to(); 152 | time_validator_is_smt_valid(); 153 | // time_validator_get_orig_output(); 154 | time_interpret(); 155 | time_cost_init(); 156 | time_cost_error_cost(); 157 | time_cost_perf_cost(); 158 | time_mh_sampler(); 159 | return 0; 160 | } 161 | -------------------------------------------------------------------------------- /src/inout.cc: -------------------------------------------------------------------------------- 1 | #include "inout.h" 2 | 3 | using namespace std; 4 | 5 | inout::inout() { 6 | input.init(); 7 | output.init(); 8 | } 9 | 10 | inout::inout(const inout_t& in, const inout_t& out) { 11 | input.init(); 12 | output.init(); 13 | input = in; 14 | output = out; 15 | } 16 | 17 | void inout::set_in_out(const inout_t& _input, const inout_t& _output) { 18 | input.init(); 19 | output.init(); 20 | input = _input; 21 | output = _output; 22 | } 23 | 24 | void inout::operator=(const inout &rhs) { 25 | input = rhs.input; 26 | output = rhs.output; 27 | } 28 | 29 | void inout::clear() { 30 | input.clear(); 31 | output.clear(); 32 | } 33 | 34 | ostream& operator<< (ostream& out, const inout &_inout) { 35 | out << "input:" << _inout.input << " output:" << _inout.output; 36 | return out; 37 | } 38 | 39 | ostream& operator<< (ostream& out, const vector &_inout_vec) { 40 | for (size_t i = 0; i < _inout_vec.size(); i++) { 41 | out << _inout_vec[i] << " "; 42 | } 43 | return out; 44 | } 45 | 46 | examples::examples() {} 47 | 48 | examples::~examples() {} 49 | 50 | void examples::insert(const inout& ex) { 51 | inout io; 52 | _exs.push_back(io); 53 | _exs[size() - 1].set_in_out(ex.input, ex.output); 54 | } 55 | 56 | void examples::clear() { 57 | _exs.clear(); 58 | } 59 | -------------------------------------------------------------------------------- /src/inout.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "utils.h" 6 | #include "../src/isa/inst_header.h" 7 | 8 | using namespace std; 9 | 10 | /* A class representing one input-output example. Currently, very simple and 11 | assumes a single integer input and a single integer output. */ 12 | class inout { 13 | public: 14 | inout_t input; 15 | inout_t output; 16 | inout(); 17 | inout(const inout_t& in, const inout_t& out); 18 | void set_in_out(const inout_t& in, const inout_t& out); 19 | void clear(); 20 | void operator=(const inout &rhs); 21 | friend ostream& operator<< (ostream& out, const inout &_inout); 22 | friend ostream& operator<< (ostream& out, const vector &_inout_vec); 23 | }; 24 | 25 | /* Class examples is a set of inouts with different input values. */ 26 | class examples { 27 | public: 28 | vector _exs; 29 | examples(); 30 | ~examples(); 31 | void insert(const inout& ex); 32 | unsigned int size() {return _exs.size();} 33 | void clear(); 34 | }; 35 | -------------------------------------------------------------------------------- /src/inout_test.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "utils.h" 4 | #include "inout.h" 5 | 6 | using namespace std; 7 | 8 | void test1() { 9 | cout << "test 1 starts...\n"; 10 | inout_t in, out; 11 | in.init(); 12 | out.init(); 13 | in.reg = 1; 14 | out.reg = 15; 15 | inout io; 16 | io.set_in_out(in, out); 17 | examples ex_set; 18 | ex_set.insert(io); 19 | bool assert_res = true; 20 | if (ex_set._exs.size() == 1) { 21 | assert_res = (ex_set._exs[0].input.reg == 1) && \ 22 | (ex_set._exs[0].output.reg == 15); 23 | } else { 24 | assert_res = false; 25 | } 26 | print_test_res(assert_res, "examples::insert nonexistent value"); 27 | ex_set.insert(ex_set._exs[0]); 28 | if (ex_set._exs.size() == 1) { 29 | assert_res = (ex_set._exs[0].input.reg == 1) && \ 30 | (ex_set._exs[0].output.reg == 15); 31 | } else { 32 | assert_res = false; 33 | } 34 | print_test_res(!assert_res, "examples::insert existent value"); 35 | } 36 | 37 | int main() { 38 | test1(); 39 | return 0; 40 | } 41 | -------------------------------------------------------------------------------- /src/isa/ebpf/bpf.h: -------------------------------------------------------------------------------- 1 | /* copy from linux/bpf.h, linux/bpf_common.h 2 | * commit: 9e6c535c64adf6155e4a11fe8d63b384fe3452f8 3 | */ 4 | 5 | #pragma once 6 | 7 | /* Instruction classes */ 8 | #define BPF_CLASS(code) ((code) & 0x07) 9 | #define BPF_LD 0x00 10 | #define BPF_LDX 0x01 11 | #define BPF_ST 0x02 12 | #define BPF_STX 0x03 13 | #define BPF_ALU 0x04 14 | #define BPF_JMP 0x05 15 | #define BPF_RET 0x06 16 | #define BPF_MISC 0x07 17 | 18 | /* ld/ldx fields */ 19 | #define BPF_SIZE(code) ((code) & 0x18) 20 | #define BPF_W 0x00 /* 32-bit */ 21 | #define BPF_H 0x08 /* 16-bit */ 22 | #define BPF_B 0x10 /* 8-bit */ 23 | /* eBPF BPF_DW 0x18 64-bit */ 24 | #define BPF_MODE(code) ((code) & 0xe0) 25 | #define BPF_IMM 0x00 26 | #define BPF_ABS 0x20 27 | #define BPF_IND 0x40 28 | #define BPF_MEM 0x60 29 | #define BPF_LEN 0x80 30 | #define BPF_MSH 0xa0 31 | 32 | /* alu/jmp fields */ 33 | #define BPF_OP(code) ((code) & 0xf0) 34 | #define BPF_ADD 0x00 35 | #define BPF_SUB 0x10 36 | #define BPF_MUL 0x20 37 | #define BPF_DIV 0x30 38 | #define BPF_OR 0x40 39 | #define BPF_AND 0x50 40 | #define BPF_LSH 0x60 41 | #define BPF_RSH 0x70 42 | #define BPF_NEG 0x80 43 | #define BPF_MOD 0x90 44 | #define BPF_XOR 0xa0 45 | 46 | #define BPF_JA 0x00 47 | #define BPF_JEQ 0x10 48 | #define BPF_JGT 0x20 49 | #define BPF_JGE 0x30 50 | #define BPF_JSET 0x40 51 | #define BPF_SRC(code) ((code) & 0x08) 52 | #define BPF_K 0x00 53 | #define BPF_X 0x08 54 | 55 | /* instruction classes */ 56 | #define BPF_JMP32 0x06 /* jmp mode in word width */ 57 | #define BPF_ALU64 0x07 /* alu mode in double word width */ 58 | 59 | /* ld/ldx fields */ 60 | #define BPF_DW 0x18 /* double word (64-bit) */ 61 | #define BPF_XADD 0xc0 /* exclusive add */ 62 | 63 | /* alu/jmp fields */ 64 | #define BPF_MOV 0xb0 /* mov reg to reg */ 65 | #define BPF_ARSH 0xc0 /* sign extending arithmetic shift right */ 66 | 67 | /* change endianness of a register */ 68 | #define BPF_END 0xd0 /* flags for endianness conversion: */ 69 | #define BPF_TO_LE 0x00 /* convert to little-endian */ 70 | #define BPF_TO_BE 0x08 /* convert to big-endian */ 71 | #define BPF_FROM_LE BPF_TO_LE 72 | #define BPF_FROM_BE BPF_TO_BE 73 | 74 | /* jmp encodings */ 75 | #define BPF_JNE 0x50 /* jump != */ 76 | #define BPF_JLT 0xa0 /* LT is unsigned, '<' */ 77 | #define BPF_JLE 0xb0 /* LE is unsigned, '<=' */ 78 | #define BPF_JSGT 0x60 /* SGT is signed '>', GT in x86 */ 79 | #define BPF_JSGE 0x70 /* SGE is signed '>=', GE in x86 */ 80 | #define BPF_JSLT 0xc0 /* SLT is signed, '<' */ 81 | #define BPF_JSLE 0xd0 /* SLE is signed, '<=' */ 82 | #define BPF_CALL 0x80 /* function call */ 83 | #define BPF_EXIT 0x90 /* function return */ 84 | 85 | /* Register numbers */ 86 | enum { 87 | BPF_REG_0 = 0, 88 | BPF_REG_1, 89 | BPF_REG_2, 90 | BPF_REG_3, 91 | BPF_REG_4, 92 | BPF_REG_5, 93 | BPF_REG_6, 94 | BPF_REG_7, 95 | BPF_REG_8, 96 | BPF_REG_9, 97 | BPF_REG_10, 98 | __MAX_BPF_REG, 99 | }; 100 | 101 | #define __BPF_FUNC_MAPPER(FN) \ 102 | FN(unspec), \ 103 | FN(map_lookup_elem), \ 104 | FN(map_update_elem), \ 105 | FN(map_delete_elem), \ 106 | FN(probe_read), \ 107 | FN(ktime_get_ns), \ 108 | FN(trace_printk), \ 109 | FN(get_prandom_u32), \ 110 | FN(get_smp_processor_id), \ 111 | FN(skb_store_bytes), \ 112 | FN(l3_csum_replace), \ 113 | FN(l4_csum_replace), \ 114 | FN(tail_call), \ 115 | FN(clone_redirect), \ 116 | FN(get_current_pid_tgid), \ 117 | FN(get_current_uid_gid), \ 118 | FN(get_current_comm), \ 119 | FN(get_cgroup_classid), \ 120 | FN(skb_vlan_push), \ 121 | FN(skb_vlan_pop), \ 122 | FN(skb_get_tunnel_key), \ 123 | FN(skb_set_tunnel_key), \ 124 | FN(perf_event_read), \ 125 | FN(redirect), \ 126 | FN(get_route_realm), \ 127 | FN(perf_event_output), \ 128 | FN(skb_load_bytes), \ 129 | FN(get_stackid), \ 130 | FN(csum_diff), \ 131 | FN(skb_get_tunnel_opt), \ 132 | FN(skb_set_tunnel_opt), \ 133 | FN(skb_change_proto), \ 134 | FN(skb_change_type), \ 135 | FN(skb_under_cgroup), \ 136 | FN(get_hash_recalc), \ 137 | FN(get_current_task), \ 138 | FN(probe_write_user), \ 139 | FN(current_task_under_cgroup), \ 140 | FN(skb_change_tail), \ 141 | FN(skb_pull_data), \ 142 | FN(csum_update), \ 143 | FN(set_hash_invalid), \ 144 | FN(get_numa_node_id), \ 145 | FN(skb_change_head), \ 146 | FN(xdp_adjust_head), \ 147 | FN(probe_read_str), \ 148 | FN(get_socket_cookie), \ 149 | FN(get_socket_uid), \ 150 | FN(set_hash), \ 151 | FN(setsockopt), \ 152 | FN(skb_adjust_room), \ 153 | FN(redirect_map), \ 154 | FN(sk_redirect_map), \ 155 | FN(sock_map_update), \ 156 | FN(xdp_adjust_meta), \ 157 | FN(perf_event_read_value), \ 158 | FN(perf_prog_read_value), \ 159 | FN(getsockopt), \ 160 | FN(override_return), \ 161 | FN(sock_ops_cb_flags_set), \ 162 | FN(msg_redirect_map), \ 163 | FN(msg_apply_bytes), \ 164 | FN(msg_cork_bytes), \ 165 | FN(msg_pull_data), \ 166 | FN(bind), \ 167 | FN(xdp_adjust_tail), \ 168 | FN(skb_get_xfrm_state), \ 169 | FN(get_stack), \ 170 | FN(skb_load_bytes_relative), \ 171 | FN(fib_lookup), \ 172 | FN(sock_hash_update), \ 173 | FN(msg_redirect_hash), \ 174 | FN(sk_redirect_hash), \ 175 | FN(lwt_push_encap), \ 176 | FN(lwt_seg6_store_bytes), \ 177 | FN(lwt_seg6_adjust_srh), \ 178 | FN(lwt_seg6_action), \ 179 | FN(rc_repeat), \ 180 | FN(rc_keydown), \ 181 | FN(skb_cgroup_id), \ 182 | FN(get_current_cgroup_id), \ 183 | FN(get_local_storage), \ 184 | FN(sk_select_reuseport), \ 185 | FN(skb_ancestor_cgroup_id), \ 186 | FN(sk_lookup_tcp), \ 187 | FN(sk_lookup_udp), \ 188 | FN(sk_release), \ 189 | FN(map_push_elem), \ 190 | FN(map_pop_elem), \ 191 | FN(map_peek_elem), \ 192 | FN(msg_push_data), \ 193 | FN(msg_pop_data), \ 194 | FN(rc_pointer_rel), \ 195 | FN(spin_lock), \ 196 | FN(spin_unlock), \ 197 | FN(sk_fullsock), \ 198 | FN(tcp_sock), \ 199 | FN(skb_ecn_set_ce), \ 200 | FN(get_listener_sock), \ 201 | FN(skc_lookup_tcp), \ 202 | FN(tcp_check_syncookie), \ 203 | FN(sysctl_get_name), \ 204 | FN(sysctl_get_current_value), \ 205 | FN(sysctl_get_new_value), \ 206 | FN(sysctl_set_new_value), \ 207 | FN(strtol), \ 208 | FN(strtoul), \ 209 | FN(sk_storage_get), \ 210 | FN(sk_storage_delete), \ 211 | FN(send_signal), \ 212 | FN(tcp_gen_syncookie), \ 213 | FN(skb_output), \ 214 | FN(probe_read_user), \ 215 | FN(probe_read_kernel), \ 216 | FN(probe_read_user_str), \ 217 | FN(probe_read_kernel_str), \ 218 | FN(tcp_send_ack), \ 219 | FN(send_signal_thread), \ 220 | FN(jiffies64), 221 | 222 | /* integer value in 'imm' field of BPF_CALL instruction selects which helper 223 | * function eBPF program intends to call 224 | */ 225 | #define __BPF_ENUM_FN(x) BPF_FUNC_ ## x 226 | enum bpf_func_id { 227 | __BPF_FUNC_MAPPER(__BPF_ENUM_FN) 228 | __BPF_FUNC_MAX_ID, 229 | }; 230 | #undef __BPF_ENUM_FN 231 | 232 | enum bpf_map_type { 233 | BPF_MAP_TYPE_UNSPEC, 234 | BPF_MAP_TYPE_HASH, 235 | BPF_MAP_TYPE_ARRAY, 236 | BPF_MAP_TYPE_PROG_ARRAY, 237 | BPF_MAP_TYPE_PERF_EVENT_ARRAY, 238 | BPF_MAP_TYPE_PERCPU_HASH, 239 | BPF_MAP_TYPE_PERCPU_ARRAY, 240 | BPF_MAP_TYPE_STACK_TRACE, 241 | BPF_MAP_TYPE_CGROUP_ARRAY, 242 | BPF_MAP_TYPE_LRU_HASH, 243 | BPF_MAP_TYPE_LRU_PERCPU_HASH, 244 | BPF_MAP_TYPE_LPM_TRIE, 245 | BPF_MAP_TYPE_ARRAY_OF_MAPS, 246 | BPF_MAP_TYPE_HASH_OF_MAPS, 247 | BPF_MAP_TYPE_DEVMAP, 248 | BPF_MAP_TYPE_SOCKMAP, 249 | BPF_MAP_TYPE_CPUMAP, 250 | BPF_MAP_TYPE_XSKMAP, 251 | BPF_MAP_TYPE_SOCKHASH, 252 | BPF_MAP_TYPE_CGROUP_STORAGE, 253 | BPF_MAP_TYPE_REUSEPORT_SOCKARRAY, 254 | BPF_MAP_TYPE_PERCPU_CGROUP_STORAGE, 255 | BPF_MAP_TYPE_QUEUE, 256 | BPF_MAP_TYPE_STACK, 257 | BPF_MAP_TYPE_SK_STORAGE, 258 | BPF_MAP_TYPE_DEVMAP_HASH, 259 | BPF_MAP_TYPE_STRUCT_OPS, 260 | }; 261 | -------------------------------------------------------------------------------- /src/isa/ebpf/canonicalize.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "../../../src/verify/cfg.h" 5 | #include "inst.h" 6 | 7 | using namespace std; 8 | 9 | void remove_nops(inst* program, int len); 10 | 11 | void canonicalize(inst* program, int len); 12 | void set_nops_as_JA0(inst* program, int len); 13 | 14 | class inst_static_state { 15 | public: 16 | vector> reg_state; // all possible states of registers 17 | live_variables live_var; 18 | unsigned int min_pkt_sz; // minimum pkt size before executing the insn. 19 | 20 | inst_static_state(); 21 | void copy_reg_state(int dst_reg, int src_reg); 22 | void set_reg_state(int reg, int type, int off = 0); 23 | void set_reg_state(int reg, register_state rs); 24 | void insert_reg_state(inst_static_state& iss); 25 | void insert_live_reg(int reg); 26 | void insert_live_off(int type, int off); 27 | void insert_live_var(inst_static_state& iss); 28 | static void intersection_live_var(inst_static_state& iss, inst_static_state& iss1, inst_static_state& iss2); 29 | inst_static_state& operator=(const inst_static_state &rhs); 30 | friend ostream& operator<<(ostream& out, const inst_static_state& x); 31 | }; 32 | 33 | class prog_static_state { 34 | public: 35 | vector static_state; 36 | vector block_static_state; 37 | graph g; 38 | vector dag; 39 | void clear() {static_state.clear(); g.clear(); dag.clear();}; 40 | }; 41 | 42 | void static_analysis(prog_static_state& pss, inst* program, int len); 43 | void set_up_smt_inout_orig(prog_static_state& pss, inst* program, int len, int win_start, int win_end); 44 | void set_up_smt_inout_win(smt_input& sin, smt_output& sout, prog_static_state& pss_orig, inst* program, int win_start, int win_end); 45 | // todo: move random input related functions to other place 46 | void gen_random_input(vector& inputs, int n, int64_t reg_min, int64_t reg_max); 47 | void gen_random_input_for_win(vector& inputs, int n, inst_static_state& iss, inst& insn, int win_start, int win_end); 48 | void static_safety_check_pgm(inst* program, int len); 49 | void static_safety_check_win(inst* win_prog, int win_start, int win_end, prog_static_state& pss_orig); 50 | 51 | // for unit tests 52 | void type_const_inference_pgm(prog_static_state& pss, inst* program, int len); 53 | -------------------------------------------------------------------------------- /src/isa/ebpf/inst.runtime: -------------------------------------------------------------------------------- 1 | /* make sure there is one space between `opcode` and `runtime`; machine: d6515 (ubuntu20) */ 2 | ADD32XC 0.366 3 | SUB32XC 0.356 4 | MUL32XC 1.025 5 | DIV32XC 4.776 6 | AND32XC 0.363 7 | OR32XC 0.37 8 | LSH32XC 0.329 9 | RSH32XC 0.365 10 | XOR32XC 0.36 11 | MOD32XC 5.128 12 | ARSH32XC 0.351 13 | NEG32XC 0.375 14 | ADD32XY 0.327 15 | SUB32XY 0.333 16 | MUL32XY 1.021 17 | DIV32XY 5.152 18 | AND32XY 0.33 19 | OR32XY 0.359 20 | LSH32XY 0.368 21 | RSH32XY 0.368 22 | XOR32XY 0.366 23 | MOD32XY 5.22 24 | ARSH32XY 0.369 25 | ADD64XC 0.369 26 | SUB64XC 0.369 27 | MUL64XC 1.135 28 | DIV64XC 5.182 29 | AND64XC 0.375 30 | OR64XC 0.37 31 | LSH64XC 0.384 32 | RSH64XC 0.38 33 | XOR64XC 0.37 34 | MOD64XC 5.239 35 | ARSH64XC 0.379 36 | NEG64XC 0.369 37 | ADD64XY 0.369 38 | SUB64XY 0.371 39 | MUL64XY 1.135 40 | DIV64XY 5.314 41 | AND64XY 0.375 42 | OR64XY 0.37 43 | LSH64XY 0.384 44 | RSH64XY 0.381 45 | XOR64XY 0.38 46 | MOD64XY 5.239 47 | ARSH64XY 0.379 48 | MOV64XC 0.095 49 | MOV64XY 0.079 50 | MOV32XY 0.085 51 | MOV32XC 0.096 52 | JEQXC 1.868 53 | JEQ32XC 1.882 54 | JNEXC 1.864 55 | JNEQ32XC 0.181 56 | JGTXC 0.181 57 | JSGTXC 1.849 58 | JGEXC 0.183 59 | JLTXC 1.893 60 | JLEXC 1.877 61 | JA 0 62 | JSETXC 0.212 63 | JEQXY 1.875 64 | JEQ32XY 1.832 65 | JNEXY 1.896 66 | JNEQ32XY 0.183 67 | JGTXY 0.191 68 | JSGTXY 1.834 69 | JGEXY 0.181 70 | JLTXY 1.887 71 | JLEXY 1.875 72 | JSETXY 0.201 73 | STXB 0.382 74 | STXH 0.369 75 | STXW 0.376 76 | STXDW 0.381 77 | STB 0.381 78 | STH 0.381 79 | STW 0.372 80 | STDW 0.376 81 | XADD64 6.616 82 | XADD32 5.514 83 | LDXB 0.179 84 | LDXH 0.181 85 | LDXW 0.18 86 | LDXDW 0.182 87 | LDXC 0.12 88 | LDDW 0.363 89 | LDABSB 4.81 90 | LDABSH 5.022 91 | LDABSW 4.836 92 | BPF_FUNC_map_lookup_elem 1.518 93 | BPF_FUNC_map_update_elem 19.03 94 | BPF_FUNC_map_delete_elem 2.12 95 | BPF_FUNC_get_prandom_u32 5.279 96 | BPF_FUNC_tail_call 1.85 97 | BE16 0.187 98 | BE32 0.177 99 | BE64 0.114 100 | LE16 0.284 101 | LE32 0.197 102 | LE64 0.189 -------------------------------------------------------------------------------- /src/isa/ebpf/inst_cyclops.runtime: -------------------------------------------------------------------------------- 1 | /* make sure there is one space between `opcode` and `runtime`; machine: cyclops */ 2 | ADD32XC 0.7 3 | SUB32XC 0.7 4 | MUL32XC 4.7 5 | DIV32XC 24.7 6 | AND32XC 0.7 7 | OR32XC 0.7 8 | LSH32XC 0.7 9 | RSH32XC 0.7 10 | XOR32XC 0.7 11 | MOD32XC 23.7 12 | NEG32XC 0.7 13 | ADD32XY 0.7 14 | SUB32XY 0.7 15 | MUL32XY 4.5 16 | DIV32XY 25.8 17 | AND32XY 0.7 18 | OR32XY 0.7 19 | LSH32XY 4 20 | RSH32XY 4.1 21 | XOR32XY 0.7 22 | MOD32XY 24.4 23 | ADD64XC 0.7 24 | SUB64XC 0.7 25 | MUL64XC 4.4 26 | DIV64XC 33.3 27 | AND64XC 0.7 28 | OR64XC 0.7 29 | LSH64XC 0.7 30 | RSH64XC 0.7 31 | XOR64XC 0.7 32 | MOD64XC 30.6 33 | ARSH64XC 0.7 34 | NEG64XC 0.7 35 | ADD64XY 0.7 36 | SUB64XY 0.7 37 | MUL64XY 4.8 38 | DIV64XY 34.7 39 | AND64XY 0.7 40 | OR64XY 0.7 41 | LSH64XY 4.1 42 | RSH64XY 4 43 | XOR64XY 0.7 44 | MOD64XY 31.5 45 | ARSH64XY 4 46 | MOV64XC 0.3 47 | MOV64XY 0.3 48 | MOV32XY 0.3 49 | JEQXC 2 50 | JNEXC 2.1 51 | JGTXC 0.7 52 | JGEXC 0.6 53 | JLTXC 1.2 54 | JLEXC 1.1 55 | JA 0.1 56 | JSETXC 0.6 57 | JEQXY 2.2 58 | JNEXY 2 59 | JGTXY 0.6 60 | JGEXY 0.6 61 | JLTXY 1.3 62 | JLEXY 1.1 63 | JSETXY 0.6 64 | STXB 0.8 65 | STXH 0.9 66 | STXW 0.9 67 | STXDW 0.9 68 | STB 0.9 69 | STH 3.3 70 | STW 0.9 71 | STDW 0.9 72 | XADD64 5.7 73 | LDXB 0.5 74 | LDXH 0.5 75 | LDXW 0.8 76 | LDXDW 0.8 77 | LDDW 0.6 78 | LDABSB 5.6 79 | LDABSH 5 80 | LDABSW 4.7 81 | BPF_FUNC_map_lookup_elem 2.7 82 | BPF_FUNC_map_update_elem 75.7 83 | BPF_FUNC_map_delete_elem 44.3 84 | MOV32XC 0.2 85 | ARSH32XC 0.8 86 | ARSH32XY 3.6 87 | BPF_FUNC_get_prandom_u32 20.1 88 | BPF_FUNC_tail_call 2.3 89 | XADD32 17.8 90 | BE16 0.5 91 | BE32 0.5 92 | BE64 0.4 93 | LE16 0.8 94 | LE32 1 95 | LE64 1 96 | JSGTXY 2.4 97 | JSGTXC 2 98 | JEQ32XY 0.9 99 | JEQ32XC 0.9 100 | JNEQ32XY 0.8 101 | JNEQ32XC 0.9 -------------------------------------------------------------------------------- /src/isa/ebpf/win_select.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "win_select.h" 3 | 4 | int num_unimplemented = 0; 5 | unordered_set unimplemented_opcodes = {}; 6 | int num_call = 0; 7 | int num_multi_values = 0; 8 | int num_symbolic = 0; 9 | 10 | /* If the insn does not satisfy ISA window constraints, return false and 11 | this insn won't be selected in the windows. 12 | */ 13 | bool insn_satisfy_isa_win_constraints(const inst& insn, const inst_static_state& iss) { 14 | /* opcode is implemented */ 15 | if (! inst::is_valid_opcode(insn._opcode)) { 16 | num_unimplemented++; 17 | unimplemented_opcodes.insert(insn._opcode); 18 | return false; 19 | } 20 | 21 | /* not a function call 22 | (1. some helpers are not supported; 23 | 2. can hardly improve function call; 24 | 3. map update flag limitation: 25 | https://github.com/smartnic/superopt/commit/d1c11c5ebfbeda68fd4e9cbc14066bca8c0316f1) 26 | */ 27 | if (insn._opcode == CALL) { 28 | num_call++; 29 | return false; 30 | } 31 | 32 | /* memory access: 33 | 1. addr register is a pointer (pointer from addxy is not tracked in static analysis) 34 | 2. memory access offset is a constant (safety: not able to check out of bound) 35 | */ 36 | vector mem_acc_regs; 37 | insn.mem_access_regs(mem_acc_regs); 38 | if (mem_acc_regs.size() > 0) { 39 | for (int i = 0; i < mem_acc_regs.size(); i++) { 40 | int reg = mem_acc_regs[i]; 41 | const vector& reg_state = iss.reg_state[reg]; 42 | if (reg_state.size() != 1) { // constant offset check 43 | num_multi_values++; 44 | return false; 45 | } 46 | int reg_type = reg_state[0].type; 47 | if (! is_ptr(reg_type)) { // addr reg is a ptr check 48 | num_symbolic++; 49 | return false; 50 | } 51 | } 52 | } 53 | 54 | return true; 55 | } 56 | 57 | void reset_isa_win_constraints_statistics() { 58 | num_unimplemented = 0; 59 | unimplemented_opcodes = {}; 60 | num_call = 0; 61 | num_multi_values = 0; 62 | num_symbolic = 0; 63 | } 64 | 65 | void print_isa_win_constraints_statistics() { 66 | cout << "# unimplemented opcodes: " << num_unimplemented << " "; 67 | if (num_unimplemented != 0) { 68 | cout << "# opcodes:" << unimplemented_opcodes.size() << " "; 69 | cout << "opcodes:"; 70 | for (auto op : unimplemented_opcodes) cout << hex << "0x" << op << dec << " "; 71 | } 72 | cout << endl; 73 | cout << "# helper calls: " << num_call << endl; 74 | cout << "# mem_acc multi_values: " << num_multi_values << endl; 75 | cout << "# mem_acc num_symbolic: " << num_symbolic << endl; 76 | } 77 | -------------------------------------------------------------------------------- /src/isa/ebpf/win_select.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "inst.h" 4 | #include "canonicalize.h" 5 | 6 | using namespace std; 7 | 8 | bool insn_satisfy_isa_win_constraints(const inst& insn, const inst_static_state& iss); 9 | void reset_isa_win_constraints_statistics(); 10 | void print_isa_win_constraints_statistics(); 11 | -------------------------------------------------------------------------------- /src/isa/ebpf/win_select_test.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "win_select.h" 3 | 4 | using namespace std; 5 | 6 | void test_satisfy_constraints(inst* pgm, int len, 7 | unordered_set& false_insns_exp, 8 | string test_name) { 9 | prog_static_state pss; 10 | static_analysis(pss, pgm, len); 11 | bool test_res; 12 | vector failed_insns; 13 | for (int i = 0; i < len; i++) { 14 | bool res = insn_satisfy_isa_win_constraints(pgm[i], pss.static_state[i]); 15 | bool res_exp = (false_insns_exp.find(i) == false_insns_exp.end()); 16 | if (res != res_exp) failed_insns.push_back(i); 17 | } 18 | if (failed_insns.size() == 0) { 19 | print_test_res(true, test_name); 20 | } else { 21 | print_test_res(false, test_name); 22 | cout << "failed insn ids: "; 23 | for (int i = 0; i < failed_insns.size(); i++) { 24 | cout << i << " "; 25 | } 26 | cout << endl; 27 | } 28 | } 29 | 30 | void test1() { 31 | cout << "Test1: test insn_satisfy_isa_win_constraints" << endl; 32 | cout << "1. test function call" << endl; 33 | mem_t::_layout.clear(); 34 | mem_t::set_pgm_input_type(PGM_INPUT_pkt); 35 | mem_t::set_pkt_sz(32); 36 | mem_t::add_map(map_attr(16, 32, 16)); 37 | inst p1_1[] = {inst(STB, 10, -1, 0x1), 38 | INSN_LDMAPID(1, 0), 39 | inst(MOV64XY, 2, 10), 40 | inst(ADD64XC, 2, -1), 41 | inst(CALL, BPF_FUNC_map_lookup_elem), 42 | inst(JEQXC, 0, 0, 2), 43 | inst(LDXB, 0, 0, 0), 44 | inst(EXIT), 45 | inst(MOV64XC, 0, 0), 46 | inst(EXIT), 47 | }; 48 | const int len_p1_1 = sizeof(p1_1) / sizeof(inst); 49 | unordered_set false_insns_exp = {4}; 50 | test_satisfy_constraints(p1_1, len_p1_1, false_insns_exp, "1"); 51 | 52 | cout << "2. test memory access for PGM_INPUT_pkt" << endl; 53 | mem_t::_layout.clear(); 54 | mem_t::set_pgm_input_type(PGM_INPUT_pkt); 55 | mem_t::set_pkt_sz(32); 56 | inst p2_1[] = {inst(LDXB, 5, 1, 0), // r5 = *r1 57 | inst(STXB, 1, 2, 5), 58 | inst(ADD64XY, 1, 1), 59 | inst(STXB, 1, 1, 5), 60 | inst(MOV64XC, 0, 0), 61 | inst(EXIT), 62 | }; 63 | const int len_p2_1 = sizeof(p2_1) / sizeof(inst); 64 | false_insns_exp = {3}; 65 | test_satisfy_constraints(p2_1, len_p2_1, false_insns_exp, "1"); 66 | 67 | cout << "3. test symbolic memory access for PGM_INPUT_pkt_ptrs" << endl; 68 | mem_t::_layout.clear(); 69 | mem_t::set_pgm_input_type(PGM_INPUT_pkt_ptrs); 70 | mem_t::set_pkt_sz(32); 71 | inst p3_1[] = {inst(LDXW, 2, 1, 4), // r2: PTR_TO_PACKET_END 72 | inst(LDXW, 7, 1, 0), // r7: pkt_s 73 | inst(MOV64XY, 3, 7), // r3 = r7 + 4 74 | inst(ADD64XC, 3, 4), 75 | inst(MOV64XY, 0, 7), // r0 = r7 76 | inst(JGEXC, 3, 0xff, 1), 77 | inst(ADD64XC, 0, 4), // 6: r0 += 4 78 | inst(JGTXY, 3, 2, 2), // if r3 > r2, exit; 79 | inst(STB, 0, 0, 1), // 8: 80 | inst(MOV64XC, 0, 0), 81 | inst(EXIT), 82 | }; 83 | const int len_p3_1 = sizeof(p3_1) / sizeof(inst); 84 | false_insns_exp = {8}; 85 | test_satisfy_constraints(p3_1, len_p3_1, false_insns_exp, "1"); 86 | } 87 | 88 | int main() { 89 | try { 90 | test1(); 91 | } catch (const string err_msg) { 92 | cout << "NOT SUCCESS: " << err_msg << endl; 93 | } 94 | 95 | return 0; 96 | } 97 | -------------------------------------------------------------------------------- /src/isa/inst.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "inst.h" 4 | 5 | using namespace std; 6 | 7 | int inst_base::max_prog_len = 7; // the default value is set as 7 8 | 9 | void inst_base::to_abs_bv(vector& abs_vec) const { 10 | const int num_args = _args.size(); 11 | abs_vec.push_back(_opcode); 12 | for (int i = 0; i < num_args; i++) { 13 | abs_vec.push_back(_args[i]); 14 | } 15 | } 16 | 17 | int inst_base::get_operand(int op_index) const { 18 | assert(op_index < _args.size()); 19 | return _args[op_index]; 20 | } 21 | 22 | void inst_base::set_operand(int op_index, op_t op_value) { 23 | assert(op_index < _args.size()); 24 | _args[op_index] = op_value; 25 | } 26 | 27 | int inst_base::get_opcode() const { 28 | return _opcode; 29 | } 30 | 31 | int inst_base::get_opcode_by_idx(int idx) const { 32 | return idx; 33 | } 34 | 35 | void inst_base::set_opcode(int op_value) { 36 | _opcode = op_value; 37 | } 38 | 39 | size_t instHash::operator()(const inst_base &x) const { 40 | size_t res = hash()(x._opcode); 41 | for (int i = 0; i < x._args.size(); i++) { 42 | res ^= hash()(x._args[i]) << (i + 1); 43 | } 44 | return res; 45 | } 46 | -------------------------------------------------------------------------------- /src/isa/inst.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "../../src/utils.h" 6 | #include "../../src/isa/inst_var.h" 7 | 8 | #if ISA_TOY_ISA 9 | #include "../../src/isa/toy-isa/inst_var.h" 10 | #elif ISA_EBPF 11 | #include "../../src/isa/ebpf/inst_var.h" 12 | #endif 13 | 14 | using namespace std; 15 | 16 | enum ISA_TYPES { 17 | TOY_ISA = 0, 18 | EBPF, 19 | }; 20 | 21 | // Opcode types for instructions 22 | enum OPCODE_TYPES { 23 | OP_NOP = 0, 24 | OP_RET, 25 | OP_UNCOND_JMP, 26 | OP_COND_JMP, 27 | OP_OTHERS, 28 | OP_ST, 29 | OP_LD, 30 | OP_CALL, 31 | }; 32 | 33 | // Return opcode types for the end instruction of a program 34 | #define RET_C 0 // return immediate number 35 | #define RET_X 1 // return register 36 | 37 | class inst_base { 38 | public: 39 | static int max_prog_len; 40 | int _opcode; 41 | vector _args; 42 | inst_base() {} 43 | void to_abs_bv(vector& abs_vec) const; 44 | int get_operand(int op_index) const; 45 | void set_operand(int op_index, op_t op_value); 46 | int get_opcode() const; 47 | int get_opcode_by_idx(int idx) const; 48 | void set_opcode(int op_value); 49 | double get_runtime() const {return 1;} 50 | bool sample_unmodifiable() const {return false;} 51 | z3::expr smt_inst_safety_chk(smt_var& sv) const {return Z3_true;} 52 | void insert_opcodes_not_gen(unordered_set& opcode_set) const {} 53 | int num_insns() const {return 1;} 54 | 55 | /* Functions class inst should support */ 56 | // inst& operator=(const inst &rhs) 57 | bool operator==(const inst_base &x) const {RAISE_EXCEPTION("inst::operator==");} 58 | void print() const {RAISE_EXCEPTION("inst::print");} 59 | // get_canonical_reg_list returns the list of regs which can be modified by prog canonicalize 60 | vector get_canonical_reg_list() const {RAISE_EXCEPTION("inst::get_canonical_reg_list");} 61 | static vector get_isa_canonical_reg_list() {RAISE_EXCEPTION("inst::get_isa_canonical_reg_list");} 62 | string opcode_to_str(int) const {RAISE_EXCEPTION("inst::opcode_to_str");} 63 | op_t get_max_operand_val(int op_index, int inst_index = 0) const {RAISE_EXCEPTION("inst::get_max_operand_val");} 64 | op_t get_min_operand_val(int op_index, int inst_index = 0) const {RAISE_EXCEPTION("inst::get_min_operand_val");} 65 | int get_jmp_dis() const {RAISE_EXCEPTION("inst::get_jmp_dis");} 66 | // insert all jmp opcode in jmp_set, used by proposals.cc to 67 | // avoid jumps in the last line of the program 68 | void insert_jmp_opcodes(unordered_set& jmp_set) const {RAISE_EXCEPTION("inst::insert_jmp_opcodes");} 69 | int inst_output_opcode_type() const {RAISE_EXCEPTION("inst::inst_output_opcode_type");} 70 | int inst_output() const {RAISE_EXCEPTION("inst::inst_output");} 71 | bool is_reg(int op_index) const {RAISE_EXCEPTION("inst::is_reg");} 72 | // If ISA allows an implicit register, return the register, else return -1 73 | int implicit_ret_reg() const {RAISE_EXCEPTION("inst::implicit_ret_reg");} 74 | void set_as_nop_inst() {RAISE_EXCEPTION("inst::set_as_nop_inst");} 75 | unsigned int get_input_reg() const {RAISE_EXCEPTION("inst::get_input_reg");} 76 | int get_num_operands() const {RAISE_EXCEPTION("inst::get_num_operands");} 77 | int get_insn_num_regs() const {RAISE_EXCEPTION("inst::get_insn_num_regs");} 78 | int get_opcode_type() const {RAISE_EXCEPTION("inst::get_opcode_type");} 79 | // smt 80 | // return SMT for the given OP_OTHERS type instruction, other types return false 81 | z3::expr smt_inst(smt_var& sv, unsigned int block = 0) const {RAISE_EXCEPTION("inst::smt_inst");} 82 | // return SMT for the given OP_COND_JMP type instruction, other types return false 83 | z3::expr smt_inst_jmp(smt_var& sv) const {RAISE_EXCEPTION("inst::smt_inst_jmp");} 84 | static z3::expr smt_set_pre(z3::expr input, smt_var& sv) {RAISE_EXCEPTION("inst::smt_set_pre");} 85 | bool is_cfg_basic_block_end() const {RAISE_EXCEPTION("inst::is_cfg_basic_block_end");} 86 | }; 87 | 88 | struct instHash { 89 | size_t operator()(const inst_base &x) const; 90 | }; 91 | 92 | /* inst.cc should support */ 93 | // int num_real_instructions(const inst* program, int length); 94 | // void interpret(inout_t& output, inst* program, int length, prog_state &ps, inout_t& output); 95 | 96 | -------------------------------------------------------------------------------- /src/isa/inst_header.h: -------------------------------------------------------------------------------- 1 | #if ISA_TOY_ISA 2 | #include "../../src/isa/toy-isa/inst_var.h" 3 | #include "../../src/isa/toy-isa/inst.h" 4 | #elif ISA_EBPF 5 | #include "../../src/isa/ebpf/inst_var.h" 6 | #include "../../src/isa/ebpf/inst.h" 7 | #include "../../src/isa/ebpf/canonicalize.h" 8 | #endif 9 | -------------------------------------------------------------------------------- /src/isa/inst_header_basic.h: -------------------------------------------------------------------------------- 1 | #if ISA_TOY_ISA 2 | #include "../../src/isa/toy-isa/inst_var.h" 3 | #include "../../src/isa/toy-isa/inst.h" 4 | #elif ISA_EBPF 5 | #include "../../src/isa/ebpf/inst_var.h" 6 | #include "../../src/isa/ebpf/inst.h" 7 | #endif 8 | -------------------------------------------------------------------------------- /src/isa/inst_var.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "inst_var.h" 3 | 4 | using namespace std; 5 | 6 | z3::context smt_c; 7 | 8 | z3::expr string_to_expr(string s) { 9 | if (s == "true") { 10 | return smt_c.bool_val(true); 11 | } else if (s == "false") { 12 | return smt_c.bool_val(false); 13 | } 14 | return smt_c.bv_const(s.c_str(), NUM_REG_BITS); 15 | } 16 | 17 | z3::expr to_bool_expr(string s) { 18 | return smt_c.bool_const(s.c_str()); 19 | } 20 | 21 | z3::expr to_expr(int64_t x, unsigned sz) { 22 | return smt_c.bv_val(x, sz); 23 | } 24 | 25 | z3::expr to_expr(uint64_t x, unsigned sz) { 26 | return smt_c.bv_val(x, sz); 27 | } 28 | 29 | z3::expr to_expr(int32_t x, unsigned sz) { 30 | return smt_c.bv_val(x, sz); 31 | } 32 | 33 | z3::expr to_expr(string s, unsigned sz) { 34 | return smt_c.bv_const(s.c_str(), sz); 35 | } 36 | 37 | // use dfs 38 | bool dag::is_path_a2b(unsigned int a, unsigned int b) { 39 | if (a == b) return true; 40 | for (int i = 0; i < out_edges_list[a].size(); i++) { 41 | if (is_path_a2b(out_edges_list[a][i], b)) return true; 42 | } 43 | return false; 44 | } 45 | 46 | // check whether there is a way from a to b that does not go through c 47 | // use dfs 48 | bool dag::is_path_a2b_without_c(unsigned int a, unsigned int b, unsigned int c) { 49 | if (a == c) return false; // does not go through c 50 | if (a == b) return true; 51 | for (int i = 0; i < out_edges_list[a].size(); i++) { 52 | unsigned int node = out_edges_list[a][i]; 53 | if (is_path_a2b_without_c(node, b, c)) return true; 54 | } 55 | return false; 56 | } 57 | 58 | // 1. check whether there is a way from root to a that goes through b 59 | // since there is a way from root to b, just need to check whether 60 | // there is a way from b to a 61 | // 2. check whether there is a way from root to a that does not go through b 62 | int dag::is_b_on_root2a_path(unsigned int a, unsigned int b) { 63 | assert(a < out_edges_list.size()); 64 | assert(b < out_edges_list.size()); 65 | if (! is_path_a2b(b, a)) return INT_false; 66 | if (is_path_a2b_without_c(root, a, b)) return INT_uncertain; 67 | return INT_true; 68 | } 69 | 70 | ostream& operator<<(ostream& out, const dag& d) { 71 | out << "dag: " << endl 72 | << "root: " << d.root << endl; 73 | out << "edge: " << endl; 74 | for (int i = 0; i < d.out_edges_list.size(); i++) { 75 | out << " " << i << ", out: "; 76 | for (int j = 0; j < d.out_edges_list[i].size(); j++) { 77 | out << d.out_edges_list[i][j] << " "; 78 | } 79 | out << endl; 80 | } 81 | return out; 82 | } 83 | 84 | smt_var_base::smt_var_base() { 85 | path_cond_id = 0; 86 | } 87 | 88 | smt_var_base::smt_var_base(unsigned int prog_id, unsigned int node_id, unsigned int num_regs) { 89 | path_cond_id = 0; 90 | init(prog_id, node_id, num_regs); 91 | } 92 | 93 | smt_var_base::~smt_var_base() { 94 | } 95 | 96 | // reset register related variables 97 | // designed for different basic blocks of one program, 98 | void smt_var_base::set_new_node_id(unsigned int node_id, const vector& nodes_in, 99 | const vector& node_in_pc_list, 100 | const vector>& nodes_in_regs) { 101 | for (int i = 0; i < reg_cur_id.size(); i++) { 102 | reg_cur_id[i] = 0; 103 | } 104 | size_t pos = _name.find('_'); 105 | assert(pos != string::npos); // string::npos: flag of not found 106 | string prog_id_str = _name.substr(0, pos); 107 | _name = prog_id_str + "_" + to_string(node_id); 108 | string name_prefix = "r_" + _name + "_"; 109 | for (size_t i = 0; i < reg_var.size(); i++) { 110 | string name = name_prefix + to_string(i) + "_0"; 111 | reg_var[i] = string_to_expr(name); 112 | } 113 | } 114 | 115 | z3::expr smt_var_base::update_path_cond() { 116 | path_cond_id++; 117 | string name = "pc_" + _name + "_" + to_string(path_cond_id); 118 | return to_bool_expr(name); 119 | } 120 | 121 | z3::expr smt_var_base::update_reg_var(unsigned int reg_id) { 122 | reg_cur_id[reg_id]++; 123 | string name = "r_" + _name + "_" + to_string(reg_id) \ 124 | + "_" + to_string(reg_cur_id[reg_id]); 125 | reg_var[reg_id] = string_to_expr(name); 126 | return get_cur_reg_var(reg_id); 127 | } 128 | 129 | z3::expr smt_var_base::get_cur_reg_var(unsigned int reg_id) { 130 | return reg_var[reg_id]; 131 | } 132 | 133 | z3::expr smt_var_base::get_init_reg_var(unsigned int reg_id) { 134 | string name = "r_" + _name + "_" + to_string(reg_id) + "_0"; 135 | return string_to_expr(name); 136 | } 137 | 138 | void smt_var_base::init(unsigned int prog_id, unsigned int node_id, unsigned int num_regs, unsigned int n_blocks) { 139 | reg_cur_id.resize(num_regs, 0); 140 | _name = to_string(prog_id) + "_" + to_string(node_id); 141 | string name_prefix = "r_" + _name + "_"; 142 | for (size_t i = 0; i < num_regs; i++) { 143 | string name = name_prefix + to_string(i) + "_0"; 144 | reg_var.push_back(string_to_expr(name)); 145 | } 146 | } 147 | 148 | void smt_var_base::clear() { 149 | for (size_t i = 0; i < reg_var.size(); i++) { 150 | reg_cur_id[i] = 0; 151 | string name = "r_" + _name + "_" + to_string(i) + "_0"; 152 | reg_var[i] = string_to_expr(name); 153 | } 154 | } 155 | 156 | void prog_state_base::print() const { 157 | for (int i = 0; i < _regs.size(); i++) { 158 | cout << "Register " << i << " " << _regs[i] << endl; 159 | } 160 | }; 161 | 162 | void prog_state_base::clear() { 163 | _pc = 0; 164 | for (int i = 0; i < _regs.size(); i++) { 165 | _regs[i] = 0; 166 | } 167 | }; 168 | -------------------------------------------------------------------------------- /src/isa/inst_var.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "z3++.h" 6 | #include "../../src/utils.h" 7 | 8 | using namespace std; 9 | 10 | // For most applications this is sufficient. An application may use multiple Z3 contexts. 11 | // Objects created in one context cannot be used in another one. 12 | // reference: https://github.com/Z3Prover/z3/blob/master/src/api/python/z3/z3.py 13 | extern z3::context smt_c; 14 | 15 | #define Z3_true string_to_expr("true") 16 | #define Z3_false string_to_expr("false") 17 | #define INT_true 1 18 | #define INT_false 0 19 | #define INT_uncertain -1 20 | 21 | // convert string s into expr e 22 | // if e = "true"/"false" the type of e is bool_val 23 | // else the type of e is int_const 24 | z3::expr string_to_expr(string s); 25 | z3::expr to_bool_expr(string s); 26 | z3::expr to_expr(int64_t x, unsigned sz = NUM_REG_BITS); 27 | z3::expr to_expr(uint64_t x, unsigned sz = NUM_REG_BITS); 28 | z3::expr to_expr(int32_t x, unsigned sz = NUM_REG_BITS); 29 | z3::expr to_expr(string s, unsigned sz = NUM_REG_BITS); 30 | 31 | class dag { // directed acyclic graph 32 | private: 33 | bool is_path_a2b(unsigned int a, unsigned int b); 34 | bool is_path_a2b_without_c(unsigned int a, unsigned int b, unsigned int c); 35 | public: 36 | unsigned int root; 37 | vector> out_edges_list; // outgoing edges, list index: node id 38 | dag(unsigned int n_nodes = 1, unsigned int root_node = 0) {init(n_nodes, root_node);} 39 | void init(unsigned int n_nodes, unsigned int root_node) { 40 | out_edges_list.resize(n_nodes); 41 | root = root_node; 42 | } 43 | void add_edge_a2b(unsigned int a, unsigned int b) { 44 | out_edges_list[a].push_back(b); 45 | } 46 | int is_b_on_root2a_path(unsigned int a, unsigned int b); 47 | friend ostream& operator<<(ostream& out, const dag& d); 48 | }; 49 | 50 | // SMT Variable format 51 | // register: r_[prog_id]_[node_id]_[reg_id]_[version_id] 52 | class smt_var_base { 53 | protected: 54 | unsigned int path_cond_id; 55 | // _name: [prog_id]_[node_id] 56 | string _name; 57 | // store the curId 58 | vector reg_cur_id; 59 | vector reg_var; 60 | public: 61 | dag pgm_dag; 62 | smt_var_base(); 63 | // 1. Convert prog_id and node_id into _name, that is string([prog_id]_[node_id]) 64 | // 2. Initialize reg_val[i] = r_[_name]_0, i = 0, ..., num_regs 65 | smt_var_base(unsigned int prog_id, unsigned int node_id, unsigned int num_regs); 66 | ~smt_var_base(); 67 | void set_new_node_id(unsigned int node_id, const vector& nodes_in, 68 | const vector& node_in_pc_list, 69 | const vector>& nodes_in_regs); 70 | z3::expr update_path_cond(); 71 | // inital value for [versionId] is 0, and increases when updated 72 | z3::expr update_reg_var(unsigned int reg_id); 73 | z3::expr get_cur_reg_var(unsigned int reg_id); 74 | z3::expr get_init_reg_var(unsigned int reg_id); 75 | void init() {} 76 | void init(unsigned int prog_id, unsigned int node_id, unsigned int num_regs, unsigned int n_blocks = 1); 77 | void clear(); 78 | }; 79 | 80 | class prog_state_base { 81 | int _pc = 0; /* Assume only straight line code execution for now */ 82 | public: 83 | vector _regs; /* assume only registers for now */ 84 | void init() {} 85 | void print() const; 86 | void clear(); 87 | }; 88 | 89 | class inout_t_base { 90 | public: 91 | void clear() {RAISE_EXCEPTION("inout_t::clear()");} 92 | void init() {RAISE_EXCEPTION("inout_t::init()");} 93 | bool operator==(const inout_t_base &rhs) const {RAISE_EXCEPTION("inout_t::operator==");} 94 | friend ostream& operator<<(ostream& out, const inout_t_base& x) {RAISE_EXCEPTION("inout_t::operator<<");} 95 | }; 96 | 97 | // exposed APIs 98 | // void get_cmp_lists(vector& val_list1, vector& val_list2, 99 | // inout_t& output1, inout_t& output2); 100 | /* Generate the random inputs and store them in the input paramenter `inputs`. 101 | Parameters `reg_min` and `reg_max` are the minimum and maximum values of the input register. 102 | This limitation needs to be generalized later. 103 | */ 104 | // void gen_random_input(vector& inputs, reg_t reg_min, reg_t reg_max); 105 | -------------------------------------------------------------------------------- /src/isa/inst_var_test.cc: -------------------------------------------------------------------------------- 1 | #include "inst_var.h" 2 | 3 | void test1() { 4 | cout << "Test 1: check class dag functions" << endl; 5 | cout << "check is_b_on_root2a_path()" << endl; 6 | unsigned int n_nodes = 6, root = 0; 7 | dag g1(n_nodes, root); 8 | vector> out_edges = {{1, 2}, {3}, {3, 5}, {4}, {}, {}}; 9 | for (int i = 0; i < out_edges.size(); i++) { 10 | for (int j = 0; j < out_edges[i].size(); j++) { 11 | g1.add_edge_a2b(i, out_edges[i][j]); 12 | } 13 | } 14 | int test_count = 0; 15 | // check root must be on the path of root to each node 16 | for (int i = 0; i < n_nodes; i++) { 17 | print_test_res(g1.is_b_on_root2a_path(i, root) == INT_true, to_string(++test_count)); 18 | } 19 | for (int i = 1; i < n_nodes; i++) { 20 | print_test_res(g1.is_b_on_root2a_path(root, i) == INT_false, to_string(++test_count)); 21 | } 22 | print_test_res(g1.is_b_on_root2a_path(3, 1) == INT_uncertain, to_string(++test_count)); 23 | print_test_res(g1.is_b_on_root2a_path(3, 2) == INT_uncertain, to_string(++test_count)); 24 | print_test_res(g1.is_b_on_root2a_path(3, 5) == INT_false, to_string(++test_count)); 25 | print_test_res(g1.is_b_on_root2a_path(5, 2) == INT_true, to_string(++test_count)); 26 | } 27 | 28 | int main() { 29 | test1(); 30 | return 0; 31 | } 32 | -------------------------------------------------------------------------------- /src/isa/prog.cc: -------------------------------------------------------------------------------- 1 | #include "prog.h" 2 | 3 | using namespace std; 4 | 5 | // TODO: find canonical way to invoke one constructor from another 6 | prog::prog(const prog& other) { 7 | inst_list = new inst[inst::max_prog_len]; 8 | for (int i = 0; i < inst::max_prog_len; i++) { 9 | inst_list[i] = other.inst_list[i]; 10 | } 11 | freq_count = other.freq_count; 12 | _error_cost = other._error_cost; 13 | _perf_cost = other._perf_cost; 14 | } 15 | 16 | prog::prog(inst* instructions) { 17 | inst_list = new inst[inst::max_prog_len]; 18 | for (int i = 0; i < inst::max_prog_len; i++) { 19 | inst_list[i] = instructions[i]; 20 | } 21 | freq_count = 0; 22 | _error_cost = -1; 23 | _perf_cost = -1; 24 | } 25 | 26 | void prog::reset_vals() { 27 | freq_count = 0; 28 | _error_cost = -1; 29 | _perf_cost = -1; 30 | } 31 | 32 | prog::prog() { 33 | freq_count = 0; 34 | _error_cost = -1; 35 | _perf_cost = -1; 36 | } 37 | 38 | prog::~prog() { 39 | delete []inst_list; 40 | } 41 | 42 | void prog::print() const { 43 | for (int i = 0; i < inst::max_prog_len; i++) { 44 | cout << i << ": "; 45 | inst_list[i].print(); 46 | } 47 | cout << endl; 48 | } 49 | 50 | bool prog::operator==(const prog &x) const { 51 | for (int i = 0; i < inst::max_prog_len; i++) { 52 | if (! (inst_list[i] == x.inst_list[i])) return false; 53 | } 54 | return true; 55 | } 56 | 57 | void prog::set_vals(const prog &x) { 58 | freq_count = x.freq_count; 59 | _error_cost = x._error_cost; 60 | _perf_cost = x._perf_cost; 61 | } 62 | 63 | void prog::set_error_cost(double cost) { 64 | _error_cost = cost; 65 | } 66 | 67 | void prog::set_perf_cost(double cost) { 68 | _perf_cost = cost; 69 | } 70 | 71 | int prog::to_rel_bv(const prog &p) const { 72 | int bv = 0; 73 | for (int i = 0; i < inst::max_prog_len; i++) { 74 | if (inst_list[i] == p.inst_list[i]) { 75 | bv |= 1 << (inst::max_prog_len - 1 - i); 76 | } 77 | } 78 | return bv; 79 | } 80 | 81 | int prog::to_rel_bv(const vector &ps) const { 82 | int best = 0; 83 | int count = 0; 84 | for (int i = 0; i < ps.size(); i++) { 85 | int bv = to_rel_bv(ps[i]); 86 | int bv_count = pop_count_asm(bv); 87 | if (bv_count > count) { 88 | count = bv_count; 89 | best = bv; 90 | } 91 | } 92 | return best; 93 | } 94 | 95 | void prog::to_abs_bv(vector& bv) const { 96 | for (int i = 0; i < inst::max_prog_len; i++) { 97 | inst_list[i].to_abs_bv(bv); 98 | } 99 | } 100 | 101 | bool prog::if_ret_exists(int start, int end) const { 102 | for (int i = start; i < end; i++) { 103 | if (inst_list[i].get_opcode_type() == OP_RET) { 104 | return true; 105 | } 106 | } 107 | return false; 108 | } 109 | 110 | // if reg 0 is NOT used but implicit RETX 0 instruction is needed, reg 0 cannot be used 111 | // case 1: no RETs instruction(test6 insts41) 112 | // case 2: has RETs instruction, but JMP makes implicit RETX 0 instruction needed(test6 insts42) 113 | void prog::update_map_if_implicit_ret_r0_needed(unordered_map &map_before_after) const { 114 | bool can_use_reg0 = true; 115 | // check whether there is RETs 116 | bool ret_exists = if_ret_exists(0, inst::max_prog_len); 117 | // step 1: check whether reg0 can be used 118 | if (! ret_exists) { 119 | // no RETs instruction 120 | can_use_reg0 = false; 121 | } else { 122 | // has RETs instruction, check jmp distance 123 | int start_index_chk_ret = 0; 124 | for (int i = 0; i < inst::max_prog_len; i++) { 125 | if ((inst_list[i].get_opcode_type() == OP_COND_JMP) && 126 | ((i + 1 + inst_list[i].get_jmp_dis()) > start_index_chk_ret)) { 127 | start_index_chk_ret = i + 1 + inst_list[i].get_jmp_dis(); 128 | } 129 | } 130 | ret_exists = if_ret_exists(start_index_chk_ret, inst::max_prog_len); 131 | if (! ret_exists) { 132 | can_use_reg0 = false; 133 | } 134 | } 135 | 136 | // step 2: if reg0 cannot be used, update the map_before_after 137 | if (! can_use_reg0) { 138 | for (auto it = map_before_after.begin(); it != map_before_after.end(); it++) { 139 | it->second++; 140 | } 141 | } 142 | } 143 | 144 | void prog::canonicalize() { 145 | unordered_map map_before_after; // key: reg_id before, val: reg_id after 146 | vector reg_list; 147 | // traverse all instructions and once there is a new reg_id(before), assign it a reg_id(after) 148 | // store reg_id(before) and reg_id(after) into map 149 | vector available_reg_list = inst::get_isa_canonical_reg_list(); 150 | int count = 0; 151 | for (int i = 0; i < inst::max_prog_len; i++) { 152 | reg_list = inst_list[i].get_canonical_reg_list(); 153 | for (size_t j = 0; j < reg_list.size(); j++) { 154 | int cur_reg = reg_list[j]; 155 | if (map_before_after.find(cur_reg) == map_before_after.end()) { 156 | assert(count < available_reg_list.size()); 157 | map_before_after[cur_reg] = available_reg_list[count]; 158 | count++; 159 | } 160 | } 161 | } 162 | if (map_before_after.size() == 0) return; 163 | 164 | // replace reg_ids(before) with reg_ids(after) for all instructions 165 | for (int i = 0; i < inst::max_prog_len; i++) { 166 | for (int j = 0; j < MAX_OP_LEN; j++) { 167 | if (inst_list[i].is_reg(j)) { 168 | auto it = map_before_after.find(inst_list[i].get_operand(j)); 169 | if (it != map_before_after.end()) 170 | inst_list[i].set_operand(j, it->second); 171 | } 172 | } 173 | } 174 | } 175 | 176 | int prog::num_real_instructions() const { 177 | return ::num_real_instructions(inst_list, inst::max_prog_len); 178 | } 179 | 180 | double prog::instructions_runtime() const { 181 | double runtime = 0; 182 | for (int i = 0; i < inst::max_prog_len; i++) { 183 | runtime += inst_list[i].get_runtime(); 184 | } 185 | return runtime; 186 | } 187 | 188 | double prog::instructions_runtime(int insn_s, int insn_e) const { 189 | double runtime = 0; 190 | for (int i = insn_s; i <= insn_e; i++) { 191 | runtime += inst_list[i].get_runtime(); 192 | } 193 | return runtime; 194 | } 195 | 196 | void prog::interpret(inout_t& output, prog_state &ps, const inout_t& input) const { 197 | return ::interpret(output, inst_list, inst::max_prog_len, ps, input); 198 | } 199 | 200 | size_t progHash::operator()(const prog &x) const { 201 | size_t hval = 0; 202 | for (int i = 0; i < inst::max_prog_len; i++) { 203 | hval = hval ^ (instHash()(x.inst_list[i]) << (i % 4)); 204 | } 205 | return hval; 206 | } 207 | 208 | top_k_progs::top_k_progs(unsigned int k_val) { 209 | assert(k_val > 0); 210 | k = k_val; 211 | cout << "[top_k_progs] set k = " << k << endl; 212 | max_perf_cost = numeric_limits::max(); 213 | max_perf_cost_id = -1; 214 | } 215 | 216 | top_k_progs::~top_k_progs() { 217 | progs.clear(); 218 | max_perf_cost = numeric_limits::max(); 219 | max_perf_cost_id = -1; 220 | } 221 | 222 | // check whether program p is in the progs 223 | bool top_k_progs::can_find(prog* p) { 224 | for (int i = 0; i < progs.size(); i++) { 225 | if (progs[i]->_perf_cost != p->_perf_cost) continue; 226 | // further check whether two programs are the same 227 | if (*(progs[i]) == *p) { 228 | return true; 229 | } 230 | } 231 | return false; 232 | } 233 | 234 | void top_k_progs::insert_without_check(prog* p) { 235 | if (logger.is_print_level(LOGGER_DEBUG)) { 236 | cout << "insert a new program in top_k_progs" << endl; 237 | } 238 | prog* p_copy = new prog(*p); 239 | if (progs.size() < k) progs.push_back(p_copy); 240 | else { 241 | delete progs[max_perf_cost_id]; 242 | progs[max_perf_cost_id] = p_copy; 243 | } 244 | 245 | max_perf_cost = progs[0]->_perf_cost; 246 | max_perf_cost_id = 0; 247 | for (int i = 1; i < progs.size(); i++) { 248 | if (progs[i]->_perf_cost > max_perf_cost) { 249 | max_perf_cost = progs[i]->_perf_cost; 250 | max_perf_cost_id = i; 251 | } 252 | } 253 | } 254 | 255 | void top_k_progs::insert(prog* p) { 256 | if (p->_error_cost != 0) return; 257 | if (progs.size() < k) { // check whether this program is in the progs 258 | if (can_find(p)) return; 259 | insert_without_check(p); 260 | } else { 261 | assert(progs.size() != 0); 262 | if (p->_perf_cost >= max_perf_cost) return; 263 | if (can_find(p)) return; 264 | insert_without_check(p); 265 | } 266 | } 267 | 268 | void top_k_progs::clear() { 269 | for (int i = 0; i < progs.size(); i++) { 270 | delete progs[i]; // release each prog's space 271 | } 272 | progs.clear(); 273 | } 274 | 275 | bool top_k_progs_sort_function(prog* x, prog* y) { 276 | return (x->_perf_cost < y->_perf_cost); 277 | } 278 | 279 | void top_k_progs::sort() { 280 | if (progs.size() == 0) return; 281 | std::sort(progs.begin(), progs.end(), top_k_progs_sort_function); 282 | max_perf_cost = progs[0]->_perf_cost; 283 | max_perf_cost_id = 0; 284 | for (int i = 1; i < progs.size(); i++) { 285 | if (progs[i]->_perf_cost > max_perf_cost) { 286 | max_perf_cost = progs[i]->_perf_cost; 287 | max_perf_cost_id = i; 288 | } 289 | } 290 | } 291 | -------------------------------------------------------------------------------- /src/isa/prog.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "../../src/utils.h" 8 | #include "../../src/isa/inst_header.h" 9 | 10 | using namespace std; 11 | 12 | class prog { 13 | public: 14 | inst* inst_list; 15 | int freq_count; 16 | double _error_cost; 17 | double _perf_cost; 18 | prog(const prog& other); 19 | prog(inst* instructions); 20 | prog(); 21 | void print() const; 22 | ~prog(); 23 | bool operator==(const prog &x) const; 24 | void reset_vals(); 25 | void set_vals(const prog &x); 26 | void set_error_cost(double cost); 27 | void set_perf_cost(double cost); 28 | int to_rel_bv(const prog &p) const; 29 | int to_rel_bv(const vector &ps) const; 30 | void to_abs_bv(vector& bv) const; 31 | bool if_ret_exists(int start, int end) const; 32 | void update_map_if_implicit_ret_r0_needed(unordered_map &map_before_after) const; 33 | void canonicalize(); 34 | int num_real_instructions() const; 35 | double instructions_runtime() const; 36 | double instructions_runtime(int insn_s, int insn_e) const; 37 | void interpret(inout_t& output, prog_state &ps, const inout_t& input) const; 38 | }; 39 | 40 | struct progHash { 41 | size_t operator()(const prog &x) const; 42 | }; 43 | 44 | // top_k_progs: performance cost top k different programs with zero error cost 45 | // make sure k >= 1 46 | // assume k is a small number 47 | class top_k_progs { 48 | private: 49 | double max_perf_cost; 50 | int max_perf_cost_id; 51 | unsigned int k; 52 | bool can_find(prog* p); 53 | void insert_without_check(prog* p); 54 | public: 55 | // `greater` makes progs in descending order of keys 56 | // key: perf cost, value.first: prog hash value, value.second: prog pointer 57 | vector progs; 58 | top_k_progs(unsigned int k_val); 59 | ~top_k_progs(); 60 | void insert(prog* p); // insert p if p is one of top k 61 | void clear(); 62 | void sort(); 63 | }; 64 | -------------------------------------------------------------------------------- /src/isa/prog_test_ebpf.cc: -------------------------------------------------------------------------------- 1 | #include "../../src/utils.h" 2 | #include "prog.h" 3 | 4 | void test1() { 5 | // test r10 (value of frame pointer), r1 won't be modified in canonicalize() 6 | inst insts1[7] = {inst(STXW, 10, -4, 1), 7 | inst(LDXW, 0, 10, -4), 8 | inst(EXIT), 9 | inst(), 10 | inst(), 11 | inst(), 12 | inst(), 13 | }; 14 | prog p11(insts1); 15 | prog p12(insts1); 16 | p11.canonicalize(); 17 | print_test_res(p11 == p12, "canonicalize 1"); 18 | return; 19 | 20 | // test input r1 won't be modified 21 | inst insts21[7] = {inst(MOV64XC, 3, 0x1), 22 | inst(MOV64XC, 2, 0x2), 23 | inst(MOV64XY, 0, 1), 24 | inst(EXIT), 25 | inst(), 26 | inst(), 27 | inst(), 28 | }; 29 | inst insts22[7] = {inst(MOV64XC, 0, 0x1), 30 | inst(MOV64XC, 2, 0x2), 31 | inst(MOV64XY, 3, 1), 32 | inst(EXIT), 33 | inst(), 34 | inst(), 35 | inst(), 36 | }; 37 | prog p21(insts21); 38 | prog p22(insts22); 39 | p21.canonicalize(); 40 | print_test_res(p21 == p22, "canonicalize 2"); 41 | 42 | // test when r1 is not used, r1 can be used in register renaming 43 | inst insts31[7] = {inst(MOV64XC, 3, 0x1), 44 | inst(MOV64XC, 2, 0x2), 45 | inst(MOV64XY, 0, 2), 46 | inst(EXIT), 47 | inst(), 48 | inst(), 49 | inst(), 50 | }; 51 | inst insts32[7] = {inst(MOV64XC, 0, 0x1), 52 | inst(MOV64XC, 1, 0x2), 53 | inst(MOV64XY, 2, 1), 54 | inst(EXIT), 55 | inst(), 56 | inst(), 57 | inst(), 58 | }; 59 | prog p31(insts31); 60 | prog p32(insts32); 61 | p31.canonicalize(); 62 | print_test_res(p31 == p32, "canonicalize 3"); 63 | } 64 | 65 | bool check_top_k_progs_res(vector perf_costs, top_k_progs& topk_progs) { 66 | if (perf_costs.size() != topk_progs.progs.size()) return false; 67 | topk_progs.sort(); 68 | for (int i = 0; i < topk_progs.progs.size(); i++) { 69 | if (topk_progs.progs[i]->_error_cost != 0) return false; 70 | if (topk_progs.progs[i]->_perf_cost != perf_costs[i]) return false; 71 | } 72 | return true; 73 | } 74 | 75 | void test2() { 76 | cout << "test2: test top_k_progs" << endl; 77 | top_k_progs topk_progs1(1), topk_progs2(3); 78 | inst p1[inst::max_prog_len]; 79 | prog* pgm = new prog(p1); 80 | pgm->_error_cost = 0; 81 | pgm->_perf_cost = 1.25; 82 | 83 | topk_progs1.insert(pgm); 84 | topk_progs2.insert(pgm); 85 | vector perf_costs_1, perf_costs_2; 86 | perf_costs_1 = {1.25}; 87 | perf_costs_2 = {1.25}; 88 | print_test_res(check_top_k_progs_res(perf_costs_1, topk_progs1), "1.1"); 89 | print_test_res(check_top_k_progs_res(perf_costs_2, topk_progs2), "1.2"); 90 | 91 | topk_progs1.insert(pgm); 92 | topk_progs2.insert(pgm); 93 | print_test_res(check_top_k_progs_res(perf_costs_1, topk_progs1), "2.1"); 94 | print_test_res(check_top_k_progs_res(perf_costs_2, topk_progs2), "2.2"); 95 | 96 | pgm->inst_list[0] = inst(MOV64XC, 0, 0); 97 | pgm->_error_cost = 0; 98 | pgm->_perf_cost = 1.25; 99 | topk_progs1.insert(pgm); 100 | topk_progs2.insert(pgm); 101 | perf_costs_1 = {1.25}; 102 | perf_costs_2 = {1.25, 1.25}; 103 | print_test_res(check_top_k_progs_res(perf_costs_1, topk_progs1), "3.1"); 104 | print_test_res(check_top_k_progs_res(perf_costs_2, topk_progs2), "3.2"); 105 | 106 | pgm->inst_list[1] = inst(MOV64XC, 0, 1); 107 | pgm->_error_cost = 1; 108 | pgm->_perf_cost = 1.25; 109 | topk_progs1.insert(pgm); 110 | topk_progs2.insert(pgm); 111 | perf_costs_1 = {1.25}; 112 | perf_costs_2 = {1.25, 1.25}; 113 | print_test_res(check_top_k_progs_res(perf_costs_1, topk_progs1), "4.1"); 114 | print_test_res(check_top_k_progs_res(perf_costs_2, topk_progs2), "4.2"); 115 | 116 | pgm->inst_list[2] = inst(MOV64XC, 1, 0); 117 | pgm->_error_cost = 0; 118 | pgm->_perf_cost = 1.37; 119 | topk_progs1.insert(pgm); 120 | topk_progs2.insert(pgm); 121 | perf_costs_1 = {1.25}; 122 | perf_costs_2 = {1.25, 1.25, 1.37}; 123 | print_test_res(check_top_k_progs_res(perf_costs_1, topk_progs1), "5.1"); 124 | print_test_res(check_top_k_progs_res(perf_costs_2, topk_progs2), "5.2"); 125 | 126 | pgm->inst_list[3] = inst(MOV64XC, 1, 0); 127 | pgm->_error_cost = 0; 128 | pgm->_perf_cost = 0.81; 129 | topk_progs1.insert(pgm); 130 | topk_progs2.insert(pgm); 131 | perf_costs_1 = {0.81}; 132 | perf_costs_2 = {0.81, 1.25, 1.25}; 133 | print_test_res(check_top_k_progs_res(perf_costs_1, topk_progs1), "6.1"); 134 | print_test_res(check_top_k_progs_res(perf_costs_2, topk_progs2), "6.2"); 135 | } 136 | 137 | int main() { 138 | test1(); 139 | test2(); 140 | return 0; 141 | } 142 | -------------------------------------------------------------------------------- /src/isa/toy-isa/inst.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "z3++.h" 6 | #include "../../../src/isa/inst_var.h" 7 | #include "../../../src/isa/inst.h" 8 | #include "inst_var.h" 9 | #include "inst_codegen.h" 10 | 11 | using namespace std; 12 | 13 | static constexpr int MAX_PROG_LEN = 7; 14 | // Max number of operands in one instruction 15 | static constexpr int MAX_OP_LEN = 3; 16 | 17 | // Number of bits of a single opcode or operand 18 | static constexpr int OP_NUM_BITS = 5; 19 | // Number of bits of a single instruction 20 | static constexpr int INST_NUM_BITS = 20; 21 | 22 | // Instruction opcodes 23 | enum OPCODES { 24 | NOP = 0, 25 | ADDXY, 26 | MOVXC, 27 | RETX, 28 | RETC, 29 | JMP, 30 | JMPEQ, 31 | JMPGT, 32 | JMPGE, 33 | JMPLT, 34 | JMPLE, 35 | MAXC, 36 | MAXX, 37 | NUM_INSTR, // Number of opcode types 38 | }; 39 | 40 | static constexpr int num_operands[NUM_INSTR] = { 41 | [NOP] = 0, 42 | [ADDXY] = 2, 43 | [MOVXC] = 2, 44 | [RETX] = 1, 45 | [RETC] = 1, 46 | [JMP] = 1, 47 | [JMPEQ] = 3, 48 | [JMPGT] = 3, 49 | [JMPGE] = 3, 50 | [JMPLT] = 3, 51 | [JMPLE] = 3, 52 | [MAXC] = 2, 53 | [MAXX] = 2, 54 | }; 55 | 56 | static constexpr int insn_num_regs[NUM_INSTR] = { 57 | [NOP] = 0, 58 | [ADDXY] = 2, 59 | [MOVXC] = 1, 60 | [RETX] = 1, 61 | [RETC] = 0, 62 | [JMP] = 0, 63 | [JMPEQ] = 2, 64 | [JMPGT] = 2, 65 | [JMPGE] = 2, 66 | [JMPLT] = 2, 67 | [JMPLE] = 2, 68 | [MAXC] = 1, 69 | [MAXX] = 2, 70 | }; 71 | 72 | static constexpr int opcode_type[NUM_INSTR] = { 73 | [NOP] = OP_NOP, 74 | [ADDXY] = OP_OTHERS, 75 | [MOVXC] = OP_OTHERS, 76 | [RETX] = OP_RET, 77 | [RETC] = OP_RET, 78 | [JMP] = OP_UNCOND_JMP, 79 | [JMPEQ] = OP_COND_JMP, 80 | [JMPGT] = OP_COND_JMP, 81 | [JMPGE] = OP_COND_JMP, 82 | [JMPLT] = OP_COND_JMP, 83 | [JMPLE] = OP_COND_JMP, 84 | [MAXC] = OP_OTHERS, 85 | [MAXX] = OP_OTHERS, 86 | }; 87 | 88 | // Max value for immediate operand 89 | static constexpr int MAX_CONST = 20; 90 | // Operand types for instructions 91 | static constexpr int OP_UNUSED = 0; 92 | static constexpr int OP_REG = 1; 93 | static constexpr int OP_IMM = 2; 94 | static constexpr int OP_OFF = 3; 95 | 96 | /* The definitions below assume a minimum 16-bit integer data type */ 97 | #define OPTYPE(opcode, opindex) ((optable[opcode] >> ((opindex) * 5)) & 31) 98 | #define FSTOP(x) (x) 99 | #define SNDOP(x) (x << 5) 100 | #define TRDOP(x) (x << 10) 101 | #define JMP_OPS (FSTOP(OP_REG) | SNDOP(OP_REG) | TRDOP(OP_OFF)) 102 | #define UNUSED_OPS (FSTOP(OP_UNUSED) | SNDOP(OP_UNUSED) | TRDOP(OP_UNUSED)) 103 | static constexpr int optable[NUM_INSTR] = { 104 | [NOP] = UNUSED_OPS, 105 | [ADDXY] = FSTOP(OP_REG) | SNDOP(OP_REG) | TRDOP(OP_UNUSED), 106 | [MOVXC] = FSTOP(OP_REG) | SNDOP(OP_IMM) | TRDOP(OP_UNUSED), 107 | [RETX] = FSTOP(OP_REG) | SNDOP(OP_UNUSED) | TRDOP(OP_UNUSED), 108 | [RETC] = FSTOP(OP_IMM) | SNDOP(OP_UNUSED) | TRDOP(OP_UNUSED), 109 | [JMP] = FSTOP(OP_OFF) | SNDOP(OP_UNUSED) | TRDOP(OP_UNUSED), 110 | [JMPEQ] = JMP_OPS, 111 | [JMPGT] = JMP_OPS, 112 | [JMPGE] = JMP_OPS, 113 | [JMPLT] = JMP_OPS, 114 | [JMPLE] = JMP_OPS, 115 | [MAXC] = FSTOP(OP_REG) | SNDOP(OP_IMM) | TRDOP(OP_UNUSED), 116 | [MAXX] = FSTOP(OP_REG) | SNDOP(OP_REG) | TRDOP(OP_UNUSED), 117 | }; 118 | #undef FSTOP 119 | #undef SNDOP 120 | #undef TRDOP 121 | #undef JMP_OPS 122 | #undef UNUSED_OPS 123 | 124 | class inst: public inst_base { 125 | public: 126 | inst(int opcode = NOP, int arg1 = 0, int arg2 = 0, int arg3 = 0) { 127 | _args.resize(MAX_OP_LEN); 128 | _opcode = opcode; 129 | _args[0] = arg1; 130 | _args[1] = arg2; 131 | _args[2] = arg3; 132 | } 133 | inst& operator=(const inst &rhs); 134 | bool operator==(const inst &x) const; 135 | string opcode_to_str(int) const; 136 | void print() const; 137 | int get_max_operand_val(int op_index, int inst_index = 0) const; 138 | int get_min_operand_val(int op_index, int inst_index = 0) const; 139 | int get_jmp_dis() const; 140 | vector get_canonical_reg_list() const; 141 | static vector get_isa_canonical_reg_list(); 142 | void insert_jmp_opcodes(unordered_set& jmp_set) const; 143 | int inst_output_opcode_type() const; 144 | int inst_output() const; 145 | bool is_real_inst() const; 146 | bool is_reg(int op_index) const; 147 | int implicit_ret_reg() const; 148 | void set_as_nop_inst(); 149 | // return the register for storing the given input 150 | unsigned int get_input_reg() const {return 0;} 151 | int get_num_operands() const {return num_operands[_opcode];} 152 | int get_insn_num_regs() const {return insn_num_regs[_opcode];} 153 | int get_opcode_type() const {return opcode_type[_opcode];} 154 | // smt 155 | z3::expr smt_inst(smt_var& sv, unsigned int block = 0) const; 156 | z3::expr smt_inst_jmp(smt_var& sv) const; 157 | z3::expr smt_inst_end(smt_var & sv) const; 158 | static z3::expr smt_set_pre(z3::expr input, smt_var& sv); 159 | 160 | bool is_cfg_basic_block_end() const; 161 | bool is_pgm_end() const; 162 | }; 163 | 164 | int num_real_instructions(const inst* program, int length); 165 | void interpret(inout_t& output, inst* program, int length, prog_state &ps, const inout_t& input); 166 | 167 | inline int opcode_2_idx(int opcode) {return opcode;} 168 | -------------------------------------------------------------------------------- /src/isa/toy-isa/inst_codegen.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "z3++.h" 4 | #include "../../../src/utils.h" 5 | #include "../../../src/isa/inst_var.h" 6 | #include "inst_var.h" 7 | 8 | using namespace std; 9 | 10 | /* APIs exposed to the externals start */ 11 | // return (out = in) 12 | inline int compute_mov(int in, int out = 0); 13 | // return (out = in1 + in2) 14 | inline int compute_add(int in1, int in2, int out = 0); 15 | // return (out = max(in1, in2)) 16 | inline int compute_max(int in1, int in2, int out = 0); 17 | // return (out == in) 18 | inline z3::expr compute_mov(z3::expr in, z3::expr out); 19 | // return (out == in1 + in2) 20 | inline z3::expr compute_add(z3::expr in1, z3::expr in2, z3::expr out); 21 | // return (out == max(in1, in2)) 22 | inline z3::expr compute_max(z3::expr in1, z3::expr in2, z3::expr out); 23 | /* APIs exposed to the externals end */ 24 | 25 | /* Inputs in, out must be side-effect-free expressions. */ 26 | #undef MOV_EXPR 27 | #define MOV_EXPR(in, out) (out EQ in) 28 | /* Inputs in1, in2, out must be side-effect-free expressions. */ 29 | #undef ADD_EXPR 30 | #define ADD_EXPR(in1, in2, out) (out EQ in1 + in2) 31 | 32 | /* Predicate expressions capture instructions like MAX which have different 33 | * results on a register based on the evaluation of a predicate. */ 34 | /* Inputs out, pred_if, pred_else must be side-effect-free. */ 35 | #undef PRED_BINARY_EXPR 36 | #define PRED_BINARY_EXPR(out, pred_if, ret_if, ret_else) ({ \ 37 | IF_PRED_ACTION(pred_if, ret_if, out) \ 38 | CONNECTIFELSE \ 39 | ELSE_PRED_ACTION(pred_if, ret_else, out); \ 40 | }) 41 | 42 | /* Inputs in1, in2, out must be side-effect-free. */ 43 | #undef MAX_EXPR 44 | #define MAX_EXPR(in1, in2, out) (PRED_BINARY_EXPR(out, in1 > in2, in1, in2)) 45 | 46 | /* Macros for interpreter start */ 47 | // Operator macros in experssion macros for interpreter start 48 | #undef EQ 49 | #define EQ = 50 | #undef IF_PRED_ACTION 51 | #define IF_PRED_ACTION(pred, expr, var) if(pred) var EQ expr 52 | #undef CONNECTIFELSE 53 | #define CONNECTIFELSE ; 54 | #undef ELSE_PRED_ACTION 55 | #define ELSE_PRED_ACTION(pred, expr, var) else var EQ expr 56 | // Operator macros in experssion macros for interpreter end 57 | 58 | // Functions for interpreter start 59 | #undef COMPUTE_UNARY 60 | #define COMPUTE_UNARY(func_name, operation, para1_t, para2_t, ret_t) \ 61 | inline ret_t compute_##func_name(para1_t in, para2_t out) { \ 62 | operation(in, out); \ 63 | return out; \ 64 | } 65 | 66 | #undef COMPUTE_BINARY 67 | #define COMPUTE_BINARY(func_name, operation, para1_t, para2_t, para3_t, ret_t) \ 68 | inline ret_t compute_##func_name(para1_t in1, para2_t in2, para3_t out) { \ 69 | operation(in1, in2, out); \ 70 | return out; \ 71 | } 72 | 73 | COMPUTE_UNARY(mov, MOV_EXPR, int, int, int) 74 | COMPUTE_BINARY(add, ADD_EXPR, int, int, int, int) 75 | COMPUTE_BINARY(max, MAX_EXPR, int, int, int, int) 76 | // Functions for interpreter end 77 | /* Macros for interpreter end */ 78 | 79 | /* Macros for validator start */ 80 | // Operator macros in experssion macros for validator start 81 | #undef EQ 82 | #define EQ == 83 | #undef IF_PRED_ACTION 84 | #define IF_PRED_ACTION(pred, expr, var) ((pred) && (var EQ expr)) 85 | #undef CONNECTIFELSE 86 | #define CONNECTIFELSE || 87 | #undef ELSE_PRED_ACTION 88 | #define ELSE_PRED_ACTION(pred, expr, var) (!(pred) && (var EQ expr)) 89 | // Operator macros in experssion macros for validator end 90 | 91 | // Functions for validator start 92 | #undef PREDICATE_UNARY 93 | #define PREDICATE_UNARY(func_name, operation) \ 94 | inline z3::expr predicate_##func_name(z3::expr in, z3::expr out) { \ 95 | return operation(in, out); \ 96 | } 97 | #undef PREDICATE_BINARY 98 | #define PREDICATE_BINARY(func_name, operation) \ 99 | inline z3::expr predicate_##func_name(z3::expr in1, z3::expr in2, z3::expr out) { \ 100 | return operation(in1, in2, out); \ 101 | } 102 | 103 | PREDICATE_UNARY(mov, MOV_EXPR) 104 | PREDICATE_BINARY(add, ADD_EXPR) 105 | PREDICATE_BINARY(max, MAX_EXPR) 106 | 107 | // Functions for validator en 108 | /* Macros for validator end */ 109 | inline z3::expr smt_pgm_eq_chk(smt_var& sv1, smt_var& sv2) { 110 | return (sv1.ret_val == sv2.ret_val); 111 | } 112 | 113 | inline z3::expr smt_pgm_set_same_input(smt_var& sv1, smt_var& sv2) { 114 | return Z3_true; 115 | } 116 | 117 | inline void counterex_2_input_mem(inout_t& input, z3::model& mdl, 118 | smt_var& sv1, smt_var& sv2) {} 119 | -------------------------------------------------------------------------------- /src/isa/toy-isa/inst_codegen_test.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../../../src/utils.h" 3 | #include "inst_codegen.h" 4 | 5 | using namespace std; 6 | 7 | z3::context ctx; 8 | 9 | void test1() { 10 | cout << "Test 1" << endl; 11 | int a = 4, b = 5, c = 10; 12 | z3::expr x = ctx.int_val(a); 13 | z3::expr y = ctx.int_val(b); 14 | z3::expr z = ctx.int_val(c); 15 | 16 | // check add 17 | print_test_res(compute_add(a, b, c) == (a + b), "compute_add"); 18 | z3::expr expected = (z == x + y); 19 | print_test_res(predicate_add(x, y, z) == expected, "predicate_add"); 20 | 21 | // check mov 22 | print_test_res(compute_mov(a, b) == a, "compute_mov"); 23 | expected = (y == x); 24 | print_test_res(predicate_mov(x, y) == expected, "predicate_mov"); 25 | 26 | // check max 27 | print_test_res(compute_max(a, b, c) == max(a, b), "compute_max"); 28 | expected = ((x > a) && (z == x)) || ((x <= a) && (z == a)); 29 | print_test_res(predicate_max(x, y, z) == expected, "predicate_max"); 30 | } 31 | 32 | int main() { 33 | test1(); 34 | 35 | return 0; 36 | } 37 | -------------------------------------------------------------------------------- /src/isa/toy-isa/inst_test.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "../../../src/utils.h" 4 | #include "inst.h" 5 | 6 | /* r0 contains the input */ 7 | inst instructions[6] = {inst(MOVXC, 2, 4), /* mov r2, 4 */ 8 | inst(ADDXY, 0, 2), /* add r0, r2 */ 9 | inst(MOVXC, 3, 15), /* mov r3, 15 */ 10 | inst(JMPGT, 0, 3, 1), /* if r0 <= r3: */ 11 | inst(RETX, 3), /* ret r3 */ 12 | inst(RETX, 0), /* else ret r0 */ 13 | }; 14 | 15 | inst instructions2[4] = {inst(MOVXC, 2, 4), /* mov r2, 4 */ 16 | inst(ADDXY, 0, 2), /* add r0, r2 */ 17 | inst(MAXC, 0, 15), /* max r0, 15 */ 18 | inst(RETX, 0), /* ret r0 */ 19 | }; 20 | 21 | inst instructions3[2] = {inst(NOP), /* test no-op */ 22 | inst(RETX, 0), /* ret r0 */ 23 | }; 24 | 25 | /* test unconditional jmp */ 26 | inst instructions4[3] = {inst(JMP, 1), 27 | inst(ADDXY, 0, 0), 28 | inst(RETX, 0), 29 | }; 30 | 31 | void test1(int input_reg) { 32 | prog_state ps; 33 | inout_t input, output, expected; 34 | input.init(); 35 | output.init(); 36 | expected.init(); 37 | cout << "Test 1: full interpretation check" << endl; 38 | 39 | input.reg = input_reg; 40 | expected.reg = max(input_reg + 4, 15); 41 | interpret(output, instructions, 6, ps, input); 42 | print_test_res(output == expected, "interpret program 1"); 43 | 44 | input.reg = input_reg; 45 | expected.reg = max(input_reg + 4, 15); 46 | interpret(output, instructions2, 4, ps, input); 47 | print_test_res(output == expected, "interpret program 2"); 48 | 49 | input.reg = input_reg; 50 | expected.reg = input_reg; 51 | interpret(output, instructions3, 2, ps, input); 52 | print_test_res(output == input, "interpret program 3"); 53 | 54 | input.reg = input_reg; 55 | expected.reg = input_reg; 56 | interpret(output, instructions4, 3, ps, input); 57 | print_test_res(output == input, "interpret program 4"); 58 | } 59 | 60 | void test2() { 61 | cout << "Test 2" << endl; 62 | inst x = inst(MOVXC, 2, 4); 63 | inst y = inst(MOVXC, 2, 4); 64 | inst z = inst(MOVXC, 2, 3); 65 | inst w = inst(RETX, 3); 66 | 67 | cout << "Instruction operator== check" << endl; 68 | print_test_res((x == y) == true, "operator== 1"); 69 | print_test_res((inst(RETX, 3) == inst(RETC, 3)) == false, "operator== 2"); 70 | print_test_res((inst(RETX, 3) == inst(RETX, 2)) == false, "operator== 3"); 71 | print_test_res((inst(RETX, 3) == inst(RETX, 3)) == true, "operator== 4"); 72 | 73 | cout << "Instruction hash value check" << endl; 74 | print_test_res(instHash()(x) == 22, "hash value 1"); 75 | print_test_res(instHash()(y) == 22, "hash value 2"); 76 | print_test_res(instHash()(z) == 10, "hash value 3"); 77 | print_test_res(instHash()(w) == 5, "hash value 4"); 78 | } 79 | 80 | void test3() { 81 | cout << "Test 3" << endl; 82 | string expected_bv_str = string("00010000100010000000") + 83 | string("00001000000001000000") + 84 | string("00010000110111100000") + 85 | string("00111000000001100001") + 86 | string("00011000110000000000") + 87 | string("00011000000000000000"); 88 | string bv_str = ""; 89 | for (int i = 0; i < 6; i++) { 90 | inst x = instructions[i]; 91 | vector abs_bv; 92 | x.to_abs_bv(abs_bv); 93 | for (int j = 0; j < abs_bv.size(); j++) { 94 | bv_str += bitset(abs_bv[j]).to_string(); 95 | } 96 | } 97 | print_test_res(bv_str == expected_bv_str, "inst_to_abs_bv"); 98 | } 99 | 100 | int main(int argc, char *argv[]) { 101 | /* Add the notion of program input */ 102 | int input = 10; 103 | if (argc > 1) { 104 | input = atoi(argv[1]); 105 | } 106 | 107 | test1(input); 108 | test2(); 109 | test3(); 110 | 111 | return 0; 112 | } 113 | -------------------------------------------------------------------------------- /src/isa/toy-isa/inst_var.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "inst_var.h" 4 | 5 | using namespace std; 6 | 7 | default_random_engine gen_toy_isa_inst_var; 8 | uniform_real_distribution unidist_toy_isa_inst_var(0.0, 1.0); 9 | 10 | void update_ps_by_input(prog_state& ps, const inout_t& input) { 11 | ps._regs[0] = input.reg; 12 | } 13 | 14 | void update_output_by_ps(inout_t& output, const prog_state& ps) { 15 | output.reg = ps._regs[0]; 16 | } 17 | 18 | void get_cmp_lists(vector& val_list1, vector& val_list2, 19 | inout_t& output1, inout_t& output2) { 20 | val_list1.resize(1); 21 | val_list2.resize(1); 22 | val_list1[0] = output1.reg; 23 | val_list2[0] = output2.reg; 24 | } 25 | 26 | void gen_random_input(vector& inputs, int reg_min, int reg_max) { 27 | unordered_set input_set; 28 | for (size_t i = 0; i < inputs.size();) { 29 | reg_t input = reg_min + (reg_max - reg_min) * 30 | unidist_toy_isa_inst_var(gen_toy_isa_inst_var); 31 | if (input_set.find(input) == input_set.end()) { 32 | input_set.insert(input); 33 | inputs[i].reg = input; 34 | i++; 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/isa/toy-isa/inst_var.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../../../src/isa/inst_var.h" 4 | 5 | using namespace std; 6 | 7 | static constexpr int NUM_REGS = 4; 8 | 9 | class smt_var: public smt_var_base { 10 | public: 11 | z3::expr ret_val = to_expr("ret_val"); 12 | smt_var(): smt_var_base() {} 13 | smt_var(unsigned int prog_id, unsigned int node_id, unsigned int num_regs) 14 | : smt_var_base(prog_id, node_id, num_regs) { 15 | ret_val = to_expr("ret_val_" + to_string(prog_id)); 16 | }; 17 | ~smt_var() {}; 18 | void init(unsigned int prog_id, unsigned int node_id, unsigned int num_regs, unsigned int n_blocks = 1) { 19 | smt_var_base::init(prog_id, node_id, num_regs); 20 | ret_val = to_expr("ret_val_" + to_string(prog_id)); 21 | } 22 | }; 23 | 24 | class smt_var_bl { 25 | public: 26 | void store_state_before_smt_block(smt_var& sv) {} 27 | z3::expr gen_smt_after_smt_block(smt_var& sv, z3::expr& pc) {return Z3_true;} 28 | }; 29 | 30 | class prog_state: public prog_state_base { 31 | public: 32 | prog_state() {_regs.resize(NUM_REGS, 0);} 33 | }; 34 | 35 | class inout_t: public inout_t_base { 36 | public: 37 | int reg; 38 | void clear() {reg = 0;} 39 | void init() {} 40 | bool operator==(const inout_t &rhs) const {return (reg == rhs.reg);} 41 | friend ostream& operator<<(ostream& out, const inout_t& x) { 42 | out << x.reg; 43 | return out; 44 | } 45 | }; 46 | 47 | void get_cmp_lists(vector& val_list1, vector& val_list2, 48 | inout_t& output1, inout_t& output2); 49 | void gen_random_input(vector& inputs, int reg_min, int reg_max); 50 | -------------------------------------------------------------------------------- /src/search/cost.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "../../src/utils.h" 5 | #include "../../src/inout.h" 6 | #include "../../src/isa/inst_header.h" 7 | #include "../../src/isa/prog.h" 8 | #include "../../src/verify/validator.h" 9 | 10 | using namespace std; 11 | extern int dur_sum; 12 | extern int dur_sum_long; 13 | extern int n_sum_long; 14 | 15 | #define ERROR_COST_MAX 100000 16 | 17 | #define ERROR_COST_STRATEGY_ABS 0 18 | #define ERROR_COST_STRATEGY_POP 1 19 | #define ERROR_COST_STRATEGY_EQ1 0 20 | #define ERROR_COST_STRATEGY_EQ2 1 21 | #define ERROR_COST_STRATEGY_NAVG 0 22 | #define ERROR_COST_STRATEGY_AVG 1 23 | 24 | #define PERF_COST_STRATEGY_LEN 0 // length of programs 25 | #define PERF_COST_STRATEGY_RUNTIME 1 // runtime of programs 26 | 27 | class cost { 28 | private: 29 | // perf_cost_base_win is a cache of perf cost of all program instrcutions except window instructions 30 | // perf_csot = perf_cost_base_win + perf_cost of window 31 | double _perf_cost_base_win = 0; // -1 means not set 32 | int _num_real_orig; 33 | double get_ex_error_cost(inout_t& output1, inout_t& output2); 34 | int get_avg_value(int ex_set_size); 35 | double get_final_error_cost(double exs_cost, int is_equal, 36 | int ex_set_size, int num_successful_ex, 37 | int avg_value); 38 | double get_ex_error_cost_from_val_lists_abs(vector& val_list1, vector& val_list2); 39 | double get_ex_error_cost_from_val_lists_pop(vector& val_list1, vector& val_list2); 40 | void set_perf_cost_base_win(prog* orig, int len, int win_start, int win_end); 41 | public: 42 | validator _vld; 43 | examples _examples; 44 | bool _meas_new_counterex_gened; 45 | double _w_e = 0.5; 46 | double _w_p = 0.5; 47 | int _strategy_ex = 0; 48 | int _strategy_eq = 0; 49 | int _strategy_avg = 0; 50 | int _strategy_perf = 0; 51 | cost(); 52 | ~cost(); 53 | void init(prog* orig, int len, const vector &input, 54 | double w_e = 0.5, double w_p = 0.5, 55 | int strategy_ex = 0, int strategy_eq = 0, 56 | int strategy_avg = 0, int strategy_perf = 0, 57 | bool enable_prog_eq_cache = true, 58 | bool enable_prog_uneq_cache = false, 59 | bool is_win = false); 60 | void set_examples(const vector &input, prog* orig); 61 | void set_orig(prog* orig, int len, int win_start = 0, int win_end = inst::max_prog_len); 62 | double error_cost(prog* orig, int len1, prog* synth, int len2); 63 | double perf_cost(prog* synth, int len, bool set_win = false); 64 | double total_prog_cost(prog* orig, int len1, prog* synth, int len2); 65 | }; 66 | 67 | unsigned int pop_count_outputs(int64_t output1, int64_t output2); 68 | -------------------------------------------------------------------------------- /src/search/cost_test.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../../src/utils.h" 3 | #include "cost.h" 4 | 5 | using namespace std; 6 | 7 | inst instructions[7] = {inst(MOVXC, 2, 4), /* mov r2, 4 */ 8 | inst(ADDXY, 0, 2), /* add r0, r2 */ 9 | inst(MOVXC, 3, 15), /* mov r3, 15 */ 10 | inst(JMPGT, 0, 3, 1), /* if r0 <= r3: */ 11 | inst(RETX, 3), /* ret r3 */ 12 | inst(RETX, 0), /* else ret r0 */ 13 | inst(), /* nop */ 14 | }; 15 | 16 | inst instructions2[7] = {inst(MOVXC, 2, 4), /* mov r2, 4 */ 17 | inst(ADDXY, 0, 2), /* add r0, r2 */ 18 | inst(MAXC, 0, 15), /* max r0, 15 */ 19 | inst(RETX, 0), /* ret r0 */ 20 | inst(), /* nop */ 21 | inst(), /* nop */ 22 | inst(), /* nop */ 23 | }; 24 | 25 | inst instructions3[7] = {inst(MOVXC, 2, 4), /* mov r2, 4 */ 26 | inst(ADDXY, 0, 2), /* add r0, r2 */ 27 | inst(MOVXC, 3, 15), /* mov r3, 15 */ 28 | inst(JMPEQ, 0, 3, 1), /* if r0 != r3: */ 29 | inst(RETX, 3), /* ret r3 */ 30 | inst(RETX, 0), /* else ret r0 */ 31 | inst(), /* nop */ 32 | }; 33 | 34 | inst instructions4[7] = {inst(MOVXC, 2, 4), /* mov r2, 4 */ 35 | inst(ADDXY, 0, 2), /* add r0, r2 */ 36 | inst(MOVXC, 3, 16), /* mov r3, 16 */ 37 | inst(JMPGT, 0, 3, 1), /* if r0 <= r3: */ 38 | inst(RETC, 15), /* ret 15 */ 39 | inst(RETX, 0), /* else ret r0 */ 40 | inst(), /* nop */ 41 | }; 42 | 43 | vector ex_set(2); 44 | 45 | void test1() { 46 | #define NUM_INTS 6 47 | cout << "test 1: pop_count_asm starts...\n"; 48 | unsigned int ints_list[NUM_INTS] = {0, 1, 5, 7, 63, 114}; 49 | int truth[NUM_INTS] = {0, 1, 2, 3, 6, 4}; 50 | for (int i = 0; i < NUM_INTS; i++) { 51 | print_test_res(pop_count_asm(ints_list[i]) == truth[i], 52 | to_string(i + 1)); 53 | } 54 | } 55 | 56 | void test2() { 57 | cout << "test 2: error_cost check starts...\n"; 58 | prog orig(instructions); 59 | cost c; 60 | c.set_orig(&orig, 7); 61 | c._examples.clear(); 62 | for (size_t i = 0; i < ex_set.size(); i++) { 63 | c._examples.insert(ex_set[i]); 64 | } 65 | int err_cost = c.error_cost(&orig, 7, &orig, 7); 66 | print_test_res(err_cost == 0, "1"); 67 | prog synth1(instructions2); 68 | err_cost = c.error_cost(&orig, 7, &synth1, 7); 69 | print_test_res(err_cost == 0, "2"); 70 | prog synth2(instructions3); 71 | err_cost = c.error_cost(&orig, 7, &synth2, 7); 72 | print_test_res(err_cost == 6, "3"); 73 | prog synth3(instructions4); 74 | err_cost = c.error_cost(&orig, 7, &synth3, 7); 75 | print_test_res(err_cost == 2, "4"); 76 | } 77 | 78 | int main() { 79 | inout_t x, y; 80 | x.init(); 81 | y.init(); 82 | x.reg = 10; 83 | y.reg = 15; 84 | ex_set[0].set_in_out(x, y); 85 | x.reg = 16; 86 | y.reg = 20; 87 | ex_set[1].set_in_out(x, y); 88 | test1(); 89 | test2(); 90 | return 0; 91 | } 92 | -------------------------------------------------------------------------------- /src/search/cost_test_ebpf.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../../src/utils.h" 3 | #include "../../measure/benchmark_ebpf.h" 4 | #include "cost.h" 5 | 6 | using namespace std; 7 | 8 | double get_error_cost(inst* p1, inst* p2, int win_start, int win_end) { 9 | cost c; 10 | c._vld._is_win = true; 11 | smt_var::is_win = true; 12 | prog prog1(p1), prog2(p2); 13 | 14 | inout_t::start_insn = win_start; 15 | inout_t::end_insn = win_end; 16 | static_safety_check_pgm(prog1.inst_list, inst::max_prog_len); 17 | c.set_orig(&prog1, inst::max_prog_len, win_start, win_end); 18 | prog_static_state pss; 19 | static_analysis(pss, p1, inst::max_prog_len); 20 | int num_examples = 30; 21 | vector examples; 22 | gen_random_input_for_win(examples, num_examples, 23 | pss.static_state[win_start], p1[win_start], 24 | win_start, win_end); 25 | c.set_examples(examples, &prog1); 26 | return c.error_cost(&prog1, inst::max_prog_len, &prog2, inst::max_prog_len); 27 | } 28 | 29 | void test1() { 30 | cout << "Test1: test error cost" << endl; 31 | inst p1[N3], p2[N3]; 32 | for (int i = 0; i < N3; i++) p1[i] = bm3[i]; 33 | for (int i = 0; i < N3; i++) p2[i] = bm3[i]; 34 | p2[5] = inst(MOV32XY, 1, 0); 35 | p2[6] = inst(); 36 | p2[7] = inst(); 37 | int win_start = 5, win_end = 7; 38 | mem_t::_layout.clear(); 39 | inst::max_prog_len = N3; 40 | mem_t::set_pgm_input_type(PGM_INPUT_pkt); 41 | mem_t::set_pkt_sz(128); 42 | mem_t::add_map(map_attr(128, 64, 91)); 43 | mem_t::add_map(map_attr(96, 96, 91)); 44 | mem_t::add_map(map_attr(64, 128, 91)); 45 | mem_t::_layout._n_randoms_u32 = 1; 46 | smt_var::init_static_variables(); 47 | 48 | print_test_res(get_error_cost(p1, p2, win_start, win_end) == 0, "rcv_sock4 1"); 49 | mem_t::_layout.clear(); 50 | 51 | // xdp_exception 52 | const int xdp_exp_len = N16; 53 | inst::max_prog_len = xdp_exp_len; 54 | mem_t::set_pgm_input_type(PGM_INPUT_pkt); 55 | mem_t::set_pkt_sz(32); 56 | mem_t::add_map(map_attr(32, 64, N16)); 57 | inst xdp_exp[xdp_exp_len]; 58 | inst xdp_exp_1[xdp_exp_len]; 59 | for (int i = 0; i < xdp_exp_len; i++) xdp_exp[i] = bm16[i]; 60 | for (int i = 0; i < xdp_exp_len; i++) xdp_exp_1[i] = xdp_exp[i]; 61 | win_start = 12; 62 | win_end = 14; 63 | xdp_exp_1[12] = inst(); 64 | xdp_exp_1[13] = inst(); 65 | xdp_exp_1[14] = inst(XADD64, 0, 0, 1); 66 | print_test_res(get_error_cost(xdp_exp, xdp_exp_1, win_start, win_end) == 0, "xdp_exception 1"); 67 | mem_t::_layout.clear(); 68 | 69 | // xdp_pktcntr, bm24 70 | const int xdp_pkt_len = N24; 71 | inst::max_prog_len = xdp_pkt_len; 72 | mem_t::set_pgm_input_type(PGM_INPUT_pkt); 73 | mem_t::set_pkt_sz(68); 74 | mem_t::add_map(map_attr(32, 32, N24)); 75 | mem_t::add_map(map_attr(32, 64, N24)); 76 | inst xdp_pkt[xdp_pkt_len]; 77 | inst xdp_pkt_1[xdp_pkt_len]; 78 | for (int i = 0; i < xdp_pkt_len; i++) xdp_pkt[i] = bm24[i]; 79 | for (int i = 0; i < xdp_pkt_len; i++) xdp_pkt_1[i] = xdp_pkt[i]; 80 | win_start = 17; 81 | win_end = 19; 82 | xdp_pkt_1[17] = inst(); 83 | xdp_pkt_1[18] = inst(MOV32XC, 1, 1); 84 | xdp_pkt_1[19] = inst(XADD64, 0, 0, 1); 85 | print_test_res(get_error_cost(xdp_pkt, xdp_pkt_1, win_start, win_end) == 0, "xdp_pktcntr 1"); 86 | mem_t::_layout.clear(); 87 | 88 | // test PGM_INPUT_pkt_ptrs 89 | const int p3_len = 8; 90 | inst::max_prog_len = p3_len; 91 | mem_t::set_pgm_input_type(PGM_INPUT_pkt_ptrs); 92 | mem_t::set_pkt_sz(32); 93 | mem_t::add_map(map_attr(32, 64, p3_len)); 94 | inst p3[] = {inst(LDXW, 3, 1, 4), 95 | inst(LDXW, 2, 1, 0), 96 | inst(LDXB, 4, 2, 12), // insn 2 97 | inst(LDXB, 5, 2, 13), 98 | inst(LSH64XC, 5, 8), 99 | inst(OR64XY, 5, 4), // insn 5 100 | inst(MOV64XY, 0, 5), 101 | inst(EXIT), 102 | }; 103 | inst p3_1[] = {inst(LDXW, 3, 1, 4), 104 | inst(LDXW, 2, 1, 0), 105 | inst(LDXB, 4, 2, 12), 106 | inst(LDXB, 5, 2, 13), 107 | inst(LSH64XC, 5, 8), 108 | inst(LDXH, 5, 2, 12), 109 | inst(MOV64XY, 0, 5), 110 | inst(EXIT), 111 | }; 112 | win_start = 2; 113 | win_end = 5; 114 | print_test_res(get_error_cost(p3, p3_1, win_start, win_end) == 0, "PGM_INPUT_pkt_ptrs 1"); 115 | mem_t::_layout.clear(); 116 | 117 | 118 | inst p4[] = {inst(LDXW, 2, 1, 4), 119 | inst(LDXW, 8, 1, 0), 120 | inst(MOV64XC, 1, 0), 121 | inst(STXW, 10, -4, 1), 122 | inst(STXW, 10, -8, 1), 123 | inst(MOV64XC, 7, 1), 124 | inst(MOV64XY, 1, 8), 125 | inst(ADD64XC, 1, 14), 126 | inst(JGTXY, 1, 2, 32), 127 | inst(MOV64XY, 2, 10), 128 | inst(ADD64XC, 2, -4), 129 | INSN_LDMAPID(1, 0), 130 | inst(NOP), 131 | inst(CALL, 1), 132 | inst(MOV64XY, 6, 0), 133 | inst(JEQXC, 6, 0, 25), 134 | inst(MOV64XY, 2, 10), 135 | inst(ADD64XC, 2, -8), 136 | INSN_LDMAPID(1, 1), 137 | inst(NOP), 138 | inst(CALL, 1), 139 | inst(JEQXC, 0, 0, 3), 140 | inst(LDXDW, 1, 0, 0), 141 | inst(ADD64XC, 1, 1), 142 | inst(STXDW, 0, 0, 1), 143 | inst(LDXH, 1, 8, 0), 144 | inst(LDXH, 2, 8, 6), 145 | inst(STXH, 8, 0, 2), 146 | inst(LDXH, 2, 8, 8), 147 | inst(LDXH, 3, 8, 2), 148 | inst(STXH, 8, 8, 3), 149 | inst(STXH, 8, 2, 2), 150 | inst(LDXH, 2, 8, 10), 151 | inst(LDXH, 3, 8, 4), 152 | inst(STXH, 8, 10, 3), 153 | inst(STXH, 8, 6, 1), 154 | inst(STXH, 8, 4, 2), 155 | inst(LDXW, 1, 6, 0), 156 | inst(MOV64XC, 2, 0), 157 | inst(CALL, 23), 158 | inst(MOV64XY, 7, 0), 159 | inst(MOV64XY, 0, 7), 160 | inst(EXIT), 161 | }; 162 | 163 | const int p4_len = sizeof(p4) / sizeof(inst); 164 | inst::max_prog_len = p4_len; 165 | mem_t::set_pgm_input_type(PGM_INPUT_pkt_ptrs); 166 | mem_t::set_pkt_sz(64); 167 | mem_t::add_map(map_attr(32, 32, p4_len)); 168 | mem_t::add_map(map_attr(32, 64, p4_len)); 169 | inst p4_1[p4_len]; 170 | for (int i = 0; i < p4_len; i++) p4_1[i] = p4[i]; 171 | win_start = 2; 172 | win_end = 4; 173 | p4_1[2] = inst(); 174 | p4_1[3] = inst(); 175 | p4_1[4] = inst(STDW, 10, -8, 0); 176 | print_test_res(get_error_cost(p4, p4_1, win_start, win_end) == 0, "PGM_INPUT_pkt_ptrs 2"); 177 | mem_t::_layout.clear(); 178 | 179 | inst p5[] = {inst(MOV64XY, 6, 1), 180 | inst(MOV64XC, 1, 0), 181 | inst(CALL, BPF_FUNC_get_prandom_u32), 182 | inst(MOV64XC, 2, 0), 183 | inst(STXDW, 10, -8, 2), 184 | inst(LDXDW, 0, 10, -8), 185 | inst(EXIT), 186 | }; 187 | inst p5_1[] = {inst(MOV64XY, 6, 1), 188 | inst(MOV64XC, 1, 0), 189 | inst(CALL, BPF_FUNC_get_prandom_u32), 190 | inst(MOV64XC, 2, 0), 191 | inst(STXDW, 10, -8, 1), // src_reg from 2 to 1 192 | inst(LDXDW, 0, 10, -8), 193 | inst(EXIT), 194 | }; 195 | const int p5_len = sizeof(p5) / sizeof(inst); 196 | inst::max_prog_len = p5_len; 197 | mem_t::set_pgm_input_type(PGM_INPUT_pkt); 198 | mem_t::set_pkt_sz(32); 199 | win_start = 3; 200 | win_end = 5; 201 | double err = get_error_cost(p5, p5_1, win_start, win_end); 202 | print_test_res(get_error_cost(p5, p5_1, win_start, win_end) == ERROR_COST_MAX, "p5"); 203 | 204 | mem_t::_layout.clear(); 205 | } 206 | 207 | int main() { 208 | test1(); 209 | kill_server(); 210 | return 0; 211 | } 212 | -------------------------------------------------------------------------------- /src/search/mh.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #define PDF_SUPPORT 40 7 | 8 | using namespace std; 9 | 10 | default_random_engine gen; 11 | uniform_real_distribution unidist(0.0, 1.0); 12 | 13 | /* Function that provides the probability density given a support variable x */ 14 | double pi(double x) { 15 | double center = 15.0; 16 | double width = 5.0; 17 | if (x > center) 18 | return max(0.0, center + width - x); 19 | else 20 | return max(0.0, width - center + x); 21 | } 22 | 23 | /* Generate a new sample from proposal distribution */ 24 | double generate_y(double x) { 25 | normal_distribution dist(x, 2.0); 26 | double y = dist(gen); 27 | return y; 28 | } 29 | 30 | /* compute acceptance function */ 31 | double alpha(double x, double y) { 32 | /* Use the simplified form when proposal distribution is symmetric */ 33 | return min(1.0, pi(y) / pi(x)); 34 | } 35 | 36 | /* Get the next MCMC sample */ 37 | double mh_next(double x) { 38 | double y = generate_y(x); 39 | double uni_sample = unidist(gen); 40 | if (uni_sample < alpha(x, y)) 41 | return y; 42 | else 43 | return x; 44 | } 45 | 46 | int main() { 47 | double x = 10.0; 48 | double y; 49 | int p[PDF_SUPPORT] = {}; 50 | int nrolls = 100000; 51 | int nstars = 100; 52 | double support_lower_end = -10.0; 53 | double support_higher_end = PDF_SUPPORT + support_lower_end; 54 | 55 | cout << "pi " << pi(x) << endl << "Dist computed:" << endl; 56 | 57 | /* MH iterations */ 58 | for (int i=0; i= support_lower_end && y < support_higher_end) { 61 | int index = int(y - support_lower_end); 62 | ++p[index]; 63 | } 64 | x = y; 65 | } 66 | 67 | /* Print the distribution */ 68 | for (int i=0; i 2 | #include 3 | #include 4 | #include "../../src/utils.h" 5 | #include "../../src/inout.h" 6 | #include "../../src/isa/inst_header.h" 7 | #include "../../src/isa/prog.h" 8 | #include "proposals.h" 9 | #include "cost.h" 10 | #include "../../measure/meas_mh_bhv.h" 11 | 12 | using namespace std; 13 | 14 | // 1. when to restart strategy 15 | #define MH_SAMPLER_ST_WHEN_TO_RESTART_NO_RESTART 0 16 | #define MH_SAMPLER_ST_WHEN_TO_RESTART_MAX_ITER 1 17 | // 2. start prog strategy 18 | #define MH_SAMPLER_ST_NEXT_START_PROG_ORIG 0 19 | #define MH_SAMPLER_ST_NEXT_START_PROG_ALL_INSTS 1 20 | #define MH_SAMPLER_ST_NEXT_START_PROG_K_CONT_INSTS 2 21 | 22 | class mh_sampler_next_win { 23 | public: 24 | unsigned int _st_next_win; 25 | unsigned int _max_num_iter; 26 | vector _win_s_list; 27 | vector _win_e_list; 28 | unsigned int _cur_win; 29 | mh_sampler_next_win(); 30 | void set_win_lists(const vector& win_s_list, const vector& win_e_list); 31 | void set_max_num_iter(unsigned int max_num_iter); 32 | bool whether_to_reset(unsigned int iter_num); 33 | pair update_and_get_next_win(); 34 | }; 35 | 36 | class mh_sampler_restart { 37 | public: 38 | unsigned int _st_when_to_start; 39 | // restart every `_max_num_iter` iterations 40 | unsigned int _max_num_iter; 41 | unsigned int _st_next_start_prog; 42 | vector _w_e_list; 43 | vector _w_p_list; 44 | size_t _cur_w_pointer; 45 | mh_sampler_restart(); 46 | ~mh_sampler_restart(); 47 | void set_st_when_to_restart(unsigned int st, unsigned int max_num_iter = 0); 48 | void set_st_next_start_prog(unsigned int st); 49 | void set_we_wp_list(const vector &w_e_list, const vector &w_p_list); 50 | bool whether_to_restart(unsigned int iter_num); 51 | prog* next_start_prog(prog* curr); 52 | pair next_start_we_wp(); 53 | }; 54 | 55 | /* The main function of class mh_sampler_next_proposal is to 56 | * generate next proposal program according to the probability of different methods. 57 | * 58 | * Next proposal program can be generate by function next_proposal(.), noting that 59 | * when generate next proposal program, the sum of all probabilities is assumed as 1. 60 | * 61 | * Three methods are supported now, that is, 62 | * modify random instrution operand, instruction and two continuous instructions. 63 | * 64 | * The probabilities three methods can be set by set_probability(.). 65 | */ 66 | class mh_sampler_next_proposal { 67 | public: 68 | // `_thr_*` variables are used as thresholds when using uniform sample to 69 | // randomly choose different proposal generating methods. View more details in next_proposal(.) 70 | // 1. threshold mod_random_inst_operand is the probablity of mod_random_inst_operand 71 | // 2. threshold mod_random_inst is sum of the probablities of mod_random_inst_operand and mod_random_inst 72 | // 3. threshold mod_random_inst_as_nop is the probablity of mod_random_inst_operand, 73 | // mod_random_inst and mod_random_inst_as_nop 74 | double _thr_mod_random_inst_operand; 75 | double _thr_mod_random_inst; 76 | double _thr_mod_random_inst_as_nop; 77 | int _win_start, _win_end; 78 | mh_sampler_next_proposal(); 79 | ~mh_sampler_next_proposal(); 80 | void set_probability(double p_mod_random_inst_operand, 81 | double p_mod_random_inst, 82 | double p_mod_random_inst_as_nop); 83 | void set_win(int start, int end); 84 | prog* next_proposal(prog* curr); 85 | }; 86 | 87 | /* Class mh_sampler can be used to generate a chain of sampled programs for a 88 | * given program. The sampled probability mainly depends on cost function, 89 | * start program, moves. 90 | * Example to use a mh_sampler: 91 | * mh_sampler mh; // define a `mh_sampler` variable 92 | * mh._restart.set_st_*(.) // [optional] set different mh sampler strategies, 93 | * // the default values are used without setting 94 | * mh._next_proposal.set_probability(.) // [optional] set probabilities of different moves 95 | * // the default values are used without setting 96 | * mh.turn_on_measure(); // [optional] turn on measure mode if measurement needed 97 | * mh._cost.init(.); // initialize the parameters of cost function 98 | * // view `cost.h` for more details 99 | * mh.mcmc_iter(.); // sample programs 100 | * store_*_to_file(.) // [optional] call data store function(s) to store 101 | * // measurement data; view `meas_mh_data.h` for more details 102 | * mh.turn_off_measure();// [optional] if turn on measure mode, should turn off 103 | */ 104 | class mh_sampler { 105 | private: 106 | double cost_to_pi(double cost); 107 | void print_restart_info(int iter_num, const prog &restart, double w_e, double w_p); 108 | public: 109 | mh_sampler_next_win _next_win; 110 | mh_sampler_restart _restart; 111 | mh_sampler_next_proposal _next_proposal; 112 | meas_mh_data _meas_data; 113 | cost _cost; 114 | double _base = 2; 115 | mh_sampler(); 116 | ~mh_sampler(); 117 | double alpha(prog* curr, prog* next, prog* orig); 118 | prog* mh_next(prog* curr, prog* orig); 119 | void turn_on_measure(); 120 | void turn_off_measure(); 121 | void mcmc_iter(top_k_progs& topk_progs, int niter, prog* orig, bool is_win = false); 122 | }; 123 | -------------------------------------------------------------------------------- /src/search/mh_prog_test.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "../../src/utils.h" 7 | #include "mh_prog.h" 8 | 9 | using namespace std; 10 | 11 | #define N 7 12 | inst instructions[N] = {inst(MOVXC, 2, 4), /* mov r2, 4 */ 13 | inst(ADDXY, 0, 2), /* add r0, r2 */ 14 | inst(MOVXC, 3, 15), /* mov r3, 15 */ 15 | inst(JMPGT, 0, 3, 1), /* if r0 <= r3: */ 16 | inst(RETX, 3), /* ret r3 */ 17 | inst(RETX, 0), /* else ret r0 */ 18 | inst(), /* control never reaches here */ 19 | }; 20 | 21 | vector inputs; 22 | 23 | void mh_sampler_res_print(int nrolls, 24 | unordered_map > prog_freq) { 25 | // Get the best program(s) 26 | int max = 0; 27 | int concurrent_max = 0; 28 | prog *best; 29 | int nprogs = 0; 30 | for (std::pair > element : prog_freq) { 31 | vector pl = element.second; // list of progs with the same hash 32 | for (auto p : pl) { 33 | nprogs++; 34 | if (p->freq_count > max) { 35 | concurrent_max = 1; 36 | best = p; 37 | max = p->freq_count; 38 | } else if (p->freq_count == max) { 39 | concurrent_max++; 40 | } 41 | } 42 | } 43 | cout << "number of unique hashes observed: " << prog_freq.size() << endl; 44 | cout << "number of unique programs observed: " << nprogs << endl; 45 | cout << "Number of concurrently best programs:" << concurrent_max << endl; 46 | cout << "One of the best programs: " << endl; 47 | cout << "Observed frequency " << max << " out of " << nrolls << endl; 48 | best->print(); 49 | } 50 | 51 | void test1(int nrolls, double w_e, double w_p) { 52 | mh_sampler mh; 53 | mh._restart.set_st_when_to_restart(MH_SAMPLER_ST_WHEN_TO_RESTART_MAX_ITER, 5); 54 | mh._restart.set_st_next_start_prog(MH_SAMPLER_ST_NEXT_START_PROG_K_CONT_INSTS); 55 | mh._restart.set_we_wp_list(vector {0.5, 1.5}, vector {1.5, 0.5}); 56 | mh._next_proposal.set_probability(0.3, 0.5); 57 | std::unordered_map > prog_freq; 58 | prog orig(instructions); 59 | mh._cost.init(&orig, N, inputs, w_e, w_p); 60 | mh.mcmc_iter(nrolls, orig, prog_freq); 61 | mh_sampler_res_print(nrolls, prog_freq); 62 | } 63 | 64 | int main(int argc, char* argv[]) { 65 | int nrolls = 10; 66 | double w_e = 1.0; 67 | double w_p = 0.0; 68 | if (argc > 1) { 69 | nrolls = atoi(argv[1]); 70 | if (argc > 3) { 71 | w_e = atof(argv[2]); 72 | w_p = atof(argv[3]); 73 | } 74 | } 75 | vector inputs(30); 76 | for (int i = 0; i < inputs.size(); i++) { 77 | inputs[i].init(); 78 | } 79 | gen_random_input(inputs, 0, 50); 80 | test1(nrolls, w_e, w_p); 81 | return 0; 82 | } 83 | -------------------------------------------------------------------------------- /src/search/proposals.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "proposals.h" 10 | 11 | using namespace std; 12 | 13 | default_random_engine gen; 14 | uniform_real_distribution unidist(0.0, 1.0); 15 | 16 | /* Return a uniformly random integer from start to end inclusive */ 17 | int sample_int(int start, int end) { 18 | end++; 19 | int val; 20 | do { 21 | val = start + (int)(unidist(gen) * (double)(end - start)); 22 | } while (val == end && end > start); 23 | return val; 24 | } 25 | 26 | /* Return a uniformly random integer from 0 to limit inclusive */ 27 | int sample_int(int limit) { 28 | return sample_int(0, limit); 29 | } 30 | 31 | /* Return a uniformly random integer from 0 to limit inclusive, with the 32 | * exceptions of `excepts`. */ 33 | int sample_int_with_exceptions(int limit, unordered_set &excepts) { 34 | int val = sample_int(limit - excepts.size()); 35 | set excepts_set; 36 | for (auto e : excepts) { 37 | excepts_set.insert(e); 38 | } 39 | for (auto e : excepts_set) { 40 | if (e <= val) val++; 41 | } 42 | return val; 43 | } 44 | 45 | /* Return a uniformly random integer from start to end inclusive, with the 46 | * exception of `except`. */ 47 | int sample_int_with_exception(int start, int end, int except) { 48 | end++; 49 | int val; 50 | do { 51 | val = start + (int)(unidist(gen) * (double)(end - start)); 52 | } while ((val == end || val == except) && ((end - start) > 1)); 53 | return val; 54 | } 55 | 56 | /* Return a uniformly random integer from 0 to limit inclusive, with the 57 | * exception of `except`. */ 58 | int sample_int_with_exception(int limit, int except) { 59 | return sample_int_with_exception(0, limit, except); 60 | } 61 | 62 | // sample with exception `old_opvalue` 63 | int get_new_operand(int sel_inst_index, const inst& sel_inst, int op_to_change, int old_opvalue) { 64 | int max_opvalue = sel_inst.get_max_operand_val(op_to_change, sel_inst_index); 65 | int min_opvalue = sel_inst.get_min_operand_val(op_to_change, sel_inst_index); 66 | // TODO: is it wise to sample with exception? 67 | int new_opvalue = sample_int_with_exception(min_opvalue, max_opvalue, old_opvalue); 68 | return new_opvalue; 69 | } 70 | 71 | // sample without exception 72 | int get_new_operand(int sel_inst_index, const inst& sel_inst, int op_to_change) { 73 | int max_opvalue = sel_inst.get_max_operand_val(op_to_change, sel_inst_index); 74 | int min_opvalue = sel_inst.get_min_operand_val(op_to_change, sel_inst_index); 75 | int new_opvalue = sample_int(min_opvalue, max_opvalue); 76 | return new_opvalue; 77 | } 78 | 79 | void mod_operand(const prog &orig, prog* synth, int sel_inst_index, int op_to_change) { 80 | assert (op_to_change < MAX_OP_LEN); 81 | assert(sel_inst_index < inst::max_prog_len); 82 | // First make a fresh copy of the program. 83 | inst* sel_inst = &synth->inst_list[sel_inst_index]; 84 | if (sel_inst->sample_unmodifiable()) return; 85 | int old_opvalue = sel_inst->get_operand(op_to_change); 86 | int new_opvalue = get_new_operand(sel_inst_index, *sel_inst, op_to_change, old_opvalue); 87 | sel_inst->set_operand(op_to_change, new_opvalue); 88 | } 89 | 90 | void mod_random_operand(const prog &orig, prog* synth, int inst_index) { 91 | int num = orig.inst_list[inst_index].get_num_operands(); 92 | if (num == 0) return; 93 | int op_to_change = sample_int(num - 1); 94 | mod_operand(orig, synth, inst_index, op_to_change); 95 | } 96 | 97 | prog* mod_random_inst_operand(const prog &orig, int win_start, int win_end) { 98 | assert(win_end < inst::max_prog_len); 99 | // TODO: remove instructions whithout valid operands, such as NOP, EXIT 100 | int inst_index = sample_int(win_start, win_end); 101 | prog* synth = new prog(orig); 102 | synth->reset_vals(); 103 | mod_random_operand(orig, synth, inst_index); 104 | if (synth->inst_list[inst_index] == orig.inst_list[inst_index]) { 105 | synth->set_vals(orig); 106 | } 107 | return synth; 108 | } 109 | 110 | /* randomly choose a possible opcode to replace the memory old opcode 111 | */ 112 | void mod_mem_inst_opcode(prog *orig, unsigned int sel_inst_index) { 113 | // 1. check whether it is a memory inst 114 | if (! orig->inst_list[sel_inst_index].is_mem_inst()) return; 115 | // 2. get number of possible opcodes 116 | inst* sel_inst = &orig->inst_list[sel_inst_index]; 117 | int old_opcode = sel_inst->get_opcode(); 118 | int old_opcode_sample_mem_idx = sel_inst->sample_mem_idx(old_opcode); 119 | 120 | int except = {old_opcode_sample_mem_idx}; 121 | int num = sel_inst->num_sample_mem_opcodes(); 122 | int new_mem_opcode_index = sample_int_with_exception(0, num - 1, except); // [0, num) 123 | int new_mem_opcode = sel_inst->get_mem_opcode_by_sample_idx(new_mem_opcode_index); 124 | // 3. modify opcode 125 | sel_inst->set_opcode(new_mem_opcode); 126 | sel_inst->set_unused_operands_default_vals(); 127 | } 128 | 129 | void mod_select_inst(prog *orig, unsigned int sel_inst_index) { 130 | assert(sel_inst_index < inst::max_prog_len); 131 | // TODO: is it wise to sample with exception? 132 | inst* sel_inst = &orig->inst_list[sel_inst_index]; 133 | if (sel_inst->sample_unmodifiable()) return; 134 | int old_opcode = sel_inst->get_opcode(); 135 | if (sel_inst->is_mem_inst()) { 136 | // 50% use the same modification as other opcodes, 50% use memory specific modification 137 | int num_types = 2; 138 | int type = sample_int(num_types - 1); // [0, 1] 139 | if (type == 0) { 140 | mod_mem_inst_opcode(orig, sel_inst_index); 141 | return; 142 | } 143 | } 144 | 145 | // exceptions set is used to avoid jumps in the last line of the program 146 | unordered_set exceptions; 147 | if (sel_inst_index == inst::max_prog_len - 1) { 148 | exceptions = {opcode_2_idx(old_opcode)}; 149 | sel_inst->insert_jmp_opcodes(exceptions); 150 | } else { 151 | exceptions = {opcode_2_idx(old_opcode)}; 152 | } 153 | // if window program eq check is used, set jmp opcodes as exceptions, 154 | // since window program eq check cannot deal with jmp opcodes and exit opcodes 155 | if (smt_var::is_win) { 156 | sel_inst->insert_jmp_opcodes(exceptions); 157 | sel_inst->insert_exit_opcodes(exceptions); 158 | } 159 | sel_inst->insert_opcodes_not_gen(exceptions); 160 | int new_opcode_idx = sample_int_with_exceptions(NUM_INSTR - 1, exceptions); 161 | int new_opcode = sel_inst->get_opcode_by_idx(new_opcode_idx); 162 | sel_inst->set_as_nop_inst(); 163 | sel_inst->set_opcode(new_opcode); 164 | for (int i = 0; i < sel_inst->get_num_operands(); i++) { 165 | int new_opvalue = get_new_operand(sel_inst_index, *sel_inst, i); 166 | sel_inst->set_operand(i, new_opvalue); 167 | } 168 | } 169 | 170 | prog* mod_random_inst(const prog &orig, int win_start, int win_end) { 171 | assert(win_end < inst::max_prog_len); 172 | // First make a copy of the old program 173 | prog* synth = new prog(orig); 174 | synth->reset_vals(); 175 | int inst_index = sample_int(win_start, win_end); 176 | mod_select_inst(synth, inst_index); 177 | if (synth->inst_list[inst_index] == orig.inst_list[inst_index]) { 178 | synth->set_vals(orig); 179 | } 180 | return synth; 181 | } 182 | 183 | prog* mod_random_k_cont_insts(const prog &orig, unsigned int k, int win_start, int win_end) { 184 | assert(win_end < inst::max_prog_len); 185 | // If k is too big, modify all instructions of the program window 186 | if (win_start + k - 1 > win_end) k = win_end - win_start + 1; 187 | // First make a copy of the old program 188 | prog* synth = new prog(orig); 189 | synth->reset_vals(); 190 | // Select a random start instruction 191 | int start_inst_index = sample_int(win_start, win_end - k + 1); 192 | for (int i = start_inst_index; i < start_inst_index + k; i++) { 193 | mod_select_inst(synth, i); 194 | } 195 | bool is_same_pgm = true; 196 | for (int i = start_inst_index; i < start_inst_index + k; i++) { 197 | if (!(synth->inst_list[i] == orig.inst_list[i])) { 198 | is_same_pgm = false; 199 | break; 200 | } 201 | } 202 | if (is_same_pgm) synth->set_vals(orig); 203 | return synth; 204 | } 205 | 206 | prog* mod_random_cont_insts(const prog &orig, int win_start, int win_end) { 207 | assert(win_end < inst::max_prog_len); 208 | int start_k_value = 2; // at least change two instructions 209 | int max_len = win_end - win_start + 1; 210 | int k = sample_int(start_k_value, max_len); 211 | return mod_random_k_cont_insts(orig, k); 212 | } 213 | 214 | prog* mod_random_inst_as_nop(const prog &orig, int win_start, int win_end) { 215 | assert(win_end < inst::max_prog_len); 216 | int inst_index = sample_int(win_start, win_end); 217 | prog* synth = new prog(orig); 218 | synth->reset_vals(); 219 | synth->inst_list[inst_index].set_as_nop_inst(); 220 | if (synth->inst_list[inst_index] == orig.inst_list[inst_index]) { 221 | synth->set_vals(orig); 222 | } 223 | return synth; 224 | } 225 | -------------------------------------------------------------------------------- /src/search/proposals.h: -------------------------------------------------------------------------------- 1 | #include "../../src/utils.h" 2 | #include "../../src/isa/prog.h" 3 | #include "../../src/isa/inst_header.h" 4 | 5 | using namespace std; 6 | 7 | // a modification in program window [win_start, win_end] 8 | prog* mod_random_inst_operand(const prog &program, 9 | int win_start = 0, int win_end = inst::max_prog_len - 1); 10 | prog* mod_random_inst(const prog &program, 11 | int win_start = 0, int win_end = inst::max_prog_len - 1); 12 | prog* mod_random_k_cont_insts(const prog &program, unsigned int k, 13 | int win_start = 0, int win_end = inst::max_prog_len - 1); 14 | prog* mod_random_cont_insts(const prog &program, 15 | int win_start = 0, int win_end = inst::max_prog_len - 1); 16 | prog* mod_random_inst_opcode_width(const prog &program, 17 | int win_start = 0, int win_end = inst::max_prog_len - 1); 18 | prog* mod_random_inst_as_nop(const prog &program, 19 | int win_start = 0, int win_end = inst::max_prog_len - 1); 20 | -------------------------------------------------------------------------------- /src/search/proposals_test.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../../src/utils.h" 3 | #include "../../src/isa/toy-isa/inst.h" 4 | #include "proposals.h" 5 | 6 | int test1(int input) { 7 | cout << "Test 1" << endl; 8 | #define N 7 9 | inst instructions[N] = {inst(MOVXC, 1, input), /* mov r1, input */ 10 | inst(MOVXC, 2, 4), /* mov r2, 4 */ 11 | inst(ADDXY, 1, 2), /* add r1, r2 */ 12 | inst(MOVXC, 3, 15), /* mov r3, 15 */ 13 | inst(JMPGT, 1, 3, 1), /* if r1 <= r3: */ 14 | inst(RETX, 3), /* ret r3 */ 15 | inst(RETX, 1), /* else ret r1 */ 16 | }; 17 | prog p1(instructions); 18 | p1.print(); 19 | prog* p[6]; 20 | p[0] = &p1; 21 | for (int i = 1; i < 6; i++) { 22 | p[i] = mod_random_inst_operand(*p[i - 1]); 23 | p[i]->print(); 24 | } 25 | for (int i = 1; i < 6; i++) { 26 | delete p[i]; 27 | } 28 | return 0; 29 | } 30 | 31 | int test2(int input) { 32 | cout << "Test 2" << endl; 33 | #define N 7 34 | inst instructions[N] = {inst(MOVXC, 1, input), /* mov r1, input */ 35 | inst(MOVXC, 2, 4), /* mov r2, 4 */ 36 | inst(ADDXY, 1, 2), /* add r1, r2 */ 37 | inst(MOVXC, 3, 15), /* mov r3, 15 */ 38 | inst(JMPGT, 1, 3, 1), /* if r1 <= r3: */ 39 | inst(RETX, 3), /* ret r3 */ 40 | inst(RETX, 1), /* else ret r1 */ 41 | }; 42 | prog p1(instructions); 43 | p1.print(); 44 | prog* p[6]; 45 | p[0] = &p1; 46 | for (int i = 1; i < 6; i++) { 47 | p[i] = mod_random_inst(*p[i - 1]); 48 | cout << "Transformed program after " << i << " proposals:" << endl; 49 | p[i]->print(); 50 | } 51 | bool assert_res = true; 52 | for (int i = 1; i < 6; i++) { 53 | for (int j = 0; j < N; j++) { 54 | inst ins = p[i]->inst_list[j]; 55 | for (int k = ins.get_num_operands(); k < MAX_OP_LEN; k++) { 56 | bool res = (ins.get_operand(k) == 0); 57 | if (! res) { 58 | assert_res = false; 59 | cout << "unused " << k << "th operand in "; 60 | ins.print(); 61 | cout << "is not 0, but " << ins.get_operand(k) << endl; 62 | } 63 | } 64 | } 65 | } 66 | print_test_res(assert_res, "set unused operands as 0"); 67 | 68 | for (int i = 1; i < 6; i++) { 69 | delete p[i]; 70 | } 71 | return 0; 72 | } 73 | 74 | int test3(int input) { 75 | cout << "Test 3" << endl; 76 | #define N 7 77 | inst instructions[N] = {inst(MOVXC, 1, input), /* mov r1, input */ 78 | inst(MOVXC, 2, 4), /* mov r2, 4 */ 79 | inst(ADDXY, 1, 2), /* add r1, r2 */ 80 | inst(MOVXC, 3, 15), /* mov r3, 15 */ 81 | inst(JMPGT, 1, 3, 1), /* if r1 <= r3: */ 82 | inst(RETX, 3), /* ret r3 */ 83 | inst(RETX, 1), /* else ret r1 */ 84 | }; 85 | prog p1(instructions); 86 | p1.print(); 87 | prog* p[6]; 88 | p[0] = &p1; 89 | for (int i = 1; i < 6; i++) { 90 | p[i] = mod_random_k_cont_insts(*p[i - 1], i); 91 | cout << "Transformed program after " << i << " proposals:" << endl; 92 | cout << "(" << i << " continuous instrcution(s) is(are) changed." << ")" << endl; 93 | p[i]->print(); 94 | } 95 | for (int i = 1; i < 6; i++) { 96 | delete p[i]; 97 | } 98 | return 0; 99 | } 100 | 101 | int main(int argc, char *argv[]) { 102 | int input = 10; 103 | if (argc > 1) { 104 | input = atoi(argv[1]); 105 | } 106 | test1(input); 107 | test2(input); 108 | test3(input); 109 | } 110 | -------------------------------------------------------------------------------- /src/search/win_select.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "win_select.h" 3 | 4 | using namespace std; 5 | int num_multi_insns = 0; 6 | int num_jmp = 0; 7 | int num_ret = 0; 8 | 9 | #define WIN_SZ_max 4 10 | 11 | /* If the insn does not satisfy window constraints, return false. 12 | */ 13 | bool insn_satisfy_general_win_constraints(inst* insn, int num_insns) { 14 | assert(num_insns >= 1); 15 | 16 | /* opcode that contains multiple contiguous insns (proposal sampling not 17 | support multiple contiguous insns) 18 | */ 19 | if (num_insns > 1) { 20 | num_multi_insns += num_insns; 21 | return false; 22 | } 23 | 24 | /* no jmp and return insns */ 25 | int op_type = insn[0].get_opcode_type(); 26 | if ((op_type == OP_UNCOND_JMP) || 27 | (op_type == OP_COND_JMP)) { 28 | num_jmp++; 29 | return false; 30 | } else if (op_type == OP_RET) { 31 | num_ret++; 32 | return false; 33 | } 34 | 35 | return true; 36 | } 37 | 38 | /* Generate windows according to valid insns and general constraints 39 | such as in one basic block. 40 | */ 41 | void gen_wins_by_general_constraints(vector>& wins, 42 | prog_static_state& pss, 43 | vector& insns_valid) { 44 | wins.clear(); 45 | const vector& dag = pss.dag; 46 | const graph& g = pss.g; 47 | for (int i = 0; i < dag.size(); i++) { 48 | unsigned int block = dag[i]; 49 | unsigned int block_s = g.nodes[block]._start; 50 | unsigned int block_e = g.nodes[block]._end; 51 | int win_s = block_s; 52 | bool set_win_s = false; 53 | for (int j = block_s; j <= block_e; j++) { 54 | // update win_s as the first valid insn from j to block_e 55 | if (! set_win_s) { 56 | for (int k = j; k <= block_e; k++) { 57 | if (! insns_valid[k]) continue; 58 | win_s = k; 59 | j = k; 60 | set_win_s = true; 61 | break; 62 | } 63 | // check whether there is win_s in this block, 64 | // if not, break in order to go to the next block 65 | if (! set_win_s) break; 66 | } 67 | 68 | // 2 cases that can set the window 69 | // 1. when reach block end 70 | // 2. next insn is not valid 71 | if ((j == block_e) || 72 | ((j < block_e) && (! insns_valid[j + 1]))) { 73 | // push (win_s, j) into wins 74 | wins.push_back(pair {win_s, j}); 75 | set_win_s = false; 76 | } 77 | } 78 | } 79 | } 80 | 81 | void reset_win_constraints_statistics() { 82 | reset_isa_win_constraints_statistics(); 83 | num_multi_insns = 0; 84 | num_jmp = 0; 85 | num_ret = 0; 86 | } 87 | 88 | void print_win_constraints_statistics(const vector>& wins) { 89 | print_isa_win_constraints_statistics(); 90 | cout << "# multi_insns: " << num_multi_insns << endl; 91 | cout << "# jmp: " << num_jmp << endl; 92 | cout << "# ret: " << num_ret << endl; 93 | map mp; 94 | for (int i = 0; i < wins.size(); i++) { 95 | int win_len = wins[i].second - wins[i].first + 1; 96 | auto it = mp.find(win_len); 97 | if (it == mp.end()) { 98 | mp[win_len] = 1; 99 | } else { 100 | it->second++; 101 | } 102 | } 103 | cout << "windows length and frequency: "; 104 | int sum = 0; 105 | for (auto it : mp) { 106 | cout << it.first << ":" << it.second << " "; 107 | sum += it.first * it.second; 108 | } 109 | cout << endl; 110 | cout << "# winodws:" << wins.size() << " # insns: " << sum << endl; 111 | } 112 | 113 | void gen_wins(vector>& wins, inst* pgm, int len, prog_static_state& pss) { 114 | reset_win_constraints_statistics(); 115 | 116 | vector insns_valid(len); 117 | for (int i = 0; i < len; i++) insns_valid[i] = true; 118 | 119 | for (int i = 0; i < len; i++) { 120 | int num_insns = pgm[i].num_insns(); 121 | bool satisfied = insn_satisfy_general_win_constraints(&pgm[i], num_insns); 122 | if (! satisfied) { 123 | for (int j = i; j < i + num_insns; j++) { 124 | insns_valid[j] = false; 125 | } 126 | } 127 | i += num_insns - 1; 128 | } 129 | 130 | for (int i = 0; i < len; i++) { 131 | if (! insns_valid[i]) continue; 132 | bool satisfied = insn_satisfy_isa_win_constraints(pgm[i], pss.static_state[i]); 133 | if (! satisfied) insns_valid[i] = false; 134 | } 135 | 136 | gen_wins_by_general_constraints(wins, pss, insns_valid); 137 | 138 | // if (logger.is_print_level(LOGGER_DEBUG)) { 139 | // print_win_constraints_statistics(wins); 140 | // cout << "insns not in windows" << ": "; 141 | // int sum = 0; 142 | // for (int i = 0; i < len; i++) { 143 | // if (! insns_valid[i]) { 144 | // cout << i << " "; 145 | // sum++; 146 | // } 147 | // } 148 | // cout << endl; 149 | // cout << "sum: " << sum << endl; 150 | // } 151 | } 152 | 153 | // convert one window into windows according to some rules, eg remove small windows 154 | void optimize_one_win(vector>& wins_after, const pair& win_before) { 155 | wins_after.clear(); 156 | int win_s = win_before.first; 157 | int win_e = win_before.second; 158 | int win_sz = win_e - win_s + 1; 159 | // 1. remove windows with size 1 160 | if (win_sz == 1) { 161 | return; 162 | } 163 | // 2. split big windows into smaller ones 164 | if (win_sz > WIN_SZ_max) { 165 | for (int i = win_s; i <= win_e + 1 - WIN_SZ_max; i += WIN_SZ_max) { 166 | wins_after.push_back(pair {i, i + WIN_SZ_max - 1}); 167 | } 168 | int remainder = win_sz % WIN_SZ_max; 169 | if (remainder != 0) { 170 | wins_after.push_back(pair {win_e + 1 - WIN_SZ_max, win_e}); 171 | } 172 | } else { 173 | wins_after.push_back(pair {win_s, win_e}); 174 | } 175 | } 176 | 177 | // optimize windows to make windows easier to be optimized in a given time 178 | void optimize_wins(vector>& wins) { 179 | vector> wins_temp(wins.size()); 180 | for (int i = 0; i < wins.size(); i++) { 181 | wins_temp[i] = wins[i]; 182 | } 183 | 184 | wins.clear(); 185 | for (int i = 0; i < wins_temp.size(); i++) { 186 | vector> wins_after; 187 | optimize_one_win(wins_after, wins_temp[i]); 188 | for (int j = 0; j < wins_after.size(); j++) { 189 | wins.push_back(wins_after[j]); 190 | } 191 | } 192 | if (logger.is_print_level(LOGGER_DEBUG)) { 193 | print_win_constraints_statistics(wins); 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /src/search/win_select.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../../src/isa/inst_header.h" 4 | #include "../../src/isa/ebpf/win_select.h" 5 | 6 | using namespace std; 7 | 8 | void gen_wins(vector>& wins, inst* pgm, int len, prog_static_state& pss); 9 | void optimize_wins(vector>& wins); 10 | -------------------------------------------------------------------------------- /src/search/win_select_test_ebpf.cc: -------------------------------------------------------------------------------- 1 | #include "../../src/isa/ebpf/inst.h" 2 | #include "win_select.h" 3 | 4 | bool are_wins_equal(vector>& wins, 5 | vector& win_s_expected, 6 | vector& win_e_expected) { 7 | bool are_equal = true; 8 | if (wins.size() == win_s_expected.size()) { 9 | for (int i = 0; i < wins.size(); i++) { 10 | bool found = false; 11 | for (int j = 0; j < win_s_expected.size(); j++) { 12 | if ((wins[i].first == win_s_expected[j]) && 13 | (wins[i].second == win_e_expected[j])) { 14 | found = true; 15 | break; 16 | } 17 | } 18 | if (! found) { 19 | are_equal = false; 20 | break; 21 | } 22 | } 23 | } else { 24 | are_equal = false; 25 | } 26 | return are_equal; 27 | } 28 | 29 | void check_gen_wins(inst* program, int len, 30 | vector& win_s_expected, 31 | vector& win_e_expected, 32 | string test_name) { 33 | assert(win_s_expected.size() == win_e_expected.size()); 34 | prog_static_state pss; 35 | static_analysis(pss, program, len); 36 | vector> wins; 37 | gen_wins(wins, program, len, pss); 38 | bool eq = are_wins_equal(wins, win_s_expected, win_e_expected); 39 | print_test_res(eq, test_name); 40 | } 41 | 42 | void check_opt_wins(inst* program, int len, 43 | vector& win_s_expected, 44 | vector& win_e_expected, 45 | string test_name) { 46 | assert(win_s_expected.size() == win_e_expected.size()); 47 | prog_static_state pss; 48 | static_analysis(pss, program, len); 49 | vector> wins; 50 | gen_wins(wins, program, len, pss); 51 | optimize_wins(wins); 52 | bool eq = are_wins_equal(wins, win_s_expected, win_e_expected); 53 | print_test_res(eq, test_name); 54 | } 55 | 56 | void test1() { 57 | cout << "test 1" << endl; 58 | mem_t::_layout.clear(); 59 | mem_t::set_pgm_input_type(PGM_INPUT_pkt); 60 | mem_t::set_pkt_sz(32); 61 | mem_t::add_map(map_attr(16, 32, 16)); 62 | 63 | vector win_s_expected, win_e_expected; 64 | inst p1[] = {inst(MOV64XC, 1, 0), 65 | inst(MOV64XC, 2, 2), 66 | inst(MOV64XC, 3, 0), 67 | inst(EXIT), 68 | }; 69 | 70 | win_s_expected = {0}; 71 | win_e_expected = {2}; 72 | check_gen_wins(p1, sizeof(p1) / sizeof(inst), win_s_expected, win_e_expected, "1"); 73 | check_opt_wins(p1, sizeof(p1) / sizeof(inst), win_s_expected, win_e_expected, "1"); 74 | 75 | inst p2[] = {inst(MOV32XC, 0, -1), /* r0 = 0xffffffff */ 76 | inst(JGTXC, 0, 0, 1), /* if r0 <= 0, ret r0 = 0xffffffff */ 77 | inst(EXIT), 78 | inst(MOV64XC, 1, -1), /* else r1 = 0xffffffffffffffff */ 79 | inst(JGTXY, 1, 0, 1), /* if r1 <= r0, ret r0 = 0xffffffff */ 80 | inst(EXIT), 81 | inst(MOV64XC, 0, 0), /* else r0 = 0 */ 82 | inst(EXIT), /* exit, return r0 */ 83 | }; 84 | win_s_expected = {0, 3, 6}; 85 | win_e_expected = {0, 3, 6}; 86 | check_gen_wins(p2, sizeof(p2) / sizeof(inst), win_s_expected, win_e_expected, "2"); 87 | win_s_expected = {}; 88 | win_e_expected = {}; 89 | check_opt_wins(p2, sizeof(p2) / sizeof(inst), win_s_expected, win_e_expected, "2"); 90 | 91 | inst p3[] = {inst(JA, 0), 92 | inst(MOV64XC, 1, 0), 93 | inst(NOP), 94 | inst(MOV64XC, 2, 0), 95 | inst(EXIT), 96 | }; 97 | win_s_expected = {1}; 98 | win_e_expected = {3}; 99 | check_gen_wins(p3, sizeof(p3) / sizeof(inst), win_s_expected, win_e_expected, "3"); 100 | check_opt_wins(p3, sizeof(p3) / sizeof(inst), win_s_expected, win_e_expected, "3"); 101 | 102 | inst p4[] = {inst(MOV64XC, 0, 0), 103 | inst(JEQXC, 0, 0, 3), 104 | inst(MOV64XC, 0, 0), 105 | inst(MOV64XC, 0, 0), 106 | inst(MOV64XC, 0, 0), 107 | inst(EXIT), 108 | }; 109 | win_s_expected = {0, 2}; 110 | win_e_expected = {0, 4}; 111 | check_gen_wins(p4, sizeof(p4) / sizeof(inst), win_s_expected, win_e_expected, "4"); 112 | win_s_expected = {2}; 113 | win_e_expected = {4}; 114 | check_opt_wins(p4, sizeof(p4) / sizeof(inst), win_s_expected, win_e_expected, "4"); 115 | 116 | inst p5[] = {inst(MOV64XC, 0, 0), // 0: 117 | inst(), 118 | inst(), 119 | inst(), // 3: 120 | inst(), 121 | inst(JEQXC, 0, 0, 4), 122 | inst(), // 6: 123 | inst(), 124 | inst(), // 8: 125 | inst(JA, 10), 126 | inst(), // 10: 127 | inst(), 128 | inst(), 129 | inst(), 130 | inst(), 131 | inst(), 132 | inst(), 133 | inst(), 134 | inst(), 135 | inst(), // 19 136 | inst(EXIT), 137 | }; 138 | win_s_expected = {0, 6, 10}; 139 | win_e_expected = {4, 8, 19}; 140 | check_gen_wins(p5, sizeof(p5) / sizeof(inst), win_s_expected, win_e_expected, "5"); 141 | win_s_expected = {0, 1, 6, 10, 14, 16}; 142 | win_e_expected = {3, 4, 8, 13, 17, 19}; 143 | check_opt_wins(p5, sizeof(p5) / sizeof(inst), win_s_expected, win_e_expected, "5"); 144 | 145 | cout << "Test 1.2: test the opcode with has multiple insns" << endl; 146 | // test ldmapid 147 | inst p2_1[] = {inst(STH, 10, -2, 0xff), 148 | INSN_LDMAPID(1, 0), 149 | inst(), 150 | inst(MOV64XY, 2, 10), 151 | inst(ADD64XC, 2, -2), 152 | inst(CALL, BPF_FUNC_map_lookup_elem), 153 | inst(JEQXC, 0, 0, 3), 154 | inst(LDXB, 1, 0, 0), // insn 6 155 | inst(MOV64XY, 0, 1), 156 | inst(EXIT), 157 | inst(MOV64XC, 0, 0), 158 | inst(EXIT), 159 | }; 160 | win_s_expected = {0, 3, 7, 10}; 161 | win_e_expected = {0, 4, 8, 10}; 162 | check_gen_wins(p2_1, sizeof(p2_1) / sizeof(inst), win_s_expected, win_e_expected, "1"); 163 | win_s_expected = {3, 7}; 164 | win_e_expected = {4, 8}; 165 | check_opt_wins(p2_1, sizeof(p2_1) / sizeof(inst), win_s_expected, win_e_expected, "1"); 166 | 167 | // test movdwxc 168 | inst p2_2[] = {INSN_MOVDWXC(0, 0x1234567890), 169 | inst(EXIT), 170 | }; 171 | win_s_expected = {}; 172 | win_e_expected = {}; 173 | check_gen_wins(p2_2, sizeof(p2_2) / sizeof(inst), win_s_expected, win_e_expected, "2"); 174 | check_opt_wins(p2_2, sizeof(p2_2) / sizeof(inst), win_s_expected, win_e_expected, "2"); 175 | } 176 | 177 | int main() { 178 | logger.set_least_print_level(LOGGER_ERROR); 179 | try { 180 | test1(); 181 | } catch (const string err_msg) { 182 | cout << "NOT SUCCESS: " << err_msg << endl; 183 | } 184 | 185 | return 0; 186 | } 187 | -------------------------------------------------------------------------------- /src/utils.cc: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | using namespace std; 4 | 5 | logger_class logger; 6 | 7 | void print_test_res(bool res, string test_name) { 8 | if (res) { 9 | std::cout << "check " + test_name + " SUCCESS\n"; 10 | } else { 11 | std::cout << "check " + test_name + " NOT SUCCESS\n"; 12 | } 13 | } 14 | 15 | ostream& operator<<(ostream& out, const vector& vec) { 16 | for (size_t i = 0; i < vec.size(); i++) { 17 | out << vec[i] << " "; 18 | } 19 | return out; 20 | } 21 | 22 | void split_string(const string& s, vector& v, const string& c) { 23 | std::string::size_type pos1, pos2; 24 | pos2 = s.find(c); 25 | pos1 = 0; 26 | while (std::string::npos != pos2) { 27 | v.push_back(s.substr(pos1, pos2 - pos1)); 28 | pos1 = pos2 + c.size(); 29 | pos2 = s.find(c, pos1); 30 | } 31 | if (pos1 != s.length()) 32 | v.push_back(s.substr(pos1)); 33 | } 34 | 35 | 36 | /* Requires support for advanced bit manipulation (ABM) instructions on the 37 | * architecture where this program is run. */ 38 | unsigned int pop_count_asm(unsigned int x) { 39 | unsigned int y = x; 40 | unsigned int z; 41 | asm ("popcnt %1, %0" 42 | : "=a" (z) 43 | : "b" (y) 44 | ); 45 | return z; 46 | } 47 | 48 | bool is_little_endian() { 49 | int i = 1; 50 | char *p = (char *)&i; 51 | if (*p == 1) return true; 52 | return false; 53 | } 54 | 55 | // convert uint8_t vector to hex string 56 | // e.g. addr[2] = {0x1, 0xff}, hex string: "01ff" 57 | string uint8_t_vec_2_hex_str(const vector& a) { 58 | stringstream ss; 59 | ss << hex << setfill('0'); 60 | 61 | for (int i = 0; i < a.size(); i++) { 62 | ss << hex << setw(2) << static_cast(a[i]); 63 | } 64 | 65 | return ss.str(); 66 | } 67 | -------------------------------------------------------------------------------- /src/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | using namespace std; 11 | 12 | #if ISA_TOY_ISA 13 | // register type, also used as input/output type 14 | // since input/output are assigned to/from registers 15 | typedef int reg_t; 16 | typedef int op_t; 17 | // number of register bits, used by smt_var.h/cc, cost.cc 18 | #define NUM_REG_BITS 32 19 | #elif ISA_EBPF 20 | typedef int64_t reg_t; 21 | typedef int32_t op_t; 22 | #define NUM_REG_BITS 64 23 | #else 24 | typedef int reg_t; 25 | typedef int op_t; 26 | #define NUM_REG_BITS 32 27 | #endif 28 | 29 | #define NUM_ADDR_BITS 64 30 | #define NUM_BYTE_BITS 8 31 | 32 | #define NOW chrono::steady_clock::now() 33 | #define DUR(t1, t2) chrono::duration (t2 - t1).count() 34 | 35 | #define H32(v) (0xffffffff00000000 & (v)) 36 | #define H48(v) (0xffffffffffff0000 & (v)) 37 | #define L5(v) (0x000000000000001f & (v)) 38 | #define L6(v) (0x000000000000003f & (v)) 39 | #define L16(v) (0x000000000000ffff & (v)) 40 | #define L32(v) (0x00000000ffffffff & (v)) 41 | 42 | #define RAISE_EXCEPTION(x) {\ 43 | string err_msg = string(x) + string(" has not been implemented"); \ 44 | cerr << err_msg << endl;\ 45 | throw (err_msg); \ 46 | } 47 | 48 | void print_test_res(bool res, string test_name); 49 | ostream& operator<<(ostream& out, const vector& vec); 50 | void split_string(const string& s, vector& v, const string& c); 51 | unsigned int pop_count_asm(unsigned int x); 52 | bool is_little_endian(); 53 | // convert uint8_t vector to hex string 54 | // e.g. addr[2] = {0x1, 0xff}, hex string: "01ff" 55 | string uint8_t_vec_2_hex_str(const vector& a); 56 | 57 | enum LOGGER_LEVEL { 58 | LOGGER_ERROR = 0, 59 | LOGGER_DEBUG, 60 | }; 61 | 62 | class logger_class { 63 | private: 64 | int least_print_level = LOGGER_ERROR; 65 | public: 66 | void set_least_print_level(int level) {least_print_level = level;} 67 | bool is_print_level(int level) {return (level <= least_print_level);} 68 | }; 69 | 70 | extern logger_class logger; 71 | -------------------------------------------------------------------------------- /src/verify/cfg.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "../../src/utils.h" 8 | #include "../../src/isa/inst_header_basic.h" 9 | 10 | using namespace std; 11 | 12 | typedef unordered_map unsigned_map; 13 | 14 | class node { 15 | private: 16 | public: 17 | unsigned int _start = 0; // start intruction ID 18 | unsigned int _end = 0; // end instruction ID 19 | node(unsigned int start, unsigned int end); 20 | ~node(); 21 | string to_str(); 22 | friend ostream& operator<<(ostream& out, const node& n); 23 | }; 24 | 25 | class graph { 26 | private: 27 | size_t get_end_inst_id(inst* inst_lst, size_t start, size_t end); 28 | void insert_node_start(int cur_index, int d, int length, set& node_starts); 29 | void gen_node_starts(inst* inst_lst, int length, set& node_starts); 30 | void gen_node_ends(inst* inst_lst, int length, set& node_starts, vector& node_ends); 31 | void gen_all_nodes_graph(vector& gnodes, set& node_starts, vector& node_ends); 32 | void gen_all_edges_graph(vector >& gnodes_out, vector& gnodes, inst* inst_lst); 33 | void gen_id_map(unsigned_map& id_map, vector& gnodes); 34 | void add_node(node& nd, unsigned int& added); 35 | void dfs(size_t cur_gnode_id, vector& gnodes, vector >& gnodes_out, \ 36 | vector& added, vector& visited, vector& finished); 37 | void init(); 38 | public: 39 | vector nodes; 40 | vector > nodes_in; 41 | vector > nodes_out; 42 | graph(); 43 | graph(inst* inst_lst, int length); 44 | ~graph(); 45 | void gen_graph(inst* inst_lst, int length); 46 | string graph_to_str() const; 47 | void clear() {nodes.clear(); nodes_in.clear(); nodes_out.clear();}; 48 | friend ostream& operator<<(ostream& out, const graph& g); 49 | }; 50 | 51 | void topo_sort_for_graph(vector& nodes, const graph& g); 52 | -------------------------------------------------------------------------------- /src/verify/cfg_test.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../../src/utils.h" 3 | #include "../../src/isa/toy-isa/inst.h" 4 | #include "cfg.h" 5 | 6 | using namespace std; 7 | 8 | void test1() { 9 | cout << "Test1: cfg check" << endl; 10 | inst instructions1[6] = {inst(JMPGT, 0, 2, 1), // if r0 <= r2: 11 | inst(ADDXY, 0, 1), // add r0, r1 12 | inst(MOVXC, 2, 15), // mov r2, 15 13 | inst(JMPGT, 0, 2, 1), // if r0 <= r2: 14 | inst(RETX, 2), // ret r2 15 | inst(RETX, 0), // else ret r0 16 | }; 17 | string expected; 18 | graph g1(instructions1, 6); 19 | expected = "nodes:0,0 1,1 2,3 4,4 5,5 edges: 0:;1,2, 1:0,;2, 2:1,0,;3,4, 3:2,; 4:2,;"; 20 | print_test_res(g1.graph_to_str() == expected, "program 1"); 21 | 22 | inst instructions2[6] = {inst(MOVXC, 1, 4), // mov r0, 4 23 | inst(JMPGT, 0, 2, 3), // if r0 <= r2: 24 | inst(MOVXC, 2, 15), // mov r2, 15 25 | inst(JMPGT, 0, 2, 1), // if r0 <= r2: 26 | inst(RETX, 2), // ret r2 27 | inst(RETX, 0), // else ret r0 28 | }; 29 | graph g2(instructions2, 6); 30 | expected = "nodes:0,1 2,3 4,4 5,5 edges: 0:;1,3, 1:0,;2,3, 2:1,; 3:1,0,;"; 31 | print_test_res(g2.graph_to_str() == expected, "program 2"); 32 | 33 | inst instructions3[7] = {inst(MOVXC, 1, 4), // mov r1, 4 34 | inst(JMPGT, 0, 2, 3), // if r0 <= r3: 35 | inst(RETX, 0), // else ret r0 36 | inst(MOVXC, 2, 15), // mov r2, 15 37 | inst(JMPGT, 0, 2, 1), // if r0 <= r2: 38 | inst(RETX, 2), // ret r2 39 | inst(RETX, 0), // else ret r0 40 | }; 41 | graph g3(instructions3, 7); 42 | expected = "nodes:0,1 2,2 5,5 edges: 0:;1,2, 1:0,; 2:0,;"; 43 | print_test_res(g3.graph_to_str() == expected, "program 3"); 44 | 45 | inst instructions4[4] = {inst(JMPGT, 0, 2, 1), // 0 JMP to inst 2 46 | inst(RETX, 2), // 1 END 47 | inst(JMPGT, 0, 2, -2), // 2 JMP to inst 1 48 | inst(RETX, 2), // 3 END 49 | }; 50 | graph g4(instructions4, 4); 51 | expected = "nodes:0,0 1,1 2,2 3,3 edges: 0:;1,2, 1:0,2,; 2:0,;3,1, 3:2,;"; 52 | print_test_res(g4.graph_to_str() == expected, "program 4"); 53 | } 54 | 55 | void test2() { 56 | cout << "Test2: illegal input cfg check" << endl; 57 | string expected; 58 | string actual; 59 | // test illegal input with loop 60 | expected = "illegal input: loop from node 1[2:3] to node 0[0:1]"; 61 | actual = ""; 62 | inst instructions1[6] = {inst(MOVXC, 1, 4), // 0 mov r0, 4 63 | inst(JMPGT, 0, 2, 3), // 1 if r0 <= r2: 64 | inst(MOVXC, 2, 15), // 2 mov r2, 15 65 | inst(JMPGT, 0, 2, -4), // 3 loop from here to instruction 0 66 | inst(RETX, 2), // 4 ret r2 67 | inst(RETX, 0), // 5 else ret r0 68 | }; 69 | try { 70 | graph g1(instructions1, 6); 71 | } catch (const string err_msg) { 72 | actual = err_msg; 73 | } 74 | print_test_res(actual == expected, "program 1"); 75 | 76 | // test illegal input with loop 77 | expected = "illegal input: loop from node 2[2:2] to node 3[3:3]"; 78 | actual = ""; 79 | inst instructions2[5] = {inst(JMPGT, 0, 2, 2), // 0 JMP to inst 3 80 | inst(RETX, 2), // 1 END 81 | inst(NOP), // 2 82 | inst(JMPGT, 0, 2, -2), // 3 JMP to inst 2, cause the loop from inst 2 to inst 3 83 | inst(RETX, 0), // 4 END 84 | }; 85 | try { 86 | graph g2(instructions2, 5); 87 | } catch (const string err_msg) { 88 | actual = err_msg; 89 | } 90 | print_test_res(actual == expected, "program 2"); 91 | 92 | // test illegal input: goes to an invalid instruction 93 | expected = "illegal input: instruction 3 goes to an invalid instruction 4"; 94 | actual = ""; 95 | inst instructions3[4] = {inst(JMPGT, 0, 2, 2), // 0 JMP to inst 3 96 | inst(RETX, 2), // 1 END 97 | inst(RETX, 0), // 2 END 98 | inst(JMPGT, 0, 2, -2), // 3 JMP to inst 2. illegal: no jump will go to 4 99 | }; 100 | try { 101 | graph g3(instructions3, 4); 102 | } catch (const string err_msg) { 103 | actual = err_msg; 104 | } 105 | print_test_res(actual == expected, "program 3"); 106 | 107 | // test illegal input: goes to an invalid instruction 108 | expected = "illegal input: instruction 0 goes to an invalid instruction 2"; 109 | actual = ""; 110 | inst instructions4[2] = {inst(JMPGT, 0, 2, 1), // 0 JMP to inst 2 -> illegal 111 | inst(RETX, 2), // 1 END 112 | }; 113 | try { 114 | graph g4(instructions4, 2); 115 | } catch (const string err_msg) { 116 | actual = err_msg; 117 | } 118 | 119 | print_test_res(actual == expected, "program 4"); 120 | 121 | // test illegal input: goes to an invalid instruction 122 | expected = "illegal input: instruction 0 goes to an invalid instruction -1"; 123 | actual = ""; 124 | inst instructions5[2] = {inst(JMPGT, 0, 2, -2), // 0 JMP to inst -1 -> illegal 125 | inst(RETX, 2), // 1 END 126 | }; 127 | try { 128 | graph g5(instructions5, 2); 129 | } catch (const string err_msg) { 130 | actual = err_msg; 131 | } 132 | print_test_res(actual == expected, "program 5"); 133 | 134 | // loop caused by unconditional jmp 135 | expected = "illegal input: loop from node 1[2:2] to node 0[0:0]"; 136 | actual = ""; 137 | inst instructions6[3] = {inst(JMP, 1), 138 | inst(RETX, 0), 139 | inst(JMP, -3), 140 | }; 141 | try { 142 | graph g6(instructions6, 3); 143 | } catch (const string err_msg) { 144 | actual = err_msg; 145 | } 146 | print_test_res(actual == expected, "program 6"); 147 | 148 | // test (jmp -1) loop 149 | expected = "illegal input: loop from node 0[0:0] to node 0[0:0]"; 150 | actual = ""; 151 | inst instructions7[1] = {inst(JMP, -1)}; 152 | try { 153 | graph g7(instructions7, 1); 154 | } catch (const string err_msg) { 155 | actual = err_msg; 156 | } 157 | print_test_res(actual == expected, "program 7"); 158 | } 159 | 160 | void test3() { 161 | cout << "Test 3: special cases check" << endl; 162 | string expected; 163 | // instruction JMP logic when jmp distance is 0 164 | inst instructions1[7] = {inst(MOVXC, 2, 4), 165 | inst(ADDXY, 0, 2), 166 | inst(MOVXC, 3, 15), 167 | inst(JMPLE, 2, 0, 0), // expect: this JMPLE will not cause other blocks 168 | inst(), 169 | inst(MAXX, 0, 3), 170 | inst(), 171 | }; 172 | graph g1(instructions1, 7); 173 | expected = "nodes:0,3 4,6 edges: 0:;1,1, 1:0,0,;"; 174 | print_test_res(g1.graph_to_str() == expected, "case 1"); 175 | 176 | // block ending up with NOP will connect to block starting from the instruction following NOP 177 | inst instructions2[4] = {inst(JMPEQ, 1, 3, 1), 178 | inst(), 179 | inst(MOVXC, 2, 4), 180 | inst(), 181 | }; 182 | graph g2(instructions2, 4); 183 | expected = "nodes:0,0 1,1 2,3 edges: 0:;1,2, 1:0,;2, 2:1,0,;"; 184 | print_test_res(g2.graph_to_str() == expected, "case 2"); 185 | } 186 | 187 | /* test unconditional jmp */ 188 | void test4() { 189 | cout << "Test 4: unconditional jmp check" << endl; 190 | string expected; 191 | inst instructions1[3] = {inst(JMP, 1), 192 | inst(ADDXY, 0, 0), 193 | inst(RETX, 0), 194 | }; 195 | graph g1(instructions1, 3); 196 | expected = "nodes:0,0 2,2 edges: 0:;1, 1:0,;"; 197 | print_test_res(g1.graph_to_str() == expected, "program 1"); 198 | 199 | 200 | inst instructions2[3] = {inst(JMP, 1), 201 | inst(RETX, 0), 202 | inst(JMP, -2), 203 | }; 204 | graph g2(instructions2, 3); 205 | expected = "nodes:0,0 2,2 1,1 edges: 0:;1, 1:0,;2, 2:1,;"; 206 | print_test_res(g2.graph_to_str() == expected, "program 2"); 207 | } 208 | 209 | int main () { 210 | test1(); 211 | test2(); 212 | test3(); 213 | test4(); 214 | 215 | return 0; 216 | } 217 | -------------------------------------------------------------------------------- /src/verify/cfg_test_ebpf.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../../src/utils.h" 3 | #include "../../src/isa/ebpf/inst.h" 4 | #include "cfg.h" 5 | 6 | using namespace std; 7 | 8 | inst instructions1[8] = {inst(MOV32XC, 0, -1), /* r0 = 0xffffffff */ 9 | inst(JGTXC, 0, 0, 1), /* if r0 <= 0, ret r0 = 0xffffffff */ 10 | inst(EXIT), 11 | inst(MOV64XC, 1, -1), /* else r1 = 0xffffffffffffffff */ 12 | inst(JGTXY, 1, 0, 1), /* if r1 <= r0, ret r0 = 0xffffffff */ 13 | inst(EXIT), 14 | inst(MOV64XC, 0, 0), /* else r0 = 0 */ 15 | inst(EXIT), /* exit, return r0 */ 16 | }; 17 | 18 | inst instructions2[9] = {inst(MOV32XC, 0, -1), /* r0 = 0x00000000ffffffff */ 19 | inst(ADD64XC, 0, 0x1), /* r0 = 0x0000000100000000 */ 20 | inst(MOV64XC, 1, 0x0), /* r1 = 0 */ 21 | inst(JEQXC, 0, 0, 4), /* if r0 == 0, ret r0 = 0x100000000 */ 22 | inst(MOV64XC, 0, -1), /* else r0 = 0xffffffffffffffff */ 23 | inst(JEQXC, 0, 0xffffffff, 1),/* if r0 == -1, ret r0 = 0 */ 24 | inst(JA, 1), /* else ret r0 = 0xffffffffffffffff */ 25 | inst(MOV64XC, 0, 0), 26 | inst(EXIT), 27 | }; 28 | 29 | // test jmp and st/ld 30 | inst instructions3[6] = {inst(STXB, 10, -1, 1), 31 | inst(JEQXC, 1, 0x12, 2), 32 | inst(MOV64XC, 1, 0x12), 33 | inst(STXW, 10, -1, 1), 34 | inst(LDXB, 0, 10, -1), 35 | inst(EXIT), 36 | }; 37 | 38 | // test tail call 39 | inst instructions4[6] = {inst(MOV64XY, 1, 1), 40 | INSN_LDMAPID(2, 0), 41 | inst(MOV64XC, 3, 1), 42 | inst(CALL, BPF_FUNC_tail_call), 43 | inst(MOV64XC, 0, 0xff), 44 | inst(EXIT), 45 | }; 46 | 47 | void test1() { 48 | string expected; 49 | graph g1(instructions1, 8); 50 | expected = "nodes:0,1 2,2 3,4 5,5 6,7 edges: 0:;1,2, 1:0,; 2:0,;3,4, 3:2,; 4:2,;"; 51 | print_test_res(g1.graph_to_str() == expected, "program 1"); 52 | 53 | graph g2(instructions2, 9); 54 | expected = "nodes:0,3 4,5 6,6 8,8 7,7 edges: 0:;1,3, 1:0,;2,4, 2:1,;3, 3:2,4,0,; 4:1,;3,"; 55 | print_test_res(g2.graph_to_str() == expected, "program 2"); 56 | 57 | graph g3(instructions3, 6); 58 | expected = "nodes:0,1 2,3 4,5 edges: 0:;1,2, 1:0,;2, 2:1,0,;"; 59 | print_test_res(g3.graph_to_str() == expected, "program 3"); 60 | 61 | graph g4(instructions4, 6); 62 | expected = "nodes:0,3 edges: 0:;"; 63 | print_test_res(g4.graph_to_str() == expected, "program 4"); 64 | } 65 | 66 | int main() { 67 | test1(); 68 | 69 | return 0; 70 | } 71 | -------------------------------------------------------------------------------- /src/verify/smt_prog.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "z3++.h" 5 | #include "../../src/utils.h" 6 | #include "cfg.h" 7 | #include "../../src/isa/inst_header.h" 8 | 9 | using namespace z3; 10 | 11 | /* smt_prog algorithm document: https://github.com/ngsrinivas/superopt/tree/master/doc */ 12 | 13 | ostream& operator<< (ostream& out, vector& _expr_vec); 14 | ostream& operator<< (ostream& out, vector >& _expr_vec); 15 | bool is_smt_valid(expr smt); 16 | 17 | class smt_prog { 18 | private: 19 | // post_reg_val[i] is post register values of basic block i, 20 | // which are initial values for NEXT basic blocks 21 | vector > post_reg_val; 22 | void init_pgm_dag(unsigned int root_node); 23 | // return the SMT for the given program without branch and loop 24 | void smt_block(expr& smt_b, expr& smt_sc, inst* program, int start, int end, smt_var& sv, size_t cur_bid); 25 | void init(unsigned int num_regs); 26 | // void topo_sort_dfs(size_t cur_bid, vector& blocks, vector& finished); 27 | void gen_block_prog_logic(expr& e, expr& f_mem, expr& f_sc, smt_var& sv, size_t cur_bid, inst* inst_lst); 28 | void store_post_reg_val(smt_var& sv, size_t cur_bid, unsigned int num_regs); 29 | void add_path_cond(expr p_con, size_t cur_bid, size_t next_bId); 30 | void gen_post_path_con(smt_var& sv, size_t cur_bid, inst& inst_end); 31 | void get_init_val(expr& f_iv, smt_var& sv, size_t in_bid, unsigned int num_regs); 32 | expr smt_end_block_inst(size_t cur_bid, inst& inst_end, unsigned int prog_id); 33 | void gen_block_c_in(expr& c_in, size_t cur_bid); 34 | public: 35 | // `public` for unit test check 36 | smt_var sv; 37 | // program logic 38 | expr pl = string_to_expr("true"); 39 | // program's safety check expression 40 | expr p_sc = Z3_true; 41 | // store path_con, reg_iv, bl, post, g 42 | // 1. path_con[i] stores pre path condition formulas of basic block i 43 | // There is a corresponding relationship between path_con and g.nodesIn 44 | // more specifically, path_con[i][j] stores the all pre path condition formulae from basic block g.nodesIn[i][j] to i 45 | vector > path_con; 46 | // 2. reg_iv[i][j] stores pre register initial value formula 47 | // that values from the last node(g.nodes[i][j]) are fed to the node(i) 48 | vector > reg_iv; 49 | // 3. bl[i] stores block logic formula of basic block i 50 | // more specifically, bl[i] = instLogic_i_0 && instLogic_i_1 && ... && instLogic_i_n 51 | vector bl; 52 | // 4. post[i] store post logic formula (output formula) for the end basic block i 53 | vector post; 54 | // control flow graph 55 | graph g; 56 | smt_prog(); 57 | ~smt_prog(); 58 | // Return the program logic FOL formula 'PL' including basic program logic 59 | // and the formula of capturing the output of the program in the variable output[prog_id] 60 | expr gen_smt(unsigned int prog_id, inst* inst_lst, int length, bool is_win = false, int win_start = 0, int win_end = inst::max_prog_len); 61 | }; 62 | -------------------------------------------------------------------------------- /src/verify/smt_prog_test.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../../src/utils.h" 3 | #include "../../src/isa/inst_header.h" 4 | #include "smt_prog.h" 5 | 6 | using namespace z3; 7 | 8 | #define v(x) string_to_expr(x) 9 | 10 | // basic block test 11 | void test1() { 12 | std::cout << "test 1: basic block check starts...\n"; 13 | inst p[5] = {inst(MOVXC, 1, 10), // 0 14 | inst(JMPLT, 0, 1, 1), // 1 15 | inst(RETX, 1), // 2 16 | inst(MAXC, 0, 15), // 3 17 | inst(RETX, 0), // 4 18 | }; 19 | smt_prog ps; 20 | unsigned int prog_id = 0; 21 | expr pl = ps.gen_smt(prog_id, p, 5); 22 | // test block 2[3:4] 23 | std::cout << "test 1.1: check basic block 2[3:4]\n"; 24 | // fmt: r_[prog_id]_[block_id]_[reg_id]_[version_id] 25 | expr prePC2 = (v("r_0_0_0_0") < v("r_0_0_1_1")); 26 | expr preIV2 = (v("r_0_0_0_0") == v("r_0_2_0_0") && \ 27 | v("r_0_0_1_1") == v("r_0_2_1_0") && \ 28 | v("r_0_0_2_0") == v("r_0_2_2_0") && \ 29 | v("r_0_0_3_0") == v("r_0_2_3_0") 30 | ); 31 | expr bl2 = (implies(v("r_0_2_0_0") > 15, v("r_0_2_0_1") == v("r_0_2_0_0")) && \ 32 | implies(v("r_0_2_0_0") <= 15, v("r_0_2_0_1") == 15) && 33 | (v("ret_val_" + to_string(prog_id)) == v("r_0_2_0_1")) 34 | ); 35 | print_test_res(is_smt_valid(prePC2 == ps.path_con[2][0]), "pre path condition"); 36 | print_test_res(is_smt_valid(preIV2 == ps.reg_iv[2][0]), "pre register initial values"); 37 | print_test_res(is_smt_valid(bl2 == ps.bl[2]), "basic block logic"); 38 | 39 | std::cout << "\ntest1.2: check basic block 2[2:3]\n"; 40 | inst p1[7] = {inst(JMPLT, 0, 1, 3), // 0 [0:0] 41 | inst(MOVXC, 0, 1), // 1 [1:1] 42 | inst(ADDXY, 0, 0), // 2 [2:3] 43 | inst(RETX, 0), // 3 44 | inst(ADDXY, 0, 0), // 4 [4:5] 45 | inst(JMPLT, 0, 1, -4), // 5 46 | inst(RETX, 0), // 6 [6:6] 47 | }; 48 | prog_id = 1; 49 | ps.gen_smt(prog_id, p1, 7); 50 | // blocks: 0[0:0] 1[1:1] 2[2:3] 3[4:5] 4[6:6] 51 | // case0: 0 -> 1 -> 2; case1: 0 -> 3 -> 2 52 | // fmt: r_[prog_id]_[block_id]_[reg_id]_[version_id] 53 | expr pre_pc2_0 = !(v("r_1_0_0_0") < v("r_1_0_1_0")); 54 | expr pre_pc2_1 = ((v("r_1_0_0_0") < v("r_1_0_1_0")) && \ 55 | (v("r_1_3_0_1") < v("r_1_3_1_0")) 56 | ); 57 | expr pre_iv2_0 = (v("r_1_1_0_1") == v("r_1_2_0_0") && \ 58 | v("r_1_1_1_0") == v("r_1_2_1_0") && \ 59 | v("r_1_1_2_0") == v("r_1_2_2_0") && \ 60 | v("r_1_1_3_0") == v("r_1_2_3_0") 61 | ); 62 | expr pre_iv2_1 = (v("r_1_3_0_1") == v("r_1_2_0_0") && \ 63 | v("r_1_3_1_0") == v("r_1_2_1_0") && \ 64 | v("r_1_3_2_0") == v("r_1_2_2_0") && \ 65 | v("r_1_3_3_0") == v("r_1_2_3_0") 66 | ); 67 | bl2 = (v("r_1_2_0_1") == v("r_1_2_0_0") + v("r_1_2_0_0")) && 68 | (v("ret_val_" + to_string(prog_id)) == v("r_1_2_0_1")); 69 | print_test_res(is_smt_valid(pre_pc2_0 == ps.path_con[2][0]), "pre path condition 0"); 70 | print_test_res(is_smt_valid(pre_pc2_1 == ps.path_con[2][1]), "pre path condition 1"); 71 | print_test_res(is_smt_valid(pre_iv2_0 == ps.reg_iv[2][0]), "pre register initial values 0"); 72 | print_test_res(is_smt_valid(pre_iv2_1 == ps.reg_iv[2][1]), "pre register initial values 1"); 73 | print_test_res(is_smt_valid(bl2 == ps.bl[2]), "basic block logic"); 74 | 75 | std::cout << "\ntest1.3: check program-end basic block 0[0:0] without RET instructions\n"; 76 | inst p2[1] = {inst(ADDXY, 0, 0), 77 | }; 78 | prog_id = 2; 79 | ps.gen_smt(prog_id, p2, 1); 80 | // fmt: r_[prog_id]_[block_id]_[reg_id]_[version_id] 81 | expr bl = (v("r_2_0_0_1") == v("r_2_0_0_0") + v("r_2_0_0_0")) && 82 | (v("ret_val_" + to_string(prog_id)) == v("r_2_0_0_1")); 83 | print_test_res(is_smt_valid(bl == ps.bl[0]), "basic block logic"); 84 | } 85 | 86 | void test2() { 87 | std::cout << "\ntest2.1: check single instruction logic\n"; 88 | // check instrcution MAXX logic 89 | // case1: inst(MAXX, 0, 0); case2: inst(MAXX, 0, 1) 90 | inst p[1] = {inst(MAXX, 0, 0)}; 91 | smt_prog ps; 92 | unsigned int prog_id = 0; 93 | ps.gen_smt(prog_id, p, 1); 94 | expr bl_expected = (v("r_0_0_0_1") == v("r_0_0_0_0")) && 95 | (v("ret_val_0") == v("r_0_0_0_1")); 96 | bool assert_res = is_smt_valid(bl_expected == ps.bl[0]); 97 | 98 | inst p1[1] = {inst(MAXX, 0, 1)}; 99 | ps.gen_smt(prog_id, p1, 1); 100 | bl_expected = (v("r_0_0_0_0") >= v("r_0_0_1_0") && (v("r_0_0_0_1") == v("r_0_0_0_0"))) || 101 | (v("r_0_0_0_0") < v("r_0_0_1_0") && (v("r_0_0_0_1") == v("r_0_0_1_0"))); 102 | bl_expected = bl_expected && (v("ret_val_0") == v("r_0_0_0_1")); 103 | assert_res = assert_res && is_smt_valid(bl_expected == ps.bl[0]); 104 | print_test_res(assert_res, "instruction MAXX logic"); 105 | 106 | // check instruction JMP logic when jmp distance is 0 107 | inst p2[2] = {inst(JMPEQ, 2, 0, 0), 108 | inst(ADDXY, 0, 1), 109 | }; 110 | expr pl = ps.gen_smt(prog_id, p2, 2); 111 | expr pl_expected = (v("r_0_1_0_0") == v("r_0_0_0_0")) && 112 | (v("r_0_1_1_0") == v("r_0_0_1_0")) && 113 | (v("r_0_1_2_0") == v("r_0_0_2_0")) && 114 | (v("r_0_1_3_0") == v("r_0_0_3_0")) && 115 | (v("r_0_1_0_1") == v("r_0_1_0_0") + v("r_0_1_1_0")) && 116 | (v("ret_val_0") == v("r_0_1_0_1")); 117 | print_test_res(is_smt_valid(pl_expected == pl), "instruction JMP logic when jmp distance is 0"); 118 | } 119 | 120 | void test3() { 121 | std::cout << "\ntest3: check unconditional jmp program\n"; 122 | inst p1[4] = {inst(JMP, 1), 123 | inst(ADDXY, 0, 0), 124 | inst(ADDXY, 0, 0), 125 | inst(RETX, 0), 126 | }; 127 | int prog_id = 1; 128 | smt_prog ps; 129 | ps.gen_smt(prog_id, p1, 4); 130 | expr pre_iv1_1 = (v("r_1_1_0_0") == v("r_1_0_0_0") && \ 131 | v("r_1_1_1_0") == v("r_1_0_1_0") && \ 132 | v("r_1_1_2_0") == v("r_1_0_2_0") && \ 133 | v("r_1_1_3_0") == v("r_1_0_3_0") 134 | ); 135 | expr bl1_1 = (v("r_1_1_0_1") == v("r_1_1_0_0") + v("r_1_1_0_0")); 136 | expr post1 = v("ret_val_" + to_string(prog_id)) == v("r_1_1_0_1"); 137 | expr pl1 = pre_iv1_1 && bl1_1 && post1; 138 | print_test_res(is_smt_valid(pl1 == ps.pl), "unconditional jmp"); 139 | } 140 | 141 | int main() { 142 | test1(); 143 | test2(); 144 | test3(); 145 | return 0; 146 | } 147 | -------------------------------------------------------------------------------- /src/verify/smt_prog_test_ebpf.cc: -------------------------------------------------------------------------------- 1 | #include "../../src/utils.h" 2 | #include "../../src/isa/inst_header.h" 3 | #include "smt_prog.h" 4 | 5 | using namespace z3; 6 | 7 | #define v(x) string_to_expr(x) 8 | 9 | void test1() { 10 | mem_t::_layout.clear(); 11 | // branch test for st/ld 12 | inst p1[6] = {inst(STXB, 10, -1, 1), 13 | inst(JEQXC, 1, 0x12, 2), 14 | inst(MOV64XC, 1, 0x12), 15 | inst(STXB, 10, -1, 1), 16 | inst(LDXB, 0, 10, -1), 17 | inst(EXIT), 18 | }; 19 | smt_prog ps; 20 | unsigned int prog_id = 0; 21 | expr pl = ps.gen_smt(prog_id, p1, 6); 22 | // graph info 23 | // nodes: 24 | // 0[0:1] 1[2:3] 2[4:5] 25 | // edges: 26 | // 0 in:4294967295 out:1 2 // 4294967295 means -1 27 | // 1 in:0 out:2 28 | // 2 in:1 0 out: 29 | smt_wt s0, s1, s21, s22; 30 | s0.add(v("r_0_0_10_0") + to_expr(-1), v("r_0_0_1_0").extract(7, 0)); 31 | bool res = (s0 == ps.post_sv[0][0].mem_var._mem_table._wt); 32 | s1 = s0; 33 | s1.add(v("r_0_1_10_0") + to_expr(-1), v("r_0_1_1_1").extract(7, 0)); 34 | res = res && (s1 == ps.post_sv[1][0].mem_var._mem_table._wt); 35 | s21 = s1; 36 | s22 = s0; 37 | res = res && (s21 == ps.post_sv[2][0].mem_var._mem_table._wt) && (s22 == ps.post_sv[2][1].mem_var._mem_table._wt); 38 | print_test_res(res, "post stack write table 1"); 39 | 40 | // test jmp 0 41 | inst p2[5] = {inst(STXB, 10, -1, 1), 42 | inst(JEQXY, 0, 1, 0), 43 | inst(STXB, 10, -1, 1), 44 | inst(LDXB, 0, 10, -1), 45 | inst(EXIT), 46 | }; 47 | pl = ps.gen_smt(prog_id, p2, 5); 48 | // nodes: 49 | // 0[0:1] 1[2:4] 50 | // edges: 51 | // 0 in:4294967295 out:1 1 // 4294967295 means -1 52 | // 1 in:0 0 out: 53 | smt_wt s; 54 | s.add(v("r_0_0_10_0") + to_expr(-1), v("r_0_0_1_0").extract(7, 0)); 55 | res = (s == ps.post_sv[0][0].mem_var._mem_table._wt); 56 | s.add(v("r_0_1_10_0") + to_expr(-1), v("r_0_1_1_0").extract(7, 0)); 57 | res = res && 58 | (s == ps.post_sv[1][0].mem_var._mem_table._wt) && 59 | (s == ps.post_sv[1][1].mem_var._mem_table._wt); 60 | print_test_res(res, "post stack write table 2"); 61 | 62 | // test jmp 0 with other jmps 63 | inst p3[8] = {inst(STXB, 10, -1, 1), 64 | inst(JEQXY, 1, 2, 2), // jmp case 1, r1 == r2 65 | inst(STXB, 10, -1, 2), 66 | inst(JEQXY, 1, 3, 2), // jmp case 2, r1 == r3 67 | inst(STXB, 10, -1, 3), 68 | inst(JEQXY, 1, 4, 0), // jmp case 3, r1 == r4 69 | inst(LDXB, 0, 10, -1), 70 | inst(EXIT), 71 | }; 72 | pl = ps.gen_smt(prog_id, p3, 8); 73 | // nodes: 74 | // 0[0:1] 1[2:3] 2[4:5] 3[6:7] 75 | // edges: 76 | // 0 in:4294967295 out:1 2 77 | // 1 in:0 out:2 3 78 | // 2 in:1 0 out:3 3 79 | // 3 in:2 2 1 out: 80 | // test the post stack write table of basic block 3 81 | vector w_addr = {v("r_0_0_10_0") + to_expr(-1), // write in basic block 0 82 | v("r_0_1_10_0") + to_expr(-1), // write in basic block 1 83 | v("r_0_2_10_0") + to_expr(-1), // write in basic block 2 84 | }; 85 | vector w_val = {v("r_0_0_1_0").extract(7, 0), 86 | v("r_0_1_2_0").extract(7, 0), 87 | v("r_0_2_3_0").extract(7, 0), 88 | }; 89 | // for basic block 3, paths: 90 | // case1: from 2: 0 -> 1 -> 2 -> 3, 0 -> 2 -> 3, 91 | // case2: from 2: 0 -> 1 -> 2 -> 3, 0 -> 2 -> 3, 92 | // case3: from 1: 0 -> 1 -> 3 93 | // case1 and 2 are the same, can be tested by the same stack write table 94 | s.clear(); 95 | // 0 -> 1 -> 2 -> 3 96 | s.add(w_addr[0], w_val[0]); s.add(w_addr[1], w_val[1]); s.add(w_addr[2], w_val[2]); 97 | res = (s == ps.post_sv[3][0].mem_var._mem_table._wt) && (s == ps.post_sv[3][2].mem_var._mem_table._wt); 98 | // 0 -> 2 -> 3 99 | s.clear(); 100 | s.add(w_addr[0], w_val[0]); s.add(w_addr[2], w_val[2]); 101 | res = res && (s == ps.post_sv[3][1].mem_var._mem_table._wt) && (s == ps.post_sv[3][3].mem_var._mem_table._wt); 102 | // 0 -> 1 -> 3 103 | s.clear(); 104 | s.add(w_addr[0], w_val[0]); s.add(w_addr[1], w_val[1]); 105 | res = res && (s == ps.post_sv[3][4].mem_var._mem_table._wt); 106 | print_test_res(res, "post stack write table 3"); 107 | } 108 | 109 | int main() { 110 | test1(); 111 | 112 | return 0; 113 | } 114 | -------------------------------------------------------------------------------- /src/verify/smt_var_test.cc: -------------------------------------------------------------------------------- 1 | #include "smt_var.h" 2 | #include "../../src/utils.h" 3 | 4 | using namespace std; 5 | 6 | void test1() { 7 | smt_wt s1; 8 | s1.add(string_to_expr("a"), string_to_expr("b")); 9 | smt_wt s2 = s1; 10 | z3::expr f = (s1.addr[0] == s2.addr[0]) && (s1.val[0] == s2.val[0]); 11 | print_test_res(f.simplify() == string_to_expr("true"), "smt_wt="); 12 | 13 | // test smt_wt == 14 | smt_wt s3 = s2; 15 | smt_wt s4 = s2; 16 | print_test_res(s1 == s2, "smt_wt == 1"); 17 | s1.add(string_to_expr("a1"), string_to_expr("b1")); 18 | print_test_res(!(s1 == s2), "smt_wt == 2"); 19 | s2.add(string_to_expr("a1"), string_to_expr("b1")); 20 | print_test_res(s1 == s2, "smt_wt == 3"); 21 | s3.add(string_to_expr("a2"), string_to_expr("b1")); 22 | print_test_res(!(s1 == s3), "smt_wt == 4"); 23 | s4.add(string_to_expr("a1"), string_to_expr("b2")); 24 | print_test_res(!(s1 == s4), "smt_wt == 5"); 25 | } 26 | 27 | int main() { 28 | test1(); 29 | 30 | return 0; 31 | } 32 | -------------------------------------------------------------------------------- /src/verify/validator.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "z3++.h" 4 | #include "../../src/utils.h" 5 | #include "../../src/inout.h" 6 | #include "../../src/isa/inst_header.h" 7 | #include "../../src/isa/prog.h" 8 | #include "smt_prog.h" 9 | #include "z3client.h" 10 | 11 | using namespace z3; 12 | 13 | #define ILLEGAL_CEX -2 // program is illegal and has a counterexample 14 | enum COUNTEREX_TYPE { 15 | COUNTEREX_eq_check = 0, 16 | COUNTEREX_safety_check, 17 | }; 18 | /* Validator algorithm document: https://github.com/ngsrinivas/superopt/tree/master/doc */ 19 | 20 | /* Class validator supports two functions now: equivalence check and output computation. 21 | * Funtion 1: equivalence check: check whether a synthesis program is equal to the original program/function. 22 | * Steps to use 23 | * step 1. set the original: 24 | * a. this step will compute and store the program-level pre-condition and program logic of the original 25 | * into `_pre_orig` and `_pl_orig`; 26 | * b. there are two ways to set the original: set in constructors or funtion `set_orig`; 27 | * c. function paramters can be either a program or a function. 28 | * step 2. call funtion `is_equal_to` to check whether a synthesis program is equal to the original one 29 | * a. this step will compute the program-level pre-condition and program logic of the synthesis program and 30 | * the post condition of the synthesis and original program. Then generate the SMT to check equivalence. 31 | * b. return value: 32 | * 1 (equal); 33 | * 0 (unequal) if unequal, a counter-example will be generated and stored in `_last_counterex`; 34 | * -1 (synthesis is illegal, e.g, program with loop or goes to invalid instructions). 35 | * c. paramters: a synthesis program 36 | * step 3. if the return value in step 2 is `0`, a counter-example can be extracted in `_last_counterex` 37 | * 38 | * Funtion 2: output computation: given input value, compute output value for the original program/function 39 | * Steps to use : 40 | * step 1. set the original: the same as step 1 in the previous equivalence check function. 41 | * step 2. call funtion `get_orig_output` to get the output value for the give input value in parameter 42 | */ 43 | #define VLD_PROG_ID_ORIG 0 44 | #define VLD_PROG_ID_SYNTH 1 45 | 46 | class validator { 47 | private: 48 | bool is_in_prog_cache(prog& pgm, unordered_map >& prog_cache, bool print = false); 49 | void insert_into_prog_cache(prog& pgm, unordered_map >& prog_cache); 50 | public: 51 | static bool enable_z3server; 52 | // pre_: input formula of program: setting register 0 in basic block 0 as input[prog_id] 53 | // or the input variable of FOL formula as input[prog_id] 54 | expr _pre_orig = string_to_expr("true"); 55 | expr _pl_orig = string_to_expr("true"); 56 | smt_var _post_sv_orig; 57 | // last counterexample 58 | inout _last_counterex; 59 | // the cache of programs that are equal to the original program 60 | unordered_map > _prog_eq_cache; 61 | bool _enable_prog_eq_cache = true; 62 | unordered_map > _prog_uneq_cache; 63 | bool _enable_prog_uneq_cache = false; 64 | bool _is_win = false; 65 | int _win_start, _win_end; 66 | smt_input _smt_input_orig; 67 | prog_static_state _pss_orig; 68 | // mem_t _last_counterex_mem; 69 | /* store variables start */ 70 | // ps_: program logic formula, including basic program logic 71 | // and the formula of capturing the output of the program in the variable output[prog_id] 72 | smt_prog _store_ps_orig; 73 | // two program's output formula of setting outputs of two programs are equal, 74 | // i.e., output[VLD_PROG_ID_ORIG] == output[VLD_PROG_ID_SYNTH] 75 | expr _store_post = string_to_expr("true"); 76 | // f = pre^pre2^p1^p2 -> post 77 | expr _store_f = string_to_expr("true"); 78 | /* store variables end */ 79 | /* counter variables */ 80 | unsigned int _count_is_equal_to = 0; 81 | unsigned int _count_throw_err = 0; 82 | unsigned int _count_prog_eq_cache = 0; 83 | unsigned int _count_solve_safety = 0; 84 | // a counter of calling is_smt_valid for solving equivalence check 85 | unsigned int _count_solve_eq = 0; 86 | /* counter variables end */ 87 | validator(); 88 | validator(inst* orig, int length, bool is_win = false, int win_start = 0, int win_end = inst::max_prog_len); 89 | validator(expr fx, expr input, expr output); 90 | ~validator(); 91 | // calculate and store pre_orig, ps_orign 92 | void set_orig(inst* orig, int length, int win_start = 0, int win_end = inst::max_prog_len); 93 | // fx is the original FOL formula, input/output is the input/output variable of fx 94 | void set_orig(expr fx, expr input, expr output); 95 | // check whether synth is equal to orig 96 | // return 0: not equal; return 1: equal; return -1: synth is illegal 97 | int is_equal_to(inst* orig, int length_orig, inst* synth, int length_syn); 98 | // given input and register to store the input, return the output of the original 99 | reg_t get_orig_output(reg_t input, unsigned int num_regs, unsigned int input_reg); 100 | // move from `private` to `public` for testing time 101 | int is_smt_valid(expr& smt, model& mdl); 102 | void gen_counterex(inst* orig, int length, model& m, smt_var& post_sv_synth, smt_input& sin_synth, int counterex_type); 103 | // set register 0 in basic block 0 as input[prog_id] 104 | void smt_pre(expr& pre, unsigned int prog_id, unsigned int num_regs, unsigned int input_reg, smt_input& sin, smt_var& sv); 105 | // set the input variable of FOL formula as input[prog_id] 106 | void smt_pre(expr& pre, expr e); 107 | // setting outputs of two programs are equal 108 | void smt_post(expr& pst, unsigned int prog_id1, unsigned int prog_id2, smt_var& post_sv_synth); 109 | int safety_check(inst* orig, int len, expr& pre, expr& pl, expr& p_sc, smt_var& sv, smt_input& sin); 110 | void print_counters() const; 111 | }; 112 | -------------------------------------------------------------------------------- /src/verify/validator_test.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../../src/utils.h" 3 | #include "../../src/isa/inst_header.h" 4 | #include "validator.h" 5 | 6 | using namespace z3; 7 | 8 | void test1() { 9 | std::cout << "test 1: no branch program equivalence check starts...\n"; 10 | // instructions1 == instructions2 == instructions3 != instructions4 11 | inst instructions1[6] = {inst(MOVXC, 1, 4), /* mov r1, 4 */ 12 | inst(ADDXY, 0, 1), /* add r0, r1 */ 13 | inst(MOVXC, 2, 15), /* mov r2, 15 */ 14 | inst(MAXC, 0, 15), /* max r0, 15 */ 15 | inst(MAXX, 0, 1), /* max r0, r1 */ 16 | inst(RETX, 0), 17 | }; 18 | 19 | inst instructions2[7] = {inst(MOVXC, 1, 4), /* mov r1, 4 */ 20 | inst(MOVXC, 2, 10), /* mov r2, 10 */ 21 | inst(ADDXY, 0, 1), /* add r0, r1 */ 22 | inst(MOVXC, 2, 15), /* mov r2, 15 */ 23 | inst(MAXC, 0, 15), /* max r0, 15 */ 24 | inst(MAXX, 0, 1), /* max r0, r1 */ 25 | inst(RETX, 0), 26 | }; 27 | 28 | inst instructions3[5] = {inst(MOVXC, 1, 4), /* mov r1, 4 */ 29 | inst(ADDXY, 0, 1), /* add r0, r1 */ 30 | inst(MOVXC, 2, 15), /* mov r2, 15 */ 31 | inst(MAXC, 0, 15), /* max r0, 15 */ 32 | inst(MAXX, 0, 2), /* max r0, r2 */ 33 | }; // default: ret 0 34 | 35 | inst instructions4[6] = {inst(MOVXC, 1, 4), /* mov r1, 4 */ 36 | inst(ADDXY, 0, 1), /* add r0, r1 */ 37 | inst(MOVXC, 2, 15), /* mov r2, 15 */ 38 | inst(MAXC, 0, -1), /* max r0, 15 */ 39 | inst(MAXX, 0, 3), /* max r0, r3 */ 40 | inst(RETX, 0), 41 | }; 42 | validator vld(instructions1, 6); 43 | 44 | print_test_res(vld.is_equal_to(instructions1, 6, instructions2, 7), "instructions1 == instructions2"); 45 | print_test_res(vld.is_equal_to(instructions1, 6, instructions3, 5), "instructions1 == instructions3"); 46 | print_test_res(!vld.is_equal_to(instructions1, 6, instructions4, 6), "instructions1 != instructions4"); 47 | } 48 | 49 | void test2() { 50 | validator vld; 51 | std::cout << "\ntest 2: branch program equivalence check starts...\n"; 52 | // instructions1 == instructions2 53 | inst instructions1[3] = {inst(JMPGT, 0, 2, 1), // if r0 <= r2: 54 | inst(RETX, 0), // ret r0 55 | inst(RETX, 2), // ret r2; 56 | }; 57 | inst instructions2[3] = {inst(JMPLT, 0, 2, 1), // if r0 >= r2 58 | inst(RETX, 2), // ret r2 59 | inst(RETX, 0), // ret r0 60 | }; 61 | vld.set_orig(instructions1, 3); 62 | print_test_res(vld.is_equal_to(instructions1, 3, instructions2, 3), "instructions1 == instructions2"); 63 | 64 | // instructions3 == instructions4 != instructions5 65 | inst instructions3[3] = {inst(JMPGT, 0, 2, 1), // return max(r0, r2) 66 | inst(RETX, 2), 67 | inst(RETX, 0), 68 | }; 69 | inst instructions4[2] = {inst(MAXX, 0, 2), // return r0=max(r0, r2) 70 | inst(RETX, 0), 71 | }; 72 | inst instructions5[3] = {inst(JMPGT, 2, 0, 1), // return min(r0, r2) 73 | inst(RETX, 2), 74 | inst(RETX, 0), 75 | }; 76 | vld.set_orig(instructions3, 3); 77 | 78 | print_test_res(vld.is_equal_to(instructions3, 3, instructions4, 2), "instructions3 == instructions4"); 79 | print_test_res(!vld.is_equal_to(instructions3, 3, instructions5, 3), "instructions3 != instructions5"); 80 | 81 | // f(x) = max(x, r1, r2, 10) 82 | // p11 == p12 83 | inst p11[5] = {inst(MAXX, 0, 1), 84 | inst(MAXX, 0, 2), 85 | inst(MOVXC, 1, 10), 86 | inst(MAXX, 0, 1), 87 | inst(RETX, 0), 88 | }; 89 | inst p12[11] = {inst(JMPGT, 0, 1, 2), // skip r0 <- r1, if r0 > r1 90 | inst(MOVXC, 0, 0), 91 | inst(ADDXY, 0, 1), 92 | inst(JMPGT, 0, 2, 2), // skip r0 <- r2, if r0 > r2 93 | inst(MOVXC, 0, 0), 94 | inst(ADDXY, 0, 2), 95 | inst(MOVXC, 1, 10), // r1 <- 10 96 | inst(JMPGT, 0, 1, 2), // skip r0 <- r1, if r0 > r1 97 | inst(MOVXC, 0, 0), 98 | inst(ADDXY, 0, 1), 99 | inst(RETX, 0), // ret r0 100 | }; 101 | vld.set_orig(p11, 5); 102 | print_test_res(vld.is_equal_to(p11, 5, p12, 11), "f(x)_p1 == f(x)_p2"); 103 | 104 | // check unconditonal jmp 105 | // p13 != p11, p14 == p15 == p11 106 | inst p13[6] = {inst(JMP, 3), 107 | inst(MAXX, 0, 1), 108 | inst(MAXX, 0, 2), 109 | inst(MOVXC, 1, 10), 110 | inst(MAXX, 0, 1), 111 | inst(RETX, 0), 112 | }; 113 | inst p14[6] = {inst(JMP, 0), 114 | inst(MAXX, 0, 1), 115 | inst(MAXX, 0, 2), 116 | inst(MOVXC, 1, 10), 117 | inst(MAXX, 0, 1), 118 | inst(RETX, 0), 119 | }; 120 | inst p15[7] = {inst(JMP, 3), 121 | inst(MOVXC, 1, 10), 122 | inst(MAXX, 0, 1), 123 | inst(RETX, 0), 124 | inst(MAXX, 0, 1), 125 | inst(MAXX, 0, 2), 126 | inst(JMP, -6), 127 | }; 128 | print_test_res(!vld.is_equal_to(p11, 5, p13, 6), "unconditonal jmp 1"); 129 | print_test_res(vld.is_equal_to(p11, 5, p14, 6), "unconditonal jmp 2"); 130 | print_test_res(vld.is_equal_to(p11, 5, p15, 7), "unconditonal jmp 3"); 131 | } 132 | 133 | // // fx == program_fx test 134 | // void test3() { 135 | // std::cout << "\ntest 3 starts...\n"; 136 | // expr x = string_to_expr("x"); 137 | // expr y = string_to_expr("y"); 138 | // expr fx = implies(x > 10, y == x) && implies(x <= 10, y == 10); 139 | // inst p_fx[4] = {inst(MOVXC, 1, 10), 140 | // inst(JMPLT, 0, 1, 1), 141 | // inst(RETX, 0), 142 | // inst(RETX, 1), 143 | // }; 144 | // validator vld(fx, x, y); 145 | // print_test_res(vld.is_equal_to(p_fx, 4), "Program_f(x) == (f(x)=max(x, 10))"); 146 | // } 147 | 148 | void test4() { 149 | std::cout << "\ntest4: check counterexample generation\n"; 150 | // orig: output = max(input, 11); 151 | // synth: output = max(input, 10); 152 | // counterexample: input <= 10, output = 11 153 | inst orig[3] = {inst(MOVXC, 2, 11), /* mov r2, 11 */ 154 | inst(MAXX, 0, 2), /* max r0, r2 */ 155 | inst(RETX, 0), 156 | }; 157 | inst synth[3] = {inst(MOVXC, 2, 10), /* mov r2, 10 */ 158 | inst(MAXX, 0, 2), /* max r0, r2 */ 159 | inst(RETX, 0), 160 | }; 161 | validator vld(orig, 3); 162 | inout counterex; 163 | inout_t input, output; 164 | input.reg = 0; 165 | output.reg = 0; 166 | counterex.set_in_out(input, output); 167 | if (!vld.is_equal_to(orig, 3, synth, 3)) { 168 | counterex = vld._last_counterex; 169 | } 170 | print_test_res((counterex.input.reg <= 10) && (counterex.output.reg == 11), 171 | "counterexample generation"); 172 | } 173 | 174 | void test5() { 175 | std::cout << "\ntest5: check get_orig_output\n"; 176 | // orig: output = max(input, 11); 177 | inst orig[3] = {inst(MOVXC, 2, 11), /* mov r2, 11 */ 178 | inst(MAXX, 0, 2), /* max r0, r2 */ 179 | inst(RETX, 0), 180 | }; 181 | validator vld(orig, 3); 182 | vector ex_set = {1, 10, 11, 12, 20}; 183 | vector expected = {11, 11, 11, 12, 20}; 184 | for (size_t i = 0; i < ex_set.size(); i++) { 185 | int output = vld.get_orig_output(ex_set[i], NUM_REGS, orig->get_input_reg()); 186 | print_test_res(output == expected[i], to_string(i)); 187 | } 188 | } 189 | 190 | int main(int argc, char *argv[]) { 191 | test1(); // no branch 192 | test2(); // with branch 193 | // test3(); 194 | test4(); 195 | // test5(); 196 | return 0; 197 | } 198 | -------------------------------------------------------------------------------- /src/verify/z3client.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "z3client.h" 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include "../../src/utils.h" 16 | 17 | using namespace std; 18 | 19 | #define FORMULA_SHM_KEY 224 20 | #define RESULTS_SHM_KEY 46 21 | #define FORMULA_SIZE_BYTES (1 << 22) 22 | #define RESULT_SIZE_BYTES (1 << 22) 23 | #define SOLVER_RESPAWN_THRESOLD 1000 24 | 25 | 26 | int SERVER_PORT = 8002; /* default port */ 27 | z3::context c; 28 | pid_t child_pid_1 = -1; 29 | pid_t child_pid_2 = -1; 30 | pid_t pid; 31 | int nsolve1 = 0; 32 | int nsolve2 = 0; 33 | 34 | char form_buffer[FORMULA_SIZE_BYTES + 1] = {0}; 35 | char res_buffer[RESULT_SIZE_BYTES + 1] = {0}; 36 | 37 | // server_id starts from 0 38 | int get_server_port(int server_id) { 39 | assert(server_id >= 0); 40 | return (SERVER_PORT + server_id); 41 | } 42 | 43 | int spawn_server(int port) { 44 | // cout << "Hello before sleep\n"; 45 | // sleep(10); 46 | // cout << "after sleep\n"; 47 | cout << "Spawining Server with port: " << port << "\n"; 48 | pid = fork(); 49 | if (pid == -1) { 50 | cout << "Fork error occurred. Can't spawn a z3 solver server."; 51 | return -1; 52 | } else if (pid == 0) { /* in the child process; exec to z3server */ 53 | std::string NEWPORT = std::to_string(port); 54 | char *argv_list[] = {(char *)"./z3server.out ", const_cast(NEWPORT.c_str()), (char *)NULL}; 55 | execv("./z3server.out", argv_list); 56 | exit(-1); /* never supposed to get here until the exec fails. */ 57 | } else { 58 | /* in the parent process; record and return the child pid for later. */ 59 | return pid; 60 | } 61 | } 62 | 63 | void kill_server_by_pid(pid_t id) { 64 | if (id <= 0) return; 65 | string cmd = "kill -9 " + to_string(id); 66 | int status = system(cmd.c_str()); 67 | if ((status != -1) && WIFEXITED(status) && (WEXITSTATUS(status) == 0)) { 68 | cout << "kill the z3 solver server " << id << " successfully" << endl; 69 | } else { 70 | cout << "kill the z3 solver server " << id << " failed" << endl; 71 | } 72 | } 73 | 74 | void kill_server() { 75 | kill_server_by_pid(pid); 76 | kill_server_by_pid(child_pid_1); 77 | kill_server_by_pid(child_pid_2); 78 | } 79 | 80 | int create_and_connect_socket(int port) { 81 | int sock = 0; 82 | struct sockaddr_in serv_addr; 83 | if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { 84 | perror("z3client: socket creation failed"); 85 | return -1; 86 | } 87 | 88 | serv_addr.sin_family = AF_INET; 89 | serv_addr.sin_port = htons(port); 90 | if (inet_pton(AF_INET, "127.0.0.1", &serv_addr.sin_addr) <= 0) { 91 | perror("z3client: Invalid localhost network address"); 92 | return -1; 93 | } 94 | 95 | if (connect(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) 96 | < 0) { 97 | perror("z3client: connect() to z3server failed"); 98 | return -1; 99 | } 100 | return sock; 101 | } 102 | 103 | /* Send the formula to the server */ 104 | void send_formula(int sock, string formula) { 105 | int nchars; 106 | //cout << "z3client: Sending formula to server...\n"; 107 | nchars = std::min(FORMULA_SIZE_BYTES, (int)formula.length()); 108 | strncpy(form_buffer, formula.c_str(), nchars); 109 | form_buffer[nchars] = '\0'; 110 | send(sock, form_buffer, nchars + 1, 0); 111 | } 112 | 113 | /* Send the formula to the server */ 114 | void read_from_solver(int sock) { 115 | int nread, total_read; 116 | total_read = 0; 117 | do { 118 | nread = read(sock, res_buffer + total_read, RESULT_SIZE_BYTES - total_read); 119 | total_read += nread; 120 | } while (res_buffer[total_read - 1] != '\0' && 121 | total_read < RESULT_SIZE_BYTES); 122 | if (total_read >= RESULT_SIZE_BYTES) 123 | cout << "Exhausted result read buffer\n"; 124 | close(sock); 125 | } 126 | 127 | /* Poll Server Status non-blocking */ 128 | int poll_servers(int sock, int timeout) { 129 | fd_set fds; 130 | FD_ZERO (&fds); 131 | FD_SET (sock, &fds); 132 | struct timeval tv1 = {timeout, 0}; 133 | int readSockets = select (FD_SETSIZE, &fds, NULL, NULL, &tv1); 134 | return FD_ISSET (sock, &fds); 135 | } 136 | string write_problem_to_z3server(string formula) { 137 | // cout << "z3client: Received a formula to solve\n"; 138 | 139 | /* Server One */ 140 | bool no_child_now = child_pid_1 <= 0; 141 | bool time_to_respawn = (! no_child_now) && 142 | nsolve1 > 0 && nsolve1 % SOLVER_RESPAWN_THRESOLD == 0; 143 | if (no_child_now || time_to_respawn) { 144 | if (time_to_respawn) /* kill the existing server. */ 145 | kill(child_pid_1, SIGKILL); 146 | 147 | child_pid_1 = spawn_server(get_server_port(0)); 148 | if (child_pid_1 <= 0) { /* unsuccessful spawn */ 149 | cout << "z3client: spawning server 1 failed\n"; 150 | return ""; 151 | } 152 | sleep(2); /* letting socket listen to be setup */ 153 | } 154 | /* Server Two */ 155 | no_child_now = child_pid_2 <= 0; 156 | time_to_respawn = (! no_child_now) && 157 | nsolve2 > 0 && nsolve2 % SOLVER_RESPAWN_THRESOLD == 0; 158 | if (no_child_now || time_to_respawn) { 159 | if (time_to_respawn) /* kill the existing server. */ 160 | kill(child_pid_2, SIGKILL); 161 | 162 | child_pid_2 = spawn_server(get_server_port(1)); 163 | if (child_pid_2 <= 0) { /* unsuccessful spawn */ 164 | cout << "z3client: spawning server 1 failed\n"; 165 | return ""; 166 | } 167 | sleep(2); /* letting socket listen to be setup */ 168 | } 169 | 170 | /* Make connection request to server */ 171 | //cout << "Connecting Server 1\n"; 172 | int sock1 = create_and_connect_socket(get_server_port(0)); 173 | //cout << "Connecting Server 2\n"; 174 | int sock2 = create_and_connect_socket(get_server_port(1)); 175 | if (sock1 == -1 || sock2 == -1) { /* socket creation error */ 176 | return ""; 177 | } 178 | 179 | send_formula(sock1, formula); 180 | send_formula(sock2, formula); 181 | 182 | /* Block until one socket returns data */ 183 | fd_set fds; 184 | FD_ZERO (&fds); 185 | FD_SET (sock1, &fds); 186 | FD_SET (sock2, &fds); 187 | struct timeval tv = {86400, 0}; /* set timeout for 24 hours */ 188 | int readSockets = select (FD_SETSIZE, &fds, NULL, NULL, &tv); 189 | if (readSockets < 0) { 190 | perror("z3client: neither server returned anything"); 191 | return ""; 192 | } 193 | if (readSockets == 0) { 194 | perror("z3client: timeout"); 195 | return ""; 196 | } 197 | int server1_read = FD_ISSET (sock1, &fds); 198 | int server2_read = FD_ISSET (sock2, &fds); 199 | int status; 200 | if (server1_read > 0 && server2_read > 0) { /* both sockets are readable */ 201 | // cout << "z3Client: both servers returned\n"; 202 | read_from_solver(sock1); 203 | read_from_solver(sock2); 204 | nsolve1++; 205 | nsolve2++; 206 | } else if (server1_read > 0 && server2_read == 0) { /* socket 1 is readable */ 207 | read_from_solver(sock1); 208 | server2_read = poll_servers(sock2, 2); 209 | if (server2_read > 0) { 210 | // cout << "z3client: both servers returned\n"; 211 | read_from_solver(sock2); 212 | nsolve2++; 213 | } else { 214 | cout << "z3client: Only server 1 returned. Killing server 2\n"; 215 | kill(child_pid_2, SIGKILL); 216 | waitpid(child_pid_2, &status, 0); 217 | child_pid_2 = spawn_server(get_server_port(1)); 218 | } 219 | nsolve1++; 220 | } else if (server1_read == 0 && server2_read > 0) { /* socket 2 is readable */ 221 | read_from_solver(sock2); 222 | server1_read = poll_servers(sock1, 2); 223 | if (server1_read > 0) { 224 | // cout << "z3client: both servers returned\n"; 225 | read_from_solver(sock1); 226 | nsolve1++; 227 | } else { 228 | cout << "z3client: Only server 2 returned. Killing server 1\n"; 229 | kill(child_pid_1, SIGKILL); 230 | waitpid(child_pid_1, &status, 0); 231 | child_pid_1 = spawn_server(get_server_port(0)); 232 | } 233 | nsolve2++; 234 | } 235 | /* Read back solver results. */ 236 | // cout << "z3client: Waiting for solver results from server...\n"; 237 | return string(res_buffer); 238 | } 239 | -------------------------------------------------------------------------------- /src/verify/z3client.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "z3++.h" 4 | using namespace std; 5 | extern int SERVER_PORT; 6 | int spawn_server(); 7 | void kill_server(); 8 | string write_problem_to_z3server(string formula); 9 | -------------------------------------------------------------------------------- /z3server.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "z3++.h" 9 | #include 10 | 11 | #define FORMULA_SHM_KEY 224 12 | #define RESULTS_SHM_KEY 46 13 | #define FORMULA_SIZE_BYTES (1 << 22) 14 | #define RESULT_SIZE_BYTES (1 << 22) 15 | 16 | using namespace std; 17 | 18 | z3::context c; 19 | int read_problem_from_z3client(int PORT); 20 | 21 | char buffer[FORMULA_SIZE_BYTES + 1] = {0}; 22 | char res_buffer[RESULT_SIZE_BYTES + 1] = {0}; 23 | 24 | string run_solver(char* formula) { 25 | z3::tactic t = z3::tactic(c, "bv"); 26 | z3::solver s = t.mk_solver(); 27 | 28 | Z3_set_ast_print_mode(s.ctx(), Z3_PRINT_SMTLIB2_COMPLIANT); 29 | string res; 30 | s.from_string(formula); 31 | // cout << "Running the solver..." << endl; 32 | switch (s.check()) { 33 | case z3::unsat: { 34 | return "unsat"; 35 | } 36 | case z3::sat: { 37 | ostringstream strm; 38 | z3::model mdl = s.get_model(); 39 | strm << mdl; 40 | res = strm.str(); 41 | return res; 42 | } 43 | case z3::unknown: { 44 | return "unknown"; 45 | } 46 | } 47 | } 48 | void set_seed() { 49 | srand (time(NULL)); 50 | int iSecret = rand() % ((int) pow(2, 8)) + 1; 51 | // iSecret = rand() % 4 + 1; 52 | cout << "z3server: seed = " << iSecret << endl; 53 | z3::set_param("sls.random_seed", iSecret); 54 | z3::set_param("smt.random_seed", iSecret); 55 | z3::set_param("sat.random_seed", iSecret); 56 | z3::set_param("fp.spacer.random_seed", iSecret); 57 | } 58 | 59 | int read_problem_from_z3client(int PORT) { 60 | int server_fd, acc_socket, nread, total_read, nchars; 61 | int opt = 1; 62 | struct sockaddr_in address; 63 | int addrlen = sizeof(address); 64 | string result; 65 | set_seed(); 66 | if ((server_fd = socket(AF_INET, SOCK_STREAM, 0)) == 0) { 67 | perror("z3server: socket creation failed"); 68 | exit(EXIT_FAILURE); 69 | } 70 | 71 | if (setsockopt(server_fd, SOL_SOCKET, SO_REUSEPORT, 72 | (char*)&opt, sizeof(opt)) < 0) { 73 | perror("z3server: setsockopt to reuse addr/port failed"); 74 | exit(EXIT_FAILURE); 75 | } 76 | 77 | address.sin_family = AF_INET; 78 | address.sin_addr.s_addr = INADDR_ANY; 79 | address.sin_port = htons(PORT); 80 | 81 | if (::bind(server_fd, (struct sockaddr *)&address, 82 | sizeof(address)) < 0) { 83 | perror("z3server: socket bind to local address/port failed"); 84 | exit(EXIT_FAILURE); 85 | } 86 | 87 | if (listen(server_fd, 1) < 0) { 88 | perror("z3server: can't listen to bound socket"); 89 | exit(EXIT_FAILURE); 90 | } 91 | 92 | /* Main server + solver loop */ 93 | while ((acc_socket = accept(server_fd, (struct sockaddr *)&address, 94 | (socklen_t*)&addrlen)) ) { 95 | if (acc_socket < 0) { 96 | perror("z3server: failed to accept incoming connection"); 97 | exit(EXIT_FAILURE); 98 | } 99 | 100 | //cout << "z3server: Received a new connection. Reading formula on port: " << PORT << endl; 101 | /* Read the full formula into buffer. */ 102 | total_read = 0; 103 | do { 104 | nread = read(acc_socket, buffer + total_read, FORMULA_SIZE_BYTES - total_read); 105 | total_read += nread; 106 | } while (buffer[total_read - 1] != '\0' && 107 | total_read < FORMULA_SIZE_BYTES); 108 | if (total_read >= FORMULA_SIZE_BYTES) 109 | cout << "Exhausted formula read buffer\n"; 110 | 111 | //cout << "z3server: Recieved Formula from client on port: " << PORT << endl; 112 | 113 | /* Run the solver. */ 114 | result = run_solver(buffer); 115 | nchars = min((int)result.length(), RESULT_SIZE_BYTES); 116 | strncpy(res_buffer, result.c_str(), nchars); 117 | res_buffer[nchars] = '\0'; 118 | 119 | //cout << "z3server: Sending formula to Client...\n"; 120 | /* Send result. */ 121 | send(acc_socket, res_buffer, nchars + 1, 0); 122 | close(acc_socket); 123 | } 124 | return 0; 125 | } 126 | 127 | int main(int argc, char *argv[]) { 128 | if (argc != 2) { 129 | cout << "No port argument" << endl; 130 | return 1; 131 | } 132 | int PORT = std::stoi(argv[1]); 133 | cout << "z3server: Port is " << argv[1] << endl; 134 | /* Receive a z3 smtlib2 formula in a shared memory segment, and 135 | return sat or unsat in another one. */ 136 | read_problem_from_z3client(PORT); 137 | return 0; 138 | } 139 | --------------------------------------------------------------------------------