├── multicollinearity_test.py ├── corr_coefficient_matrix.py ├── full_table_statistics.py ├── README.md ├── data_visual.py ├── data_process.py ├── LICENSE ├── mlp.py └── neural_network_gui.py /multicollinearity_test.py: -------------------------------------------------------------------------------- 1 | from secretflow.stats.ss_vif_v import VIF 2 | from data_process import data_dict,spu 3 | 4 | def VIF_calculation(vdf_hat): 5 | ''' 6 | 计算VIF来进行多重共线性检验 7 | ''' 8 | # 创建一个 VIF 计算器对象,传入 SPU 作为参数 9 | vif_calculator = VIF(spu) 10 | 11 | # 使用 VIF 计算器计算 vdf_hat 数据集的 VIF 值 12 | vif_results = vif_calculator.vif(vdf_hat) 13 | 14 | # 打印分隔符和多重共线性检验标题 15 | print("="*40 + "多重共线性检验" + "="*40) 16 | 17 | # 打印 vdf_hat 数据集的列名 18 | print(vdf_hat.columns) 19 | 20 | # 打印计算得到的 VIF 结果 21 | print(vif_results) 22 | 23 | if __name__ == '__main__': 24 | vdf_hat = data_dict['vdf_hat'] 25 | VIF_calculation(vdf_hat) -------------------------------------------------------------------------------- /corr_coefficient_matrix.py: -------------------------------------------------------------------------------- 1 | from data_process import data_dict,spu 2 | from secretflow.stats.ss_pearsonr_v import PearsonR 3 | import numpy as np 4 | 5 | def correlation_coefficient_matrix(vdf_hat): 6 | ''' 7 | 计算相关系数矩阵(利用皮尔逊相关系数) 8 | 在计算相关系数矩阵的时候要排除无序类别的数据 9 | ''' 10 | # 创建一个 PearsonR 对象,用于计算相关系数矩阵 11 | pearson_r = PearsonR(spu) 12 | 13 | # 使用 PearsonR 对象计算 vdf_hat 的相关系数矩阵 14 | corr_matrix = pearson_r.pearsonr(vdf_hat) 15 | 16 | # 打印标题,表示接下来输出的是相关系数矩阵 17 | print("==================== 相关系数矩阵 ====================\n") 18 | 19 | # 设置 numpy 的打印选项,格式化浮点数输出为小数点后三位 20 | np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)}) 21 | 22 | # 打印相关系数矩阵 23 | print(corr_matrix) 24 | 25 | if __name__ == '__main__': 26 | vdf_hat = data_dict['vdf_hat'] 27 | correlation_coefficient_matrix(vdf_hat) -------------------------------------------------------------------------------- /full_table_statistics.py: -------------------------------------------------------------------------------- 1 | from secretflow.stats.table_statistics import table_statistics 2 | from data_process import data_dict 3 | import os 4 | import tempfile 5 | import subprocess 6 | 7 | def full_table_statistics(vdf): 8 | ''' 9 | 全表统计并保存为Excel文件 10 | ''' 11 | # 生成全表统计数据 12 | data_stats = table_statistics(vdf) 13 | 14 | # 创建临时文件路径 15 | temp_dir = tempfile.gettempdir() # 获取系统临时目录 16 | output_file = os.path.join(temp_dir, 'full_table_stats.xlsx') # 设置文件路径 17 | 18 | # 将数据保存为Excel文件 19 | data_stats.to_excel(output_file, index=False) 20 | 21 | print(f"统计结果已保存至 {output_file}") 22 | 23 | # 打开生成的Excel文件 24 | try: 25 | # 使用 LibreOffice Calc 打开 Excel 文件 26 | subprocess.run(['libreoffice', '--calc', output_file], check=True) 27 | except subprocess.CalledProcessError as e: 28 | print(f"An error occurred while opening the Excel file: {e}") 29 | # 对于非Windows系统,可以使用以下代码: 30 | # os.system(f'open "{output_file}"') # macOS 31 | os.system(f'xdg-open "{output_file}"') # Linux 32 | return 33 | 34 | if __name__ == '__main__': 35 | full_table_statistics(data_dict['vdf']) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Neural Network Based on SecretFlow 3 | 4 | 本项目使用SecretFlow框架实现了多方安全计算,旨在通过联合数据训练,预测城市居民的年收入是否超过50k。该项目由城市内的三个政府机构合作进行,运用了多层感知器(MLP)模型来进行神经网络训练和预测。为了确保数据的准确性和可靠性,我们进行了全面的统计分析,包括多重共线性检验(VIF)和相关系数矩阵分析,并通过直方图、饼图等可视化方式呈现数据分析结果。此外,我们还开发了一个图形用户界面(GUI),使项目的操作更加简便直观。该项目的实施不仅体现了隐私保护技术在现实中的应用潜力,同时也为政府部门间的数据共享与合作提供了安全、高效的解决方案。 5 | 6 | 7 | 8 | [![Contributors][contributors-shield]][contributors-url] 9 | [![Forks][forks-shield]][forks-url] 10 | [![Stargazers][stars-shield]][stars-url] 11 | [![Issues][issues-shield]][issues-url] 12 | [![MIT License][license-shield]][license-url] 13 | 14 | 15 |
16 | 17 |

18 | 19 | Logo 20 | 21 | 22 |

"Neural Network Based on SecretFlow

23 |

24 | 一个使用SecretFlow、MLP模型来进行神经网络训练和预测的项目! 25 |
26 | 探索本项目的文档 » 27 |
28 |
29 | 查看Demo 30 | · 31 | 报告Bug 32 | · 33 | 提出新特性 34 |

35 | 36 |

