├── .gitignore ├── LICENSE ├── README.md ├── data └── house │ ├── airbnb_clean.csv │ └── house_clean.csv ├── requirements.txt └── src ├── __init__.py ├── dataset ├── DataSampler.py ├── LocalDataset.py ├── VFLDataset.py ├── VFLRealDataset.py └── __init__.py ├── metric ├── RMSE.py └── __init__.py ├── model ├── FeT.py ├── PosEncoding.py ├── Solo.py ├── SplitNN.py ├── Transformer.py └── __init__.py ├── preprocess ├── ExactSpitter.py ├── FeatureEvaluator.py ├── FeatureSplitter.py ├── FuzzySplitter.py ├── __init__.py ├── gisette │ ├── __init__.py │ └── gisette_loader.py ├── hdb │ ├── __init__.py │ ├── clean_hdb.py │ ├── clean_school.py │ └── hdb_loader.py ├── house │ ├── __init__.py │ ├── beijing_loder.py │ ├── clean_airbnb.py │ └── clean_house.py ├── ml_dataset │ ├── __init__.py │ └── two_party_loader.py ├── nytaxi │ ├── __init__.py │ ├── clean_citibike.py │ ├── clean_tlc.py │ ├── filter_kaggle.py │ └── ny_loader.py ├── split-bias.sh └── vsplit.py ├── privacy ├── GaussianMechanism.py └── __init__.py ├── script ├── ablation_dm_or_not.sh ├── ablation_keynoise.sh ├── ablation_keynoise_baseline.sh ├── ablation_knnk.sh ├── ablation_knnk_real.sh ├── ablation_party_dropout.sh ├── ablation_pe_average_freq.sh ├── ablation_pe_or_not.sh ├── ablation_real_dm_or_not.sh ├── download_dataset.sh ├── run_real_fet.sh ├── run_scale.sh ├── split_scale.sh ├── train_fet.py └── train_solo.py ├── train ├── Evaluate.py ├── Fit.py └── __init__.py └── utils ├── BasicUtils.py ├── __init__.py └── logger.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | **/__pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | 7 | # Virtual Environment 8 | venv/ 9 | env/ 10 | .env 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyCharm 31 | .idea/ 32 | 33 | # VS Code 34 | .vscode/ 35 | 36 | # Jupyter Notebook 37 | .ipynb_checkpoints 38 | 39 | # Project specific 40 | data/syn/ 41 | cache/ 42 | log/ 43 | fig/ 44 | out/ 45 | 46 | # OS generated files 47 | .DS_Store 48 | .DS_Store? 49 | ._* 50 | .Spotlight-V100 51 | .Trashes 52 | ehthumbs.db 53 | Thumbs.db 54 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [NeruIPS 2024] Federated Transformer (FeT) 2 | 3 | [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 4 | [![NeurIPS 2024](https://img.shields.io/badge/NeurIPS-2024-red.svg)](https://neurips.cc/Conferences/2024) 5 | [![Python 3.10](https://img.shields.io/badge/python-3.10-blue.svg)](https://www.python.org/downloads/release/python-3100/) 6 | [![PyTorch 2.1.2](https://img.shields.io/badge/PyTorch-2.1.2-EE4C2C.svg)](https://pytorch.org/) 7 | [![CUDA 12.1](https://img.shields.io/badge/CUDA-12.1-76B900.svg)](https://developer.nvidia.com/cuda-toolkit) 8 | 9 | 10 | 11 | 12 | This paper _"[Federated Transformer: Multi-Party Vertical Federated Learning on Practical Fuzzily Linked Data](https://arxiv.org/pdf/2410.17986)"_ has been accepted by _**NeurIPS 2024**_. 13 | 14 | ## Project Overview 15 | 16 | The **Federated Transformer (FeT)** is a novel framework designed to handle **multi-party Vertical Federated Learning (VFL)** scenarios involving **fuzzy identifiers**, where distinct features of shared data instances are provided by different parties without directly sharing raw data. FeT addresses the challenges of performance degradation and privacy overhead commonly faced in existing multi-party VFL models. It innovatively **encodes fuzzy identifiers into data representations** and **distributes a transformer architecture across parties** to enhance collaborative learning. FeT integrates **differential privacy** and **secure multi-party computation** to ensure strong privacy protection while minimizing utility loss. Experimental results show that FeT significantly improves performance - boosting accuracy by up to 46% over [FedSim](https://github.com/Xtra-Computing/FedSim) when scaled to 50 parties - and achieves superior performance even in **two-party fuzzy VFL** scenarios compared to [FedSim](https://github.com/Xtra-Computing/FedSim). 17 | 18 | ## Features 19 | - Multi-party vertical federated learning 20 | - Promising Performance on Fuzzy identifiers 21 | - SplitAvg Framework: Differential Privacy and Secure Multi-party Computation 22 | 23 | ## Prerequisites 24 | ### Hardware Requirements 25 | 26 | The Federated Transformer (FeT) framework is designed to operate efficiently without necessitating high-memory GPUs. For small-scale implementations, such as two-party fuzzy Vertical Federated Learning (VFL) on modest datasets, a single GPU with 4GB of memory is sufficient. However, for more extensive applications, particularly 50-party fuzzy VFL on large-scale datasets, we recommend utilizing A100 GPUs with a minimum of 40GB memory capacity. It is important to note that the current implementation does not support multi-GPU training configurations. 27 | 28 | ### Software Dependencies 29 | 30 | The codebase has been developed and tested using Python version `3.10` with CUDA version `12.1`. While these specific versions are recommended, the framework is expected to be compatible with subsequent versions of both Python and CUDA. 31 | 32 | ## Installation 33 | 1. Clone the repository: 34 | ```bash 35 | git clone https://github.com/JerryLife/FeT.git 36 | cd FeT 37 | ``` 38 | 2. Set up a virtual environment (recommended): 39 | ```bash 40 | python -m venv fet 41 | source fet/bin/activate # On Windows, use `fet\Scripts\activate` 42 | ``` 43 | 3. Install the required dependencies: 44 | ```bash 45 | pip install -r requirements.txt 46 | ``` 47 | ## Dataset 48 | 49 | The real-world datasets used in FeT are the same as [FedSim](https://github.com/Xtra-Computing/FedSim). The synthetic datasets generated by splitting `gisette` and `mnist` dataset. Those synthetic datasets can be obtained by 50 | ```bash 51 | bash ./src/script/download_dataset.sh # download gisette and mnist dataset 52 | bash ./src/script/split_scale.sh # split them into multiple parties 53 | ``` 54 | 55 | ## Usage 56 | 57 | To train the Federated Transformer model, run the `train_fet.py` script located in the `src/script` directory. Below is the API documentation for `train_fet.py` along with example usage: 58 | 59 | ### API Documentation 60 | 61 | #### Arguments 62 | - `-g`, `--gpu` (int): GPU ID. Use `-1` for CPU. (default: `0`) 63 | - `-d`, `--dataset` (str): Dataset to use. 64 | - `-p`, `--n_parties` (int): Number of parties. Should be `>=2`. (default: `4`) 65 | - `-pp`, `--primary_party` (int): Primary party ID. Should be in `[0, n_parties-1]`. (default: `0`) 66 | - `-sp`, `--splitter` (str): Splitter method to use. (default: `'imp'`) 67 | - `-w`, `--weights` (float): Weights for the ImportanceSplitter. (default: `1`) 68 | - `-b`, `--beta` (float): Beta for the CorrelationSplitter. (default: `1`) 69 | - `-e`, `--epochs` (int): Number of training epochs. (default: `100`) 70 | - `-lr`, `--lr` (float): Learning rate. (default: `1e-3`) 71 | - `-wd`, `--weight_decay` (float): Weight decay for regularization. (default: `1e-5`) 72 | - `-bs`, `--batch_size` (int): Batch size. (default: `128`) 73 | - `-c`, `--n_classes` (int): Number of classes. `1` for regression, `2` for binary classification, `>=3` for multi-class classification. (default: `1`) 74 | - `-m`, `--metric` (str): Metric to evaluate the model. Supported metrics: `['accuracy', 'rmse']`. (default: `'acc'`) 75 | - `-rp`, `--result-path` (str): Path to save the result. (default: `None`) 76 | - `-s`, `--seed` (int): Random seed. (default: `0`) 77 | - `-ld`, `--log-dir` (str): Log directory. (default: `'log'`) 78 | - `-ded`, `--data-embed-dim` (int): Data embedding dimension. (default: `200`) 79 | - `-ked`, `--key-embed-dim` (int): Key embedding dimension. (default: `200`) 80 | - `-nh`, `--num-heads` (int): Number of heads in multi-head attention. (default: `4`) 81 | - `--dropout` (float): Dropout rate. (default: `0.0`) 82 | - `--party-dropout` (float): Dropout rate for entire party. (default: `0.0`) 83 | - `-nlb`, `--n-local-blocks` (int): Number of local blocks. (default: `6`) 84 | - `-nab`, `--n-agg-blocks` (int): Number of aggregation blocks. (default: `6`) 85 | - `--knn-k` (int): k for KNN. (default: `100`) 86 | - `--disable-pe` (bool): Disable positional encoding if set. (default: `False`) 87 | - `--disable-dm` (bool): Disable dynamic masking if set. (default: `False`) 88 | - `-paf`, `--pe-average-freq` (int): Average frequency for positional encoding on each party. (default: `0`) 89 | 90 | ### Example Usage 91 | 92 | To train the FeT model on the `house` dataset, run the following command: 93 | 94 | ```bash 95 | python src/script/train_fet.py -d house -m rmse -c 1 -p 2 -s 0 --knn-k 100 -nh 4 -ded 100 -ked 100 -nlb 3 -nab 3 -paf 1 --dropout 0.3 -g 0 96 | ``` 97 | 98 | ### Experimentation Scripts 99 | 100 | For conducting various experiments included in the paper, you can find the relevant scripts in the `src/script` directory. These scripts are designed to facilitate different experimental setups and can be customized as needed for your specific research requirements. The detailed usage of these scripts is as follows: 101 | 102 | - `download_datasets.sh`: This script is used to download `gisette` and `mnist` datasets used in the experiments. 103 | - `split_scale.sh`: This script is used to split `gisette` and `mnist` datasets into multiple parties with different hyperparameters. 104 | - `run_real_fet.sh`: This script is used to run the FeT model on three real-world datasets, including `house`, `taxi`, and `hdb`. 105 | - `run_scale.sh`: This script is used to run the FeT model on synthetic multi-party VFL datasets generated by splitting `gisette` and `mnist`. 106 | - `ablation*.sh`: This series of scripts are used to run the ablation studies on different components of the FeT model. 107 | - `ablation_dm_or_not.sh`: This script runs experiments to compare the performance of FeT with and without dynamic masking. 108 | - `ablation_keynoise_baseline.sh`: This script conducts experiments to evaluate the impact of key noise on baseline models. 109 | - `ablation_keynoise.sh`: This script tests the robustness of FeT against different levels of key noise. 110 | - `ablation_knnk_real.sh`: This script performs ablation studies on the effect of different k values in KNN for real-world datasets. 111 | - `ablation_knnk.sh`: This script examines the impact of varying k values in KNN on synthetic datasets. 112 | - `ablation_party_dropout.sh`: This script evaluates the model's performance under different party dropout rates. 113 | - `ablation_pe_average_freq.sh`: This script investigates the effect of different average frequencies in positional encoding. 114 | - `ablation_pe_or_not.sh`: This script compares the performance of FeT with and without positional encoding. 115 | - `ablation_real_dm_or_not.sh`: This script runs dynamic masking ablation studies specifically on real-world datasets. 116 | 117 | 118 | ## Citation 119 | 120 | If you find this work useful in your research, please consider citing our paper: 121 | 122 | ```bibtex 123 | @inproceedings{wu2024fet, 124 | title={Federated Transformer: Multi-Party Vertical Federated Learning on Practical Fuzzily Linked Data}, 125 | author={Wu, Zhaomin and Hou, Junyi and Diao, Yiqun and He, Bingsheng}, 126 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, 127 | year={2024} 128 | } 129 | ``` 130 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cachetools==5.3.2 2 | diskcache==5.6.3 3 | joblib==1.3.2 4 | matplotlib==3.8.2 5 | numpy==1.24.0 6 | nvitop==1.3.2 7 | opacus==1.4.0 8 | opt_einsum==3.3.0 9 | pandas==2.2.1 10 | pymoo==0.6.1.1 11 | pytz==2023.4 12 | Requests==2.31.0 13 | scikit_learn==1.4.0 14 | scipy==1.12.0 15 | shap==0.43.0 16 | torch_optimizer==0.3.0 17 | torchmetrics==1.3.0.post0 18 | torchsummary==1.5.1 19 | tqdm==4.66.1 20 | wget==3.2 21 | tensorboard==2.15.1 22 | torch==2.1.2 23 | torchinfo==1.8.0 24 | torchvision==0.16.2 25 | nmslib-metabrainz==2.1.2 26 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/FeT/836cd91602b3a0fa6379c5b000b7df288bced790/src/__init__.py -------------------------------------------------------------------------------- /src/dataset/DataSampler.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Sequence, Union 3 | from copy import deepcopy 4 | from collections import defaultdict 5 | from functools import wraps 6 | 7 | from cachetools import cached, LRUCache 8 | from cachetools.keys import hashkey 9 | import numpy as np 10 | import nmslib 11 | 12 | 13 | class DataSampler(abc.ABC): 14 | return_multi = False 15 | def __init__(self, keys: Sequence, primary_party_id: int = 0, seed=None, **kwargs): 16 | """ 17 | Sample the indices from a sequence of dataset according to keys 18 | :param keys: keys of each dataset 19 | :param primary_party_id: primary party id 20 | """ 21 | self.keys = deepcopy(keys) 22 | self.primary_party_id = primary_party_id 23 | self.seed = seed 24 | np.random.seed(seed) 25 | self.n_datasets = len(keys) 26 | 27 | def initialize(self, keys, **kwargs): 28 | pass 29 | 30 | @abc.abstractmethod 31 | def sample(self, p_id): 32 | """ 33 | Sample one or multiple secondary indices from one primary ID 34 | """ 35 | raise NotImplementedError 36 | 37 | 38 | class SimSampler(DataSampler, abc.ABC): 39 | def __init__(self, keys: Sequence, primary_party_id: int = 0, seed=None, indices=None, keep_primary=False, 40 | **kwargs): 41 | """ 42 | Sample the indices from a sequence of dataset according to keys 43 | :param keys: keys of each party 44 | :param primary_party_id: primary party id 45 | :param seed: random see 46 | :param indices: If None, initialize the indices from keys. Otherwise, use the given indices. 47 | :param kwargs: other parameters for self.initialize 48 | """ 49 | super().__init__(keys, primary_party_id, seed, **kwargs) 50 | self.indices = None # indices for querying 51 | self.keep_primary = keep_primary 52 | if indices is None: 53 | self.initialize(keys, **kwargs) 54 | else: 55 | self.indices = indices 56 | 57 | def initialize(self, keys, method='hnsw', space='l2'): 58 | self.indices = [] 59 | for i, key in enumerate(keys): 60 | if (not self.keep_primary) and i == self.primary_party_id: 61 | self.indices.append(None) # placeholder for primary party 62 | continue 63 | 64 | index = nmslib.init(method=method, space=space, data_type=nmslib.DataType.DENSE_VECTOR) 65 | index.addDataPointBatch(key) 66 | index.createIndex() 67 | self.indices.append(index) 68 | 69 | 70 | class Top1Sampler(SimSampler): 71 | def __init__(self, keys: Sequence, primary_party_id: int = 0, seed=None, indices=None, **kwargs): 72 | super().__init__(keys, primary_party_id, seed, indices, **kwargs) 73 | 74 | def sample(self, idx): 75 | """ 76 | Sample most similar secondary indices from one primary ID (pid) 77 | :param idx: [Int] sample ID on the primary party 78 | :return: List[Int] sampled indices for all parties. The i-th element is the sampled index on the i-th party. 79 | The total length should be `self.n_datasets`. 80 | """ 81 | 82 | primary_key = self.keys[self.primary_party_id][idx] 83 | 84 | sampled_indices = [] 85 | for data_id in range(self.n_datasets): 86 | if data_id == self.primary_party_id: 87 | sampled_indices.append(idx) 88 | else: 89 | nbr_ids, dist = self.indices[data_id].knnQuery(primary_key, k=1) 90 | sampled_indices.append(nbr_ids[0]) 91 | 92 | return sampled_indices 93 | 94 | 95 | class TopkUniformSampler(SimSampler): 96 | def __init__(self, keys: Sequence, ks: Union[Sequence, int], primary_party_id: int = 0, seed=None, **kwargs): 97 | super().__init__(keys, primary_party_id, seed, **kwargs) 98 | if isinstance(ks, int): 99 | self.ks = [ks] * self.n_datasets 100 | else: 101 | self.ks = ks 102 | 103 | if len(self.ks) != self.n_datasets: 104 | raise ValueError(f"The length of ks {len(self.ks)} should be the same as the number of parties " 105 | f"{self.n_datasets}") 106 | 107 | def sample(self, idx): 108 | """ 109 | Uniformly sample secondary indices among topk-most-similar ones from one primary ID (pid) 110 | :param idx: [Int] sample ID on the primary party 111 | :return: List[Int] sampled indices for all parties. The i-th element is the sampled index on the i-th party. 112 | The total length should be `self.n_datasets`. 113 | """ 114 | primary_key = self.keys[self.primary_party_id][idx] 115 | 116 | sampled_indices = [] 117 | for data_id in range(self.n_datasets): 118 | if data_id == self.primary_party_id: 119 | sampled_indices.append(idx) 120 | else: 121 | nbr_ids, dist = self.indices[data_id].knnQuery(primary_key, k=self.ks[data_id]) 122 | sampled_indices.append(np.random.choice(nbr_ids)) 123 | 124 | return sampled_indices 125 | 126 | 127 | class TopkSimAsProbSampler(SimSampler): 128 | def __init__(self, keys: Sequence, ks: Union[Sequence, int], primary_party_id: int = 0, seed=None, **kwargs): 129 | super().__init__(keys, primary_party_id, seed, **kwargs) 130 | if isinstance(ks, int): 131 | self.ks = [ks] * self.n_datasets 132 | else: 133 | self.ks = ks 134 | 135 | if len(self.ks) != self.n_datasets: 136 | raise ValueError(f"The length of ks {len(self.ks)} should be the same as the number of parties " 137 | f"{self.n_datasets}") 138 | 139 | def sample(self, idx): 140 | """ 141 | Sample secondary indices among topk-most-similar ones from one primary ID (pid). The probability of sampling 142 | is proportional to the exponential negative distance. 143 | :param idx: [Int] sample ID on the primary party 144 | :return: List[Int] sampled indices for all parties. The i-th element is the sampled index on the i-th party. 145 | The total length should be `self.n_datasets`. 146 | """ 147 | primary_key = self.keys[self.primary_party_id][idx] 148 | 149 | sampled_indices = [] 150 | for data_id in range(self.n_datasets): 151 | if data_id == self.primary_party_id: 152 | sampled_indices.append(idx) 153 | else: 154 | nbr_ids, dist = self.indices[data_id].knnQuery(primary_key, k=self.ks[data_id]) 155 | # make dist as probability 156 | scaled_dist = np.exp(-dist) / np.sum(np.exp(-dist)) 157 | sampled_indices.append(np.random.choice(nbr_ids, p=scaled_dist)) 158 | 159 | return sampled_indices 160 | 161 | 162 | def conditional_cached(cache_key, cache=LRUCache(maxsize=10**7)): 163 | def decorator(func): 164 | if cache_key is not None: 165 | # Apply caching if cache_key is not None 166 | return cached(cache=cache, key=lambda self, idx: hashkey((cache_key, idx)))(func) 167 | else: 168 | # Return the original function unmodified if cache_key is None 169 | @wraps(func) 170 | def wrapper(*args, **kwargs): 171 | return func(*args, **kwargs) 172 | return wrapper 173 | return decorator 174 | 175 | 176 | class TopkSampler(SimSampler): 177 | return_multi = True # return multiple indices for secondary parties 178 | def __init__(self, keys: Sequence, ks: Union[Sequence, int], primary_party_id: int = 0, seed=None, indices=None, 179 | multi_primary=False, sample_rate_before_topk=None, cache_key=None, **kwargs): 180 | super().__init__(keys, primary_party_id, seed, method='brute_force', space='l2', indices=indices, 181 | keep_primary=multi_primary, **kwargs) 182 | self.multi_primary = multi_primary 183 | self.sample_rate_before_topk = sample_rate_before_topk 184 | if isinstance(ks, int): 185 | self.ks = [ks] * self.n_datasets 186 | else: 187 | self.ks = ks 188 | 189 | if len(self.ks) != self.n_datasets: 190 | raise ValueError(f"The length of ks {len(self.ks)} should be the same as the number of parties " 191 | f"{self.n_datasets}") 192 | self.cache_key = cache_key 193 | 194 | # @conditional_cached(cache_key=lambda self: self.cache_key) 195 | def sample(self, idx): 196 | """ 197 | Sample secondary indices among topk-most-similar ones from one primary ID (pid). The probability of sampling 198 | is proportional to the exponential negative distance. 199 | :param idx: [Int] sample ID on the primary party 200 | :param multi_primary: [Bool] whether to return multiple indices for primary party 201 | :return: List[List[Int]] sampled indices for all parties. The i-th element is the sampled index set on the i-th 202 | The total length should be `self.n_datasets`. The length of the i-th element is `self.ks[i]`. 203 | """ 204 | 205 | primary_key = self.keys[self.primary_party_id][idx] 206 | 207 | sampled_indices = [] 208 | for data_id in range(self.n_datasets): 209 | if (not self.multi_primary) and data_id == self.primary_party_id: 210 | sampled_indices.append(idx) 211 | continue 212 | 213 | if self.sample_rate_before_topk: 214 | index = nmslib.init(method='brute_force', space='l2', data_type=nmslib.DataType.DENSE_VECTOR) 215 | sample_size = int(self.sample_rate_before_topk * len(self.keys[data_id])) 216 | sample_idx = np.random.choice(len(self.keys[data_id]), sample_size, replace=False) 217 | sample_key = self.keys[data_id][sample_idx] 218 | index.addDataPointBatch(sample_key) 219 | index.createIndex() 220 | nbr_ids_i, dist = index.knnQuery(primary_key, k=self.ks[data_id]) 221 | nbr_ids = sample_idx[nbr_ids_i] 222 | if len(nbr_ids) < self.ks[data_id]: 223 | # repeat nbr_ids if not enough 224 | nbr_ids = np.concatenate([nbr_ids, np.random.choice(nbr_ids, self.ks[data_id] - len(nbr_ids))]) 225 | else: 226 | nbr_ids, dist = self.indices[data_id].knnQuery(primary_key, k=self.ks[data_id]) 227 | 228 | sampled_indices.append(nbr_ids) 229 | 230 | return sampled_indices 231 | 232 | 233 | class RandomSampler: 234 | return_multi = True # return multiple indices for secondary parties 235 | def __init__(self, sizes, n_samples, primary_party_id=0, seed=None): 236 | self.sizes = sizes 237 | self.n_samples = n_samples 238 | self.primary_party_id = primary_party_id 239 | self.seed = seed 240 | np.random.seed(seed) 241 | 242 | def sample(self, idx): 243 | """ 244 | Sample one or multiple secondary indices from one primary ID 245 | :param p_id: primary party ID 246 | :return: List[List[Int]] sampled indices for all parties. The i-th element is the sampled index set on the i-th 247 | The total length should be `self.n_datasets`. The length of the i-th element is `self.ks[i]`. 248 | """ 249 | sampled_indices = [] 250 | for data_id in range(len(self.sizes)): 251 | if data_id == self.primary_party_id: 252 | sampled_indices.append(idx) 253 | continue 254 | sampled_indices.append(np.random.choice(self.sizes[data_id], self.n_samples, replace=False)) 255 | 256 | return sampled_indices 257 | -------------------------------------------------------------------------------- /src/dataset/LocalDataset.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Protocol 3 | import pickle 4 | 5 | import numpy as np 6 | import pandas 7 | import pandas as pd 8 | from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler 9 | from sklearn.model_selection import train_test_split 10 | 11 | import torch 12 | from torch.utils.data import Dataset 13 | 14 | 15 | class LocalDataset(Dataset): 16 | """ 17 | Base class for local datasets 18 | """ 19 | 20 | def __init__(self, X, y=None, key=None, **kwargs): 21 | """ 22 | Required parameters: 23 | :param X: features (array) 24 | 25 | Optional parameters: 26 | :param key: key of the ID (array) 27 | :param y: labels (1d array) 28 | """ 29 | self.key = key 30 | if isinstance(X, np.ndarray): 31 | self.X = X.astype(np.float32) 32 | elif isinstance(X, torch.Tensor): 33 | self.X = X.float() 34 | else: 35 | raise TypeError(f"X should be either np.ndarray or torch.Tensor, but got {type(X)}") 36 | 37 | if y is None: 38 | self.y = None 39 | elif isinstance(y, np.ndarray): 40 | self.y = y.astype(np.float32) 41 | elif isinstance(y, torch.Tensor): 42 | self.y = y.float() if y is not None else None 43 | else: 44 | raise TypeError("y should be either np.ndarray or torch.Tensor") 45 | 46 | self.check_shape() 47 | 48 | # if key is not provided, set key as nan to avoid collate_fn error 49 | if self.key is None: 50 | self.key = np.arange(self.X.shape[0]).reshape(-1, 1) 51 | # self.key = np.full((self.X.shape[0], 1), np.nan) 52 | 53 | def __add__(self, other): 54 | """ 55 | Concatenate two LocalDataset 56 | """ 57 | if not isinstance(other, LocalDataset): 58 | raise TypeError(f"other should be LocalDataset, but got {type(other)}") 59 | 60 | if self.X.shape[1:] != other.X.shape[1:]: 61 | raise ValueError(f"self.X.shape[1:] != other.X.shape[1:]") 62 | 63 | X = np.concatenate([self.X, other.X], axis=0) 64 | if self.y is None and other.y is None: 65 | y = None 66 | else: 67 | if self.y is None and other.y is not None: 68 | raise ValueError(f"self.y is None, but other.y is not None") 69 | if self.y is not None and other.y is None: 70 | raise ValueError(f"self.y is not None, but other.y is None") 71 | if self.y is not None and other.y is not None: 72 | if self.y.shape[1:] != other.y.shape[1:]: 73 | raise ValueError(f"self.y.shape[1:] != other.y.shape[1:]") 74 | y = np.concatenate([self.y, other.y], axis=0) if self.y is not None else None 75 | if self.key is None and other.key is None: 76 | key = None 77 | else: 78 | if self.key.shape[1:] != other.key.shape[1:]: 79 | raise ValueError(f"self.key.shape[1:] != other.key.shape[1:]") 80 | key = np.concatenate([self.key, other.key], axis=0) 81 | return LocalDataset(X, y, key) 82 | 83 | @torch.no_grad() 84 | def check_shape(self): 85 | if self.y is not None: 86 | assert self.X.shape[0] == self.y.shape[0], "The number of samples in X and y should be the same" 87 | if self.key is not None: 88 | assert self.X.shape[0] == self.key.shape[0], "The number of samples in X and key should be the same" 89 | 90 | def __len__(self): 91 | return self.X.shape[0] 92 | 93 | def __getitem__(self, idx): 94 | """ 95 | :param idx: the index of the item 96 | :return: key[idx], X[idx], y[idx] 97 | """ 98 | X = self.X[idx] 99 | key = self.key[idx] if self.key is not None else None 100 | y = self.y[idx] if self.y is not None else None 101 | return (key, X), y 102 | 103 | @property 104 | def data(self): 105 | return self.key, self.X, self.y 106 | 107 | @property 108 | def key_X_dim(self): 109 | if self.key is None: 110 | return self.X.shape[1] 111 | else: 112 | return self.X.shape[1] + self.key.shape[1] 113 | 114 | @classmethod 115 | def from_csv(cls, csv_path, header=None, key_cols=1, **kwargs): 116 | """ 117 | Load dataset from csv file. The key_cols columns are keys, the last column is the label, and the rest 118 | columns are features. 119 | :param csv_path: path to csv file 120 | :param header: row number(s) to use as the column names, and the start of the data. 121 | Same as the header in pandas.read_csv() 122 | :param key_cols: Int. Number of key columns. | key1 | key2 | key.. | keyN | X1 | X2 | X3.. | Xn | y | 123 | """ 124 | df = pd.read_csv(csv_path, header=header) 125 | if key_cols is None: 126 | key = None 127 | X = df.iloc[:, :-1].values 128 | else: 129 | assert df.shape[1] > key_cols + 1, "The number of columns should be larger than key_cols + 1" 130 | key = df.iloc[:, :key_cols].values 131 | X = df.iloc[:, key_cols:-1].values 132 | y = df.iloc[:, -1].values 133 | return cls(X, y, key, **kwargs) 134 | 135 | @classmethod 136 | def from_pickle(cls, pickle_path): 137 | with open(pickle_path, 'rb') as f: 138 | return pickle.load(f) 139 | 140 | def to_pickle(self, path): 141 | with open(path, 'wb') as f: 142 | pickle.dump(self, f) 143 | 144 | def to_csv(self, path, type='raw'): 145 | # flatten >=2 dimensional X (e.g. image) to 1 dimensional 146 | if len(self.X.shape) > 2: 147 | X = self.X.reshape(self.X.shape[0], -1) 148 | else: 149 | X = self.X 150 | 151 | assert type in ['raw', 'fedtree', 'fedtrans'], "type should be in ['raw', 'fedtree', 'fedtrans']" 152 | if type == 'raw': 153 | df = pd.DataFrame(np.concatenate([X, self.y.reshape(-1, 1)], axis=1)) 154 | df.to_csv(path, header=False, index=False) 155 | if type == 'fedtrans': 156 | y = self.y 157 | if y is None: 158 | # create dummy y 159 | y = np.array([None for i in range(X.shape[0])]) 160 | df = pd.DataFrame(np.concatenate([self.key, X, y.reshape(-1, 1)], axis=1)) 161 | for i in range(self.key.shape[1]): 162 | df.rename(columns={i: f'key{i}'}, inplace=True) 163 | for i in range(X.shape[1]): 164 | df.rename(columns={i + self.key.shape[1]: f'x{i}'}, inplace=True) 165 | df.rename(columns={ df.shape[1] - 1: 'y'}, inplace=True) 166 | df.to_csv(path, header=True, index=False) # You have to have a header to be able to read it back in. otherwise we don't know if there is a y column or not 167 | elif type == 'fedtree': 168 | if self.key is None: 169 | raise ValueError("key is None. FedTree requires key column.") 170 | if len(self.key.shape) != 1 and self.key.shape[1] != 1: 171 | raise ValueError("FedTree does not support multi-dimensional key.") 172 | if self.y is None: 173 | columns = ['id'] + [f'x{i}' for i in range(X.shape[1])] 174 | df = pd.DataFrame(np.concatenate([self.key.reshape(-1, 1), X], axis=1), columns=columns) 175 | else: 176 | columns = ['id', 'y'] + [f'x{i}' for i in range(X.shape[1])] 177 | df = pd.DataFrame(np.concatenate([self.key.reshape(-1, 1), self.y.reshape(-1, 1), X], axis=1), 178 | columns=columns) 179 | df.to_csv(path, index=False) 180 | else: 181 | raise NotImplementedError(f"CSV type {type} is not implemented.") 182 | 183 | def to_tensor_(self): 184 | """ 185 | Convert X, y, key to torch.Tensor 186 | """ 187 | if isinstance(self.X, np.ndarray): 188 | self.X = torch.from_numpy(self.X).float() 189 | if isinstance(self.y, np.ndarray): 190 | self.y = torch.from_numpy(self.y).float() 191 | if isinstance(self.key, np.ndarray): 192 | self.key = torch.from_numpy(self.key).float() 193 | 194 | def scale_y_(self, lower=0, upper=1, scaler=None): 195 | """ 196 | Scale the label to [lower, upper] 197 | """ 198 | if self.y is None: 199 | return None 200 | 201 | if scaler is None: 202 | scaler = MinMaxScaler(feature_range=(lower, upper)) 203 | self.y = scaler.fit_transform(self.y.reshape(-1, 1)).reshape(-1) 204 | return scaler 205 | else: 206 | self.y = scaler.transform(self.y.reshape(-1, 1)).reshape(-1) 207 | return None 208 | 209 | def normalize_(self, scaler=None, include_key=False): 210 | """ 211 | Normalize the features 212 | """ 213 | if scaler is None: 214 | scaler = StandardScaler() 215 | if include_key: 216 | key_X = np.concatenate([self.key, self.X], axis=1) 217 | key_X = scaler.fit_transform(key_X) 218 | self.key = key_X[:, :self.key.shape[1]] 219 | self.X = key_X[:, self.key.shape[1]:] 220 | else: 221 | self.X = scaler.fit_transform(self.X) 222 | return scaler 223 | else: 224 | if include_key: 225 | key_X = np.concatenate([self.key, self.X], axis=1) 226 | key_X = scaler.transform(key_X) 227 | self.key = key_X[:, :self.key.shape[1]] 228 | self.X = key_X[:, self.key.shape[1]:] 229 | else: 230 | self.X = scaler.transform(self.X) 231 | return None 232 | 233 | def split_train_test(self, val_ratio=0.1, test_ratio=0.2, random_state=None, shuffle=False): 234 | """ 235 | Split the dataset into train and test set. 236 | :param val_ratio: ratio of validation set, if None, no validation set will be generated 237 | :param test_ratio: ratio of test set, if None, no test set will be generated 238 | :param random_state: random state, by default None 239 | :param hard_train_test_split: the split point for hard train-test split, 240 | e.g., cifar10 has 50K train data and 10K test data, we set it as 50K, and the input dataset should be the concatenation of [train_data, test_data]. 241 | Default 0 means no hard train-test split. 242 | :return: three LocalDataset, train, val, test 243 | """ 244 | key, X, y = self.data 245 | 246 | if y is None: 247 | raise ValueError(f"y should not be None") 248 | 249 | def train_test_split_ignore_none(X, y, key, test_size, random_state, shuffle): 250 | if key is None: 251 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, 252 | random_state=random_state, shuffle=shuffle) 253 | return X_train, X_test, y_train, y_test, None, None 254 | else: 255 | X_train, X_test, y_train, y_test, key_train, key_test = train_test_split(X, y, key, test_size=test_size, 256 | random_state=random_state, shuffle=shuffle) 257 | return X_train, X_test, y_train, y_test, key_train, key_test 258 | 259 | match val_ratio, test_ratio: 260 | case (None, None): 261 | raise ValueError("val_ratio and test_ratio cannot be both None") 262 | case (None, _): 263 | X_train, X_test, y_train, y_test, key_train, key_test = train_test_split_ignore_none(X, y, key, 264 | test_size=test_ratio, 265 | random_state=random_state, 266 | shuffle=shuffle) 267 | return [LocalDataset(X_train, y_train, key_train), 268 | None, 269 | LocalDataset(X_test, y_test, key_test)] 270 | case (_, None): 271 | X_train, X_val, y_train, y_val, key_train, key_val = train_test_split_ignore_none(X, y, key, 272 | test_size=val_ratio, 273 | random_state=random_state, 274 | shuffle=shuffle) 275 | return [LocalDataset(X_train, y_train, key_train), 276 | LocalDataset(X_val, y_val, key_val), 277 | None] 278 | case (_, _): 279 | X_train_val, X_test, y_train_val, y_test, key_train_val, key_test = ( 280 | train_test_split_ignore_none(X, y, key, test_size=test_ratio, random_state=random_state, 281 | shuffle=shuffle)) 282 | 283 | X_train, X_val, y_train, y_val, key_train, key_val = ( 284 | train_test_split_ignore_none(X_train_val, y_train_val, key_train_val, test_size=val_ratio / (1 - test_ratio), 285 | random_state=random_state, shuffle=shuffle)) 286 | 287 | return [LocalDataset(X_train, y_train, key_train), 288 | LocalDataset(X_val, y_val, key_val), 289 | LocalDataset(X_test, y_test, key_test)] 290 | 291 | 292 | -------------------------------------------------------------------------------- /src/dataset/VFLDataset.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import os.path 3 | import sys 4 | 5 | import numpy as np 6 | 7 | from torch.utils.data import Dataset 8 | 9 | 10 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 11 | 12 | from dataset.LocalDataset import LocalDataset 13 | from src.utils.BasicUtils import PartyPath 14 | from preprocess.FeatureEvaluator import CorrelationEvaluator 15 | 16 | 17 | class VFLDataset: 18 | """ 19 | Base class for vertical federated learning (VFL) datasets. The __len__ and __getitem__ methods are 20 | undefined because the length of local datasets and the way to get items might be different. They should be 21 | defined in the subclass. 22 | """ 23 | 24 | def __init__(self, num_parties: int, local_datasets, primary_party_id: int = 0): 25 | """ 26 | :param num_parties: number of parties 27 | :param local_datasets: local datasets of each party (list of LocalDataset or None) 28 | :param primary_party_id: primary party (the party with labels) id, should be in range of [0, num_parties) 29 | """ 30 | 31 | self.num_parties = num_parties 32 | self.local_datasets = local_datasets 33 | self.primary_party_id = primary_party_id 34 | 35 | self.check_param() 36 | 37 | def check_param(self): 38 | """ 39 | Check if the data is valid 40 | """ 41 | if self.local_datasets is not None: 42 | assert len(self.local_datasets) == self.num_parties, \ 43 | f"The number of parties {self.num_parties} should be the same as the number of local datasets {len(self.local_datasets)}" 44 | 45 | assert 0 <= self.primary_party_id < self.num_parties, "primary_party_id should be in range of [0, num_parties)" 46 | 47 | 48 | class VFLRawDataset(VFLDataset, abc.ABC): 49 | """ 50 | Linkable VFL dataset where the keys of local datasets are the same. It does not define the __len__ and 51 | __getitem__ methods, thus it cannot be directly used in torch.utils.data.DataLoader to train models. 52 | """ 53 | def __init__(self, num_parties: int, local_datasets: list, primary_party_id: int = 0): 54 | """ 55 | :param num_parties: number of parties 56 | :param local_datasets: (list) local datasets of each party, each element is a LocalDataset 57 | :param primary_party_id: primary party (the party with labels) id, should be in range of [0, num_parties) 58 | """ 59 | super().__init__(num_parties, local_datasets, primary_party_id) 60 | self.check_key() 61 | 62 | def check_key(self): 63 | """ 64 | Check if the keys of local datasets are the same 65 | """ 66 | key = self.local_datasets[0].key 67 | for local_dataset in self.local_datasets: 68 | assert key.shape[1] == local_dataset.key.shape[1], "The number of columns of keys should be the same" 69 | 70 | @abc.abstractmethod 71 | def link(self, *args, **kwargs): 72 | """ 73 | Link the local datasets 74 | :return: VFLAlignedDataset, a linked VFL dataset 75 | """ 76 | pass 77 | 78 | 79 | class VFLAlignedDataset(VFLDataset, Dataset): 80 | """ 81 | Trainable VFL dataset where the number of samples in local datasets is the same. It defines the __len__ and 82 | __getitem__ methods, thus it can be directly used in torch.utils.data.DataLoader to train models. 83 | """ 84 | def __init__(self, num_parties: int, local_datasets, primary_party_id: int = 0): 85 | """ 86 | :param num_parties: number of parties 87 | :param local_datasets: (ndarray) local datasets of each party, each element is a LocalDataset. Note that 88 | this CANNOT be changed to a list because a PyTorch multiprocessing issue (see 89 | https://github.com/pytorch/pytorch/issues/13246) 90 | :param primary_party_id: primary party (the party with labels) id, should be in range of [0, num_parties) 91 | """ 92 | super().__init__(num_parties, local_datasets, primary_party_id) 93 | self.local_datasets = np.array([None for _ in range(num_parties)]) 94 | self.local_datasets[:] = local_datasets 95 | 96 | def check_shape(self): 97 | """ 98 | Check if the shape of local datasets is aligned 99 | """ 100 | assert self.local_datasets[self.primary_party_id].y is not None, f"The primary party {self.primary_party_id} does not have labels" 101 | for local_dataset in self.local_datasets: 102 | assert len(local_dataset) == len(self.local_datasets[self.primary_party_id]), \ 103 | "The number of samples in local datasets should be the same" 104 | 105 | def __len__(self): 106 | return len(self.local_datasets[self.primary_party_id]) 107 | 108 | def __getitem__(self, idx): 109 | """ 110 | Invoke __getitem__ of each local dataset to get the item 111 | :param idx: the index of the item 112 | :return: a tuple of tensors. The last tensor is a tensor of y. The rest tensors are tensors of X. 113 | """ 114 | Xs = [] 115 | for local_dataset in self.local_datasets: 116 | _, X, _ = local_dataset[idx] # key is omitted because it is not used in training 117 | Xs.append(X) 118 | _, _, y = self.local_datasets[self.primary_party_id][idx] 119 | 120 | return Xs, y 121 | 122 | @property 123 | def local_input_channels(self): 124 | return [local_dataset.X.shape[1] for local_dataset in self.local_datasets] 125 | 126 | @property 127 | def local_key_channels(self): 128 | return [local_dataset.key.shape[1] for local_dataset in self.local_datasets] 129 | 130 | @property 131 | def local_key_X_channels(self): 132 | return [local_dataset.key.shape[1] + local_dataset.X.shape[1] for local_dataset in self.local_datasets] 133 | 134 | def scale_y_(self, lower=0, upper=1, scaler=None): 135 | """ 136 | Scale the labels to [lower, upper] 137 | :param lower: lower bound 138 | :param upper: upper bound 139 | :param scaler: scaler to use. If None, use the scaler of the primary party. 140 | """ 141 | return self.local_datasets[self.primary_party_id].scale_y_(lower=lower, upper=upper, scaler=scaler) 142 | 143 | 144 | class VFLSynAlignedDataset(VFLAlignedDataset): 145 | def __init__(self, num_parties: int, local_datasets, primary_party_id: int = 0): 146 | super().__init__(num_parties, local_datasets, primary_party_id) 147 | 148 | @classmethod 149 | def from_pickle(cls, dir: str, dataset: str, n_parties, primary_party_id: int = 0, 150 | splitter: str = 'imp', weight: float = 1, beta: float = 1, seed: int = 0, type='train'): 151 | """ 152 | Load a VFLAlignedDataset from pickle file. The pickle files are local datasets of each party. 153 | 154 | Parameters 155 | ---------- 156 | dir : str 157 | The directory of pickle files. 158 | dataset : str 159 | The name of the dataset. 160 | n_parties : int 161 | The number of parties. 162 | primary_party_id : int, optional 163 | The primary party id, by default 0 164 | splitter : str, optional 165 | The splitter used to split the dataset, by default 'imp' 166 | weight : float, optional 167 | The weight of the primary party, by default 1 168 | beta : float, optional 169 | The beta of the primary party, by default 1 170 | seed : int, optional 171 | The seed of the primary party, by default 0 172 | type : str, optional 173 | The type of the dataset, by default 'train'. It should be ['train', 'test']. 174 | """ 175 | # assert type in ['train', 'test', 'both'], "type should be 'train', 'test', or 'both'" 176 | if type is None or type in ['train', 'test']: 177 | local_datasets = [] 178 | for party_id in range(n_parties): 179 | path_in_dir = PartyPath(dataset_path=dataset, n_parties=n_parties, party_id=party_id, 180 | splitter=splitter, weight=weight, beta=beta, seed=seed, fmt='pkl').data(type) 181 | path = os.path.join(dir, path_in_dir) 182 | if not os.path.exists(path): 183 | raise FileNotFoundError(f"File {path} does not exist") 184 | local_dataset = LocalDataset.from_pickle(path) 185 | if party_id != primary_party_id: # remove y of secondary parties 186 | local_dataset.y = None 187 | local_datasets.append(local_dataset) 188 | return cls(n_parties, local_datasets, primary_party_id) 189 | 190 | if type == 'both': 191 | # load train and test datasets and merge them 192 | local_datasets = [] 193 | for party_id in range(n_parties): 194 | train_path_in_dir = PartyPath(dataset_path=dataset, n_parties=n_parties, party_id=party_id, 195 | splitter=splitter, weight=weight, beta=beta, seed=seed, fmt='pkl').data('train') 196 | train_path = os.path.join(dir, train_path_in_dir) 197 | if not os.path.exists(train_path): 198 | raise FileNotFoundError(f"File {train_path} does not exist") 199 | train_local_dataset = LocalDataset.from_pickle(train_path) 200 | 201 | test_path_in_dir = PartyPath(dataset_path=dataset, n_parties=n_parties, party_id=party_id, 202 | splitter=splitter, weight=weight, beta=beta, seed=seed, fmt='pkl').data('test') 203 | test_path = os.path.join(dir, test_path_in_dir) 204 | if not os.path.exists(test_path): 205 | raise FileNotFoundError(f"File {test_path} does not exist") 206 | test_local_dataset = LocalDataset.from_pickle(test_path) 207 | 208 | local_dataset = train_local_dataset + test_local_dataset 209 | local_datasets.append(local_dataset) 210 | return cls(n_parties, local_datasets, primary_party_id) 211 | 212 | def visualize_corr(self, corr_func='spearmanr', gpu_id=None, output_score=True): 213 | """ 214 | Visualize the correlation of the dataset in heatmap. 215 | """ 216 | evaluator = CorrelationEvaluator(corr_func=corr_func, gpu_id=gpu_id) 217 | Xs = [local_dataset.X for local_dataset in self.local_datasets] 218 | if output_score: 219 | score = evaluator.fit_evaluate(Xs) 220 | else: 221 | score = evaluator.fit(Xs) # score is None 222 | evaluator.visualize(value=score) 223 | -------------------------------------------------------------------------------- /src/dataset/VFLRealDataset.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import os.path 3 | import sys 4 | from typing import Sequence, List, Tuple, Optional 5 | from copy import deepcopy 6 | import pickle 7 | import multiprocessing as mp 8 | import ctypes 9 | 10 | import numpy as np 11 | from sklearn.model_selection import train_test_split 12 | 13 | import torch 14 | from torch.utils.data import Dataset 15 | 16 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 17 | 18 | from src.dataset.LocalDataset import LocalDataset 19 | from src.dataset.VFLDataset import VFLSynAlignedDataset 20 | from src.utils.BasicUtils import PartyPath 21 | from src.dataset.DataSampler import Top1Sampler, TopkUniformSampler, TopkSimAsProbSampler, TopkSampler, RandomSampler 22 | 23 | 24 | class VFLRealDataset(Dataset): 25 | def __init__(self, local_datasets=None, primary_party_id=0, key_cols=None, multi_primary=False, 26 | primary_train_local_dataset=None, ks=100, sample_rate_before_topk=None, cache_key=None, 27 | use_cache=False, primary_train_indices=None, **kwargs): 28 | self.key_cols = key_cols 29 | self.sample_rate_before_topk = sample_rate_before_topk 30 | if local_datasets is None or isinstance(local_datasets[0], LocalDataset): 31 | self.local_datasets = local_datasets 32 | elif isinstance(local_datasets[0], Sequence): 33 | # (List[numpy.ndarray], ndarray) => ([X1, X2, ...], y) 34 | Xs, y = local_datasets 35 | self.local_datasets = [] 36 | for i, X in enumerate(Xs): 37 | if key_cols is not None: 38 | key = X[:, :key_cols] 39 | X = X[:, key_cols:] 40 | else: 41 | key = None 42 | if i == primary_party_id: 43 | self.local_datasets.append(LocalDataset(X, y=y, key=key)) 44 | else: 45 | self.local_datasets.append(LocalDataset(X, key=key)) 46 | 47 | else: 48 | raise TypeError(f"local_datasets should be either LocalDataset or Sequence, " 49 | f"but got {type(local_datasets)}") 50 | 51 | if primary_train_local_dataset is None: 52 | # this object the training dataset 53 | assert primary_train_indices is None 54 | self.primary_train_local_dataset = self.local_datasets[primary_party_id] 55 | else: 56 | # this object the test/val dataset 57 | assert primary_train_indices is not None 58 | self.primary_train_local_dataset = primary_train_local_dataset 59 | 60 | self.primary_party_id = primary_party_id 61 | self.multi_primary = multi_primary 62 | self.cache_key = cache_key 63 | 64 | if key_cols is not None: 65 | print("Using TopkSampler") 66 | self.data_sampler = TopkSampler([local_dataset.key for local_dataset in self.local_datasets], ks=ks, 67 | primary_party_id=primary_party_id, seed=0, multi_primary=multi_primary, 68 | sample_rate_before_topk=sample_rate_before_topk, cache_key=cache_key, 69 | indices=primary_train_indices) 70 | else: 71 | print("Using RandomSampler") 72 | self.data_sampler = RandomSampler([local_dataset.X.shape[0] for local_dataset in self.local_datasets], 73 | n_samples=ks, 74 | primary_party_id=primary_party_id, seed=0) 75 | 76 | self.n_parties = len(local_datasets) 77 | self.ks = ks 78 | self.to_tensor_() 79 | 80 | self.use_cache = use_cache 81 | self.cache_key = cache_key 82 | self.cache = None 83 | 84 | def __len__(self): 85 | return len(self.local_datasets[self.primary_party_id]) 86 | 87 | def __getitem__(self, idx): 88 | if self.use_cache: 89 | if self.cache is None: 90 | raise ValueError("Cache is not loaded") 91 | if len(self.cache) <= idx: 92 | raise IndexError(f"Index {idx} is out of range {len(self.cache)}") 93 | return self.cache[idx] 94 | 95 | indices_per_party = self.data_sampler.sample(idx) 96 | if not self.data_sampler.return_multi: 97 | # single index for secondary parties 98 | Xs = [] 99 | y_idx = idx 100 | for pid in range(self.n_parties): 101 | if pid == self.primary_party_id: 102 | key_X, _ = self.primary_train_local_dataset[indices_per_party[pid]] 103 | else: 104 | key_X, _ = self.local_datasets[pid][indices_per_party[pid]] 105 | Xs.append(key_X) 106 | y = self.local_datasets[self.primary_party_id].y[y_idx] 107 | return Xs, y 108 | else: 109 | # list of indices for secondary parties 110 | Xs = [] 111 | for pid in range(self.n_parties): 112 | index = torch.tensor(indices_per_party[pid]) 113 | local_dataset = self.primary_train_local_dataset \ 114 | if pid == self.primary_party_id else self.local_datasets[pid] 115 | 116 | if self.local_datasets[pid].key is None: 117 | key = np.zeros_like(index).reshape(-1, 1) * np.nan # for successful collate_fn 118 | else: 119 | key = torch.index_select(local_dataset.key, 0, index) 120 | 121 | X = torch.index_select(local_dataset.X, 0, index) 122 | 123 | if pid == self.primary_party_id: 124 | # For test set, the nearest neighbors may not be itself. Remove 125 | # the farthest neighbor and add itself. For training set this is 126 | # not necessary, but we do it anyway for simplicity. 127 | x_self = self.local_datasets[pid].X[idx].unsqueeze(0) 128 | X = torch.cat([x_self, X[:-1]], dim=0) 129 | Xs.append((key, X)) 130 | y = self.local_datasets[self.primary_party_id].y[idx] 131 | return Xs, y 132 | 133 | # def create_cache(self): 134 | # """ 135 | # Create cache for the dataset by iterating through all data 136 | # :return: 137 | # """ 138 | # if not self.use_cache: 139 | # return 140 | # os.makedirs(os.path.dirname(self.cache_key), exist_ok=True) 141 | # 142 | # self.use_cache = False # temporarily disable cache, force __getitem__ to calculate data 143 | # self.cache = [None] * len(self) 144 | # for idx in range(len(self)): 145 | # self.cache[idx] = self.__getitem__(idx) 146 | # self.use_cache = True 147 | # 148 | # with open(self.cache_path, 'wb') as f: 149 | # print(f"Creating cache of {len(self.cache)} records to {self.cache_path}") 150 | # pickle.dump(self.cache, f) 151 | # print(f"Saved cache to {self.cache_path}") 152 | # 153 | # def load_cache(self): 154 | # """ 155 | # Load cache for the dataset 156 | # :return: 157 | # """ 158 | # if not self.use_cache: 159 | # return 160 | # 161 | # with open(self.cache_path, 'rb') as f: 162 | # print(f"Loading cache from {self.cache_path}") 163 | # self.cache = pickle.load(f) 164 | # print(f"Loaded cache of {len(self.cache)} records from {self.cache_path}") 165 | 166 | @property 167 | def local_key_channels(self): 168 | key_channels = [] 169 | for local_dataset in self.local_datasets: 170 | if local_dataset.key is None: 171 | key_channels.append(0) 172 | elif len(local_dataset.key.shape) == 1: 173 | key_channels.append(1) 174 | else: 175 | key_channels.append(local_dataset.key.shape[1]) 176 | return key_channels 177 | 178 | @property 179 | def local_input_channels(self): 180 | return [local_dataset.X.shape[1] if len(local_dataset.X.shape) == 2 else 1 181 | for local_dataset in self.local_datasets] 182 | 183 | @property 184 | def local_key_X_channels(self): 185 | X_channels = self.local_input_channels 186 | key_channels = self.local_key_channels 187 | return [X + key for X, key in zip(X_channels, key_channels)] 188 | 189 | @classmethod 190 | def from_csv(cls, paths: Sequence, multi_primary=False, ks=100, key_cols=None, **kwargs): 191 | """ 192 | Create a VFLRealDataset from csv files 193 | :param paths: paths to csv files 194 | :param multi_primary: whether to have multiple primary parties 195 | :return: a VFLRealDataset 196 | """ 197 | local_datasets = [LocalDataset.from_csv(path, key_cols=key_cols, **kwargs) for path in paths] 198 | return cls(local_datasets, multi_primary=multi_primary, ks=ks, key_cols=key_cols, **kwargs) 199 | 200 | @classmethod 201 | def from_syn_aligned(cls, dataset: VFLSynAlignedDataset, ks=100, key_cols=None, **kwargs): 202 | """ 203 | Create a VFLRealDataset from a VFLSynAlignedDataset 204 | :param dataset: a VFLSynAlignedDataset 205 | :return: a VFLRealDataset 206 | """ 207 | return cls(dataset.local_datasets, dataset.primary_party_id, ks=ks, key_cols=key_cols, **kwargs) 208 | 209 | @classmethod 210 | def _from_split_datasets(cls, X_train, X_val, X_test, y_train, y_val, y_test, secondary_datasets, 211 | key_train=None, key_val=None, key_test=None, **vfl_args): 212 | train_local_datasets = [LocalDataset(X_train, y_train, key_train)] + secondary_datasets 213 | if X_val is not None: 214 | val_local_datasets = [LocalDataset(X_val, y_val, key_val)] + secondary_datasets 215 | else: 216 | val_local_datasets = None 217 | if X_test is not None: 218 | test_local_datasets = [LocalDataset(X_test, y_test, key_test)] + secondary_datasets 219 | else: 220 | test_local_datasets = None 221 | 222 | # keep only hyperparameters 223 | vfl_args = {k: v for k, v in vfl_args.items() if v is None or isinstance(v, (int, float, str, bool))} 224 | 225 | # remove training-only arguments 226 | val_args = {k: v for k, v in vfl_args.items() if (v is None or isinstance(v, (int, float, str, bool))) and 227 | k not in ['sample_rate_before_topk']} 228 | test_args = deepcopy(val_args) 229 | 230 | if 'cache_key' in vfl_args and vfl_args['cache_key'] is not None: 231 | vfl_args['cache_key'] = vfl_args['cache_key'] + '-train' 232 | if 'cache_key' in val_args and val_args['cache_key'] is not None: 233 | val_args['cache_key'] = val_args['cache_key'] + '-val' 234 | if 'cache_key' in test_args and test_args['cache_key'] is not None: 235 | test_args['cache_key'] = test_args['cache_key'] + '-test' 236 | 237 | multi_primary = vfl_args.get('multi_primary', False) 238 | match val_local_datasets, test_local_datasets: 239 | case (None, None): 240 | return cls(train_local_datasets, **vfl_args), None, None 241 | case (None, _): 242 | train_dataset = cls(train_local_datasets, **vfl_args) 243 | if multi_primary: 244 | test_dataset = cls(test_local_datasets, 245 | primary_train_local_dataset=train_dataset.primary_train_local_dataset, 246 | primary_train_indices=train_dataset.data_sampler.indices, 247 | **test_args) 248 | else: 249 | test_dataset = cls(test_local_datasets, **test_args) 250 | return train_dataset, None, test_dataset 251 | case (_, None): 252 | train_dataset = cls(train_local_datasets, **vfl_args) 253 | if multi_primary: 254 | val_dataset = cls(val_local_datasets, 255 | primary_train_local_dataset=train_dataset.primary_train_local_dataset, 256 | primary_train_indices=train_dataset.data_sampler.indices, 257 | **val_args) 258 | else: 259 | val_dataset = cls(val_local_datasets, **val_args) 260 | return train_dataset, val_dataset, None 261 | case (_, _): 262 | if multi_primary: 263 | train_dataset = cls(train_local_datasets, **vfl_args) 264 | val_dataset = cls(val_local_datasets, 265 | primary_train_local_dataset=train_dataset.primary_train_local_dataset, 266 | primary_train_indices=train_dataset.data_sampler.indices, 267 | **val_args) 268 | test_dataset = cls(test_local_datasets, 269 | primary_train_local_dataset=train_dataset.primary_train_local_dataset, 270 | primary_train_indices=train_dataset.data_sampler.indices, 271 | **test_args) 272 | else: 273 | train_dataset = cls(train_local_datasets, **vfl_args) 274 | val_dataset = cls(val_local_datasets, **val_args) 275 | test_dataset = cls(test_local_datasets, **test_args) 276 | return train_dataset, val_dataset, test_dataset 277 | 278 | def to_tensor_(self): 279 | for local_dataset in self.local_datasets: 280 | local_dataset.to_tensor_() 281 | 282 | def scale_y_(self, lower=0, upper=1, scaler=None): 283 | return self.local_datasets[self.primary_party_id].scale_y_(lower=lower, upper=upper, scaler=scaler) 284 | 285 | def normalize_(self, scalers=None, include_key=False): 286 | """ 287 | Normalize the features 288 | :param scalers: If scaler is None, normalize *all* parties and return the scalers. Otherwise, only normalize 289 | the primary party and use the given scalers. 290 | :param include_key: whether to normalize the key 291 | :return: 292 | """ 293 | if scalers is None: 294 | scalers = [None] * self.n_parties 295 | for pid in range(self.n_parties): 296 | scalers[pid] = self.local_datasets[pid].normalize_(include_key=include_key) 297 | return scalers 298 | else: 299 | if len(scalers) != self.n_parties: 300 | raise ValueError(f"Length of scalers {len(scalers)} does not match n_parties {self.n_parties}") 301 | for pid in range(self.n_parties): 302 | if pid == self.primary_party_id: 303 | self.local_datasets[pid].normalize_(scaler=scalers[pid], include_key=include_key) 304 | return None 305 | 306 | def split_train_test_primary(self, val_ratio=0.1, test_ratio=0.2, random_state=None, shuffle=False): 307 | """ 308 | Split the dataset into train and test set. 309 | :param val_ratio: ratio of validation set, if None, no validation set will be generated 310 | :param test_ratio: ratio of test set, if None, no test set will be generated 311 | :param random_state: random state, by default None 312 | :param hard_train_test_split: the split point for hard train-test split, 313 | e.g., cifar10 has 50K train data and 10K test data, we set it as 50K, and the input dataset should be the concatenation of [train_data, test_data]. 314 | Default 0 means no hard train-test split. 315 | :return: three VFLRealDataset, train, val, test 316 | """ 317 | primary_dataset = self.local_datasets[self.primary_party_id] 318 | secondary_datasets = (list(self.local_datasets[:self.primary_party_id]) + 319 | list(self.local_datasets[self.primary_party_id + 1:])) 320 | key, X, y = primary_dataset.data 321 | 322 | if y is None: 323 | raise ValueError(f"y on the primary party {self.primary_party_id} should not be None") 324 | 325 | if key is not None: 326 | match val_ratio, test_ratio: 327 | case (None, None): 328 | raise ValueError("val_ratio and test_ratio cannot be both None") 329 | case (None, _): 330 | X_train, X_test, y_train, y_test, key_train, key_test = train_test_split(X, y, key, 331 | test_size=test_ratio, 332 | random_state=random_state, 333 | shuffle=shuffle) 334 | 335 | return VFLRealDataset._from_split_datasets(X_train, None, X_test, y_train, None, y_test, 336 | secondary_datasets, 337 | key_train, None, key_test, **self.__dict__) 338 | case (_, None): 339 | X_train, X_val, y_train, y_val, key_train, key_val = train_test_split(X, y, key, 340 | test_size=val_ratio, 341 | random_state=random_state, 342 | shuffle=shuffle) 343 | return VFLRealDataset._from_split_datasets(X_train, X_val, None, y_train, y_val, None, 344 | secondary_datasets, 345 | key_train, key_val, None, **self.__dict__) 346 | case (_, _): 347 | X_train_val, X_test, y_train_val, y_test, key_train_val, key_test = ( 348 | train_test_split(X, y, key, test_size=test_ratio, random_state=random_state, 349 | shuffle=shuffle)) 350 | 351 | X_train, X_val, y_train, y_val, key_train, key_val = ( 352 | train_test_split(X_train_val, y_train_val, key_train_val, 353 | test_size=val_ratio / (1 - test_ratio), 354 | random_state=random_state, 355 | shuffle=shuffle)) 356 | return VFLRealDataset._from_split_datasets(X_train, X_val, X_test, y_train, y_val, y_test, 357 | secondary_datasets, 358 | key_train, key_val, key_test, **self.__dict__) 359 | else: 360 | # key is None 361 | match val_ratio, test_ratio: 362 | case (None, None): 363 | raise ValueError("val_ratio and test_ratio cannot be both None") 364 | case (None, _): 365 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_ratio, 366 | random_state=random_state, shuffle=shuffle) 367 | return VFLRealDataset._from_split_datasets(X_train, None, X_test, y_train, None, y_test, 368 | secondary_datasets, 369 | **self.__dict__) 370 | case (_, None): 371 | X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=val_ratio, 372 | random_state=random_state, shuffle=shuffle) 373 | return VFLRealDataset._from_split_datasets(X_train, X_val, None, y_train, y_val, None, 374 | secondary_datasets, 375 | **self.__dict__) 376 | case (_, _): 377 | X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=test_ratio, 378 | random_state=random_state, shuffle=shuffle) 379 | X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, 380 | test_size=val_ratio / (1 - test_ratio), 381 | random_state=random_state,shuffle=shuffle) 382 | return VFLRealDataset._from_split_datasets(X_train, X_val, X_test, y_train, y_val, y_test, 383 | secondary_datasets, 384 | **self.__dict__) 385 | -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /src/metric/RMSE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class RMSE: 4 | def __init__(self, is_robust=True): 5 | self.is_robust = is_robust 6 | 7 | def __call__(self, label, pred): 8 | if self.is_robust: 9 | return np.sqrt(np.nanmean((label - pred) ** 2)) 10 | else: 11 | return np.sqrt(np.mean((label - pred) ** 2)) 12 | -------------------------------------------------------------------------------- /src/metric/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/FeT/836cd91602b3a0fa6379c5b000b7df288bced790/src/metric/__init__.py -------------------------------------------------------------------------------- /src/model/PosEncoding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | """ 7 | Learnable Fourier Features for Multi-Dimensional Spatial Positional Encoding (NeurIPS 2021) 8 | https://github.com/willGuimont/learnable_fourier_positional_encoding/blob/master/learnable_fourier_pos_encoding.py 9 | """ 10 | 11 | class LearnableFourierPositionalEncoding(nn.Module): 12 | def __init__(self, G: int, M: int, F_dim: int, H_dim: int, D: int, gamma: float): 13 | """ 14 | Learnable Fourier Features from https://arxiv.org/pdf/2106.02795.pdf (Algorithm 1) 15 | Implementation of Algorithm 1: Compute the Fourier feature positional encoding of a multi-dimensional position 16 | Computes the positional encoding of a tensor of shape [N, G, M] 17 | :param G: positional groups (positions in different groups are independent) 18 | :param M: each point has a M-dimensional positional values 19 | :param F_dim: depth of the Fourier feature dimension 20 | :param H_dim: hidden layer dimension 21 | :param D: positional encoding dimension 22 | :param gamma: parameter to initialize Wr 23 | """ 24 | super().__init__() 25 | self.G = G 26 | self.M = M 27 | self.F_dim = F_dim 28 | self.H_dim = H_dim 29 | self.D = D 30 | self.gamma = gamma 31 | 32 | # If D is not divisible by G, we pad D to be divisible by G 33 | if self.D % self.G != 0: 34 | self.pad = self.D % self.G 35 | self.D = self.D // self.G * self.G 36 | else: 37 | self.pad = 0 38 | 39 | # Projection matrix on learned lines (used in eq. 2) 40 | self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False) 41 | # MLP (GeLU(F @ W1 + B1) @ W2 + B2 (eq. 6) 42 | self.mlp = nn.Sequential( 43 | nn.Linear(self.F_dim, self.H_dim, bias=True), 44 | nn.GELU(), 45 | nn.Linear(self.H_dim, self.D // self.G) 46 | ) 47 | 48 | self.init_weights() 49 | 50 | def init_weights(self): 51 | nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2) 52 | 53 | def forward(self, x): 54 | """ 55 | Produce positional encodings from x 56 | :param x: tensor of shape [N, G, M] that represents N positions where each position is in the shape of [G, M], 57 | where G is the positional group and each group has M-dimensional positional values. 58 | Positions in different positional groups are independent 59 | :return: positional encoding for X 60 | """ 61 | N, G, M = x.shape 62 | # Step 1. Compute Fourier features (eq. 2) 63 | projected = self.Wr(x) 64 | cosines = torch.cos(projected) 65 | sines = torch.sin(projected) 66 | F = 1 / np.sqrt(self.F_dim) * torch.cat([cosines, sines], dim=-1) 67 | # Step 2. Compute projected Fourier features (eq. 6) 68 | Y = self.mlp(F) 69 | # Step 3. Reshape to x's shape 70 | PEx = Y.reshape((N, -1)) 71 | # pad Y with zero 72 | 73 | if self.pad != 0: 74 | PEx = torch.cat([PEx, torch.zeros((N, self.pad)).to(x.device)], dim=-1) 75 | return PEx 76 | 77 | 78 | if __name__ == '__main__': 79 | G = 3 80 | M = 17 81 | x = torch.randn((97, G, M)) 82 | enc = LearnableFourierPositionalEncoding(G, M, 768, 32, 768, 10) 83 | pex = enc(x) 84 | print(pex.shape) -------------------------------------------------------------------------------- /src/model/Solo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Models of single party 3 | """ 4 | 5 | from typing import Callable 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | class MLP(nn.Module): 12 | def __init__(self, input_size, hidden_sizes: list, output_size=1, activation=None): 13 | super(MLP, self).__init__() 14 | self.hidden_sizes = hidden_sizes 15 | self.activation = activation 16 | if len(hidden_sizes) != 0: 17 | self.fc_layers = nn.ModuleList([nn.Linear(input_size, hidden_sizes[0])]) 18 | for i in range(len(hidden_sizes) - 1): 19 | self.fc_layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1])) 20 | self.fc_layers.append(nn.Linear(hidden_sizes[-1], output_size)) 21 | else: 22 | self.fc_layers = nn.ModuleList([nn.Linear(input_size, output_size)]) 23 | 24 | def forward(self, key_X): 25 | if key_X[0] is not None: 26 | X = torch.cat(key_X, dim=1) 27 | else: 28 | X = key_X[1] 29 | 30 | if len(list(self.fc_layers)) == 0: 31 | return X 32 | 33 | if len((list(self.fc_layers))) == 1: 34 | out = X 35 | else: 36 | out = F.relu(self.fc_layers[0](X)) 37 | 38 | for fc in self.fc_layers[1:-1]: 39 | out = F.relu(fc(out)) 40 | 41 | if self.activation == 'sigmoid': 42 | out = torch.sigmoid(self.fc_layers[-1](out)) 43 | elif self.activation == 'tanh': 44 | out = torch.tanh(self.fc_layers[-1](out)) 45 | elif self.activation == 'relu': 46 | out = torch.relu(self.fc_layers[-1](out)) 47 | elif isinstance(self.activation, Callable): 48 | out = self.activation(self.fc_layers[-1](out)) 49 | elif self.activation is None: 50 | out = self.fc_layers[-1](out) 51 | else: 52 | assert False 53 | return out 54 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/FeT/836cd91602b3a0fa6379c5b000b7df288bced790/src/model/__init__.py -------------------------------------------------------------------------------- /src/preprocess/ExactSpitter.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Tuple 3 | import argparse 4 | import os 5 | import os.path 6 | import sys 7 | 8 | from sklearn.decomposition import PCA 9 | from sklearn.datasets import load_svmlight_file 10 | from sklearn.preprocessing import StandardScaler 11 | import numpy as np 12 | import pandas as pd 13 | 14 | # add src to python path 15 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 16 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 17 | 18 | from src.dataset.LocalDataset import LocalDataset 19 | from src.utils.BasicUtils import PartyPath 20 | 21 | 22 | def get_exact_key(n, key_dim=1) -> np.ndarray: 23 | """ 24 | Generate a (nx1) key matrix with exact values. Each value is in [-1, 1]. 25 | """ 26 | return np.random.uniform(-1, 1, (n, key_dim)) 27 | 28 | 29 | 30 | def exact_split(X: np.ndarray, n_parties: int, key_dim: int = 1) -> List[np.ndarray]: 31 | """ 32 | Split the data into multiple parties. 33 | Each party has the same key matrix. 34 | """ 35 | 36 | raw_Xs = np.array_split(X, n_parties, axis=1) # split the data into multiple parties 37 | 38 | key = get_exact_key(len(raw_Xs[0]), key_dim) 39 | 40 | Xs = [] 41 | for i in range(n_parties): 42 | Xi = np.concatenate((key, raw_Xs[i]), axis=1) 43 | Xs.append(Xi) 44 | 45 | return Xs 46 | 47 | 48 | def load_data(data_path) -> Tuple[np.ndarray, np.ndarray]: 49 | """ 50 | Load data from a file. 51 | :param data_path: [str] path to the data file 52 | :return: [Tuple[np.ndarray, np.ndarray]] data matrix and label vector 53 | """ 54 | if not os.path.exists(data_path): 55 | raise FileNotFoundError(f"{data_path} not found.") 56 | if data_path.endswith('.libsvm'): 57 | X, y = load_svmlight_file(data_path) 58 | X = X.toarray() 59 | elif data_path.endswith('.csv'): 60 | X = pd.read_csv(data_path).values 61 | y = X[:, -1] 62 | X = X[:, :-1] 63 | else: 64 | raise NotImplementedError(f"Unknown file format {data_path}") 65 | return X, y 66 | 67 | 68 | if __name__ == '__main__': 69 | parser = argparse.ArgumentParser(description='Equal Splitter with exact key') 70 | parser.add_argument('-d', '--dataset', type=str, default='gisette.libsvm', help='path to the data file') 71 | parser.add_argument('-p', '--n_parties', type=int, default=2, help='number of parties') 72 | parser.add_argument('-kd', '--key_dim', type=int, default=1, help='key dimension') 73 | parser.add_argument('--seed', type=int, default=0, help='random seed') 74 | args = parser.parse_args() 75 | 76 | if len(args.dataset.split('.')) < 2: 77 | raise ValueError(f"Invalid dataset name {args.dataset}, should be in the format of 'dataset.format'") 78 | fmt = args.dataset.split('.')[-1] 79 | dataset = args.dataset.split('.')[0] 80 | X, y = load_data(f"data/syn/{dataset}/{dataset}.{fmt}") 81 | 82 | # random shuffle X, y 83 | np.random.seed(args.seed) 84 | random.seed(args.seed) 85 | 86 | idx = np.random.permutation(X.shape[0]) 87 | X = X[idx] 88 | y = y[idx] 89 | np.random.shuffle(X.T) # shuffle the data along the feature dimension 90 | 91 | Xs = exact_split(X, args.n_parties, key_dim=args.key_dim) 92 | save_dir = f"data/syn/exact_key_dataset/{dataset}/" 93 | os.makedirs(save_dir, exist_ok=True) 94 | for i, Xi in enumerate(Xs): 95 | party_path = os.path.join(save_dir, f"{dataset}_party{args.n_parties}-{i}_imp_weight100.0_seed{args.seed}.pkl") 96 | LocalDataset(Xi[:, args.key_dim:], y, key=Xi[:, :args.key_dim]).to_pickle(party_path) 97 | print(f"Saved {party_path}") 98 | -------------------------------------------------------------------------------- /src/preprocess/FuzzySplitter.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Tuple 3 | import argparse 4 | import os 5 | import os.path 6 | import sys 7 | 8 | from sklearn.decomposition import PCA 9 | from sklearn.datasets import load_svmlight_file 10 | from sklearn.preprocessing import StandardScaler 11 | import numpy as np 12 | import pandas as pd 13 | 14 | # add src to python path 15 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 16 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 17 | 18 | from src.dataset.LocalDataset import LocalDataset 19 | from src.utils.BasicUtils import PartyPath 20 | from src.preprocess.FeatureSplitter import ImportanceSplitter 21 | 22 | 23 | def get_fuzzy_key(X: np.ndarray, key_dim: int = 5, train_ratio = 0.7) -> np.ndarray: 24 | """ 25 | Split the data into multiple parties. The key of each sample is calculated by PCA. Gaussian noise is added to the 26 | key to make it fuzzy. 27 | :param X: [np.ndarray] (N x D) data matrix 28 | :param key_dim: [int] key dimension 29 | :return: [np.ndarray] (N x key_dim) key matrix 30 | """ 31 | # PCA 32 | 33 | pca = PCA(n_components=key_dim) 34 | train_size = int(X.shape[0] * train_ratio) 35 | pca.fit(X[:train_size]) # fit PCA on the training data to avoid information leakage 36 | key = pca.transform(X) # transform the entire data 37 | 38 | # scale to [-1, 1] 39 | scaler = StandardScaler() 40 | scaler.fit(key[:train_size]) 41 | key = scaler.transform(key) 42 | 43 | return key 44 | 45 | 46 | def get_noise_key(n_instances, key_dim, noise_scale): 47 | """ 48 | Generate a key matrix with Gaussian noise. 49 | :param n_instances: [int] number of instances 50 | :param key_dim: [int] key dimension 51 | :param noise_scale: [float] scale of the Gaussian noise 52 | :return: [np.ndarray] (N x key_dim) key matrix 53 | """ 54 | return np.random.normal(0, noise_scale, (n_instances, key_dim)) 55 | 56 | 57 | def fuzzy_split(X: np.ndarray, n_parties: int, key_dim: int = 5, noise_scale: float = 0.0, 58 | key_base_noise: float = 1.0) -> List[np.ndarray]: 59 | """ 60 | Split the data into multiple parties. The key of each sample is calculated by PCA. Gaussian noise is added to the 61 | key to make it fuzzy. 62 | :param X: [np.ndarray] (N x D) data matrix 63 | :param key_dim: [int] key dimension 64 | :param n_parties: [int] number of parties 65 | :param key_base_noise: [float] scale of the base noise 66 | :param noise_scale: [float] scale of the Gaussian noise 67 | :return: [List[np.ndarray]] list of key matrices for each party 68 | """ 69 | 70 | raw_Xs = np.array_split(X, n_parties, axis=1) # split the data into multiple parties 71 | 72 | key = get_fuzzy_key(raw_Xs[0], key_dim) # first party as primary party to generate key 73 | # key = get_noise_key(X.shape[0], key_dim, key_base_noise) 74 | 75 | Xs = [] 76 | for i in range(n_parties): 77 | key_i = key + np.random.normal(0, noise_scale, key.shape) 78 | Xi = np.concatenate((key_i, raw_Xs[i]), axis=1) 79 | Xs.append(Xi) 80 | 81 | return Xs 82 | 83 | 84 | def load_data(data_path) -> Tuple[np.ndarray, np.ndarray]: 85 | """ 86 | Load data from a file. 87 | :param data_path: [str] path to the data file 88 | :return: [Tuple[np.ndarray, np.ndarray]] data matrix and label vector 89 | """ 90 | if not os.path.exists(data_path): 91 | raise FileNotFoundError(f"{data_path} not found.") 92 | if data_path.endswith('.libsvm'): 93 | X, y = load_svmlight_file(data_path) 94 | X = X.toarray() 95 | elif data_path.endswith('.csv'): 96 | X = pd.read_csv(data_path).values 97 | y = X[:, -1] 98 | X = X[:, :-1] 99 | else: 100 | raise NotImplementedError(f"Unknown file format {data_path}") 101 | return X, y 102 | 103 | 104 | if __name__ == '__main__': 105 | parser = argparse.ArgumentParser(description='Fuzzy Splitter') 106 | parser.add_argument('-d', '--dataset', type=str, default='gisette.libsvm', help='path to the data file') 107 | parser.add_argument('-p', '--n_parties', type=int, default=2, help='number of parties') 108 | parser.add_argument('-kd', '--key_dim', type=int, default=4, help='key dimension') 109 | parser.add_argument('-ns', '--noise_scale', type=float, default=0.0, help='scale of the Gaussian noise') 110 | # parser.add_argument('-sd', '--save-dir', type=str, default='data/syn/multi_party_dataset/gisette', help='directory to save the split data') 111 | parser.add_argument('-a', '--alpha', type=float, default=None, help='weight of the importance score') 112 | parser.add_argument('--seed', type=int, default=0, help='random seed') 113 | args = parser.parse_args() 114 | 115 | if len(args.dataset.split('.')) < 2: 116 | raise ValueError(f"Invalid dataset name {args.dataset}, should be in the format of 'dataset.format'") 117 | fmt = args.dataset.split('.')[-1] 118 | dataset = args.dataset.split('.')[0] 119 | X, y = load_data(f"data/syn/{dataset}/{dataset}.{fmt}") 120 | 121 | # random shuffle X, y 122 | np.random.seed(args.seed) 123 | random.seed(args.seed) 124 | 125 | if args.alpha is None: 126 | idx = np.random.permutation(X.shape[0]) 127 | X = X[idx] 128 | y = y[idx] 129 | np.random.shuffle(X.T) # shuffle the data along the feature dimension 130 | Xs = fuzzy_split(X, args.n_parties, key_dim=args.key_dim, noise_scale=args.noise_scale) 131 | else: 132 | splitter = ImportanceSplitter(num_parties=args.n_parties, weights=args.alpha, seed=args.seed) 133 | Xs_no_key = splitter.split(X) 134 | 135 | # add noise 136 | key = get_fuzzy_key(X, key_dim=args.key_dim) 137 | Xs = [] 138 | for i in range(args.n_parties): 139 | key_i = key + np.random.normal(0, args.noise_scale, key.shape) 140 | Xs.append(np.concatenate((key_i, Xs_no_key[i]), axis=1)) 141 | 142 | save_dir = f"data/syn/{dataset}/noise{args.noise_scale}" 143 | os.makedirs(save_dir, exist_ok=True) 144 | for i, Xi in enumerate(Xs): 145 | weight = '100.0' if args.alpha is None else f"{args.alpha:.1f}" # 100 for balanced split (the real alpha may not be 100, but a large number) 146 | party_path = os.path.join(save_dir, f"{dataset}_party{args.n_parties}-{i}_imp_weight{weight}_seed{args.seed}.pkl") 147 | assert Xi[:, :args.key_dim].shape[1] > 0 148 | assert Xi[:, args.key_dim:].shape[1] > 0 149 | LocalDataset(Xi[:, args.key_dim:], y, key=Xi[:, :args.key_dim]).to_pickle(party_path) 150 | print(f"Saved {party_path}") 151 | -------------------------------------------------------------------------------- /src/preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/FeT/836cd91602b3a0fa6379c5b000b7df288bced790/src/preprocess/__init__.py -------------------------------------------------------------------------------- /src/preprocess/gisette/__init__.py: -------------------------------------------------------------------------------- 1 | from .gisette_loader import load_both -------------------------------------------------------------------------------- /src/preprocess/gisette/gisette_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loader for gisette dataset. Adaption for FeT 3 | """ 4 | 5 | import pandas as pd 6 | 7 | from src.utils.BasicUtils import move_item_to_start_ 8 | 9 | 10 | def load_both(primary_path, secondary_path): 11 | print(f'Loading primary from {primary_path}') 12 | primary = pd.read_csv(primary_path, index_col=False) 13 | 14 | print(f'Loading secondary from {secondary_path}') 15 | secondary = pd.read_csv(secondary_path, index_col=False) 16 | 17 | labels = primary['y'].to_numpy() 18 | labels[labels == -1] = 0 19 | 20 | primary.drop(columns=['y'], inplace=True) 21 | secondary.drop(columns=['y'], inplace=True) 22 | 23 | data1 = primary.to_numpy() 24 | data2 = secondary.to_numpy() 25 | 26 | return [data1, data2], labels 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /src/preprocess/hdb/__init__.py: -------------------------------------------------------------------------------- 1 | from .hdb_loader import load_both, load_hdb -------------------------------------------------------------------------------- /src/preprocess/hdb/clean_hdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import requests 4 | import json 5 | 6 | from tqdm import tqdm 7 | import pandas as pd 8 | 9 | 10 | def get_blk_loc(hdb_path, out_hdb_w_blk_loc_path): 11 | hdb_df = pd.read_csv(hdb_path, parse_dates=['month']) 12 | 13 | hdb_df['address'] = hdb_df['block'] + " " + hdb_df['street_name'] 14 | addrs_unique = hdb_df['address'].drop_duplicates(keep='first') 15 | 16 | # get the address of blocks by OneMap API 17 | latitude = [] 18 | longitude = [] 19 | blk_no = [] 20 | road_name = [] 21 | postal_code = [] 22 | address = [] 23 | for addr in tqdm(addrs_unique): 24 | query_string = 'https://developers.onemap.sg/commonapi/search?searchVal=' + str( 25 | addr) + '&returnGeom=Y&getAddrDetails=Y' 26 | resp = requests.get(query_string) 27 | 28 | # Convert JSON into Python Object 29 | data_geo_location = json.loads(resp.content) 30 | if data_geo_location['found'] != 0: 31 | latitude.append(data_geo_location['results'][0]['LATITUDE']) 32 | longitude.append(data_geo_location['results'][0]['LONGITUDE']) 33 | blk_no.append(data_geo_location['results'][0]['BLK_NO']) # this one is a unique block No. 34 | road_name.append(data_geo_location['results'][0]['ROAD_NAME']) 35 | postal_code.append(data_geo_location['results'][0]['POSTAL']) 36 | address.append(addr) 37 | # print(str(addr) + " ,Lat: " + data_geo_location['results'][0]['LATITUDE'] + " Long: " + 38 | # data_geo_location['results'][0]['LONGITUDE']) 39 | else: 40 | print("No Results") 41 | 42 | print("Converting to dataframe") 43 | block_loc_df = pd.DataFrame({'address': address, 'lat': latitude, 'lon': longitude}) 44 | print("Joining with HDB data") 45 | hdb_df = hdb_df.merge(block_loc_df, on='address', how='left') 46 | print("Saving to {}".format(out_hdb_w_blk_loc_path)) 47 | hdb_df.to_csv(out_hdb_w_blk_loc_path, index=False) 48 | 49 | 50 | def clean_hdb(hdb_path, out_hdb_path): 51 | hdb_df = pd.read_csv(hdb_path, parse_dates=['month']) 52 | hdb_df.dropna(inplace=True) 53 | hdb_df.drop(columns=['month', 'block', 'street_name', 'address'], inplace=True) 54 | 55 | hdb_df['lease_commence_year_before_2020'] = 2020 - hdb_df['lease_commence_date'] 56 | hdb_df.drop(columns=['lease_commence_date', 'remaining_lease'], inplace=True) 57 | 58 | hdb_df = pd.get_dummies(hdb_df, 59 | columns=['town', 'flat_type', 'storey_range', 'flat_model'], 60 | prefix=['tn', 'ft', 'sr', 'fm'], drop_first=True) 61 | 62 | hdb_df['resale_price'] = hdb_df['resale_price'] / 1000 # change to kS$ 63 | 64 | hdb_df.to_csv(out_hdb_path, index=False) 65 | 66 | 67 | if __name__ == '__main__': 68 | os.chdir(sys.path[0] + "/../../../data/hdb") # change working directory 69 | # get_blk_loc("resale-flat-prices-based-on-registration-date-from-jan-2017-onwards.csv", 70 | # "hdb_2017_onwards_w_blk_loc.csv") 71 | clean_hdb("hdb_2017_onwards_w_blk_loc.csv", 72 | "hdb_clean.csv") -------------------------------------------------------------------------------- /src/preprocess/hdb/clean_school.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import requests 5 | import json 6 | 7 | from tqdm import tqdm 8 | import pandas as pd 9 | 10 | 11 | def clean_school(rank_path_list, out_path): 12 | school_df_list = [] 13 | for i, rank_path in enumerate(rank_path_list): 14 | school_name_list = [] 15 | n_places_after_2b_list = [] 16 | vacancy_rate_list = [] 17 | with open(rank_path, 'r') as f: 18 | for line in f: 19 | out_params = re.split('–', line) # This '–' is not minus symbol '-' in the keyboard 20 | school_name = out_params[0].strip().lower() 21 | in_params = re.split('[(),]', out_params[1]) 22 | n_places_after_2b = int(in_params[0]) 23 | vacancy_rate = eval(in_params[2].strip()) 24 | school_name_list.append(school_name) 25 | n_places_after_2b_list.append(n_places_after_2b) 26 | vacancy_rate_list.append(vacancy_rate) 27 | school_df_i = pd.DataFrame({ 28 | 'school_name': school_name_list, 29 | 'n_places_{}'.format(i): n_places_after_2b, 30 | 'vacancy_rate_{}'.format(i): vacancy_rate_list 31 | }) 32 | school_df_i.set_index('school_name', inplace=True) 33 | school_df_list.append(school_df_i) 34 | 35 | all_school_df = school_df_list[0].join(school_df_list[1:]) 36 | all_school_df.to_csv(out_path, index=True) 37 | 38 | 39 | def get_school_loc(school_summary_path, out_path): 40 | school_df = pd.read_csv(school_summary_path) 41 | school_list = school_df['school_name'].to_list() 42 | names = [] 43 | lats = [] 44 | lons = [] 45 | for name in tqdm(school_list): 46 | query_str = "https://developers.onemap.sg/commonapi/search?searchVal=" + str( 47 | name) + "&returnGeom=Y&getAddrDetails=Y" 48 | resp = requests.get(query_str) 49 | 50 | # Convert JSON into Python Object 51 | try: 52 | data_geo_location = json.loads(resp.content) 53 | except json.decoder.JSONDecodeError: 54 | print("Failed to retrieve result") 55 | continue 56 | if data_geo_location['found'] != 0: 57 | lats.append(data_geo_location['results'][0]['LATITUDE']) 58 | lons.append(data_geo_location['results'][0]['LONGITUDE']) 59 | names.append(name) 60 | else: 61 | print("No Results") 62 | 63 | school_loc_df = pd.DataFrame({'school_name': names, 64 | 'lat': lats, 65 | 'lon': lons}).set_index('school_name') 66 | school_df = school_df.set_index('school_name').join(school_loc_df) 67 | school_df.dropna(inplace=True) 68 | school_df.to_csv(out_path) 69 | 70 | 71 | if __name__ == '__main__': 72 | os.chdir(sys.path[0] + "/../../../data/hdb") # change working directory 73 | # clean_school(["primary_school_rank_2015.txt", 74 | # "primary_school_rank_2016.txt", 75 | # "primary_school_rank_2017.txt", 76 | # "primary_school_rank_2018.txt"], 77 | # "school_summary.csv") 78 | get_school_loc("school_summary.csv", "school_clean.csv") 79 | -------------------------------------------------------------------------------- /src/preprocess/hdb/hdb_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loader for HDB dataset 3 | """ 4 | 5 | import pandas as pd 6 | 7 | from src.utils.BasicUtils import move_item_to_start_ 8 | 9 | 10 | def load_hdb(hdb_path): 11 | print(f'Loading hdb from {hdb_path}') 12 | hdb_data = pd.read_csv(hdb_path) 13 | 14 | hdb_data.drop(columns=['lon', 'lat']) 15 | 16 | hdb_data.info(verbose=True) 17 | 18 | labels = hdb_data['resale_price'].to_numpy() 19 | hdb_data = hdb_data.drop(columns=['resale_price']).to_numpy() 20 | 21 | return hdb_data, labels 22 | 23 | 24 | def load_both(hdb_path, airbnb_path, active_party='hdb'): 25 | print(f'Loading house from {hdb_path}') 26 | hdb_data = pd.read_csv(hdb_path) 27 | print(f'Loading airbnb from {airbnb_path}') 28 | school_data = pd.read_csv(airbnb_path) 29 | 30 | if active_party == 'hdb': 31 | labels = hdb_data['resale_price'].to_numpy() 32 | hdb_data.drop(columns=['resale_price'], inplace=True) 33 | 34 | # move lon and lat to end 35 | hdb_cols = list(hdb_data.columns) 36 | move_item_to_start_(hdb_cols, ['lon', 'lat']) 37 | hdb_data = hdb_data[hdb_cols] 38 | print(f'Current hdb columns {hdb_data.columns}') 39 | 40 | school_data.drop(columns=['school_name'], inplace=True) 41 | 42 | # move lon and lat to start 43 | school_cols = list(school_data.columns) 44 | move_item_to_start_(school_cols, ['lon', 'lat']) 45 | school_data = school_data[school_cols] 46 | print(f'Current airbnb columns {school_data.columns}') 47 | 48 | data1 = hdb_data.to_numpy() 49 | data2 = school_data.to_numpy() 50 | else: 51 | raise NotImplementedError 52 | 53 | return [data1, data2], labels 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /src/preprocess/house/__init__.py: -------------------------------------------------------------------------------- 1 | from .beijing_loder import * -------------------------------------------------------------------------------- /src/preprocess/house/beijing_loder.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from src.utils.BasicUtils import move_item_to_start_, move_item_to_end_ 3 | 4 | 5 | def load_house(house_path): 6 | print("Loading house from {}".format(house_path)) 7 | house_data = pd.read_csv(house_path) 8 | 9 | house_data.drop(columns=['lon', 'lat'], inplace=True) 10 | 11 | house_data.info(verbose=True) 12 | 13 | labels = house_data['totalPrice'].to_numpy() 14 | house_data = house_data.drop(columns=['totalPrice']).to_numpy() 15 | 16 | return house_data, labels 17 | 18 | 19 | def load_both(house_path, airbnb_path, active_party='house'): 20 | print("Loading house from {}".format(house_path)) 21 | house_data = pd.read_csv(house_path) 22 | print("Loading airbnb from {}".format(airbnb_path)) 23 | airbnb_data = pd.read_csv(airbnb_path) 24 | 25 | if active_party == 'house': 26 | labels = house_data['totalPrice'].to_numpy() 27 | house_data.drop(columns=['totalPrice'], inplace=True) 28 | 29 | # move lon and lat to end 30 | house_cols = list(house_data.columns) 31 | move_item_to_start_(house_cols, ['lon', 'lat']) 32 | house_data = house_data[house_cols] 33 | print("Current house columns {}".format(house_data.columns)) 34 | 35 | # move lon and lat to start 36 | airbnb_cols = list(airbnb_data.columns) 37 | move_item_to_start_(airbnb_cols, ['lon', 'lat']) 38 | airbnb_data = airbnb_data[airbnb_cols] 39 | print("Current airbnb columns {}".format(airbnb_data.columns)) 40 | 41 | data1 = house_data.to_numpy() 42 | data2 = airbnb_data.to_numpy() 43 | else: 44 | raise NotImplementedError 45 | 46 | return [data1, data2], labels 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /src/preprocess/house/clean_airbnb.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import sys 4 | 5 | 6 | def clean_airbnb(airbnb_path, out_airbnb_path): 7 | airbnb_data = pd.read_csv(airbnb_path) 8 | 9 | # remove useless columns and NA 10 | airbnb_data.drop(columns=['id', 'name', 'host_id', 'host_name', 'last_review'], inplace=True) 11 | airbnb_data.dropna(inplace=True) 12 | 13 | # remove extreme high prices 14 | airbnb_data = airbnb_data[airbnb_data['price'] < 3000] 15 | 16 | airbnb_data.rename(columns={'latitude': 'lat', 'longitude': 'lon'}, inplace=True) 17 | 18 | airbnb_data = pd.get_dummies(airbnb_data, 19 | columns=['neighbourhood', 'room_type'], 20 | prefix=['nbr', 'rt'], drop_first=True) 21 | 22 | print("Got columns " + str(airbnb_data.columns)) 23 | print("Got {} lines".format(len(airbnb_data.index))) 24 | 25 | airbnb_data.to_csv(out_airbnb_path, index=False) 26 | 27 | 28 | if __name__ == '__main__': 29 | os.chdir(sys.path[0] + "/../../../data/beijing") # change working directory 30 | clean_airbnb("airbnb.csv", "airbnb_clean.csv") 31 | -------------------------------------------------------------------------------- /src/preprocess/house/clean_house.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import sys 4 | 5 | 6 | def clean_house(house_path, out_house_path, include_cid=False): 7 | house_data = pd.read_csv(house_path, encoding="iso-8859-1", parse_dates=['tradeTime'], 8 | dtype={'Cid': 'category'}) 9 | 10 | house_data.dropna(inplace=True) 11 | 12 | house_data['buildingType'] = house_data['buildingType'].astype('int') 13 | 14 | # remove the houses sold before 2013 15 | house_data = house_data[house_data['tradeTime'].dt.year > 2012] 16 | 17 | house_data['trade_year'] = house_data['tradeTime'].dt.year 18 | house_data['trade_month'] = house_data['tradeTime'].dt.month 19 | 20 | # rename longitude and latitude 21 | house_data.rename(columns={'Lng': 'lon', 'Lat': 'lat', 'Cid': 'cid'}, inplace=True) 22 | 23 | # # filter too large data 24 | # house_data = house_data[house_data['DOM'] < 365] 25 | 26 | # remove non-numeric values in constructionTime 27 | house_data['constructionTime'] = house_data['constructionTime'].str.extract('(\d+)', expand=False) 28 | 29 | # remove non-numeric values in floor 30 | house_data['floor'] = house_data['floor'].str.extract('(\d+)', expand=False) 31 | 32 | # remove houses with prices extremely large or small [10w, 2000w) 33 | house_data = house_data[house_data['totalPrice'] >= 10] 34 | house_data = house_data[house_data['totalPrice'] < 1000] 35 | 36 | # one-hot categorical features 37 | if include_cid: 38 | house_data = pd.get_dummies(house_data, 39 | columns=['cid', 'district', 'buildingType', 'renovationCondition', 40 | 'buildingStructure', 'trade_year', 'trade_month'], 41 | prefix=['cid', 'did', 'bt', 'rc', 'bs', 'ty', 'tm'], drop_first=True) 42 | else: 43 | house_data = pd.get_dummies(house_data, 44 | columns=['district', 'buildingType', 'renovationCondition', 'buildingStructure', 45 | 'trade_year', 'trade_month'], 46 | prefix=['did', 'bt', 'rc', 'bs', 'ty', 'tm'], drop_first=True) 47 | 48 | # price is not needed to predict totalPrice, otherwise totalPrice = price * squares 49 | house_data.drop(columns=['url', 'id', 'communityAverage', 'price', 'tradeTime'], inplace=True) 50 | 51 | print("Got columns " + str(house_data.columns)) 52 | print("Got {} lines".format(len(house_data.index))) 53 | 54 | house_data.dropna(inplace=True) 55 | 56 | house_data.to_csv(out_house_path, index=False) 57 | 58 | 59 | if __name__ == '__main__': 60 | os.chdir(sys.path[0] + "/../../../data/beijing") # change working directory 61 | clean_house("house.csv", "house_clean.csv") 62 | 63 | -------------------------------------------------------------------------------- /src/preprocess/ml_dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/FeT/836cd91602b3a0fa6379c5b000b7df288bced790/src/preprocess/ml_dataset/__init__.py -------------------------------------------------------------------------------- /src/preprocess/ml_dataset/two_party_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import random 5 | from sklearn.datasets import load_svmlight_file 6 | import pickle 7 | import inspect 8 | from scipy.sparse import csr_matrix 9 | 10 | 11 | class TwoPartyLoader: 12 | def __init__(self, num_features, num_common_features: int, 13 | common_feature_noise_scale=0.0, data_fmt='libsvm', dataset_name=None, cache_path=None, 14 | n_classes=2, seed=0): 15 | """ 16 | :param cache_path: path for cache of the object 17 | :param dataset_name: name of the dataset 18 | :param num_features_per_party: number of features on both party, including common features 19 | :param num_common_features: number of common features 20 | """ 21 | self.cache_path = cache_path 22 | self.dataset_name = dataset_name 23 | self.data_fmt = data_fmt 24 | self.n_classes = n_classes 25 | self.common_feature_noise_scale = common_feature_noise_scale 26 | self.num_common_features = num_common_features 27 | self.num_features = num_features 28 | self.seeds = list(range(seed, seed + 3)) 29 | 30 | self.X = None 31 | self.y = None 32 | self.Xs = None 33 | 34 | def load_dataset(self, path=None, use_cache=True, scale_label=False): 35 | """ 36 | :param use_cache: whether to use cache 37 | :param path: path of the ml dataset 38 | :param scale_label: whether to scale back the label from [0,1] to int. True in covtype.scale01. 39 | :return: features, labels 40 | """ 41 | if use_cache and self.X is not None and self.y is not None: 42 | assert self.num_features == self.X.shape[1], "Total number of features mismatch." 43 | return self.X, self.y 44 | 45 | assert path is not None 46 | print("Loading {} dataset".format(self.dataset_name)) 47 | if inspect.isfunction(self.data_fmt): 48 | X, y = self.data_fmt(path) 49 | elif self.data_fmt == 'libsvm': 50 | X, y = load_svmlight_file(path) 51 | X = X.toarray() 52 | 53 | # hard code for a strange dataset whose labels are 1 & 2 54 | if self.dataset_name == 'covtype.binary': 55 | y -= 1 56 | elif self.data_fmt == 'csv': 57 | dataset = np.loadtxt(path, delimiter=',', skiprows=1) 58 | X = dataset[:, :-1] 59 | y = dataset[:, -1].reshape(-1) 60 | else: 61 | assert False, "Unsupported ML dataset format" 62 | 63 | if scale_label: 64 | y = np.rint(y * (self.n_classes - 1)).astype('int32') 65 | 66 | assert self.num_features == X.shape[1], "Total number of features mismatch." 67 | print("Done") 68 | if use_cache: 69 | self.X, self.y = X, y 70 | 71 | return X, y 72 | 73 | def load_parties(self, path=None, use_cache=True, scale_label=False): 74 | X, y = self.load_dataset(path, use_cache, scale_label) 75 | if use_cache and self.Xs is not None: 76 | print("Loading parties from cache") 77 | return self.Xs, self.y 78 | 79 | # assuming all the features are useful 80 | print("Splitting features to two parties") 81 | 82 | # randomly divide trained features to two parties 83 | shuffle_state = np.random.RandomState(self.seeds[0]) 84 | shuffle_state.shuffle(X.T) # shuffle columns 85 | trained_features = X[:, self.num_common_features:] 86 | trained_features1 = trained_features[:, :trained_features.shape[1] // 2] 87 | trained_features2 = trained_features[:, trained_features.shape[1] // 2:] 88 | 89 | # append common features 90 | common_features = X[:, :self.num_common_features] 91 | noise_state = np.random.RandomState(self.seeds[2]) 92 | noised_common_features = common_features.copy() + noise_state.normal( 93 | scale=self.common_feature_noise_scale, size=common_features.shape) 94 | X1 = np.concatenate([trained_features1, common_features], axis=1) 95 | X2 = np.concatenate([noised_common_features, trained_features2], axis=1) 96 | 97 | assert X1.shape[1] + X2.shape[1] - self.num_common_features == self.X.shape[1] 98 | 99 | if use_cache: 100 | # refresh cached Xs 101 | self.Xs = [X1, X2] 102 | print("Done") 103 | return [X1, X2], y 104 | 105 | def to_pickle(self, save_path: str): 106 | with open(save_path, 'wb') as f: 107 | pickle.dump(self, f) 108 | 109 | @staticmethod 110 | def from_pickle(load_path: str): 111 | with open(load_path, 'rb') as f: 112 | return pickle.load(f) 113 | 114 | 115 | class ThreePartyLoader: 116 | def __init__(self, num_features, num_common_features: int, 117 | common_feature_noise_scale=0.0, data_fmt='libsvm', dataset_name=None, cache_path=None, 118 | n_classes=2, seed=0): 119 | """ 120 | :param cache_path: path for cache of the object 121 | :param dataset_name: name of the dataset 122 | :param num_features_per_party: number of features on both party, including common features 123 | :param num_common_features: number of common features 124 | """ 125 | self.cache_path = cache_path 126 | self.dataset_name = dataset_name 127 | self.data_fmt = data_fmt 128 | self.n_classes = n_classes 129 | self.common_feature_noise_scale = common_feature_noise_scale 130 | self.num_common_features = num_common_features 131 | self.num_features = num_features 132 | self.seeds = list(range(seed, seed + 3)) 133 | 134 | self.X = None 135 | self.y = None 136 | self.Xs = None 137 | 138 | def load_dataset(self, path=None, use_cache=True, scale_label=False): 139 | """ 140 | :param use_cache: whether to use cache 141 | :param path: path of the ml dataset 142 | :param scale_label: whether to scale back the label from [0,1] to int. True in covtype.scale01. 143 | :return: features, labels 144 | """ 145 | if use_cache and self.X is not None and self.y is not None: 146 | assert self.num_features == self.X.shape[1], "Total number of features mismatch." 147 | return self.X, self.y 148 | 149 | assert path is not None 150 | print("Loading {} dataset".format(self.dataset_name)) 151 | if inspect.isfunction(self.data_fmt): 152 | X, y = self.data_fmt(path) 153 | elif self.data_fmt == 'libsvm': 154 | X, y = load_svmlight_file(path) 155 | X = X.toarray() 156 | 157 | # hard code for a strange dataset whose labels are 1 & 2 158 | if self.dataset_name == 'covtype.binary': 159 | y -= 1 160 | elif self.data_fmt == 'csv': 161 | dataset = np.loadtxt(path, delimiter=',', skiprows=1) 162 | X = dataset[:, :-1] 163 | y = dataset[:, -1].reshape(-1) 164 | else: 165 | assert False, "Unsupported ML dataset format" 166 | 167 | if scale_label: 168 | y = np.rint(y * (self.n_classes - 1)).astype('int32') 169 | 170 | assert self.num_features == X.shape[1], "Total number of features mismatch." 171 | print("Done") 172 | if use_cache: 173 | self.X, self.y = X, y 174 | 175 | return X, y 176 | 177 | def load_parties(self, path=None, use_cache=True, scale_label=False): 178 | X, y = self.load_dataset(path, use_cache, scale_label) 179 | if use_cache and self.Xs is not None: 180 | print("Loading parties from cache") 181 | return self.Xs, self.y 182 | 183 | # assuming all the features are useful 184 | print("Splitting features to two parties") 185 | 186 | # randomly divide trained features to three parties 187 | shuffle_state = np.random.RandomState(self.seeds[0]) 188 | shuffle_state.shuffle(X.T) # shuffle columns 189 | trained_features = X[:, self.num_common_features:] 190 | trained_features1, trained_features2, trained_features3 = np.split(trained_features, 3, axis=1) 191 | 192 | # append common features 193 | common_features = X[:, :self.num_common_features] 194 | noise_state1 = np.random.RandomState(self.seeds[1]) 195 | noise_state2 = np.random.RandomState(self.seeds[2]) 196 | noised_common_features1 = common_features.copy() + noise_state1.normal( 197 | scale=self.common_feature_noise_scale, size=common_features.shape) 198 | noised_common_features2 = common_features.copy() + noise_state2.normal( 199 | scale=self.common_feature_noise_scale, size=common_features.shape) 200 | X1 = np.concatenate([trained_features1, common_features], axis=1) 201 | X2 = np.concatenate([noised_common_features1, trained_features2, noised_common_features2], axis=1) 202 | X3 = np.concatenate([noised_common_features2, trained_features3], axis=1) 203 | 204 | if use_cache: 205 | # refresh cached Xs 206 | self.Xs = [X1, X2, X3] 207 | print("Done") 208 | return [X1, X2, X3], y 209 | 210 | def to_pickle(self, save_path: str): 211 | with open(save_path, 'wb') as f: 212 | pickle.dump(self, f) 213 | 214 | @staticmethod 215 | def from_pickle(load_path: str): 216 | with open(load_path, 'rb') as f: 217 | return pickle.load(f) 218 | -------------------------------------------------------------------------------- /src/preprocess/nytaxi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/FeT/836cd91602b3a0fa6379c5b000b7df288bced790/src/preprocess/nytaxi/__init__.py -------------------------------------------------------------------------------- /src/preprocess/nytaxi/clean_citibike.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from datetime import datetime 4 | 5 | import pandas as pd 6 | 7 | 8 | def fill_zero_padding(s: str): 9 | """ 10 | Fill a datetime string with zero paddings in days and months 11 | :param s: original datetime string with format '%-m/%-d/%Y %H:%M:%S' 12 | :return: 13 | """ 14 | month, day, other = s.split('/') 15 | month = "0" + month if len(month) == 1 else month 16 | day = "0" + day if len(day) == 1 else day 17 | return month + "/" + day + "/" + other 18 | 19 | 20 | def clean_bike(bike_ori_data_path, out_bike_data_path, sample_n=None): 21 | print("Reading from {}".format(bike_ori_data_path)) 22 | date_parser = lambda x: datetime.strptime(fill_zero_padding(x), '%m/%d/%Y %H:%M:%S') 23 | bike_ori_data = pd.read_csv(bike_ori_data_path, parse_dates=['starttime', 'stoptime'], 24 | date_parser=date_parser) 25 | 26 | print("Remove all nonsense data") 27 | bike_ori_data.dropna(inplace=True) 28 | bike_ori_data = bike_ori_data[bike_ori_data['tripduration'] < 2000] 29 | 30 | print("Remove useless features from dataset") 31 | bike_ori_data.drop(columns=['bikeid', 'usertype', 'start station name', 'end station name'], inplace=True) 32 | 33 | print("Get pick-up and drop-off hour") 34 | bike_ori_data['start_hour'] = bike_ori_data['starttime'].dt.hour 35 | bike_ori_data['end_hour'] = bike_ori_data['stoptime'].dt.hour 36 | 37 | print("Drop specific time information") 38 | bike_ori_data.drop(columns=['starttime', 'stoptime'], inplace=True) 39 | 40 | print("Rename columns") 41 | bike_ori_data.rename(columns={'start station id': 'start_id', 42 | 'end station id': 'end_id', 43 | 'start station longitude': 'start_lon', 44 | 'start station latitude': 'start_lat', 45 | 'end station longitude': 'end_lon', 46 | 'end station latitude': 'end_lat'}, inplace=True) 47 | 48 | print("Change birth year to age") 49 | bike_ori_data['age'] = bike_ori_data['birth year'].apply(lambda x: 2016 - x) 50 | bike_ori_data.drop(columns=['birth year'], inplace=True) 51 | 52 | print("Columns: " + str(bike_ori_data.columns)) 53 | 54 | out_bike_data = pd.get_dummies(bike_ori_data, 55 | columns=['gender', 'start_id', 'end_id'], 56 | prefix=['gender', 'sid', 'eid'], drop_first=True) 57 | 58 | print("sampling from dataset") 59 | if sample_n is not None: 60 | out_bike_data = out_bike_data.sample(n=sample_n, random_state=0) 61 | 62 | print("Saving cleaned dataset to {}".format(out_bike_data_path)) 63 | out_bike_data.to_pickle(out_bike_data_path) 64 | print("Saved {} samples to file".format(len(out_bike_data.index))) 65 | 66 | 67 | if __name__ == '__main__': 68 | os.chdir(sys.path[0] + "/../../../data/nytaxi") # change working directory 69 | clean_bike("201606-citibike-tripdata.csv", "bike_201606_clean_sample_2e5.pkl", sample_n=200000) 70 | -------------------------------------------------------------------------------- /src/preprocess/nytaxi/clean_tlc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import pandas as pd 5 | 6 | 7 | def clean_tlc_for_airbnb(tlc_ori_data_path, out_tlc_data_path, sample_n=None, keep_col=None): 8 | print("Reading from {}".format(tlc_ori_data_path)) 9 | tlc_ori_data = pd.read_csv(tlc_ori_data_path, parse_dates=['tpep_pickup_datetime', 'tpep_dropoff_datetime']) 10 | 11 | print("get pick-up and drop-off hour") 12 | tlc_ori_data.drop(columns=['store_and_fwd_flag'], inplace=True) 13 | 14 | print("get pick-up and drop-off hour") 15 | tlc_ori_data['pickup_hour'] = tlc_ori_data['tpep_pickup_datetime'].dt.hour 16 | tlc_ori_data['dropoff_hour'] = tlc_ori_data['tpep_dropoff_datetime'].dt.hour 17 | 18 | print("drop specific time information") 19 | tlc_ori_data.drop(columns=['tpep_pickup_datetime', 'tpep_dropoff_datetime'], inplace=True) 20 | 21 | print("divide pickup and dropoff dataset") 22 | tlc_ori_data_pickup = tlc_ori_data.drop(columns=['dropoff_hour', 'dropoff_longitude', 'dropoff_latitude']) 23 | tlc_ori_data_pickup['is_pickup'] = 1 24 | tlc_ori_data_pickup.rename(columns={'pickup_hour': 'hour', 25 | 'pickup_longitude': 'lon', 26 | 'pickup_latitude': 'lat'}, inplace=True) 27 | tlc_ori_data_dropoff = tlc_ori_data.drop(columns=['pickup_hour', 'pickup_longitude', 'pickup_latitude']) 28 | tlc_ori_data_dropoff.rename(columns={'dropoff_hour': 'hour', 29 | 'dropoff_longitude': 'lon', 30 | 'dropoff_latitude': 'lat'}, inplace=True) 31 | tlc_ori_data_dropoff['is_pickup'] = 0 32 | 33 | print("concat pickup and dropoff dataset by rows") 34 | out_tlc_data = pd.concat([tlc_ori_data_pickup, tlc_ori_data_dropoff]) 35 | print("Finished, print all the columns:") 36 | print(out_tlc_data.dtypes) 37 | 38 | if keep_col is None: 39 | print("make categorical features one-hot") 40 | out_tlc_data = pd.get_dummies(out_tlc_data, 41 | columns=['hour', 'VendorID', 'RatecodeID', 'payment_type'], 42 | prefix=['hr', 'vid', 'rid', 'pt'], drop_first=True) 43 | else: 44 | print("Filter columns {}".format(keep_col)) 45 | out_tlc_data = out_tlc_data[keep_col + ['lon', 'lat']] 46 | print("make categorical features one-hot") 47 | dummy_col, dummy_prefix = [], [] 48 | col_prefix = { 49 | 'hour': 'hr', 50 | 'VendorID': 'vid', 51 | 'RatecodeID': 'rid', 52 | 'payment_type': 'pt' 53 | } 54 | for col, prefix in col_prefix.items(): 55 | if col in out_tlc_data.columns: 56 | dummy_col.append(col) 57 | dummy_prefix.append(prefix) 58 | out_tlc_data = pd.get_dummies(out_tlc_data, columns=dummy_col, prefix=dummy_prefix, drop_first=True) 59 | 60 | print("sampling from dataset") 61 | if sample_n is not None: 62 | out_tlc_data = out_tlc_data.sample(n=sample_n, random_state=0) 63 | 64 | print("Saving cleaned dataset to {}".format(out_tlc_data_path)) 65 | out_tlc_data.to_csv(out_tlc_data_path, index=False) 66 | print("Saved {} samples to file".format(len(out_tlc_data.index))) 67 | 68 | 69 | def clean_tlc_for_bike(tlc_ori_data_path, out_tlc_data_path, sample_n=None): 70 | print("Reading from {}".format(tlc_ori_data_path)) 71 | tlc_ori_data = pd.read_csv(tlc_ori_data_path, parse_dates=['tpep_pickup_datetime', 'tpep_dropoff_datetime']) 72 | 73 | print("Drop values that are not reasonable") 74 | tlc_ori_data.dropna(inplace=True) 75 | tlc_ori_data = tlc_ori_data[tlc_ori_data['trip_distance'] > 0] 76 | tlc_ori_data = tlc_ori_data[tlc_ori_data['trip_distance'] < 10] 77 | 78 | print("get duration of the trip") 79 | tlc_ori_data['taxi_duration'] = (tlc_ori_data['tpep_dropoff_datetime'] 80 | - tlc_ori_data['tpep_pickup_datetime']).astype('timedelta64[s]') 81 | tlc_ori_data = tlc_ori_data[tlc_ori_data['taxi_duration'] > 0] 82 | tlc_ori_data = tlc_ori_data[tlc_ori_data['taxi_duration'] < 10000] 83 | 84 | print("get pick-up and drop-off hour") 85 | tlc_ori_data['start_hour'] = tlc_ori_data['tpep_pickup_datetime'].dt.hour 86 | tlc_ori_data['end_hour'] = tlc_ori_data['tpep_dropoff_datetime'].dt.hour 87 | 88 | print("drop specific time information") 89 | tlc_ori_data.drop(columns=['tpep_pickup_datetime', 'tpep_dropoff_datetime'], inplace=True) 90 | 91 | print("divide pickup and dropoff dataset") 92 | tlc_ori_data.rename(columns={'pickup_longitude': 'start_lon', 93 | 'pickup_latitude': 'start_lat', 94 | 'dropoff_longitude': 'end_lon', 95 | 'dropoff_latitude': 'end_lat'}, inplace=True) 96 | 97 | print("Drop useless features") 98 | out_tlc_data = tlc_ori_data[['start_lon', 'start_lat', 'end_lon', 'end_lat', 99 | 'start_hour', 'end_hour', 'trip_distance', 'taxi_duration']] 100 | 101 | print("sampling from dataset") 102 | if sample_n is not None: 103 | out_tlc_data = out_tlc_data.sample(n=sample_n, random_state=0) 104 | 105 | print("Saving cleaned dataset to {}".format(out_tlc_data_path)) 106 | out_tlc_data.to_pickle(out_tlc_data_path) 107 | print("Saved {} samples to file".format(len(out_tlc_data.index))) 108 | 109 | 110 | if __name__ == '__main__': 111 | os.chdir(sys.path[0] + "/../../../data/nytaxi") # change working directory 112 | # clean_tlc("yellow_tripdata_2016-06.csv", "taxi_201606_clean.csv", sample_n=None) 113 | # clean_tlc_for_airbnb("yellow_tripdata_2016-06.csv", "taxi_201606_clean_sample_1e6.csv", 114 | # sample_n=1000000, keep_col=['RatecodeID', 'tip_amount']) 115 | clean_tlc_for_bike("yellow_tripdata_2016-06.csv", "taxi_201606_clean_sample_1e5.pkl", 116 | sample_n=100000) 117 | -------------------------------------------------------------------------------- /src/preprocess/nytaxi/filter_kaggle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import pandas as pd 5 | 6 | 7 | def filter_kaggle(kaggle_train_path, kaggle_out_path): 8 | print("Start filtering") 9 | 10 | print("Loading training data from csv files") 11 | kaggle_train = pd.read_csv(kaggle_train_path, index_col=0, parse_dates=['key']) 12 | 13 | print("Filtering data") 14 | filtered_train = kaggle_train.loc['2009-01-01': '2009-01-31'] 15 | print("Finished filtering training set, got {} samples".format(len(filtered_train.index))) 16 | 17 | print("Saving the filtered data") 18 | filtered_train.to_csv(kaggle_out_path) 19 | print("Done") 20 | 21 | 22 | if __name__ == '__main__': 23 | os.chdir(sys.path[0] + "/../../../data/nytaxi") # change working directory 24 | filter_kaggle("kaggle_train_ori.csv", "kaggle_data.csv") 25 | -------------------------------------------------------------------------------- /src/preprocess/nytaxi/ny_loader.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from src.utils.BasicUtils import move_item_to_start_, move_item_to_end_ 4 | 5 | 6 | class NYAirbnbTaxiLoader: 7 | def __init__(self, airbnb_path, taxi_path=None, link=False): 8 | print("Loading airbnb from {}".format(airbnb_path)) 9 | self.airbnb_data = pd.read_csv(airbnb_path) 10 | print("Loaded.") 11 | if taxi_path is not None: 12 | print("Loading taxi from {}".format(taxi_path)) 13 | self.taxi_data = pd.read_csv(taxi_path) 14 | print("Loaded.") 15 | 16 | if link: 17 | self.labels = self.airbnb_data['price'].to_numpy() 18 | self.airbnb_data.drop(columns=['price'], inplace=True) 19 | 20 | # move lon and lat to end of airbnb 21 | ab_cols = list(self.airbnb_data) 22 | ab_cols.insert(len(ab_cols), ab_cols.pop(ab_cols.index('longitude'))) 23 | ab_cols.insert(len(ab_cols), ab_cols.pop(ab_cols.index('latitude'))) 24 | self.airbnb_data = self.airbnb_data[ab_cols] 25 | print("Current airbnb columns: " + str(list(self.airbnb_data))) 26 | self.airbnb_data = self.airbnb_data.to_numpy() 27 | 28 | # move lon and lat to the front of taxi 29 | tx_cols = list(self.taxi_data) 30 | tx_cols.insert(0, tx_cols.pop(tx_cols.index('lat'))) 31 | tx_cols.insert(0, tx_cols.pop(tx_cols.index('lon'))) 32 | self.taxi_data = self.taxi_data[tx_cols] 33 | print("Current taxi columns: " + str(list(self.taxi_data))) 34 | self.taxi_data = self.taxi_data.to_numpy() 35 | else: 36 | self.airbnb_data.drop(columns=['longitude', 'latitude'], inplace=True) 37 | self.labels = self.airbnb_data['price'].to_numpy() 38 | self.airbnb_data = self.airbnb_data.drop(columns=['price']).to_numpy() 39 | 40 | def load_single(self): 41 | return self.airbnb_data, self.labels 42 | 43 | def load_parties(self): 44 | return [self.airbnb_data, self.taxi_data], self.labels 45 | 46 | 47 | class NYBikeTaxiLoader: 48 | def __init__(self, bike_path, taxi_path=None, link=False): 49 | print("Loading bike from {}".format(bike_path)) 50 | self.bike_data = pd.read_pickle(bike_path) 51 | # self.bike_data = self.bike_data.head(10000) 52 | # print("Remove N/A from bike") 53 | # self.bike_data.dropna() 54 | print("Loaded.") 55 | if taxi_path is not None: 56 | print("Loading taxi from {}".format(taxi_path)) 57 | self.taxi_data = pd.read_pickle(taxi_path) 58 | print("Loaded.") 59 | 60 | if link: 61 | self.labels = self.bike_data['tripduration'].to_numpy() 62 | self.bike_data.drop(columns=['tripduration'], inplace=True) 63 | 64 | # move lon and lat to end of airbnb 65 | bike_cols = list(self.bike_data) 66 | move_item_to_start_(bike_cols, ['start_lon', 'start_lat', 'end_lon', 'end_lat', 67 | 'start_hour', 'end_hour']) 68 | self.bike_data = self.bike_data[bike_cols] 69 | self.bike_data.drop(columns=['start_hour', 'end_hour'], inplace=True) 70 | print("Current bike columns: " + str(list(self.bike_data))) 71 | self.bike_data = self.bike_data.to_numpy() 72 | 73 | # move lon and lat to the front of taxi 74 | tx_cols = list(self.taxi_data) 75 | move_item_to_start_(tx_cols, ['start_lon', 'start_lat', 'end_lon', 'end_lat', 76 | 'start_hour', 'end_hour']) 77 | self.taxi_data = self.taxi_data[tx_cols] 78 | self.taxi_data.drop(columns=['start_hour', 'end_hour'], inplace=True) 79 | print("Current taxi columns: " + str(list(self.taxi_data))) 80 | self.taxi_data = self.taxi_data.to_numpy() 81 | else: 82 | print("Remove columns that are used for linkage") 83 | self.bike_data.drop(columns=['start_lon', 'start_lat', 'end_lon', 'end_lat', 84 | 'start_hour', 'end_hour'], inplace=True) 85 | print('Extract labels') 86 | self.labels = self.bike_data['tripduration'].to_numpy() 87 | print("Extract data") 88 | self.bike_data = self.bike_data.drop(columns=['tripduration']).to_numpy() 89 | 90 | def load_single(self): 91 | return self.bike_data, self.labels 92 | 93 | def load_parties(self): 94 | return [self.bike_data, self.taxi_data], self.labels 95 | -------------------------------------------------------------------------------- /src/preprocess/split-bias.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Split VFL dataset using vertibench under different alpha values 4 | for p in 10 20; do 5 | for ns in 0.05; do 6 | for a in 0.1 0.5 1.0 5.0 10.0 50.0; do 7 | python src/preprocess/FuzzySplitter.py -d $1 -p $p -kd 4 -ns $ns -a $a & 8 | done 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /src/preprocess/vsplit.py: -------------------------------------------------------------------------------- 1 | import wget 2 | import os 3 | import vertibench 4 | import pandas as pd 5 | from sklearn.datasets import load_svmlight_file 6 | from sklearn.preprocessing import MinMaxScaler 7 | 8 | from vertibench.Splitter import ImportanceSplitter 9 | 10 | os.path.join(os.path.dirname(__file__), '..') 11 | os.path.join(os.path.dirname(__file__), '..', '..') 12 | 13 | from src.dataset.LocalDataset import LocalDataset 14 | from src.dataset.VFLRealDataset import VFLRealDataset 15 | from src.utils.BasicUtils import PartyPath 16 | 17 | # check if the data is already downloaded 18 | syn_root = "data/syn/" 19 | dataset_paths = { 20 | # 'covtype': 'covtype.libsvm', 21 | # 'gisette': 'gisette.libsvm', 22 | # 'letter': 'letter.libsvm', 23 | # 'radar': 'radar.csv', 24 | 'realsim': 'realsim.libsvm', 25 | } 26 | 27 | for dataset, filename in dataset_paths.items(): 28 | data_path = os.path.join(syn_root, dataset, filename) 29 | if not os.path.exists(data_path): 30 | raise FileNotFoundError(f"{data_path} not found. Please download the data first.") 31 | 32 | if filename.endswith('.libsvm'): 33 | X, y = load_svmlight_file(data_path) 34 | X = X.toarray() 35 | if dataset in ['gisette']: 36 | # gisette is a binary classification dataset with labels in {-1, 1} 37 | scaler = MinMaxScaler((0, 1)) 38 | y = scaler.fit_transform(y.reshape(-1, 1)).flatten() 39 | elif filename.endswith('.csv'): 40 | df = pd.read_csv(data_path) 41 | X = df.values[:, :-1] 42 | y = df.values[:, -1] 43 | else: 44 | raise NotImplementedError(f"Unknown file format {filename}") 45 | 46 | # split the data 47 | for n_parties in [2, 4, 8]: 48 | for alpha in [0.1, 1, 10, 100]: 49 | splitter = ImportanceSplitter(num_parties=n_parties, weights=alpha, seed=0) 50 | Xs = splitter.split(X) 51 | 52 | # save the data 53 | for i in range(n_parties): 54 | party_path = PartyPath(dataset_path=dataset, n_parties=n_parties, party_id=i, 55 | splitter='imp', weight=alpha, beta=0, seed=0, fmt='pkl').data(None) 56 | party_full_path = os.path.join(syn_root, dataset, party_path) 57 | LocalDataset(Xs[i], y).to_pickle(party_full_path) 58 | print(f"Saved {party_full_path}") 59 | 60 | -------------------------------------------------------------------------------- /src/privacy/GaussianMechanism.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/BorjaBalle/analytic-gaussian-mechanism/blob/master/agm-example.py 3 | Under the Apache License 2.0 4 | """ 5 | 6 | from math import exp, sqrt 7 | from scipy.special import erf 8 | 9 | 10 | def calibrateAnalyticGaussianMechanism(epsilon, delta, GS, tol=1.e-12): 11 | """ Calibrate a Gaussian perturbation for differential privacy using the analytic Gaussian mechanism of [Balle and Wang, ICML'18] 12 | 13 | Arguments: 14 | epsilon : target epsilon (epsilon > 0) 15 | delta : target delta (0 < delta < 1) 16 | GS : upper bound on L2 global sensitivity (GS >= 0) 17 | tol : error tolerance for binary search (tol > 0) 18 | 19 | Output: 20 | sigma : standard deviation of Gaussian noise needed to achieve (epsilon,delta)-DP under global sensitivity GS 21 | """ 22 | 23 | def Phi(t): 24 | return 0.5 * (1.0 + erf(float(t) / sqrt(2.0))) 25 | 26 | def caseA(epsilon, s): 27 | return Phi(sqrt(epsilon * s)) - exp(epsilon) * Phi(-sqrt(epsilon * (s + 2.0))) 28 | 29 | def caseB(epsilon, s): 30 | return Phi(-sqrt(epsilon * s)) - exp(epsilon) * Phi(-sqrt(epsilon * (s + 2.0))) 31 | 32 | def doubling_trick(predicate_stop, s_inf, s_sup): 33 | while (not predicate_stop(s_sup)): 34 | s_inf = s_sup 35 | s_sup = 2.0 * s_inf 36 | return s_inf, s_sup 37 | 38 | def binary_search(predicate_stop, predicate_left, s_inf, s_sup): 39 | s_mid = s_inf + (s_sup - s_inf) / 2.0 40 | while (not predicate_stop(s_mid)): 41 | if (predicate_left(s_mid)): 42 | s_sup = s_mid 43 | else: 44 | s_inf = s_mid 45 | s_mid = s_inf + (s_sup - s_inf) / 2.0 46 | return s_mid 47 | 48 | delta_thr = caseA(epsilon, 0.0) 49 | 50 | if (delta == delta_thr): 51 | alpha = 1.0 52 | 53 | else: 54 | if (delta > delta_thr): 55 | predicate_stop_DT = lambda s: caseA(epsilon, s) >= delta 56 | function_s_to_delta = lambda s: caseA(epsilon, s) 57 | predicate_left_BS = lambda s: function_s_to_delta(s) > delta 58 | function_s_to_alpha = lambda s: sqrt(1.0 + s / 2.0) - sqrt(s / 2.0) 59 | 60 | else: 61 | predicate_stop_DT = lambda s: caseB(epsilon, s) <= delta 62 | function_s_to_delta = lambda s: caseB(epsilon, s) 63 | predicate_left_BS = lambda s: function_s_to_delta(s) < delta 64 | function_s_to_alpha = lambda s: sqrt(1.0 + s / 2.0) + sqrt(s / 2.0) 65 | 66 | predicate_stop_BS = lambda s: abs(function_s_to_delta(s) - delta) <= tol 67 | 68 | s_inf, s_sup = doubling_trick(predicate_stop_DT, 0.0, 1.0) 69 | s_final = binary_search(predicate_stop_BS, predicate_left_BS, s_inf, s_sup) 70 | alpha = function_s_to_alpha(s_final) 71 | 72 | sigma = alpha * GS / sqrt(2.0 * epsilon) 73 | 74 | return sigma -------------------------------------------------------------------------------- /src/privacy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/FeT/836cd91602b3a0fa6379c5b000b7df288bced790/src/privacy/__init__.py -------------------------------------------------------------------------------- /src/script/ablation_dm_or_not.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | gpus=(1 3 4) 5 | num_gpus=${#gpus[@]} 6 | cnt=0 7 | 8 | pdo=0.6 9 | 10 | dataset=gisette 11 | outdir=out/ablation-dm-or-not/$dataset 12 | mkdir -p $outdir 13 | for seed in 0 1 2 3 4; do 14 | for noise in 0.05; do 15 | # without dm 16 | gpu=${gpus[$cnt]} 17 | taskset -c 0-55 python src/script/train_fet.py -d $dataset -m acc -c 2 -p 50 --key-noise $noise -s $seed -g "$gpu" \ 18 | -e 50 -v 6 --knn-k 100 -w 100 -nh 1 -ded 200 -ked 200 -nlb 1 -nab 1 --dropout 0.0 -paf 1 --party-dropout $pdo --disable-dm \ 19 | > "$outdir"/"$dataset"_fet_p50_k100_noise${noise}_pdo${pdo}_s${seed}_no_dm.log & 20 | cnt=$((cnt+1)) 21 | if [ $cnt -eq "$num_gpus" ]; then 22 | wait 23 | cnt=0 24 | fi 25 | 26 | # with dm 27 | gpu=${gpus[$cnt]} 28 | taskset -c 0-55 python src/script/train_fet.py -d $dataset -m acc -c 2 -p 50 --key-noise $noise -s $seed -g "$gpu" \ 29 | -e 50 -v 6 --knn-k 100 -w 100 -nh 1 -ded 200 -ked 200 -nlb 1 -nab 1 --dropout 0.0 -paf 1 --party-dropout $pdo \ 30 | > "$outdir"/"$dataset"_fet_p50_k100_noise${noise}_pdo${pdo}_s${seed}.log & 31 | cnt=$((cnt+1)) 32 | if [ $cnt -eq "$num_gpus" ]; then 33 | wait 34 | cnt=0 35 | fi 36 | done 37 | done 38 | 39 | 40 | dataset=mnist 41 | outdir=out/ablation-dm-or-not/$dataset 42 | mkdir -p $outdir 43 | for seed in 0 1 2 3 4; do 44 | for noise in 0.05; do 45 | # disable dm 46 | gpu=${gpus[$cnt]} 47 | taskset -c 0-55 python src/script/train_fet.py -d $dataset -m acc -c 10 -p 50 --key-noise $noise -s $seed -g "$gpu" \ 48 | -e 30 -v 6 --knn-k 100 -w 100 -nh 1 -ded 200 -ked 200 -nlb 1 -nab 1 --dropout 0.0 -paf 1 --party-dropout $pdo --disable-dm \ 49 | > "$outdir"/"$dataset"_fet_p50_k100_noise${noise}_pdo${pdo}_s${seed}_no_dm.log & 50 | cnt=$((cnt+1)) 51 | if [ $cnt -eq "$num_gpus" ]; then 52 | wait 53 | cnt=0 54 | fi 55 | 56 | gpu=${gpus[$cnt]} 57 | taskset -c 0-55 python src/script/train_fet.py -d $dataset -m acc -c 10 -p 50 --key-noise $noise -s $seed -g "$gpu" \ 58 | -e 30 -v 6 --knn-k 100 -w 100 -nh 1 -ded 200 -ked 200 -nlb 1 -nab 1 --dropout 0.0 -paf 1 --party-dropout $pdo \ 59 | > "$outdir"/"$dataset"_fet_p50_k100_noise${noise}_pdo${pdo}_s${seed}.log & 60 | cnt=$((cnt+1)) 61 | if [ $cnt -eq "$num_gpus" ]; then 62 | wait 63 | cnt=0 64 | fi 65 | done 66 | done 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /src/script/ablation_keynoise.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | gpus=(1 2 3 4 5 6) 5 | num_gpus=${#gpus[@]} 6 | cnt=0 7 | party=50 8 | 9 | dataset=gisette 10 | outdir=out/ablation-keynoise/$dataset 11 | mkdir -p $outdir 12 | for seed in 0 1 2 3 4; do 13 | for noise in 0.02 0.05 0.06 0.07 0.1; do 14 | for knnk in 100; do 15 | gpu=${gpus[$cnt]} 16 | python src/script/train_fet.py -d $dataset -m acc -c 2 -p $party --key-noise $noise -s $seed -g "$gpu" \ 17 | -e 50 -v 6 --knn-k ${knnk} -w 100 -nh 1 -ded 200 -ked 200 -nlb 1 -nab 1 --dropout 0.0 -paf 1 --party-dropout 0.6 \ 18 | > "$outdir"/"$dataset"_fet_p${party}_k${knnk}_noise${noise}_pdo0_s${seed}.log & 19 | cnt=$((cnt+1)) 20 | if [ $cnt -eq "$num_gpus" ]; then 21 | wait 22 | cnt=0 23 | fi 24 | done 25 | done 26 | done 27 | 28 | 29 | dataset=mnist 30 | outdir=out/ablation-keynoise/$dataset 31 | mkdir -p $outdir 32 | for seed in 0 1 2 3 4; do 33 | for noise in 0.02 0.05 0.06 0.07 0.1; do 34 | for knnk in 100; do 35 | gpu=${gpus[$cnt]} 36 | python src/script/train_fet.py -d $dataset -m acc -c 10 -p $party --key-noise $noise -s $seed -g "$gpu" \ 37 | -e 30 -v 6 --knn-k ${knnk} -w 100 -nh 4 -ded 200 -ked 200 -nlb 6 -nab 6 --dropout 0.0 -paf 1 --party-dropout 0.6 \ 38 | > "$outdir"/"$dataset"_fet_p${party}_k${knnk}_noise${noise}_pdo0_s${seed}.log & 39 | cnt=$((cnt+1)) 40 | if [ $cnt -eq "$num_gpus" ]; then 41 | wait 42 | cnt=0 43 | fi 44 | done 45 | done 46 | done 47 | 48 | 49 | -------------------------------------------------------------------------------- /src/script/ablation_keynoise_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | gpus=(1 2 3 4 5 6 7) 5 | process_per_gpu=2 6 | num_gpus=${#gpus[@]} 7 | cnt=0 8 | party=50 9 | 10 | dataset=gisette 11 | outdir=out/ablation-keynoise/$dataset 12 | mkdir -p $outdir 13 | for seed in 0 1 2 3 4; do 14 | for noise in 0.02 0.05 0.06 0.07 0.1; do 15 | for knnk in 100; do 16 | gpu=${gpus[$cnt]} 17 | python src/script/train_solo.py -d $dataset -m acc -c 2 -p $party -sp imp -w 100 -s $seed -g $gpu --key-noise $noise \ 18 | > "$outdir"/"$dataset"_solo_p${party}_k${knnk}_noise${noise}_pdo0_s${seed}.log & 19 | cnt=$((cnt+1)) 20 | if [ $cnt -eq "$num_gpus" ]; then 21 | wait 22 | cnt=0 23 | fi 24 | 25 | gpu=${gpus[$cnt]} 26 | python src/script/train_top1.py -d $dataset -m acc -c 2 -p $party -sp imp -w 100 -s $seed -g $gpu -v 6 --knn-k 1 -nh 1 \ 27 | -ded 50 -ked 50 -nlb 1 -nab 1 --dropout 0.0 --key-noise $noise \ 28 | > "$outdir"/"$dataset"_top1_p${party}_k${knnk}_noise${noise}_pdo0_s${seed}.log & 29 | cnt=$((cnt+1)) 30 | if [ $cnt -eq "$num_gpus" ]; then 31 | wait 32 | cnt=0 33 | fi 34 | done 35 | done 36 | done 37 | 38 | 39 | dataset=mnist 40 | outdir=out/ablation-keynoise/$dataset 41 | mkdir -p $outdir 42 | for seed in 0 1 2 3 4; do 43 | for noise in 0.02 0.05 0.06 0.07 0.1; do 44 | for knnk in 100; do 45 | gpu=${gpus[$cnt]} 46 | python src/script/train_solo.py -d $dataset -e 50 -m acc -c 10 -p $party -sp imp -w 100 -s $seed -g $gpu --key-noise $noise \ 47 | > "$outdir"/"$dataset"_solo_p${party}_k${knnk}_noise${noise}_pdo0_s${seed}.log & 48 | cnt=$((cnt+1)) 49 | if [ $cnt -eq "$num_gpus" ]; then 50 | wait 51 | cnt=0 52 | fi 53 | done 54 | done 55 | done 56 | 57 | 58 | batch=0 59 | dataset=mnist 60 | outdir=out/ablation-keynoise/$dataset 61 | mkdir -p $outdir 62 | for seed in 0 1 2 3 4; do 63 | for noise in 0.02 0.05 0.06 0.07 0.1; do 64 | for knnk in 100; do 65 | if [ $fcnt -lt $finish ]; then 66 | fcnt=$((fcnt+1)) # skip the first 14 67 | continue 68 | fi 69 | 70 | gpu=${gpus[$cnt]} 71 | python src/script/train_top1.py -d $dataset -e 30 -m acc -c 10 -p $party -sp imp -w 100 -s $seed -g $gpu -v 6 --knn-k 1 -nh 1 \ 72 | -ded 50 -ked 50 -nlb 1 -nab 1 --dropout 0.0 --key-noise $noise \ 73 | > "$outdir"/"$dataset"_top1_p${party}_k${knnk}_noise${noise}_pdo0_s${seed}.log & 74 | cnt=$((cnt+1)) 75 | if [ $cnt -eq "$num_gpus" ]; then 76 | batch=$((batch+1)) 77 | if [ $batch -eq "$process_per_gpu" ]; then 78 | wait 79 | batch=0 80 | fi 81 | cnt=0 82 | fi 83 | done 84 | done 85 | done 86 | 87 | 88 | -------------------------------------------------------------------------------- /src/script/ablation_knnk.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | gpus=(0 1 2 3) 5 | num_gpus=${#gpus[@]} 6 | cnt=0 7 | party=20 8 | 9 | dataset=gisette 10 | outdir=out/ablation-knnk/$dataset 11 | mkdir -p $outdir 12 | for seed in 0 1 2 3 4; do 13 | for noise in 0.05; do 14 | for knnk in 1 5 10 20 40 60 80 100; do 15 | gpu=${gpus[$cnt]} 16 | python src/script/train_fet.py -d $dataset -m acc -c 2 -p $party --key-noise $noise -s $seed -g "$gpu" \ 17 | -e 50 -v 6 --knn-k ${knnk} -w 100 -nh 1 -ded 200 -ked 200 -nlb 1 -nab 1 --dropout 0.0 -paf 1 --party-dropout 0.6 \ 18 | > "$outdir"/"$dataset"_fet_p${party}_k${knnk}_noise${noise}_pdo0_s${seed}.log & 19 | cnt=$((cnt+1)) 20 | if [ $cnt -eq "$num_gpus" ]; then 21 | wait 22 | cnt=0 23 | fi 24 | done 25 | done 26 | done 27 | 28 | 29 | dataset=mnist 30 | outdir=out/ablation-knnk/$dataset 31 | mkdir -p $outdir 32 | for seed in 0 1 2 3 4; do 33 | for noise in 0.05; do 34 | for knnk in 1 5 10 20 40 60 80 100; do 35 | gpu=${gpus[$cnt]} 36 | python src/script/train_fet.py -d $dataset -m acc -c 10 -p $party --key-noise $noise -s $seed -g "$gpu" \ 37 | -e 50 -v 6 --knn-k ${knnk} -w 100 -nh 1 -ded 200 -ked 200 -nlb 1 -nab 1 --dropout 0.0 -paf 1 --party-dropout 0.6 \ 38 | > "$outdir"/"$dataset"_fet_p${party}_k${knnk}_noise${noise}_pdo0_s${seed}.log & 39 | cnt=$((cnt+1)) 40 | if [ $cnt -eq "$num_gpus" ]; then 41 | wait 42 | cnt=0 43 | fi 44 | done 45 | done 46 | done 47 | 48 | 49 | -------------------------------------------------------------------------------- /src/script/ablation_knnk_real.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | gpus=(0 1 2 3 4 5 6 7) 5 | num_gpus=${#gpus[@]} 6 | cnt=0 7 | party=2 8 | 9 | dataset=house 10 | outdir=out/ablation-knnk/$dataset 11 | mkdir -p $outdir 12 | for seed in 0 1 2 3 4; do 13 | for knnk in 1 5 10 20 40 60 80 100; do 14 | gpu=${gpus[$cnt]} 15 | python src/script/train_fet.py -d $dataset -m rmse -c 1 -p $party -s $seed -g "$gpu" \ 16 | -e 100 -v 6 --knn-k ${knnk} -nh 1 -ded 100 -ked 100 -nlb 3 -nab 3 --dropout 0.3 -paf 1 \ 17 | > "$outdir"/"$dataset"_fet_p${party}_k${knnk}_s${seed}.log & 18 | cnt=$((cnt+1)) 19 | if [ $cnt -eq "$num_gpus" ]; then 20 | wait 21 | cnt=0 22 | fi 23 | done 24 | done 25 | 26 | 27 | dataset=taxi 28 | outdir=out/ablation-knnk/$dataset 29 | mkdir -p $outdir 30 | for seed in 0 1 2 3 4; do 31 | for knnk in 1 5 10 20 40 60 80 100; do 32 | gpu=${gpus[$cnt]} 33 | python src/script/train_fet.py -d $dataset -m rmse -c 1 -p $party -s $seed -g "$gpu" \ 34 | -e 100 -v 6 --knn-k ${knnk} -nh 1 -ded 100 -ked 100 -nlb 3 -nab 3 --dropout 0.3 -paf 1 \ 35 | > "$outdir"/"$dataset"_fet_p${party}_k${knnk}_s${seed}.log & 36 | cnt=$((cnt+1)) 37 | if [ $cnt -eq "$num_gpus" ]; then 38 | wait 39 | cnt=0 40 | fi 41 | done 42 | done 43 | 44 | dataset=hdb 45 | outdir=out/ablation-knnk/$dataset 46 | mkdir -p $outdir 47 | for seed in 0 1 2 3 4; do 48 | for knnk in 1 5 10 20 40 60 80 100; do 49 | gpu=${gpus[$cnt]} 50 | python src/script/train_fet.py -d $dataset -m rmse -c 1 -p $party -s $seed -g "$gpu" \ 51 | -e 100 -v 6 --knn-k ${knnk} -nh 1 -ded 100 -ked 100 -nlb 1 -nab 2 --dropout 0.1 -paf 1 \ 52 | > "$outdir"/"$dataset"_fet_p${party}_k${knnk}_noise${noise}_pdo0_s${seed}.log & 53 | cnt=$((cnt+1)) 54 | if [ $cnt -eq "$num_gpus" ]; then 55 | wait 56 | cnt=0 57 | fi 58 | done 59 | done 60 | 61 | -------------------------------------------------------------------------------- /src/script/ablation_party_dropout.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | gpus=(2 3 4 5 6 7) 5 | num_gpus=${#gpus[@]} 6 | cnt=0 7 | 8 | dataset=gisette 9 | outdir=out/ablation-party-dropout/$dataset 10 | mkdir -p $outdir 11 | for seed in 0 1 2 3 4; do 12 | for noise in 0.05; do 13 | for pdo in 0 0.2 0.4 0.6 0.8 1.0; do 14 | gpu=${gpus[$cnt]} 15 | taskset -c 0-55 python src/script/train_fet.py -d $dataset -m acc -c 2 -p 50 --key-noise $noise -s $seed -g "$gpu" \ 16 | -e 50 -v 6 --knn-k 100 -w 100 -nh 1 -ded 200 -ked 200 -nlb 1 -nab 1 --dropout 0.0 -paf 2 --party-dropout $pdo \ 17 | > "$outdir"/"$dataset"_fet_p50_k100_noise${noise}_pdo${pdo}_s${seed}.log & 18 | cnt=$((cnt+1)) 19 | if [ $cnt -eq "$num_gpus" ]; then 20 | wait 21 | cnt=0 22 | fi 23 | done 24 | done 25 | done 26 | 27 | 28 | dataset=mnist 29 | outdir=out/ablation-party-dropout/$dataset 30 | mkdir -p $outdir 31 | for seed in 0 1 2 3 4; do 32 | for noise in 0.05; do 33 | for pdo in 0 0.2 0.4 0.6 0.8 1.0; do 34 | gpu=${gpus[$cnt]} 35 | taskset -c 0-55 python src/script/train_fet.py -d $dataset -m acc -c 10 -p 50 --key-noise $noise -s $seed -g "$gpu" \ 36 | -e 30 -v 6 --knn-k 100 -w 100 -nh 1 -ded 200 -ked 200 -nlb 1 -nab 1 --dropout 0.0 -paf 2 --party-dropout $pdo \ 37 | > "$outdir"/"$dataset"_fet_p50_k100_noise${noise}_pdo${pdo}_s${seed}.log & 38 | cnt=$((cnt+1)) 39 | if [ $cnt -eq "$num_gpus" ]; then 40 | wait 41 | cnt=0 42 | fi 43 | done 44 | done 45 | done 46 | 47 | 48 | -------------------------------------------------------------------------------- /src/script/ablation_pe_average_freq.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | gpus=(2 3 4 5 6 7) 4 | num_gpus=${#gpus[@]} 5 | cnt=0 6 | k=100 7 | 8 | dataset=gisette 9 | folder=out/ablation-pe-avg-freq/$dataset 10 | mkdir -p $folder 11 | for noise in 0.05; do 12 | for seed in 0 1 2 3 4; do 13 | for paf in 0 1 2 3 5 10; do 14 | for np in 2 5 20 50; do 15 | gpu=${gpus[$cnt]} 16 | taskset -c 0-55 python src/script/train_fet.py -d $dataset -m acc -c 2 -p $np --key-noise $noise -s $seed -v 6 --knn-k $k \ 17 | -w 100 -nh 1 -ded 200 -ked 200 -nlb 1 -nab 1 -e 50 --dropout 0.0 -paf $paf -g "$gpu" \ 18 | > $folder/${dataset}_fet_p${np}_k${k}_noise${noise}_paf${paf}_s${seed}.log & 19 | cnt=$((cnt+1)) 20 | if [ $cnt -eq "$num_gpus" ]; then 21 | wait 22 | cnt=0 23 | fi 24 | done 25 | done 26 | done 27 | done 28 | 29 | 30 | dataset=mnist 31 | folder=out/ablation-pe-avg-freq/$dataset 32 | mkdir -p $folder 33 | for np in 2 5 20 50; do 34 | for noise in 0.05; do 35 | for seed in 0 1 2 3 4; do 36 | for paf in 0 1 2 3 5 10; do 37 | gpu=${gpus[$cnt]} 38 | taskset -c 0-55 python src/script/train_fet.py -d $dataset -m acc -c 10 -p $np --key-noise $noise -s $seed -v 6 --knn-k $k \ 39 | -w 100 -nh 1 -ded 200 -ked 200 -nlb 1 -nab 1 -e 30 --dropout 0.0 -paf $paf -g "$gpu" \ 40 | > ${folder}/${dataset}_fet_p${np}_k${k}_noise${noise}_paf${paf}_s${seed}.log & 41 | cnt=$((cnt+1)) 42 | if [ $cnt -eq "$num_gpus" ]; then 43 | wait 44 | cnt=0 45 | fi 46 | done 47 | done 48 | done 49 | done 50 | -------------------------------------------------------------------------------- /src/script/ablation_pe_or_not.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | gpus=(2 3 4 5 6 7) 5 | num_gpus=${#gpus[@]} 6 | cnt=0 7 | 8 | pdo=0.6 9 | 10 | dataset=gisette 11 | outdir=out/ablation-pe-or-not/$dataset 12 | mkdir -p $outdir 13 | for seed in 0 1 2 3 4; do 14 | for noise in 0.05; do 15 | # without pe 16 | gpu=${gpus[$cnt]} 17 | taskset -c 0-55 python src/script/train_fet.py -d $dataset -m acc -c 2 -p 50 --key-noise $noise -s $seed -g "$gpu" \ 18 | -e 50 -v 6 --knn-k 100 -w 100 -nh 1 -ded 200 -ked 200 -nlb 1 -nab 1 --dropout 0.0 -paf 1 --party-dropout $pdo --disable-pe \ 19 | > "$outdir"/"$dataset"_fet_p50_k100_noise${noise}_pdo${pdo}_s${seed}_no_pe.log & 20 | cnt=$((cnt+1)) 21 | if [ $cnt -eq "$num_gpus" ]; then 22 | wait 23 | cnt=0 24 | fi 25 | 26 | # with pe 27 | gpu=${gpus[$cnt]} 28 | taskset -c 0-55 python src/script/train_fet.py -d $dataset -m acc -c 2 -p 50 --key-noise $noise -s $seed -g "$gpu" \ 29 | -e 50 -v 6 --knn-k 100 -w 100 -nh 1 -ded 200 -ked 200 -nlb 1 -nab 1 --dropout 0.0 -paf 1 --party-dropout $pdo \ 30 | > "$outdir"/"$dataset"_fet_p50_k100_noise${noise}_pdo${pdo}_s${seed}.log & 31 | cnt=$((cnt+1)) 32 | if [ $cnt -eq "$num_gpus" ]; then 33 | wait 34 | cnt=0 35 | fi 36 | done 37 | done 38 | 39 | 40 | dataset=mnist 41 | outdir=out/ablation-pe-or-not/$dataset 42 | mkdir -p $outdir 43 | for seed in 0 1 2 3 4; do 44 | for noise in 0.05; do 45 | # disable pe 46 | gpu=${gpus[$cnt]} 47 | taskset -c 0-55 python src/script/train_fet.py -d $dataset -m acc -c 10 -p 50 --key-noise $noise -s $seed -g "$gpu" \ 48 | -e 30 -v 6 --knn-k 100 -w 100 -nh 1 -ded 200 -ked 200 -nlb 1 -nab 1 --dropout 0.0 -paf 1 --party-dropout $pdo --disable-pe \ 49 | > "$outdir"/"$dataset"_fet_p50_k100_noise${noise}_pdo${pdo}_s${seed}_no_pe.log & 50 | cnt=$((cnt+1)) 51 | if [ $cnt -eq "$num_gpus" ]; then 52 | wait 53 | cnt=0 54 | fi 55 | 56 | gpu=${gpus[$cnt]} 57 | taskset -c 0-55 python src/script/train_fet.py -d $dataset -m acc -c 10 -p 50 --key-noise $noise -s $seed -g "$gpu" \ 58 | -e 30 -v 6 --knn-k 100 -w 100 -nh 1 -ded 200 -ked 200 -nlb 1 -nab 1 --dropout 0.0 -paf 1 --party-dropout $pdo \ 59 | > "$outdir"/"$dataset"_fet_p50_k100_noise${noise}_pdo${pdo}_s${seed}.log & 60 | cnt=$((cnt+1)) 61 | if [ $cnt -eq "$num_gpus" ]; then 62 | wait 63 | cnt=0 64 | fi 65 | done 66 | done 67 | 68 | 69 | -------------------------------------------------------------------------------- /src/script/ablation_real_dm_or_not.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | outdir=out/real/ 4 | mkdir -p $outdir 5 | 6 | for seed in 0 1 2 3 4; do 7 | python src/script/train_fet.py -d house -m rmse -c 1 -p 2 -s $seed -v 6 --knn-k 100 -nh 4 -ded 100 -ked 100 -nlb 3 -nab 3 -paf 1 --dropout 0.3 -g 4 --disable-dm > \ 8 | ${outdir}/house_seed${seed}_nodm.log & 9 | python src/script/train_fet.py -d taxi -m rmse -c 1 -p 2 -s $seed -v 6 -e 50 -lr 3e-4 --knn-k 100 -nh 4 -ded 100 -ked 100 -nlb 3 -nab 3 -paf 1 --dropout 0.3 -g 7 --disable-dm > \ 10 | ${outdir}/taxi_seed${seed}_nodm.log & 11 | python src/script/train_fet.py -d hdb -m rmse -c 1 -p 2 -s $seed -v 6 --knn-k 100 -nh 4 -ded 100 -ked 100 -nlb 1 -nab 2 -paf 1 --dropout 0.3 -g 5 --disable-dm > \ 12 | ${outdir}/hdbs_seed${seed}_nodm.log & 13 | wait 14 | done 15 | 16 | -------------------------------------------------------------------------------- /src/script/download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | mkdir -p data/syn/gisette data/syn/mnist 4 | wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/gisette_scale.bz2 -O data/syn/gisette/gisette.libsvm.bz2 5 | echo "Unziping gisette.libsvm.bz2" 6 | bzip2 -d data/syn/gisette/gisette.libsvm.bz2 7 | 8 | wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.bz2 -O data/syn/mnist/mnist.libsvm.bz2 9 | echo "Unziping mnist.libsvm.bz2" 10 | bzip2 -d data/syn/mnist/mnist.libsvm.bz2 11 | -------------------------------------------------------------------------------- /src/script/run_real_fet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | outdir=out/real/ 4 | mkdir -p $outdir 5 | 6 | for seed in 0 1 2 3 4; do 7 | python src/script/train_fet.py -d house -m rmse -c 1 -p 2 -s $seed --knn-k 100 -nh 4 -ded 100 -ked 100 -nlb 3 -nab 3 -paf 1 --dropout 0.3 -g 0 > \ 8 | ${outdir}/house_seed${seed}.log 9 | python src/script/train_fet.py -d taxi -m rmse -c 1 -p 2 -s $seed -e 50 -lr 3e-4 --knn-k 100 -nh 4 -ded 100 -ked 100 -nlb 3 -nab 3 -paf 1 --dropout 0.3 -g 0 > \ 10 | ${outdir}/taxi_seed${seed}.log 11 | python src/script/train_fet.py -d hdb -m rmse -c 1 -p 2 -s $seed --knn-k 100 -nh 4 -ded 100 -ked 100 -nlb 1 -nab 2 -paf 1 --dropout 0.3 -g 0 > \ 12 | ${outdir}/hdbs_seed${seed}.log 13 | wait 14 | done 15 | -------------------------------------------------------------------------------- /src/script/run_scale.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Function to get GPU ID 4 | get_gpu() { 5 | # echo $((RANDOM % 8)) 6 | echo 0 7 | } 8 | 9 | 10 | FET_HOME="." 11 | 12 | # Main execution 13 | datasets=("gisette" "mnist") 14 | n_classes=(2 10) 15 | 16 | 17 | for ((i=0; i<${#datasets[@]}; i++)); do 18 | dataset=${datasets[i]} 19 | n_class=${n_classes[i]} 20 | 21 | LOG_DIR="${FET_HOME}/out/scale/${dataset}" 22 | mkdir -p "${LOG_DIR}" 23 | 24 | for seed in {0..4}; do 25 | for party in 10 20 30 40 50; do 26 | noise=0.05 27 | 28 | # Solo 29 | gpu=$(get_gpu) 30 | python "${FET_HOME}/src/script/train_solo.py" \ 31 | -d "${dataset}" -c "${n_class}" -m "acc" -p "${party}" \ 32 | -s "${seed}" -w 100 --key-noise "${noise}" \ 33 | -g "${gpu}" \ 34 | > "${LOG_DIR}/scaletest-solo_${dataset}_c${n_class}-noise${noise}_macc_party${party}_seed${seed}.txt" 35 | 36 | # FeT 37 | gpu=$(get_gpu) 38 | python "${FET_HOME}/src/script/train_fet.py" \ 39 | -d "${dataset}" -m "acc" -c "${n_class}" -p "${party}" \ 40 | -s "${seed}" --knn-k 100 -w 100 -nh 1 -ded 200 -ked 200 \ 41 | -nlb 1 -nab 1 --dropout 0.0 --key-noise "${noise}" --party-dropout 0.6 -paf 1 \ 42 | -g "${gpu}" \ 43 | > "${LOG_DIR}/scaletest-fet_${dataset}_c${n_class}-noise${noise}_macc_party${party}_seed${seed}_k100_nh1_ded200_ked200_nAb1_nLb1_dropOut0.0.txt" 44 | done 45 | done 46 | done 47 | 48 | echo "All tasks completed" 49 | -------------------------------------------------------------------------------- /src/script/split_scale.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | for dataset in gisette mnist; do 4 | for ns in 0.05; do 5 | for p in 10 20 30 40 50; do 6 | python src/preprocess/FuzzySplitter.py -d ${dataset}.libsvm -p $p -kd 4 -ns $ns & 7 | done 8 | done 9 | wait 10 | echo "Done $dataset" 11 | done 12 | -------------------------------------------------------------------------------- /src/script/train_fet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import warnings 5 | from datetime import datetime 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import torch.multiprocessing 10 | import torch.nn as nn 11 | import torch_optimizer as optim 12 | from torch.utils.data import DataLoader 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | # add src to python path 16 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 17 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 18 | 19 | from src.dataset.VFLDataset import VFLSynAlignedDataset 20 | from src.dataset.VFLRealDataset import VFLRealDataset 21 | from src.preprocess.hdb.hdb_loader import load_both as load_both_hdb 22 | from src.preprocess.ml_dataset.two_party_loader import TwoPartyLoader as FedSimSynLoader 23 | from src.preprocess.nytaxi.ny_loader import NYBikeTaxiLoader 24 | from src.train.Fit import fit 25 | from src.utils.BasicUtils import (PartyPath, get_metric_from_str, get_metric_positive_from_str) 26 | from src.utils.logger import CommLogger 27 | from src.model.FeT import FeT 28 | 29 | # Avoid "Too many open files" error 30 | torch.multiprocessing.set_sharing_strategy('file_system') 31 | 32 | 33 | 34 | if __name__ == '__main__': 35 | # arguments 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--gpu', '-g', type=int, default=0, 38 | help="GPU ID. Set to None if you want to use CPU") 39 | 40 | # parameters for dataset 41 | parser.add_argument('--dataset', '-d', type=str, 42 | help="dataset to use.") 43 | parser.add_argument('--n_parties', '-p', type=int, default=4, 44 | help="number of parties. Should be >=2") 45 | parser.add_argument('--primary_party', '-pp', type=int, default=0, 46 | help="primary party. Should be in [0, n_parties-1]") 47 | parser.add_argument('--splitter', '-sp', type=str, default='imp') 48 | parser.add_argument('--weights', '-w', type=float, default=1, help="weights for the ImportanceSplitter") 49 | parser.add_argument('--beta', '-b', type=float, default=1, help="beta for the CorrelationSplitter") 50 | 51 | # parameters for model 52 | parser.add_argument('--epochs', '-e', type=int, default=100) 53 | parser.add_argument('--lr', '-lr', type=float, default=1e-3) 54 | parser.add_argument('--weight_decay', '-wd', type=float, default=1e-5) 55 | parser.add_argument('--batch_size', '-bs', type=int, default=128) 56 | parser.add_argument('--n_classes', '-c', type=int, default=1, 57 | help="number of classes. 1 for regression, 2 for binary classification," 58 | ">=3 for multi-class classification") 59 | parser.add_argument('--metric', '-m', type=str, default='acc', 60 | help="metric to evaluate the model. Supported metrics: [accuracy, rmse]") 61 | parser.add_argument('--result-path', '-rp', type=str, default=None, 62 | help="path to save the result") 63 | parser.add_argument('--seed', '-s', type=int, default=0, help="random seed") 64 | parser.add_argument('--log-dir', '-ld', type=str, default='log', help='log directory') 65 | parser.add_argument('--data-embed-dim', '-ded', type=int, default=200, help='data embedding dimension') 66 | parser.add_argument('--key-embed-dim', '-ked', type=int, default=200, help='key embedding dimension') 67 | parser.add_argument('--num-heads', '-nh', type=int, default=4, help='number of heads in multi-head attention') 68 | parser.add_argument('--dropout', type=float, default=0.0, help='dropout rate') 69 | parser.add_argument('--party-dropout', type=float, default=0.0, help='dropout rate for entire party') 70 | parser.add_argument('--n-local-blocks', '-nlb', type=int, default=6, help='number of local blocks') 71 | parser.add_argument('--n-agg-blocks', '-nab', type=int, default=6, help='number of aggregation blocks') 72 | parser.add_argument('--knn-k', type=int, default=100, help='k for knn') 73 | parser.add_argument('--disable-pe', action='store_true', help='disable positional encoding') 74 | parser.add_argument('--disable-dm', action='store_true', help='disable dynamic masking') 75 | parser.add_argument('-paf', '--pe-average-freq', type=int, default=0, 76 | help='average frequency for positional encoding on each party') 77 | 78 | # parameters for fedsim synthetic dataset 79 | parser.add_argument('--key-noise', type=float, default=0.0, help='key noise in FedSim synthetic data') 80 | 81 | # parameters for differential privacy 82 | parser.add_argument('--dp-noise', type=float, default=None, help='noise scale for differential privacy') 83 | parser.add_argument('--dp-clip', type=float, default=1.0, help='clip bound for differential privacy') 84 | parser.add_argument('--dp-sample', type=float, default=None, help='sample rate for differential privacy (privacy amplification)') 85 | 86 | # cache parameters 87 | parser.add_argument('--flush-cache', action='store_true', help='flush cache') 88 | parser.add_argument('--disable-cache', action='store_true', help='disable cache', default=True) # default to True 89 | args = parser.parse_args() 90 | 91 | # print hostname 92 | print(f"Hostname: {os.uname().nodename}") 93 | 94 | path = PartyPath(f"data/syn/{args.dataset}", args.n_parties, 0, args.splitter, args.weights, args.beta, 95 | args.seed, fmt='pkl', comm_root="log") 96 | comm_logger = CommLogger(args.n_parties, path.comm_log) 97 | 98 | real_root = "data/" 99 | syn_root = "data/syn" 100 | cache_root = "cache" 101 | normalize_key = True 102 | if args.dataset == 'house': 103 | house_root = f"{real_root}/house/" 104 | key_dim = 2 105 | house_dataset = VFLRealDataset.from_csv([os.path.join(house_root, "house_clean.csv"), 106 | os.path.join(house_root, "airbnb_clean.csv")], key_cols=key_dim, 107 | header=1, 108 | multi_primary=False, ks=args.knn_k, 109 | cache_key=os.path.join(cache_root, "nbr/house/main.pkl"), 110 | use_cache=not args.disable_cache, 111 | sample_rate_before_topk=args.dp_sample) 112 | train_dataset, val_dataset, test_dataset = house_dataset.split_train_test_primary( 113 | val_ratio=0.1, test_ratio=0.2, random_state=args.seed, shuffle=True) 114 | train_dataset.data_sampler.sample_rate_before_topk = args.dp_sample 115 | elif args.dataset == 'taxi': 116 | key_dim = 4 117 | taxi_root = f"{real_root}/nytaxi/" 118 | bike_path = "bike_201606_clean_sample_2e5.pkl" 119 | taxi_path = "taxi_201606_clean_sample_1e5.pkl" 120 | base_loader = NYBikeTaxiLoader(bike_path=os.path.join(taxi_root, bike_path), 121 | taxi_path=os.path.join(taxi_root, taxi_path), link=True) 122 | [X1, X2], y = base_loader.load_parties() 123 | 124 | # Append two empty columns to X2 since it feature_dim is smaller than key_dim, which leads to an error in 125 | # PositionalEncoding 126 | X2 = np.concatenate([X2, np.zeros([X2.shape[0], 2])], axis=1) 127 | 128 | taxi_dataset = VFLRealDataset(([X1, X2], y), primary_party_id=0, key_cols=key_dim, ks=args.knn_k, 129 | sample_rate_before_topk=args.dp_sample) 130 | train_dataset, val_dataset, test_dataset = taxi_dataset.split_train_test_primary( 131 | val_ratio=0.1, test_ratio=0.2, random_state=args.seed, shuffle=True) 132 | elif args.dataset == 'hdb': 133 | key_dim = 2 134 | hdb_path = f"{real_root}/hdb/hdb_clean.csv" 135 | school_path = f"{real_root}/hdb/school_clean.csv" 136 | [X1, X2], y = load_both_hdb(hdb_path, school_path, active_party='hdb') 137 | hdb_dataset = VFLRealDataset(([X1, X2], y), primary_party_id=0, key_cols=key_dim, ks=args.knn_k, 138 | sample_rate_before_topk=args.dp_sample) 139 | train_dataset, val_dataset, test_dataset = hdb_dataset.split_train_test_primary( 140 | val_ratio=0.1, test_ratio=0.2, random_state=args.seed, shuffle=True) 141 | elif args.dataset in ("gisette", "mnist"): 142 | # multi_party_dataset for scalability 143 | normalize_key = False 144 | key_dim = 4 145 | 146 | syn_dataset_dir = f"data/syn/{args.dataset}/noise{args.key_noise}/" 147 | 148 | syn_aligned_dataset = VFLSynAlignedDataset.from_pickle(syn_dataset_dir, args.dataset, args.n_parties, 149 | primary_party_id=args.primary_party, 150 | splitter=args.splitter, 151 | weight=args.weights, beta=args.beta, seed=0, 152 | type=None) 153 | syn_dataset = VFLRealDataset.from_syn_aligned(syn_aligned_dataset, ks=args.knn_k, key_cols=key_dim, 154 | sample_rate_before_topk=args.dp_sample, 155 | use_cache=not args.disable_cache, 156 | cache_key=f"{args.dataset}-{args.n_parties}", 157 | multi_primary=False) 158 | train_dataset, val_dataset, test_dataset = syn_dataset.split_train_test_primary( 159 | val_ratio=0.1, test_ratio=0.2, random_state=args.seed) 160 | else: 161 | key_dim = 0 162 | syn_dataset_dir = f"{syn_root}/{args.dataset}" 163 | print(f"Loading synthetic dataset from {syn_dataset_dir}") 164 | syn_aligned_dataset = VFLSynAlignedDataset.from_pickle(syn_dataset_dir, args.dataset, args.n_parties, 165 | primary_party_id=args.primary_party, 166 | splitter=args.splitter, 167 | weight=args.weights, beta=args.beta, seed=args.seed, 168 | type=None) 169 | for local_dataset in syn_aligned_dataset.local_datasets: 170 | # local_dataset.key = None 171 | local_dataset.key = np.arange(len(local_dataset)).reshape(-1, 1) # debug 172 | syn_dataset = VFLRealDataset.from_syn_aligned(syn_aligned_dataset, ks=args.knn_k) 173 | syn_dataset.key_cols = 1 # debug 174 | train_dataset, val_dataset, test_dataset = syn_dataset.split_train_test_primary( 175 | val_ratio=0.1, test_ratio=0.2, random_state=args.seed) 176 | 177 | # normalize features 178 | X_scalers = train_dataset.normalize_(include_key=normalize_key) 179 | if val_dataset is not None: 180 | val_dataset.normalize_(scalers=X_scalers, include_key=normalize_key) 181 | test_dataset.normalize_(scalers=X_scalers, include_key=normalize_key) 182 | 183 | # create the model 184 | y_scaler = None 185 | if args.n_classes == 1: # regression 186 | task = 'reg' 187 | loss_fn = nn.MSELoss() 188 | out_dim = 1 189 | out_activation = nn.Sigmoid() 190 | if args.metric == 'acc': # if metric is accuracy, change it to rmse 191 | args.metric = 'rmse' 192 | warnings.warn("Metric is changed to rmse for regression task") 193 | # scale the labels to [0, 1] 194 | y_scaler = train_dataset.scale_y_() 195 | if val_dataset is not None: 196 | val_dataset.scale_y_(scaler=y_scaler) 197 | test_dataset.scale_y_(scaler=y_scaler) 198 | elif args.n_classes == 2: # binary classification 199 | task = 'bin-cls' 200 | loss_fn = nn.BCELoss() 201 | out_dim = 1 202 | out_activation = nn.Sigmoid() 203 | # make sure the labels are in [0, 1] 204 | train_dataset.scale_y_(0, 1) 205 | if val_dataset is not None: 206 | val_dataset.scale_y_(0, 1) 207 | test_dataset.scale_y_(0, 1) 208 | else: # multi-class classification 209 | task = 'multi-cls' 210 | loss_fn = nn.CrossEntropyLoss() 211 | out_dim = args.n_classes 212 | out_activation = None # No need for softmax since it is included in CrossEntropyLoss 213 | 214 | model = FeT(key_dims=train_dataset.local_key_channels, data_dims=train_dataset.local_input_channels, 215 | out_dim=out_dim, data_embed_dim=args.data_embed_dim, 216 | key_embed_dim=args.key_embed_dim, 217 | num_heads=args.num_heads, dropout=args.dropout, party_dropout=args.party_dropout, 218 | # n_embeddings=len(train_dataset) + len(test_dataset), 219 | n_embeddings=None, out_activation=out_activation, 220 | n_local_blocks=args.n_local_blocks, n_agg_blocks=args.n_agg_blocks, k=args.knn_k, 221 | rep_noise=args.dp_noise, max_rep_norm=args.dp_clip, enable_pe=not args.disable_pe, 222 | enable_dm=not args.disable_dm) 223 | # model = torch.compile(model) 224 | 225 | # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 226 | optimizer = optim.Lamb(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 227 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5, verbose=True) 228 | 229 | train_dataset.to_tensor_() 230 | test_dataset.to_tensor_() 231 | if val_dataset is not None: 232 | val_dataset.to_tensor_() 233 | 234 | def is_debug(): 235 | gettrace = getattr(sys, 'gettrace', None) 236 | if gettrace is None: 237 | return False 238 | elif gettrace(): 239 | return True 240 | else: 241 | return False 242 | 243 | if is_debug(): 244 | print("Debug mode, set num_workers to 0") 245 | n_workers = 0 246 | else: 247 | n_workers = 0 # disable multiprocessing for a multi-process bug in pytorch (some times it will freeze) 248 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=n_workers, 249 | drop_last=False) 250 | if val_dataset is not None: 251 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=n_workers, 252 | drop_last=False) 253 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=n_workers, 254 | drop_last=False) 255 | 256 | metric_fn = get_metric_from_str(args.metric) 257 | metric_positive = get_metric_positive_from_str(args.metric) 258 | 259 | timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") 260 | writer = SummaryWriter(log_dir=str(os.path.join(args.log_dir, args.dataset))) 261 | 262 | cache_dir = os.path.join("cache", args.dataset) 263 | os.makedirs(cache_dir, exist_ok=True) 264 | model_path = os.path.join(cache_dir, f"model_{args.dataset}_party{args.n_parties}_knn{args.knn_k}" 265 | f"_{timestamp}.pt") 266 | 267 | test_loss_list, test_score_list = fit(model, optimizer, loss_fn, metric_fn, train_loader, epochs=args.epochs, 268 | gpu_id=args.gpu, 269 | n_classes=args.n_classes, test_loader=test_loader, task=task, 270 | scheduler=scheduler, has_key=True, 271 | val_loader=val_loader, metric_positive=metric_positive, y_scaler=y_scaler, 272 | solo=False, writer=writer, log_timestamp=timestamp, 273 | visualize=False, model_path=model_path, dataset_name=args.dataset) 274 | 275 | if args.result_path is not None: 276 | # save test loss and score to a two-column csv file, each row is for one epoch (with pandas) 277 | test_result = pd.DataFrame({'loss': test_loss_list, 'score': test_score_list}) 278 | test_result.to_csv(args.result_path, index=False) 279 | 280 | print("Done!") 281 | -------------------------------------------------------------------------------- /src/script/train_solo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Callable 4 | import argparse 5 | import warnings 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | 12 | import pandas as pd 13 | from tqdm import tqdm 14 | 15 | # add src to python path 16 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 17 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 18 | from model.Solo import MLP 19 | from dataset.LocalDataset import LocalDataset 20 | from src.utils import get_device_from_gpu_id, get_metric_from_str, PartyPath 21 | from src.utils import get_metric_positive_from_str 22 | from train.Fit import fit 23 | from dataset.VFLRealDataset import VFLRealDataset 24 | from src.preprocess.nytaxi.ny_loader import NYBikeTaxiLoader 25 | from src.preprocess.hdb.hdb_loader import load_both as load_both_hdb 26 | from src.preprocess.ml_dataset.two_party_loader import TwoPartyLoader as FedSimSynLoader 27 | 28 | if __name__ == '__main__': 29 | # arguments 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--gpu', '-g', type=int, default=0, 32 | help="GPU ID. Set to None if you want to use CPU") 33 | 34 | # parameters for dataset 35 | parser.add_argument('--dataset', '-d', type=str, 36 | help="dataset to use.") 37 | parser.add_argument('--n_parties', '-p', type=int, default=4, 38 | help="number of parties. Should be >=2") 39 | parser.add_argument('--primary_party', '-pp', type=int, default=0, 40 | help="primary party. Should be in [0, n_parties-1]") 41 | parser.add_argument('--splitter', '-sp', type=str, default='imp') 42 | parser.add_argument('--weights', '-w', type=float, default=1, help="weights for the ImportanceSplitter") 43 | parser.add_argument('--beta', '-b', type=float, default=1, help="beta for the CorrelationSplitter") 44 | parser.add_argument('--key-noise', type=float, default=0.0, help='key noise in FedSim synthetic data') 45 | 46 | # parameters for model 47 | parser.add_argument('--epochs', '-e', type=int, default=100) 48 | parser.add_argument('--lr', '-lr', type=float, default=1e-3) 49 | parser.add_argument('--weight_decay', '-wd', type=float, default=1e-5) 50 | parser.add_argument('--batch_size', '-bs', type=int, default=128) 51 | parser.add_argument('--n_classes', '-c', type=int, default=1, 52 | help="number of classes. 1 for regression, 2 for binary classification," 53 | ">=3 for multi-class classification") 54 | parser.add_argument('--metric', '-m', type=str, default='acc', 55 | help="metric to evaluate the model. Supported metrics: [accuracy, rmse]") 56 | parser.add_argument('--result-path', '-rp', type=str, default=None, 57 | help="path to save the result") 58 | parser.add_argument('--seed', '-s', type=int, default=0, help="random seed") 59 | parser.add_argument('--log-dir', '-ld', type=str, default='log', help='log directory') 60 | parser.add_argument('--data-embed-dim', '-ded', type=int, default=200, help='data embedding dimension') 61 | parser.add_argument('--key-embed-dim', '-ked', type=int, default=200, help='key embedding dimension') 62 | parser.add_argument('--num-heads', '-nh', type=int, default=8, help='number of heads in multi-head attention') 63 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout rate') 64 | parser.add_argument('--n-local-blocks', '-nlb', type=int, default=6, help='number of local blocks') 65 | parser.add_argument('--n-agg-blocks', '-nab', type=int, default=6, help='number of aggregation blocks') 66 | parser.add_argument('--knn-k', type=int, default=50, help='k for knn') 67 | parser.add_argument('-v', '--version', type=int, default=3, help='version of the model') 68 | 69 | args = parser.parse_args() 70 | 71 | # print hostname 72 | print(f"Hostname: {os.uname().nodename}") 73 | 74 | syn_root = "data/syn/" 75 | real_root = "data/fedsim-data/" 76 | if args.dataset == 'house': 77 | house_root = f"{real_root}/beijing/" 78 | house_dataset = LocalDataset.from_csv(os.path.join(house_root, "house_clean.csv"), header=1, key_cols=2) 79 | train_dataset, val_dataset, test_dataset = house_dataset.split_train_test( 80 | val_ratio=0.1, test_ratio=0.2, random_state=args.seed) 81 | elif args.dataset == 'taxi': 82 | key_dim = 4 83 | taxi_root = f"{real_root}/nytaxi/" 84 | bike_path = "bike_201606_clean_sample_2e5.pkl" 85 | taxi_path = "taxi_201606_clean_sample_1e5.pkl" 86 | base_loader = NYBikeTaxiLoader(bike_path=os.path.join(taxi_root, bike_path), 87 | taxi_path=os.path.join(taxi_root, taxi_path), link=True) 88 | [X1, X2], y = base_loader.load_parties() 89 | 90 | taxi_dataset = LocalDataset(X1, y, key=None) 91 | train_dataset, val_dataset, test_dataset = taxi_dataset.split_train_test( 92 | val_ratio=0.1, test_ratio=0.2, random_state=args.seed) 93 | elif args.dataset == 'hdb': 94 | key_dim = 2 95 | hdb_path = f"{real_root}/hdb/hdb_clean.csv" 96 | school_path = f"{real_root}/hdb/school_clean.csv" 97 | [X1, X2], y = load_both_hdb(hdb_path, school_path, active_party='hdb') 98 | hdb_dataset = LocalDataset(X1, y, key=None) 99 | train_dataset, val_dataset, test_dataset = hdb_dataset.split_train_test( 100 | val_ratio=0.1, test_ratio=0.2, random_state=args.seed) 101 | 102 | elif args.dataset in ("gisette", "mnist"): 103 | # multi_party_dataset for scalability 104 | key_dim = 4 105 | data_path = PartyPath(dataset_path=args.dataset, n_parties=args.n_parties, party_id=args.primary_party, 106 | splitter=args.splitter, weight=args.weights, beta=args.beta, seed=0, 107 | fmt='pkl').data(None) # use seed=0 to get the same data for all parties 108 | syn_dataset_dir = f"data/syn/{args.dataset}/noise{args.key_noise}/" 109 | solo_dataset = LocalDataset.from_pickle( 110 | os.path.join(syn_dataset_dir, data_path) 111 | ) 112 | 113 | train_dataset, val_dataset, test_dataset = solo_dataset.split_train_test( 114 | val_ratio=0.1, test_ratio=0.2, random_state=args.seed) 115 | 116 | else: 117 | # Note: torch.compile() in torch 2.0 significantly harms the accuracy with little speed up 118 | data_path = PartyPath(dataset_path=args.dataset, n_parties=args.n_parties, party_id=args.primary_party, 119 | splitter=args.splitter, weight=args.weights, beta=args.beta, seed=args.seed, 120 | fmt='pkl').data(None) 121 | solo_dataset = LocalDataset.from_pickle(os.path.join(syn_root, args.dataset, data_path)) 122 | 123 | train_dataset, val_dataset, test_dataset = solo_dataset.split_train_test( 124 | val_ratio=0.1, test_ratio=0.2, random_state=args.seed) 125 | 126 | X_scaler = train_dataset.normalize_() 127 | if val_dataset is not None: 128 | val_dataset.normalize_(scaler=X_scaler) 129 | test_dataset.normalize_(scaler=X_scaler) 130 | 131 | # create the model 132 | scaler = None 133 | if args.n_classes == 1: # regression 134 | task = 'reg' 135 | loss_fn = nn.MSELoss() 136 | out_dim = 1 137 | out_activation = nn.Sigmoid() 138 | if args.metric == 'acc': # if metric is accuracy, change it to rmse 139 | args.metric = 'rmse' 140 | warnings.warn("Metric is changed to rmse for regression task") 141 | # scale the labels to [0, 1] 142 | scaler = train_dataset.scale_y_() 143 | if val_dataset is not None: 144 | val_dataset.scale_y_(scaler=scaler) 145 | test_dataset.scale_y_(scaler=scaler) 146 | elif args.n_classes == 2: # binary classification 147 | task = 'bin-cls' 148 | loss_fn = nn.BCELoss() 149 | out_dim = 1 150 | out_activation = nn.Sigmoid() 151 | # make sure the labels are in [0, 1] 152 | train_dataset.scale_y_() 153 | if val_dataset is not None: 154 | val_dataset.scale_y_() 155 | test_dataset.scale_y_() 156 | else: # multi-class classification 157 | task = 'multi-cls' 158 | loss_fn = nn.CrossEntropyLoss() 159 | out_dim = args.n_classes 160 | out_activation = None # No need for softmax since it is included in CrossEntropyLoss 161 | 162 | # use SplitSum 163 | model = MLP(train_dataset.key_X_dim, [400, 400], out_dim, activation=out_activation) 164 | 165 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 166 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5, verbose=True) 167 | 168 | train_dataset.to_tensor_() 169 | test_dataset.to_tensor_() 170 | if val_dataset is not None: 171 | val_dataset.to_tensor_() 172 | n_workers = 0 173 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=n_workers) 174 | if val_dataset is not None: 175 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=n_workers) 176 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=n_workers) 177 | 178 | metric_fn = get_metric_from_str(args.metric) 179 | metric_positive = get_metric_positive_from_str(args.metric) 180 | 181 | test_loss_list, test_score_list = fit(model, optimizer, loss_fn, metric_fn, train_loader, epochs=args.epochs, 182 | gpu_id=args.gpu, 183 | n_classes=args.n_classes, test_loader=test_loader, task=task, 184 | scheduler=scheduler, has_key=True, 185 | val_loader=val_loader, metric_positive=metric_positive, y_scaler=scaler, 186 | solo=True) 187 | 188 | if args.result_path is not None: 189 | # save test loss and score to a two-column csv file, each row is for one epoch (with pandas) 190 | test_result = pd.DataFrame({'loss': test_loss_list, 'score': test_score_list}) 191 | test_result.to_csv(args.result_path, index=False) 192 | 193 | print("Done!") 194 | -------------------------------------------------------------------------------- /src/train/Evaluate.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") 3 | from typing import Callable 4 | 5 | import os 6 | import sys 7 | 8 | import torch 9 | 10 | 11 | import pandas as pd 12 | from tqdm import tqdm 13 | 14 | # add src to python path 15 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 16 | 17 | from src.utils.BasicUtils import get_device_from_gpu_id, get_metric_from_str, PartyPath 18 | 19 | 20 | 21 | # evaluate the model on the test set 22 | def evaluate(model, test_loader, metric_fn: Callable, gpu_id=0, n_classes=1): 23 | device = get_device_from_gpu_id(gpu_id) 24 | model.to(device) 25 | model.eval() 26 | with torch.no_grad(): 27 | y_all = torch.zeros([0, 1]).to(device) 28 | y_pred_all = torch.zeros([0, 1]).to(device) 29 | for Xs, y in test_loader: 30 | # to device 31 | Xs = [Xi.to(device) for Xi in Xs] 32 | y = y.to(device).reshape(-1, 1) 33 | y_pred = model(Xs) 34 | if n_classes == 1: 35 | y_pred = y_pred.reshape(-1, 1) 36 | else: 37 | y_pred = torch.argmax(y_pred, dim=1).reshape(-1, 1) 38 | y_pred_all = torch.cat((y_pred_all, y_pred), dim=0) 39 | y_all = torch.cat((y_all, y), dim=0) 40 | y_pred_all = y_pred_all.cpu().numpy() 41 | y_all = y_all.cpu().numpy() 42 | return metric_fn(y_pred_all, y_all) 43 | -------------------------------------------------------------------------------- /src/train/Fit.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | import multiprocessing as mp 4 | 5 | warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") 6 | 7 | import os 8 | import sys 9 | import datetime, pytz 10 | 11 | import torch 12 | from torchinfo import summary 13 | from src.model.FeT import FeT 14 | 15 | 16 | import pandas as pd 17 | from tqdm import tqdm 18 | import deprecated 19 | 20 | # add src to python path 21 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 22 | 23 | from src.utils.BasicUtils import get_device_from_gpu_id, get_metric_from_str, PartyPath 24 | 25 | 26 | @deprecated.deprecated(reason="Previously used for handling format, now deprecated") 27 | def preprocess_Xs_y_temp(Xs, y, device, task, solo=False, has_key=False, duplicate_y=None): 28 | Xs, y = preprocess_Xs_y(Xs, y, device, task, solo=solo, has_key=has_key, duplicate_y=duplicate_y) 29 | concated_Xs = [Xi[1].squeeze(1) for Xi in Xs] # shape [128, 1, 2317] --> [128, 2317] 30 | return concated_Xs, y 31 | 32 | 33 | def preprocess_Xs_y(Xs, y, device, task, solo=False, has_key=False, duplicate_y=None): 34 | # pdbr.set_trace() 35 | if not has_key: 36 | Xs = [Xi.to(device) for Xi in Xs] 37 | default_keys = torch.arange(Xs[0].shape[0]).reshape(-1, 1).long().to(device) 38 | Xs = [(default_keys, Xi) for Xi in Xs] 39 | else: 40 | if solo: 41 | Xs = (Xs[0].float().to(device), Xs[1].float().to(device)) 42 | else: 43 | Xs = [(Xi[0].float().to(device), Xi[1].float().to(device)) for Xi in Xs] 44 | y = y.to(device) 45 | y = y.flatten() 46 | if duplicate_y is not None: 47 | y = y.repeat(duplicate_y) 48 | y = y.long() if task == 'multi-cls' else y.float() 49 | return Xs, y 50 | 51 | 52 | def summarize_fet_model(model, loader, depth=4): 53 | sample_Xs = next(iter(loader))[0] 54 | sample_Xs_shape = [tuple(torch.cat([Xi[0], Xi[1]], dim=-1).shape) for Xi in sample_Xs] 55 | key_dim = sample_Xs[0][0].shape[-1] 56 | 57 | # wrap the model with a concatenation 58 | class WrapModel(torch.nn.Module): 59 | def __init__(self, model, key_dim): 60 | super().__init__() 61 | self.model = model 62 | self.key_dim = key_dim 63 | 64 | def forward(self, *key_Xs_concat): 65 | key_Xs = [(key_X[:, :, :self.key_dim], key_X[:, :, self.key_dim:]) for key_X in key_Xs_concat] 66 | return self.model(key_Xs) 67 | 68 | wrap_model = WrapModel(model, key_dim) 69 | stats = summary(wrap_model, sample_Xs_shape, depth=depth) 70 | return stats 71 | 72 | 73 | 74 | def fit(model, optimizer, loss_fn, metric_fn, train_loader, test_loader=None, epochs=10, gpu_id=0, n_classes=1, 75 | task='bin-cls', scheduler=None, has_key=False, val_loader=None, metric_positive=True, y_scaler=None, 76 | solo=False, writer=None, log_timestamp=None, visualize=False, model_path=None, dataset_name=None, fig_dir='fig', 77 | average_pe_freq=None): 78 | device = get_device_from_gpu_id(gpu_id) 79 | model.to(device) 80 | 81 | if fig_dir is not None: 82 | os.makedirs(fig_dir, exist_ok=True) 83 | if isinstance(model, FeT): 84 | stats = summarize_fet_model(model, train_loader, depth=4) 85 | 86 | test_loss_list = [] 87 | test_score_list = [] 88 | best_epoch = -1 89 | if metric_positive: 90 | best_train_score = -np.inf 91 | best_val_score = -np.inf 92 | best_test_score = -np.inf 93 | else: 94 | best_train_score = np.inf 95 | best_val_score = np.inf 96 | best_test_score = np.inf 97 | 98 | for epoch in range(epochs): 99 | model.train() 100 | train_pred_y = train_y = torch.zeros([0, 1], device=device) 101 | train_loss = 0 102 | 103 | for Xs, y in tqdm(train_loader): 104 | # pdbr.set_trace() 105 | Xs, y = preprocess_Xs_y(Xs, y, device, task, solo=solo, has_key=has_key) 106 | 107 | optimizer.zero_grad() 108 | y_pred = model(Xs) 109 | 110 | y_pred = y_pred.flatten() if task in ['reg', 'bin-cls'] else y_pred 111 | loss = loss_fn(y_pred, y) 112 | train_loss += loss.item() 113 | 114 | if n_classes == 2: 115 | y_pred = torch.round(y_pred) 116 | elif n_classes > 2: 117 | y_pred = torch.argmax(y_pred, dim=1).reshape(-1, 1) 118 | 119 | y_pred = y_pred.reshape(-1, 1) 120 | train_pred_y = torch.cat([train_pred_y, y_pred], dim=0) 121 | train_y = torch.cat([train_y, y.reshape(-1, 1)], dim=0) 122 | loss.backward() 123 | optimizer.step() 124 | 125 | train_y_array = train_y.data.cpu().numpy() 126 | train_pred_y_array = train_pred_y.data.cpu().numpy() 127 | if y_scaler is not None: 128 | train_y_array = y_scaler.inverse_transform(train_y_array.reshape(-1, 1)).reshape(-1) 129 | train_pred_y_array = y_scaler.inverse_transform(train_pred_y_array.reshape(-1, 1)).reshape(-1) 130 | # pdbr.set_trace() 131 | train_score = metric_fn(train_y_array, train_pred_y_array) 132 | 133 | timestamp_now = datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y-%m-%d %H:%M:%S") 134 | print(timestamp_now, f"Epoch: {epoch}, Train Loss: {train_loss / len(train_loader)}, Train Score: {train_score}") 135 | if hasattr(model, 'comm_logger'): 136 | model.comm_logger.save_log() 137 | if writer is not None: 138 | writer.add_scalars(f"loss", {f'{log_timestamp}/train': train_loss / len(train_loader)}, epoch) 139 | writer.add_scalars(f"score", {f'{log_timestamp}/train': train_score}, epoch) 140 | 141 | if scheduler is not None and val_loader is None: 142 | scheduler.step(loss.item() / len(train_loader)) 143 | 144 | # if train_loader.dataset.cache_need_update and epoch == 0: 145 | # train_loader.dataset.cache_need_update = False 146 | # train_loader.dataset.cache.save_pkl() 147 | 148 | # if visualize: 149 | # model.visualize_positional_encoding() 150 | # pdbr.set_trace() 151 | 152 | # # debug 153 | # if hasattr(model, 'save_pe_inout') and epoch % 20 == 0 and epoch != 0: 154 | # # inference and save the input and output of pe layer 155 | # model.save_pe_inout(val_loader, f"log/{dataset_name}/_pe_inout_independent", device=device) 156 | 157 | if val_loader is not None: 158 | model.eval() 159 | with torch.no_grad(): 160 | val_loss = torch.zeros(1, device=device) 161 | val_y_pred = val_y = torch.zeros([0, 1], device=device) 162 | for Xs, y in tqdm(val_loader): 163 | Xs, y = preprocess_Xs_y(Xs, y, device, task, solo=solo, has_key=has_key) 164 | 165 | y_pred = model(Xs) 166 | y_pred = y_pred.flatten() if task in ['reg', 'bin-cls'] else y_pred 167 | val_loss += loss_fn(y_pred, y) 168 | 169 | if n_classes == 2: 170 | y_pred = torch.round(y_pred) 171 | elif n_classes > 2: 172 | y_pred = torch.argmax(y_pred, dim=1).reshape(-1, 1) 173 | 174 | if y_pred.isnan().any(): 175 | warnings.warn("y_pred has nan") 176 | y_pred[y_pred.isnan()] = 0 177 | 178 | y_pred = y_pred.reshape(-1, 1) 179 | val_y_pred = torch.cat([val_y_pred, y_pred], dim=0) 180 | val_y = torch.cat([val_y, y.reshape(-1, 1)], dim=0) 181 | 182 | val_y_array = val_y.data.cpu().numpy() 183 | val_y_pred_array = val_y_pred.data.cpu().numpy() 184 | 185 | if y_scaler is not None: 186 | scaled_val_y_array = y_scaler.inverse_transform(val_y_array.reshape(-1, 1)).reshape(-1) 187 | scaled_val_y_pred_array = y_scaler.inverse_transform(val_y_pred_array.reshape(-1, 1)).reshape(-1) 188 | else: 189 | scaled_val_y_array = val_y_array 190 | scaled_val_y_pred_array = val_y_pred_array 191 | try: 192 | val_score = metric_fn(scaled_val_y_array, scaled_val_y_pred_array) 193 | except ValueError as e: 194 | print(f"Error: {e}") 195 | raise e 196 | 197 | val_loss_mean = float(val_loss.cpu().numpy() / len(val_loader)) 198 | timestamp_now = datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y-%m-%d %H:%M:%S") 199 | print(timestamp_now, f"Epoch: {epoch}, Val Loss: {val_loss_mean}, Val Score: {val_score}") 200 | 201 | if writer is not None: 202 | writer.add_scalars(f"loss", {f'{log_timestamp}/val': val_loss_mean}, epoch) 203 | writer.add_scalars(f"score", {f'{log_timestamp}/val': val_score}, epoch) 204 | 205 | scheduler.step(val_score) 206 | 207 | # if val_loader.dataset.cache_need_update and epoch == 0: 208 | # val_loader.dataset.cache_need_update = False 209 | # val_loader.dataset.cache.save_pkl() 210 | 211 | if test_loader is not None: 212 | model.eval() 213 | with torch.no_grad(): 214 | test_loss = torch.zeros(1, device=device) 215 | test_y_pred = test_y = torch.zeros([0, 1], device=device) 216 | for Xs, y in tqdm(test_loader): 217 | Xs, y = preprocess_Xs_y(Xs, y, device, task, solo=solo, has_key=has_key) 218 | 219 | y_pred = model(Xs) 220 | y_pred = y_pred.flatten() if task in ['reg', 'bin-cls'] else y_pred 221 | test_loss += loss_fn(y_pred, y) 222 | 223 | if n_classes == 2: 224 | y_pred = torch.round(y_pred) 225 | elif n_classes > 2: 226 | y_pred = torch.argmax(y_pred, dim=1).reshape(-1, 1) 227 | 228 | if y_pred.isnan().any(): 229 | warnings.warn("y_pred has nan") 230 | y_pred[y_pred.isnan()] = 0 231 | 232 | y_pred = y_pred.reshape(-1, 1) 233 | test_y_pred = torch.cat([test_y_pred, y_pred], dim=0) 234 | test_y = torch.cat([test_y, y.reshape(-1, 1)], dim=0) 235 | 236 | test_y_array = test_y.data.cpu().numpy() 237 | test_y_pred_array = test_y_pred.data.cpu().numpy() 238 | 239 | if y_scaler is not None: 240 | scaled_test_y_array = y_scaler.inverse_transform(test_y_array.reshape(-1, 1)).reshape(-1) 241 | scaled_test_y_pred_array = y_scaler.inverse_transform(test_y_pred_array.reshape(-1, 1)).reshape(-1) 242 | else: 243 | scaled_test_y_array = test_y_array 244 | scaled_test_y_pred_array = test_y_pred_array 245 | try: 246 | test_score = metric_fn(scaled_test_y_array, scaled_test_y_pred_array) 247 | except ValueError as e: 248 | print(f"Error: {e}") 249 | raise e 250 | test_loss_mean = float(test_loss.cpu().numpy() / len(test_loader)) 251 | timestamp_now = datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y-%m-%d %H:%M:%S") 252 | print(timestamp_now, f"Epoch: {epoch}, Test Loss: {test_loss_mean}, Test Score: {test_score}") 253 | test_loss_list.append(test_loss_mean) 254 | test_score_list.append(test_score) 255 | 256 | if writer is not None: 257 | writer.add_scalars(f"loss", {f'{log_timestamp}/test': test_loss_mean}, epoch) 258 | writer.add_scalars(f"score", {f'{log_timestamp}/test': test_score}, epoch) 259 | 260 | # if test_loader.dataset.cache_need_update and epoch == 0: 261 | # test_loader.dataset.cache_need_update = False 262 | # test_loader.dataset.cache.save_pkl() 263 | 264 | if visualize: 265 | model.visualize_positional_encoding(dataset=dataset_name, device=device, 266 | save_path=os.path.join(fig_dir, f"{dataset_name}_epoch{epoch}.png")) 267 | 268 | if val_loader is not None: 269 | if (metric_positive and val_score > best_val_score) or (not metric_positive and val_score < best_val_score): 270 | best_val_score = val_score 271 | best_train_score = train_score 272 | best_test_score = test_score 273 | best_epoch = epoch 274 | if model_path is not None: 275 | model.save(model_path) 276 | print(f"Best epoch: {best_epoch}, Train Score: {best_train_score}, Val Score: {best_val_score}, Test Score: {best_test_score}") 277 | 278 | # Average the positional encoding layer 279 | if (hasattr(model, 'average_pe_') and (average_pe_freq is not None or average_pe_freq == 0) 280 | and epoch % average_pe_freq == 0 and epoch != 0): 281 | model.average_pe_() 282 | return test_loss_list, test_score_list -------------------------------------------------------------------------------- /src/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/FeT/836cd91602b3a0fa6379c5b000b7df288bced790/src/train/__init__.py -------------------------------------------------------------------------------- /src/utils/BasicUtils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from typing import Callable 4 | 5 | 6 | import torch 7 | from sklearn.metrics import accuracy_score, mean_squared_error, r2_score 8 | 9 | from src.metric.RMSE import RMSE 10 | 11 | 12 | class PartyPath: 13 | def __init__(self, dataset_path, n_parties, party_id=0, splitter='imp', weight=1, beta=1, seed=None, 14 | fmt='pkl', comm_root=None): 15 | self.dataset_path = dataset_path 16 | path = pathlib.Path(self.dataset_path) 17 | self.dataset_name = path.stem 18 | self.n_parties = n_parties 19 | self.party_id = party_id 20 | self.splitter = splitter 21 | self.weight = weight 22 | self.beta = beta 23 | self.seed = seed 24 | self.fmt = fmt 25 | self.comm_root = comm_root # the root of communication log 26 | 27 | def data(self, type='train'): 28 | path = pathlib.Path(self.dataset_path) 29 | if type is None: 30 | type_str = '' 31 | else: 32 | type_str = f"_{type}" 33 | if self.splitter == 'imp': 34 | # insert meta information before the file extension (extension may not be .csv) 35 | path = path.with_name(f"{path.stem}_party{self.n_parties}-{self.party_id}_{self.splitter}" 36 | f"_weight{self.weight:.1f}" 37 | f"{'_seed' + str(self.seed) if self.seed is not None else ''}{type_str}.{self.fmt}") 38 | elif self.splitter == 'corr': 39 | path = path.with_name(f"{path.stem}_party{self.n_parties}-{self.party_id}_{self.splitter}" 40 | f"_beta{self.beta:.1f}" 41 | f"{'_seed' + str(self.seed) if self.seed is not None else ''}{type_str}.{self.fmt}") 42 | else: 43 | raise NotImplementedError(f"Splitter {self.splitter} is not implemented. " 44 | f"Splitter should be in ['imp', 'corr']") 45 | return str(path) 46 | 47 | @property 48 | def train_data(self): 49 | return self.data('train') 50 | 51 | @property 52 | def test_data(self): 53 | return self.data('test') 54 | 55 | @property 56 | def comm_log(self): 57 | if self.comm_root is None: 58 | raise FileNotFoundError("comm_root is None") 59 | comm_dir = os.path.join(self.comm_root, self.dataset_name) 60 | os.makedirs(comm_dir, exist_ok=True) 61 | path = pathlib.Path(comm_dir) 62 | if self.splitter == 'imp': 63 | path = path / (f"{self.dataset_name}_party{self.n_parties}_{self.splitter}_weight{self.weight:.1f}" 64 | f"{'_seed' + str(self.seed) if self.seed is not None else ''}.log") 65 | elif self.splitter == 'corr': 66 | path = path / (f"{self.dataset_name}_party{self.n_parties}_{self.splitter}_beta{self.beta:.1f}" 67 | f"{'_seed' + str(self.seed) if self.seed is not None else ''}.log") 68 | else: 69 | raise NotImplementedError(f"Splitter {self.splitter} is not implemented." 70 | f" splitter should be in ['imp', 'corr']") 71 | return str(path) 72 | 73 | # def party_path(dataset_path, n_parties, party_id, splitter='imp', weight=1, beta=1, seed=None, type='train', 74 | # fmt='pkl') -> str: 75 | # assert type in ['train', 'test'] 76 | # path = pathlib.Path(dataset_path) 77 | # if splitter == 'imp': 78 | # # insert meta information before the file extension (extension may not be .csv) 79 | # path = path.with_name(f"{path.stem}_party{n_parties}-{party_id}_{splitter}" 80 | # f"_weight{weight:.1f}{'_seed' + str(seed) if seed is not None else ''}_{type}.{fmt}") 81 | # elif splitter == 'corr': 82 | # path = path.with_name(f"{path.stem}_party{n_parties}-{party_id}_{splitter}" 83 | # f"_beta{beta:.1f}{'_seed' + str(seed) if seed is not None else ''}_{type}.{fmt}") 84 | # else: 85 | # raise NotImplementedError(f"Splitter {splitter} is not implemented. splitter should be in ['imp', 'corr']") 86 | # return str(path) 87 | 88 | 89 | def get_device_from_gpu_id(gpu_id): 90 | if gpu_id is None: 91 | return torch.device('cpu') 92 | else: 93 | return torch.device(f'cuda:{gpu_id}') 94 | 95 | 96 | def get_metric_from_str(metric) -> Callable: 97 | supported_list = ['acc', 'rmse', 'r2'] 98 | assert metric in supported_list 99 | if metric == 'acc': 100 | return accuracy_score 101 | elif metric == 'rmse': 102 | return lambda y_true, y_pred: RMSE()(y_true, y_pred) 103 | elif metric == 'r2': 104 | return r2_score 105 | else: 106 | raise NotImplementedError(f"Metric {metric} is not implemented. metric should be in {supported_list}") 107 | 108 | 109 | def get_metric_positive_from_str(metric) -> bool: 110 | supported_list = ['acc', 'rmse', 'r2'] 111 | assert metric in supported_list 112 | if metric in ['acc', 'r2']: 113 | return True 114 | elif metric in ['rmse']: 115 | return False 116 | else: 117 | raise NotImplementedError(f"Metric {metric} is not implemented. metric should be in {supported_list}") 118 | 119 | 120 | def get_split_points(array, size): 121 | assert size > 1 122 | 123 | prev = array[0] 124 | split_points = [0] 125 | for i in range(1, size): 126 | if prev != array[i]: 127 | prev = array[i] 128 | split_points.append(i) 129 | 130 | split_points.append(size) 131 | return split_points 132 | 133 | 134 | def move_item_to_end_(arr, items): 135 | for item in items: 136 | arr.insert(len(arr), arr.pop(arr.index(item))) 137 | 138 | 139 | def move_item_to_start_(arr, items): 140 | for item in items[::-1]: 141 | arr.insert(0, arr.pop(arr.index(item))) 142 | 143 | 144 | def equal_split(n, k): 145 | if n % k == 0: 146 | return [n // k for _ in range(k)] 147 | else: 148 | return [n // k for _ in range(k - 1)] + [n % k] 149 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .BasicUtils import PartyPath, get_device_from_gpu_id, get_metric_from_str, get_metric_positive_from_str 2 | from .logger import CommLogger, CommRecord 3 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | class CommRecord: 6 | def __init__(self, from_party_id, to_party_id, size): 7 | """ 8 | Record of communication size 9 | 10 | Parameters 11 | ---------- 12 | from_party_id : int 13 | ID of the party that sends data 14 | to_party_id : int 15 | ID of the party that receives data 16 | size : int 17 | size of data 18 | """ 19 | self.from_party_id = from_party_id 20 | self.to_party_id = to_party_id 21 | self.size = size 22 | 23 | def __str__(self): 24 | return f"{self.from_party_id},{self.to_party_id},{self.size}" 25 | 26 | 27 | class CommLogger: 28 | def __init__(self, n_parties, path=None): 29 | """ 30 | Logger of communication size 31 | 32 | Parameters 33 | ---------- 34 | n_parties : int 35 | number of parties. A server is also included as a party. 36 | """ 37 | self.n_parties = n_parties 38 | self.path = path 39 | self.comm_records = [] 40 | 41 | self.in_comm = [0. for _ in range(n_parties)] # communication size received by each party 42 | self.out_comm = [0. for _ in range(n_parties)] # communication size sent by each party 43 | 44 | def comm(self, from_party_id, to_party_id, size: int): 45 | """ 46 | Record communication size in bytes 47 | 48 | Parameters 49 | ---------- 50 | from_party_id : int 51 | ID of the party that sends data 52 | to_party_id : int 53 | ID of the party that receives data 54 | size : int 55 | size of data 56 | """ 57 | self.comm_records.append((from_party_id, to_party_id, size)) 58 | self.in_comm[to_party_id] += size 59 | self.out_comm[from_party_id] += size 60 | 61 | def broadcast(self, from_party_id, size): 62 | """ 63 | Record the communication from a party to all other parties 64 | 65 | Parameters 66 | ---------- 67 | from_party_id : int 68 | ID of the party that sends data 69 | size : int 70 | size of data 71 | """ 72 | for to_party_id in range(self.n_parties): 73 | if to_party_id != from_party_id: 74 | self.comm(from_party_id, to_party_id, size) 75 | 76 | def receive_all(self, to_party_id, size): 77 | """ 78 | Record the communication from all other parties to a party 79 | 80 | Parameters 81 | ---------- 82 | to_party_id : int 83 | ID of the party that receives data 84 | size : int 85 | size of data 86 | """ 87 | for from_party_id in range(self.n_parties): 88 | if from_party_id != to_party_id: 89 | self.comm(from_party_id, to_party_id, size) 90 | 91 | def save_log(self): 92 | """ 93 | Save the log to a csv file 94 | """ 95 | columns = ["From", "To", "Size"] 96 | df = pd.DataFrame(self.comm_records, columns=columns) 97 | df.to_csv(self.path, index=False) 98 | 99 | @classmethod 100 | def load_log(cls, path): 101 | """ 102 | Save the log to a csv file 103 | 104 | Parameters 105 | ---------- 106 | path : str 107 | path of the csv file 108 | """ 109 | data = pd.read_csv(path) 110 | n_parties = max(data["From"].max(), data["To"].max()) + 1 111 | logger = cls(n_parties) 112 | comm_records = data.values 113 | 114 | def add_row(row): 115 | logger.in_comm[row[1]] += row[2] 116 | logger.out_comm[row[0]] += row[2] 117 | np.apply_along_axis(add_row, axis=1, arr=comm_records) 118 | logger.comm_records = comm_records.tolist() 119 | 120 | return logger 121 | 122 | @property 123 | def total_comm_bytes(self): 124 | # each communication is counted twice (one in from_party_id, one in to_party_id) 125 | return (sum(self.in_comm) + sum(self.out_comm)) / 2 126 | 127 | @property 128 | def max_in_comm_bytes(self): 129 | return max(self.in_comm) 130 | 131 | @property 132 | def max_out_comm_bytes(self): 133 | return max(self.out_comm) 134 | 135 | @property 136 | def total_comm_kB(self): 137 | return self.total_comm_bytes / 1024 138 | 139 | @property 140 | def max_in_comm_kB(self): 141 | return self.max_in_comm_bytes / 1024 142 | 143 | @property 144 | def max_out_comm_kB(self): 145 | return self.max_out_comm_bytes / 1024 146 | 147 | @property 148 | def total_comm_MB(self): 149 | return self.total_comm_kB / 1024 150 | 151 | @property 152 | def max_in_comm_MB(self): 153 | return self.max_in_comm_kB / 1024 154 | 155 | @property 156 | def max_out_comm_MB(self): 157 | return self.max_out_comm_kB / 1024 158 | 159 | @property 160 | def total_comm_GB(self): 161 | return self.total_comm_MB / 1024 162 | 163 | @property 164 | def max_in_comm_GB(self): 165 | return self.max_in_comm_MB / 1024 166 | 167 | @property 168 | def max_out_comm_GB(self): 169 | return self.max_out_comm_MB / 1024 170 | 171 | 172 | --------------------------------------------------------------------------------