├── requirements.txt ├── .gitignore ├── hwtypes ├── __init__.py ├── compatibility.py ├── smt_utils.py ├── adt_util.py ├── util.py ├── smt_int.py ├── modifiers.py ├── fp_vector_abc.py ├── smt_fp_vector.py ├── bit_vector_util.py ├── adt.py ├── bit_vector_abc.py ├── fp_vector.py ├── z3_bit_vector.py └── bit_vector.py ├── tests ├── test_hash.py ├── test_div.py ├── test_overflow.py ├── test_concat.py ├── test_adc.py ├── test_util.py ├── test_ext.py ├── test_smt_utils.py ├── test_smt_bit.py ├── test_adt_visitor.py ├── test_smt_bv.py ├── test_ite.py ├── test_bit.py ├── test_smt_fp.py ├── test_meta.py ├── test_uint.py ├── test_bv_protocol.py ├── test_poly.py ├── test_smt_int.py ├── test_sint.py ├── test_optypes.py ├── test_modifiers.py ├── test_bv.py ├── test_rebind.py ├── test_fp.py └── test_adt.py ├── README.md ├── .github └── workflows │ ├── linux-test.yml │ └── deploy.yml ├── setup.py └── LICENSE.txt /requirements.txt: -------------------------------------------------------------------------------- 1 | pysmt 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | bit_vector.egg-info 3 | __pycache__ 4 | dist 5 | .cache 6 | -------------------------------------------------------------------------------- /hwtypes/__init__.py: -------------------------------------------------------------------------------- 1 | from .bit_vector import * 2 | from .bit_vector_abc import * 3 | from .adt import * 4 | from .smt_bit_vector import * 5 | from .z3_bit_vector import * 6 | from .fp_vector_abc import * 7 | from .fp_vector import * 8 | from .smt_fp_vector import * 9 | from .modifiers import * 10 | from .smt_int import * 11 | -------------------------------------------------------------------------------- /hwtypes/compatibility.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | __all__ = ['IntegerTypes', 'StringTypes'] 4 | 5 | if sys.version_info < (3,): 6 | IntegerTypes = (int, long) 7 | StringTypes = (str, unicode) 8 | long = long 9 | import __builtin__ as builtins 10 | else: 11 | IntegerTypes = (int,) 12 | StringTypes = (str,) 13 | long = int 14 | import builtins 15 | -------------------------------------------------------------------------------- /tests/test_hash.py: -------------------------------------------------------------------------------- 1 | from hwtypes import Bit, BitVector, UIntVector, SIntVector 2 | 3 | 4 | def test_hash(): 5 | x = { 6 | Bit(0): 0, 7 | BitVector[3](0): 1, 8 | UIntVector[3](): 2, 9 | SIntVector[3](0): 3, 10 | } 11 | assert x[Bit(0)] == 0 12 | assert x[BitVector[3](0)] == 1 13 | assert x[UIntVector[3](0)] == 2 14 | assert x[SIntVector[3](0)] == 3 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Linux Test](https://github.com/leonardt/hwtypes/actions/workflows/linux-test.yml/badge.svg)](https://github.com/leonardt/hwtypes/actions/workflows/linux-test.yml) 2 | 3 | 4 | # Install 5 | ## Debian 6 | ``` 7 | apt install libgmp-dev libmpfr-dev libmpc-dev 8 | pip install hwtypes 9 | ``` 10 | ## OSX 11 | ``` 12 | brew install gmp mpfr libmpc 13 | pip install hwtypes 14 | ``` 15 | 16 | ## CentOS 17 | ``` 18 | yum install libmpc-devel mpfr-devel gmp-devel 19 | pip install hwtypes 20 | ``` 21 | -------------------------------------------------------------------------------- /tests/test_div.py: -------------------------------------------------------------------------------- 1 | from hwtypes import BitVector as BV 2 | import random 3 | import pytest 4 | 5 | 6 | div_params = [(32, 0xdeadbeaf, 0)] 7 | 8 | for i in range(0, 32): 9 | n = random.randint(1, 32) 10 | a = random.randint(0, (1 << n) - 1) 11 | b = random.randint(0, (1 << n) - 1) 12 | div_params.append((n, a, b)) 13 | 14 | @pytest.mark.parametrize("n,a,b", div_params) 15 | def test_div(n, a, b): 16 | if b != 0: 17 | res = a // b 18 | else: 19 | res = (1 << n) - 1 20 | 21 | assert res == (BV[n](a) // BV[n](b)).as_uint() 22 | 23 | -------------------------------------------------------------------------------- /tests/test_overflow.py: -------------------------------------------------------------------------------- 1 | from hwtypes import SIntVector as SV 2 | from hwtypes import overflow 3 | import random 4 | import pytest 5 | 6 | 7 | ovfl_params = [] 8 | 9 | for i in range(0, 32): 10 | n = random.randint(1, 32) 11 | a = SV[n](random.randint(-(1 << (n - 1)), (1 << (n - 1)) - 1)) 12 | b = SV[n](random.randint(-(1 << (n - 1)), (1 << (n - 1)) - 1)) 13 | ovfl_params.append((a, b)) 14 | 15 | 16 | @pytest.mark.parametrize("a,b", ovfl_params) 17 | def test_ovfl(a, b): 18 | expected = ((a < 0) and (b < 0) and ((a + b) >= 0)) or \ 19 | ((a >= 0) and (b >= 0) and ((a + b) < 0)) 20 | assert overflow(a, b, a + b) == expected 21 | -------------------------------------------------------------------------------- /tests/test_concat.py: -------------------------------------------------------------------------------- 1 | from hwtypes import BitVector 2 | import random 3 | 4 | NTESTS = 10 5 | MAX_BITS = 128 6 | 7 | def test_concat_const(): 8 | a = BitVector[4](4) 9 | b = BitVector[4](1) 10 | c = a.concat(b) 11 | print(a.binary_string()) 12 | print(c.binary_string()) 13 | expected = BitVector[8]([0,0,1,0,1,0,0,0]) 14 | assert expected == c 15 | 16 | def test_concat_random(): 17 | for _ in range(NTESTS): 18 | n1 = random.randint(1, MAX_BITS) 19 | n2 = random.randint(1, MAX_BITS) 20 | a = BitVector.random(n1) 21 | b = BitVector.random(n2) 22 | c = a.concat(b) 23 | assert c.size == a.size + b.size 24 | assert c == BitVector[n1 + n2](a.bits() + b.bits()) 25 | assert c.binary_string() == b.binary_string() + a.binary_string() 26 | -------------------------------------------------------------------------------- /tests/test_adc.py: -------------------------------------------------------------------------------- 1 | from hwtypes import BitVector as BV 2 | import random 3 | import pytest 4 | 5 | 6 | adc_params = [] 7 | 8 | for i in range(0, 32): 9 | n = random.randint(1, 32) 10 | a = BV[n](random.randint(0, (1 << n) - 1)) 11 | b = BV[n](random.randint(0, (1 << n) - 1)) 12 | c = BV[1](random.randint(0, 1)) 13 | adc_params.append((a, b, c)) 14 | 15 | 16 | @pytest.mark.parametrize("a,b,c", adc_params) 17 | def test_adc(a, b, c): 18 | res, carry = a.adc(b, c) 19 | assert res == a + b + c.zext(a.size - c.size) 20 | assert carry == (a.zext(1) + b.zext(1) + c.zext(1 + a.size - c.size))[-1] 21 | 22 | def test_adc1(): 23 | a = BV[16](27734) 24 | b = BV[16](13207) 25 | c = BV[1](0) 26 | res, carry = a.adc(b, c) 27 | assert res == a + b + c.zext(a.size - c.size) 28 | assert carry == (a.zext(1) + b.zext(1) + c.zext(1 + a.size - c.size))[-1] 29 | -------------------------------------------------------------------------------- /.github/workflows/linux-test.yml: -------------------------------------------------------------------------------- 1 | name: Linux Test 2 | 3 | on: 4 | push: 5 | pull_request: 6 | 7 | jobs: 8 | test: 9 | 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ['3.8', '3.10'] 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: | 23 | sudo apt install -y libgmp-dev libmpfr-dev libmpc-dev 24 | git clone https://github.com/aleaxit/gmpy 25 | cd gmpy 26 | python setup.py install 27 | cd .. 28 | pip install -r requirements.txt 29 | pip install pytest-cov 30 | pip install coveralls 31 | pip install -e . 32 | - name: Test with pytest 33 | run: | 34 | py.test --cov=hwtypes tests/ 35 | - name: Coverage 36 | run: | 37 | coveralls --service=github 38 | env: 39 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 40 | -------------------------------------------------------------------------------- /tests/test_util.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from hwtypes.util import TypedProperty 4 | 5 | def test_typedproperty(): 6 | class A: 7 | _x : int 8 | def __init__(self, x): 9 | self.x = x 10 | 11 | @TypedProperty(int) 12 | def x(self): 13 | return self._x 14 | 15 | @x.setter 16 | def x(self, val): 17 | self._x = val 18 | 19 | @TypedProperty(str) 20 | def foo(self): 21 | return 'foo' 22 | 23 | @property 24 | def bar(self): 25 | return 'bar' 26 | 27 | assert A.x is int 28 | assert A.foo is str 29 | 30 | a = A(0) 31 | 32 | assert a.x == 0 33 | a.x = 1 34 | assert a.x == 1 35 | assert a.foo == 'foo' 36 | 37 | with pytest.raises(TypeError): 38 | a.x = 'a' 39 | 40 | with pytest.raises(AttributeError): 41 | class B(A): 42 | @A.foo.setter 43 | def foo(self, value): 44 | pass 45 | 46 | class C(A): 47 | @A.bar.setter 48 | def bar(self, value): 49 | pass 50 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | # from pip.req import parse_requirements 3 | 4 | # parse_requirements() returns generator of pip.req.InstallRequirement objects 5 | 6 | # install_requires = [] 7 | # extra_requires = {} 8 | # for item in parse_requirements("requirements.txt", session=False): 9 | # req = str(item.req) 10 | # if item.markers is not None: 11 | # req += ";" + str(item.markers) 12 | # install_requires.append(req) 13 | 14 | with open("README.md", "r") as fh: 15 | long_description = fh.read() 16 | 17 | setup( 18 | name='hwtypes', 19 | url='https://github.com/leonardt/hwtypes', 20 | author='Leonard Truong', 21 | author_email='lenny@cs.stanford.edu', 22 | version='1.4.7', 23 | description='Python implementations of fixed size hardware types (Bit, ' 24 | 'BitVector, UInt, SInt, ...) based on the SMT-LIB2 semantics', 25 | scripts=[], 26 | packages=[ 27 | "hwtypes", 28 | ], 29 | install_requires=['pysmt', 'z3-solver', 'gmpy2'], 30 | long_description=long_description, 31 | long_description_content_type="text/markdown" 32 | # python_requires='>=3.6' 33 | ) 34 | -------------------------------------------------------------------------------- /.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | name: Linux Deploy 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | 8 | jobs: 9 | deploy: 10 | 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ['3.8'] 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | sudo apt install -y libgmp-dev libmpfr-dev libmpc-dev 25 | git clone https://github.com/aleaxit/gmpy 26 | cd gmpy 27 | python setup.py install 28 | cd .. 29 | pip install -r requirements.txt 30 | pip install pytest 31 | pip install -e . 32 | - name: Test with pytest 33 | run: | 34 | py.test tests/ 35 | - name: Install deploy packages 36 | shell: bash -l {0} 37 | run: | 38 | pip install twine 39 | - name: Upload to PyPI 40 | shell: bash -l {0} 41 | run: | 42 | python setup.py sdist build 43 | twine upload dist/* -u leonardt -p $PYPI_PASSWORD 44 | env: 45 | PYPI_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 46 | -------------------------------------------------------------------------------- /tests/test_ext.py: -------------------------------------------------------------------------------- 1 | from hwtypes import BitVector as BV 2 | from hwtypes import SIntVector as SV 3 | import random 4 | import pytest 5 | 6 | 7 | zext_params = [] 8 | 9 | for i in range(0, 64): 10 | n = random.randint(0, (1 << 32 - 1)) 11 | num_bits = random.randint(n.bit_length(), 32) 12 | ext_amount = 32 - num_bits 13 | zext_params.append((n, num_bits, ext_amount)) 14 | 15 | 16 | @pytest.mark.parametrize("n,num_bits,ext_amount", zext_params) 17 | def test_zext(n, num_bits, ext_amount): 18 | a = BV[num_bits](n) 19 | assert num_bits + ext_amount == a.zext(ext_amount).num_bits 20 | assert n == a.zext(ext_amount).as_uint() 21 | assert BV[num_bits + ext_amount](n) == a.zext(ext_amount) 22 | 23 | 24 | sext_params = [] 25 | 26 | for i in range(0, 64): 27 | n = random.randint(-(2 ** 15), (2 ** 15) - 1) 28 | num_bits = random.randint(n.bit_length() + 1, 17) 29 | ext_amount = 32 - num_bits 30 | sext_params.append((n, num_bits, ext_amount)) 31 | 32 | 33 | @pytest.mark.parametrize("n,num_bits,ext_amount", sext_params) 34 | def test_sext(n, num_bits, ext_amount): 35 | a = SV[num_bits](n) 36 | assert num_bits + ext_amount == a.sext(ext_amount).num_bits 37 | assert SV[num_bits + ext_amount](n).bits() == a.sext(ext_amount).bits() 38 | assert SV[num_bits + ext_amount](n) == a.sext(ext_amount) 39 | -------------------------------------------------------------------------------- /tests/test_smt_utils.py: -------------------------------------------------------------------------------- 1 | from hwtypes import SMTBit 2 | from hwtypes import smt_utils as utils 3 | 4 | def _var_idx(v): 5 | return v._name[2:] 6 | 7 | def test_fc(): 8 | x = SMTBit(prefix='x') 9 | y = SMTBit(prefix='y') 10 | z = SMTBit(prefix='z') 11 | XI = _var_idx(x) 12 | YI = _var_idx(y) 13 | ZI = _var_idx(z) 14 | f = utils.Implies( 15 | utils.And([ 16 | x, 17 | utils.Or([ 18 | ~y, 19 | ~z, 20 | ]), 21 | ]), 22 | SMTBit(1) 23 | ) 24 | 25 | f_ht = f.to_hwtypes() 26 | assert isinstance(f_ht, SMTBit) 27 | f_str = f.serialize() 28 | assert f_str == \ 29 | '''\ 30 | Implies( 31 | | And( 32 | | | x_XI, 33 | | | Or( 34 | | | | (! y_YI), 35 | | | | (! z_ZI) 36 | | | ) 37 | | ), 38 | | True 39 | )'''.replace("XI", XI).replace("YI", YI).replace("ZI", ZI) 40 | f_str = f.serialize(line_prefix='*****', indent=' ') 41 | assert f_str == \ 42 | '''\ 43 | *****Implies( 44 | ***** And( 45 | ***** x_XI, 46 | ***** Or( 47 | ***** (! y_YI), 48 | ***** (! z_ZI) 49 | ***** ) 50 | ***** ), 51 | ***** True 52 | *****)'''.replace("XI", XI).replace("YI", YI).replace("ZI", ZI) 53 | 54 | def test_0_len(): 55 | f = utils.And([]) 56 | assert f.to_hwtypes().value.constant_value() is True 57 | f = utils.Or([]) 58 | assert f.to_hwtypes().value.constant_value() is False 59 | 60 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2018 Leonard Truong 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | 3. Neither the name of the copyright holder nor the names of its contributors 14 | may be used to endorse or promote products derived from this software without 15 | specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 21 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 23 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 24 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 25 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /tests/test_smt_bit.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import operator 3 | from hwtypes import SMTBit, SMTBitVector, z3Bit 4 | 5 | @pytest.mark.parametrize("op", [ 6 | operator.and_, 7 | operator.or_, 8 | operator.xor, 9 | operator.eq, 10 | operator.ne, 11 | ]) 12 | @pytest.mark.parametrize("Bit", [SMTBit, z3Bit]) 13 | def test_bin_op(op, Bit): 14 | assert isinstance(op(Bit(), Bit()), Bit) 15 | 16 | @pytest.mark.parametrize("op", [ 17 | operator.inv, 18 | ]) 19 | @pytest.mark.parametrize("Bit", [SMTBit, z3Bit]) 20 | def test_unary_op(op, Bit): 21 | assert isinstance(op(Bit()), Bit) 22 | 23 | def test_substitute(): 24 | a0 = SMTBit() 25 | a1 = SMTBit() 26 | b0 = SMTBit() 27 | b1 = SMTBit() 28 | expr0 = a0|b0 29 | expr1 = expr0.substitute((a0, a1), (b0, b1)) 30 | assert expr1.value is (a1|b1).value 31 | 32 | 33 | def test_ite_tuple(): 34 | a = SMTBitVector[8](), SMTBit(), SMTBitVector[4]() 35 | b = SMTBitVector[8](), SMTBit(), SMTBitVector[4]() 36 | c = SMTBit() 37 | 38 | res = c.ite(a, b) 39 | assert isinstance(res, tuple) 40 | assert len(res) == 3 41 | assert isinstance(res[0], SMTBitVector[8]) 42 | assert isinstance(res[1], SMTBit) 43 | assert isinstance(res[2], SMTBitVector[4]) 44 | 45 | 46 | def test_ite_fail(): 47 | p = SMTBit() 48 | t = SMTBit() 49 | f = SMTBitVector[1]() 50 | with pytest.raises(TypeError): 51 | res = p.ite(t, f) 52 | 53 | 54 | def test_bool(): 55 | b = SMTBit() 56 | with pytest.raises(TypeError): 57 | bool(b) 58 | -------------------------------------------------------------------------------- /tests/test_adt_visitor.py: -------------------------------------------------------------------------------- 1 | from hwtypes.adt import Tuple, Product, Enum, TaggedUnion, Sum 2 | from hwtypes.adt_util import ADTVisitor, ADTInstVisitor 3 | 4 | 5 | class FlattenT(ADTVisitor): 6 | def __init__(self): 7 | self.leaves = [] 8 | 9 | def visit_leaf(self, n): 10 | self.leaves.append(n) 11 | 12 | def visit_Enum(self, n): 13 | self.leaves.append(n) 14 | 15 | 16 | class FlattenI(ADTInstVisitor): 17 | def __init__(self): 18 | self.leaves = [] 19 | 20 | def visit_leaf(self, n): 21 | self.leaves.append(n) 22 | 23 | def visit_Enum(self, n): 24 | self.leaves.append(n._value_) 25 | 26 | 27 | class T1: pass 28 | class T2: pass 29 | class T3: pass 30 | class T4: pass 31 | class T5: pass 32 | 33 | class Root(Product): 34 | leaf = T1 35 | 36 | class Branch(TaggedUnion): 37 | s_child = Sum[T2, T3] 38 | t_child = Tuple[T4, T5] 39 | 40 | class Tag(Enum): 41 | tag_a = Enum.Auto() 42 | tag_i = 1 43 | 44 | 45 | def test_visit(): 46 | flattener = FlattenT() 47 | flattener.visit(Root) 48 | 49 | assert set(flattener.leaves) == {T1, T2, T3, T4, T5, Root.Tag} 50 | 51 | def test_visit_i(): 52 | t1 = T1() 53 | t2 = T2() 54 | t3 = T3() 55 | t4 = T4() 56 | t5 = T5() 57 | 58 | root = Root( 59 | leaf=t1, 60 | Branch=Root.Branch(s_child=Sum[T2, T3](t3)), 61 | Tag=Root.Tag.tag_i 62 | ) 63 | 64 | flattener = FlattenI() 65 | flattener.visit(root) 66 | 67 | assert set(flattener.leaves) == {t1, t3, 1} 68 | 69 | -------------------------------------------------------------------------------- /tests/test_smt_bv.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import operator 3 | from hwtypes import SMTBitVector, z3BitVector, SMTBit 4 | 5 | 6 | WIDTHS = [1,2,4,8] 7 | @pytest.mark.parametrize("width", WIDTHS) 8 | @pytest.mark.parametrize("op", [ 9 | operator.add, 10 | operator.mul, 11 | operator.sub, 12 | operator.floordiv, 13 | operator.mod, 14 | operator.and_, 15 | operator.or_, 16 | operator.xor, 17 | operator.lshift, 18 | operator.rshift, 19 | ]) 20 | @pytest.mark.parametrize("BV", [SMTBitVector, z3BitVector]) 21 | def test_bin_op(width, op, BV): 22 | assert isinstance(op(BV[width](), BV[width]()), BV[width]) 23 | 24 | @pytest.mark.parametrize("width", WIDTHS) 25 | @pytest.mark.parametrize("op", [ 26 | operator.neg, 27 | operator.inv, 28 | ]) 29 | @pytest.mark.parametrize("BV", [SMTBitVector, z3BitVector]) 30 | def test_unary_op(width, op, BV): 31 | assert isinstance(op(BV[width]()), BV[width]) 32 | 33 | @pytest.mark.parametrize("width", WIDTHS) 34 | @pytest.mark.parametrize("op", [ 35 | operator.eq, 36 | operator.ne, 37 | operator.lt, 38 | operator.le, 39 | operator.gt, 40 | operator.ge, 41 | ]) 42 | @pytest.mark.parametrize("BV", [SMTBitVector, z3BitVector]) 43 | def test_bit_op(width, op, BV): 44 | assert isinstance(op(BV[width](), BV[width]()), BV.get_family().Bit) 45 | 46 | def test_substitute(): 47 | a0 = SMTBitVector[3]() 48 | a1 = SMTBitVector[3]() 49 | b0 = SMTBitVector[3]() 50 | b1 = SMTBitVector[3]() 51 | expr0 = a0 + b0*a0 52 | expr1 = expr0.substitute((a0, a1), (b0, b1)) 53 | assert expr1.value is (a1 + b1*a1).value 54 | 55 | -------------------------------------------------------------------------------- /tests/test_ite.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import pytest 4 | 5 | from hwtypes import BitVector 6 | from hwtypes import Bit 7 | 8 | ite_params = [] 9 | N_TESTS = 32 10 | 11 | def rand_bit(): 12 | return Bit(random.randint(0, 1)) 13 | 14 | for i in range(N_TESTS): 15 | n = random.randint(1, 32) 16 | a = BitVector.random(n) 17 | b = BitVector.random(n) 18 | c = BitVector.random(1) 19 | ite_params.append((a, b, c)) 20 | 21 | 22 | for i in range(N_TESTS): 23 | n = random.randint(1, 32) 24 | a = BitVector.random(n) 25 | b = BitVector.random(n) 26 | c = rand_bit() 27 | ite_params.append((a, b, c)) 28 | 29 | @pytest.mark.parametrize("a, b, c", ite_params) 30 | def test_ite(a, b, c): 31 | res = c.ite(a, b) 32 | assert res == (a if int(c) else b) 33 | 34 | def gen_rand_val(t_idx): 35 | val = [] 36 | for k in t_idx: 37 | if k == 0: 38 | val.append(rand_bit()) 39 | else: 40 | val.append(BitVector.random(k)) 41 | return tuple(val) 42 | 43 | tuple_params = [] 44 | for i in range(N_TESTS): 45 | l = random.randint(1, 8) 46 | t_idx = [random.randint(0, 4) for _ in range(l)] 47 | a = gen_rand_val(t_idx) 48 | b = gen_rand_val(t_idx) 49 | c = BitVector.random(1) 50 | tuple_params.append((a, b, c)) 51 | 52 | for i in range(N_TESTS): 53 | l = random.randint(1, 8) 54 | t_idx = [random.randint(0, 4) for _ in range(l)] 55 | a = gen_rand_val(t_idx) 56 | b = gen_rand_val(t_idx) 57 | c = rand_bit() 58 | tuple_params.append((a, b, c)) 59 | 60 | @pytest.mark.parametrize("a, b, c", tuple_params) 61 | def test_ite(a, b, c): 62 | res = c.ite(a, b) 63 | assert res == (a if int(c) else b) 64 | -------------------------------------------------------------------------------- /tests/test_bit.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import operator 3 | from hwtypes import Bit 4 | 5 | 6 | class TrueType: 7 | @staticmethod 8 | def __bool__(): 9 | return True 10 | 11 | 12 | class FalseType: 13 | @staticmethod 14 | def __bool__(): 15 | return False 16 | 17 | def test_illegal(): 18 | with pytest.raises(TypeError): 19 | Bit(object()) 20 | 21 | with pytest.raises(ValueError): 22 | Bit(2) 23 | 24 | @pytest.mark.parametrize("value, arg",[ 25 | (False, False), 26 | (False, 0), 27 | (False, FalseType()), 28 | (False, Bit(0)), 29 | (True, True), 30 | (True, 1), 31 | (True, TrueType()), 32 | (True, Bit(1)), 33 | ]) 34 | def test_value(value, arg): 35 | assert bool(Bit(arg)) == value 36 | 37 | 38 | @pytest.mark.parametrize("op, reference", [ 39 | (operator.invert, lambda x: not x), 40 | ]) 41 | @pytest.mark.parametrize("v", [False, True]) 42 | def test_operator_bit1(op, reference, v): 43 | assert reference(v) == bool(op(Bit(v))) 44 | 45 | 46 | 47 | @pytest.mark.parametrize("op, reference", [ 48 | (operator.and_, lambda x, y: x & y ), 49 | (operator.or_, lambda x, y: x | y ), 50 | (operator.xor, lambda x, y: x ^ y ), 51 | (operator.eq, lambda x, y: x == y), 52 | (operator.ne, lambda x, y: x != y), 53 | ]) 54 | @pytest.mark.parametrize("v1", [False, True]) 55 | @pytest.mark.parametrize("v2", [False, True]) 56 | def test_operator_bit2(op, reference, v1, v2): 57 | assert reference(v1, v2) == bool(op(Bit(v1), Bit(v2))) 58 | 59 | 60 | def test_random(): 61 | assert Bit.random() in [0, 1] 62 | -------------------------------------------------------------------------------- /tests/test_smt_fp.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import pytest 3 | 4 | from hwtypes import SMTFPVector, RoundingMode, SMTBit 5 | 6 | @pytest.mark.parametrize('op', [operator.neg, operator.abs, lambda x : abs(x).fp_sqrt()]) 7 | @pytest.mark.parametrize('eb', [4, 16]) 8 | @pytest.mark.parametrize('mb', [4, 16]) 9 | @pytest.mark.parametrize('mode', [ 10 | RoundingMode.RNE, 11 | RoundingMode.RNA, 12 | RoundingMode.RTP, 13 | RoundingMode.RTN, 14 | RoundingMode.RTZ, 15 | ]) 16 | @pytest.mark.parametrize('ieee', [False, True]) 17 | def test_unary_op(op, eb, mb, mode, ieee): 18 | T = SMTFPVector[eb, mb, mode, ieee] 19 | x = T() 20 | z = op(x) 21 | assert isinstance(z, T) 22 | 23 | 24 | @pytest.mark.parametrize('op', [operator.add, operator.sub, operator.mul, operator.truediv]) 25 | @pytest.mark.parametrize('eb', [4, 16]) 26 | @pytest.mark.parametrize('mb', [4, 16]) 27 | @pytest.mark.parametrize('mode', [ 28 | RoundingMode.RNE, 29 | RoundingMode.RNA, 30 | RoundingMode.RTP, 31 | RoundingMode.RTN, 32 | RoundingMode.RTZ, 33 | ]) 34 | @pytest.mark.parametrize('ieee', [False, True]) 35 | def test_bin_ops(op, eb, mb, mode, ieee): 36 | T = SMTFPVector[eb, mb, mode, ieee] 37 | x = T() 38 | y = T() 39 | z = op(x, y) 40 | assert isinstance(z, T) 41 | 42 | @pytest.mark.parametrize('op', [operator.eq, operator.ne, operator.lt, operator.le, operator.gt, operator.ge]) 43 | @pytest.mark.parametrize('eb', [4, 16]) 44 | @pytest.mark.parametrize('mb', [4, 16]) 45 | @pytest.mark.parametrize('mode', [ 46 | RoundingMode.RNE, 47 | RoundingMode.RNA, 48 | RoundingMode.RTP, 49 | RoundingMode.RTN, 50 | RoundingMode.RTZ, 51 | ]) 52 | @pytest.mark.parametrize('ieee', [False, True]) 53 | def test_bool_ops(op, eb, mb, mode, ieee): 54 | T = SMTFPVector[eb, mb, mode, ieee] 55 | x = T() 56 | y = T() 57 | z = op(x, y) 58 | assert isinstance(z, SMTBit) 59 | 60 | -------------------------------------------------------------------------------- /tests/test_meta.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from hwtypes import AbstractBitVector as ABV 3 | from hwtypes import BitVector as BV 4 | 5 | def test_subclass(): 6 | class A(ABV): pass 7 | class B(ABV): pass 8 | class C(A): pass 9 | class D(A, B): pass 10 | class E(ABV[8]): pass 11 | 12 | assert issubclass(A, ABV) 13 | assert issubclass(A[8], ABV) 14 | assert issubclass(A[8], ABV[8]) 15 | assert not issubclass(A, ABV[8]) 16 | assert not issubclass(A[7], ABV[8]) 17 | 18 | assert issubclass(C[8], A) 19 | assert issubclass(C[8], ABV) 20 | assert issubclass(C[8], A[8]) 21 | assert issubclass(C[8], ABV[8]) 22 | 23 | assert issubclass(D[8], A) 24 | assert issubclass(D[8], B) 25 | assert issubclass(D[8], ABV) 26 | assert issubclass(D[8], A[8]) 27 | assert issubclass(D[8], B[8]) 28 | assert issubclass(D[8], ABV[8]) 29 | 30 | 31 | assert issubclass(E, ABV[8]) 32 | assert issubclass(E, ABV) 33 | 34 | with pytest.raises(TypeError): 35 | ABV[8][2] 36 | 37 | with pytest.raises(TypeError): 38 | E[2] 39 | 40 | with pytest.raises(TypeError): 41 | class F(ABV[8], ABV[7]): pass 42 | 43 | 44 | def test_size(): 45 | class A(ABV): pass 46 | 47 | assert A.size is None 48 | assert A.unsized_t is A 49 | bv = A[8] 50 | assert bv.size == 8 51 | assert bv.unsized_t is A 52 | 53 | class B(ABV[8]): pass 54 | 55 | assert B.size == 8 56 | with pytest.raises(AttributeError): 57 | B.unsized_t 58 | 59 | with pytest.raises(TypeError): 60 | B[2] 61 | 62 | class C(A, B): pass 63 | assert C.size == 8 64 | with pytest.raises(AttributeError): 65 | C.unsized_t 66 | 67 | with pytest.raises(TypeError): 68 | C[6] 69 | 70 | def test_instance(): 71 | x = BV[8](0) 72 | assert isinstance(x, ABV) 73 | assert isinstance(x, ABV[8]) 74 | assert isinstance(x, BV) 75 | with pytest.raises(TypeError): 76 | ABV[4](0) 77 | 78 | class A(BV): pass 79 | 80 | y = A[8](0) 81 | assert isinstance(y, ABV) 82 | assert isinstance(y, ABV[8]) 83 | assert isinstance(y, BV) 84 | assert isinstance(y, BV[8]) 85 | -------------------------------------------------------------------------------- /hwtypes/smt_utils.py: -------------------------------------------------------------------------------- 1 | from . import SMTBit 2 | import operator 3 | from functools import partial, reduce 4 | import abc 5 | 6 | _or_reduce = partial(reduce, operator.or_) 7 | _and_reduce = partial(reduce, operator.and_) 8 | 9 | class FormulaConstructor(metaclass=abc.ABCMeta): 10 | @abc.abstractmethod 11 | def serialize(self, line_prefix, indent): 12 | ... 13 | 14 | @abc.abstractmethod 15 | def to_hwtypes(self): 16 | ... 17 | 18 | def _to_hwtypes(v): 19 | if isinstance(v, FormulaConstructor): 20 | return v.to_hwtypes() 21 | return v 22 | 23 | def _value_to_str(v, line_prefix, indent): 24 | if isinstance(v, FormulaConstructor): 25 | return v.serialize(line_prefix, indent) 26 | else: 27 | return f"{line_prefix}{v.value.serialize()}" 28 | 29 | def _op_to_str(vs, opname, line_prefix, indent): 30 | new_line_prefix = line_prefix + indent 31 | return "\n".join([ 32 | f"{line_prefix}{opname}(", 33 | ",\n".join([_value_to_str(v, new_line_prefix, indent) for v in vs]), 34 | f"{line_prefix})" 35 | ]) 36 | def _check(vs): 37 | if not all(isinstance(v, (FormulaConstructor, SMTBit)) for v in vs): 38 | raise ValueError("Formula Constructor requires SMTBit or other FormulaConstructors") 39 | 40 | class And(FormulaConstructor): 41 | def __init__(self, values: list): 42 | _check(values) 43 | self.values = list(values) 44 | 45 | def serialize(self, line_prefix="", indent="| "): 46 | return _op_to_str(self.values, "And", line_prefix, indent) 47 | 48 | def to_hwtypes(self): 49 | if len(self.values) == 0: 50 | return SMTBit(True) 51 | return _and_reduce(_to_hwtypes(v) for v in self.values) 52 | 53 | class Or(FormulaConstructor): 54 | def __init__(self, values: list): 55 | _check(values) 56 | self.values = list(values) 57 | 58 | def serialize(self, line_prefix="", indent="| "): 59 | return _op_to_str(self.values, "Or", line_prefix, indent) 60 | 61 | def to_hwtypes(self): 62 | if len(self.values)==0: 63 | return SMTBit(False) 64 | return _or_reduce(_to_hwtypes(v) for v in self.values) 65 | 66 | class Implies(FormulaConstructor): 67 | def __init__(self, p, q): 68 | _check([p, q]) 69 | self.p = p 70 | self.q = q 71 | 72 | def serialize(self, line_prefix="", indent="| "): 73 | return _op_to_str((self.p, self.q), "Implies", line_prefix, indent) 74 | 75 | def to_hwtypes(self): 76 | return (~_to_hwtypes(self.p)) | _to_hwtypes(self.q) 77 | -------------------------------------------------------------------------------- /tests/test_uint.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import operator 3 | from hwtypes import UIntVector, Bit 4 | 5 | def unsigned(value, width): 6 | return UIntVector[width](value).as_uint() 7 | 8 | NTESTS = 4 9 | WIDTHS = [8] 10 | 11 | @pytest.mark.parametrize("op, reference", [ 12 | (operator.invert, lambda x: ~x), 13 | (operator.neg, lambda x: -x), 14 | ]) 15 | @pytest.mark.parametrize("width", WIDTHS) 16 | def test_operator_uint1(op, reference, width): 17 | for _ in range(NTESTS): 18 | I = UIntVector.random(width) 19 | expected = unsigned(reference(int(I)), width) 20 | assert expected == int(op(I)) 21 | 22 | @pytest.mark.parametrize("op, reference", [ 23 | (operator.and_, lambda x, y: x & y ), 24 | (operator.or_, lambda x, y: x | y ), 25 | (operator.xor, lambda x, y: x ^ y ), 26 | (operator.lshift, lambda x, y: x << y), 27 | (operator.rshift, lambda x, y: x >> y), 28 | (operator.add, lambda x, y: x + y ), 29 | (operator.sub, lambda x, y: x - y ), 30 | (operator.mul, lambda x, y: x * y ), 31 | (operator.floordiv, lambda x, y: x // y if y != 0 else -1), 32 | ]) 33 | @pytest.mark.parametrize("width", WIDTHS) 34 | def test_operator_uint2(op, reference, width): 35 | for _ in range(NTESTS): 36 | I0, I1 = UIntVector.random(width), UIntVector.random(width) 37 | expected = unsigned(reference(int(I0), int(I1)), width) 38 | assert expected == int(op(I0, I1)) 39 | 40 | @pytest.mark.parametrize("op, reference", [ 41 | (operator.eq, lambda x, y: x == y), 42 | (operator.ne, lambda x, y: x != y), 43 | (operator.lt, lambda x, y: x < y), 44 | (operator.le, lambda x, y: x <= y), 45 | (operator.gt, lambda x, y: x > y), 46 | (operator.ge, lambda x, y: x >= y), 47 | ]) 48 | @pytest.mark.parametrize("width", WIDTHS) 49 | def test_operator_comparison(op, reference, width): 50 | for _ in range(NTESTS): 51 | I0, I1 = UIntVector.random(width), UIntVector.random(width) 52 | expected = Bit(reference(int(I0), int(I1))) 53 | assert expected == bool(op(I0, I1)) 54 | 55 | def test_unsigned(): 56 | a = UIntVector[4](4) 57 | assert int(a) == 4 58 | -------------------------------------------------------------------------------- /tests/test_bv_protocol.py: -------------------------------------------------------------------------------- 1 | from hwtypes.bit_vector_util import BitVectorProtocol, BitVectorProtocolMeta 2 | from hwtypes import BitVector, Bit 3 | from hwtypes import SMTBitVector, SMTBit 4 | 5 | 6 | def gen_counter(BitVector): 7 | Word = BitVector[32] 8 | 9 | 10 | class CounterMeta(type): 11 | def _bitvector_t_(cls): 12 | return Word 13 | 14 | 15 | class Counter(metaclass=CounterMeta): 16 | def __init__(self, initial_value=0): 17 | self._cnt : Word = Word(initial_value) 18 | 19 | @classmethod 20 | def _from_bitvector_(cls, bv: Word) -> 'cls': 21 | obj = cls() 22 | obj._cnt = bv 23 | return obj 24 | 25 | def _to_bitvector_(self) -> Word: 26 | return self._cnt 27 | 28 | def inc(self): 29 | self._cnt += 1 30 | 31 | @property 32 | def val(self): 33 | return self._cnt 34 | 35 | return Counter, CounterMeta, Word 36 | 37 | 38 | def test_implicit_inheritance(): 39 | assert not issubclass(BitVectorProtocol, BitVectorProtocolMeta) 40 | Counter, CounterMeta, Word = gen_counter(BitVector) 41 | counter = Counter() 42 | assert isinstance(counter, BitVectorProtocol) 43 | assert isinstance(Counter, BitVectorProtocolMeta) 44 | assert not isinstance(counter, BitVectorProtocolMeta) 45 | 46 | assert issubclass(Counter, BitVectorProtocol) 47 | assert issubclass(CounterMeta, BitVectorProtocolMeta) 48 | assert not issubclass(Counter, BitVectorProtocolMeta) 49 | 50 | 51 | def test_protocol_ite(): 52 | Counter, CounterMeta, Word = gen_counter(BitVector) 53 | cnt1 = Counter() 54 | cnt2 = Counter() 55 | 56 | cnt1.inc() 57 | 58 | assert cnt1.val == 1 59 | assert cnt2.val == 0 60 | 61 | cnt3 = Bit(0).ite(cnt1, cnt2) 62 | 63 | assert cnt3.val == cnt2.val 64 | 65 | cnt4 = Bit(1).ite(cnt1, cnt2) 66 | 67 | assert cnt4.val == cnt1.val 68 | 69 | 70 | def test_protocol_ite_smt(): 71 | Counter, CounterMeta, Word = gen_counter(SMTBitVector) 72 | 73 | init = Word(name='init') 74 | cnt1 = Counter(init) 75 | cnt2 = Counter(init) 76 | 77 | cnt1.inc() 78 | 79 | # pysmt == is structural equiv 80 | assert cnt1.val.value == (init.value + 1) 81 | assert cnt1.val.value != init.value 82 | assert cnt2.val.value == init.value 83 | 84 | cnt3 = SMTBit(0).ite(cnt1, cnt2) 85 | 86 | assert cnt3.val.value == cnt2.val.value 87 | 88 | cnt4 = SMTBit(1).ite(cnt1, cnt2) 89 | 90 | assert cnt4.val.value == cnt1.val.value 91 | 92 | cond = SMTBit(name='cond') 93 | 94 | cnt5 = cond.ite(cnt1, cnt2) 95 | 96 | assert cnt5.val.value == cond.ite(cnt1.val, cnt2.val).value 97 | -------------------------------------------------------------------------------- /tests/test_poly.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pysmt import shortcuts as sc 4 | 5 | from hwtypes import BitVector, SIntVector, UIntVector 6 | from hwtypes import Bit 7 | 8 | from hwtypes import SMTBit, SMTBitVector 9 | from hwtypes import SMTUIntVector, SMTSIntVector 10 | 11 | @pytest.mark.parametrize("cond_0", [Bit(0), Bit(1)]) 12 | @pytest.mark.parametrize("cond_1", [Bit(0), Bit(1)]) 13 | def test_poly_bv(cond_0, cond_1): 14 | S = SIntVector[8] 15 | U = UIntVector[8] 16 | val = cond_0.ite(S(0), U(0)) - 1 17 | 18 | assert val < 0 if cond_0 else val > 0 19 | val2 = cond_1.ite(S(-1), val) 20 | val2 = val2.ext(1) 21 | assert val2 == val.sext(1) if cond_0 or cond_1 else val2 == val.zext(1) 22 | 23 | val3 = cond_1.ite(cond_0.ite(U(0), S(1)), cond_0.ite(S(-1), U(2))) 24 | 25 | if cond_1: 26 | if cond_0: 27 | assert val3 == 0 28 | assert val3 - 1 > 0 29 | else: 30 | assert val3 == 1 31 | assert val3 - 2 < 0 32 | else: 33 | if cond_0: 34 | assert val3 == -1 35 | assert val3 < 0 36 | else: 37 | assert val3 == 2 38 | assert val3 - 3 > 0 39 | 40 | def test_poly_smt(): 41 | S = SMTSIntVector[8] 42 | U = SMTUIntVector[8] 43 | 44 | c1 = SMTBit(name='c1') 45 | u1 = U(name='u1') 46 | u2 = U(name='u2') 47 | s1 = S(name='s1') 48 | s2 = S(name='s2') 49 | 50 | # NOTE: __eq__ on pysmt terms is strict structural equivalence 51 | # for example: 52 | assert u1.value == u1.value # .value extract pysmt term 53 | assert u1.value != u2.value 54 | assert (u1 * 2).value != (u1 + u1).value 55 | assert (u1 + u2).value == (u1 + u2).value 56 | assert (u1 + u2).value != (u2 + u1).value 57 | 58 | # On to the real test 59 | expr = c1.ite(u1, s1) < 1 60 | # get the pysmt values 61 | _c1, _u1, _s1 = c1.value, u1.value, s1.value 62 | e1 = sc.Ite(_c1, _u1, _s1) 63 | one = sc.BV(1, 8) 64 | # Here we see that `< 1` dispatches symbolically 65 | f = sc.Ite(_c1, sc.BVULT(e1, one), sc.BVSLT(e1, one)) 66 | assert expr.value == f 67 | 68 | expr = expr.ite(c1.ite(u1, s1), c1.ite(s2, u2)).ext(1) 69 | 70 | e2 = sc.Ite(_c1, s2.value, u2.value) 71 | e3 = sc.Ite(f, e1, e2) 72 | 73 | se = sc.BVSExt(e3, 1) 74 | ze = sc.BVZExt(e3, 1) 75 | 76 | 77 | g = sc.Ite( 78 | f, 79 | sc.Ite(_c1, ze, se), 80 | sc.Ite(_c1, se, ze) 81 | ) 82 | # Here we see that ext dispatches symbolically / recursively 83 | assert expr.value == g 84 | 85 | 86 | # Here we see that polymorphic types only build muxes if they need to 87 | expr = c1.ite(u1, s1) + 1 88 | assert expr.value == sc.BVAdd(e1, one) 89 | # Note how it is not: 90 | assert expr.value != sc.Ite(_c1, sc.BVAdd(e1, one), sc.BVAdd(e1, one)) 91 | # which was the pattern for sign dependent operators 92 | 93 | -------------------------------------------------------------------------------- /tests/test_smt_int.py: -------------------------------------------------------------------------------- 1 | from hwtypes import SMTInt, SMTBitVector, SMTBit 2 | import pysmt.shortcuts as smt 3 | import pytest 4 | 5 | def test_sat_unsat(): 6 | x = SMTInt(prefix='x') 7 | f1 = (x > 0) & (x < 0) 8 | logic = None 9 | #Test unsat 10 | with smt.Solver(logic=logic, name='z3') as solver: 11 | solver.add_assertion(f1.value) 12 | res = solver.solve() 13 | assert not res 14 | 15 | f2 = (x >= 0) & (x*x+x == 30) 16 | #test sat 17 | with smt.Solver(logic=logic, name='z3') as solver: 18 | solver.add_assertion(f2.value) 19 | res = solver.solve() 20 | assert res 21 | x_val = solver.get_value(x.value) 22 | assert x_val.constant_value() == 5 23 | 24 | bin_ops = dict( 25 | sub = lambda x,y: x-y, 26 | add = lambda x,y: x+y, 27 | mul = lambda x,y: x*y, 28 | div = lambda x,y: x//y, 29 | lte = lambda x,y: x<=y, 30 | lt = lambda x,y: x=y, 32 | gt = lambda x,y: x>y, 33 | eq = lambda x,y: x != y, 34 | neq = lambda x,y: x == y, 35 | ) 36 | 37 | unary_ops = dict( 38 | neg = lambda x: -x, 39 | ) 40 | 41 | import random 42 | 43 | @pytest.mark.parametrize("name, fun", bin_ops.items()) 44 | def test_bin_ops(name, fun): 45 | x = SMTInt(prefix='x') 46 | y = SMTInt(prefix='y') 47 | x_val = random.randint(-100, 100) 48 | y_val = random.choice(list(set(x for x in range(-100,100))-set((0,)))) 49 | f_val = fun(x_val,y_val) 50 | f = (fun(x, y) == f_val) & (y != 0) 51 | with smt.Solver(name='z3') as solver: 52 | solver.add_assertion(f.value) 53 | res = solver.solve() 54 | assert res 55 | x_solved = solver.get_value(x.value).constant_value() 56 | y_solved = solver.get_value(y.value).constant_value() 57 | assert fun(x_solved, y_solved) == f_val 58 | 59 | 60 | @pytest.mark.parametrize("name, fun", unary_ops.items()) 61 | def test_unary_ops(name, fun): 62 | x = SMTInt(prefix='x') 63 | x_val = random.randint(-100, 100) 64 | f_val = fun(x_val) 65 | f = (fun(x) == f_val) 66 | with smt.Solver(name='z3') as solver: 67 | solver.add_assertion(f.value) 68 | res = solver.solve() 69 | assert res 70 | x_solved = solver.get_value(x.value).constant_value() 71 | assert fun(x_solved) == f_val 72 | 73 | def test_init_bv(): 74 | SMTInt(SMTBitVector[5]()) 75 | SMTInt(SMTBitVector[5](10)) 76 | SMTInt(SMTBitVector[5](10).value) 77 | 78 | 79 | def test_as_sint(): 80 | x = SMTBitVector[5](-5) 81 | x_int = x.as_sint() 82 | assert isinstance(x_int, SMTInt) 83 | assert x_int.value.constant_value() == -5 84 | 85 | def test_as_uint(): 86 | x = SMTBitVector[5](5) 87 | x_int = x.as_uint() 88 | assert isinstance(x_int, SMTInt) 89 | assert x.as_uint().value.constant_value() == 5 90 | 91 | def test_name_table_bug(): 92 | SMTBit(prefix='x') 93 | SMTInt(prefix='x') 94 | 95 | bin_ops_r = dict( 96 | sub = lambda x,y: x-y, 97 | add = lambda x,y: x+y, 98 | mul = lambda x,y: x*y, 99 | div = lambda x,y: x//y, 100 | ) 101 | 102 | @pytest.mark.parametrize('name, op', bin_ops_r.items()) 103 | def test_r_ops(name, op): 104 | res = op(5, SMTInt(2)) 105 | assert isinstance(res, SMTInt) 106 | assert res.value.is_constant() 107 | assert op(5, 2) == res.value.constant_value() -------------------------------------------------------------------------------- /tests/test_sint.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import operator 3 | import random 4 | from hwtypes import UIntVector, SIntVector, Bit 5 | 6 | def signed(value, width): 7 | return SIntVector[width](value).as_sint() 8 | 9 | NTESTS = 4 10 | WIDTHS = [8] 11 | 12 | @pytest.mark.parametrize("op, reference", [ 13 | (operator.invert, lambda x: ~x), 14 | (operator.neg, lambda x: -x), 15 | ]) 16 | @pytest.mark.parametrize("width", WIDTHS) 17 | def test_operator_int1(op, reference, width): 18 | for _ in range(NTESTS): 19 | I = SIntVector.random(width) 20 | expected = signed(reference(int(I)), width) 21 | assert expected == int(op(I)) 22 | 23 | @pytest.mark.parametrize("op, reference", [ 24 | (operator.and_, lambda x, y: x & y ), 25 | (operator.or_, lambda x, y: x | y ), 26 | (operator.xor, lambda x, y: x ^ y ), 27 | (operator.add, lambda x, y: x + y ), 28 | (operator.sub, lambda x, y: x - y ), 29 | (operator.mul, lambda x, y: x * y ), 30 | (operator.floordiv, lambda x, y: x // y if y != 0 else -1), 31 | ]) 32 | @pytest.mark.parametrize("width", WIDTHS) 33 | def test_operator_int2(op, reference, width): 34 | for _ in range(NTESTS): 35 | I0, I1 = SIntVector.random(width), SIntVector.random(width) 36 | expected = signed(reference(int(I0), int(I1)), width) 37 | assert expected == int(op(I0, I1)) 38 | 39 | @pytest.mark.parametrize("op, reference", [ 40 | (operator.eq, lambda x, y: x == y), 41 | (operator.ne, lambda x, y: x != y), 42 | (operator.lt, lambda x, y: x < y), 43 | (operator.le, lambda x, y: x <= y), 44 | (operator.gt, lambda x, y: x > y), 45 | (operator.ge, lambda x, y: x >= y), 46 | ]) 47 | @pytest.mark.parametrize("width", WIDTHS) 48 | def test_comparison(op, reference, width): 49 | for _ in range(NTESTS): 50 | I0, I1 = SIntVector.random(width), SIntVector.random(width) 51 | if op is operator.floordiv and I1 == 0: 52 | # Skip divide by zero 53 | continue 54 | expected = Bit(reference(int(I0), int(I1))) 55 | assert expected == bool(op(I0, I1)) 56 | 57 | @pytest.mark.parametrize("op, reference", [ 58 | (operator.lshift, lambda x, y: x << y), 59 | (operator.rshift, lambda x, y: x >> y), 60 | ]) 61 | @pytest.mark.parametrize("width", WIDTHS) 62 | def test_operator_int_shift(op, reference, width): 63 | for _ in range(NTESTS): 64 | I0, I1 = SIntVector.random(width), UIntVector.random(width) 65 | expected = signed(reference(int(I0), int(I1)), width) 66 | assert expected == int(op(I0, I1)) 67 | 68 | def test_signed(): 69 | a = SIntVector[4](4) 70 | assert int(a) == 4 71 | a = SIntVector[4](-4) 72 | assert a._value != 4, "Stored as unsigned two's complement value" 73 | assert int(a) == -4, "int returns the native signed int representation" 74 | 75 | 76 | @pytest.mark.parametrize("op, reference", [ 77 | (operator.floordiv, lambda x, y: x // y if y != 0 else -1), 78 | (operator.mod, lambda x, y: x % y if y != 0 else x), 79 | ]) 80 | def test_operator_by_0(op, reference): 81 | I0, I1 = SIntVector.random(5), 0 82 | expected = signed(reference(int(I0), int(I1)), 5) 83 | assert expected == int(op(I0, I1)) 84 | -------------------------------------------------------------------------------- /tests/test_optypes.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import operator 3 | import random 4 | from itertools import product 5 | 6 | from hwtypes import SIntVector, BitVector, Bit 7 | from hwtypes.bit_vector_abc import InconsistentSizeError 8 | from hwtypes.bit_vector_util import PolyVector, PolyBase 9 | 10 | def _rand_bv(width): 11 | return BitVector[width](random.randint(0, (1 << width) - 1)) 12 | 13 | def _rand_signed(width): 14 | return SIntVector[width](random.randint(0, (1 << width) - 1)) 15 | 16 | 17 | def _rand_int(width): 18 | return random.randint(0, (1 << width) - 1) 19 | 20 | @pytest.mark.parametrize("op", [ 21 | operator.and_, 22 | operator.or_, 23 | operator.xor, 24 | operator.lshift, 25 | operator.rshift, 26 | ]) 27 | @pytest.mark.parametrize("width1", (1, 2)) 28 | @pytest.mark.parametrize("width2", (1, 2)) 29 | @pytest.mark.parametrize("use_int", (False, True)) 30 | def test_bin(op, width1, width2, use_int): 31 | x = _rand_bv(width1) 32 | if use_int: 33 | y = _rand_int(width2) 34 | res = op(x, y) 35 | assert type(res) is type(x) 36 | else: 37 | y = _rand_bv(width2) 38 | if width1 != width2: 39 | assert type(x) is not type(y) 40 | with pytest.raises(InconsistentSizeError): 41 | op(x, y) 42 | else: 43 | assert type(x) is type(y) 44 | res = op(x, y) 45 | assert type(res) is type(x) 46 | 47 | 48 | @pytest.mark.parametrize("op", [ 49 | operator.eq, 50 | operator.ne, 51 | operator.lt, 52 | operator.le, 53 | operator.gt, 54 | operator.ge, 55 | ]) 56 | @pytest.mark.parametrize("width1", (1, 2)) 57 | @pytest.mark.parametrize("width2", (1, 2)) 58 | @pytest.mark.parametrize("use_int", (False, True)) 59 | def test_comp(op, width1, width2, use_int): 60 | x = _rand_bv(width1) 61 | if use_int: 62 | y = _rand_int(width2) 63 | res = op(x, y) 64 | assert type(res) is Bit 65 | else: 66 | y = _rand_bv(width2) 67 | if width1 != width2: 68 | assert type(x) is not type(y) 69 | with pytest.raises(InconsistentSizeError): 70 | op(x, y) 71 | else: 72 | assert type(x) is type(y) 73 | res = op(x, y) 74 | assert type(res) is Bit 75 | 76 | 77 | @pytest.mark.parametrize("t_constructor", (_rand_bv, _rand_signed, _rand_int)) 78 | @pytest.mark.parametrize("t_size", (1, 2, 4)) 79 | @pytest.mark.parametrize("f_constructor", (_rand_bv, _rand_signed, _rand_int)) 80 | @pytest.mark.parametrize("f_size", (1, 2, 4)) 81 | def test_ite(t_constructor, t_size, f_constructor, f_size): 82 | pred = Bit(_rand_int(1)) 83 | t = t_constructor(t_size) 84 | f = f_constructor(f_size) 85 | 86 | t_is_bv_constructor = t_constructor in {_rand_signed, _rand_bv} 87 | f_is_bv_constructor = f_constructor in {_rand_signed, _rand_bv} 88 | sizes_equal = t_size == f_size 89 | 90 | if (t_constructor is f_constructor and t_is_bv_constructor and sizes_equal): 91 | # The same bv_constructor 92 | res = pred.ite(t, f) 93 | assert type(res) is type(t) 94 | elif t_is_bv_constructor and f_is_bv_constructor and sizes_equal: 95 | # Different bv_constuctor 96 | res = pred.ite(t, f) 97 | # The bases should be the most specific types that are common 98 | # to both branches and PolyBase. 99 | assert isinstance(res, PolyBase) 100 | assert isinstance(res, BitVector[t_size]) 101 | elif t_is_bv_constructor and f_is_bv_constructor and not sizes_equal: 102 | # BV with different size 103 | with pytest.raises(InconsistentSizeError): 104 | res = pred.ite(t, f) 105 | else: 106 | # Trying to coerce an int 107 | with pytest.raises(TypeError): 108 | res = pred.ite(t, f) 109 | -------------------------------------------------------------------------------- /tests/test_modifiers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from hwtypes import Bit, BitVector, AbstractBit 4 | import hwtypes.modifiers as modifiers 5 | from hwtypes.modifiers import make_modifier, is_modified, is_modifier, unwrap_modifier, wrap_modifier 6 | from hwtypes.modifiers import get_modifier, get_unmodified 7 | from hwtypes.modifiers import strip_modifiers, push_modifiers 8 | from hwtypes.adt import Tuple, Product, Sum, Enum, TaggedUnion 9 | 10 | modifiers._DEBUG = True 11 | 12 | def test_basic(): 13 | Global = make_modifier("Global") 14 | GlobalBit = Global(Bit) 15 | 16 | assert GlobalBit is Global(Bit) 17 | 18 | assert issubclass(GlobalBit, Bit) 19 | assert issubclass(GlobalBit, AbstractBit) 20 | assert issubclass(GlobalBit, Global) 21 | assert issubclass(GlobalBit, Global(AbstractBit)) 22 | 23 | global_bit = GlobalBit(0) 24 | 25 | assert isinstance(global_bit, GlobalBit) 26 | assert isinstance(global_bit, Bit) 27 | assert isinstance(global_bit, AbstractBit) 28 | assert isinstance(global_bit, Global) 29 | assert isinstance(global_bit, Global(AbstractBit)) 30 | 31 | assert is_modifier(Global) 32 | assert is_modified(GlobalBit) 33 | assert not is_modifier(Bit) 34 | assert not is_modified(Bit) 35 | assert not is_modified(Global) 36 | 37 | assert get_modifier(GlobalBit) is Global 38 | assert get_unmodified(GlobalBit) is Bit 39 | 40 | with pytest.raises(TypeError): 41 | get_modifier(Bit) 42 | 43 | with pytest.raises(TypeError): 44 | get_unmodified(Bit) 45 | 46 | 47 | def test_modify_adt(): 48 | Mod = make_modifier("Mod") 49 | 50 | T = Tuple[int, str] 51 | MT = Mod(T) 52 | assert issubclass(MT, T) 53 | assert MT(0, 'y')[0] == 0 54 | assert MT(0, 'y')[1] == 'y' 55 | 56 | class P(Product): 57 | x = int 58 | y = str 59 | 60 | MP = Mod(P) 61 | assert issubclass(MP, T) 62 | assert issubclass(MP, P) 63 | assert issubclass(MP, MT) 64 | assert MP(x=0, y='y').y == 'y' 65 | assert MP(y='y', x=0).x == 0 66 | 67 | S = Sum[int, str] 68 | MS = Mod(S) 69 | assert issubclass(MS, S) 70 | assert MS(0)[int].value == 0 71 | assert MS('x')[str].value == 'x' 72 | 73 | 74 | def test_cache(): 75 | G1 = make_modifier("Global", cache=True) 76 | G2 = make_modifier("Global", cache=True) 77 | G3 = make_modifier("Global") 78 | 79 | assert G1 is G2 80 | assert G1 is not G3 81 | 82 | def test_nested(): 83 | A = make_modifier("A") 84 | B = make_modifier("B") 85 | C = make_modifier("C") 86 | ABCBit = C(B(A(Bit))) 87 | base, mods = unwrap_modifier(ABCBit) 88 | assert base is Bit 89 | assert mods == [A, B, C] 90 | assert wrap_modifier(Bit, mods) == ABCBit 91 | with pytest.raises(TypeError): 92 | wrap_modifier(Bit, [A, B, C, A]) 93 | 94 | def test_strip(): 95 | M0 = make_modifier("M0") 96 | M1 = make_modifier("M1") 97 | M2 = make_modifier("M2") 98 | BV = BitVector 99 | 100 | class E(Enum): 101 | a=1 102 | b=2 103 | class A(Product, cache=True): 104 | b = M0(Sum[Bit, M1(BV[6]), M2(E)]) 105 | c = M1(E) 106 | d = M2(Tuple[M0(BV[3]), M1(M0(Bit))]) 107 | 108 | A_stripped = strip_modifiers(A) 109 | assert A_stripped.b is Sum[Bit, BV[6], E] 110 | assert A_stripped.c is E 111 | assert A_stripped.d is Tuple[BV[3], Bit] 112 | 113 | def test_ta_strip(): 114 | class Ta(TaggedUnion, cache=True): 115 | a = int 116 | b = int 117 | c = str 118 | class B(Ta): pass 119 | class C(B): pass 120 | M0 = make_modifier("M0") 121 | 122 | C_stripped = strip_modifiers(M0(C)) 123 | assert C == C_stripped 124 | 125 | def test_push(): 126 | M0 = make_modifier("M0") 127 | M1 = make_modifier("M1") 128 | M2 = make_modifier("M2") 129 | BV = BitVector 130 | 131 | class E(Enum): 132 | a=1 133 | b=2 134 | class A(Product, cache=True): 135 | b = M0(Sum[Bit, M1(BV[6]), M2(E)]) 136 | c = M1(E) 137 | d = M2(Tuple[M0(BV[3]), M1(M0(Bit))]) 138 | 139 | A_pushed = push_modifiers(A) 140 | A_pushed.b is Sum[M0(Bit), M0(M1(BV[6])), M0(M2(E))] 141 | A_pushed.c is M1(E) 142 | A_pushed.d is Tuple[M2(M0(BV[3])), M2(M1(M0(Bit)))] 143 | -------------------------------------------------------------------------------- /hwtypes/adt_util.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from .adt_meta import BoundMeta 3 | from .bit_vector_abc import AbstractBitVectorMeta, AbstractBitVector, AbstractBit 4 | 5 | from .util import _issubclass 6 | from hwtypes.modifiers import unwrap_modifier, wrap_modifier, is_modified 7 | from .adt import Product, Sum, Tuple, Enum, TaggedUnion 8 | from inspect import isclass 9 | 10 | 11 | 12 | class _ADTVisitor(metaclass=ABCMeta): 13 | def visit(self, adt_t): 14 | # The order here is important because Product < Tuple 15 | # and TaggedUnion < Sum 16 | if self.check_t(adt_t, Enum): 17 | self.visit_Enum(adt_t) 18 | elif self.check_t(adt_t, Product): 19 | self.visit_Product(adt_t) 20 | elif self.check_t(adt_t, Tuple): 21 | self.visit_Tuple(adt_t) 22 | elif self.check_t(adt_t, TaggedUnion): 23 | self.visit_TaggedUnion(adt_t) 24 | elif self.check_t(adt_t, Sum): 25 | self.visit_Sum(adt_t) 26 | else: 27 | self.visit_leaf(adt_t) 28 | 29 | @abstractmethod 30 | def check_t(self, adt_t): pass 31 | 32 | @abstractmethod 33 | def generic_visit(self, adt_t): pass 34 | 35 | def visit_Leaf(self, adt_t): pass 36 | 37 | def visit_Enum(self, adt_t): pass 38 | 39 | def visit_Product(self, adt_t): 40 | self.generic_visit(adt_t) 41 | 42 | def visit_Tuple(self, adt_t): 43 | self.generic_visit(adt_t) 44 | 45 | def visit_TaggedUnion(self, adt_t): 46 | self.generic_visit(adt_t) 47 | 48 | def visit_Sum(self, adt_t): 49 | self.generic_visit(adt_t) 50 | 51 | 52 | class ADTVisitor(_ADTVisitor): 53 | ''' 54 | Visitor for ADTs 55 | ''' 56 | check_t = staticmethod(_issubclass) 57 | 58 | def generic_visit(self, adt_t): 59 | for T in adt_t.field_dict.values(): 60 | self.visit(T) 61 | 62 | 63 | class ADTInstVisitor(_ADTVisitor): 64 | ''' 65 | Visitor for ADT instances 66 | ''' 67 | check_t = staticmethod(isinstance) 68 | 69 | 70 | def generic_visit(self, adt): 71 | for k, v in adt.value_dict.items(): 72 | if v is not None: 73 | self.visit(v) 74 | 75 | def rebind_bitvector( 76 | adt, 77 | bv_type_0: AbstractBitVectorMeta, 78 | bv_type_1: AbstractBitVectorMeta, 79 | keep_modifiers=False): 80 | if keep_modifiers and is_modified(adt): 81 | unmod, mods = unwrap_modifier(adt) 82 | return wrap_modifier(rebind_bitvector(unmod,bv_type_0,bv_type_1,True),mods) 83 | 84 | if _issubclass(adt, bv_type_0): 85 | if adt.is_sized: 86 | return bv_type_1[adt.size] 87 | else: 88 | return bv_type_1 89 | elif isinstance(adt, BoundMeta): 90 | _to_new = [] 91 | for field in adt.fields: 92 | new_field = rebind_bitvector(field, bv_type_0, bv_type_1,keep_modifiers) 93 | _to_new.append((field,new_field)) 94 | new_adt = adt 95 | 96 | for field,new_field in _to_new: 97 | new_adt = new_adt.rebind(field, new_field, rebind_recursive=False) 98 | return new_adt 99 | else: 100 | return adt 101 | 102 | def rebind_keep_modifiers(adt, A, B): 103 | if is_modified(adt): 104 | unmod, mods = unwrap_modifier(adt) 105 | return wrap_modifier(rebind_keep_modifiers(unmod,A,B),mods) 106 | 107 | if _issubclass(adt,A): 108 | return B 109 | elif isinstance(adt, BoundMeta): 110 | new_adt = adt 111 | for field in adt.fields: 112 | new_field = rebind_keep_modifiers(field, A, B) 113 | new_adt = new_adt.rebind(field, new_field) 114 | return new_adt 115 | else: 116 | return adt 117 | 118 | #rebind_type will rebind a type to a different family 119 | #Types that will be rebinded: 120 | # Product,Tuple,Sum, BitVector, Bit 121 | # Modified types 122 | #If the passed in type cannot be rebinded, it will just be returned unmodified 123 | def rebind_type(T, family): 124 | def _rebind_bv(T): 125 | return rebind_bitvector(T, AbstractBitVector, family.BitVector).rebind(AbstractBit, family.Bit, True) 126 | if not isclass(T): 127 | return T 128 | elif is_modified(T): 129 | return get_modifier(T)(rebind_type(get_unmodified(T), family, dont_rebind, do_rebind, is_magma)) 130 | elif issubclass(T, AbstractBitVector): 131 | return rebind_bitvector(T, AbstractBitVector, family.BitVector) 132 | elif issubclass(T, AbstractBit): 133 | return family.Bit 134 | elif issubclass(T, (Product, Tuple, Sum)): 135 | return _rebind_bv(T) 136 | else: 137 | return T 138 | -------------------------------------------------------------------------------- /hwtypes/util.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from collections.abc import Mapping, MutableMapping 3 | import typing as tp 4 | import types 5 | 6 | class FrozenDict(Mapping): 7 | __slots__ = '_d', '_hash' 8 | 9 | def __init__(self, *args, **kwargs): 10 | self._d = dict(*args, **kwargs) 11 | self._hash = hash(frozenset(self.items())) 12 | 13 | def __getitem__(self, key): 14 | return self._d.__getitem__(key) 15 | 16 | def __iter__(self): 17 | return self._d.__iter__() 18 | 19 | def __len__(self): 20 | return self._d.__len__() 21 | 22 | def __eq__(self, other): 23 | if isinstance(other, type(self)): 24 | return self._d == other._d 25 | else: 26 | return self._d == other 27 | 28 | def __ne__(self, other): 29 | return not (self == other) 30 | 31 | def __hash__(self): 32 | return self._hash 33 | 34 | 35 | class OrderedFrozenDict(FrozenDict): 36 | __slots__ = () 37 | 38 | def __init__(self, *args, **kwargs): 39 | self._d = OrderedDict(*args, **kwargs) 40 | self._hash = hash(tuple(self.items())) 41 | 42 | 43 | class TypedProperty: 44 | ''' 45 | Behaves mostly like property except: 46 | class A: 47 | @property 48 | def foo(self): ... 49 | 50 | @foo.setter 51 | def foo(self, value): ... 52 | 53 | A.Foo -> 54 | 55 | class B: 56 | @TypedProperty(T) 57 | def foo(self): ... 58 | 59 | @foo.setter 60 | def foo(self, value): ... 61 | 62 | B.Foo -> T 63 | B().foo = T() #works 64 | B().foo = A() # TypeError expected T not A 65 | 66 | class C(A): 67 | @A.foo.setter #works 68 | 69 | class D(B): 70 | @B.foo.setter #error unless T has atrribute setter in which case 71 | #T.setter is called for the decorator as B.Foo -> T. 72 | #In generalTypedProperty objects must not be modified 73 | #outside of the class in which they are declared 74 | 75 | ''' 76 | def __init__(self, T): 77 | self.T = T 78 | self.final = False 79 | self.fget = None 80 | self.fset = None 81 | self.fdel = None 82 | self.__doc__ = None 83 | 84 | def __call__(self, fget=None, fset=None, fdel=None, doc=None): 85 | self.fget = fget 86 | self.fset = fset 87 | self.fdel = fdel 88 | if doc is None and fget is not None: 89 | doc = fget.__doc__ 90 | self.__doc__ = doc 91 | return self 92 | 93 | def __get__(self, obj, objtype=None): 94 | if obj is None: 95 | if self.final: 96 | return self.T 97 | else: 98 | return self 99 | 100 | if self.fget is None: 101 | raise AttributeError("unreadable attribute") 102 | return self.fget(obj) 103 | 104 | def __set__(self, obj, value): 105 | if self.fset is None: 106 | raise AttributeError("can't set attribute") 107 | elif not isinstance(value, self.T): 108 | raise TypeError(f'Expected {self.T} not {type(value)}') 109 | self.fset(obj, value) 110 | 111 | def __delete__(self, obj): 112 | if self.fdel is None: 113 | raise AttributeError("can't delete attribute") 114 | self.fdel(obj) 115 | 116 | def __set_name__(self, cls, name): 117 | self.final = True 118 | 119 | def getter(self, fget): 120 | return type(self)(self.T)(fget, self.fset, self.fdel, self.__doc__) 121 | 122 | def setter(self, fset): 123 | return type(self)(self.T)(self.fget, fset, self.fdel, self.__doc__) 124 | 125 | def deleter(self, fdel): 126 | return type(self)(self.T)(self.fget, self.fset, fdel, self.__doc__) 127 | 128 | class Method: 129 | ''' 130 | Method descriptor which automatically sets the name of the bound function 131 | ''' 132 | def __init__(self, m): 133 | self.m = m 134 | 135 | def __get__(self, obj, objtype=None): 136 | if obj is not None: 137 | return types.MethodType(self.m, obj) 138 | else: 139 | return self.m 140 | 141 | def __set_name__(self, owner, name): 142 | self.m.__name__ = name 143 | self.m.__qualname__ = owner.__qualname__ + '.' + name 144 | 145 | 146 | def __call__(self, *args, **kwargs): 147 | # HACK 148 | # need this because of vcall works 149 | return self.m(*args, **kwargs) 150 | 151 | def _issubclass(sub : tp.Any, parent : type) -> bool: 152 | try: 153 | return issubclass(sub, parent) 154 | except TypeError: 155 | return False 156 | -------------------------------------------------------------------------------- /hwtypes/smt_int.py: -------------------------------------------------------------------------------- 1 | import functools as ft 2 | from .smt_bit_vector import SMTBit, SMTBitVector, _gen_name, _name_re, _name_table, SMYBOLIC, AUTOMATIC 3 | 4 | import pysmt 5 | import pysmt.shortcuts as smt 6 | from pysmt.typing import INT 7 | 8 | import warnings 9 | 10 | __ALL__ = ['SMTInt'] 11 | 12 | 13 | def int_cast(fn): 14 | @ft.wraps(fn) 15 | def wrapped(self, other): 16 | if isinstance(other, SMTInt): 17 | return fn(self, other) 18 | else: 19 | try: 20 | other = SMTInt(other) 21 | except TypeError: 22 | return NotImplemented 23 | return fn(self, other) 24 | return wrapped 25 | 26 | class SMTInt: 27 | def __init__(self, value=SMYBOLIC, *, name=AUTOMATIC, prefix=AUTOMATIC): 28 | if (name is not AUTOMATIC or prefix is not AUTOMATIC) and value is not SMYBOLIC: 29 | raise TypeError('Can only name symbolic variables') 30 | elif name is not AUTOMATIC and prefix is not AUTOMATIC: 31 | raise ValueError('Can only set either name or prefix not both') 32 | elif name is not AUTOMATIC: 33 | if not isinstance(name, str): 34 | raise TypeError('Name must be string') 35 | elif name in _name_table: 36 | raise ValueError(f'Name {name} already in use') 37 | elif _name_re.fullmatch(name): 38 | warnings.warn('Name looks like an auto generated name, this might break things') 39 | _name_table[name] = self 40 | elif prefix is not AUTOMATIC: 41 | name = _gen_name(prefix) 42 | _name_table[name] = self 43 | elif name is AUTOMATIC and value is SMYBOLIC: 44 | name = _gen_name() 45 | _name_table[name] = self 46 | 47 | if value is SMYBOLIC: 48 | self._value = smt.Symbol(name, INT) 49 | elif isinstance(value, pysmt.fnode.FNode): 50 | if value.get_type().is_int_type(): 51 | self._value = value 52 | elif value.get_type().is_bv_type(): 53 | self._value = smt.BVToNatural(value) 54 | else: 55 | raise TypeError(f'Expected int type not {value.get_type()}') 56 | elif isinstance(value, SMTInt): 57 | self._value = value._value 58 | elif isinstance(value, SMTBitVector): 59 | self._value = smt.BVToNatural(value.value) 60 | elif isinstance(value, bool): 61 | self._value = smt.Int(int(value)) 62 | elif isinstance(value, int): 63 | self._value = smt.Int(value) 64 | elif hasattr(value, '__int__'): 65 | self._value = smt.Int(int(value)) 66 | else: 67 | raise TypeError("Can't coerce {} to Int".format(type(value))) 68 | 69 | self._name = name 70 | self._value = smt.simplify(self._value) 71 | 72 | def __repr__(self): 73 | if self._name is not AUTOMATIC: 74 | return f'{type(self)}({self._name})' 75 | else: 76 | return f'{type(self)}({self._value})' 77 | 78 | @property 79 | def value(self): 80 | return self._value 81 | 82 | def __neg__(self): 83 | return SMTInt(0) - self 84 | 85 | @int_cast 86 | def __sub__(self, other: 'SMTInt') -> 'SMTInt': 87 | return SMTInt(self.value - other.value) 88 | 89 | @int_cast 90 | def __rsub__(self, other: 'SMTInt') -> 'SMTInt': 91 | return SMTInt(other.value - self.value) 92 | 93 | @int_cast 94 | def __add__(self, other: 'SMTInt') -> 'SMTInt': 95 | return SMTInt(self.value + other.value) 96 | 97 | def __radd__(self, other: 'SMTInt') -> 'SMTInt': 98 | return self + other 99 | 100 | @int_cast 101 | def __mul__(self, other: 'SMTInt') -> 'SMTInt': 102 | return SMTInt(self.value * other.value) 103 | 104 | def __rmul__(self, other: 'SMTInt') -> 'SMTInt': 105 | return self * other 106 | 107 | @int_cast 108 | def __floordiv__(self, other: 'SMTInt') -> 'SMTInt': 109 | return SMTInt(smt.Div(self.value, other.value)) 110 | 111 | @int_cast 112 | def __rfloordiv__(self, other: 'SMTInt') -> 'SMTInt': 113 | return SMTInt(smt.Div(other.value, self.value)) 114 | 115 | @int_cast 116 | def __ge__(self, other: 'SMTInt') -> SMTBit: 117 | return SMTBit(self.value >= other.value) 118 | 119 | @int_cast 120 | def __gt__(self, other: 'SMTInt') -> SMTBit: 121 | return SMTBit(self.value > other.value) 122 | 123 | @int_cast 124 | def __le__(self, other: 'SMTInt') -> SMTBit: 125 | return SMTBit(self.value <= other.value) 126 | 127 | @int_cast 128 | def __lt__(self, other: 'SMTInt') -> SMTBit: 129 | return SMTBit(self.value < other.value) 130 | 131 | @int_cast 132 | def __eq__(self, other: 'SMTInt') -> SMTBit: 133 | return SMTBit(smt.Equals(self.value, other.value)) 134 | 135 | @int_cast 136 | def __ne__(self, other: 'SMTInt') -> SMTBit: 137 | return SMTBit(smt.NotEquals(self.value, other.value)) 138 | -------------------------------------------------------------------------------- /tests/test_bv.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import operator 3 | from hwtypes import BitVector, Bit 4 | 5 | NTESTS = 4 6 | WIDTHS = [1,2,4,8] 7 | 8 | def unsigned(value, width): 9 | return BitVector[width](value) 10 | 11 | 12 | def test_illegal(): 13 | with pytest.raises(TypeError): 14 | BitVector[1](object()) 15 | 16 | @pytest.mark.parametrize("value, arg",[ 17 | (0, 0), 18 | (0, [0, 0]), 19 | (0, Bit(0)), 20 | (1, Bit(1)), 21 | (1, [1, 0]), 22 | (2, [0, 1]), 23 | (1, 5), 24 | ]) 25 | def test_uint(value, arg): 26 | assert BitVector[2](arg).as_uint() == value 27 | 28 | @pytest.mark.parametrize("op, reference", [ 29 | (operator.invert, lambda x: ~x), 30 | ]) 31 | @pytest.mark.parametrize("width", WIDTHS) 32 | def test_operator_bit1(op, reference, width): 33 | for _ in range(NTESTS): 34 | I = BitVector.random(width) 35 | expected = unsigned(reference(int(I)), width) 36 | assert expected == int(op(I)) 37 | 38 | @pytest.mark.parametrize("op, reference", [ 39 | (operator.add, lambda x, y: x + y ), 40 | (operator.mul, lambda x, y: x * y ), 41 | (operator.sub, lambda x, y: x - y ), 42 | (operator.floordiv, lambda x, y: x // y if y != 0 else -1), 43 | (operator.mod, lambda x, y: x % y if y != 0 else x), 44 | (operator.and_, lambda x, y: x & y ), 45 | (operator.or_, lambda x, y: x | y ), 46 | (operator.xor, lambda x, y: x ^ y ), 47 | (operator.lshift, lambda x, y: x << y), 48 | (operator.rshift, lambda x, y: x >> y), 49 | ]) 50 | @pytest.mark.parametrize("width", WIDTHS) 51 | def test_operator_bit2(op, reference, width): 52 | for _ in range(NTESTS): 53 | I0, I1 = BitVector.random(width), BitVector.random(width) 54 | expected = unsigned(reference(int(I0), int(I1)), width) 55 | assert expected == int(op(I0, I1)) 56 | 57 | @pytest.mark.parametrize("op, reference", [ 58 | (operator.eq, lambda x, y: x == y), 59 | (operator.ne, lambda x, y: x != y), 60 | (operator.lt, lambda x, y: x < y), 61 | (operator.le, lambda x, y: x <= y), 62 | (operator.gt, lambda x, y: x > y), 63 | (operator.ge, lambda x, y: x >= y), 64 | ]) 65 | @pytest.mark.parametrize("width", WIDTHS) 66 | def test_comparisons(op, reference, width): 67 | for _ in range(NTESTS): 68 | I0, I1 = BitVector.random(width), BitVector.random(width) 69 | expected = Bit(reference(int(I0), int(I1))) 70 | assert expected == bool(op(I0, I1)) 71 | 72 | def test_as(): 73 | bv = BitVector[4](1) 74 | assert bv.as_sint() == 1 75 | assert bv.as_uint() == 1 76 | assert int(bv) == 1 77 | assert bool(bv) == True 78 | assert bv.bits() == [1,0,0,0] 79 | assert bv.as_binary_string()== '0b0001' 80 | assert str(bv) == '1' 81 | assert repr(bv) == 'BitVector[4](1)' 82 | 83 | def test_eq(): 84 | assert BitVector[4](1) == 1 85 | assert BitVector[4](1) == BitVector[4](1) 86 | assert [BitVector[4](1)] == [BitVector[4](1)] 87 | assert [[BitVector[4](1)]] == [[BitVector[4](1)]] 88 | 89 | 90 | def test_setitem(): 91 | bv = BitVector[3](5) 92 | assert bv.as_uint() ==5 93 | bv[0] = 0 94 | assert repr(bv) == 'BitVector[3](4)' 95 | bv[1] = 1 96 | assert repr(bv) == 'BitVector[3](6)' 97 | bv[2] = 0 98 | assert repr(bv) == 'BitVector[3](2)' 99 | 100 | @pytest.mark.parametrize("val", [ 101 | BitVector.random(8), 102 | BitVector.random(8).as_sint(), 103 | BitVector.random(8).as_uint(), 104 | [0,1,1,0], 105 | ]) 106 | def test_deprecated(val): 107 | with pytest.warns(DeprecationWarning): 108 | BitVector(val) 109 | 110 | @pytest.mark.parametrize("val", [ 111 | BitVector.random(4), 112 | BitVector.random(4).as_sint(), 113 | BitVector.random(4).as_uint(), 114 | [0,1,1,0], 115 | ]) 116 | def test_old_style(val): 117 | with pytest.warns(DeprecationWarning): 118 | with pytest.raises(TypeError): 119 | BitVector(val, 4) 120 | 121 | 122 | @pytest.mark.parametrize("op, reference", [ 123 | (operator.floordiv, lambda x, y: x // y if y != 0 else -1), 124 | (operator.mod, lambda x, y: x % y if y != 0 else x), 125 | ]) 126 | def test_operator_by_0(op, reference): 127 | I0, I1 = BitVector.random(5), 0 128 | expected = unsigned(reference(int(I0), int(I1)), 5) 129 | assert expected == int(op(I0, I1)) 130 | 131 | 132 | 133 | @pytest.mark.parametrize("op", [ 134 | operator.add, 135 | operator.mul, 136 | operator.sub, 137 | operator.floordiv, 138 | operator.mod, 139 | operator.and_, 140 | operator.or_, 141 | operator.xor, 142 | operator.lshift, 143 | operator.rshift, 144 | operator.eq, 145 | operator.ne, 146 | operator.lt, 147 | operator.le, 148 | operator.gt, 149 | operator.ge, 150 | ]) 151 | def test_coercion(op): 152 | a = BitVector.random(16) 153 | b = BitVector.random(16) 154 | assert op(a, b) == op(int(a), b) == op(a, int(b)) 155 | -------------------------------------------------------------------------------- /hwtypes/modifiers.py: -------------------------------------------------------------------------------- 1 | import types 2 | import weakref 3 | from .adt_meta import GetitemSyntax, AttrSyntax, EnumMeta 4 | 5 | 6 | __ALL__ = ['new', 'make_modifier', 'is_modified', 'is_modifier', 'get_modifier', 'get_unmodified', 'strip_modifiers', 'push_modifiers'] 7 | 8 | 9 | _DEBUG = False 10 | #special sentinal value 11 | class _MISSING: pass 12 | 13 | def new(klass, bind=_MISSING, *, name=_MISSING, module=_MISSING): 14 | class T(klass): pass 15 | if name is not _MISSING: 16 | T.__name__ = name 17 | if module is not _MISSING: 18 | T.__module__ = module 19 | 20 | if bind is not _MISSING: 21 | return T[bind] 22 | else: 23 | return T 24 | 25 | class _ModifierMeta(type): 26 | _modifier_lookup = weakref.WeakKeyDictionary() 27 | def __instancecheck__(cls, obj): 28 | if cls is AbstractModifier: 29 | return super().__instancecheck__(obj) 30 | else: 31 | return type(obj) in cls._sub_classes 32 | 33 | def __subclasscheck__(cls, T): 34 | if cls is AbstractModifier: 35 | return super().__subclasscheck__(T) 36 | else: 37 | return T in cls._sub_classes 38 | 39 | def __call__(cls, *args): 40 | if cls is AbstractModifier: 41 | raise TypeError('Cannot instance or apply AbstractModifier') 42 | 43 | if len(args) != 1: 44 | return super().__call__(*args) 45 | 46 | unmod_cls = args[0] 47 | try: 48 | return cls._sub_class_cache[unmod_cls] 49 | except KeyError: 50 | pass 51 | 52 | mod_name = cls.__name__ + unmod_cls.__name__ 53 | bases = [unmod_cls] 54 | for base in unmod_cls.__bases__: 55 | bases.append(cls(base)) 56 | mod_cls = type(mod_name, tuple(bases), {}) 57 | cls._register_modified(unmod_cls, mod_cls) 58 | return mod_cls 59 | 60 | class AbstractModifier(metaclass=_ModifierMeta): 61 | def __init_subclass__(cls, **kwargs): 62 | cls._sub_class_cache = weakref.WeakValueDictionary() 63 | cls._sub_classes = weakref.WeakSet() 64 | 65 | @classmethod 66 | def _register_modified(cls, unmod_cls, mod_cls): 67 | type(cls)._modifier_lookup[mod_cls] = cls 68 | cls._sub_classes.add(mod_cls) 69 | cls._sub_class_cache[unmod_cls] = mod_cls 70 | if _DEBUG: 71 | # O(n) assert, but its a pretty key invariant 72 | assert set(cls._sub_classes) == set(cls._sub_class_cache.values()) 73 | 74 | def is_modified(T): 75 | return T in _ModifierMeta._modifier_lookup 76 | 77 | def is_modifier(T): 78 | return issubclass(T, AbstractModifier) 79 | 80 | def get_modifier(T): 81 | if is_modified(T): 82 | return _ModifierMeta._modifier_lookup[T] 83 | else: 84 | raise TypeError(f'{T} has no modifiers') 85 | 86 | def get_unmodified(T): 87 | if is_modified(T): 88 | unmod = T.__bases__[0] 89 | if _DEBUG: 90 | # Not an expensive assert but as there is a 91 | # already a debug guard might as well use it. 92 | mod = get_modifier(T) 93 | assert mod._sub_class_cache[unmod] is T 94 | return unmod 95 | else: 96 | raise TypeError(f'{T} has no modifiers') 97 | 98 | def get_all_modifiers(T): 99 | if is_modified(T): 100 | yield from get_all_modifiers(get_unmodified(T)) 101 | yield get_modifier(T) 102 | 103 | _mod_cache = weakref.WeakValueDictionary() 104 | # This is a factory for type modifiers. 105 | def make_modifier(name, cache=False): 106 | if cache: 107 | try: 108 | return _mod_cache[name] 109 | except KeyError: 110 | pass 111 | 112 | ModType = _ModifierMeta(name, (AbstractModifier, ), {}) 113 | 114 | if cache: 115 | return _mod_cache.setdefault(name, ModType) 116 | 117 | return ModType 118 | 119 | def unwrap_modifier(T): 120 | if not is_modified(T): 121 | return T, [] 122 | mod = get_modifier(T) 123 | unmod = get_unmodified(T) 124 | unmod, mods = unwrap_modifier(unmod) 125 | mods.append(mod) 126 | return unmod, mods 127 | 128 | def wrap_modifier(T, mods): 129 | wrapped = T 130 | if len(set(mods)) != len(mods): 131 | raise TypeError(f"{mods} must contain no duplicates") 132 | for mod in mods: 133 | wrapped = mod(wrapped) 134 | return wrapped 135 | 136 | #Takes an ADT type and removes any modifiers on any level of the Tree 137 | def strip_modifiers(adt_t): 138 | #remove modifiers from this level 139 | adt_t, _ = unwrap_modifier(adt_t) 140 | if not hasattr(adt_t, "unbound_t"): 141 | return adt_t 142 | elif isinstance(adt_t, AttrSyntax) and not isinstance(adt_t, EnumMeta): 143 | new_fields = {n:strip_modifiers(sub_adt_t) for n, sub_adt_t in adt_t.field_dict.items()} 144 | return adt_t.unbound_t.from_fields(adt_t.__name__, new_fields) 145 | elif isinstance(adt_t, GetitemSyntax): 146 | new_fields = [strip_modifiers(sub_adt_t) for sub_adt_t in adt_t.fields] 147 | return adt_t.unbound_t[new_fields] 148 | else: 149 | return adt_t 150 | 151 | #Takes an ADT type and pushes all the modifiers from internal nodes to leaf nodes 152 | def push_modifiers(adt_t, mods=[]): 153 | #remove modifiers from this level 154 | adt_t, new_mods = unwrap_modifier(adt_t) 155 | mods = new_mods + mods 156 | if isinstance(adt_t, AttrSyntax) and not isinstance(adt_t, EnumMeta): 157 | new_fields = {n:push_modifiers(sub_adt_t, mods) for n, sub_adt_t in adt_t.field_dict.items()} 158 | return adt_t.unbound_t.from_fields(adt_t.__name__, new_fields) 159 | elif isinstance(adt_t, GetitemSyntax): 160 | new_fields = [push_modifiers(sub_adt_t, mods) for sub_adt_t in adt_t.fields] 161 | return adt_t.unbound_t[new_fields] 162 | else: 163 | return wrap_modifier(adt_t, mods) 164 | 165 | 166 | -------------------------------------------------------------------------------- /tests/test_rebind.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from hwtypes.adt import Product, Sum, Enum, Tuple, TaggedUnion 4 | from hwtypes.adt_util import rebind_bitvector, rebind_keep_modifiers, rebind_type 5 | from hwtypes.bit_vector import AbstractBitVector, BitVector, AbstractBit, Bit 6 | from hwtypes.smt_bit_vector import SMTBit, SMTBitVector 7 | from hwtypes.util import _issubclass 8 | from hwtypes.modifiers import make_modifier 9 | 10 | class A: pass 11 | class B: pass 12 | class C(A): pass 13 | class D(B): pass 14 | 15 | class E(Enum): 16 | A = 0 17 | B = 1 18 | C = 2 19 | E = 3 20 | 21 | T0 = Tuple[A, B, C, E] 22 | 23 | class P0(Product, cache=True): 24 | A = A 25 | B = B 26 | C = C 27 | E = E 28 | 29 | S0 = Sum[A, B, C, E] 30 | 31 | class P1(Product, cache=True): 32 | P0 = P0 33 | S0 = S0 34 | T0 = T0 35 | D = D 36 | 37 | S1 = Sum[P0, S0, T0, D] 38 | 39 | 40 | 41 | @pytest.mark.parametrize("type_0", [A, B, C, D, E]) 42 | @pytest.mark.parametrize("type_1", [A, B, C, D, E]) 43 | @pytest.mark.parametrize("rebind_sub_types", [False, True]) 44 | def test_rebind_enum(type_0, type_1, rebind_sub_types): 45 | assert E is E.rebind(type_0, type_1, rebind_sub_types) 46 | 47 | 48 | @pytest.mark.parametrize("T", [T0, S0]) 49 | @pytest.mark.parametrize("type_0", [A, B, C, D, E]) 50 | @pytest.mark.parametrize("type_1", [A, B, C, D, E]) 51 | @pytest.mark.parametrize("rebind_sub_types", [False, True]) 52 | def test_rebind_sum_tuple(T, type_0, type_1, rebind_sub_types): 53 | fields = T.fields 54 | T_ = T.rebind(type_0, type_1, rebind_sub_types) 55 | 56 | if rebind_sub_types: 57 | map_fn = lambda s : type_1 if _issubclass(s, type_0) else s 58 | else: 59 | map_fn = lambda s : type_1 if s == type_0 else s 60 | 61 | new_fields = map(map_fn, fields) 62 | 63 | assert T_ is T.unbound_t[new_fields] 64 | 65 | 66 | @pytest.mark.parametrize("type_0", [A, B, C, D, E]) 67 | @pytest.mark.parametrize("type_1", [A, B, C, D, E]) 68 | @pytest.mark.parametrize("rebind_sub_types", [False, True]) 69 | def test_rebind_product(type_0, type_1, rebind_sub_types): 70 | field_dict = P0.field_dict 71 | P_ = P0.rebind(type_0, type_1, rebind_sub_types) 72 | 73 | if rebind_sub_types: 74 | map_fn = lambda s : type_1 if _issubclass(s, type_0) else s 75 | else: 76 | map_fn = lambda s : type_1 if s == type_0 else s 77 | 78 | new_fields = {} 79 | for k,v in field_dict.items(): 80 | new_fields[k] = map_fn(v) 81 | 82 | assert P_ is Product.from_fields('P0', new_fields) 83 | 84 | 85 | @pytest.mark.parametrize("rebind_sub_types", [False, True]) 86 | def test_rebind_recursive(rebind_sub_types): 87 | S_ = S1.rebind(B, A, rebind_sub_types) 88 | if rebind_sub_types: 89 | gold = Sum[ 90 | P0.rebind(B, A, rebind_sub_types), 91 | S0.rebind(B, A, rebind_sub_types), 92 | T0.rebind(B, A, rebind_sub_types), 93 | A 94 | ] 95 | else: 96 | gold = Sum[ 97 | P0.rebind(B, A, rebind_sub_types), 98 | S0.rebind(B, A, rebind_sub_types), 99 | T0.rebind(B, A, rebind_sub_types), 100 | D 101 | ] 102 | 103 | assert S_ is gold 104 | 105 | P_ = P1.rebind(B, A, rebind_sub_types) 106 | if rebind_sub_types: 107 | gold = Product.from_fields('P1', { 108 | 'P0' : P0.rebind(B, A, rebind_sub_types), 109 | 'S0' : S0.rebind(B, A, rebind_sub_types), 110 | 'T0' : T0.rebind(B, A, rebind_sub_types), 111 | 'D' : A 112 | }) 113 | else: 114 | gold = Product.from_fields('P1', { 115 | 'P0' : P0.rebind(B, A, rebind_sub_types), 116 | 'S0' : S0.rebind(B, A, rebind_sub_types), 117 | 'T0' : T0.rebind(B, A, rebind_sub_types), 118 | 'D' : D 119 | }) 120 | 121 | 122 | assert P_ is gold 123 | 124 | 125 | class P(Product): 126 | X = AbstractBitVector[16] 127 | S = Sum[AbstractBitVector[4], AbstractBitVector[8]] 128 | T = Tuple[AbstractBitVector[32]] 129 | class F(Product): 130 | Y = AbstractBitVector 131 | 132 | 133 | def test_rebind_bv(): 134 | P_bound = rebind_bitvector(P, AbstractBitVector, BitVector) 135 | assert P_bound.X == BitVector[16] 136 | assert P_bound.S == Sum[BitVector[4], BitVector[8]] 137 | assert P_bound.T[0] == BitVector[32] 138 | assert P_bound.F.Y == BitVector 139 | 140 | P_unbound = rebind_bitvector(P_bound, BitVector, AbstractBitVector) 141 | assert P_unbound.X == AbstractBitVector[16] 142 | assert P_unbound.S == Sum[AbstractBitVector[4], AbstractBitVector[8]] 143 | assert P_unbound.T[0] == AbstractBitVector[32] 144 | 145 | 146 | def test_rebind_bv_Tu(): 147 | class T(TaggedUnion): 148 | a=AbstractBit 149 | b=AbstractBitVector[5] 150 | 151 | T_bv_rebound = rebind_bitvector(T, AbstractBitVector, BitVector) 152 | assert T_bv_rebound.b == BitVector[5] 153 | T_bit_rebound = T_bv_rebound.rebind(AbstractBit, Bit, True) 154 | assert T_bit_rebound.b == BitVector[5] 155 | assert T_bit_rebound.a == Bit 156 | 157 | def test_issue_74(): 158 | class A(Product): 159 | a = Bit 160 | 161 | A_smt = A.rebind(AbstractBit, SMTBit, True) 162 | assert A_smt.a is SMTBit 163 | 164 | def test_rebind_mod(): 165 | M = make_modifier("M") 166 | class A(Product): 167 | a=M(Bit) 168 | b=M(BitVector[4]) 169 | 170 | A_smt = rebind_bitvector(A,AbstractBitVector, SMTBitVector, True) 171 | A_smt = rebind_keep_modifiers(A_smt, AbstractBit, SMTBit) 172 | assert A_smt.b == M(SMTBitVector[4]) 173 | assert A_smt.a == M(SMTBit) 174 | 175 | #Tests an issue with rebind_bitvector and Sum 176 | #The issue: 177 | # If we had Sum[A,B] and A was contained within B 178 | # (like a field of a Tuple or Product), rebind would fail non-deterministically. 179 | def test_sum_issue(): 180 | for _ in range(1000): 181 | BV = BitVector 182 | SBV = SMTBitVector 183 | T = Tuple[BV[8],BV[1]] 184 | T_smt = rebind_bitvector(T,AbstractBitVector,SMTBitVector, True) 185 | S = Sum[T,BV[8]] 186 | S_smt = rebind_bitvector(S,AbstractBitVector,SMTBitVector, True) 187 | assert T_smt in S_smt.fields 188 | 189 | def test_rebind_type(): 190 | class E(Enum): 191 | a=1 192 | b=2 193 | def gen_type(family): 194 | Bit = family.Bit 195 | BV = family.BitVector 196 | class A(Product, cache=True): 197 | a=Sum[Bit,BV[8],E] 198 | b=BV[16] 199 | c=Tuple[E,Bit,BV[7]] 200 | return A 201 | A_bv = gen_type(Bit.get_family()) 202 | A_smt = gen_type(SMTBit.get_family()) 203 | Ar_smt = rebind_type(A_bv, SMTBit.get_family()) 204 | Ar_bv = rebind_type(A_smt, Bit.get_family()) 205 | assert A_bv == Ar_bv 206 | assert A_smt == Ar_smt 207 | -------------------------------------------------------------------------------- /hwtypes/fp_vector_abc.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | import typing as tp 3 | import weakref 4 | import warnings 5 | import enum 6 | 7 | from . import AbstractBitVectorMeta, AbstractBitVector, AbstractBit 8 | 9 | class RoundingMode(enum.Enum): 10 | RNE = enum.auto() # roundTiesToEven 11 | RNA = enum.auto() # roundTiesToAway 12 | RTP = enum.auto() # roundTowardPositive 13 | RTN = enum.auto() # roundTowardNegative 14 | RTZ = enum.auto() # roundTowardZero 15 | 16 | class AbstractFPVectorMeta(ABCMeta): 17 | # FPVectorType, (eb, mb, mode, ieee_compliance) : FPVectorType[eb, mb, mode, ieee_compliance] 18 | _class_cache = weakref.WeakValueDictionary() 19 | 20 | def __new__(mcs, name, bases, namespace, info=(None, None), **kwargs): 21 | if '_info_' in namespace: 22 | raise TypeError('class attribute _info_ is reversed by the type machinery') 23 | 24 | binding = info[1] 25 | for base in bases: 26 | if getattr(base, 'is_bound', False): 27 | if binding is None: 28 | binding = base.binding 29 | elif binding != base.binding: 30 | raise TypeError("Can't inherit from multiple FP types") 31 | 32 | namespace['_info_'] = info[0], binding 33 | t = super().__new__(mcs, name, bases, namespace, **kwargs) 34 | 35 | if binding is None: 36 | #class is unbound so t.unbound_t -> t 37 | t._info_ = t, binding 38 | elif info[0] is None: 39 | #class inherited from bound type so there is no unbound_t 40 | t._info_ = None, binding 41 | 42 | return t 43 | 44 | def __getitem__(cls, idx : tp.Tuple[int, int, RoundingMode, bool]): 45 | mcs = type(cls) 46 | try: 47 | return mcs._class_cache[cls, idx] 48 | except KeyError: 49 | pass 50 | 51 | if cls.is_bound: 52 | raise TypeError(f'{cls} is already bound') 53 | 54 | if len(idx) != 4 or tuple(map(type,idx)) != (int, int, RoundingMode, bool): 55 | raise IndexError(f'Constructing a floating point type requires:\n' 56 | 'exponent bits : int, mantisa bits : int, RoundingMode, ieee_compliance : bool') 57 | 58 | eb, mb, mode, ieee_compliance = idx 59 | 60 | if eb <= 0 or mb <= 0: 61 | raise ValueError('exponents bits and mantisa bits must be greater than 0') 62 | 63 | 64 | bases = [cls] 65 | bases.extend(b[idx] for b in cls.__bases__ if isinstance(b, AbstractFPVectorMeta)) 66 | bases = tuple(bases) 67 | class_name = f'{cls.__name__}[{eb},{mb},{mode},{ieee_compliance}]' 68 | t = mcs(class_name, bases, {}, info=(cls, idx)) 69 | t.__module__ = cls.__module__ 70 | mcs._class_cache[cls, idx] = t 71 | return t 72 | 73 | @property 74 | def unbound_t(cls) -> 'AbstractBitVectorMeta': 75 | t = cls._info_[0] 76 | if t is not None: 77 | return t 78 | else: 79 | raise AttributeError('type {} has no unsized_t'.format(cls)) 80 | 81 | @property 82 | def size(cls): 83 | return 1 + cls.exponent_size + cls.mantissa_size 84 | 85 | @property 86 | def is_bound(cls): 87 | return cls.binding is not None 88 | 89 | @property 90 | def binding(cls): 91 | return cls._info_[1] 92 | 93 | @property 94 | def exponent_size(cls): 95 | if cls.is_bound: 96 | return cls.binding[0] 97 | else: 98 | raise AttributeError('unbound type has no exponent_size') 99 | 100 | @property 101 | def mantissa_size(cls): 102 | if cls.is_bound: 103 | return cls.binding[1] 104 | else: 105 | raise AttributeError('unbound type has no mantissa_size') 106 | 107 | @property 108 | def mode(cls) -> RoundingMode: 109 | if cls.is_bound: 110 | return cls.binding[2] 111 | else: 112 | raise AttributeError('unbound type has no mode') 113 | 114 | @property 115 | def ieee_compliance(cls) -> bool: 116 | if cls.is_bound: 117 | return cls.binding[3] 118 | else: 119 | raise AttributeError('unbound type has no ieee_compliance') 120 | 121 | class AbstractFPVector(metaclass=AbstractFPVectorMeta): 122 | @property 123 | def size(self) -> int: 124 | return type(self).size 125 | 126 | @abstractmethod 127 | def fp_abs(self) -> 'AbstractFPVector': pass 128 | 129 | @abstractmethod 130 | def fp_neg(self) -> 'AbstractFPVector': pass 131 | 132 | @abstractmethod 133 | def fp_add(self, other: 'AbstractFPVector') -> 'AbstractFPVector': pass 134 | 135 | @abstractmethod 136 | def fp_sub(self, other: 'AbstractFPVector') -> 'AbstractFPVector': pass 137 | 138 | @abstractmethod 139 | def fp_mul(self, other: 'AbstractFPVector') -> 'AbstractFPVector': pass 140 | 141 | @abstractmethod 142 | def fp_div(self, other: 'AbstractFPVector') -> 'AbstractFPVector': pass 143 | 144 | @abstractmethod 145 | def fp_fma(self, coef: 'AbstractFPVector', offset: 'AbstractFPVector') -> 'AbstractFPVector': pass 146 | 147 | @abstractmethod 148 | def fp_sqrt(self) -> 'AbstractFPVector': pass 149 | 150 | @abstractmethod 151 | def fp_rem(self, other: 'AbstractFPVector') -> 'AbstractFPVector': pass 152 | 153 | @abstractmethod 154 | def fp_round_to_integral(self) -> 'AbstractFPVector': pass 155 | 156 | @abstractmethod 157 | def fp_min(self, other: 'AbstractFPVector') -> 'AbstractFPVector': pass 158 | 159 | @abstractmethod 160 | def fp_max(self, other: 'AbstractFPVector') -> 'AbstractFPVector': pass 161 | 162 | @abstractmethod 163 | def fp_leq(self, other: 'AbstractFPVector') -> AbstractBit: pass 164 | 165 | @abstractmethod 166 | def fp_lt(self, other: 'AbstractFPVector') -> AbstractBit: pass 167 | 168 | @abstractmethod 169 | def fp_geq(self, other: 'AbstractFPVector') -> AbstractBit: pass 170 | 171 | @abstractmethod 172 | def fp_gt(self, other: 'AbstractFPVector') -> AbstractBit: pass 173 | 174 | @abstractmethod 175 | def fp_eq(self, other: 'AbstractFPVector') -> AbstractBit: pass 176 | 177 | @abstractmethod 178 | def fp_is_normal(self) -> AbstractBit: pass 179 | 180 | @abstractmethod 181 | def fp_is_subnormal(self) -> AbstractBit: pass 182 | 183 | @abstractmethod 184 | def fp_is_zero(self) -> AbstractBit: pass 185 | 186 | @abstractmethod 187 | def fp_is_infinite(self) -> AbstractBit: pass 188 | 189 | @abstractmethod 190 | def fp_is_NaN(self) -> AbstractBit: pass 191 | 192 | @abstractmethod 193 | def fp_is_negative(self) -> AbstractBit: pass 194 | 195 | @abstractmethod 196 | def fp_is_positive(self) -> AbstractBit: pass 197 | 198 | @abstractmethod 199 | def to_ubv(self, size : int) -> AbstractBitVector: pass 200 | 201 | @abstractmethod 202 | def to_sbv(self, size : int) -> AbstractBitVector: pass 203 | 204 | @abstractmethod 205 | def reinterpret_as_bv(self) -> AbstractBitVector: pass 206 | 207 | @classmethod 208 | @abstractmethod 209 | def reinterpret_from_bv(self, value: AbstractBitVector) -> 'AbstractFPVector': pass 210 | -------------------------------------------------------------------------------- /hwtypes/smt_fp_vector.py: -------------------------------------------------------------------------------- 1 | import weakref 2 | 3 | import pysmt 4 | import pysmt.shortcuts as shortcuts 5 | from pysmt.typing import BVType, BOOL, FunctionType 6 | 7 | from .smt_bit_vector import SMTBit, SMTBitVector, SMTSIntVector, SMYBOLIC, AUTOMATIC 8 | from .fp_vector_abc import AbstractFPVector, RoundingMode 9 | 10 | # using None to represent cls 11 | _SIGS = [ 12 | ('fp_abs', None, None,), 13 | ('fp_neg', None, None,), 14 | ('fp_add', None, None, None,), 15 | ('fp_sub', None, None, None,), 16 | ('fp_mul', None, None, None,), 17 | ('fp_div', None, None, None,), 18 | ('fp_fma', None, None, None, None,), 19 | ('fp_sqrt', None, None,), 20 | ('fp_rem', None, None,None), 21 | ('fp_round_to_integral', None, None,), 22 | ('fp_min', None, None,), 23 | ('fp_max', None, None,), 24 | ('fp_leq', None, None, BOOL,), 25 | ('fp_lt', None, None, BOOL,), 26 | ('fp_geq', None, None, BOOL,), 27 | ('fp_gt', None, None, BOOL,), 28 | ('fp_eq', None, None, BOOL,), 29 | ('fp_is_normal', None, BOOL,), 30 | ('fp_is_subnormal', None, BOOL,), 31 | ('fp_is_zero', None, BOOL,), 32 | ('fp_is_infinite', None, BOOL,), 33 | ('fp_is_NaN', None, BOOL,), 34 | ('fp_is_negative', None, BOOL,), 35 | ('fp_is_positive', None, BOOL,), 36 | ] 37 | 38 | _name_table = weakref.WeakValueDictionary() 39 | _uf_table = weakref.WeakKeyDictionary() 40 | 41 | 42 | class SMTFPVector(AbstractFPVector): 43 | def __init__(self, value=SMYBOLIC, *, name=AUTOMATIC): 44 | if name is not AUTOMATIC and value is not SMYBOLIC: 45 | raise TypeError('Can only name symbolic variables') 46 | elif name is not AUTOMATIC: 47 | if not isinstance(name, str): 48 | raise TypeError('Name must be string') 49 | elif name in _name_table: 50 | raise ValueError(f'Name {name} already in use') 51 | _name_table[name] = self 52 | 53 | T = BVType(self.size) 54 | 55 | if value is SMYBOLIC: 56 | if name is AUTOMATIC: 57 | value = shortcuts.FreshSymbol(T) 58 | else: 59 | value = shortcuts.Symbol(name, T) 60 | elif isinstance(value, pysmt.fnode.FNode): 61 | t = value.get_type() 62 | if t is not T: 63 | raise TypeError(f'Expected {T} not {t}') 64 | elif isinstance(value, type(self)): 65 | value = value._value 66 | elif isinstance(value, int): 67 | value = shortcuts.BV(value, self.size) 68 | else: 69 | raise TypeError(f"Can't coerce {value} to SMTFPVector") 70 | 71 | self._name = name 72 | self._value = value 73 | 74 | def __init_subclass__(cls): 75 | _uf_table[cls] = ufs = dict() 76 | T = BVType(cls.size) 77 | for method_name, *args in _SIGS: 78 | args = [T if x is None else x for x in args] 79 | rtype = args[-1] 80 | params = args[:-1] 81 | name = '.'.join((cls.__name__, method_name)) 82 | ufs[method_name] = shortcuts.Symbol(name, FunctionType(rtype, params)) 83 | 84 | ufs['to_sbv'] = dict() 85 | ufs['to_ubv'] = dict() 86 | 87 | @classmethod 88 | def _fp_method(cls, method_name, *args): 89 | uf = _uf_table[cls][method_name] 90 | return cls(uf(*args)) 91 | 92 | @classmethod 93 | def _bit_method(cls, method_name, *args): 94 | uf = _uf_table[cls][method_name] 95 | return SMTBit(uf(*args)) 96 | 97 | def fp_abs(self) -> 'SMTFPVector': 98 | return self._fp_method('fp_abs', self._value) 99 | 100 | def fp_neg(self) -> 'SMTFPVector': 101 | return self._fp_method('fp_neg', self._value) 102 | 103 | def fp_add(self, other: 'SMTFPVector') -> 'SMTFPVector': 104 | return self._fp_method('fp_add', self._value, other._value) 105 | 106 | def fp_sub(self, other: 'SMTFPVector') -> 'SMTFPVector': 107 | return self._fp_method('fp_sub', self._value, other._value) 108 | 109 | def fp_mul(self, other: 'SMTFPVector') -> 'SMTFPVector': 110 | return self._fp_method('fp_mul', self._value, other._value) 111 | 112 | def fp_div(self, other: 'SMTFPVector') -> 'SMTFPVector': 113 | return self._fp_method('fp_div', self._value, other._value) 114 | 115 | def fp_fma(self, coef: 'SMTFPVector', offset: 'SMTFPVector') -> 'SMTFPVector': 116 | return self._fp_method('fp_fma', self._value, coef._value, offset._value) 117 | 118 | def fp_sqrt(self) -> 'SMTFPVector': 119 | return self._fp_method('fp_sqrt', self._value) 120 | 121 | def fp_rem(self, other: 'SMTFPVector') -> 'SMTFPVector': 122 | return self._fp_method('fp_rem', self._value, other._value) 123 | 124 | def fp_round_to_integral(self) -> 'SMTFPVector': 125 | return self._fp_method('fp_round_to_integral', self._value) 126 | 127 | def fp_min(self, other: 'SMTFPVector') -> 'SMTFPVector': 128 | return self._fp_method('fp_min', self._value, other._value) 129 | 130 | def fp_max(self, other: 'SMTFPVector') -> 'SMTFPVector': 131 | return self._fp_method('fp_max', self._value, other._value) 132 | 133 | def fp_leq(self, other: 'SMTFPVector') -> SMTBit: 134 | return self._bit_method('fp_leq', self._value, other._value) 135 | 136 | def fp_lt(self, other: 'SMTFPVector') -> SMTBit: 137 | return self._bit_method('fp_lt', self._value, other._value) 138 | 139 | def fp_geq(self, other: 'SMTFPVector') -> SMTBit: 140 | return self._bit_method('fp_geq', self._value, other._value) 141 | 142 | def fp_gt(self, other: 'SMTFPVector') -> SMTBit: 143 | return self._bit_method('fp_gt', self._value, other._value) 144 | 145 | def fp_eq(self, other: 'SMTFPVector') -> SMTBit: 146 | return self._bit_method('fp_eq', self._value, other._value) 147 | 148 | def fp_is_normal(self) -> SMTBit: 149 | return self._bit_method('fp_is_normal', self._value) 150 | 151 | def fp_is_subnormal(self) -> SMTBit: 152 | return self._bit_method('fp_is_subnormal', self._value) 153 | 154 | def fp_is_zero(self) -> SMTBit: 155 | return self._bit_method('fp_is_zero', self._value) 156 | 157 | def fp_is_infinite(self) -> SMTBit: 158 | return self._bit_method('fp_is_infinite', self._value) 159 | 160 | def fp_is_NaN(self) -> SMTBit: 161 | return self._bit_method('fp_is_NaN', self._value) 162 | 163 | def fp_is_negative(self) -> SMTBit: 164 | return self._bit_method('fp_is_negative', self._value) 165 | 166 | def fp_is_positive(self) -> SMTBit: 167 | return self._bit_method('fp_is_positive', self._value) 168 | 169 | def to_ubv(self, size : int) -> SMTBitVector: 170 | cls = type(self) 171 | ufs = _uf_table[cls]['to_ubv'] 172 | if size not in ufs: 173 | name = '.'.join((cls.__name__, f'to_ubv[{size}]')) 174 | ufs[size] = shortcuts.Symbol( 175 | name, 176 | FunctionType(BVType(size), (BVType(self.size),)) 177 | ) 178 | 179 | return SMTBitVector[size](ufs[size](self._value)) 180 | 181 | def to_sbv(self, size : int) -> SMTBitVector: 182 | cls = type(self) 183 | ufs = _uf_table[cls]['to_usbv'] 184 | if size not in ufs: 185 | name = '.'.join((cls.__name__, f'to_sbv[{size}]')) 186 | ufs[size] = shortcuts.Symbol( 187 | name, 188 | FunctionType(BVType(size), (BVType(self.size),)) 189 | ) 190 | 191 | return SMTBitVector[size](ufs[size](self._value)) 192 | 193 | def reinterpret_as_bv(self) -> SMTBitVector: 194 | return SMTBitVector[self.size](self._value) 195 | 196 | @classmethod 197 | def reinterpret_from_bv(cls, value: SMTBitVector) -> 'SMTFPVector': 198 | return cls(value._value) 199 | 200 | def __neg__(self): return self.fp_neg() 201 | def __abs__(self): return self.fp_abs() 202 | def __add__(self, other): return self.fp_add(other) 203 | def __sub__(self, other): return self.fp_sub(other) 204 | def __mul__(self, other): return self.fp_mul(other) 205 | def __truediv__(self, other): return self.fp_div(other) 206 | def __mod__(self, other): return self.fp_rem(other) 207 | 208 | def __eq__(self, other): return self.fp_eq(other) 209 | def __ne__(self, other): return ~(self.fp_eq(other)) 210 | def __ge__(self, other): return self.fp_geq(other) 211 | def __gt__(self, other): return self.fp_gt(other) 212 | def __le__(self, other): return self.fp_leq(other) 213 | def __lt__(self, other): return self.fp_lt(other) 214 | -------------------------------------------------------------------------------- /hwtypes/bit_vector_util.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import functools as ft 3 | import inspect 4 | import types 5 | 6 | from .util import Method 7 | from .bit_vector_abc import InconsistentSizeError 8 | from .bit_vector_abc import BitVectorMeta, AbstractBitVector, AbstractBit 9 | 10 | 11 | # needed to define isinstance(..., BitVectorProtocolMeta) 12 | class _BitVectorProtocolMetaMeta(type): 13 | def __instancecheck__(mcs, cls): 14 | # mcs should be BitVectorProtocolMeta 15 | return issubclass(type(cls), mcs) 16 | 17 | def __subclasscheck__(mcs, sub): 18 | # `not isinstance(sub, mcs)` blocks BitVectorProtocol from 19 | # looking like a subclass of BitVectorProtocolMeta 20 | # which it otherwise would because its type defines `_bitvector_t_`. 21 | return hasattr(sub, '_bitvector_t_') and not isinstance(sub, mcs) 22 | 23 | 24 | class BitVectorProtocolMeta(type, metaclass=_BitVectorProtocolMetaMeta): 25 | # Any type that has _bitvector_t_ shall be considered 26 | # a instance of BitVectorProtocolMeta 27 | def __instancecheck__(cls, obj): 28 | return issubclass(type(obj), cls) 29 | 30 | def __subclasscheck__(cls, sub): 31 | return (isinstance(sub, type(cls)) 32 | and hasattr(sub, '_from_bitvector_') 33 | and hasattr(sub, '_to_bitvector_')) 34 | 35 | @abstractmethod 36 | def _bitvector_t_(cls): pass 37 | 38 | 39 | class BitVectorProtocol(metaclass=BitVectorProtocolMeta): 40 | # Any object whose type is an instance of BitVectorProtocolMeta (see above) 41 | # and which defines _to_bitvector_ and _from_bitvector_ shall be considered 42 | # an instance of BitVectorProtocol 43 | @classmethod 44 | @abstractmethod 45 | def _from_bitvector_(cls, bv): pass 46 | 47 | @abstractmethod 48 | def _to_bitvector_(self): pass 49 | 50 | 51 | # used as a tag 52 | class PolyBase: pass 53 | 54 | class PolyType(type): 55 | def __getitem__(cls, args): 56 | # From a typing perspective it would be better to make select an 57 | # argument to init instead of making it a type param. This would 58 | # allow types to be cached etc... However, making it an init arg 59 | # means type(self)(val) is no longer sufficient to to cast val. 60 | # Instead one would need to write type(self)(val, self._select_) or 61 | # equivalent and hence would require a major change in the engineering 62 | # of bitvector types (they would need to be aware of polymorphism). 63 | # Note we can't cache as select is not necessarily hashable. 64 | 65 | T0, T1, select = args 66 | 67 | if not cls._type_check(T0, T1): 68 | raise TypeError(f'Cannot construct {cls} from {T0} and {T1}') 69 | if not isinstance(select, AbstractBit): 70 | raise TypeError('select must be a Bit') 71 | if (T0.get_family() is not T1.get_family() 72 | or T0.get_family() is not select.get_family()): 73 | raise TypeError('Cannot construct PolyTypes across families') 74 | 75 | 76 | # stupid generator to make sure PolyBase is not replicated 77 | # and always comes last 78 | bases = *(b for b in cls._get_bases(T0, T1) if b is not PolyBase), PolyBase 79 | class_name = f'{cls.__name__}[{T0.__name__}, {T1.__name__}, {select}]' 80 | meta, namespace, _ = types.prepare_class(class_name, bases) 81 | 82 | d0 = dict(inspect.getmembers(T0)) 83 | d1 = dict(inspect.getmembers(T1)) 84 | 85 | attrs = d0.keys() & d1.keys() 86 | for k in attrs: 87 | if k in {'_info_', '__int__', '__repr__', '__str__'}: 88 | continue 89 | 90 | m0 = inspect.getattr_static(T0, k) 91 | m1 = inspect.getattr_static(T1, k) 92 | namespace[k] = build_VCall(select, m0, m1) 93 | 94 | 95 | new_cls = meta(class_name, bases, namespace) 96 | final = cls._finalize(new_cls, T0, T1) 97 | return final 98 | 99 | 100 | def _get_common_bases(T0, T1): 101 | if issubclass(T0, T1): 102 | return T1, 103 | elif issubclass(T1, T0): 104 | return T0, 105 | else: 106 | bases = set() 107 | for t in T0.__bases__: 108 | bases.update(_get_common_bases(t, T1)) 109 | 110 | for t in T1.__bases__: 111 | bases.update(_get_common_bases(t, T0)) 112 | 113 | # Filter to most specific types 114 | bases_ = set() 115 | for bi in bases: 116 | if not any(issubclass(bj, bi) for bj in bases if bi is not bj): 117 | bases_.add(bi) 118 | 119 | return tuple(bases_) 120 | 121 | class PolyVector(metaclass=PolyType): 122 | @classmethod 123 | def _type_check(cls, T0, T1): 124 | if (issubclass(T0, AbstractBitVector) 125 | and issubclass(T1, AbstractBitVector)): 126 | if T0.size != T1.size: 127 | raise InconsistentSizeError(f'Cannot construct {cls} from {T0} and {T1}') 128 | else: 129 | return True 130 | else: 131 | return False 132 | 133 | @classmethod 134 | def _get_bases(cls, T0, T1): 135 | bases = _get_common_bases(T0, T1) 136 | 137 | # get the unsized versions 138 | bases_ = set() 139 | for base in bases: 140 | try: 141 | bases_.add(base.unsized_t) 142 | except AttributeError: 143 | bases_.add(base) 144 | 145 | return tuple(bases_) 146 | 147 | 148 | @classmethod 149 | def _finalize(cls, new_class, T0, T1): 150 | return new_class[T0.size] 151 | 152 | class PolyBit(metaclass=PolyType): 153 | @classmethod 154 | def _type_check(cls, T0, T1): 155 | return (issubclass(T0, AbstractBit) 156 | and issubclass(T1, AbstractBit)) 157 | 158 | @classmethod 159 | def _get_bases(cls, T0, T1): 160 | return _get_common_bases(T0, T1) 161 | 162 | @classmethod 163 | def _finalize(cls, new_class, T0, T1): 164 | return new_class 165 | 166 | def build_VCall(select, m0, m1): 167 | if m0 is m1: 168 | return m0 169 | else: 170 | def VCall(*args, **kwargs): 171 | v0 = m0(*args, **kwargs) 172 | v1 = m1(*args, **kwargs) 173 | if v0 is NotImplemented or v0 is NotImplemented: 174 | return NotImplemented 175 | return select.ite(v0, v1) 176 | return Method(VCall) 177 | 178 | 179 | def get_branch_type(branch): 180 | if isinstance(branch, tuple): 181 | return tuple(map(get_branch_type, branch)) 182 | else: 183 | return type(branch) 184 | 185 | def determine_return_type(select, t_branch, f_branch): 186 | def _recurse(t_branch, f_branch): 187 | tb_t = get_branch_type(t_branch) 188 | fb_t = get_branch_type(f_branch) 189 | 190 | if (isinstance(tb_t, tuple) 191 | and isinstance(fb_t, tuple) 192 | and len(tb_t) == len(fb_t)): 193 | try: 194 | return tuple(_recurse(t, f) for t, f in zip(t_branch, f_branch)) 195 | except (TypeError, InconsistentSizeError): 196 | raise TypeError(f'Branches have inconsistent types: ' 197 | f'{tb_t} and {fb_t}') 198 | elif (isinstance(tb_t, tuple) 199 | or isinstance(fb_t, tuple)): 200 | raise TypeError(f'Branches have inconsistent types: {tb_t} and {fb_t}') 201 | elif issubclass(tb_t, AbstractBit) and issubclass(fb_t, AbstractBit): 202 | if tb_t is fb_t: 203 | return tb_t 204 | return PolyBit[tb_t, fb_t, select] 205 | elif issubclass(tb_t, AbstractBitVector) and issubclass(fb_t, AbstractBitVector): 206 | if tb_t is fb_t: 207 | return tb_t 208 | return PolyVector[tb_t, fb_t, select] 209 | elif tb_t is fb_t and issubclass(tb_t, BitVectorProtocol): 210 | def from_bv_args(val): 211 | return tb_t._from_bitvector_(tb_t._bitvector_t_()(val)) 212 | return from_bv_args 213 | else: 214 | raise TypeError(f'tb_t: {tb_t}, fb_t: {fb_t}') 215 | 216 | return _recurse(t_branch, f_branch) 217 | 218 | def coerce_branch(r_type, branch): 219 | if isinstance(r_type, tuple): 220 | assert isinstance(branch, tuple) 221 | assert len(r_type) == len(branch) 222 | return tuple(coerce_branch(t, arg) for t, arg in zip(r_type, branch)) 223 | else: 224 | return r_type(branch) 225 | 226 | def push_ite(ite, select, t_branch, f_branch): 227 | def _recurse(t_branch, f_branch): 228 | if isinstance(t_branch, tuple): 229 | assert isinstance(f_branch, tuple) 230 | assert len(t_branch) == len(f_branch) 231 | return tuple(_recurse(t, f) for t, f in zip(t_branch, f_branch)) 232 | elif isinstance(t_branch, BitVectorProtocol): 233 | assert type(t_branch) is type(f_branch) 234 | return _recurse(t_branch._to_bitvector_(), f_branch._to_bitvector_()) 235 | else: 236 | return ite(select, t_branch, f_branch) 237 | return _recurse(t_branch, f_branch) 238 | 239 | def build_ite(ite, select, t_branch, f_branch): 240 | r_type = determine_return_type(select, t_branch, f_branch) 241 | r_val = push_ite(ite, select, t_branch, f_branch) 242 | r_val = coerce_branch(r_type, r_val) 243 | return r_val 244 | -------------------------------------------------------------------------------- /hwtypes/adt.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from types import MappingProxyType 3 | import typing as tp 4 | import warnings 5 | 6 | from .adt_meta import TupleMeta, AnonymousProductMeta, ProductMeta 7 | from .adt_meta import SumMeta, TaggedUnionMeta, EnumMeta, is_adt_type 8 | 9 | __all__ = ['Tuple', 'Product', 'Sum', 'Enum', 'TaggedUnion'] 10 | __all__ += ['new_instruction', 'is_adt_type'] 11 | 12 | #special sentinal value 13 | class _MISSING: pass 14 | 15 | class Tuple(metaclass=TupleMeta): 16 | def __new__(cls, *value): 17 | if cls.is_bound: 18 | return super().__new__(cls) 19 | else: 20 | idx = tuple(type(v) for v in value) 21 | return cls[idx].__new__(cls[idx], *value) 22 | 23 | def __init__(self, *value): 24 | cls = type(self) 25 | if len(value) != len(cls.fields): 26 | raise ValueError('Incorrect number of arguments') 27 | for v,t in zip(value, cls.fields): 28 | if not isinstance(v,t): 29 | raise TypeError('Value {} is not of type {}'.format(repr(v), repr(t))) 30 | 31 | self._value_ = value 32 | 33 | def __eq__(self, other): 34 | if isinstance(other, type(self)): 35 | return self._value_ == other._value_ 36 | else: 37 | return NotImplemented 38 | 39 | def __ne__(self, other): 40 | return not (self == other) 41 | 42 | def __hash__(self): 43 | return hash(self._value_) 44 | 45 | def __getitem__(self, idx): 46 | return self._value_[idx] 47 | 48 | def __setitem__(self, idx, value): 49 | if isinstance(value, type(self)[idx]): 50 | v = list(self._value_) 51 | v[idx] = value 52 | self._value_ = v 53 | else: 54 | raise TypeError(f'Value {value} is not of type {type(self)[idx]}') 55 | 56 | def __repr__(self): 57 | return f'{type(self).__name__}({", ".join(map(repr, self._value_))})' 58 | 59 | @property 60 | def value_dict(self): 61 | d = {} 62 | for k in type(self).field_dict: 63 | d[k] = self[k] 64 | return MappingProxyType(d) 65 | 66 | @classmethod 67 | def from_values(cls, value_dict): 68 | if value_dict.keys() != cls.field_dict.keys(): 69 | raise ValueError('Keys do not match field_dict') 70 | 71 | for k, v in value_dict.items(): 72 | if not isinstance(v, cls.field_dict[k]): 73 | raise TypeError(f'Expected object of type {cls.field_dict[k]}' 74 | f' got {v}') 75 | 76 | 77 | # So Product can use the type checking logic above 78 | return cls._from_kwargs(value_dict) 79 | 80 | 81 | @classmethod 82 | def _from_kwargs(cls, kwargs): 83 | args = [None]*len(kwargs) 84 | for k, v in kwargs.items(): 85 | args[k] = v 86 | 87 | return cls(*args) 88 | 89 | @property 90 | def value(self): 91 | warnings.warn('DEPRECATION WARNING: ADT.value is deprecated', DeprecationWarning, 2) 92 | return self._value_ 93 | 94 | class AnonymousProduct(Tuple, metaclass=AnonymousProductMeta): 95 | def __repr__(self): 96 | return f'{type(self).__name__}({", ".join(f"{k}={v}" for k,v in self.value_dict.items())})' 97 | 98 | @property 99 | def value_dict(self): 100 | d = OrderedDict() 101 | for k in type(self).field_dict: 102 | d[k] = getattr(self, k) 103 | return MappingProxyType(d) 104 | 105 | @classmethod 106 | def _from_kwargs(cls, kwargs): 107 | return cls(**kwargs) 108 | 109 | 110 | class Product(AnonymousProduct, metaclass=ProductMeta): 111 | pass 112 | 113 | class Sum(metaclass=SumMeta): 114 | class Match: 115 | __slots__ = ('_match', '_value', '_safe') 116 | def __init__(self, match, value, *, safe: bool = True): 117 | self._match = match 118 | self._value = value 119 | self._safe = safe 120 | 121 | @property 122 | def match(self): 123 | return self._match 124 | 125 | @property 126 | def value(self): 127 | if not self._safe or self.match: 128 | return self._value 129 | else: 130 | raise TypeError(f'No value for unmatched type') 131 | 132 | def __init__(self, value): 133 | if type(value) not in type(self): 134 | raise TypeError(f'Value {value} is not of types {type(self).fields}') 135 | self._value_ = value 136 | 137 | def __eq__(self, other): 138 | if isinstance(other, type(self)): 139 | return self._value_ == other._value_ 140 | else: 141 | return NotImplemented 142 | 143 | def __ne__(self, other): 144 | return not (self == other) 145 | 146 | def __hash__(self): 147 | return hash(self._value_) 148 | 149 | def __getitem__(self, T) -> 'Sum.Match': 150 | cls = type(self) 151 | if T in cls: 152 | return cls.Match(isinstance(self._value_, T), self._value_) 153 | else: 154 | raise TypeError(f'{T} not in {cls}') 155 | 156 | 157 | def __setitem__(self, T, value): 158 | if T not in type(self): 159 | raise TypeError(f'indices must be in {type(self).fields} not {T}') 160 | elif not isinstance(value, T): 161 | raise TypeError(f'expected {T} not {type(value)}') 162 | else: 163 | self._value_ = value 164 | 165 | def __repr__(self) -> str: 166 | return f'{type(self)}({self._value_})' 167 | 168 | @property 169 | def value_dict(self): 170 | d = {} 171 | for k,t in type(self).field_dict.items(): 172 | if self[t].match: 173 | d[k] = self._value_ 174 | else: 175 | d[k] = None 176 | return MappingProxyType(d) 177 | 178 | @classmethod 179 | def from_values(cls, value_dict): 180 | if value_dict.keys() != cls.field_dict.keys(): 181 | raise ValueError('Keys do not match field_dict') 182 | 183 | kwargs = {} 184 | for k, v in value_dict.items(): 185 | if v is not None and not isinstance(v, cls.field_dict[k]): 186 | raise TypeError(f'Expected object of type {cls.field_dict[k]}' 187 | f' got {v}') 188 | elif v is not None: 189 | kwargs[k] = v 190 | 191 | if not len(kwargs) == 1: 192 | raise ValueError(f'value_dict must have exactly one non None entry') 193 | 194 | # So TaggedUnion can use the type checking logic above 195 | return cls._from_kwargs(kwargs) 196 | 197 | @classmethod 198 | def _from_kwargs(cls, kwargs): 199 | assert len(kwargs) == 1 200 | _, v = kwargs.popitem() 201 | return cls(v) 202 | 203 | @property 204 | def value(self): 205 | warnings.warn('DEPRECATION WARNING: ADT.value is deprecated', DeprecationWarning, 2) 206 | return self._value_ 207 | 208 | 209 | class TaggedUnion(Sum, metaclass=TaggedUnionMeta): 210 | def __init__(self, **kwargs): 211 | if len(kwargs) == 0: 212 | raise ValueError('Must specify a value') 213 | elif len(kwargs) != 1: 214 | raise ValueError('Expected one value') 215 | 216 | cls = type(self) 217 | field, value = next(iter(kwargs.items())) 218 | if field not in cls.field_dict: 219 | raise ValueError(f'Invalid field {field}') 220 | 221 | setattr(self, field, value) 222 | 223 | def __setitem__(self, T, value): 224 | raise TypeError(f'setitem syntax is not supported on {type(self)}') 225 | 226 | def __eq__(self, other): 227 | if isinstance(other, type(self)): 228 | return self._tag_ == other._tag_ and self._value_ == other._value_ 229 | else: 230 | return NotImplemented 231 | 232 | def __hash__(self): 233 | return hash(self._tag_) + hash(self._value_) 234 | 235 | def __repr__(self): 236 | return f'{type(self).__name__}({", ".join(f"{k}={v}" for k,v in self.value_dict.items())})' 237 | 238 | @property 239 | def value_dict(self): 240 | d = dict() 241 | for k in type(self).field_dict: 242 | m = getattr(self, k) 243 | if m.match: 244 | d[k] = m.value 245 | else: 246 | d[k] = None 247 | return MappingProxyType(d) 248 | 249 | @classmethod 250 | def _from_kwargs(cls, kwargs): 251 | assert len(kwargs) == 1 252 | return cls(**kwargs) 253 | 254 | 255 | class Enum(metaclass=EnumMeta): 256 | def __init__(self, value): 257 | self._value_ = value 258 | 259 | @property 260 | def name(self): 261 | return self._name_ 262 | 263 | def __repr__(self): 264 | return f'{type(self).__name__}.{self.name}' 265 | 266 | def __eq__(self, other): 267 | if isinstance(other, type(self)): 268 | return self._value_ == other._value_ 269 | else: 270 | return NotImplemented 271 | 272 | def __ne__(self, other): 273 | return not (self == other) 274 | 275 | def __hash__(self): 276 | return hash(self._value_) 277 | 278 | def __getattribute__(self, attr): 279 | # prevent: 280 | # class E(Enum): 281 | # a = 0 282 | # b = 1 283 | # E.a.b == E.b 284 | if attr in type(self).field_dict: 285 | raise AttributeError('Cannot access enum members from enum instances') 286 | else: 287 | return super().__getattribute__(attr) 288 | 289 | @property 290 | def value(self): 291 | warnings.warn('DEPRECATION WARNING: ADT.value is deprecated', DeprecationWarning, 3) 292 | return self._value_ 293 | 294 | def new_instruction(): 295 | return EnumMeta.Auto() 296 | -------------------------------------------------------------------------------- /hwtypes/bit_vector_abc.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from collections import namedtuple 3 | import typing as tp 4 | import functools as ft 5 | import weakref 6 | import warnings 7 | 8 | from .util import _issubclass 9 | 10 | TypeFamily = namedtuple('TypeFamily', ['Bit', 'BitVector', 'Unsigned', 'Signed']) 11 | 12 | # Should be raised when bv[k].op(bv[j]) and j != k 13 | 14 | class InconsistentSizeError(TypeError): pass 15 | 16 | #I want to be able differentiate an old style call 17 | #BitVector(val, None) from BitVector(val) 18 | _MISSING = object() 19 | class AbstractBitVectorMeta(type): #:(ABCMeta): 20 | # BitVectorType, size : BitVectorType[size] 21 | _class_cache = weakref.WeakValueDictionary() 22 | 23 | def __call__(cls, value=_MISSING, *args, **kwargs): 24 | if cls.is_sized: 25 | if value is _MISSING: 26 | return super().__call__(*args, **kwargs) 27 | else: 28 | return super().__call__(value, *args, **kwargs) 29 | else: 30 | warnings.warn('DEPRECATION WARNING: Use of implicitly sized ' 31 | 'BitVectors is deprecated', DeprecationWarning) 32 | 33 | if value is _MISSING: 34 | raise TypeError('Cannot construct {} without a value'.format(cls, value)) 35 | elif isinstance(value, AbstractBitVector): 36 | size = value.size 37 | elif isinstance(value, AbstractBit): 38 | size = 1 39 | elif isinstance(value, tp.Sequence): 40 | size = max(len(value), 1) 41 | elif isinstance(value, int): 42 | size = max(value.bit_length(), 1) 43 | elif hasattr(value, '__int__'): 44 | size = max(int(value).bit_length(), 1) 45 | else: 46 | raise TypeError('Cannot construct {} from {}'.format(cls, value)) 47 | 48 | return type(cls).__call__(cls[size], value, *args, **kwargs) 49 | 50 | 51 | def __new__(mcs, name, bases, namespace, info=(None, None), **kwargs): 52 | if '_info_' in namespace: 53 | raise TypeError('class attribute _info_ is reversed by the type machinery') 54 | 55 | size = info[1] 56 | for base in bases: 57 | if getattr(base, 'is_sized', False): 58 | if size is None: 59 | size = base.size 60 | elif size != base.size: 61 | raise TypeError("Can't inherit from multiple different sizes") 62 | 63 | namespace['_info_'] = info[0], size 64 | t = super().__new__(mcs, name, bases, namespace, **kwargs) 65 | if size is None: 66 | #class is unsized so t.unsized_t -> t 67 | t._info_ = t, size 68 | elif info[0] is None: 69 | #class inherited from sized types so there is no unsized_t 70 | t._info_ = None, size 71 | 72 | return t 73 | 74 | 75 | def __getitem__(cls, idx : int) -> 'AbstractBitVectorMeta': 76 | mcs = type(cls) 77 | try: 78 | return mcs._class_cache[cls, idx] 79 | except KeyError: 80 | pass 81 | 82 | if not isinstance(idx, int): 83 | raise TypeError('Size of BitVectors must be of type int not {}'.format(type(idx))) 84 | if idx < 0: 85 | raise ValueError('Size of BitVectors must be positive') 86 | 87 | if cls.is_sized: 88 | raise TypeError('{} is already sized'.format(cls)) 89 | 90 | bases = [cls] 91 | bases.extend(b[idx] for b in cls.__bases__ if isinstance(b, mcs)) 92 | bases = tuple(bases) 93 | class_name = '{}[{}]'.format(cls.__name__, idx) 94 | t = mcs(class_name, bases, {}, info=(cls,idx)) 95 | t.__module__ = cls.__module__ 96 | mcs._class_cache[cls, idx] = t 97 | return t 98 | 99 | @property 100 | def unsized_t(cls) -> 'AbstractBitVectorMeta': 101 | t = cls._info_[0] 102 | if t is not None: 103 | return t 104 | else: 105 | raise AttributeError('type {} has no unsized_t'.format(cls)) 106 | 107 | @property 108 | def size(cls) -> int: 109 | return cls._info_[1] 110 | 111 | @property 112 | def is_sized(cls) -> bool: 113 | return cls.size is not None 114 | 115 | def __len__(cls): 116 | if cls.is_sized: 117 | return cls.size 118 | else: 119 | raise AttributeError('unsized type has no len') 120 | 121 | def __repr__(cls): 122 | return cls.__name__ 123 | 124 | 125 | class AbstractBit(metaclass=ABCMeta): 126 | @staticmethod 127 | def get_family() -> TypeFamily: 128 | return _Family_ 129 | 130 | @abstractmethod 131 | def __eq__(self, other) -> 'AbstractBit': 132 | pass 133 | 134 | def __ne__(self, other) -> 'AbstractBit': 135 | return ~(self == other) 136 | 137 | @abstractmethod 138 | def __invert__(self) -> 'AbstractBit': 139 | pass 140 | 141 | @abstractmethod 142 | def __and__(self, other : 'AbstractBit') -> 'AbstractBit': 143 | pass 144 | 145 | @abstractmethod 146 | def __or__(self, other : 'AbstractBit') -> 'AbstractBit': 147 | pass 148 | 149 | @abstractmethod 150 | def __xor__(self, other : 'AbstractBit') -> 'AbstractBit': 151 | pass 152 | 153 | @abstractmethod 154 | def ite(self, t_branch, f_branch): 155 | pass 156 | 157 | class AbstractBitVector(metaclass=AbstractBitVectorMeta): 158 | @staticmethod 159 | def get_family() -> TypeFamily: 160 | return _Family_ 161 | 162 | @property 163 | def size(self) -> int: 164 | return type(self).size 165 | 166 | @classmethod 167 | @abstractmethod 168 | def make_constant(self, value, num_bits:tp.Optional[int]=None) -> 'AbstractBitVector': 169 | pass 170 | 171 | @abstractmethod 172 | def __getitem__(self, index) -> AbstractBit: 173 | pass 174 | 175 | @abstractmethod 176 | def __setitem__(self, index : int, value : AbstractBit): 177 | pass 178 | 179 | @abstractmethod 180 | def __len__(self) -> int: 181 | pass 182 | 183 | @abstractmethod 184 | def concat(self, other) -> 'AbstractBitVector': 185 | pass 186 | 187 | @abstractmethod 188 | def bvnot(self) -> 'AbstractBitVector': 189 | pass 190 | 191 | @abstractmethod 192 | def bvand(self, other) -> 'AbstractBitVector': 193 | pass 194 | 195 | def bvnand(self, other) -> 'AbstractBitVector': 196 | return self.bvand(other).bvnot() 197 | 198 | @abstractmethod 199 | def bvor(self, other) -> 'AbstractBitVector': 200 | pass 201 | 202 | def bvnor(self, other) -> 'AbstractBitVector': 203 | return self.bvor(other).bvnot() 204 | 205 | @abstractmethod 206 | def bvxor(self, other) -> 'AbstractBitVector': 207 | pass 208 | 209 | def bvxnor(self, other) -> 'AbstractBitVector': 210 | return self.bvxor(other).bvnot() 211 | 212 | @abstractmethod 213 | def bvshl(self, other) -> 'AbstractBitVector': 214 | pass 215 | 216 | @abstractmethod 217 | def bvlshr(self, other) -> 'AbstractBitVector': 218 | pass 219 | 220 | @abstractmethod 221 | def bvashr(self, other) -> 'AbstractBitVector': 222 | pass 223 | 224 | @abstractmethod 225 | def bvrol(self, other) -> 'AbstractBitVector': 226 | pass 227 | 228 | @abstractmethod 229 | def bvror(self, other) -> 'AbstractBitVector': 230 | pass 231 | 232 | @abstractmethod 233 | def bvcomp(self, other) -> 'AbstractBitVector[1]': 234 | pass 235 | 236 | @abstractmethod 237 | def bveq(self, other) -> AbstractBit: 238 | pass 239 | 240 | def bvne(self, other) -> AbstractBit: 241 | return ~self.bveq(other) 242 | 243 | @abstractmethod 244 | def bvult(self, other) -> AbstractBit: 245 | pass 246 | 247 | def bvule(self, other) -> AbstractBit: 248 | return self.bvult(other) | self.bveq(other) 249 | 250 | def bvugt(self, other) -> AbstractBit: 251 | return ~self.bvule(other) 252 | 253 | def bvuge(self, other) -> AbstractBit: 254 | return ~self.bvult(other) 255 | 256 | @abstractmethod 257 | def bvslt(self, other) -> AbstractBit: 258 | pass 259 | 260 | def bvsle(self, other) -> AbstractBit: 261 | return self.bvslt(other) | self.bveq(other) 262 | 263 | def bvsgt(self, other) -> AbstractBit: 264 | return ~self.bvsle(other) 265 | 266 | def bvsge(self, other) -> AbstractBit: 267 | return ~self.bvslt(other) 268 | 269 | @abstractmethod 270 | def bvneg(self) -> 'AbstractBitVector': 271 | pass 272 | 273 | @abstractmethod 274 | def adc(self, other, carry) -> tp.Tuple['AbstractBitVector', AbstractBit]: 275 | pass 276 | 277 | @abstractmethod 278 | def ite(i,t,e) -> 'AbstractBitVector': 279 | pass 280 | 281 | @abstractmethod 282 | def bvadd(self, other) -> 'AbstractBitVector': 283 | pass 284 | 285 | @abstractmethod 286 | def bvsub(self, other) -> 'AbstractBitVector': 287 | pass 288 | 289 | @abstractmethod 290 | def bvmul(self, other) -> 'AbstractBitVector': 291 | pass 292 | 293 | @abstractmethod 294 | def bvudiv(self, other) -> 'AbstractBitVector': 295 | pass 296 | 297 | @abstractmethod 298 | def bvurem(self, other) -> 'AbstractBitVector': 299 | pass 300 | 301 | @abstractmethod 302 | def bvsdiv(self, other) -> 'AbstractBitVector': 303 | pass 304 | 305 | @abstractmethod 306 | def bvsrem(self, other) -> 'AbstractBitVector': 307 | pass 308 | 309 | @abstractmethod 310 | def repeat(self, other) -> 'AbstractBitVector': 311 | pass 312 | 313 | @abstractmethod 314 | def sext(self, other) -> 'AbstractBitVector': 315 | pass 316 | 317 | @abstractmethod 318 | def ext(self, other) -> 'AbstractBitVector': 319 | pass 320 | 321 | @abstractmethod 322 | def zext(self, other) -> 'AbstractBitVector': 323 | pass 324 | 325 | BitVectorMeta = AbstractBitVectorMeta 326 | 327 | _Family_ = TypeFamily(AbstractBit, AbstractBitVector, None, None) 328 | -------------------------------------------------------------------------------- /tests/test_fp.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import operator 3 | 4 | import ctypes 5 | import random 6 | from hwtypes import FPVector, RoundingMode,BitVector 7 | import math 8 | import gmpy2 9 | 10 | #A sort of reference vector to test against 11 | #Wraps either ctypes.c_float or ctypes.c_double 12 | def _c_type_vector(T): 13 | class vector: 14 | def __init__(self, value): 15 | self._value = T(value) 16 | 17 | def __repr__(self): 18 | return f'{self.value}' 19 | 20 | @property 21 | def value(self) -> float: 22 | return self._value.value 23 | 24 | def fp_abs(self): 25 | if self.value < 0: 26 | return type(self)(-self.value) 27 | else: 28 | return type(self)(self.value) 29 | 30 | def fp_neg(self): 31 | return type(self)(-self.value) 32 | 33 | def fp_add(self, other): 34 | return type(self)(self.value + other.value) 35 | 36 | def fp_sub(self, other): 37 | return type(self)(self.value - other.value) 38 | 39 | def fp_mul(self, other): 40 | return type(self)(self.value * other.value) 41 | 42 | def fp_div(self, other): 43 | return type(self)(self.value / other.value) 44 | 45 | def fp_fma(self, coef, offset): 46 | raise NotImplementedError() 47 | 48 | def fp_sqrt(self): 49 | return type(self)(math.sqrt(self.value)) 50 | 51 | def fp_rem(self, other): 52 | return type(self)(math.remainder(self.value)) 53 | 54 | def fp_round_to_integral(self): 55 | return type(self)(math.round(self.value)) 56 | 57 | def fp_min(self, other): 58 | return type(self)(min(self.value, other.value)) 59 | 60 | def fp_max(self, other): 61 | return type(self)(max(self.value, other.value)) 62 | 63 | def fp_leq(self, other): 64 | return self.value <= other.value 65 | 66 | def fp_lt(self, other): 67 | return self.value < other.value 68 | 69 | def fp_geq(self, other): 70 | return self.value >= other.value 71 | 72 | def fp_gt(self, other): 73 | return self.value > other.value 74 | 75 | def fp_eq(self, other): 76 | return self.value == other.value 77 | 78 | def fp_is_normal(self): 79 | raise NotImplementedError() 80 | 81 | def fp_is_subnormal(self): 82 | raise NotImplementedError() 83 | 84 | def fp_is_zero(self): 85 | return self.value == 0.0 86 | 87 | def fp_is_infinite(self): 88 | return math.isinf(self.value) 89 | 90 | def fp_is_NaN(self): 91 | return math.isnan(self.value) 92 | 93 | def fp_is_negative(self): 94 | return self.value < 0.0 95 | 96 | def fp_is_positive(self): 97 | return self.value > 0.0 98 | 99 | def to_ubv(self, size : int): 100 | raise NotImplementedError() 101 | 102 | def to_sbv(self, size : int): 103 | raise NotImplementedError() 104 | 105 | def reinterpret_as_bv(self): 106 | raise NotImplementedError() 107 | 108 | def __neg__(self): return self.fp_neg() 109 | def __abs__(self): return self.fp_abs() 110 | def __add__(self, other): return self.fp_add(other) 111 | def __sub__(self, other): return self.fp_sub(other) 112 | def __mul__(self, other): return self.fp_mul(other) 113 | def __truediv__(self, other): return self.fp_div(other) 114 | def __mod__(self, other): return self.fp_rem(other) 115 | 116 | def __eq__(self, other): return self.fp_eq(other) 117 | def __ne__(self, other): return ~(self.fp_eq(other)) 118 | def __ge__(self, other): return self.fp_geq(other) 119 | def __gt__(self, other): return self.fp_gt(other) 120 | def __le__(self, other): return self.fp_leq(other) 121 | def __lt__(self, other): return self.fp_lt(other) 122 | 123 | def __float__(self): 124 | return self.value 125 | 126 | return vector 127 | 128 | NTESTS = 128 129 | 130 | @pytest.mark.parametrize("mode", [ 131 | RoundingMode.RNE, 132 | RoundingMode.RNA, 133 | RoundingMode.RTP, 134 | RoundingMode.RTN, 135 | RoundingMode.RTZ, 136 | ]) 137 | @pytest.mark.parametrize("ieee", [False, True]) 138 | def test_init(mode, ieee): 139 | BigFloat = FPVector[27,100, mode, ieee] 140 | class F: 141 | def __float__(self): 142 | return 0.5 143 | 144 | class I: 145 | def __int__(self): 146 | return 1 147 | 148 | assert BigFloat(0.5) == BigFloat(F()) 149 | assert BigFloat(1) == BigFloat(I()) 150 | assert BigFloat(0.5) == BigFloat('0.5') 151 | assert BigFloat('1/3') == BigFloat(1)/BigFloat(3) 152 | assert BigFloat('1/3') != BigFloat(1/3) # as 1/3 is performed in python floats 153 | 154 | @pytest.mark.parametrize("FT", [ 155 | FPVector[8, 7, RoundingMode.RNE, True], 156 | FPVector[8, 7, RoundingMode.RNE, False], 157 | FPVector[11, 52, RoundingMode.RNE, True], 158 | FPVector[11, 52, RoundingMode.RNE, False], 159 | ]) 160 | @pytest.mark.parametrize("allow_inf", [False, True]) 161 | def test_random(FT, allow_inf): 162 | for _ in range(NTESTS): 163 | r = FT.random(allow_inf) 164 | assert allow_inf or not r.fp_is_infinite() 165 | assert not r.fp_is_NaN() 166 | 167 | @pytest.mark.parametrize("FT", [ 168 | FPVector[8, 7, RoundingMode.RNE, True], 169 | FPVector[8, 7, RoundingMode.RNE, False], 170 | FPVector[11, 52, RoundingMode.RNE, True], 171 | FPVector[11, 52, RoundingMode.RNE, False], 172 | ]) 173 | def test_reinterpret(FT): 174 | #basic sanity 175 | for x in ('-2.0', '-1.75', '-1.5', '-1.25', 176 | '-1.0', '-0.75', '-0.5', '-0.25', 177 | '0.0', '0.25', '0.5', '0.75', 178 | '1.0', '1.25', '1.5', '1.75',): 179 | f1 = FT(x) 180 | bv = f1.reinterpret_as_bv() 181 | f2 = FT.reinterpret_from_bv(bv) 182 | assert f1 == f2 183 | 184 | #epsilon 185 | f1 = FT(1) 186 | while f1/2 != 0: 187 | bv = f1.reinterpret_as_bv() 188 | f2 = FT.reinterpret_from_bv(bv) 189 | assert f1 == f2 190 | f1 = f1/2 191 | 192 | #using FPVector.random 193 | for _ in range(NTESTS): 194 | f1 = FT.random() 195 | bv = f1.reinterpret_as_bv() 196 | f2 = FT.reinterpret_from_bv(bv) 197 | assert f1 == f2 198 | 199 | @pytest.mark.parametrize("FT", [ 200 | FPVector[8, 23, RoundingMode.RNE, True], 201 | FPVector[8, 23, RoundingMode.RNE, False], 202 | FPVector[11, 52, RoundingMode.RNE, True], 203 | FPVector[11, 52, RoundingMode.RNE, False], 204 | ]) 205 | @pytest.mark.parametrize("mean, variance", [ 206 | (0, 2**-64), 207 | (0, 2**-16), 208 | (0, 2**-4), 209 | (0, 1), 210 | (0, 2**4), 211 | (0, 2**16), 212 | (0, 2**64), 213 | ]) 214 | def test_reinterpret_pyrandom(FT, mean, variance): 215 | for _ in range(NTESTS): 216 | x = random.normalvariate(mean, variance) 217 | f1 = FT(x) 218 | bv = f1.reinterpret_as_bv() 219 | f2 = FT.reinterpret_from_bv(bv) 220 | assert f1 == f2 221 | 222 | @pytest.mark.parametrize("FT", [ 223 | FPVector[8, 7, RoundingMode.RNE, True], 224 | FPVector[8, 7, RoundingMode.RNE, False], 225 | FPVector[11, 52, RoundingMode.RNE, True], 226 | FPVector[11, 52, RoundingMode.RNE, False], 227 | ]) 228 | def test_reinterpret_bv(FT): 229 | ms = FT.mantissa_size 230 | for _ in range(NTESTS): 231 | bv1 = BitVector.random(FT.size) 232 | #dont generate denorms or NaN unless ieee compliant 233 | while (not FT.ieee_compliance 234 | and (bv1[ms:-1] == 0 or bv1[ms:-1] == -1) 235 | and bv1[:ms] != 0): 236 | bv1 = BitVector.random(FT.size) 237 | 238 | f = FT.reinterpret_from_bv(bv1) 239 | bv2 = f.reinterpret_as_bv() 240 | if not f.fp_is_NaN(): 241 | assert bv1 == bv2 242 | else: 243 | #exponents should be -1 244 | assert bv1[ms:-1] == BitVector[FT.exponent_size](-1) 245 | assert bv2[ms:-1] == BitVector[FT.exponent_size](-1) 246 | #mantissa should be non 0 247 | assert bv1[:ms] != 0 248 | assert bv2[:ms] != 0 249 | 250 | def test_reinterpret_bv_corner(): 251 | for _ in range(NTESTS): 252 | FT = FPVector[random.randint(3, 16), 253 | random.randint(2, 64), 254 | random.choice(list(RoundingMode)), 255 | True] 256 | bv_pinf = BitVector[FT.mantissa_size](0).concat(BitVector[FT.exponent_size](-1)).concat(BitVector[1](0)) 257 | bv_ninf = BitVector[FT.mantissa_size](0).concat(BitVector[FT.exponent_size](-1)).concat(BitVector[1](1)) 258 | pinf = FT.reinterpret_from_bv(bv_pinf) 259 | ninf = FT.reinterpret_from_bv(bv_ninf) 260 | assert pinf.reinterpret_as_bv() == bv_pinf 261 | assert ninf.reinterpret_as_bv() == bv_ninf 262 | assert pinf.fp_is_positive() 263 | assert pinf.fp_is_infinite() 264 | assert ninf.fp_is_negative() 265 | assert ninf.fp_is_infinite() 266 | 267 | bv_pz = BitVector[FT.size](0) 268 | bv_nz = BitVector[FT.size-1](0).concat(BitVector[1](1)) 269 | pz = FT.reinterpret_from_bv(bv_pz) 270 | nz = FT.reinterpret_from_bv(bv_nz) 271 | assert pz.reinterpret_as_bv() == bv_pz 272 | assert nz.reinterpret_as_bv() == bv_nz 273 | assert pz.fp_is_zero() 274 | assert nz.fp_is_zero() 275 | 276 | bv_nan = BitVector[FT.mantissa_size](1).concat(BitVector[FT.exponent_size](-1)).concat(BitVector[1](0)) 277 | nan = FT.reinterpret_from_bv(bv_nan) 278 | assert nan.reinterpret_as_bv() == bv_nan 279 | assert nan.fp_is_NaN() 280 | 281 | @pytest.mark.parametrize("CT, FT", [ 282 | (_c_type_vector(ctypes.c_float), FPVector[8, 23, RoundingMode.RNE, True]), 283 | (_c_type_vector(ctypes.c_double), FPVector[11, 52, RoundingMode.RNE, True]),]) 284 | def test_epsilon(CT, FT): 285 | cx, fx = CT(1), FT(1) 286 | c2, f2 = CT(2), FT(2) 287 | while not ((cx/c2).fp_is_zero() or (fx/f2).fp_is_zero()): 288 | assert float(cx) == float(fx) 289 | cx = cx/c2 290 | fx = fx/f2 291 | 292 | assert not cx.fp_is_zero() 293 | assert not fx.fp_is_zero() 294 | assert (cx/c2).fp_is_zero() 295 | assert (fx/f2).fp_is_zero() 296 | 297 | @pytest.mark.parametrize("op", [operator.neg, operator.abs, lambda x : abs(x).fp_sqrt()]) 298 | @pytest.mark.parametrize("CT, FT", [ 299 | (_c_type_vector(ctypes.c_float), FPVector[8, 23, RoundingMode.RNE, True]), 300 | (_c_type_vector(ctypes.c_double), FPVector[11, 52, RoundingMode.RNE, True]),]) 301 | @pytest.mark.parametrize("mean, variance", [ 302 | (0, 2**-64), 303 | (0, 2**-16), 304 | (0, 2**-4), 305 | (0, 1), 306 | (0, 2**4), 307 | (0, 2**16), 308 | (0, 2**64), 309 | ]) 310 | def test_unary_op(op, CT, FT, mean, variance): 311 | for _ in range(NTESTS): 312 | x = random.normalvariate(mean, variance) 313 | cx = CT(x) 314 | fx = FT(x) 315 | assert float(cx) == float(fx) 316 | cr, fr = op(cx), op(fx) 317 | assert float(cr) == float(fr) 318 | 319 | @pytest.mark.parametrize("op", [operator.add, operator.sub, operator.mul, operator.truediv]) 320 | @pytest.mark.parametrize("CT, FT", [ 321 | (_c_type_vector(ctypes.c_float), FPVector[8, 23, RoundingMode.RNE, True]), 322 | (_c_type_vector(ctypes.c_double), FPVector[11, 52, RoundingMode.RNE, True]),]) 323 | @pytest.mark.parametrize("mean, variance", [ 324 | (0, 2**-64), 325 | (0, 2**-16), 326 | (0, 2**-4), 327 | (0, 1), 328 | (0, 2**4), 329 | (0, 2**16), 330 | (0, 2**64), 331 | ]) 332 | def test_bin_op(op, CT, FT, mean, variance): 333 | for _ in range(NTESTS): 334 | x = random.normalvariate(mean, variance) 335 | y = random.normalvariate(mean, variance) 336 | cx, cy = CT(x), CT(y) 337 | fx, fy = FT(x), FT(y) 338 | cr, fr = op(cx, cy), op(fx, fy) 339 | assert float(cr) == float(fr) 340 | 341 | @pytest.mark.parametrize("op", [operator.eq, operator.ne, operator.lt, operator.le, operator.gt, operator.ge]) 342 | @pytest.mark.parametrize("CT, FT", [ 343 | (_c_type_vector(ctypes.c_float), FPVector[8, 23, RoundingMode.RNE, True]), 344 | (_c_type_vector(ctypes.c_double), FPVector[11, 52, RoundingMode.RNE, True]),]) 345 | @pytest.mark.parametrize("mean, variance", [ 346 | (0, 2**-64), 347 | (0, 2**-16), 348 | (0, 2**-4), 349 | (0, 1), 350 | (0, 2**4), 351 | (0, 2**16), 352 | (0, 2**64), 353 | ]) 354 | def test_bool_op(op, CT, FT, mean, variance): 355 | for _ in range(NTESTS): 356 | x = random.normalvariate(mean, variance) 357 | y = random.normalvariate(mean, variance) 358 | cx, cy = CT(x), CT(y) 359 | fx, fy = FT(x), FT(y) 360 | cr, fr = op(cx, cy), op(fx, fy) 361 | assert bool(cr) == bool(fr) 362 | -------------------------------------------------------------------------------- /tests/test_adt.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from hwtypes.adt import Product, Sum, Enum, Tuple, TaggedUnion, AnonymousProduct 3 | from hwtypes.adt_meta import RESERVED_ATTRS, ReservedNameError, AttrSyntax, GetitemSyntax 4 | from hwtypes.modifiers import new 5 | from hwtypes.adt_util import rebind_bitvector 6 | from hwtypes import BitVector, AbstractBitVector, Bit, AbstractBit 7 | 8 | 9 | class En1(Enum): 10 | a = 0 11 | b = 1 12 | 13 | 14 | class En2(Enum): 15 | c = 0 16 | d = 1 17 | 18 | Tu = Tuple[En1, En2] 19 | 20 | Ap = AnonymousProduct[{'x': En1, 'y': En2}] 21 | 22 | class Pr(Product): 23 | x = En1 24 | y = En2 25 | 26 | 27 | class Pr2(Product, cache=False): 28 | x = En1 29 | y = En2 30 | 31 | 32 | class Pr3(Product, cache=True): 33 | y = En2 34 | x = En1 35 | 36 | 37 | Su = Sum[En1, Pr] 38 | 39 | class Ta(TaggedUnion): 40 | x = En1 41 | y = En1 42 | z = Pr 43 | 44 | class Ta2(TaggedUnion, cache=False): 45 | x = En1 46 | y = En1 47 | z = Pr 48 | 49 | class Ta3(TaggedUnion, cache=True): 50 | y = En1 51 | x = En1 52 | z = Pr 53 | 54 | def test_enum(): 55 | assert set(En1.enumerate()) == { 56 | En1.a, 57 | En1.b, 58 | } 59 | 60 | assert En1.a._value_ == 0 61 | assert En1.a is En1.a 62 | 63 | assert issubclass(En1, Enum) 64 | assert isinstance(En1.a, Enum) 65 | assert isinstance(En1.a, En1) 66 | assert En1.is_bound 67 | 68 | with pytest.raises(AttributeError): 69 | En1.a.b 70 | 71 | 72 | def test_tuple(): 73 | assert set(Tu.enumerate()) == { 74 | Tu(En1.a, En2.c), 75 | Tu(En1.a, En2.d), 76 | Tu(En1.b, En2.c), 77 | Tu(En1.b, En2.d), 78 | } 79 | 80 | assert Tu(En1.a, En2.c)._value_ == (En1.a, En2.c) 81 | 82 | assert issubclass(Tu, Tuple) 83 | assert isinstance(Tu(En1.a, En2.c), Tuple) 84 | assert isinstance(Tu(En1.a, En2.c), Tu) 85 | assert Tu[0] == En1 86 | assert Tu[1] == En2 87 | 88 | t = Tu(En1.a, En2.c) 89 | assert (t[0],t[1]) == (En1.a,En2.c) 90 | t[0] = En1.b 91 | assert (t[0],t[1]) == (En1.b,En2.c) 92 | 93 | assert Tu.field_dict == {0 : En1, 1 : En2 } 94 | assert t.value_dict == {0 : En1.b, 1: En2.c} 95 | 96 | with pytest.raises(TypeError): 97 | Tu(En1.a, 1) 98 | 99 | with pytest.raises(TypeError): 100 | t[1] = 1 101 | 102 | def test_anonymous_product(): 103 | assert set(Ap.enumerate()) == { 104 | Ap(En1.a, En2.c), 105 | Ap(En1.a, En2.d), 106 | Ap(En1.b, En2.c), 107 | Ap(En1.b, En2.d), 108 | } 109 | 110 | assert Ap(En1.a, En2.c)._value_ == (En1.a, En2.c) 111 | assert issubclass(Ap, AnonymousProduct) 112 | assert isinstance(Ap(En1.a, En2.c), AnonymousProduct) 113 | assert isinstance(Ap(En1.a, En2.c), Ap) 114 | 115 | assert Ap(En1.a, En2.c) == Tu(En1.a, En2.c) 116 | assert issubclass(Ap, Tu) 117 | assert issubclass(Ap, Tuple) 118 | assert isinstance(Ap(En1.a, En2.c), Tu) 119 | 120 | assert Ap[0] == Ap.x == En1 121 | assert Ap[1] == Ap.y == En2 122 | 123 | assert Ap.field_dict == {'x' : En1, 'y' : En2 } 124 | 125 | p = Ap(En1.a, En2.c) 126 | with pytest.raises(TypeError): 127 | Ap(En1.a, En1.a) 128 | 129 | assert p[0] == p.x == En1.a 130 | assert p[1] == p.y == En2.c 131 | assert p.value_dict == {'x' : En1.a, 'y' : En2.c} 132 | 133 | p.x = En1.b 134 | assert p[0] == p.x == En1.b 135 | assert p.value_dict == {'x' : En1.b, 'y' : En2.c} 136 | 137 | p[0] = En1.a 138 | assert p[0] == p.x == En1.a 139 | 140 | with pytest.raises(TypeError): 141 | p[0] = En2.c 142 | 143 | 144 | def test_product(): 145 | assert set(Pr.enumerate()) == { 146 | Pr(En1.a, En2.c), 147 | Pr(En1.a, En2.d), 148 | Pr(En1.b, En2.c), 149 | Pr(En1.b, En2.d), 150 | } 151 | 152 | assert Pr(En1.a, En2.c)._value_ == (En1.a, En2.c) 153 | 154 | assert issubclass(Pr, Product) 155 | assert isinstance(Pr(En1.a, En2.c), Product) 156 | assert isinstance(Pr(En1.a, En2.c), Pr) 157 | 158 | assert Pr(En1.a, En2.c) == Ap(En1.a, En2.c) 159 | assert issubclass(Pr, Ap) 160 | assert issubclass(Pr, AnonymousProduct) 161 | assert isinstance(Pr(En1.a, En2.c), Ap) 162 | 163 | assert Pr[0] == Pr.x == En1 164 | assert Pr[1] == Pr.y == En2 165 | 166 | assert Pr.field_dict == {'x' : En1, 'y' : En2 } 167 | 168 | p = Pr(En1.a, En2.c) 169 | with pytest.raises(TypeError): 170 | Pr(En1.a, En1.a) 171 | 172 | assert p[0] == p.x == En1.a 173 | assert p[1] == p.y == En2.c 174 | assert p.value_dict == {'x' : En1.a, 'y' : En2.c} 175 | 176 | p.x = En1.b 177 | assert p[0] == p.x == En1.b 178 | assert p.value_dict == {'x' : En1.b, 'y' : En2.c} 179 | 180 | p[0] = En1.a 181 | assert p[0] == p.x == En1.a 182 | 183 | with pytest.raises(TypeError): 184 | p[0] = En2.c 185 | 186 | 187 | def test_product_from_fields(): 188 | P = Product.from_fields('P', {'A' : int, 'B' : str}) 189 | assert issubclass(P, Product) 190 | assert issubclass(P, Tuple[int, str]) 191 | assert P.A == int 192 | assert P.B == str 193 | assert P.__module__ == Product.__module__ 194 | 195 | assert P is Product.from_fields('P', {'A' : int, 'B' : str}) 196 | 197 | P2 = Product.from_fields('P', {'B' : str, 'A' : int}) 198 | assert P2 is not P 199 | 200 | P3 = Product.from_fields('P', {'A' : int, 'B' : str}, cache=False) 201 | assert P3 is not P 202 | assert P3 is not P2 203 | 204 | with pytest.raises(TypeError): 205 | Pr.from_fields('P', {'A' : int, 'B' : str}) 206 | 207 | 208 | def test_product_caching(): 209 | global Pr 210 | assert Pr is not Pr2 211 | assert Pr is not Pr3 212 | assert Pr is Product.from_fields('Pr', {'x' : En1, 'y' : En2 }, cache=True) 213 | assert Pr.field_dict == Pr2.field_dict 214 | assert Pr.field_dict != Pr3.field_dict 215 | Pr_ = Pr 216 | 217 | class Pr(Product): 218 | y = En2 219 | x = En1 220 | 221 | # Order matters 222 | assert Pr_ is not Pr 223 | 224 | class Pr(Product): 225 | x = En1 226 | y = En2 227 | 228 | assert Pr_ is Pr 229 | 230 | 231 | 232 | 233 | def test_sum(): 234 | assert set(Su.enumerate()) == { 235 | Su(En1.a), 236 | Su(En1.b), 237 | Su(Pr(En1.a, En2.c)), 238 | Su(Pr(En1.a, En2.d)), 239 | Su(Pr(En1.b, En2.c)), 240 | Su(Pr(En1.b, En2.d)), 241 | } 242 | 243 | assert Su(En1.a)._value_ == En1.a 244 | 245 | assert En1 in Su 246 | assert Pr in Su 247 | assert En2 not in Su 248 | 249 | assert issubclass(Su, Sum) 250 | assert isinstance(Su(En1.a), Su) 251 | assert isinstance(Su(En1.a), Sum) 252 | 253 | assert Su[En1] == En1 254 | assert Su[Pr] == Pr 255 | with pytest.raises(KeyError): 256 | Su[En2] 257 | 258 | assert Su.field_dict == {En1 : En1, Pr : Pr} 259 | 260 | with pytest.raises(TypeError): 261 | Su(1) 262 | 263 | s = Su(En1.a) 264 | assert s[En1].match 265 | assert not s[Pr].match 266 | with pytest.raises(TypeError): 267 | s[En2].match 268 | 269 | assert s[En1].value == En1.a 270 | 271 | with pytest.raises(TypeError): 272 | s[Pr].value 273 | 274 | assert s.value_dict == {En1 : En1.a, Pr : None} 275 | 276 | s[En1] = En1.b 277 | assert s[En1].value == En1.b 278 | 279 | s[Pr] = Pr(En1.a, En2.c) 280 | assert s[Pr].match 281 | assert not s[En1].match 282 | assert s[Pr].value == Pr(En1.a, En2.c) 283 | 284 | with pytest.raises(TypeError): 285 | s[En1].value 286 | 287 | with pytest.raises(TypeError): 288 | s[En1] = En2.c 289 | 290 | with pytest.raises(TypeError): 291 | s[Pr] = En1.a 292 | 293 | def test_tagged_union(): 294 | assert set(Ta.enumerate()) == { 295 | Ta(x=En1.a), 296 | Ta(x=En1.b), 297 | Ta(y=En1.a), 298 | Ta(y=En1.b), 299 | Ta(z=Pr(En1.a, En2.c)), 300 | Ta(z=Pr(En1.a, En2.d)), 301 | Ta(z=Pr(En1.b, En2.c)), 302 | Ta(z=Pr(En1.b, En2.d)), 303 | } 304 | 305 | 306 | assert Ta(x=En1.a)._value_ == En1.a 307 | 308 | assert En1 in Ta 309 | assert Pr in Ta 310 | assert En2 not in Ta 311 | 312 | assert issubclass(Ta, TaggedUnion) 313 | assert isinstance(Ta(x=En1.a), TaggedUnion) 314 | assert isinstance(Ta(x=En1.a), Ta) 315 | 316 | assert Ta(x=En1.a) == Su(En1.a) 317 | assert issubclass(Ta, Su) 318 | assert issubclass(Ta, Sum) 319 | assert isinstance(Ta(x=En1.a), Su) 320 | 321 | assert Ta.x == Ta.y == Ta[En1] == Su[En1] == En1 322 | assert Ta.z == Ta[Pr] == Su[Pr] == Pr 323 | 324 | assert Ta.field_dict == {'x': En1, 'y': En1, 'z': Pr} 325 | 326 | t = Ta(x=En1.a) 327 | with pytest.raises(TypeError): 328 | Ta(x=En2.c) 329 | 330 | assert t[En1].match and t.x.match and not t.y.match and not t.z.match 331 | assert t[En1].value == t.x.value == En1.a 332 | 333 | with pytest.raises(TypeError): 334 | t.y.value 335 | 336 | with pytest.raises(TypeError): 337 | t.z.value 338 | 339 | assert t.value_dict == {'x': En1.a, 'y': None, 'z': None} 340 | 341 | t.x = En1.b 342 | assert t[En1].match and t.x.match and not t.y.match and not t.z.match 343 | assert t[En1].value == t.x.value == En1.b 344 | 345 | t.y = En1.a 346 | assert t[En1].match and t.y.match and not t.x.match and not t.z.match 347 | assert t[En1].value == t.y.value == En1.a 348 | 349 | with pytest.raises(TypeError): 350 | t[En1] = En1.a 351 | 352 | with pytest.raises(TypeError): 353 | t.z = En1.a 354 | 355 | with pytest.raises(TypeError): 356 | t.y = En2.c 357 | 358 | assert Ta != Ta2 359 | assert Ta.field_dict == Ta2.field_dict 360 | assert Ta is TaggedUnion.from_fields('Ta', {'x': En1, 'y': En1, 'z': Pr}, cache=True) 361 | 362 | 363 | def test_tagged_union_from_fields(): 364 | T = TaggedUnion.from_fields('T', {'A' : int, 'B' : str}) 365 | assert issubclass(T, TaggedUnion) 366 | assert issubclass(T, Sum[int, str]) 367 | assert T.A == int 368 | assert T.B == str 369 | assert T.__module__ == TaggedUnion.__module__ 370 | 371 | assert T is TaggedUnion.from_fields('T', {'A' : int, 'B' : str}) 372 | 373 | T2 = TaggedUnion.from_fields('T', {'B' : str, 'A' : int}) 374 | assert T2 is T 375 | 376 | T3 = TaggedUnion.from_fields('T', {'A' : int, 'B' : str}, cache=False) 377 | assert T3 is not T 378 | 379 | with pytest.raises(TypeError): 380 | Ta.from_fields('T', {'A' : int, 'B' : str}) 381 | 382 | 383 | def test_tagged_union_caching(): 384 | global Ta 385 | assert Ta is not Ta2 386 | assert Ta is not Ta3 387 | assert Ta is TaggedUnion.from_fields('Ta', Ta3.field_dict, cache=True) 388 | assert Ta2 is not TaggedUnion.from_fields('Ta2', Ta2.field_dict, cache=True) 389 | assert Ta2 is not TaggedUnion.from_fields('Ta2', Ta2.field_dict, cache=False) 390 | assert Ta.field_dict == Ta2.field_dict == Ta3.field_dict 391 | 392 | Ta_ = Ta 393 | 394 | class Ta(TaggedUnion): 395 | x = En1 396 | y = En1 397 | z = Pr 398 | 399 | assert Ta_ is Ta 400 | 401 | class Ta(TaggedUnion): 402 | y = En1 403 | x = En1 404 | z = Pr 405 | 406 | assert Ta_ is Ta 407 | 408 | 409 | def test_new(): 410 | t = new(Tuple) 411 | s = new(Tuple) 412 | assert issubclass(t, Tuple) 413 | assert issubclass(t[En1], t) 414 | assert issubclass(t[En1], Tuple[En1]) 415 | assert t is not Tuple 416 | assert s is not t 417 | 418 | t = new(Sum, (En1, Pr)) 419 | assert t is not Su 420 | assert Sum[En1, Pr] is Su 421 | assert t.__module__ == 'hwtypes.modifiers' 422 | 423 | t = new(Sum, (En1, Pr), module=__name__) 424 | assert t.__module__ == __name__ 425 | 426 | 427 | @pytest.mark.parametrize("T", [En1, Tu, Su, Pr, Ta]) 428 | def test_repr(T): 429 | s = repr(T) 430 | assert isinstance(s, str) 431 | assert s != '' 432 | for e in T.enumerate(): 433 | s = repr(e) 434 | assert isinstance(s, str) 435 | assert s != '' 436 | 437 | 438 | @pytest.mark.parametrize("T_field", 439 | [(Enum, '0'), (Product, 'int'), (TaggedUnion, 'int')]) 440 | @pytest.mark.parametrize("field_name", list(RESERVED_ATTRS)) 441 | def test_reserved(T_field, field_name): 442 | T, field = T_field 443 | l_dict = {'T' : T} 444 | cls_str = f''' 445 | class _(T): 446 | {field_name} = {field} 447 | ''' 448 | with pytest.raises(ReservedNameError): 449 | exec(cls_str, l_dict) 450 | 451 | @pytest.mark.parametrize("t, base", [ 452 | (En1, Enum), 453 | (Pr, Product), 454 | (Su, Sum), 455 | (Tu, Tuple), 456 | (Ta, TaggedUnion), 457 | ]) 458 | def test_unbound_t(t, base): 459 | assert t.unbound_t == base 460 | class sub_t(t): pass 461 | with pytest.raises(AttributeError): 462 | sub_t.unbound_t 463 | 464 | @pytest.mark.parametrize("val", [ 465 | En1.a, Su(En1.a), Ta(x=En1.a), 466 | Tu(En1.a, En2.c), Pr(En1.a, En2.c)]) 467 | def test_deprecated(val): 468 | with pytest.warns(DeprecationWarning): 469 | val.value 470 | 471 | def test_adt_syntax(): 472 | # En1, Pr, Su, Tu, Ta 473 | for T in (En1, Pr, Ta): 474 | assert isinstance(T, AttrSyntax) 475 | assert not isinstance(T, GetitemSyntax) 476 | 477 | for T in (Su, Tu): 478 | assert not isinstance(T, AttrSyntax) 479 | assert isinstance(T, GetitemSyntax) 480 | 481 | for T in (str, Bit, BitVector[4], AbstractBit, AbstractBitVector[4], int): 482 | assert not isinstance(T, AttrSyntax) 483 | assert not isinstance(T, GetitemSyntax) 484 | 485 | @pytest.mark.parametrize("adt", [ 486 | Su(En1.a), Ta(x=En1.a), 487 | Tu(En1.a, En2.c), Pr(En1.a, En2.c), Ap(En1.a, En2.c)]) 488 | def test_from_values(adt): 489 | assert adt == type(adt).from_values(adt.value_dict) 490 | -------------------------------------------------------------------------------- /hwtypes/fp_vector.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import gmpy2 3 | import random 4 | import typing as tp 5 | import warnings 6 | 7 | from .bit_vector import Bit, BitVector, SIntVector 8 | from .fp_vector_abc import AbstractFPVector, RoundingMode 9 | 10 | __ALL__ = ['FPVector'] 11 | 12 | _mode_2_gmpy2 = { 13 | RoundingMode.RNE : gmpy2.RoundToNearest, 14 | RoundingMode.RNA : gmpy2.RoundAwayZero, 15 | RoundingMode.RTP : gmpy2.RoundUp, 16 | RoundingMode.RTN : gmpy2.RoundDown, 17 | RoundingMode.RTZ : gmpy2.RoundToZero, 18 | } 19 | 20 | def _coerce(T : tp.Type['FPVector'], val : tp.Any) -> 'FPVector': 21 | if not isinstance(val, FPVector): 22 | return T(val) 23 | elif type(val).binding != T.binding: 24 | raise TypeError('Inconsistent FP type') 25 | else: 26 | return val 27 | 28 | def fp_cast(fn : tp.Callable[['FPVector', 'FPVector'], tp.Any]) -> tp.Callable[['FPVector', tp.Any], tp.Any]: 29 | @functools.wraps(fn) 30 | def wrapped(self : 'FPVector', other : tp.Any) -> tp.Any: 31 | other = _coerce(type(self), other) 32 | return fn(self, other) 33 | return wrapped 34 | 35 | def set_context(fn: tp.Callable) -> tp.Callable: 36 | @functools.wraps(fn) 37 | def wrapped(self : 'FPVector', *args, **kwargs): 38 | with gmpy2.local_context(self._ctx_): 39 | return fn(self, *args, **kwargs) 40 | return wrapped 41 | 42 | class FPVector(AbstractFPVector): 43 | @set_context 44 | def __init__(self, value): 45 | # Because for some reason gmpy2.mpfr is a function and not a type 46 | if isinstance(value, type(gmpy2.mpfr(0))): 47 | #need to specify precision because mpfr will use the input 48 | #precision not the context precision when constructing from mpfr 49 | value = gmpy2.mpfr(value, self._ctx_.precision) 50 | elif isinstance(value, FPVector): 51 | value = gmpy2.mpfr(value._value, self._ctx_.precision) 52 | elif isinstance(value, (int, float, type(gmpy2.mpz(0)), type(gmpy2.mpq(0)))): 53 | value = gmpy2.mpfr(value) 54 | elif isinstance(value, str): 55 | try: 56 | #Handles '0.5' 57 | value = gmpy2.mpfr(value) 58 | except ValueError: 59 | try: 60 | #Handles '1/2' 61 | value = gmpy2.mpfr(gmpy2.mpq(value)) 62 | except ValueError: 63 | raise ValueError('Invalid string') 64 | elif hasattr(value, '__float__'): 65 | value = gmpy2.mpfr(float(value)) 66 | elif hasattr(value, '__int__'): 67 | value = gmpy2.mpfr(int(value)) 68 | else: 69 | try: 70 | #if gmpy2 doesn't complain I wont 71 | value = gmpy2.mpfr(value) 72 | except TypeError: 73 | raise TypeError(f"Can't construct FPVector from {type(value)}") 74 | 75 | if gmpy2.is_nan(value) and not type(self).ieee_compliance: 76 | if gmpy2.is_signed(value): 77 | self._value = gmpy2.mpfr('-inf') 78 | else: 79 | self._value = gmpy2.mpfr('inf') 80 | 81 | else: 82 | self._value = value 83 | 84 | @classmethod 85 | def __init_subclass__(cls, **kwargs): 86 | if cls.is_bound: 87 | if cls.ieee_compliance: 88 | precision=cls.mantissa_size+1 89 | emax=1<<(cls.exponent_size - 1) 90 | emin=4-emax-precision 91 | subnormalize=True 92 | else: 93 | precision=cls.mantissa_size+1 94 | emax=1<<(cls.exponent_size - 1) 95 | emin=3-emax 96 | subnormalize=False 97 | 98 | ctx = gmpy2.context( 99 | precision=precision, 100 | emin=emin, 101 | emax=emax, 102 | round=_mode_2_gmpy2[cls.mode], 103 | subnormalize=cls.ieee_compliance, 104 | allow_complex=False, 105 | ) 106 | 107 | if hasattr(cls, '_ctx_'): 108 | c_ctx = cls._ctx_ 109 | if not isinstance(c_ctx, type(ctx)): 110 | raise TypeError('class attribute _ctx_ is reversed by FPVector') 111 | #stupid compare because contexts aren't comparable 112 | elif (c_ctx.precision != ctx.precision 113 | or c_ctx.real_prec != ctx.real_prec 114 | or c_ctx.imag_prec != ctx.imag_prec 115 | or c_ctx.round != ctx.round 116 | or c_ctx.real_round != ctx.real_round 117 | or c_ctx.imag_round != ctx.imag_round 118 | or c_ctx.emax != ctx.emax 119 | or c_ctx.emin != ctx.emin 120 | or c_ctx.subnormalize != ctx.subnormalize 121 | or c_ctx.trap_underflow != ctx.trap_underflow 122 | or c_ctx.trap_overflow != ctx.trap_overflow 123 | or c_ctx.trap_inexact != ctx.trap_inexact 124 | or c_ctx.trap_erange != ctx.trap_erange 125 | or c_ctx.trap_divzero != ctx.trap_divzero 126 | or c_ctx.trap_expbound != ctx.trap_expbound 127 | or c_ctx.allow_complex != ctx.allow_complex): 128 | # this basically should never happen unless some one does 129 | # class Foo(FPVector): 130 | # _ctx_ = gmpy2.context(....) 131 | raise TypeError('Incompatible context types') 132 | 133 | cls._ctx_ = ctx 134 | 135 | 136 | 137 | def __repr__(self): 138 | return f'{type(self)}({self._value})' 139 | 140 | @set_context 141 | def fp_abs(self) -> 'FPVector': 142 | v = self._value 143 | return type(self)(v if v >= 0 else -v) 144 | 145 | @set_context 146 | def fp_neg(self) -> 'FPVector': 147 | return type(self)(-self._value) 148 | 149 | @set_context 150 | @fp_cast 151 | def fp_add(self, other : 'FPVector') -> 'FPVector': 152 | return type(self)(self._value + other._value) 153 | 154 | @set_context 155 | @fp_cast 156 | def fp_sub(self, other : 'FPVector') -> 'FPVector': 157 | return type(self)(self._value - other._value) 158 | 159 | @set_context 160 | @fp_cast 161 | def fp_mul(self, other : 'FPVector') -> 'FPVector': 162 | return type(self)(self._value * other._value) 163 | 164 | @set_context 165 | @fp_cast 166 | def fp_div(self, other : 'FPVector') -> 'FPVector': 167 | return type(self)(self._value / other._value) 168 | 169 | @set_context 170 | def fp_fma(self, coef, offset) -> 'FPVector': 171 | cls = type(self) 172 | coef = _coerce(cls, coef) 173 | offset = _coerce(cls, offset) 174 | return cls(gmpy2.fma(self._value, coef._value, offset._value)) 175 | 176 | @set_context 177 | def fp_sqrt(self) -> 'FPVector': 178 | return type(self)(gmpy2.sqrt(self._value)) 179 | 180 | @set_context 181 | @fp_cast 182 | def fp_rem(self, other : 'FPVector') -> 'FPVector': 183 | return type(self)(gmpy2.remainder(self._value, other._value)) 184 | 185 | @set_context 186 | def fp_round_to_integral(self) -> 'FPVector': 187 | return type(self)(gmpy2.rint(self._value)) 188 | 189 | @set_context 190 | @fp_cast 191 | def fp_min(self, other : 'FPVector') -> 'FPVector': 192 | return type(self)(gmpy2.minnum(self._value, other._value)) 193 | 194 | @set_context 195 | @fp_cast 196 | def fp_max(self, other : 'FPVector') -> 'FPVector': 197 | return type(self)(gmpy2.maxnum(self._value, other._value)) 198 | 199 | @set_context 200 | @fp_cast 201 | def fp_leq(self, other : 'FPVector') -> Bit: 202 | return Bit(self._value <= other._value) 203 | 204 | @set_context 205 | @fp_cast 206 | def fp_lt(self, other : 'FPVector') -> Bit: 207 | return Bit(self._value < other._value) 208 | 209 | @set_context 210 | @fp_cast 211 | def fp_geq(self, other : 'FPVector') -> Bit: 212 | return Bit(self._value >= other._value) 213 | 214 | @set_context 215 | @fp_cast 216 | def fp_gt(self, other : 'FPVector') -> Bit: 217 | return Bit(self._value > other._value) 218 | 219 | @set_context 220 | @fp_cast 221 | def fp_eq(self, other : 'FPVector') -> Bit: 222 | return Bit(self._value == other._value) 223 | 224 | @set_context 225 | def fp_is_normal(self) -> Bit: 226 | return ~(self.fp_is_zero() | self.fp_is_infinite() | self.fp_is_subnormal() | self.fp_is_NaN()) 227 | 228 | @set_context 229 | def fp_is_subnormal(self) -> Bit: 230 | bv = self.reinterpret_as_bv() 231 | if (bv[type(self).mantissa_size:-1] == 0) & ~self.fp_is_zero(): 232 | assert type(self).ieee_compliance 233 | return Bit(True) 234 | else: 235 | return Bit(False) 236 | 237 | 238 | @set_context 239 | def fp_is_zero(self) -> Bit: 240 | return Bit(gmpy2.is_zero(self._value)) 241 | 242 | @set_context 243 | def fp_is_infinite(self) -> Bit: 244 | return Bit(gmpy2.is_infinite(self._value)) 245 | 246 | @set_context 247 | def fp_is_NaN(self) -> Bit: 248 | return Bit(gmpy2.is_nan(self._value)) 249 | 250 | @set_context 251 | def fp_is_negative(self) -> Bit: 252 | return Bit(self._value < 0) 253 | 254 | @set_context 255 | def fp_is_positive(self) -> Bit: 256 | return Bit(self._value > 0) 257 | 258 | @set_context 259 | def to_ubv(self, size : int) -> BitVector: 260 | return BitVector[size](int(self._value)) 261 | 262 | @set_context 263 | def to_sbv(self, size : int) -> SIntVector: 264 | return SIntVector[size](int(self._value)) 265 | 266 | @set_context 267 | def reinterpret_as_bv(self) -> BitVector: 268 | cls = type(self) 269 | sign_bit = BitVector[1](gmpy2.is_signed(self._value)) 270 | 271 | if self.fp_is_zero(): 272 | return BitVector[cls.size-1](0).concat(sign_bit) 273 | elif self.fp_is_infinite(): 274 | exp_bits = BitVector[cls.exponent_size](-1) 275 | mantissa_bits = BitVector[cls.mantissa_size](0) 276 | return mantissa_bits.concat(exp_bits).concat(sign_bit) 277 | elif self.fp_is_NaN(): 278 | exp_bits = BitVector[cls.exponent_size](-1) 279 | mantissa_bits = BitVector[cls.mantissa_size](1) 280 | return mantissa_bits.concat(exp_bits).concat(sign_bit) 281 | 282 | 283 | bias = (1 << (cls.exponent_size - 1)) - 1 284 | v = self._value 285 | 286 | mantissa_int, exp = v.as_mantissa_exp() 287 | 288 | exp = exp + cls.mantissa_size 289 | 290 | if exp < 1-bias: 291 | if not cls.ieee_compliance: 292 | warnings.warn('denorm will be flushed to 0') 293 | mantissa_int = 0 294 | else: 295 | while exp < 1 - bias: 296 | mantissa_int >>= 1 297 | exp += 1 298 | exp = 0 299 | else: 300 | exp = exp + bias 301 | 302 | assert exp.bit_length() <= cls.exponent_size 303 | 304 | if sign_bit: 305 | mantissa = -BitVector[cls.mantissa_size+1](mantissa_int) 306 | else: 307 | mantissa = BitVector[cls.mantissa_size+1](mantissa_int) 308 | exp_bits = BitVector[cls.exponent_size](exp) 309 | mantissa_bits = mantissa[:cls.mantissa_size] 310 | return mantissa_bits.concat(exp_bits).concat(sign_bit) 311 | 312 | @classmethod 313 | @set_context 314 | def reinterpret_from_bv(cls, value: BitVector) -> 'FPVector': 315 | if cls.size != value.size: 316 | raise TypeError() 317 | 318 | mantissa = value[:cls.mantissa_size] 319 | exp = value[cls.mantissa_size:-1] 320 | sign = value[-1:] 321 | assert exp.size == cls.exponent_size 322 | 323 | bias = (1 << (cls.exponent_size - 1)) - 1 324 | 325 | if exp == 0: 326 | if mantissa != 0: 327 | if not cls.ieee_compliance: 328 | warnings.warn('denorm will be flushed to 0') 329 | if sign[0]: 330 | return cls('-0') 331 | else: 332 | return cls('0') 333 | else: 334 | exp = 1 - bias 335 | s = ['-0.' if sign[0] else '0.', mantissa.binary_string()] 336 | s.append('e') 337 | s.append(str(exp)) 338 | return cls(gmpy2.mpfr(''.join(s),cls.mantissa_size+1, 2)) 339 | elif sign[0]: 340 | return cls('-0') 341 | else: 342 | return cls('0') 343 | elif exp == -1: 344 | if mantissa == 0: 345 | if sign: 346 | return cls('-inf') 347 | else: 348 | return cls('inf') 349 | else: 350 | if not cls.ieee_compliance: 351 | warnings.warn('NaN will be flushed to infinity') 352 | return cls('nan') 353 | else: 354 | #unbias the exponent 355 | exp = exp - bias 356 | 357 | s = ['-1.' if sign[0] else '1.', mantissa.binary_string()] 358 | s.append('e') 359 | s.append(str(exp.as_sint())) 360 | return cls(gmpy2.mpfr(''.join(s),cls.mantissa_size+1, 2)) 361 | 362 | 363 | def __neg__(self): return self.fp_neg() 364 | def __abs__(self): return self.fp_abs() 365 | def __add__(self, other): return self.fp_add(other) 366 | def __sub__(self, other): return self.fp_sub(other) 367 | def __mul__(self, other): return self.fp_mul(other) 368 | def __truediv__(self, other): return self.fp_div(other) 369 | def __mod__(self, other): return self.fp_rem(other) 370 | 371 | def __eq__(self, other): return self.fp_eq(other) 372 | def __ne__(self, other): return ~(self.fp_eq(other)) 373 | def __ge__(self, other): return self.fp_geq(other) 374 | def __gt__(self, other): return self.fp_gt(other) 375 | def __le__(self, other): return self.fp_leq(other) 376 | def __lt__(self, other): return self.fp_lt(other) 377 | 378 | @set_context 379 | def __float__(self): 380 | return float(self._value) 381 | 382 | @classmethod 383 | @set_context 384 | def random(cls, allow_inf=True) -> 'FPVector': 385 | bias = (1 << (cls.exponent_size - 1)) - 1 386 | if allow_inf: 387 | sign = random.choice([-1, 1]) 388 | mantissa = gmpy2.mpfr_random(gmpy2.random_state()) + 1 389 | exp = random.randint(1-bias, bias+1) 390 | return cls(mantissa*sign*(gmpy2.mpfr(2)**gmpy2.mpfr(exp))) 391 | else: 392 | sign = random.choice([-1, 1]) 393 | mantissa = gmpy2.mpfr_random(gmpy2.random_state()) + 1 394 | exp = random.randint(1-bias, bias) 395 | return cls(mantissa*sign*(gmpy2.mpfr(2)**gmpy2.mpfr(exp))) 396 | -------------------------------------------------------------------------------- /hwtypes/z3_bit_vector.py: -------------------------------------------------------------------------------- 1 | import hwtypes as ht 2 | import typing as tp 3 | import itertools as it 4 | import functools as ft 5 | from .bit_vector_abc import AbstractBitVector, AbstractBit, TypeFamily 6 | 7 | from abc import abstractmethod 8 | 9 | import z3 10 | 11 | import re 12 | import warnings 13 | import weakref 14 | 15 | import random 16 | 17 | __ALL__ = ['z3BitVector', 'z3NumVector', 'z3SIntVector', 'z3UIntVector'] 18 | 19 | _var_counter = it.count() 20 | _name_table = weakref.WeakValueDictionary() 21 | _free_names = [] 22 | 23 | def _gen_name(): 24 | if _free_names: 25 | return _free_names.pop() 26 | name = f'V_{next(_var_counter)}' 27 | while name in _name_table: 28 | name = f'V_{next(_var_counter)}' 29 | return name 30 | 31 | _name_re = re.compile(r'V_\d+') 32 | 33 | class _SMYBOLIC: 34 | def __repr__(self): 35 | return 'SYMBOLIC' 36 | 37 | class _AUTOMATIC: 38 | def __repr__(self): 39 | return 'AUTOMATIC' 40 | 41 | SMYBOLIC = _SMYBOLIC() 42 | AUTOMATIC = _AUTOMATIC() 43 | 44 | def bit_cast(fn): 45 | @ft.wraps(fn) 46 | def wrapped(self, other): 47 | if isinstance(other, z3Bit): 48 | return fn(self, other) 49 | else: 50 | return fn(self, z3Bit(other)) 51 | return wrapped 52 | 53 | class z3Bit(AbstractBit): 54 | @staticmethod 55 | def get_family() -> ht.TypeFamily: 56 | return _Family_ 57 | 58 | def __init__(self, value=SMYBOLIC, *, name=AUTOMATIC): 59 | if name is not AUTOMATIC and value is not SMYBOLIC: 60 | raise TypeError('Can only name symbolic variables') 61 | elif name is not AUTOMATIC: 62 | if not isinstance(name, str): 63 | raise TypeError('Name must be string') 64 | elif name in _name_table: 65 | raise ValueError(f'Name {name} already in use') 66 | elif _name_re.fullmatch(name): 67 | warnings.warn('Name looks like an auto generated name, this might break things') 68 | _name_table[name] = self 69 | elif name is AUTOMATIC and value is SMYBOLIC: 70 | name = _gen_name() 71 | _name_table[name] = self 72 | 73 | if value is SMYBOLIC: 74 | self._value = z3.Bool(name) 75 | elif isinstance(value, z3.BoolRef): 76 | self._value = value 77 | elif isinstance(value, z3Bit): 78 | if name is not AUTOMATIC and name != value.name: 79 | warnings.warn('Changing the name of a z3Bit does not cause a new underlying smt variable to be created') 80 | self._value = value._value 81 | elif isinstance(value, bool): 82 | self._value = z3.BoolVal(value) 83 | elif isinstance(value, int): 84 | if value not in {0, 1}: 85 | raise ValueError('Bit must have value 0 or 1 not {}'.format(value)) 86 | self._value = z3.BoolVal(bool(value)) 87 | elif hasattr(value, '__bool__'): 88 | self._value = z3.BoolVal(bool(value)) 89 | else: 90 | raise TypeError("Can't coerce {} to Bit".format(type(value))) 91 | 92 | self._name = name 93 | self._value = z3.simplify(self._value) 94 | 95 | def __repr__(self): 96 | if self._name is not AUTOMATIC: 97 | return f'{type(self)}({self._name})' 98 | else: 99 | return f'{type(self)}({self._value})' 100 | 101 | @property 102 | def value(self): 103 | return self._value 104 | 105 | @bit_cast 106 | def __eq__(self, other : 'z3Bit') -> 'z3Bit': 107 | return type(self)(self.value == other.value) 108 | 109 | @bit_cast 110 | def __ne__(self, other : 'z3Bit') -> 'z3Bit': 111 | return type(self)(self.value != other.value) 112 | 113 | def __invert__(self) -> 'z3Bit': 114 | return type(self)(z3.Not(self.value)) 115 | 116 | @bit_cast 117 | def __and__(self, other : 'z3Bit') -> 'z3Bit': 118 | return type(self)(z3.And(self.value, other.value)) 119 | 120 | @bit_cast 121 | def __or__(self, other : 'z3Bit') -> 'z3Bit': 122 | return type(self)(z3.Or(self.value, other.value)) 123 | 124 | @bit_cast 125 | def __xor__(self, other : 'z3Bit') -> 'z3Bit': 126 | return type(self)(z3.Xor(self.value, other.value)) 127 | 128 | def ite(self, t_branch, f_branch): 129 | tb_t = type(t_branch) 130 | fb_t = type(f_branch) 131 | BV_t = self.get_family().BitVector 132 | if isinstance(t_branch, BV_t) and isinstance(f_branch, BV_t): 133 | if tb_t is not fb_t: 134 | raise TypeError('Both branches must have the same type') 135 | T = tb_t 136 | elif isinstance(t_branch, BV_t): 137 | f_branch = tb_t(f_branch) 138 | T = tb_t 139 | elif isinstance(f_branch, BV_t): 140 | t_branch = fb_t(t_branch) 141 | T = fb_t 142 | else: 143 | t_branch = BV_t(t_branch) 144 | f_branch = BV_t(f_branch) 145 | ext = t_branch.size - f_branch.size 146 | if ext > 0: 147 | f_branch = f_branch.zext(ext) 148 | elif ext < 0: 149 | t_branch = t_branch.zext(-ext) 150 | 151 | T = type(t_branch) 152 | 153 | 154 | return T(z3.If(self.value, t_branch.value, f_branch.value)) 155 | 156 | def _coerce(T : tp.Type['z3BitVector'], val : tp.Any) -> 'z3BitVector': 157 | if not isinstance(val, z3BitVector): 158 | return T(val) 159 | elif val.size != T.size: 160 | raise TypeError('Inconsistent size') 161 | else: 162 | return val 163 | 164 | def bv_cast(fn : tp.Callable[['z3BitVector', 'z3BitVector'], tp.Any]) -> tp.Callable[['z3BitVector', tp.Any], tp.Any]: 165 | @ft.wraps(fn) 166 | def wrapped(self : 'z3BitVector', other : tp.Any) -> tp.Any: 167 | other = _coerce(type(self), other) 168 | return fn(self, other) 169 | return wrapped 170 | 171 | def int_cast(fn : tp.Callable[['z3BitVector', int], tp.Any]) -> tp.Callable[['z3BitVector', tp.Any], tp.Any]: 172 | @ft.wraps(fn) 173 | def wrapped(self : 'z3BitVector', other : tp.Any) -> tp.Any: 174 | other = int(other) 175 | return fn(self, other) 176 | return wrapped 177 | 178 | class z3BitVector(AbstractBitVector): 179 | @staticmethod 180 | def get_family() -> TypeFamily: 181 | return _Family_ 182 | 183 | def __init__(self, value=SMYBOLIC, *, name=AUTOMATIC): 184 | if name is not AUTOMATIC and value is not SMYBOLIC: 185 | raise TypeError('Can only name symbolic variables') 186 | elif name is not AUTOMATIC: 187 | if not isinstance(name, str): 188 | raise TypeError('Name must be string') 189 | elif name in _name_table: 190 | raise ValueError(f'Name {name} already in use') 191 | elif _name_re.fullmatch(name): 192 | warnings.warn('Name looks like an auto generated name, this might break things') 193 | _name_table[name] = self 194 | elif name is AUTOMATIC and value is SMYBOLIC: 195 | name = _gen_name() 196 | _name_table[name] = self 197 | self._name = name 198 | 199 | T = z3.BitVecSort(self.size) 200 | 201 | if value is SMYBOLIC: 202 | self._value = z3.BitVec(name, T) 203 | elif isinstance(value, z3.BitVecRef): 204 | t = value.sort() 205 | if t == T: 206 | self._value = value 207 | else: 208 | raise TypeError(f'Expected {T} not {t}') 209 | elif isinstance(value, z3BitVector): 210 | if name is not AUTOMATIC and name != value.name: 211 | warnings.warn('Changing the name of a z3BitVector does not cause a new underlying smt variable to be created') 212 | 213 | ext = self.size - value.size 214 | 215 | if ext < 0: 216 | warnings.warn('Truncating value from {} to {}'.format(type(value), type(self))) 217 | self._value = value[:self.size].value 218 | elif ext > 0: 219 | self._value = value.zext(ext).value 220 | else: 221 | self._value = value.value 222 | elif isinstance(value, z3.BoolRef): 223 | self._value = z3.If(value, z3.BitVecVal(1, T), z3.BitVecVal(0, T)) 224 | 225 | elif isinstance(value, z3Bit): 226 | self._value = z3.If(value.value, z3.BitVecVal(1, T), z3.BitVecVal(0, T)) 227 | 228 | elif isinstance(value, tp.Sequence): 229 | if len(value) != self.size: 230 | raise ValueError('Iterable is not the correct size') 231 | cls = type(self) 232 | B1 = cls.unsized_t[1] 233 | self._value = ft.reduce(lambda acc, elem : acc.concat(elem), map(B1, value)).value 234 | elif isinstance(value, int): 235 | self._value = z3.BitVecVal(value, self.size) 236 | 237 | elif hasattr(value, '__int__'): 238 | value = int(value) 239 | self._value = z3.BitVecVal(value, self.size) 240 | else: 241 | raise TypeError("Can't coerce {} to z3BitVector".format(type(value))) 242 | 243 | self._value = z3.simplify(self._value) 244 | assert self._value.sort() == T 245 | 246 | def make_constant(self, value, size:tp.Optional[int]=None): 247 | if size is None: 248 | size = self.size 249 | return type(self).unsized_t[size](value) 250 | 251 | @property 252 | def value(self): 253 | return self._value 254 | 255 | @property 256 | def num_bits(self): 257 | return self.size 258 | 259 | def __repr__(self): 260 | if self._name is not AUTOMATIC: 261 | return f'{type(self)}({self._name})' 262 | else: 263 | return f'{type(self)}({self._value})' 264 | 265 | def __getitem__(self, index): 266 | size = self.size 267 | if isinstance(index, slice): 268 | start, stop, step = index.start, index.stop, index.step 269 | 270 | if start is None: 271 | start = 0 272 | elif start < 0: 273 | start = size + start 274 | 275 | if stop is None: 276 | stop = size 277 | elif stop < 0: 278 | stop = size + stop 279 | 280 | stop = min(stop, size) 281 | 282 | if step is None: 283 | step = 1 284 | elif step != 1: 285 | raise IndexError('SMT extract does not support step != 1') 286 | 287 | v = z3.Extract(stop-1, start, self.value) 288 | return type(self).unsized_t[v.sort().size()](v) 289 | elif isinstance(index, int): 290 | if index < 0: 291 | index = size+index 292 | 293 | if not (0 <= index < size): 294 | raise IndexError() 295 | 296 | v = z3.Extract(index, index, self.value) 297 | return self.get_family().Bit(v == z3.BitVecVal(1, 1)) 298 | else: 299 | raise TypeError() 300 | 301 | 302 | def __setitem__(self, index, value): 303 | if isinstance(index, slice): 304 | raise NotImplementedError() 305 | else: 306 | if not (isinstance(value, bool) or isinstance(value, z3Bit) or (isinstance(value, int) and value in {0, 1})): 307 | raise ValueError("Second argument __setitem__ on a single BitVector index should be a bit, boolean or 0 or 1, not {value}".format(value=value)) 308 | 309 | if index < 0: 310 | index = self.size+index 311 | 312 | if not (0 <= index < self.size): 313 | raise IndexError() 314 | 315 | mask = type(self)(1 << index) 316 | self._value = z3Bit(value).ite(self | mask, self & ~mask)._value 317 | 318 | 319 | def __len__(self): 320 | return self.size 321 | 322 | def concat(self, other): 323 | T = type(self).unsized_t 324 | if not isinstance(other, T): 325 | raise TypeError(f'value must of type {T}') 326 | return T[self.size + other.size](z3.Concat(other.value, self.value)) 327 | 328 | def bvnot(self): 329 | return type(self)(~self.value) 330 | 331 | @bv_cast 332 | def bvand(self, other): 333 | return type(self)(self.value & other.value) 334 | 335 | @bv_cast 336 | def bvnand(self, other): 337 | return type(self)(~(self.value & other.value)) 338 | 339 | @bv_cast 340 | def bvor(self, other): 341 | return type(self)(self.value | other.value) 342 | 343 | @bv_cast 344 | def bvnor(self, other): 345 | return type(self)(~(self.value | other.value)) 346 | 347 | @bv_cast 348 | def bvxor(self, other): 349 | return type(self)(self.value ^ other.value) 350 | 351 | @bv_cast 352 | def bvxnor(self, other): 353 | return type(self)(~(self.value ^ other.value)) 354 | 355 | @bv_cast 356 | def bvshl(self, other): 357 | return type(self)(self.value << other.value) 358 | 359 | @bv_cast 360 | def bvlshr(self, other): 361 | return type(self)(z3.LShR(self.value, other.value)) 362 | 363 | @bv_cast 364 | def bvashr(self, other): 365 | return type(self)(self.value >> other.value) 366 | 367 | @bv_cast 368 | def bvrol(self, other): 369 | return type(self)(z3.RotateLeft(self.value, other.value)) 370 | 371 | @bv_cast 372 | def bvror(self, other): 373 | return type(self)(z3.RotateRight(self.value, other.value)) 374 | 375 | @bv_cast 376 | def bvcomp(self, other): 377 | return type(self).unsized_t[1](self.value == other.value) 378 | 379 | @bv_cast 380 | def bveq(self, other): 381 | return self.get_family().Bit(self.value == other.value) 382 | 383 | @bv_cast 384 | def bvne(self, other): 385 | return self.get_family().Bit(self.value != other.value) 386 | 387 | @bv_cast 388 | def bvult(self, other): 389 | return self.get_family().Bit(z3.ULT(self.value, other.value)) 390 | 391 | @bv_cast 392 | def bvule(self, other): 393 | return self.get_family().Bit(z3.ULE(self.value, other.value)) 394 | 395 | @bv_cast 396 | def bvugt(self, other): 397 | return self.get_family().Bit(z3.UGT(self.value, other.value)) 398 | 399 | @bv_cast 400 | def bvuge(self, other): 401 | return self.get_family().Bit(z3.UGE(self.value, other.value)) 402 | 403 | @bv_cast 404 | def bvslt(self, other): 405 | return self.get_family().Bit(self.value < other.value) 406 | 407 | @bv_cast 408 | def bvsle(self, other): 409 | return self.get_family().Bit(self.value <= other.value) 410 | 411 | @bv_cast 412 | def bvsgt(self, other): 413 | return self.get_family().Bit(self.value > other.value) 414 | 415 | @bv_cast 416 | def bvsge(self, other): 417 | return self.get_family().Bit(self.value >= other.value) 418 | 419 | def bvneg(self): 420 | return type(self)(-self.value) 421 | 422 | def adc(self, other : 'z3BitVector', carry : z3Bit) -> tp.Tuple['z3BitVector', z3Bit]: 423 | """ 424 | add with carry 425 | 426 | returns a two element tuple of the form (result, carry) 427 | 428 | """ 429 | T = type(self) 430 | other = _coerce(T, other) 431 | carry = _coerce(T.unsized_t[1], carry) 432 | 433 | a = self.zext(1) 434 | b = other.zext(1) 435 | c = carry.zext(T.size) 436 | 437 | res = a + b + c 438 | return res[0:-1], res[-1] 439 | 440 | def ite(self, t_branch, f_branch): 441 | return self.bvne(0).ite(t_branch, f_branch) 442 | 443 | @bv_cast 444 | def bvadd(self, other): 445 | return type(self)(self.value + other.value) 446 | 447 | @bv_cast 448 | def bvsub(self, other): 449 | return type(self)(self.value - other.value) 450 | 451 | @bv_cast 452 | def bvmul(self, other): 453 | return type(self)(self.value * other.value) 454 | 455 | @bv_cast 456 | def bvudiv(self, other): 457 | return type(self)(z3.UDiv(self.value, other.value)) 458 | 459 | @bv_cast 460 | def bvurem(self, other): 461 | return type(self)(z3.URem(self.value, other.value)) 462 | 463 | @bv_cast 464 | def bvsdiv(self, other): 465 | return type(self)(self.value / other.value) 466 | 467 | @bv_cast 468 | def bvsrem(self, other): 469 | return type(self)(self.value % other.value) 470 | 471 | __invert__ = bvnot 472 | __and__ = bvand 473 | __or__ = bvor 474 | __xor__ = bvxor 475 | 476 | __lshift__ = bvshl 477 | __rshift__ = bvlshr 478 | 479 | __neg__ = bvneg 480 | __add__ = bvadd 481 | __sub__ = bvsub 482 | __mul__ = bvmul 483 | __floordiv__ = bvudiv 484 | __mod__ = bvurem 485 | 486 | __eq__ = bveq 487 | __ne__ = bvne 488 | __ge__ = bvuge 489 | __gt__ = bvugt 490 | __le__ = bvule 491 | __lt__ = bvult 492 | 493 | 494 | 495 | @int_cast 496 | def repeat(self, n): 497 | return type(self)(z3.RepeatBitVec(n, self.value)) 498 | 499 | @int_cast 500 | def sext(self, ext): 501 | if ext < 0: 502 | raise ValueError() 503 | return type(self).unsized_t[self.size + ext](z3.SignExt(ext, self.value)) 504 | 505 | def ext(self, ext): 506 | return self.zext(ext) 507 | 508 | @int_cast 509 | def zext(self, ext): 510 | if ext < 0: 511 | raise ValueError() 512 | return type(self).unsized_t[self.size + ext](z3.ZeroExt(ext, self.value)) 513 | 514 | # Used in testing 515 | # def bits(self): 516 | # return [(self >> i) & 1 for i in range(self.size)] 517 | # 518 | # def __int__(self): 519 | # return self.as_uint() 520 | # 521 | # def as_uint(self): 522 | # return self._value.as_long() 523 | # 524 | # def as_sint(self): 525 | # return self._value.as_signed_long() 526 | # 527 | # @classmethod 528 | # def random(cls, width): 529 | # return cls.unsized_t[width](random.randint(0, (1 << width) - 1)) 530 | # 531 | 532 | class z3NumVector(z3BitVector): 533 | pass 534 | 535 | 536 | class z3UIntVector(z3NumVector): 537 | pass 538 | 539 | class z3SIntVector(z3NumVector): 540 | def __rshift__(self, other): 541 | return self.bvashr(other) 542 | 543 | def __floordiv__(self, other): 544 | return self.bvsdiv(other) 545 | 546 | def __mod__(self, other): 547 | return self.bvsrem(other) 548 | 549 | def __ge__(self, other): 550 | return self.bvsge(other) 551 | 552 | def __gt__(self, other): 553 | return self.bvsgt(other) 554 | 555 | def __lt__(self, other): 556 | return self.bvslt(other) 557 | 558 | def __le__(self, other): 559 | return self.bvsle(other) 560 | 561 | _Family_ = ht.TypeFamily(z3Bit, z3BitVector, z3UIntVector, z3SIntVector) 562 | 563 | 564 | -------------------------------------------------------------------------------- /hwtypes/bit_vector.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | from .bit_vector_abc import AbstractBitVector, AbstractBit, TypeFamily, InconsistentSizeError 3 | from .bit_vector_util import build_ite 4 | from .util import Method 5 | from .compatibility import IntegerTypes, StringTypes 6 | 7 | import functools 8 | import random 9 | import warnings 10 | 11 | # 12 | # seq to int 13 | # 14 | def seq2int(l : tp.Sequence): 15 | value = 0 16 | for idx,v in enumerate(l): 17 | value |= int(bool(v)) << idx 18 | return value 19 | 20 | # 21 | # int to seq 22 | # 23 | def int2seq(value : int, n : int): 24 | return [(value >> i) & 1 for i in range(n)] 25 | 26 | 27 | def bit_cast(fn : tp.Callable[['Bit', 'Bit'], 'Bit']) -> tp.Callable[['Bit', tp.Union['Bit', bool]], 'Bit']: 28 | @functools.wraps(fn) 29 | def wrapped(self : 'Bit', other : tp.Union['Bit', bool]) -> 'Bit': 30 | if isinstance(other, Bit): 31 | return fn(self, other) 32 | else: 33 | try: 34 | other = Bit(other) 35 | except TypeError: 36 | return NotImplemented 37 | return fn(self, other) 38 | return wrapped 39 | 40 | 41 | class Bit(AbstractBit): 42 | @staticmethod 43 | def get_family() -> TypeFamily: 44 | return _Family_ 45 | 46 | def __init__(self, value): 47 | if isinstance(value, Bit): 48 | self._value = value._value 49 | elif isinstance(value, bool): 50 | self._value = value 51 | elif isinstance(value, int): 52 | if value not in {0, 1}: 53 | raise ValueError('Bit must have value 0 or 1 not {}'.format(value)) 54 | self._value = bool(value) 55 | elif hasattr(value, '__bool__'): 56 | self._value = bool(value) 57 | else: 58 | raise TypeError("Can't coerce {} to Bit".format(type(value))) 59 | 60 | def __invert__(self): 61 | return type(self)(not self._value) 62 | 63 | @bit_cast 64 | def __eq__(self, other): 65 | return type(self)(self._value == other._value) 66 | 67 | @bit_cast 68 | def __ne__(self, other): 69 | return type(self)(self._value != other._value) 70 | 71 | @bit_cast 72 | def __and__(self, other): 73 | return type(self)(self._value & other._value) 74 | 75 | @bit_cast 76 | def __rand__(self, other): 77 | return type(self)(other._value & self._value) 78 | 79 | @bit_cast 80 | def __or__(self, other): 81 | return type(self)(self._value | other._value) 82 | 83 | @bit_cast 84 | def __ror__(self, other): 85 | return type(self)(other._value | self._value) 86 | 87 | @bit_cast 88 | def __xor__(self, other): 89 | return type(self)(self._value ^ other._value) 90 | 91 | @bit_cast 92 | def __rxor__(self, other): 93 | return type(self)(other._value ^ self._value) 94 | 95 | 96 | def ite(self, t_branch, f_branch): 97 | ''' 98 | typing works as follows: 99 | if t_branch and f_branch are both Bit[Vector] types 100 | from the same family, and type(t_branch) is type(f_branch), 101 | then return type is t_branch. 102 | 103 | elif t_branch and f_branch are both Bit[Vector] types 104 | from the same family, then return type is a polymorphic 105 | type. 106 | 107 | elif t_brand and f_branch are tuples of the 108 | same length, then these rules are applied recursively 109 | 110 | else there is an error 111 | 112 | ''' 113 | def _ite(select, t_branch, f_branch): 114 | return t_branch if select else f_branch 115 | 116 | return build_ite(_ite, self, t_branch, f_branch) 117 | 118 | def __bool__(self) -> bool: 119 | return self._value 120 | 121 | def __int__(self) -> int: 122 | return int(self._value) 123 | 124 | def __repr__(self) -> str: 125 | return f'{type(self).__name__}({self._value})' 126 | 127 | def __hash__(self) -> int: 128 | return hash((type(self), self._value)) 129 | 130 | @classmethod 131 | def random(cls) -> AbstractBit: 132 | return cls(random.getrandbits(1)) 133 | 134 | def _coerce(T : tp.Type['BitVector'], val : tp.Any) -> 'BitVector': 135 | if not isinstance(val, BitVector): 136 | return T(val) 137 | elif val.size != T.size: 138 | raise InconsistentSizeError('Inconsistent size') 139 | else: 140 | return val 141 | 142 | def bv_cast(fn : tp.Callable[['BitVector', 'BitVector'], tp.Any]) -> tp.Callable[['BitVector', tp.Any], tp.Any]: 143 | @functools.wraps(fn) 144 | def wrapped(self : 'BitVector', other : tp.Any) -> tp.Any: 145 | other = _coerce(type(self), other) 146 | return fn(self, other) 147 | return wrapped 148 | 149 | 150 | 151 | def dispatch_oper(method: tp.MethodDescriptorType): 152 | def oper(self, other): 153 | try: 154 | return method(self, other) 155 | except InconsistentSizeError as e: 156 | raise e from None 157 | except TypeError: 158 | return NotImplemented 159 | 160 | return Method(oper) 161 | 162 | 163 | # A little inefficient because of double _coerce but whate;er 164 | def dispatch_roper(method: Method): 165 | def roper(self, other): 166 | try: 167 | other = _coerce(type(self), other) 168 | except inconsistentsizeerror as e: 169 | raise e from None 170 | except TypeError: 171 | return NotImplemented 172 | return method(other, self) 173 | 174 | return Method(roper) 175 | 176 | 177 | class BitVector(AbstractBitVector): 178 | @staticmethod 179 | def get_family() -> TypeFamily: 180 | return _Family_ 181 | 182 | def __init__(self, value=0): 183 | if isinstance(value, BitVector): 184 | if value.size > self.size: 185 | warnings.warn('Truncating value from {} to {}'.format(type(value), type(self)), stacklevel=3) 186 | value = value._value 187 | elif isinstance(value, Bit): 188 | value = int(bool(value)) 189 | elif isinstance(value, IntegerTypes): 190 | if value.bit_length() > self.size: 191 | pass 192 | #This warning would be trigger constantly 193 | #warnings.warn('Truncating value {} to {}'.format(value, type(self))) 194 | elif isinstance(value, tp.Sequence): 195 | if len(value) > self.size: 196 | warnings.warn('Truncating value {} to {}'.format(value, type(self)), stacklevel=3) 197 | value = seq2int(value) 198 | elif hasattr(value, '__int__'): 199 | value = int(value) 200 | if value.bit_length() > self.size: 201 | warnings.warn('Truncating value {} to {}'.format(value, type(self)), stacklevel=3) 202 | else: 203 | raise TypeError('Cannot construct {} from {}'.format(type(self), value)) 204 | mask = (1 << self.size) - 1 205 | self._value = value & mask 206 | 207 | @classmethod 208 | def make_constant(cls, value, size=None): 209 | if size is None: 210 | return cls(value) 211 | else: 212 | return cls.unsized_t[size](value) 213 | 214 | def __hash__(self): 215 | return hash(f"{type(self)}{self._value}") 216 | 217 | def __str__(self): 218 | return str(int(self)) 219 | 220 | def __repr__(self): 221 | return f'{type(self).__name__}({self._value})' 222 | 223 | @property 224 | def value(self): 225 | return self._value 226 | 227 | def __setitem__(self, index, value): 228 | if isinstance(index, slice): 229 | raise NotImplementedError() 230 | else: 231 | if not (isinstance(value, bool) or isinstance(value, Bit) or (isinstance(value, int) and value in {0, 1})): 232 | raise ValueError("Second argument __setitem__ on a single BitVector index should be a boolean or 0 or 1, not {value}".format(value=value)) 233 | 234 | if index < 0: 235 | index = self.size+index 236 | 237 | if not (0 <= index < self.size): 238 | raise IndexError() 239 | 240 | mask = type(self)(1 << index) 241 | self._value = Bit(value).ite(self | mask, self & ~mask)._value 242 | 243 | def __getitem__(self, index : tp.Union[int, slice]) -> tp.Union['BitVector', Bit]: 244 | if isinstance(index, slice): 245 | v = self.bits()[index] 246 | return type(self).unsized_t[len(v)](v) 247 | elif isinstance(index, int): 248 | if index < 0: 249 | index = self.size+index 250 | if not (0 <= index < self.size): 251 | raise IndexError() 252 | return Bit((self._value >> index) & 1) 253 | else: 254 | raise TypeError() 255 | 256 | @property 257 | def num_bits(self): 258 | return self.size 259 | 260 | def __len__(self): 261 | return self.size 262 | 263 | def concat(self, other): 264 | T = type(self).unsized_t 265 | if not isinstance(other, T): 266 | raise TypeError(f'value must of type {T}') 267 | return T[self.size+other.size](self.value | (other.value << self.size)) 268 | 269 | def bvnot(self): 270 | return type(self)(~self.as_uint()) 271 | 272 | @bv_cast 273 | def bvand(self, other): 274 | return type(self)(self.as_uint() & other.as_uint()) 275 | 276 | @bv_cast 277 | def bvor(self, other): 278 | return type(self)(self.as_uint() | other.as_uint()) 279 | 280 | @bv_cast 281 | def bvxor(self, other): 282 | return type(self)(self.as_uint() ^ other.as_uint()) 283 | 284 | @bv_cast 285 | def bvshl(self, other): 286 | return type(self)(self.as_uint() << other.as_uint()) 287 | 288 | @bv_cast 289 | def bvlshr(self, other): 290 | return type(self)(self.as_uint() >> other.as_uint()) 291 | 292 | @bv_cast 293 | def bvashr(self, other): 294 | return type(self)(self.as_sint() >> other.as_uint()) 295 | 296 | @bv_cast 297 | def bvrol(self, other): 298 | other = (len(self) - other.as_uint()) % len(self) 299 | return self[other:].concat(self[:other]) 300 | 301 | @bv_cast 302 | def bvror(self, other): 303 | other = other.as_uint() % len(self) 304 | return self[other:].concat(self[:other]) 305 | 306 | @bv_cast 307 | def bvcomp(self, other): 308 | return type(self).unsized_t[1](self.as_uint() == other.as_uint()) 309 | 310 | @bv_cast 311 | def bveq(self, other): 312 | return self.get_family().Bit(self.as_uint() == other.as_uint()) 313 | 314 | @bv_cast 315 | def bvne(self, other): 316 | return self.get_family().Bit(self.as_uint() != other.as_uint()) 317 | 318 | @bv_cast 319 | def bvuge(self, other): 320 | return self.get_family().Bit(self.as_uint() >= other.as_uint()) 321 | 322 | @bv_cast 323 | def bvugt(self, other): 324 | return self.get_family().Bit(self.as_uint() > other.as_uint()) 325 | 326 | @bv_cast 327 | def bvule(self, other): 328 | return self.get_family().Bit(self.as_uint() <= other.as_uint()) 329 | 330 | @bv_cast 331 | def bvult(self, other): 332 | return self.get_family().Bit(self.as_uint() < other.as_uint()) 333 | 334 | @bv_cast 335 | def bvslt(self, other): 336 | return self.get_family().Bit(self.as_sint() < other.as_sint()) 337 | 338 | def bvneg(self): 339 | return type(self)(~self.as_uint() + 1) 340 | 341 | def adc(self, other : 'BitVector', carry : Bit) -> tp.Tuple['BitVector', Bit]: 342 | """ 343 | add with carry 344 | 345 | returns a two element tuple of the form (result, carry) 346 | 347 | """ 348 | T = type(self) 349 | other = _coerce(T, other) 350 | carry = _coerce(T.unsized_t[1], carry) 351 | 352 | a = self.zext(1) 353 | b = other.zext(1) 354 | c = carry.zext(T.size) 355 | 356 | res = a + b + c 357 | return res[0:-1], res[-1] 358 | 359 | def ite(self, t_branch, f_branch): 360 | return self.bvne(0).ite(t_branch, f_branch) 361 | 362 | @bv_cast 363 | def bvadd(self, other): 364 | return type(self)(self.as_uint() + other.as_uint()) 365 | 366 | @bv_cast 367 | def bvsub(self, other): 368 | return type(self)(self.as_uint() - other.as_uint()) 369 | 370 | @bv_cast 371 | def bvmul(self, other): 372 | return type(self)(self.as_uint() * other.as_uint()) 373 | 374 | @bv_cast 375 | def bvudiv(self, other): 376 | other = other.as_uint() 377 | if other == 0: 378 | return type(self)((1 << self.size) - 1) 379 | return type(self)(self.as_uint() // other) 380 | 381 | @bv_cast 382 | def bvurem(self, other): 383 | other = other.as_uint() 384 | if other == 0: 385 | return self 386 | return type(self)(self.as_uint() % other) 387 | 388 | @bv_cast 389 | def bvsdiv(self, other): 390 | other = other.as_sint() 391 | if other == 0: 392 | return type(self)((1 << self.size) - 1) 393 | return type(self)(self.as_sint() // other) 394 | 395 | @bv_cast 396 | def bvsrem(self, other): 397 | other = other.as_sint() 398 | if other == 0: 399 | return self 400 | return type(self)(self.as_sint() % other) 401 | 402 | def __invert__(self): return self.bvnot() 403 | 404 | __and__ = dispatch_oper(bvand) 405 | __rand__ = dispatch_roper(__and__) 406 | 407 | __or__ = dispatch_oper(bvor) 408 | __ror__ = dispatch_roper(__or__) 409 | 410 | __xor__ = dispatch_oper(bvxor) 411 | __rxor__ = dispatch_roper(__xor__) 412 | 413 | __lshift__ = dispatch_oper(bvshl) 414 | __rlshift__ = dispatch_roper(__lshift__) 415 | 416 | __rshift__ = dispatch_oper(bvlshr) 417 | __rrshift__ = dispatch_oper(__rshift__) 418 | 419 | def __neg__(self): return self.bvneg() 420 | 421 | __add__ = dispatch_oper(bvadd) 422 | __radd__ = dispatch_roper(__add__) 423 | 424 | __sub__ = dispatch_oper(bvsub) 425 | __rsub__ = dispatch_roper(__sub__) 426 | 427 | __mul__ = dispatch_oper(bvmul) 428 | __rmul__ = dispatch_roper(__mul__) 429 | 430 | __floordiv__ = dispatch_oper(bvudiv) 431 | __rfloordiv__ = dispatch_roper(__floordiv__) 432 | 433 | __mod__ = dispatch_oper(bvurem) 434 | __rmod__ = dispatch_roper(__mod__) 435 | 436 | __eq__ = dispatch_oper(bveq) 437 | __ne__ = dispatch_oper(bvne) 438 | __ge__ = dispatch_oper(bvuge) 439 | __gt__ = dispatch_oper(bvugt) 440 | __le__ = dispatch_oper(bvule) 441 | __lt__ = dispatch_oper(bvult) 442 | 443 | def as_uint(self): 444 | return self._value 445 | 446 | def as_sint(self): 447 | value = self._value 448 | if self[-1]: 449 | value = value - (1 << self.size) 450 | return value 451 | 452 | as_int = as_sint 453 | 454 | def __int__(self): 455 | return self.as_uint() 456 | 457 | def __bool__(self): 458 | return bool(int(self)) 459 | 460 | def binary_string(self): 461 | return "".join(str(int(i)) for i in reversed(self.bits())) 462 | 463 | def as_binary_string(self): 464 | return "0b" + self.binary_string() 465 | 466 | def bits(self): 467 | return int2seq(self._value, self.size) 468 | 469 | def as_bool_list(self): 470 | return [bool(x) for x in self.bits()] 471 | 472 | def repeat(self, r): 473 | r = int(r) 474 | if r <= 0: 475 | raise ValueError() 476 | 477 | return type(self).unsized_t[r * self.size](r * self.bits()) 478 | 479 | def sext(self, ext): 480 | ext = int(ext) 481 | if ext < 0: 482 | raise ValueError() 483 | 484 | T = type(self).unsized_t 485 | return self.concat(T[1](self[-1]).repeat(ext)) 486 | 487 | def ext(self, ext): 488 | return self.zext(ext) 489 | 490 | def zext(self, ext): 491 | ext = int(ext) 492 | if ext < 0: 493 | raise ValueError() 494 | 495 | T = type(self).unsized_t 496 | return self.concat(T[ext](0)) 497 | 498 | @staticmethod 499 | def random(width): 500 | return BitVector[width](random.randint(0, (1 << width) - 1)) 501 | 502 | 503 | class NumVector(BitVector): 504 | __hash__ = BitVector.__hash__ 505 | 506 | 507 | class UIntVector(NumVector): 508 | __hash__ = NumVector.__hash__ 509 | 510 | @staticmethod 511 | def random(width): 512 | return UIntVector[width](random.randint(0, (1 << width) - 1)) 513 | 514 | 515 | class SIntVector(NumVector): 516 | __hash__ = NumVector.__hash__ 517 | 518 | def __int__(self): 519 | return self.as_sint() 520 | 521 | def __rshift__(self, other): 522 | try: 523 | return self.bvashr(other) 524 | except InconsistentSizeError as e: 525 | raise e from None 526 | except TypeError: 527 | return NotImplemented 528 | 529 | __rrshift__ = dispatch_roper(__rshift__) 530 | 531 | def __floordiv__(self, other): 532 | try: 533 | return self.bvsdiv(other) 534 | except InconsistentSizeError as e: 535 | raise e from None 536 | except TypeError: 537 | return NotImplemented 538 | 539 | __rfloordiv__ = dispatch_roper(__floordiv__) 540 | 541 | def __mod__(self, other): 542 | try: 543 | return self.bvsrem(other) 544 | except InconsistentSizeError as e: 545 | raise e from None 546 | except TypeError: 547 | return NotImplemented 548 | 549 | __rmod__ = dispatch_roper(__mod__) 550 | 551 | def __ge__(self, other): 552 | try: 553 | return self.bvsge(other) 554 | except InconsistentSizeError as e: 555 | raise e from None 556 | except TypeError: 557 | return NotImplemented 558 | 559 | def __gt__(self, other): 560 | try: 561 | return self.bvsgt(other) 562 | except InconsistentSizeError as e: 563 | raise e from None 564 | except TypeError: 565 | return NotImplemented 566 | 567 | def __lt__(self, other): 568 | try: 569 | return self.bvslt(other) 570 | except InconsistentSizeError as e: 571 | raise e from None 572 | except TypeError: 573 | return NotImplemented 574 | 575 | def __le__(self, other): 576 | try: 577 | return self.bvsle(other) 578 | except InconsistentSizeError as e: 579 | raise e from None 580 | except TypeError: 581 | return NotImplemented 582 | 583 | @staticmethod 584 | def random(width): 585 | w = width - 1 586 | return SIntVector[width](random.randint(-(1 << w), (1 << w) - 1)) 587 | 588 | def ext(self, other): 589 | return self.sext(other) 590 | 591 | def overflow(a, b, res): 592 | msb_a = a[-1] 593 | msb_b = b[-1] 594 | N = res[-1] 595 | return (msb_a & msb_b & ~N) | (~msb_a & ~msb_b & N) 596 | 597 | _Family_ = TypeFamily(Bit, BitVector, UIntVector, SIntVector) 598 | --------------------------------------------------------------------------------