├── .gitignore
├── LICENSE
├── README.md
├── __about__.py
├── __init__.py
├── example.py
├── requirements.txt
├── smpcp
├── __init__.py
└── smpcp.py
├── test.py
└── test_case
├── BeautifulReport.py
├── __init__.py
├── report
├── CSVReport.csv
└── HTMLReport.html
├── template
└── template
└── test_smpcp.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 | /__pycache__
5 | /.idea
6 | /pypi
7 | # Datasource local storage ignored files
8 | /dataSources/
9 | /dataSources.local.xml
10 | # 基于编辑器的 HTTP 客户端请求
11 | /httpRequests/
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Zhan Shi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
安全多方计算协议
2 |
3 |
4 |
5 |
6 |
7 |
8 | ---
9 |
10 | ## 项目背景
11 |
12 | 安全多方计算(Secure Multi-Party Computation)的研究主要是针对无可信第三方的情况下,如何安全地计算一个约定函数的问题。安全多方计算是电子选举、门限签名以及电子拍卖等诸多应用得以实施的密码学基础。
13 |
14 | 一个安全多方计算协议,如果对于拥有无限计算能力攻击者而言是安全的,则称作是信息论安全的或无条件安全的;如果对于拥有多项式计算能力的攻击者是安全的,则称为是密码学安全的或条件安全的。
15 |
16 | 已有的结果证明了在无条件安全模型下,当且仅当恶意参与者的人数少于总人数的1/3时,安全的方案才存在。而在条件安全模型下,当且仅当恶意参与者的人数少于总人数的一半时,安全的方案才存在。
17 |
18 | 安全多方计算起源于1982年[姚期智](https://baike.baidu.com/item/姚期智)的百万富翁问题。后来Oded Goldreich有比较细致系统的论述。
19 |
20 | 基于[phe](https://github.com/data61/python-paillier)库 (Paillier Homomorphic Encryption) 的安全多方计算协议实现,包含:
21 |
22 | - 安全乘法协议
23 | - 安全除法协议
24 | - 安全最大值计算协议
25 | - 安全最小值计算协议
26 | - 安全奇偶性判断协议
27 | - 安全二进制分解协议
28 | - 安全二进制与协议
29 | - 安全二进制或协议
30 | - 安全二进制非协议
31 | - 安全二进制异或协议
32 | - 安全相等协议
33 | - 安全不相等协议
34 | - 安全大于协议
35 | - 安全大于等于协议
36 | - 安全小于协议
37 | - 安全小于等于协议
38 |
39 | ---
40 |
41 | ## 项目环境
42 |
43 | - `python3.8`
44 | - `gmpy2>=2.0.8`
45 | - `pandas>=1.2.4`
46 | - `phe>=1.4.0`
47 | - `tqdm>=4.59.0`
48 | - `numpy>=1.20.2`
49 |
50 | 详见`requirements.txt`。
51 |
52 | ---
53 |
54 | ## 项目示例
55 |
56 | ### 准备工作
57 |
58 | 安全依赖环境: `pip install -r requirements.txt`
59 |
60 | 安装`smpcp`库: `pip install smpcp`
61 |
62 | 引入`phe`库: `import phe`
63 |
64 | 引入`smpcp`库: `from smpcp import CloudPlatform, CloudPlatformThird, SecureMultiPartyComputationProtocol`
65 |
66 | ### 生成密钥
67 |
68 | ```python
69 | public_key, secret_key = phe.generate_paillier_keypair(n_length=2048)
70 | ```
71 |
72 | 其中`n_length`为密钥长度。
73 |
74 | ### 定义云服务器
75 |
76 | ```python
77 | cloud1 = CloudPlatform(public_key=public_key)
78 | cloud2 = CloudPlatformThird(public_key=public_key, secret_key=secret_key)
79 | ```
80 |
81 | ### 定义安全多方计算协议
82 |
83 | ```python
84 | protocol = SecureMultiPartyComputationProtocol(c1=cloud1, c2=cloud2)
85 | ```
86 |
87 | ### 编码
88 |
89 | ```python
90 | n1 = protocol.encode(public_key.encrypt(6))
91 | n2 = public_key.encrypt(3)
92 | b1 = protocol.encode(public_key.encrypt(1))
93 | b2 = public_key.encrypt(0)
94 | ```
95 |
96 | ### 解码
97 |
98 | ```python
99 | assert secret_key.decrypt(n1.decode()) == 6
100 | ```
101 |
102 | ### 安全多方计算协议实现
103 |
104 | ```python
105 | # TODO 安全乘法协议
106 | assert secret_key.decrypt(n1 * n2) == 18
107 | # TODO 安全除法协议
108 | assert secret_key.decrypt(n1 / n2) == 2
109 | # TODO 安全最大值协议
110 | assert secret_key.decrypt(n1.optimum(n2, 'max')) == 6
111 | # TODO 安全最小值协议
112 | assert secret_key.decrypt(n1.optimum(n2, 'min')) == 3
113 | # TODO 安全奇偶性判断协议
114 | assert secret_key.decrypt(n1.parity()) == 0
115 | assert secret_key.decrypt(protocol.encode(n2).parity()) == 1
116 | # TODO 安全二进制分解协议
117 | bit = []
118 | for v in n1.bit_dec(3):
119 | bit.append(secret_key.decrypt(v))
120 | assert bit == [1, 1, 0]
121 | # TODO 安全二进制与协议
122 | assert secret_key.decrypt(b1 | b2) == 1
123 | # TODO 安全二进制或协议
124 | assert secret_key.decrypt(b1 & b2) == 0
125 | # TODO 安全二进制非协议
126 | assert secret_key.decrypt(b1.bit_not()) == 0
127 | # TODO 安全二进制异或协议
128 | assert secret_key.decrypt(b1 ^ b2) == 1
129 | # TODO 安全相等协议
130 | assert secret_key.decrypt(n1 == n2) == 0
131 | assert secret_key.decrypt(n1 == n2 * 2) == 1
132 | # TODO 安全不相等协议
133 | assert secret_key.decrypt(n1 != n2) == 1
134 | assert secret_key.decrypt(n1 != n2 * 2) == 0
135 | # TODO 安全大于协议
136 | assert secret_key.decrypt(n1 > n2) == 1
137 | assert secret_key.decrypt(n2 > n1) == 0
138 | # TODO 安全大于等于协议
139 | assert secret_key.decrypt(n1 >= n2) == 1
140 | assert secret_key.decrypt(n2 >= n1) == 0
141 | # TODO 安全小于协议
142 | assert secret_key.decrypt(n1 < n2) == 0
143 | assert secret_key.decrypt(n2 < n1) == 1
144 | # TODO 安全小于等于协议
145 | assert secret_key.decrypt(n1 <= n2) == 0
146 | assert secret_key.decrypt(n2 <= n1) == 1
147 | ```
148 |
149 | 详见`example.py`。
150 |
151 | ---
152 |
153 | ## 项目测试
154 |
155 | 经过`Unit Test`测试,测试结果如下:
156 |
157 | ```python
158 | key_length = 2048 # TODO 密钥长度
159 |
160 | public_key, secret_key = phe.generate_paillier_keypair(n_length=key_length) # 生成密钥对
161 |
162 | cloud1 = CloudPlatform(public_key=public_key) # 云服务器1
163 | cloud2 = CloudPlatformThird(public_key=public_key, secret_key=secret_key) # 云服务器2
164 |
165 | protocol = SecureMultiPartyComputationProtocol(c1=cloud1, c2=cloud2) # 安全多方计算协议类
166 |
167 |
168 | class SMPCPTest(unittest.TestCase):
169 | """
170 | 安全多方计算协议测试类
171 | """
172 |
173 | def setUp(self):
174 | """
175 | 测试前
176 | """
177 | # 生成浮点数
178 | self.float1 = int(
179 | gmpy2.mpz_random(gmpy2.random_state(
180 | int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)),
181 | key_length))), key_length)) * random.uniform(0.1, 1.0)
182 | self.float2 = int(
183 | gmpy2.mpz_random(gmpy2.random_state(
184 | int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)), key_length))),
185 | key_length)) * random.uniform(0.1, 1.0)
186 | self.float_n1 = protocol.encode(public_key.encrypt(self.float1))
187 | self.float_n2 = public_key.encrypt(self.float2)
188 | # 生成整数
189 | self.int1 = int(gmpy2.mpz_random(gmpy2.random_state(
190 | int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)), key_length))),
191 | key_length))
192 | self.int2 = int(gmpy2.mpz_random(gmpy2.random_state(
193 | int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)), key_length))),
194 | key_length))
195 | self.int_n1 = protocol.encode(public_key.encrypt(self.int1))
196 | self.int_n2 = public_key.encrypt(self.int2)
197 | return super().setUp()
198 |
199 | def tearDown(self):
200 | """
201 | 测试后
202 | """
203 | return super().tearDown()
204 |
205 | # TODO 安全乘法协议测试
206 | # @unittest.skip('跳过安全乘法协议')
207 | def test_mul(self):
208 | """
209 | 安全乘法协议
210 | """
211 | # 浮点乘法测试:经过测试,最高支持8位浮点乘法
212 | self.assertEqual(round(self.float1 * self.float2, 8),
213 | round(secret_key.decrypt(self.float_n1 * self.float_n2), 8))
214 |
215 | # 整数乘法测试:经过测试,无明显问题
216 | self.assertEqual(self.int1 * self.int2, secret_key.decrypt(self.int_n1 * self.int_n2))
217 |
218 | # TODO 安全除法协议测试
219 | # @unittest.skip('跳过安全除法协议')
220 | def test_div(self):
221 | """
222 | 安全除法协议
223 | """
224 | # 浮点除法测试:经过测试,最高支持10位浮点除法
225 | self.assertEqual(round(self.float1 / self.float2, 10),
226 | round(secret_key.decrypt(self.float_n1 / self.float_n2), 10))
227 |
228 | # 整数除法测试:经过测试,最高支持10位整数除法
229 | self.assertEqual(round(self.int1 / self.int2, 10), round(secret_key.decrypt(self.int_n1 / self.int_n2), 10))
230 |
231 | # TODO 安全最值计算协议测试
232 | # @unittest.skip('跳过安全最值计算协议')
233 | def test_optimum(self):
234 | """
235 | 安全最值计算协议
236 | """
237 | mode = 'max' if random.random() > 0.5 else 'min'
238 | if mode == 'max':
239 | # 浮点最大值计算测试:经过测试,无明显问题
240 | self.assertEqual(max(self.float1, self.float2),
241 | secret_key.decrypt(self.float_n1.optimum(self.float_n2, 'max')))
242 |
243 | # 整数最大值计算测试:经过测试,无明显问题
244 | self.assertEqual(max(self.int1, self.int2), secret_key.decrypt(self.int_n1.optimum(self.int_n2, 'max')))
245 | else:
246 | # 浮点最小值计算测试:经过测试,无明显问题
247 | self.assertEqual(min(self.float1, self.float2),
248 | secret_key.decrypt(self.float_n1.optimum(self.float_n2, 'min')))
249 |
250 | # 整数最小值计算测试:经过测试,无明显问题
251 | self.assertEqual(min(self.int1, self.int2), secret_key.decrypt(self.int_n1.optimum(self.int_n2, 'min')))
252 |
253 | # TODO 安全奇偶性判断协议测试
254 | # @unittest.skip('跳过安全奇偶性判断协议')
255 | def test_parity(self):
256 | """
257 | 安全奇偶性判断协议
258 | """
259 | # 整数奇偶性判断测试:经过测试,无明显问题
260 | self.assertEqual(self.int1 % 2, secret_key.decrypt(self.int_n1.parity()))
261 |
262 | # TODO 安全二进制分解协议测试
263 | # @unittest.skip('跳过安全二进制分解协议')
264 | def test_bit_dec(self):
265 | """
266 | 安全二进制分解协议
267 | """
268 | # 整数二进制分解测试:经过测试,无明显问题
269 | bit = len(bin(self.int1).split('b')[1])
270 | result = ''.join([str(secret_key.decrypt(v)) for v in self.int_n1.bit_dec(bit)])
271 | self.assertEqual(bin(self.int1).split('b')[1], result)
272 |
273 | # TODO 安全二进制与协议测试
274 | # @unittest.skip('跳过安全二进制与协议')
275 | def test_and(self):
276 | """
277 | 安全二进制与协议
278 | """
279 | bit1 = random.SystemRandom().randint(0, 1)
280 | bit2 = random.SystemRandom().randint(0, 1)
281 | bit_n1 = protocol.encode(public_key.encrypt(bit1))
282 | bit_n2 = public_key.encrypt(bit2)
283 | # 二进制或测试:经过测试,无明显问题
284 | self.assertEqual(bit1 & bit2, secret_key.decrypt(bit_n1 & bit_n2))
285 |
286 | # TODO 安全二进制或协议测试
287 | # @unittest.skip('跳过安全二进制或协议')
288 | def test_or(self):
289 | """
290 | 安全二进制或协议
291 | """
292 | bit1 = random.SystemRandom().randint(0, 1)
293 | bit2 = random.SystemRandom().randint(0, 1)
294 | bit_n1 = protocol.encode(public_key.encrypt(bit1))
295 | bit_n2 = public_key.encrypt(bit2)
296 | # 二进制或测试:经过测试,无明显问题
297 | self.assertEqual(bit1 | bit2, secret_key.decrypt(bit_n1 | bit_n2))
298 |
299 | # TODO 安全二进制非协议测试
300 | # @unittest.skip('跳过安全二进制非协议')
301 | def test_bit_not(self):
302 | """
303 | 安全二进制非协议
304 | """
305 | bit1 = random.SystemRandom().randint(0, 1)
306 | bit_n1 = protocol.encode(public_key.encrypt(bit1))
307 | # 二进制或测试:经过测试,无明显问题
308 | self.assertEqual(1 - bit1, secret_key.decrypt(bit_n1.bit_not()))
309 |
310 | # TODO 安全二进制异或协议测试
311 | # @unittest.skip('跳过安全二进制异或协议')
312 | def test_xor(self):
313 | """
314 | 安全二进制异或协议
315 | """
316 | bit1 = random.SystemRandom().randint(0, 1)
317 | bit2 = random.SystemRandom().randint(0, 1)
318 | bit_n1 = protocol.encode(public_key.encrypt(bit1))
319 | bit_n2 = public_key.encrypt(bit2)
320 | # 二进制或测试:经过测试,无明显问题
321 | self.assertEqual(bit1 ^ bit2, secret_key.decrypt(bit_n1 ^ bit_n2))
322 |
323 | # TODO 安全相等协议测试
324 | # @unittest.skip('跳过安全相等协议')
325 | def test_eq(self):
326 | """
327 | 安全相等协议
328 | """
329 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果
330 | self.assertEqual(1 if self.float1 == self.float1 else 0,
331 | secret_key.decrypt(self.float_n1 == self.float_n1.decode()))
332 |
333 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果
334 | self.assertEqual(1 if self.int1 == self.int1 else 0, secret_key.decrypt(self.int_n1 == self.int_n1.decode()))
335 |
336 | # TODO 安全不相等协议测试
337 | # @unittest.skip('跳过安全不相等协议')
338 | def test_ne(self):
339 | """
340 | 安全不相等协议
341 | """
342 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果
343 | self.assertEqual(1 if self.float1 != self.float2 else 0, secret_key.decrypt(self.float_n1 != self.float_n2))
344 |
345 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果
346 | self.assertEqual(1 if self.int1 != self.int2 else 0, secret_key.decrypt(self.int_n1 != self.int_n2))
347 |
348 | # TODO 安全大于协议测试
349 | # @unittest.skip('跳过安全大于协议')
350 | def test_gt(self):
351 | """
352 | 安全大于协议
353 | """
354 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果
355 | self.assertEqual(1 if self.float1 > self.float2 else 0, secret_key.decrypt(self.float_n1 > self.float_n2))
356 |
357 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果
358 | self.assertEqual(1 if self.int1 > self.int2 else 0, secret_key.decrypt(self.int_n1 > self.int_n2))
359 |
360 | # TODO 安全大于等于协议测试
361 | # @unittest.skip('跳过安全大于等于协议')
362 | def test_ge(self):
363 | """
364 | 安全大于等于协议
365 | """
366 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果
367 | self.assertEqual(1 if self.float1 >= self.float2 else 0, secret_key.decrypt(self.float_n1 >= self.float_n2))
368 |
369 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果
370 | self.assertEqual(1 if self.int1 >= self.int2 else 0, secret_key.decrypt(self.int_n1 >= self.int_n2))
371 |
372 | # TODO 安全小于协议测试
373 | # @unittest.skip('跳过安全小于协议')
374 | def test_lt(self):
375 | """
376 | 安全小于协议
377 | """
378 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果
379 | self.assertEqual(1 if self.float1 < self.float2 else 0, secret_key.decrypt(self.float_n1 < self.float_n2))
380 |
381 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果
382 | self.assertEqual(1 if self.int1 < self.int2 else 0, secret_key.decrypt(self.int_n1 < self.int_n2))
383 |
384 | # TODO 安全小于等于协议测试
385 | # @unittest.skip('跳过安全小于等于协议')
386 | def test_le(self):
387 | """
388 | 安全小于等于协议
389 | """
390 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果
391 | self.assertEqual(1 if self.float1 <= self.float2 else 0, secret_key.decrypt(self.float_n1 <= self.float_n2))
392 |
393 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果
394 | self.assertEqual(1 if self.int1 <= self.int2 else 0, secret_key.decrypt(self.int_n1 <= self.int_n2))
395 | ```
396 |
397 | 详见`test_case/test_smpcp.py`, 项目报告依赖基于`unittest`的[项目](https://github.com/TesterlifeRaymond/BeautifulReport)`test_case/BeautifulReport.py`。
398 |
399 | ---
400 |
401 | ## 联系方式
402 |
403 | 作者:沈阳航空航天大学 数据安全与隐私计算课题组 施展
404 |
405 | Github: https://github.com/shine813/
406 |
407 | Pypi: https://pypi.org/project/smpcp/
408 |
409 | 邮箱:phe.zshi@gmail.com
410 |
411 | 如有问题,可及时联系作者
412 |
--------------------------------------------------------------------------------
/__about__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | __all__ = [
4 | "__title__", "__summary__", "__uri__",
5 | "__version__", "__author__", "__email__"
6 | ]
7 |
8 | __title__ = "smpcp"
9 | __summary__ = "Secure Multi-Party Computation Protocol base on Partially Homomorphic Encryption for Python"
10 | __uri__ = "https://github.com/shine813/Secure-Multi-Party-Computation-Protocol/"
11 |
12 | __version__ = "2.0.2"
13 |
14 | __author__ = "Zhan Shi"
15 | __email__ = "phe.zshi@gmail.com"
16 |
17 | __license__ = "MIT"
18 | __copyright__ = "Copyright (c) 2021 {0}".format(__author__)
19 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | """
4 | @Version: 2.0.2
5 | @Project: Secure-Multi-Party-Computation-Protocol
6 | @Author: Zhan Shi
7 | @Time : 2022/5/5 12:17
8 | @File: __init__.py
9 | @License: MIT
10 | """
11 |
12 | __name__ = 'Secure-Multi-Party-Computation-Protocol'
13 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | """
4 | @Version: 2.0.2
5 | @Project: Secure-Multi-Party-Computation-Protocol
6 | @Author: Zhan Shi
7 | @Time : 2022/5/1 09:35
8 | @File: example.py
9 | @License: MIT
10 | """
11 | import phe
12 |
13 | from smpcp.smpcp import CloudPlatform, CloudPlatformThird, SecureMultiPartyComputationProtocol
14 |
15 | # TODO 生成密钥
16 | public_key, secret_key = phe.generate_paillier_keypair(n_length=128)
17 | # TODO 定义云服务器
18 | cloud1 = CloudPlatform(public_key=public_key)
19 | cloud2 = CloudPlatformThird(public_key=public_key, secret_key=secret_key)
20 | # TODO 定义安全多方计算协议
21 | protocol = SecureMultiPartyComputationProtocol(c1=cloud1, c2=cloud2)
22 |
23 | if __name__ == '__main__':
24 | # TODO 安全多方计算协议编码
25 | n1 = protocol.encode(public_key.encrypt(6))
26 | n2 = public_key.encrypt(3)
27 | b1 = protocol.encode(public_key.encrypt(1))
28 | b2 = public_key.encrypt(0)
29 | # TODO 协议解码
30 | assert secret_key.decrypt(n1.decode()) == 6
31 | # TODO 安全乘法协议
32 | assert secret_key.decrypt(n1 * n2) == 18
33 | # TODO 安全除法协议
34 | assert secret_key.decrypt(n1 / n2) == 2
35 | # TODO 安全最大值协议
36 | assert secret_key.decrypt(n1.optimum(n2, 'max')) == 6
37 | # TODO 安全最小值协议
38 | assert secret_key.decrypt(n1.optimum(n2, 'min')) == 3
39 | # TODO 安全奇偶性判断协议
40 | assert secret_key.decrypt(n1.parity()) == 0
41 | assert secret_key.decrypt(protocol.encode(n2).parity()) == 1
42 | # TODO 安全二进制分解协议
43 | bit = []
44 | for v in n1.bit_dec(3):
45 | bit.append(secret_key.decrypt(v))
46 | assert bit == [1, 1, 0]
47 | # TODO 安全二进制与协议
48 | assert secret_key.decrypt(b1 & b2) == 1
49 | # TODO 安全二进制或协议
50 | assert secret_key.decrypt(b1 | b2) == 0
51 | # TODO 安全二进制非协议
52 | assert secret_key.decrypt(b1.bit_not()) == 0
53 | # TODO 安全二进制异或协议
54 | assert secret_key.decrypt(b1 ^ b2) == 1
55 | # TODO 安全相等协议
56 | assert secret_key.decrypt(n1 == n2) == 0
57 | assert secret_key.decrypt(n1 == n2 * 2) == 1
58 | # TODO 安全不相等协议
59 | assert secret_key.decrypt(n1 != n2) == 1
60 | assert secret_key.decrypt(n1 != n2 * 2) == 0
61 | # TODO 安全大于协议
62 | assert secret_key.decrypt(n1 > n2) == 1
63 | assert secret_key.decrypt(n2 > n1) == 0
64 | # TODO 安全大于等于协议
65 | assert secret_key.decrypt(n1 >= n2) == 1
66 | assert secret_key.decrypt(n2 >= n1) == 0
67 | # TODO 安全小于协议
68 | assert secret_key.decrypt(n1 < n2) == 0
69 | assert secret_key.decrypt(n2 < n1) == 1
70 | # TODO 安全小于等于协议
71 | assert secret_key.decrypt(n1 <= n2) == 0
72 | assert secret_key.decrypt(n2 <= n1) == 1
73 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | gmpy2==2.0.8
2 | pandas==1.2.4
3 | phe==1.4.0
4 | tqdm==4.59.0
5 | numpy==1.20.2
6 |
--------------------------------------------------------------------------------
/smpcp/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | """
4 | @Version: 2.0.2
5 | @Project: Secure-Multi-Party-Computation-Protocol
6 | @Author: Zhan Shi
7 | @Time : 2022/5/6 13:26
8 | @File: __init__.py
9 | @License: MIT
10 | """
11 |
12 | __name__ = 'smpcp'
13 |
--------------------------------------------------------------------------------
/smpcp/smpcp.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | """
4 | @Version: 2.0.2
5 | @Project: Secure-Multi-Party-Computation-Protocol
6 | @Author: Zhan Shi
7 | @Time : 2022/5/3 12:23
8 | @File: smpcp.py
9 | @License: MIT
10 | """
11 | import random
12 |
13 | import gmpy2
14 |
15 |
16 | class SecureMultiPartyComputationProtocol:
17 | """
18 | 安全多方计算协议类
19 | """
20 |
21 | def __init__(self, c1, c2, cipher=None):
22 | """
23 | 安全多方计算协议类 定义
24 | :param c1: 云服务器
25 | :param c2: 第三方云服务器
26 | :param cipher: 密文
27 | """
28 | self.c1 = c1
29 | self.c2 = c2
30 | self.cipher = cipher
31 |
32 | def encode(self, encrypted_number):
33 | """
34 | 安全多方计算协议类 编码
35 | :param encrypted_number: 加密数字
36 | :return: 编码后的加密数字
37 | """
38 | return SecureMultiPartyComputationProtocol(c1=self.c1, c2=self.c2, cipher=encrypted_number)
39 |
40 | def decode(self):
41 | """
42 | 安全多方计算协议类 解码
43 | :return: 解码后的加密数字
44 | """
45 | return self.cipher
46 |
47 | def __mul__(self, other):
48 | """
49 | TODO 安全多方计算协议类 安全乘法协议
50 | :param other: 密文
51 | :return: 安全乘法协议结果 E(self.cipher * other)
52 | """
53 | return self.c1.mul(self.cipher, other, self.c2)
54 |
55 | def __truediv__(self, other):
56 | """
57 | TODO 安全多方计算协议类 安全除法协议
58 | :param other: 密文
59 | :return: 安全除法协议结果 E(self.cipher / other)
60 | """
61 | return self.c1.truediv(self.cipher, other, self.c2)
62 |
63 | def optimum(self, other, mode):
64 | """
65 | TODO 安全多方计算协议类 安全最值计算协议
66 | :param other: 密文
67 | :param mode: 'max' or 'min'
68 | :return: 安全最值计算协议结果 E(max(self.cipher, other)) or E(min(self.cipher, other))
69 | """
70 | return self.c1.optimum(self.cipher, other, self.c2, mode)
71 |
72 | def parity(self):
73 | """
74 | TODO 安全多方计算协议类 安全奇偶性判断协议
75 | :return: 安全奇偶性判断协议结果 奇数: E(1) 偶数: E(0)
76 | """
77 | return self.c1.parity(self.cipher, self.c2)
78 |
79 | def bit_dec(self, bit):
80 | """
81 | TODO 安全多方计算协议类 安全二进制分解协议
82 | :param bit: 位数
83 | :return: 安全二进制分解协议结果 self.cipher的二进制数列 -> [E(1) or E(0), ...] 长度为bit
84 | """
85 | return self.c1.bit_dec(self.cipher, bit, self.c2)
86 |
87 | def __and__(self, other):
88 | """
89 | TODO 安全多方计算协议类 安全二进制与协议
90 | ! 只能用于二进制数
91 | :param other: 密文
92 | :return: 安全二进制与协议结果 E(self.cipher & other)
93 | """
94 | return self.c1.bit_and(self.cipher, other, self.c2)
95 |
96 | def __or__(self, other):
97 | """
98 | TODO 安全多方计算协议类 安全二进制或协议
99 | ! 只能用于二进制数
100 | :param other: 密文
101 | :return: 安全二进制或协议结果 E(self.cipher | other)
102 | """
103 | return self.c1.bit_or(self.cipher, other, self.c2)
104 |
105 | def bit_not(self):
106 | """
107 | TODO 安全多方计算协议类 安全二进制非协议
108 | ! 只能用于二进制数
109 | :return: 安全二进制非协议结果 E(!self.cipher)
110 | """
111 | return self.c1.bit_not(self.cipher)
112 |
113 | def __xor__(self, other):
114 | """
115 | TODO 安全多方计算协议类 安全二进制异或协议
116 | ! 只能用于二进制数
117 | :return: 安全二进制异或协议结果 E(self.cipher ^ other)
118 | """
119 | return self.c1.bit_xor(self.cipher, other, self.c2)
120 |
121 | def __eq__(self, other):
122 | """
123 | TODO 安全多方计算协议类 安全相等协议
124 | :param other: 密文
125 | :return: 安全相等协议结果 E(self.cipher == other)
126 | """
127 | return self.c1.eq(self.cipher, other, self.c2)
128 |
129 | def __ne__(self, other):
130 | """
131 | TODO 安全多方计算协议类 安全不相等协议
132 | :param other: 密文
133 | :return: 安全不相等协议结果 E(self.cipher != other)
134 | """
135 | return self.c1.ne(self.cipher, other, self.c2)
136 |
137 | def __gt__(self, other):
138 | """
139 | TODO 安全多方计算协议类 安全大于协议
140 | :param other: 密文
141 | :return: 安全大于协议结果 E(self.cipher > other)
142 | """
143 | return self.c1.gt(self.cipher, other, self.c2)
144 |
145 | def __ge__(self, other):
146 | """
147 | TODO 安全多方计算协议类 安全大于等于协议
148 | :param other: 密文
149 | :return: 安全大于等于协议结果 E(self.cipher >= other)
150 | """
151 | return self.c1.ge(self.cipher, other, self.c2)
152 |
153 | def __lt__(self, other):
154 | """
155 | TODO 安全多方计算协议类 安全小于协议
156 | :param other: 密文
157 | :return: 安全小于协议结果 E(self.cipher < other)
158 | """
159 | return self.c1.lt(self.cipher, other, self.c2)
160 |
161 | def __le__(self, other):
162 | """
163 | TODO 安全多方计算协议类 安全小于等于协议
164 | :param other: 密文
165 | :return: 安全小于等于协议结果 E(self.cipher <= other)
166 | """
167 | return self.c1.le(self.cipher, other, self.c2)
168 |
169 |
170 | class CloudPlatform:
171 | """
172 | 云服务器类
173 | """
174 |
175 | def __init__(self, public_key):
176 | """
177 | 云服务器类 定义
178 | :param public_key: 公钥
179 | """
180 | self.public_key = public_key
181 | self.key_length = len(str(self.public_key.n))
182 |
183 | def mul(self, c1, c2, cloud_platform_third):
184 | """
185 | TODO 云服务器类 安全乘法协议
186 | :param c1: 密文1
187 | :param c2: 密文2
188 | :param cloud_platform_third: 第三方云服务器
189 | :return: 加密乘法结果
190 | """
191 | r1 = self._generate_random()
192 | r2 = self._generate_random()
193 |
194 | h1 = c1 + r1
195 | h2 = c2 + r2
196 |
197 | return cloud_platform_third.mul(h1, h2) - (c1 * r2 + c2 * r1 + r1 * r2)
198 |
199 | def truediv(self, c1, c2, cloud_platform_third):
200 | """
201 | TODO 云服务器类 安全除法协议
202 | :param c1: 密文1
203 | :param c2: 密文2
204 | :param cloud_platform_third: 第三方云服务器
205 | :return: 加密除法结果
206 | """
207 | r1 = self._generate_random()
208 | r2 = self._generate_random()
209 |
210 | h1 = c1 * r1 + c2 * r1 * r2
211 | h2 = c2 * r1
212 |
213 | return cloud_platform_third.truediv(h1, h2) - r2
214 |
215 | def optimum(self, c1, c2, cloud_platform_third, mode):
216 | """
217 | TODO 云服务器类 安全最值计算协议
218 | :param c1: 密文1
219 | :param c2: 密文2
220 | :param cloud_platform_third: 第三方云服务器
221 | :param mode: 'max' or 'min'
222 | :return: 加密最值计算结果
223 | """
224 | r1 = self._generate_random()
225 | r2 = self._generate_random()
226 | r3 = self._generate_random()
227 |
228 | if random.random() > 5e-1:
229 | h1 = (c1 - c2) * r1
230 | h2 = c1 + r2
231 | h3 = c2 + r3
232 | else:
233 | h1 = (c2 - c1) * r1
234 | h2 = c2 + r2
235 | h3 = c1 + r3
236 |
237 | alpha, beta = cloud_platform_third.optimum(h1, h2, h3, mode)
238 |
239 | return c1 + c2 - beta + alpha * r3 + (1 - alpha) * r2
240 |
241 | def parity(self, c, cloud_platform_third):
242 | """
243 | TODO 云服务器类 安全奇偶性判断协议
244 | :param c: 密文
245 | :param cloud_platform_third: 第三方云服务器
246 | :return: 加密奇偶性判断结果
247 | """
248 | r = self._generate_random()
249 | h = c + r
250 | alpha = cloud_platform_third.parity(h)
251 |
252 | return alpha if r % 2 == 0 else 1 - alpha
253 |
254 | def bit_dec(self, c, bit, cloud_platform_third):
255 | """
256 | TODO 云服务器类 安全二进制分解协议
257 | :param c: 密文
258 | :param bit: 位数
259 | :param cloud_platform_third: 第三方云服务器
260 | :return: 加密二进制分解结果
261 | """
262 | sigma = 5e-1
263 | result = []
264 | for i in range(bit):
265 | result.append(self.parity(c, cloud_platform_third))
266 | zeta = c - result[i]
267 | c = zeta * sigma
268 | result.reverse()
269 |
270 | return result
271 |
272 | def bit_and(self, c1, c2, cloud_platform_third):
273 | """
274 | TODO 云服务器类 安全二进制与协议
275 | :param c1: 密文1
276 | :param c2: 密文2
277 | :param cloud_platform_third: 第三方云服务器
278 | :return: 加密二进制与结果
279 | """
280 | return self.mul(c1, c2, cloud_platform_third)
281 |
282 | def bit_or(self, c1, c2, cloud_platform_third):
283 | """
284 | TODO 云服务器类 安全二进制或协议
285 | :param c1: 密文1
286 | :param c2: 密文2
287 | :param cloud_platform_third: 第三方云服务器
288 | :return: 加密二进制或结果
289 | """
290 | return c1 + c2 - self.bit_and(c1, c2, cloud_platform_third)
291 |
292 | @staticmethod
293 | def bit_not(c):
294 | """
295 | TODO 云服务器类 安全二进制非协议
296 | :param c: 密文
297 | :return: 加密二进制非结果
298 | """
299 | return 1 - c
300 |
301 | def bit_xor(self, c1, c2, cloud_platform_third):
302 | """
303 | TODO 云服务器类 安全二进制异或协议
304 | :param c1: 密文1
305 | :param c2: 密文2
306 | :param cloud_platform_third: 第三方云服务器
307 | :return: 加密二进制异或结果
308 | """
309 | return c1 + c2 - 2 * self.mul(c1, c2, cloud_platform_third)
310 |
311 | def eq(self, c1, c2, cloud_platform_third):
312 | """
313 | TODO 云服务器类 安全相等协议
314 | :param c1: 密文1
315 | :param c2: 密文2
316 | :param cloud_platform_third: 第三方云服务器
317 | :return: 加密相等结果
318 | """
319 | sigma = -1 if random.random() > 5e-1 else 1
320 | r1 = self._generate_random()
321 | r2 = self._generate_random()
322 | if r2 > r1:
323 | r2, r1 = r1, r2
324 | alpha = r1 * sigma * self.mul(c1 - c2, c1 - c2, cloud_platform_third) - sigma * r2
325 |
326 | return cloud_platform_third.eq(alpha) if sigma == 1 else 1 - cloud_platform_third.eq(alpha)
327 |
328 | def ne(self, c1, c2, cloud_platform_third):
329 | """
330 | TODO 云服务器类 安全不相等协议
331 | :param c1: 密文1
332 | :param c2: 密文2
333 | :param cloud_platform_third: 第三方云服务器
334 | :return: 加密不相等结果
335 | """
336 | return 1 - self.eq(c1, c2, cloud_platform_third)
337 |
338 | def gt(self, c1, c2, cloud_platform_third):
339 | """
340 | TODO 云服务器类 安全大于协议
341 | :param c1: 密文1
342 | :param c2: 密文2
343 | :param cloud_platform_third: 第三方云服务器
344 | :return: 加密大于结果
345 | """
346 | sigma = -1 if random.random() > 5e-1 else 1
347 | r1 = self._generate_random()
348 | r2 = self._generate_random()
349 | (r2, r1) = (r1, r2) if r2 > r1 else (r2, r1)
350 | alpha = r1 * sigma * (c2 - c1) + sigma * r2
351 |
352 | return cloud_platform_third.eq(alpha) if sigma == 1 else 1 - cloud_platform_third.eq(alpha)
353 |
354 | def ge(self, c1, c2, cloud_platform_third):
355 | """
356 | TODO 云服务器类 安全大于等于协议
357 | :param c1: 密文1
358 | :param c2: 密文2
359 | :param cloud_platform_third: 第三方云服务器
360 | :return: 加密大于等于结果
361 | """
362 | return self.bit_or(self.eq(c1, c2, cloud_platform_third), self.gt(c1, c2, cloud_platform_third),
363 | cloud_platform_third)
364 |
365 | def lt(self, c1, c2, cloud_platform_third):
366 | """
367 | TODO 云服务器类 安全小于协议
368 | :param c1: 密文1
369 | :param c2: 密文2
370 | :param cloud_platform_third: 第三方云服务器
371 | :return: 加密小于结果
372 | """
373 | sigma = -1 if random.random() > 5e-1 else 1
374 | r1 = self._generate_random()
375 | r2 = self._generate_random()
376 | (r2, r1) = (r1, r2) if r2 > r1 else (r2, r1)
377 | alpha = r1 * sigma * (c1 - c2) + sigma * r2
378 |
379 | return cloud_platform_third.eq(alpha) if sigma == 1 else 1 - cloud_platform_third.eq(alpha)
380 |
381 | def le(self, c1, c2, cloud_platform_third):
382 | """
383 | TODO 云服务器类 安全小于等于协议
384 | :param c1: 密文1
385 | :param c2: 密文2
386 | :param cloud_platform_third: 第三方云服务器
387 | :return: 加密小于等于结果
388 | """
389 | return self.bit_or(self.eq(c1, c2, cloud_platform_third), self.lt(c1, c2, cloud_platform_third),
390 | cloud_platform_third)
391 |
392 | def _generate_random(self):
393 | """
394 | 云服务器 随机数生成
395 | :return: 密钥长度的随机数
396 | """
397 | return int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)), self.key_length))
398 |
399 |
400 | class CloudPlatformThird:
401 | """
402 | 第三方云服务器类
403 | """
404 |
405 | def __init__(self, public_key, secret_key):
406 | """
407 | 第三方云服务器类 定义
408 | :param public_key: 公钥
409 | :param secret_key: 私钥
410 | """
411 | self.public_key = public_key
412 | self.secret_key = secret_key
413 |
414 | def mul(self, h1, h2):
415 | """
416 | TODO 第三方云服务器类 安全乘法协议
417 | :param h1: 参数1
418 | :param h2: 参数2
419 | :return: 安全乘法协议结果
420 | """
421 | return self.public_key.encrypt(self.secret_key.decrypt(h1) * self.secret_key.decrypt(h2))
422 |
423 | def truediv(self, h1, h2):
424 | """
425 | TODO 第三方云服务器类 安全除法协议
426 | :param h1: 参数1
427 | :param h2: 参数2
428 | :return: 安全除法协议结果
429 | """
430 | h2 = self.secret_key.decrypt(h2)
431 | if h2 != 0:
432 | return self.public_key.encrypt(self.secret_key.decrypt(h1) / h2)
433 | else:
434 | assert ValueError("Divisor cannot be 0")
435 |
436 | def optimum(self, h1, h2, h3, mode):
437 | """
438 | TODO 第三方云服务器类 安全最值计算协议
439 | :param h1: 参数
440 | :param h2: 参数
441 | :param h3: 参数
442 | :param mode: 'max' or 'min'
443 | :return: 安全最值计算协议结果
444 | """
445 | mode = self.secret_key.decrypt(h1) > 0 if mode == 'max' else self.secret_key.decrypt(h1) < 0
446 | alpha = 1 if mode else 0
447 |
448 | return self.public_key.encrypt(alpha), h3 if alpha == 1 else h2
449 |
450 | def parity(self, h):
451 | """
452 | TODO 第三方云服务器类 安全奇偶性判断协议
453 | :param h: 参数
454 | :return: 安全奇偶性判断协议结果
455 | """
456 | return self.public_key.encrypt(0) if self.secret_key.decrypt(h) % 2 == 0 else self.public_key.encrypt(1)
457 |
458 | def eq(self, h):
459 | """
460 | TODO 第三方云服务器类 安全相等协议
461 | :param h: 参数
462 | :return: 安全相等协议结果
463 | """
464 | return self.public_key.encrypt(1) if self.secret_key.decrypt(h) < 0 else self.public_key.encrypt(0)
465 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | """
4 | @Version: 2.0.2
5 | @Project: Secure-Multi-Party-Computation-Protocol
6 | @Author: Zhan Shi
7 | @Time : 2022/5/4 17:42
8 | @File: test.py
9 | @License: MIT
10 | """
11 | import json
12 | import os
13 | import sys
14 | import time
15 | import unittest
16 | from multiprocessing import Pool, cpu_count
17 |
18 | import pandas as pd
19 | import tqdm
20 |
21 | sys.path.append("test_case/") # 添加测试文件路径
22 | from BeautifulReport import BeautifulReport
23 |
24 | TEMPLATE_PATH = 'test_case/template/template' # 模板文件路径
25 | REPORT_PATH = 'test_case/report' # 报告文件路径
26 | REPORT_FILE = 'HTMLReport.html' # 报告文件
27 | DESCRIPTION = '安全多方计算协议测试' # 报告名称
28 |
29 |
30 | def format_report(_report):
31 | """
32 | 规范报告格式
33 | :param _report: 测试报告数据
34 | :return: 规范后的测试报告数据
35 | """
36 | _output = {'testPass': 0, 'testResult': [], 'testFail': 0, 'testSkip': 0, 'testError': 0}
37 |
38 | for v in _report:
39 | _output['testPass'] += v.get()['testPass']
40 | for m in v.get()['testResult']:
41 | _output.get('testResult').append(m)
42 | _output['testAll'] = len(v.get()['testResult'])
43 | _output['testName'] = v.get()['testName']
44 | _output['testFail'] += v.get()['testFail']
45 | _output['beginTime'] = v.get()['beginTime']
46 | _output['totalTime'] = v.get()['totalTime']
47 | _output['testError'] += v.get()['testError']
48 | _output['testSkip'] += v.get()['testSkip']
49 |
50 | return _output
51 |
52 |
53 | def output_report(_report):
54 | """
55 | 输出测试报告
56 | :param _report: 测试报告数据
57 | """
58 | pd.DataFrame(_report).to_csv("{0}/CSVReport.csv".format(REPORT_PATH), header=False, index=False)
59 |
60 | template_path = TEMPLATE_PATH
61 | override_path = os.path.abspath(REPORT_PATH) \
62 | if os.path.abspath(REPORT_PATH).endswith('/') \
63 | else os.path.abspath(REPORT_PATH) + '/'
64 |
65 | with open(template_path, 'rb') as file:
66 | body = file.readlines()
67 | with open(override_path + REPORT_FILE, 'wb') as write_file:
68 | for item in body:
69 | if item.strip().startswith(b'var resultData'):
70 | head = ' var resultData = '
71 | item = item.decode().split(head)
72 | item[1] = head + json.dumps(
73 | _report, ensure_ascii=False, indent=4)
74 | item = ''.join(item).encode()
75 | item = bytes(item) + b';\n'
76 | write_file.write(item)
77 |
78 |
79 | def run():
80 | """
81 | 开始测试
82 | :return: 测试报告数据
83 | """
84 | # 构造测试用例
85 | cases = unittest.defaultTestLoader.discover("test_case", pattern="test_smpcp.py", top_level_dir=None)
86 | # 读取测试用例 运行测试
87 | return BeautifulReport(cases).report(filename=REPORT_FILE, log_path=REPORT_PATH, description=DESCRIPTION)
88 |
89 |
90 | if __name__ == '__main__':
91 | start = time.time() # 开始时间
92 | times = 100 # TODO 测试次数
93 | process_pool = Pool(cpu_count()) # 开启进程池
94 | # 进度条
95 | process_bar = tqdm.tqdm(iterable=range(times), ncols=80, nrows=20, desc=DESCRIPTION)
96 | # 多进程测试
97 | report = [process_pool.apply_async(run, (), callback=lambda _: process_bar.update()) for _ in range(times)]
98 | # 进程开始
99 | process_pool.close()
100 | process_pool.join()
101 | # 测试报告
102 | output = format_report(report)
103 | output_report(output)
104 |
--------------------------------------------------------------------------------
/test_case/BeautifulReport.py:
--------------------------------------------------------------------------------
1 | """
2 | @Version: 1.0
3 | @Project: BeautyReport
4 | @Author: Raymond
5 | @Data: 2017/11/15 下午5:28
6 | @File: BeautifulReport.py
7 | @License: MIT
8 | """
9 |
10 | import base64
11 | import json
12 | import os
13 | import platform
14 | import sys
15 | import time
16 | import traceback
17 | import unittest
18 | from distutils.sysconfig import get_python_lib
19 | from functools import wraps
20 | from io import StringIO as StringIO
21 |
22 | __all__ = ['BeautifulReport']
23 |
24 | HTML_IMG_TEMPLATE = """
25 |
26 |
27 |
28 |
29 | """
30 |
31 |
32 | class OutputRedirector(object):
33 | """
34 | Wrapper to redirect stdout or stderr
35 | """
36 |
37 | def __init__(self, fp):
38 | self.fp = fp
39 |
40 | def write(self, s):
41 | self.fp.write(s)
42 |
43 | def writelines(self, lines):
44 | self.fp.writelines(lines)
45 |
46 | def flush(self):
47 | self.fp.flush()
48 |
49 |
50 | stdout_redirector = OutputRedirector(sys.stdout)
51 | stderr_redirector = OutputRedirector(sys.stderr)
52 |
53 | SYSSTR = platform.system()
54 | SITE_PAKAGE_PATH = get_python_lib()
55 |
56 | FIELDS = {
57 | "testPass": 0,
58 | "testResult": [
59 | ],
60 | "testName": "",
61 | "testAll": 0,
62 | "testFail": 0,
63 | "beginTime": "",
64 | "totalTime": "",
65 | "testSkip": 0
66 | }
67 |
68 |
69 | class PATH:
70 | """
71 | all file PATH meta
72 | """
73 | config_tmp_path = 'test_case/template/template'
74 |
75 |
76 | class MakeResultJson:
77 | """
78 | make html table tags
79 | """
80 |
81 | def __init__(self, datas: tuple):
82 | """
83 | init self object
84 | :param datas: 拿到所有返回数据结构
85 | """
86 | self.datas = datas
87 | self.result_schema = {}
88 |
89 | def __setitem__(self, key, value):
90 | """
91 |
92 | :param key: self[key]
93 | :param value: value
94 | :return:
95 | """
96 | self[key] = value
97 |
98 | def __repr__(self) -> str:
99 | """
100 | 返回对象的html结构体
101 | :rtype: dict
102 | :return: self的repr对象, 返回一个构造完成的tr表单
103 | """
104 | keys = (
105 | 'className',
106 | 'methodName',
107 | 'description',
108 | 'spendTime',
109 | 'status',
110 | 'log',
111 | )
112 | for key, data in zip(keys, self.datas):
113 | self.result_schema.setdefault(key, data)
114 | return json.dumps(self.result_schema)
115 |
116 |
117 | class ReportTestResult(unittest.TestResult):
118 | """
119 | override
120 | """
121 |
122 | def __init__(self, suite, stream=sys.stdout):
123 | """
124 | pass
125 | """
126 | super(ReportTestResult, self).__init__()
127 | self.begin_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
128 | self.start_time = 0
129 | self.stream = stream
130 | self.end_time = 0
131 | self.failure_count = 0
132 | self.error_count = 0
133 | self.success_count = 0
134 | self.skipped = 0
135 | self.verbosity = 1
136 | self.success_case_info = []
137 | self.skipped_case_info = []
138 | self.failures_case_info = []
139 | self.errors_case_info = []
140 | self.all_case_counter = 0
141 | self.suite = suite
142 | self.status = ''
143 | self.result_list = []
144 | self.case_log = ''
145 | self.default_report_name = '自动化测试报告'
146 | self.FIELDS = None
147 | self.sys_stdout = None
148 | self.sys_stderr = None
149 | self.outputBuffer = None
150 |
151 | @property
152 | def success_counter(self) -> int:
153 | """
154 | set success counter
155 | """
156 | return self.success_count
157 |
158 | @success_counter.setter
159 | def success_counter(self, value) -> None:
160 | """
161 | success_counter函数的setter方法, 用于改变成功的case数量
162 | :param value: 当前传递进来的成功次数的int数值
163 | :return:
164 | """
165 | self.success_count = value
166 |
167 | def startTest(self, test) -> None:
168 | """
169 | 当测试用例测试即将运行时调用
170 | :return:
171 | """
172 | unittest.TestResult.startTest(self, test)
173 | self.outputBuffer = StringIO()
174 | stdout_redirector.fp = self.outputBuffer
175 | stderr_redirector.fp = self.outputBuffer
176 | self.sys_stdout = sys.stdout
177 | self.sys_stdout = sys.stderr
178 | sys.stdout = stdout_redirector
179 | sys.stderr = stderr_redirector
180 | self.start_time = time.time()
181 |
182 | def stopTest(self, test) -> None:
183 | """
184 | 当测试用例执行完成后进行调用
185 | :return:
186 | """
187 | self.end_time = '{0:.3} s'.format((time.time() - self.start_time))
188 | self.result_list.append(self.get_all_result_info_tuple(test))
189 | self.complete_output()
190 |
191 | def complete_output(self):
192 | """
193 | Disconnect output redirection and return buffer.
194 | Safe to call multiple times.
195 | """
196 | if self.sys_stdout:
197 | sys.stdout = self.sys_stdout
198 | sys.stderr = self.sys_stdout
199 | self.sys_stdout = None
200 | self.sys_stdout = None
201 | return self.outputBuffer.getvalue()
202 |
203 | def stopTestRun(self, title=None) -> dict:
204 | """
205 | 所有测试执行完成后, 执行该方法
206 | :param title:
207 | :return:
208 | """
209 | FIELDS['testPass'] = self.success_counter
210 | for item in self.result_list:
211 | item = json.loads(str(MakeResultJson(item)))
212 | FIELDS.get('testResult').append(item)
213 | FIELDS['testAll'] = len(self.result_list)
214 | FIELDS['testName'] = title if title else self.default_report_name
215 | FIELDS['testFail'] = self.failure_count
216 | FIELDS['beginTime'] = self.begin_time
217 | end_time = int(time.time())
218 | start_time = int(time.mktime(time.strptime(self.begin_time, '%Y-%m-%d %H:%M:%S')))
219 | FIELDS['totalTime'] = str(end_time - start_time) + 's'
220 | FIELDS['testError'] = self.error_count
221 | FIELDS['testSkip'] = self.skipped
222 | self.FIELDS = FIELDS
223 | return FIELDS
224 |
225 | def get_all_result_info_tuple(self, test) -> tuple:
226 | """
227 | 接受test 相关信息, 并拼接成一个完成的tuple结构返回
228 | :param test:
229 | :return:
230 | """
231 | return tuple([*self.get_testcase_property(test), self.end_time, self.status, self.case_log])
232 |
233 | @staticmethod
234 | def error_or_failure_text(err) -> str:
235 | """
236 | 获取sys.exc_info()的参数并返回字符串类型的数据, 去掉t6 error
237 | :param err:
238 | :return:
239 | """
240 | return traceback.format_exception(*err)
241 |
242 | def addSuccess(self, test) -> None:
243 | """
244 | pass
245 | :param test:
246 | :return:
247 | """
248 | logs = []
249 | output = self.complete_output()
250 | logs.append(output)
251 | # if self.verbosity > 1:
252 | # sys.stderr.write('ok ')
253 | # sys.stderr.write(str(test))
254 | # sys.stderr.write('\n')
255 | # else:
256 | # sys.stderr.write('#')
257 | self.success_counter += 1
258 | self.status = '成功'
259 | self.case_log = output.split('\n')
260 | self._mirrorOutput = True # print(class_name, method_name, method_doc)
261 |
262 | def addError(self, test, err):
263 | """
264 | add Some Error Result and infos
265 | :param test:
266 | :param err:
267 | :return:
268 | """
269 | logs = []
270 | output = self.complete_output()
271 | logs.append(output)
272 | logs.extend(self.error_or_failure_text(err))
273 | self.failure_count += 1
274 | self.add_test_type('失败', logs)
275 | # if self.verbosity > 1:
276 | # sys.stderr.write('F ')
277 | # sys.stderr.write(str(test))
278 | # sys.stderr.write('\n')
279 | # self.process_bar.update(1)
280 | # else:
281 | # sys.stderr.write('F')
282 | # self.process_bar.update(1)
283 |
284 | self._mirrorOutput = True
285 |
286 | def addFailure(self, test, err):
287 | """
288 | add Some Failures Result and infos
289 | :param test:
290 | :param err:
291 | :return:
292 | """
293 | logs = []
294 | output = self.complete_output()
295 | logs.append(output)
296 | logs.extend(self.error_or_failure_text(err))
297 | self.failure_count += 1
298 | self.add_test_type('失败', logs)
299 | # if self.verbosity > 1:
300 | # sys.stderr.write('F ')
301 | # sys.stderr.write(str(test))
302 | # sys.stderr.write('\n')
303 | # self.process_bar.update(1)
304 | # else:
305 | # sys.stderr.write('F')
306 | # self.process_bar.update(1)
307 |
308 | self._mirrorOutput = True
309 |
310 | def addSkip(self, test, reason) -> None:
311 | """
312 | 获取全部的跳过的case信息
313 | :param test:
314 | :param reason:
315 | :return: None
316 | """
317 | logs = [reason]
318 | self.complete_output()
319 | self.skipped += 1
320 | self.add_test_type('跳过', logs)
321 |
322 | # if self.verbosity > 1:
323 | # sys.stderr.write('S ')
324 | # sys.stderr.write(str(test))
325 | # sys.stderr.write('\n')
326 | # self.process_bar.update(1)
327 | # else:
328 | # sys.stderr.write('S')
329 | # self.process_bar.update(1)
330 | self._mirrorOutput = True
331 |
332 | def add_test_type(self, status: str, case_log: list) -> None:
333 | """
334 | abstruct add test type and return tuple
335 | :param status:
336 | :param case_log:
337 | :return:
338 | """
339 | self.status = status
340 | self.case_log = case_log
341 |
342 | @staticmethod
343 | def get_testcase_property(test) -> tuple:
344 | """
345 | 接受一个test, 并返回一个test的class_name, method_name, method_doc属性
346 | :param test:
347 | :return: (class_name, method_name, method_doc) -> tuple
348 | """
349 | class_name = test.__class__.__qualname__
350 | method_name = test.__dict__['_testMethodName']
351 | method_doc = test.__dict__['_testMethodDoc']
352 | return class_name, method_name, method_doc
353 |
354 |
355 | class BeautifulReport(ReportTestResult, PATH):
356 | img_path = 'img/' if platform.system() != 'Windows' else 'img\\'
357 |
358 | def __init__(self, suites):
359 | super(BeautifulReport, self).__init__(suites)
360 | self.suites = suites
361 | self.log_path = None
362 | self.title = '自动化测试报告'
363 | self.filename = 'report.html'
364 |
365 | def report(self, description, filename: str = None, log_path='.'):
366 | """
367 | 生成测试报告,并放在当前运行路径下
368 | :param log_path: 生成report的文件存储路径
369 | :param filename: 生成文件的filename
370 | :param description: 生成文件的注释
371 | :return:
372 | """
373 | if filename:
374 | self.filename = filename if filename.endswith('.html') else filename + '.html'
375 |
376 | if description:
377 | self.title = description
378 |
379 | self.log_path = os.path.abspath(log_path)
380 | self.suites.run(result=self)
381 | self.stopTestRun(self.title)
382 | # self.output_report()
383 | # text = '测试已全部完成, 可前往{}查询测试报告'.format(self.log_path)
384 | # print(text)
385 | return self.FIELDS
386 |
387 | def output_report(self):
388 | """
389 | 生成测试报告到指定路径下
390 | :return:
391 | """
392 | template_path = self.config_tmp_path
393 | override_path = os.path.abspath(self.log_path) if \
394 | os.path.abspath(self.log_path).endswith('/') else \
395 | os.path.abspath(self.log_path) + '/'
396 |
397 | with open(template_path, 'rb') as file:
398 | body = file.readlines()
399 | with open(override_path + self.filename, 'wb') as write_file:
400 | for item in body:
401 | if item.strip().startswith(b'var resultData'):
402 | head = ' var resultData = '
403 | item = item.decode().split(head)
404 | item[1] = head + json.dumps(self.FIELDS, ensure_ascii=False, indent=4)
405 | item = ''.join(item).encode()
406 | item = bytes(item) + b';\n'
407 | write_file.write(item)
408 |
409 | @staticmethod
410 | def img2base(img_path: str, file_name: str) -> str:
411 | """
412 | 接受传递进函数的filename 并找到文件转换为base64格式
413 | :param img_path: 通过文件名及默认路径找到的img绝对路径
414 | :param file_name: 用户在装饰器中传递进来的问价匿名
415 | :return:
416 | """
417 | pattern = '/' if platform != 'Windows' else '\\'
418 |
419 | with open(img_path + pattern + file_name, 'rb') as file:
420 | data = file.read()
421 | return base64.b64encode(data).decode()
422 |
423 | def add_test_img(*pargs):
424 | """
425 | 接受若干个图片元素, 并展示在测试报告中
426 | :param pargs:
427 | :return:
428 | """
429 |
430 | def _wrap(func):
431 | @wraps(func)
432 | def __wrap(*args, **kwargs):
433 | img_path = os.path.abspath('{}'.format(BeautifulReport.img_path))
434 | try:
435 | result = func(*args, **kwargs)
436 | except Exception:
437 | if 'save_img' in dir(args[0]):
438 | save_img = getattr(args[0], 'save_img')
439 | save_img(func.__name__)
440 | data = BeautifulReport.img2base(img_path, pargs[0] + '.png')
441 | print(HTML_IMG_TEMPLATE.format(data, data))
442 | sys.exit(0)
443 | print('
')
444 |
445 | if len(pargs) > 1:
446 | for parg in pargs:
447 | print(parg + ':')
448 | data = BeautifulReport.img2base(img_path, parg + '.png')
449 | print(HTML_IMG_TEMPLATE.format(data, data))
450 | return result
451 | if not os.path.exists(img_path + pargs[0] + '.png'):
452 | return result
453 | data = BeautifulReport.img2base(img_path, pargs[0] + '.png')
454 | print(HTML_IMG_TEMPLATE.format(data, data))
455 | return result
456 |
457 | return __wrap
458 |
459 | return _wrap
460 |
--------------------------------------------------------------------------------
/test_case/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | """
4 | @Version: 2.0.2
5 | @Project: Secure-Multi-Party-Computation-Protocol
6 | @Author: Zhan Shi
7 | @Time : 2022/5/5 12:13
8 | @File: __init__.py
9 | @License: MIT
10 | """
11 | __name__ = 'test_case'
12 |
--------------------------------------------------------------------------------
/test_case/test_smpcp.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | """
4 | @Version: 2.0.2
5 | @Project: Secure-Multi-Party-Computation-Protocol
6 | @Author: Zhan Shi
7 | @Time : 2022/5/4 15:54
8 | @File: test_smpcp.py
9 | @License: MIT
10 | """
11 | import random
12 | import sys
13 | import unittest
14 |
15 | import gmpy2
16 | import phe
17 |
18 | from smpcp.smpcp import CloudPlatform, CloudPlatformThird, SecureMultiPartyComputationProtocol
19 |
20 | sys.path.append("test_case/") # 添加测试文件路径
21 |
22 | key_length = 2048 # TODO 密钥长度
23 |
24 | public_key, secret_key = phe.generate_paillier_keypair(n_length=key_length) # 生成密钥对
25 |
26 | cloud1 = CloudPlatform(public_key=public_key) # 云服务器1
27 | cloud2 = CloudPlatformThird(public_key=public_key, secret_key=secret_key) # 云服务器2
28 |
29 | protocol = SecureMultiPartyComputationProtocol(c1=cloud1, c2=cloud2) # 安全多方计算协议类
30 |
31 |
32 | class SMPCPTest(unittest.TestCase):
33 | """
34 | 安全多方计算协议测试类
35 | """
36 |
37 | def setUp(self):
38 | """
39 | 测试前
40 | """
41 | # 生成浮点数
42 | self.float1 = int(
43 | gmpy2.mpz_random(gmpy2.random_state(
44 | int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)),
45 | key_length))), key_length)) * random.uniform(0.1, 1.0)
46 | self.float2 = int(
47 | gmpy2.mpz_random(gmpy2.random_state(
48 | int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)), key_length))),
49 | key_length)) * random.uniform(0.1, 1.0)
50 | self.float_n1 = protocol.encode(public_key.encrypt(self.float1))
51 | self.float_n2 = public_key.encrypt(self.float2)
52 | # 生成整数
53 | self.int1 = int(gmpy2.mpz_random(gmpy2.random_state(
54 | int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)), key_length))),
55 | key_length))
56 | self.int2 = int(gmpy2.mpz_random(gmpy2.random_state(
57 | int(gmpy2.mpz_random(gmpy2.random_state(random.SystemRandom().randint(1, 0xffffffff)), key_length))),
58 | key_length))
59 | self.int_n1 = protocol.encode(public_key.encrypt(self.int1))
60 | self.int_n2 = public_key.encrypt(self.int2)
61 | return super().setUp()
62 |
63 | def tearDown(self):
64 | """
65 | 测试后
66 | """
67 | return super().tearDown()
68 |
69 | # TODO 安全乘法协议测试
70 | # @unittest.skip('跳过安全乘法协议')
71 | def test_mul(self):
72 | """
73 | 安全乘法协议
74 | """
75 | # 浮点乘法测试:经过测试,最高支持8位浮点乘法
76 | self.assertEqual(round(self.float1 * self.float2, 8),
77 | round(secret_key.decrypt(self.float_n1 * self.float_n2), 8))
78 |
79 | # 整数乘法测试:经过测试,无明显问题
80 | self.assertEqual(self.int1 * self.int2, secret_key.decrypt(self.int_n1 * self.int_n2))
81 |
82 | # TODO 安全除法协议测试
83 | # @unittest.skip('跳过安全除法协议')
84 | def test_div(self):
85 | """
86 | 安全除法协议
87 | """
88 | # 浮点除法测试:经过测试,最高支持10位浮点除法
89 | self.assertEqual(round(self.float1 / self.float2, 10),
90 | round(secret_key.decrypt(self.float_n1 / self.float_n2), 10))
91 |
92 | # 整数除法测试:经过测试,最高支持10位整数除法
93 | self.assertEqual(round(self.int1 / self.int2, 10), round(secret_key.decrypt(self.int_n1 / self.int_n2), 10))
94 |
95 | # TODO 安全最值计算协议测试
96 | # @unittest.skip('跳过安全最值计算协议')
97 | def test_optimum(self):
98 | """
99 | 安全最值计算协议
100 | """
101 | mode = 'max' if random.random() > 0.5 else 'min'
102 | if mode == 'max':
103 | # 浮点最大值计算测试:经过测试,无明显问题
104 | self.assertEqual(max(self.float1, self.float2),
105 | secret_key.decrypt(self.float_n1.optimum(self.float_n2, 'max')))
106 |
107 | # 整数最大值计算测试:经过测试,无明显问题
108 | self.assertEqual(max(self.int1, self.int2), secret_key.decrypt(self.int_n1.optimum(self.int_n2, 'max')))
109 | else:
110 | # 浮点最小值计算测试:经过测试,无明显问题
111 | self.assertEqual(min(self.float1, self.float2),
112 | secret_key.decrypt(self.float_n1.optimum(self.float_n2, 'min')))
113 |
114 | # 整数最小值计算测试:经过测试,无明显问题
115 | self.assertEqual(min(self.int1, self.int2), secret_key.decrypt(self.int_n1.optimum(self.int_n2, 'min')))
116 |
117 | # TODO 安全奇偶性判断协议测试
118 | # @unittest.skip('跳过安全奇偶性判断协议')
119 | def test_parity(self):
120 | """
121 | 安全奇偶性判断协议
122 | """
123 | # 整数奇偶性判断测试:经过测试,无明显问题
124 | self.assertEqual(self.int1 % 2, secret_key.decrypt(self.int_n1.parity()))
125 |
126 | # TODO 安全二进制分解协议测试
127 | # @unittest.skip('跳过安全二进制分解协议')
128 | def test_bit_dec(self):
129 | """
130 | 安全二进制分解协议
131 | """
132 | # 整数二进制分解测试:经过测试,无明显问题
133 | bit = len(bin(self.int1).split('b')[1])
134 | result = ''.join([str(secret_key.decrypt(v)) for v in self.int_n1.bit_dec(bit)])
135 | self.assertEqual(bin(self.int1).split('b')[1], result)
136 |
137 | # TODO 安全二进制与协议测试
138 | # @unittest.skip('跳过安全二进制与协议')
139 | def test_and(self):
140 | """
141 | 安全二进制与协议
142 | """
143 | bit1 = random.SystemRandom().randint(0, 1)
144 | bit2 = random.SystemRandom().randint(0, 1)
145 | bit_n1 = protocol.encode(public_key.encrypt(bit1))
146 | bit_n2 = public_key.encrypt(bit2)
147 | # 二进制或测试:经过测试,无明显问题
148 | self.assertEqual(bit1 & bit2, secret_key.decrypt(bit_n1 & bit_n2))
149 |
150 | # TODO 安全二进制或协议测试
151 | # @unittest.skip('跳过安全二进制或协议')
152 | def test_or(self):
153 | """
154 | 安全二进制或协议
155 | """
156 | bit1 = random.SystemRandom().randint(0, 1)
157 | bit2 = random.SystemRandom().randint(0, 1)
158 | bit_n1 = protocol.encode(public_key.encrypt(bit1))
159 | bit_n2 = public_key.encrypt(bit2)
160 | # 二进制或测试:经过测试,无明显问题
161 | self.assertEqual(bit1 | bit2, secret_key.decrypt(bit_n1 | bit_n2))
162 |
163 | # TODO 安全二进制非协议测试
164 | # @unittest.skip('跳过安全二进制非协议')
165 | def test_bit_not(self):
166 | """
167 | 安全二进制非协议
168 | """
169 | bit1 = random.SystemRandom().randint(0, 1)
170 | bit_n1 = protocol.encode(public_key.encrypt(bit1))
171 | # 二进制或测试:经过测试,无明显问题
172 | self.assertEqual(1 - bit1, secret_key.decrypt(bit_n1.bit_not()))
173 |
174 | # TODO 安全二进制异或协议测试
175 | # @unittest.skip('跳过安全二进制异或协议')
176 | def test_xor(self):
177 | """
178 | 安全二进制异或协议
179 | """
180 | bit1 = random.SystemRandom().randint(0, 1)
181 | bit2 = random.SystemRandom().randint(0, 1)
182 | bit_n1 = protocol.encode(public_key.encrypt(bit1))
183 | bit_n2 = public_key.encrypt(bit2)
184 | # 二进制或测试:经过测试,无明显问题
185 | self.assertEqual(bit1 ^ bit2, secret_key.decrypt(bit_n1 ^ bit_n2))
186 |
187 | # TODO 安全相等协议测试
188 | # @unittest.skip('跳过安全相等协议')
189 | def test_eq(self):
190 | """
191 | 安全相等协议
192 | """
193 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果
194 | self.assertEqual(1 if self.float1 == self.float1 else 0,
195 | secret_key.decrypt(self.float_n1 == self.float_n1.decode()))
196 |
197 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果
198 | self.assertEqual(1 if self.int1 == self.int1 else 0, secret_key.decrypt(self.int_n1 == self.int_n1.decode()))
199 |
200 | # TODO 安全不相等协议测试
201 | # @unittest.skip('跳过安全不相等协议')
202 | def test_ne(self):
203 | """
204 | 安全不相等协议
205 | """
206 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果
207 | self.assertEqual(1 if self.float1 != self.float2 else 0, secret_key.decrypt(self.float_n1 != self.float_n2))
208 |
209 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果
210 | self.assertEqual(1 if self.int1 != self.int2 else 0, secret_key.decrypt(self.int_n1 != self.int_n2))
211 |
212 | # TODO 安全大于协议测试
213 | # @unittest.skip('跳过安全大于协议')
214 | def test_gt(self):
215 | """
216 | 安全大于协议
217 | """
218 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果
219 | self.assertEqual(1 if self.float1 > self.float2 else 0, secret_key.decrypt(self.float_n1 > self.float_n2))
220 |
221 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果
222 | self.assertEqual(1 if self.int1 > self.int2 else 0, secret_key.decrypt(self.int_n1 > self.int_n2))
223 |
224 | # TODO 安全大于等于协议测试
225 | # @unittest.skip('跳过安全大于等于协议')
226 | def test_ge(self):
227 | """
228 | 安全大于等于协议
229 | """
230 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果
231 | self.assertEqual(1 if self.float1 >= self.float2 else 0, secret_key.decrypt(self.float_n1 >= self.float_n2))
232 |
233 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果
234 | self.assertEqual(1 if self.int1 >= self.int2 else 0, secret_key.decrypt(self.int_n1 >= self.int_n2))
235 |
236 | # TODO 安全小于协议测试
237 | # @unittest.skip('跳过安全小于协议')
238 | def test_lt(self):
239 | """
240 | 安全小于协议
241 | """
242 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果
243 | self.assertEqual(1 if self.float1 < self.float2 else 0, secret_key.decrypt(self.float_n1 < self.float_n2))
244 |
245 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果
246 | self.assertEqual(1 if self.int1 < self.int2 else 0, secret_key.decrypt(self.int_n1 < self.int_n2))
247 |
248 | # TODO 安全小于等于协议测试
249 | # @unittest.skip('跳过安全小于等于协议')
250 | def test_le(self):
251 | """
252 | 安全小于等于协议
253 | """
254 | # 浮点数相等测试:经过测试,极少数情况下,浮点数会影响结果
255 | self.assertEqual(1 if self.float1 <= self.float2 else 0, secret_key.decrypt(self.float_n1 <= self.float_n2))
256 |
257 | # 整数相等测试:经过测试,极少数情况下,浮点数会影响结果
258 | self.assertEqual(1 if self.int1 <= self.int2 else 0, secret_key.decrypt(self.int_n1 <= self.int_n2))
259 |
--------------------------------------------------------------------------------