├── README.md ├── testing.py ├── LICENSE ├── config.py ├── asymmetric.py ├── symmetric.py ├── bootstrappable.py └── polynomial.py /README.md: -------------------------------------------------------------------------------- 1 | # morphine 2 | Homomorphic Encryption Library in Python 3 | 4 | ### Introduction 5 | This implementation of homomorphic encryption is based on [Computing Arbitrary Functions of Encrypted Data](https://crypto.stanford.edu/craig/easy-fhe.pdf) by [Craig Gentry](https://www.macfound.org/fellows/914/), published in [Communications of the ACM, Vol. 53 No. 3](http://cacm.acm.org/magazines/2010/3/76272-computing-arbitrary-functions-of-encrypted-data/abstract) in March 2010. 6 | -------------------------------------------------------------------------------- /testing.py: -------------------------------------------------------------------------------- 1 | # Python modules 2 | import time 3 | from datetime import timedelta 4 | 5 | 6 | def consistency(func, args, expected, n=10**4): 7 | """Analyze and report on the consistency of a function.""" 8 | print('\n[CONSISTENCY TEST] {0}'.format(func.__doc__.format(*args))) 9 | 10 | def show(num, den, t, p, end='\r'): 11 | print('{3}|{4:.3f}: {0}/{1} = {2}'.format(num, den, num/den, str(timedelta(seconds=t)), p), end=end) 12 | 13 | start = time.time() 14 | interval = start 15 | tally = 0 16 | for i in range(n): 17 | isCorrect = func(*args) == expected 18 | tally += (1 if isCorrect else 0) 19 | diff = time.time() - interval 20 | if diff > 0.01: 21 | interval = time.time() 22 | show(tally, (i+1), time.time() - start, (i+1)/n) 23 | show(tally, n, time.time() - start, (i+1)/n, '\n') 24 | 25 | 26 | def max_over(n, func, args=None): 27 | """Compute the maximum value returned by func(args) in n runs.""" 28 | m = 0 29 | for i in range(n): 30 | v = func(*args) if args else func() 31 | if v > m: 32 | m = v 33 | return m -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Daniel Augusto Rizzi Salvadori 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Python modules 2 | from math import log2, floor 3 | 4 | 5 | SECURITY_PARAMETER = 4 6 | 7 | N = SECURITY_PARAMETER 8 | P = SECURITY_PARAMETER**2 9 | Q = SECURITY_PARAMETER**5 10 | 11 | 12 | HINT_SUBSET_SIZE = floor(SECURITY_PARAMETER/log2(SECURITY_PARAMETER)) # ALPHA 13 | HINT_SIZE = 2 * HINT_SUBSET_SIZE 14 | 15 | PK_SIZE = 15 # This is the number of encryptions of zero in the public key 16 | 17 | # need to tweak the following 2 parameters such that the maximum noise of an asymmetric cyphertext is <= than that of a symmetric one 18 | PK_M_BIT_LENGTH = 3 19 | PK_SUBSET_SIZE = 4 #5 works almost with 2 mults 20 | # Todo: work through math (instead of finding values empirically) 21 | # When computing asymmetric encryption we perform (PK_SUBSET_SIZE + 1) sums of elements 22 | # whose noise is at most PK_M_BIT_LENGTH bits long. 23 | # The upperbound on the bitlength of the noise of the sum is then (PK_M_BIT_LENGTH + PK_SUBSET_SIZE + 1) 24 | # testUpperBound = (PK_M_BIT_LENGTH + PK_SUBSET_SIZE -2)#PK_M_BIT_LENGTH + PK_SUBSET_SIZE + 1 25 | # print('asymmetric upperbound for noise (test)', testUpperBound) # this last value shoulb be <= to N 26 | # print(2**N, 2**testUpperBound) 27 | 28 | if __name__ == '__main__': 29 | print('N, P, Q:', N, P, Q) 30 | print('HINT_SUBSET_SIZE:', HINT_SUBSET_SIZE) 31 | print('HINT_SIZE:', HINT_SIZE) -------------------------------------------------------------------------------- /asymmetric.py: -------------------------------------------------------------------------------- 1 | # Project modules 2 | import symmetric 3 | from config import * 4 | 5 | # Python modules 6 | import numpy as np 7 | 8 | 9 | def keygen(kt=symmetric.KeyType.RANDOM): 10 | """Generate private and public keys.""" 11 | 12 | # Generate regular secret key 13 | sk = symmetric.keygen(kt) 14 | 15 | # Generate a set of encryptions of zero 16 | pk = [symmetric.encrypt(sk, 0, PK_M_BIT_LENGTH) for i in range(PK_SIZE)] 17 | 18 | return (sk, pk) 19 | 20 | 21 | def encrypt(pk, b): 22 | """Encrypt a bit b into an integer based on a public key.""" 23 | 24 | # Choose a random subset of the pk's encrypted zeros 25 | idxs = np.random.choice(np.arange(PK_SIZE), PK_SUBSET_SIZE, replace=False) 26 | 27 | # Sum subset of encrypted zeros along with the bit to be encrypted 28 | cypher = b + sum([pk[i] for i in idxs]) 29 | 30 | return cypher 31 | 32 | 33 | def decrypt(sk, c): 34 | """Decrypt using the normal symmetric algorithm.""" 35 | return symmetric.decrypt(sk, c) 36 | 37 | 38 | if __name__ == '__main__': 39 | 40 | sk, pk = keygen() 41 | assert decrypt(sk, encrypt(pk, 0)) == 0 42 | assert decrypt(sk, encrypt(pk, 1)) == 1 43 | 44 | from testing import consistency 45 | 46 | T = 10**5 47 | 48 | def keygenEncryptDecrypt(b): 49 | """asymmetric keygen-encrypt-decrypt (bit = {0})""" 50 | sk, pk = keygen() 51 | c = encrypt(pk, b) 52 | return decrypt(sk, c) 53 | 54 | consistency(keygenEncryptDecrypt, (0,), 0, T) 55 | consistency(keygenEncryptDecrypt, (1,), 1, T) 56 | 57 | def keygenEncryptSum(b1, b2): 58 | """asymmetric keygen-encrypt-sum ({0}+{1})""" 59 | sk, pk = keygen() 60 | c1 = encrypt(pk, b1) 61 | c2 = encrypt(pk, b2) 62 | return decrypt(sk, c1 + c2) 63 | 64 | consistency(keygenEncryptSum, (0,0), 0, T) 65 | consistency(keygenEncryptSum, (0,1), 1, T) 66 | consistency(keygenEncryptSum, (1,1), 0, T) 67 | 68 | def keygenEncryptMult(b1, b2): 69 | """asymmetric keygen-encrypt-mult ({0}*{1})""" 70 | sk, pk = keygen() 71 | c1 = encrypt(pk, b1) 72 | c2 = encrypt(pk, b2) 73 | return decrypt(sk, c1 * c2) 74 | 75 | consistency(keygenEncryptMult, (0,0), 0, T) 76 | consistency(keygenEncryptMult, (0,1), 0, T) 77 | consistency(keygenEncryptMult, (1,1), 1, T) 78 | -------------------------------------------------------------------------------- /symmetric.py: -------------------------------------------------------------------------------- 1 | # Project modules 2 | from config import * 3 | 4 | # Python modules 5 | import random 6 | 7 | 8 | class KeyType: 9 | SMALLEST = 0 10 | RANDOM = 1 11 | LARGEST = 2 12 | 13 | 14 | def keygen(kt=KeyType.RANDOM): 15 | """Generate a key (i.e. a random odd integer between 2^(P-1) and 2^P).""" 16 | lowerBound = 2**(P-2) 17 | higherBound = 2**(P-1) - 1 18 | if kt == KeyType.SMALLEST: 19 | half = lowerBound 20 | elif kt == KeyType.LARGEST: 21 | half = higherBound 22 | elif kt == KeyType.RANDOM: 23 | half = random.randint(lowerBound, higherBound) 24 | return (half << 1) + 1 25 | # NOTE: If the lower bound were 0 (i.e. the odd integer were truly random in P), it would mean that the secret key 26 | # could be smaller then the N-bit integer 'm' generated during encryption, in which case decryption would fail. 27 | 28 | 29 | def encrypt(sk, b, mbits=N): 30 | """Encrypt a bit into a Q-bit integer based on the provided key.""" 31 | 32 | # Random N-bit integer with the same parity as b 33 | m = (random.randint(2**(mbits-2), 2**(mbits-1) -1) << 1) + b 34 | 35 | # Random Q-bit integer 36 | q = random.randint(2**(Q-1), 2**Q) - 1 37 | 38 | return (m + sk*q) 39 | 40 | 41 | def encryptD(sk, b, mbits=N): 42 | """Same as encrypt except it returns Q as well to allow computation of the noise.""" 43 | 44 | # Random N-bit integer with the same parity as b 45 | m = (random.randint(2**(mbits-2), 2**(mbits-1) -1) << 1) + b 46 | 47 | # Random Q-bit integer 48 | q = random.randint(2**(Q-1), 2**Q) - 1 49 | 50 | return ((m + sk*q), q) 51 | 52 | 53 | def decrypt(sk, c): 54 | """Decrypt a cyphertext based on the provided key.""" 55 | return (c % sk) % 2 56 | 57 | 58 | def noise(sk, c): 59 | """Get the noise of a cyphertext based on the provided key.""" 60 | return (c % sk) 61 | 62 | 63 | def noiseQ(sk, c, q): 64 | """Get the noise of a cyphertext based on the provided key.""" 65 | return (c - sk*q) 66 | 67 | 68 | if __name__ == '__main__': 69 | 70 | key = keygen() 71 | assert decrypt(key, encrypt(key, 0)) == 0 72 | assert decrypt(key, encrypt(key, 1)) == 1 73 | 74 | # Print smallest and largest key sizes 75 | sk = keygen(KeyType.SMALLEST) 76 | print('Smallest Key:', sk) 77 | print('Largest Key:', keygen(KeyType.LARGEST), '\n') 78 | 79 | # Calculate Max Multiplicative Depth 80 | print('Calculating max multiplicative depth...') 81 | c, q = encryptD(sk, 0) 82 | mmd = -1 83 | for i in range(10): 84 | m = decrypt(sk, c) 85 | cumulative_noise = noiseQ(sk, c, q) 86 | print("[i = {0}] noise = {1} | b = {2}".format(i, cumulative_noise, m)) 87 | if cumulative_noise > sk: 88 | break 89 | mmd += 1 90 | (c2, q2) = encryptD(sk, 0) 91 | q = noise(sk, c)*q2 + noise(sk, c2)*q + sk*q*q2 92 | c *= c2 93 | print('Max Multiplicative Depth:', mmd, '\n') 94 | 95 | # Calculate Max Additive Depth 96 | print('Calculating max additive depth...') 97 | mads = [] 98 | for j in range(30): 99 | c, q = encryptD(sk, 0) 100 | mad = -1 101 | for i in range(10000): 102 | m = decrypt(sk, c) 103 | cumulative_noise = noiseQ(sk, c, q) 104 | if cumulative_noise > sk: 105 | break 106 | mad += 1 107 | (c2, q2) = encryptD(sk, 0) 108 | q += q2 109 | c += c2 110 | mads.append(mad) 111 | print('Max Additive Depth:', min(mads)) 112 | 113 | from testing import consistency 114 | T = 10**5 115 | 116 | def keygenEncryptDecrypt(b): 117 | """symmetric keygen-encrypt-decrypt (bit = {0})""" 118 | key = keygen() 119 | c = encrypt(key, b) 120 | return decrypt(key, c) 121 | 122 | consistency(keygenEncryptDecrypt, (0,), 0, T) 123 | consistency(keygenEncryptDecrypt, (1,), 1, T) 124 | 125 | def keygenEncryptSum(b1, b2): 126 | """symmetric keygen-encrypt-sum ({0}+{1})""" 127 | key = keygen() 128 | c1 = encrypt(key, b1) 129 | c2 = encrypt(key, b2) 130 | return decrypt(key, c1 + c2) 131 | 132 | consistency(keygenEncryptSum, (0,0), 0, T) 133 | consistency(keygenEncryptSum, (0,1), 1, T) 134 | consistency(keygenEncryptSum, (1,1), 0, T) 135 | 136 | def keygenEncryptMult(b1, b2): 137 | """symmetric keygen-encrypt-mult ({0}*{1})""" 138 | key = keygen() 139 | c1 = encrypt(key, b1) 140 | c2 = encrypt(key, b2) 141 | return decrypt(key, c1 * c2) 142 | 143 | consistency(keygenEncryptMult, (0,0), 0, T) 144 | consistency(keygenEncryptMult, (0,1), 0, T) 145 | consistency(keygenEncryptMult, (1,1), 1, T) 146 | -------------------------------------------------------------------------------- /bootstrappable.py: -------------------------------------------------------------------------------- 1 | # Project modules 2 | import asymmetric 3 | from config import * 4 | 5 | # Python modules 6 | from decimal import * 7 | import random 8 | import numpy as np 9 | 10 | getcontext().prec = Q 11 | 12 | 13 | def generate_ssp_set(size, subset_size, sumto): 14 | """ 15 | Generate random set of Decimals such that a subset sums to a given value. 16 | https://en.wikipedia.org/wiki/Subset_sum_problem 17 | """ 18 | 19 | # Generate set with entries uniformly distributed between 0 and 2*average 20 | average = sumto / Decimal(size) 21 | ssp_set = [Decimal(random.random()) * 2 * average for i in range(size)] 22 | 23 | # Choose a random subset 24 | idxs = np.random.choice(np.arange(size), subset_size, replace=False) 25 | 26 | # Adjust subset values to ensure they sum to 'sumto' 27 | sumto_initial = sum([ssp_set[i] for i in idxs]) 28 | divdiff = (sumto - sumto_initial)/Decimal(subset_size) 29 | ssp_set = [v+divdiff for v in ssp_set] 30 | 31 | # Return the choice of subset encoded as a binary array with Hamming weigth of 'subset_size' 32 | subset = [1 if i in idxs else 0 for i in range(size)] 33 | 34 | return (subset, ssp_set) 35 | 36 | 37 | def keygen(): 38 | """ 39 | Generate private and public keys for the "greased", bootstrappable encryption scheme. 40 | 41 | The main idea here is that we want to minimize computation during decryption because it originally contains 42 | an expensive modulus operation (c % p) that is too much for our somewhat homomorphic encryption system to handle 43 | in a single recrypt. We would like to substitute that expensive operation by a small summation. To do that we 44 | include a set with the public key that contains a hidden subset that sums to 1/p. We call that the SSP set. 45 | The secret key then becomes the binary vector encoding that subset. The SSP set is provided as part of the public 46 | key, along with the set of encryptions of zero. 47 | """ 48 | 49 | # Generate a regular key pair 50 | p, pk_z = asymmetric.keygen() 51 | 52 | # Generate a set of rational numbers such that a hidden subset sums to 1/p 53 | hint = (Decimal(1)/Decimal(p)) 54 | sk, pk_y = generate_ssp_set(HINT_SIZE, HINT_SUBSET_SIZE, hint) 55 | 56 | return (sk, pk_z, pk_y) 57 | 58 | 59 | def encrypt(pk_z, pk_y, b): 60 | """Encrypt a bit b into an array of integers based on the provided public key.""" 61 | 62 | # Perform regular asymmetric encryption 63 | c = asymmetric.encrypt(pk_z, b) 64 | 65 | # Post process the cypher text to generate the appropriate 1/p SPP set 66 | cy = post_process(c, pk_y) 67 | 68 | return (c, cy) 69 | 70 | 71 | def post_process(c, pk_y): 72 | """ 73 | Post-process the cyphertext with pk_y to simplify computation during decryption. 74 | 75 | We compute and return the 1/p SSP set for the provided cyphertext before decryption takes place. 76 | This is one of the key steps in making our homomorphic encryption scheme bootstrappable. 77 | """ 78 | return [Decimal(c) * Decimal(y) for y in pk_y] 79 | 80 | 81 | def decrypt(sk, c, cy): 82 | """Decrypt a bit (c, cz) based on sk.""" 83 | 84 | # Decode/compute hidden subset sum given secret key 85 | x = round(sum([cy[i] if sk[i] > 0 else 0 for i in range(HINT_SIZE)])) 86 | 87 | lsb_x = x & 1 88 | lsb_c = c & 1 89 | 90 | # XOR the least significant bits together 91 | return lsb_c ^ lsb_x 92 | 93 | 94 | if __name__ == '__main__': 95 | 96 | sk, ssp = generate_ssp_set(5, 2, Decimal(3.14159265)) 97 | assert float(sum([ssp[i] if sk[i] else 0 for i in range(5)])) == 3.14159265 98 | 99 | sk, pk_z, pk_y = keygen() 100 | assert decrypt(sk, *encrypt(pk_z, pk_y, 0)) == 0 101 | assert decrypt(sk, *encrypt(pk_z, pk_y, 1)) == 1 102 | 103 | from testing import consistency 104 | 105 | T = 10**5 106 | 107 | def keygenEncryptDecrypt(b): 108 | """bootstrappable keygen-encrypt-decrypt (bit = {0})""" 109 | sk, pk_z, pk_y = keygen() 110 | c, cy = encrypt(pk_z, pk_y, b) 111 | return decrypt(sk, c, cy) 112 | 113 | consistency(keygenEncryptDecrypt, (0,), 0, T) 114 | consistency(keygenEncryptDecrypt, (1,), 1, T) 115 | 116 | def keygenEncryptSum(b1, b2): 117 | """bootstrappable keygen-encrypt-sum ({0}+{1})""" 118 | sk, pk_z, pk_y = keygen() 119 | c1, cy1 = encrypt(pk_z, pk_y, b1) 120 | c2, cy2 = encrypt(pk_z, pk_y, b2) 121 | s = c1 + c2 122 | sy = post_process(s, pk_y) 123 | return decrypt(sk, s, sy) 124 | 125 | consistency(keygenEncryptSum, (0,0), 0, T) 126 | consistency(keygenEncryptSum, (0,1), 1, T) 127 | consistency(keygenEncryptSum, (1,1), 0, T) 128 | 129 | def keygenEncryptMult(b1, b2): 130 | """bootstrappable keygen-encrypt-mult ({0}*{1})""" 131 | sk, pk_z, pk_y = keygen() 132 | c1, cy1 = encrypt(pk_z, pk_y, b1) 133 | c2, cy2 = encrypt(pk_z, pk_y, b2) 134 | m = c1 * c2 135 | my = post_process(m, pk_y) 136 | return decrypt(sk, m, my) 137 | 138 | consistency(keygenEncryptMult, (0,0), 0, T) 139 | consistency(keygenEncryptMult, (0,1), 0, T) 140 | consistency(keygenEncryptMult, (1,1), 1, T) 141 | -------------------------------------------------------------------------------- /polynomial.py: -------------------------------------------------------------------------------- 1 | # Python modules 2 | from itertools import combinations, product 3 | from collections import Counter 4 | 5 | 6 | class Polynomial: 7 | def __init__(self, *args): 8 | self.values = Counter(args) 9 | 10 | def __iter__(self): 11 | return iter(self.elements()) 12 | 13 | def __getitem__(self, key): 14 | return self.elements()[key] 15 | 16 | def __str__(self): 17 | return type(self).__name__ + str(self.elements())#sorted(self.values)) 18 | 19 | def __repr__(self): 20 | return self.__str__() 21 | 22 | def __hash__(self): 23 | return hash(type(self).__name__) + hash(tuple(self.elements())) 24 | 25 | def __lt__(self, other): 26 | if isinstance(other, int): 27 | return False 28 | if isinstance(other, Polynomial): 29 | if type(self).__name__ == type(other).__name__: 30 | if len(self.values.keys()) < len(other.values.keys()): 31 | return True 32 | return False 33 | return len(type(self).__name__) < len(type(other).__name__) 34 | 35 | def __gt__(self, other): 36 | return not self.__lt__(other) 37 | 38 | def __eq__(self, other): 39 | return hash(flatten(self)) == hash(flatten(other)) 40 | 41 | def __ne__(self, other): 42 | return hash(flatten(self)) != hash(flatten(other)) 43 | 44 | def elements(self): 45 | return sorted(self.values.elements()) 46 | 47 | def eval(self, values): 48 | return evaluate(values, self) 49 | 50 | class Sum(Polynomial): 51 | pass 52 | 53 | class Prod(Polynomial): 54 | pass 55 | 56 | 57 | def evaluate(values: list, poly: Polynomial): 58 | """Evaluate a polynomial using the provided values.""" 59 | pt = type(poly) 60 | if pt == Sum: 61 | return sum(evaluate(values, i) for i in poly) 62 | elif pt == Prod: 63 | res = 1 64 | for item in poly: 65 | res *= evaluate(values, item) 66 | return res 67 | else: 68 | return values[poly] 69 | 70 | 71 | def elementary_symmetric(size: int, degree: int): 72 | """Return the elementary symmetric polynomial of specified size and degree.""" 73 | c = list(combinations(range(size), degree)) 74 | products = [Prod(*i) if len(i) > 1 else i[0] for i in c] 75 | return Sum(*products) if len(products) > 1 else products[0] 76 | 77 | 78 | def normalize(poly: Polynomial): 79 | """Unroll equivalent nested operations.""" 80 | new = _normalize(poly) 81 | while hash(poly) != hash(new): 82 | poly = new 83 | new = _normalize(poly) 84 | return poly 85 | 86 | 87 | def _normalize(poly: Polynomial): 88 | """Unroll equivalent nested operations once.""" 89 | pt = type(poly) 90 | if issubclass(pt, Polynomial): 91 | if len(poly.elements()) == 1: 92 | return _normalize(poly[0]) 93 | new_elements = [] 94 | for p in poly: 95 | if type(p) == pt: 96 | new_elements.extend(_normalize(p)) 97 | else: 98 | new_elements.append(_normalize(p)) 99 | return pt(*new_elements) 100 | else: 101 | return poly 102 | 103 | 104 | def flatten(poly: Polynomial): 105 | """Flatten and normalize a polynomial to the minimal depth.""" 106 | return normalize(_flatten(poly)) 107 | 108 | 109 | def _flatten(poly: Polynomial): 110 | """Flatten a polynomial to the minimal depth (unnormalized).""" 111 | pt = type(poly) 112 | if issubclass(pt, Polynomial): 113 | if pt == Prod: 114 | prod_elements = [] 115 | for sumOrInt in poly: 116 | if type(sumOrInt) == Sum: 117 | prod_elements.append([_flatten(s) for s in sumOrInt.elements()]) 118 | else: 119 | prod_elements.append([sumOrInt]) 120 | elements = list(product(*prod_elements)) 121 | if len(elements) == 1: 122 | return Prod(*elements[0]) 123 | return Sum(*[Prod(*i) for i in elements]) 124 | elif pt == Sum: 125 | return Sum(*[_flatten(i) for i in poly.elements()]) 126 | else: 127 | return poly 128 | 129 | 130 | if __name__ == '__main__': 131 | 132 | assert Sum(1, 2) == Sum(1, 2) 133 | assert Sum(1, 1) != Sum(1) 134 | assert Sum(1, 2) == Sum(2, 1) 135 | assert Sum(1, 2) != frozenset([1, 2]) 136 | assert Sum(1, 2) != Sum(1, 3) 137 | assert Prod(3, 6) == Prod(6, 3) 138 | assert Sum(10, 20) != Prod(10, 20) 139 | assert Sum(3, 2, Prod(21, 4)) == Sum(Prod(4, 21), 3, 2) 140 | 141 | assert hash(Sum(1, 2)) == hash(Sum(1, 2)) 142 | assert hash(Sum(1, 1)) != hash(Sum(1)) 143 | assert hash(Sum(1, 2)) == hash(Sum(2, 1)) 144 | assert hash(Sum(1, 2)) != hash(frozenset([1, 2])) 145 | assert hash(Sum(1, 2)) != hash(Sum(1, 3)) 146 | assert hash(Prod(3, 6)) == hash(Prod(6, 3)) 147 | assert hash(Sum(10, 20)) != hash(Prod(10, 20)) 148 | assert hash(Sum(3, 2, Prod(21, 4))) == hash(Sum(Prod(4, 21), 3, 2)) 149 | 150 | assert evaluate([10, 20, 30, 40], Sum(Prod(0, 1), Prod(2, 3))) == 1400 151 | assert evaluate([10, 20, 30, 40], Sum(Prod(0, 1, 2), Prod(3))) == 6040 152 | assert evaluate([1, 2, 3], Sum(Prod(Sum(0, 1), 2), Sum(0, 1), Prod(2), 0)) == 16 153 | assert Sum(Prod(Sum(0, 1), 2), Sum(0, 1), Prod(2), 0).eval([1, 2, 3]) == 16 154 | 155 | assert elementary_symmetric(4, 1) == Sum(0, 1, 2, 3) 156 | assert elementary_symmetric(4, 2) == Sum(Prod(0, 1), Prod(0, 2), Prod(0, 3), Prod(1, 2), Prod(1, 3), Prod(2, 3)) 157 | assert elementary_symmetric(4, 3) == Sum(Prod(0, 1, 2), Prod(0, 1, 3), Prod(0, 2, 3), Prod(1, 2, 3)) 158 | assert elementary_symmetric(4, 4) == Prod(0, 1, 2, 3) 159 | 160 | # TODO test ordering of nested polynomials - len(poly) should also take into account the items' length 161 | 162 | assert normalize(Sum(Sum(0, Sum(1, 2)), Sum(0, 1))) == Sum(0, 0, 1, 1, 2) 163 | assert normalize(Sum(Sum(0, 1), Sum(0, 1, 3, 1))) == Sum(0, 0, 1, 1, 1, 3) 164 | assert normalize(Prod(Prod(0, 1), Prod(0, 1))) == Prod(0, 0, 1, 1) 165 | assert normalize(Prod(Sum(3), Prod(Sum(1, 2), 5, 0, 4))) == Prod(0, 3, 4, 5, Sum(1, 2)) 166 | assert normalize(Prod(Sum(0, Prod(1, 2)), Sum(0, 1))) == Prod(Sum(0, Prod(1, 2)), Sum(0, 1)) 167 | assert normalize(Prod(Sum(1, 2))) == Sum(1, 2) 168 | 169 | # Test that normalizing once is equal to normalizing any number of times 170 | assert hash(normalize(Sum(Prod(0, Sum(Prod(1, 2))), Prod(1, Sum(Prod(1, 2)))))) == hash(normalize(normalize(Sum(Prod(0, Sum(Prod(1, 2))), Prod(1, Sum(Prod(1, 2))))))) 171 | 172 | assert Sum(Prod(1)) == Sum(1) 173 | assert Sum(Prod(Sum(1))) == Prod(1) 174 | 175 | assert Prod(0, Prod(1, 2)) == Prod(0, 1, 2) 176 | assert normalize(Prod(0, Prod(1, 2))) == Prod(0, 1, 2) 177 | 178 | nested = Prod(Sum(0, Prod(1, 2)), Sum(0, 1)) 179 | assert flatten(nested) == Sum(Prod(0, 0), Prod(0, 1), Prod(0, 1, 2), Prod(1, 1, 2)) 180 | assert normalize(normalize(nested)) == nested 181 | assert flatten(flatten(nested)) == nested 182 | 183 | assert flatten(Prod(0, Sum(0,1))) == Sum(Prod(0, 0), Prod(0, 1)) 184 | assert flatten(Sum(Sum(0, Sum(1, 2)), Sum(0, 1))) == Sum(0, 0, 1, 1, 2) 185 | assert flatten(Prod(Sum(0, Prod(1, 2)), Sum(0, 1))) == Sum(Prod(0, 0), Prod(0, 1), Prod(0, 1, 2), Prod(1, 1, 2)) 186 | assert flatten(Prod(2, Sum(0, 1))) == Sum(Prod(0, 2), Prod(1, 2)) 187 | --------------------------------------------------------------------------------