├── .gitignore ├── LICENSE ├── README.md ├── __about__.py ├── __init__.py ├── example.py ├── requirements.txt ├── smpcp ├── __init__.py └── smpcp.py ├── test.py └── test_case ├── BeautifulReport.py ├── __init__.py ├── report ├── CSVReport.csv └── HTMLReport.html ├── template └── template └── test_smpcp.py /.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | /__pycache__ 5 | /.idea 6 | /pypi 7 | # Datasource local storage ignored files 8 | /dataSources/ 9 | /dataSources.local.xml 10 | # 基于编辑器的 HTTP 客户端请求 11 | /httpRequests/ 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Zhan Shi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

安全多方计算协议

2 | 3 | 4 | 5 | 6 | pypi 7 | 8 | --- 9 | 10 | ## 项目背景 11 | 12 | 安全多方计算(Secure Multi-Party Computation)的研究主要是针对无可信第三方的情况下,如何安全地计算一个约定函数的问题。安全多方计算是电子选举、门限签名以及电子拍卖等诸多应用得以实施的密码学基础。 13 | 14 | 一个安全多方计算协议,如果对于拥有无限计算能力攻击者而言是安全的,则称作是信息论安全的或无条件安全的;如果对于拥有多项式计算能力的攻击者是安全的,则称为是密码学安全的或条件安全的。 15 | 16 | 已有的结果证明了在无条件安全模型下,当且仅当恶意参与者的人数少于总人数的1/3时,安全的方案才存在。而在条件安全模型下,当且仅当恶意参与者的人数少于总人数的一半时,安全的方案才存在。 17 | 18 | 安全多方计算起源于1982年[姚期智](https://baike.baidu.com/item/姚期智)的百万富翁问题。后来Oded Goldreich有比较细致系统的论述。 19 | 20 | 基于[phe](https://github.com/data61/python-paillier)库 (Paillier Homomorphic Encryption) 的安全多方计算协议实现,包含: 21 | 22 | - 安全乘法协议 23 | - 安全除法协议 24 | - 安全最大值计算协议 25 | - 安全最小值计算协议 26 | - 安全奇偶性判断协议 27 | - 安全二进制分解协议 28 | - 安全二进制与协议 29 | - 安全二进制或协议 30 | - 安全二进制非协议 31 | - 安全二进制异或协议 32 | - 安全相等协议 33 | - 安全不相等协议 34 | - 安全大于协议 35 | - 安全大于等于协议 36 | - 安全小于协议 37 | - 安全小于等于协议 38 | 39 | --- 40 | 41 | ## 项目环境 42 | 43 | - `python3.8` 44 | - `gmpy2>=2.0.8` 45 | - `pandas>=1.2.4` 46 | - `phe>=1.4.0` 47 | - `tqdm>=4.59.0` 48 | - `numpy>=1.20.2` 49 | 50 | 详见`requirements.txt`。 51 | 52 | --- 53 | 54 | ## 项目示例 55 | 56 | ### 准备工作 57 | 58 | 安全依赖环境: `pip install -r requirements.txt` 59 | 60 | 安装`smpcp`库: `pip install smpcp` 61 | 62 | 引入`phe`库: `import phe` 63 | 64 | 引入`smpcp`库: `from smpcp import CloudPlatform, CloudPlatformThird, SecureMultiPartyComputationProtocol` 65 | 66 | ### 生成密钥 67 | 68 | ```python 69 | public_key, secret_key = phe.generate_paillier_keypair(n_length=2048) 70 | ``` 71 | 72 | 其中`n_length`为密钥长度。 73 | 74 | ### 定义云服务器 75 | 76 | ```python 77 | cloud1 = CloudPlatform(public_key=public_key) 78 | cloud2 = CloudPlatformThird(public_key=public_key, secret_key=secret_key) 79 | ``` 80 | 81 | ### 定义安全多方计算协议 82 | 83 | ```python 84 | protocol = SecureMultiPartyComputationProtocol(c1=cloud1, c2=cloud2) 85 | ``` 86 | 87 | ### 编码 88 | 89 | ```python 90 | n1 = protocol.encode(public_key.encrypt(6)) 91 | n2 = public_key.encrypt(3) 92 | b1 = protocol.encode(public_key.encrypt(1)) 93 | b2 = public_key.encrypt(0) 94 | ``` 95 | 96 | ### 解码 97 | 98 | ```python 99 | assert secret_key.decrypt(n1.decode()) == 6 100 | ``` 101 | 102 | ### 安全多方计算协议实现 103 | 104 | ```python 105 | # TODO 安全乘法协议 106 | assert secret_key.decrypt(n1 * n2) == 18 107 | # TODO 安全除法协议 108 | assert secret_key.decrypt(n1 / n2) == 2 109 | # TODO 安全最大值协议 110 | assert secret_key.decrypt(n1.optimum(n2, 'max')) == 6 111 | # TODO 安全最小值协议 112 | assert secret_key.decrypt(n1.optimum(n2, 'min')) == 3 113 | # TODO 安全奇偶性判断协议 114 | assert secret_key.decrypt(n1.parity()) == 0 115 | assert secret_key.decrypt(protocol.encode(n2).parity()) == 1 116 | # TODO 安全二进制分解协议 117 | bit = [] 118 | for v in n1.bit_dec(3): 119 | bit.append(secret_key.decrypt(v)) 120 | assert bit == [1, 1, 0] 121 | # TODO 安全二进制与协议 122 | assert secret_key.decrypt(b1 | b2) == 1 123 | # TODO 安全二进制或协议 124 | assert secret_key.decrypt(b1 & b2) == 0 125 | # TODO 安全二进制非协议 126 | assert secret_key.decrypt(b1.bit_not()) == 0 127 | # TODO 安全二进制异或协议 128 | assert secret_key.decrypt(b1 ^ b2) == 1 129 | # TODO 安全相等协议 130 | assert secret_key.decrypt(n1 == n2) == 0 131 | assert secret_key.decrypt(n1 == n2 * 2) == 1 132 | # TODO 安全不相等协议 133 | assert secret_key.decrypt(n1 != n2) == 1 134 | assert secret_key.decrypt(n1 != n2 * 2) == 0 135 | # TODO 安全大于协议 136 | assert secret_key.decrypt(n1 > n2) == 1 137 | assert secret_key.decrypt(n2 > n1) == 0 138 | # TODO 安全大于等于协议 139 | assert secret_key.decrypt(n1 >= n2) == 1 140 | assert secret_key.decrypt(n2 >= n1) == 0 141 | # TODO 安全小于协议 142 | assert secret_key.decrypt(n1 < n2) == 0 143 | assert secret_key.decrypt(n2 < n1) == 1 144 | # TODO 安全小于等于协议 145 | assert secret_key.decrypt(n1 <= n2) == 0 146 | assert secret_key.decrypt(n2 <= n1) == 1 147 | ``` 148 | 149 | 详见`example.py`。 150 | 151 | --- 152 | 153 | ## 项目测试 154 | 155 | 经过`Unit Test`测试,测试结果如下: 156 | 157 | ```python 158 | key_length = 2048 # TODO 密钥长度 159 | 160 | public_key, secret_key = phe.generate_paillier_keypair(n_length=key_length) # 生成密钥对 161 | 162 | cloud1 = CloudPlatform(public_key=public_key) # 云服务器1 163 | cloud2 = CloudPlatformThird(public_key=public_key, secret_key=secret_key) # 云服务器2 164 | 165 | protocol = SecureMultiPartyComputationProtocol(c1=cloud1, c2=cloud2) # 安全多方计算协议类 166 | 167 | 168 | class SMPCPTest(unittest.TestCase): 169 | """ 170 | 安全多方计算协议测试类 171 | """ 172 | 173 | def setUp(self): 174 | """ 175 | 测试前 176 | """ 177 | # 生成浮点数 178 | self.float1 = int( 179 | gmpy2.mpz_random(gmpy2.random_state( 180 | int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)), 181 | key_length))), key_length)) * random.uniform(0.1, 1.0) 182 | self.float2 = int( 183 | gmpy2.mpz_random(gmpy2.random_state( 184 | int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)), key_length))), 185 | key_length)) * random.uniform(0.1, 1.0) 186 | self.float_n1 = protocol.encode(public_key.encrypt(self.float1)) 187 | self.float_n2 = public_key.encrypt(self.float2) 188 | # 生成整数 189 | self.int1 = int(gmpy2.mpz_random(gmpy2.random_state( 190 | int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)), key_length))), 191 | key_length)) 192 | self.int2 = int(gmpy2.mpz_random(gmpy2.random_state( 193 | int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)), key_length))), 194 | key_length)) 195 | self.int_n1 = protocol.encode(public_key.encrypt(self.int1)) 196 | self.int_n2 = public_key.encrypt(self.int2) 197 | return super().setUp() 198 | 199 | def tearDown(self): 200 | """ 201 | 测试后 202 | """ 203 | return super().tearDown() 204 | 205 | # TODO 安全乘法协议测试 206 | # @unittest.skip('跳过安全乘法协议') 207 | def test_mul(self): 208 | """ 209 | 安全乘法协议 210 | """ 211 | # 浮点乘法测试:经过测试,最高支持8位浮点乘法 212 | self.assertEqual(round(self.float1 * self.float2, 8), 213 | round(secret_key.decrypt(self.float_n1 * self.float_n2), 8)) 214 | 215 | # 整数乘法测试:经过测试,无明显问题 216 | self.assertEqual(self.int1 * self.int2, secret_key.decrypt(self.int_n1 * self.int_n2)) 217 | 218 | # TODO 安全除法协议测试 219 | # @unittest.skip('跳过安全除法协议') 220 | def test_div(self): 221 | """ 222 | 安全除法协议 223 | """ 224 | # 浮点除法测试:经过测试,最高支持10位浮点除法 225 | self.assertEqual(round(self.float1 / self.float2, 10), 226 | round(secret_key.decrypt(self.float_n1 / self.float_n2), 10)) 227 | 228 | # 整数除法测试:经过测试,最高支持10位整数除法 229 | self.assertEqual(round(self.int1 / self.int2, 10), round(secret_key.decrypt(self.int_n1 / self.int_n2), 10)) 230 | 231 | # TODO 安全最值计算协议测试 232 | # @unittest.skip('跳过安全最值计算协议') 233 | def test_optimum(self): 234 | """ 235 | 安全最值计算协议 236 | """ 237 | mode = 'max' if random.random() > 0.5 else 'min' 238 | if mode == 'max': 239 | # 浮点最大值计算测试:经过测试,无明显问题 240 | self.assertEqual(max(self.float1, self.float2), 241 | secret_key.decrypt(self.float_n1.optimum(self.float_n2, 'max'))) 242 | 243 | # 整数最大值计算测试:经过测试,无明显问题 244 | self.assertEqual(max(self.int1, self.int2), secret_key.decrypt(self.int_n1.optimum(self.int_n2, 'max'))) 245 | else: 246 | # 浮点最小值计算测试:经过测试,无明显问题 247 | self.assertEqual(min(self.float1, self.float2), 248 | secret_key.decrypt(self.float_n1.optimum(self.float_n2, 'min'))) 249 | 250 | # 整数最小值计算测试:经过测试,无明显问题 251 | self.assertEqual(min(self.int1, self.int2), secret_key.decrypt(self.int_n1.optimum(self.int_n2, 'min'))) 252 | 253 | # TODO 安全奇偶性判断协议测试 254 | # @unittest.skip('跳过安全奇偶性判断协议') 255 | def test_parity(self): 256 | """ 257 | 安全奇偶性判断协议 258 | """ 259 | # 整数奇偶性判断测试:经过测试,无明显问题 260 | self.assertEqual(self.int1 % 2, secret_key.decrypt(self.int_n1.parity())) 261 | 262 | # TODO 安全二进制分解协议测试 263 | # @unittest.skip('跳过安全二进制分解协议') 264 | def test_bit_dec(self): 265 | """ 266 | 安全二进制分解协议 267 | """ 268 | # 整数二进制分解测试:经过测试,无明显问题 269 | bit = len(bin(self.int1).split('b')[1]) 270 | result = ''.join([str(secret_key.decrypt(v)) for v in self.int_n1.bit_dec(bit)]) 271 | self.assertEqual(bin(self.int1).split('b')[1], result) 272 | 273 | # TODO 安全二进制与协议测试 274 | # @unittest.skip('跳过安全二进制与协议') 275 | def test_and(self): 276 | """ 277 | 安全二进制与协议 278 | """ 279 | bit1 = random.SystemRandom().randint(0, 1) 280 | bit2 = random.SystemRandom().randint(0, 1) 281 | bit_n1 = protocol.encode(public_key.encrypt(bit1)) 282 | bit_n2 = public_key.encrypt(bit2) 283 | # 二进制或测试:经过测试,无明显问题 284 | self.assertEqual(bit1 & bit2, secret_key.decrypt(bit_n1 & bit_n2)) 285 | 286 | # TODO 安全二进制或协议测试 287 | # @unittest.skip('跳过安全二进制或协议') 288 | def test_or(self): 289 | """ 290 | 安全二进制或协议 291 | """ 292 | bit1 = random.SystemRandom().randint(0, 1) 293 | bit2 = random.SystemRandom().randint(0, 1) 294 | bit_n1 = protocol.encode(public_key.encrypt(bit1)) 295 | bit_n2 = public_key.encrypt(bit2) 296 | # 二进制或测试:经过测试,无明显问题 297 | self.assertEqual(bit1 | bit2, secret_key.decrypt(bit_n1 | bit_n2)) 298 | 299 | # TODO 安全二进制非协议测试 300 | # @unittest.skip('跳过安全二进制非协议') 301 | def test_bit_not(self): 302 | """ 303 | 安全二进制非协议 304 | """ 305 | bit1 = random.SystemRandom().randint(0, 1) 306 | bit_n1 = protocol.encode(public_key.encrypt(bit1)) 307 | # 二进制或测试:经过测试,无明显问题 308 | self.assertEqual(1 - bit1, secret_key.decrypt(bit_n1.bit_not())) 309 | 310 | # TODO 安全二进制异或协议测试 311 | # @unittest.skip('跳过安全二进制异或协议') 312 | def test_xor(self): 313 | """ 314 | 安全二进制异或协议 315 | """ 316 | bit1 = random.SystemRandom().randint(0, 1) 317 | bit2 = random.SystemRandom().randint(0, 1) 318 | bit_n1 = protocol.encode(public_key.encrypt(bit1)) 319 | bit_n2 = public_key.encrypt(bit2) 320 | # 二进制或测试:经过测试,无明显问题 321 | self.assertEqual(bit1 ^ bit2, secret_key.decrypt(bit_n1 ^ bit_n2)) 322 | 323 | # TODO 安全相等协议测试 324 | # @unittest.skip('跳过安全相等协议') 325 | def test_eq(self): 326 | """ 327 | 安全相等协议 328 | """ 329 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果 330 | self.assertEqual(1 if self.float1 == self.float1 else 0, 331 | secret_key.decrypt(self.float_n1 == self.float_n1.decode())) 332 | 333 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果 334 | self.assertEqual(1 if self.int1 == self.int1 else 0, secret_key.decrypt(self.int_n1 == self.int_n1.decode())) 335 | 336 | # TODO 安全不相等协议测试 337 | # @unittest.skip('跳过安全不相等协议') 338 | def test_ne(self): 339 | """ 340 | 安全不相等协议 341 | """ 342 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果 343 | self.assertEqual(1 if self.float1 != self.float2 else 0, secret_key.decrypt(self.float_n1 != self.float_n2)) 344 | 345 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果 346 | self.assertEqual(1 if self.int1 != self.int2 else 0, secret_key.decrypt(self.int_n1 != self.int_n2)) 347 | 348 | # TODO 安全大于协议测试 349 | # @unittest.skip('跳过安全大于协议') 350 | def test_gt(self): 351 | """ 352 | 安全大于协议 353 | """ 354 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果 355 | self.assertEqual(1 if self.float1 > self.float2 else 0, secret_key.decrypt(self.float_n1 > self.float_n2)) 356 | 357 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果 358 | self.assertEqual(1 if self.int1 > self.int2 else 0, secret_key.decrypt(self.int_n1 > self.int_n2)) 359 | 360 | # TODO 安全大于等于协议测试 361 | # @unittest.skip('跳过安全大于等于协议') 362 | def test_ge(self): 363 | """ 364 | 安全大于等于协议 365 | """ 366 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果 367 | self.assertEqual(1 if self.float1 >= self.float2 else 0, secret_key.decrypt(self.float_n1 >= self.float_n2)) 368 | 369 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果 370 | self.assertEqual(1 if self.int1 >= self.int2 else 0, secret_key.decrypt(self.int_n1 >= self.int_n2)) 371 | 372 | # TODO 安全小于协议测试 373 | # @unittest.skip('跳过安全小于协议') 374 | def test_lt(self): 375 | """ 376 | 安全小于协议 377 | """ 378 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果 379 | self.assertEqual(1 if self.float1 < self.float2 else 0, secret_key.decrypt(self.float_n1 < self.float_n2)) 380 | 381 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果 382 | self.assertEqual(1 if self.int1 < self.int2 else 0, secret_key.decrypt(self.int_n1 < self.int_n2)) 383 | 384 | # TODO 安全小于等于协议测试 385 | # @unittest.skip('跳过安全小于等于协议') 386 | def test_le(self): 387 | """ 388 | 安全小于等于协议 389 | """ 390 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果 391 | self.assertEqual(1 if self.float1 <= self.float2 else 0, secret_key.decrypt(self.float_n1 <= self.float_n2)) 392 | 393 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果 394 | self.assertEqual(1 if self.int1 <= self.int2 else 0, secret_key.decrypt(self.int_n1 <= self.int_n2)) 395 | ``` 396 | 397 | 详见`test_case/test_smpcp.py`, 项目报告依赖基于`unittest`的[项目](https://github.com/TesterlifeRaymond/BeautifulReport)`test_case/BeautifulReport.py`。 398 | 399 | --- 400 | 401 | ## 联系方式 402 | 403 | 作者:沈阳航空航天大学 数据安全与隐私计算课题组 施展 404 | 405 | Github: https://github.com/shine813/ 406 | 407 | Pypi: https://pypi.org/project/smpcp/ 408 | 409 | 邮箱:phe.zshi@gmail.com 410 | 411 | 如有问题,可及时联系作者 412 | -------------------------------------------------------------------------------- /__about__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | __all__ = [ 4 | "__title__", "__summary__", "__uri__", 5 | "__version__", "__author__", "__email__" 6 | ] 7 | 8 | __title__ = "smpcp" 9 | __summary__ = "Secure Multi-Party Computation Protocol base on Partially Homomorphic Encryption for Python" 10 | __uri__ = "https://github.com/shine813/Secure-Multi-Party-Computation-Protocol/" 11 | 12 | __version__ = "2.0.2" 13 | 14 | __author__ = "Zhan Shi" 15 | __email__ = "phe.zshi@gmail.com" 16 | 17 | __license__ = "MIT" 18 | __copyright__ = "Copyright (c) 2021 {0}".format(__author__) 19 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Version: 2.0.2 5 | @Project: Secure-Multi-Party-Computation-Protocol 6 | @Author: Zhan Shi 7 | @Time : 2022/5/5 12:17 8 | @File: __init__.py 9 | @License: MIT 10 | """ 11 | 12 | __name__ = 'Secure-Multi-Party-Computation-Protocol' 13 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Version: 2.0.2 5 | @Project: Secure-Multi-Party-Computation-Protocol 6 | @Author: Zhan Shi 7 | @Time : 2022/5/1 09:35 8 | @File: example.py 9 | @License: MIT 10 | """ 11 | import phe 12 | 13 | from smpcp.smpcp import CloudPlatform, CloudPlatformThird, SecureMultiPartyComputationProtocol 14 | 15 | # TODO 生成密钥 16 | public_key, secret_key = phe.generate_paillier_keypair(n_length=128) 17 | # TODO 定义云服务器 18 | cloud1 = CloudPlatform(public_key=public_key) 19 | cloud2 = CloudPlatformThird(public_key=public_key, secret_key=secret_key) 20 | # TODO 定义安全多方计算协议 21 | protocol = SecureMultiPartyComputationProtocol(c1=cloud1, c2=cloud2) 22 | 23 | if __name__ == '__main__': 24 | # TODO 安全多方计算协议编码 25 | n1 = protocol.encode(public_key.encrypt(6)) 26 | n2 = public_key.encrypt(3) 27 | b1 = protocol.encode(public_key.encrypt(1)) 28 | b2 = public_key.encrypt(0) 29 | # TODO 协议解码 30 | assert secret_key.decrypt(n1.decode()) == 6 31 | # TODO 安全乘法协议 32 | assert secret_key.decrypt(n1 * n2) == 18 33 | # TODO 安全除法协议 34 | assert secret_key.decrypt(n1 / n2) == 2 35 | # TODO 安全最大值协议 36 | assert secret_key.decrypt(n1.optimum(n2, 'max')) == 6 37 | # TODO 安全最小值协议 38 | assert secret_key.decrypt(n1.optimum(n2, 'min')) == 3 39 | # TODO 安全奇偶性判断协议 40 | assert secret_key.decrypt(n1.parity()) == 0 41 | assert secret_key.decrypt(protocol.encode(n2).parity()) == 1 42 | # TODO 安全二进制分解协议 43 | bit = [] 44 | for v in n1.bit_dec(3): 45 | bit.append(secret_key.decrypt(v)) 46 | assert bit == [1, 1, 0] 47 | # TODO 安全二进制与协议 48 | assert secret_key.decrypt(b1 & b2) == 1 49 | # TODO 安全二进制或协议 50 | assert secret_key.decrypt(b1 | b2) == 0 51 | # TODO 安全二进制非协议 52 | assert secret_key.decrypt(b1.bit_not()) == 0 53 | # TODO 安全二进制异或协议 54 | assert secret_key.decrypt(b1 ^ b2) == 1 55 | # TODO 安全相等协议 56 | assert secret_key.decrypt(n1 == n2) == 0 57 | assert secret_key.decrypt(n1 == n2 * 2) == 1 58 | # TODO 安全不相等协议 59 | assert secret_key.decrypt(n1 != n2) == 1 60 | assert secret_key.decrypt(n1 != n2 * 2) == 0 61 | # TODO 安全大于协议 62 | assert secret_key.decrypt(n1 > n2) == 1 63 | assert secret_key.decrypt(n2 > n1) == 0 64 | # TODO 安全大于等于协议 65 | assert secret_key.decrypt(n1 >= n2) == 1 66 | assert secret_key.decrypt(n2 >= n1) == 0 67 | # TODO 安全小于协议 68 | assert secret_key.decrypt(n1 < n2) == 0 69 | assert secret_key.decrypt(n2 < n1) == 1 70 | # TODO 安全小于等于协议 71 | assert secret_key.decrypt(n1 <= n2) == 0 72 | assert secret_key.decrypt(n2 <= n1) == 1 73 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gmpy2==2.0.8 2 | pandas==1.2.4 3 | phe==1.4.0 4 | tqdm==4.59.0 5 | numpy==1.20.2 6 | -------------------------------------------------------------------------------- /smpcp/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Version: 2.0.2 5 | @Project: Secure-Multi-Party-Computation-Protocol 6 | @Author: Zhan Shi 7 | @Time : 2022/5/6 13:26 8 | @File: __init__.py 9 | @License: MIT 10 | """ 11 | 12 | __name__ = 'smpcp' 13 | -------------------------------------------------------------------------------- /smpcp/smpcp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Version: 2.0.2 5 | @Project: Secure-Multi-Party-Computation-Protocol 6 | @Author: Zhan Shi 7 | @Time : 2022/5/3 12:23 8 | @File: smpcp.py 9 | @License: MIT 10 | """ 11 | import random 12 | 13 | import gmpy2 14 | 15 | 16 | class SecureMultiPartyComputationProtocol: 17 | """ 18 | 安全多方计算协议类 19 | """ 20 | 21 | def __init__(self, c1, c2, cipher=None): 22 | """ 23 | 安全多方计算协议类 定义 24 | :param c1: 云服务器 25 | :param c2: 第三方云服务器 26 | :param cipher: 密文 27 | """ 28 | self.c1 = c1 29 | self.c2 = c2 30 | self.cipher = cipher 31 | 32 | def encode(self, encrypted_number): 33 | """ 34 | 安全多方计算协议类 编码 35 | :param encrypted_number: 加密数字 36 | :return: 编码后的加密数字 37 | """ 38 | return SecureMultiPartyComputationProtocol(c1=self.c1, c2=self.c2, cipher=encrypted_number) 39 | 40 | def decode(self): 41 | """ 42 | 安全多方计算协议类 解码 43 | :return: 解码后的加密数字 44 | """ 45 | return self.cipher 46 | 47 | def __mul__(self, other): 48 | """ 49 | TODO 安全多方计算协议类 安全乘法协议 50 | :param other: 密文 51 | :return: 安全乘法协议结果 E(self.cipher * other) 52 | """ 53 | return self.c1.mul(self.cipher, other, self.c2) 54 | 55 | def __truediv__(self, other): 56 | """ 57 | TODO 安全多方计算协议类 安全除法协议 58 | :param other: 密文 59 | :return: 安全除法协议结果 E(self.cipher / other) 60 | """ 61 | return self.c1.truediv(self.cipher, other, self.c2) 62 | 63 | def optimum(self, other, mode): 64 | """ 65 | TODO 安全多方计算协议类 安全最值计算协议 66 | :param other: 密文 67 | :param mode: 'max' or 'min' 68 | :return: 安全最值计算协议结果 E(max(self.cipher, other)) or E(min(self.cipher, other)) 69 | """ 70 | return self.c1.optimum(self.cipher, other, self.c2, mode) 71 | 72 | def parity(self): 73 | """ 74 | TODO 安全多方计算协议类 安全奇偶性判断协议 75 | :return: 安全奇偶性判断协议结果 奇数: E(1) 偶数: E(0) 76 | """ 77 | return self.c1.parity(self.cipher, self.c2) 78 | 79 | def bit_dec(self, bit): 80 | """ 81 | TODO 安全多方计算协议类 安全二进制分解协议 82 | :param bit: 位数 83 | :return: 安全二进制分解协议结果 self.cipher的二进制数列 -> [E(1) or E(0), ...] 长度为bit 84 | """ 85 | return self.c1.bit_dec(self.cipher, bit, self.c2) 86 | 87 | def __and__(self, other): 88 | """ 89 | TODO 安全多方计算协议类 安全二进制与协议 90 | ! 只能用于二进制数 91 | :param other: 密文 92 | :return: 安全二进制与协议结果 E(self.cipher & other) 93 | """ 94 | return self.c1.bit_and(self.cipher, other, self.c2) 95 | 96 | def __or__(self, other): 97 | """ 98 | TODO 安全多方计算协议类 安全二进制或协议 99 | ! 只能用于二进制数 100 | :param other: 密文 101 | :return: 安全二进制或协议结果 E(self.cipher | other) 102 | """ 103 | return self.c1.bit_or(self.cipher, other, self.c2) 104 | 105 | def bit_not(self): 106 | """ 107 | TODO 安全多方计算协议类 安全二进制非协议 108 | ! 只能用于二进制数 109 | :return: 安全二进制非协议结果 E(!self.cipher) 110 | """ 111 | return self.c1.bit_not(self.cipher) 112 | 113 | def __xor__(self, other): 114 | """ 115 | TODO 安全多方计算协议类 安全二进制异或协议 116 | ! 只能用于二进制数 117 | :return: 安全二进制异或协议结果 E(self.cipher ^ other) 118 | """ 119 | return self.c1.bit_xor(self.cipher, other, self.c2) 120 | 121 | def __eq__(self, other): 122 | """ 123 | TODO 安全多方计算协议类 安全相等协议 124 | :param other: 密文 125 | :return: 安全相等协议结果 E(self.cipher == other) 126 | """ 127 | return self.c1.eq(self.cipher, other, self.c2) 128 | 129 | def __ne__(self, other): 130 | """ 131 | TODO 安全多方计算协议类 安全不相等协议 132 | :param other: 密文 133 | :return: 安全不相等协议结果 E(self.cipher != other) 134 | """ 135 | return self.c1.ne(self.cipher, other, self.c2) 136 | 137 | def __gt__(self, other): 138 | """ 139 | TODO 安全多方计算协议类 安全大于协议 140 | :param other: 密文 141 | :return: 安全大于协议结果 E(self.cipher > other) 142 | """ 143 | return self.c1.gt(self.cipher, other, self.c2) 144 | 145 | def __ge__(self, other): 146 | """ 147 | TODO 安全多方计算协议类 安全大于等于协议 148 | :param other: 密文 149 | :return: 安全大于等于协议结果 E(self.cipher >= other) 150 | """ 151 | return self.c1.ge(self.cipher, other, self.c2) 152 | 153 | def __lt__(self, other): 154 | """ 155 | TODO 安全多方计算协议类 安全小于协议 156 | :param other: 密文 157 | :return: 安全小于协议结果 E(self.cipher < other) 158 | """ 159 | return self.c1.lt(self.cipher, other, self.c2) 160 | 161 | def __le__(self, other): 162 | """ 163 | TODO 安全多方计算协议类 安全小于等于协议 164 | :param other: 密文 165 | :return: 安全小于等于协议结果 E(self.cipher <= other) 166 | """ 167 | return self.c1.le(self.cipher, other, self.c2) 168 | 169 | 170 | class CloudPlatform: 171 | """ 172 | 云服务器类 173 | """ 174 | 175 | def __init__(self, public_key): 176 | """ 177 | 云服务器类 定义 178 | :param public_key: 公钥 179 | """ 180 | self.public_key = public_key 181 | self.key_length = len(str(self.public_key.n)) 182 | 183 | def mul(self, c1, c2, cloud_platform_third): 184 | """ 185 | TODO 云服务器类 安全乘法协议 186 | :param c1: 密文1 187 | :param c2: 密文2 188 | :param cloud_platform_third: 第三方云服务器 189 | :return: 加密乘法结果 190 | """ 191 | r1 = self._generate_random() 192 | r2 = self._generate_random() 193 | 194 | h1 = c1 + r1 195 | h2 = c2 + r2 196 | 197 | return cloud_platform_third.mul(h1, h2) - (c1 * r2 + c2 * r1 + r1 * r2) 198 | 199 | def truediv(self, c1, c2, cloud_platform_third): 200 | """ 201 | TODO 云服务器类 安全除法协议 202 | :param c1: 密文1 203 | :param c2: 密文2 204 | :param cloud_platform_third: 第三方云服务器 205 | :return: 加密除法结果 206 | """ 207 | r1 = self._generate_random() 208 | r2 = self._generate_random() 209 | 210 | h1 = c1 * r1 + c2 * r1 * r2 211 | h2 = c2 * r1 212 | 213 | return cloud_platform_third.truediv(h1, h2) - r2 214 | 215 | def optimum(self, c1, c2, cloud_platform_third, mode): 216 | """ 217 | TODO 云服务器类 安全最值计算协议 218 | :param c1: 密文1 219 | :param c2: 密文2 220 | :param cloud_platform_third: 第三方云服务器 221 | :param mode: 'max' or 'min' 222 | :return: 加密最值计算结果 223 | """ 224 | r1 = self._generate_random() 225 | r2 = self._generate_random() 226 | r3 = self._generate_random() 227 | 228 | if random.random() > 5e-1: 229 | h1 = (c1 - c2) * r1 230 | h2 = c1 + r2 231 | h3 = c2 + r3 232 | else: 233 | h1 = (c2 - c1) * r1 234 | h2 = c2 + r2 235 | h3 = c1 + r3 236 | 237 | alpha, beta = cloud_platform_third.optimum(h1, h2, h3, mode) 238 | 239 | return c1 + c2 - beta + alpha * r3 + (1 - alpha) * r2 240 | 241 | def parity(self, c, cloud_platform_third): 242 | """ 243 | TODO 云服务器类 安全奇偶性判断协议 244 | :param c: 密文 245 | :param cloud_platform_third: 第三方云服务器 246 | :return: 加密奇偶性判断结果 247 | """ 248 | r = self._generate_random() 249 | h = c + r 250 | alpha = cloud_platform_third.parity(h) 251 | 252 | return alpha if r % 2 == 0 else 1 - alpha 253 | 254 | def bit_dec(self, c, bit, cloud_platform_third): 255 | """ 256 | TODO 云服务器类 安全二进制分解协议 257 | :param c: 密文 258 | :param bit: 位数 259 | :param cloud_platform_third: 第三方云服务器 260 | :return: 加密二进制分解结果 261 | """ 262 | sigma = 5e-1 263 | result = [] 264 | for i in range(bit): 265 | result.append(self.parity(c, cloud_platform_third)) 266 | zeta = c - result[i] 267 | c = zeta * sigma 268 | result.reverse() 269 | 270 | return result 271 | 272 | def bit_and(self, c1, c2, cloud_platform_third): 273 | """ 274 | TODO 云服务器类 安全二进制与协议 275 | :param c1: 密文1 276 | :param c2: 密文2 277 | :param cloud_platform_third: 第三方云服务器 278 | :return: 加密二进制与结果 279 | """ 280 | return self.mul(c1, c2, cloud_platform_third) 281 | 282 | def bit_or(self, c1, c2, cloud_platform_third): 283 | """ 284 | TODO 云服务器类 安全二进制或协议 285 | :param c1: 密文1 286 | :param c2: 密文2 287 | :param cloud_platform_third: 第三方云服务器 288 | :return: 加密二进制或结果 289 | """ 290 | return c1 + c2 - self.bit_and(c1, c2, cloud_platform_third) 291 | 292 | @staticmethod 293 | def bit_not(c): 294 | """ 295 | TODO 云服务器类 安全二进制非协议 296 | :param c: 密文 297 | :return: 加密二进制非结果 298 | """ 299 | return 1 - c 300 | 301 | def bit_xor(self, c1, c2, cloud_platform_third): 302 | """ 303 | TODO 云服务器类 安全二进制异或协议 304 | :param c1: 密文1 305 | :param c2: 密文2 306 | :param cloud_platform_third: 第三方云服务器 307 | :return: 加密二进制异或结果 308 | """ 309 | return c1 + c2 - 2 * self.mul(c1, c2, cloud_platform_third) 310 | 311 | def eq(self, c1, c2, cloud_platform_third): 312 | """ 313 | TODO 云服务器类 安全相等协议 314 | :param c1: 密文1 315 | :param c2: 密文2 316 | :param cloud_platform_third: 第三方云服务器 317 | :return: 加密相等结果 318 | """ 319 | sigma = -1 if random.random() > 5e-1 else 1 320 | r1 = self._generate_random() 321 | r2 = self._generate_random() 322 | if r2 > r1: 323 | r2, r1 = r1, r2 324 | alpha = r1 * sigma * self.mul(c1 - c2, c1 - c2, cloud_platform_third) - sigma * r2 325 | 326 | return cloud_platform_third.eq(alpha) if sigma == 1 else 1 - cloud_platform_third.eq(alpha) 327 | 328 | def ne(self, c1, c2, cloud_platform_third): 329 | """ 330 | TODO 云服务器类 安全不相等协议 331 | :param c1: 密文1 332 | :param c2: 密文2 333 | :param cloud_platform_third: 第三方云服务器 334 | :return: 加密不相等结果 335 | """ 336 | return 1 - self.eq(c1, c2, cloud_platform_third) 337 | 338 | def gt(self, c1, c2, cloud_platform_third): 339 | """ 340 | TODO 云服务器类 安全大于协议 341 | :param c1: 密文1 342 | :param c2: 密文2 343 | :param cloud_platform_third: 第三方云服务器 344 | :return: 加密大于结果 345 | """ 346 | sigma = -1 if random.random() > 5e-1 else 1 347 | r1 = self._generate_random() 348 | r2 = self._generate_random() 349 | (r2, r1) = (r1, r2) if r2 > r1 else (r2, r1) 350 | alpha = r1 * sigma * (c2 - c1) + sigma * r2 351 | 352 | return cloud_platform_third.eq(alpha) if sigma == 1 else 1 - cloud_platform_third.eq(alpha) 353 | 354 | def ge(self, c1, c2, cloud_platform_third): 355 | """ 356 | TODO 云服务器类 安全大于等于协议 357 | :param c1: 密文1 358 | :param c2: 密文2 359 | :param cloud_platform_third: 第三方云服务器 360 | :return: 加密大于等于结果 361 | """ 362 | return self.bit_or(self.eq(c1, c2, cloud_platform_third), self.gt(c1, c2, cloud_platform_third), 363 | cloud_platform_third) 364 | 365 | def lt(self, c1, c2, cloud_platform_third): 366 | """ 367 | TODO 云服务器类 安全小于协议 368 | :param c1: 密文1 369 | :param c2: 密文2 370 | :param cloud_platform_third: 第三方云服务器 371 | :return: 加密小于结果 372 | """ 373 | sigma = -1 if random.random() > 5e-1 else 1 374 | r1 = self._generate_random() 375 | r2 = self._generate_random() 376 | (r2, r1) = (r1, r2) if r2 > r1 else (r2, r1) 377 | alpha = r1 * sigma * (c1 - c2) + sigma * r2 378 | 379 | return cloud_platform_third.eq(alpha) if sigma == 1 else 1 - cloud_platform_third.eq(alpha) 380 | 381 | def le(self, c1, c2, cloud_platform_third): 382 | """ 383 | TODO 云服务器类 安全小于等于协议 384 | :param c1: 密文1 385 | :param c2: 密文2 386 | :param cloud_platform_third: 第三方云服务器 387 | :return: 加密小于等于结果 388 | """ 389 | return self.bit_or(self.eq(c1, c2, cloud_platform_third), self.lt(c1, c2, cloud_platform_third), 390 | cloud_platform_third) 391 | 392 | def _generate_random(self): 393 | """ 394 | 云服务器 随机数生成 395 | :return: 密钥长度的随机数 396 | """ 397 | return int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)), self.key_length)) 398 | 399 | 400 | class CloudPlatformThird: 401 | """ 402 | 第三方云服务器类 403 | """ 404 | 405 | def __init__(self, public_key, secret_key): 406 | """ 407 | 第三方云服务器类 定义 408 | :param public_key: 公钥 409 | :param secret_key: 私钥 410 | """ 411 | self.public_key = public_key 412 | self.secret_key = secret_key 413 | 414 | def mul(self, h1, h2): 415 | """ 416 | TODO 第三方云服务器类 安全乘法协议 417 | :param h1: 参数1 418 | :param h2: 参数2 419 | :return: 安全乘法协议结果 420 | """ 421 | return self.public_key.encrypt(self.secret_key.decrypt(h1) * self.secret_key.decrypt(h2)) 422 | 423 | def truediv(self, h1, h2): 424 | """ 425 | TODO 第三方云服务器类 安全除法协议 426 | :param h1: 参数1 427 | :param h2: 参数2 428 | :return: 安全除法协议结果 429 | """ 430 | h2 = self.secret_key.decrypt(h2) 431 | if h2 != 0: 432 | return self.public_key.encrypt(self.secret_key.decrypt(h1) / h2) 433 | else: 434 | assert ValueError("Divisor cannot be 0") 435 | 436 | def optimum(self, h1, h2, h3, mode): 437 | """ 438 | TODO 第三方云服务器类 安全最值计算协议 439 | :param h1: 参数 440 | :param h2: 参数 441 | :param h3: 参数 442 | :param mode: 'max' or 'min' 443 | :return: 安全最值计算协议结果 444 | """ 445 | mode = self.secret_key.decrypt(h1) > 0 if mode == 'max' else self.secret_key.decrypt(h1) < 0 446 | alpha = 1 if mode else 0 447 | 448 | return self.public_key.encrypt(alpha), h3 if alpha == 1 else h2 449 | 450 | def parity(self, h): 451 | """ 452 | TODO 第三方云服务器类 安全奇偶性判断协议 453 | :param h: 参数 454 | :return: 安全奇偶性判断协议结果 455 | """ 456 | return self.public_key.encrypt(0) if self.secret_key.decrypt(h) % 2 == 0 else self.public_key.encrypt(1) 457 | 458 | def eq(self, h): 459 | """ 460 | TODO 第三方云服务器类 安全相等协议 461 | :param h: 参数 462 | :return: 安全相等协议结果 463 | """ 464 | return self.public_key.encrypt(1) if self.secret_key.decrypt(h) < 0 else self.public_key.encrypt(0) 465 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Version: 2.0.2 5 | @Project: Secure-Multi-Party-Computation-Protocol 6 | @Author: Zhan Shi 7 | @Time : 2022/5/4 17:42 8 | @File: test.py 9 | @License: MIT 10 | """ 11 | import json 12 | import os 13 | import sys 14 | import time 15 | import unittest 16 | from multiprocessing import Pool, cpu_count 17 | 18 | import pandas as pd 19 | import tqdm 20 | 21 | sys.path.append("test_case/") # 添加测试文件路径 22 | from BeautifulReport import BeautifulReport 23 | 24 | TEMPLATE_PATH = 'test_case/template/template' # 模板文件路径 25 | REPORT_PATH = 'test_case/report' # 报告文件路径 26 | REPORT_FILE = 'HTMLReport.html' # 报告文件 27 | DESCRIPTION = '安全多方计算协议测试' # 报告名称 28 | 29 | 30 | def format_report(_report): 31 | """ 32 | 规范报告格式 33 | :param _report: 测试报告数据 34 | :return: 规范后的测试报告数据 35 | """ 36 | _output = {'testPass': 0, 'testResult': [], 'testFail': 0, 'testSkip': 0, 'testError': 0} 37 | 38 | for v in _report: 39 | _output['testPass'] += v.get()['testPass'] 40 | for m in v.get()['testResult']: 41 | _output.get('testResult').append(m) 42 | _output['testAll'] = len(v.get()['testResult']) 43 | _output['testName'] = v.get()['testName'] 44 | _output['testFail'] += v.get()['testFail'] 45 | _output['beginTime'] = v.get()['beginTime'] 46 | _output['totalTime'] = v.get()['totalTime'] 47 | _output['testError'] += v.get()['testError'] 48 | _output['testSkip'] += v.get()['testSkip'] 49 | 50 | return _output 51 | 52 | 53 | def output_report(_report): 54 | """ 55 | 输出测试报告 56 | :param _report: 测试报告数据 57 | """ 58 | pd.DataFrame(_report).to_csv("{0}/CSVReport.csv".format(REPORT_PATH), header=False, index=False) 59 | 60 | template_path = TEMPLATE_PATH 61 | override_path = os.path.abspath(REPORT_PATH) \ 62 | if os.path.abspath(REPORT_PATH).endswith('/') \ 63 | else os.path.abspath(REPORT_PATH) + '/' 64 | 65 | with open(template_path, 'rb') as file: 66 | body = file.readlines() 67 | with open(override_path + REPORT_FILE, 'wb') as write_file: 68 | for item in body: 69 | if item.strip().startswith(b'var resultData'): 70 | head = ' var resultData = ' 71 | item = item.decode().split(head) 72 | item[1] = head + json.dumps( 73 | _report, ensure_ascii=False, indent=4) 74 | item = ''.join(item).encode() 75 | item = bytes(item) + b';\n' 76 | write_file.write(item) 77 | 78 | 79 | def run(): 80 | """ 81 | 开始测试 82 | :return: 测试报告数据 83 | """ 84 | # 构造测试用例 85 | cases = unittest.defaultTestLoader.discover("test_case", pattern="test_smpcp.py", top_level_dir=None) 86 | # 读取测试用例 运行测试 87 | return BeautifulReport(cases).report(filename=REPORT_FILE, log_path=REPORT_PATH, description=DESCRIPTION) 88 | 89 | 90 | if __name__ == '__main__': 91 | start = time.time() # 开始时间 92 | times = 100 # TODO 测试次数 93 | process_pool = Pool(cpu_count()) # 开启进程池 94 | # 进度条 95 | process_bar = tqdm.tqdm(iterable=range(times), ncols=80, nrows=20, desc=DESCRIPTION) 96 | # 多进程测试 97 | report = [process_pool.apply_async(run, (), callback=lambda _: process_bar.update()) for _ in range(times)] 98 | # 进程开始 99 | process_pool.close() 100 | process_pool.join() 101 | # 测试报告 102 | output = format_report(report) 103 | output_report(output) 104 | -------------------------------------------------------------------------------- /test_case/BeautifulReport.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Version: 1.0 3 | @Project: BeautyReport 4 | @Author: Raymond 5 | @Data: 2017/11/15 下午5:28 6 | @File: BeautifulReport.py 7 | @License: MIT 8 | """ 9 | 10 | import base64 11 | import json 12 | import os 13 | import platform 14 | import sys 15 | import time 16 | import traceback 17 | import unittest 18 | from distutils.sysconfig import get_python_lib 19 | from functools import wraps 20 | from io import StringIO as StringIO 21 | 22 | __all__ = ['BeautifulReport'] 23 | 24 | HTML_IMG_TEMPLATE = """ 25 | 26 | 27 | 28 |