37 | 38 | 39 | 40 | 41 | ## 目录 42 | 43 | - [上手指南](#上手指南) 44 | - [开发前的配置要求](#开发前的配置要求) 45 | - [安装步骤](#安装步骤) 46 | - [文件目录说明](#文件目录说明) 47 | - [开发的架构](#开发的架构) 48 | - [部署](#部署) 49 | - [使用到的框架](#使用到的框架) 50 | - [贡献者](#贡献者) 51 | - [如何参与开源项目](#如何参与开源项目) 52 | - [版本控制](#版本控制) 53 | - [作者](#作者) 54 | - [鸣谢](#鸣谢) 55 | 56 | ### 上手指南 57 | 58 | 请将所有链接中的“YnRen22852/secretflowgryffindor”改为“your_github_name/your_repository” 59 | 60 | 61 | 62 | ###### 开发前的配置要求 63 | 64 | 1. xxxxx x.x.x 65 | 2. xxxxx x.x.x 66 | 67 | ###### **安装步骤** 68 | 69 | 1. Get a free API Key at [https://example.com](https://example.com) 70 | 2. Clone the repo 71 | 72 | ```sh 73 | git clone https://github.com/YnRen22852/secretflowgryffindor.git 74 | ``` 75 | 76 | ### 文件目录说明 77 | eg: 78 | 79 | ``` 80 | filetree 81 | ├── ARCHITECTURE.md 82 | ├── LICENSE.txt 83 | ├── README.md 84 | ├── /account/ 85 | ├── /bbs/ 86 | ├── /docs/ 87 | │ ├── /rules/ 88 | │ │ ├── backend.txt 89 | │ │ └── frontend.txt 90 | ├── manage.py 91 | ├── /oa/ 92 | ├── /static/ 93 | ├── /templates/ 94 | ├── useless.md 95 | └── /util/ 96 | 97 | ``` 98 | 99 | 100 | 101 | 102 | 103 | ### 开发的架构 104 | 105 | 请阅读[ARCHITECTURE.md](https://github.com/YnRen22852/secretflowgryffindor/blob/master/ARCHITECTURE.md) 查阅为该项目的架构。 106 | 107 | ### 部署 108 | 109 | 暂无 110 | 111 | ### 使用到的框架 112 | 113 | - [xxxxxxx](https://getbootstrap.com) 114 | - [xxxxxxx](https://jquery.com) 115 | - [xxxxxxx](https://laravel.com) 116 | 117 | ### 贡献者 118 | 119 | 请阅读[贡献者](https://github.com/YnRen22852/secretflowgryffindor/graphs/contributors) 查阅为该项目做出贡献的开发者。 120 | 121 | #### 如何参与开源项目 122 | 123 | 贡献使开源社区成为一个学习、激励和创造的绝佳场所。你所作的任何贡献都是**非常感谢**的。 124 | 125 | 126 | 1. Fork the Project 127 | 2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`) 128 | 3. Commit your Changes (`git commit -m 'Add some AmazingFeature'`) 129 | 4. Push to the Branch (`git push origin feature/AmazingFeature`) 130 | 5. Open a Pull Request 131 | 132 | 133 | 134 | ### 版本控制 135 | 136 | 该项目使用Git进行版本管理。您可以在repository参看当前可用版本。 137 | 138 | ### 作者 139 | 140 | xxx@xxxx 141 | 142 | 知乎:xxxx   qq:xxxxxx 143 | 144 | *您也可以在贡献者名单中参看所有参与该项目的开发者。* 145 | 146 | ### 版权说明 147 | 148 | 该项目签署了Apache License 2.0授权许可,详情请参阅 [LICENSE](https://github.com/YnRen22852/secretflowgryffindor/blob/master/LICENSE) 149 | 150 | ### 鸣谢 151 | 152 | 153 | - [GitHub Emoji Cheat Sheet](https://www.webpagefx.com/tools/emoji-cheat-sheet) 154 | - [Img Shields](https://shields.io) 155 | - [Choose an Open Source License](https://choosealicense.com) 156 | - [GitHub Pages](https://pages.github.com) 157 | - [Animate.css](https://daneden.github.io/animate.css) 158 | - [xxxxxxxxxxxxxx](https://connoratherton.com/loaders) 159 | 160 | 161 | [your-project-path]:https://github.com/YnRen22852/secretflowgryffindor 162 | [contributors-shield]: https://img.shields.io/github/contributors/YnRen22852/secretflowgryffindor.svg?style=flat-square 163 | [contributors-url]: https://github.com/YnRen22852/secretflowgryffindor/graphs/contributors 164 | [forks-shield]: https://img.shields.io/github/forks/YnRen22852/secretflowgryffindor.svg?style=flat-square 165 | [forks-url]: https://github.com/YnRen22852/secretflowgryffindor/network/members 166 | [stars-shield]: https://img.shields.io/github/stars/YnRen22852/secretflowgryffindor.svg?style=flat-square 167 | [stars-url]: https://github.com/YnRen22852/secretflowgryffindor/stargazers 168 | [issues-shield]: https://img.shields.io/github/issues/YnRen22852/secretflowgryffindor.svg?style=flat-square 169 | [issues-url]: https://github.com/YnRen22852/secretflowgryffindor/issues 170 | [license-shield]: https://img.shields.io/github/license/YnRen22852/secretflowgryffindor.svg?style=flat-square 171 | [license-url]: https://github.com/YnRen22852/secretflowgryffindor/blob/master/LICENSE 172 | 173 | 174 | 175 | 176 | -------------------------------------------------------------------------------- /data_visual.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import pandas as pd 5 | from matplotlib.backends.backend_pdf import PdfPages 6 | import tempfile 7 | import subprocess 8 | from data_process import alice_path, bob_path, carol_path, full_file_path 9 | from mlp import install_package 10 | 11 | # 安装seaborn包,用于数据可视化 12 | # install_package('seaborn') 13 | 14 | def histogram(column_data, pdf_pages): 15 | ''' 16 | 绘制直方图 17 | ''' 18 | plt.figure(figsize=(10, 6)) 19 | 20 | # 使用 seaborn 库绘制直方图,包含核密度估计(kde),并设置直方图的柱数为 30 21 | sns.histplot(column_data, kde=True, bins=30) 22 | 23 | # 设置图表标题,标题内容为 'Histogram of ' 加上列名 24 | plt.title(f'Histogram of {column_data.name}') 25 | 26 | # 设置 x 轴标签,标签内容为列名 27 | plt.xlabel(column_data.name) 28 | 29 | # 设置 y 轴标签,标签内容为 'Frequency' 30 | plt.ylabel('Frequency') 31 | 32 | # 调整 x 轴刻度标签的旋转角度为 90 度 33 | plt.xticks(rotation=90) 34 | 35 | # 保存当前图表到 PDF 文件 36 | pdf_pages.savefig() 37 | 38 | # 关闭当前图表,释放内存 39 | plt.close() 40 | 41 | def bar_chart(column_data, pdf_pages): 42 | ''' 43 | 绘制条形图 44 | ''' 45 | plt.figure(figsize=(10, 6)) 46 | 47 | # 使用 seaborn 库绘制条形图,y 轴为 column_data,颜色调色板为 'viridis' 48 | ax = sns.countplot(y=column_data, palette='viridis') 49 | 50 | # 设置图表标题,标题内容为 'Bar Chart of ' 加上列名 51 | plt.title(f'Bar Chart of {column_data.name}') 52 | 53 | # 设置 x 轴标签,标签内容为 'Count' 54 | plt.xlabel('Count') 55 | 56 | # 设置 y 轴标签,标签内容为列名 57 | plt.ylabel(column_data.name) 58 | 59 | # 设置条形标签 60 | for p in ax.patches: 61 | # 获取条形的宽度(即计数值) 62 | width = p.get_width() 63 | 64 | # 获取条形的高度 65 | height = p.get_height() 66 | 67 | # 获取条形的 x 坐标 68 | x = p.get_x() + width 69 | 70 | # 获取条形的 y 坐标,条形的中心位置 71 | y = p.get_y() + height / 2 72 | 73 | # 确定标签位置以避免重叠,将标签向右偏移一些 74 | label_x = x + 0.1 75 | label_y = y 76 | 77 | # 添加标签,显示条形的宽度(即计数值) 78 | ax.annotate(f'{width}', 79 | (label_x, label_y), 80 | ha='left', va='center', fontsize=10, color='black') 81 | 82 | # 保存当前图表到 PDF 文件 83 | pdf_pages.savefig() 84 | 85 | # 关闭当前图表,释放内存 86 | plt.close() 87 | 88 | def pie_chart(column_data, colors, pdf_pages): 89 | ''' 90 | 绘制饼图 91 | ''' 92 | plt.figure(figsize=(8, 8)) 93 | 94 | # 计算每个类别的频数 95 | sizes = column_data.value_counts() 96 | 97 | # 绘制饼图 98 | sizes.plot.pie( 99 | autopct='%1.1f%%', # 设置百分比格式 100 | startangle=90, # 设置饼图的起始角度 101 | colors=colors, # 设置饼图的颜色 102 | wedgeprops=dict(edgecolor='grey') # 设置切片的边缘颜色 103 | ) 104 | 105 | # 设置图表标题,标题内容为 'Pie Chart of ' 加上列名,字体大小为 20 106 | plt.title(f'Pie Chart of {column_data.name}', fontsize=20) 107 | 108 | # 隐藏 y 轴标签 109 | plt.ylabel('') 110 | 111 | # 保存当前图表到 PDF 文件 112 | pdf_pages.savefig() 113 | 114 | # 关闭当前图表,释放内存 115 | plt.close() 116 | 117 | def data_visualize(user_file_path): 118 | ''' 119 | 数据可视化 120 | ''' 121 | # 定义需要绘制直方图的列名 122 | cols_histogram = ['age', 'fnlwgt', 'education', 'hours-per-week', 'capital-gain', 'capital-loss', 'native-country'] 123 | 124 | # 定义需要绘制条形图的列名 125 | cols_bar_chart = ['race', 'workclass', 'occupation', 'education-num', 'marital-status', 'relationship'] 126 | 127 | # 定义需要绘制饼图的列名 128 | cols_pie_chart = ['sex', 'income'] 129 | 130 | # 读取用户提供的 CSV 文件,生成 DataFrame 131 | df = pd.read_csv(user_file_path) 132 | 133 | # 创建临时 PDF 文件 134 | with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as temp_pdf: 135 | # 获取临时文件的路径 136 | temp_file_path = temp_pdf.name 137 | print() 138 | 139 | # 使用 PdfPages 创建一个 PDF 文件对象 140 | with PdfPages(temp_file_path) as pdf_pages: 141 | # 遍历 DataFrame 中的每一列 142 | for col in df.columns: 143 | # 如果列名在直方图列名列表中,则绘制直方图 144 | if col in cols_histogram: 145 | histogram(df[col], pdf_pages) 146 | 147 | # 如果列名在条形图列名列表中,则绘制条形图 148 | if col in cols_bar_chart: 149 | bar_chart(df[col], pdf_pages) 150 | 151 | # 如果列名在饼图列名列表中,则绘制饼图 152 | if col in cols_pie_chart: 153 | # 自定义颜色调色板 154 | colors = sns.color_palette('Set2', n_colors=len(df[col].value_counts())) 155 | pie_chart(df[col], colors, pdf_pages) 156 | 157 | # 尝试打开生成的 PDF 文件 158 | try: 159 | # 使用系统默认的 PDF 查看器打开生成的 PDF 文件 160 | subprocess.run(['evince', temp_file_path], check=True) 161 | except subprocess.CalledProcessError as e: 162 | # 如果打开文件时发生错误,打印错误信息 163 | print(f"An error occurred while opening the PDF file: {e}") 164 | 165 | if __name__ == '__main__': 166 | choice = sys.argv[1] 167 | if choice == '1': 168 | data_visualize(alice_path) 169 | elif choice == '2': 170 | data_visualize(bob_path) 171 | elif choice == '3': 172 | data_visualize(carol_path) 173 | elif choice == '4': 174 | data_visualize(full_file_path) 175 | else: 176 | print("输入选项有误,请重新输入!") -------------------------------------------------------------------------------- /data_process.py: -------------------------------------------------------------------------------- 1 | import secretflow as sf 2 | from secretflow.data.vertical import read_csv as v_read_csv 3 | from secretflow.preprocessing import LabelEncoder 4 | from secretflow.preprocessing import OneHotEncoder 5 | from secretflow.preprocessing import StandardScaler 6 | from secretflow.data.split import train_test_split 7 | from secretflow.stats.table_statistics import table_statistics 8 | from secretflow.stats.ss_vif_v import VIF 9 | from secretflow.stats.ss_pearsonr_v import PearsonR 10 | import jax.numpy as jnp 11 | import pandas as pd 12 | import numpy as np 13 | import tempfile 14 | import os 15 | import matplotlib.pyplot as plt 16 | import seaborn as sns 17 | 18 | def download_dataset(full_file_path): 19 | ''' 20 | 将数据集下载到本地 21 | ''' 22 | # UCI Adult 数据集的 URL 23 | url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data" 24 | 25 | # 列名 26 | column_names = [ 27 | 'age', 'workclass', 'fnlwgt', 'education', 'education-num', 28 | 'marital-status', 'occupation', 'relationship', 'race', 'sex', 29 | 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'income' 30 | ] 31 | 32 | # 下载数据集 33 | adult_df = pd.read_csv(url, names=column_names, sep=',\s', na_values=["?"], engine='python') 34 | # 保存为csv文件 35 | adult_df.to_csv(full_file_path,index=False) 36 | 37 | def load_dataset(full_file_path): 38 | ''' 39 | 在指定目录家在数据集,如果不存在则下载 40 | ''' 41 | # 检查文件路径是否存在,如果不存在则下载数据集 42 | if not os.path.exists(full_file_path): 43 | download_dataset(full_file_path) 44 | 45 | try: 46 | # 尝试读取 CSV 文件到 DataFrame 47 | df = pd.read_csv(full_file_path) 48 | except Exception as e: 49 | # 如果读取文件时发生错误,抛出带有错误信息的 ValueError 50 | raise ValueError(f"文件读取错误:{e}") 51 | 52 | # 添加一列 'uid',其值为 DataFrame 的索引加 1 53 | df['uid'] = df.index + 1 54 | 55 | # 返回处理后的 DataFrame 56 | return df 57 | 58 | def split_dataset(data_df, alice_path, bob_path, carol_path): 59 | ''' 60 | 分割数据集,将数据集垂直切分,并且存储在在各个参与方的临时文件路径 61 | ''' 62 | # 获取数据集的列数,减去 1 是因为最后一列是 'uid' 63 | num_columns = data_df.shape[1] - 1 64 | 65 | # 设置分割点,将列数分成三部分 66 | split_point1 = num_columns // 3 67 | split_point2 = 2 * num_columns // 3 68 | 69 | # 分割数据集,并且确保每一方都拥有 'uid' 列 70 | # Alice 获取前 split_point1 列和最后一列 'uid',并随机抽取 90% 的数据 71 | data_alice = data_df.iloc[:, np.r_[0:split_point1, -1]].sample(frac=0.9) 72 | 73 | # Bob 获取从 split_point1 到 split_point2 列和最后一列 'uid',并随机抽取 90% 的数据 74 | data_bob = data_df.iloc[:, np.r_[split_point1:split_point2, -1]].sample(frac=0.9) 75 | 76 | # Carol 获取从 split_point2 到最后一列(包括 'uid'),并随机抽取 90% 的数据 77 | data_carol = data_df.iloc[:, split_point2:].sample(frac=0.9) 78 | 79 | # 将三方数据集保存至 CSV 文件 80 | data_alice.reset_index(drop=True).to_csv(alice_path, index=False) 81 | data_bob.reset_index(drop=True).to_csv(bob_path, index=False) 82 | data_carol.reset_index(drop=True).to_csv(carol_path, index=False) 83 | 84 | # 返回保存的文件路径 85 | return alice_path, bob_path, carol_path 86 | 87 | def secret_psi(alice_path, bob_path, carol_path): 88 | ''' 89 | 隐私求交实现数据对齐,将求交结果保存至VDataFrame 90 | ''' 91 | # 使用 v_read_csv 函数读取 CSV 文件,生成一个虚拟 DataFrame (vdf) 92 | vdf = v_read_csv( 93 | {alice: alice_path, bob: bob_path, carol: carol_path}, 94 | spu=spu,# - spu:指定安全计算单元(SPU) 95 | keys='uid',# - keys='uid':指定用于 PSI(私密集合交集)的键列 96 | drop_keys='uid',# - drop_keys='uid':指定读取后要删除的键列 97 | psi_protocl="ECDH_PSI_3PC",# - psi_protocl="ECDH_PSI_3PC":指定使用 ECDH_PSI_3PC 协议进行三方 PSI 98 | ) 99 | # 返回生成的虚拟 DataFrame (vdf) 100 | return vdf 101 | 102 | def Missing_Value_Filling(vdf): 103 | ''' 104 | 缺失值填充 105 | 填充规则为:填充该列中的‘众数’ 106 | ''' 107 | # 定义需要填充缺失值的列名 108 | cols = ['workclass', 'occupation', 'native-country'] 109 | 110 | # 遍历每一列 111 | for col in cols: 112 | # 找到该列的众数 113 | most_frequent_value = vdf[col].mode()[0] 114 | 115 | # 使用众数填充该列中的缺失值 116 | vdf[col].fillna(most_frequent_value, inplace=True) 117 | 118 | # 返回填充缺失值后的 DataFrame 119 | return vdf 120 | 121 | def label_encode_function(vdf): 122 | ''' 123 | 对无序且二值的序列,采用label encoding,转换为0/1表示 124 | ''' 125 | # 创建一个 LabelEncoder 对象,用于将分类数据转换为数值数据 126 | label_encoder = LabelEncoder() 127 | 128 | # 定义需要进行标签编码的列名 129 | cols = ['sex', 'income'] 130 | 131 | # 遍历每一列 132 | for col in cols: 133 | # 拟合标签编码器到该列数据 134 | label_encoder.fit(vdf[col]) 135 | 136 | # 将该列数据转换为数值数据 137 | vdf[col] = label_encoder.transform(vdf[col]) 138 | 139 | # 返回进行标签编码后的 DataFrame 140 | return vdf 141 | 142 | def Ordinal_Cate_Features(vdf): 143 | ''' 144 | 对于有序的类别数据,构建映射,将类别数据转换为0~n-1的整数 145 | ''' 146 | vdf['education'] = vdf['education'].replace( 147 | { 148 | "Preschool":0, 149 | "1st-4th":1, 150 | "5th-6th":2, 151 | "7th-8th":3, 152 | "9th":4, 153 | "10th":5, 154 | "11th":6, 155 | "12th":7, 156 | "HS-grad":8, 157 | "Some-college":9, 158 | "Assoc-voc":10, 159 | "Assoc-acdm":11, 160 | "Bachelors":12, 161 | "Masters":13, 162 | "Prof-school":14, 163 | "Doctorate":15 164 | } 165 | ) 166 | return vdf 167 | 168 | def One_Hot_Function(vdf): 169 | ''' 170 | 对于无序类别数据,采用one-hot编码 171 | ''' 172 | # 定义需要进行one-hot编码的列名 173 | onehot_cols = ['workclass', 'marital-status', 'occupation', 'relationship', 'race', 'native-country'] 174 | 175 | # 创建一个 OneHotEncoder 对象,用于将分类数据转换为one-hot编码 176 | onehot_encoder = OneHotEncoder() 177 | 178 | # 拟合one-hot编码器到指定列的数据 179 | onehot_encoder.fit(vdf[onehot_cols]) 180 | 181 | # 创建一个新的 DataFrame vdf_hat,删除需要one-hot编码的列 182 | vdf_hat = vdf.drop(columns=onehot_cols) 183 | 184 | # 将指定列的数据转换为one-hot编码 185 | enc_feats = onehot_encoder.transform(vdf[onehot_cols]) 186 | 187 | # 获取one-hot编码后的特征名称 188 | features_name = enc_feats.columns 189 | 190 | # 从原始 DataFrame 中删除需要one-hot编码的列 191 | vdf = vdf.drop(columns=onehot_cols) 192 | 193 | # 将one-hot编码后的特征添加到 DataFrame 中 194 | vdf[features_name] = enc_feats 195 | 196 | # 返回进行one-hot编码后的 DataFrame 和删除指定列后的 DataFrame 197 | return vdf, vdf_hat 198 | 199 | def standard_scaler_func(vdf): 200 | ''' 201 | 对数值进行标准化 202 | ''' 203 | # 创建 vdf 的副本,以避免对原始数据进行修改 204 | vdf = vdf.copy() 205 | 206 | # 从 vdf 中删除 'income' 列,并将剩余的数据存储在 X 中 207 | X = vdf.drop(columns=['income']) 208 | 209 | # 将 'income' 列的数据存储在 y 中 210 | y = vdf['income'] 211 | 212 | # 创建一个 StandardScaler 对象,用于标准化数据 213 | scaler = StandardScaler() 214 | 215 | # 拟合标准化器并转换 X 中的数据 216 | X = scaler.fit_transform(X) 217 | 218 | # 将标准化后的数据赋值回 vdf 的相应列 219 | vdf[X.columns] = X 220 | 221 | # 返回标准化后的 DataFrame 222 | return vdf 223 | 224 | def split_train_test(vdf, train_size, random_state): 225 | ''' 226 | 训练集和测试集拆分 227 | ''' 228 | # 使用 train_test_split 函数将数据集拆分为训练集和测试集 229 | train_vdf, test_vdf = train_test_split(vdf, train_size=train_size, random_state=random_state) 230 | 231 | # 从训练集中删除 'income' 列,并将剩余的数据存储在 train_X 中 232 | train_X = train_vdf.drop(columns=['income']) 233 | 234 | # 将训练集中的 'income' 列存储在 train_y 中 235 | train_y = train_vdf['income'] 236 | 237 | # 从测试集中删除 'income' 列,并将剩余的数据存储在 test_X 中 238 | test_X = test_vdf.drop(columns=['income']) 239 | 240 | # 将测试集中的 'income' 列存储在 test_y 中 241 | test_y = test_vdf['income'] 242 | 243 | # 返回训练集和测试集的特征和标签 244 | return train_X, test_X, train_y, test_y 245 | 246 | def vdataframe_to_spu(vdf): 247 | ''' 248 | 将VDataFrame数据类型转换为SPUObject 249 | ''' 250 | # 创建一个空列表,用于存储每个设备上的 SPU 分区 251 | spu_partitions = [] 252 | 253 | # 遍历 vdf 的每个分区 254 | for device in vdf.partitions: 255 | # 将每个分区的数据转换为 SPU 格式,并添加到 spu_partitions 列表中 256 | spu_partitions.append(vdf.partitions[device].data.to(spu)) 257 | 258 | # 取出第一个 SPU 分区作为基础分区 259 | base_partition = spu_partitions[0] 260 | 261 | # 遍历剩余的 SPU 分区 262 | for i in range(1, len(spu_partitions)): 263 | # 使用 SPU 计算,将当前基础分区与下一个 SPU 分区在轴 1 上进行拼接 264 | base_partition = spu(lambda x, y: jnp.concatenate([x, y], axis=1))( 265 | base_partition, spu_partitions[i] 266 | ) 267 | 268 | # 返回拼接后的基础分区 269 | return base_partition 270 | 271 | def convert_to_spu(train_X, test_X, train_y, test_y): 272 | ''' 273 | 将训练集特征数据转换为 SPU 格式 274 | ''' 275 | X_train_spu = vdataframe_to_spu(train_X) 276 | 277 | # 将训练集标签数据转换为 SPU 格式 278 | y_train_spu = train_y.partitions[carol].data.to(spu) 279 | 280 | # 将测试集特征数据转换为 SPU 格式 281 | X_test_spu = vdataframe_to_spu(test_X) 282 | 283 | # 将测试集标签数据转换为 SPU 格式 284 | y_test_spu = test_y.partitions[carol].data.to(spu) 285 | 286 | # 返回转换后的训练集和测试集特征及标签数据 287 | return X_train_spu, X_test_spu, y_train_spu, y_test_spu 288 | 289 | 290 | def data_preprocessing(full_file_path): 291 | ''' 292 | 数据预处理,包括加载数据集,数据集切分 293 | ''' 294 | # 给各个参与方分配临时文件路径 295 | _, alice_path = tempfile.mkstemp() 296 | _, bob_path = tempfile.mkstemp() 297 | _, carol_path = tempfile.mkstemp() 298 | 299 | data = load_dataset(full_file_path)# 加载数据集 300 | # print(data.head()) # 查看前几行数据以了解数据集结构 301 | split_dataset(data,alice_path,bob_path,carol_path) # 将数据集切分为三方数据集 302 | return alice_path,bob_path,carol_path 303 | 304 | def data_process(alice_path,bob_path,carol_path): 305 | ''' 306 | 数据处理 307 | ''' 308 | vdf = secret_psi(alice_path,bob_path,carol_path) # 隐私求交实现数据对齐 309 | # print(vdf) # 查看求交结果 310 | vdf_1 = vdf.copy()# 复制求交结果 311 | vdf_1 = Missing_Value_Filling(vdf_1) # 缺失值填充 312 | vdf_1 = label_encode_function(vdf_1) # 将无序且二值的序列转换为0/1表示 313 | vdf_1 = Ordinal_Cate_Features(vdf_1) # 对于有序的类别数据,构建映射,将类别数据转换为0~n-1的整数 314 | vdf_1, vdf_hat = One_Hot_Function(vdf_1) # 对于无序类别数据,采用one-hot编码 315 | vdf_1 = standard_scaler_func(vdf_1) # 对数值进行标准化 316 | 317 | # 训练集和测试集拆分 318 | train_size = 0.8 # 训练集占比 319 | random_state = 1234 # 随机种子 320 | train_X,test_X,train_y,test_y = split_train_test(vdf_1,train_size,random_state) 321 | 322 | # 数据类型转换 323 | X_train_spu,X_test_spu,y_train_spu,y_test_spu = convert_to_spu(train_X,test_X,train_y,test_y) 324 | 325 | # 构建返回字典 326 | results = { 327 | 'X_train_spu': X_train_spu, 328 | 'X_test_spu': X_test_spu, 329 | 'y_train_spu': y_train_spu, 330 | 'y_test_spu': y_test_spu, 331 | 'vdf': vdf_1, 332 | 'vdf_hat': vdf_hat 333 | } 334 | return results 335 | 336 | # 文件的下载和存储路径 337 | full_file_path = './adult_data.csv' 338 | 339 | # 配置SPU相关设备 340 | sf.shutdown()# 关闭所有SPU设备 341 | sf.init(['alice','bob','carol'],address='local') 342 | aby3_config = sf.utils.testing.cluster_def(parties=['alice', 'bob', 'carol']) 343 | spu = sf.SPU(aby3_config) 344 | alice = sf.PYU('alice') 345 | bob = sf.PYU('bob') 346 | carol = sf.PYU('carol') 347 | 348 | # 数据集加载和切分 349 | alice_path,bob_path,carol_path = data_preprocessing(full_file_path) 350 | 351 | ''' 352 | 数据处理 353 | 这里的vdf是psi意思求交的结果 354 | vdf_hat是数据处理之后的结果,但是不包括one-hot编码的结果 355 | 处理结果是包括X_train_spu,X_test_spu,y_train_spu,y_test_spu的字典 356 | ''' 357 | data_dict = data_process(alice_path,bob_path,carol_path) 358 | 359 | # 输出处理后的数据 360 | # print(sf.reveal(data_dict['X_train_spu'])) 361 | # print(sf.reveal(data_dict['X_test_spu'])) 362 | # print(sf.reveal(data_dict['y_train_spu'])) 363 | # print(sf.reveal(data_dict['y_test_spu'])) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /mlp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import secretflow as sf 3 | from secretflow.data.vertical import read_csv as v_read_csv 4 | from secretflow.preprocessing import LabelEncoder 5 | from secretflow.preprocessing import OneHotEncoder 6 | from secretflow.preprocessing import StandardScaler 7 | from secretflow.data.split import train_test_split 8 | from secretflow.stats.table_statistics import table_statistics 9 | from secretflow.stats.ss_vif_v import VIF 10 | from secretflow.stats.ss_pearsonr_v import PearsonR 11 | import jax.numpy as jnp 12 | import jax 13 | import pandas as pd 14 | import numpy as np 15 | import tempfile 16 | import os 17 | import matplotlib.pyplot as plt 18 | import sys 19 | import subprocess 20 | from data_process import spu, alice, bob, carol, data_dict 21 | 22 | def install_package(package_name): 23 | ''' 24 | 安装指定的Python包 25 | ''' 26 | try: 27 | # 使用 subprocess 执行 pip 安装命令,静默安装指定的包 28 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', package_name, '-q']) 29 | except subprocess.CalledProcessError as e: 30 | # 安装出错则打印错误信息 31 | print(f"Error occurred while installing {package_name}: {e}") 32 | else: 33 | # 安装成功后打印成功信息 34 | print(f"Successfully installed {package_name}") 35 | return 36 | 37 | # 安装0.6.0版本的flax包,用于构建神经网络 38 | install_package('flax==0.6.0') 39 | 40 | # 安装seaborn包,用于数据可视化 41 | install_package('seaborn') 42 | 43 | import seaborn as sns 44 | from typing import Sequence 45 | import flax.linen as nn 46 | from sklearn.metrics import roc_auc_score 47 | 48 | class MLP(nn.Module): # 定义一个多层感知器(MLP)类 49 | features: Sequence[int] # 每一层神经元的数量 50 | dropout_rate: float # Dropout 概率 51 | @nn.compact 52 | def __call__(self, x, train, rngs=None): # 定义前向传播过程 53 | for feat in self.features[:-1]: # 遍历每一层(除了最后一层) 54 | x = nn.relu(nn.Dense(feat)(x)) # 应用全连接层和 ReLU 激活函数 55 | x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) # 应用 Dropout 56 | x = nn.BatchNorm(use_running_average=not train, momentum=0.5, epsilon=1e-5)(x) # 应用 Batch Normalization 57 | x = nn.Dense(self.features[-1])(x) # 最后一层只应用全连接层 58 | return x # 返回输出 59 | 60 | 61 | def predict(params, x, train=False, rng_key=None): 62 | ''' 63 | 预测函数,传入权重偏置和输入,训练和测试都要用到 64 | ''' 65 | from typing import Sequence 66 | import flax.linen as nn 67 | 68 | class MLP(nn.Module): # 定义一个多层感知器(MLP)类 69 | features: Sequence[int] # 每一层神经元的数量 70 | dropout_rate: float # Dropout 概率 71 | @nn.compact 72 | def __call__(self, x, train, rngs=None): # 定义前向传播过程 73 | for feat in self.features[:-1]: # 遍历每一层(除了最后一层) 74 | x = nn.relu(nn.Dense(feat)(x)) # 应用全连接层和 ReLU 激活函数 75 | x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) # 应用 Dropout 76 | x = nn.BatchNorm(use_running_average=not train, momentum=0.5, epsilon=1e-5)(x) # 应用 Batch Normalization 77 | x = nn.Dense(self.features[-1])(x) # 最后一层只应用全连接层 78 | return x # 返回输出 79 | # FEATURE :每一层神经元的数目 80 | FEATURES = [dim, 15, 8, 1] # 定义每一层神经元的数量 81 | flax_nn = MLP(features=FEATURES, dropout_rate=0.1) # 创建 MLP 实例,并设置 Dropout 概率为 0.1 82 | if rng_key is None: # 如果没有提供随机数生成器的键 83 | rng_key = jax.random.PRNGKey(0) # 使用默认的 PRNG 键 84 | y, updates = flax_nn.apply( # 调用 MLP 的 apply 方法进行前向传播 85 | params, # 模型参数 86 | x, # 输入数据 87 | train, # 是否处于训练模式 88 | mutable=['batch_stats'], # 指示 Batch Normalization 的统计数据是可变的 89 | rngs={'dropout': rng_key} # 传入用于 Dropout 的随机数生成器 90 | ) 91 | batch_stats = updates['batch_stats'] # 获取 Batch Normalization 的统计数据 92 | return y # 返回模型的预测结果 93 | 94 | def loss_func(params, x, y, rng_key): 95 | ''' 96 | 使用MSE作为损失函数 97 | ''' 98 | # 调用 predict 函数,使用给定的参数 params 和输入数据 x 进行预测,得到预测值 pred 99 | # 传递 train=True 和随机数生成键 rng_key 作为额外参数 100 | pred = predict(params, x, train=True, rng_key=rng_key) 101 | # 定义均方误差(MSE)函数 102 | def mse(y, pred): 103 | # 定义平方误差函数 104 | def squared_error(y, y_pred): 105 | # 计算每个样本的平方误差,并除以 2.0 106 | return jnp.multiply(y - y_pred, y - y_pred) / 2.0 107 | # 计算所有样本的平均平方误差 108 | return jnp.mean(squared_error(y, pred)) 109 | # 调用 mse 函数,计算并返回 y 和 pred 之间的均方误差 110 | return mse(y, pred) 111 | 112 | def train_auto_grad(X, y, params, batch_size=10, epochs=10, learning_rate=0.01): 113 | ''' 114 | 模型训练 115 | ''' 116 | # 将输入数据 X 和标签 y 按照 batch_size 分割成多个小批次 117 | xs = jnp.array_split(X, len(X) // batch_size, axis=0) 118 | ys = jnp.array_split(y, len(y) // batch_size, axis=0) 119 | 120 | # 打印输入数据的shape 121 | #print(X.shape) 122 | 123 | # 初始化随机数生成器的键 124 | rng_key = jax.random.PRNGKey(0) 125 | 126 | # 进行多个 epoch 的训练 127 | for epoch in range(epochs): 128 | # 遍历每个小批次的数据 129 | for batch_x, batch_y in zip(xs, ys): 130 | # 计算当前批次的损失和梯度 131 | loss, grads = jax.value_and_grad(loss_func)(params, batch_x, batch_y, rng_key) 132 | 133 | # 使用梯度下降法更新模型参数 134 | params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grads) 135 | 136 | # 返回更新后的模型参数 137 | return params 138 | 139 | class MLP_spu(nn.Module): # 定义一个多层感知器(MLP)类 spu版本 140 | features: Sequence[int] # 每一层神经元的数量 141 | @nn.compact 142 | def __call__(self, x): # 定义前向传播过程 143 | for feat in self.features[:-1]: # 遍历每一层(除了最后一层) 144 | x = nn.relu(nn.Dense(feat)(x)) # 应用全连接层和 ReLU 激活函数 145 | x = nn.Dense(self.features[-1])(x) # 最后一层只应用全连接层 146 | return x # 返回输出 147 | 148 | def predict_spu(params, x): 149 | ''' 150 | spu版本的预测函数 151 | ''' 152 | # 从 typing 模块导入 Sequence 类型 153 | from typing import Sequence 154 | # 从 flax.linen 模块导入 nn 155 | import flax.linen as nn 156 | 157 | # 定义一个多层感知器(MLP)类,适用于 SPU 158 | class MLP_spu(nn.Module): 159 | # 定义每一层神经元的数量 160 | features: Sequence[int] 161 | 162 | # 定义前向传播过程 163 | @nn.compact 164 | def __call__(self, x): 165 | # 遍历每一层(除了最后一层) 166 | for feat in self.features[:-1]: 167 | # 应用全连接层和 ReLU 激活函数 168 | x = nn.relu(nn.Dense(feat)(x)) 169 | # 最后一层只应用全连接层 170 | x = nn.Dense(self.features[-1])(x) 171 | # 返回输出 172 | return x 173 | 174 | # 定义每一层神经元的数量,FEATURES 列表 175 | FEATURES = [dim, 15, 8, 1] 176 | 177 | # 创建一个 MLP_spu 实例,传入 FEATURES 作为参数 178 | flax_nn = MLP_spu(features=FEATURES) 179 | 180 | # 使用给定的参数 params 和输入数据 x 进行预测,返回预测结果 181 | return flax_nn.apply(params, x) 182 | 183 | def loss_func_spu(params, x, y): 184 | ''' 185 | spu版本的损失函数 186 | ''' 187 | # 使用给定的参数 params 和输入数据 x 进行预测,得到预测值 pred 188 | pred = predict_spu(params, x) 189 | # 定义均方误差(MSE)函数 190 | def mse(y, pred): 191 | # 定义平方误差函数 192 | def squared_error(y, y_pred): 193 | # 计算每个样本的平方误差,并除以 2.0 194 | return jnp.multiply(y - y_pred, y - y_pred) / 2.0 195 | 196 | # 计算所有样本的平均平方误差 197 | return jnp.mean(squared_error(y, pred)) 198 | 199 | # 调用 mse 函数,计算并返回 y 和 pred 之间的均方误差 200 | return mse(y, pred) 201 | 202 | def train_auto_grad_spu(X, y, params, batch_size=10, epochs=10, learning_rate=0.01): 203 | ''' 204 | spu版本的模型训练函数 205 | ''' 206 | # 将输入数据 X 按照 batch_size 分割成多个小批量,存储在 xs 列表中 207 | xs = jnp.array_split(X, len(X) // batch_size, axis=0) 208 | 209 | # 将目标数据 y 按照 batch_size 分割成多个小批量,存储在 ys 列表中 210 | ys = jnp.array_split(y, len(y) // batch_size, axis=0) 211 | 212 | # 打印输入数据 X 的形状 213 | # print(X.shape) 214 | 215 | # 进行 epochs 次训练迭代 216 | for epoch in range(epochs): 217 | # 遍历每个小批量数据 218 | for batch_x, batch_y in zip(xs, ys): 219 | # 计算当前批量数据的损失值和梯度 220 | loss, grads = jax.value_and_grad(loss_func_spu)(params, batch_x, batch_y) 221 | 222 | # 更新模型参数,使用梯度下降法 223 | params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grads) 224 | 225 | # 返回训练后的模型参数 226 | return params 227 | 228 | def cpu_version_mlp(X_train_plaintext, y_train_plaintext, init_params, batch_size, epochs, learning_rate): 229 | ''' 230 | 在 CPU 上训练和评估MLP模型 231 | ''' 232 | # 使用自动微分方法训练模型,返回训练后的参数 233 | params = train_auto_grad( 234 | X_train_plaintext, y_train_plaintext, init_params, batch_size, epochs, learning_rate 235 | ) 236 | 237 | # 设置用于预测的随机数生成器的键 238 | rng_key = jax.random.PRNGKey(1) 239 | 240 | # 使用训练后的参数进行预测 241 | y_pred = predict(params, X_test_plaintext, train=False) 242 | 243 | # 计算并打印模型的 AUC 分数 244 | os.system('clear') 245 | print(f"\033[31m(Flax NN CPU) auc: {roc_auc_score(y_test_plaintext, y_pred)}\033[0m") 246 | 247 | def spu_version_mlp(X_train_spu, y_train_spu, params_spu, batch_size, epochs, learning_rate): 248 | ''' 249 | 在 SPU 上训练和评估MLP模型 250 | ''' 251 | # 使用 SPU 环境中的 train_auto_grad 函数训练模型,返回训练后的参数 252 | params_spu = spu( 253 | train_auto_grad_spu, static_argnames=['batch_size', 'epochs', 'learning_rate'] 254 | )( 255 | X_train_spu, # 训练数据 256 | y_train_spu, # 训练标签 257 | params_spu, # 初始参数 258 | batch_size=batch_size, # 批次大小 259 | epochs=epochs, # 训练轮数 260 | learning_rate=learning_rate # 学习率 261 | ) 262 | 263 | # 使用 SPU 环境中的 predict 函数进行预测 264 | y_pred_spu = spu(predict_spu)(params_spu, X_test_spu) 265 | 266 | # 将预测结果从 SPU 环境中揭示出来 267 | y_pred_ = sf.reveal(y_pred_spu) 268 | 269 | # 计算并打印模型的 AUC 分数 270 | print(f"\033[31m(Flax NN SPU) auc: {roc_auc_score(y_test_plaintext, y_pred_)}\033[0m") 271 | 272 | if __name__ == '__main__': 273 | # 获取数据集 274 | X_train_spu = data_dict['X_train_spu'] 275 | y_train_spu = data_dict['y_train_spu'] 276 | X_test_spu = data_dict['X_test_spu'] 277 | y_test_spu = data_dict['y_test_spu'] 278 | 279 | # 将 SPU 上的训练和测试数据揭露为明文数据 280 | X_train_plaintext = sf.reveal(X_train_spu) # 揭露训练集特征数据 281 | y_train_plaintext = sf.reveal(y_train_spu) # 揭露训练集目标数据 282 | X_test_plaintext = sf.reveal(X_test_spu) # 揭露测试集特征数据 283 | y_test_plaintext = sf.reveal(y_test_spu) # 揭露测试集目标数据 284 | 285 | # 获取训练集特征数据的维度 286 | dim = X_train_plaintext.shape[1] 287 | 288 | # 定义每一层神经元的数量 289 | FEATURES = [dim, 15, 8, 1] 290 | 291 | # 创建 CPU 版本的多层感知器(MLP)实例,设置 Dropout 概率为 0.1 292 | flax_nn = MLP(features=FEATURES, dropout_rate=0.1) 293 | 294 | # 创建 SPU 版本的多层感知器(MLP)实例 295 | flax_nn_spu = MLP_spu(features=FEATURES) 296 | 297 | # 根据数据集的特征维度设置特征维度 298 | feature_dim = dim 299 | 300 | if len(sys.argv[1])==3: 301 | # 设置模型训练的参数 302 | epochs = sys.argv[1][0] # CPU 版本的训练轮数 303 | learning_rate = sys.argv[1][1] # CPU 版本的学习率 304 | batch_size = sys.argv[1][2] # CPU 版本的批量大小 305 | epochs_spu = sys.argv[1][0] # SPU 版本的训练轮数 306 | learning_rate_spu = sys.argv[1][1] # SPU 版本的学习率 307 | batch_size_spu = sys.argv[1][2] # SPU 版本的批量大小 308 | else: 309 | epochs = 2 # CPU 版本的训练轮数 310 | learning_rate = 0.02 # CPU 版本的学习率 311 | batch_size = 100 # CPU 版本的批量大小 312 | epochs_spu = 2 # SPU 版本的训练轮数 313 | learning_rate_spu = 0.02 # SPU 版本的学习率 314 | batch_size_spu = 100 # SPU 版本的批量大小 315 | # 初始化 CPU 版本的模型参数,使用随机数生成键和全 1 的输入数据 316 | init_params = flax_nn.init(jax.random.PRNGKey(1), jnp.ones((batch_size, feature_dim)), train=False) 317 | 318 | # 初始化 SPU 版本的模型参数,使用随机数生成键和全 1 的输入数据 319 | init_params_spu = flax_nn_spu.init(jax.random.PRNGKey(1), jnp.ones((batch_size, feature_dim))) 320 | 321 | # 将初始化的 SPU 版本模型参数从 Alice 传递到 SPU 322 | params = sf.to(alice, init_params_spu).to(spu) 323 | 324 | # 打印训练集特征数据的形状 325 | # print(X_train_plaintext.shape) 326 | 327 | # 调用 CPU 版本的 MLP 训练函数,传入训练集特征数据、目标数据、初始化参数、批量大小、训练轮数和学习率 328 | cpu_version_mlp(X_train_plaintext, y_train_plaintext, init_params, batch_size, epochs, learning_rate) 329 | 330 | # 调用 SPU 版本的 MLP 训练函数,传入训练集特征数据、目标数据、初始化参数、批量大小、训练轮数和学习率 331 | spu_version_mlp(X_train_spu, y_train_spu, params, batch_size_spu, epochs_spu, learning_rate_spu) -------------------------------------------------------------------------------- /neural_network_gui.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import subprocess 3 | from PyQt5.QtWidgets import ( 4 | QApplication, QMainWindow, QWidget, QVBoxLayout, QPushButton, QLabel, QStackedWidget, QSizePolicy, QTextEdit, QLineEdit, QHBoxLayout 5 | ) 6 | from PyQt5.QtCore import Qt 7 | from PyQt5.QtGui import QFont 8 | 9 | def filter_output(output, keywords): 10 | """ 11 | 过滤掉包含任意关键字的行。 12 | 13 | :param output: 要过滤的原始输出字符串。 14 | :param keywords: 需要过滤的关键字列表。 15 | :return: 过滤后的字符串。 16 | """ 17 | lines = output.splitlines() 18 | filtered_lines = [line for line in lines if not any(keyword in line for keyword in keywords)] 19 | return '\n'.join(filtered_lines) 20 | 21 | class MainWindow(QMainWindow): 22 | def __init__(self): 23 | super().__init__() 24 | 25 | # 设置主窗口标题和尺寸 26 | self.setWindowTitle("Gryffindor - Neural Network Based on SecretFlow") 27 | self.setGeometry(100, 100, 1000, 800) # 增加窗口尺寸 28 | 29 | # 创建堆栈窗口用于页面切换 30 | self.stacked_widget = QStackedWidget() 31 | self.setCentralWidget(self.stacked_widget) 32 | 33 | # 创建主页面 34 | self.main_page = QWidget() 35 | self.stacked_widget.addWidget(self.main_page) 36 | 37 | # 设置主页面布局 38 | self.main_layout = QVBoxLayout(self.main_page) 39 | 40 | # 显示参赛题目(占据主界面高度的30%) 41 | self.topic_label = QLabel(" Neural Network Based on SecretFlow", self) 42 | self.topic_label.setAlignment(Qt.AlignCenter) 43 | self.topic_label.setFont(QFont("黑体", 24, QFont.Bold)) # 保持字体为黑体,大小24 44 | self.topic_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.MinimumExpanding) 45 | self.main_layout.addWidget(self.topic_label) 46 | 47 | # 显示队名(占据主界面高度的20%) 48 | self.team_label = QLabel("Team: Gryffindor", self) 49 | self.team_label.setAlignment(Qt.AlignCenter) 50 | self.team_label.setFont(QFont("黑体", 20, QFont.Bold)) # 保持字体为黑体,大小20 51 | self.team_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.MinimumExpanding) 52 | self.main_layout.addWidget(self.team_label) 53 | 54 | # 创建功能按钮(占据主界面高度的50%) 55 | self.function1_button = QPushButton("Function 1: Data Analysis", self) 56 | self.function2_button = QPushButton("Function 2: Run Neural Network", self) 57 | self.function1_button.setFont(QFont("宋体", 14)) # 保持默认字体,仅调整大小为14 58 | self.function2_button.setFont(QFont("宋体", 14)) # 保持默认字体,仅调整大小为14 59 | self.function1_button.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.MinimumExpanding) 60 | self.function2_button.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.MinimumExpanding) 61 | self.main_layout.addWidget(self.function1_button) 62 | self.main_layout.addWidget(self.function2_button) 63 | 64 | # 连接功能按钮点击事件 65 | self.function1_button.clicked.connect(self.show_function1) 66 | self.function2_button.clicked.connect(self.show_neural_network_page) 67 | 68 | # 创建第一个功能的子功能页面 69 | self.create_function1_pages() 70 | 71 | def create_function1_pages(self): 72 | # 创建第一个功能页面 73 | self.function1_page = QWidget() 74 | self.function1_layout = QVBoxLayout(self.function1_page) 75 | self.stacked_widget.addWidget(self.function1_page) 76 | 77 | # 创建子功能按钮,并改名 78 | self.sub_function1_button = QPushButton("Sub Function 1: Full Table Statistics", self) 79 | self.sub_function2_button = QPushButton("Sub Function 2: VIF Multicollinearity Test", self) 80 | self.sub_function3_button = QPushButton("Sub Function 3: Correlation Coefficient Matrix", self) 81 | self.sub_function4_button = QPushButton("Sub Function 4: Data Visual", self) 82 | self.sub_function1_button.setFont(QFont("宋体", 12)) # 保持默认字体,仅调整大小为17 83 | self.sub_function2_button.setFont(QFont("宋体", 12)) # 保持默认字体,仅调整大小为17 84 | self.sub_function3_button.setFont(QFont("宋体", 12)) # 保持默认字体,仅调整大小为17 85 | self.sub_function4_button.setFont(QFont("宋体", 12)) # 保持默认字体,仅调整大小为17 86 | self.function1_layout.addWidget(self.sub_function1_button) 87 | self.function1_layout.addWidget(self.sub_function2_button) 88 | self.function1_layout.addWidget(self.sub_function3_button) 89 | self.function1_layout.addWidget(self.sub_function4_button) 90 | 91 | # 增加回退按钮,返回主页面,并调整按钮尺寸 92 | self.back_to_main_button = QPushButton("Back to Main Menu", self) 93 | self.back_to_main_button.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum) 94 | self.function1_layout.addWidget(self.back_to_main_button) 95 | self.back_to_main_button.clicked.connect(self.go_back_to_main) 96 | 97 | # 增加按钮尺寸 98 | self.sub_function1_button.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) 99 | self.sub_function2_button.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) 100 | self.sub_function3_button.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) 101 | self.sub_function4_button.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) 102 | 103 | # 连接子功能按钮点击事件 104 | self.sub_function1_button.clicked.connect(lambda: self.show_output_page("network_demo/full_table_statistics.py", "Full Table Statistics Output")) 105 | self.sub_function2_button.clicked.connect(lambda: self.show_output_page("network_demo/multicollinearity_test.py", "VIF Multicollinearity Test Output")) 106 | self.sub_function3_button.clicked.connect(lambda: self.show_output_page("network_demo/corr_coefficient_matrix.py", "Correlation Coefficient Matrix Output")) 107 | self.sub_function4_button.clicked.connect(self.show_data_visual_page) 108 | 109 | def show_function1(self): 110 | self.stacked_widget.setCurrentWidget(self.function1_page) 111 | 112 | def show_output_page(self, script_path, page_title, main_menu=False): 113 | # 创建输出页面 114 | output_page = QWidget() 115 | output_layout = QVBoxLayout(output_page) 116 | self.stacked_widget.addWidget(output_page) 117 | 118 | # 创建输出文本框 119 | output_text = QTextEdit(output_page) 120 | output_text.setReadOnly(True) 121 | output_layout.addWidget(output_text) 122 | 123 | keywords_to_filter = ["pid", "SPURuntime", "info"] 124 | 125 | try: 126 | result = subprocess.run(["python", script_path], capture_output=True, text=True, check=True) 127 | filtered_output = filter_output(result.stdout, keywords_to_filter) 128 | output_text.append(f"{page_title}:\n\n{filtered_output}") 129 | except subprocess.CalledProcessError as e: 130 | filtered_error = filter_output(e.stderr, keywords_to_filter) 131 | output_text.append(f"An error occurred while running {script_path}:\n\n{filtered_error}") 132 | 133 | # 增加回退按钮,返回适当的菜单,并调整按钮尺寸 134 | back_button = QPushButton("Back to Main Menu" if main_menu else "Back to Function 1 Menu", self) 135 | back_button.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum) 136 | output_layout.addWidget(back_button) 137 | back_button.clicked.connect(self.go_back_to_main if main_menu else self.go_back_to_function1) 138 | 139 | # 显示输出页面 140 | self.stacked_widget.setCurrentWidget(output_page) 141 | 142 | def show_data_visual_page(self): 143 | # 创建Data Visual交互页面 144 | data_visual_page = QWidget() 145 | data_visual_layout = QVBoxLayout(data_visual_page) 146 | self.stacked_widget.addWidget(data_visual_page) 147 | 148 | # 输入框提示和输入框 149 | input_label = QLabel("Enter a value (1-4):"+ "\n"+ 150 | """ | 1. alice 数据可视化\t\t| 151 | | 2. bob 数据可视化\t\t| 152 | | 3. carol 数据可视化\t\t| 153 | | 4. full_file_path 数据可视化\t| """, self) 154 | data_visual_layout.addWidget(input_label) 155 | 156 | input_edit = QLineEdit(self) 157 | data_visual_layout.addWidget(input_edit) 158 | 159 | # 结果展示框 160 | result_text = QTextEdit(self) 161 | result_text.setReadOnly(True) 162 | data_visual_layout.addWidget(result_text) 163 | 164 | # 运行按钮 165 | run_button = QPushButton("Run Data Visual", self) 166 | data_visual_layout.addWidget(run_button) 167 | 168 | # 运行逻辑 169 | def run_data_visual(): 170 | choice = input_edit.text() 171 | 172 | # 根据用户输入的值调整命令 173 | dv_command = ["python", "network_demo/data_visual.py", choice] 174 | keywords_to_filter = ["pid", "SPURuntime", "info"] 175 | try: 176 | result = subprocess.run(dv_command, capture_output=True, text=True, check=True) 177 | filtered_output = filter_output(result.stdout, keywords_to_filter) 178 | result_text.append(filtered_output) 179 | except subprocess.CalledProcessError as e: 180 | result_text.append(f"An error occurred:\n{e.stderr}") 181 | run_button.clicked.connect(run_data_visual) 182 | 183 | # 回退按钮 184 | back_button = QPushButton("Back to Function 1 Menu", self) 185 | back_button.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum) 186 | data_visual_layout.addWidget(back_button) 187 | back_button.clicked.connect(self.go_back_to_function1) 188 | 189 | # 显示页面 190 | self.stacked_widget.setCurrentWidget(data_visual_page) 191 | 192 | def show_neural_network_page(self): 193 | # 创建运行神经网络功能的页面 194 | neural_network_page = QWidget() 195 | neural_network_layout = QVBoxLayout(neural_network_page) 196 | self.stacked_widget.addWidget(neural_network_page) 197 | 198 | # 输入框设置 199 | epochs_label = QLabel("Enter epochs:(默认参数需要输入2)", self) 200 | learning_rate_label = QLabel("Enter learning rate:(默认参数需要输入0.02)", self) 201 | batch_size_label = QLabel("Enter batch size:(默认参数需要输入100)", self) 202 | 203 | epochs_edit = QLineEdit(self) 204 | learning_rate_edit = QLineEdit(self) 205 | batch_size_edit = QLineEdit(self) 206 | 207 | neural_network_layout.addWidget(epochs_label) 208 | neural_network_layout.addWidget(epochs_edit) 209 | neural_network_layout.addWidget(learning_rate_label) 210 | neural_network_layout.addWidget(learning_rate_edit) 211 | neural_network_layout.addWidget(batch_size_label) 212 | neural_network_layout.addWidget(batch_size_edit) 213 | 214 | # 结果展示框 215 | nn_result_text = QTextEdit(self) 216 | nn_result_text.setReadOnly(True) 217 | neural_network_layout.addWidget(nn_result_text) 218 | 219 | # 运行按钮 220 | nn_run_button = QPushButton("Run Neural Network", self) 221 | neural_network_layout.addWidget(nn_run_button) 222 | keywords_to_filter = ["pid", "SPURuntime", "info"] 223 | # 运行逻辑 224 | def run_neural_network(): 225 | epochs = epochs_edit.text() 226 | learning_rate = learning_rate_edit.text() 227 | batch_size = batch_size_edit.text() 228 | 229 | nn_command = [ 230 | "python", "network_demo/mlp.py", 231 | "--epochs", epochs, 232 | "--learning_rate", learning_rate, 233 | "--batch_size", batch_size 234 | ] 235 | 236 | try: 237 | result = subprocess.run(nn_command, capture_output=True, text=True, check=True) 238 | filtered_output = filter_output(result.stdout, keywords_to_filter) 239 | nn_result_text.append(filtered_output) 240 | except subprocess.CalledProcessError as e: 241 | nn_result_text.append(f"An error occurred:\n{e.stderr}") 242 | nn_run_button.clicked.connect(run_neural_network) 243 | 244 | # 回退按钮 245 | back_button = QPushButton("Back to Main Menu", self) 246 | back_button.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum) 247 | neural_network_layout.addWidget(back_button) 248 | back_button.clicked.connect(self.go_back_to_main) 249 | 250 | # 显示页面 251 | self.stacked_widget.setCurrentWidget(neural_network_page) 252 | 253 | def go_back_to_function1(self): 254 | self.stacked_widget.setCurrentWidget(self.function1_page) 255 | 256 | def go_back_to_main(self): 257 | self.stacked_widget.setCurrentWidget(self.main_page) 258 | 259 | if __name__ == "__main__": 260 | app = QApplication(sys.argv) 261 | main_window = MainWindow() 262 | main_window.show() 263 | sys.exit(app.exec_()) 264 | --------------------------------------------------------------------------------