├── multicollinearity_test.py
├── corr_coefficient_matrix.py
├── full_table_statistics.py
├── README.md
├── data_visual.py
├── data_process.py
├── LICENSE
├── mlp.py
└── neural_network_gui.py
/multicollinearity_test.py:
--------------------------------------------------------------------------------
1 | from secretflow.stats.ss_vif_v import VIF
2 | from data_process import data_dict,spu
3 |
4 | def VIF_calculation(vdf_hat):
5 | '''
6 | 计算VIF来进行多重共线性检验
7 | '''
8 | # 创建一个 VIF 计算器对象,传入 SPU 作为参数
9 | vif_calculator = VIF(spu)
10 |
11 | # 使用 VIF 计算器计算 vdf_hat 数据集的 VIF 值
12 | vif_results = vif_calculator.vif(vdf_hat)
13 |
14 | # 打印分隔符和多重共线性检验标题
15 | print("="*40 + "多重共线性检验" + "="*40)
16 |
17 | # 打印 vdf_hat 数据集的列名
18 | print(vdf_hat.columns)
19 |
20 | # 打印计算得到的 VIF 结果
21 | print(vif_results)
22 |
23 | if __name__ == '__main__':
24 | vdf_hat = data_dict['vdf_hat']
25 | VIF_calculation(vdf_hat)
--------------------------------------------------------------------------------
/corr_coefficient_matrix.py:
--------------------------------------------------------------------------------
1 | from data_process import data_dict,spu
2 | from secretflow.stats.ss_pearsonr_v import PearsonR
3 | import numpy as np
4 |
5 | def correlation_coefficient_matrix(vdf_hat):
6 | '''
7 | 计算相关系数矩阵(利用皮尔逊相关系数)
8 | 在计算相关系数矩阵的时候要排除无序类别的数据
9 | '''
10 | # 创建一个 PearsonR 对象,用于计算相关系数矩阵
11 | pearson_r = PearsonR(spu)
12 |
13 | # 使用 PearsonR 对象计算 vdf_hat 的相关系数矩阵
14 | corr_matrix = pearson_r.pearsonr(vdf_hat)
15 |
16 | # 打印标题,表示接下来输出的是相关系数矩阵
17 | print("==================== 相关系数矩阵 ====================\n")
18 |
19 | # 设置 numpy 的打印选项,格式化浮点数输出为小数点后三位
20 | np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
21 |
22 | # 打印相关系数矩阵
23 | print(corr_matrix)
24 |
25 | if __name__ == '__main__':
26 | vdf_hat = data_dict['vdf_hat']
27 | correlation_coefficient_matrix(vdf_hat)
--------------------------------------------------------------------------------
/full_table_statistics.py:
--------------------------------------------------------------------------------
1 | from secretflow.stats.table_statistics import table_statistics
2 | from data_process import data_dict
3 | import os
4 | import tempfile
5 | import subprocess
6 |
7 | def full_table_statistics(vdf):
8 | '''
9 | 全表统计并保存为Excel文件
10 | '''
11 | # 生成全表统计数据
12 | data_stats = table_statistics(vdf)
13 |
14 | # 创建临时文件路径
15 | temp_dir = tempfile.gettempdir() # 获取系统临时目录
16 | output_file = os.path.join(temp_dir, 'full_table_stats.xlsx') # 设置文件路径
17 |
18 | # 将数据保存为Excel文件
19 | data_stats.to_excel(output_file, index=False)
20 |
21 | print(f"统计结果已保存至 {output_file}")
22 |
23 | # 打开生成的Excel文件
24 | try:
25 | # 使用 LibreOffice Calc 打开 Excel 文件
26 | subprocess.run(['libreoffice', '--calc', output_file], check=True)
27 | except subprocess.CalledProcessError as e:
28 | print(f"An error occurred while opening the Excel file: {e}")
29 | # 对于非Windows系统,可以使用以下代码:
30 | # os.system(f'open "{output_file}"') # macOS
31 | os.system(f'xdg-open "{output_file}"') # Linux
32 | return
33 |
34 | if __name__ == '__main__':
35 | full_table_statistics(data_dict['vdf'])
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Neural Network Based on SecretFlow
3 |
4 | 本项目使用SecretFlow框架实现了多方安全计算,旨在通过联合数据训练,预测城市居民的年收入是否超过50k。该项目由城市内的三个政府机构合作进行,运用了多层感知器(MLP)模型来进行神经网络训练和预测。为了确保数据的准确性和可靠性,我们进行了全面的统计分析,包括多重共线性检验(VIF)和相关系数矩阵分析,并通过直方图、饼图等可视化方式呈现数据分析结果。此外,我们还开发了一个图形用户界面(GUI),使项目的操作更加简便直观。该项目的实施不仅体现了隐私保护技术在现实中的应用潜力,同时也为政府部门间的数据共享与合作提供了安全、高效的解决方案。
5 |
6 |
7 |
8 | [![Contributors][contributors-shield]][contributors-url]
9 | [![Forks][forks-shield]][forks-url]
10 | [![Stargazers][stars-shield]][stars-url]
11 | [![Issues][issues-shield]][issues-url]
12 | [![MIT License][license-shield]][license-url]
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
"Neural Network Based on SecretFlow
23 |
24 | 一个使用SecretFlow、MLP模型来进行神经网络训练和预测的项目!
25 |
26 | 探索本项目的文档 »
27 |
28 |
29 | 查看Demo
30 | ·
31 | 报告Bug
32 | ·
33 | 提出新特性
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | ## 目录
42 |
43 | - [上手指南](#上手指南)
44 | - [开发前的配置要求](#开发前的配置要求)
45 | - [安装步骤](#安装步骤)
46 | - [文件目录说明](#文件目录说明)
47 | - [开发的架构](#开发的架构)
48 | - [部署](#部署)
49 | - [使用到的框架](#使用到的框架)
50 | - [贡献者](#贡献者)
51 | - [如何参与开源项目](#如何参与开源项目)
52 | - [版本控制](#版本控制)
53 | - [作者](#作者)
54 | - [鸣谢](#鸣谢)
55 |
56 | ### 上手指南
57 |
58 | 请将所有链接中的“YnRen22852/secretflowgryffindor”改为“your_github_name/your_repository”
59 |
60 |
61 |
62 | ###### 开发前的配置要求
63 |
64 | 1. xxxxx x.x.x
65 | 2. xxxxx x.x.x
66 |
67 | ###### **安装步骤**
68 |
69 | 1. Get a free API Key at [https://example.com](https://example.com)
70 | 2. Clone the repo
71 |
72 | ```sh
73 | git clone https://github.com/YnRen22852/secretflowgryffindor.git
74 | ```
75 |
76 | ### 文件目录说明
77 | eg:
78 |
79 | ```
80 | filetree
81 | ├── ARCHITECTURE.md
82 | ├── LICENSE.txt
83 | ├── README.md
84 | ├── /account/
85 | ├── /bbs/
86 | ├── /docs/
87 | │ ├── /rules/
88 | │ │ ├── backend.txt
89 | │ │ └── frontend.txt
90 | ├── manage.py
91 | ├── /oa/
92 | ├── /static/
93 | ├── /templates/
94 | ├── useless.md
95 | └── /util/
96 |
97 | ```
98 |
99 |
100 |
101 |
102 |
103 | ### 开发的架构
104 |
105 | 请阅读[ARCHITECTURE.md](https://github.com/YnRen22852/secretflowgryffindor/blob/master/ARCHITECTURE.md) 查阅为该项目的架构。
106 |
107 | ### 部署
108 |
109 | 暂无
110 |
111 | ### 使用到的框架
112 |
113 | - [xxxxxxx](https://getbootstrap.com)
114 | - [xxxxxxx](https://jquery.com)
115 | - [xxxxxxx](https://laravel.com)
116 |
117 | ### 贡献者
118 |
119 | 请阅读[贡献者](https://github.com/YnRen22852/secretflowgryffindor/graphs/contributors) 查阅为该项目做出贡献的开发者。
120 |
121 | #### 如何参与开源项目
122 |
123 | 贡献使开源社区成为一个学习、激励和创造的绝佳场所。你所作的任何贡献都是**非常感谢**的。
124 |
125 |
126 | 1. Fork the Project
127 | 2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`)
128 | 3. Commit your Changes (`git commit -m 'Add some AmazingFeature'`)
129 | 4. Push to the Branch (`git push origin feature/AmazingFeature`)
130 | 5. Open a Pull Request
131 |
132 |
133 |
134 | ### 版本控制
135 |
136 | 该项目使用Git进行版本管理。您可以在repository参看当前可用版本。
137 |
138 | ### 作者
139 |
140 | xxx@xxxx
141 |
142 | 知乎:xxxx qq:xxxxxx
143 |
144 | *您也可以在贡献者名单中参看所有参与该项目的开发者。*
145 |
146 | ### 版权说明
147 |
148 | 该项目签署了Apache License 2.0授权许可,详情请参阅 [LICENSE](https://github.com/YnRen22852/secretflowgryffindor/blob/master/LICENSE)
149 |
150 | ### 鸣谢
151 |
152 |
153 | - [GitHub Emoji Cheat Sheet](https://www.webpagefx.com/tools/emoji-cheat-sheet)
154 | - [Img Shields](https://shields.io)
155 | - [Choose an Open Source License](https://choosealicense.com)
156 | - [GitHub Pages](https://pages.github.com)
157 | - [Animate.css](https://daneden.github.io/animate.css)
158 | - [xxxxxxxxxxxxxx](https://connoratherton.com/loaders)
159 |
160 |
161 | [your-project-path]:https://github.com/YnRen22852/secretflowgryffindor
162 | [contributors-shield]: https://img.shields.io/github/contributors/YnRen22852/secretflowgryffindor.svg?style=flat-square
163 | [contributors-url]: https://github.com/YnRen22852/secretflowgryffindor/graphs/contributors
164 | [forks-shield]: https://img.shields.io/github/forks/YnRen22852/secretflowgryffindor.svg?style=flat-square
165 | [forks-url]: https://github.com/YnRen22852/secretflowgryffindor/network/members
166 | [stars-shield]: https://img.shields.io/github/stars/YnRen22852/secretflowgryffindor.svg?style=flat-square
167 | [stars-url]: https://github.com/YnRen22852/secretflowgryffindor/stargazers
168 | [issues-shield]: https://img.shields.io/github/issues/YnRen22852/secretflowgryffindor.svg?style=flat-square
169 | [issues-url]: https://github.com/YnRen22852/secretflowgryffindor/issues
170 | [license-shield]: https://img.shields.io/github/license/YnRen22852/secretflowgryffindor.svg?style=flat-square
171 | [license-url]: https://github.com/YnRen22852/secretflowgryffindor/blob/master/LICENSE
172 |
173 |
174 |
175 |
176 |
--------------------------------------------------------------------------------
/data_visual.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import matplotlib.pyplot as plt
3 | import seaborn as sns
4 | import pandas as pd
5 | from matplotlib.backends.backend_pdf import PdfPages
6 | import tempfile
7 | import subprocess
8 | from data_process import alice_path, bob_path, carol_path, full_file_path
9 | from mlp import install_package
10 |
11 | # 安装seaborn包,用于数据可视化
12 | # install_package('seaborn')
13 |
14 | def histogram(column_data, pdf_pages):
15 | '''
16 | 绘制直方图
17 | '''
18 | plt.figure(figsize=(10, 6))
19 |
20 | # 使用 seaborn 库绘制直方图,包含核密度估计(kde),并设置直方图的柱数为 30
21 | sns.histplot(column_data, kde=True, bins=30)
22 |
23 | # 设置图表标题,标题内容为 'Histogram of ' 加上列名
24 | plt.title(f'Histogram of {column_data.name}')
25 |
26 | # 设置 x 轴标签,标签内容为列名
27 | plt.xlabel(column_data.name)
28 |
29 | # 设置 y 轴标签,标签内容为 'Frequency'
30 | plt.ylabel('Frequency')
31 |
32 | # 调整 x 轴刻度标签的旋转角度为 90 度
33 | plt.xticks(rotation=90)
34 |
35 | # 保存当前图表到 PDF 文件
36 | pdf_pages.savefig()
37 |
38 | # 关闭当前图表,释放内存
39 | plt.close()
40 |
41 | def bar_chart(column_data, pdf_pages):
42 | '''
43 | 绘制条形图
44 | '''
45 | plt.figure(figsize=(10, 6))
46 |
47 | # 使用 seaborn 库绘制条形图,y 轴为 column_data,颜色调色板为 'viridis'
48 | ax = sns.countplot(y=column_data, palette='viridis')
49 |
50 | # 设置图表标题,标题内容为 'Bar Chart of ' 加上列名
51 | plt.title(f'Bar Chart of {column_data.name}')
52 |
53 | # 设置 x 轴标签,标签内容为 'Count'
54 | plt.xlabel('Count')
55 |
56 | # 设置 y 轴标签,标签内容为列名
57 | plt.ylabel(column_data.name)
58 |
59 | # 设置条形标签
60 | for p in ax.patches:
61 | # 获取条形的宽度(即计数值)
62 | width = p.get_width()
63 |
64 | # 获取条形的高度
65 | height = p.get_height()
66 |
67 | # 获取条形的 x 坐标
68 | x = p.get_x() + width
69 |
70 | # 获取条形的 y 坐标,条形的中心位置
71 | y = p.get_y() + height / 2
72 |
73 | # 确定标签位置以避免重叠,将标签向右偏移一些
74 | label_x = x + 0.1
75 | label_y = y
76 |
77 | # 添加标签,显示条形的宽度(即计数值)
78 | ax.annotate(f'{width}',
79 | (label_x, label_y),
80 | ha='left', va='center', fontsize=10, color='black')
81 |
82 | # 保存当前图表到 PDF 文件
83 | pdf_pages.savefig()
84 |
85 | # 关闭当前图表,释放内存
86 | plt.close()
87 |
88 | def pie_chart(column_data, colors, pdf_pages):
89 | '''
90 | 绘制饼图
91 | '''
92 | plt.figure(figsize=(8, 8))
93 |
94 | # 计算每个类别的频数
95 | sizes = column_data.value_counts()
96 |
97 | # 绘制饼图
98 | sizes.plot.pie(
99 | autopct='%1.1f%%', # 设置百分比格式
100 | startangle=90, # 设置饼图的起始角度
101 | colors=colors, # 设置饼图的颜色
102 | wedgeprops=dict(edgecolor='grey') # 设置切片的边缘颜色
103 | )
104 |
105 | # 设置图表标题,标题内容为 'Pie Chart of ' 加上列名,字体大小为 20
106 | plt.title(f'Pie Chart of {column_data.name}', fontsize=20)
107 |
108 | # 隐藏 y 轴标签
109 | plt.ylabel('')
110 |
111 | # 保存当前图表到 PDF 文件
112 | pdf_pages.savefig()
113 |
114 | # 关闭当前图表,释放内存
115 | plt.close()
116 |
117 | def data_visualize(user_file_path):
118 | '''
119 | 数据可视化
120 | '''
121 | # 定义需要绘制直方图的列名
122 | cols_histogram = ['age', 'fnlwgt', 'education', 'hours-per-week', 'capital-gain', 'capital-loss', 'native-country']
123 |
124 | # 定义需要绘制条形图的列名
125 | cols_bar_chart = ['race', 'workclass', 'occupation', 'education-num', 'marital-status', 'relationship']
126 |
127 | # 定义需要绘制饼图的列名
128 | cols_pie_chart = ['sex', 'income']
129 |
130 | # 读取用户提供的 CSV 文件,生成 DataFrame
131 | df = pd.read_csv(user_file_path)
132 |
133 | # 创建临时 PDF 文件
134 | with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as temp_pdf:
135 | # 获取临时文件的路径
136 | temp_file_path = temp_pdf.name
137 | print()
138 |
139 | # 使用 PdfPages 创建一个 PDF 文件对象
140 | with PdfPages(temp_file_path) as pdf_pages:
141 | # 遍历 DataFrame 中的每一列
142 | for col in df.columns:
143 | # 如果列名在直方图列名列表中,则绘制直方图
144 | if col in cols_histogram:
145 | histogram(df[col], pdf_pages)
146 |
147 | # 如果列名在条形图列名列表中,则绘制条形图
148 | if col in cols_bar_chart:
149 | bar_chart(df[col], pdf_pages)
150 |
151 | # 如果列名在饼图列名列表中,则绘制饼图
152 | if col in cols_pie_chart:
153 | # 自定义颜色调色板
154 | colors = sns.color_palette('Set2', n_colors=len(df[col].value_counts()))
155 | pie_chart(df[col], colors, pdf_pages)
156 |
157 | # 尝试打开生成的 PDF 文件
158 | try:
159 | # 使用系统默认的 PDF 查看器打开生成的 PDF 文件
160 | subprocess.run(['evince', temp_file_path], check=True)
161 | except subprocess.CalledProcessError as e:
162 | # 如果打开文件时发生错误,打印错误信息
163 | print(f"An error occurred while opening the PDF file: {e}")
164 |
165 | if __name__ == '__main__':
166 | choice = sys.argv[1]
167 | if choice == '1':
168 | data_visualize(alice_path)
169 | elif choice == '2':
170 | data_visualize(bob_path)
171 | elif choice == '3':
172 | data_visualize(carol_path)
173 | elif choice == '4':
174 | data_visualize(full_file_path)
175 | else:
176 | print("输入选项有误,请重新输入!")
--------------------------------------------------------------------------------
/data_process.py:
--------------------------------------------------------------------------------
1 | import secretflow as sf
2 | from secretflow.data.vertical import read_csv as v_read_csv
3 | from secretflow.preprocessing import LabelEncoder
4 | from secretflow.preprocessing import OneHotEncoder
5 | from secretflow.preprocessing import StandardScaler
6 | from secretflow.data.split import train_test_split
7 | from secretflow.stats.table_statistics import table_statistics
8 | from secretflow.stats.ss_vif_v import VIF
9 | from secretflow.stats.ss_pearsonr_v import PearsonR
10 | import jax.numpy as jnp
11 | import pandas as pd
12 | import numpy as np
13 | import tempfile
14 | import os
15 | import matplotlib.pyplot as plt
16 | import seaborn as sns
17 |
18 | def download_dataset(full_file_path):
19 | '''
20 | 将数据集下载到本地
21 | '''
22 | # UCI Adult 数据集的 URL
23 | url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
24 |
25 | # 列名
26 | column_names = [
27 | 'age', 'workclass', 'fnlwgt', 'education', 'education-num',
28 | 'marital-status', 'occupation', 'relationship', 'race', 'sex',
29 | 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'income'
30 | ]
31 |
32 | # 下载数据集
33 | adult_df = pd.read_csv(url, names=column_names, sep=',\s', na_values=["?"], engine='python')
34 | # 保存为csv文件
35 | adult_df.to_csv(full_file_path,index=False)
36 |
37 | def load_dataset(full_file_path):
38 | '''
39 | 在指定目录家在数据集,如果不存在则下载
40 | '''
41 | # 检查文件路径是否存在,如果不存在则下载数据集
42 | if not os.path.exists(full_file_path):
43 | download_dataset(full_file_path)
44 |
45 | try:
46 | # 尝试读取 CSV 文件到 DataFrame
47 | df = pd.read_csv(full_file_path)
48 | except Exception as e:
49 | # 如果读取文件时发生错误,抛出带有错误信息的 ValueError
50 | raise ValueError(f"文件读取错误:{e}")
51 |
52 | # 添加一列 'uid',其值为 DataFrame 的索引加 1
53 | df['uid'] = df.index + 1
54 |
55 | # 返回处理后的 DataFrame
56 | return df
57 |
58 | def split_dataset(data_df, alice_path, bob_path, carol_path):
59 | '''
60 | 分割数据集,将数据集垂直切分,并且存储在在各个参与方的临时文件路径
61 | '''
62 | # 获取数据集的列数,减去 1 是因为最后一列是 'uid'
63 | num_columns = data_df.shape[1] - 1
64 |
65 | # 设置分割点,将列数分成三部分
66 | split_point1 = num_columns // 3
67 | split_point2 = 2 * num_columns // 3
68 |
69 | # 分割数据集,并且确保每一方都拥有 'uid' 列
70 | # Alice 获取前 split_point1 列和最后一列 'uid',并随机抽取 90% 的数据
71 | data_alice = data_df.iloc[:, np.r_[0:split_point1, -1]].sample(frac=0.9)
72 |
73 | # Bob 获取从 split_point1 到 split_point2 列和最后一列 'uid',并随机抽取 90% 的数据
74 | data_bob = data_df.iloc[:, np.r_[split_point1:split_point2, -1]].sample(frac=0.9)
75 |
76 | # Carol 获取从 split_point2 到最后一列(包括 'uid'),并随机抽取 90% 的数据
77 | data_carol = data_df.iloc[:, split_point2:].sample(frac=0.9)
78 |
79 | # 将三方数据集保存至 CSV 文件
80 | data_alice.reset_index(drop=True).to_csv(alice_path, index=False)
81 | data_bob.reset_index(drop=True).to_csv(bob_path, index=False)
82 | data_carol.reset_index(drop=True).to_csv(carol_path, index=False)
83 |
84 | # 返回保存的文件路径
85 | return alice_path, bob_path, carol_path
86 |
87 | def secret_psi(alice_path, bob_path, carol_path):
88 | '''
89 | 隐私求交实现数据对齐,将求交结果保存至VDataFrame
90 | '''
91 | # 使用 v_read_csv 函数读取 CSV 文件,生成一个虚拟 DataFrame (vdf)
92 | vdf = v_read_csv(
93 | {alice: alice_path, bob: bob_path, carol: carol_path},
94 | spu=spu,# - spu:指定安全计算单元(SPU)
95 | keys='uid',# - keys='uid':指定用于 PSI(私密集合交集)的键列
96 | drop_keys='uid',# - drop_keys='uid':指定读取后要删除的键列
97 | psi_protocl="ECDH_PSI_3PC",# - psi_protocl="ECDH_PSI_3PC":指定使用 ECDH_PSI_3PC 协议进行三方 PSI
98 | )
99 | # 返回生成的虚拟 DataFrame (vdf)
100 | return vdf
101 |
102 | def Missing_Value_Filling(vdf):
103 | '''
104 | 缺失值填充
105 | 填充规则为:填充该列中的‘众数’
106 | '''
107 | # 定义需要填充缺失值的列名
108 | cols = ['workclass', 'occupation', 'native-country']
109 |
110 | # 遍历每一列
111 | for col in cols:
112 | # 找到该列的众数
113 | most_frequent_value = vdf[col].mode()[0]
114 |
115 | # 使用众数填充该列中的缺失值
116 | vdf[col].fillna(most_frequent_value, inplace=True)
117 |
118 | # 返回填充缺失值后的 DataFrame
119 | return vdf
120 |
121 | def label_encode_function(vdf):
122 | '''
123 | 对无序且二值的序列,采用label encoding,转换为0/1表示
124 | '''
125 | # 创建一个 LabelEncoder 对象,用于将分类数据转换为数值数据
126 | label_encoder = LabelEncoder()
127 |
128 | # 定义需要进行标签编码的列名
129 | cols = ['sex', 'income']
130 |
131 | # 遍历每一列
132 | for col in cols:
133 | # 拟合标签编码器到该列数据
134 | label_encoder.fit(vdf[col])
135 |
136 | # 将该列数据转换为数值数据
137 | vdf[col] = label_encoder.transform(vdf[col])
138 |
139 | # 返回进行标签编码后的 DataFrame
140 | return vdf
141 |
142 | def Ordinal_Cate_Features(vdf):
143 | '''
144 | 对于有序的类别数据,构建映射,将类别数据转换为0~n-1的整数
145 | '''
146 | vdf['education'] = vdf['education'].replace(
147 | {
148 | "Preschool":0,
149 | "1st-4th":1,
150 | "5th-6th":2,
151 | "7th-8th":3,
152 | "9th":4,
153 | "10th":5,
154 | "11th":6,
155 | "12th":7,
156 | "HS-grad":8,
157 | "Some-college":9,
158 | "Assoc-voc":10,
159 | "Assoc-acdm":11,
160 | "Bachelors":12,
161 | "Masters":13,
162 | "Prof-school":14,
163 | "Doctorate":15
164 | }
165 | )
166 | return vdf
167 |
168 | def One_Hot_Function(vdf):
169 | '''
170 | 对于无序类别数据,采用one-hot编码
171 | '''
172 | # 定义需要进行one-hot编码的列名
173 | onehot_cols = ['workclass', 'marital-status', 'occupation', 'relationship', 'race', 'native-country']
174 |
175 | # 创建一个 OneHotEncoder 对象,用于将分类数据转换为one-hot编码
176 | onehot_encoder = OneHotEncoder()
177 |
178 | # 拟合one-hot编码器到指定列的数据
179 | onehot_encoder.fit(vdf[onehot_cols])
180 |
181 | # 创建一个新的 DataFrame vdf_hat,删除需要one-hot编码的列
182 | vdf_hat = vdf.drop(columns=onehot_cols)
183 |
184 | # 将指定列的数据转换为one-hot编码
185 | enc_feats = onehot_encoder.transform(vdf[onehot_cols])
186 |
187 | # 获取one-hot编码后的特征名称
188 | features_name = enc_feats.columns
189 |
190 | # 从原始 DataFrame 中删除需要one-hot编码的列
191 | vdf = vdf.drop(columns=onehot_cols)
192 |
193 | # 将one-hot编码后的特征添加到 DataFrame 中
194 | vdf[features_name] = enc_feats
195 |
196 | # 返回进行one-hot编码后的 DataFrame 和删除指定列后的 DataFrame
197 | return vdf, vdf_hat
198 |
199 | def standard_scaler_func(vdf):
200 | '''
201 | 对数值进行标准化
202 | '''
203 | # 创建 vdf 的副本,以避免对原始数据进行修改
204 | vdf = vdf.copy()
205 |
206 | # 从 vdf 中删除 'income' 列,并将剩余的数据存储在 X 中
207 | X = vdf.drop(columns=['income'])
208 |
209 | # 将 'income' 列的数据存储在 y 中
210 | y = vdf['income']
211 |
212 | # 创建一个 StandardScaler 对象,用于标准化数据
213 | scaler = StandardScaler()
214 |
215 | # 拟合标准化器并转换 X 中的数据
216 | X = scaler.fit_transform(X)
217 |
218 | # 将标准化后的数据赋值回 vdf 的相应列
219 | vdf[X.columns] = X
220 |
221 | # 返回标准化后的 DataFrame
222 | return vdf
223 |
224 | def split_train_test(vdf, train_size, random_state):
225 | '''
226 | 训练集和测试集拆分
227 | '''
228 | # 使用 train_test_split 函数将数据集拆分为训练集和测试集
229 | train_vdf, test_vdf = train_test_split(vdf, train_size=train_size, random_state=random_state)
230 |
231 | # 从训练集中删除 'income' 列,并将剩余的数据存储在 train_X 中
232 | train_X = train_vdf.drop(columns=['income'])
233 |
234 | # 将训练集中的 'income' 列存储在 train_y 中
235 | train_y = train_vdf['income']
236 |
237 | # 从测试集中删除 'income' 列,并将剩余的数据存储在 test_X 中
238 | test_X = test_vdf.drop(columns=['income'])
239 |
240 | # 将测试集中的 'income' 列存储在 test_y 中
241 | test_y = test_vdf['income']
242 |
243 | # 返回训练集和测试集的特征和标签
244 | return train_X, test_X, train_y, test_y
245 |
246 | def vdataframe_to_spu(vdf):
247 | '''
248 | 将VDataFrame数据类型转换为SPUObject
249 | '''
250 | # 创建一个空列表,用于存储每个设备上的 SPU 分区
251 | spu_partitions = []
252 |
253 | # 遍历 vdf 的每个分区
254 | for device in vdf.partitions:
255 | # 将每个分区的数据转换为 SPU 格式,并添加到 spu_partitions 列表中
256 | spu_partitions.append(vdf.partitions[device].data.to(spu))
257 |
258 | # 取出第一个 SPU 分区作为基础分区
259 | base_partition = spu_partitions[0]
260 |
261 | # 遍历剩余的 SPU 分区
262 | for i in range(1, len(spu_partitions)):
263 | # 使用 SPU 计算,将当前基础分区与下一个 SPU 分区在轴 1 上进行拼接
264 | base_partition = spu(lambda x, y: jnp.concatenate([x, y], axis=1))(
265 | base_partition, spu_partitions[i]
266 | )
267 |
268 | # 返回拼接后的基础分区
269 | return base_partition
270 |
271 | def convert_to_spu(train_X, test_X, train_y, test_y):
272 | '''
273 | 将训练集特征数据转换为 SPU 格式
274 | '''
275 | X_train_spu = vdataframe_to_spu(train_X)
276 |
277 | # 将训练集标签数据转换为 SPU 格式
278 | y_train_spu = train_y.partitions[carol].data.to(spu)
279 |
280 | # 将测试集特征数据转换为 SPU 格式
281 | X_test_spu = vdataframe_to_spu(test_X)
282 |
283 | # 将测试集标签数据转换为 SPU 格式
284 | y_test_spu = test_y.partitions[carol].data.to(spu)
285 |
286 | # 返回转换后的训练集和测试集特征及标签数据
287 | return X_train_spu, X_test_spu, y_train_spu, y_test_spu
288 |
289 |
290 | def data_preprocessing(full_file_path):
291 | '''
292 | 数据预处理,包括加载数据集,数据集切分
293 | '''
294 | # 给各个参与方分配临时文件路径
295 | _, alice_path = tempfile.mkstemp()
296 | _, bob_path = tempfile.mkstemp()
297 | _, carol_path = tempfile.mkstemp()
298 |
299 | data = load_dataset(full_file_path)# 加载数据集
300 | # print(data.head()) # 查看前几行数据以了解数据集结构
301 | split_dataset(data,alice_path,bob_path,carol_path) # 将数据集切分为三方数据集
302 | return alice_path,bob_path,carol_path
303 |
304 | def data_process(alice_path,bob_path,carol_path):
305 | '''
306 | 数据处理
307 | '''
308 | vdf = secret_psi(alice_path,bob_path,carol_path) # 隐私求交实现数据对齐
309 | # print(vdf) # 查看求交结果
310 | vdf_1 = vdf.copy()# 复制求交结果
311 | vdf_1 = Missing_Value_Filling(vdf_1) # 缺失值填充
312 | vdf_1 = label_encode_function(vdf_1) # 将无序且二值的序列转换为0/1表示
313 | vdf_1 = Ordinal_Cate_Features(vdf_1) # 对于有序的类别数据,构建映射,将类别数据转换为0~n-1的整数
314 | vdf_1, vdf_hat = One_Hot_Function(vdf_1) # 对于无序类别数据,采用one-hot编码
315 | vdf_1 = standard_scaler_func(vdf_1) # 对数值进行标准化
316 |
317 | # 训练集和测试集拆分
318 | train_size = 0.8 # 训练集占比
319 | random_state = 1234 # 随机种子
320 | train_X,test_X,train_y,test_y = split_train_test(vdf_1,train_size,random_state)
321 |
322 | # 数据类型转换
323 | X_train_spu,X_test_spu,y_train_spu,y_test_spu = convert_to_spu(train_X,test_X,train_y,test_y)
324 |
325 | # 构建返回字典
326 | results = {
327 | 'X_train_spu': X_train_spu,
328 | 'X_test_spu': X_test_spu,
329 | 'y_train_spu': y_train_spu,
330 | 'y_test_spu': y_test_spu,
331 | 'vdf': vdf_1,
332 | 'vdf_hat': vdf_hat
333 | }
334 | return results
335 |
336 | # 文件的下载和存储路径
337 | full_file_path = './adult_data.csv'
338 |
339 | # 配置SPU相关设备
340 | sf.shutdown()# 关闭所有SPU设备
341 | sf.init(['alice','bob','carol'],address='local')
342 | aby3_config = sf.utils.testing.cluster_def(parties=['alice', 'bob', 'carol'])
343 | spu = sf.SPU(aby3_config)
344 | alice = sf.PYU('alice')
345 | bob = sf.PYU('bob')
346 | carol = sf.PYU('carol')
347 |
348 | # 数据集加载和切分
349 | alice_path,bob_path,carol_path = data_preprocessing(full_file_path)
350 |
351 | '''
352 | 数据处理
353 | 这里的vdf是psi意思求交的结果
354 | vdf_hat是数据处理之后的结果,但是不包括one-hot编码的结果
355 | 处理结果是包括X_train_spu,X_test_spu,y_train_spu,y_test_spu的字典
356 | '''
357 | data_dict = data_process(alice_path,bob_path,carol_path)
358 |
359 | # 输出处理后的数据
360 | # print(sf.reveal(data_dict['X_train_spu']))
361 | # print(sf.reveal(data_dict['X_test_spu']))
362 | # print(sf.reveal(data_dict['y_train_spu']))
363 | # print(sf.reveal(data_dict['y_test_spu']))
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/mlp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import secretflow as sf
3 | from secretflow.data.vertical import read_csv as v_read_csv
4 | from secretflow.preprocessing import LabelEncoder
5 | from secretflow.preprocessing import OneHotEncoder
6 | from secretflow.preprocessing import StandardScaler
7 | from secretflow.data.split import train_test_split
8 | from secretflow.stats.table_statistics import table_statistics
9 | from secretflow.stats.ss_vif_v import VIF
10 | from secretflow.stats.ss_pearsonr_v import PearsonR
11 | import jax.numpy as jnp
12 | import jax
13 | import pandas as pd
14 | import numpy as np
15 | import tempfile
16 | import os
17 | import matplotlib.pyplot as plt
18 | import sys
19 | import subprocess
20 | from data_process import spu, alice, bob, carol, data_dict
21 |
22 | def install_package(package_name):
23 | '''
24 | 安装指定的Python包
25 | '''
26 | try:
27 | # 使用 subprocess 执行 pip 安装命令,静默安装指定的包
28 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', package_name, '-q'])
29 | except subprocess.CalledProcessError as e:
30 | # 安装出错则打印错误信息
31 | print(f"Error occurred while installing {package_name}: {e}")
32 | else:
33 | # 安装成功后打印成功信息
34 | print(f"Successfully installed {package_name}")
35 | return
36 |
37 | # 安装0.6.0版本的flax包,用于构建神经网络
38 | install_package('flax==0.6.0')
39 |
40 | # 安装seaborn包,用于数据可视化
41 | install_package('seaborn')
42 |
43 | import seaborn as sns
44 | from typing import Sequence
45 | import flax.linen as nn
46 | from sklearn.metrics import roc_auc_score
47 |
48 | class MLP(nn.Module): # 定义一个多层感知器(MLP)类
49 | features: Sequence[int] # 每一层神经元的数量
50 | dropout_rate: float # Dropout 概率
51 | @nn.compact
52 | def __call__(self, x, train, rngs=None): # 定义前向传播过程
53 | for feat in self.features[:-1]: # 遍历每一层(除了最后一层)
54 | x = nn.relu(nn.Dense(feat)(x)) # 应用全连接层和 ReLU 激活函数
55 | x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) # 应用 Dropout
56 | x = nn.BatchNorm(use_running_average=not train, momentum=0.5, epsilon=1e-5)(x) # 应用 Batch Normalization
57 | x = nn.Dense(self.features[-1])(x) # 最后一层只应用全连接层
58 | return x # 返回输出
59 |
60 |
61 | def predict(params, x, train=False, rng_key=None):
62 | '''
63 | 预测函数,传入权重偏置和输入,训练和测试都要用到
64 | '''
65 | from typing import Sequence
66 | import flax.linen as nn
67 |
68 | class MLP(nn.Module): # 定义一个多层感知器(MLP)类
69 | features: Sequence[int] # 每一层神经元的数量
70 | dropout_rate: float # Dropout 概率
71 | @nn.compact
72 | def __call__(self, x, train, rngs=None): # 定义前向传播过程
73 | for feat in self.features[:-1]: # 遍历每一层(除了最后一层)
74 | x = nn.relu(nn.Dense(feat)(x)) # 应用全连接层和 ReLU 激活函数
75 | x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) # 应用 Dropout
76 | x = nn.BatchNorm(use_running_average=not train, momentum=0.5, epsilon=1e-5)(x) # 应用 Batch Normalization
77 | x = nn.Dense(self.features[-1])(x) # 最后一层只应用全连接层
78 | return x # 返回输出
79 | # FEATURE :每一层神经元的数目
80 | FEATURES = [dim, 15, 8, 1] # 定义每一层神经元的数量
81 | flax_nn = MLP(features=FEATURES, dropout_rate=0.1) # 创建 MLP 实例,并设置 Dropout 概率为 0.1
82 | if rng_key is None: # 如果没有提供随机数生成器的键
83 | rng_key = jax.random.PRNGKey(0) # 使用默认的 PRNG 键
84 | y, updates = flax_nn.apply( # 调用 MLP 的 apply 方法进行前向传播
85 | params, # 模型参数
86 | x, # 输入数据
87 | train, # 是否处于训练模式
88 | mutable=['batch_stats'], # 指示 Batch Normalization 的统计数据是可变的
89 | rngs={'dropout': rng_key} # 传入用于 Dropout 的随机数生成器
90 | )
91 | batch_stats = updates['batch_stats'] # 获取 Batch Normalization 的统计数据
92 | return y # 返回模型的预测结果
93 |
94 | def loss_func(params, x, y, rng_key):
95 | '''
96 | 使用MSE作为损失函数
97 | '''
98 | # 调用 predict 函数,使用给定的参数 params 和输入数据 x 进行预测,得到预测值 pred
99 | # 传递 train=True 和随机数生成键 rng_key 作为额外参数
100 | pred = predict(params, x, train=True, rng_key=rng_key)
101 | # 定义均方误差(MSE)函数
102 | def mse(y, pred):
103 | # 定义平方误差函数
104 | def squared_error(y, y_pred):
105 | # 计算每个样本的平方误差,并除以 2.0
106 | return jnp.multiply(y - y_pred, y - y_pred) / 2.0
107 | # 计算所有样本的平均平方误差
108 | return jnp.mean(squared_error(y, pred))
109 | # 调用 mse 函数,计算并返回 y 和 pred 之间的均方误差
110 | return mse(y, pred)
111 |
112 | def train_auto_grad(X, y, params, batch_size=10, epochs=10, learning_rate=0.01):
113 | '''
114 | 模型训练
115 | '''
116 | # 将输入数据 X 和标签 y 按照 batch_size 分割成多个小批次
117 | xs = jnp.array_split(X, len(X) // batch_size, axis=0)
118 | ys = jnp.array_split(y, len(y) // batch_size, axis=0)
119 |
120 | # 打印输入数据的shape
121 | #print(X.shape)
122 |
123 | # 初始化随机数生成器的键
124 | rng_key = jax.random.PRNGKey(0)
125 |
126 | # 进行多个 epoch 的训练
127 | for epoch in range(epochs):
128 | # 遍历每个小批次的数据
129 | for batch_x, batch_y in zip(xs, ys):
130 | # 计算当前批次的损失和梯度
131 | loss, grads = jax.value_and_grad(loss_func)(params, batch_x, batch_y, rng_key)
132 |
133 | # 使用梯度下降法更新模型参数
134 | params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grads)
135 |
136 | # 返回更新后的模型参数
137 | return params
138 |
139 | class MLP_spu(nn.Module): # 定义一个多层感知器(MLP)类 spu版本
140 | features: Sequence[int] # 每一层神经元的数量
141 | @nn.compact
142 | def __call__(self, x): # 定义前向传播过程
143 | for feat in self.features[:-1]: # 遍历每一层(除了最后一层)
144 | x = nn.relu(nn.Dense(feat)(x)) # 应用全连接层和 ReLU 激活函数
145 | x = nn.Dense(self.features[-1])(x) # 最后一层只应用全连接层
146 | return x # 返回输出
147 |
148 | def predict_spu(params, x):
149 | '''
150 | spu版本的预测函数
151 | '''
152 | # 从 typing 模块导入 Sequence 类型
153 | from typing import Sequence
154 | # 从 flax.linen 模块导入 nn
155 | import flax.linen as nn
156 |
157 | # 定义一个多层感知器(MLP)类,适用于 SPU
158 | class MLP_spu(nn.Module):
159 | # 定义每一层神经元的数量
160 | features: Sequence[int]
161 |
162 | # 定义前向传播过程
163 | @nn.compact
164 | def __call__(self, x):
165 | # 遍历每一层(除了最后一层)
166 | for feat in self.features[:-1]:
167 | # 应用全连接层和 ReLU 激活函数
168 | x = nn.relu(nn.Dense(feat)(x))
169 | # 最后一层只应用全连接层
170 | x = nn.Dense(self.features[-1])(x)
171 | # 返回输出
172 | return x
173 |
174 | # 定义每一层神经元的数量,FEATURES 列表
175 | FEATURES = [dim, 15, 8, 1]
176 |
177 | # 创建一个 MLP_spu 实例,传入 FEATURES 作为参数
178 | flax_nn = MLP_spu(features=FEATURES)
179 |
180 | # 使用给定的参数 params 和输入数据 x 进行预测,返回预测结果
181 | return flax_nn.apply(params, x)
182 |
183 | def loss_func_spu(params, x, y):
184 | '''
185 | spu版本的损失函数
186 | '''
187 | # 使用给定的参数 params 和输入数据 x 进行预测,得到预测值 pred
188 | pred = predict_spu(params, x)
189 | # 定义均方误差(MSE)函数
190 | def mse(y, pred):
191 | # 定义平方误差函数
192 | def squared_error(y, y_pred):
193 | # 计算每个样本的平方误差,并除以 2.0
194 | return jnp.multiply(y - y_pred, y - y_pred) / 2.0
195 |
196 | # 计算所有样本的平均平方误差
197 | return jnp.mean(squared_error(y, pred))
198 |
199 | # 调用 mse 函数,计算并返回 y 和 pred 之间的均方误差
200 | return mse(y, pred)
201 |
202 | def train_auto_grad_spu(X, y, params, batch_size=10, epochs=10, learning_rate=0.01):
203 | '''
204 | spu版本的模型训练函数
205 | '''
206 | # 将输入数据 X 按照 batch_size 分割成多个小批量,存储在 xs 列表中
207 | xs = jnp.array_split(X, len(X) // batch_size, axis=0)
208 |
209 | # 将目标数据 y 按照 batch_size 分割成多个小批量,存储在 ys 列表中
210 | ys = jnp.array_split(y, len(y) // batch_size, axis=0)
211 |
212 | # 打印输入数据 X 的形状
213 | # print(X.shape)
214 |
215 | # 进行 epochs 次训练迭代
216 | for epoch in range(epochs):
217 | # 遍历每个小批量数据
218 | for batch_x, batch_y in zip(xs, ys):
219 | # 计算当前批量数据的损失值和梯度
220 | loss, grads = jax.value_and_grad(loss_func_spu)(params, batch_x, batch_y)
221 |
222 | # 更新模型参数,使用梯度下降法
223 | params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grads)
224 |
225 | # 返回训练后的模型参数
226 | return params
227 |
228 | def cpu_version_mlp(X_train_plaintext, y_train_plaintext, init_params, batch_size, epochs, learning_rate):
229 | '''
230 | 在 CPU 上训练和评估MLP模型
231 | '''
232 | # 使用自动微分方法训练模型,返回训练后的参数
233 | params = train_auto_grad(
234 | X_train_plaintext, y_train_plaintext, init_params, batch_size, epochs, learning_rate
235 | )
236 |
237 | # 设置用于预测的随机数生成器的键
238 | rng_key = jax.random.PRNGKey(1)
239 |
240 | # 使用训练后的参数进行预测
241 | y_pred = predict(params, X_test_plaintext, train=False)
242 |
243 | # 计算并打印模型的 AUC 分数
244 | os.system('clear')
245 | print(f"\033[31m(Flax NN CPU) auc: {roc_auc_score(y_test_plaintext, y_pred)}\033[0m")
246 |
247 | def spu_version_mlp(X_train_spu, y_train_spu, params_spu, batch_size, epochs, learning_rate):
248 | '''
249 | 在 SPU 上训练和评估MLP模型
250 | '''
251 | # 使用 SPU 环境中的 train_auto_grad 函数训练模型,返回训练后的参数
252 | params_spu = spu(
253 | train_auto_grad_spu, static_argnames=['batch_size', 'epochs', 'learning_rate']
254 | )(
255 | X_train_spu, # 训练数据
256 | y_train_spu, # 训练标签
257 | params_spu, # 初始参数
258 | batch_size=batch_size, # 批次大小
259 | epochs=epochs, # 训练轮数
260 | learning_rate=learning_rate # 学习率
261 | )
262 |
263 | # 使用 SPU 环境中的 predict 函数进行预测
264 | y_pred_spu = spu(predict_spu)(params_spu, X_test_spu)
265 |
266 | # 将预测结果从 SPU 环境中揭示出来
267 | y_pred_ = sf.reveal(y_pred_spu)
268 |
269 | # 计算并打印模型的 AUC 分数
270 | print(f"\033[31m(Flax NN SPU) auc: {roc_auc_score(y_test_plaintext, y_pred_)}\033[0m")
271 |
272 | if __name__ == '__main__':
273 | # 获取数据集
274 | X_train_spu = data_dict['X_train_spu']
275 | y_train_spu = data_dict['y_train_spu']
276 | X_test_spu = data_dict['X_test_spu']
277 | y_test_spu = data_dict['y_test_spu']
278 |
279 | # 将 SPU 上的训练和测试数据揭露为明文数据
280 | X_train_plaintext = sf.reveal(X_train_spu) # 揭露训练集特征数据
281 | y_train_plaintext = sf.reveal(y_train_spu) # 揭露训练集目标数据
282 | X_test_plaintext = sf.reveal(X_test_spu) # 揭露测试集特征数据
283 | y_test_plaintext = sf.reveal(y_test_spu) # 揭露测试集目标数据
284 |
285 | # 获取训练集特征数据的维度
286 | dim = X_train_plaintext.shape[1]
287 |
288 | # 定义每一层神经元的数量
289 | FEATURES = [dim, 15, 8, 1]
290 |
291 | # 创建 CPU 版本的多层感知器(MLP)实例,设置 Dropout 概率为 0.1
292 | flax_nn = MLP(features=FEATURES, dropout_rate=0.1)
293 |
294 | # 创建 SPU 版本的多层感知器(MLP)实例
295 | flax_nn_spu = MLP_spu(features=FEATURES)
296 |
297 | # 根据数据集的特征维度设置特征维度
298 | feature_dim = dim
299 |
300 | if len(sys.argv[1])==3:
301 | # 设置模型训练的参数
302 | epochs = sys.argv[1][0] # CPU 版本的训练轮数
303 | learning_rate = sys.argv[1][1] # CPU 版本的学习率
304 | batch_size = sys.argv[1][2] # CPU 版本的批量大小
305 | epochs_spu = sys.argv[1][0] # SPU 版本的训练轮数
306 | learning_rate_spu = sys.argv[1][1] # SPU 版本的学习率
307 | batch_size_spu = sys.argv[1][2] # SPU 版本的批量大小
308 | else:
309 | epochs = 2 # CPU 版本的训练轮数
310 | learning_rate = 0.02 # CPU 版本的学习率
311 | batch_size = 100 # CPU 版本的批量大小
312 | epochs_spu = 2 # SPU 版本的训练轮数
313 | learning_rate_spu = 0.02 # SPU 版本的学习率
314 | batch_size_spu = 100 # SPU 版本的批量大小
315 | # 初始化 CPU 版本的模型参数,使用随机数生成键和全 1 的输入数据
316 | init_params = flax_nn.init(jax.random.PRNGKey(1), jnp.ones((batch_size, feature_dim)), train=False)
317 |
318 | # 初始化 SPU 版本的模型参数,使用随机数生成键和全 1 的输入数据
319 | init_params_spu = flax_nn_spu.init(jax.random.PRNGKey(1), jnp.ones((batch_size, feature_dim)))
320 |
321 | # 将初始化的 SPU 版本模型参数从 Alice 传递到 SPU
322 | params = sf.to(alice, init_params_spu).to(spu)
323 |
324 | # 打印训练集特征数据的形状
325 | # print(X_train_plaintext.shape)
326 |
327 | # 调用 CPU 版本的 MLP 训练函数,传入训练集特征数据、目标数据、初始化参数、批量大小、训练轮数和学习率
328 | cpu_version_mlp(X_train_plaintext, y_train_plaintext, init_params, batch_size, epochs, learning_rate)
329 |
330 | # 调用 SPU 版本的 MLP 训练函数,传入训练集特征数据、目标数据、初始化参数、批量大小、训练轮数和学习率
331 | spu_version_mlp(X_train_spu, y_train_spu, params, batch_size_spu, epochs_spu, learning_rate_spu)
--------------------------------------------------------------------------------
/neural_network_gui.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import subprocess
3 | from PyQt5.QtWidgets import (
4 | QApplication, QMainWindow, QWidget, QVBoxLayout, QPushButton, QLabel, QStackedWidget, QSizePolicy, QTextEdit, QLineEdit, QHBoxLayout
5 | )
6 | from PyQt5.QtCore import Qt
7 | from PyQt5.QtGui import QFont
8 |
9 | def filter_output(output, keywords):
10 | """
11 | 过滤掉包含任意关键字的行。
12 |
13 | :param output: 要过滤的原始输出字符串。
14 | :param keywords: 需要过滤的关键字列表。
15 | :return: 过滤后的字符串。
16 | """
17 | lines = output.splitlines()
18 | filtered_lines = [line for line in lines if not any(keyword in line for keyword in keywords)]
19 | return '\n'.join(filtered_lines)
20 |
21 | class MainWindow(QMainWindow):
22 | def __init__(self):
23 | super().__init__()
24 |
25 | # 设置主窗口标题和尺寸
26 | self.setWindowTitle("Gryffindor - Neural Network Based on SecretFlow")
27 | self.setGeometry(100, 100, 1000, 800) # 增加窗口尺寸
28 |
29 | # 创建堆栈窗口用于页面切换
30 | self.stacked_widget = QStackedWidget()
31 | self.setCentralWidget(self.stacked_widget)
32 |
33 | # 创建主页面
34 | self.main_page = QWidget()
35 | self.stacked_widget.addWidget(self.main_page)
36 |
37 | # 设置主页面布局
38 | self.main_layout = QVBoxLayout(self.main_page)
39 |
40 | # 显示参赛题目(占据主界面高度的30%)
41 | self.topic_label = QLabel(" Neural Network Based on SecretFlow", self)
42 | self.topic_label.setAlignment(Qt.AlignCenter)
43 | self.topic_label.setFont(QFont("黑体", 24, QFont.Bold)) # 保持字体为黑体,大小24
44 | self.topic_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.MinimumExpanding)
45 | self.main_layout.addWidget(self.topic_label)
46 |
47 | # 显示队名(占据主界面高度的20%)
48 | self.team_label = QLabel("Team: Gryffindor", self)
49 | self.team_label.setAlignment(Qt.AlignCenter)
50 | self.team_label.setFont(QFont("黑体", 20, QFont.Bold)) # 保持字体为黑体,大小20
51 | self.team_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.MinimumExpanding)
52 | self.main_layout.addWidget(self.team_label)
53 |
54 | # 创建功能按钮(占据主界面高度的50%)
55 | self.function1_button = QPushButton("Function 1: Data Analysis", self)
56 | self.function2_button = QPushButton("Function 2: Run Neural Network", self)
57 | self.function1_button.setFont(QFont("宋体", 14)) # 保持默认字体,仅调整大小为14
58 | self.function2_button.setFont(QFont("宋体", 14)) # 保持默认字体,仅调整大小为14
59 | self.function1_button.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.MinimumExpanding)
60 | self.function2_button.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.MinimumExpanding)
61 | self.main_layout.addWidget(self.function1_button)
62 | self.main_layout.addWidget(self.function2_button)
63 |
64 | # 连接功能按钮点击事件
65 | self.function1_button.clicked.connect(self.show_function1)
66 | self.function2_button.clicked.connect(self.show_neural_network_page)
67 |
68 | # 创建第一个功能的子功能页面
69 | self.create_function1_pages()
70 |
71 | def create_function1_pages(self):
72 | # 创建第一个功能页面
73 | self.function1_page = QWidget()
74 | self.function1_layout = QVBoxLayout(self.function1_page)
75 | self.stacked_widget.addWidget(self.function1_page)
76 |
77 | # 创建子功能按钮,并改名
78 | self.sub_function1_button = QPushButton("Sub Function 1: Full Table Statistics", self)
79 | self.sub_function2_button = QPushButton("Sub Function 2: VIF Multicollinearity Test", self)
80 | self.sub_function3_button = QPushButton("Sub Function 3: Correlation Coefficient Matrix", self)
81 | self.sub_function4_button = QPushButton("Sub Function 4: Data Visual", self)
82 | self.sub_function1_button.setFont(QFont("宋体", 12)) # 保持默认字体,仅调整大小为17
83 | self.sub_function2_button.setFont(QFont("宋体", 12)) # 保持默认字体,仅调整大小为17
84 | self.sub_function3_button.setFont(QFont("宋体", 12)) # 保持默认字体,仅调整大小为17
85 | self.sub_function4_button.setFont(QFont("宋体", 12)) # 保持默认字体,仅调整大小为17
86 | self.function1_layout.addWidget(self.sub_function1_button)
87 | self.function1_layout.addWidget(self.sub_function2_button)
88 | self.function1_layout.addWidget(self.sub_function3_button)
89 | self.function1_layout.addWidget(self.sub_function4_button)
90 |
91 | # 增加回退按钮,返回主页面,并调整按钮尺寸
92 | self.back_to_main_button = QPushButton("Back to Main Menu", self)
93 | self.back_to_main_button.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum)
94 | self.function1_layout.addWidget(self.back_to_main_button)
95 | self.back_to_main_button.clicked.connect(self.go_back_to_main)
96 |
97 | # 增加按钮尺寸
98 | self.sub_function1_button.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
99 | self.sub_function2_button.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
100 | self.sub_function3_button.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
101 | self.sub_function4_button.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
102 |
103 | # 连接子功能按钮点击事件
104 | self.sub_function1_button.clicked.connect(lambda: self.show_output_page("network_demo/full_table_statistics.py", "Full Table Statistics Output"))
105 | self.sub_function2_button.clicked.connect(lambda: self.show_output_page("network_demo/multicollinearity_test.py", "VIF Multicollinearity Test Output"))
106 | self.sub_function3_button.clicked.connect(lambda: self.show_output_page("network_demo/corr_coefficient_matrix.py", "Correlation Coefficient Matrix Output"))
107 | self.sub_function4_button.clicked.connect(self.show_data_visual_page)
108 |
109 | def show_function1(self):
110 | self.stacked_widget.setCurrentWidget(self.function1_page)
111 |
112 | def show_output_page(self, script_path, page_title, main_menu=False):
113 | # 创建输出页面
114 | output_page = QWidget()
115 | output_layout = QVBoxLayout(output_page)
116 | self.stacked_widget.addWidget(output_page)
117 |
118 | # 创建输出文本框
119 | output_text = QTextEdit(output_page)
120 | output_text.setReadOnly(True)
121 | output_layout.addWidget(output_text)
122 |
123 | keywords_to_filter = ["pid", "SPURuntime", "info"]
124 |
125 | try:
126 | result = subprocess.run(["python", script_path], capture_output=True, text=True, check=True)
127 | filtered_output = filter_output(result.stdout, keywords_to_filter)
128 | output_text.append(f"{page_title}:\n\n{filtered_output}")
129 | except subprocess.CalledProcessError as e:
130 | filtered_error = filter_output(e.stderr, keywords_to_filter)
131 | output_text.append(f"An error occurred while running {script_path}:\n\n{filtered_error}")
132 |
133 | # 增加回退按钮,返回适当的菜单,并调整按钮尺寸
134 | back_button = QPushButton("Back to Main Menu" if main_menu else "Back to Function 1 Menu", self)
135 | back_button.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum)
136 | output_layout.addWidget(back_button)
137 | back_button.clicked.connect(self.go_back_to_main if main_menu else self.go_back_to_function1)
138 |
139 | # 显示输出页面
140 | self.stacked_widget.setCurrentWidget(output_page)
141 |
142 | def show_data_visual_page(self):
143 | # 创建Data Visual交互页面
144 | data_visual_page = QWidget()
145 | data_visual_layout = QVBoxLayout(data_visual_page)
146 | self.stacked_widget.addWidget(data_visual_page)
147 |
148 | # 输入框提示和输入框
149 | input_label = QLabel("Enter a value (1-4):"+ "\n"+
150 | """ | 1. alice 数据可视化\t\t|
151 | | 2. bob 数据可视化\t\t|
152 | | 3. carol 数据可视化\t\t|
153 | | 4. full_file_path 数据可视化\t| """, self)
154 | data_visual_layout.addWidget(input_label)
155 |
156 | input_edit = QLineEdit(self)
157 | data_visual_layout.addWidget(input_edit)
158 |
159 | # 结果展示框
160 | result_text = QTextEdit(self)
161 | result_text.setReadOnly(True)
162 | data_visual_layout.addWidget(result_text)
163 |
164 | # 运行按钮
165 | run_button = QPushButton("Run Data Visual", self)
166 | data_visual_layout.addWidget(run_button)
167 |
168 | # 运行逻辑
169 | def run_data_visual():
170 | choice = input_edit.text()
171 |
172 | # 根据用户输入的值调整命令
173 | dv_command = ["python", "network_demo/data_visual.py", choice]
174 | keywords_to_filter = ["pid", "SPURuntime", "info"]
175 | try:
176 | result = subprocess.run(dv_command, capture_output=True, text=True, check=True)
177 | filtered_output = filter_output(result.stdout, keywords_to_filter)
178 | result_text.append(filtered_output)
179 | except subprocess.CalledProcessError as e:
180 | result_text.append(f"An error occurred:\n{e.stderr}")
181 | run_button.clicked.connect(run_data_visual)
182 |
183 | # 回退按钮
184 | back_button = QPushButton("Back to Function 1 Menu", self)
185 | back_button.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum)
186 | data_visual_layout.addWidget(back_button)
187 | back_button.clicked.connect(self.go_back_to_function1)
188 |
189 | # 显示页面
190 | self.stacked_widget.setCurrentWidget(data_visual_page)
191 |
192 | def show_neural_network_page(self):
193 | # 创建运行神经网络功能的页面
194 | neural_network_page = QWidget()
195 | neural_network_layout = QVBoxLayout(neural_network_page)
196 | self.stacked_widget.addWidget(neural_network_page)
197 |
198 | # 输入框设置
199 | epochs_label = QLabel("Enter epochs:(默认参数需要输入2)", self)
200 | learning_rate_label = QLabel("Enter learning rate:(默认参数需要输入0.02)", self)
201 | batch_size_label = QLabel("Enter batch size:(默认参数需要输入100)", self)
202 |
203 | epochs_edit = QLineEdit(self)
204 | learning_rate_edit = QLineEdit(self)
205 | batch_size_edit = QLineEdit(self)
206 |
207 | neural_network_layout.addWidget(epochs_label)
208 | neural_network_layout.addWidget(epochs_edit)
209 | neural_network_layout.addWidget(learning_rate_label)
210 | neural_network_layout.addWidget(learning_rate_edit)
211 | neural_network_layout.addWidget(batch_size_label)
212 | neural_network_layout.addWidget(batch_size_edit)
213 |
214 | # 结果展示框
215 | nn_result_text = QTextEdit(self)
216 | nn_result_text.setReadOnly(True)
217 | neural_network_layout.addWidget(nn_result_text)
218 |
219 | # 运行按钮
220 | nn_run_button = QPushButton("Run Neural Network", self)
221 | neural_network_layout.addWidget(nn_run_button)
222 | keywords_to_filter = ["pid", "SPURuntime", "info"]
223 | # 运行逻辑
224 | def run_neural_network():
225 | epochs = epochs_edit.text()
226 | learning_rate = learning_rate_edit.text()
227 | batch_size = batch_size_edit.text()
228 |
229 | nn_command = [
230 | "python", "network_demo/mlp.py",
231 | "--epochs", epochs,
232 | "--learning_rate", learning_rate,
233 | "--batch_size", batch_size
234 | ]
235 |
236 | try:
237 | result = subprocess.run(nn_command, capture_output=True, text=True, check=True)
238 | filtered_output = filter_output(result.stdout, keywords_to_filter)
239 | nn_result_text.append(filtered_output)
240 | except subprocess.CalledProcessError as e:
241 | nn_result_text.append(f"An error occurred:\n{e.stderr}")
242 | nn_run_button.clicked.connect(run_neural_network)
243 |
244 | # 回退按钮
245 | back_button = QPushButton("Back to Main Menu", self)
246 | back_button.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum)
247 | neural_network_layout.addWidget(back_button)
248 | back_button.clicked.connect(self.go_back_to_main)
249 |
250 | # 显示页面
251 | self.stacked_widget.setCurrentWidget(neural_network_page)
252 |
253 | def go_back_to_function1(self):
254 | self.stacked_widget.setCurrentWidget(self.function1_page)
255 |
256 | def go_back_to_main(self):
257 | self.stacked_widget.setCurrentWidget(self.main_page)
258 |
259 | if __name__ == "__main__":
260 | app = QApplication(sys.argv)
261 | main_window = MainWindow()
262 | main_window.show()
263 | sys.exit(app.exec_())
264 |
--------------------------------------------------------------------------------