29 | """ 30 | 31 | 32 | class OutputRedirector(object): 33 | """ 34 | Wrapper to redirect stdout or stderr 35 | """ 36 | 37 | def __init__(self, fp): 38 | self.fp = fp 39 | 40 | def write(self, s): 41 | self.fp.write(s) 42 | 43 | def writelines(self, lines): 44 | self.fp.writelines(lines) 45 | 46 | def flush(self): 47 | self.fp.flush() 48 | 49 | 50 | stdout_redirector = OutputRedirector(sys.stdout) 51 | stderr_redirector = OutputRedirector(sys.stderr) 52 | 53 | SYSSTR = platform.system() 54 | SITE_PAKAGE_PATH = get_python_lib() 55 | 56 | FIELDS = { 57 | "testPass": 0, 58 | "testResult": [ 59 | ], 60 | "testName": "", 61 | "testAll": 0, 62 | "testFail": 0, 63 | "beginTime": "", 64 | "totalTime": "", 65 | "testSkip": 0 66 | } 67 | 68 | 69 | class PATH: 70 | """ 71 | all file PATH meta 72 | """ 73 | config_tmp_path = 'test_case/template/template' 74 | 75 | 76 | class MakeResultJson: 77 | """ 78 | make html table tags 79 | """ 80 | 81 | def __init__(self, datas: tuple): 82 | """ 83 | init self object 84 | :param datas: 拿到所有返回数据结构 85 | """ 86 | self.datas = datas 87 | self.result_schema = {} 88 | 89 | def __setitem__(self, key, value): 90 | """ 91 | 92 | :param key: self[key] 93 | :param value: value 94 | :return: 95 | """ 96 | self[key] = value 97 | 98 | def __repr__(self) -> str: 99 | """ 100 | 返回对象的html结构体 101 | :rtype: dict 102 | :return: self的repr对象, 返回一个构造完成的tr表单 103 | """ 104 | keys = ( 105 | 'className', 106 | 'methodName', 107 | 'description', 108 | 'spendTime', 109 | 'status', 110 | 'log', 111 | ) 112 | for key, data in zip(keys, self.datas): 113 | self.result_schema.setdefault(key, data) 114 | return json.dumps(self.result_schema) 115 | 116 | 117 | class ReportTestResult(unittest.TestResult): 118 | """ 119 | override 120 | """ 121 | 122 | def __init__(self, suite, stream=sys.stdout): 123 | """ 124 | pass 125 | """ 126 | super(ReportTestResult, self).__init__() 127 | self.begin_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 128 | self.start_time = 0 129 | self.stream = stream 130 | self.end_time = 0 131 | self.failure_count = 0 132 | self.error_count = 0 133 | self.success_count = 0 134 | self.skipped = 0 135 | self.verbosity = 1 136 | self.success_case_info = [] 137 | self.skipped_case_info = [] 138 | self.failures_case_info = [] 139 | self.errors_case_info = [] 140 | self.all_case_counter = 0 141 | self.suite = suite 142 | self.status = '' 143 | self.result_list = [] 144 | self.case_log = '' 145 | self.default_report_name = '自动化测试报告' 146 | self.FIELDS = None 147 | self.sys_stdout = None 148 | self.sys_stderr = None 149 | self.outputBuffer = None 150 | 151 | @property 152 | def success_counter(self) -> int: 153 | """ 154 | set success counter 155 | """ 156 | return self.success_count 157 | 158 | @success_counter.setter 159 | def success_counter(self, value) -> None: 160 | """ 161 | success_counter函数的setter方法, 用于改变成功的case数量 162 | :param value: 当前传递进来的成功次数的int数值 163 | :return: 164 | """ 165 | self.success_count = value 166 | 167 | def startTest(self, test) -> None: 168 | """ 169 | 当测试用例测试即将运行时调用 170 | :return: 171 | """ 172 | unittest.TestResult.startTest(self, test) 173 | self.outputBuffer = StringIO() 174 | stdout_redirector.fp = self.outputBuffer 175 | stderr_redirector.fp = self.outputBuffer 176 | self.sys_stdout = sys.stdout 177 | self.sys_stdout = sys.stderr 178 | sys.stdout = stdout_redirector 179 | sys.stderr = stderr_redirector 180 | self.start_time = time.time() 181 | 182 | def stopTest(self, test) -> None: 183 | """ 184 | 当测试用例执行完成后进行调用 185 | :return: 186 | """ 187 | self.end_time = '{0:.3} s'.format((time.time() - self.start_time)) 188 | self.result_list.append(self.get_all_result_info_tuple(test)) 189 | self.complete_output() 190 | 191 | def complete_output(self): 192 | """ 193 | Disconnect output redirection and return buffer. 194 | Safe to call multiple times. 195 | """ 196 | if self.sys_stdout: 197 | sys.stdout = self.sys_stdout 198 | sys.stderr = self.sys_stdout 199 | self.sys_stdout = None 200 | self.sys_stdout = None 201 | return self.outputBuffer.getvalue() 202 | 203 | def stopTestRun(self, title=None) -> dict: 204 | """ 205 | 所有测试执行完成后, 执行该方法 206 | :param title: 207 | :return: 208 | """ 209 | FIELDS['testPass'] = self.success_counter 210 | for item in self.result_list: 211 | item = json.loads(str(MakeResultJson(item))) 212 | FIELDS.get('testResult').append(item) 213 | FIELDS['testAll'] = len(self.result_list) 214 | FIELDS['testName'] = title if title else self.default_report_name 215 | FIELDS['testFail'] = self.failure_count 216 | FIELDS['beginTime'] = self.begin_time 217 | end_time = int(time.time()) 218 | start_time = int(time.mktime(time.strptime(self.begin_time, '%Y-%m-%d %H:%M:%S'))) 219 | FIELDS['totalTime'] = str(end_time - start_time) + 's' 220 | FIELDS['testError'] = self.error_count 221 | FIELDS['testSkip'] = self.skipped 222 | self.FIELDS = FIELDS 223 | return FIELDS 224 | 225 | def get_all_result_info_tuple(self, test) -> tuple: 226 | """ 227 | 接受test 相关信息, 并拼接成一个完成的tuple结构返回 228 | :param test: 229 | :return: 230 | """ 231 | return tuple([*self.get_testcase_property(test), self.end_time, self.status, self.case_log]) 232 | 233 | @staticmethod 234 | def error_or_failure_text(err) -> str: 235 | """ 236 | 获取sys.exc_info()的参数并返回字符串类型的数据, 去掉t6 error 237 | :param err: 238 | :return: 239 | """ 240 | return traceback.format_exception(*err) 241 | 242 | def addSuccess(self, test) -> None: 243 | """ 244 | pass 245 | :param test: 246 | :return: 247 | """ 248 | logs = [] 249 | output = self.complete_output() 250 | logs.append(output) 251 | # if self.verbosity > 1: 252 | # sys.stderr.write('ok ') 253 | # sys.stderr.write(str(test)) 254 | # sys.stderr.write('\n') 255 | # else: 256 | # sys.stderr.write('#') 257 | self.success_counter += 1 258 | self.status = '成功' 259 | self.case_log = output.split('\n') 260 | self._mirrorOutput = True # print(class_name, method_name, method_doc) 261 | 262 | def addError(self, test, err): 263 | """ 264 | add Some Error Result and infos 265 | :param test: 266 | :param err: 267 | :return: 268 | """ 269 | logs = [] 270 | output = self.complete_output() 271 | logs.append(output) 272 | logs.extend(self.error_or_failure_text(err)) 273 | self.failure_count += 1 274 | self.add_test_type('失败', logs) 275 | # if self.verbosity > 1: 276 | # sys.stderr.write('F ') 277 | # sys.stderr.write(str(test)) 278 | # sys.stderr.write('\n') 279 | # self.process_bar.update(1) 280 | # else: 281 | # sys.stderr.write('F') 282 | # self.process_bar.update(1) 283 | 284 | self._mirrorOutput = True 285 | 286 | def addFailure(self, test, err): 287 | """ 288 | add Some Failures Result and infos 289 | :param test: 290 | :param err: 291 | :return: 292 | """ 293 | logs = [] 294 | output = self.complete_output() 295 | logs.append(output) 296 | logs.extend(self.error_or_failure_text(err)) 297 | self.failure_count += 1 298 | self.add_test_type('失败', logs) 299 | # if self.verbosity > 1: 300 | # sys.stderr.write('F ') 301 | # sys.stderr.write(str(test)) 302 | # sys.stderr.write('\n') 303 | # self.process_bar.update(1) 304 | # else: 305 | # sys.stderr.write('F') 306 | # self.process_bar.update(1) 307 | 308 | self._mirrorOutput = True 309 | 310 | def addSkip(self, test, reason) -> None: 311 | """ 312 | 获取全部的跳过的case信息 313 | :param test: 314 | :param reason: 315 | :return: None 316 | """ 317 | logs = [reason] 318 | self.complete_output() 319 | self.skipped += 1 320 | self.add_test_type('跳过', logs) 321 | 322 | # if self.verbosity > 1: 323 | # sys.stderr.write('S ') 324 | # sys.stderr.write(str(test)) 325 | # sys.stderr.write('\n') 326 | # self.process_bar.update(1) 327 | # else: 328 | # sys.stderr.write('S') 329 | # self.process_bar.update(1) 330 | self._mirrorOutput = True 331 | 332 | def add_test_type(self, status: str, case_log: list) -> None: 333 | """ 334 | abstruct add test type and return tuple 335 | :param status: 336 | :param case_log: 337 | :return: 338 | """ 339 | self.status = status 340 | self.case_log = case_log 341 | 342 | @staticmethod 343 | def get_testcase_property(test) -> tuple: 344 | """ 345 | 接受一个test, 并返回一个test的class_name, method_name, method_doc属性 346 | :param test: 347 | :return: (class_name, method_name, method_doc) -> tuple 348 | """ 349 | class_name = test.__class__.__qualname__ 350 | method_name = test.__dict__['_testMethodName'] 351 | method_doc = test.__dict__['_testMethodDoc'] 352 | return class_name, method_name, method_doc 353 | 354 | 355 | class BeautifulReport(ReportTestResult, PATH): 356 | img_path = 'img/' if platform.system() != 'Windows' else 'img\\' 357 | 358 | def __init__(self, suites): 359 | super(BeautifulReport, self).__init__(suites) 360 | self.suites = suites 361 | self.log_path = None 362 | self.title = '自动化测试报告' 363 | self.filename = 'report.html' 364 | 365 | def report(self, description, filename: str = None, log_path='.'): 366 | """ 367 | 生成测试报告,并放在当前运行路径下 368 | :param log_path: 生成report的文件存储路径 369 | :param filename: 生成文件的filename 370 | :param description: 生成文件的注释 371 | :return: 372 | """ 373 | if filename: 374 | self.filename = filename if filename.endswith('.html') else filename + '.html' 375 | 376 | if description: 377 | self.title = description 378 | 379 | self.log_path = os.path.abspath(log_path) 380 | self.suites.run(result=self) 381 | self.stopTestRun(self.title) 382 | # self.output_report() 383 | # text = '测试已全部完成, 可前往{}查询测试报告'.format(self.log_path) 384 | # print(text) 385 | return self.FIELDS 386 | 387 | def output_report(self): 388 | """ 389 | 生成测试报告到指定路径下 390 | :return: 391 | """ 392 | template_path = self.config_tmp_path 393 | override_path = os.path.abspath(self.log_path) if \ 394 | os.path.abspath(self.log_path).endswith('/') else \ 395 | os.path.abspath(self.log_path) + '/' 396 | 397 | with open(template_path, 'rb') as file: 398 | body = file.readlines() 399 | with open(override_path + self.filename, 'wb') as write_file: 400 | for item in body: 401 | if item.strip().startswith(b'var resultData'): 402 | head = ' var resultData = ' 403 | item = item.decode().split(head) 404 | item[1] = head + json.dumps(self.FIELDS, ensure_ascii=False, indent=4) 405 | item = ''.join(item).encode() 406 | item = bytes(item) + b';\n' 407 | write_file.write(item) 408 | 409 | @staticmethod 410 | def img2base(img_path: str, file_name: str) -> str: 411 | """ 412 | 接受传递进函数的filename 并找到文件转换为base64格式 413 | :param img_path: 通过文件名及默认路径找到的img绝对路径 414 | :param file_name: 用户在装饰器中传递进来的问价匿名 415 | :return: 416 | """ 417 | pattern = '/' if platform != 'Windows' else '\\' 418 | 419 | with open(img_path + pattern + file_name, 'rb') as file: 420 | data = file.read() 421 | return base64.b64encode(data).decode() 422 | 423 | def add_test_img(*pargs): 424 | """ 425 | 接受若干个图片元素, 并展示在测试报告中 426 | :param pargs: 427 | :return: 428 | """ 429 | 430 | def _wrap(func): 431 | @wraps(func) 432 | def __wrap(*args, **kwargs): 433 | img_path = os.path.abspath('{}'.format(BeautifulReport.img_path)) 434 | try: 435 | result = func(*args, **kwargs) 436 | except Exception: 437 | if 'save_img' in dir(args[0]): 438 | save_img = getattr(args[0], 'save_img') 439 | save_img(func.__name__) 440 | data = BeautifulReport.img2base(img_path, pargs[0] + '.png') 441 | print(HTML_IMG_TEMPLATE.format(data, data)) 442 | sys.exit(0) 443 | print('

') 444 | 445 | if len(pargs) > 1: 446 | for parg in pargs: 447 | print(parg + ':') 448 | data = BeautifulReport.img2base(img_path, parg + '.png') 449 | print(HTML_IMG_TEMPLATE.format(data, data)) 450 | return result 451 | if not os.path.exists(img_path + pargs[0] + '.png'): 452 | return result 453 | data = BeautifulReport.img2base(img_path, pargs[0] + '.png') 454 | print(HTML_IMG_TEMPLATE.format(data, data)) 455 | return result 456 | 457 | return __wrap 458 | 459 | return _wrap 460 | -------------------------------------------------------------------------------- /test_case/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Version: 2.0.2 5 | @Project: Secure-Multi-Party-Computation-Protocol 6 | @Author: Zhan Shi 7 | @Time : 2022/5/5 12:13 8 | @File: __init__.py 9 | @License: MIT 10 | """ 11 | __name__ = 'test_case' 12 | -------------------------------------------------------------------------------- /test_case/test_smpcp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Version: 2.0.2 5 | @Project: Secure-Multi-Party-Computation-Protocol 6 | @Author: Zhan Shi 7 | @Time : 2022/5/4 15:54 8 | @File: test_smpcp.py 9 | @License: MIT 10 | """ 11 | import random 12 | import sys 13 | import unittest 14 | 15 | import gmpy2 16 | import phe 17 | 18 | from smpcp.smpcp import CloudPlatform, CloudPlatformThird, SecureMultiPartyComputationProtocol 19 | 20 | sys.path.append("test_case/") # 添加测试文件路径 21 | 22 | key_length = 2048 # TODO 密钥长度 23 | 24 | public_key, secret_key = phe.generate_paillier_keypair(n_length=key_length) # 生成密钥对 25 | 26 | cloud1 = CloudPlatform(public_key=public_key) # 云服务器1 27 | cloud2 = CloudPlatformThird(public_key=public_key, secret_key=secret_key) # 云服务器2 28 | 29 | protocol = SecureMultiPartyComputationProtocol(c1=cloud1, c2=cloud2) # 安全多方计算协议类 30 | 31 | 32 | class SMPCPTest(unittest.TestCase): 33 | """ 34 | 安全多方计算协议测试类 35 | """ 36 | 37 | def setUp(self): 38 | """ 39 | 测试前 40 | """ 41 | # 生成浮点数 42 | self.float1 = int( 43 | gmpy2.mpz_random(gmpy2.random_state( 44 | int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)), 45 | key_length))), key_length)) * random.uniform(0.1, 1.0) 46 | self.float2 = int( 47 | gmpy2.mpz_random(gmpy2.random_state( 48 | int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)), key_length))), 49 | key_length)) * random.uniform(0.1, 1.0) 50 | self.float_n1 = protocol.encode(public_key.encrypt(self.float1)) 51 | self.float_n2 = public_key.encrypt(self.float2) 52 | # 生成整数 53 | self.int1 = int(gmpy2.mpz_random(gmpy2.random_state( 54 | int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)), key_length))), 55 | key_length)) 56 | self.int2 = int(gmpy2.mpz_random(gmpy2.random_state( 57 | int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)), key_length))), 58 | key_length)) 59 | self.int_n1 = protocol.encode(public_key.encrypt(self.int1)) 60 | self.int_n2 = public_key.encrypt(self.int2) 61 | return super().setUp() 62 | 63 | def tearDown(self): 64 | """ 65 | 测试后 66 | """ 67 | return super().tearDown() 68 | 69 | # TODO 安全乘法协议测试 70 | # @unittest.skip('跳过安全乘法协议') 71 | def test_mul(self): 72 | """ 73 | 安全乘法协议 74 | """ 75 | # 浮点乘法测试:经过测试,最高支持8位浮点乘法 76 | self.assertEqual(round(self.float1 * self.float2, 8), 77 | round(secret_key.decrypt(self.float_n1 * self.float_n2), 8)) 78 | 79 | # 整数乘法测试:经过测试,无明显问题 80 | self.assertEqual(self.int1 * self.int2, secret_key.decrypt(self.int_n1 * self.int_n2)) 81 | 82 | # TODO 安全除法协议测试 83 | # @unittest.skip('跳过安全除法协议') 84 | def test_div(self): 85 | """ 86 | 安全除法协议 87 | """ 88 | # 浮点除法测试:经过测试,最高支持10位浮点除法 89 | self.assertEqual(round(self.float1 / self.float2, 10), 90 | round(secret_key.decrypt(self.float_n1 / self.float_n2), 10)) 91 | 92 | # 整数除法测试:经过测试,最高支持10位整数除法 93 | self.assertEqual(round(self.int1 / self.int2, 10), round(secret_key.decrypt(self.int_n1 / self.int_n2), 10)) 94 | 95 | # TODO 安全最值计算协议测试 96 | # @unittest.skip('跳过安全最值计算协议') 97 | def test_optimum(self): 98 | """ 99 | 安全最值计算协议 100 | """ 101 | mode = 'max' if random.random() > 0.5 else 'min' 102 | if mode == 'max': 103 | # 浮点最大值计算测试:经过测试,无明显问题 104 | self.assertEqual(max(self.float1, self.float2), 105 | secret_key.decrypt(self.float_n1.optimum(self.float_n2, 'max'))) 106 | 107 | # 整数最大值计算测试:经过测试,无明显问题 108 | self.assertEqual(max(self.int1, self.int2), secret_key.decrypt(self.int_n1.optimum(self.int_n2, 'max'))) 109 | else: 110 | # 浮点最小值计算测试:经过测试,无明显问题 111 | self.assertEqual(min(self.float1, self.float2), 112 | secret_key.decrypt(self.float_n1.optimum(self.float_n2, 'min'))) 113 | 114 | # 整数最小值计算测试:经过测试,无明显问题 115 | self.assertEqual(min(self.int1, self.int2), secret_key.decrypt(self.int_n1.optimum(self.int_n2, 'min'))) 116 | 117 | # TODO 安全奇偶性判断协议测试 118 | # @unittest.skip('跳过安全奇偶性判断协议') 119 | def test_parity(self): 120 | """ 121 | 安全奇偶性判断协议 122 | """ 123 | # 整数奇偶性判断测试:经过测试,无明显问题 124 | self.assertEqual(self.int1 % 2, secret_key.decrypt(self.int_n1.parity())) 125 | 126 | # TODO 安全二进制分解协议测试 127 | # @unittest.skip('跳过安全二进制分解协议') 128 | def test_bit_dec(self): 129 | """ 130 | 安全二进制分解协议 131 | """ 132 | # 整数二进制分解测试:经过测试,无明显问题 133 | bit = len(bin(self.int1).split('b')[1]) 134 | result = ''.join([str(secret_key.decrypt(v)) for v in self.int_n1.bit_dec(bit)]) 135 | self.assertEqual(bin(self.int1).split('b')[1], result) 136 | 137 | # TODO 安全二进制与协议测试 138 | # @unittest.skip('跳过安全二进制与协议') 139 | def test_and(self): 140 | """ 141 | 安全二进制与协议 142 | """ 143 | bit1 = random.SystemRandom().randint(0, 1) 144 | bit2 = random.SystemRandom().randint(0, 1) 145 | bit_n1 = protocol.encode(public_key.encrypt(bit1)) 146 | bit_n2 = public_key.encrypt(bit2) 147 | # 二进制或测试:经过测试,无明显问题 148 | self.assertEqual(bit1 & bit2, secret_key.decrypt(bit_n1 & bit_n2)) 149 | 150 | # TODO 安全二进制或协议测试 151 | # @unittest.skip('跳过安全二进制或协议') 152 | def test_or(self): 153 | """ 154 | 安全二进制或协议 155 | """ 156 | bit1 = random.SystemRandom().randint(0, 1) 157 | bit2 = random.SystemRandom().randint(0, 1) 158 | bit_n1 = protocol.encode(public_key.encrypt(bit1)) 159 | bit_n2 = public_key.encrypt(bit2) 160 | # 二进制或测试:经过测试,无明显问题 161 | self.assertEqual(bit1 | bit2, secret_key.decrypt(bit_n1 | bit_n2)) 162 | 163 | # TODO 安全二进制非协议测试 164 | # @unittest.skip('跳过安全二进制非协议') 165 | def test_bit_not(self): 166 | """ 167 | 安全二进制非协议 168 | """ 169 | bit1 = random.SystemRandom().randint(0, 1) 170 | bit_n1 = protocol.encode(public_key.encrypt(bit1)) 171 | # 二进制或测试:经过测试,无明显问题 172 | self.assertEqual(1 - bit1, secret_key.decrypt(bit_n1.bit_not())) 173 | 174 | # TODO 安全二进制异或协议测试 175 | # @unittest.skip('跳过安全二进制异或协议') 176 | def test_xor(self): 177 | """ 178 | 安全二进制异或协议 179 | """ 180 | bit1 = random.SystemRandom().randint(0, 1) 181 | bit2 = random.SystemRandom().randint(0, 1) 182 | bit_n1 = protocol.encode(public_key.encrypt(bit1)) 183 | bit_n2 = public_key.encrypt(bit2) 184 | # 二进制或测试:经过测试,无明显问题 185 | self.assertEqual(bit1 ^ bit2, secret_key.decrypt(bit_n1 ^ bit_n2)) 186 | 187 | # TODO 安全相等协议测试 188 | # @unittest.skip('跳过安全相等协议') 189 | def test_eq(self): 190 | """ 191 | 安全相等协议 192 | """ 193 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果 194 | self.assertEqual(1 if self.float1 == self.float1 else 0, 195 | secret_key.decrypt(self.float_n1 == self.float_n1.decode())) 196 | 197 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果 198 | self.assertEqual(1 if self.int1 == self.int1 else 0, secret_key.decrypt(self.int_n1 == self.int_n1.decode())) 199 | 200 | # TODO 安全不相等协议测试 201 | # @unittest.skip('跳过安全不相等协议') 202 | def test_ne(self): 203 | """ 204 | 安全不相等协议 205 | """ 206 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果 207 | self.assertEqual(1 if self.float1 != self.float2 else 0, secret_key.decrypt(self.float_n1 != self.float_n2)) 208 | 209 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果 210 | self.assertEqual(1 if self.int1 != self.int2 else 0, secret_key.decrypt(self.int_n1 != self.int_n2)) 211 | 212 | # TODO 安全大于协议测试 213 | # @unittest.skip('跳过安全大于协议') 214 | def test_gt(self): 215 | """ 216 | 安全大于协议 217 | """ 218 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果 219 | self.assertEqual(1 if self.float1 > self.float2 else 0, secret_key.decrypt(self.float_n1 > self.float_n2)) 220 | 221 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果 222 | self.assertEqual(1 if self.int1 > self.int2 else 0, secret_key.decrypt(self.int_n1 > self.int_n2)) 223 | 224 | # TODO 安全大于等于协议测试 225 | # @unittest.skip('跳过安全大于等于协议') 226 | def test_ge(self): 227 | """ 228 | 安全大于等于协议 229 | """ 230 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果 231 | self.assertEqual(1 if self.float1 >= self.float2 else 0, secret_key.decrypt(self.float_n1 >= self.float_n2)) 232 | 233 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果 234 | self.assertEqual(1 if self.int1 >= self.int2 else 0, secret_key.decrypt(self.int_n1 >= self.int_n2)) 235 | 236 | # TODO 安全小于协议测试 237 | # @unittest.skip('跳过安全小于协议') 238 | def test_lt(self): 239 | """ 240 | 安全小于协议 241 | """ 242 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果 243 | self.assertEqual(1 if self.float1 < self.float2 else 0, secret_key.decrypt(self.float_n1 < self.float_n2)) 244 | 245 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果 246 | self.assertEqual(1 if self.int1 < self.int2 else 0, secret_key.decrypt(self.int_n1 < self.int_n2)) 247 | 248 | # TODO 安全小于等于协议测试 249 | # @unittest.skip('跳过安全小于等于协议') 250 | def test_le(self): 251 | """ 252 | 安全小于等于协议 253 | """ 254 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果 255 | self.assertEqual(1 if self.float1 <= self.float2 else 0, secret_key.decrypt(self.float_n1 <= self.float_n2)) 256 | 257 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果 258 | self.assertEqual(1 if self.int1 <= self.int2 else 0, secret_key.decrypt(self.int_n1 <= self.int_n2)) 259 | --------------------------------------------------------------------------------