├── .github └── workflows │ └── pythonapp.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .python-version ├── MANIFEST.in ├── README.md ├── appveyor.yml ├── chipc ├── __init__.py ├── alu.g4 ├── compiler.py ├── iterative_solver.py ├── lib │ ├── LICENSE.txt │ └── antlr-4.7.2-complete.jar ├── mode.py ├── sketch_code_generator.py ├── sketch_stateful_alu_visitor.py ├── sketch_stateless_alu_visitor.py ├── sketch_utils.py ├── templates │ ├── code_generator.j2 │ ├── mux.j2 │ ├── muxes_and_alus.j2 │ ├── opt_verify.j2 │ └── router_data_path_sketch.j2 ├── utils.py └── z3_utils.py ├── example_alus ├── stateful_alus │ ├── if_else_raw.alu │ ├── nested_ifs.alu │ ├── pair.alu │ ├── pred_raw.alu │ ├── raw.alu │ └── sub.alu └── stateless_alus │ ├── stateless_alu.alu │ ├── stateless_alu_arith.alu │ ├── stateless_alu_arith_rel.alu │ ├── stateless_alu_arith_rel_cond.alu │ └── stateless_alu_arith_rel_cond_bool.alu ├── example_specs ├── blue_decrease.sk ├── blue_increase.sk ├── learn_filter_modified_for_test.sk ├── marple_new_flow.sk ├── marple_tcp_nmo.sk ├── rcp.sk ├── sampling.sk ├── sampling_revised.sk ├── simple.sk ├── simple2.sk ├── simplest.sk ├── simplest2.sk ├── simplified_hull.sk ├── snap_heavy_hitter.sk ├── test.sk └── times_two.sk ├── requirements-dev.txt ├── setup.py └── tests ├── __init__.py ├── data ├── hello.dag ├── hello.smt ├── sampling.dag ├── sampling.sk ├── sampling.smt └── simple_raw_2_2_codegen_iteration_1.sk ├── test_iterative_solver.py ├── test_utils.py └── test_z3_utils.py /.github/workflows/pythonapp.yml: -------------------------------------------------------------------------------- 1 | name: Python application 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v1 12 | - name: Set up Python 3.7 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: 3.7 16 | - name: Install dependencies 17 | run: | 18 | old_dir=`pwd` 19 | set -e 20 | sudo apt-get update 21 | sudo apt-get install -y bison python3-pip flex 22 | cd /usr/local/lib 23 | sudo wget https://www.antlr.org/download/antlr-4.7.2-complete.jar 24 | export CLASSPATH=".:/usr/local/lib/antlr-4.7.2-complete.jar:$CLASSPATH" 25 | antlr4='java -jar /usr/local/lib/antlr-4.7.2-complete.jar' 26 | cd ~ 27 | wget https://people.csail.mit.edu/asolar/sketch-1.7.5.tar.gz 28 | tar xvzf sketch-1.7.5.tar.gz 29 | cd sketch-1.7.5 30 | cd sketch-backend 31 | chmod +x ./configure 32 | ./configure 33 | make 34 | cd .. 35 | cd sketch-frontend 36 | chmod +x ./sketch 37 | ./sketch test/sk/seq/miniTest1.sk 38 | export PATH="$PATH:`pwd`" 39 | export SKETCH_HOME="`pwd`/runtime" 40 | cd $old_dir 41 | python -m pip install --upgrade pip 42 | pip install -e . 43 | iterative_solver example_specs/simple.sk example_alus/stateful_alus/raw.alu example_alus/stateless_alus/stateless_alu.alu 2 2 "0,1,2,3" 10 --hole-elimination 44 | iterative_solver example_specs/simple.sk example_alus/stateful_alus/raw.alu example_alus/stateless_alus/stateless_alu.alu 2 2 "0,1,2,3" 10 --parallel-sketch --hole-elimination 45 | python3 -m unittest 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | *.pickle 4 | *.success 5 | *.errors 6 | *.sk 7 | !example_specs/*.sk 8 | *.class 9 | *.java 10 | *.interp 11 | *.tokens 12 | aluLexer.py 13 | aluParser.py 14 | aluVisitor.py 15 | aluListener.py 16 | *.egg-info 17 | *.smt2 18 | *.dag 19 | *_output.txt 20 | .vscode/ 21 | 22 | !tests/data/*.sk 23 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v2.1.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-docstring-first 10 | - id: check-yaml 11 | - id: double-quote-string-fixer 12 | - id: name-tests-test 13 | args: [--django] 14 | - repo: https://gitlab.com/pycqa/flake8 15 | rev: 3.7.7 16 | hooks: 17 | - id: flake8 18 | - repo: https://github.com/pre-commit/mirrors-autopep8 19 | rev: v1.4.3 20 | hooks: 21 | - id: autopep8 22 | - repo: https://github.com/asottile/reorder_python_imports 23 | rev: v1.4.0 24 | hooks: 25 | - id: reorder-python-imports 26 | - repo: meta 27 | hooks: 28 | - id: check-hooks-apply 29 | - id: check-useless-excludes 30 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.5.4/envs/chipmunk 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include chipc/alu.g4 2 | include chipc/templates/*.j2 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chipmunk 2 | 3 | [![Build status](https://ci.appveyor.com/api/projects/status/060fwhaq3vfvt22n/branch/master?svg=true)](https://ci.appveyor.com/project/anirudhSK/chipmunk-hhg5f/branch/master) 4 | 5 | ## Installation 6 | - Install [sketch](https://people.csail.mit.edu/asolar/sketch-1.7.5.tar.gz) 7 | - Install Java. This repo uses [antlr](https://www.antlr.org/) to generate 8 | parser and lexer. 9 | - `pip3 install -r requirements-dev.txt -e . && pre-commit install` (if you want to make changes to 10 | this repo), 11 | - `pip3 install .` (if you want to simply use chipmunk.). 12 | - Add sudo if you want to install system wide. 13 | 14 | ## How to 15 | 16 | ### Develop 17 | 18 | If you have installed it as above, first re-install via following command. 19 | 20 | ```shell 21 | pip3 install -r requirements-dev.txt -e . 22 | pre-commit install 23 | ``` 24 | Note that there is `-e` in install command. It will install this package in 25 | development mode, and simply link actual chipc directory to your Python's 26 | site-packages directory. 27 | 28 | 1. Make changes to python code 29 | 2. Consider implementing tests and run tests `python3 -m unittest` 30 | 3. Run your desired binary like `python chipc/chipmunk.py ...` 31 | 32 | This way you don't have to keep installing and uninstalling whenever you make a 33 | change and test. However, still you have to run via `python3 chipc/chipmunk.py` 34 | instead of using the installed binary. 35 | 36 | Also consider using [venv](https://docs.python.org/3/library/venv.html), 37 | [virtualenv](https://virtualenv.pypa.io/en/latest/) or 38 | [pipenv](https://pipenv.readthedocs.io/en/latest/) to create an isolated Python 39 | development environment. 40 | 41 | ### Iterative solver 42 | ```shell 43 | iterative_solver example_specs/simple.sk example_alus/stateful_alus/raw.alu example_alus/stateless_alus/stateless_alu.alu 2 2 "0,1,2,3" 10 --hole-elimination 44 | ``` 45 | 46 | ```shell 47 | iterative_solver example_specs/simple.sk example_alus/stateful_alus/raw.alu example_alus/stateless_alus/stateless_alu.alu 2 2 "0,1,2,3" 10 --parallel --parallel-sketch --hole-elimination 48 | ``` 49 | 50 | ### Test 51 | 52 | Run: 53 | 54 | You need to run setup in editable mode (with -e) to generate lexer and parser 55 | in this directory. 56 | ```shell 57 | pip3 install -r requirements-dev.txt -e . 58 | python3 -m unittest 59 | ``` 60 | 61 | If you want to add a test, add a new file in [tests](tests/) directory or add 62 | test cases in existing `test_*.py` file. 63 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | version: 1.0.{build} 2 | image: Ubuntu1604 3 | stack: jdk 8 4 | install: 5 | - sh: >- 6 | old_dir=`pwd` 7 | 8 | set -e 9 | 10 | sudo apt-get update 11 | 12 | sudo apt-get install -y bison python3-pip flex 13 | 14 | cd /usr/local/lib 15 | 16 | sudo wget https://www.antlr.org/download/antlr-4.7.2-complete.jar 17 | 18 | export CLASSPATH=".:/usr/local/lib/antlr-4.7.2-complete.jar:$CLASSPATH" 19 | 20 | antlr4='java -jar /usr/local/lib/antlr-4.7.2-complete.jar' 21 | 22 | cd ~ 23 | 24 | wget https://people.csail.mit.edu/asolar/sketch-1.7.5.tar.gz 25 | 26 | tar xvzf sketch-1.7.5.tar.gz 27 | 28 | cd sketch-1.7.5 29 | 30 | cd sketch-backend 31 | 32 | chmod +x ./configure 33 | 34 | ./configure 35 | 36 | make 37 | 38 | cd .. 39 | 40 | cd sketch-frontend 41 | 42 | chmod +x ./sketch 43 | 44 | ./sketch test/sk/seq/miniTest1.sk 45 | 46 | export PATH="$PATH:`pwd`" 47 | 48 | export SKETCH_HOME="`pwd`/runtime" 49 | 50 | cd $old_dir 51 | 52 | build_script: 53 | - sh: >- 54 | set -e 55 | 56 | sudo pip3 install . 57 | 58 | $antlr4 chipc/alu.g4 -Dlanguage=Python3 -visitor -package chipc 59 | 60 | iterative_solver example_specs/simple.sk example_alus/stateful_alus/raw.alu example_alus/stateless_alus/stateless_alu.alu 2 2 "0,1,2,3" 10 --hole-elimination 61 | 62 | iterative_solver example_specs/simple.sk example_alus/stateful_alus/raw.alu example_alus/stateless_alus/stateless_alu.alu 2 2 "0,1,2,3" 10 --parallel --parallel-sketch --hole-elimination 63 | 64 | python3 -m unittest 65 | 66 | echo "********ALL DONE*******" 67 | -------------------------------------------------------------------------------- /chipc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chipmunk-project/chipmunk/a86eb9030c7accfd552dc4c899fa247193588251/chipc/__init__.py -------------------------------------------------------------------------------- /chipc/alu.g4: -------------------------------------------------------------------------------- 1 | grammar alu; 2 | 3 | // Hide whitespace, but don't skip it 4 | WS : [ \n\t\r]+ -> channel(HIDDEN); 5 | LINE_COMMENT : '//' ~[\r\n]* -> skip; 6 | // Keywords 7 | RELOP : 'rel_op'; // <, >, <=, >=, ==, != Captures everything from slide 14 of salu.pdf 8 | BOOLOP : 'bool_op'; // !, &&, || and combinations of these (best guess for how update_lo/hi_1/2_predicate works 9 | ARITHOP : 'arith_op'; // Captures +/- used in slide 14 of salu.pdf 10 | COMPUTEALU : 'compute_alu'; // Captures everything from slide 15 of salu.pdf 11 | // TODO: Instead of having numbered MUX, maybe implement a support one mux that 12 | // can take variable number of arguments. 13 | MUX5 : 'Mux5'; // 5-to-1 mux 14 | MUX4 : 'Mux4'; 15 | MUX3 : 'Mux3'; // 3-to-1 mux 16 | MUX2 : 'Mux2'; // 2-to-1 mux 17 | OPT : 'Opt'; // Pick either the argument or 0 18 | CONSTANT : 'C()'; // Return a finite constant 19 | TRUE : 'True'; // Guard corresponding to "always update" 20 | IF : 'if'; 21 | ELSE : 'else'; 22 | ELIF : 'elif'; 23 | RETURN : 'return'; 24 | EQUAL : '=='; 25 | GREATER : '>'; 26 | LESS : '<'; 27 | GREATER_OR_EQUAL : '>='; 28 | LESS_OR_EQUAL : '<='; 29 | NOT_EQUAL : '!='; 30 | OR : '||'; 31 | AND : '&&'; 32 | NOT : '!'; 33 | QUESTION : '?'; 34 | ASSERT_FALSE : 'assert(false);'; 35 | 36 | // Identifiers 37 | ID : ('a'..'z' | 'A'..'Z') ('a'..'z' | 'A'..'Z' | '_' | '0'..'9')*; 38 | 39 | // Numerical constant 40 | NUM : ('0'..'9') | (('1'..'9')('0'..'9')+); 41 | 42 | 43 | // alias id to state_var and packet_field 44 | state_var : ID; 45 | temp_var : ID; 46 | packet_field : ID; 47 | // alias id to hole variables 48 | hole_var : ID; 49 | 50 | // Determines whether the ALU is stateless or stateful 51 | stateless : 'stateless'; 52 | stateful : 'stateful'; 53 | state_indicator : 'type' ':' stateless 54 | | 'type' ':' stateful; 55 | 56 | state_var_def : 'state' 'variables' ':' '{' state_var_seq '}'; 57 | 58 | state_var_seq : /* epsilon */ 59 | | state_vars 60 | ; 61 | 62 | state_vars : state_var #SingleStateVar 63 | | state_var ',' state_vars #MultipleStateVars 64 | ; 65 | 66 | hole_def : 'hole' 'variables' ':' '{' hole_seq '}'; 67 | 68 | hole_seq : /* epsilon */ 69 | | hole_vars 70 | ; 71 | 72 | hole_vars : hole_var #SingleHoleVar 73 | | hole_var ',' hole_vars #MultipleHoleVars 74 | ; 75 | 76 | packet_field_def : 'packet' 'fields' ':' '{' packet_field_seq '}'; 77 | 78 | packet_field_seq : /* epsilon */ 79 | | packet_fields 80 | ; 81 | 82 | packet_fields : packet_field #SinglePacketField 83 | | packet_field ',' packet_fields #MultiplePacketFields 84 | ; 85 | 86 | // alu_body 87 | alu_body : statement+; 88 | 89 | condition_block : '(' expr ')' '{' alu_body '}'; 90 | 91 | statement : variable '=' expr ';' #StmtUpdateExpr 92 | | 'int ' temp_var '=' expr ';' #StmtUpdateTempInt 93 | | 'bit ' temp_var '=' expr ';' #StmtUpdateTempBit 94 | // NOTE: Having multiple return statements between a pair of curly 95 | // braces is syntactically correct, but such program might not make 96 | // sense for us. 97 | // TODO: Modify the generator to catch multiple return statements 98 | // and output errors early on. 99 | | return_statement #StmtReturn 100 | | IF condition_block (ELIF condition_block)* (ELSE '{' else_body = alu_body '}')? #StmtIfElseIfElse 101 | | ASSERT_FALSE #AssertFalse 102 | ; 103 | 104 | return_statement : RETURN expr ';'; 105 | 106 | variable : ID ; 107 | expr : variable #Var 108 | | expr op=('+'|'-'|'*'|'/') expr #ExprWithOp 109 | | '(' expr ')' #ExprWithParen 110 | | NUM #Num 111 | | expr EQUAL expr #Equals 112 | | expr GREATER expr #Greater 113 | | expr GREATER_OR_EQUAL expr #GreaterEqual 114 | | expr LESS expr #Less 115 | | expr LESS_OR_EQUAL expr #LessEqual 116 | | expr NOT_EQUAL expr #NotEqual 117 | | expr AND expr #And 118 | | expr OR expr #Or 119 | | NOT expr #NOT 120 | | TRUE #True 121 | | expr '?' expr ':' expr #Ternary 122 | // Currently, we use below rules only from stateful ALUs. 123 | | MUX2 '(' expr ',' expr ')' #Mux2 124 | | MUX3 '(' expr ',' expr ',' NUM ')' #Mux3WithNum 125 | | MUX3 '(' expr ',' expr ',' expr ')' #Mux3 126 | | MUX4 '(' expr ',' expr ',' expr ',' expr ')' #Mux4 127 | | MUX5 '(' expr ',' expr ',' expr ',' expr ',' expr ')' #Mux5 128 | | OPT '(' expr ')' #Opt 129 | | CONSTANT #Constant 130 | | ARITHOP '(' expr ',' expr ')' # ArithOp 131 | | COMPUTEALU '(' expr ',' expr ')' # ComputeAlu 132 | | RELOP '(' expr ',' expr ')' #RelOp 133 | | BOOLOP '(' expr ',' expr ')' #BoolOp 134 | ; 135 | 136 | alu: state_indicator state_var_def hole_def packet_field_def alu_body; 137 | -------------------------------------------------------------------------------- /chipc/compiler.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures as cf 2 | import itertools 3 | import os 4 | import signal 5 | from collections import OrderedDict 6 | from os import path 7 | from pathlib import Path 8 | 9 | import psutil 10 | from jinja2 import Environment 11 | from jinja2 import FileSystemLoader 12 | from jinja2 import StrictUndefined 13 | 14 | from chipc import sketch_utils 15 | from chipc import z3_utils 16 | from chipc.mode import Mode 17 | from chipc.sketch_code_generator import SketchCodeGenerator 18 | from chipc.utils import get_hole_bit_width 19 | from chipc.utils import get_hole_value_assignments 20 | from chipc.utils import get_num_pkt_fields 21 | from chipc.utils import get_state_group_info 22 | 23 | 24 | def kill_child_processes(parent_pid, sig=signal.SIGTERM): 25 | try: 26 | parent = psutil.Process(parent_pid) 27 | except psutil.NoSuchProcess: 28 | return 29 | children = parent.children(recursive=True) 30 | for process in children: 31 | try: 32 | process.send_signal(sig) 33 | print('send_signal killed a child process', process) 34 | except psutil.NoSuchProcess as e: 35 | print("send_signal didn't have any effect because process didn't" 36 | 'exist') 37 | print(e) 38 | 39 | 40 | class Compiler: 41 | def __init__(self, spec_filename, stateful_alu_filename, 42 | stateless_alu_filename, num_pipeline_stages, 43 | num_alus_per_stage, sketch_name, parallel_sketch, 44 | constant_set, synthesized_allocation=False, 45 | output_packet_fields=[], 46 | output_state_groups=[], 47 | input_packet_fields=[]): 48 | self.spec_filename = spec_filename 49 | self.stateful_alu_filename = stateful_alu_filename 50 | self.stateless_alu_filename = stateless_alu_filename 51 | self.num_pipeline_stages = num_pipeline_stages 52 | self.num_alus_per_stage = num_alus_per_stage 53 | self.sketch_name = sketch_name 54 | self.parallel_sketch = parallel_sketch 55 | self.constant_set = constant_set 56 | self.synthesized_allocation = synthesized_allocation 57 | 58 | program_content = Path(spec_filename).read_text() 59 | self.num_fields_in_prog = get_num_pkt_fields(program_content) 60 | self.num_state_groups = len(get_state_group_info(program_content)) 61 | 62 | if not input_packet_fields: 63 | assert self.num_fields_in_prog <= num_alus_per_stage, ( 64 | 'Number of fields in program %d is greater than number of ' 65 | 'alus per stage %d. Try increasing ' 66 | 'number of alus per stage.' % ( 67 | self.num_fields_in_prog, num_alus_per_stage)) 68 | else: 69 | assert len(input_packet_fields) <= num_alus_per_stage, ( 70 | 'Number of input fields in program %d is' 71 | 'greater than number of alus per stage %d. Try increasing ' 72 | 'number of alus per stage.' % ( 73 | len(input_packet_fields), num_alus_per_stage)) 74 | # Guarantee that # of output_packet_fields is less than or equal 75 | # to the num_alus_per_stage 76 | if output_packet_fields is not None: 77 | assert len(output_packet_fields) <= num_alus_per_stage, ( 78 | 'Number of checked fields in program %d is ' 79 | 'greater than number of alus per stage %d. ' 80 | 'Try increasing number of alus per stage.' % ( 81 | len(output_packet_fields), num_alus_per_stage)) 82 | 83 | # Initialize jinja2 environment for templates 84 | self.jinja2_env = Environment( 85 | loader=FileSystemLoader( 86 | [path.join(path.dirname(__file__), './templates'), 87 | path.join(os.getcwd(), 88 | stateless_alu_filename[ 89 | :stateless_alu_filename.rfind('/')]), 90 | '.', '/']), 91 | undefined=StrictUndefined, 92 | trim_blocks=True, 93 | lstrip_blocks=True) 94 | 95 | if not output_packet_fields and not output_state_groups: 96 | output_packet_fields = list(range(self.num_fields_in_prog)) 97 | output_state_groups = list(range(self.num_state_groups)) 98 | elif not output_packet_fields and output_state_groups: 99 | output_packet_fields = [] 100 | elif output_packet_fields and not output_state_groups: 101 | output_state_groups = [] 102 | 103 | # Differentiate between using default pkt input vs. specify pkt input 104 | if not input_packet_fields: 105 | input_packet_fields = list(range(self.num_fields_in_prog)) 106 | 107 | # Create an object for sketch generation 108 | self.sketch_code_generator = SketchCodeGenerator( 109 | sketch_name=sketch_name, 110 | num_pipeline_stages=num_pipeline_stages, 111 | num_alus_per_stage=num_alus_per_stage, 112 | num_phv_containers=num_alus_per_stage, 113 | num_state_groups=self.num_state_groups, 114 | num_fields_in_prog=self.num_fields_in_prog, 115 | output_packet_fields=output_packet_fields, 116 | output_state_groups=output_state_groups, 117 | jinja2_env=self.jinja2_env, 118 | stateful_alu_filename=stateful_alu_filename, 119 | stateless_alu_filename=stateless_alu_filename, 120 | constant_set=constant_set, 121 | synthesized_allocation=synthesized_allocation, 122 | input_packet_fields=input_packet_fields) 123 | 124 | def update_constants_for_synthesis(self, constant_set): 125 | # Join the values in constant_set to get constant_array in sketch. 126 | new_constant_set_str = '{' + ','.join(constant_set) + '}' 127 | 128 | # Use string format to create constant_arr_def_ 129 | self.sketch_code_generator.constant_arr_def_ = \ 130 | 'int[{}]'.format(str(len(constant_set))) + \ 131 | 'constant_vector = {};\n\n'.format(new_constant_set_str) 132 | self.sketch_code_generator.constant_arr_size_ = get_hole_bit_width( 133 | len(constant_set)) 134 | 135 | def single_codegen_run(self, compiler_input): 136 | additional_constraints = compiler_input[0] 137 | additional_testcases = compiler_input[1] 138 | sketch_file_name = compiler_input[2] 139 | 140 | """Codegeneration""" 141 | codegen_code = self.sketch_code_generator.generate_sketch( 142 | spec_filename=self.spec_filename, 143 | mode=Mode.CODEGEN, 144 | synthesized_allocation=self.synthesized_allocation, 145 | additional_constraints=additional_constraints, 146 | additional_testcases=additional_testcases) 147 | 148 | # Create file and write sketch_harness into it. 149 | with open(sketch_file_name, 'w') as sketch_file: 150 | sketch_file.write(codegen_code) 151 | 152 | # Call sketch on it 153 | print('Total number of hole bits is', 154 | self.sketch_code_generator.total_hole_bits_) 155 | print('Sketch file is', sketch_file_name) 156 | assert (self.parallel_sketch in [True, False]) 157 | (ret_code, output) = sketch_utils.synthesize( 158 | sketch_file_name, 159 | bnd_inbits=2, 160 | slv_seed=1, 161 | slv_parallel=self.parallel_sketch) 162 | 163 | # Store sketch output 164 | with open(sketch_file_name[:sketch_file_name.find('.sk')] + 165 | '_output.txt', 'w') as output_file: 166 | output_file.write(output) 167 | if (ret_code == 0): 168 | holes_to_values = get_hole_value_assignments( 169 | self.sketch_code_generator.hole_names_, output) 170 | else: 171 | holes_to_values = OrderedDict() 172 | return (ret_code, output, holes_to_values) 173 | 174 | def serial_codegen(self, iter_cnt=1, additional_constraints=[], 175 | additional_testcases=''): 176 | return self.single_codegen_run((additional_constraints, 177 | additional_testcases, 178 | self.sketch_name + 179 | '_codegen_iteration_' + 180 | str(iter_cnt) + '.sk')) 181 | 182 | def parallel_codegen(self, 183 | additional_constraints=[], 184 | additional_testcases=''): 185 | # For each state_group, pick a pipeline_stage exhaustively. 186 | # Note that some of these assignments might be infeasible, but that's 187 | # OK. Sketch will reject these anyway. 188 | count = 0 189 | compiler_output = None 190 | compiler_inputs = [] 191 | for assignment in itertools.product(list( 192 | range(self.num_pipeline_stages)), 193 | repeat=self.num_state_groups): 194 | constraint_list = additional_constraints.copy() 195 | count = count + 1 196 | print('Now in assignment # ', count, ' assignment is ', assignment) 197 | for state_group in range(self.num_state_groups): 198 | assigned_stage = assignment[state_group] 199 | for stage in range(self.num_pipeline_stages): 200 | if (stage == assigned_stage): 201 | constraint_list += [ 202 | self.sketch_name + '_salu_config_' + 203 | str(stage) + '_' + str(state_group) + ' == 1' 204 | ] 205 | else: 206 | constraint_list += [ 207 | self.sketch_name + '_salu_config_' + 208 | str(stage) + '_' + str(state_group) + ' == 0' 209 | ] 210 | compiler_inputs += [ 211 | (constraint_list, additional_testcases, 212 | self.sketch_name + '_' + str(count) + '_codegen.sk') 213 | ] 214 | 215 | with cf.ProcessPoolExecutor(max_workers=count) as executor: 216 | futures = [] 217 | for compiler_input in compiler_inputs: 218 | futures.append( 219 | executor.submit(self.single_codegen_run, compiler_input)) 220 | 221 | for f in cf.as_completed(futures): 222 | compiler_output = f.result() 223 | if (compiler_output[0] == 0): 224 | print('Success') 225 | # TODO: Figure out the right way to do this in the future. 226 | executor.shutdown(wait=False) 227 | kill_child_processes(os.getpid()) 228 | return compiler_output 229 | else: 230 | print('One run failed, waiting for others.') 231 | return compiler_output 232 | 233 | def verify(self, hole_assignments, input_bits, iter_cnt=1): 234 | """Verify hole value assignments for the sketch with a specific input 235 | bit lengths with z3. 236 | 237 | Returns: 238 | A tuple of two dicts from string to ints, where the first one 239 | represents counterexamples for packet variables and the second for 240 | state group variables. 241 | If the hole value assignments work for the input_bits, returns 242 | a tuple of two empty dicts. 243 | """ 244 | # Check all holes have values. 245 | for hole in self.sketch_code_generator.hole_names_: 246 | assert hole in hole_assignments 247 | 248 | # Generate a sketch file to verify the hole value assignments with 249 | # the specified input bit lengths. 250 | sketch_to_verify = self.sketch_code_generator.generate_sketch( 251 | spec_filename=self.spec_filename, 252 | mode=Mode.VERIFY, 253 | synthesized_allocation=self.synthesized_allocation, 254 | hole_assignments=hole_assignments 255 | ) 256 | 257 | # Write sketch to a file. 258 | file_basename = self.sketch_name + '_verify_iter_' + str(iter_cnt) 259 | sketch_filename = file_basename + '.sk' 260 | Path(sketch_filename).write_text(sketch_to_verify) 261 | 262 | sketch_ir = sketch_utils.generate_ir(sketch_filename) 263 | 264 | z3_formula = z3_utils.get_z3_formula(sketch_ir, input_bits) 265 | 266 | return z3_utils.generate_counterexamples(z3_formula) 267 | -------------------------------------------------------------------------------- /chipc/iterative_solver.py: -------------------------------------------------------------------------------- 1 | """Repeated Solver""" 2 | import argparse 3 | import sys 4 | from pathlib import Path 5 | 6 | from ordered_set import OrderedSet 7 | 8 | from chipc.compiler import Compiler 9 | from chipc.utils import compilation_failure 10 | from chipc.utils import compilation_success 11 | from chipc.utils import get_num_pkt_fields 12 | from chipc.utils import get_state_group_info 13 | 14 | 15 | def generate_hole_elimination_assert(hole_assignments): 16 | """Given hole value assignments, {'n_0: 'v_0', 'n_1': 'v_1', ... }, which 17 | failed to verify for larger input bit ranges, generates a single element 18 | string list representing the negation of holes all equal to the values. 19 | This is then passed to sketch file to avoid this specific combination of 20 | hole value assignments.""" 21 | if len(hole_assignments) == 0: 22 | return [] 23 | 24 | # The ! is to ensure a hole combination isn't present. 25 | hole_elimination_string = '!(' 26 | # To make it easier to test, sort the hole value assignments using the hole 27 | # names. 28 | for idx, (hole, value) in enumerate( 29 | sorted(hole_assignments.items(), key=lambda x: x[0])): 30 | hole_elimination_string += '(' + hole + ' == ' + value + ')' 31 | if idx != len(hole_assignments) - 1: 32 | hole_elimination_string += ' && ' 33 | hole_elimination_string += ')' 34 | return [hole_elimination_string] 35 | 36 | 37 | def set_default_values(pkt_fields, state_vars, num_fields_in_prog, 38 | state_group_info): 39 | """Check if all packet fields and state variables exist in counterexample 40 | dictionaries, pkt_fields and state_vars. If not, set those missing to 0 41 | since they don't really matter. 42 | """ 43 | for i in range(int(num_fields_in_prog)): 44 | field_name = 'pkt_' + str(i) 45 | if field_name not in pkt_fields: 46 | print('Setting value 0 for', field_name) 47 | pkt_fields[field_name] = 0 48 | for group_idx, vs in state_group_info.items(): 49 | for var_idx in vs: 50 | state_var_name = 'state_group_' + group_idx + '_state_' + var_idx 51 | if state_var_name not in state_vars: 52 | print('Setting value 0 for', state_var_name) 53 | state_vars[state_var_name] = 0 54 | return (pkt_fields, state_vars) 55 | 56 | 57 | def generate_counterexample_asserts(pkt_fields, state_vars, num_fields_in_prog, 58 | state_group_info, count, 59 | output_packet_fields, 60 | state_group_to_check, group_size): 61 | counterexample_defs = '' 62 | counterexample_asserts = '' 63 | 64 | counterexample_defs += '|StateAndPacket| x_' + str( 65 | count) + ' = |StateAndPacket|(\n' 66 | for field_name, value in pkt_fields.items(): 67 | counterexample_defs += field_name + ' = ' + str( 68 | value) + ',\n' 69 | 70 | for i, (state_var_name, value) in enumerate(state_vars.items()): 71 | counterexample_defs += state_var_name + ' = ' + str( 72 | value) 73 | if i < len(state_vars) - 1: 74 | counterexample_defs += ',\n' 75 | else: 76 | counterexample_defs += ');\n' 77 | 78 | if output_packet_fields is None and state_group_to_check is None: 79 | counterexample_asserts += 'assert (pipeline(' + 'x_' + str( 80 | count) + ')' + ' == ' + 'program(' + 'x_' + str( 81 | count) + '));\n' 82 | elif output_packet_fields is not None and state_group_to_check is None: 83 | # If our spec only cares about specific packet fields, 84 | # the counterexample generated should also care about 85 | # the same packet fields 86 | # For example, we only care about pkt_0 then the counterexample 87 | # assert should be assert(pipeline(x_1).pkt_0 == program(x_1).pkt_0) 88 | 89 | # Same case for stateful_groups except we should care about the size 90 | # of state group 91 | # For example, we only care about the state_group_0 with the size 92 | # of this group to be 2, then the counterexample assert should be 93 | # assert(pipeline(x_1).state_group_0_state_0 == 94 | # program(x_1).state_group_0_state_0) 95 | # assert(pipeline(x_1).state_group_0_state_1 == 96 | # program(x_1).state_group_0_state_1) 97 | for i in range(len(output_packet_fields)): 98 | counterexample_asserts += 'assert (pipeline(' + 'x_' + str( 99 | count) + ').pkt_' + str(output_packet_fields[i]) + \ 100 | ' == ' + 'program(' + 'x_' + str( 101 | count) + ').pkt_' + str(output_packet_fields[i]) + ');\n' 102 | elif output_packet_fields is not None and state_group_to_check is not None: 103 | # counterexample for packet fields 104 | for i in range(len(output_packet_fields)): 105 | counterexample_asserts += 'assert (pipeline(' + 'x_' + str( 106 | count) + ').pkt_' + str(output_packet_fields[i]) + \ 107 | ' == ' + 'program(' + 'x_' + str( 108 | count) + ').pkt_' + str(output_packet_fields[i]) + ');\n' 109 | # counterexample for stateful groups 110 | for i in range(len(state_group_to_check)): 111 | for j in range(group_size): 112 | counterexample_asserts += 'assert (pipeline(' + 'x_' + str( 113 | count) + ').state_group_' + \ 114 | str(state_group_to_check[i]) + \ 115 | '_state_' + str(j) + ' == ' + 'program(' + 'x_' + str( 116 | count) + ').state_group_' + \ 117 | str(state_group_to_check[i]) + '_state_' + str(j) + ');\n' 118 | else: 119 | assert output_packet_fields is None 120 | assert state_group_to_check is not None 121 | for i in range(len(state_group_to_check)): 122 | for j in range(group_size): 123 | counterexample_asserts += 'assert (pipeline(' + 'x_' + str( 124 | count) + ').state_group_' + \ 125 | str(state_group_to_check[i]) + \ 126 | '_state_' + str(j) + ' == ' + 'program(' + 'x_' + str( 127 | count) + ').state_group_' + str(state_group_to_check[i]) +\ 128 | '_state_' + str(j) + ');\n' 129 | 130 | return counterexample_defs + counterexample_asserts 131 | 132 | 133 | def main(argv): 134 | parser = argparse.ArgumentParser(description='Iterative solver.') 135 | parser.add_argument( 136 | 'spec_filename', help='Program specification in .sk file') 137 | parser.add_argument('stateful_alu_filename', 138 | help='Stateful ALU file to use.') 139 | parser.add_argument( 140 | 'stateless_alu_filename', help='Stateless ALU file to use.') 141 | parser.add_argument( 142 | 'num_pipeline_stages', type=int, help='Number of pipeline stages') 143 | parser.add_argument( 144 | 'num_alus_per_stage', 145 | type=int, 146 | help='Number of stateless/stateful ALUs per stage') 147 | parser.add_argument( 148 | 'constant_set', 149 | type=str, 150 | help='The content in the constant_set\ 151 | and the format will be like 0,1,2,3\ 152 | and we will calculate the number of\ 153 | comma to get the size of it') 154 | parser.add_argument( 155 | 'max_input_bit', 156 | type=int, 157 | help='The maximum input value in bits') 158 | parser.add_argument( 159 | '--pkt-fields', 160 | type=int, 161 | nargs='+', 162 | help='Packet fields to check correctness') 163 | parser.add_argument( 164 | '--state-groups', 165 | type=int, 166 | nargs='+', 167 | help='State groups to check correctness') 168 | parser.add_argument( 169 | '--input-packet', 170 | type=int, 171 | nargs='+', 172 | help='This is intended to provide user with choice\ 173 | to pick up the packet fields that will \ 174 | influence the packet fields/states we\ 175 | want to check correctness for to feed into chipmunk.\ 176 | For example, in example_specs/blue_decrease.sk, \ 177 | packet field pkt_1 only depends on pkt_0, \ 178 | so we can specify --input-packet=0 when we \ 179 | set - -pkt-fields=1.') 180 | parser.add_argument( 181 | '-p', 182 | '--parallel', 183 | action='store_true', 184 | help='Whether to run multiple smaller sketches in parallel by\ 185 | setting salu_config variables explicitly.') 186 | parser.add_argument( 187 | '--parallel-sketch', 188 | action='store_true', 189 | help='Whether sketch process internally uses parallelism') 190 | parser.add_argument( 191 | '--hole-elimination', 192 | action='store_true', 193 | help='If set, add addtional assert statements to sketch, so that we \ 194 | would not see the same combination of hole value assignments.' 195 | ) 196 | parser.add_argument( 197 | '--synthesized-allocation', 198 | action='store_true', 199 | help='If set let sketch allocate state variables otherwise \ 200 | use canonical allocation, i.e, first state variable assigned \ 201 | to first phv container.' 202 | ) 203 | 204 | args = parser.parse_args(argv[1:]) 205 | # Use program_content to store the program file text rather than using it 206 | # twice 207 | program_content = Path(args.spec_filename).read_text() 208 | num_fields_in_prog = get_num_pkt_fields(program_content) 209 | 210 | # Get the state vars information 211 | # TODO: add the max_input_bit into sketch_name 212 | state_group_info = get_state_group_info(program_content) 213 | sketch_name = args.spec_filename.split('/')[-1].split('.')[0] + \ 214 | '_' + args.stateful_alu_filename.split('/')[-1].split('.')[0] + \ 215 | '_' + args.stateless_alu_filename.split('/')[-1].split('.')[0] + \ 216 | '_' + str(args.num_pipeline_stages) + \ 217 | '_' + str(args.num_alus_per_stage) 218 | 219 | # Get how many members in each state group 220 | # state_group_info is OrderedDict which stores the num of state group and 221 | # the how many stateful vars in each state group 222 | # state_group_info[item] specifies the 223 | # num of stateful vars in each state group 224 | # In our example, each state group has the same size, so we only pick up 225 | # one of them to get the group size 226 | # For example, 227 | # if state_group_info = OrderedDict([('0', OrderedSet(['0', '1']))]) 228 | # state_group_info[item] = OrderedSet(['0', '1']) 229 | # group_size = 2 230 | for item in state_group_info: 231 | group_size = len(state_group_info[item]) 232 | break 233 | 234 | # Use OrderedSet here for deterministic compilation results. We can also 235 | # use built-in dict() for Python versions 3.6 and later, as it's inherently 236 | # ordered. 237 | constant_set = OrderedSet(args.constant_set.split(',')) 238 | 239 | compiler = Compiler(args.spec_filename, args.stateful_alu_filename, 240 | args.stateless_alu_filename, 241 | args.num_pipeline_stages, args.num_alus_per_stage, 242 | sketch_name, args.parallel_sketch, 243 | constant_set, 244 | args.synthesized_allocation, 245 | args.pkt_fields, args.state_groups, 246 | args.input_packet) 247 | # Repeatedly run synthesis at 2 bits and verification using all valid ints 248 | # until either verification succeeds or synthesis fails at 2 bits. Note 249 | # that the verification with all ints, might not work because sketch only 250 | # considers positive integers. 251 | # Synthesis is much faster at a smaller bit width, while verification needs 252 | # to run at a larger bit width for soundness. 253 | count = 1 254 | hole_elimination_assert = [] 255 | additional_testcases = '' 256 | sol_verify_bit = args.max_input_bit 257 | while 1: 258 | print('Iteration #' + str(count)) 259 | (synthesis_ret_code, output, hole_assignments) = \ 260 | compiler.parallel_codegen( 261 | additional_constraints=hole_elimination_assert, 262 | additional_testcases=additional_testcases) \ 263 | if args.parallel else \ 264 | compiler.serial_codegen( 265 | iter_cnt=count, 266 | additional_constraints=hole_elimination_assert, 267 | additional_testcases=additional_testcases) 268 | 269 | if synthesis_ret_code != 0: 270 | compilation_failure(sketch_name, output) 271 | return 1 272 | 273 | print('Synthesis succeeded with 2 bits, proceeding to verification.') 274 | pkt_fields, state_vars = compiler.verify( 275 | hole_assignments, sol_verify_bit, iter_cnt=count 276 | ) 277 | 278 | if len(pkt_fields) == 0 and len(state_vars) == 0: 279 | compilation_success(sketch_name, hole_assignments, output) 280 | return 0 281 | 282 | print('Verification failed.') 283 | 284 | # NOTE(taegyunkim): There is no harm in using both hole elimination 285 | # asserts and counterexamples. We want to compare using only hole 286 | # elimination asserts and only counterexamples. 287 | if args.hole_elimination: 288 | hole_elimination_assert += generate_hole_elimination_assert( 289 | hole_assignments) 290 | print(hole_elimination_assert) 291 | else: 292 | print('Use returned counterexamples', pkt_fields, state_vars) 293 | 294 | # compiler.constant_set will be in the form "0,1,2,3" 295 | 296 | # Get the value of counterexample and add them into constant_set 297 | for _, value in pkt_fields.items(): 298 | value_str = str(value) 299 | constant_set.add(value_str) 300 | for _, value in state_vars.items(): 301 | value_str = str(value) 302 | constant_set.add(value_str) 303 | 304 | # Print the updated constant_array just for debugging 305 | print('updated constant array', constant_set) 306 | 307 | # Add constant set to compiler for next synthesis. 308 | compiler.update_constants_for_synthesis(constant_set) 309 | 310 | pkt_fields, state_vars = set_default_values( 311 | pkt_fields, state_vars, num_fields_in_prog, state_group_info 312 | ) 313 | 314 | additional_testcases += generate_counterexample_asserts( 315 | pkt_fields, state_vars, num_fields_in_prog, state_group_info, 316 | count, args.pkt_fields, args.state_groups, group_size) 317 | 318 | count += 1 319 | 320 | 321 | def run_main(): 322 | sys.exit(main(sys.argv)) 323 | 324 | 325 | if __name__ == '__main__': 326 | run_main() 327 | -------------------------------------------------------------------------------- /chipc/lib/LICENSE.txt: -------------------------------------------------------------------------------- 1 | [The "BSD 3-clause license"] 2 | Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 2. Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 3. Neither the name of the copyright holder nor the names of its contributors 14 | may be used to endorse or promote products derived from this software 15 | without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 18 | IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 19 | OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 20 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, 21 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 22 | NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 23 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 24 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 26 | THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | 28 | ===== 29 | 30 | MIT License for codepointat.js from https://git.io/codepointat 31 | MIT License for fromcodepoint.js from https://git.io/vDW1m 32 | 33 | Copyright Mathias Bynens 34 | 35 | Permission is hereby granted, free of charge, to any person obtaining 36 | a copy of this software and associated documentation files (the 37 | "Software"), to deal in the Software without restriction, including 38 | without limitation the rights to use, copy, modify, merge, publish, 39 | distribute, sublicense, and/or sell copies of the Software, and to 40 | permit persons to whom the Software is furnished to do so, subject to 41 | the following conditions: 42 | 43 | The above copyright notice and this permission notice shall be 44 | included in all copies or substantial portions of the Software. 45 | 46 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 47 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 48 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 49 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 50 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 51 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 52 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 53 | -------------------------------------------------------------------------------- /chipc/lib/antlr-4.7.2-complete.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chipmunk-project/chipmunk/a86eb9030c7accfd552dc4c899fa247193588251/chipc/lib/antlr-4.7.2-complete.jar -------------------------------------------------------------------------------- /chipc/mode.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class Mode(Enum): 5 | SOL_VERIFY = 2 # Deprecated, no longer in use. 6 | CODEGEN = 3 7 | VERIFY = 4 8 | 9 | def is_SOL_VERIFY(self): 10 | return self.name == 'SOL_VERIFY' 11 | 12 | def is_CODEGEN(self): 13 | return self.name == 'CODEGEN' 14 | 15 | def is_VERIFY(self): 16 | return self.name == 'VERIFY' 17 | -------------------------------------------------------------------------------- /chipc/sketch_code_generator.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from pathlib import Path 3 | 4 | from antlr4 import CommonTokenStream 5 | from antlr4 import FileStream 6 | 7 | from chipc.aluLexer import aluLexer 8 | from chipc.aluParser import aluParser 9 | from chipc.mode import Mode 10 | from chipc.sketch_stateful_alu_visitor import SketchStatefulAluVisitor 11 | from chipc.sketch_stateless_alu_visitor import SketchStatelessAluVisitor 12 | from chipc.utils import get_hole_bit_width 13 | 14 | 15 | class Hole: 16 | def __init__(self, hole_name, max_value): 17 | self.name = hole_name 18 | self.max = max_value 19 | 20 | 21 | def add_prefix_suffix(text, prefix_string, suffix_string): 22 | return prefix_string + str(text) + suffix_string 23 | 24 | 25 | # Sketch Generator class 26 | class SketchCodeGenerator: 27 | def __init__(self, sketch_name, num_phv_containers, num_state_groups, 28 | num_alus_per_stage, num_pipeline_stages, num_fields_in_prog, 29 | output_packet_fields, output_state_groups, 30 | jinja2_env, stateful_alu_filename, 31 | stateless_alu_filename, constant_set, 32 | synthesized_allocation, input_packet_fields): 33 | self.sketch_name_ = sketch_name 34 | self.total_hole_bits_ = 0 35 | self.hole_names_ = [] 36 | self.hole_preamble_ = '' 37 | self.hole_arguments_ = [] 38 | self.holes_ = [] 39 | self.asserts_ = '' 40 | self.constraints_ = [] 41 | self.num_phv_containers_ = num_phv_containers 42 | self.num_pipeline_stages_ = num_pipeline_stages 43 | self.num_state_groups_ = num_state_groups 44 | self.num_alus_per_stage_ = num_alus_per_stage 45 | self.num_fields_in_prog_ = num_fields_in_prog 46 | self.output_packet_fields_ = output_packet_fields 47 | self.output_state_groups_ = output_state_groups 48 | self.jinja2_env_ = jinja2_env 49 | self.jinja2_env_.filters['add_prefix_suffix'] = add_prefix_suffix 50 | self.stateful_alu_filename_ = stateful_alu_filename 51 | self.stateless_alu_filename_ = stateless_alu_filename 52 | # self.constant_arr_def_ will be the form like 53 | # int constant_vector[4] = {0,1,2,3}; 54 | 55 | constant_set_str = '{' + ','.join(constant_set) + '}' 56 | self.constant_arr_def_ = 'int[' + \ 57 | str(len(constant_set)) + \ 58 | '] constant_vector = ' + \ 59 | constant_set_str + \ 60 | ';\n\n' 61 | self.constant_arr_size_ = get_hole_bit_width(len(constant_set)) 62 | self.num_operands_to_stateful_alu_ = 0 63 | self.num_state_slots_ = 0 64 | self.synthesized_allocation_ = synthesized_allocation 65 | self.input_packet_fields_ = input_packet_fields 66 | 67 | def reset_holes_and_asserts(self): 68 | self.total_hole_bits_ = 0 69 | self.hole_names_ = [] 70 | self.hole_preamble_ = '' 71 | self.hole_arguments_ = [] 72 | self.holes_ = [] 73 | self.asserts_ = '' 74 | self.constraints_ = [] 75 | 76 | # Write all holes to a single hole string for ease of debugging 77 | def add_hole(self, hole_name, hole_bit_width): 78 | assert (hole_bit_width >= 0) 79 | self.hole_names_ += [hole_name] 80 | self.hole_preamble_ += 'int ' + hole_name + '= ??(' + str( 81 | hole_bit_width) + ');\n' 82 | self.total_hole_bits_ += hole_bit_width 83 | self.hole_arguments_ += ['int ' + hole_name] 84 | self.holes_ += [Hole(hole_name, 2**hole_bit_width - 1)] 85 | 86 | # Write several holes from a dictionary (new_holes) into self.holes_ 87 | def add_holes(self, new_holes): 88 | for hole in sorted(new_holes): 89 | self.add_hole(hole, new_holes[hole]) 90 | 91 | def add_assert(self, assert_predicate): 92 | self.asserts_ += 'assert(' + assert_predicate + ');\n' 93 | self.constraints_ += [assert_predicate] 94 | 95 | # Generate Sketch code for a simple stateless alu (+,-,*,/) 96 | def generate_stateless_alu(self, alu_name, potential_operands): 97 | # Grab the stateless alu file name by using 98 | input_stream = FileStream(self.stateless_alu_filename_) 99 | lexer = aluLexer(input_stream) 100 | stream = CommonTokenStream(lexer) 101 | parser = aluParser(stream) 102 | tree = parser.alu() 103 | 104 | sketch_stateless_alu_visitor = \ 105 | SketchStatelessAluVisitor( 106 | self.stateless_alu_filename_, self.sketch_name_ + '_' + 107 | alu_name, potential_operands, self.generate_mux, 108 | self.constant_arr_size_) 109 | sketch_stateless_alu_visitor.visit(tree) 110 | self.add_holes(sketch_stateless_alu_visitor.global_holes) 111 | self.stateless_alu_hole_arguments_ = [ 112 | x for x in sorted( 113 | sketch_stateless_alu_visitor.stateless_alu_args 114 | )] 115 | 116 | self.num_stateless_muxes_ = \ 117 | len(sketch_stateless_alu_visitor.packet_fields) 118 | 119 | return (sketch_stateless_alu_visitor.helper_function_strings + 120 | sketch_stateless_alu_visitor.main_function) 121 | 122 | # Generate Sketch code for a simple stateful alu (+,-,*,/) 123 | # Takes one state and one packet operand (or immediate operand) as inputs 124 | # Updates the state in place and returns the old value of the state 125 | def generate_stateful_alu(self, alu_name): 126 | input_stream = FileStream(self.stateful_alu_filename_) 127 | lexer = aluLexer(input_stream) 128 | stream = CommonTokenStream(lexer) 129 | parser = aluParser(stream) 130 | tree = parser.alu() 131 | sketch_stateful_alu_visitor = SketchStatefulAluVisitor( 132 | self.sketch_name_ + '_' + alu_name, 133 | self.constant_arr_size_) 134 | sketch_stateful_alu_visitor.visit(tree) 135 | self.add_holes(sketch_stateful_alu_visitor.global_holes) 136 | self.stateful_alu_hole_arguments_ = [ 137 | x for x in sorted(sketch_stateful_alu_visitor.alu_args) 138 | ] 139 | self.num_operands_to_stateful_alu_ = len( 140 | sketch_stateful_alu_visitor.packet_fields) 141 | self.num_state_slots_ = len(sketch_stateful_alu_visitor.state_vars) 142 | 143 | return (sketch_stateful_alu_visitor.helper_function_strings + 144 | sketch_stateful_alu_visitor.main_function) 145 | 146 | # This allocator is only used for synthesized allocation 147 | # for stateless_vars 148 | def generate_pkt_field_allocator(self): 149 | for j in range(self.num_phv_containers_): 150 | for k in range(self.num_fields_in_prog_): 151 | self.add_hole( 152 | 'phv_config_' + str(k) + '_' + 153 | str(j), 1) 154 | # add assert for phv_config 155 | # assert sum(field) phv_config_{field}_{container} <= 1 156 | for j in range(self.num_phv_containers_): 157 | assert_predicate = '(' 158 | for k in range(self.num_fields_in_prog_): 159 | assert_predicate += ' phv_config_' + \ 160 | str(k) + '_' + str(j) + '+' 161 | assert_predicate += '0) <= 1' 162 | self.add_assert(assert_predicate) 163 | # assert sum(container) phv_config_{field}_{container} == 1 164 | for k in range(self.num_fields_in_prog_): 165 | assert_predicate = '(' 166 | for j in range(self.num_phv_containers_): 167 | assert_predicate += ' phv_config_' + \ 168 | str(k) + '_' + str(j) + '+' 169 | assert_predicate += '0) == 1' 170 | self.add_assert(assert_predicate) 171 | 172 | # This allocator is only used for synthesized allocation 173 | # for stateful_vars 174 | def generate_state_allocator_synthesized(self): 175 | # stateful_var_allocation_group_1_0_2 means 176 | # group 1 has been allocate to stateful_alu No.2 177 | # in stage 0 178 | for i in range(self.num_state_groups_): 179 | for j in range(self.num_pipeline_stages_): 180 | for k in range(self.num_phv_containers_): 181 | # Add hole_def for stateful_var_allocation_group_ 182 | self.add_hole( 183 | self.sketch_name_ + '_' + 'salu_config_' + str(i) + 184 | '_' + str(j) + '_' + str(k), 1) 185 | 186 | # add assert for stateful_var_allocation_group_ 187 | # any particular group can only be allocated to at most one 188 | # stateful_alu 189 | for i in range(self.num_state_groups_): 190 | assert_predicate = '(' 191 | for j in range(self.num_pipeline_stages_): 192 | for k in range(self.num_phv_containers_): 193 | assert_predicate += self.sketch_name_ + '_' + \ 194 | 'salu_config_' + str(i) + '_' + str(j) + '_' + \ 195 | str(k) + '+' 196 | assert_predicate += '0) <= 1' 197 | self.add_assert(assert_predicate) 198 | 199 | # any stateful_alu can only be used by at most one stateful_group 200 | for j in range(self.num_pipeline_stages_): 201 | for k in range(self.num_phv_containers_): 202 | assert_predicate = '(' 203 | for i in range(self.num_state_groups_): 204 | assert_predicate += self.sketch_name_ + '_' + \ 205 | 'salu_config_' + str(i) + '_' + str(j) + '_' + \ 206 | str(k) + '+' 207 | assert_predicate += '0) <= 1' 208 | self.add_assert(assert_predicate) 209 | 210 | def generate_state_allocator_canonicalized(self): 211 | for i in range(self.num_pipeline_stages_): 212 | for l in range(self.num_state_groups_): 213 | self.add_hole( 214 | self.sketch_name_ + '_' + 'salu_config_' + str(i) + '_' + 215 | str(l), 1) 216 | 217 | for i in range(self.num_pipeline_stages_): 218 | assert_predicate = '(' 219 | for l in range(self.num_state_groups_): 220 | assert_predicate += self.sketch_name_ + '_' + \ 221 | 'salu_config_' + str(i) + '_' + str(l) + ' + ' 222 | assert_predicate += '0) <= ' + str(self.num_alus_per_stage_) 223 | self.add_assert(assert_predicate) 224 | 225 | for l in range(self.num_state_groups_): 226 | assert_predicate = '(' 227 | for i in range(self.num_pipeline_stages_): 228 | assert_predicate += self.sketch_name_ + '_' + \ 229 | 'salu_config_' + str(i) + '_' + str(l) + ' + ' 230 | assert_predicate += '0) <= 1' 231 | self.add_assert(assert_predicate) 232 | 233 | # Sketch code for an n-to-1 mux 234 | def generate_mux(self, n, mux_name): 235 | assert (n >= 1) 236 | num_bits = get_hole_bit_width(n) 237 | operand_mux_template = self.jinja2_env_.get_template('mux.j2') 238 | mux_code = operand_mux_template.render( 239 | mux_name=mux_name, 240 | operand_list=['input' + str(i) for i in range(0, n)], 241 | arg_list=['int input' + str(i) for i in range(0, n)], 242 | num_operands=n) 243 | self.add_hole(mux_name + '_ctrl', num_bits) 244 | return mux_code 245 | 246 | # Stateful operand muxes, stateless ones are part of generate_stateless_alu 247 | def generate_stateful_operand_muxes(self): 248 | ret = '' 249 | # Generate one mux for inputs: num_phv_containers+1 to 1. The +1 is to 250 | # support constant/immediate operands. 251 | assert (self.num_operands_to_stateful_alu_ > 0) 252 | for i in range(self.num_pipeline_stages_): 253 | # TODO: merge these two into a function later 254 | if self.synthesized_allocation_: 255 | for l in range(self.num_phv_containers_): 256 | for k in range(self.num_operands_to_stateful_alu_): 257 | ret += self.generate_mux( 258 | self.num_phv_containers_, 259 | self.sketch_name_ + '_stateful_alu_' + str(i) + 260 | '_' + str(l) + '_' + 'operand_mux_' + 261 | str(k)) + '\n' 262 | else: 263 | for l in range(self.num_state_groups_): 264 | for k in range(self.num_operands_to_stateful_alu_): 265 | ret += self.generate_mux( 266 | self.num_phv_containers_, 267 | self.sketch_name_ + '_stateful_alu_' + str(i) + 268 | '_' + str(l) + '_' + 'operand_mux_' + str(k)) +\ 269 | '\n' 270 | return ret 271 | 272 | # Output muxes to pick between stateful ALUs and stateless ALU 273 | def generate_output_muxes(self): 274 | # Note: We are generating a mux that takes as input all virtual 275 | # stateful ALUs + corresponding stateless ALU The number of virtual 276 | # stateful ALUs is more or less than the physical stateful ALUs because 277 | # it equals the number of state variables in the original spec, but 278 | # this doesn't affect correctness because we enforce that the total 279 | # number of active virtual stateful ALUs is within the physical limit. 280 | # It also doesn't affect the correctness of modeling the output mux 281 | # because the virtual output mux setting can be translated into the 282 | # physical output mux setting during post processing. 283 | ret = '' 284 | for i in range(self.num_pipeline_stages_): 285 | for k in range(self.num_phv_containers_): 286 | # synthesized_allocation we give num_phv_containers virtual 287 | # stateful alus per stage 288 | if self.synthesized_allocation_: 289 | ret += self.generate_mux( 290 | self.num_phv_containers_ * self.num_state_slots_ 291 | + 1, 292 | self.sketch_name_ + '_output_mux_phv_' + 293 | str(i) + '_' + str(k)) + '\n' 294 | else: 295 | ret += self.generate_mux( 296 | self.num_state_groups_ * self.num_state_slots_ + 1, 297 | self.sketch_name_ + '_output_mux_phv_' + 298 | str(i) + '_' + str(k)) + '\n' 299 | return ret 300 | 301 | def generate_alus(self): 302 | # Generate sketch code for alus and immediate operands in each stage 303 | ret = '' 304 | for i in range(self.num_pipeline_stages_): 305 | for j in range(self.num_alus_per_stage_): 306 | ret += self.generate_stateless_alu( 307 | 'stateless_alu_' + str(i) + '_' + str(j), [ 308 | 'input' + str(k) 309 | for k in range(0, self.num_phv_containers_) 310 | ]) + '\n' 311 | if self.synthesized_allocation_: 312 | for l in range(self.num_phv_containers_): 313 | ret += self.generate_stateful_alu('stateful_alu_' + 314 | str(i) + 315 | '_' + str(l)) + '\n' 316 | else: 317 | for l in range(self.num_state_groups_): 318 | ret += self.generate_stateful_alu('stateful_alu_' + str(i) 319 | + '_' + str(l)) + '\n' 320 | return ret 321 | 322 | def generate_sketch(self, spec_filename, mode, synthesized_allocation, 323 | additional_constraints=[], 324 | hole_assignments=OrderedDict(), 325 | additional_testcases=''): 326 | self.reset_holes_and_asserts() 327 | assert(mode in [Mode.CODEGEN, Mode.VERIFY]) 328 | template = self.jinja2_env_.get_template('code_generator.j2') 329 | 330 | # Create stateless and stateful ALUs, operand muxes for stateful ALUs, 331 | # and output muxes. 332 | alu_definitions = self.generate_alus() 333 | stateful_operand_mux_definitions = ( 334 | self.generate_stateful_operand_muxes()) 335 | output_mux_definitions = self.generate_output_muxes() 336 | 337 | # Create allocator to ensure each state var is assigned to exactly 338 | # stateful ALU and vice versa. 339 | if self.synthesized_allocation_: 340 | self.generate_pkt_field_allocator() 341 | self.generate_state_allocator_synthesized() 342 | else: 343 | self.generate_state_allocator_canonicalized() 344 | 345 | return template.render( 346 | mode=mode, 347 | synthesized_allocation=synthesized_allocation, 348 | sketch_name=self.sketch_name_, 349 | spec_filename=spec_filename, 350 | num_pipeline_stages=self.num_pipeline_stages_, 351 | num_alus_per_stage=self.num_alus_per_stage_, 352 | num_phv_containers=self.num_phv_containers_, 353 | # Add constant_arr_def to hole_definitions 354 | hole_definitions=self.constant_arr_def_ + self.hole_preamble_, 355 | stateful_operand_mux_definitions=stateful_operand_mux_definitions, 356 | num_stateless_muxes=self.num_stateless_muxes_, 357 | output_mux_definitions=output_mux_definitions, 358 | alu_definitions=alu_definitions, 359 | num_fields_in_prog=self.num_fields_in_prog_, 360 | output_packet_fields=self.output_packet_fields_, 361 | output_state_groups=self.output_state_groups_, 362 | num_state_groups=self.num_state_groups_, 363 | spec_as_sketch=Path(spec_filename).read_text(), 364 | all_assertions=self.asserts_, 365 | hole_arguments=self.hole_arguments_, 366 | stateful_alu_hole_arguments=self.stateful_alu_hole_arguments_, 367 | num_operands_to_stateful_alu=self.num_operands_to_stateful_alu_, 368 | num_state_slots=self.num_state_slots_, 369 | additional_constraints='\n'.join( 370 | ['assert(' + str(x) + ');' for x in additional_constraints]), 371 | # Add constant_arr_def to hole_assignments 372 | hole_assignments=self.constant_arr_def_ + '\n'.join( 373 | ['int ' + str(hole) + ' = ' + str(value) + ';' 374 | for hole, value in hole_assignments.items()]), 375 | additional_testcases=additional_testcases, 376 | input_packet_fields=self.input_packet_fields_) 377 | -------------------------------------------------------------------------------- /chipc/sketch_stateful_alu_visitor.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from textwrap import dedent 3 | 4 | from overrides import overrides 5 | 6 | from chipc.aluParser import aluParser 7 | from chipc.aluVisitor import aluVisitor 8 | 9 | 10 | class SketchStatefulAluVisitor(aluVisitor): 11 | def __init__(self, alu_name, constant_arr_size): 12 | self.alu_name = alu_name 13 | self.constant_arr_size = constant_arr_size 14 | self.mux5_count = 0 15 | self.mux4_count = 0 16 | self.mux3_count = 0 17 | self.mux2_count = 0 18 | self.rel_op_count = 0 19 | self.arith_op_count = 0 20 | self.opt_count = 0 21 | self.constant_count = 0 22 | self.compute_alu_count = 0 23 | self.bool_op_count = 0 24 | self.helper_function_strings = '\n\n\n' 25 | self.alu_args = OrderedDict() 26 | self.global_holes = OrderedDict() 27 | self.main_function = '' 28 | self.packet_fields = [] 29 | self.state_vars = [] 30 | 31 | # Copied From Taegyun's code 32 | def add_hole(self, hole_name, hole_width): 33 | prefixed_hole = self.alu_name + '_' + hole_name 34 | assert (prefixed_hole + '_global' not in self.global_holes) 35 | self.global_holes[prefixed_hole + '_global'] = hole_width 36 | assert (hole_name not in self.alu_args) 37 | self.alu_args[hole_name] = hole_width 38 | 39 | @overrides 40 | def visitAlu(self, ctx): 41 | self.main_function += ('int ' + self.alu_name + 42 | '(ref | StateGroup | state_group, ') 43 | 44 | self.visit(ctx.getChild(0, aluParser.Packet_field_defContext)) 45 | 46 | self.visit(ctx.getChild(0, aluParser.State_var_defContext)) 47 | 48 | self.main_function += \ 49 | ', %s) {\n' 50 | 51 | assert len(self.state_vars) > 0 52 | for idx, state_var in enumerate(self.state_vars): 53 | self.main_function += '\nint ' + state_var + \ 54 | ' = state_group.state_' + str(idx) + ';' 55 | 56 | self.visit(ctx.getChild(0, aluParser.Alu_bodyContext)) 57 | self.main_function += '\n\n}' 58 | argument_string = ','.join( 59 | ['int ' + hole for hole in sorted(self.alu_args)]) 60 | self.main_function = self.main_function % argument_string 61 | 62 | @overrides 63 | def visitState_var_def(self, ctx): 64 | self.visitChildren(ctx) 65 | 66 | @overrides 67 | def visitState_var_seq(self, ctx): 68 | if ctx.getChildCount() > 0: 69 | self.visitChildren(ctx) 70 | 71 | @overrides 72 | def visitSingleStateVar(self, ctx): 73 | self.visit(ctx.getChild(0, aluParser.State_varContext)) 74 | 75 | @overrides 76 | def visitMultipleStateVars(self, ctx): 77 | self.visitChildren(ctx) 78 | 79 | @overrides 80 | def visitState_var(self, ctx): 81 | state_var_name = ctx.getText() 82 | self.state_vars.append(state_var_name) 83 | 84 | @overrides 85 | def visitPacket_field_def(self, ctx): 86 | self.visit(ctx.getChild(0, aluParser.Packet_field_seqContext)) 87 | 88 | @overrides 89 | def visitPacket_field_seq(self, ctx): 90 | if ctx.getChildCount() > 0: 91 | self.visitChildren(ctx) 92 | 93 | @overrides 94 | def visitSinglePacketField(self, ctx): 95 | self.visitChildren(ctx) 96 | 97 | @overrides 98 | def visitMultiplePacketFields(self, ctx): 99 | self.visit(ctx.getChild(0, aluParser.Packet_fieldContext)) 100 | self.main_function += ', ' 101 | self.visit(ctx.getChild(0, aluParser.Packet_fieldsContext)) 102 | 103 | @overrides 104 | def visitPacket_field(self, ctx): 105 | packet_field_name = ctx.getText() 106 | self.main_function += 'int ' 107 | self.main_function += packet_field_name 108 | self.packet_fields.append(packet_field_name) 109 | 110 | @overrides 111 | def visitVariable(self, ctx): 112 | self.main_function += ctx.getText() 113 | 114 | @overrides 115 | def visitNum(self, ctx): 116 | self.main_function += ctx.getText() 117 | 118 | @overrides 119 | def visitTemp_var(self, ctx): 120 | self.main_function += ctx.getText() 121 | 122 | @overrides 123 | def visitState_indicator(self, ctx): 124 | pass 125 | 126 | @overrides 127 | def visitReturn_statement(self, ctx): 128 | for idx, state_var in enumerate(self.state_vars): 129 | self.main_function += '\nstate_group.state_' + str( 130 | idx) + ' = ' + state_var + ';' 131 | 132 | self.main_function += 'return ' 133 | self.visit(ctx.getChild(1)) 134 | self.main_function += ';' 135 | 136 | @overrides 137 | def visitCondition_block(self, ctx): 138 | self.main_function += ' (' 139 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 140 | self.main_function += ')\n {' 141 | self.visit(ctx.getChild(0, aluParser.Alu_bodyContext)) 142 | self.main_function += '\n}\n' 143 | 144 | @overrides 145 | def visitStmtIfElseIfElse(self, ctx): 146 | condition_blocks = ctx.condition_block() 147 | 148 | for i, block in enumerate(condition_blocks): 149 | if i != 0: 150 | self.main_function += 'else ' 151 | self.main_function += 'if' 152 | self.visit(block) 153 | 154 | if ctx.else_body is not None: 155 | self.main_function += 'else {\n' 156 | self.visit(ctx.else_body) 157 | self.main_function += '\n}\n' 158 | 159 | @overrides 160 | def visitAnd(self, ctx): 161 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 162 | self.main_function += ' && ' 163 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 164 | 165 | # TODO: Implement (!A) && B, we only support A && (!B), and !B && A. 166 | @overrides 167 | def visitNOT(self, ctx): 168 | self.main_function += ' !' 169 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 170 | 171 | @overrides 172 | def visitOr(self, ctx): 173 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 174 | self.main_function += ' || ' 175 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 176 | 177 | @overrides 178 | def visitStmtUpdateExpr(self, ctx): 179 | assert ctx.getChild(ctx.getChildCount() - 1).getText() == ';', \ 180 | 'Every update must end with a semicolon.' 181 | self.visit(ctx.getChild(0, aluParser.VariableContext)) 182 | self.main_function += ' = ' 183 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 184 | self.main_function += ';' 185 | 186 | @overrides 187 | def visitStmtUpdateTempInt(self, ctx): 188 | assert ctx.getChild(ctx.getChildCount() - 1).getText() == ';', \ 189 | 'Every update must end with a semicolon.' 190 | self.main_function += ctx.getChild(0).getText() 191 | self.visit(ctx.getChild(0, aluParser.Temp_varContext)) 192 | self.main_function += '=' 193 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 194 | self.main_function += ';' 195 | 196 | @overrides 197 | def visitStmtUpdateTempBit(self, ctx): 198 | assert ctx.getChild(ctx.getChildCount() - 1).getText() == ';', \ 199 | 'Every update must end with a semicolon.' 200 | self.main_function += ctx.getChild(0).getText() 201 | self.visit(ctx.getChild(0, aluParser.Temp_varContext)) 202 | self.main_function += '=' 203 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 204 | self.main_function += ';' 205 | 206 | @overrides 207 | def visitAssertFalse(self, ctx): 208 | self.main_function += ctx.getChild(0).getText() 209 | 210 | @overrides 211 | def visitExprWithOp(self, ctx): 212 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 213 | self.main_function += ctx.getChild(1).getText() 214 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 215 | 216 | @overrides 217 | def visitExprWithParen(self, ctx): 218 | self.main_function += '(' 219 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 220 | self.main_function += ')' 221 | 222 | @overrides 223 | def visitMux5(self, ctx): 224 | self.main_function += self.alu_name + '_' + 'Mux5_' + str( 225 | self.mux5_count) + '(' 226 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 227 | self.main_function += ',' 228 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 229 | self.main_function += ',' 230 | self.visit(ctx.getChild(2, aluParser.ExprContext)) 231 | self.main_function += ',' 232 | self.visit(ctx.getChild(3, aluParser.ExprContext)) 233 | self.main_function += ',' 234 | self.visit(ctx.getChild(4, aluParser.ExprContext)) 235 | self.main_function += ',' + 'Mux5_' + str(self.mux5_count) + ')' 236 | self.generateMux5() 237 | self.mux5_count += 1 238 | 239 | @overrides 240 | def visitMux4(self, ctx): 241 | self.main_function += self.alu_name + '_' + 'Mux4_' + str( 242 | self.mux4_count) + '(' 243 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 244 | self.main_function += ',' 245 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 246 | self.main_function += ',' 247 | self.visit(ctx.getChild(2, aluParser.ExprContext)) 248 | self.main_function += ',' 249 | self.visit(ctx.getChild(3, aluParser.ExprContext)) 250 | self.main_function += ',' + 'Mux4_' + str(self.mux4_count) + ')' 251 | self.generateMux4() 252 | self.mux4_count += 1 253 | 254 | @overrides 255 | def visitMux3(self, ctx): 256 | self.main_function += self.alu_name + '_' + 'Mux3_' + str( 257 | self.mux3_count) + '(' 258 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 259 | self.main_function += ',' 260 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 261 | self.main_function += ',' 262 | self.visit(ctx.getChild(2, aluParser.ExprContext)) 263 | self.main_function += ',' + 'Mux3_' + str(self.mux3_count) + ')' 264 | self.generateMux3() 265 | self.mux3_count += 1 266 | 267 | @overrides 268 | def visitMux3WithNum(self, ctx): 269 | self.main_function += self.alu_name + '_Mux3_' + str( 270 | self.mux3_count) + '(' 271 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 272 | self.main_function += ',' 273 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 274 | self.main_function += ',' + 'Mux3_' + str(self.mux3_count) + ')' 275 | # Here it's the child with index 6. The grammar parse for this 276 | # expression as whole is following, NUM '(' expr ',' expr ',' NUM ')' 277 | # Where NUM is not considered as an expr. Consider parsing NUM as expr 278 | # so we could simply do ctx.getChild(2, stateful_aluParser.ExprContext) 279 | # below. 280 | self.generateMux3WithNum(ctx.getChild(6).getText()) 281 | self.mux3_count += 1 282 | 283 | @overrides 284 | def visitMux2(self, ctx): 285 | self.main_function += self.alu_name + '_' + 'Mux2_' + str( 286 | self.mux2_count) + '(' 287 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 288 | self.main_function += ',' 289 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 290 | self.main_function += ',' + 'Mux2_' + str(self.mux2_count) + ')' 291 | self.generateMux2() 292 | self.mux2_count += 1 293 | 294 | @overrides 295 | def visitRelOp(self, ctx): 296 | self.main_function += self.alu_name + '_' + 'rel_op_' + str( 297 | self.rel_op_count) + '(' 298 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 299 | self.main_function += ',' 300 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 301 | self.main_function += ',' + 'rel_op_' + \ 302 | str(self.rel_op_count) + ') == 1' 303 | self.generateRelOp() 304 | self.rel_op_count += 1 305 | 306 | @overrides 307 | def visitBoolOp(self, ctx): 308 | self.main_function += self.alu_name + '_' + 'bool_op_' + str( 309 | self.bool_op_count) + '(' 310 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 311 | self.main_function += ',' 312 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 313 | self.main_function += ',' + 'bool_op_' + \ 314 | str(self.bool_op_count) + ') == 1' 315 | self.generateBoolOp() 316 | self.bool_op_count += 1 317 | 318 | @overrides 319 | def visitArithOp(self, ctx): 320 | self.main_function += self.alu_name + '_' + 'arith_op_' + str( 321 | self.arith_op_count) + '(' 322 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 323 | self.main_function += ',' 324 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 325 | self.main_function += ',' + 'arith_op_' + \ 326 | str(self.arith_op_count) + ')' 327 | self.generateArithOp() 328 | self.arith_op_count += 1 329 | 330 | @overrides 331 | def visitOpt(self, ctx): 332 | self.main_function += self.alu_name + '_' + 'Opt_' + str( 333 | self.opt_count) + '(' 334 | self.visitChildren(ctx) 335 | self.main_function += ',' + 'Opt_' + str(self.opt_count) + ')' 336 | self.generateOpt() 337 | self.opt_count += 1 338 | 339 | @overrides 340 | def visitConstant(self, ctx): 341 | self.main_function += self.alu_name + '_' + 'C_' + str( 342 | self.constant_count) + '(' 343 | self.main_function += 'const_' + str(self.constant_count) + ')' 344 | self.generateConstant() 345 | self.constant_count += 1 346 | 347 | @overrides 348 | def visitComputeAlu(self, ctx): 349 | self.main_function += self.alu_name + '_' + 'compute_alu_' + str( 350 | self.compute_alu_count) + '(' 351 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 352 | self.main_function += ',' 353 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 354 | self.main_function += ',' + 'compute_alu_' + \ 355 | str(self.compute_alu_count) + ')' 356 | self.generateComputeAlu() 357 | self.compute_alu_count += 1 358 | 359 | def generateMux5(self): 360 | function_str = """\ 361 | int {alu_name}_Mux5_{mux5_count}(int op1, int op2, int op3, int op4, int op5, 362 | int opcode) {{ 363 | if (opcode == 0) return op1; 364 | else if (opcode == 1) return op2; 365 | else if (opcode == 2) return op3; 366 | else if (opcode == 3) return op4; 367 | else return op5; 368 | }} 369 | """ 370 | self.helper_function_strings += dedent( 371 | function_str.format( 372 | alu_name=self.alu_name, 373 | mux5_count=str(self.mux5_count))) 374 | self.add_hole('Mux5_' + str(self.mux5_count), 3) 375 | 376 | def generateMux4(self): 377 | function_str = """\ 378 | int {alu_name}_Mux4_{mux4_count}(int op1, int op2, int op3, int op4, 379 | int opcode) {{ 380 | if (opcode == 0) return op1; 381 | else if (opcode == 1) return op2; 382 | else if (opcode == 2) return op3; 383 | else return op4; 384 | }} 385 | """ 386 | self.helper_function_strings += dedent( 387 | function_str.format( 388 | alu_name=self.alu_name, 389 | mux4_count=str(self.mux4_count))) 390 | self.add_hole('Mux4_' + str(self.mux4_count), 2) 391 | 392 | def generateMux3(self): 393 | self.helper_function_strings += 'int ' + self.alu_name + '_' + \ 394 | 'Mux3_' + str(self.mux3_count) + \ 395 | """(int op1, int op2, int op3, int choice) { 396 | if (choice == 0) return op1; 397 | else if (choice == 1) return op2; 398 | else return op3; 399 | } \n\n""" 400 | self.add_hole('Mux3_' + str(self.mux3_count), 2) 401 | 402 | def generateMux3WithNum(self, num): 403 | # NOTE: To escape curly brace, use double curly brace. 404 | function_str = """\ 405 | int {0}_Mux3_{1}(int op1, int op2, int choice) {{ 406 | if (choice == 0) return op1; 407 | else if (choice == 1) return op2; 408 | else return {2}; 409 | }}\n 410 | """ 411 | self.helper_function_strings += dedent( 412 | function_str.format(self.alu_name, str(self.mux3_count), 413 | num)) 414 | # Add two bit width hole, to express 3 possible values for choice in 415 | # the above code. 416 | self.add_hole('Mux3_' + str(self.mux3_count), 2) 417 | 418 | def generateMux2(self): 419 | self.helper_function_strings += 'int ' + self.alu_name + '_' + \ 420 | 'Mux2_' + str(self.mux2_count) + \ 421 | """(int op1, int op2, int choice) { 422 | if (choice == 0) return op1; 423 | else return op2; 424 | } \n\n""" 425 | self.add_hole('Mux2_' + str(self.mux2_count), 1) 426 | # TODO: return the member of the vector 427 | 428 | def generateConstant(self): 429 | self.helper_function_strings += 'int ' + self.alu_name + '_' + \ 430 | 'C_' + str(self.constant_count) + """(int const) { 431 | return constant_vector[const]; 432 | }\n\n""" 433 | self.add_hole('const_' + str(self.constant_count), 434 | self.constant_arr_size) 435 | 436 | def generateRelOp(self): 437 | self.helper_function_strings += 'int ' + self.alu_name + '_' + \ 438 | 'rel_op_' + str(self.rel_op_count) + \ 439 | """(int operand1, int operand2, int opcode) { 440 | if (opcode == 0) { 441 | return (operand1 != operand2) ? 1 : 0; 442 | } else if (opcode == 1) { 443 | return (operand1 < operand2) ? 1 : 0; 444 | } else if (opcode == 2) { 445 | return (operand1 > operand2) ? 1 : 0; 446 | } else { 447 | return (operand1 == operand2) ? 1 : 0; 448 | } 449 | } \n\n""" 450 | self.add_hole('rel_op_' + str(self.rel_op_count), 2) 451 | 452 | def generateBoolOp(self): 453 | function_str = """\ 454 | bit {alu_name}_bool_op_{bool_op_count} (bit op1, bit op2, int opcode) {{ 455 | if (opcode == 0) {{ 456 | return false; 457 | }} else if (opcode == 1) {{ 458 | return ~(op1 || op2); 459 | }} else if (opcode == 2) {{ 460 | return (~op1) && op2; 461 | }} else if (opcode == 3) {{ 462 | return ~op1; 463 | }} else if (opcode == 4) {{ 464 | return op1 && (~op2); 465 | }} else if (opcode == 5) {{ 466 | return ~op2; 467 | }} else if (opcode == 6) {{ 468 | return op1 ^ op2; 469 | }} else if (opcode == 7) {{ 470 | return ~(op1 && op2); 471 | }} else if (opcode == 8) {{ 472 | return op1 && op2; 473 | }} else if (opcode == 9) {{ 474 | return ~(op1 ^ op2); 475 | }} else if (opcode == 10) {{ 476 | return op2; 477 | }} else if (opcode == 11) {{ 478 | return (~op1) || op2; 479 | }} else if (opcode == 12) {{ 480 | return op1; 481 | }} else if (opcode == 13) {{ 482 | return op1 || (~op2); 483 | }} else if (opcode == 14) {{ 484 | return op1 || op2; 485 | }} else {{ 486 | return true; 487 | }} 488 | }}\n 489 | """ 490 | self.helper_function_strings += dedent( 491 | function_str.format( 492 | alu_name=self.alu_name, 493 | bool_op_count=self.bool_op_count 494 | ) 495 | ) 496 | self.add_hole('bool_op_' + str(self.bool_op_count), 4) 497 | 498 | def generateArithOp(self): 499 | self.helper_function_strings += 'int ' + self.alu_name + '_' + \ 500 | 'arith_op_' + str(self.arith_op_count) + \ 501 | """(int operand1, int operand2, int opcode) { 502 | if (opcode == 0) { 503 | return operand1 + operand2; 504 | } else { 505 | return operand1 - operand2; 506 | } 507 | }\n\n""" 508 | self.add_hole('arith_op_' + str(self.arith_op_count), 1) 509 | 510 | def generateOpt(self): 511 | self.helper_function_strings += 'int ' + self.alu_name + '_' + \ 512 | 'Opt_' + str(self.opt_count) + """(int op1, int enable) { 513 | if (enable != 0) return 0; 514 | return op1; 515 | } \n\n""" 516 | self.add_hole('Opt_' + str(self.opt_count), 1) 517 | 518 | def generateComputeAlu(self): 519 | function_str = """\ 520 | int {alu_name}_compute_alu_{compute_alu_count}(int op1, int op2, int opcode) {{ 521 | if (opcode == 0) {{ 522 | return op1 + op2; 523 | }} else if (opcode == 1) {{ 524 | return op1 - op2; 525 | }} else if (opcode == 2) {{ 526 | return op2 - op1; 527 | }} else if (opcode == 3) {{ 528 | return op2; 529 | }} else if (opcode == 4) {{ 530 | return op1; 531 | }} else if (opcode == 5) {{ 532 | return 0; 533 | }} else {{ 534 | return 1; 535 | }} 536 | }}\n""" 537 | self.helper_function_strings += dedent( 538 | function_str.format(alu_name=self.alu_name, 539 | compute_alu_count=str(self.compute_alu_count))) 540 | 541 | self.add_hole('compute_alu_' + str(self.compute_alu_count), 5) 542 | -------------------------------------------------------------------------------- /chipc/sketch_stateless_alu_visitor.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import OrderedDict 3 | 4 | from overrides import overrides 5 | 6 | from chipc.aluParser import aluParser 7 | from chipc.aluVisitor import aluVisitor 8 | from chipc.utils import get_hole_bit_width 9 | 10 | 11 | class SketchStatelessAluVisitor (aluVisitor): 12 | def __init__(self, alu_filename, alu_name, potential_operands, 13 | generate_stateless_mux, constant_arr_size): 14 | # TODO: alu_filename is need to open the alu template file and get the 15 | # maximum possible value of opcode as specified in there. However, this 16 | # can also be handled by the parser and visitor. Remove this argument 17 | # by doing so. 18 | # NOTE: alu_name is stateless_alu prefixed and postfixed with the 19 | # original spec name and pipeline stage, and alu number within that 20 | # stage. See find_opcode_bis() 21 | self.alu_filename = alu_filename 22 | self.alu_name = alu_name 23 | self.potential_operands = potential_operands 24 | self.generate_stateless_mux = generate_stateless_mux 25 | self.constant_arr_size = constant_arr_size 26 | self.helper_function_strings = '\n\n\n' 27 | self.global_holes = OrderedDict() 28 | self.stateless_alu_args = OrderedDict() 29 | self.main_function = '' 30 | self.packet_fields = [] 31 | 32 | self.opcode_bits = self.find_opcode_bits() 33 | 34 | def add_hole(self, hole_name, hole_width): 35 | prefixed_hole = self.alu_name + '_' + hole_name 36 | 37 | if 'immediate_operand' in hole_name: 38 | hole_width = self.constant_arr_size 39 | assert (prefixed_hole not in self.global_holes) 40 | try: 41 | assert prefixed_hole not in self.global_holes, prefixed_hole + \ 42 | ' already in global holes' 43 | except AssertionError: 44 | raise 45 | self.global_holes[prefixed_hole] = hole_width 46 | 47 | assert (hole_name not in self.stateless_alu_args) 48 | self.stateless_alu_args[hole_name+'_hole_local'] = hole_width 49 | 50 | # Calculates number of bits to set opcode hole to 51 | def find_opcode_bits(self): 52 | with open(self.alu_filename) as f: 53 | first_line = f.readline() 54 | prog = re.compile(r'// Max value of opcode is (\d+)') 55 | result = int(prog.match(first_line).groups(0)[0]) 56 | return get_hole_bit_width(result + 1) 57 | 58 | # Generates the mux ctrl paremeters 59 | def write_mux_inputs(self): 60 | for mux_index in range(len(self.packet_fields)): 61 | self.main_function += ', ' 62 | self.main_function += 'int ' + 'operand_mux_' + str(mux_index) + \ 63 | '_ctrl_hole_local' 64 | 65 | # Reassigns hole variable parameters to inputs specified 66 | # in ALU file 67 | def write_temp_hole_vars(self): 68 | for arg in self.stateless_alu_args: 69 | temp_var = arg[:arg.index('_hole_local')] 70 | # Follow the format like this 71 | # int opcode = opcode_hole_local; 72 | # int immediate_operand = \ 73 | # constant_vector[immediate_operand_hole_local]; 74 | if 'immediate_operand' in temp_var: 75 | self.main_function += '\tint ' + temp_var + \ 76 | ' = ' + 'constant_vector[' + arg + '];\n' 77 | else: 78 | self.main_function += '\tint ' + temp_var + ' = ' + arg + ';\n' 79 | 80 | # Generates code that calls muxes from within stateless ALU 81 | def write_mux_call(self): 82 | mux_index = 0 83 | mux_input_str = '' 84 | for operand in self.potential_operands: 85 | mux_input_str += operand+',' 86 | for i, p in enumerate(self.packet_fields): 87 | assert(i == mux_index) 88 | mux_ctrl = 'operand_mux_' + str(mux_index) + '_ctrl_hole_local' 89 | self.main_function += '\tint ' + p + ' = ' + \ 90 | self.alu_name + '_operand_mux_' + \ 91 | str(mux_index) + '(' + mux_input_str + \ 92 | mux_ctrl + ');\n' 93 | full_name = self.alu_name + '_operand_mux_' + str(mux_index) 94 | self.helper_function_strings += \ 95 | self.generate_stateless_mux( 96 | len(self.potential_operands), full_name) 97 | 98 | mux_index += 1 99 | 100 | @overrides 101 | def visitAlu(self, ctx): 102 | self.visit(ctx.getChild(0, aluParser.State_indicatorContext)) 103 | self.visit(ctx.getChild(0, aluParser.State_var_defContext)) 104 | 105 | self.main_function += 'int ' + self.alu_name + '(' 106 | # Takes in all phv_containers as parameters 107 | for p in self.potential_operands: 108 | self.main_function += 'int ' + p + ',' 109 | # Records packet fields being used (results from the muxes) 110 | self.visit(ctx.getChild(0, aluParser.Packet_field_defContext)) 111 | 112 | # Adds hole variables to parameters 113 | self.visit(ctx.getChild(0, aluParser.Hole_defContext)) 114 | self.write_mux_inputs() 115 | self.main_function += ' %s){\n' 116 | self.write_temp_hole_vars() 117 | self.write_mux_call() 118 | self.visit(ctx.getChild(0, aluParser.Alu_bodyContext)) 119 | self.main_function += '\n}' 120 | 121 | # For additional hole variables (relops, opts, etc) 122 | argument_string = '' 123 | if len(self.stateless_alu_args) > 2: 124 | 125 | argument_string = ',' + ','.join( 126 | ['int ' + hole for hole in sorted(self.stateless_alu_args)]) 127 | 128 | self.main_function = self.main_function % argument_string 129 | 130 | if self.main_function[-1] == ',': 131 | self.main_function = self.main_function[:-1] 132 | 133 | @overrides 134 | def visitState_indicator(self, ctx): 135 | try: 136 | assert ctx.getChildCount() == 3, 'Error: invalid state' + \ 137 | ' indicator argument provided for type. Insert + \ 138 | ''\'stateful\' or \'stateless\'' 139 | 140 | assert ctx.getChild(2).getText() == 'stateless', 'Error: ' + \ 141 | 'type is declared as ' + ctx.getChild(2).getText() + \ 142 | ' and not \'stateless\' for stateless ALU ' 143 | 144 | except AssertionError: 145 | raise 146 | 147 | @overrides 148 | def visitHole_def(self, ctx): 149 | self.visit(ctx.getChild(0, aluParser.Hole_seqContext)) 150 | 151 | @overrides 152 | def visitHole_seq(self, ctx): 153 | if ctx.getChildCount() > 0: 154 | self.visitChildren(ctx) 155 | 156 | @overrides 157 | def visitSingleHoleVar(self, ctx): 158 | self.visitChildren(ctx) 159 | 160 | @overrides 161 | def visitMultipleHoleVars(self, ctx): 162 | self.visit(ctx.getChild(0, aluParser.Hole_varContext)) 163 | self.main_function += ', ' 164 | self.visit(ctx.getChild(0, aluParser.Hole_varsContext)) 165 | 166 | @overrides 167 | def visitHole_var(self, ctx): 168 | var_name = ctx.getText() 169 | 170 | num_bits = 2 171 | if var_name == 'opcode': 172 | num_bits = self.opcode_bits 173 | self.add_hole(var_name, num_bits) 174 | self.main_function += 'int ' + var_name + '_hole_local' 175 | 176 | @overrides 177 | def visitPacket_field_def(self, ctx): 178 | self.visit(ctx.getChild(0, aluParser.Packet_field_seqContext)) 179 | 180 | @overrides 181 | def visitPacket_field_seq(self, ctx): 182 | if ctx.getChildCount() > 0: 183 | self.visitChildren(ctx) 184 | 185 | @overrides 186 | def visitSinglePacketField(self, ctx): 187 | self.visitChildren(ctx) 188 | 189 | @overrides 190 | def visitMultiplePacketFields(self, ctx): 191 | self.visit(ctx.getChild(0, aluParser.Packet_fieldContext)) 192 | self.visit(ctx.getChild(0, aluParser.Packet_fieldsContext)) 193 | 194 | @overrides 195 | def visitPacket_field(self, ctx): 196 | packet_field_name = ctx.getText() 197 | self.packet_fields.append(packet_field_name) 198 | 199 | @overrides 200 | def visitVar(self, ctx): 201 | self.main_function += ctx.getText() 202 | 203 | @overrides 204 | def visitCondition_block(self, ctx): 205 | self.main_function += ' (' 206 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 207 | self.main_function += ')\n {' 208 | self.visit(ctx.getChild(0, aluParser.Alu_bodyContext)) 209 | self.main_function += '\n}\n' 210 | 211 | @overrides 212 | def visitStmtIfElseIfElse(self, ctx): 213 | condition_blocks = ctx.condition_block() 214 | 215 | for i, block in enumerate(condition_blocks): 216 | if i != 0: 217 | self.main_function += 'else ' 218 | self.main_function += 'if' 219 | self.visit(block) 220 | 221 | if ctx.else_body is not None: 222 | self.main_function += 'else {\n' 223 | self.visit(ctx.else_body) 224 | self.main_function += '\n}\n' 225 | 226 | @overrides 227 | def visitStmtUpdateExpr(self, ctx): 228 | assert ctx.getChild(ctx.getChildCount() - 1).getText() == ';', \ 229 | 'Every update must end with a semicolon.' 230 | self.visit(ctx.getChild(0, aluParser.State_varContext)) 231 | self.main_function += ' = ' 232 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 233 | self.main_function += ';' 234 | 235 | @overrides 236 | def visitStmtUpdateTempInt(self, ctx): 237 | assert ctx.getChild(ctx.getChildCount() - 1).getText() == ';', \ 238 | 'Every update must end with a semicolon.' 239 | self.main_function += ctx.getChild(0).getText() 240 | self.visit(ctx.getChild(0, aluParser.Temp_varContext)) 241 | self.main_function += '=' 242 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 243 | self.main_function += ';' 244 | 245 | @overrides 246 | def visitStmtUpdateTempBit(self, ctx): 247 | assert ctx.getChild(ctx.getChildCount() - 1).getText() == ';', \ 248 | 'Every update must end with a semicolon.' 249 | self.main_function += ctx.getChild(0).getText() 250 | self.visit(ctx.getChild(0, aluParser.Temp_varContext)) 251 | self.main_function += '=' 252 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 253 | self.main_function += ';' 254 | 255 | @overrides 256 | def visitReturn_statement(self, ctx): 257 | self.main_function += '\t\treturn ' 258 | self.visit(ctx.getChild(1)) 259 | self.main_function += ';' 260 | 261 | @overrides 262 | def visitEquals(self, ctx): 263 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 264 | self.main_function += '==' 265 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 266 | 267 | @overrides 268 | def visitGreater(self, ctx): 269 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 270 | self.main_function += '>' 271 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 272 | 273 | @overrides 274 | def visitGreaterEqual(self, ctx): 275 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 276 | self.main_function += '>=' 277 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 278 | 279 | @overrides 280 | def visitLess(self, ctx): 281 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 282 | self.main_function += '<' 283 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 284 | 285 | @overrides 286 | def visitLessEqual(self, ctx): 287 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 288 | self.main_function += '<=' 289 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 290 | 291 | @overrides 292 | def visitOr(self, ctx): 293 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 294 | self.main_function += '||' 295 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 296 | 297 | @overrides 298 | def visitAnd(self, ctx): 299 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 300 | self.main_function += '&&' 301 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 302 | 303 | @overrides 304 | def visitNotEqual(self, ctx): 305 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 306 | self.main_function += '!=' 307 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 308 | 309 | @overrides 310 | def visitNum(self, ctx): 311 | self.main_function += ctx.getText() 312 | 313 | @overrides 314 | def visitTernary(self, ctx): 315 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 316 | self.main_function += '?' 317 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 318 | self.main_function += ':' 319 | self.visit(ctx.getChild(2, aluParser.ExprContext)) 320 | 321 | @overrides 322 | def visitTrue(self, ctx): 323 | self.main_function += 'true' 324 | 325 | @overrides 326 | def visitExprWithParen(self, ctx): 327 | self.main_function += ctx.getChild(0).getText() 328 | self.visit(ctx.getChild(1)) 329 | 330 | self.main_function += ctx.getChild(2).getText() 331 | 332 | @overrides 333 | def visitExprWithOp(self, ctx): 334 | self.visit(ctx.getChild(0, aluParser.ExprContext)) 335 | self.main_function += ctx.getChild(1).getText() 336 | self.visit(ctx.getChild(1, aluParser.ExprContext)) 337 | 338 | @overrides 339 | def visitMux2(self, ctx): 340 | assert False, 'Unexpected keyword Mux2 in stateless ALU.' 341 | 342 | @overrides 343 | def visitMux3(self, ctx): 344 | assert False, 'Unexpected keyword Mux3 in stateless ALU.' 345 | 346 | @overrides 347 | def visitMux3WithNum(self, ctx): 348 | assert False, 'Unexpected keyword Mux3 in stateless ALU.' 349 | 350 | @overrides 351 | def visitOpt(self, ctx): 352 | assert False, 'Unexpected keyword Opt in stateless ALU.' 353 | 354 | @overrides 355 | def visitRelOp(self, ctx): 356 | assert False, 'Unexpected keyword rel_op in stateless ALU.' 357 | 358 | @overrides 359 | def visitArithOp(self, ctx): 360 | assert False, 'Unexpected keyword arith_op in stateless ALU.' 361 | 362 | @overrides 363 | def visitConstant(self, ctx): 364 | assert False, 'Unexpected keyword C() in stateless ALU.' 365 | -------------------------------------------------------------------------------- /chipc/sketch_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import subprocess 3 | from pathlib import Path 4 | 5 | SLV_TIMEOUT_MINS = 0.1 6 | 7 | 8 | def check_syntax(sketch_file_name): 9 | # Check syntax of given sketch file. 10 | (return_code, output) = subprocess.getstatusoutput( 11 | 'sketch ' + sketch_file_name + ' --slv-timeout=0.001') 12 | if (return_code == 0): 13 | print(sketch_file_name + ' passed syntax check. ') 14 | assert(output.rfind('Program Parse Error:') == -1) 15 | else: 16 | if (output.rfind('Program Parse Error:') != -1): 17 | raise Exception( 18 | sketch_file_name + ' contains a syntax error.' + 19 | 'Output pasted below:\n\n' + output) 20 | 21 | 22 | def synthesize(sketch_file_name, bnd_inbits, slv_seed, slv_parallel=False): 23 | assert(slv_parallel in [True, False]) 24 | check_syntax(sketch_file_name) 25 | par_string = ' --slv-parallel' if slv_parallel else '' 26 | # Consider switching to subprocess.run as subprocess.getstatusoutput is 27 | # considered legacy. 28 | # https://docs.python.org/3.5/library/subprocess.html#legacy-shell-invocation-functions 29 | (return_code, output) = subprocess.getstatusoutput('time sketch -V 12 ' + 30 | '--slv-nativeints ' + 31 | sketch_file_name + 32 | ' --bnd-inbits=' + 33 | str(bnd_inbits) + 34 | ' --slv-seed=' + 35 | str(slv_seed) + 36 | par_string) 37 | assert(output.rfind('Program Parse Error:') == -1) 38 | return (return_code, output) 39 | 40 | 41 | def generate_smt2_formula(sketch_file_name, smt_file_name, bit_range): 42 | check_syntax(sketch_file_name) 43 | (return_code, output) = subprocess.getstatusoutput('sketch ' + 44 | sketch_file_name + 45 | ' --bnd-inbits=' + 46 | str(bit_range) + 47 | ' --slv-timeout=' + 48 | str(SLV_TIMEOUT_MINS) + 49 | ' --beopt:writeSMT ' + 50 | smt_file_name) 51 | 52 | 53 | def generate_ir(sketch_file_name): 54 | """Given a sketch file, returns its IR (intermediate representation). 55 | 56 | This function calls sketch and generates a .dag file having IR for the 57 | sketch file. Then reads the .dag file and returns its content.""" 58 | check_syntax(sketch_file_name) 59 | # Generate the dag filename by replacing sk extension with dag. 60 | dag_file_name = re.sub('sk$', 'dag', sketch_file_name) 61 | subprocess.run([ 62 | 'sketch', 63 | '-V', '3', 64 | sketch_file_name, 65 | '--debug-output-dag', dag_file_name, 66 | # We only want the dag and sketch output is irrelevant here. So quickly 67 | # return from it using --slv-timeout. 68 | '--slv-seed', '1', 69 | '--slv-timeout', str(SLV_TIMEOUT_MINS) 70 | ], 71 | # Pipe stdout and stderr to /dev/null as we don't need them. 72 | stdout=subprocess.DEVNULL, 73 | stderr=subprocess.DEVNULL) 74 | 75 | return Path(dag_file_name).read_text() 76 | -------------------------------------------------------------------------------- /chipc/templates/code_generator.j2: -------------------------------------------------------------------------------- 1 | // This is an autogenerated sketch file corresponding to 2 | // the router's data path and is used to solve the Chipmunk compilation problem. 3 | // spec_filename = {{ spec_filename }} num_pipeline_stages = {{ num_pipeline_stages }} 4 | // num_alus_per_stage = {{ num_alus_per_stage }} 5 | // num_phv_containers = {{ num_phv_containers }} 6 | 7 | {% if mode.is_CODEGEN() %} 8 | {{ hole_definitions }} 9 | {% elif mode.is_SOL_VERIFY() or mode.is_VERIFY() %} 10 | {{ hole_assignments }} 11 | {% endif %} 12 | 13 | // Definitions of muxes and ALUs of the router 14 | {% include "muxes_and_alus.j2" %} 15 | 16 | struct StateGroup { 17 | {% for slot_number in range(num_state_slots) %} 18 | int state_{{slot_number}}; 19 | {% endfor %} 20 | } 21 | 22 | // Data type for holding result from spec and implementation 23 | struct StateAndPacket { 24 | {% for field_number in range(num_fields_in_prog) %} 25 | int pkt_{{field_number}}; 26 | {% endfor %} 27 | {% for group_number in range(num_state_groups) %} 28 | {% for slot_number in range(num_state_slots) %} 29 | int state_group_{{group_number}}_state_{{slot_number}}; 30 | {% endfor %} 31 | {% endfor %} 32 | } 33 | 34 | // Specification 35 | {{spec_as_sketch}} 36 | 37 | // Implementation 38 | {% include "router_data_path_sketch.j2" %} 39 | -------------------------------------------------------------------------------- /chipc/templates/mux.j2: -------------------------------------------------------------------------------- 1 | int {{mux_name}}({{arg_list|join(',')}}, int {{mux_name}}_ctrl_local) { 2 | {% if num_operands == 1 %} 3 | return {{operand_list[0]}}; 4 | {% else %} 5 | int mux_ctrl = {{mux_name}}_ctrl_local; 6 | if (mux_ctrl == 0) { 7 | return {{operand_list[0]}}; 8 | } 9 | 10 | {% for operand_number in range(1, num_operands - 1) %} 11 | else if (mux_ctrl == {{operand_number}}) { 12 | return {{operand_list[operand_number]}}; 13 | } 14 | {% endfor %} 15 | 16 | else { return {{operand_list[num_operands - 1]}}; } 17 | {% endif %} 18 | } 19 | -------------------------------------------------------------------------------- /chipc/templates/muxes_and_alus.j2: -------------------------------------------------------------------------------- 1 | // Operand muxes for each ALU in each stage 2 | // Total of {{ num_pipeline_stages }} * {{ num_alus_per_stage }} * 3 {{num_phv_containers}}-to-1 muxes 3 | // The 3 is for two stateless operands and one stateful operand. 4 | 5 | {{ stateful_operand_mux_definitions }} 6 | 7 | // Output mux for each PHV container 8 | // Allows the container to be written from either its own stateless ALU or any stateful ALU 9 | 10 | {{ output_mux_definitions }} 11 | 12 | // Definition for ALUs 13 | 14 | {{ alu_definitions }} 15 | -------------------------------------------------------------------------------- /chipc/templates/opt_verify.j2: -------------------------------------------------------------------------------- 1 | // Data type for holding result from spec and implementation 2 | struct StateGroup { 3 | {% for state_number in range(num_state_slots) %} 4 | int state_{{state_number}}; 5 | {% endfor %} 6 | } 7 | 8 | // Data type for holding result from both sketches 9 | struct StateAndPacket { 10 | {% for field_number in range(num_fields_in_prog) %} 11 | int pkt_{{field_number}}; 12 | {% endfor %} 13 | {% for group_number in range(num_state_groups) %} 14 | {% for slot_number in range(num_state_slots) %} 15 | int state_group_{{group_number}}_state_{{slot_number}}; 16 | {% endfor %} 17 | {% endfor %} 18 | } 19 | 20 | include "{{sketch1_file_name}}"; 21 | include "{{sketch2_file_name}}"; 22 | 23 | harness void main( 24 | {% for field_number in range(num_fields_in_prog) %} 25 | int sanp_pkt_{{field_number}}, 26 | {% endfor %} 27 | {% for state_group_number in range(num_state_groups) %} 28 | {% for slot_number in range(num_state_slots) %} 29 | int sanp_state_group_{{state_group_number}}_{{slot_number}} , 30 | {% endfor %} 31 | {% endfor %} 32 | {{hole1_arguments|join(',')}}, 33 | {{hole2_arguments|join(',')}}) { 34 | 35 | // Preconditions: 36 | 37 | {% for hole in sketch1_holes %} 38 | assume(({{hole.name}} >= 0) && ({{hole.name}} <= {{hole.max}})); 39 | {% endfor %} 40 | 41 | {% for predicate in sketch1_asserts %} 42 | assume({{predicate}}); 43 | {% endfor %} 44 | 45 | // Transformation from sketch 1's holes to sketch 2's holes 46 | {{ transform_function }} 47 | 48 | // Check that sketches are equivalent. 49 | StateAndPacket state_and_packet = new StateAndPacket(); 50 | {% for field_number in range(num_fields_in_prog) %} 51 | state_and_packet.pkt_{{field_number}} = sanp_pkt_{{field_number}}; 52 | {% endfor %} 53 | {% for state_group_number in range(num_state_groups) %} 54 | {% for slot_number in range(num_state_slots) %} 55 | state_and_packet.state_group_{{state_group_number}}_state_{{slot_number}} = sanp_state_group_{{state_group_number}}_{{slot_number}}; 56 | {% endfor %} 57 | {% endfor %} 58 | 59 | assert(pipeline@{{sketch1_name}}(state_and_packet, {{sketch1_holes|map(attribute='name')|join(',')}}) == pipeline@{{sketch2_name}}(state_and_packet, {{sketch2_holes|map(attribute='name')|join(',')}})); 60 | 61 | // Postconditions: 62 | 63 | {% for hole in sketch2_holes %} 64 | assert(({{hole.name}} >= 0) && ({{hole.name}} <= {{hole.max}})); 65 | {% endfor %} 66 | 67 | {% for predicate in sketch2_asserts %} 68 | assert({{predicate}}); 69 | {% endfor %} 70 | } 71 | -------------------------------------------------------------------------------- /chipc/templates/router_data_path_sketch.j2: -------------------------------------------------------------------------------- 1 | |StateAndPacket| pipeline (|StateAndPacket| state_and_packet) { 2 | // Any additional constraints to speed up synthesis through parallel execution. 3 | {{additional_constraints}} 4 | 5 | // Consolidate all constraints on holes here. 6 | {{all_assertions}} 7 | 8 | // One variable for each container in the PHV 9 | // Container i will be allocated to packet field i from the spec. 10 | {% for container_number in range(num_phv_containers) %} 11 | int input_0_{{container_number}} = 0; 12 | {% endfor %} 13 | 14 | // One variable for each stateful ALU's state operand 15 | // This will be allocated to a state variable from the program using indicator variables. 16 | {% for stage_number in range(num_pipeline_stages) %} 17 | {% if synthesized_allocation %} 18 | {% for container_number in range(num_phv_containers) %} 19 | |StateGroup| state_operand_salu_{{stage_number}}_{{container_number}} = |StateGroup|( 20 | {% for slot_number in range(num_state_slots-1) %} 21 | state_{{slot_number}} = 0, 22 | {% endfor %} 23 | state_{{num_state_slots-1}} = 0 24 | ); 25 | {% endfor %} 26 | {% else %} 27 | {% for state_group_number in range(num_state_groups) %} 28 | |StateGroup| state_operand_salu_{{stage_number}}_{{state_group_number}} = |StateGroup|( 29 | {% for slot_number in range(num_state_slots-1) %} 30 | state_{{slot_number}} = 0, 31 | {% endfor %} 32 | state_{{num_state_slots-1}} = 0 33 | ); 34 | {% endfor %} 35 | {% endif %} 36 | {% endfor %} 37 | 38 | {% for stage_number in range(num_pipeline_stages) %} 39 | /*********** Stage {{stage_number}} *********/ 40 | 41 | // Inputs 42 | {% if stage_number == 0 %} 43 | // Read each PHV container from corresponding packet field. 44 | {% for field_number in input_packet_fields %} 45 | {% if synthesized_allocation %} 46 | // TODO: deal with the case where not all packets are fed as input 47 | {% for container_number in range(num_phv_containers) %} 48 | if (phv_config_{{field_number}}_{{container_number}} == 1){ 49 | input_0_{{container_number}} = state_and_packet.pkt_{{field_number}}; 50 | } 51 | {% endfor %} 52 | {% else %} 53 | // loop.index starts from 1 that's why we need to -1 54 | input_0_{{loop.index - 1}} = state_and_packet.pkt_{{field_number}}; 55 | {% endif %} 56 | {% endfor %} 57 | 58 | {% else %} 59 | // Input of this stage is the output of the previous one. 60 | {% for container_number in range(num_phv_containers) %} 61 | int input_{{stage_number}}_{{container_number}} = output_{{stage_number-1}}_{{container_number}}; 62 | {% endfor %} 63 | 64 | {% endif %} 65 | 66 | // Stateless ALUs 67 | {% for alu_number in range(num_alus_per_stage) %} 68 | int destination_{{stage_number}}_{{alu_number}} = {{sketch_name}}_stateless_alu_{{stage_number}}_{{alu_number}}( 69 | {% for container_number in range(num_phv_containers) %} 70 | {% if container_number != num_phv_containers - 1 %} 71 | input_{{stage_number}}_{{container_number}}, 72 | {% else %} 73 | input_{{stage_number}}_{{container_number}} 74 | {% endif %} 75 | {% endfor %}, 76 | 77 | {{sketch_name}}_stateless_alu_{{stage_number}}_{{alu_number}}_opcode, 78 | {{sketch_name}}_stateless_alu_{{stage_number}}_{{alu_number}}_immediate_operand, 79 | {% for mux_num in range (num_stateless_muxes) %} 80 | {{sketch_name}}_stateless_alu_{{stage_number}}_{{alu_number}}_operand_mux_{{mux_num}}_ctrl{% if not loop.last %},{% endif %} 81 | {% endfor %} 82 | ); 83 | {% endfor %} 84 | 85 | // Stateful operands 86 | {% if synthesized_allocation %} 87 | {% for stateful_container_num in range(num_phv_containers) %} 88 | {% for operand_number in range(num_operands_to_stateful_alu) %} 89 | int packet_operand_salu{{stage_number}}_{{stateful_container_num}}_{{operand_number}} = {{sketch_name}}_stateful_alu_{{stage_number}}_{{stateful_container_num}}_operand_mux_{{operand_number}}( 90 | {% for container_number in range(num_phv_containers) %} 91 | {% if container_number != num_phv_containers - 1 %} 92 | input_{{stage_number}}_{{container_number}}, 93 | {% else %} 94 | input_{{stage_number}}_{{container_number}} 95 | {% endif %} 96 | {% endfor %} 97 | , {{sketch_name}}_stateful_alu_{{stage_number}}_{{stateful_container_num}}_operand_mux_{{operand_number}}_ctrl); 98 | {% endfor %} 99 | {% endfor %} 100 | {% else %} 101 | {% for state_group_number in range(num_state_groups) %} 102 | {% for operand_number in range(num_operands_to_stateful_alu) %} 103 | int packet_operand_salu{{stage_number}}_{{state_group_number}}_{{operand_number}} = {{sketch_name}}_stateful_alu_{{stage_number}}_{{state_group_number}}_operand_mux_{{operand_number}}( 104 | {% for container_number in range(num_phv_containers) %} 105 | {% if container_number != num_phv_containers - 1 %} 106 | input_{{stage_number}}_{{container_number}}, 107 | {% else %} 108 | input_{{stage_number}}_{{container_number}} 109 | {% endif %} 110 | {% endfor %} 111 | , {{sketch_name}}_stateful_alu_{{stage_number}}_{{state_group_number}}_operand_mux_{{operand_number}}_ctrl); 112 | {% endfor %} 113 | {% endfor %} 114 | {% endif %} 115 | 116 | // Read stateful ALU slots from allocated state vars. 117 | {% if synthesized_allocation %} 118 | {% for state_group_number in range(num_state_groups) %} 119 | {% for container_number in range(num_phv_containers) %} 120 | if ({{sketch_name}}_salu_config_{{state_group_number}}_{{stage_number}}_{{container_number}} == 1){ 121 | state_operand_salu_{{stage_number}}_{{container_number}} = 122 | |StateGroup|({% for slot_number in range(num_state_slots) %} 123 | {% if slot_number < num_state_slots - 1 %} 124 | state_{{slot_number}} = state_and_packet.state_group_{{state_group_number}}_state_{{slot_number}}, 125 | {% else %} 126 | state_{{slot_number}} = state_and_packet.state_group_{{state_group_number}}_state_{{slot_number}} 127 | {% endif %} 128 | {%endfor%}); 129 | } 130 | {% endfor %} 131 | {% endfor %} 132 | {% else %} 133 | {% for state_group_number in range(num_state_groups) %} 134 | if ({{sketch_name}}_salu_config_{{stage_number}}_{{state_group_number}} == 1) { 135 | state_operand_salu_{{stage_number}}_{{state_group_number}} = 136 | |StateGroup|({% for slot_number in range(num_state_slots) %} 137 | {% if slot_number < num_state_slots - 1 %} 138 | state_{{slot_number}} = state_and_packet.state_group_{{state_group_number}}_state_{{slot_number}}, 139 | {% else %} 140 | state_{{slot_number}} = state_and_packet.state_group_{{state_group_number}}_state_{{slot_number}} 141 | {% endif %} 142 | {%endfor%});} 143 | {% endfor %} 144 | {% endif %} 145 | 146 | // Stateful ALUs 147 | // TODO: maybe we need to combine the following if-else branch together because they share a lot of common things 148 | {% if synthesized_allocation %} 149 | {% for container_number in range(num_phv_containers) %} 150 | int returned_state_{{stage_number}}_{{container_number}} = {{sketch_name}}_stateful_alu_{{stage_number}}_{{container_number}}(state_operand_salu_{{stage_number}}_{{container_number}}, 151 | {% for operand_number in range(num_operands_to_stateful_alu) %} 152 | packet_operand_salu{{stage_number}}_{{container_number}}_{{operand_number}}, 153 | {% endfor %} 154 | {% set prefix_string = sketch_name + "_stateful_alu_" + (stage_number|string) + "_" + (container_number|string) + "_" %} 155 | {{stateful_alu_hole_arguments|map('add_prefix_suffix', prefix_string, "_global")|join(',')}}); 156 | {% endfor %} 157 | {% else %} 158 | {% for state_group_number in range(num_state_groups) %} 159 | int returned_state_{{stage_number}}_{{state_group_number}} = {{sketch_name}}_stateful_alu_{{stage_number}}_{{state_group_number}}(state_operand_salu_{{stage_number}}_{{state_group_number}}, 160 | {% for operand_number in range(num_operands_to_stateful_alu) %} 161 | packet_operand_salu{{stage_number}}_{{state_group_number}}_{{operand_number}}, 162 | {% endfor %} 163 | {% set prefix_string = sketch_name + "_stateful_alu_" + (stage_number|string) + "_" + (state_group_number|string) + "_" %} 164 | {{stateful_alu_hole_arguments|map('add_prefix_suffix', prefix_string, "_global")|join(',')}}); 165 | {% endfor %} 166 | {% endif %} 167 | 168 | // Outputs 169 | {% for container_number in range(num_phv_containers) %} 170 | int output_{{stage_number}}_{{container_number}} = {{sketch_name}}_output_mux_phv_{{stage_number}}_{{container_number}}( 171 | {% if synthesized_allocation %} 172 | {% for container_number in range(num_phv_containers) %} 173 | {% for state_slot in range(num_state_slots) %} 174 | returned_state_{{stage_number}}_{{container_number}}, 175 | {% endfor %} 176 | {% endfor %} 177 | {% else %} 178 | {% for state_group_number in range(num_state_groups) %} 179 | {% for state_slot in range(num_state_slots) %} 180 | returned_state_{{stage_number}}_{{state_group_number}}, 181 | {% endfor %} 182 | {% endfor %} 183 | {% endif %} 184 | destination_{{stage_number}}_{{container_number}}, 185 | {{sketch_name}}_output_mux_phv_{{stage_number}}_{{container_number}}_ctrl 186 | ); 187 | {% endfor %} 188 | 189 | {% for state_group_number in range(num_state_groups) %} 190 | // Write stateful_vars 191 | {% if synthesized_allocation %} 192 | {% for container_number in range(num_phv_containers) %} 193 | if ({{sketch_name}}_salu_config_{{state_group_number}}_{{stage_number}}_{{container_number}} == 1){ 194 | {% for slot_number in range(num_state_slots) %} 195 | state_and_packet.state_group_{{state_group_number}}_state_{{slot_number}} = state_operand_salu_{{stage_number}}_{{container_number}}.state_{{slot_number}}; 196 | {% endfor %} 197 | } 198 | {% endfor %} 199 | {% else %} 200 | if ({{sketch_name}}_salu_config_{{stage_number}}_{{state_group_number}} == 1) { 201 | {% for slot_number in range(num_state_slots) %} 202 | state_and_packet.state_group_{{state_group_number}}_state_{{slot_number}} = state_operand_salu_{{stage_number}}_{{state_group_number}}.state_{{slot_number}}; 203 | {% endfor %} 204 | } 205 | {% endif %} 206 | {% endfor %} 207 | {% endfor %} 208 | 209 | {% for field_number in output_packet_fields %} 210 | // Write pkt_{{field_number}} 211 | {% if synthesized_allocation %} 212 | {% for container_number in range(num_phv_containers) %} 213 | if (phv_config_{{field_number}}_{{container_number}} == 1){ 214 | state_and_packet.pkt_{{field_number}} = output_{{num_pipeline_stages - 1}}_{{container_number}}; 215 | } 216 | {% endfor %} 217 | {% else%} 218 | state_and_packet.pkt_{{field_number}} = output_{{num_pipeline_stages - 1}}_{{loop.index - 1}}; 219 | {% endif %} 220 | {% endfor %} 221 | 222 | // Return updated packet fields and state vars 223 | return state_and_packet; 224 | } 225 | 226 | {% if mode.is_CODEGEN() or mode.is_VERIFY() or mode.is_SOL_VERIFY() %} 227 | harness void main( 228 | {{range(num_fields_in_prog)|map('add_prefix_suffix', "int pkt_", "")|join(',')}} 229 | {% for state_group_number in range(num_state_groups) %} 230 | {% for slot_number in range(num_state_slots) %} 231 | , int state_group_{{state_group_number}}_state_{{slot_number}} 232 | {% endfor %} 233 | {% endfor %}) { 234 | 235 | |StateAndPacket| x = |StateAndPacket|({% for field_number in range(num_fields_in_prog) %} 236 | pkt_{{field_number}} = pkt_{{field_number}}, 237 | {% endfor %} 238 | {% for state_group_number in range(num_state_groups) %} 239 | {% for slot_number in range(num_state_slots) %} 240 | {% if (state_group_number == num_state_groups - 1) and (slot_number == num_state_slots - 1) %} 241 | state_group_{{state_group_number}}_state_{{slot_number}} = state_group_{{state_group_number}}_state_{{slot_number}} 242 | {% else %} 243 | state_group_{{state_group_number}}_state_{{slot_number}} = state_group_{{state_group_number}}_state_{{slot_number}}, 244 | {% endif %} 245 | {% endfor %} 246 | {% endfor %}); 247 | 248 | |StateAndPacket| pipeline_result = pipeline(x); 249 | |StateAndPacket| program_result = program(x); 250 | 251 | {% for state_group_number in output_state_groups %} 252 | {% for slot_number in range(num_state_slots) %} 253 | assert(pipeline_result.state_group_{{state_group_number}}_state_{{slot_number}} 254 | == program_result.state_group_{{state_group_number}}_state_{{slot_number}}); 255 | {% endfor %} 256 | {% endfor %} 257 | 258 | {% for field_number in output_packet_fields %} 259 | assert(pipeline_result.pkt_{{field_number}} == program_result.pkt_{{field_number}}); 260 | {% endfor %} 261 | 262 | {{additional_testcases}} 263 | 264 | } 265 | {% endif %} 266 | -------------------------------------------------------------------------------- /chipc/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for Chipmunk""" 2 | import math 3 | from collections import OrderedDict 4 | from pathlib import Path 5 | from re import findall 6 | 7 | from ordered_set import OrderedSet 8 | 9 | 10 | def get_state_group_info(program): 11 | """Returns a dictionary from state group indices to set of state variables 12 | indices. 13 | For state_group_0_state_1, the dict will have an entry {0: set(1)}""" 14 | 15 | state_group_info = OrderedDict() 16 | for i, j in findall( 17 | r'state_and_packet.state_group_(\d+)_state_(\d+)', program): 18 | indices = state_group_info.get(i, OrderedSet()) 19 | indices.add(j) 20 | state_group_info[i] = indices 21 | 22 | return state_group_info 23 | 24 | 25 | def get_num_pkt_fields(program): 26 | """Returns number of packet fields in the program. 27 | If this returns n, packet field indices are from 0 to n-1. 28 | """ 29 | pkt_fields = set() 30 | for x in findall(r'state_and_packet.pkt_(\d+)', program): 31 | pkt_fields.add(int(x)) 32 | 33 | return len(pkt_fields) 34 | 35 | 36 | def get_hole_dicts(sketch): 37 | """Returns a dictionary from hole names to hole bit sizes given a sketch. 38 | """ 39 | return { 40 | name: bits 41 | for name, bits in findall(r'(\w+)= \?\?\((\d+)\);', sketch) 42 | } 43 | 44 | 45 | def get_hole_value_assignments(hole_names, sketch): 46 | """Returns a dictionary from hole names to hole value assignments given a 47 | list of hole names and a completed sketch. 48 | """ 49 | 50 | holes_to_values = {} 51 | 52 | for name in hole_names: 53 | values = findall('' + name + '__' + r'\w+ = (\d+)', sketch) 54 | assert len(values) == 1, ( 55 | 'Unexpected number of assignment statements found for hole %s, ' 56 | 'with values %s' % (name, values)) 57 | holes_to_values[name] = values[0] 58 | 59 | return holes_to_values 60 | 61 | 62 | def get_hole_bit_width(k: int) -> int: 63 | """Returns the number of bits to represent k number of different options 64 | in binary.""" 65 | 66 | return math.ceil(math.log2(k)) 67 | 68 | 69 | def compilation_success(sketch_name, hole_assignments, output): 70 | print('Compilation succeeded. Hole value assignments are following:') 71 | for hole, value in sorted(hole_assignments.items()): 72 | print('int', hole, '=', value) 73 | p = Path(sketch_name + '.success') 74 | p.write_text(output) 75 | print('Output left in', p.name) 76 | 77 | 78 | def compilation_failure(sketch_name, output): 79 | print('Compilation failed.') 80 | p = Path(sketch_name + '.errors') 81 | p.write_text(output) 82 | print('Output left in', p.name) 83 | -------------------------------------------------------------------------------- /chipc/z3_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import re 3 | 4 | import z3 5 | 6 | 7 | def parse_smt2_file(smt2_filename): 8 | """Reads a smt2 file and returns the first formula. 9 | 10 | Args: 11 | smt2_filename: smt2 file that was generated from Sketch. 12 | 13 | Raises: 14 | An assertion if the original smt2 file didn't contain any assert 15 | statements. 16 | """ 17 | # parse_smt2_file returns a vector of ASTs, and each element corresponds to 18 | # one assert statement in the original file. The smt2 file generated by 19 | # sketch only has one assertions, simply take the first. 20 | formulas = z3.parse_smt2_file(smt2_filename) 21 | assert len(formulas) == 1, (smt2_filename, 22 | 'contains 0 or more than 1 asserts.') 23 | return formulas[0] 24 | 25 | 26 | def negated_body(formula): 27 | """Given a z3.QuantiferRef formula with z3.Int variables, 28 | return negation of the body. 29 | 30 | Returns: 31 | A z3.BoolRef which is the negation of the formula body. 32 | """ 33 | assert z3.is_quantifier(formula), ('Formula is not a quantifier:\n', 34 | formula) 35 | var_names = [formula.var_name(i) for i in range(formula.num_vars())] 36 | vs = [z3.Int(n) for n in var_names] 37 | 38 | # Here simply doing z3.Not(formula.body()) doesn't work. formula.body() 39 | # returns an expression without any bounded variable, i.e., it refers 40 | # variables using indices in the order they appear instead of its names, 41 | # Var(0), Var(1), Var(2). Thus, we have to re-bind the variables using 42 | # substitute_vars. It is also necessary to reverse the list of variables. 43 | # See https://github.com/Z3Prover/z3/issues/402 for more details. 44 | return z3.Not(z3.substitute_vars(formula.body(), *reversed(vs))) 45 | 46 | 47 | def generate_counterexamples(formula): 48 | """Given a z3 formula generated from a sketch, returns counterexample 49 | values for the formula. 50 | 51 | Returns: 52 | A tuple of two dicts from string to ints, where the first one 53 | represents counterexamples for packet variables and the second for 54 | state group variables. 55 | """ 56 | # We negate the body of formula, and check whether the new formula is 57 | # satisfiable. If so, we extract the input values and they are 58 | # counterexamples for the original formula. Otherwise, the original formula 59 | # is satisfiable and there is no counterexample. 60 | new_formula = negated_body(formula) 61 | 62 | z3_slv = z3.Solver() 63 | # To get counterexamples we need to set proof and unsat_core. Random seed 64 | # is set for determinism. 65 | z3_slv.set(proof=True, unsat_core=True, random_seed=1) 66 | z3_slv.add(new_formula) 67 | 68 | # Use OrderedDict here for deterministic compilation results. We can also 69 | # use built-in dict() for Python versions 3.6 and later, as it's inherently 70 | # ordered. 71 | pkt_fields = collections.OrderedDict() 72 | state_vars = collections.OrderedDict() 73 | 74 | result = z3_slv.check() 75 | if result != z3.sat: 76 | print('Failed to generate counterexamples, z3 returned', result) 77 | return (pkt_fields, state_vars) 78 | 79 | model = z3_slv.model() 80 | for var in model.decls(): 81 | value = model.get_interp(var).as_long() 82 | match_object = re.match(r'pkt_\d+', var.name()) 83 | if match_object: 84 | var_name = match_object.group(0) 85 | pkt_fields[var_name] = value 86 | continue 87 | 88 | match_object = re.match(r'state_group_\d+_state_\d+', var.name()) 89 | if match_object: 90 | var_name = match_object.group(0) 91 | state_vars[var_name] = value 92 | 93 | return (pkt_fields, state_vars) 94 | 95 | 96 | def check_sort(z3_var): 97 | assert (z3.is_bool(z3_var) or z3.is_int(z3_var)),\ 98 | str(z3_var) + ' has unsupported type ' + str(type(z3_var)) 99 | 100 | 101 | def make_int(z3_var): 102 | if z3.is_bool(z3_var): 103 | return z3.If(z3_var, 1, 0) 104 | else: 105 | return z3_var 106 | 107 | 108 | def make_bool(z3_var): 109 | if z3.is_int(z3_var): 110 | # Use > 0 to convert int to bool as per BooleanNodes.h 111 | # in the SKETCH code base. 112 | return z3_var > 0 113 | else: 114 | return z3_var 115 | 116 | 117 | def get_z3_formula(sketch_ir: str, input_bits: int) -> z3.QuantifierRef: 118 | """Given an intermediate representation of a sketch file and returns a z3 119 | formula corresponding to that IR with the specified input bits for source 120 | variables.""" 121 | 122 | z3_vars = collections.OrderedDict() 123 | z3_asserts = [] 124 | z3_srcs = [] 125 | for line in sketch_ir.splitlines(): 126 | records = line.split() 127 | start = records[0] 128 | if (start in ['dag', 'TUPLE_DEF']): 129 | continue 130 | else: 131 | # common processing across all nodes 132 | output_var = '_n' + records[0] 133 | operation = records[2] 134 | if operation in ['NEG', 'NOT']: 135 | operand1 = z3_vars['_n' + records[4]] 136 | check_sort(operand1) 137 | elif operation in ['AND', 'OR', 'XOR', 'PLUS', 138 | 'TIMES', 'DIV', 'MOD', 'LT', 139 | 'EQ']: 140 | operand1 = z3_vars['_n' + records[4]] 141 | operand2 = z3_vars['_n' + records[5]] 142 | check_sort(operand1) 143 | check_sort(operand2) 144 | 145 | # node-specific processing 146 | if operation == 'ASSERT': 147 | z3_asserts += ['_n' + records[3]] 148 | elif operation == 'S': 149 | var_type = records[3] 150 | source_name = records[4] 151 | assert var_type == 'INT', ('Unexpected variable type found in \ 152 | sketch IR:', line) 153 | z3_vars[source_name] = z3.Int(source_name) 154 | z3_vars[output_var] = z3.Int(source_name) 155 | z3_srcs += [source_name] 156 | elif operation in ['NEG']: 157 | z3_vars[output_var] = -make_int(operand1) 158 | elif operation in ['NOT']: 159 | z3_vars[output_var] = z3.Not(make_bool(operand1)) 160 | elif operation in [ 161 | 'AND', 'OR', 'XOR', 'PLUS', 'TIMES', 'DIV', 'MOD', 'LT', 162 | 'EQ' 163 | ]: 164 | if operation == 'AND': 165 | z3_vars[output_var] = z3.And( 166 | make_bool(operand1), make_bool(operand2)) 167 | elif operation == 'OR': 168 | z3_vars[output_var] = z3.Or( 169 | make_bool(operand1), make_bool(operand2)) 170 | elif operation == 'XOR': 171 | z3_vars[output_var] = z3.Xor( 172 | make_bool(operand1), make_bool(operand2)) 173 | elif operation == 'PLUS': 174 | z3_vars[output_var] = make_int( 175 | operand1) + make_int(operand2) 176 | elif operation == 'TIMES': 177 | z3_vars[output_var] = make_int( 178 | operand1) * make_int(operand2) 179 | elif operation == 'DIV': 180 | z3_vars[output_var] = make_int( 181 | operand1) / make_int(operand2) 182 | elif operation == 'MOD': 183 | z3_vars[output_var] = make_int( 184 | operand1) % make_int(operand2) 185 | elif operation == 'LT': 186 | z3_vars[output_var] = make_int( 187 | operand1) < make_int(operand2) 188 | elif operation == 'EQ': 189 | z3_vars[output_var] = make_int( 190 | operand1) == make_int(operand2) 191 | else: 192 | assert False, ('Invalid operation', operation) 193 | # One can consider ARRACC and ARRASS as array access and 194 | # assignment. For more details please refer this sketchusers 195 | # mailing list thread. 196 | # https://lists.csail.mit.edu/pipermail/sketchusers/2019-August/000104.html 197 | elif operation in ['ARRACC']: 198 | predicate = make_bool((z3_vars['_n' + records[4]])) 199 | yes_val = z3_vars['_n' + records[7]] 200 | no_val = z3_vars['_n' + records[6]] 201 | z3_vars[output_var] = z3.If(predicate, yes_val, no_val) 202 | elif operation in ['ARRASS']: 203 | var_type = type(z3_vars['_n' + records[4]]) 204 | if var_type == z3.BoolRef: 205 | assert records[6] in ['0', '1'] 206 | cmp_constant = records[6] == '1' 207 | elif var_type == z3.ArithRef: 208 | cmp_constant = int(records[6]) 209 | else: 210 | assert False, ('Variable type', var_type, 'not supported') 211 | predicate = z3_vars['_n' + records[4]] == cmp_constant 212 | yes_val = z3_vars['_n' + records[8]] 213 | no_val = z3_vars['_n' + records[7]] 214 | z3_vars[output_var] = z3.If(predicate, yes_val, no_val) 215 | elif operation in ['CONST']: 216 | var_type = records[3] 217 | if var_type == 'INT': 218 | z3_vars[output_var] = z3.IntVal(int(records[4])) 219 | elif var_type == 'BOOL': 220 | assert records[4] in ['0', '1'] 221 | z3_vars[output_var] = z3.BoolVal(records[4] == '1') 222 | else: 223 | assert False, ('Constant type', var_type, 'not supported') 224 | else: 225 | assert False, ('Unknown operation:', line) 226 | 227 | # To handle cases where we don't have any assert or source variable, add 228 | # a dummy bool variable. 229 | constraints = z3.BoolVal(True) 230 | for var in z3_asserts: 231 | constraints = z3.And(constraints, z3_vars[var]) 232 | 233 | variable_range = z3.BoolVal(True) 234 | for var in z3_srcs: 235 | variable_range = z3.And( 236 | variable_range, 237 | z3.And(0 <= z3_vars[var], z3_vars[var] < 2**input_bits)) 238 | 239 | final_assert = z3.ForAll([z3_vars[x] for x in z3_srcs], 240 | z3.Implies(variable_range, constraints)) 241 | # We could use z3.simplify on the final assert, however that could result 242 | # in a formula that is oversimplified and doesn't have a QuantfierRef which 243 | # is expected from the negated_body() function above. 244 | return final_assert 245 | 246 | 247 | def simple_check(smt2_filename): 248 | """Given a smt2 file generated from a sketch, parses assertion from the 249 | file and checks with z3. We assume that the file already has input bit 250 | ranges defined by sketch. 251 | 252 | Returns: 253 | True if satisfiable else False. 254 | """ 255 | formula = parse_smt2_file(smt2_filename) 256 | 257 | # The original formula's body is comprised of Implies(A, B) where A 258 | # specifies range of input variables and where B is a condition. We're 259 | # interested to check whether B is True within the range specified by A 260 | 261 | z3_slv = z3.Solver() 262 | z3_slv.add(formula) 263 | 264 | return z3_slv.check() == z3.sat 265 | -------------------------------------------------------------------------------- /example_alus/stateful_alus/if_else_raw.alu: -------------------------------------------------------------------------------- 1 | type: stateful 2 | state variables : {state_0} 3 | hole variables : {} 4 | packet fields : {pkt_0, pkt_1} 5 | 6 | int old_state_0 = state_0; 7 | if (rel_op(Opt(state_0), Mux3(pkt_0, pkt_1, C()))) { 8 | state_0 = Opt(state_0) + Mux3(pkt_0, pkt_1, C()); 9 | } 10 | else { 11 | state_0 = Opt(state_0) + Mux3(pkt_0, pkt_1, C()); 12 | } 13 | return Mux2(old_state_0, state_0); 14 | -------------------------------------------------------------------------------- /example_alus/stateful_alus/nested_ifs.alu: -------------------------------------------------------------------------------- 1 | type : stateful 2 | state variables : {state_0} 3 | hole variables : {} 4 | packet fields : {pkt_0, pkt_1} 5 | 6 | int old_state_0 = state_0; 7 | if (rel_op(Opt(state_0) + Mux3(pkt_0, pkt_1, 0) - Mux3(pkt_0, pkt_1, 0), C())) { 8 | if (rel_op(Opt(state_0) + Mux3(pkt_0, pkt_1, 0) - Mux3(pkt_0, pkt_1, 0), C())) { 9 | state_0 = Opt(state_0) + arith_op(Mux3(pkt_0, pkt_1, C()), Mux3(pkt_0, pkt_1, C())); 10 | } else { 11 | state_0 = Opt(state_0) + arith_op(Mux3(pkt_0, pkt_1, C()), Mux3(pkt_0, pkt_1, C())); 12 | } 13 | } else { 14 | if (rel_op(Opt(state_0) + Mux3(pkt_0, pkt_1, 0) - Mux3(pkt_0, pkt_1, 0), C())) { 15 | state_0 = Opt(state_0) + arith_op(Mux3(pkt_0, pkt_1, C()), Mux3(pkt_0, pkt_1, C())); 16 | } else { 17 | state_0 = Opt(state_0) + arith_op(Mux3(pkt_0, pkt_1, C()), Mux3(pkt_0, pkt_1, C())); 18 | } 19 | } 20 | return Mux2(old_state_0, state_0); 21 | -------------------------------------------------------------------------------- /example_alus/stateful_alus/pair.alu: -------------------------------------------------------------------------------- 1 | type : stateful 2 | state variables : {state_0, state_1} 3 | hole variables : {} 4 | packet fields : {pkt_1, pkt_2, pkt_3, pkt_4, pkt_5} 5 | 6 | int old_state_0 = state_0; 7 | int old_state_1 = state_1; 8 | if (rel_op(Mux3(state_0, state_1, 0) + Mux3(pkt_1, pkt_2, 0) - Mux3(pkt_1, pkt_2, 0), C())) { 9 | if (rel_op(Mux3(state_0, state_1, 0) + Mux3(pkt_1, pkt_2, 0) - Mux3(pkt_1, pkt_2, 0), C())) { 10 | state_0 = Opt(state_0) + arith_op(Mux3(pkt_1, pkt_2, C()), Mux3(pkt_1, pkt_2, C())); 11 | state_1 = Opt(state_1) + arith_op(Mux3(pkt_1, pkt_2, C()), Mux3(pkt_1, pkt_2, C())); 12 | } else { 13 | state_0 = Opt(state_0) + arith_op(Mux3(pkt_1, pkt_2, C()), Mux3(pkt_1, pkt_2, C())); 14 | state_1 = Opt(state_1) + arith_op(Mux3(pkt_1, pkt_2, C()), Mux3(pkt_1, pkt_2, C())); 15 | } 16 | } elif (rel_op(Mux3(state_0, state_1, 0) + Mux3(pkt_1, pkt_2, 0) - Mux3(pkt_1, pkt_2, 0), C())) { 17 | if (rel_op(Mux3(state_0, state_1, 0) + Mux3(pkt_1, pkt_2, 0) - Mux3(pkt_1, pkt_2, 0), C())) { 18 | state_0 = Opt(state_0) + arith_op(Mux3(pkt_1, pkt_2, C()), Mux3(pkt_1, pkt_2, C())); 19 | state_1 = Opt(state_1) + arith_op(Mux3(pkt_1, pkt_2, C()), Mux3(pkt_1, pkt_2, C())); 20 | } else { 21 | state_0 = Opt(state_0) + arith_op(Mux3(pkt_1, pkt_2, C()), Mux3(pkt_1, pkt_2, C())); 22 | state_1 = Opt(state_1) + arith_op(Mux3(pkt_1, pkt_2, C()), Mux3(pkt_1, pkt_2, C())); 23 | } 24 | } 25 | 26 | return Mux4(old_state_0, old_state_1, state_0, state_1); 27 | -------------------------------------------------------------------------------- /example_alus/stateful_alus/pred_raw.alu: -------------------------------------------------------------------------------- 1 | type : stateful 2 | state variables : {state_0} 3 | hole variables : {} 4 | packet fields : {pkt_0, pkt_1} 5 | 6 | int old_state_0 = state_0; 7 | if (rel_op(Opt(state_0), Mux3(pkt_0, pkt_1, C()))) { 8 | state_0 = Opt(state_0) + Mux3(pkt_0, pkt_1, C()); 9 | } 10 | return Mux2(old_state_0, state_0); 11 | -------------------------------------------------------------------------------- /example_alus/stateful_alus/raw.alu: -------------------------------------------------------------------------------- 1 | type : stateful 2 | state variables : {state_0} 3 | hole variables : {} 4 | packet fields : {pkt_0} 5 | 6 | int old_state_0 = state_0; 7 | state_0 = Opt(state_0) + Mux2(pkt_0, C()); 8 | return Mux2(old_state_0, state_0); 9 | -------------------------------------------------------------------------------- /example_alus/stateful_alus/sub.alu: -------------------------------------------------------------------------------- 1 | type : stateful 2 | state variables : {state_0} 3 | hole variables : {} 4 | packet fields : {pkt_0, pkt_1} 5 | 6 | int old_state_0 = state_0; 7 | if (rel_op(Opt(state_0), Mux3(pkt_0, pkt_1, C()))) { 8 | state_0 = Opt(state_0) + 9 | arith_op(Mux3(pkt_0, pkt_1, C()), Mux3(pkt_0, pkt_1, C())); 10 | } else { 11 | state_0 = Opt(state_0) + 12 | arith_op(Mux3(pkt_0, pkt_1, C()), Mux3(pkt_0, pkt_1, C())); 13 | } 14 | return Mux2(old_state_0, state_0); 15 | -------------------------------------------------------------------------------- /example_alus/stateless_alus/stateless_alu.alu: -------------------------------------------------------------------------------- 1 | // Max value of opcode is 20 2 | type : stateless 3 | state variables : {} 4 | hole variables : {opcode, immediate_operand} 5 | packet fields : {pkt_0, pkt_1, pkt_2} 6 | 7 | if (opcode == 0) { 8 | return immediate_operand; 9 | } elif (opcode == 1) { 10 | return pkt_0 + pkt_1; 11 | } elif (opcode == 2) { 12 | return pkt_0 + immediate_operand; 13 | } elif (opcode == 3) { 14 | return pkt_0 - pkt_1; 15 | } elif (opcode == 4) { 16 | return pkt_0 - immediate_operand; 17 | } elif (opcode == 5) { 18 | return immediate_operand - pkt_0; 19 | } elif (opcode == 6) { 20 | return pkt_0!=pkt_1; 21 | } elif (opcode == 7) { 22 | return (pkt_0 != immediate_operand); 23 | } elif (opcode == 8) { 24 | return (pkt_0 == pkt_1); 25 | } elif (opcode == 9) { 26 | return (pkt_0 == immediate_operand); 27 | } elif (opcode == 10) { 28 | return (pkt_0 >= pkt_1); 29 | } elif (opcode == 11) { 30 | return (pkt_0 >= immediate_operand); 31 | } elif (opcode == 12) { 32 | return (pkt_0 < pkt_1); 33 | } elif (opcode == 13) { 34 | return (pkt_0 < immediate_operand); 35 | } elif (opcode == 14) { 36 | return pkt_0 != 0 ? pkt_1 : pkt_2; 37 | } elif (opcode == 15) { 38 | return pkt_0 != 0 ? pkt_1 : immediate_operand; 39 | } elif (opcode == 16) { 40 | return ((pkt_0 != 0) || (pkt_1 != 0)); 41 | } elif (opcode == 17) { 42 | return ((pkt_0 != 0) || (immediate_operand != 0)); 43 | } elif (opcode == 18) { 44 | return ((pkt_0 != 0) && (pkt_1 != 0)); 45 | } elif (opcode == 19) { 46 | return ((pkt_0 != 0) && (immediate_operand != 0)); 47 | } else { 48 | return (pkt_0 == 0); 49 | } 50 | -------------------------------------------------------------------------------- /example_alus/stateless_alus/stateless_alu_arith.alu: -------------------------------------------------------------------------------- 1 | // Max value of opcode is 5 2 | type : stateless 3 | state variables : {} 4 | hole variables : {opcode, immediate_operand} 5 | packet fields : {pkt_0, pkt_1} 6 | 7 | if (opcode == 0){ 8 | return immediate_operand; 9 | } elif (opcode == 1 ){ 10 | return pkt_0 + pkt_1; 11 | } elif (opcode == 2){ 12 | return pkt_0 + immediate_operand; 13 | } elif (opcode == 3){ 14 | return pkt_0 - pkt_1; 15 | } elif (opcode == 4){ 16 | return pkt_0 - immediate_operand; 17 | } else{ 18 | return immediate_operand - pkt_0; 19 | } 20 | -------------------------------------------------------------------------------- /example_alus/stateless_alus/stateless_alu_arith_rel.alu: -------------------------------------------------------------------------------- 1 | // Max value of opcode is 13 2 | type : stateless 3 | state variables : {} 4 | hole variables : {opcode, immediate_operand} 5 | packet fields : {pkt_0, pkt_1} 6 | 7 | if (opcode == 0) { 8 | return immediate_operand; 9 | } elif (opcode == 1) { 10 | return pkt_0 + pkt_1; 11 | } elif (opcode == 2) { 12 | return pkt_0 + immediate_operand; 13 | } elif (opcode == 3) { 14 | return pkt_0 - pkt_1; 15 | } elif (opcode == 4) { 16 | return pkt_0 - immediate_operand; 17 | } elif (opcode == 5) { 18 | return immediate_operand - pkt_0; 19 | } elif (opcode == 6) { 20 | return (pkt_0 != pkt_1); 21 | } elif (opcode == 7) { 22 | return (pkt_0 != immediate_operand); 23 | } elif (opcode == 8) { 24 | return (pkt_0 == pkt_1); 25 | } elif (opcode == 9) { 26 | return (pkt_0 == immediate_operand); 27 | } elif (opcode == 10) { 28 | return (pkt_0 >= pkt_1); 29 | } elif (opcode == 11) { 30 | return (pkt_0 >= immediate_operand); 31 | } elif (opcode == 12) { 32 | return (pkt_0 < pkt_1); 33 | } else { 34 | return (pkt_0 < immediate_operand); 35 | } 36 | -------------------------------------------------------------------------------- /example_alus/stateless_alus/stateless_alu_arith_rel_cond.alu: -------------------------------------------------------------------------------- 1 | // Max value of opcode is 15 2 | type :stateless 3 | state variables : {} 4 | hole variables : {opcode, immediate_operand} 5 | packet fields : {pkt_0, pkt_1, pkt_2} 6 | 7 | if (opcode == 0) { 8 | return immediate_operand; 9 | } elif (opcode == 1) { 10 | return pkt_0 + pkt_1; 11 | } elif (opcode == 2) { 12 | return pkt_0 + immediate_operand; 13 | } elif (opcode == 3) { 14 | return pkt_0 - pkt_1; 15 | } elif (opcode == 4) { 16 | return pkt_0 - immediate_operand; 17 | } elif (opcode == 5) { 18 | return immediate_operand - pkt_0; 19 | } elif (opcode == 6) { 20 | return (pkt_0 != pkt_1); 21 | } elif (opcode == 7) { 22 | return ( pkt_0 != immediate_operand); 23 | } elif (opcode == 8) { 24 | return (pkt_0 == pkt_1); 25 | } elif (opcode == 9) { 26 | return (pkt_0 == immediate_operand); 27 | } elif (opcode == 10) { 28 | return (pkt_0 >= pkt_1); 29 | } elif (opcode == 11) { 30 | return (pkt_0 >= immediate_operand); 31 | } elif (opcode == 12) { 32 | return (pkt_0 < pkt_1); 33 | } elif (opcode == 13) { 34 | return (pkt_0 < immediate_operand); 35 | } elif (opcode == 14) { 36 | return pkt_0 != 0 ? pkt_1 : pkt_2; 37 | } else { 38 | return pkt_0 != 0 ? pkt_1 : immediate_operand 39 | } 40 | -------------------------------------------------------------------------------- /example_alus/stateless_alus/stateless_alu_arith_rel_cond_bool.alu: -------------------------------------------------------------------------------- 1 | // Max value of opcode is 20 2 | type : stateless 3 | state variables : {} 4 | hole variables : {opcode, immediate_operand} 5 | packet fields : {pkt_0, pkt_1, pkt_2} 6 | 7 | if (opcode == 0) { 8 | return immediate_operand; 9 | } elif (opcode == 1) { 10 | return pkt_0 + pkt_1; 11 | } elif (opcode == 2) { 12 | return pkt_0 + immediate_operand; 13 | } elif (opcode == 3) { 14 | return pkt_0 - pkt_1; 15 | } elif (opcode == 4) { 16 | return pkt_0 - immediate_operand; 17 | } elif (opcode == 5) { 18 | return immediate_operand - pkt_0; 19 | } elif (opcode == 6) { 20 | return ( pkt_0 != pkt_1); 21 | } elif (opcode == 7) { 22 | return (pkt_0 != immediate_operand); 23 | } elif (opcode == 8) { 24 | return (pkt_0 == pkt_1); 25 | } elif (opcode == 9) { 26 | return (pkt_0 == immediate_operand); 27 | } elif (opcode == 10) { 28 | return (pkt_0 >= pkt_1); 29 | } elif (opcode == 11) { 30 | return (pkt_0 >= immediate_operand); 31 | } elif (opcode == 12) { 32 | return ( pkt_0 < pkt_1); 33 | } elif (opcode == 13) { 34 | return ( pkt_0 < immediate_operand); 35 | } elif (opcode == 14) { 36 | return pkt_0 != 0 ? pkt_1 : pkt_2; 37 | } elif (opcode == 15) { 38 | return pkt_0 != 0 ? pkt_1 : immediate_operand; 39 | } elif (opcode == 16) { 40 | return ((pkt_0 != 0) || (pkt_1 != 0)); 41 | } elif (opcode == 17) { 42 | return ((pkt_0 != 0) || (immediate_operand != 0)); 43 | } elif (opcode == 18) { 44 | return ((pkt_0 != 0) && (pkt_1 != 0)); 45 | } elif (opcode == 19) { 46 | return ((pkt_0 != 0) && (immediate_operand != 0)); 47 | } else { 48 | return (pkt_0 == 0); 49 | } 50 | -------------------------------------------------------------------------------- /example_specs/blue_decrease.sk: -------------------------------------------------------------------------------- 1 | /* 2 | // Original program: 3 | struct Packet { 4 | int loss; 5 | int qlen; 6 | int pkt_0; 7 | int link_idle; 8 | int cond1; 9 | int pkt_1; 10 | }; 11 | int state_group_1_state_0; 12 | int state_group_0_state_0; 13 | void func(struct Packet p) { 14 | p.pkt_1 = p.pkt_0 - 10; 15 | if (p.pkt_1 > state_group_1_state_0) { 16 | state_group_0_state_0 = state_group_0_state_0 - 2; 17 | state_group_1_state_0 = p.pkt_0; 18 | } 19 | } 20 | */ 21 | 22 | | StateAndPacket | program(| StateAndPacket | state_and_packet) { 23 | state_and_packet.pkt_1 = state_and_packet.pkt_0 - 1; 24 | if (state_and_packet.pkt_1 > state_and_packet.state_group_1_state_0) { 25 | state_and_packet.state_group_0_state_0 = 26 | state_and_packet.state_group_0_state_0 - 2; 27 | state_and_packet.state_group_1_state_0 = state_and_packet.pkt_0; 28 | } 29 | return state_and_packet; 30 | } 31 | -------------------------------------------------------------------------------- /example_specs/blue_increase.sk: -------------------------------------------------------------------------------- 1 | /* 2 | // Original program: 3 | struct Packet { 4 | int loss; 5 | int qlen; 6 | int pkt_0; 7 | int link_idle; 8 | int cond1; 9 | int pkt_1; 10 | }; 11 | int state_group_1_state_0; 12 | int state_group_0_state_0; 13 | void func(struct Packet p) { 14 | p.pkt_1 = p.pkt_0 - 10; 15 | if (p.pkt_1 > state_group_1_state_0) { 16 | state_group_0_state_0 = state_group_0_state_0 + 1; 17 | state_group_1_state_0 = p.pkt_0; 18 | } 19 | } 20 | */ 21 | 22 | | StateAndPacket | program(| StateAndPacket | state_and_packet) { 23 | state_and_packet.pkt_1 = state_and_packet.pkt_0 - 1; 24 | if (state_and_packet.pkt_1 > state_and_packet.state_group_1_state_0) { 25 | state_and_packet.state_group_0_state_0 = 26 | state_and_packet.state_group_0_state_0 + 1; 27 | state_and_packet.state_group_1_state_0 = state_and_packet.pkt_0; 28 | } 29 | return state_and_packet; 30 | } 31 | -------------------------------------------------------------------------------- /example_specs/learn_filter_modified_for_test.sk: -------------------------------------------------------------------------------- 1 | /* 2 | // Original program: 3 | struct Packet{ 4 | int sport; 5 | int dport; 6 | int pkt_0; 7 | int filter1_idx; 8 | int filter2_idx; 9 | int filter3_idx; 10 | }; 11 | int state_group_0_state_0 = {0}; 12 | int state_group_1_state_0 = {0}; 13 | int state_group_2_state_0 = {0}; 14 | void func(struct Packet p){ 15 | p.pkt_0=p.pkt_0; 16 | state_group_0_state_0=1; 17 | state_group_1_state_0=1; 18 | state_group_2_state_0=1; 19 | } */ 20 | 21 | |StateAndPacket| program (|StateAndPacket| state_and_packet) { 22 | state_and_packet.pkt_0=state_and_packet.pkt_0; 23 | state_and_packet.state_group_0_state_0=1; 24 | state_and_packet.state_group_1_state_0=1; 25 | state_and_packet.state_group_2_state_0=1; 26 | return state_and_packet; 27 | } 28 | -------------------------------------------------------------------------------- /example_specs/marple_new_flow.sk: -------------------------------------------------------------------------------- 1 | /* 2 | // Original program: 3 | int state_group_0_state_0 = 0; 4 | struct Packet { 5 | int pkt_0; 6 | }; 7 | void func(struct Packet p) { 8 | if (state_group_0_state_0 == 0) { 9 | state_group_0_state_0 = 1; 10 | p.pkt_0 = 1; 11 | } 12 | } 13 | */ 14 | 15 | | StateAndPacket | program(| StateAndPacket | state_and_packet) { 16 | if (state_and_packet.state_group_0_state_0 == 0) { 17 | state_and_packet.state_group_0_state_0 = 1; 18 | state_and_packet.pkt_0 = 1; 19 | } 20 | return state_and_packet; 21 | } 22 | -------------------------------------------------------------------------------- /example_specs/marple_tcp_nmo.sk: -------------------------------------------------------------------------------- 1 | /* 2 | // Original program: 3 | struct Packet { 4 | int pkt_0; 5 | }; 6 | int state_group_0_state_0 = 0; 7 | int state_group_1_state_0 = 0; 8 | void func(struct Packet p) { 9 | if (p.pkt_0 < state_group_1_state_0) { 10 | state_group_0_state_0 = state_group_0_state_0 + 1; 11 | } else { 12 | state_group_1_state_0 = p.pkt_0; 13 | } 14 | } 15 | */ 16 | 17 | | StateAndPacket | program(| StateAndPacket | state_and_packet) { 18 | if (state_and_packet.pkt_0 < state_and_packet.state_group_1_state_0) { 19 | state_and_packet.state_group_0_state_0 = 20 | state_and_packet.state_group_0_state_0 + 1; 21 | } else { 22 | state_and_packet.state_group_1_state_0 = state_and_packet.pkt_0; 23 | } 24 | return state_and_packet; 25 | } 26 | -------------------------------------------------------------------------------- /example_specs/rcp.sk: -------------------------------------------------------------------------------- 1 | /* 2 | // Original program: 3 | int state_group_0_state_0 = 0; 4 | int state_group_1_state_0 = 0; 5 | int state_group_2_state_0 = 0; 6 | struct Packet { 7 | int pkt_0; 8 | int pkt_1; 9 | }; 10 | void func(struct Packet p) { 11 | state_group_0_state_0 += p.pkt_0; 12 | if (p.pkt_1 < 2) { 13 | state_group_1_state_0 += p.pkt_1; 14 | state_group_2_state_0 += 1; 15 | } 16 | } 17 | */ 18 | 19 | | StateAndPacket | program(| StateAndPacket | state_and_packet) { 20 | state_and_packet.state_group_0_state_0 += state_and_packet.pkt_0; 21 | if (state_and_packet.pkt_1 < 2) { 22 | state_and_packet.state_group_1_state_0 += state_and_packet.pkt_1; 23 | state_and_packet.state_group_2_state_0 += 1; 24 | } 25 | return state_and_packet; 26 | } 27 | -------------------------------------------------------------------------------- /example_specs/sampling.sk: -------------------------------------------------------------------------------- 1 | /* 2 | // Original program: 3 | struct Packet { 4 | int pkt_0; 5 | }; 6 | int state_group_0_state_0 = 0; 7 | void func(struct Packet p) { 8 | if (state_group_0_state_0 == 3 - 1) { 9 | p.pkt_0 = 1; 10 | state_group_0_state_0 = 0; 11 | 12 | ; 13 | } else { 14 | p.pkt_0 = 0; 15 | state_group_0_state_0 = state_group_0_state_0 + 1; 16 | } 17 | } 18 | */ 19 | 20 | | StateAndPacket | program(| StateAndPacket | state_and_packet) { 21 | if (state_and_packet.state_group_0_state_0 == 3 - 1) { 22 | state_and_packet.pkt_0 = 1; 23 | state_and_packet.state_group_0_state_0 = 0; 24 | } else { 25 | state_and_packet.pkt_0 = 0; 26 | state_and_packet.state_group_0_state_0 = 27 | state_and_packet.state_group_0_state_0 + 1; 28 | } 29 | return state_and_packet; 30 | } 31 | -------------------------------------------------------------------------------- /example_specs/sampling_revised.sk: -------------------------------------------------------------------------------- 1 | /* 2 | // Original program: 3 | // Sample every 30th packet in a flow 4 | #define N 30 5 | 6 | struct Packet { 7 | int sample; 8 | }; 9 | 10 | int count = 0; 11 | 12 | void func(struct Packet pkt) { 13 | if (count == N - 1) { 14 | count = 0; 15 | } else if (count == 8){ 16 | count = 2; 17 | } else{ 18 | count = 1; 19 | } 20 | pkt.sample = 1; 21 | } 22 | */ 23 | 24 | // Output the rename map: 25 | // stateless variable rename list: 26 | 27 | // state_and_packet.pkt_0 = pkt.sample 28 | 29 | // stateful variable rename list: 30 | 31 | // state_and_packet.state_group_0_state_0 = count 32 | 33 | | StateAndPacket | program(| StateAndPacket | state_and_packet) { 34 | if (state_and_packet.state_group_0_state_0 == 30 - 1) { 35 | state_and_packet.state_group_0_state_0 = 0; 36 | } else { 37 | if (state_and_packet.state_group_0_state_0 == 8) { 38 | state_and_packet.state_group_0_state_0 = 2; 39 | } else { 40 | state_and_packet.state_group_0_state_0 = 1; 41 | } 42 | } 43 | state_and_packet.pkt_0 = 1; 44 | return state_and_packet; 45 | } 46 | -------------------------------------------------------------------------------- /example_specs/simple.sk: -------------------------------------------------------------------------------- 1 | // Spec for Sketch 2 | | StateAndPacket | program(| StateAndPacket | state_and_packet) { 3 | state_and_packet.pkt_0 = 1 + state_and_packet.state_group_0_state_0; 4 | return state_and_packet; 5 | } 6 | -------------------------------------------------------------------------------- /example_specs/simple2.sk: -------------------------------------------------------------------------------- 1 | // Spec for Sketch 2 | | StateAndPacket | program(| StateAndPacket | state_and_packet) { 3 | state_and_packet.pkt_0 = 1 + state_and_packet.pkt_0; 4 | state_and_packet.state_group_0_state_0 = 5 | state_and_packet.state_group_0_state_0 + 1; 6 | return state_and_packet; 7 | } 8 | -------------------------------------------------------------------------------- /example_specs/simplest.sk: -------------------------------------------------------------------------------- 1 | // Spec for Sketch 2 | | StateAndPacket | program(| StateAndPacket | state_and_packet) { 3 | state_and_packet.pkt_0 = state_and_packet.state_group_0_state_0; 4 | return state_and_packet; 5 | } 6 | -------------------------------------------------------------------------------- /example_specs/simplest2.sk: -------------------------------------------------------------------------------- 1 | // Spec for Sketch 2 | | StateAndPacket | program(| StateAndPacket | state_and_packet) { 3 | state_and_packet.pkt_0 = state_and_packet.state_group_0_state_0; 4 | state_and_packet.pkt_1 = state_and_packet.pkt_0 + state_and_packet.pkt_1; 5 | return state_and_packet; 6 | } 7 | -------------------------------------------------------------------------------- /example_specs/simplified_hull.sk: -------------------------------------------------------------------------------- 1 | /* 2 | // Original program: 3 | #define ECN_THRESH 20 4 | 5 | int counter = ECN_THRESH; 6 | int last_time = 0; 7 | 8 | struct Packet { 9 | int bytes; 10 | int time; 11 | int mark; 12 | }; 13 | 14 | void func(struct Packet p) { 15 | // Decrement counter according to drain rate 16 | counter = counter - (p.time - last_time); 17 | if (counter < 0) counter = 0; 18 | 19 | // Increment counter 20 | counter += p.bytes; 21 | 22 | // If we are above the ECN_THRESH, mark 23 | if (counter > ECN_THRESH) p.mark = 1; 24 | 25 | // Store last time 26 | last_time = p.time; 27 | } 28 | */ 29 | 30 | // Output the rename map: 31 | // stateless variable rename list: 32 | 33 | // state_and_packet.pkt_0 = p.time 34 | // state_and_packet.pkt_1 = p.bytes 35 | // Only this matters 36 | // state_and_packet.pkt_2 = p.mark 37 | 38 | // stateful variable rename list: 39 | 40 | // state_and_packet.state_group_0_state_0 = counter 41 | // state_and_packet.state_group_1_state_0 = last_time 42 | 43 | | StateAndPacket | program(| StateAndPacket | state_and_packet) { 44 | state_and_packet.state_group_0_state_0 = 45 | state_and_packet.state_group_0_state_0 - 46 | (state_and_packet.pkt_0 - state_and_packet.state_group_1_state_0); 47 | if (state_and_packet.state_group_0_state_0 < 0) { 48 | state_and_packet.state_group_0_state_0 = 0; 49 | } 50 | state_and_packet.state_group_0_state_0 += state_and_packet.pkt_1; 51 | if (state_and_packet.state_group_0_state_0 > 20) { 52 | state_and_packet.pkt_2 = 1; 53 | } 54 | state_and_packet.state_group_1_state_0 = state_and_packet.pkt_0; 55 | return state_and_packet; 56 | } 57 | -------------------------------------------------------------------------------- /example_specs/snap_heavy_hitter.sk: -------------------------------------------------------------------------------- 1 | /* 2 | // Original program: 3 | struct Packet { 4 | int pkt_0; 5 | }; 6 | int state_group_0_state_0 = {0}; 7 | int state_group_0_state_1 = {0}; 8 | void func(struct Packet p) { 9 | p.pkt_0 = p.pkt_0; 10 | if (state_group_0_state_0 == 0) { 11 | state_group_0_state_1 = state_group_0_state_1 + 1; 12 | if (state_group_0_state_1 == 1000) { 13 | state_group_0_state_0 = 1; 14 | } 15 | } 16 | } 17 | */ 18 | 19 | | StateAndPacket | program(| StateAndPacket | state_and_packet) { 20 | state_and_packet.pkt_0 = state_and_packet.pkt_0; 21 | if (state_and_packet.state_group_0_state_0 == 0) { 22 | state_and_packet.state_group_0_state_1 = 23 | state_and_packet.state_group_0_state_1 + 1; 24 | if (state_and_packet.state_group_0_state_1 == 1000) { 25 | state_and_packet.state_group_0_state_0 = 1; 26 | } 27 | } 28 | return state_and_packet; 29 | } 30 | -------------------------------------------------------------------------------- /example_specs/test.sk: -------------------------------------------------------------------------------- 1 | // Spec for Sketch 2 | | StateAndPacket | program(| StateAndPacket | state_and_packet) { 3 | state_and_packet.pkt_1 = state_and_packet.pkt_0 + 1 + state_and_packet.pkt_1; 4 | state_and_packet.state_group_0_state_0 = 5 | state_and_packet.state_group_0_state_0 + 1; 6 | state_and_packet.state_group_1_state_0 = 7 | state_and_packet.state_group_0_state_0 + 8 | state_and_packet.state_group_1_state_0; 9 | state_and_packet.pkt_0 = 10 | state_and_packet.state_group_0_state_0 + state_and_packet.pkt_1; 11 | state_and_packet.pkt_2 = state_and_packet.state_group_0_state_0 + 12 | state_and_packet.state_group_1_state_0; 13 | return state_and_packet; 14 | } 15 | -------------------------------------------------------------------------------- /example_specs/times_two.sk: -------------------------------------------------------------------------------- 1 | | StateAndPacket | program(| StateAndPacket | state_and_packet) { 2 | if (state_and_packet.pkt_0 * 2 == state_and_packet.pkt_1) { 3 | state_and_packet.state_group_0_state_0 = 1; 4 | } else { 5 | state_and_packet.state_group_0_state_0 = 0; 6 | } 7 | return state_and_packet; 8 | } 9 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | 3 | autopep8 4 | flake8 5 | nose 6 | pre-commit 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import subprocess 4 | from distutils import log 5 | from pathlib import Path 6 | 7 | from setuptools import find_packages 8 | from setuptools import setup 9 | from setuptools.command.build_py import build_py 10 | from setuptools.command.develop import develop 11 | 12 | _PACKAGE_NAME = 'chipc' 13 | 14 | 15 | def _generate_parser(): 16 | """Generates chipmunk grammar parser using chipmunk/stateful_alu.g4 17 | file. Assumes the user has java binary.""" 18 | 19 | grammar_name = 'alu' 20 | antlr_ext = '.g4' 21 | 22 | alu_filepath = _PACKAGE_NAME + '/' + grammar_name + antlr_ext 23 | assert os.access(alu_filepath, 24 | os.R_OK), "Can't find grammar file: %s" % alu_filepath 25 | 26 | antlr_jar = Path(_PACKAGE_NAME, 'lib', 'antlr-4.7.2-complete.jar') 27 | run_args = [ 28 | 'java', '-jar', 29 | str(antlr_jar), alu_filepath, '-Dlanguage=Python3', '-visitor', 30 | '-package', _PACKAGE_NAME 31 | ] 32 | 33 | subprocess.run(run_args, check=True) 34 | generated_files = glob.glob(_PACKAGE_NAME + '/' + grammar_name + '*.py') 35 | # Check whether Antlr actually generated any file. 36 | assert generated_files, 'Antlr4 failed to generate Parser/Lexer.' 37 | log.info('Antlr generated Python files: %s' % ', '.join( 38 | [str(f) for f in generated_files])) 39 | 40 | 41 | class DevelopWrapper(develop): 42 | def run(self): 43 | _generate_parser() 44 | develop.run(self) 45 | 46 | 47 | class BuildPyWrapper(build_py): 48 | def run(self): 49 | _generate_parser() 50 | build_py.run(self) 51 | 52 | 53 | setup( 54 | name=_PACKAGE_NAME, 55 | version='0.1', 56 | description='A switch code generator based on end-to-end program ' + 57 | 'synthesis.', 58 | url='https://github.com/anirudhSK/chipmunk', 59 | author='Chipmunk Contributors', 60 | packages=find_packages(exclude=['tests*', '*.interp', '*.tokens']), 61 | # This will let setuptools to copy ver what"s listed in MANIFEST.in 62 | include_package_data=True, 63 | install_requires=[ 64 | 'antlr4-python3-runtime>=4.7.2', 'Jinja2>=2.10', 'ordered_set>=3.1.1', 65 | 'overrides>=1.9', 'psutil>=5.6.1', 'z3-solver>=4.8.0.0' 66 | ], 67 | cmdclass={ 68 | 'build_py': BuildPyWrapper, 69 | 'develop': DevelopWrapper 70 | }, 71 | entry_points={ 72 | 'console_scripts': [ 73 | 'iterative_solver=' + _PACKAGE_NAME + '.iterative_solver:run_main' 74 | ] 75 | }) 76 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chipmunk-project/chipmunk/a86eb9030c7accfd552dc4c899fa247193588251/tests/__init__.py -------------------------------------------------------------------------------- /tests/data/hello.dag: -------------------------------------------------------------------------------- 1 | dag main__WrapperNospec : 2 | TUPLE_DEF Fmain_ANONYMOUS 3 | TUPLE_DEF Fmain__WrapperNospec_ANONYMOUS 4 | TUPLE_DEF Fmain__Wrapper_ANONYMOUS 5 | 0 = S INT pkt_1_1_0 2 6 | 1 = PLUS INT 0 0 7 | 2 = CONST INT 2 8 | 3 = TIMES INT 0 2 9 | 4 = EQ BOOL 3 1 10 | 5 = ASSERT 4 "Assert at hello.sk:2 (2)" 11 | -------------------------------------------------------------------------------- /tests/data/hello.smt: -------------------------------------------------------------------------------- 1 | (assert (forall ((pkt_1_1_0 Int )) (let ((_n0 pkt_1_1_0 )) 2 | (let ((_n1 (+ _n0 _n0 ) )) 3 | (let ((_n2 2 )) 4 | (let ((_n3 (* _n0 _n2 ) )) 5 | (let ((_n4 (= _n3 _n1 ) )) 6 | (implies (and (>= pkt_1_1_0 0) (< pkt_1_1_0 4 )) _n4 )))))))) 7 | (check-sat) 8 | (get-model) 9 | (exit) 10 | -------------------------------------------------------------------------------- /tests/data/sampling.dag: -------------------------------------------------------------------------------- 1 | dag main__WrapperNospec : 2 | TUPLE_DEF FglblInit_constant_vector__ANONYMOUS_s90_ANONYMOUS INT INT INT INT 3 | TUPLE_DEF Fmain_ANONYMOUS INT INT INT INT 4 | TUPLE_DEF Fmain__WrapperNospec_ANONYMOUS 5 | TUPLE_DEF Fmain__Wrapper_ANONYMOUS 6 | TUPLE_DEF Fpipeline_ANONYMOUS INT INT INT INT INT INT 7 | TUPLE_DEF Fprogram_ANONYMOUS INT INT 8 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_output_mux_phv_0_0_ANONYMOUS INT 9 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_output_mux_phv_1_0_ANONYMOUS INT 10 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_0_0_ANONYMOUS INT INT INT INT INT INT 11 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_0_0_C_0_ANONYMOUS INT INT INT INT INT 12 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_0_0_C_1_ANONYMOUS INT INT INT INT INT 13 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_0_0_C_2_ANONYMOUS INT INT INT INT INT 14 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_0_0_Mux3_0_ANONYMOUS INT 15 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_0_0_Mux3_1_ANONYMOUS INT 16 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_0_0_Mux3_2_ANONYMOUS INT 17 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_0_0_Opt_0_ANONYMOUS INT 18 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_0_0_Opt_1_ANONYMOUS INT 19 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_0_0_Opt_2_ANONYMOUS INT 20 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_0_0_rel_op_0_ANONYMOUS INT 21 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_1_0_ANONYMOUS INT INT INT INT INT INT 22 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_1_0_C_0_ANONYMOUS INT INT INT INT INT 23 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_1_0_C_1_ANONYMOUS INT INT INT INT INT 24 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_1_0_C_2_ANONYMOUS INT INT INT INT INT 25 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_1_0_Mux3_0_ANONYMOUS INT 26 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_1_0_Mux3_1_ANONYMOUS INT 27 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_1_0_Mux3_2_ANONYMOUS INT 28 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_1_0_Opt_0_ANONYMOUS INT 29 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_1_0_Opt_1_ANONYMOUS INT 30 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_1_0_Opt_2_ANONYMOUS INT 31 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_alu_1_0_rel_op_0_ANONYMOUS INT 32 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_operand_mux_0_0_0_ANONYMOUS INT 33 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_operand_mux_0_0_1_ANONYMOUS INT 34 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_operand_mux_1_0_0_ANONYMOUS INT 35 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateful_operand_mux_1_0_1_ANONYMOUS INT 36 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateless_alu_0_0_ANONYMOUS INT INT INT INT INT 37 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateless_alu_0_0_mux1_ANONYMOUS INT 38 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateless_alu_0_0_mux2_ANONYMOUS INT 39 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateless_alu_0_0_mux3_ANONYMOUS INT 40 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateless_alu_1_0_ANONYMOUS INT INT INT INT INT 41 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateless_alu_1_0_mux1_ANONYMOUS INT 42 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateless_alu_1_0_mux2_ANONYMOUS INT 43 | TUPLE_DEF Fsampling_if_else_raw_stateless_alu_2_1_stateless_alu_1_0_mux3_ANONYMOUS INT 44 | 0 = S INT state_group_0_state_0_4_6_0 2 45 | 1 = CONST BOOL 1 46 | 2 = PLUS INT 0 1 47 | 3 = CONST BOOL 0 48 | 4 = ARRASS INT 0 == 2 2 3 49 | 5 = CONST INT 2 50 | 6 = EQ BOOL 0 5 51 | 7 = NOT BOOL 6 52 | 8 = PLUS INT 0 7 53 | 9 = ARRASS INT 0 == 2 0 3 54 | 10 = ARRASS INT 6 == 1 8 9 55 | 11 = EQ BOOL 10 4 56 | 12 = ASSERT 11 "Assert at test2.sk:620 (1)" 57 | 13 = EQ BOOL 10 3 58 | 14 = NOT BOOL 13 59 | 15 = ARRASS INT 14 == 1 13 3 60 | 16 = EQ BOOL 15 6 61 | 17 = ASSERT 16 "Assert at test2.sk:623 (1)" 62 | 18 = S INT pkt_0_3_5_0 2 63 | -------------------------------------------------------------------------------- /tests/data/sampling.smt: -------------------------------------------------------------------------------- 1 | (assert (forall ((pkt_0_3_5_0 Int )(state_group_0_state_0_4_6_0 Int )) (let ((_n0 state_group_0_state_0_4_6_0 )) 2 | (let ((_n1 true )) 3 | (let ((_n2 (+ _n0 (ite _n1 1 0) ) )) 4 | (let ((_n3 false )) 5 | (let ((_n4 (ite (= _n0 2) (ite _n3 1 0) _n2 ) )) 6 | (let ((_n5 2 )) 7 | (let ((_n6 (= _n0 _n5 ) )) 8 | (let ((_n7 (not _n6 ) )) 9 | (let ((_n8 (+ _n0 (ite _n7 1 0) ) )) 10 | (let ((_n9 (ite (= _n0 2) (ite _n3 1 0) _n0 ) )) 11 | (let ((_n10 (ite (= (ite _n6 1 0) 1) _n9 _n8 ) )) 12 | (let ((_n11 (= _n10 _n4 ) )) 13 | (let ((_n13 (= _n10 (ite _n3 1 0) ) )) 14 | (let ((_n14 (not _n13 ) )) 15 | (let ((_n15 (ite (= (ite _n14 1 0) 1) _n3 _n13 ) )) 16 | (let ((_n16 (= (ite _n15 1 0) (ite _n6 1 0) ) )) 17 | (let ((_n18 pkt_0_3_5_0 )) 18 | (implies (and (and (>= state_group_0_state_0_4_6_0 0) (< state_group_0_state_0_4_6_0 32 )) (and (>= pkt_0_3_5_0 0) (< pkt_0_3_5_0 32 )) ) (and _n11 _n16 ) )))))))))))))))))))) 19 | (check-sat) 20 | (get-model) 21 | (exit) 22 | -------------------------------------------------------------------------------- /tests/data/simple_raw_2_2_codegen_iteration_1.sk: -------------------------------------------------------------------------------- 1 | // This is an autogenerated sketch file corresponding to 2 | // the router's data path and is used to solve the Chipmunk compilation problem. 3 | // program_file = /Users/Xiangyu/chipmunk/tests/../example_specs/simple.sk num_pipeline_stages = 2 4 | // num_alus_per_stage = 2 5 | // num_phv_containers = 2 6 | 7 | int[4] constant_vector = {0,1,2,3}; 8 | 9 | int simple_raw_2_2_stateless_alu_0_0_mux1_ctrl= ??(1); 10 | int simple_raw_2_2_stateless_alu_0_0_mux2_ctrl= ??(1); 11 | int simple_raw_2_2_stateless_alu_0_0_mux3_ctrl= ??(1); 12 | int simple_raw_2_2_stateless_alu_0_0_immediate= ??(2); 13 | int simple_raw_2_2_stateless_alu_0_0_opcode= ??(5); 14 | int simple_raw_2_2_stateless_alu_0_1_mux1_ctrl= ??(1); 15 | int simple_raw_2_2_stateless_alu_0_1_mux2_ctrl= ??(1); 16 | int simple_raw_2_2_stateless_alu_0_1_mux3_ctrl= ??(1); 17 | int simple_raw_2_2_stateless_alu_0_1_immediate= ??(2); 18 | int simple_raw_2_2_stateless_alu_0_1_opcode= ??(5); 19 | int simple_raw_2_2_stateful_alu_0_0_Mux2_0_global= ??(1); 20 | int simple_raw_2_2_stateful_alu_0_0_Opt_0_global= ??(1); 21 | int simple_raw_2_2_stateful_alu_0_0_const_0_global= ??(2); 22 | int simple_raw_2_2_stateful_alu_0_0_output_mux_global= ??(1); 23 | int simple_raw_2_2_stateless_alu_1_0_mux1_ctrl= ??(1); 24 | int simple_raw_2_2_stateless_alu_1_0_mux2_ctrl= ??(1); 25 | int simple_raw_2_2_stateless_alu_1_0_mux3_ctrl= ??(1); 26 | int simple_raw_2_2_stateless_alu_1_0_immediate= ??(2); 27 | int simple_raw_2_2_stateless_alu_1_0_opcode= ??(5); 28 | int simple_raw_2_2_stateless_alu_1_1_mux1_ctrl= ??(1); 29 | int simple_raw_2_2_stateless_alu_1_1_mux2_ctrl= ??(1); 30 | int simple_raw_2_2_stateless_alu_1_1_mux3_ctrl= ??(1); 31 | int simple_raw_2_2_stateless_alu_1_1_immediate= ??(2); 32 | int simple_raw_2_2_stateless_alu_1_1_opcode= ??(5); 33 | int simple_raw_2_2_stateful_alu_1_0_Mux2_0_global= ??(1); 34 | int simple_raw_2_2_stateful_alu_1_0_Opt_0_global= ??(1); 35 | int simple_raw_2_2_stateful_alu_1_0_const_0_global= ??(2); 36 | int simple_raw_2_2_stateful_alu_1_0_output_mux_global= ??(1); 37 | int simple_raw_2_2_stateful_operand_mux_0_0_0_ctrl= ??(1); 38 | int simple_raw_2_2_stateful_operand_mux_1_0_0_ctrl= ??(1); 39 | int simple_raw_2_2_output_mux_phv_0_0_ctrl= ??(1); 40 | int simple_raw_2_2_output_mux_phv_0_1_ctrl= ??(1); 41 | int simple_raw_2_2_output_mux_phv_1_0_ctrl= ??(1); 42 | int simple_raw_2_2_output_mux_phv_1_1_ctrl= ??(1); 43 | int simple_raw_2_2_salu_config_0_0= ??(1); 44 | int simple_raw_2_2_salu_config_1_0= ??(1); 45 | 46 | // Definitions of muxes and ALUs of the router 47 | // Operand muxes for each ALU in each stage 48 | // Total of 2 * 2 * 3 2-to-1 muxes 49 | // The 3 is for two stateless operands and one stateful operand. 50 | 51 | int simple_raw_2_2_stateful_operand_mux_0_0_0(int input0,int input1, int simple_raw_2_2_stateful_operand_mux_0_0_0_ctrl_local) { 52 | int mux_ctrl = simple_raw_2_2_stateful_operand_mux_0_0_0_ctrl_local; 53 | if (mux_ctrl == 0) { 54 | return input0; 55 | } 56 | 57 | 58 | else { return input1; } 59 | } 60 | int simple_raw_2_2_stateful_operand_mux_1_0_0(int input0,int input1, int simple_raw_2_2_stateful_operand_mux_1_0_0_ctrl_local) { 61 | int mux_ctrl = simple_raw_2_2_stateful_operand_mux_1_0_0_ctrl_local; 62 | if (mux_ctrl == 0) { 63 | return input0; 64 | } 65 | 66 | 67 | else { return input1; } 68 | } 69 | 70 | 71 | // Output mux for each PHV container 72 | // Allows the container to be written from either its own stateless ALU or any stateful ALU 73 | 74 | int simple_raw_2_2_output_mux_phv_0_0(int input0,int input1, int simple_raw_2_2_output_mux_phv_0_0_ctrl_local) { 75 | int mux_ctrl = simple_raw_2_2_output_mux_phv_0_0_ctrl_local; 76 | if (mux_ctrl == 0) { 77 | return input0; 78 | } 79 | 80 | 81 | else { return input1; } 82 | } 83 | int simple_raw_2_2_output_mux_phv_0_1(int input0,int input1, int simple_raw_2_2_output_mux_phv_0_1_ctrl_local) { 84 | int mux_ctrl = simple_raw_2_2_output_mux_phv_0_1_ctrl_local; 85 | if (mux_ctrl == 0) { 86 | return input0; 87 | } 88 | 89 | 90 | else { return input1; } 91 | } 92 | int simple_raw_2_2_output_mux_phv_1_0(int input0,int input1, int simple_raw_2_2_output_mux_phv_1_0_ctrl_local) { 93 | int mux_ctrl = simple_raw_2_2_output_mux_phv_1_0_ctrl_local; 94 | if (mux_ctrl == 0) { 95 | return input0; 96 | } 97 | 98 | 99 | else { return input1; } 100 | } 101 | int simple_raw_2_2_output_mux_phv_1_1(int input0,int input1, int simple_raw_2_2_output_mux_phv_1_1_ctrl_local) { 102 | int mux_ctrl = simple_raw_2_2_output_mux_phv_1_1_ctrl_local; 103 | if (mux_ctrl == 0) { 104 | return input0; 105 | } 106 | 107 | 108 | else { return input1; } 109 | } 110 | 111 | 112 | // Definition for ALUs 113 | 114 | 115 | 116 | 117 | int simple_raw_2_2_stateless_alu_0_0_mux1(int input0,int input1, int simple_raw_2_2_stateless_alu_0_0_mux1_ctrl_local) { 118 | int mux_ctrl = simple_raw_2_2_stateless_alu_0_0_mux1_ctrl_local; 119 | if (mux_ctrl == 0) { 120 | return input0; 121 | } 122 | 123 | 124 | else { return input1; } 125 | }int simple_raw_2_2_stateless_alu_0_0_mux2(int input0,int input1, int simple_raw_2_2_stateless_alu_0_0_mux2_ctrl_local) { 126 | int mux_ctrl = simple_raw_2_2_stateless_alu_0_0_mux2_ctrl_local; 127 | if (mux_ctrl == 0) { 128 | return input0; 129 | } 130 | 131 | 132 | else { return input1; } 133 | }int simple_raw_2_2_stateless_alu_0_0_mux3(int input0,int input1, int simple_raw_2_2_stateless_alu_0_0_mux3_ctrl_local) { 134 | int mux_ctrl = simple_raw_2_2_stateless_alu_0_0_mux3_ctrl_local; 135 | if (mux_ctrl == 0) { 136 | return input0; 137 | } 138 | 139 | 140 | else { return input1; } 141 | }int simple_raw_2_2_stateless_alu_0_0(int input0,int input1,int opcode_hole_local,int immediate_operand_hole_local, int mux1_ctrl_hole_local, int mux2_ctrl_hole_local, int mux3_ctrl_hole_local ){ 142 | int opcode = opcode_hole_local; 143 | int immediate_operand = constant_vector[immediate_operand_hole_local]; 144 | int pkt_0 = simple_raw_2_2_stateless_alu_0_0_mux1(input0,input1,mux1_ctrl_hole_local); 145 | int pkt_1 = simple_raw_2_2_stateless_alu_0_0_mux2(input0,input1,mux2_ctrl_hole_local); 146 | int pkt_2 = simple_raw_2_2_stateless_alu_0_0_mux3(input0,input1,mux3_ctrl_hole_local); 147 | if (opcode==0) { 148 | return immediate_operand; 149 | } 150 | else if (opcode==1) { 151 | return pkt_0+pkt_1; 152 | } 153 | else if (opcode==2) { 154 | return pkt_0+immediate_operand; 155 | } 156 | else if (opcode==3) { 157 | return pkt_0-pkt_1; 158 | } 159 | else if (opcode==4) { 160 | return pkt_0-immediate_operand; 161 | } 162 | else if (opcode==5) { 163 | return immediate_operand-pkt_0; 164 | } 165 | else if (opcode==6) { 166 | return pkt_0!=pkt_1; 167 | } 168 | else if (opcode==7) { 169 | return (pkt_0!=immediate_operand); 170 | } 171 | else if (opcode==8) { 172 | return (pkt_0==pkt_1); 173 | } 174 | else if (opcode==9) { 175 | return (pkt_0==immediate_operand); 176 | } 177 | else if (opcode==10) { 178 | return (pkt_0>=pkt_1); 179 | } 180 | else if (opcode==11) { 181 | return (pkt_0>=immediate_operand); 182 | } 183 | else if (opcode==12) { 184 | return (pkt_0=pkt_1); 289 | } 290 | else if (opcode==11) { 291 | return (pkt_0>=immediate_operand); 292 | } 293 | else if (opcode==12) { 294 | return (pkt_0=pkt_1); 428 | } 429 | else if (opcode==11) { 430 | return (pkt_0>=immediate_operand); 431 | } 432 | else if (opcode==12) { 433 | return (pkt_0=pkt_1); 538 | } 539 | else if (opcode==11) { 540 | return (pkt_0>=immediate_operand); 541 | } 542 | else if (opcode==12) { 543 | return (pkt_0 b) 29 | 30 | self.assertEqual(z3.Not(a > b), 31 | z3_utils.negated_body(formula)) 32 | 33 | def test_variable_order(self): 34 | # Little more complex case, to see whether ordering of variables are 35 | # kept right. 36 | a, b, c = z3.Ints('a b c') 37 | formula = z3.ForAll([c, a, b], z3.And(b > a, a > c)) 38 | 39 | self.assertEqual(z3.Not(z3.And(b > a, a > c)), 40 | z3_utils.negated_body(formula)) 41 | 42 | def test_raise_assert_non_quantifiers(self): 43 | a, b = z3.Bools('a b') 44 | formula = z3.Implies(a, b) 45 | with self.assertRaisesRegex(AssertionError, 'not a quantifier'): 46 | z3_utils.negated_body(formula) 47 | 48 | 49 | class GenerateCounterexamplesTest(unittest.TestCase): 50 | def test_successs_with_mock(self): 51 | x = z3.Int('pkt_0_0_0_0') 52 | simple_formula = z3.ForAll([x], z3.And(x > 3, x < 2)) 53 | pkt_fields, state_vars = z3_utils.generate_counterexamples( 54 | simple_formula) 55 | self.assertTrue('pkt_0' in pkt_fields) 56 | self.assertDictEqual(state_vars, {}) 57 | 58 | def test_unsat_formula(self): 59 | x = z3.Int('x') 60 | equality = z3.ForAll([x], x == x) 61 | pkt_fields, state_vars = z3_utils.generate_counterexamples(equality) 62 | self.assertDictEqual(pkt_fields, {}) 63 | self.assertDictEqual(state_vars, {}) 64 | 65 | def test_state_group_with_alphabets(self): 66 | x = z3.Int('state_group_1_state_0_b_b_0') 67 | simple_formula = z3.ForAll([x], z3.And(x > 3, x < 2)) 68 | pkt_fields, state_vars = z3_utils.generate_counterexamples( 69 | simple_formula) 70 | self.assertDictEqual(pkt_fields, {}) 71 | self.assertTrue('state_group_1_state_0' in state_vars) 72 | 73 | class GetZ3FormulaTest(unittest.TestCase): 74 | def test_conversion(self): 75 | # Smoke test for bool-to-int and int-to-bool conversion 76 | # doesn't do anything too substantial, but checks that 77 | # the string is being parsed and converted into z3. 78 | z3_utils.get_z3_formula("""0 = CONST BOOL 1 79 | 1 = CONST BOOL 0 80 | 2 = LT BOOL 0 1 81 | 3 = TIMES INT 0 1 82 | 4 = PLUS INT 0 1 83 | 5 = S INT foobar 2""", 10) 84 | z3_utils.get_z3_formula("""0 = CONST INT 1 85 | 1 = CONST INT 0 86 | 2 = AND BOOL 0 1 87 | 3 = OR BOOL 0 1 88 | 4 = XOR BOOL 0 1 89 | 5 = S INT foobar 2""", 10) 90 | 91 | def test_hello(self): 92 | base_path = Path(__file__).parent 93 | sketch_ir = Path(base_path / './data/hello.dag').resolve().read_text() 94 | formula_from_ir = z3_utils.get_z3_formula(sketch_ir, input_bits=2) 95 | ir_pkt_fields, ir_state_vars = z3_utils.generate_counterexamples( 96 | formula_from_ir) 97 | 98 | formula_from_smt = z3_utils.parse_smt2_file( 99 | str(Path(base_path / './data/hello.smt').resolve())) 100 | smt_pkt_fields, smt_state_vars = z3_utils.generate_counterexamples( 101 | formula_from_smt) 102 | 103 | self.assertDictEqual(ir_pkt_fields, smt_pkt_fields) 104 | self.assertDictEqual(ir_state_vars, smt_state_vars) 105 | 106 | def test_sampling(self): 107 | base_path = Path(__file__).parent 108 | sketch_ir = Path( 109 | base_path / './data/sampling.dag').resolve().read_text() 110 | formula_from_ir = z3_utils.get_z3_formula(sketch_ir, input_bits=2) 111 | ir_pkt_fields, ir_state_vars = z3_utils.generate_counterexamples( 112 | formula_from_ir) 113 | 114 | formula_from_smt = z3_utils.parse_smt2_file( 115 | str(Path(base_path / './data/sampling.smt').resolve())) 116 | smt_pkt_fields, smt_state_vars = z3_utils.generate_counterexamples( 117 | formula_from_smt) 118 | 119 | self.assertDictEqual(ir_pkt_fields, smt_pkt_fields) 120 | self.assertDictEqual(ir_state_vars, smt_state_vars) 121 | 122 | 123 | class SimpleCheckTest(unittest.TestCase): 124 | def test_success(self): 125 | a = z3.Int('a') 126 | input_formula = z3.ForAll([a], z3.Implies(a > 0, a + 1 > a)) 127 | with patch('z3.parse_smt2_file', return_value=[input_formula]): 128 | self.assertTrue(z3_utils.simple_check('foobar')) 129 | 130 | 131 | if __name__ == '__main__': 132 | unittest.main() 133 | --------------------------------------------------------------------------------