├── README.md ├── gmssl ├── __init__.py ├── func.py ├── optimized_curve.py ├── optimized_field_elements.py ├── optimized_pairing.py ├── sm2.py ├── sm3.py ├── sm4.py └── sm9.py └── tests ├── test_sm2.py ├── test_sm3.py ├── test_sm4.py └── test_sm9.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | Original: https://github.com/duanhongyi/gmssl 3 | 4 | Now include sm9. 5 | 6 | GMSSL 7 | ======== 8 | GmSSL是一个开源的加密包的python实现,支持SM2/SM3/SM4/SM9等国密(国家商用密码)算法、项目采用对商业应用友好的类BSD开源许可证,开源且可以用于闭源的商业应用。 9 | 10 | ### Setup and Test 11 | 12 | ``` 13 | export PYTHONPATH=/path/to/gmssl:$PYTHONPATH 14 | ``` 15 | 16 | Replace /path/to/gmssl with the path where gmssl is placed. Run: 17 | 18 | ``` 19 | python3 tests/test_sm2.py 20 | python3 tests/test_sm3.py 21 | python3 tests/test_sm4.py 22 | python3 tests/test_sm9.py 23 | ``` 24 | 25 | Replace tests with the path into the tests directory. 26 | 27 | ### SM2算法 28 | RSA算法的危机在于其存在亚指数算法,对ECC算法而言一般没有亚指数攻击算法 29 | SM2椭圆曲线公钥密码算法:我国自主知识产权的商用密码算法,是ECC(Elliptic Curve Cryptosystem)算法的一种,基于椭圆曲线离散对数问题,计算复杂度是指数级,求解难度较大,同等安全程度要求下,椭圆曲线密码较其他公钥算法所需密钥长度小很多。 30 | 31 | gmssl是包含国密SM2算法的Python实现, 提供了 `encrypt`、 `decrypt`等函数用于加密解密, 用法如下: 32 | 33 | #### 1. 初始化`CryptSM2` 34 | 35 | ```python 36 | import base64 37 | import binascii 38 | from gmssl import sm2, func 39 | #16进制的公钥和私钥 40 | private_key = '00B9AB0B828FF68872F21A837FC303668428DEA11DCD1B24429D0C99E24EED83D5' 41 | public_key = 'B9C9A6E04E9C91F7BA880429273747D7EF5DDEB0BB2FF6317EB00BEF331A83081A6994B8993F3F5D6EADDDB81872266C87C018FB4162F5AF347B483E24620207' 42 | sm2_crypt = sm2.CryptSM2( 43 | public_key=public_key, private_key=private_key) 44 | ``` 45 | 46 | #### 2. `encrypt`和`decrypt` 47 | 48 | ```python 49 | #数据和加密后数据为bytes类型 50 | data = b"111" 51 | enc_data = sm2_crypt.encrypt(data) 52 | dec_data =sm2_crypt.decrypt(enc_data) 53 | assert dec_data == data 54 | ``` 55 | 56 | #### 3. `sign`和`verify` 57 | ```python 58 | data = b"111" # bytes类型 59 | random_hex_str = func.random_hex(sm2_crypt.para_len) 60 | sign = sm2_crypt.sign(data, random_hex_str) # 16进制 61 | assert sm2_crypt.verify(sign, data) # 16进制 62 | ``` 63 | 64 | ### SM4算法 65 | 66 | 国密SM4(无线局域网SMS4)算法, 一个分组算法, 分组长度为128bit, 密钥长度为128bit, 67 | 算法具体内容参照[SM4算法](https://drive.google.com/file/d/0B0o25hRlUdXcbzdjT0hrYkkwUjg/view?usp=sharing)。 68 | 69 | gmssl是包含国密SM4算法的Python实现, 提供了 `encrypt_ecb`、 `decrypt_ecb`、 `encrypt_cbc`、 70 | `decrypt_cbc`等函数用于加密解密, 用法如下: 71 | 72 | #### 1. 初始化`CryptSM4` 73 | 74 | ```python 75 | from gmssl.sm4 import CryptSM4, SM4_ENCRYPT, SM4_DECRYPT 76 | 77 | key = b'3l5butlj26hvv313' 78 | value = b'111' # bytes类型 79 | iv = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' # bytes类型 80 | crypt_sm4 = CryptSM4() 81 | ``` 82 | 83 | #### 2. `encrypt_ecb`和`decrypt_ecb` 84 | 85 | ```python 86 | 87 | crypt_sm4.set_key(key, SM4_ENCRYPT) 88 | encrypt_value = crypt_sm4.crypt_ecb(value) # bytes类型 89 | crypt_sm4.set_key(key, SM4_DECRYPT) 90 | decrypt_value = crypt_sm4.crypt_ecb(encrypt_value) # bytes类型 91 | assert value == decrypt_value 92 | 93 | ``` 94 | 95 | #### 3. `encrypt_cbc`和`decrypt_cbc` 96 | 97 | ```python 98 | 99 | crypt_sm4.set_key(key, SM4_ENCRYPT) 100 | encrypt_value = crypt_sm4.crypt_cbc(iv , value) # bytes类型 101 | crypt_sm4.set_key(key, SM4_DECRYPT) 102 | decrypt_value = crypt_sm4.crypt_cbc(iv , encrypt_value) # bytes类型 103 | assert value == decrypt_value 104 | 105 | ``` 106 | 107 | ### SM9算法 108 | 109 | #### 1. `sign`和`verify` 110 | 111 | ```python 112 | 113 | idA = 'A' 114 | idB = 'B' 115 | master_public, master_secret = sm9.setup ('sign') 116 | Da = sm9.private_key_extract ('sign', master_public, master_secret, idA) 117 | message = 'abc' 118 | signature = sm9.sign (master_public, Da, message) 119 | assert (sm9.verify (master_public, idA, message, signature)) 120 | 121 | ``` 122 | 123 | #### 2. `key agreement` 124 | 125 | ```python 126 | 127 | idA = 'A' 128 | idB = 'B' 129 | master_public, master_secret = sm9.setup ('keyagreement') 130 | Da = sm9.private_key_extract ('keyagreement', master_public, master_secret, idA) 131 | Db = sm9.private_key_extract ('keyagreement', master_public, master_secret, idB) 132 | xa, Ra = sm9.generate_ephemeral (master_public, idB) 133 | xb, Rb = sm9.generate_ephemeral (master_public, idA) 134 | ska = sm9.generate_session_key (idA, idB, Ra, Rb, Da, xa, master_public, 'A', 128) 135 | skb = sm9.generate_session_key (idA, idB, Ra, Rb, Db, xb, master_public, 'B', 128) 136 | assert (ska == skb) 137 | 138 | ``` 139 | 140 | #### 3. `encrypt`和`decrypt` 141 | 142 | ```python 143 | 144 | idA = 'A' 145 | master_public, master_secret = sm9.setup ('encrypt') 146 | Da = sm9.private_key_extract ('encrypt', master_public, master_secret, idA) 147 | message = 'abc' 148 | ct = sm9.kem_dem_enc (master_public, idA, message, 32) 149 | pt = sm9.kem_dem_dec (master_public, idA, Da, ct, 32) 150 | assert (message == pt) 151 | 152 | 153 | -------------------------------------------------------------------------------- /gmssl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gongxian-ding/gmssl-python/498ba0545a3f7667ab4575e93cd10e9b19baf4f0/gmssl/__init__.py -------------------------------------------------------------------------------- /gmssl/func.py: -------------------------------------------------------------------------------- 1 | from random import choice 2 | 3 | 4 | xor = lambda a, b:list(map(lambda x, y: x ^ y, a, b)) 5 | 6 | rotl = lambda x, n:((x << n) & 0xffffffff) | ((x >> (32 - n)) & 0xffffffff) 7 | 8 | get_uint32_be = lambda key_data:((key_data[0] << 24) | (key_data[1] << 16) | (key_data[2] << 8) | (key_data[3])) 9 | 10 | put_uint32_be = lambda n:[((n>>24)&0xff), ((n>>16)&0xff), ((n>>8)&0xff), ((n)&0xff)] 11 | 12 | padding = lambda data, block=16: data + [(16 - len(data) % block)for _ in range(16 - len(data) % block)] 13 | 14 | unpadding = lambda data: data[:-data[-1]] 15 | 16 | list_to_bytes = lambda data: b''.join([bytes((i,)) for i in data]) 17 | 18 | bytes_to_list = lambda data: [i for i in data] 19 | 20 | random_hex = lambda x: ''.join([choice('0123456789abcdef') for _ in range(x)]) 21 | -------------------------------------------------------------------------------- /gmssl/optimized_curve.py: -------------------------------------------------------------------------------- 1 | from gmssl.optimized_field_elements import FQ2, FQ12, field_modulus, FQ 2 | 3 | curve_order = 21888242871839275222246405745257275088548364400416034343698204186575808495617 4 | 5 | # Curve order should be prime 6 | assert pow(2, curve_order, curve_order) == 2 7 | # Curve order should be a factor of field_modulus**12 - 1 8 | assert (field_modulus ** 12 - 1) % curve_order == 0 9 | 10 | # Curve is y**2 = x**3 + 3 11 | b = FQ(3) 12 | # Twisted curve over FQ**2 13 | b2 = FQ2([3, 0]) / FQ2([9, 1]) 14 | # Extension curve over FQ**12; same b value as over FQ 15 | b12 = FQ12([3] + [0] * 11) 16 | 17 | # Generator for curve over FQ 18 | G1 = (FQ(1), FQ(2), FQ(1)) 19 | # Generator for twisted curve over FQ2 20 | G2 = (FQ2([10857046999023057135944570762232829481370756359578518086990519993285655852781, 11559732032986387107991004021392285783925812861821192530917403151452391805634]), 21 | FQ2([8495653923123431417604973247489272438418190587263600148770280649306958101930, 4082367875863433681332203403145435568316851327593401208105741076214120093531]), FQ2.one()) 22 | 23 | # Check if a point is the point at infinity 24 | def is_inf(pt): 25 | return pt[-1] == pt[-1].__class__.zero() 26 | 27 | # Check that a point is on the curve defined by y**2 == x**3 + b 28 | def is_on_curve(pt, b): 29 | if is_inf(pt): 30 | return True 31 | x, y, z = pt 32 | return y**2 * z - x**3 == b * z**3 33 | 34 | assert is_on_curve(G1, b) 35 | assert is_on_curve(G2, b2) 36 | 37 | # Elliptic curve doubling 38 | def double(pt): 39 | x, y, z = pt 40 | W = 3 * x * x 41 | S = y * z 42 | B = x * y * S 43 | H = W * W - 8 * B 44 | S_squared = S * S 45 | newx = 2 * H * S 46 | newy = W * (4 * B - H) - 8 * y * y * S_squared 47 | newz = 8 * S * S_squared 48 | return newx, newy, newz 49 | 50 | # Elliptic curve addition 51 | def add(p1, p2): 52 | one, zero = p1[0].__class__.one(), p1[0].__class__.zero() 53 | if p1[2] == zero or p2[2] == zero: 54 | return p1 if p2[2] == zero else p2 55 | x1, y1, z1 = p1 56 | x2, y2, z2 = p2 57 | U1 = y2 * z1 58 | U2 = y1 * z2 59 | V1 = x2 * z1 60 | V2 = x1 * z2 61 | if V1 == V2 and U1 == U2: 62 | return double(p1) 63 | elif V1 == V2: 64 | return (one, one, zero) 65 | U = U1 - U2 66 | V = V1 - V2 67 | V_squared = V * V 68 | V_squared_times_V2 = V_squared * V2 69 | V_cubed = V * V_squared 70 | W = z1 * z2 71 | A = U * U * W - V_cubed - 2 * V_squared_times_V2 72 | newx = V * A 73 | newy = U * (V_squared_times_V2 - A) - V_cubed * U2 74 | newz = V_cubed * W 75 | return (newx, newy, newz) 76 | 77 | # Elliptic curve point multiplication 78 | def multiply(pt, n): 79 | if n == 0: 80 | return (pt[0].__class__.one(), pt[0].__class__.one(), pt[0].__class__.zero()) 81 | elif n == 1: 82 | return pt 83 | elif not n % 2: 84 | return multiply(double(pt), n // 2) 85 | else: 86 | return add(multiply(double(pt), int(n // 2)), pt) 87 | 88 | def eq(p1, p2): 89 | x1, y1, z1 = p1 90 | x2, y2, z2 = p2 91 | return x1 * z2 == x2 * z1 and y1 * z2 == y2 * z1 92 | 93 | def normalize(pt): 94 | x, y, z = pt 95 | return (x / z, y / z) 96 | 97 | # "Twist" a point in E(FQ2) into a point in E(FQ12) 98 | w = FQ12([0, 1] + [0] * 10) 99 | 100 | # Convert P => -P 101 | def neg(pt): 102 | if pt is None: 103 | return None 104 | x, y, z = pt 105 | return (x, -y, z) 106 | 107 | def twist(pt): 108 | if pt is None: 109 | return None 110 | _x, _y, _z = pt 111 | # Field isomorphism from Z[p] / x**2 to Z[p] / x**2 - 18*x + 82 112 | xcoeffs = [_x.coeffs[0] - _x.coeffs[1] * 9, _x.coeffs[1]] 113 | ycoeffs = [_y.coeffs[0] - _y.coeffs[1] * 9, _y.coeffs[1]] 114 | zcoeffs = [_z.coeffs[0] - _z.coeffs[1] * 9, _z.coeffs[1]] 115 | x, y, z = _x - _y * 9, _y, _z 116 | nx = FQ12([xcoeffs[0]] + [0] * 5 + [xcoeffs[1]] + [0] * 5) 117 | ny = FQ12([ycoeffs[0]] + [0] * 5 + [ycoeffs[1]] + [0] * 5) 118 | nz = FQ12([zcoeffs[0]] + [0] * 5 + [zcoeffs[1]] + [0] * 5) 119 | return (nx * w **2, ny * w**3, nz) 120 | 121 | # Check that the twist creates a point that is on the curve 122 | G12 = twist(G2) 123 | assert is_on_curve(G12, b12) 124 | -------------------------------------------------------------------------------- /gmssl/optimized_field_elements.py: -------------------------------------------------------------------------------- 1 | field_modulus = 21888242871839275222246405745257275088696311157297823662689037894645226208583 2 | FQ12_modulus_coeffs = [82, 0, 0, 0, 0, 0, -18, 0, 0, 0, 0, 0] # Implied + [1] 3 | FQ12_mc_tuples = [(i, c) for i, c in enumerate(FQ12_modulus_coeffs) if c] 4 | 5 | # python3 compatibility 6 | try: 7 | foo = long 8 | except: 9 | long = int 10 | 11 | # Extended euclidean algorithm to find modular inverses for 12 | # integers 13 | def prime_field_inv(a, n): 14 | if a == 0: 15 | return 0 16 | lm, hm = 1, 0 17 | low, high = a % n, n 18 | while low > 1: 19 | r = high//low 20 | nm, new = hm-lm*r, high-low*r 21 | lm, low, hm, high = nm, new, lm, low 22 | return lm % n 23 | 24 | # A class for field elements in FQ. Wrap a number in this class, 25 | # and it becomes a field element. 26 | class FQ(): 27 | def __init__(self, n): 28 | if isinstance(n, self.__class__): 29 | self.n = n.n 30 | else: 31 | self.n = n % field_modulus 32 | assert isinstance(self.n, (int, long)) 33 | 34 | def __add__(self, other): 35 | on = other.n if isinstance(other, FQ) else other 36 | return FQ((self.n + on) % field_modulus) 37 | 38 | def __mul__(self, other): 39 | on = other.n if isinstance(other, FQ) else other 40 | return FQ((self.n * on) % field_modulus) 41 | 42 | def __rmul__(self, other): 43 | return self * other 44 | 45 | def __radd__(self, other): 46 | return self + other 47 | 48 | def __rsub__(self, other): 49 | on = other.n if isinstance(other, FQ) else other 50 | return FQ((on - self.n) % field_modulus) 51 | 52 | def __sub__(self, other): 53 | on = other.n if isinstance(other, FQ) else other 54 | return FQ((self.n - on) % field_modulus) 55 | 56 | def __div__(self, other): 57 | on = other.n if isinstance(other, FQ) else other 58 | assert isinstance(on, (int, long)) 59 | return FQ(self.n * prime_field_inv(on, field_modulus) % field_modulus) 60 | 61 | def __truediv__(self, other): 62 | return self.__div__(other) 63 | 64 | def __rdiv__(self, other): 65 | on = other.n if isinstance(other, FQ) else other 66 | assert isinstance(on, (int, long)), on 67 | return FQ(prime_field_inv(self.n, field_modulus) * on % field_modulus) 68 | 69 | def __rtruediv__(self, other): 70 | return self.__rdiv__(other) 71 | 72 | def __pow__(self, other): 73 | if other == 0: 74 | return FQ(1) 75 | elif other == 1: 76 | return FQ(self.n) 77 | elif other % 2 == 0: 78 | return (self * self) ** (other // 2) 79 | else: 80 | return ((self * self) ** int(other // 2)) * self 81 | 82 | def __eq__(self, other): 83 | if isinstance(other, FQ): 84 | return self.n == other.n 85 | else: 86 | return self.n == other 87 | 88 | def __ne__(self, other): 89 | return not self == other 90 | 91 | def __neg__(self): 92 | return FQ(-self.n) 93 | 94 | def __repr__(self): 95 | return repr(self.n) 96 | 97 | @classmethod 98 | def one(cls): 99 | return cls(1) 100 | 101 | @classmethod 102 | def zero(cls): 103 | return cls(0) 104 | 105 | # Utility methods for polynomial math 106 | def deg(p): 107 | d = len(p) - 1 108 | while p[d] == 0 and d: 109 | d -= 1 110 | return d 111 | 112 | def poly_rounded_div(a, b): 113 | dega = deg(a) 114 | degb = deg(b) 115 | temp = [x for x in a] 116 | o = [0 for x in a] 117 | for i in range(dega - degb, -1, -1): 118 | o[i] = (o[i] + temp[degb + i] * prime_field_inv(b[degb], field_modulus)) 119 | for c in range(degb + 1): 120 | temp[c + i] = (temp[c + i] - o[c]) 121 | return [x % field_modulus for x in o[:deg(o)+1]] 122 | 123 | # A class for elements in polynomial extension fields 124 | class FQP(): 125 | def __init__(self, coeffs, modulus_coeffs): 126 | assert len(coeffs) == len(modulus_coeffs) 127 | self.coeffs = coeffs 128 | # The coefficients of the modulus, without the leading [1] 129 | self.modulus_coeffs = modulus_coeffs 130 | # The degree of the extension field 131 | self.degree = len(self.modulus_coeffs) 132 | 133 | def __add__(self, other): 134 | assert isinstance(other, self.__class__) 135 | return self.__class__([(x+y) % field_modulus for x,y in zip(self.coeffs, other.coeffs)]) 136 | 137 | def __sub__(self, other): 138 | assert isinstance(other, self.__class__) 139 | return self.__class__([(x-y) % field_modulus for x,y in zip(self.coeffs, other.coeffs)]) 140 | 141 | def __mul__(self, other): 142 | if isinstance(other, (int, long)): 143 | return self.__class__([c * other % field_modulus for c in self.coeffs]) 144 | else: 145 | # assert isinstance(other, self.__class__) 146 | b = [0] * (self.degree * 2 - 1) 147 | inner_enumerate = list(enumerate(other.coeffs)) 148 | for i, eli in enumerate(self.coeffs): 149 | for j, elj in inner_enumerate: 150 | b[i + j] += eli * elj 151 | # MID = len(self.coeffs) // 2 152 | for exp in range(self.degree - 2, -1, -1): 153 | top = b.pop() 154 | for i, c in self.mc_tuples: 155 | b[exp + i] -= top * c 156 | return self.__class__([x % field_modulus for x in b]) 157 | 158 | def __rmul__(self, other): 159 | return self * other 160 | 161 | def __div__(self, other): 162 | if isinstance(other, (int, long)): 163 | return self.__class__([c * prime_field_inv(other, field_modulus) % field_modulus for c in self.coeffs]) 164 | else: 165 | assert isinstance(other, self.__class__) 166 | return self * other.inv() 167 | 168 | def __truediv__(self, other): 169 | return self.__div__(other) 170 | 171 | def __pow__(self, other): 172 | o = self.__class__([1] + [0] * (self.degree - 1)) 173 | t = self 174 | while other > 0: 175 | if other & 1: 176 | o = o * t 177 | other >>= 1 178 | t = t * t 179 | return o 180 | 181 | # Extended euclidean algorithm used to find the modular inverse 182 | def inv(self): 183 | lm, hm = [1] + [0] * self.degree, [0] * (self.degree + 1) 184 | low, high = self.coeffs + [0], self.modulus_coeffs + [1] 185 | while deg(low): 186 | r = poly_rounded_div(high, low) 187 | r += [0] * (self.degree + 1 - len(r)) 188 | nm = [x for x in hm] 189 | new = [x for x in high] 190 | # assert len(lm) == len(hm) == len(low) == len(high) == len(nm) == len(new) == self.degree + 1 191 | for i in range(self.degree + 1): 192 | for j in range(self.degree + 1 - i): 193 | nm[i+j] -= lm[i] * r[j] 194 | new[i+j] -= low[i] * r[j] 195 | nm = [x % field_modulus for x in nm] 196 | new = [x % field_modulus for x in new] 197 | lm, low, hm, high = nm, new, lm, low 198 | return self.__class__(lm[:self.degree]) / low[0] 199 | 200 | def __repr__(self): 201 | return repr(self.coeffs) 202 | 203 | def __eq__(self, other): 204 | assert isinstance(other, self.__class__) 205 | for c1, c2 in zip(self.coeffs, other.coeffs): 206 | if c1 != c2: 207 | return False 208 | return True 209 | 210 | def __ne__(self, other): 211 | return not self == other 212 | 213 | def __neg__(self): 214 | return self.__class__([-c for c in self.coeffs]) 215 | 216 | @classmethod 217 | def one(cls): 218 | return cls([1] + [0] * (cls.degree - 1)) 219 | 220 | @classmethod 221 | def zero(cls): 222 | return cls([0] * cls.degree) 223 | 224 | # The quadratic extension field 225 | class FQ2(FQP): 226 | def __init__(self, coeffs): 227 | self.coeffs = coeffs 228 | self.modulus_coeffs = [1, 0] 229 | self.mc_tuples = [(0, 1)] 230 | self.degree = 2 231 | self.__class__.degree = 2 232 | 233 | # The 12th-degree extension field 234 | class FQ12(FQP): 235 | def __init__(self, coeffs): 236 | self.coeffs = coeffs 237 | self.modulus_coeffs = FQ12_modulus_coeffs 238 | self.mc_tuples = FQ12_mc_tuples 239 | self.degree = 12 240 | self.__class__.degree = 12 241 | -------------------------------------------------------------------------------- /gmssl/optimized_pairing.py: -------------------------------------------------------------------------------- 1 | from gmssl.optimized_curve import double, add, multiply, is_on_curve, neg, twist, b, b2, b12, curve_order, G1, G2, G12, normalize 2 | from gmssl.optimized_field_elements import FQ2, FQ12, field_modulus, FQ 3 | 4 | ate_loop_count = 29793968203157093288 5 | log_ate_loop_count = 63 6 | pseudo_binary_encoding = [0, 0, 0, 1, 0, 1, 0, -1, 0, 0, 1, -1, 0, 0, 1, 0, 7 | 0, 1, 1, 0, -1, 0, 0, 1, 0, -1, 0, 0, 0, 0, 1, 1, 8 | 1, 0, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, -1, 0, 0, 1, 9 | 1, 0, 0, -1, 0, 0, 0, 1, 1, 0, -1, 0, 0, 1, 0, 1, 1] 10 | 11 | assert sum([e * 2**i for i, e in enumerate(pseudo_binary_encoding)]) == ate_loop_count 12 | 13 | 14 | def normalize1(p): 15 | x, y = normalize(p) 16 | return x, y, x.__class__.one() 17 | 18 | # Create a function representing the line between P1 and P2, 19 | # and evaluate it at T. Returns a numerator and a denominator 20 | # to avoid unneeded divisions 21 | def linefunc(P1, P2, T): 22 | zero = P1[0].__class__.zero() 23 | x1, y1, z1 = P1 24 | x2, y2, z2 = P2 25 | xt, yt, zt = T 26 | # points in projective coords: (x / z, y / z) 27 | # hence, m = (y2/z2 - y1/z1) / (x2/z2 - x1/z1) 28 | # multiply numerator and denominator by z1z2 to get values below 29 | m_numerator = y2 * z1 - y1 * z2 30 | m_denominator = x2 * z1 - x1 * z2 31 | if m_denominator != zero: 32 | # m * ((xt/zt) - (x1/z1)) - ((yt/zt) - (y1/z1)) 33 | return m_numerator * (xt * z1 - x1 * zt) - m_denominator * (yt * z1 - y1 * zt), \ 34 | m_denominator * zt * z1 35 | elif m_numerator == zero: 36 | # m = 3(x/z)^2 / 2(y/z), multiply num and den by z**2 37 | m_numerator = 3 * x1 * x1 38 | m_denominator = 2 * y1 * z1 39 | return m_numerator * (xt * z1 - x1 * zt) - m_denominator * (yt * z1 - y1 * zt), \ 40 | m_denominator * zt * z1 41 | else: 42 | return xt * z1 - x1 * zt, z1 * zt 43 | 44 | def cast_point_to_fq12(pt): 45 | if pt is None: 46 | return None 47 | x, y, z = pt 48 | return (FQ12([x.n] + [0] * 11), FQ12([y.n] + [0] * 11), FQ12([z.n] + [0] * 11)) 49 | 50 | # Check consistency of the "line function" 51 | one, two, three = G1, double(G1), multiply(G1, 3) 52 | negone, negtwo, negthree = multiply(G1, curve_order - 1), multiply(G1, curve_order - 2), multiply(G1, curve_order - 3) 53 | 54 | assert linefunc(one, two, one)[0] == FQ(0) 55 | assert linefunc(one, two, two)[0] == FQ(0) 56 | assert linefunc(one, two, three)[0] != FQ(0) 57 | assert linefunc(one, two, negthree)[0] == FQ(0) 58 | assert linefunc(one, negone, one)[0] == FQ(0) 59 | assert linefunc(one, negone, negone)[0] == FQ(0) 60 | assert linefunc(one, negone, two)[0] != FQ(0) 61 | assert linefunc(one, one, one)[0] == FQ(0) 62 | assert linefunc(one, one, two)[0] != FQ(0) 63 | assert linefunc(one, one, negtwo)[0] == FQ(0) 64 | 65 | # Main miller loop 66 | def miller_loop(Q, P, final_exponentiate=True): 67 | if Q is None or P is None: 68 | return FQ12.one() 69 | R = Q 70 | f_num, f_den = FQ12.one(), FQ12.one() 71 | for b in pseudo_binary_encoding[63::-1]: 72 | #for i in range(log_ate_loop_count, -1, -1): 73 | _n, _d = linefunc(R, R, P) 74 | f_num = f_num * f_num * _n 75 | f_den = f_den * f_den * _d 76 | R = double(R) 77 | #if ate_loop_count & (2**i): 78 | if b == 1: 79 | _n, _d = linefunc(R, Q, P) 80 | f_num = f_num * _n 81 | f_den = f_den * _d 82 | R = add(R, Q) 83 | elif b == -1: 84 | nQ = neg(Q) 85 | _n, _d = linefunc(R, nQ, P) 86 | f_num = f_num * _n 87 | f_den = f_den * _d 88 | R = add(R, nQ) 89 | # assert R == multiply(Q, ate_loop_count) 90 | Q1 = (Q[0] ** field_modulus, Q[1] ** field_modulus, Q[2] ** field_modulus) 91 | # assert is_on_curve(Q1, b12) 92 | nQ2 = (Q1[0] ** field_modulus, -Q1[1] ** field_modulus, Q1[2] ** field_modulus) 93 | # assert is_on_curve(nQ2, b12) 94 | _n1, _d1 = linefunc(R, Q1, P) 95 | R = add(R, Q1) 96 | _n2, _d2 = linefunc(R, nQ2, P) 97 | f = f_num * _n1 * _n2 / (f_den * _d1 * _d2) 98 | # R = add(R, nQ2) This line is in many specifications but it technically does nothing 99 | if final_exponentiate: 100 | return f ** ((field_modulus ** 12 - 1) // curve_order) 101 | else: 102 | return f 103 | 104 | # Pairing computation 105 | def pairing(Q, P, final_exponentiate=True): 106 | assert is_on_curve(Q, b2) 107 | assert is_on_curve(P, b) 108 | if P[-1] == P[-1].__class__.zero() or Q[-1] == Q[-1].__class__.zero(): 109 | return FQ12.one() 110 | return miller_loop(twist(Q), cast_point_to_fq12(P), final_exponentiate=final_exponentiate) 111 | 112 | def final_exponentiate(p): 113 | return p ** ((field_modulus ** 12 - 1) // curve_order) 114 | -------------------------------------------------------------------------------- /gmssl/sm2.py: -------------------------------------------------------------------------------- 1 | import binascii 2 | from random import choice 3 | from . import sm3, func 4 | # 选择素域,设置椭圆曲线参数 5 | 6 | default_ecc_table = { 7 | 'n': 'FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123', 8 | 'p': 'FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF', 9 | 'g': '32c4ae2c1f1981195f9904466a39c9948fe30bbff2660be1715a4589334c74c7'\ 10 | 'bc3736a2f4f6779c59bdcee36b692153d0a9877cc62a474002df32e52139f0a0', 11 | 'a': 'FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC', 12 | 'b': '28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93', 13 | } 14 | 15 | class CryptSM2(object): 16 | 17 | def __init__(self, private_key, public_key, ecc_table=default_ecc_table): 18 | self.private_key = private_key 19 | self.public_key = public_key 20 | self.para_len = len(ecc_table['n']) 21 | self.ecc_a3 = ( 22 | int(ecc_table['a'], base=16) + 3) % int(ecc_table['p'], base=16) 23 | self.ecc_table = ecc_table 24 | 25 | def _kg(self, k, Point): # kP运算 26 | Point = '%s%s' % (Point, '1') 27 | mask_str = '8' 28 | for i in range(self.para_len - 1): 29 | mask_str += '0' 30 | mask = int(mask_str, 16) 31 | Temp = Point 32 | flag = False 33 | for n in range(self.para_len * 4): 34 | if (flag): 35 | Temp = self._double_point(Temp) 36 | if (k & mask) != 0: 37 | if (flag): 38 | Temp = self._add_point(Temp, Point) 39 | else: 40 | flag = True 41 | Temp = Point 42 | k = k << 1 43 | return self._convert_jacb_to_nor(Temp) 44 | 45 | def _double_point(self, Point): # 倍点 46 | l = len(Point) 47 | len_2 = 2 * self.para_len 48 | if l< self.para_len * 2: 49 | return None 50 | else: 51 | x1 = int(Point[0:self.para_len], 16) 52 | y1 = int(Point[self.para_len:len_2], 16) 53 | if l == len_2: 54 | z1 = 1 55 | else: 56 | z1 = int(Point[len_2:], 16) 57 | 58 | T6 = (z1 * z1) % int(self.ecc_table['p'], base=16) 59 | T2 = (y1 * y1) % int(self.ecc_table['p'], base=16) 60 | T3 = (x1 + T6) % int(self.ecc_table['p'], base=16) 61 | T4 = (x1 - T6) % int(self.ecc_table['p'], base=16) 62 | T1 = (T3 * T4) % int(self.ecc_table['p'], base=16) 63 | T3 = (y1 * z1) % int(self.ecc_table['p'], base=16) 64 | T4 = (T2 * 8) % int(self.ecc_table['p'], base=16) 65 | T5 = (x1 * T4) % int(self.ecc_table['p'], base=16) 66 | T1 = (T1 * 3) % int(self.ecc_table['p'], base=16) 67 | T6 = (T6 * T6) % int(self.ecc_table['p'], base=16) 68 | T6 = (self.ecc_a3 * T6) % int(self.ecc_table['p'], base=16) 69 | T1 = (T1 + T6) % int(self.ecc_table['p'], base=16) 70 | z3 = (T3 + T3) % int(self.ecc_table['p'], base=16) 71 | T3 = (T1 * T1) % int(self.ecc_table['p'], base=16) 72 | T2 = (T2 * T4) % int(self.ecc_table['p'], base=16) 73 | x3 = (T3 - T5) % int(self.ecc_table['p'], base=16) 74 | 75 | if (T5 % 2) == 1: 76 | T4 = (T5 + ((T5 + int(self.ecc_table['p'], base=16)) >> 1) - T3) % int(self.ecc_table['p'], base=16) 77 | else: 78 | T4 = (T5 + (T5 >> 1) - T3) % int(self.ecc_table['p'], base=16) 79 | 80 | T1 = (T1 * T4) % int(self.ecc_table['p'], base=16) 81 | y3 = (T1 - T2) % int(self.ecc_table['p'], base=16) 82 | 83 | form = '%%0%dx' % self.para_len 84 | form = form * 3 85 | return form % (x3, y3, z3) 86 | 87 | def _add_point(self, P1, P2): # 点加函数,P2点为仿射坐标即z=1,P1为Jacobian加重射影坐标 88 | len_2 = 2 * self.para_len 89 | l1 = len(P1) 90 | l2 = len(P2) 91 | if (l1 < len_2) or (l2 < len_2): 92 | return None 93 | else: 94 | X1 = int(P1[0:self.para_len], 16) 95 | Y1 = int(P1[self.para_len:len_2], 16) 96 | if (l1 == len_2): 97 | Z1 = 1 98 | else: 99 | Z1 = int(P1[len_2:], 16) 100 | x2 = int(P2[0:self.para_len], 16) 101 | y2 = int(P2[self.para_len:len_2], 16) 102 | 103 | T1 = (Z1 * Z1) % int(self.ecc_table['p'], base=16) 104 | T2 = (y2 * Z1) % int(self.ecc_table['p'], base=16) 105 | T3 = (x2 * T1) % int(self.ecc_table['p'], base=16) 106 | T1 = (T1 * T2) % int(self.ecc_table['p'], base=16) 107 | T2 = (T3 - X1) % int(self.ecc_table['p'], base=16) 108 | T3 = (T3 + X1) % int(self.ecc_table['p'], base=16) 109 | T4 = (T2 * T2) % int(self.ecc_table['p'], base=16) 110 | T1 = (T1 - Y1) % int(self.ecc_table['p'], base=16) 111 | Z3 = (Z1 * T2) % int(self.ecc_table['p'], base=16) 112 | T2 = (T2 * T4) % int(self.ecc_table['p'], base=16) 113 | T3 = (T3 * T4) % int(self.ecc_table['p'], base=16) 114 | T5 = (T1 * T1) % int(self.ecc_table['p'], base=16) 115 | T4 = (X1 * T4) % int(self.ecc_table['p'], base=16) 116 | X3 = (T5 - T3) % int(self.ecc_table['p'], base=16) 117 | T2 = (Y1 * T2) % int(self.ecc_table['p'], base=16) 118 | T3 = (T4 - X3) % int(self.ecc_table['p'], base=16) 119 | T1 = (T1 * T3) % int(self.ecc_table['p'], base=16) 120 | Y3 = (T1 - T2) % int(self.ecc_table['p'], base=16) 121 | 122 | form = '%%0%dx' % self.para_len 123 | form = form * 3 124 | return form % (X3, Y3, Z3) 125 | 126 | def _convert_jacb_to_nor(self, Point): # Jacobian加重射影坐标转换成仿射坐标 127 | len_2 = 2 * self.para_len 128 | x = int(Point[0:self.para_len], 16) 129 | y = int(Point[self.para_len:len_2], 16) 130 | z = int(Point[len_2:], 16) 131 | z_inv = pow(z, int(self.ecc_table['p'], base=16) - 2, int(self.ecc_table['p'], base=16)) 132 | z_invSquar = (z_inv * z_inv) % int(self.ecc_table['p'], base=16) 133 | z_invQube = (z_invSquar * z_inv) % int(self.ecc_table['p'], base=16) 134 | x_new = (x * z_invSquar) % int(self.ecc_table['p'], base=16) 135 | y_new = (y * z_invQube) % int(self.ecc_table['p'], base=16) 136 | z_new = (z * z_inv) % int(self.ecc_table['p'], base=16) 137 | if z_new == 1: 138 | form = '%%0%dx' % self.para_len 139 | form = form * 2 140 | return form % (x_new, y_new) 141 | else: 142 | return None 143 | 144 | def verify(self, Sign, data): 145 | # 验签函数,sign签名r||s,E消息hash,public_key公钥 146 | r = int(Sign[0:self.para_len], 16) 147 | s = int(Sign[self.para_len:2*self.para_len], 16) 148 | e = int(data.hex(), 16) 149 | t = (r + s) % int(self.ecc_table['n'], base=16) 150 | if t == 0: 151 | return 0 152 | 153 | P1 = self._kg(s, self.ecc_table['g']) 154 | P2 = self._kg(t, self.public_key) 155 | # print(P1) 156 | # print(P2) 157 | if P1 == P2: 158 | P1 = '%s%s' % (P1, 1) 159 | P1 = self._double_point(P1) 160 | else: 161 | P1 = '%s%s' % (P1, 1) 162 | P1 = self._add_point(P1, P2) 163 | P1 = self._convert_jacb_to_nor(P1) 164 | 165 | x = int(P1[0:self.para_len], 16) 166 | return (r == ((e + x) % int(self.ecc_table['n'], base=16))) 167 | 168 | def sign(self, data, K): # 签名函数, data消息的hash,private_key私钥,K随机数,均为16进制字符串 169 | E = data.hex() # 消息转化为16进制字符串 170 | e = int(E, 16) 171 | 172 | d = int(self.private_key, 16) 173 | k = int(K, 16) 174 | 175 | P1 = self._kg(k, self.ecc_table['g']) 176 | 177 | x = int(P1[0:self.para_len], 16) 178 | R = ((e + x) % int(self.ecc_table['n'], base=16)) 179 | if R == 0 or R + k == int(self.ecc_table['n'], base=16): 180 | return None 181 | d_1 = pow(d+1, int(self.ecc_table['n'], base=16) - 2, int(self.ecc_table['n'], base=16)) 182 | S = (d_1*(k + R) - R) % int(self.ecc_table['n'], base=16) 183 | if S == 0: 184 | return None 185 | else: 186 | return '%064x%064x' % (R,S) 187 | 188 | def encrypt(self, data): 189 | # 加密函数,data消息(bytes) 190 | msg = data.hex() # 消息转化为16进制字符串 191 | k = func.random_hex(self.para_len) 192 | C1 = self._kg(int(k,16),self.ecc_table['g']) 193 | xy = self._kg(int(k,16),self.public_key) 194 | x2 = xy[0:self.para_len] 195 | y2 = xy[self.para_len:2*self.para_len] 196 | ml = len(msg) 197 | t = sm3.sm3_kdf(xy.encode('utf8'), ml/2) 198 | if int(t,16)==0: 199 | return None 200 | else: 201 | form = '%%0%dx' % ml 202 | C2 = form % (int(msg, 16) ^ int(t, 16)) 203 | C3 = sm3.sm3_hash([ 204 | i for i in bytes.fromhex('%s%s%s'% (x2,msg,y2)) 205 | ]) 206 | return bytes.fromhex('%s%s%s' % (C1,C3,C2)) 207 | 208 | def decrypt(self, data): 209 | # 解密函数,data密文(bytes) 210 | data = data.hex() 211 | len_2 = 2 * self.para_len 212 | len_3 = len_2 + 64 213 | C1 = data[0:len_2] 214 | C3 = data[len_2:len_3] 215 | C2 = data[len_3:] 216 | xy = self._kg(int(self.private_key,16),C1) 217 | # print('xy = %s' % xy) 218 | x2 = xy[0:self.para_len] 219 | y2 = xy[self.para_len:len_2] 220 | cl = len(C2) 221 | t = sm3.sm3_kdf(xy.encode('utf8'), cl/2) 222 | if int(t, 16) == 0: 223 | return None 224 | else: 225 | form = '%%0%dx' % cl 226 | M = form % (int(C2,16) ^ int(t,16)) 227 | u = sm3.sm3_hash([ 228 | i for i in bytes.fromhex('%s%s%s'% (x2,M,y2)) 229 | ]) 230 | return bytes.fromhex(M) 231 | 232 | 233 | def _sm3_z(self, data): 234 | """ 235 | SM3WITHSM2 签名规则: SM2.sign(SM3(Z+MSG),PrivateKey) 236 | 其中: z = Hash256(Len(ID) + ID + a + b + xG + yG + xA + yA) 237 | """ 238 | # sm3withsm2 的 z 值 239 | z = '0080'+'31323334353637383132333435363738' + \ 240 | self.ecc_table['a'] + self.ecc_table['b'] + self.ecc_table['g'] + \ 241 | self.public_key 242 | z = binascii.a2b_hex(z) 243 | Za = sm3.sm3_hash(func.bytes_to_list(z)) 244 | M_ = (Za + data.hex()).encode('utf-8') 245 | e = sm3.sm3_hash(func.bytes_to_list(binascii.a2b_hex(M_))) 246 | return e 247 | 248 | 249 | def sign_with_sm3(self, data, random_hex_str=None): 250 | sign_data = binascii.a2b_hex(self._sm3_z(data).encode('utf-8')) 251 | if random_hex_str is None: 252 | random_hex_str = func.random_hex(self.para_len) 253 | sign = self.sign(sign_data, random_hex_str) # 16进制 254 | return sign 255 | 256 | 257 | def verify_with_sm3(self, sign, data): 258 | sign_data = binascii.a2b_hex(self._sm3_z(data).encode('utf-8')) 259 | return self.verify(sign, sign_data) 260 | -------------------------------------------------------------------------------- /gmssl/sm3.py: -------------------------------------------------------------------------------- 1 | import binascii 2 | from math import ceil 3 | from .func import rotl, bytes_to_list 4 | 5 | IV = [ 6 | 1937774191, 1226093241, 388252375, 3666478592, 7 | 2842636476, 372324522, 3817729613, 2969243214, 8 | ] 9 | 10 | T_j = [ 11 | 2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 12 | 2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 13 | 2043430169, 2043430169, 2043430169, 2043430169, 2055708042, 2055708042, 14 | 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 15 | 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 16 | 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 17 | 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 18 | 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 19 | 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 20 | 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 21 | 2055708042, 2055708042, 2055708042, 2055708042 22 | ] 23 | 24 | def sm3_ff_j(x, y, z, j): 25 | if 0 <= j and j < 16: 26 | ret = x ^ y ^ z 27 | elif 16 <= j and j < 64: 28 | ret = (x & y) | (x & z) | (y & z) 29 | return ret 30 | 31 | def sm3_gg_j(x, y, z, j): 32 | if 0 <= j and j < 16: 33 | ret = x ^ y ^ z 34 | elif 16 <= j and j < 64: 35 | #ret = (X | Y) & ((2 ** 32 - 1 - X) | Z) 36 | ret = (x & y) | ((~ x) & z) 37 | return ret 38 | 39 | def sm3_p_0(x): 40 | return x ^ (rotl(x, 9 % 32)) ^ (rotl(x, 17 % 32)) 41 | 42 | def sm3_p_1(x): 43 | return x ^ (rotl(x, 15 % 32)) ^ (rotl(x, 23 % 32)) 44 | 45 | def sm3_cf(v_i, b_i): 46 | w = [] 47 | for i in range(16): 48 | weight = 0x1000000 49 | data = 0 50 | for k in range(i*4,(i+1)*4): 51 | data = data + b_i[k]*weight 52 | weight = int(weight/0x100) 53 | w.append(data) 54 | 55 | for j in range(16, 68): 56 | w.append(0) 57 | w[j] = sm3_p_1(w[j-16] ^ w[j-9] ^ (rotl(w[j-3], 15 % 32))) ^ (rotl(w[j-13], 7 % 32)) ^ w[j-6] 58 | str1 = "%08x" % w[j] 59 | w_1 = [] 60 | for j in range(0, 64): 61 | w_1.append(0) 62 | w_1[j] = w[j] ^ w[j+4] 63 | str1 = "%08x" % w_1[j] 64 | 65 | a, b, c, d, e, f, g, h = v_i 66 | 67 | for j in range(0, 64): 68 | ss_1 = rotl( 69 | ((rotl(a, 12 % 32)) + 70 | e + 71 | (rotl(T_j[j], j % 32))) & 0xffffffff, 7 % 32 72 | ) 73 | ss_2 = ss_1 ^ (rotl(a, 12 % 32)) 74 | tt_1 = (sm3_ff_j(a, b, c, j) + d + ss_2 + w_1[j]) & 0xffffffff 75 | tt_2 = (sm3_gg_j(e, f, g, j) + h + ss_1 + w[j]) & 0xffffffff 76 | d = c 77 | c = rotl(b, 9 % 32) 78 | b = a 79 | a = tt_1 80 | h = g 81 | g = rotl(f, 19 % 32) 82 | f = e 83 | e = sm3_p_0(tt_2) 84 | 85 | a, b, c, d, e, f, g, h = map( 86 | lambda x:x & 0xFFFFFFFF ,[a, b, c, d, e, f, g, h]) 87 | 88 | v_j = [a, b, c, d, e, f, g, h] 89 | return [v_j[i] ^ v_i[i] for i in range(8)] 90 | 91 | def sm3_hash(msg): 92 | # print(msg) 93 | len1 = len(msg) 94 | reserve1 = len1 % 64 95 | msg.append(0x80) 96 | reserve1 = reserve1 + 1 97 | # 56-64, add 64 byte 98 | range_end = 56 99 | if reserve1 > range_end: 100 | range_end = range_end + 64 101 | 102 | for i in range(reserve1, range_end): 103 | msg.append(0x00) 104 | 105 | bit_length = (len1) * 8 106 | bit_length_str = [bit_length % 0x100] 107 | for i in range(7): 108 | bit_length = int(bit_length / 0x100) 109 | bit_length_str.append(bit_length % 0x100) 110 | for i in range(8): 111 | msg.append(bit_length_str[7-i]) 112 | 113 | group_count = round(len(msg) / 64) 114 | 115 | B = [] 116 | for i in range(0, group_count): 117 | B.append(msg[i*64:(i+1)*64]) 118 | 119 | V = [] 120 | V.append(IV) 121 | for i in range(0, group_count): 122 | V.append(sm3_cf(V[i], B[i])) 123 | 124 | y = V[i+1] 125 | result = "" 126 | for i in y: 127 | result = '%s%08x' % (result, i) 128 | return result 129 | 130 | def sm3_kdf(z, klen): # z为16进制表示的比特串(str),klen为密钥长度(单位byte) 131 | klen = int(klen) 132 | ct = 0x00000001 133 | rcnt = ceil(klen/32) 134 | zin = [i for i in bytes.fromhex(z.decode('utf8'))] 135 | ha = "" 136 | for i in range(rcnt): 137 | msg = zin + [i for i in binascii.a2b_hex(('%08x' % ct).encode('utf8'))] 138 | ha = ha + sm3_hash(msg) 139 | ct += 1 140 | return ha[0: klen * 2] 141 | -------------------------------------------------------------------------------- /gmssl/sm4.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | import copy 3 | from .func import xor, rotl, get_uint32_be, put_uint32_be, \ 4 | bytes_to_list, list_to_bytes, padding, unpadding 5 | 6 | #Expanded SM4 box table 7 | SM4_BOXES_TABLE = [ 8 | 0xd6,0x90,0xe9,0xfe,0xcc,0xe1,0x3d,0xb7,0x16,0xb6,0x14,0xc2,0x28,0xfb,0x2c, 9 | 0x05,0x2b,0x67,0x9a,0x76,0x2a,0xbe,0x04,0xc3,0xaa,0x44,0x13,0x26,0x49,0x86, 10 | 0x06,0x99,0x9c,0x42,0x50,0xf4,0x91,0xef,0x98,0x7a,0x33,0x54,0x0b,0x43,0xed, 11 | 0xcf,0xac,0x62,0xe4,0xb3,0x1c,0xa9,0xc9,0x08,0xe8,0x95,0x80,0xdf,0x94,0xfa, 12 | 0x75,0x8f,0x3f,0xa6,0x47,0x07,0xa7,0xfc,0xf3,0x73,0x17,0xba,0x83,0x59,0x3c, 13 | 0x19,0xe6,0x85,0x4f,0xa8,0x68,0x6b,0x81,0xb2,0x71,0x64,0xda,0x8b,0xf8,0xeb, 14 | 0x0f,0x4b,0x70,0x56,0x9d,0x35,0x1e,0x24,0x0e,0x5e,0x63,0x58,0xd1,0xa2,0x25, 15 | 0x22,0x7c,0x3b,0x01,0x21,0x78,0x87,0xd4,0x00,0x46,0x57,0x9f,0xd3,0x27,0x52, 16 | 0x4c,0x36,0x02,0xe7,0xa0,0xc4,0xc8,0x9e,0xea,0xbf,0x8a,0xd2,0x40,0xc7,0x38, 17 | 0xb5,0xa3,0xf7,0xf2,0xce,0xf9,0x61,0x15,0xa1,0xe0,0xae,0x5d,0xa4,0x9b,0x34, 18 | 0x1a,0x55,0xad,0x93,0x32,0x30,0xf5,0x8c,0xb1,0xe3,0x1d,0xf6,0xe2,0x2e,0x82, 19 | 0x66,0xca,0x60,0xc0,0x29,0x23,0xab,0x0d,0x53,0x4e,0x6f,0xd5,0xdb,0x37,0x45, 20 | 0xde,0xfd,0x8e,0x2f,0x03,0xff,0x6a,0x72,0x6d,0x6c,0x5b,0x51,0x8d,0x1b,0xaf, 21 | 0x92,0xbb,0xdd,0xbc,0x7f,0x11,0xd9,0x5c,0x41,0x1f,0x10,0x5a,0xd8,0x0a,0xc1, 22 | 0x31,0x88,0xa5,0xcd,0x7b,0xbd,0x2d,0x74,0xd0,0x12,0xb8,0xe5,0xb4,0xb0,0x89, 23 | 0x69,0x97,0x4a,0x0c,0x96,0x77,0x7e,0x65,0xb9,0xf1,0x09,0xc5,0x6e,0xc6,0x84, 24 | 0x18,0xf0,0x7d,0xec,0x3a,0xdc,0x4d,0x20,0x79,0xee,0x5f,0x3e,0xd7,0xcb,0x39, 25 | 0x48, 26 | ] 27 | 28 | # System parameter 29 | SM4_FK = [0xa3b1bac6,0x56aa3350,0x677d9197,0xb27022dc] 30 | 31 | # fixed parameter 32 | SM4_CK = [ 33 | 0x00070e15,0x1c232a31,0x383f464d,0x545b6269, 34 | 0x70777e85,0x8c939aa1,0xa8afb6bd,0xc4cbd2d9, 35 | 0xe0e7eef5,0xfc030a11,0x181f262d,0x343b4249, 36 | 0x50575e65,0x6c737a81,0x888f969d,0xa4abb2b9, 37 | 0xc0c7ced5,0xdce3eaf1,0xf8ff060d,0x141b2229, 38 | 0x30373e45,0x4c535a61,0x686f767d,0x848b9299, 39 | 0xa0a7aeb5,0xbcc3cad1,0xd8dfe6ed,0xf4fb0209, 40 | 0x10171e25,0x2c333a41,0x484f565d,0x646b7279 41 | ] 42 | 43 | SM4_ENCRYPT = 0 44 | SM4_DECRYPT = 1 45 | 46 | 47 | class CryptSM4(object): 48 | 49 | def __init__(self, mode=SM4_ENCRYPT): 50 | self.sk = [0]*32 51 | self.mode = mode 52 | # Calculating round encryption key. 53 | # args: [in] a: a is a 32 bits unsigned value; 54 | # return: sk[i]: i{0,1,2,3,...31}. 55 | @classmethod 56 | def _round_key(cls, ka): 57 | b = [0, 0, 0, 0] 58 | a = put_uint32_be(ka) 59 | b[0] = SM4_BOXES_TABLE[a[0]] 60 | b[1] = SM4_BOXES_TABLE[a[1]] 61 | b[2] = SM4_BOXES_TABLE[a[2]] 62 | b[3] = SM4_BOXES_TABLE[a[3]] 63 | bb = get_uint32_be(b[0:4]) 64 | rk = bb ^ (rotl(bb, 13)) ^ (rotl(bb, 23)) 65 | return rk 66 | 67 | # Calculating and getting encryption/decryption contents. 68 | # args: [in] x0: original contents; 69 | # args: [in] x1: original contents; 70 | # args: [in] x2: original contents; 71 | # args: [in] x3: original contents; 72 | # args: [in] rk: encryption/decryption key; 73 | # return the contents of encryption/decryption contents. 74 | @classmethod 75 | def _f(cls, x0, x1, x2, x3, rk): 76 | # "T algorithm" == "L algorithm" + "t algorithm". 77 | # args: [in] a: a is a 32 bits unsigned value; 78 | # return: c: c is calculated with line algorithm "L" and nonline algorithm "t" 79 | def _sm4_l_t(ka): 80 | b = [0, 0, 0, 0] 81 | a = put_uint32_be(ka) 82 | b[0] = SM4_BOXES_TABLE[a[0]] 83 | b[1] = SM4_BOXES_TABLE[a[1]] 84 | b[2] = SM4_BOXES_TABLE[a[2]] 85 | b[3] = SM4_BOXES_TABLE[a[3]] 86 | bb = get_uint32_be(b[0:4]) 87 | c = bb ^ (rotl(bb, 2)) ^ (rotl(bb, 10)) ^ (rotl(bb, 18)) ^ (rotl(bb, 24)) 88 | return c 89 | return (x0 ^ _sm4_l_t(x1 ^ x2 ^ x3 ^ rk)) 90 | 91 | def set_key(self, key, mode): 92 | key = bytes_to_list(key) 93 | MK = [0, 0, 0, 0] 94 | k = [0]*36 95 | MK[0] = get_uint32_be(key[0:4]) 96 | MK[1] = get_uint32_be(key[4:8]) 97 | MK[2] = get_uint32_be(key[8:12]) 98 | MK[3] = get_uint32_be(key[12:16]) 99 | k[0:4] = xor(MK[0:4], SM4_FK[0:4]) 100 | for i in range(32): 101 | k[i + 4] = k[i] ^ ( 102 | self._round_key(k[i + 1] ^ k[i + 2] ^ k[i + 3] ^ SM4_CK[i])) 103 | self.sk[i] = k[i + 4] 104 | self.mode = mode 105 | if mode == SM4_DECRYPT: 106 | for idx in range(16): 107 | t = self.sk[idx] 108 | self.sk[idx] = self.sk[31 - idx] 109 | self.sk[31 - idx] = t 110 | 111 | def one_round(self, sk, in_put): 112 | out_put = [] 113 | ulbuf = [0]*36 114 | ulbuf[0] = get_uint32_be(in_put[0:4]) 115 | ulbuf[1] = get_uint32_be(in_put[4:8]) 116 | ulbuf[2] = get_uint32_be(in_put[8:12]) 117 | ulbuf[3] = get_uint32_be(in_put[12:16]) 118 | for idx in range(32): 119 | ulbuf[idx + 4] = self._f(ulbuf[idx], ulbuf[idx + 1], ulbuf[idx + 2], ulbuf[idx + 3], sk[idx]) 120 | 121 | out_put += put_uint32_be(ulbuf[35]) 122 | out_put += put_uint32_be(ulbuf[34]) 123 | out_put += put_uint32_be(ulbuf[33]) 124 | out_put += put_uint32_be(ulbuf[32]) 125 | return out_put 126 | 127 | def crypt_ecb(self, input_data): 128 | # SM4-ECB block encryption/decryption 129 | input_data = bytes_to_list(input_data) 130 | if self.mode == SM4_ENCRYPT: 131 | input_data = padding(input_data) 132 | length = len(input_data) 133 | i = 0 134 | output_data = [] 135 | while length > 0: 136 | output_data += self.one_round(self.sk, input_data[i:i+16]) 137 | i += 16 138 | length -= 16 139 | if self.mode == SM4_DECRYPT: 140 | return list_to_bytes(unpadding(output_data)) 141 | return list_to_bytes(output_data) 142 | 143 | def crypt_cbc(self, iv, input_data): 144 | #SM4-CBC buffer encryption/decryption 145 | i = 0 146 | output_data = [] 147 | tmp_input = [0]*16 148 | iv = bytes_to_list(iv) 149 | if self.mode == SM4_ENCRYPT: 150 | input_data = padding(bytes_to_list(input_data)) 151 | length = len(input_data) 152 | while length > 0: 153 | tmp_input[0:16] = xor(input_data[i:i+16], iv[0:16]) 154 | output_data += self.one_round(self.sk, tmp_input[0:16]) 155 | iv = copy.deepcopy(output_data[i:i+16]) 156 | i += 16 157 | length -= 16 158 | return list_to_bytes(output_data) 159 | else: 160 | length = len(input_data) 161 | while length > 0: 162 | output_data += self.one_round(self.sk, input_data[i:i+16]) 163 | output_data[i:i+16] = xor(output_data[i:i+16], iv[0:16]) 164 | iv = copy.deepcopy(input_data[i:i + 16]) 165 | i += 16 166 | length -= 16 167 | return list_to_bytes(unpadding(output_data)) 168 | -------------------------------------------------------------------------------- /gmssl/sm9.py: -------------------------------------------------------------------------------- 1 | 2 | import binascii 3 | from math import ceil, floor, log 4 | from gmssl.sm3 import sm3_kdf, sm3_hash 5 | 6 | from random import SystemRandom 7 | 8 | import gmssl.optimized_field_elements as fq 9 | import gmssl.optimized_curve as ec 10 | import gmssl.optimized_pairing as ate 11 | 12 | FAILURE = False 13 | SUCCESS = True 14 | 15 | 16 | def bitlen (n): 17 | return floor (log(n,2) + 1) 18 | 19 | def i2sp (m, l): 20 | format_m = ('%x' % m).zfill(l*2).encode('utf-8') 21 | octets = [j for j in binascii.a2b_hex(format_m)] 22 | octets = octets[0:l] 23 | return ''.join (['%02x' %oc for oc in octets]) 24 | 25 | def fe2sp (fe): 26 | fe_str = ''.join (['%x' %c for c in fe.coeffs]) 27 | if (len(fe_str) % 2) == 1: 28 | fe_str = '0' + fe_str 29 | return fe_str 30 | 31 | def ec2sp (P): 32 | ec_str = ''.join([fe2sp(fe) for fe in P]) 33 | return ec_str 34 | 35 | def str2hexbytes (str_in): 36 | return [b for b in str_in.encode ('utf-8')] 37 | 38 | def h2rf (i, z, n): 39 | l = 8 * ceil ((5*bitlen(n)) / 32) 40 | msg = i2sp(i,1).encode('utf-8') 41 | ha = sm3_kdf (msg+z, l) 42 | h = int (ha, 16) 43 | return (h % (n-1)) + 1 44 | 45 | def setup (scheme): 46 | P1 = ec.G2 47 | P2 = ec.G1 48 | 49 | rand_gen = SystemRandom() 50 | s = rand_gen.randrange (ec.curve_order) 51 | 52 | if (scheme == 'sign'): 53 | Ppub = ec.multiply (P2, s) 54 | g = ate.pairing (P1, Ppub) 55 | elif (scheme == 'keyagreement') | (scheme == 'encrypt'): 56 | Ppub = ec.multiply (P1, s) 57 | g = ate.pairing (Ppub, P2) 58 | else: 59 | raise Exception('Invalid scheme') 60 | 61 | master_public_key = (P1, P2, Ppub, g) 62 | return (master_public_key, s) 63 | 64 | def private_key_extract (scheme, master_public, master_secret, identity): 65 | P1 = master_public[0] 66 | P2 = master_public[1] 67 | 68 | user_id = sm3_hash (str2hexbytes (identity)) 69 | m = h2rf (1, (user_id + '01').encode('utf-8'), ec.curve_order) 70 | m = master_secret + m 71 | if (m % ec.curve_order) == 0: 72 | return FAILURE 73 | m = master_secret * fq.prime_field_inv (m, ec.curve_order) 74 | 75 | if (scheme == 'sign'): 76 | Da = ec.multiply (P1, m) 77 | elif (scheme == 'keyagreement') | (scheme == 'encrypt'): 78 | Da = ec.multiply (P2, m) 79 | else: 80 | raise Exception('Invalid scheme') 81 | 82 | return Da 83 | 84 | def public_key_extract (scheme, master_public, identity): 85 | P1, P2, Ppub, g = master_public 86 | 87 | user_id = sm3_hash (str2hexbytes (identity)) 88 | h1 = h2rf (1, (user_id + '01').encode('utf-8'), ec.curve_order) 89 | 90 | if (scheme == 'sign'): 91 | Q = ec.multiply (P2, h1) 92 | elif (scheme == 'keyagreement') | (scheme == 'encrypt'): 93 | Q = ec.multiply (P1, h1) 94 | else: 95 | raise Exception('Invalid scheme') 96 | 97 | Q = ec.add (Q, Ppub) 98 | 99 | return Q 100 | 101 | # scheme = 'sign' 102 | def sign (master_public, Da, msg): 103 | g = master_public[3] 104 | 105 | rand_gen = SystemRandom() 106 | x = rand_gen.randrange (ec.curve_order) 107 | w = g**x 108 | 109 | msg_hash = sm3_hash (str2hexbytes(msg)) 110 | z = (msg_hash + fe2sp(w)).encode('utf-8') 111 | h = h2rf (2, z, ec.curve_order) 112 | l = (x - h) % ec.curve_order 113 | 114 | S = ec.multiply (Da, l) 115 | return (h, S) 116 | 117 | def verify (master_public, identity, msg, signature): 118 | (h, S) = signature 119 | 120 | if (h < 0) | (h >= ec.curve_order): 121 | return FAILURE 122 | if ec.is_on_curve (S, ec.b2) == False: 123 | return FAILURE 124 | 125 | Q = public_key_extract ('sign', master_public, identity) 126 | 127 | g = master_public[3] 128 | u = ate.pairing (S, Q) 129 | t = g**h 130 | wprime = u * t 131 | 132 | msg_hash = sm3_hash (str2hexbytes(msg)) 133 | z = (msg_hash + fe2sp(wprime)).encode('utf-8') 134 | h2 = h2rf (2, z, ec.curve_order) 135 | 136 | if h != h2: 137 | return FAILURE 138 | return SUCCESS 139 | 140 | # scheme = 'keyagreement' 141 | def generate_ephemeral (master_public, identity): 142 | Q = public_key_extract ('keyagreement', master_public, identity) 143 | 144 | rand_gen = SystemRandom() 145 | x = rand_gen.randrange (ec.curve_order) 146 | R = ec.multiply (Q, x) 147 | 148 | return (x, R) 149 | 150 | def generate_session_key (idA, idB, Ra, Rb, D, x, master_public, entity, l): 151 | P1, P2, Ppub, g = master_public 152 | 153 | if entity == 'A': 154 | R = Rb 155 | elif entity == 'B': 156 | R = Ra 157 | else: 158 | raise Exception('Invalid entity') 159 | 160 | g1 = ate.pairing (R, D) 161 | g2 = g**x 162 | g3 = g1**x 163 | 164 | if (entity == 'B'): 165 | (g1, g2) = (g2, g1) 166 | 167 | uidA = sm3_hash (str2hexbytes (idA)) 168 | uidB = sm3_hash (str2hexbytes (idB)) 169 | 170 | kdf_input = uidA + uidB 171 | kdf_input += ec2sp(Ra) + ec2sp (Rb) 172 | kdf_input += fe2sp(g1) + fe2sp(g2) + fe2sp(g3) 173 | 174 | sk = sm3_kdf (kdf_input.encode ('utf-8'), l) 175 | 176 | return sk 177 | 178 | # encrypt 179 | 180 | def kem_encap (master_public, identity, l): 181 | P1, P2, Ppub, g = master_public 182 | 183 | Q = public_key_extract ('encrypt', master_public, identity) 184 | 185 | rand_gen = SystemRandom() 186 | x = rand_gen.randrange (ec.curve_order) 187 | 188 | C1 = ec.multiply (Q, x) 189 | t = g**x 190 | 191 | uid = sm3_hash (str2hexbytes (identity)) 192 | kdf_input = ec2sp(C1) + fe2sp(t) + uid 193 | k = sm3_kdf (kdf_input.encode ('utf-8'), l) 194 | 195 | return (k, C1) 196 | 197 | def kem_decap (master_public, identity, D, C1, l): 198 | if ec.is_on_curve (C1, ec.b2) == False: 199 | return FAILURE 200 | 201 | t = ate.pairing (C1, D) 202 | 203 | uid = sm3_hash (str2hexbytes (identity)) 204 | kdf_input = ec2sp(C1) + fe2sp(t) + uid 205 | k = sm3_kdf (kdf_input.encode ('utf-8'), l) 206 | 207 | return k 208 | 209 | def kem_dem_enc (master_public, identity, message, v): 210 | hex_msg = str2hexbytes (message) 211 | mbytes = len(hex_msg) 212 | mbits = mbytes * 8 213 | 214 | k, C1 = kem_encap (master_public, identity, mbits + v) 215 | k = str2hexbytes (k) 216 | k1 = k[:mbytes] 217 | k2 = k[mbytes:] 218 | 219 | C2 = [] 220 | for i in range (mbytes): 221 | C2.append (hex_msg[i] ^ k1[i]) 222 | 223 | hash_input = C2 + k2 224 | C3 = sm3_hash(hash_input)[:int(v/8)] 225 | 226 | return (C1, C2, C3) 227 | 228 | def kem_dem_dec (master_public, identity, D, ct, v): 229 | C1, C2, C3 = ct 230 | 231 | mbytes = len(C2) 232 | l = mbytes*8 + v 233 | k = kem_decap (master_public, identity, D, C1, l) 234 | 235 | k = str2hexbytes (k) 236 | k1 = k[:mbytes] 237 | k2 = k[mbytes:] 238 | 239 | hash_input = C2 + k2 240 | C3prime = sm3_hash(hash_input)[:int(v/8)] 241 | 242 | if C3 != C3prime: 243 | return FAILURE 244 | 245 | pt = [] 246 | for i in range (mbytes): 247 | pt.append (chr (C2[i] ^ k1[i])) 248 | 249 | message = ''.join(pt) 250 | 251 | return message 252 | -------------------------------------------------------------------------------- /tests/test_sm2.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import binascii 3 | from gmssl import sm2, func 4 | 5 | 6 | def test_sm2(): 7 | private_key = '00B9AB0B828FF68872F21A837FC303668428DEA11DCD1B24429D0C99E24EED83D5' 8 | public_key = 'B9C9A6E04E9C91F7BA880429273747D7EF5DDEB0BB2FF6317EB00BEF331A83081A6994B8993F3F5D6EADDDB81872266C87C018FB4162F5AF347B483E24620207' 9 | 10 | sm2_crypt = sm2.CryptSM2( 11 | public_key=public_key, private_key=private_key) 12 | data = b"111" 13 | enc_data = sm2_crypt.encrypt(data) 14 | #print("enc_data:%s" % enc_data) 15 | #print("enc_data_base64:%s" % base64.b64encode(bytes.fromhex(enc_data))) 16 | dec_data = sm2_crypt.decrypt(enc_data) 17 | print(b"dec_data:%s" % dec_data) 18 | assert data == dec_data 19 | 20 | print("-----------------test sign and verify---------------") 21 | random_hex_str = func.random_hex(sm2_crypt.para_len) 22 | sign = sm2_crypt.sign(data, random_hex_str) 23 | print('sign:%s' % sign) 24 | verify = sm2_crypt.verify(sign, data) 25 | print('verify:%s' % verify) 26 | assert verify 27 | 28 | 29 | def test_sm2sm3(): 30 | private_key = "3945208F7B2144B13F36E38AC6D39F95889393692860B51A42FB81EF4DF7C5B8" 31 | public_key = "09F9DF311E5421A150DD7D161E4BC5C672179FAD1833FC076BB08FF356F35020"\ 32 | "CCEA490CE26775A52DC6EA718CC1AA600AED05FBF35E084A6632F6072DA9AD13" 33 | random_hex_str = "59276E27D506861A16680F3AD9C02DCCEF3CC1FA3CDBE4CE6D54B80DEAC1BC21" 34 | 35 | sm2_crypt = sm2.CryptSM2(public_key=public_key, private_key=private_key) 36 | data = b"message digest" 37 | 38 | print("-----------------test SM2withSM3 sign and verify---------------") 39 | sign = sm2_crypt.sign_with_sm3(data, random_hex_str) 40 | print('sign: %s' % sign) 41 | verify = sm2_crypt.verify_with_sm3(sign, data) 42 | print('verify: %s' % verify) 43 | assert verify 44 | 45 | 46 | if __name__ == '__main__': 47 | test_sm2() 48 | test_sm2sm3() 49 | 50 | -------------------------------------------------------------------------------- /tests/test_sm3.py: -------------------------------------------------------------------------------- 1 | from gmssl import sm3, func 2 | 3 | if __name__ == '__main__': 4 | y = sm3.sm3_hash(func.bytes_to_list(b"abc")) 5 | print(y) 6 | -------------------------------------------------------------------------------- /tests/test_sm4.py: -------------------------------------------------------------------------------- 1 | from gmssl.sm4 import CryptSM4, SM4_ENCRYPT, SM4_DECRYPT 2 | 3 | key = b'3l5butlj26hvv313' 4 | value = b'111' 5 | iv = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' 6 | crypt_sm4 = CryptSM4() 7 | 8 | crypt_sm4.set_key(key, SM4_ENCRYPT) 9 | encrypt_value = crypt_sm4.crypt_ecb(value) 10 | crypt_sm4.set_key(key, SM4_DECRYPT) 11 | decrypt_value = crypt_sm4.crypt_ecb(encrypt_value) 12 | assert value == decrypt_value 13 | 14 | crypt_sm4.set_key(key, SM4_ENCRYPT) 15 | encrypt_value = crypt_sm4.crypt_cbc(iv , value) 16 | crypt_sm4.set_key(key, SM4_DECRYPT) 17 | decrypt_value = crypt_sm4.crypt_cbc(iv , encrypt_value) 18 | assert value == decrypt_value 19 | -------------------------------------------------------------------------------- /tests/test_sm9.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from gmssl import sm9 4 | 5 | if __name__ == '__main__': 6 | idA = 'a' 7 | idB = 'b' 8 | 9 | print ("-----------------test sign and verify---------------") 10 | 11 | master_public, master_secret = sm9.setup ('sign') 12 | 13 | Da = sm9.private_key_extract ('sign', master_public, master_secret, idA) 14 | 15 | message = 'abc' 16 | signature = sm9.sign (master_public, Da, message) 17 | 18 | assert (sm9.verify (master_public, idA, message, signature)) 19 | 20 | print ("\t\t\t success") 21 | 22 | print ("-----------------test key agreement---------------") 23 | 24 | master_public, master_secret = sm9.setup ('keyagreement') 25 | 26 | Da = sm9.private_key_extract ('keyagreement', master_public, master_secret, idA) 27 | Db = sm9.private_key_extract ('keyagreement', master_public, master_secret, idB) 28 | 29 | xa, Ra = sm9.generate_ephemeral (master_public, idB) 30 | xb, Rb = sm9.generate_ephemeral (master_public, idA) 31 | 32 | ska = sm9.generate_session_key (idA, idB, Ra, Rb, Da, xa, master_public, 'A', 128) 33 | skb = sm9.generate_session_key (idA, idB, Ra, Rb, Db, xb, master_public, 'B', 128) 34 | 35 | assert (ska == skb) 36 | 37 | print ("\t\t\t success") 38 | 39 | print ("-----------------test encrypt and decrypt---------------") 40 | 41 | master_public, master_secret = sm9.setup ('encrypt') 42 | 43 | Da = sm9.private_key_extract ('encrypt', master_public, master_secret, idA) 44 | 45 | message = 'abc' 46 | ct = sm9.kem_dem_enc (master_public, idA, message, 32) 47 | pt = sm9.kem_dem_dec (master_public, idA, Da, ct, 32) 48 | 49 | assert (message == pt) 50 | 51 | print ("\t\t\t success") 52 | --------------------------------------------------------------------------------