├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── FedTree_draft_paper.pdf ├── LICENSE ├── README.md ├── dataset ├── adult │ ├── a9a_horizontal_p0 │ ├── a9a_horizontal_p1 │ ├── a9a_horizontal_test │ ├── a9a_vertical_p0 │ ├── a9a_vertical_p1 │ └── a9a_vertical_test ├── credit │ ├── credit_vertical_p0_withlabel.csv │ └── credit_vertical_p1.csv └── test_dataset.txt ├── docs ├── Makefile ├── make.bat └── source │ ├── Examples.rst │ ├── Experiments.rst │ ├── Frameworks.rst │ ├── Installation.rst │ ├── Parameters.rst │ ├── Quick-Start.rst │ ├── conf.py │ ├── images │ ├── fedtree_archi.png │ ├── fedtree_hori.png │ ├── fedtree_verti.png │ ├── hori_fram.png │ ├── hori_fram_he.png │ └── verti_fram.png │ └── index.rst ├── examples ├── abalone │ ├── a9a_horizontal_p1.conf │ ├── a9a_horizontal_server.conf │ ├── a9a_vertical_p0.conf │ ├── a9a_vertical_p1.conf │ ├── aba_horizontal_p0.conf │ ├── aba_horizontal_p1.conf │ ├── aba_horizontal_server.conf │ ├── aba_vertical_p0.conf │ └── aba_vertical_p1.conf ├── adult │ ├── a9a_horizontal_p0.conf │ ├── a9a_horizontal_p1.conf │ ├── a9a_horizontal_server.conf │ ├── a9a_vertical_p0.conf │ └── a9a_vertical_p1.conf ├── breast_distributed_horizontal_server.conf ├── credit │ ├── credit_vertical_p0.conf │ └── credit_vertical_p1.conf ├── horizontal_example.conf ├── prediction.conf └── vertical_example.conf ├── include └── FedTree │ ├── DP │ ├── differential_privacy.h │ └── noises.h │ ├── Encryption │ ├── diffie_hellman.h │ ├── paillier.h │ ├── paillier_gmp.h │ └── paillier_gpu.h │ ├── FL │ ├── FLparam.h │ ├── FLtrainer.h │ ├── comm_helper.h │ ├── distributed_party.h │ ├── distributed_server.h │ ├── partition.h │ ├── party.h │ └── server.h │ ├── Tree │ ├── GBDTparam.h │ ├── function_builder.h │ ├── gbdt.h │ ├── hist_cut.h │ ├── hist_tree_builder.h │ ├── histogram.h │ ├── splitpoint.h │ ├── tree.h │ └── tree_builder.h │ ├── booster.h │ ├── common.h │ ├── config.h.in │ ├── dataset.h │ ├── metric │ ├── metric.h │ ├── multiclass_metric.h │ ├── pointwise_metric.h │ └── ranking_metric.h │ ├── objective │ ├── multiclass_obj.h │ ├── objective_function.h │ ├── ranking_obj.h │ └── regression_obj.h │ ├── parser.h │ ├── predictor.h │ ├── syncarray.h │ ├── syncmem.h │ ├── trainer.h │ └── util │ ├── cub_wrapper.h │ ├── device_lambda.cuh │ ├── dirichlet.h │ ├── log.h │ └── multi_device.h ├── python ├── LICENSE ├── README.md ├── examples │ ├── classifier_example.py │ └── regressor_example.py ├── fedtree │ ├── __init__.py │ └── fedtree.py └── setup.py ├── src ├── FedTree │ ├── CMakeLists.txt │ ├── DP │ │ └── differential_privacy.cpp │ ├── Encryption │ │ ├── diffie_hellman.cpp │ │ ├── paillier.cpp │ │ ├── paillier_gmp.cpp │ │ └── paillier_gpu.cu │ ├── FL │ │ ├── FLtrainer.cpp │ │ ├── partition.cpp │ │ ├── party.cpp │ │ └── server.cpp │ ├── Tree │ │ ├── function_builder.cpp │ │ ├── gbdt.cpp │ │ ├── hist_cut.cpp │ │ ├── hist_tree_builder.cpp │ │ ├── tree.cpp │ │ └── tree_builder.cpp │ ├── booster.cpp │ ├── dataset.cpp │ ├── distributed_party.cpp │ ├── distributed_server.cpp │ ├── fedtree_predict.cpp │ ├── fedtree_train.cpp │ ├── grpc │ │ ├── fedtree.grpc.pb.cc │ │ ├── fedtree.grpc.pb.h │ │ ├── fedtree.pb.cc │ │ ├── fedtree.pb.h │ │ └── fedtree.proto │ ├── metric │ │ ├── metric.cpp │ │ ├── multiclass_metric.cpp │ │ ├── pointwise_metric.cpp │ │ └── rank_metric.cpp │ ├── objective │ │ ├── multiclass_obj.cpp │ │ ├── objective_function.cpp │ │ └── ranking_obj.cpp │ ├── parser.cpp │ ├── predictor.cpp │ ├── scikit_fedtree.cpp │ ├── syncmem.cpp │ └── util │ │ ├── common.cpp │ │ └── log.cpp └── test │ ├── CMakeLists.txt │ ├── test_dataset.cpp │ ├── test_find_feature_range.cpp │ ├── test_gbdt.cpp │ ├── test_gradient.cpp │ ├── test_main.cpp │ ├── test_parser.cpp │ ├── test_partition.cpp │ ├── test_tree.cpp │ └── test_tree_builder.cpp └── utils ├── FedTree_hcut.py └── FedTree_vcut.py /.gitignore: -------------------------------------------------------------------------------- 1 | /CMakeLists.txt 2 | /cmake-build-debug/ 3 | /build 4 | .DS_Store 5 | .vscode 6 | build.sh 7 | test.sh 8 | .idea/ 9 | tgbm.model 10 | .gradle 11 | /ntl-11.4.4/ 12 | 13 | logs/* 14 | cut_dataset/* 15 | fedtree-datasets/* 16 | dataset/* 17 | cut_dataset/* 18 | dataset/breast-cancer_scale 19 | 20 | dataset/diabetes_scale 21 | 22 | dataset/mnist.scale 23 | 24 | python/FedTree/libFedTree.dylib 25 | python/FedTree/libFedTree.so 26 | python/FedTree/libFedTree.dll 27 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "src/googletest"] 2 | path = src/test/googletest 3 | url = https://github.com/google/googletest 4 | #[submodule "cub"] 5 | # path = cub 6 | # url = https://github.com/NVlabs/cub.git 7 | [submodule "thrust"] 8 | path = thrust 9 | url = https://github.com/NVIDIA/thrust.git 10 | [submodule "CGBN"] 11 | path = CGBN 12 | url = https://github.com/QinbinLi/CGBN.git 13 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.15) 2 | #cmake_policy(SET CMP0104 NEW) 3 | #cmake_policy(SET CMP0048 NEW) 4 | cmake_policy(SET CMP0042 NEW) 5 | project(FedTree LANGUAGES C CXX) 6 | #enable_language(CUDA) 7 | if(MSVC) 8 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) 9 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) 10 | endif() 11 | 12 | option(BUILD_SHARED_LIBS "Build as a shared library" ON) 13 | option(USE_CUDA "Compile with CUDA for homomorphic encryption" OFF) 14 | option(USE_CUDA_ARRAY "Compile with CUDA for training" OFF) 15 | option(DISTRIBUTED "Build for distributed computing" ON) 16 | option(USE_DOUBLE "Use double as gradient_type" OFF) 17 | set(BUILD_TESTS OFF CACHE BOOL "Build Tests") 18 | set(NTL_PATH "~/usr/local" CACHE STRING "NTL Path") 19 | 20 | #find_package(Threads) 21 | find_package(OpenMP REQUIRED) 22 | #find_package(GMP REQUIRED) 23 | #find_package(Python COMPONENTS Interpreter Development REQUIRED) 24 | #find_package(pybind11 CONFIG REQUIRED) 25 | if (NOT CMAKE_BUILD_TYPE) 26 | set(CMAKE_BUILD_TYPE Release) 27 | endif () 28 | if (MSVC AND BUILD_SHARED_LIBS) 29 | set (CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) 30 | endif () 31 | 32 | if (USE_CUDA) 33 | message("Compile with CUDA") 34 | find_package(CUDA REQUIRED QUIET) 35 | 36 | # include(FindCUDA/select_compute_arch) 37 | # CUDA_DETECT_INSTALLED_GPUS(INSTALLED_GPU_CCS_1) 38 | # string(STRIP "${INSTALLED_GPU_CCS_1}" INSTALLED_GPU_CCS_2) 39 | # string(REPLACE " " ";" INSTALLED_GPU_CCS_3 "${INSTALLED_GPU_CCS_2}") 40 | # string(REPLACE "." "" CUDA_ARCH_LIST "${INSTALLED_GPU_CCS_3}") 41 | # SET(CMAKE_CUDA_ARCHITECTURES ${CUDA_ARCH_LIST}) 42 | 43 | # set(CMAKE_CUDA_ARCHITECTURES 75 CACHE STRING "CUDA architectures" FORCE) 44 | # set(CMAKE_CUDA) 45 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -std=c++11 -lineinfo --expt-extended-lambda --default-stream per-thread") 46 | # set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -std=c++11") 47 | include_directories(${PROJECT_SOURCE_DIR}/CGBN/include) 48 | # cuda_include_directories(${PROJECT_SOURCE_DIR}/CGBN/include) 49 | # set(CGBN_HEADER ${PROJECT_SOURCE_DIR}/CGBN/include) 50 | # set_source_files_properties(DIRECTORY ${PROJECT_SOURCE_DIR}/CGBN/include PROPERTIES LANGUAGE CUDA) 51 | add_subdirectory(${PROJECT_SOURCE_DIR}/thrust/) 52 | find_package(Thrust REQUIRED) 53 | thrust_create_target(ThrustOMP HOST CPP DEVICE OMP) 54 | include_directories(${NTL_PATH}/include/) 55 | else () 56 | message("Complie without CUDA") 57 | #set(Thrust_DIR "${PROJECT_SOURCE_DIR}/thrust/cmake/") 58 | add_subdirectory(${PROJECT_SOURCE_DIR}/thrust/) 59 | find_package(Thrust REQUIRED) 60 | thrust_create_target(ThrustOMP HOST CPP DEVICE OMP) 61 | # include_directories(/usr/local/include/) 62 | include_directories(${NTL_PATH}/include/) 63 | endif () 64 | 65 | # add_subdirectory(${PROJECT_SOURCE_DIR}/pybind11) 66 | 67 | if (DISTRIBUTED) 68 | include(FetchContent) 69 | FetchContent_Declare( 70 | gRPC 71 | GIT_REPOSITORY https://github.com/grpc/grpc 72 | GIT_TAG v1.35.0 # e.g v1.28.0 73 | ) 74 | set(FETCHCONTENT_QUIET OFF) 75 | FetchContent_MakeAvailable(gRPC) 76 | endif () 77 | 78 | if (CMAKE_VERSION VERSION_LESS "3.1") 79 | add_compile_options("-std=c++11") 80 | else () 81 | set(CMAKE_CXX_STANDARD 11) 82 | endif () 83 | 84 | if (OPENMP_FOUND) 85 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 86 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 87 | endif () 88 | 89 | # for easylogging++ configuration 90 | add_definitions("-DELPP_FEATURE_PERFORMANCE_TRACKING") 91 | add_definitions("-DELPP_THREAD_SAFE") 92 | add_definitions("-DELPP_STL_LOGGING") 93 | add_definitions("-DELPP_NO_LOG_TO_FILE") 94 | 95 | 96 | 97 | # includes 98 | set(COMMON_INCLUDES ${PROJECT_SOURCE_DIR}/include ${CMAKE_CURRENT_BINARY_DIR}) 99 | if(USE_CUDA) 100 | list(REMOVE_ITEM COMMON_INCLUDES "${PROJECT_SOURCE_DIR}/include/FedTree/Encryption/paillier.h") 101 | else() 102 | list(REMOVE_ITEM COMMON_INCLUDES "${PROJECT_SOURCE_DIR}/include/FedTree/Encryption/paillier_gpu.h" "${PROJECT_SOURCE_DIR}/include/FedTree/Encryption/paillier_gmp.h") 103 | endif() 104 | 105 | 106 | set(DATASET_DIR ${PROJECT_SOURCE_DIR}/dataset/) 107 | configure_file(include/FedTree/config.h.in config.h) 108 | 109 | include_directories(${COMMON_INCLUDES}) 110 | 111 | 112 | if (USE_CUDA) 113 | include_directories(${PROJECT_SOURCE_DIR}/cub) 114 | endif () 115 | add_subdirectory(${PROJECT_SOURCE_DIR}/src/FedTree) 116 | 117 | if (BUILD_TESTS) 118 | message("Building tests") 119 | add_subdirectory(src/test) 120 | endif () 121 | 122 | # configuration file 123 | 124 | -------------------------------------------------------------------------------- /FedTree_draft_paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/FedTree/ba4ea5eb4a75742a8b1bc986868e3d310260dc97/FedTree_draft_paper.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [Documentation](https://fedtree.readthedocs.io/en/latest/index.html) 2 | 3 | # Overview 4 | **FedTree** is a federated learning system for tree-based models. It is designed to be highly **efficient**, **effective**, 5 | and **secure**. It has the following features currently. 6 | 7 | - Federated training of gradient boosting decision trees. 8 | - Parallel computing on multi-core CPUs and GPUs. 9 | - Supporting homomorphic encryption, secure aggregation and differential privacy. 10 | - Supporting classification and regression. 11 | 12 | The overall architecture of FedTree is shown below. 13 | ![FedTree_archi](./docs/source/images/fedtree_archi.png) 14 | 15 | # Getting Started 16 | You can refer to our primary documentation [here](https://fedtree.readthedocs.io/en/latest/index.html). 17 | ## Prerequisites 18 | * [CMake](https://cmake.org/) 3.15 or above 19 | * [GMP](https://gmplib.org/) library 20 | * [NTL](https://libntl.org/) 21 | 22 | You can follow the following commands to install NTL library. 23 | 24 | ``` 25 | wget https://libntl.org/ntl-11.4.4.tar.gz 26 | tar -xvf ntl-11.4.4.tar.gz 27 | cd ntl-11.4.4/src 28 | ./configure SHARED=on 29 | make 30 | make check 31 | sudo make install 32 | ``` 33 | 34 | 35 | If you install the NTL library at another location, please pass the location to the `NTL_PATH` when building the library (e.g., `cmake .. -DNTL_PATH="PATH_TO_NTL"`). 36 | ## Clone and Install submodules 37 | ``` 38 | git clone https://github.com/Xtra-Computing/FedTree.git 39 | cd FedTree 40 | git submodule init 41 | git submodule update 42 | ``` 43 | # Standalone Simulation 44 | 45 | ## Build on Linux 46 | 47 | ```bash 48 | # under the directory of FedTree 49 | mkdir build && cd build 50 | cmake .. 51 | make -j 52 | ``` 53 | 54 | ## Build on MacOS 55 | 56 | ### Build with Apple Clang 57 | 58 | You need to install ```libomp``` for MacOS. 59 | ``` 60 | brew install libomp 61 | ``` 62 | 63 | Install FedTree: 64 | ```bash 65 | # under the directory of FedTree 66 | mkdir build 67 | cd build 68 | cmake -DOpenMP_C_FLAGS="-Xpreprocessor -fopenmp -I/usr/local/opt/libomp/include" \ 69 | -DOpenMP_C_LIB_NAMES=omp \ 70 | -DOpenMP_CXX_FLAGS="-Xpreprocessor -fopenmp -I/usr/local/opt/libomp/include" \ 71 | -DOpenMP_CXX_LIB_NAMES=omp \ 72 | -DOpenMP_omp_LIBRARY=/usr/local/opt/libomp/lib/libomp.dylib \ 73 | .. 74 | make -j 75 | ``` 76 | 77 | ## Run training 78 | ```bash 79 | # under 'FedTree' directory 80 | ./build/bin/FedTree-train ./examples/vertical_example.conf 81 | ``` 82 | 83 | 84 | # Distributed Setting 85 | For each machine that participates in FL, it needs to build the library first. 86 | ```bash 87 | mkdir build && cd build 88 | cmake .. -DDISTRIBUTED=ON 89 | make -j 90 | ``` 91 | Then, write your configuration file where you should specify the ip address of the server machine (`ip_address=xxx`). Run `FedTree-distributed-server` in the server machine and run `FedTree-distributed-party` in the party machines. 92 | Here are two examples for horizontal FedTree and vertical FedTree. 93 | 94 | [//]: # (export CPLUS_INCLUDE_PATH=./build/_deps/grpc-src/include/:$CPLUS_INCLUDE_PATH) 95 | [//]: # (export CPLUS_INCLUDE_PATH=./build/_deps/grpc-src/third_party/protobuf/src/:$CPLUS_INCLUDE_PATH) 96 | 97 | ## Distributed Horizontal FedTree 98 | ```bash 99 | # under 'FedTree' directory 100 | # under server machine 101 | ./build/bin/FedTree-distributed-server ./examples/adult/a9a_horizontal_server.conf 102 | # under party machine 0 103 | ./build/bin/FedTree-distributed-party ./examples/adult/a9a_horizontal_p0.conf 0 104 | # under party machine 1 105 | ./build/bin/FedTree-distributed-party ./examples/adult/a9a_horizontal_p1.conf 1 106 | ``` 107 | 108 | ## Distributed Vertical FedTree 109 | ```bash 110 | # under 'FedTree' directory 111 | # under server (i.e., the party with label) machine 0 112 | ./build/bin/FedTree-distributed-server ./examples/credit/credit_vertical_p0_withlabel.conf 113 | # open a new terminal 114 | ./build/bin/FedTree-distributed-party ./examples/credit/credit_vertical_p0_withlabel.conf 0 115 | # Under party machine 1 116 | ./build/bin/FedTree-distributed-party ./examples/credit/credit_vertical_p1.conf 1 117 | ``` 118 | 119 | # Other information 120 | FedTree is built based on [ThunderGBM](https://github.com/Xtra-Computing/thundergbm), which is a fast GBDTs and Radom Forests training system on GPUs. 121 | 122 | # Citation 123 | Please cite our paper if you use FedTree in your work. 124 | ``` 125 | @misc{fedtree, 126 | title = {FedTree: A Fast, Effective, and Secure Tree-based Federated Learning System}, 127 | author={Li, Qinbin and Cai, Yanzheng and Han, Yuxuan and Yung, Ching Man and Fu, Tianyuan and He, Bingsheng}, 128 | howpublished = {\url{https://github.com/Xtra-Computing/FedTree/blob/main/FedTree_draft_paper.pdf}}, 129 | year={2022} 130 | } 131 | ``` 132 | 133 | 134 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/Examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | Here we present several additional examples of using FedTree. 5 | 6 | Distributed Horizontal FedTree with Secure Aggregation 7 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 8 | In the horizontal FedTree, the parties have their local datasets with the same feature space but different sample spaces. 9 | Also, in each machine, a configuration file needs to be prepared. 10 | We take UCI `Adult `_ dataset as an example (partitioned data provided in `here `__). 11 | 12 | In the server machine, the configuration file `server.conf` can be: 13 | 14 | .. code:: 15 | 16 | test_data=./dataset/adult/a9a_horizontal_test 17 | n_parties=2 18 | objective=binary:logistic 19 | mode=horizontal 20 | partition=0 21 | privacy_tech=sa 22 | learning_rate=0.1 23 | max_depth=6 24 | n_trees=50 25 | 26 | In the above configuration file, it needs to specifies number of parties, objective function, mode, privacy techniques, and other parameters for the GBDT model. 27 | The `test_data` specifies the dataset for testing. 28 | 29 | Supposing the ip address of the server is a.b.c.d, in the party machine 1, the configuration file `party1.conf` can be: 30 | 31 | .. code:: 32 | 33 | data=./dataset/adult/a9a_horizontal_p0 34 | test_data=./dataset/adult/a9a_horizontal_test 35 | model_path=p1.model 36 | n_parties=2 37 | objective=binary:logistic 38 | mode=horizontal 39 | partition=0 40 | privacy_tech=sa 41 | learning_rate=0.1 42 | max_depth=6 43 | n_trees=50 44 | ip_address=a.b.c.d 45 | 46 | The difference between `party1.conf` and `server.conf` is that `party1.conf` needs to specify the path to the local data and the ip address of the server. 47 | Similarly, we can have a configuration file for each party machine by changing the `data` (and `model_path` if needed). Then, we can run the following commands in the corresponding machines. 48 | 49 | .. code:: 50 | 51 | # under 'FedTree' directory 52 | # under server machine 53 | ./build/bin/FedTree-distributed-server ./server.conf 54 | # under party machine 1 55 | ./build/bin/FedTree-distributed-party ./party1.conf 0 56 | # under party machine 2 57 | ./build/bin/FedTree-distributed-party ./party2.conf 1 58 | ...... 59 | 60 | In the above commands, the party machines need to add an additional input ID starting from 0 as its party ID. 61 | 62 | Distributed Vertical FedTree with Homomorphic Encryption 63 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 64 | In the vertical FedTree, the parties have their local datasets with the same sample space but different feature spaces. 65 | Moreover, at least one party has the labels of the samples. We need to specify one of the parties that has labels as the aggregator. 66 | Suppose party machine 1 is the aggregator. Then, we need to write a server configuration file `server.conf`, e.g., 67 | 68 | .. code:: 69 | 70 | data=./dataset/adult/a9a_vertical_p0 71 | test_data=./dataset/adult/a9a_vertical_test 72 | n_parties=2 73 | mode=vertical 74 | partition=0 75 | reorder_label=1 76 | objective=binary:logistic 77 | privacy_tech=he 78 | learning_rate=0.1 79 | max_depth=6 80 | n_trees=50 81 | 82 | For each party machine, supposing the ip address of the aggregator is a.b.c.d, we need to write a configuration file, e.g., `party1.conf` in party 1 83 | 84 | .. code:: 85 | 86 | data=./dataset/adult/a9a_vertical_p0 87 | test_data=./dataset/adult/a9a_vertical_test 88 | model_path=p1.model 89 | n_parties=2 90 | mode=vertical 91 | partition=0 92 | reorder_label=1 93 | objective=binary:logistic 94 | privacy_tech=he 95 | learning_rate=0.1 96 | max_depth=6 97 | n_trees=50 98 | ip_address=a.b.c.d 99 | 100 | Then, we can run the following commands in the corresponding machines: 101 | 102 | .. code:: 103 | 104 | #under aggregator machine (i.e., party machine 1) 105 | ./build/bin/FedTree-distributed-server ./server.conf 106 | #under party machine 1 107 | ./build/bin/FedTree-distributed-party ./party1.conf 0 108 | #under party machine 2 109 | ./build/bin/FedTree-distributed-party ./party2.conf 1 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /docs/source/Experiments.rst: -------------------------------------------------------------------------------- 1 | Experiments 2 | =========== 3 | Here we present some preliminary experimental results. We use two UCI datasets, `adult `__ and `abalone `_ for experiments. 4 | The adult dataset is a classification dataset and the abalone is a regression dataset. We use FedTree-Hori to denote the horizontal FedTree and FedTree-Verti to denote the vertical FedTree. 5 | 6 | Baselines: Homo-SecureBoost and Hetero-SecureBoost. Both approaches are from `FATE `_. 7 | 8 | 9 | Standalone Simulation 10 | ~~~~~~~~~~~~~~~~~~~~~ 11 | For the standalone simulation, we use a machine with 64*Intel Xeon Gold 6226R CPUs and 8*NVIDIA GeForce RTX 3090 to conduct experiments. 12 | We allocate each experiment with 16 threads. By default, we set the number of parties to 2, the number of trees to 50, learning rate to 0.1, the maximum depth of tree to 6, and the maximum number of bins to 255. 13 | The other parameters of all approaches are set to the default setting of FedTree. 14 | 15 | Effectiveness 16 | ^^^^^^^^^^^^^ 17 | We first compare the accuracy of federated training and centralized training using `XGBoost `_ and `ThunderGBM `_. The results are shown below. 18 | We report AUC for adult and RMSE for abalone. We can observe that the performance of FedTree is same as ThunderGBM. Also, SA and HE do not affect the model performance. 19 | 20 | +----------+---------+------------+--------------+-----------------+---------------+------------------+------------------+--------------------+ 21 | | datasets | XGBoost | ThunderGBM | FedTree-Hori | FedTree-Hori+SA | FedTree-Verti | FedTree-Verti+HE | Homo-SecureBoost | Hetero-SecureBoost | 22 | +----------+---------+------------+--------------+-----------------+---------------+------------------+------------------+--------------------+ 23 | | a9a | 0.914 | 0.914 | 0.914 | 0.914 | 0.914 | 0.914 | 0.912 | 0.914 | 24 | +----------+---------+------------+--------------+-----------------+---------------+------------------+------------------+--------------------+ 25 | | abalone | 1.53 | 1.57 | 1.57 | 1.57 | 1.56 | 1.57 | 1.56 | 0.001 | 26 | +----------+---------+------------+--------------+-----------------+---------------+------------------+------------------+--------------------+ 27 | 28 | Efficiency 29 | ^^^^^^^^^^ 30 | 31 | We compare the efficiency of FedTree-Hori with Homo-SecureBoost of FATE. The results are shown below. We present the trainig time (s) per tree. 32 | Note that FedTree-Hori+SA achieves the same security guarantee as Homo-SecureBoost. The speedup is the computed by the improvement of FedTree-Hori+SA over Homo-SecureBoost, which is quite significant. 33 | 34 | 35 | 36 | +----------+--------------+-----------------+------------------+---------+ 37 | | datasets | FedTree-Hori | FedTree-Hori+SA | Homo-SecureBoost | Speedup | 38 | +----------+--------------+-----------------+------------------+---------+ 39 | | a9a | 0.09 | 0.098 | 8.76 | 89.4 | 40 | +----------+--------------+-----------------+------------------+---------+ 41 | | abalone | 0.11 | 0.19 | 7.7 | 40.5 | 42 | +----------+--------------+-----------------+------------------+---------+ 43 | 44 | 45 | We compare the efficiency of FedTree-Verti with Hetero-SecureBoost of FATE. 46 | We present the trainig time (s) per tree. Note that FedTree-Verti+HE achieves the same security guarantee as SecureBoost. 47 | The speedup is the improvement of FedTree-Verti + HE (CPU) over FATE. FedTree is still much faster than SecureBoost. Moreover, FedTree can utilize GPU to accelerate the HE computation. 48 | 49 | +----------+---------------+------------------------+------------------------+--------------------+---------+ 50 | | datasets | FedTree-Verti | FedTree-Verti+HE (CPU) | FedTree-Verti+HE (GPU) | Hetero-SecureBoost | Speedup | 51 | +----------+---------------+------------------------+------------------------+--------------------+---------+ 52 | | a9a | 0.11 | 5.25 | 3.24 | 34.02 | 6.48 | 53 | +----------+---------------+------------------------+------------------------+--------------------+---------+ 54 | | abalone | 0.05 | 7.43 | 6.5 | 15.7 | 2.11 | 55 | +----------+---------------+------------------------+------------------------+--------------------+---------+ 56 | 57 | 58 | Distributed Computing 59 | ~~~~~~~~~~~~~~~~~~~~~ 60 | For distributed setting, we use a cluster with 5 machines, where each machine has two Intel Xeon E5-2680 14 core CPUs. 61 | We set the number of parties to 4, where each party hosts a machine. The results are shown below. Here Homo-SecureBoost (from FATE) and FedTree-Hori+SA have the same security level. 62 | We can observe that both horizontal and vertical FedTree are faster than FATE. 63 | 64 | +----------+------------------+-------------------+---------+-------------+------------------+---------+ 65 | | datasets | Homo-SecureBoost | FedTree-Hori + SA | Speedup | SecureBoost | FedTree-Verti+HE | Speedup | 66 | +----------+------------------+-------------------+---------+-------------+------------------+---------+ 67 | | a9a | 214.7 | 124.4 | 1.7 | 505.4 | 93.2 | 5.4 | 68 | +----------+------------------+-------------------+---------+-------------+------------------+---------+ 69 | | abalone | 256.3 | 156.8 | 1.6 | 299.8 | 143.5 | 2.1 | 70 | +----------+------------------+-------------------+---------+-------------+------------------+---------+ -------------------------------------------------------------------------------- /docs/source/Frameworks.rst: -------------------------------------------------------------------------------- 1 | Frameworks 2 | ========== 3 | 4 | Here is an introduction of FedTree algorithms. 5 | 6 | 7 | 8 | **Contents** 9 | 10 | - `Horizontal Federated GBDTs <#horizontal-federated-gbdts>`__ 11 | 12 | - `Vertical Federated GBDTs <#vertical-federated-gbdts>`__ 13 | 14 | - `Build on Linux <#build-on-linux>`__ 15 | 16 | - `Build on MacOS <#build-on-macos>`__ 17 | 18 | Horizontal Federated GBDTs 19 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 20 | In the horizontal FedTree, the parties have their local datasets with the same feature space but different sample spaces. The framework of horizontal federated GBDTs training is shown below. There are four steps in each round. 21 | 22 | .. image:: ./images/hori_fram.png 23 | :align: center 24 | :target: ./images/hori_fram.png 25 | 26 | 27 | 1. The server sends the initialization parameters (#round = 1) or sends the new tree (#round > 1) to the parties. 28 | 29 | 2. The parties update the gradient histogram. 30 | 31 | 3. The parties send the gradient histogram to the server. 32 | 33 | 4. The server merges the histogram and boosts a new tree. 34 | 35 | We repeat the above steps until reach the given number of trees. 36 | 37 | We provide the option to adopt the `secure aggregation `_ method to protect the exchanged histograms. 38 | In the beginning of training, the clients and the server use Diffie-Hellman key exchange to share a secret key for each pair of clients. 39 | Then, before transfering the gradient histogram, each client generates random noises for each other client, encrypts the noises by the shared key of the corresponding client, and sends the encrypted noises to the server. 40 | Then, the server sends back the encrypted noises to the clients. The clients decrypts the noises with the shared keys. Then, the clients add the generated noises and subtract the decrypted noises to the local histogram. 41 | The injected noises of each client cancel each other out and the aggregates histogram remains unchanged. 42 | 43 | The detailed algorithm is shown below. 44 | 45 | .. image:: ./images/fedtree_hori.png 46 | :align: center 47 | :target: ./images/fedtree_hori.png 48 | 49 | 50 | If adopting differential privacy, the server will train a differentially private tree in the fourth step using Laplace mechanism and exponential mechanism. 51 | 52 | 53 | 54 | .. If adopting homomorphic encryption, the framework is shown below. There are five steps in each round. 55 | .. image:: ./images/hori_fram_he.png 56 | :align: center 57 | :target: ./images/hori_fram_he.png 58 | 1. The server sends the initialization parameters and the public key (#round = 1) or sends the new tree (#round > 1) to the parties. 59 | 2. The parties update the gradient histogram and encrypt it using the public key. 60 | 3. The parties send the encrypted histogram to a selected party. 61 | 4. The party sums the encrypted histogram and sends the merged histogram to the server. 62 | 5. The server decrypts the histogram using its private key and boosts a new tree. 63 | 64 | Vertical Federated GBDTs 65 | ~~~~~~~~~~~~~~~~~~~~~~~~ 66 | In the vertical FedTree, the parties have their local datasets with the same sample space but different feature spaces. 67 | Moreover, at least one party has the labels of the samples. We specify one party that has the labels as the host party (i.e., aggregator). 68 | 69 | The framework of vertical federated GBDTs training is shown below. There are four steps in each round. 70 | 71 | .. image:: ./images/verti_fram.png 72 | :align: center 73 | :target: ./images/verti_fram.png 74 | 75 | 1. The host party (i.e., the party with the labels) updates the gradients and sends the gradients to the other parties. 76 | 77 | For each depth: 78 | 79 | 2. The parties computes the local gradient histograms. 80 | 81 | 3. The parties send their local histograms to the host party. 82 | 83 | 4. The host party aggregate the histograms, computes the best split point, and ask the corresponding party (including itself) to update the node. 84 | 85 | 5. The parties send back the nodes to the host party. 86 | 87 | Here 2-4 steps are done for each depth of a tree until reaching the given maximum depth. The above steps are repeated until reaching the given number of trees. 88 | If homomorphic encryption is applied, the host party sends the encrypted gradients in the first step and decrypts the histogram in the fourth step. 89 | 90 | We provide the option to adopt `additive homomorphic encryption `_ to protect the exchanged gradients. 91 | Specifically, the host party generates public and private keys before the training. Then, it uses the public key to encrypt the gradients before sending them. 92 | After receiving local histograms from the parties, the host party uses privacy key to decrypt the histograms before further computation. 93 | 94 | The detailed algorithm is shown below. 95 | 96 | .. image:: ./images/fedtree_verti.png 97 | :align: center 98 | :target: ./images/fedtree_verti.png 99 | 100 | 101 | If differential privacy is applied, the host party updates the tree using Laplace mechanism and exponential mechanism. 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /docs/source/Installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | Here is the guide for the installation of FedTree. 5 | 6 | 7 | 8 | **Contents** 9 | 10 | - `Prerequisites <#prerequisites>`__ 11 | 12 | - `Install Fedtree <#install-fedtree>`__ 13 | 14 | - `Build on Linux <#build-on-linux>`__ 15 | 16 | - `Build on MacOS <#build-on-macos>`__ 17 | 18 | Prerequisites 19 | ~~~~~~~~~~~~~ 20 | 21 | * `CMake `_ 3.15 or above 22 | * `GMP `_ library 23 | * `NTL `_ library 24 | 25 | You can follow the following commands to install NTL library. 26 | 27 | .. code:: 28 | 29 | wget https://libntl.org/ntl-11.4.4.tar.gz 30 | tar -xvf ntl-11.4.4.tar.gz 31 | cd ntl-11.4.4/src 32 | ./configure 33 | make 34 | make check 35 | sudo make install 36 | 37 | If you install the NTL library at another location, please pass the location to the `NTL_PATH` when building the library (e.g., `cmake .. -DNTL_PATH="PATH_TO_NTL"`). 38 | 39 | Clone Install submodules 40 | ~~~~~~~~~~~~~~~~~~~~~~~~ 41 | 42 | Run the following commands: 43 | 44 | .. code:: 45 | 46 | git clone https://github.com/Xtra-Computing/FedTree 47 | git submodule init 48 | git submodule update 49 | 50 | Build on Linux 51 | ~~~~~~~~~~~~~~ 52 | Run the following commands: 53 | 54 | .. code:: 55 | 56 | # under the directory of FedTree 57 | mkdir build && cd build 58 | cmake .. 59 | make -j 60 | 61 | Build on MacOS 62 | ~~~~~~~~~~~~~~ 63 | On MacOS, you can use Apple Clang to build FedTree. 64 | 65 | Build with Apple Clang 66 | ^^^^^^^^^^^^^^^^^^^^^^ 67 | Install `libomp` if you haven't: 68 | 69 | .. code:: 70 | 71 | brew install libomp 72 | 73 | Run the following commands: 74 | 75 | .. code:: 76 | 77 | mkdir build 78 | cd build 79 | cmake -DOpenMP_C_FLAGS="-Xpreprocessor -fopenmp -I/usr/local/opt/libomp/include" \ 80 | -DOpenMP_C_LIB_NAMES=omp \ 81 | -DOpenMP_CXX_FLAGS="-Xpreprocessor -fopenmp -I/usr/local/opt/libomp/include" \ 82 | -DOpenMP_CXX_LIB_NAMES=omp \ 83 | -DOpenMP_omp_LIBRARY=/usr/local/opt/libomp/lib/libomp.dylib \ 84 | .. 85 | make -j 86 | 87 | Building Options 88 | ~~~~~~~~~~~~~~~~ 89 | There are the following building options passing with cmake. 90 | 91 | * ``USE_CUDA`` [default = ``OFF``]: Whether using GPU to accelerate homomorphic encryption or not. 92 | 93 | * ``DISTRIBUTED`` [default = ``ON``]: Whether building distributed version of FedTree or not. 94 | 95 | * ``NTL_PATH`` [default = ``/usr/local``]: The PATH to the NTL library. 96 | 97 | For example, if you want to build a version with GPU acceleration, distributed version with NTL library under /home/NTL directory, you can use the following command. 98 | 99 | .. code:: 100 | 101 | cmake .. -DUSE_CUDA=ON -DDISTRIBUTED=ON -DNTL_PATH="/home/NTL" 102 | make -j 103 | 104 | 105 | -------------------------------------------------------------------------------- /docs/source/Parameters.rst: -------------------------------------------------------------------------------- 1 | APIs/Parameters 2 | =============== 3 | 4 | We provide two kinds of APIs: command-line interface (CLI) and Python interface. For CLI, users only need to prepare a 5 | configuration file specifying the parameters and call FedTree in a one-line command. For Python interface, users can define 6 | two classes `FLClassifier` and `FLRegressor` with the parameters and use them in a scikit-learn style (see `here `__). 7 | The parameters are below. 8 | 9 | **Contents** 10 | 11 | - `Parameters for Federated Setting <#parameters-for-federated-setting>`__ 12 | 13 | - `Parameters for GBDTs <#parameters-for-gbdts>`__ 14 | 15 | - `Parameters for Privacy Protection <#parameters-for-privacy-protection>`__ 16 | 17 | Parameters for Federated Setting 18 | -------------------------------- 19 | 20 | * ``mode`` [default = ``horizontal``, type=string] 21 | - ``horizontal``: horizontal federated learning 22 | - ``vertical``: vertical federated learning 23 | 24 | * ``num_parties`` [default = ``10``, type = int, alias: ``num_clients``, ``num_devices``] 25 | - Number of parties 26 | 27 | * ``partition`` [default = ``0``, type = bool] 28 | - ``0``: each party has a prepared local dataset 29 | - ``1``: there is a global dataset and users require FedTree to partition it to multiple subsets to simulate federated setting. 30 | 31 | * ``partition_mode`` [default=``iid``, type=string] 32 | - ``iid``: IID data partitioning 33 | - ``noniid``: non-IID data partitioning 34 | 35 | * ``ip_address`` [default=``localhost``, type=string, alias: ``server_ip_address``] 36 | - The ip address of the server in distributed FedTree. 37 | 38 | * ``data_format`` [default=``libsvm``, type=string] 39 | - ``libsvm``: the input data is in a libsvm format (label feature_id1:feature_value1 feature_id2:feature_value2). See `here `__ for an example. 40 | - ``csv``: the input data is in a csv format (the first row is the header and the other rows are feature values). See `here `__ for an example. 41 | 42 | * ``n_features`` [default=-1, type=int] 43 | - Number of features of the datasets. It needs to be specified when conducting horizontal FedTree with sparse datasets. 44 | 45 | * ``propose_split`` [default=``server``, type=string] 46 | - ``server``: the server proposes candidate split points according to the range of each feature in horizontal FedTree. 47 | - ``party``: the parties propose possible split points. Then, the server merge them and sample at most num_max_bin candidate split points in horizontal FedTree. 48 | 49 | Parameters for GBDTs 50 | -------------------- 51 | 52 | * ``data`` [default=``../dataset/test_dataset.txt``, type=string, alias: ``path``] 53 | - The path to the training dataset(s). In simulation, if multiple datasets need to be loaded where each dataset represents a party, specify the paths seperated with comma. 54 | 55 | * ``model_path`` [default=``fedtree.model``, type=string] 56 | - The path to save/load the model. 57 | 58 | * ``verbose`` [default=1, type=int] 59 | - Printing information: 0 for silence, 1 for key information and 2 for more information. 60 | 61 | * ``depth`` [default=6, type=int] 62 | 63 | - The maximum depth of the decision trees. Shallow trees tend to have better generality, and deep trees are more likely to overfit the training data. 64 | 65 | * ``n_trees`` [default=40, type=int] 66 | 67 | - The number of training iterations. ``n_trees`` equals to the number of trees in GBDTs. 68 | 69 | 70 | * ``max_num_bin`` [default=32, type=int] 71 | 72 | - The maximum number of bins in a histogram. The value needs to be smaller than 256. 73 | 74 | * ``learning_rate`` [default=1, type=float, alias: ``eta``] 75 | 76 | - Valid domain: [0,1]. This option is to set the weight of newly trained tree. Use ``eta < 1`` to mitigate overfitting. 77 | 78 | * ``objective`` [default=``reg:linear``, type=string] 79 | 80 | - Valid options include ``reg:linear``, ``reg:logistic``, ``multi:softprob``, ``multi:softmax``, ``rank:pairwise`` and ``rank:ndcg``. 81 | - ``reg:linear`` is for regression, ``reg:logistic`` and ``binary:logistic`` are for binary classification. 82 | - ``multi:softprob`` and ``multi:softmax`` are for multi-class classification. ``multi:softprob`` outputs probability for each class, and ``multi:softmax`` outputs the label only. 83 | - ``rank:pairwise`` and ``rank:ndcg`` are for ranking problems. 84 | 85 | * ``num_class`` [default=1, type=int] 86 | - Set the number of classes in the multi-class classification. This option is not compulsory. 87 | 88 | * ``min_child_weight`` [default=1, type=float] 89 | 90 | - The minimum sum of instance weight (measured by the second order derivative) needed in a child node. 91 | 92 | * ``lambda`` [default=1, type=float, alias: ``lambda_tgbm`` or ``reg_lambda``] 93 | 94 | - L2 regularization term on weights. 95 | 96 | * ``gamma`` [default=1, type=float, alias: ``min_split_loss``] 97 | 98 | - The minimum loss reduction required to make a further split on a leaf node of the tree. ``gamma`` is used in the pruning stage. 99 | 100 | 101 | Parameters for Privacy Protection 102 | --------------------------------- 103 | 104 | * ``privacy_method`` [default = ``none``, type=string] 105 | - ``none``: no additional method is used to protect the communicated messages (raw data is not transferred). 106 | - ``he``: use homomorphic encryption to protect the communicated messages (for vertical FedTree). 107 | - ``sa``: use secure aggregation to protect the communicated messages (for horizontal FedTree). 108 | - ``dp``: use differential privacy to protect the communicated messages (currently only works for single machine simulation). 109 | 110 | 111 | * ``privacy_budget`` [default=10, type=float] 112 | - Total privacy budget if using differential privacy. 113 | -------------------------------------------------------------------------------- /docs/source/Quick-Start.rst: -------------------------------------------------------------------------------- 1 | Quick Start 2 | =========== 3 | 4 | Here we present an example to simulate vertical federated learning with FedTree to help you understand the procedure of using FedTree. 5 | 6 | Prepare a dataset / datasets 7 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 8 | You can either prepare a global dataset to simulate the federated setting by partitioning in FedTree or prepare a local dataset for each party. 9 | 10 | For the data format, FedTree supports svmlight/libsvm format (each row is an instance with ``label feature_id1:feature_value1 feature_id2:feature_value2 ...``) 11 | and csv format (the first row is the header ``id,label,feature_id1,feature_id2,...`` and the other rows are the corresponding values). 12 | See `here `__ for an example of libsvm format dataset 13 | and `here `__ for an example of csv format dataset. 14 | 15 | For classification task, please ensure that the labels of the dataset are organized as ``0 1 2 ...`` (e.g., use labels 0 and 1 for binary classification). 16 | 17 | Configure the Parameters 18 | ~~~~~~~~~~~~~~~~~~~~~~~~ 19 | You can set the parameters in a file, e.g., ``machine.conf`` under ``dataset`` subdirectory. 20 | For example, we can set the following example parameters to run vertical federated learning using homomorphic encryption to protect the communicated message. 21 | For more details about the parameters, please refer to `here `__. 22 | 23 | .. code:: 24 | 25 | data=./dataset/test_dataset.txt 26 | test_data=./dataset/test_dataset.txt 27 | partition_mode=vertical 28 | n_parties=4 29 | mode=vertical 30 | privacy_tech=he 31 | n_trees=40 32 | depth=6 33 | learning_rate=0.2 34 | 35 | Run FedTree 36 | ~~~~~~~~~~~ 37 | After you install FedTree, you can simply run the following commands under ``FedTree`` directory to simulate vertical federated learning in a single machine. 38 | 39 | .. code:: 40 | 41 | ./build/bin/FedTree-train ./examples/vertical_example.conf 42 | ./build/bin/FedTree-predict ./examples/prediction.conf 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | .. _LibSVM: https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'FedTree' 21 | copyright = '2021, Xtra Computing' 22 | author = 'Xtra Computing' 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | "sphinx.ext.intersphinx", 32 | "sphinx.ext.autodoc", 33 | "sphinx.ext.mathjax", 34 | "sphinx.ext.viewcode", 35 | ] 36 | 37 | 38 | # Add any paths that contain templates here, relative to this directory. 39 | templates_path = ['_templates'] 40 | 41 | # List of patterns, relative to source directory, that match files and 42 | # directories to ignore when looking for source files. 43 | # This pattern also affects html_static_path and html_extra_path. 44 | exclude_patterns = [] 45 | 46 | 47 | # -- Options for HTML output ------------------------------------------------- 48 | 49 | # The theme to use for HTML and HTML Help pages. See the documentation for 50 | # a list of builtin themes. 51 | # 52 | html_theme = "sphinx_rtd_theme" 53 | 54 | # Add any paths that contain custom static files (such as style sheets) here, 55 | # relative to this directory. They are copied after the builtin static files, 56 | # so a file named "default.css" will overwrite the builtin "default.css". 57 | html_static_path = ['_static'] -------------------------------------------------------------------------------- /docs/source/images/fedtree_archi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/FedTree/ba4ea5eb4a75742a8b1bc986868e3d310260dc97/docs/source/images/fedtree_archi.png -------------------------------------------------------------------------------- /docs/source/images/fedtree_hori.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/FedTree/ba4ea5eb4a75742a8b1bc986868e3d310260dc97/docs/source/images/fedtree_hori.png -------------------------------------------------------------------------------- /docs/source/images/fedtree_verti.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/FedTree/ba4ea5eb4a75742a8b1bc986868e3d310260dc97/docs/source/images/fedtree_verti.png -------------------------------------------------------------------------------- /docs/source/images/hori_fram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/FedTree/ba4ea5eb4a75742a8b1bc986868e3d310260dc97/docs/source/images/hori_fram.png -------------------------------------------------------------------------------- /docs/source/images/hori_fram_he.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/FedTree/ba4ea5eb4a75742a8b1bc986868e3d310260dc97/docs/source/images/hori_fram_he.png -------------------------------------------------------------------------------- /docs/source/images/verti_fram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/FedTree/ba4ea5eb4a75742a8b1bc986868e3d310260dc97/docs/source/images/verti_fram.png -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. FedTree documentation master file, created by 2 | sphinx-quickstart on Mon Apr 19 14:22:49 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to FedTree's documentation! 7 | =================================== 8 | 9 | **FedTree** is a federated learning system for tree-based models. It is designed to be highly **efficient**, **effective**, 10 | and **secure**. It has the following features. 11 | 12 | - Parallel computing on multi-core CPUs and GPUs. 13 | - Stand-alone simulation and distributed learning. 14 | - Support of homomorphic encryption, secure aggregation, and differential privacy. 15 | - Federated training algorithms of gradient boosting decision trees and random forests. 16 | 17 | .. image:: ./images/fedtree_archi.png 18 | :align: center 19 | :target: ./images/fedtree_archi.png 20 | 21 | .. toctree:: 22 | :maxdepth: 1 23 | :caption: Contents: 24 | 25 | Installation 26 | Quick Start 27 | APIs/Parameters 28 | Examples 29 | Frameworks 30 | Experiments 31 | 32 | -------------------------------------------------------------------------------- /examples/abalone/a9a_horizontal_p1.conf: -------------------------------------------------------------------------------- 1 | data=./dataset/adult/a9a_horizontal_p1 2 | test_data=./dataset/adult/a9a_horizontal_test 3 | n_parties=2 4 | num_class=2 5 | objective=binary:logistic 6 | mode=horizontal 7 | max_num_bin=255 8 | privacy_tech=sa 9 | learning_rate=0.1 10 | max_depth=6 11 | n_trees=50 12 | ip_address=192.168.141.2 13 | -------------------------------------------------------------------------------- /examples/abalone/a9a_horizontal_server.conf: -------------------------------------------------------------------------------- 1 | test_data=./dataset/adult/a9a_horizontal_test 2 | n_parties=2 3 | num_class=2 4 | objective=binary:logistic 5 | mode=horizontal 6 | partition_mode=horizontal 7 | max_num_bin=255 8 | privacy_tech=sa 9 | learning_rate=0.1 10 | max_depth=6 11 | n_trees=50 12 | ip_address=192.168.141.2 13 | -------------------------------------------------------------------------------- /examples/abalone/a9a_vertical_p0.conf: -------------------------------------------------------------------------------- 1 | data=./dataset/adult/a9a_vertical_p0 2 | test_data=./dataset/adult/a9a_vertical_test 3 | n_parties=2 4 | mode=vertical 5 | reorder_label=1 6 | objective=binary:logistic 7 | privacy_tech=none 8 | learning_rate=0.1 9 | max_depth=6 10 | n_trees=10 11 | ip_address=localhost -------------------------------------------------------------------------------- /examples/abalone/a9a_vertical_p1.conf: -------------------------------------------------------------------------------- 1 | data=./dataset/adult/a9a_vertical_p1_nolabel 2 | test_data=./dataset/adult/a9a_vertical_test 3 | n_parties=2 4 | mode=vertical 5 | reorder_label=1 6 | objective=binary:logistic 7 | privacy_tech=none 8 | learning_rate=0.1 9 | max_depth=6 10 | n_trees=10 11 | ip_address=localhost 12 | -------------------------------------------------------------------------------- /examples/abalone/aba_horizontal_p0.conf: -------------------------------------------------------------------------------- 1 | data=./cut_dataset/aba_h_2_0 2 | n_parties=2 3 | objective=reg:linear 4 | mode=horizontal 5 | privacy_tech=sa 6 | learning_rate=0.1 7 | max_num_bin=255 8 | max_depth=6 9 | n_trees=50 10 | ip_address=192.168.141.2 11 | -------------------------------------------------------------------------------- /examples/abalone/aba_horizontal_p1.conf: -------------------------------------------------------------------------------- 1 | data=./cut_dataset/aba_h_2_1 2 | n_parties=2 3 | objective=reg:linear 4 | mode=horizontal 5 | privacy_tech=sa 6 | learning_rate=0.1 7 | max_num_bin=255 8 | max_depth=6 9 | n_trees=50 10 | ip_address=192.168.141.2 11 | -------------------------------------------------------------------------------- /examples/abalone/aba_horizontal_server.conf: -------------------------------------------------------------------------------- 1 | n_parties=2 2 | objective=reg:linear 3 | mode=horizontal 4 | privacy_tech=sa 5 | learning_rate=0.1 6 | max_num_bin=255 7 | max_depth=6 8 | n_trees=50 9 | ip_address=192.168.141.2 10 | -------------------------------------------------------------------------------- /examples/abalone/aba_vertical_p0.conf: -------------------------------------------------------------------------------- 1 | data=./cut_dataset/aba_v_2_0 2 | n_parties=2 3 | mode=vertical 4 | reorder_label=0 5 | objective=reg:linear 6 | privacy_tech=he 7 | learning_rate=0.1 8 | max_depth=6 9 | max_num_bin=255 10 | n_trees=50 11 | ip_address=192.168.141.2 12 | -------------------------------------------------------------------------------- /examples/abalone/aba_vertical_p1.conf: -------------------------------------------------------------------------------- 1 | data=./cut_dataset/aba_v_2_1_nolabel 2 | n_parties=2 3 | mode=vertical 4 | reorder_label=0 5 | objective=reg:linear 6 | privacy_tech=he 7 | learning_rate=0.1 8 | max_depth=6 9 | max_num_bin=255 10 | n_trees=50 11 | ip_address=192.168.141.2 12 | -------------------------------------------------------------------------------- /examples/adult/a9a_horizontal_p0.conf: -------------------------------------------------------------------------------- 1 | data=./dataset/adult/a9a_horizontal_p0 2 | test_data=./dataset/adult/a9a_horizontal_test 3 | n_parties=2 4 | model_path=p0.model 5 | num_class=2 6 | objective=binary:logistic 7 | mode=horizontal 8 | privacy_tech=sa 9 | learning_rate=0.1 10 | max_num_bin=255 11 | max_depth=6 12 | n_trees=50 13 | ip_address=192.168.141.2 14 | -------------------------------------------------------------------------------- /examples/adult/a9a_horizontal_p1.conf: -------------------------------------------------------------------------------- 1 | data=./dataset/adult/a9a_horizontal_p1 2 | test_data=./dataset/adult/a9a_horizontal_test 3 | n_parties=2 4 | num_class=2 5 | model_path=p1.model 6 | objective=binary:logistic 7 | mode=horizontal 8 | max_num_bin=255 9 | privacy_tech=sa 10 | learning_rate=0.1 11 | max_depth=6 12 | n_trees=50 13 | ip_address=192.168.141.2 14 | -------------------------------------------------------------------------------- /examples/adult/a9a_horizontal_server.conf: -------------------------------------------------------------------------------- 1 | test_data=./dataset/adult/a9a_horizontal_test 2 | n_parties=2 3 | num_class=2 4 | objective=binary:logistic 5 | mode=horizontal 6 | partition_mode=horizontal 7 | max_num_bin=255 8 | privacy_tech=sa 9 | learning_rate=0.1 10 | max_depth=6 11 | n_trees=50 12 | ip_address=192.168.141.2 13 | -------------------------------------------------------------------------------- /examples/adult/a9a_vertical_p0.conf: -------------------------------------------------------------------------------- 1 | data=./dataset/adult/a9a_vertical_p0 2 | test_data=./dataset/adult/a9a_vertical_test 3 | n_parties=2 4 | mode=vertical 5 | reorder_label=1 6 | objective=binary:logistic 7 | privacy_tech=he 8 | learning_rate=0.1 9 | max_depth=6 10 | max_num_bin=255 11 | n_trees=50 12 | ip_address=192.168.141.2 13 | -------------------------------------------------------------------------------- /examples/adult/a9a_vertical_p1.conf: -------------------------------------------------------------------------------- 1 | data=./dataset/adult/a9a_vertical_p1 2 | test_data=./dataset/adult/a9a_vertical_test 3 | n_parties=2 4 | mode=vertical 5 | reorder_label=1 6 | objective=binary:logistic 7 | privacy_tech=none 8 | learning_rate=0.1 9 | max_depth=6 10 | n_trees=50 11 | max_num_bin=255 12 | ip_address=192.168.141.2 13 | -------------------------------------------------------------------------------- /examples/breast_distributed_horizontal_server.conf: -------------------------------------------------------------------------------- 1 | mode=horizontal 2 | partition=0 3 | max_num_bin=32 4 | reorder_label=0 5 | objective=binary:logistic 6 | privacy_tech=none 7 | n_parties=2 8 | num_class=2 9 | learning_rate=0.1 10 | max_depth=6 11 | n_trees=10 12 | ip_address=192.168.141.1 13 | -------------------------------------------------------------------------------- /examples/credit/credit_vertical_p0.conf: -------------------------------------------------------------------------------- 1 | data=./dataset/credit/credit_vertical_p0_withlabel.csv 2 | n_parties=2 3 | num_class=2 4 | mode=vertical 5 | partition=0 6 | data_format=csv 7 | reorder_label=0 8 | max_num_bin=16 9 | objective=binary:logistic 10 | privacy_tech=none 11 | learning_rate=0.1 12 | max_depth=6 13 | n_trees=10 14 | ip_address=192.168.141.2 15 | -------------------------------------------------------------------------------- /examples/credit/credit_vertical_p1.conf: -------------------------------------------------------------------------------- 1 | data=./dataset/credit/credit_vertical_p1.csv 2 | n_parties=2 3 | num_class=2 4 | mode=vertical 5 | partition=0 6 | data_format=csv 7 | reorder_label=0 8 | max_num_bin=16 9 | objective=binary:logistic 10 | privacy_tech=none 11 | learning_rate=0.1 12 | max_depth=6 13 | n_trees=10 14 | ip_address=192.168.141.2 15 | -------------------------------------------------------------------------------- /examples/horizontal_example.conf: -------------------------------------------------------------------------------- 1 | data=./dataset/test_dataset.txt 2 | test_data=./dataset/test_dataset.txt 3 | n_parties=2 4 | mode=horizontal 5 | privacy_tech=none 6 | model_path=fedtree.model 7 | learning_rate=0.2 8 | max_depth=4 9 | n_trees=10 10 | partition=1 11 | -------------------------------------------------------------------------------- /examples/prediction.conf: -------------------------------------------------------------------------------- 1 | data=./dataset/test_dataset.txt 2 | model_path=fedtree.model 3 | -------------------------------------------------------------------------------- /examples/vertical_example.conf: -------------------------------------------------------------------------------- 1 | data=./dataset/test_dataset.txt 2 | test_data=./dataset/test_dataset.txt 3 | model_path=fedtree.model 4 | partition_mode=vertical 5 | n_parties=2 6 | mode=vertical 7 | privacy_tech=none 8 | n_trees=40 9 | depth=6 10 | learning_rate=0.2 11 | partition=1 12 | -------------------------------------------------------------------------------- /include/FedTree/DP/differential_privacy.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Tianyuan Fu on 14/3/21. 3 | // 4 | 5 | #ifndef FEDTREE_DIFFERENTIALPRIVACY_H 6 | #define FEDTREE_DIFFERENTIALPRIVACY_H 7 | 8 | #include 9 | #include "FedTree/FL/FLparam.h" 10 | #include "FedTree/Tree/tree.h" 11 | #include 12 | 13 | using namespace std; 14 | //template 15 | class DifferentialPrivacy { 16 | public: 17 | float max_gradient = 1.0; 18 | float lambda; 19 | float constant_h = 1.0; 20 | float delta_g; 21 | float delta_v; 22 | float privacy_budget; 23 | float privacy_budget_per_tree; 24 | float privacy_budget_leaf_nodes; 25 | float privacy_budget_internal_nodes; 26 | 27 | void init(FLParam fLparam); 28 | 29 | /** 30 | * calculates p value based on gain value for each split point 31 | * @param gain - gain values of all split points in the level 32 | * @param prob - probability masses (Pi) of all split points in the level (not the actual probability) 33 | */ 34 | void compute_split_point_probability(SyncArray &gain, SyncArray &prob_exponent); 35 | 36 | /** 37 | * exponential mechanism: randomly selects split point based on p value 38 | * @param prob - probability masses (Pi) of all split points in the level (not the actual probability) 39 | * @param gain - gain values of all split points in the level 40 | * @param best_idx_gain - mapping from the node index to the gain of split point; containing all the node in the level 41 | */ 42 | void exponential_select_split_point(SyncArray &prob_exponent, SyncArray &gain, 43 | SyncArray &best_idx_gain, int n_nodes_in_level, int n_bins); 44 | 45 | /** 46 | * adds Laplace noise to the data 47 | * @param node - the leaf node which noise are to be added 48 | */ 49 | void laplace_add_noise(Tree::TreeNode &node); 50 | 51 | /** 52 | * clips gradient data 53 | * @param value - gradient data 54 | */ 55 | template 56 | void clip_gradient_value(T& value) { 57 | value = max(min(value, 1),-1); 58 | } 59 | }; 60 | 61 | #endif //FEDTREE_DIFFERENTIALPRIVACY_H 62 | -------------------------------------------------------------------------------- /include/FedTree/DP/noises.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Tianyuan Fu on 10/19/20. 3 | // 4 | 5 | #ifndef FEDTREE_NOISES_H 6 | #define FEDTREE_NOISES_H 7 | 8 | #include 9 | 10 | template 11 | class DPnoises { 12 | public: 13 | static void add_gaussian_noise(T& data, float variance){ 14 | std::default_random_engine generator; 15 | std::normal_distribution distribution(0.0, variance); 16 | 17 | double noise = distribution(generator); 18 | data += noise; 19 | } 20 | 21 | static void add_laplacian_noise(T& data, float variance) { 22 | // a r.v. following Laplace(0, b) is equivalent to the difference of 2 i.i.d Exp(1/b) r.v. 23 | double b = sqrt(variance/2); 24 | std::default_random_engine generator; 25 | std::exponential_distribution distribution(1/b); 26 | double noise = distribution(generator) - distribution(generator); 27 | *data += noise; 28 | } 29 | }; 30 | 31 | #endif //FEDTREE_NOISES_H 32 | -------------------------------------------------------------------------------- /include/FedTree/Encryption/diffie_hellman.h: -------------------------------------------------------------------------------- 1 | #ifndef FEDTREE_DIFFIE_HELLMAN_H 2 | #define FEDTREE_DIFFIE_HELLMAN_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "FedTree/common.h" 9 | using namespace NTL; 10 | using namespace std; 11 | class DiffieHellman { 12 | public: 13 | DiffieHellman(); 14 | 15 | DiffieHellman& operator=(DiffieHellman source) { 16 | this->p = source.p; 17 | this->g = source.g; 18 | // this->random = source.random; 19 | return *this; 20 | } 21 | // void primegen(); 22 | 23 | void init_variables(int n_parties); 24 | void generate_public_key(); 25 | void compute_shared_keys(); 26 | void generate_noises(); 27 | void decrypt_noises(); 28 | 29 | ZZ encrypt(float_type &message, int pid); 30 | 31 | float_type decrypt(ZZ &message, int pid); 32 | 33 | NTL::ZZ p, g; 34 | ZZ public_key; 35 | Vec other_public_keys; 36 | long key_length = 1024; 37 | int pid; 38 | int n_parties; 39 | 40 | 41 | Vec encrypted_noises; 42 | Vec received_encrypted_noises; 43 | //private: 44 | Vec shared_keys; 45 | vector generated_noises; 46 | vector decrypted_noises; 47 | unsigned int secret; 48 | 49 | }; 50 | 51 | 52 | #endif 53 | 54 | -------------------------------------------------------------------------------- /include/FedTree/Encryption/paillier.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | class Paillier { 7 | public: 8 | Paillier(); 9 | 10 | Paillier& operator=(Paillier source) { 11 | this->modulus = source.modulus; 12 | this->generator = source.generator; 13 | this->keyLength = source.keyLength; 14 | 15 | // this->random = source.random; 16 | return *this; 17 | } 18 | void keygen(long keyLength); 19 | 20 | NTL::ZZ encrypt(const NTL::ZZ &message) const; 21 | 22 | NTL::ZZ decrypt(const NTL::ZZ &ciphertext) const; 23 | 24 | NTL::ZZ add(const NTL::ZZ &x, const NTL::ZZ &y) const; 25 | 26 | NTL::ZZ mul(const NTL::ZZ &x, const NTL::ZZ &y) const; 27 | 28 | NTL::ZZ modulus; 29 | NTL::ZZ generator; 30 | long keyLength; 31 | 32 | //private: 33 | NTL::ZZ p, q; 34 | NTL::ZZ lambda; 35 | NTL::ZZ lambda_power; 36 | NTL::ZZ u; 37 | 38 | // NTL::ZZ random; 39 | 40 | NTL::ZZ L_function(const NTL::ZZ &n) const { return (n - 1) / modulus; } 41 | }; 42 | 43 | -------------------------------------------------------------------------------- /include/FedTree/Encryption/paillier_gmp.h: -------------------------------------------------------------------------------- 1 | #ifndef FEDTREE_PAILLIER_GMP_H 2 | #define FEDTREE_PAILLIER_GMP_H 3 | 4 | #include 5 | #include 6 | //#include "FedTree/Encryption/paillier.h" 7 | 8 | class Paillier_GMP { 9 | public: 10 | Paillier_GMP(); 11 | 12 | Paillier_GMP& operator=(Paillier_GMP source) { 13 | //only copy the public key 14 | mpz_set(this->n,source.n); 15 | mpz_set(this->n_square, source.n_square); 16 | mpz_set(this->generator, source.generator); 17 | this->key_length = source.key_length; 18 | 19 | // mpz_set(this->r, source.r); 20 | return *this; 21 | } 22 | void keyGen(uint32_t keyLength); 23 | 24 | // void keygen(); 25 | 26 | void encrypt(mpz_t &r, const mpz_t &message) const; 27 | 28 | void decrypt(mpz_t &r, const mpz_t &ciphertext) const; 29 | 30 | void add(mpz_t &r, const mpz_t &x, const mpz_t &y) const; 31 | 32 | void mul(mpz_t &r, const mpz_t &x, const mpz_t &y) const; 33 | 34 | mpz_t n; 35 | mpz_t n_square; 36 | mpz_t generator; 37 | uint32_t key_length; 38 | 39 | 40 | mpz_t p, q; 41 | mpz_t lambda; 42 | mpz_t mu; 43 | 44 | // mpz_t r; 45 | 46 | // Paillier paillier_ntl; 47 | 48 | void L_function(mpz_t &r, mpz_t &input, const mpz_t &n) const; 49 | }; 50 | 51 | 52 | 53 | 54 | 55 | #endif // FEDTREE_PAILLIER_GMP_H 56 | -------------------------------------------------------------------------------- /include/FedTree/Encryption/paillier_gpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef FEDTREE_PAILLIER_GPU_CUH 3 | #define FEDTREE_PAILLIER_GPU_CUH 4 | 5 | 6 | #include 7 | #include 8 | #include "cgbn/cgbn.h" 9 | #include "FedTree/syncarray.h" 10 | #include "FedTree/common.h" 11 | #include "FedTree/Encryption/paillier_gmp.h" 12 | 13 | #define BITS 1024 14 | 15 | void to_mpz(mpz_t r, uint32_t *x, uint32_t count); 16 | 17 | void from_mpz(mpz_t s, uint32_t *x, uint32_t count); 18 | 19 | 20 | template 21 | class cgbn_gh{ 22 | public: 23 | cgbn_mem_t g; 24 | cgbn_mem_t h; 25 | }; 26 | 27 | //template 28 | class Paillier_GPU { 29 | public: 30 | Paillier_GPU(): key_length(BITS) {}; 31 | 32 | Paillier_GPU& operator=(Paillier_GPU source) { 33 | this->paillier_cpu = source.paillier_cpu; 34 | this->key_length = source.key_length; 35 | this->parameters_cpu_to_gpu(); 36 | return *this; 37 | } 38 | 39 | // Paillier_GPU& operator=(Paillier_GPU source) { 40 | // this->paillier_cpu = source.paillier_cpu; 41 | // this->generator = source.generator; 42 | // this->keyLength = source.keyLength; 43 | // return *this; 44 | // } 45 | 46 | // void keygen(); 47 | void parameters_cpu_to_gpu(); 48 | void keygen(); 49 | // explicit Paillier_GPU(unit32_t key_length); 50 | void L_function(mpz_t result, mpz_t input, mpz_t N); 51 | 52 | void encrypt(SyncArray &message); 53 | 54 | // void encrypt(GHPair &message); 55 | 56 | void decrypt(SyncArray &ciphertext); 57 | 58 | void decrypt(GHPair &message); 59 | 60 | void add(mpz_t &result, mpz_t &x, mpz_t &y); 61 | void mul(mpz_t result, mpz_t &x, mpz_t &y); 62 | 63 | // cgbn_mem_t add(SyncArray &x, SyncArray &y) const; 64 | 65 | // cgbn_mem_t mul(SyncArray &x, SyncArray &y) const; 66 | 67 | 68 | // cgbn_mem_t modulus; 69 | // cgbn_mem_t generator; 70 | 71 | // mpz_t n; 72 | // mpz_t n_square; 73 | // mpz_t generator; 74 | uint32_t key_length; 75 | 76 | cgbn_mem_t *n_gpu; 77 | cgbn_mem_t *n_square_gpu; 78 | cgbn_mem_t *generator_gpu; 79 | 80 | cgbn_mem_t *lambda_gpu; 81 | cgbn_mem_t *mu_gpu; 82 | 83 | // cgbn_mem_t *random_gpu; 84 | 85 | // cgbn_gh_results* gh_results_gpu; 86 | 87 | Paillier_GMP paillier_cpu; 88 | 89 | private: 90 | // mpz_t p, q; 91 | // mpz_t lambda; 92 | 93 | // mpz_t mu; 94 | }; 95 | 96 | //template 97 | //class cgbn_pailler_encryption_parameters{ 98 | //public: 99 | // cgbn_mem_t n; 100 | // cgbn_mem_t n_square; 101 | // cgbn_mem_t generator; 102 | // cgbn_mem_t random; 103 | //}; 104 | // 105 | //template 106 | //class cgbn_pailler_decryption_parameters{ 107 | //public: 108 | // cgbn_mem_t n; 109 | // cgbn_mem_t n_square; 110 | // cgbn_mem_t lambda; 111 | // cgbn_mem_t mu; 112 | //}; 113 | 114 | 115 | 116 | #endif //FEDTREE_PAILLIER_GPU_CUH -------------------------------------------------------------------------------- /include/FedTree/FL/FLparam.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 10/13/20. 3 | // 4 | 5 | #ifndef FEDTREE_FLPARAM_H 6 | #define FEDTREE_FLPARAM_H 7 | 8 | #include "FedTree/Tree/GBDTparam.h" 9 | #include "FedTree/common.h" 10 | 11 | // Todo: automatically set partition 12 | class FLParam { 13 | public: 14 | int n_parties; // number of parties 15 | bool partition; // input a single dataset for partitioning or input datasets for each party. 16 | float alpha; //the concentration parameter of Dir based partition approaches. 17 | int n_hori; //the number of horizontal partitioning subsets in hybrid partition. 18 | int n_verti; //the number of vertical partitioning subsets in hybrid partition. 19 | string mode; // "horizontal", "vertical", "hybrid", or "centralized" 20 | string partition_mode; // "horizontal", "vertical" or "hybrid" 21 | string privacy_tech; //"none" or "he" (homomorphic encryption) or "dp" (differential privacy) or "sa" (secure aggregation) 22 | string propose_split; // "server" or "client" 23 | string merge_histogram; // "server" or "client" 24 | float variance; // variance of dp noise if privacy_tech=="dp" 25 | float privacy_budget; // privacy budget for differential privacy 26 | string ip_address; // IP address of the server 27 | float ins_bagging_fraction; // randomly sample subset to train a tree without replacement 28 | int seed; // random seed for partitioning 29 | string data_format; // data format: "libsvm" or "csv" 30 | string label_location; // "server" or "party" for vertical FL 31 | int n_features; //specify the number of features for horizontal FL with sparse datasets 32 | GBDTParam gbdt_param; // parameters for the gbdt training 33 | }; 34 | 35 | 36 | #endif //FEDTREE_FLPARAM_H 37 | -------------------------------------------------------------------------------- /include/FedTree/FL/FLtrainer.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 10/13/20. 3 | // 4 | 5 | #ifndef FEDTREE_FLTRAINER_H 6 | #define FEDTREE_FLTRAINER_H 7 | #include "FedTree/common.h" 8 | #include "FedTree/FL/party.h" 9 | #include "FedTree/FL/server.h" 10 | // Todo: different federated training algorithms including horizontal GBDT and vertical GBDT. 11 | 12 | class FLtrainer { 13 | public: 14 | void horizontal_fl_trainer(vector &parties, Server &server, FLParam ¶ms); 15 | 16 | void vertical_fl_trainer(vector &parties, Server &server, FLParam ¶ms); 17 | 18 | void hybrid_fl_trainer(vector &parties, Server &server, FLParam ¶ms); 19 | 20 | void ensemble_trainer(vector &parties, Server &server, FLParam ¶ms); 21 | 22 | void solo_trainer(vector &parties, FLParam ¶ms); 23 | }; 24 | #endif //FEDTREE_FLTRAINER_H 25 | -------------------------------------------------------------------------------- /include/FedTree/FL/comm_helper.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/23/20. 3 | // 4 | 5 | #ifndef FEDTREE_COMM_HELPER_H 6 | #define FEDTREE_COMM_HELPER_H 7 | 8 | #include "party.h" 9 | #include "server.h" 10 | 11 | class Comm { 12 | public: 13 | void send_last_trees_to_server(Party &party, int pid, Server &server) { 14 | if (server.local_trees[pid].trees.size()) 15 | server.local_trees[pid].trees[0] = party.gbdt.trees.back(); 16 | else 17 | server.local_trees[pid].trees.push_back(party.gbdt.trees.back()); 18 | } 19 | 20 | void send_last_global_trees_to_party(Server &server, Party &party) { 21 | party.gbdt.trees.push_back(server.global_trees.trees.back()); 22 | }; 23 | 24 | void send_all_trees_to_server(Party &party, int pid, Server &server) { 25 | server.local_trees[pid].trees = party.gbdt.trees; 26 | } 27 | 28 | template 29 | SyncArray concat_msyncarray(MSyncArray &arrays, int n_nodes_in_level) { 30 | int n_parties = arrays.size(); 31 | int n_total_bins = 0; 32 | vector parties_n_bins(n_parties); 33 | for(int i = 0; i < arrays.size(); i++) { 34 | n_total_bins += arrays[i].size(); 35 | parties_n_bins[i] = arrays[i].size()/n_nodes_in_level; 36 | } 37 | int n_bins_sum = n_total_bins / n_nodes_in_level; 38 | // int n_bins_sum = accumulate(parties_n_bins.begin(), parties_n_bins.end(), 0); 39 | // int n_parties = parties_n_bins.size(); 40 | 41 | SyncArray concat_array(n_bins_sum * n_nodes_in_level); 42 | auto concat_array_data = concat_array.host_data(); 43 | for (int i = 0; i < n_nodes_in_level; i++) { 44 | for (int j = 0; j < n_parties; j++) { 45 | auto array_data = arrays[j].host_data(); 46 | for (int k = 0; k < parties_n_bins[j]; k++) { 47 | concat_array_data[i * n_bins_sum + accumulate(parties_n_bins.begin(), parties_n_bins.begin() + j, 0) + k] 48 | = array_data[i * parties_n_bins[j] + k]; 49 | } 50 | } 51 | } 52 | return concat_array; 53 | } 54 | }; 55 | 56 | 57 | #endif //FEDTREE_COMM_HELPER_H 58 | -------------------------------------------------------------------------------- /include/FedTree/FL/distributed_party.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 韩雨萱 on 9/4/21. 3 | // 4 | 5 | #ifndef FEDTREE_DISTRIBUTED_PARTY_H 6 | #define FEDTREE_DISTRIBUTED_PARTY_H 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include "../../../src/FedTree/grpc/fedtree.grpc.pb.h" 15 | #include "party.h" 16 | 17 | 18 | class DistributedParty : public Party { 19 | public: 20 | DistributedParty(std::shared_ptr channel) 21 | : stub_(fedtree::FedTree::NewStub(channel)) {}; 22 | 23 | void TriggerUpdateGradients(); 24 | 25 | void TriggerBuildInit(int t); 26 | 27 | void GetGradients(); 28 | 29 | void SendDatasetInfo(int n_bins, int n_columns); 30 | 31 | void SendHistograms(const SyncArray &hist, int type); 32 | 33 | void SendHistFid(const SyncArray &hist_fid); 34 | 35 | bool TriggerAggregate(int n_nodes_in_level); 36 | 37 | void GetBestInfo(vector &bests); 38 | 39 | void SendNode(Tree::TreeNode &node_data); 40 | void SendNodeEnc(Tree::TreeNode &node_data); 41 | 42 | void OrganizeNodesEnc(fedtree::NodeEncArray &nodes, Tree::TreeNode &node_data); 43 | void SendNodesEnc(fedtree::NodeEncArray &nodes); 44 | void OrganizeNodes(fedtree::NodeArray &nodes, Tree::TreeNode &node_data); 45 | void SendNodes(fedtree::NodeArray &nodes); 46 | 47 | void SendIns2NodeID(SyncArray &ins2node_id, int nid); 48 | 49 | void GetNodes(int l); 50 | 51 | void GetIns2NodeID(); 52 | 53 | bool CheckIfContinue(bool cont); 54 | 55 | void TriggerPrune(int t); 56 | 57 | void TriggerPrintScore(); 58 | 59 | void SendRange(const vector>& ranges); 60 | 61 | void SendCutPoints(); 62 | 63 | void GetCutPoints(); 64 | 65 | void TriggerCut(int n_bins); 66 | 67 | void GetRangeAndSet(int n_bins); 68 | 69 | void SendGH(GHPair party_gh); 70 | 71 | void TriggerBuildUsingGH(int k); 72 | 73 | void TriggerCalcTree(int l); 74 | 75 | void GetRootNode(); 76 | 77 | void GetSplitPoints(); 78 | 79 | bool HCheckIfContinue(); 80 | 81 | float GetAvgScore(float score); 82 | 83 | void TriggerHomoInit(); 84 | 85 | void TriggerSAInit(); 86 | 87 | void GetPaillier(); 88 | 89 | void SendHistogramsEnc(const SyncArray &hist, int type); 90 | 91 | void SendHistogramBatches(const SyncArray &hist, int type); 92 | 93 | void SendHistFidBatches(const SyncArray &hist); 94 | 95 | void GetIns2NodeIDBatches(); 96 | 97 | void SendIns2NodeIDBatches(SyncArray &ins2node_id, int nid, int l); 98 | 99 | void GetGradientBatches(); 100 | 101 | void GetGradientBatchesEnc(); 102 | 103 | void SendHistogramBatchesEnc(const SyncArray &hist, int type); 104 | 105 | void StopServer(float tot_time); 106 | 107 | void BeginBarrier(); 108 | 109 | void SendDHPubKey(); 110 | void GetDHPubKey(); 111 | void SendNoises(); 112 | void GetNoises(); 113 | double comm_time = 0; 114 | double enc_time = 0; 115 | double comm_size = 0; 116 | int n_parties; 117 | std::chrono::high_resolution_clock timer; 118 | 119 | private: 120 | std::unique_ptr stub_; 121 | }; 122 | 123 | #endif //FEDTREE_DISTRIBUTED_PARTY_H 124 | -------------------------------------------------------------------------------- /include/FedTree/FL/partition.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 10/13/20. 3 | // 4 | #include "FedTree/common.h" 5 | #include "FedTree/dataset.h" 6 | 7 | #ifndef FEDTREE_PARTITION_H 8 | #define FEDTREE_PARTITION_H 9 | 10 | class Partition { 11 | 12 | public: 13 | void homo_partition(const DataSet &dataset, const int n_parties, const bool is_horizontal, vector &subsets, 14 | std::map> &batch_idxs, int seed = 42); 15 | 16 | void hetero_partition(const DataSet &dataset, const int n_parties, const bool is_horizontal, vector &subsets, 17 | const vector alpha = {}, int seed = 42); 18 | 19 | void hybrid_partition(const DataSet &dataset, const int n_parties, vector &alpha, 20 | vector> &feature_map, vector &subsets, 21 | int part_length = 10, int part_width = 10); 22 | 23 | void hybrid_partition_with_test(const DataSet &dataset, const int n_parties, vector &alpha, 24 | vector> &feature_map, vector &train_subsets, 25 | vector &test_subsets, vector &subsets, 26 | int part_length=10, int part_width=10, float train_test_fraction=0.75); 27 | 28 | void horizontal_vertical_dir_partition(const DataSet &dataset, const int n_parties, float alpha, 29 | vector> &feature_map, vector &subsets, 30 | int n_hori = 2, int n_verti = 2); 31 | 32 | void train_test_split(DataSet &dataset, DataSet &train_dataset, DataSet &test_dataset, float train_portion = 0.75); 33 | 34 | }; 35 | 36 | #endif //FEDTREE_PARTITION_H 37 | -------------------------------------------------------------------------------- /include/FedTree/FL/server.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 10/13/20. 3 | // 4 | 5 | #ifndef FEDTREE_SERVER_H 6 | #define FEDTREE_SERVER_H 7 | 8 | #include "FedTree/FL/party.h" 9 | #include "FedTree/dataset.h" 10 | #include "FedTree/Tree/tree_builder.h" 11 | //#include "FedTree/Encryption/HE.h" 12 | #include "FedTree/DP/noises.h" 13 | #include "FedTree/Tree/gbdt.h" 14 | #include "omp.h" 15 | 16 | // Todo: the server structure. 17 | 18 | class Server : public Party { 19 | public: 20 | // void init(FLParam ¶m, int n_total_instances, vector &n_instances_per_party); 21 | 22 | void horizontal_init (FLParam ¶m); 23 | 24 | void vertical_init(FLParam ¶m, int n_total_instances, vector &n_instances_per_party, vector y, 25 | vector label); 26 | 27 | void propose_split_candidates(); 28 | void send_info(string info_type); 29 | // void send_info(vector &parties, AdditivelyHE::PaillierPublicKey serverKey,vectorcandidates); 30 | void sum_histograms(); 31 | void hybrid_merge_trees(); 32 | void ensemble_merge_trees(); 33 | 34 | void sample_data(); 35 | void predict_raw_vertical_jointly_in_training(const GBDTParam &model_param, vector &parties, 36 | SyncArray &y_predict); 37 | GBDT global_trees; 38 | vector local_trees; 39 | GBDTParam model_param; 40 | vector n_instances_per_party; 41 | vector has_label; 42 | 43 | 44 | // AdditivelyHE::PaillierPublicKey publicKey; 45 | // vector pk_vector; 46 | 47 | #ifdef USE_CUDA 48 | Paillier_GPU paillier; 49 | #else 50 | Paillier paillier; 51 | #endif 52 | DiffieHellman dh; 53 | 54 | void send_key(Party &party) { 55 | party.paillier = paillier; 56 | } 57 | 58 | void homo_init() { 59 | #ifdef USE_CUDA 60 | paillier.keygen(); 61 | // pailler_gmp = Pailler(1024); 62 | // paillier = Paillier(paillier_gmp); 63 | // paillier.keygen(); 64 | #else 65 | paillier.keygen(512); 66 | #endif 67 | } 68 | 69 | void decrypt_gh(GHPair &gh) { 70 | #ifdef USE_CUDA 71 | // gh.homo_decrypt(paillier.paillier_cpu); 72 | paillier.decrypt(gh); 73 | gh.encrypted = false; 74 | 75 | #else 76 | gh.homo_decrypt(paillier); 77 | #endif 78 | } 79 | 80 | void decrypt_gh_pairs(SyncArray &encrypted) { 81 | 82 | #ifdef USE_CUDA 83 | paillier.decrypt(encrypted); 84 | auto encrypted_data = encrypted.host_data(); 85 | for(int i = 0; i < encrypted.size(); i++){ 86 | encrypted_data[i].encrypted=false; 87 | } 88 | 89 | // std::cout<<"in decrypt lambda:"< &raw) { 114 | #ifdef USE_CUDA 115 | paillier.encrypt(raw); 116 | auto raw_data = raw.host_data(); 117 | #pragma omp parallel for 118 | for (int i = 0; i < raw.size(); i++) { 119 | raw_data[i].paillier = paillier.paillier_cpu; 120 | raw_data[i].encrypted = true; 121 | } 122 | 123 | // auto raw_data = raw.host_data(); 124 | // #pragma omp parallel for 125 | // for (int i = 0; i < raw.size(); i++) { 126 | // raw_data[i].homo_encrypt(paillier.paillier_cpu); 127 | // } 128 | #else 129 | auto raw_data = raw.host_data(); 130 | #pragma omp parallel for 131 | for (int i = 0; i < raw.size(); i++) { 132 | raw_data[i].homo_encrypt(paillier); 133 | } 134 | #endif 135 | } 136 | 137 | private: 138 | // std::unique_ptr fbuilder; 139 | DPnoises DP; 140 | }; 141 | 142 | #endif //FEDTREE_SERVER_H 143 | -------------------------------------------------------------------------------- /include/FedTree/Tree/GBDTparam.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 10/13/20. 3 | // 4 | 5 | #ifndef FEDTREE_GBDTPARAM_H 6 | #define FEDTREE_GBDTPARAM_H 7 | 8 | #include 9 | #include 10 | 11 | // Todo: gbdt params, refer to ThunderGBM https://github.com/Xtra-Computing/thundergbm/blob/master/include/thundergbm/common.h 12 | struct GBDTParam { 13 | int depth; 14 | int n_trees; 15 | float_type min_child_weight; 16 | float_type lambda; 17 | float_type gamma; 18 | float_type rt_eps; 19 | float column_sampling_rate; 20 | std::string path; 21 | std::string test_path; 22 | vector paths; 23 | string model_path; 24 | int verbose; 25 | bool profiling; 26 | bool bagging; 27 | int n_parallel_trees; 28 | float learning_rate; 29 | std::string objective; 30 | int num_class; 31 | int tree_per_rounds; // #tree of each round, depends on #class 32 | int max_num_bin; // for histogram 33 | float constant_h; // fix h to a constant for DP 34 | int n_device; 35 | std::string tree_method; 36 | std::string metric; 37 | bool reorder_label; 38 | }; 39 | 40 | #endif //FEDTREE_GBDTPARAM_H 41 | -------------------------------------------------------------------------------- /include/FedTree/Tree/function_builder.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/11/20. 3 | // 4 | 5 | #ifndef FEDTREE_FUNCTION_BUILDER_H 6 | #define FEDTREE_FUNCTION_BUILDER_H 7 | 8 | 9 | #include "tree.h" 10 | #include "FedTree/common.h" 11 | #include "FedTree/dataset.h" 12 | //#include "FedTree/Encryption/HE.h" 13 | #include "FedTree/Tree/hist_cut.h" 14 | 15 | class FunctionBuilder { 16 | public: 17 | virtual vector build_approximate(const SyncArray &gradients, bool update_y_predict = true) = 0; 18 | 19 | virtual Tree get_tree()= 0; 20 | 21 | virtual void set_tree(Tree tree) = 0; 22 | 23 | virtual void set_y_predict(int k) = 0; 24 | 25 | virtual void build_init(const GHPair sum_gh, int k) = 0; 26 | 27 | virtual void build_init(const SyncArray &gradients, int k) = 0; 28 | 29 | virtual void compute_histogram_in_a_level(int level, int n_max_splits, int n_bins, int n_nodes_in_level, 30 | int *hist_fid_data, SyncArray &missing_gh, 31 | SyncArray &hist) = 0; 32 | 33 | virtual void compute_gain_in_a_level(SyncArray &gain, int n_nodes_in_level, int n_bins, 34 | int *hist_fid_data, SyncArray &missing_gh, 35 | SyncArray &hist) = 0; 36 | 37 | virtual void get_best_gain_in_a_level(SyncArray &gain, SyncArray &best_idx_gain, 38 | int n_nodes_in_level, int n_bins) = 0; 39 | 40 | virtual void get_split_points_in_a_node(int node_id, int best_idx, float best_gain, int n_nodes_in_level, 41 | int *hist_fid, SyncArray &missing_gh, 42 | SyncArray &hist) = 0; 43 | 44 | virtual HistCut get_cut() = 0; 45 | 46 | virtual SyncArray get_hist() = 0; 47 | 48 | virtual void parties_hist_init(int party_size) = 0; 49 | 50 | virtual void append_hist(SyncArray &hist) = 0; 51 | 52 | virtual void append_hist(SyncArray &hist, SyncArray &missing_gh,int n_partition, int n_max_splits, int party_idx) = 0; 53 | 54 | virtual void concat_histograms() = 0; 55 | 56 | virtual void init(DataSet &dataset, const GBDTParam ¶m) { 57 | this->param = param; 58 | }; 59 | 60 | virtual void init(const GBDTParam ¶m, int n_instances) { 61 | this->param = param; 62 | } 63 | 64 | virtual SyncArray &get_y_predict() { return y_predict; }; 65 | 66 | virtual ~FunctionBuilder() {}; 67 | 68 | static FunctionBuilder *create(std::string name); 69 | 70 | SyncArray y_predict; 71 | 72 | protected: 73 | 74 | GBDTParam param; 75 | }; 76 | 77 | 78 | #endif //FEDTREE_FUNCTION_BUILDER_H 79 | -------------------------------------------------------------------------------- /include/FedTree/Tree/gbdt.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 10/13/20. 3 | // 4 | 5 | #ifndef FEDTREE_GBDT_H 6 | #define FEDTREE_GBDT_H 7 | 8 | #include "tree.h" 9 | #include "FedTree/dataset.h" 10 | 11 | class GBDT { 12 | public: 13 | vector> trees; 14 | 15 | GBDT() = default; 16 | 17 | GBDT(const vector> gbdt){ 18 | trees = gbdt; 19 | } 20 | 21 | void train(GBDTParam ¶m, DataSet &dataset); 22 | 23 | vector predict(const GBDTParam &model_param, const DataSet &dataSet); 24 | 25 | void predict_raw(const GBDTParam &model_param, const DataSet &dataSet, SyncArray &y_predict); 26 | 27 | void predict_raw_vertical(const GBDTParam &model_param, const DataSet &dataSet, SyncArray &y_predict, std::map> &batch_idxs); 28 | 29 | 30 | 31 | float_type predict_score(const GBDTParam &model_param, const DataSet &dataSet); 32 | 33 | float_type predict_score_vertical(const GBDTParam &model_param, const DataSet &dataSet, std::map> &batch_idxs); 34 | }; 35 | 36 | #endif //FEDTREE_GBDT_H 37 | -------------------------------------------------------------------------------- /include/FedTree/Tree/hist_cut.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 11/3/20. 3 | // 4 | 5 | #ifndef FEDTREE_HIST_CUT_H 6 | #define FEDTREE_HIST_CUT_H 7 | 8 | #include "FedTree/common.h" 9 | #include "FedTree/dataset.h" 10 | #include "tree.h" 11 | 12 | class HistCut { 13 | public: 14 | 15 | // The vales of cut points 16 | SyncArray cut_points_val; 17 | // The number of accumulated cut points for current feature 18 | SyncArray cut_col_ptr; 19 | // The feature id for current cut point 20 | SyncArray cut_fid; 21 | 22 | HistCut() = default; 23 | 24 | HistCut(const HistCut &cut) { 25 | cut_points_val.copy_from(cut.cut_points_val); 26 | cut_col_ptr.copy_from(cut.cut_col_ptr); 27 | cut_fid.copy_from(cut.cut_fid); 28 | } 29 | 30 | // equally divide the feature range to get cut points 31 | // void get_cut_points(float_type feature_min, float_type feature_max, int max_num_bins, int n_instances); 32 | void get_cut_points_by_data_range(DataSet &dataset, int max_num_bins, int n_instances); 33 | void get_cut_points_fast(DataSet &dataset, int max_num_bins, int n_instances); 34 | void get_cut_points_by_n_instance(DataSet &dataset, int max_num_bins); 35 | void get_cut_points_by_feature_range(vector> f_range, int max_num_bins); 36 | void get_cut_points_by_parties_cut_sampling(vector &parties_cut, int max_num_bin); 37 | }; 38 | 39 | 40 | #endif //FEDTREE_HIST_CUT_H 41 | -------------------------------------------------------------------------------- /include/FedTree/Tree/histogram.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 11/3/20. 3 | // 4 | 5 | #ifndef FEDTREE_HISTOGRAM_H 6 | #define FEDTREE_HISTOGRAM_H 7 | 8 | #include "hist_cut.h" 9 | 10 | class Histogram{ 11 | vector cut; 12 | SyncArray histogram; 13 | }; 14 | 15 | #endif //FEDTREE_HISTOGRAM_H 16 | -------------------------------------------------------------------------------- /include/FedTree/Tree/splitpoint.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Kelly Yung on 2020/11/23. 3 | // 4 | 5 | #ifndef FEDTREE_SPLITPOINT_H 6 | #define FEDTREE_SPLITPOINT_H 7 | 8 | #include "FedTree/Tree/tree.h" 9 | 10 | 11 | class SplitPoint { 12 | public: 13 | float_type gain; 14 | GHPair fea_missing_gh;//missing gh in this segment 15 | GHPair rch_sum_gh;//right child total gh (missing gh included if default2right) 16 | bool default_right; 17 | int nid; 18 | 19 | //split condition 20 | int split_fea_id; 21 | float_type fval;//split on this feature value (for exact) 22 | unsigned char split_bid;//split on this bin id (for hist) 23 | 24 | bool no_split_value_update; //there is no split value update. Used in build_tree_by_predefined_structure. 25 | 26 | SplitPoint() { 27 | nid = -1; 28 | split_fea_id = -1; 29 | gain = 0; 30 | no_split_value_update=false; 31 | } 32 | 33 | SplitPoint(const SplitPoint& copy){ 34 | gain = copy.gain; 35 | fea_missing_gh.g = copy.fea_missing_gh.g; 36 | fea_missing_gh.h = copy.fea_missing_gh.h; 37 | rch_sum_gh.g = copy.rch_sum_gh.g; 38 | rch_sum_gh.h = copy.rch_sum_gh.h; 39 | default_right = copy.default_right; 40 | nid = copy.nid; 41 | split_fea_id = copy.split_fea_id; 42 | fval = copy.fval; 43 | split_bid = copy.split_bid; 44 | no_split_value_update = copy.no_split_value_update; 45 | } 46 | 47 | friend std::ostream &operator<<(std::ostream &output, const SplitPoint &sp) { 48 | output << sp.gain << "/" << sp.split_fea_id << "/" << sp.nid << "/" << sp.rch_sum_gh; 49 | return output; 50 | } 51 | }; 52 | #endif //FEDTREE_SPLITPOINT_H 53 | -------------------------------------------------------------------------------- /include/FedTree/Tree/tree.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 10/13/20. 3 | // The tree structure is referring to the design of ThunderGBM: https://github.com/Xtra-Computing/thundergbm/blob/master/include/thundergbm/tree.h 4 | // 5 | 6 | #ifndef FEDTREE_TREE_H 7 | #define FEDTREE_TREE_H 8 | 9 | #include "sstream" 10 | #include "FedTree/syncarray.h" 11 | #include "GBDTparam.h" 12 | 13 | 14 | //class SplitPoint { 15 | //public: 16 | // float_type gain; 17 | // GHPair fea_missing_gh;//missing gh in this segment 18 | // GHPair rch_sum_gh;//right child total gh (missing gh included if default2right) 19 | // bool default_right; 20 | // int nid; 21 | // 22 | // //split condition 23 | // int split_fea_id; 24 | // float_type fval;//split on this feature value (for exact) 25 | // unsigned char split_bid;//split on this bin id (for hist) 26 | // 27 | // SplitPoint() { 28 | // nid = -1; 29 | // split_fea_id = -1; 30 | // gain = 0; 31 | // } 32 | // 33 | // friend std::ostream &operator<<(std::ostream &output, const SplitPoint &sp) { 34 | // output << sp.gain << "/" << sp.split_fea_id << "/" << sp.nid << "/" << sp.rch_sum_gh; 35 | // return output; 36 | // } 37 | //}; 38 | 39 | 40 | class Tree{ 41 | public: 42 | struct TreeNode { 43 | int final_id;// node id after pruning, may not equal to node index 44 | int lch_index;// index of left child 45 | int rch_index;// index of right child 46 | int parent_index;// index of parent node 47 | float_type gain;// gain of splitting this node 48 | float_type base_weight; 49 | int split_feature_id; 50 | int pid; 51 | float_type split_value; 52 | unsigned char split_bid; 53 | bool default_right; 54 | bool is_leaf; 55 | bool is_valid;// non-valid nodes are those that are "children" of leaf nodes 56 | bool is_pruned;// pruned after pruning 57 | 58 | GHPair sum_gh_pair; 59 | int n_instances = 0; // number of instances inside the node. 60 | 61 | friend std::ostream &operator<<(std::ostream &os, 62 | const TreeNode &node); 63 | 64 | HOST_DEVICE void calc_weight(float_type lambda) { 65 | this->base_weight = -sum_gh_pair.g / (sum_gh_pair.h + lambda); 66 | } 67 | 68 | HOST_DEVICE bool splittable() const { 69 | return !is_leaf && is_valid; 70 | } 71 | 72 | HOST_DEVICE TreeNode(const TreeNode& copy){ 73 | final_id = copy.final_id; 74 | lch_index = copy.lch_index; 75 | rch_index = copy.rch_index; 76 | parent_index = copy.parent_index; 77 | gain = copy.gain; 78 | base_weight = copy.base_weight; 79 | split_feature_id = copy.split_feature_id; 80 | pid = copy.pid; 81 | split_value = copy.split_value; 82 | split_bid = copy.split_bid; 83 | default_right = copy.default_right; 84 | is_leaf = copy.is_leaf; 85 | is_valid = copy.is_valid; 86 | is_pruned = copy.is_pruned; 87 | sum_gh_pair.g = copy.sum_gh_pair.g; 88 | sum_gh_pair.h = copy.sum_gh_pair.h; 89 | n_instances = copy.n_instances; 90 | } 91 | 92 | }; 93 | 94 | Tree() = default; 95 | 96 | Tree(const Tree &tree) { 97 | nodes.resize(tree.nodes.size()); 98 | nodes.copy_from(tree.nodes); 99 | n_nodes_level = tree.n_nodes_level; 100 | final_depth = tree.final_depth; 101 | } 102 | 103 | Tree &operator=(const Tree &tree) { 104 | nodes.resize(tree.nodes.size()); 105 | nodes.copy_from(tree.nodes); 106 | n_nodes_level = tree.n_nodes_level; 107 | final_depth = tree.final_depth; 108 | return *this; 109 | } 110 | 111 | void init_CPU(const GHPair sum_gh, const GBDTParam ¶m); 112 | 113 | void init_CPU(const SyncArray &gradients, const GBDTParam ¶m); 114 | 115 | void init_structure(int depth); 116 | 117 | // TODO: GPU initialization 118 | // void init2(const SyncArray &gradients, const GBDTParam ¶m); 119 | 120 | string dump(int depth) const; 121 | 122 | 123 | SyncArray nodes; 124 | //n_nodes_level[i+1] - n_nodes_level[i] stores the number of nodes in level i, for hybrid trainer 125 | vector n_nodes_level; 126 | //for hybrid trainer 127 | int final_depth; 128 | 129 | 130 | 131 | 132 | void prune_self(float_type gamma); 133 | 134 | void compute_leaf_value(); 135 | 136 | private: 137 | void preorder_traversal(int nid, int max_depth, int depth, string &s) const; 138 | 139 | int try_prune_leaf(int nid, int np, float_type gamma, vector &leaf_child_count); 140 | 141 | void reorder_nid(); 142 | }; 143 | 144 | #endif //FEDTREE_TREE_H 145 | -------------------------------------------------------------------------------- /include/FedTree/Tree/tree_builder.h: -------------------------------------------------------------------------------- 1 | 2 | // 3 | // Created by liqinbin on 10/27/20. 4 | // 5 | 6 | #ifndef FEDTREE_TREE_BUILDER_H 7 | #define FEDTREE_TREE_BUILDER_H 8 | 9 | #include "FedTree/dataset.h" 10 | //#include "FedTree/Encryption/HE.h" 11 | #include "function_builder.h" 12 | #include "tree.h" 13 | #include "splitpoint.h" 14 | #include "hist_cut.h" 15 | 16 | class TreeBuilder : public FunctionBuilder{ 17 | public: 18 | virtual void find_split(int level) = 0; 19 | 20 | virtual void find_split_by_predefined_features(int level) = 0; 21 | 22 | virtual void update_ins2node_id() = 0; 23 | 24 | vector build_approximate(const SyncArray &gradients, bool update_y_predict = true) override; 25 | 26 | void build_tree_by_predefined_structure(const SyncArray &gradients, vector &trees); 27 | 28 | void build_init(const GHPair sum_gh, int k) override; 29 | 30 | void build_init(const SyncArray &gradients, int k) override; 31 | 32 | void init(DataSet &dataset, const GBDTParam ¶m) override; 33 | void init_nosortdataset(DataSet &dataset, const GBDTParam ¶m); 34 | 35 | void update_tree(); 36 | 37 | void update_tree_in_a_node(int node_id); 38 | 39 | Tree get_tree() override { 40 | return this->trees; 41 | } 42 | 43 | void set_tree(Tree tree) override { 44 | trees = Tree(tree); 45 | } 46 | 47 | void set_y_predict(int k) override; 48 | 49 | virtual void update_tree_by_sp_values(); 50 | 51 | void predict_in_training(int k); 52 | 53 | // virtual void split_point_all_reduce(int depth); 54 | // Refer to ThunderGBM hist_tree_builder.cu find_split 55 | 56 | // void get_split(int level, int device_id); 57 | 58 | void find_split (SyncArray &sp, int n_nodes_in_level, Tree tree, SyncArray best_idx_gain, int nid_offset, HistCut cut, SyncArray hist, int n_bins); 59 | 60 | void merge_histograms(); 61 | 62 | void update_gradients(SyncArray &gradients, SyncArray &y, SyncArray &y_p); 63 | 64 | 65 | 66 | 67 | // virtual void init(const DataSet &dataset, const GBDTParam ¶m) { 68 | // this->param = param; 69 | // }; 70 | 71 | //for multi-device 72 | // virtual void ins2node_id_all_reduce(int depth); 73 | 74 | // virtual void split_point_all_reduce(int depth); 75 | 76 | virtual ~TreeBuilder(){}; 77 | 78 | SyncArray gradients; 79 | 80 | int n_instances; 81 | Tree trees; 82 | SyncArray ins2node_id; 83 | SyncArray sp; 84 | bool has_split; 85 | 86 | protected: 87 | // vector shards; 88 | // DataSet* dataset; 89 | DataSet sorted_dataset; 90 | }; 91 | 92 | #endif //FEDTREE_TREE_BUILDER_H -------------------------------------------------------------------------------- /include/FedTree/booster.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/11/20. 3 | // 4 | 5 | #ifndef FEDTREE_BOOSTER_H 6 | #define FEDTREE_BOOSTER_H 7 | 8 | #include 9 | #include 10 | // function_builder 11 | #include 12 | #include 13 | #include 14 | #include "FedTree/common.h" 15 | #include "FedTree/syncarray.h" 16 | #include "FedTree/Tree/tree.h" 17 | #include "FedTree/DP/noises.h" 18 | 19 | 20 | //#include "row_sampler.h" 21 | 22 | 23 | 24 | class Booster { 25 | public: 26 | void init(DataSet &dataSet, const GBDTParam ¶m, bool get_cut_points = 1); 27 | 28 | void init (const GBDTParam ¶m, int n_instances); 29 | 30 | void reinit(DataSet &dataSet, const GBDTParam ¶m); 31 | 32 | SyncArray get_gradients(); 33 | 34 | void set_gradients(SyncArray &gh); 35 | 36 | // void encrypt_gradients(AdditivelyHE::PaillierPublicKey pk); 37 | // 38 | // void decrypt_gradients(AdditivelyHE::PaillierPrivateKey privateKey); 39 | 40 | void add_noise_to_gradients(float variance); 41 | 42 | void update_gradients(); 43 | 44 | void boost(vector> &boosted_model); 45 | 46 | void boost_without_prediction(vector> &boosted_model); 47 | 48 | GBDTParam param; 49 | std::unique_ptr fbuilder; 50 | SyncArray gradients; 51 | 52 | std::unique_ptr metric; 53 | private: 54 | int n_devices; 55 | std::unique_ptr obj; 56 | SyncArray y; 57 | 58 | // GBDTParam param; 59 | 60 | }; 61 | 62 | 63 | 64 | 65 | #endif //FEDTREE_BOOSTER_H 66 | -------------------------------------------------------------------------------- /include/FedTree/config.h.in: -------------------------------------------------------------------------------- 1 | #cmakedefine DATASET_DIR "@DATASET_DIR@" 2 | #cmakedefine USE_CUDA 3 | #cmakedefine THRUST_IGNORE_DEPRECATED_CPP_DIALECT 4 | #cmakedefine USE_DOUBLE -------------------------------------------------------------------------------- /include/FedTree/dataset.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 10/13/20. 3 | // 4 | 5 | #ifndef FEDTREE_DATASET_H 6 | #define FEDTREE_DATASET_H 7 | 8 | 9 | #include "FedTree/FL/FLparam.h" 10 | #include "common.h" 11 | #include "syncarray.h" 12 | 13 | class DataSet{ 14 | public: 15 | ///load dataset from file 16 | // void load_from_file(const string& file_name, FLParam ¶m); 17 | void load_from_file(string file_name, FLParam ¶m); 18 | void load_from_csv(string file_name, FLParam ¶m); 19 | // void load_from_file_dense(string file_name, FLParam ¶m); 20 | void load_from_files(vectorfile_names, FLParam ¶m); 21 | void load_group_file(string file_name); 22 | void group_label(); 23 | void group_label_without_reorder(int n_class); 24 | void load_from_sparse(int n_instances, float *csr_val, int *csr_row_ptr, int *csr_col_idx, float *y, 25 | int *group, int num_group, GBDTParam ¶m); 26 | void load_csc_from_file(string file_name, FLParam ¶m, int const nfeatures=500); 27 | void csr_to_csc(); 28 | void csc_to_csr(); 29 | void get_subset(vector &idx, DataSet &subset); 30 | 31 | size_t n_features() const; 32 | 33 | size_t n_instances() const; 34 | 35 | // vector> dense_mtx; 36 | vector csr_val; 37 | vector csr_row_ptr; 38 | vector csr_col_idx; 39 | vector y; 40 | size_t n_features_; 41 | vector group; 42 | vector label; 43 | std::map label_map; 44 | 45 | 46 | // csc variables 47 | vector csc_val; 48 | vector csc_row_idx; 49 | vector csc_col_ptr; 50 | 51 | //Todo: SyncArray version 52 | // SyncArray csr_val; 53 | // SyncArray csr_row_ptr; 54 | // SyncArray csr_col_idx; 55 | // 56 | // SyncArray csc_val; 57 | // SyncArray csc_row_idx; 58 | // SyncArray csc_col_ptr; 59 | // whether the dataset is to big 60 | bool use_cpu = true; 61 | bool has_csc = false; 62 | bool is_classification = false; 63 | bool has_label = true; 64 | }; 65 | 66 | #endif //FEDTREE_DATASET_H 67 | -------------------------------------------------------------------------------- /include/FedTree/metric/metric.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/11/20. 3 | // 4 | 5 | #ifndef FEDTREE_METRIC_H 6 | #define FEDTREE_METRIC_H 7 | 8 | 9 | #include "FedTree/syncarray.h" 10 | #include "FedTree/dataset.h" 11 | 12 | class Metric { 13 | public: 14 | virtual float_type get_score(const SyncArray &y_p) const = 0; 15 | 16 | virtual void configure(const GBDTParam ¶m, const DataSet &dataset); 17 | 18 | static Metric *create(string name); 19 | 20 | virtual string get_name() const = 0; 21 | 22 | protected: 23 | SyncArray y; 24 | }; 25 | 26 | 27 | #endif //FEDTREE_METRIC_H 28 | -------------------------------------------------------------------------------- /include/FedTree/metric/multiclass_metric.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/11/20. 3 | // 4 | 5 | #ifndef FEDTREE_MULTICLASS_METRIC_H 6 | #define FEDTREE_MULTICLASS_METRIC_H 7 | 8 | 9 | #include "FedTree/common.h" 10 | #include "metric.h" 11 | 12 | class MulticlassMetric: public Metric { 13 | public: 14 | void configure(const GBDTParam ¶m, const DataSet &dataset) override { 15 | Metric::configure(param, dataset); 16 | num_class = param.num_class; 17 | CHECK_EQ(num_class, dataset.label.size()); 18 | label.resize(num_class); 19 | label.copy_from(dataset.label.data(), num_class); 20 | } 21 | 22 | protected: 23 | int num_class; 24 | SyncArray label; 25 | }; 26 | 27 | class MulticlassAccuracy: public MulticlassMetric { 28 | public: 29 | float_type get_score(const SyncArray &y_p) const override; 30 | 31 | string get_name() const override { return "multi-class accuracy"; } 32 | }; 33 | 34 | class BinaryClassMetric: public MulticlassAccuracy{ 35 | public: 36 | float_type get_score(const SyncArray &y_p) const override; 37 | 38 | string get_name() const override { return "AUC";} 39 | }; 40 | 41 | 42 | #endif //FEDTREE_MULTICLASS_METRIC_H 43 | -------------------------------------------------------------------------------- /include/FedTree/metric/pointwise_metric.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/11/20. 3 | // 4 | 5 | #ifndef FEDTREE_POINTWISE_METRIC_H 6 | #define FEDTREE_POINTWISE_METRIC_H 7 | 8 | #include "metric.h" 9 | 10 | class RMSE : public Metric { 11 | public: 12 | float_type get_score(const SyncArray &y_p) const override; 13 | 14 | string get_name() const override { return "RMSE"; } 15 | }; 16 | 17 | #endif //FEDTREE_POINTWISE_METRIC_H 18 | -------------------------------------------------------------------------------- /include/FedTree/metric/ranking_metric.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/11/20. 3 | // 4 | 5 | #ifndef FEDTREE_RANKING_METRIC_H 6 | #define FEDTREE_RANKING_METRIC_H 7 | 8 | #include "metric.h" 9 | 10 | class RankListMetric : public Metric { 11 | public: 12 | float_type get_score(const SyncArray &y_p) const override; 13 | 14 | void configure(const GBDTParam ¶m, const DataSet &dataset) override; 15 | 16 | static void configure_gptr(const vector &group, vector &gptr); 17 | 18 | protected: 19 | virtual float_type eval_query_group(vector &y, vector &y_p, int group_id) const = 0; 20 | 21 | vector gptr; 22 | int n_group; 23 | int topn; 24 | }; 25 | 26 | 27 | class MAP : public RankListMetric { 28 | public: 29 | string get_name() const override { return "MAP"; } 30 | 31 | protected: 32 | float_type eval_query_group(vector &y, vector &y_p, int group_id) const override; 33 | }; 34 | 35 | class NDCG : public RankListMetric { 36 | public: 37 | string get_name() const override { return "NDCG"; }; 38 | 39 | void configure(const GBDTParam ¶m, const DataSet &dataset) override; 40 | 41 | inline HOST_DEVICE static float_type discounted_gain(int label, int rank) { 42 | return ((1 << label) - 1) / log2f(rank + 1 + 1); 43 | } 44 | 45 | static void get_IDCG(const vector &gptr, const vector &y, vector &idcg); 46 | 47 | protected: 48 | float_type eval_query_group(vector &y, vector &y_p, int group_id) const override; 49 | 50 | private: 51 | vector idcg; 52 | }; 53 | 54 | 55 | 56 | #endif //FEDTREE_RANKING_METRIC_H 57 | -------------------------------------------------------------------------------- /include/FedTree/objective/multiclass_obj.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/15/20. 3 | // 4 | 5 | #ifndef FEDTREE_MULTICLASS_OBJ_H 6 | #define FEDTREE_MULTICLASS_OBJ_H 7 | 8 | #include "objective_function.h" 9 | //#include "FedTree/util/device_lambda.h" 10 | 11 | class Softmax : public ObjectiveFunction { 12 | public: 13 | void get_gradient(const SyncArray &y, const SyncArray &y_p, 14 | SyncArray &gh_pair) override; 15 | 16 | void predict_transform(SyncArray &y) override; 17 | 18 | void configure(GBDTParam param, const DataSet &dataset) override; 19 | 20 | string default_metric_name() override { return "macc"; } 21 | 22 | virtual ~Softmax() override = default; 23 | 24 | protected: 25 | int num_class; 26 | SyncArray label; 27 | }; 28 | 29 | 30 | class SoftmaxProb : public Softmax { 31 | public: 32 | void predict_transform(SyncArray &y) override; 33 | 34 | ~SoftmaxProb() override = default; 35 | 36 | }; 37 | 38 | #endif //FEDTREE_MULTICLASS_OBJ_H 39 | -------------------------------------------------------------------------------- /include/FedTree/objective/objective_function.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Kelly Yung on 2020/11/27. 3 | // 4 | 5 | #ifndef FEDTREE_OBJECTIVE_FUNCTION_H 6 | #define FEDTREE_OBJECTIVE_FUNCTION_H 7 | 8 | #include 9 | #include 10 | 11 | class ObjectiveFunction { 12 | public: 13 | float constant_h = 0.0; 14 | virtual void 15 | get_gradient(const SyncArray &y, const SyncArray &y_p, SyncArray &gh_pair) = 0; 16 | virtual void 17 | predict_transform(SyncArray &y){}; 18 | virtual void configure(GBDTParam param, const DataSet &dataset) {constant_h = param.constant_h;} ; 19 | virtual string default_metric_name() = 0; 20 | 21 | static ObjectiveFunction* create(string name); 22 | 23 | static bool need_load_group_file(string name); 24 | static bool need_group_label(string name); 25 | virtual ~ObjectiveFunction() = default; 26 | }; 27 | 28 | #endif //FEDTREE_OBJECTIVE_FUNCTION_H 29 | -------------------------------------------------------------------------------- /include/FedTree/objective/ranking_obj.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/15/20. 3 | // 4 | 5 | #ifndef FEDTREE_RANKING_OBJ_H 6 | #define FEDTREE_RANKING_OBJ_H 7 | 8 | #include "objective_function.h" 9 | 10 | /** 11 | * 12 | * https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/MSR-TR-2010-82.pdf 13 | */ 14 | class LambdaRank : public ObjectiveFunction { 15 | public: 16 | void get_gradient(const SyncArray &y, const SyncArray &y_p, 17 | SyncArray &gh_pair) override; 18 | 19 | void configure(GBDTParam param, const DataSet &dataset) override; 20 | 21 | string default_metric_name() override; 22 | 23 | virtual ~LambdaRank() override = default; 24 | 25 | protected: 26 | virtual inline float_type get_delta_z(float_type labelI, float_type labelJ, int rankI, int rankJ, int group_id) { return 1; }; 27 | 28 | vector gptr;//group start position 29 | int n_group; 30 | 31 | float_type sigma; 32 | }; 33 | 34 | class LambdaRankNDCG : public LambdaRank { 35 | public: 36 | void configure(GBDTParam param, const DataSet &dataset) override; 37 | 38 | string default_metric_name() override; 39 | 40 | protected: 41 | float_type get_delta_z(float_type labelI, float_type labelJ, int rankI, int rankJ, int group_id) override; 42 | 43 | private: 44 | vector idcg; 45 | }; 46 | 47 | 48 | 49 | #endif //FEDTREE_RANKING_OBJ_H 50 | -------------------------------------------------------------------------------- /include/FedTree/objective/regression_obj.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Kelly Yung on 2020/11/27. 3 | // 4 | 5 | #ifndef FEDTREE_REGRESSION_OBJ_H 6 | #define FEDTREE_REGRESSION_OBJ_H 7 | 8 | #include "objective_function.h" 9 | //#include "FedTree/util/device_lambda.h" 10 | #include "math.h" 11 | 12 | template class Loss> 13 | class RegressionObj : public ObjectiveFunction { 14 | public: 15 | void get_gradient(const SyncArray &y, const SyncArray &y_p, 16 | SyncArray &gh_pair) override { 17 | // CHECK_EQ(y.size(), y_p.size())< class Loss> 44 | class LogClsObj: public RegressionObj{ 45 | public: 46 | void get_gradient(const SyncArray &y, const SyncArray &y_p, 47 | SyncArray &gh_pair) override { 48 | auto y_data = y.host_data(); 49 | auto y_p_data = y_p.host_data(); 50 | auto gh_pair_data = gh_pair.host_data(); 51 | if(this->constant_h != 0){ 52 | for (int i = 0; i < y.size(); i++){ 53 | gh_pair_data[i] = Loss::gradient(y_data[i], y_p_data[i]); 54 | gh_pair_data[i].h = this->constant_h; 55 | } 56 | } 57 | else{ 58 | for (int i = 0; i < y.size(); i++){ 59 | gh_pair_data[i] = Loss::gradient(y_data[i], y_p_data[i]); 60 | } 61 | } 62 | 63 | } 64 | void predict_transform(SyncArray &y) { 65 | //this method transform y(#class * #instances) into y(#instances) 66 | auto yp_data = y.host_data(); 67 | auto label_data = label.host_data(); 68 | int num_class = this->num_class; 69 | int n_instances = y.size(); 70 | for (int i = 0; i < n_instances; i++) { 71 | int max_k = (yp_data[i] > 0) ? 1 : 0; 72 | yp_data[i] = label_data[max_k]; 73 | } 74 | SyncArray < float_type > temp_y(n_instances); 75 | temp_y.copy_from(y.host_data(), n_instances); 76 | y.resize(n_instances); 77 | y.copy_from(temp_y); 78 | } 79 | string default_metric_name() override{ 80 | return "error"; 81 | } 82 | void configure(GBDTParam param, const DataSet &dataset) { 83 | this->constant_h = param.constant_h; 84 | num_class = param.num_class; 85 | label.resize(num_class); 86 | CHECK_EQ(dataset.label.size(), num_class)< 11 | #include "dataset.h" 12 | #include "Tree/tree.h" 13 | 14 | // Todo: parse the parameters to FLparam. refer to ThunderGBM parser.h https://github.com/Xtra-Computing/thundergbm/blob/master/include/thundergbm/parser.h 15 | class Parser { 16 | public: 17 | void parse_param(FLParam &fl_param, int argc, char **argv); 18 | void load_model(string model_path, GBDTParam &model_param, vector> &boosted_model); 19 | void save_model(string model_path, GBDTParam &model_param, vector> &boosted_model); 20 | }; 21 | 22 | #endif //FEDTREE_PARSER_H 23 | -------------------------------------------------------------------------------- /include/FedTree/predictor.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Kelly Yung on 2020/12/3. Code taken from ThunderGBM. 3 | // 4 | 5 | #ifndef FEDTREE_PREDICTOR_H 6 | #define FEDTREE_PREDICTOR_H 7 | 8 | #include "FedTree/Tree/tree.h" 9 | #include 10 | 11 | class Predictor{ 12 | public: 13 | void get_y_predict (const GBDTParam &model_param, const vector> &boosted_model, 14 | const DataSet &dataSet, SyncArray &y_predict); 15 | }; 16 | 17 | #endif //FEDTREE_PREDICTOR_H -------------------------------------------------------------------------------- /include/FedTree/syncarray.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 10/14/20. 3 | // ThunderGBM syncarray.h: https://github.com/Xtra-Computing/thundergbm/blob/master/include/thundergbm/syncarray.h 4 | // Under Apache-2.0 License 5 | // Copyright (c) 2020 by jiashuai 6 | // 7 | 8 | #ifndef FEDTREE_SYNCARRAY_H 9 | #define FEDTREE_SYNCARRAY_H 10 | 11 | #include "FedTree/util/log.h" 12 | #include "syncmem.h" 13 | 14 | /** 15 | * @brief Wrapper of SyncMem with a type 16 | * @tparam T type of element 17 | */ 18 | template 19 | class SyncArray : public el::Loggable { 20 | public: 21 | /** 22 | * initialize class that can store given count of elements 23 | * @param count the given count 24 | */ 25 | explicit SyncArray(size_t count) : mem(new SyncMem(sizeof(T) * count)), size_(count) { 26 | } 27 | 28 | SyncArray() : mem(nullptr), size_(0) {} 29 | 30 | ~SyncArray() { delete mem; }; 31 | 32 | const T *host_data() const { 33 | to_host(); 34 | return static_cast(mem->host_data()); 35 | }; 36 | 37 | const T *device_data() const { 38 | to_device(); 39 | return static_cast(mem->device_data()); 40 | }; 41 | 42 | T *host_data() { 43 | to_host(); 44 | return static_cast(mem->host_data()); 45 | }; 46 | 47 | T *device_data() { 48 | to_device(); 49 | return static_cast(mem->device_data()); 50 | }; 51 | 52 | T *device_end() { 53 | return device_data() + size(); 54 | }; 55 | 56 | const T *device_end() const { 57 | return device_data() + size(); 58 | }; 59 | 60 | T *host_end() { 61 | return host_data() + size(); 62 | }; 63 | 64 | const T *host_end() const { 65 | return host_data() + size(); 66 | } 67 | 68 | void set_host_data(T *host_ptr) { 69 | mem->set_host_data(host_ptr); 70 | } 71 | 72 | void set_device_data(T *device_ptr) { 73 | mem->set_device_data(device_ptr); 74 | } 75 | 76 | void to_host() const { 77 | CHECK_GT(size_, 0); 78 | mem->to_host(); 79 | } 80 | 81 | void to_device() const { 82 | CHECK_GT(size_, 0); 83 | mem->to_device(); 84 | } 85 | 86 | /** 87 | * copy device data. This will call to_device() implicitly. 88 | * @param source source data pointer (data can be on host or device) 89 | * @param count the count of elements 90 | */ 91 | void copy_from(const T *source, size_t count) { 92 | 93 | #ifdef USE_CUDA_ARRAY 94 | thunder::device_mem_copy(mem->device_data(), source, sizeof(T) * count); 95 | #else 96 | memcpy(mem->host_data(), source, sizeof(T) * count); 97 | #endif 98 | }; 99 | 100 | void copy_from(const SyncArray &source) { 101 | 102 | CHECK_EQ(size(), source.size()) << "destination and source count doesn't match"; 103 | #ifdef USE_CUDA_ARRAY 104 | if (get_owner_id() == source.get_owner_id()) 105 | copy_from(source.device_data(), source.size()); 106 | else 107 | CUDA_CHECK(cudaMemcpyPeer(mem->device_data(), get_owner_id(), source.device_data(), source.get_owner_id(), 108 | source.mem_size())); 109 | #else 110 | copy_from(source.host_data(), source.size()); 111 | #endif 112 | }; 113 | 114 | /** 115 | * resize to a new size. This will also clear all data. 116 | * @param count 117 | */ 118 | void resize(size_t count) { 119 | if(mem != nullptr || mem != NULL) { 120 | delete mem; 121 | } 122 | mem = new SyncMem(sizeof(T) * count); 123 | this->size_ = count; 124 | }; 125 | 126 | /* 127 | * resize to a new size. This will not clear the origin data. 128 | * @param count 129 | */ 130 | void resize_without_delete(size_t count) { 131 | // delete mem; 132 | mem = new SyncMem(sizeof(T) * count); 133 | this->size_ = count; 134 | }; 135 | 136 | 137 | size_t mem_size() const {//number of bytes 138 | return mem->size(); 139 | } 140 | 141 | size_t size() const {//number of values 142 | return size_; 143 | } 144 | 145 | SyncMem::HEAD head() const { 146 | return mem->head(); 147 | } 148 | 149 | void log(el::base::type::ostream_t &ostream) const override { 150 | int i; 151 | ostream << "["; 152 | const T *data = host_data(); 153 | for (i = 0; i < size() - 1 && i < el::base::consts::kMaxLogPerContainer - 1; ++i) { 154 | // for (i = 0; i < size() - 1; ++i) { 155 | ostream << data[i] << ","; 156 | } 157 | ostream << host_data()[i]; 158 | if (size() <= el::base::consts::kMaxLogPerContainer) { 159 | ostream << "]"; 160 | } else { 161 | ostream << ", ...(" << size() - el::base::consts::kMaxLogPerContainer << " more)"; 162 | } 163 | }; 164 | #ifdef USE_CUDA_ARRAY 165 | int get_owner_id() const { 166 | return mem->get_owner_id(); 167 | } 168 | #endif 169 | //move constructor 170 | SyncArray(SyncArray &&rhs) noexcept : mem(rhs.mem), size_(rhs.size_) { 171 | rhs.mem = nullptr; 172 | rhs.size_ = 0; 173 | } 174 | 175 | //move assign 176 | SyncArray &operator=(SyncArray &&rhs) noexcept { 177 | delete mem; 178 | mem = rhs.mem; 179 | size_ = rhs.size_; 180 | 181 | rhs.mem = nullptr; 182 | rhs.size_ = 0; 183 | return *this; 184 | } 185 | 186 | SyncArray(const SyncArray &) = delete; 187 | 188 | SyncArray &operator=(const SyncArray &) = delete; 189 | 190 | private: 191 | SyncMem *mem; 192 | size_t size_; 193 | }; 194 | 195 | //SyncArray for multiple devices 196 | template 197 | class MSyncArray : public vector> { 198 | public: 199 | explicit MSyncArray(size_t n_device) : base_class(n_device) {}; 200 | 201 | explicit MSyncArray(size_t n_device, size_t size) : base_class(n_device) { 202 | for (int i = 0; i < n_device; ++i) { 203 | this->at(i) = SyncArray(size); 204 | } 205 | }; 206 | 207 | MSyncArray() : base_class() {}; 208 | 209 | //move constructor and assign 210 | MSyncArray(MSyncArray &&) = default; 211 | 212 | MSyncArray &operator=(MSyncArray &&) = default; 213 | 214 | MSyncArray(const MSyncArray &) = delete; 215 | 216 | MSyncArray &operator=(const MSyncArray &) = delete; 217 | 218 | private: 219 | typedef vector> base_class; 220 | }; 221 | #endif //FEDTREE_SYNCARRAY_H 222 | -------------------------------------------------------------------------------- /include/FedTree/trainer.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/11/20. 3 | // 4 | 5 | #ifndef FEDTREE_TRAINER_H 6 | #define FEDTREE_TRAINER_H 7 | 8 | #include "FedTree/common.h" 9 | #include "FedTree/Tree/tree.h" 10 | #include "FedTree/dataset.h" 11 | 12 | 13 | class TreeTrainer{ 14 | public: 15 | vector> train (GBDTParam ¶m, const DataSet &dataset); 16 | }; 17 | 18 | #endif //FEDTREE_TRAINER_H 19 | -------------------------------------------------------------------------------- /include/FedTree/util/device_lambda.cuh: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 10/14/20. 3 | // ThunderGBM device_lambda.h: https://github.com/Xtra-Computing/thundergbm/blob/master/include/thundergbm/util/device_lambda.h 4 | // Under Apache-2.0 license 5 | // copyright (c) 2020 jiashuai 6 | // 7 | 8 | #ifndef FEDTREE_DEVICE_LAMBDA_H 9 | #define FEDTREE_DEVICE_LAMBDA_H 10 | 11 | #ifdef USE_CUDA 12 | 13 | #include "FedTree/common.h" 14 | 15 | template 16 | __global__ void lambda_kernel(size_t len, L lambda) { 17 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += blockDim.x * gridDim.x) { 18 | lambda(i); 19 | } 20 | } 21 | 22 | template 23 | __global__ void anonymous_kernel_k(L lambda) { 24 | lambda(); 25 | } 26 | 27 | template 28 | __global__ void lambda_2d_sparse_kernel(const int *len2, L lambda) { 29 | int i = blockIdx.x; 30 | int begin = len2[i]; 31 | int end = len2[i + 1]; 32 | for (int j = begin + blockIdx.y * blockDim.x + threadIdx.x; j < end; j += blockDim.x * gridDim.y) { 33 | lambda(i, j); 34 | } 35 | } 36 | 37 | ///p100 has 56 MPs, using 32*56 thread blocks 38 | template 39 | //template 40 | inline void device_loop(int len, L lambda) { 41 | if (len > 0) { 42 | lambda_kernel << < NUM_BLOCK, BLOCK_SIZE >> > (len, lambda); 43 | cudaDeviceSynchronize(); 44 | CUDA_CHECK(cudaPeekAtLastError()); 45 | } 46 | } 47 | 48 | template 49 | inline void anonymous_kernel(L lambda, size_t smem_size = 0, int NUM_BLOCK = 32 * 56, int BLOCK_SIZE = 256) { 50 | anonymous_kernel_k<< < NUM_BLOCK, BLOCK_SIZE, smem_size >> > (lambda); 51 | cudaDeviceSynchronize(); 52 | CUDA_CHECK(cudaPeekAtLastError()); 53 | } 54 | 55 | /** 56 | * @brief: (len1 x NUM_BLOCK) is the total number of blocks; len2 is an array of lengths. 57 | */ 58 | template 59 | void device_loop_2d(int len1, const int *len2, L lambda, unsigned int NUM_BLOCK = 4 * 56, 60 | unsigned int BLOCK_SIZE = 256) { 61 | if (len1 > 0) { 62 | lambda_2d_sparse_kernel << < dim3(len1, NUM_BLOCK), BLOCK_SIZE >> > (len2, lambda); 63 | cudaDeviceSynchronize(); 64 | CUDA_CHECK(cudaPeekAtLastError()); 65 | } 66 | } 67 | #endif 68 | 69 | #endif //FEDTREE_DEVICE_LAMBDA_H 70 | -------------------------------------------------------------------------------- /include/FedTree/util/dirichlet.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by hanyuxuan on 22/10/20. 3 | // 4 | #include "FedTree/common.h" 5 | #include 6 | 7 | #ifndef FEDTREE_DIRICHLET_H 8 | #define FEDTREE_DIRICHLET_H 9 | 10 | template 11 | class dirichlet_distribution { 12 | public: 13 | dirichlet_distribution(const vector &); 14 | 15 | void set_params(const vector &); 16 | 17 | vector get_params(); 18 | 19 | vector operator()(RNG &); 20 | 21 | private: 22 | vector alpha; 23 | vector> gamma; 24 | }; 25 | 26 | template 27 | dirichlet_distribution::dirichlet_distribution(const vector &alpha) { 28 | set_params(alpha); 29 | } 30 | 31 | template 32 | void dirichlet_distribution::set_params(const vector &new_params) { 33 | alpha = new_params; 34 | vector> new_gamma(alpha.size()); 35 | for (int i = 0; i < alpha.size(); ++i) { 36 | std::gamma_distribution<> temp(alpha[i], 1); 37 | new_gamma[i] = temp; 38 | } 39 | gamma = new_gamma; 40 | } 41 | 42 | template 43 | vector dirichlet_distribution::get_params() { 44 | return alpha; 45 | } 46 | 47 | template 48 | vector dirichlet_distribution::operator()(RNG &generator) { 49 | vector x(alpha.size()); 50 | float sum = 0.0; 51 | for (int i = 0; i < alpha.size(); ++i) { 52 | x[i] = gamma[i](generator); 53 | sum += x[i]; 54 | } 55 | for (float &xi : x) xi = xi / sum; 56 | return x; 57 | } 58 | 59 | #endif //FEDTREE_DIRICHLET_H 60 | -------------------------------------------------------------------------------- /include/FedTree/util/multi_device.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 10/14/20. 3 | // ThunderGBM multi_device.h: https://github.com/Xtra-Computing/thundergbm/blob/master/include/thundergbm/util/multi_device.h 4 | // Under Apache-2.0 license 5 | // copyright (c) 2020 jiashuai 6 | // 7 | 8 | #ifndef FEDTREE_MULTI_DEVICE_H 9 | #define FEDTREE_MULTI_DEVICE_H 10 | 11 | #ifdef USE_CUDA 12 | #include "FedTree/common.h" 13 | 14 | //switch to specific device and do something, then switch back to the original device 15 | //FIXME make this macro into a function? 16 | #define DO_ON_DEVICE(device_id, something) \ 17 | do { \ 18 | int org_device_id = 0; \ 19 | CUDA_CHECK(cudaGetDevice(&org_device_id)); \ 20 | CUDA_CHECK(cudaSetDevice(device_id)); \ 21 | something; \ 22 | CUDA_CHECK(cudaSetDevice(org_device_id)); \ 23 | } while (false) 24 | 25 | /** 26 | * Do something on multiple devices, then switch back to the original device 27 | * 28 | * 29 | * example: 30 | * 31 | * DO_ON_MULTI_DEVICES(n_devices, [&](int device_id){ 32 | * //do_something_on_device(device_id); 33 | * }); 34 | */ 35 | 36 | template 37 | void DO_ON_MULTI_DEVICES(int n_devices, L do_something) { 38 | int org_device_id = 0; 39 | CUDA_CHECK(cudaGetDevice(&org_device_id)); 40 | #pragma omp parallel for num_threads(n_devices) 41 | for (int device_id = 0; device_id < n_devices; device_id++) { 42 | CUDA_CHECK(cudaSetDevice(device_id)); 43 | do_something(device_id); 44 | } 45 | CUDA_CHECK(cudaSetDevice(org_device_id)); 46 | 47 | } 48 | 49 | #endif 50 | 51 | #endif //FEDTREE_MULTI_DEVICE_H 52 | -------------------------------------------------------------------------------- /python/README.md: -------------------------------------------------------------------------------- 1 | ## Install Python Package 2 | 3 | We provide a scikit-learn wrapper interface. Before you use the Python interface, you must [install](https://fedtree.readthedocs.io/en/latest/Installation.html) FedTree first. 4 | Then, you can run the following command to install the Python package from source. 5 | ```bash 6 | python setup.py install 7 | ``` 8 | 9 | ## Class 10 | 11 | We provide two classes, ```FLClassifier``` and ```FLRegressor```, where the first is for classification task and the second is for regression task. 12 | 13 | ### Parameters 14 | Please refer to [here](https://fedtree.readthedocs.io/en/latest/Parameters.html) for the list of parameters. 15 | 16 | 17 | ### Methods 18 | 19 | *fit(X, y)*:\ 20 | Fit the FedTree model according to the given training data. 21 | 22 | *predict(X)*:\ 23 | Perform prediction on samples in X. 24 | 25 | *save_model(model_path)*:\ 26 | Save the FedTree model to model_path. 27 | 28 | *load_model(model_path)*:\ 29 | Load the FedTree model from model_path. 30 | 31 | ## Examples 32 | Users can simply input parameters to these classes, call ```fit()``` and ```predict``` functions like models in scikit-learn. 33 | 34 | ```bash 35 | from fedtree import FLRegressor 36 | from sklearn.metrics import mean_squared_error 37 | from sklearn.datasets import load_svmlight_file 38 | x, y = load_svmlight_file("../dataset/test_dataset.txt") 39 | clf = FLRegressor(n_trees=10, n_parties=2, mode="horizontal", learning_rate=0.2, max_depth=4, objective="reg:linear") 40 | clf.fit(x, y) 41 | y_pred = clf.predict(x) 42 | rmse = mean_squared_error(y, y_pred, squared=False) 43 | print("rmse:", rmse) 44 | ``` 45 | 46 | Under ```examples``` directory, you can find two examples on how to use FedTree with Python. 47 | -------------------------------------------------------------------------------- /python/examples/classifier_example.py: -------------------------------------------------------------------------------- 1 | from fedtree import FLClassifier 2 | from sklearn.datasets import load_digits 3 | from sklearn.metrics import accuracy_score 4 | 5 | if __name__ == '__main__': 6 | x, y = load_digits(return_X_y=True) 7 | clf = FLClassifier(n_trees=2, mode="horizontal", n_parties=2, num_class=10,objective="multi:softmax") 8 | clf.fit(x, y) 9 | y_pred = clf.predict(x) 10 | y_pred_prob = clf.predict_proba(x) 11 | accuracy = accuracy_score(y, y_pred) 12 | print("accuracy:", accuracy) 13 | -------------------------------------------------------------------------------- /python/examples/regressor_example.py: -------------------------------------------------------------------------------- 1 | from fedtree import FLRegressor 2 | from sklearn.metrics import mean_squared_error 3 | from sklearn.datasets import load_svmlight_file 4 | 5 | if __name__ == '__main__': 6 | x, y = load_svmlight_file("../dataset/test_dataset.txt") 7 | clf = FLRegressor(n_trees=10, n_parties=2, mode="horizontal", learning_rate=0.2, max_depth=4, objective="reg:linear") 8 | clf.fit(x, y) 9 | y_pred = clf.predict(x) 10 | rmse = mean_squared_error(y, y_pred, squared=False) 11 | print("rmse:", rmse) 12 | -------------------------------------------------------------------------------- /python/fedtree/__init__.py: -------------------------------------------------------------------------------- 1 | name = "fedtree" 2 | from .fedtree import * -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import setuptools 3 | from shutil import copyfile 4 | from sys import platform 5 | 6 | dirname = path.dirname(path.abspath(__file__)) 7 | 8 | if platform == "linux" or platform == "linux2": 9 | lib_path = path.abspath(path.join(dirname, '../build/lib/libFedTree.so')) 10 | elif platform == "win32": 11 | lib_path = path.abspath(path.join(dirname, '../build/bin/Debug/libFedTree.dll')) 12 | elif platform == "darwin": 13 | lib_path = path.abspath(path.join(dirname, '../build/lib/libFedTree.dylib')) 14 | else: 15 | print("OS not supported!") 16 | exit() 17 | 18 | if not path.exists(path.join(dirname, "fedtree", path.basename(lib_path))): 19 | copyfile(lib_path, path.join(dirname, "fedtree", path.basename(lib_path))) 20 | 21 | setuptools.setup(name="fedtree", 22 | version="1.0.3", 23 | packages=["fedtree"], 24 | package_dir={"python": "fedtree"}, 25 | description="A federated learning library for trees", 26 | license='Apache-2.0', 27 | author='Qinbin Li', 28 | author_email='liqinbin1998@gmail.com', 29 | url='https://github.com/Xtra-Computing/FedTree', 30 | package_data={"fedtree": [path.basename(lib_path)]}, 31 | install_requires=['numpy', 'scipy', 'scikit-learn'], 32 | classifiers=[ 33 | "Programming Language :: Python :: 3", 34 | "License :: OSI Approved :: Apache Software License", 35 | ], 36 | ) 37 | -------------------------------------------------------------------------------- /src/FedTree/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 2 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 3 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) 4 | 5 | file(GLOB SRC util/*.c* DP/*.c* FL/*.c* Tree/*.c* Encryption/*.c* objective/*.c* metric/*.c* *.c*) 6 | if(USE_CUDA) 7 | list(REMOVE_ITEM SRC "${CMAKE_CURRENT_LIST_DIR}/Encryption/paillier.cpp") 8 | else() 9 | list(REMOVE_ITEM SRC "${CMAKE_CURRENT_LIST_DIR}/Encryption/paillier_gpu.cu" "${CMAKE_CURRENT_LIST_DIR}/Encryption/paillier_gmp.cpp") 10 | endif() 11 | 12 | list(REMOVE_ITEM SRC "${CMAKE_CURRENT_LIST_DIR}/fedtree_train.cpp" "${CMAKE_CURRENT_LIST_DIR}/fedtree_predict.cpp" "${CMAKE_CURRENT_LIST_DIR}/distributed_party.cpp" "${CMAKE_CURRENT_LIST_DIR}/distributed_server.cpp") 13 | find_library(NTL_LIB ntl ${NTL_PATH}/lib) 14 | if(NOT NTL_LIB) 15 | message(FATAL_ERROR "ntl library not found. Rerun cmake with -DCMAKE_PREFIX_PATH=\";\"") 16 | endif () 17 | find_library(M_LIB m) 18 | if (NOT M_LIB) 19 | message(FATAL_ERROR "m library not found. Rerun cmake with -DCMAKE_PREFIX_PATH=\";\"") 20 | endif () 21 | #find_library(GMP_LIB gmp) 22 | #if (NOT GMP_LIB) 23 | # message(FATAL_ERROR "gmp library not found. Rerun cmake with -DCMAKE_PREFIX_PATH=\";\"") 24 | #endif () 25 | #find_library(GMPXX_LIB gmpxx) 26 | if (DISTRIBUTED) 27 | include_directories(${PROJECT_SOURCE_DIR}/build/_deps/grpc-src/third_party/protobuf/src/) 28 | include_directories(${PROJECT_SOURCE_DIR}/build/_deps/grpc-src/include/) 29 | set(ft_proto_srcs "grpc/fedtree.pb.cc") 30 | set(ft_proto_hdrs "grpc/fedtree.pb.h") 31 | set(ft_grpc_srcs "grpc/fedtree.grpc.pb.cc") 32 | set(ft_grpc_hdrs "grpc/fedtree.grpc.pb.h") 33 | 34 | add_library(ft_grpc_proto ${ft_grpc_srcs} ${ft_grpc_hdrs} ${ft_proto_srcs} ${ft_proto_hdrs}) 35 | target_link_libraries(ft_grpc_proto grpc++) 36 | endif () 37 | 38 | if (USE_CUDA) 39 | cuda_add_library(${PROJECT_NAME} SHARED ${SRC}) 40 | target_link_libraries(${PROJECT_NAME} ThrustOMP ${CUDA_cusparse_LIBRARY} ${NTL_LIB} ${M_LIB} ${GMP_LIB} ${GMPXX_LIB}) 41 | # target_include_directories(${PROJECT_NAME} PUBLIC ${PROJECT_SOURCE_DIR}/CGBN/include) 42 | cuda_add_executable(${PROJECT_NAME}-train fedtree_train.cpp ${COMMON_INCLUDES}) 43 | cuda_add_executable(${PROJECT_NAME}-predict fedtree_predict.cpp ${COMMON_INCLUDES}) 44 | # set_target_properties(${PROJECT_NAME}-train PROPERTIES CUDA_ARCHITECTURES "80") 45 | # set_property(TARGET ${PROJECT_NAME} PROPERTY CUDA_ARCHITECTURES "80") 46 | if (DISTRIBUTED) 47 | cuda_add_executable(${PROJECT_NAME}-distributed-party distributed_party.cpp) 48 | cuda_add_executable(${PROJECT_NAME}-distributed-server distributed_server.cpp) 49 | endif () 50 | else () 51 | add_library(${PROJECT_NAME} SHARED ${SRC}) 52 | target_link_libraries(${PROJECT_NAME} ThrustOMP ${NTL_LIB} ${M_LIB} ${GMP_LIB} ${GMPXX_LIB}) 53 | add_executable(${PROJECT_NAME}-train fedtree_train.cpp) 54 | add_executable(${PROJECT_NAME}-predict fedtree_predict.cpp) 55 | if (DISTRIBUTED) 56 | add_executable(${PROJECT_NAME}-distributed-party distributed_party.cpp) 57 | add_executable(${PROJECT_NAME}-distributed-server distributed_server.cpp) 58 | endif () 59 | endif () 60 | if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") 61 | target_link_libraries(${PROJECT_NAME} OpenMP::OpenMP_CXX) 62 | endif () 63 | target_link_libraries(${PROJECT_NAME}-train ${PROJECT_NAME}) 64 | target_link_libraries(${PROJECT_NAME}-predict ${PROJECT_NAME}) 65 | if (DISTRIBUTED) 66 | foreach(_target ${PROJECT_NAME}-distributed-party ${PROJECT_NAME}-distributed-server) 67 | target_link_libraries(${_target} ${PROJECT_NAME}) 68 | target_link_libraries(${_target} ft_grpc_proto) 69 | target_link_libraries(${_target} grpc++) 70 | endforeach() 71 | endif () 72 | -------------------------------------------------------------------------------- /src/FedTree/DP/differential_privacy.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Tianyuan Fu on 14/3/21. 3 | // 4 | 5 | #include "FedTree/DP/differential_privacy.h" 6 | #include "FedTree/Tree/GBDTparam.h" 7 | #include 8 | #include 9 | #include 10 | 11 | void DifferentialPrivacy::init(FLParam flparam) { 12 | GBDTParam gbdt_param = flparam.gbdt_param; 13 | this->lambda = gbdt_param.lambda; 14 | if(gbdt_param.constant_h != 0) 15 | this->constant_h = gbdt_param.constant_h; 16 | this->delta_g = 3 * this->max_gradient * this->max_gradient / this->constant_h; 17 | this->delta_v = this->max_gradient / (this->constant_h + this->lambda); 18 | 19 | this->privacy_budget = flparam.privacy_budget; 20 | this->privacy_budget_per_tree = this->privacy_budget / (gbdt_param.n_trees/int(1/flparam.ins_bagging_fraction)); 21 | this->privacy_budget_leaf_nodes = this->privacy_budget_per_tree / 2.0; 22 | this->privacy_budget_internal_nodes = this->privacy_budget_per_tree / 2.0 / gbdt_param.depth; 23 | } 24 | 25 | /** 26 | * calculates p value based on gain value for each split point 27 | * @param gain - gain values of all split points in the level 28 | * @param prob_exponent - exponent for the probability mass; the probability mass is exp(prob_exponent[i]) 29 | */ 30 | void DifferentialPrivacy::compute_split_point_probability(SyncArray &gain, SyncArray &prob_exponent) { 31 | auto prob_exponent_data = prob_exponent.host_data(); 32 | auto gain_data = gain.host_data(); 33 | for(int i = 0; i < gain.size(); i ++) { 34 | prob_exponent_data[i] = this->privacy_budget_internal_nodes * gain_data[i] / 2 / delta_g; 35 | // LOG(INFO) << "budget" << this->privacy_budget_internal_nodes; 36 | // LOG(INFO) << "gain" << gain_data[i]; 37 | } 38 | } 39 | 40 | /** 41 | * exponential mechanism: randomly selects split point based on p value 42 | * @param prob_exponent - exponent for the probability mass; the probability mass is exp(prob_exponent[i]) 43 | * @param gain - gain values of all split points in the level 44 | * @param best_idx_gain - mapping from the node index to the gain of split point; containing all the node in the level 45 | */ 46 | void DifferentialPrivacy::exponential_select_split_point(SyncArray &prob_exponent, SyncArray &gain, 47 | SyncArray &best_idx_gain, int n_nodes_in_level, 48 | int n_bins) { 49 | // initialize randomization 50 | std::random_device device; 51 | std::mt19937 generator(device()); 52 | std::uniform_real_distribution<> distribution(0.0, 1.0); 53 | 54 | auto prob_exponent_data = prob_exponent.host_data(); 55 | auto gain_data = gain.host_data(); 56 | auto best_idx_gain_data = best_idx_gain.host_data(); 57 | 58 | vector probability(n_bins * n_nodes_in_level); 59 | 60 | for(int i = 0; i < n_nodes_in_level; i ++) { 61 | int start = i * n_bins; 62 | int end = start + n_bins - 1; 63 | 64 | // Given the probability exponent: a, b, c, d 65 | // The probability[0] can be calculated by exp(a)/(exp(a)+exp(b)+exp(c)+exp(d)) 66 | // To avoid overflow, calculation will be done in 1/(exp(a-a)+exp(b-a)+exp(c-a)+exp(d-a)) 67 | // Probability value with respect to the bin will be stored in probability vector 68 | for(int j = start; j <= end; j ++) { 69 | float curr_exponent = prob_exponent_data[j]; 70 | float prob_sum_denominator = 0; 71 | for(int k = start; k <= end; k ++) { 72 | prob_sum_denominator += exp(prob_exponent_data[k] - curr_exponent); 73 | } 74 | probability[j] = 1.0 / prob_sum_denominator; 75 | } 76 | 77 | 78 | float random_sample = distribution(generator); 79 | float partial_sum = 0; 80 | for(int j = start; j <= end; j ++) { 81 | partial_sum += probability[j]; 82 | if(partial_sum > random_sample) { 83 | best_idx_gain_data[i] = thrust::make_tuple(j, gain_data[j]); 84 | break; 85 | } 86 | } 87 | } 88 | } 89 | 90 | /** 91 | * add Laplace noise to the data 92 | * @param node - the leaf node which noise are to be added 93 | */ 94 | void DifferentialPrivacy::laplace_add_noise(Tree::TreeNode &node) { 95 | // a Laplace(0, b) variable can be generated by the difference of two i.i.d Exponential(1/b) variables 96 | float b = this->delta_v/privacy_budget_leaf_nodes; 97 | 98 | std::random_device device; 99 | std::mt19937 generator(device()); 100 | std::exponential_distribution distribution(1.0/b); 101 | 102 | double noise = distribution(generator) - distribution(generator); 103 | node.base_weight += noise; 104 | } 105 | -------------------------------------------------------------------------------- /src/FedTree/Encryption/diffie_hellman.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 5/26/22. 3 | // 4 | 5 | #include "FedTree/Encryption/diffie_hellman.h" 6 | #include 7 | /* 8 | bool is_prime(int n) 9 | { 10 | // Corner cases 11 | if (n <= 1) return false; 12 | if (n <= 3) return true; 13 | if (n%2 == 0 || n%3 == 0) return false; 14 | for (int i=5; i*i<=n; i=i+6) 15 | if (n%i == 0 || n%(i+2) == 0) 16 | return false; 17 | return true; 18 | } 19 | 20 | int fast_power(int x, unsigned int y, int p) 21 | { 22 | int res = 1; // Initialize result 23 | 24 | x = x % p; // Update x if it is more than or 25 | // equal to p 26 | 27 | while (y > 0) 28 | { 29 | // If y is odd, multiply x with result 30 | if (y & 1) 31 | res = (res*x) % p; 32 | 33 | // y must be even now 34 | y = y >> 1; // y = y/2 35 | x = (x*x) % p; 36 | } 37 | return res; 38 | } 39 | 40 | // Utility function to store prime factors of a number 41 | void findPrimefactors(unordered_set &s, int n) 42 | { 43 | // Print the number of 2s that divide n 44 | while (n%2 == 0) 45 | { 46 | s.insert(2); 47 | n = n/2; 48 | } 49 | 50 | // n must be odd at this point. So we can skip 51 | // one element (Note i = i +2) 52 | for (int i = 3; i <= sqrt(n); i = i+2) 53 | { 54 | // While i divides n, print i and divide n 55 | while (n%i == 0) 56 | { 57 | s.insert(i); 58 | n = n/i; 59 | } 60 | } 61 | 62 | // This condition is to handle the case when 63 | // n is a prime number greater than 2 64 | if (n > 2) 65 | s.insert(n); 66 | } 67 | 68 | // Function to find smallest primitive root of n 69 | int get_primitive(int n) 70 | { 71 | unordered_set s; 72 | 73 | // Check if n is prime or not 74 | if (isPrime(n)==false) 75 | return -1; 76 | 77 | // Find value of Euler Totient function of n 78 | // Since n is a prime number, the value of Euler 79 | // Totient function is n-1 as there are n-1 80 | // relatively prime numbers. 81 | int phi = n-1; 82 | 83 | // Find prime factors of phi and store in a set 84 | findPrimefactors(s, phi); 85 | 86 | // Check for every number from 2 to phi 87 | for (int r=2; r<=phi; r++) 88 | { 89 | // Iterate through all prime factors of phi. 90 | // and check if we found a power with value 1 91 | bool flag = false; 92 | for (auto it = s.begin(); it != s.end(); it++) 93 | { 94 | 95 | // Check if r^((phi)/primefactors) mod n 96 | // is 1 or not 97 | if (power(r, phi/(*it), n) == 1) 98 | { 99 | flag = true; 100 | break; 101 | } 102 | } 103 | 104 | // If there was no power with value 1. 105 | if (flag == false) 106 | return r; 107 | } 108 | 109 | // If no primitive root found 110 | return -1; 111 | } 112 | */ 113 | 114 | ZZ toDec(char val){ 115 | if (val=='A' || val=='a') return to_ZZ(10); 116 | else if(val=='B' || val=='b') return to_ZZ(11); 117 | else if(val=='C' || val=='c') return to_ZZ(12); 118 | else if(val=='D' || val=='d') return to_ZZ(13); 119 | else if(val=='E' || val=='e') return to_ZZ(14); 120 | else if(val=='F' || val=='f') return to_ZZ(15); 121 | else return to_ZZ(val-'0'); 122 | } 123 | 124 | ZZ hexToZZ(string hexVal){ 125 | ZZ val; 126 | val = to_ZZ(0); //initialise the value to zero 127 | double base = 16; 128 | int j = 0; 129 | //convert the hex string to decimal string 130 | for (int i = ((hexVal.length())-1); i > -1; i--){ 131 | val += toDec(hexVal[i])*(to_ZZ((pow(base, j)))); 132 | j++; 133 | } 134 | //cout << endl << "The value in decimal is " << val << endl; 135 | return val; 136 | } 137 | 138 | //void DiffieHellman::primegen(){ 139 | // std::random_device rd; // obtain a random number from hardware 140 | // std::mt19937 gen(rd()); // seed the generator 141 | // std::uniform_int_distribution<> distr(1e^5, 1e^8); // define the range 142 | // while(true){ 143 | // p = distr(gen); 144 | // if(is_prime(p)) 145 | // break; 146 | // } 147 | // g = get_primitive(p); 148 | 149 | // from https://datatracker.ietf.org/doc/html/rfc2409#page-22 150 | 151 | //} 152 | DiffieHellman::DiffieHellman(){ 153 | //use default value as it does not affect the security 154 | //from https://datatracker.ietf.org/doc/html/rfc2409#page-22 155 | p = hexToZZ("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519" 156 | "B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A89" 157 | "9FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF"); 158 | g = 2; 159 | } 160 | 161 | ZZ DiffieHellman::encrypt(float_type &message, int pid){ 162 | return (message*1e6 + shared_keys[pid])%p; 163 | } 164 | 165 | 166 | float_type DiffieHellman::decrypt(ZZ &message, int pid){ 167 | return (float_type) NTL::to_long((message - shared_keys[pid])%p) / 1e6; 168 | } 169 | 170 | void DiffieHellman::generate_public_key(){ 171 | std::random_device rd; 172 | std::mt19937 gen(rd()); 173 | std::uniform_int_distribution<> distr(1, 100); // define the range 174 | secret = distr(gen); 175 | public_key = PowerMod(g, secret, p); 176 | return; 177 | }; 178 | 179 | void DiffieHellman::init_variables(int n_parties){ 180 | this->n_parties = n_parties; 181 | other_public_keys.SetLength(n_parties); 182 | shared_keys.SetLength(n_parties); 183 | encrypted_noises.SetLength(n_parties); 184 | generated_noises.resize(n_parties); 185 | received_encrypted_noises.SetLength(n_parties); 186 | decrypted_noises.resize(n_parties); 187 | }; 188 | 189 | void DiffieHellman::compute_shared_keys(){ 190 | for(int i = 0; i < other_public_keys.length(); i++){ 191 | if(i!=pid) { 192 | shared_keys[i] = PowerMod(other_public_keys[i], secret, p); 193 | } 194 | } 195 | 196 | } 197 | 198 | void DiffieHellman::generate_noises(){ 199 | std::random_device rd; 200 | std::mt19937 gen(rd()); 201 | std::uniform_int_distribution<> distr(1e6, 1e9); // define the range 202 | for(int i = 0; i < n_parties; i++){ 203 | if(i!=pid) { 204 | generated_noises[i] = (float_type) distr(gen) / 1e6; 205 | encrypted_noises[i] = encrypt(generated_noises[i], i); 206 | } 207 | } 208 | return; 209 | } 210 | 211 | void DiffieHellman::decrypt_noises(){ 212 | for(int i = 0; i < n_parties; i++){ 213 | if(i != pid) 214 | decrypted_noises[i] = decrypt(received_encrypted_noises[i], i); 215 | } 216 | return; 217 | } -------------------------------------------------------------------------------- /src/FedTree/Encryption/paillier.cpp: -------------------------------------------------------------------------------- 1 | #include "FedTree/Encryption/paillier.h" 2 | #include 3 | 4 | using namespace std; 5 | 6 | /* Reference: Paillier, P. (1999, May). Public-key cryptosystems based on composite degree residuosity classes. */ 7 | 8 | 9 | NTL::ZZ Gen_Coprime(const NTL::ZZ &n) { 10 | /* Coprime generation function. Generates a random coprime number of n. 11 | * 12 | * Parameters 13 | * ========== 14 | * NTL::ZZ n : a prime number 15 | * 16 | * Returns 17 | * ======= 18 | * NTL:ZZ ret : a random coprime number of n 19 | */ 20 | NTL::ZZ ret; 21 | while (true) { 22 | ret = RandomBnd(n); 23 | if (NTL::GCD(ret, n) == 1) { return ret; } 24 | } 25 | } 26 | 27 | NTL::ZZ lcm(const NTL::ZZ &x, const NTL::ZZ &y) { 28 | /* Least common multiple function. Computes the least common multiple of x and y. 29 | * 30 | * Parameters 31 | * ========== 32 | * NTL::ZZ x, y: signed, arbitrary length integers 33 | * 34 | * Returns 35 | * ======= 36 | * NTL:ZZ lcm : the least common multiple of x and y 37 | */ 38 | NTL::ZZ lcm; 39 | lcm = (x * y) / NTL::GCD(x, y); 40 | return lcm; 41 | } 42 | 43 | void GenPrimePair(NTL::ZZ &p, NTL::ZZ &q, long keyLength) { 44 | /* Prime pair generation function. Generates a prime pair in same bit length. 45 | * 46 | * Parameters 47 | * ========== 48 | * NTL::ZZ p, q: large primes in same bit length 49 | * long keyLength: the length of the key 50 | */ 51 | while (true) { 52 | long err = 80; 53 | p = NTL::GenPrime_ZZ(keyLength / 2, err); 54 | q = NTL::GenPrime_ZZ(keyLength / 2, err); 55 | while (p == q) { 56 | q = NTL::GenPrime_ZZ(keyLength / 2, err); 57 | } 58 | NTL::ZZ n = p * q; 59 | NTL::ZZ phi = (p - 1) * (q - 1); 60 | if (NTL::GCD(n, phi) == 1) return; 61 | } 62 | } 63 | 64 | Paillier::Paillier() = default; 65 | 66 | void Paillier::keygen(long keyLength) { 67 | /* Paillier parameters generation function. Generates paillier parameters from scrach. 68 | * 69 | * Parameters 70 | * ========== 71 | * long keyLength: the length of the key 72 | * 73 | * ======= 74 | * public key = (modulus, generator) 75 | * private key = lambda 76 | */ 77 | 78 | // NTL::SetSeed(NTL::ZZ(0)); 79 | 80 | this->keyLength = keyLength; 81 | GenPrimePair(p, q, keyLength); 82 | modulus = p * q; // modulus = p * q 83 | generator = modulus + 1; // generator = modulus + 1 84 | lambda = lcm(p - 1, q - 1); // lamda = lcm(p-1, q-1) 85 | lambda_power = NTL::PowerMod(generator, lambda, modulus * modulus); 86 | u = NTL::InvMod(L_function(lambda_power), 87 | modulus); // u = L((generator^lambda) mod n ^ 2) ) ^ -1 mod modulus 88 | 89 | // random = Gen_Coprime(modulus); 90 | } 91 | 92 | NTL::ZZ Paillier::add(const NTL::ZZ &x, const NTL::ZZ &y) const { 93 | /* Paillier addition function. Computes the sum of x and y. 94 | * 95 | * Parameters 96 | * ========== 97 | * NTL::ZZ x, y: signed, arbitrary length integers 98 | * 99 | * Returns 100 | * ======= 101 | * NTL:ZZ sum: the sum of x and y 102 | */ 103 | NTL::ZZ sum = x * y % (modulus * modulus); 104 | return sum; 105 | } 106 | 107 | NTL::ZZ Paillier::mul(const NTL::ZZ &x, const NTL::ZZ &y) const { 108 | /* Paillier multiplication function. Computes the product of x and y. 109 | * 110 | * Parameters 111 | * ========== 112 | * NTL::ZZ x, y: signed, arbitrary length integers 113 | * 114 | * Returns 115 | * ======= 116 | * NTL:ZZ sum: the product of x and y 117 | */ 118 | NTL::ZZ product = PowerMod(x, y, modulus * modulus); 119 | return product; 120 | } 121 | 122 | NTL::ZZ Paillier::encrypt(const NTL::ZZ &message) const { 123 | /* Paillier encryption function. Takes in a message in F(modulus), and returns a message in F(modulus**2). 124 | * 125 | * Parameters 126 | * ========== 127 | * NTL::ZZ message : The message to be encrypted. 128 | * 129 | * Returns 130 | * ======= 131 | * NTL:ZZ ciphertext : The encyrpted message. 132 | */ 133 | 134 | NTL::ZZ c = Gen_Coprime(modulus); 135 | NTL::ZZ ciphertext = 136 | NTL::PowerMod(generator, message, modulus * modulus) * NTL::PowerMod(c, modulus, modulus * modulus) % 137 | (modulus * modulus); 138 | return ciphertext; 139 | } 140 | 141 | NTL::ZZ Paillier::decrypt(const NTL::ZZ &ciphertext) const { 142 | /* Paillier decryption function. Takes in a ciphertext in F(modulus**2), and returns a message in F(modulus). 143 | * 144 | * Parameters 145 | * ========== 146 | * NTL::ZZ cipertext : The encrypted message. 147 | * 148 | * Returns 149 | * ======= 150 | * NTL::ZZ message : The original message. 151 | */ 152 | 153 | NTL::ZZ deMasked = NTL::PowerMod(ciphertext, lambda, modulus * modulus); 154 | NTL::ZZ L_deMasked = L_function(deMasked); 155 | NTL::ZZ message = (L_deMasked * u) % modulus; 156 | return message; 157 | } 158 | -------------------------------------------------------------------------------- /src/FedTree/FL/party.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/19/20. 3 | // 4 | 5 | 6 | #include "FedTree/FL/party.h" 7 | #include 8 | #include 9 | #include 10 | 11 | 12 | void Party::init(int pid, DataSet &dataset, FLParam ¶m) { 13 | this->pid = pid; 14 | this->dataset = dataset; 15 | this->param = param; 16 | this->n_total_instances = dataset.n_instances(); 17 | if (param.ins_bagging_fraction < 1.0){ 18 | this->temp_dataset = dataset; 19 | this->ins_bagging_fraction = param.ins_bagging_fraction; 20 | } 21 | // if (param.partition_mode == "hybrid") { 22 | // this->feature_map.resize(feature_map.size()); 23 | // this->feature_map.copy_from(feature_map.host_data(), feature_map.size()); 24 | // } 25 | booster.init(dataset, param.gbdt_param, (param.mode != "horizontal") || (param.propose_split == "party")); 26 | 27 | }; 28 | 29 | void Party::hybrid_init(int pid, DataSet &dataset, FLParam ¶m, SyncArray &feature_map) { 30 | this->pid = pid; 31 | this->dataset = dataset; 32 | this->param = param; 33 | this->n_total_instances = dataset.n_instances(); 34 | if (param.ins_bagging_fraction < 1.0){ 35 | this->temp_dataset = dataset; 36 | this->ins_bagging_fraction = param.ins_bagging_fraction; 37 | } 38 | if (param.partition_mode == "hybrid") { 39 | this->feature_map.resize(feature_map.size()); 40 | this->feature_map.copy_from(feature_map.host_data(), feature_map.size()); 41 | } 42 | booster.init(dataset, param.gbdt_param, (param.mode != "horizontal") || (param.propose_split == "party")); 43 | 44 | }; 45 | 46 | void Party::vertical_init(int pid, DataSet &dataset, FLParam ¶m) { 47 | // LOG(INFO)<<"in party "<pid = pid; 49 | this->dataset = dataset; 50 | this->has_label = dataset.has_label; 51 | this->param = param; 52 | this->n_total_instances = dataset.n_instances(); 53 | if(has_label) { 54 | booster.init(dataset, param.gbdt_param); 55 | } 56 | else { 57 | booster.param = param.gbdt_param; 58 | booster.fbuilder.reset(new HistTreeBuilder); 59 | booster.fbuilder->init(dataset, param.gbdt_param); //if remove this line, cannot successfully run 60 | int n_outputs = param.gbdt_param.num_class * dataset.n_instances(); 61 | booster.gradients.resize(n_outputs); 62 | } 63 | // if (param.gbdt_param.metric == "default") { 64 | // booster.metric.reset(Metric::create(booster.obj->default_metric_name())); 65 | // }else { 66 | // booster.metric.reset(Metric::create(param.gbdt_param.metric)); 67 | // } 68 | // booster.metric->configure(param.gbdt_param, dataset); 69 | // booster.n_devices = param.gbdt_param.n_device; 70 | 71 | // booster.y = SyncArray(dataset.n_instances()); 72 | // booster.y.copy_from(dataset.y.data(), dataset.n_instances()); 73 | 74 | }; 75 | 76 | void Party::bagging_init(int seed){ 77 | 78 | this->bagging_inner_round = 0; 79 | this->shuffle_idx.resize(this->n_total_instances); 80 | thrust::sequence(thrust::host, this->shuffle_idx.data(), this->shuffle_idx.data() + this->n_total_instances); 81 | if(seed == -1) 82 | std::random_shuffle(this->shuffle_idx.begin(), shuffle_idx.end()); 83 | else { 84 | std::default_random_engine e(seed); 85 | std::shuffle(this->shuffle_idx.begin(), shuffle_idx.end(), e); 86 | } 87 | } 88 | 89 | void Party::sample_data(){ 90 | int stride = this->ins_bagging_fraction * this->n_total_instances; 91 | vector batch_idx; 92 | if(this->bagging_inner_round == (int(1/this->ins_bagging_fraction) - 1)){ 93 | batch_idx = vector(this->shuffle_idx.begin()+stride*this->bagging_inner_round, this->shuffle_idx.end()); 94 | } 95 | else { 96 | batch_idx = vector(this->shuffle_idx.begin() + stride * this->bagging_inner_round, 97 | this->shuffle_idx.begin() + stride * (this->bagging_inner_round + 1)); 98 | } 99 | temp_dataset.get_subset(batch_idx, this->dataset); 100 | this->bagging_inner_round++; 101 | } 102 | 103 | 104 | 105 | void Party::correct_trees(){ 106 | // vector &last_trees = gbdt.trees.back(); 107 | //// auto unique_feature_end = thrust::unique_copy(thrust::host, dataset.csr_col_idx.data(), 108 | //// dataset.csr_col_idx.data() + dataset.csr_col_idx.size(), unique_feature_ids.host_data()); 109 | //// int unique_len = unique_feature_end - unique_feature_ids.host_data(); 110 | // auto feature_map_data = feature_map.host_data(); 111 | // for(int i = 0; i < last_trees.size(); i++){ 112 | // Tree &tree = last_trees[i]; 113 | // auto tree_nodes = tree.nodes.host_data(); 114 | // for(int nid = 0; nid < tree.nodes.size(); nid++){ 115 | // //if the node is internal node and the party has the corresponding feature id 116 | // if(!tree_nodes[nid].is_leaf){ 117 | // if(feature_map_data[tree_nodes[nid].split_feature_id]) { 118 | // // calculate gain for each possible split point 119 | // HistCut &cut = booster.fbuilder->cut; 120 | // } 121 | // else{ 122 | // //go to next level 123 | // } 124 | // 125 | // } 126 | // else{ 127 | // 128 | // } 129 | // } 130 | // } 131 | // //send gains to the server. 132 | } 133 | 134 | void Party::update_tree_info(){ 135 | 136 | 137 | // HistCut &cut = booster.fbuilder->cut; 138 | // vector &last_trees = gbdt.trees.back(); 139 | // for(int tid = 0; tid < last_trees.size(); tid++){ 140 | // Tree &tree = last_trees[tid]; 141 | // auto root_node = tree.nodes.host_data()[0]; 142 | // root_node.sum_gh_pair = thrust::reduce(thrust::host, booster.fbuilder->gradients.host_data(), 143 | // booster.fbuilder->gradients.host_end()); 144 | // int split_feature_id = root_node.split_feature_id; 145 | // auto csc_col_ptr = dataset.csc_col_ptr.data(); 146 | // auto csc_val_data = dataset.csc_val.data(); 147 | // auto cut_col_ptr = cut.cut_col_ptr.host_data(); 148 | // auto cut_val_data = cut.cut_points_val.host_data(); 149 | // for(int cid = csc_col_ptr[split_feature_id]; cid < csc_col_ptr[split_feature_id+1]; cid++){ 150 | // float_type feature_value = csc_val_data[cid]; 151 | // for(int cut_id = cut_col_ptr[cid]; cut_id < cut_col_ptr[cid+1]; cut_id++){ 152 | // float_type cut_value = cut_val_data[cut_id]; 153 | // } 154 | // } 155 | // for(int nid = 1; nid < tree.nodes.size(); nid++){ 156 | // auto tree_node_data = tree.nodes.host_data()[nid]; 157 | // 158 | // } 159 | // } 160 | } 161 | 162 | void Party::compute_leaf_values(){ 163 | 164 | } -------------------------------------------------------------------------------- /src/FedTree/Tree/function_builder.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/11/20. 3 | // 4 | 5 | #include "FedTree/Tree/function_builder.h" 6 | #include "FedTree/Tree/hist_tree_builder.h" 7 | 8 | FunctionBuilder *FunctionBuilder::create(std::string name) { 9 | if (name == "exact") { 10 | std::cout<<"not supported yet"; 11 | exit(1); 12 | } 13 | if (name == "hist") return new HistTreeBuilder; 14 | LOG(FATAL) << "unknown builder " << name; 15 | return nullptr; 16 | } 17 | 18 | -------------------------------------------------------------------------------- /src/FedTree/booster.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/17/20. 3 | // 4 | #include 5 | #include 6 | #include "FedTree/booster.h" 7 | 8 | //std::mutex mtx; 9 | 10 | //void Booster::init(const GBDTParam ¶m, int n_instances) { 11 | // this -> param = param; 12 | // fbuilder.reset(new HistTreeBuilder); 13 | // fbuilder->init(param, n_instances); 14 | // n_devices = param.n_device; 15 | // int n_outputs = param.num_class * n_instances; 16 | // gradients = SyncArray(n_outputs); 17 | //} 18 | 19 | void Booster::init(DataSet &dataSet, const GBDTParam ¶m, bool get_cut_points) { 20 | this->param = param; 21 | fbuilder.reset(new HistTreeBuilder); 22 | if(get_cut_points) 23 | fbuilder->init(dataSet, param); 24 | else { 25 | fbuilder->init_nocutpoints(dataSet, param); 26 | } 27 | obj.reset(ObjectiveFunction::create(param.objective)); 28 | obj->configure(param, dataSet); 29 | if (param.metric == "default") { 30 | metric.reset(Metric::create(obj->default_metric_name())); 31 | }else { 32 | metric.reset(Metric::create(param.metric)); 33 | } 34 | metric->configure(param, dataSet); 35 | n_devices = param.n_device; 36 | int n_outputs = param.num_class * dataSet.n_instances(); 37 | gradients.resize(n_outputs); 38 | y = SyncArray(dataSet.n_instances()); 39 | y.copy_from(dataSet.y.data(), dataSet.n_instances()); 40 | } 41 | 42 | void Booster::reinit(DataSet &dataSet, const GBDTParam ¶m){ 43 | //todo: horizontal does not need get_cut_points 44 | fbuilder->init(dataSet, param); 45 | int n_outputs = param.num_class * dataSet.n_instances(); 46 | gradients.resize(n_outputs); 47 | y.resize(dataSet.n_instances()); 48 | y.copy_from(dataSet.y.data(), dataSet.n_instances()); 49 | } 50 | 51 | SyncArray Booster::get_gradients() { 52 | SyncArray gh; 53 | gh.resize(gradients.size()); 54 | gh.copy_from(gradients); 55 | return gh; 56 | } 57 | 58 | void Booster::set_gradients(SyncArray &gh) { 59 | gradients.resize(gh.size()); 60 | 61 | // auto gradients_data = gradients.host_data(); 62 | // auto gh_data = gh.host_data(); 63 | // for(int i = 0; i < gh.size(); i++) 64 | // gradients_data[i] = gh_data[i]; 65 | gradients.copy_from(gh); 66 | } 67 | 68 | //void Booster::encrypt_gradients(AdditivelyHE::PaillierPublicKey pk) { 69 | // auto gradients_data = gradients.host_data(); 70 | // for (int i = 0; i < gradients.size(); i++) 71 | // gradients_data[i].homo_encrypt(pk); 72 | //} 73 | 74 | //void Booster::decrypt_gradients(AdditivelyHE::PaillierPrivateKey privateKey) { 75 | // auto gradients_data = gradients.host_data(); 76 | // for (int i = 0; i < gradients.size(); i++) 77 | // gradients_data[i].homo_decrypt(privateKey); 78 | //} 79 | 80 | void Booster::add_noise_to_gradients(float variance) { 81 | auto gradients_data = gradients.host_data(); 82 | for (int i = 0; i < gradients.size(); i++) { 83 | DPnoises::add_gaussian_noise(gradients_data[i].g, variance); 84 | DPnoises::add_gaussian_noise(gradients_data[i].h, variance); 85 | } 86 | } 87 | 88 | void Booster::update_gradients() { 89 | obj->get_gradient(y, fbuilder->get_y_predict(), gradients); 90 | } 91 | 92 | void Booster::boost(vector> &boosted_model) { 93 | TIMED_FUNC(timerObj); 94 | // std::unique_lock lock(mtx); 95 | //update gradients 96 | obj->get_gradient(y, fbuilder->get_y_predict(), gradients); 97 | 98 | // if (param.bagging) rowSampler.do_bagging(gradients); 99 | PERFORMANCE_CHECKPOINT(timerObj); 100 | //build new model/approximate function 101 | boosted_model.push_back(fbuilder->build_approximate(gradients)); 102 | 103 | PERFORMANCE_CHECKPOINT(timerObj); 104 | //show metric on training set 105 | std::ofstream myfile; 106 | myfile.open ("data.txt", std::ios_base::app); 107 | myfile << metric->get_score(fbuilder->get_y_predict()) << "\n"; 108 | myfile.close(); 109 | LOG(INFO) << metric->get_name() << " = " << metric->get_score(fbuilder->get_y_predict()); 110 | } 111 | 112 | void Booster::boost_without_prediction(vector> &boosted_model) { 113 | TIMED_FUNC(timerObj); 114 | // std::unique_lock lock(mtx); 115 | //update gradients 116 | obj->get_gradient(y, fbuilder->get_y_predict(), gradients); 117 | //LOG(INFO)<<"gradients after updated:"<build_approximate(gradients, false)); 123 | 124 | PERFORMANCE_CHECKPOINT(timerObj); 125 | //show metric on training set 126 | LOG(INFO) << metric->get_name() << " = " << metric->get_score(fbuilder->get_y_predict()); 127 | } -------------------------------------------------------------------------------- /src/FedTree/fedtree_predict.cpp: -------------------------------------------------------------------------------- 1 | #include "FedTree/FL/FLparam.h" 2 | #include "FedTree/parser.h" 3 | #include "FedTree/dataset.h" 4 | #include "FedTree/predictor.h" 5 | #include "FedTree/Tree/gbdt.h" 6 | 7 | 8 | #ifdef _WIN32 9 | INITIALIZE_EASYLOGGINGPP 10 | #endif 11 | int main(int argc, char** argv) { 12 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::Format, "%datetime %level %fbase:%line : %msg"); 13 | el::Loggers::addFlag(el::LoggingFlag::ColoredTerminalOutput); 14 | el::Loggers::addFlag(el::LoggingFlag::FixedTimeFormat); 15 | 16 | //centralized training test 17 | FLParam fl_param; 18 | Parser parser; 19 | parser.parse_param(fl_param, argc, argv); 20 | GBDTParam &model_param = fl_param.gbdt_param; 21 | if (model_param.verbose == 0) { 22 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 23 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 24 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "false"); 25 | } else if (model_param.verbose == 1) { 26 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 27 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 28 | } 29 | 30 | if (!model_param.profiling) { 31 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false"); 32 | } 33 | 34 | DataSet dataset; 35 | vector> boosted_model; 36 | parser.load_model(model_param.model_path, model_param, boosted_model); 37 | dataset.load_from_file(model_param.path, fl_param); 38 | 39 | GBDT gbdt(boosted_model); 40 | vector y_pred_vec = gbdt.predict(model_param, dataset); 41 | 42 | return 0; 43 | } -------------------------------------------------------------------------------- /src/FedTree/metric/metric.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/11/20. 3 | // 4 | 5 | #include "FedTree/metric/metric.h" 6 | #include "FedTree/metric/pointwise_metric.h" 7 | #include "FedTree/metric/ranking_metric.h" 8 | #include "FedTree/metric/multiclass_metric.h" 9 | 10 | Metric *Metric::create(string name) { 11 | if (name == "map") return new MAP; 12 | if (name == "rmse") return new RMSE; 13 | if (name == "ndcg") return new NDCG; 14 | if (name == "macc") return new MulticlassAccuracy; 15 | if (name == "error") return new BinaryClassMetric; 16 | LOG(FATAL) << "unknown metric " << name; 17 | return nullptr; 18 | } 19 | 20 | void Metric::configure(const GBDTParam ¶m, const DataSet &dataset) { 21 | y.resize(dataset.y.size()); 22 | y.copy_from(dataset.y.data(), dataset.n_instances()); 23 | } 24 | -------------------------------------------------------------------------------- /src/FedTree/metric/multiclass_metric.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/11/20. 3 | // 4 | 5 | #include "FedTree/metric/multiclass_metric.h" 6 | //#include "FedTree/util/device_lambda.h" 7 | #include "thrust/reduce.h" 8 | #include "thrust/execution_policy.h" 9 | 10 | using namespace std; 11 | float_type MulticlassAccuracy::get_score(const SyncArray &y_p) const { 12 | CHECK_EQ(num_class * y.size(), y_p.size()) << num_class << " * " << y.size() << " != " << y_p.size(); 13 | int n_instances = y.size(); 14 | auto y_data = y.host_data(); 15 | auto yp_data = y_p.host_data(); 16 | SyncArray is_true(n_instances); 17 | auto is_true_data = is_true.host_data(); 18 | int num_class = this->num_class; 19 | #pragma omp parallel for 20 | for (int i = 0; i < n_instances; i++){ 21 | int max_k = 0; 22 | float_type max_p = yp_data[i]; 23 | for (int k = 1; k < num_class; ++k) { 24 | if (yp_data[k * n_instances + i] > max_p) { 25 | max_p = yp_data[k * n_instances + i]; 26 | max_k = k; 27 | } 28 | } 29 | is_true_data[i] = max_k == y_data[i]; 30 | } 31 | 32 | float acc = thrust::reduce(thrust::host, is_true_data, is_true_data + n_instances) / (float) n_instances; 33 | return acc; 34 | } 35 | 36 | float_type BinaryClassMetric::get_score(const SyncArray &y_p) const { 37 | /* 38 | // compute test error 39 | int n_instances = y.size(); 40 | auto y_data = y.host_data(); 41 | auto yp_data = y_p.host_data(); 42 | SyncArray is_true(n_instances); 43 | auto is_true_data = is_true.host_data(); 44 | #pragma omp parallel for 45 | for (int i = 0; i < n_instances; i++){ 46 | // change the threshold to 0 if the classes are -1 and 1 and using regression as the objective. 47 | int max_k = (yp_data[i] > 0.5) ? 1 : 0; 48 | is_true_data[i] = max_k == y_data[i]; 49 | } 50 | float acc = thrust::reduce(thrust::host, is_true_data, is_true_data + n_instances) / (float) n_instances; 51 | return 1 - acc; 52 | */ 53 | //compute AUC 54 | int n = y.size(); 55 | int pos = 0; 56 | vector> pl; 57 | auto y_data = y.host_data(); 58 | auto yp_data = y_p.host_data(); 59 | for (int i = 0; i < n; i++) { 60 | pos += y_data[i]; 61 | pl.emplace_back(yp_data[i], y_data[i]); 62 | } 63 | sort(pl.begin(), pl.end()); 64 | double pos_sum = 0; 65 | for (int left = 0, right = 0; right < n; left = right) { 66 | float_type sum = 0, cnt = 0; 67 | while (right < n && pl[right].first == pl[left].first) { 68 | cnt += pl[right++].second; 69 | sum += right + 1; 70 | } 71 | pos_sum += sum * cnt / (right - left); 72 | } 73 | return min((pos_sum - (pos * (pos + 1) / 2)) / (pos * (n - pos)), 1.0); 74 | } 75 | /* 76 | float_type BinaryClassMetric::get_auc(const SyncArray& y_p) { 77 | int n = y.size(); 78 | int pos = 0; 79 | vector> pl; 80 | auto y_data = y.host_data(); 81 | auto yp_data = y_p.host_data(); 82 | for (int i = 0; i < n; i++) { 83 | pos += y_data[i]; 84 | pl.emplace_back(yp_data[i], y_data[i]); 85 | } 86 | sort(pl.begin(), pl.end()); 87 | double pos_sum = 0; 88 | for (int left = 0, right = 0; right < n; left = right) { 89 | double sum = 0, cnt = 0; 90 | while (right < n && pl[right].first == pl[left].first) { 91 | cnt += pl[right++].second; 92 | sum += right + 1; 93 | } 94 | pos_sum += sum * cnt / (right - left); 95 | } 96 | return (pos_sum - (pos * (pos + 1) / 2)) / (pos * (n - pos)); 97 | } 98 | */ 99 | -------------------------------------------------------------------------------- /src/FedTree/metric/pointwise_metric.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/11/20. 3 | // 4 | 5 | #define THRUST_IGNORE_DEPRECATED_CPP_DIALECT 6 | 7 | #include "thrust/reduce.h" 8 | #include "thrust/execution_policy.h" 9 | //#include "FedTree/util/device_lambda.h" 10 | #include "FedTree/metric/pointwise_metric.h" 11 | 12 | float_type RMSE::get_score(const SyncArray &y_p) const { 13 | CHECK_EQ(y_p.size(), y.size()); 14 | int n_instances = y_p.size(); 15 | SyncArray sq_err(n_instances); 16 | auto sq_err_data = sq_err.host_data(); 17 | const float_type *y_data = y.host_data(); 18 | const float_type *y_predict_data = y_p.host_data(); 19 | #pragma omp parallel for 20 | for (int i = 0; i < n_instances; i++){ 21 | float_type e = y_predict_data[i] - y_data[i]; 22 | sq_err_data[i] = e * e; 23 | } 24 | float_type rmse = 25 | sqrtf(thrust::reduce(thrust::host, sq_err.host_data(), sq_err.host_end()) / n_instances); 26 | return rmse; 27 | } 28 | 29 | -------------------------------------------------------------------------------- /src/FedTree/metric/rank_metric.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/11/20. 3 | // 4 | 5 | #include 6 | #ifdef __unix__ 7 | #include "parallel/algorithm" 8 | #endif 9 | 10 | float_type RankListMetric::get_score(const SyncArray &y_p) const { 11 | TIMED_FUNC(obj); 12 | float_type sum_score = 0; 13 | auto y_data0 = y.host_data(); 14 | auto yp_data0 = y_p.host_data(); 15 | #pragma omp parallel for schedule(static) reduction(+:sum_score) 16 | for (int k = 0; k < n_group; ++k) { 17 | int group_start = gptr[k]; 18 | int len = gptr[k + 1] - group_start; 19 | vector query_y(len); 20 | vector query_yp(len); 21 | memcpy(query_y.data(), y_data0 + group_start, len * sizeof(float_type)); 22 | memcpy(query_yp.data(), yp_data0 + group_start, len * sizeof(float_type)); 23 | sum_score += this->eval_query_group(query_y, query_yp, k); 24 | } 25 | return sum_score / n_group; 26 | } 27 | 28 | void RankListMetric::configure(const GBDTParam ¶m, const DataSet &dataset) { 29 | Metric::configure(param, dataset); 30 | 31 | //init gptr 32 | n_group = dataset.group.size(); 33 | configure_gptr(dataset.group, gptr); 34 | 35 | //TODO parse from param 36 | topn = (std::numeric_limits::max)(); 37 | } 38 | 39 | void RankListMetric::configure_gptr(const vector &group, vector &gptr) { 40 | gptr = vector(group.size() + 1, 0); 41 | for (int i = 1; i < gptr.size(); ++i) { 42 | gptr[i] = gptr[i - 1] + group[i - 1]; 43 | } 44 | } 45 | 46 | float_type MAP::eval_query_group(vector &y, vector &y_p, int group_id) const { 47 | auto y_data = y.data(); 48 | auto yp_data = y_p.data(); 49 | int len = y.size(); 50 | vector idx(len); 51 | for (int i = 0; i < len; ++i) { 52 | idx[i] = i; 53 | } 54 | #ifdef __unix__ 55 | __gnu_parallel::sort(idx.begin(), idx.end(), [=](int a, int b) { return yp_data[a] > yp_data[b]; }); 56 | #else 57 | std::sort(idx.begin(), idx.end(), [=](int a, int b) { return yp_data[a] > yp_data[b]; }); 58 | #endif 59 | int nhits = 0; 60 | double sum_ap = 0; 61 | for (int i = 0; i < len; ++i) { 62 | if (y_data[idx[i]] != 0) { 63 | nhits++; 64 | if (i < topn) { 65 | sum_ap += (double) nhits / (i + 1); 66 | } 67 | } 68 | } 69 | 70 | if (nhits != 0) 71 | return sum_ap / nhits; 72 | else return 1; 73 | } 74 | 75 | void NDCG::configure(const GBDTParam ¶m, const DataSet &dataset) { 76 | RankListMetric::configure(param, dataset); 77 | get_IDCG(gptr, dataset.y, idcg); 78 | } 79 | 80 | float_type NDCG::eval_query_group(vector &y, vector &y_p, int group_id) const { 81 | CHECK_EQ(y.size(), y_p.size()); 82 | if (idcg[group_id] == 0) return 1; 83 | int len = y.size(); 84 | vector idx(len); 85 | for (int i = 0; i < len; ++i) { 86 | idx[i] = i; 87 | } 88 | auto label = y.data(); 89 | auto score = y_p.data(); 90 | #ifdef __unix__ 91 | __gnu_parallel::sort(idx.begin(), idx.end(), [=](int a, int b) { return score[a] > score[b]; }); 92 | #else 93 | std::sort(idx.begin(), idx.end(), [=](int a, int b) { return score[a] > score[b]; }); 94 | #endif 95 | 96 | float_type dcg = 0; 97 | for (int i = 0; i < len; ++i) { 98 | dcg += discounted_gain(static_cast(label[idx[i]]), i); 99 | } 100 | return dcg / idcg[group_id]; 101 | } 102 | 103 | void NDCG::get_IDCG(const vector &gptr, const vector &y, vector &idcg) { 104 | int n_group = gptr.size() - 1; 105 | idcg.clear(); 106 | idcg.resize(n_group); 107 | //calculate IDCG 108 | #pragma omp parallel for schedule(static) 109 | for (int k = 0; k < n_group; ++k) { 110 | int group_start = gptr[k]; 111 | int len = gptr[k + 1] - group_start; 112 | vector sorted_label(len); 113 | memcpy(sorted_label.data(), y.data() + group_start, len * sizeof(float_type)); 114 | #ifdef __unix__ 115 | __gnu_parallel::sort(sorted_label.begin(), sorted_label.end(), std::greater()); 116 | #else 117 | std::sort(sorted_label.begin(), sorted_label.end(), std::greater()); 118 | #endif 119 | for (int i = 0; i < sorted_label.size(); ++i) { 120 | //assume labels are int 121 | idcg[k] += NDCG::discounted_gain(static_cast(sorted_label[i]), i); 122 | } 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /src/FedTree/objective/multiclass_obj.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/15/20. 3 | // 4 | 5 | 6 | #include "FedTree/objective/multiclass_obj.h" 7 | 8 | void Softmax::configure(GBDTParam param, const DataSet &dataset) { 9 | constant_h = param.constant_h; 10 | num_class = param.num_class; 11 | label.resize(num_class); 12 | CHECK_EQ(dataset.label.size(), num_class)< &y) { 74 | auto yp_data = y.host_data(); 75 | int num_class = this->num_class; 76 | int n_instances = y.size() / num_class; 77 | #pragma omp parallel for 78 | for(int i = 0; i < n_instances; i++){ 79 | float_type max = yp_data[i]; 80 | for (int k = 1; k < num_class; ++k) { 81 | max = fmaxf(max, yp_data[k * n_instances + i]); 82 | } 83 | float_type sum = 0; 84 | for (int k = 0; k < num_class; ++k) { 85 | //-max to avoid numerical issue 86 | yp_data[k * n_instances + i] = expf(yp_data[k * n_instances + i] - max); 87 | sum += yp_data[k * n_instances + i]; 88 | } 89 | for (int k = 0; k < num_class; ++k) { 90 | yp_data[k * n_instances + i] /= sum; 91 | } 92 | } 93 | } 94 | 95 | -------------------------------------------------------------------------------- /src/FedTree/objective/objective_function.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/15/20. 3 | // 4 | 5 | #include 6 | #include "FedTree/objective/regression_obj.h" 7 | #include "FedTree/objective/multiclass_obj.h" 8 | #include "FedTree/objective/ranking_obj.h" 9 | 10 | ObjectiveFunction *ObjectiveFunction::create(string name) { 11 | if (name == "reg:linear") return new RegressionObj; 12 | if (name == "reg:logistic") return new RegressionObj; 13 | if (name == "binary:logistic") return new LogClsObj; 14 | if (name == "multi:softprob") return new SoftmaxProb; 15 | if (name == "multi:softmax") return new Softmax; 16 | if (name == "rank:pairwise") return new LambdaRank; 17 | if (name == "rank:ndcg") return new LambdaRankNDCG; 18 | LOG(FATAL) << "undefined objective " << name; 19 | return nullptr; 20 | } 21 | 22 | bool ObjectiveFunction::need_load_group_file(string name) { 23 | return name == "rank:ndcg" || name == "rank:pairwise"; 24 | } 25 | 26 | bool ObjectiveFunction::need_group_label(string name) { 27 | return name == "multi:softprob" || name == "multi:softmax" || name == "binary:logistic"; 28 | } 29 | -------------------------------------------------------------------------------- /src/FedTree/objective/ranking_obj.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 12/15/20. 3 | // 4 | 5 | #include 6 | #include "FedTree/metric/ranking_metric.h" 7 | //#ifndef _WIN32 8 | //#include 9 | //#endif 10 | #include 11 | 12 | void LambdaRank::configure(GBDTParam param, const DataSet &dataset) { 13 | constant_h = param.constant_h; 14 | sigma = 1; 15 | 16 | //init gptr 17 | n_group = dataset.group.size(); 18 | RankListMetric::configure_gptr(dataset.group, gptr); 19 | CHECK_EQ(gptr.back(), dataset.n_instances()); 20 | } 21 | 22 | void 23 | LambdaRank::get_gradient(const SyncArray &y, const SyncArray &y_p, SyncArray &gh_pair) { 24 | TIMED_FUNC(obj); 25 | { 26 | auto gh_data = gh_pair.host_data(); 27 | #pragma omp parallel for schedule(static) 28 | for (int i = 0; i < gh_pair.size(); ++i) { 29 | gh_data[i] = 0; 30 | } 31 | } 32 | GHPair *gh_data0 = gh_pair.host_data(); 33 | const float_type *score0 = y_p.host_data(); 34 | const float_type *label_data0 = y.host_data(); 35 | PERFORMANCE_CHECKPOINT_WITH_ID(obj, "copy and init"); 36 | #pragma omp parallel for schedule(static) 37 | for (int k = 0; k < n_group; ++k) { 38 | int group_start = gptr[k]; 39 | int len = gptr[k + 1] - group_start; 40 | GHPair *gh_data = gh_data0 + group_start; 41 | const float_type *score = score0 + group_start; 42 | const float_type *label_data = label_data0 + group_start; 43 | vector idx(len); 44 | for (int i = 0; i < len; ++i) { idx[i] = i; } 45 | std::sort(idx.begin(), idx.end(), [=](int a, int b) { return score[a] > score[b]; }); 46 | vector> label_idx(len); 47 | for (int i = 0; i < len; ++i) { 48 | label_idx[i].first = label_data[idx[i]]; 49 | label_idx[i].second = idx[i]; 50 | } 51 | //sort by label ascending 52 | std::sort(label_idx.begin(), label_idx.end(), 53 | [](std::pair a, std::pair b) { return a.first > b.first; }); 54 | 55 | std::mt19937 gen(std::rand()); 56 | for (int i = 0; i < len; ++i) { 57 | int j = i + 1; 58 | while (j < len && label_idx[i].first == label_idx[j].first) j++; 59 | int nleft = i; 60 | int nright = len - j; 61 | //if not all are same label 62 | if (nleft + nright != 0) { 63 | // bucket in [i,j), get a sample outside bucket 64 | std::uniform_int_distribution<> dis(0, nleft + nright - 1); 65 | for (int l = i; l < j; ++l) { 66 | int m = dis(gen); 67 | int high = 0; 68 | int low = 0; 69 | if (m < nleft) { 70 | high = m; 71 | low = l; 72 | } else { 73 | high = l; 74 | low = m + j - i; 75 | } 76 | float_type high_label = label_idx[high].first; 77 | float_type low_label = label_idx[low].first; 78 | int high_idx = label_idx[high].second; 79 | int low_idx = label_idx[low].second; 80 | 81 | float_type abs_delta_z = fabsf(get_delta_z(high_label, low_label, high, low, k)); 82 | float_type rhoIJ = 1 / (1 + expf((score[high_idx] - score[low_idx]))); 83 | float_type lambda = -abs_delta_z * rhoIJ; 84 | float_type hessian = constant_h == 0 ? 2 * fmaxf(abs_delta_z * rhoIJ * (1 - rhoIJ), 1e-16f) : constant_h; 85 | gh_data[high_idx] = gh_data[high_idx] + GHPair(lambda, hessian); 86 | gh_data[low_idx] = gh_data[low_idx] + GHPair(-lambda, hessian); 87 | } 88 | } 89 | i = j; 90 | } 91 | } 92 | // y_p.to_device(); 93 | } 94 | 95 | string LambdaRank::default_metric_name() { return "map"; } 96 | 97 | void LambdaRankNDCG::configure(GBDTParam param, const DataSet &dataset) { 98 | constant_h = param.constant_h; 99 | LambdaRank::configure(param, dataset); 100 | NDCG::get_IDCG(gptr, dataset.y, idcg); 101 | } 102 | 103 | 104 | float_type 105 | LambdaRankNDCG::get_delta_z(float_type labelI, float_type labelJ, int rankI, int rankJ, int group_id) { 106 | if (idcg[group_id] == 0) return 0; 107 | float_type dgI1 = NDCG::discounted_gain(static_cast(labelI), rankI); 108 | float_type dgJ1 = NDCG::discounted_gain(static_cast(labelJ), rankJ); 109 | float_type dgI2 = NDCG::discounted_gain(static_cast(labelI), rankJ); 110 | float_type dgJ2 = NDCG::discounted_gain(static_cast(labelJ), rankI); 111 | return (dgI1 + dgJ1 - dgI2 - dgJ2) / idcg[group_id]; 112 | } 113 | 114 | string LambdaRankNDCG::default_metric_name() { return "ndcg"; } 115 | -------------------------------------------------------------------------------- /src/FedTree/predictor.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Kelly Yung on 2020/12/3. Code taken reference from ThunderGBM/predictor.cu 3 | // 4 | 5 | #include "FedTree/predictor.h" 6 | //#include "FedTree/util/device_lambda.h" 7 | #include "FedTree/objective/objective_function.h" 8 | 9 | void Predictor::get_y_predict (const GBDTParam &model_param, const vector> &boosted_model, 10 | const DataSet &dataSet, SyncArray &y_predict) { 11 | int n_instances = dataSet.n_instances(); 12 | int n_features = dataSet.n_features(); 13 | 14 | //the whole model to an array 15 | int num_iter = boosted_model.size(); 16 | int num_class = boosted_model.front().size(); 17 | int num_node = boosted_model[0][0].nodes.size(); 18 | int total_num_node = num_iter * num_class * num_node; 19 | y_predict.resize(n_instances * num_class); 20 | 21 | SyncArray model(total_num_node); 22 | auto model_data = model.host_data(); 23 | int tree_cnt = 0; 24 | for (auto &vtree:boosted_model) { 25 | for (auto &t:vtree) { 26 | memcpy(model_data + num_node * tree_cnt, t.nodes.host_data(), sizeof(Tree::TreeNode) * num_node); 27 | tree_cnt++; 28 | } 29 | } 30 | 31 | SyncArray csr_col_idx(dataSet.csr_col_idx.size()); 32 | SyncArray csr_val(dataSet.csr_val.size()); 33 | SyncArray csr_row_ptr(dataSet.csr_row_ptr.size()); 34 | csr_col_idx.copy_from(dataSet.csr_col_idx.data(), dataSet.csr_col_idx.size()); 35 | csr_val.copy_from(dataSet.csr_val.data(), dataSet.csr_val.size()); 36 | csr_row_ptr.copy_from(dataSet.csr_row_ptr.data(), dataSet.csr_row_ptr.size()); 37 | 38 | //do prediction 39 | auto model_host_data = model.host_data(); 40 | auto predict_data = y_predict.host_data(); 41 | auto csr_col_idx_data = csr_col_idx.host_data(); 42 | auto csr_val_data = csr_val.host_data(); 43 | auto csr_row_ptr_data = csr_row_ptr.host_data(); 44 | auto lr = model_param.learning_rate; 45 | 46 | //use sparse format and binary search 47 | for (int iid = 0; iid < n_instances; iid++) { 48 | 49 | auto get_next_child = [&](Tree::TreeNode node, float_type feaValue) { 50 | return feaValue < node.split_value ? node.lch_index : node.rch_index; 51 | }; 52 | 53 | auto get_val = [&](const int *row_idx, const float_type *row_val, int row_len, int idx, 54 | bool *is_missing) -> float_type { 55 | //binary search to get feature value 56 | const int *left = row_idx; 57 | const int *right = row_idx + row_len; 58 | 59 | while (left != right) { 60 | const int *mid = left + (right - left) / 2; 61 | if (*mid == idx) { 62 | *is_missing = false; 63 | return row_val[mid - row_idx]; 64 | } 65 | if (*mid > idx) 66 | right = mid; 67 | else left = mid + 1; 68 | } 69 | *is_missing = true; 70 | return 0; 71 | }; 72 | 73 | int *col_idx = csr_col_idx_data + csr_row_ptr_data[iid]; 74 | float_type *row_val = csr_val_data + csr_row_ptr_data[iid]; 75 | int row_len = csr_row_ptr_data[iid + 1] - csr_row_ptr_data[iid]; 76 | for (int t = 0; t < num_class; t++) { 77 | auto predict_data_class = predict_data + t * n_instances; 78 | float_type sum = 0; 79 | for (int iter = 0; iter < num_iter; iter++) { 80 | const Tree::TreeNode *node_data = model_host_data + iter * num_class * num_node + t * num_node; 81 | Tree::TreeNode curNode = node_data[0]; 82 | int cur_nid = 0; //node id 83 | while (!curNode.is_leaf) { 84 | int fid = curNode.split_feature_id; 85 | bool is_missing; 86 | float_type fval = get_val(col_idx, row_val, row_len, fid, &is_missing); 87 | if (!is_missing) 88 | cur_nid = get_next_child(curNode, fval); 89 | else if (curNode.default_right) 90 | cur_nid = curNode.rch_index; 91 | else 92 | cur_nid = curNode.lch_index; 93 | curNode = node_data[cur_nid]; 94 | } 95 | sum += lr * node_data[cur_nid].base_weight; 96 | } 97 | predict_data_class[iid] += sum; 98 | } 99 | } 100 | } -------------------------------------------------------------------------------- /src/FedTree/util/common.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 10/14/20. 3 | // ThunderGBM common.cpp: https://github.com/Xtra-Computing/thundergbm/blob/master/src/thundergbm/util/common.cpp 4 | // Under Apache-2.0 license 5 | // copyright (c) 2020 jiashuai 6 | // 7 | 8 | #include "FedTree/common.h" 9 | INITIALIZE_EASYLOGGINGPP 10 | 11 | std::ostream &operator<<(std::ostream &os, const int_float &rhs) { 12 | os << string_format("%d/%f", thrust::get<0>(rhs), thrust::get<1>(rhs)); 13 | return os; 14 | } 15 | -------------------------------------------------------------------------------- /src/test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(googletest) 2 | 3 | include_directories(googletest/googletest/include) 4 | include_directories(googletest/googlemock/include) 5 | 6 | file(GLOB TEST_SRC *) 7 | 8 | if (USE_CUDA) 9 | cuda_add_executable(${PROJECT_NAME}-test ${TEST_SRC} ${COMMON_INCLUDES}) 10 | else () 11 | add_executable(${PROJECT_NAME}-test ${TEST_SRC} ${COMMON_INCLUDES}) 12 | endif () 13 | target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME} gtest) 14 | -------------------------------------------------------------------------------- /src/test/test_dataset.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 10/19/20. 3 | // 4 | #include "gtest/gtest.h" 5 | #include "FedTree/dataset.h" 6 | 7 | class DatasetTest : public ::testing::Test { 8 | public: 9 | FLParam fl_param; 10 | DataSet dataset; 11 | protected: 12 | void SetUp() override { 13 | dataset.load_from_file(DATASET_DIR "test_dataset.txt", fl_param); 14 | } 15 | }; 16 | 17 | class DatasetLoadCscTest : public ::testing::Test { 18 | public: 19 | FLParam fl_param; 20 | DataSet dataset; 21 | protected: 22 | void SetUp() override { 23 | dataset.load_csc_from_file(DATASET_DIR "test_dataset.txt", fl_param, 119); 24 | } 25 | }; 26 | 27 | TEST_F(DatasetTest, load_from_file){ 28 | printf("### Dataset: test_dataset.txt, num_instances: %zu, num_features: %zu, get_cut_points finished. ###\n", 29 | dataset.n_instances(), 30 | dataset.n_features()); 31 | EXPECT_EQ(dataset.n_instances(), 1605); 32 | EXPECT_EQ(dataset.n_features_, 119); 33 | EXPECT_EQ(dataset.label[0], -1); 34 | EXPECT_EQ(dataset.csr_val[1], 1); 35 | } 36 | 37 | TEST_F(DatasetLoadCscTest, load_csc_from_file){ 38 | printf("### Dataset: test_dataset.txt, num_instances: %zu, num_features: %zu, get_cut_points finished. ###\n", 39 | dataset.n_instances(), 40 | dataset.n_features()); 41 | EXPECT_EQ(dataset.n_instances(), 1605); 42 | EXPECT_EQ(dataset.n_features_, 119); 43 | EXPECT_EQ(dataset.label[0], -1); 44 | EXPECT_EQ(dataset.csc_val[1], 1); 45 | } 46 | 47 | //TEST(DatasetTest, load_dataset){ 48 | // DataSet dataset; 49 | // FLParams param; 50 | // dataset.load_from_file(DATASET_DIR "test_dataset.txt", param); 51 | // printf("### Dataset: %s, num_instances: %d, num_features: %d, get_cut_points finished. ###\n", 52 | // param.path.c_str(), 53 | // dataset.n_instances(), 54 | // dataset.n_features()); 55 | // EXPECT_EQ(dataset.n_instances(), 1605); 56 | // EXPECT_EQ(dataset.n_features_, 119); 57 | // EXPECT_EQ(dataset.label[0], -1); 58 | // EXPECT_EQ(dataset.csr_val[1], 1); 59 | //} 60 | -------------------------------------------------------------------------------- /src/test/test_find_feature_range.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Kelly Yung on 2021/2/25. 3 | // 4 | 5 | #include "FedTree/FL/party.h" 6 | #include "gtest/gtest.h" 7 | #include 8 | 9 | 10 | class FeatureRangeTest: public ::testing::Test { 11 | public: 12 | 13 | vector csc_val = {10, 20, 30, 50, 40, 60, 70, 80}; 14 | vector csc_row_idx = {0, 0, 1, 1, 2, 2, 2, 4}; 15 | vector csc_col_ptr = {0, 1, 3, 4, 6, 7, 8}; 16 | Party p; 17 | 18 | protected: 19 | void SetUp() override { 20 | p.dataset.csc_row_idx = csc_row_idx; 21 | p.dataset.csc_col_ptr = csc_col_ptr; 22 | p.dataset.csc_val = csc_val; 23 | } 24 | }; 25 | 26 | TEST_F(FeatureRangeTest, find_feature_range_by_index_single_value){ 27 | vector result = p.get_feature_range_by_feature_index(0); 28 | LOG(INFO) << "Result:" << result; 29 | EXPECT_EQ(result[0], 10); 30 | EXPECT_EQ(result[1], 10); 31 | } 32 | 33 | TEST_F(FeatureRangeTest, find_feature_range_by_index_multi_value){ 34 | vector result = p.get_feature_range_by_feature_index(1); 35 | EXPECT_EQ(result[0], 20); 36 | EXPECT_EQ(result[1], 30); 37 | } 38 | 39 | TEST_F(FeatureRangeTest, find_feature_range_by_last_index){ 40 | vector result = p.get_feature_range_by_feature_index(5); 41 | EXPECT_EQ(result[0], 80); 42 | EXPECT_EQ(result[1], 80); 43 | } 44 | -------------------------------------------------------------------------------- /src/test/test_gbdt.cpp: -------------------------------------------------------------------------------- 1 | //// 2 | //// Created by Kelly Yung on 2020/12/8. 3 | //// 4 | // 5 | //#include "gtest/gtest.h" 6 | //#include "FedTree/FL/FLtrainer.h" 7 | //#include "FedTree/predictor.h" 8 | //#include "FedTree/dataset.h" 9 | //#include "FedTree/Tree/tree.h" 10 | // 11 | //class GBDTTest: public ::testing::Test { 12 | //public: 13 | // GBDTParam param; 14 | // FLParam flParam; 15 | // 16 | //protected: 17 | // void SetUp() override { 18 | // // set GBDTParam 19 | // param.depth = 0; 20 | // param.n_trees = 5; 21 | // param.n_device = 1; 22 | // param.min_child_weight = 1; 23 | // param.lambda = 1; 24 | // param.gamma = 1; 25 | // param.rt_eps = 1e-6; 26 | // param.max_num_bin = 255; 27 | // param.verbose = false; 28 | // param.profiling = false; 29 | // param.column_sampling_rate = 1; 30 | // param.bagging = false; 31 | // param.n_parallel_trees = 1; 32 | // param.learning_rate = 1; 33 | // param.objective = "reg:linear"; 34 | // param.num_class = 1; 35 | // param.path = "../dataset/test_dataset.txt"; 36 | // param.tree_method = "hist"; 37 | // param.tree_per_rounds = 1; 38 | // if (!param.verbose) { 39 | // el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 40 | // el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 41 | // el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "True"); 42 | // } 43 | // if (!param.profiling) { 44 | // el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false"); 45 | // } 46 | // // set FLParam 47 | // flParam.gbdt_param = param; 48 | // flParam.n_parties = 1; 49 | // flParam.mode = "horizontal"; 50 | // flParam.privacy_tech = "none"; 51 | // } 52 | //}; 53 | // 54 | //// test find split 55 | //TEST_F (GBDTTest, test_find_split) { 56 | // 57 | // SyncArray sp; 58 | // SyncArray best_idx_gain; 59 | // // assume first level of tree 60 | // int n_nodes_in_level = 1; 61 | // int nid_offset = 0; 62 | // HistCut cut; 63 | // SyncArray hist; 64 | // 65 | // // construct tree with single node 66 | // Tree tree; 67 | // SyncArray nodes; 68 | // Tree::TreeNode node; 69 | // node.isValid = true; 70 | // auto nodesArray = nodes.host_data(); 71 | // nodesArray[0] = node; 72 | // 73 | // // instantiate HistCut and num_bins 74 | // HisCut cut; 75 | // SyncArray cut_points_val; 76 | // 77 | // SyncArray cut_row_ptr; 78 | // SyncArray cut_fid; 79 | // int n_bins = 2; 80 | // SyncArray hist; 81 | // 82 | // //test find_split 83 | // find_split(sp, n_nodes_in_level, tree, best_idx_gain, nid_offset, cut, hist, n_bins); 84 | // 85 | // // verify 86 | // EXPECT_EQ(sp.size(), n_nodes_in_level); 87 | //} 88 | // 89 | //// test update_tree 90 | //TEST_F (GBDTTest, test_update_tree) { 91 | // 92 | //} 93 | // 94 | // 95 | //// test predictor 96 | //TEST_F(GBDTTest, test_predictor) { 97 | // vector> boosted_model; 98 | // // construct a vector of vector of tree! 99 | // Tree tree; 100 | // SyncArray nodes; 101 | // 102 | // Tree::TreeNode node; 103 | // 104 | // //test 105 | // DataSet test_dataset; 106 | // test_dataset.load_from_file(flParam.gbdt_param.path, flParam); 107 | // Predictor predictor; 108 | // SyncArray y_predict; 109 | // predictor.get_y_predict(param, boosted_model, test_dataset, y_predict); 110 | //} -------------------------------------------------------------------------------- /src/test/test_gradient.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Kelly Yung on 2020/12/9. 3 | // 4 | 5 | #include "gtest/gtest.h" 6 | #include "FedTree/objective/regression_obj.h" 7 | #include "FedTree/dataset.h" 8 | #include "FedTree/syncarray.h" 9 | 10 | class GradientTest: public ::testing::Test { 11 | public: 12 | FLParam fl_param; 13 | GBDTParam param; 14 | protected: 15 | void SetUp() override { 16 | param.depth = 6; 17 | param.n_trees = 40; 18 | param.n_device = 1; 19 | param.min_child_weight = 1; 20 | param.lambda = 1; 21 | param.gamma = 1; 22 | param.rt_eps = 1e-6; 23 | param.max_num_bin = 255; 24 | param.verbose = false; 25 | param.profiling = false; 26 | param.column_sampling_rate = 1; 27 | param.bagging = false; 28 | param.n_parallel_trees = 1; 29 | param.learning_rate = 1; 30 | param.objective = "reg:linear"; 31 | param.num_class = 2; 32 | param.path = "../dataset/test_dataset.txt"; 33 | param.tree_method = "hist"; 34 | if (!param.verbose) { 35 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 36 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 37 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "True"); 38 | } 39 | if (!param.profiling) { 40 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false"); 41 | } 42 | fl_param.gbdt_param = param; 43 | fl_param.privacy_tech = "none"; 44 | fl_param.mode = "horizontal"; 45 | fl_param.n_parties = 1; 46 | } 47 | }; 48 | 49 | // Test regression function 50 | TEST_F(GradientTest, test_regression_obj) { 51 | DataSet dataset; 52 | dataset.load_from_file(param.path, fl_param); 53 | RegressionObj rmse; 54 | SyncArray y_true(4); 55 | SyncArray y_pred(4); 56 | auto y_pred_data = y_pred.host_data(); 57 | for(int i = 0; i < 4; i++) 58 | y_pred_data[i] = -i; 59 | SyncArray gh_pair(4); 60 | EXPECT_EQ(rmse.default_metric_name(), "rmse"); 61 | rmse.get_gradient(y_true, y_pred, gh_pair); 62 | auto gh_pair_data = gh_pair.host_data(); 63 | EXPECT_EQ(gh_pair_data[0], GHPair(0.0, 1.0)); 64 | EXPECT_EQ(gh_pair_data[1], GHPair(-1.0, 1.0)); 65 | EXPECT_EQ(gh_pair_data[2], GHPair(-2.0, 1.0)); 66 | EXPECT_EQ(gh_pair_data[3], GHPair(-3.0, 1.0)); 67 | } 68 | 69 | TEST_F(GradientTest, test_logcls_obj) { 70 | DataSet dataset; 71 | dataset.load_from_file(param.path, fl_param); 72 | LogClsObj logcls; 73 | SyncArray y_true(4); 74 | SyncArray y_pred(4); 75 | auto y_pred_data = y_pred.host_data(); 76 | for(int i = 0; i < 4; i++) 77 | y_pred_data[i] = -i; 78 | SyncArray gh_pair(4); 79 | EXPECT_EQ(logcls.default_metric_name(), "error"); 80 | logcls.get_gradient(y_true, y_pred, gh_pair); 81 | std::cout << gh_pair; 82 | auto gh_pair_data = gh_pair.host_data(); 83 | EXPECT_EQ(gh_pair_data[0], GHPair(0.0, 1.0)); 84 | EXPECT_EQ(gh_pair_data[1], GHPair(-1.0, 1.0)); 85 | EXPECT_EQ(gh_pair_data[2], GHPair(-2.0, 1.0)); 86 | EXPECT_EQ(gh_pair_data[3], GHPair(-3.0, 1.0)); 87 | } -------------------------------------------------------------------------------- /src/test/test_main.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by liqinbin on 10/19/20. 3 | // 4 | 5 | #include "FedTree/parser.h" 6 | #include "gtest/gtest.h" 7 | 8 | int main(int argc, char **argv) { 9 | ::testing::InitGoogleTest(&argc, argv); 10 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::Format, "%datetime %level %fbase:%line : %msg"); 11 | el::Loggers::addFlag(el::LoggingFlag::ColoredTerminalOutput); 12 | el::Loggers::addFlag(el::LoggingFlag::FixedTimeFormat); 13 | // el::Loggers::reconfigureAllLoggers(el::ConfigurationType::Enabled, "false"); 14 | return RUN_ALL_TESTS(); 15 | } 16 | -------------------------------------------------------------------------------- /src/test/test_parser.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "FedTree/parser.h" 3 | #include "FedTree/dataset.h" 4 | #include "FedTree/Tree/tree.h" 5 | 6 | class ParserTest: public ::testing::Test { 7 | public: 8 | FLParam fl_param; 9 | GBDTParam param; 10 | vector csr_val; 11 | vector csr_row_ptr; 12 | vector csr_col_idx; 13 | vector y; 14 | size_t n_features_; 15 | vector label; 16 | protected: 17 | void SetUp() override { 18 | fl_param.mode = "horizontal"; 19 | fl_param.n_parties = 2; 20 | fl_param.privacy_tech = "he"; 21 | 22 | param.depth = 6; 23 | param.n_trees = 40; 24 | param.n_device = 1; 25 | param.min_child_weight = 1; 26 | param.lambda = 1; 27 | param.gamma = 1; 28 | param.rt_eps = 1e-6; 29 | param.max_num_bin = 255; 30 | param.verbose = false; 31 | param.profiling = false; 32 | param.column_sampling_rate = 1; 33 | param.bagging = false; 34 | param.n_parallel_trees = 1; 35 | param.learning_rate = 1; 36 | param.objective = "reg:linear"; 37 | param.num_class = 1; 38 | param.path = DATASET_DIR "test_dataset.txt"; 39 | param.tree_method = "auto"; 40 | if (!param.verbose) { 41 | el::Loggers::reconfigureAllLoggers(el::Level::Debug, el::ConfigurationType::Enabled, "false"); 42 | el::Loggers::reconfigureAllLoggers(el::Level::Trace, el::ConfigurationType::Enabled, "false"); 43 | el::Loggers::reconfigureAllLoggers(el::Level::Info, el::ConfigurationType::Enabled, "True"); 44 | } 45 | if (!param.profiling) { 46 | el::Loggers::reconfigureAllLoggers(el::ConfigurationType::PerformanceTracking, "false"); 47 | } 48 | fl_param.gbdt_param = param; 49 | } 50 | }; 51 | 52 | TEST_F(ParserTest, test_parser){ 53 | EXPECT_EQ(fl_param.gbdt_param.depth, 6); 54 | EXPECT_EQ(fl_param.gbdt_param.gamma, 1); 55 | EXPECT_EQ(fl_param.gbdt_param.learning_rate, 1); 56 | EXPECT_EQ(fl_param.gbdt_param.num_class, 1); 57 | EXPECT_EQ(fl_param.gbdt_param.tree_method, "auto"); 58 | EXPECT_EQ(fl_param.gbdt_param.max_num_bin, 255); 59 | } 60 | 61 | TEST_F(ParserTest, test_save_model) { 62 | string model_path = "tgbm.model"; 63 | vector> boosted_model; 64 | DataSet dataset; 65 | dataset.load_from_file(fl_param.gbdt_param.path, fl_param); 66 | Parser parser; 67 | parser.save_model(model_path, fl_param.gbdt_param, boosted_model, dataset); 68 | } 69 | 70 | TEST_F(ParserTest, test_load_model) { 71 | string model_path = "tgbm.model"; 72 | vector> boosted_model; 73 | DataSet dataset; 74 | dataset.load_from_file(fl_param.gbdt_param.path, fl_param); 75 | Parser parser; 76 | parser.load_model(model_path, fl_param.gbdt_param, boosted_model, dataset); 77 | // the size of he boosted model should be zero 78 | EXPECT_EQ(boosted_model.size(), 0); 79 | } 80 | -------------------------------------------------------------------------------- /src/test/test_partition.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by hanyuxuan on 22/10/20. 3 | // 4 | 5 | #include "gtest/gtest.h" 6 | #include "FedTree/dataset.h" 7 | #include "FedTree/FL/FLparam.h" 8 | #include "FedTree/FL/partition.h" 9 | 10 | 11 | class PartitionTest : public ::testing::Test { 12 | public: 13 | FLParam fl_param; 14 | DataSet dataset; 15 | protected: 16 | void SetUp() override { 17 | dataset.load_from_file(DATASET_DIR 18 | "test_dataset.txt", fl_param); 19 | } 20 | }; 21 | 22 | TEST_F(PartitionTest, homo_partition) { 23 | printf("### Dataset: test_dataset.txt, num_instances: %zu, num_features: %zu, get_cut_points finished. ###\n", 24 | dataset.n_instances(), 25 | dataset.n_features()); 26 | EXPECT_EQ(dataset.n_instances(), 1605); 27 | EXPECT_EQ(dataset.n_features_, 119); 28 | EXPECT_EQ(dataset.label[0], -1); 29 | EXPECT_EQ(dataset.csr_val[1], 1); 30 | 31 | printf("### Test Partition ###\n"); 32 | Partition partition; 33 | // TODO 34 | // std::map> batch_idxs = partition.homo_partition(dataset, 5, true); 35 | // EXPECT_EQ(batch_idxs.size(), 5); 36 | // int count = 0; 37 | // for (auto const &x : batch_idxs) EXPECT_EQ(x.second.size(), 1605/5); 38 | } 39 | 40 | TEST_F(PartitionTest, hetero_partition) { 41 | printf("### Dataset: test_dataset.txt, num_instances: %zu, num_features: %zu, get_cut_points finished. ###\n", 42 | dataset.n_instances(), 43 | dataset.n_features()); 44 | EXPECT_EQ(dataset.n_instances(), 1605); 45 | EXPECT_EQ(dataset.n_features_, 119); 46 | EXPECT_EQ(dataset.label[0], -1); 47 | EXPECT_EQ(dataset.csr_val[1], 1); 48 | 49 | printf("### Test Partition ###\n"); 50 | Partition partition; 51 | // TODO: test values of subsets 52 | // vector subsets(5); 53 | // std::map> batch_idxs = partition.hetero_partition(dataset, 5, false, subsets); 54 | // EXPECT_EQ(batch_idxs.size(), 5); 55 | // int count = 0; 56 | // for (auto const &x : batch_idxs) count += x.second.size(); 57 | // EXPECT_EQ(count, dataset.n_features_); 58 | } 59 | -------------------------------------------------------------------------------- /src/test/test_tree.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/FedTree/ba4ea5eb4a75742a8b1bc986868e3d310260dc97/src/test/test_tree.cpp -------------------------------------------------------------------------------- /src/test/test_tree_builder.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by hanyuxuan on 28/10/20. 3 | // 4 | 5 | #include "gtest/gtest.h" 6 | #include "FedTree/Tree/tree_builder.h" 7 | #include "FedTree/Tree/hist_tree_builder.h" 8 | #include "FedTree/Tree/hist_cut.h" 9 | #include "FedTree/Tree/GBDTparam.h" 10 | #include 11 | 12 | class TreeBuilderTest : public ::testing::Test { 13 | public: 14 | 15 | GHPair father; 16 | GHPair lch; 17 | GHPair rch; 18 | HistTreeBuilder treeBuilder; 19 | 20 | protected: 21 | void SetUp() override { 22 | father = GHPair(5.1, 5); 23 | lch = GHPair(-2.1, 2); 24 | rch = GHPair(9.61, 3); 25 | } 26 | }; 27 | 28 | //TEST_F(TreeBuilderTest, compute_gain) { 29 | // EXPECT_FLOAT_EQ(treeBuilder.compute_gain_in_a_node(father, lch, rch, -5, 0.1), 26.791); 30 | //} 31 | 32 | TEST_F(TreeBuilderTest, gain_per_level) { 33 | SyncArray gradients(4); 34 | const vector gradients_vec = {GHPair(0.1, 0.2), GHPair(0.3, 0.4), GHPair(0.5, 0.6), GHPair(0.7, 0.8)}; 35 | gradients.copy_from(&gradients_vec[0], gradients_vec.size()); 36 | HistTreeBuilder htb; 37 | Tree tree; 38 | GBDTParam param = GBDTParam(); 39 | param.depth = 2; 40 | param.min_child_weight = 0.0; 41 | param.lambda = 0.2; 42 | tree.init_CPU(gradients, param); 43 | SyncArray hist(2); 44 | const vector hist_vec = {GHPair(0.2, 0.2), GHPair(0.5, 0.5)}; 45 | hist.copy_from(&hist_vec[0], hist_vec.size()); 46 | auto result = htb.gain(tree, hist, 0, 2); 47 | EXPECT_EQ(result.size(), 2); 48 | EXPECT_FLOAT_EQ(result.host_data()[0], 0); 49 | EXPECT_FLOAT_EQ(result.host_data()[1], 0); 50 | } 51 | 52 | TEST_F(TreeBuilderTest, compute_histogram) { 53 | printf("### Test compute_histogram ###\n"); 54 | int n_instances = 4; 55 | int n_columns = 2; 56 | 57 | SyncArray gradients(4); 58 | const vector gradients_vec = {GHPair(0.1, 0.2), GHPair(0.3, 0.4), GHPair(0.5, 0.6), GHPair(0.7, 0.8)}; 59 | gradients.copy_from(&gradients_vec[0], gradients_vec.size()); 60 | 61 | HistCut cut; 62 | cut.cut_col_ptr = SyncArray(3); 63 | const vector cut_col_ptr_vec = {0, 1, 3}; 64 | cut.cut_col_ptr.copy_from(&cut_col_ptr_vec[0], cut_col_ptr_vec.size()); 65 | 66 | SyncArray dense_bin_id(8); 67 | const vector bin_vec = {0, 0, 0, 1, 1, 1, 1, 2}; 68 | dense_bin_id.copy_from(&bin_vec[0], bin_vec.size()); 69 | 70 | SyncArray hist(5); 71 | HistTreeBuilder htb; 72 | htb.compute_histogram_in_a_node(gradients, cut, dense_bin_id); 73 | hist.copy_from(htb.get_hist()); 74 | auto hist_data = hist.host_data(); 75 | EXPECT_NEAR(hist_data[0].g, 0.4, 1e-5); 76 | EXPECT_NEAR(hist_data[0].h, 0.6, 1e-5); 77 | EXPECT_NEAR(hist_data[1].g, 1.2, 1e-5); 78 | EXPECT_NEAR(hist_data[1].h, 1.4, 1e-5); 79 | EXPECT_NEAR(hist_data[2].g, 0.1, 1e-5); 80 | EXPECT_NEAR(hist_data[2].h, 0.2, 1e-5); 81 | EXPECT_NEAR(hist_data[3].g, 0.8, 1e-5); 82 | EXPECT_NEAR(hist_data[3].h, 1.0, 1e-5); 83 | EXPECT_NEAR(hist_data[4].g, 0.7, 1e-5); 84 | EXPECT_NEAR(hist_data[4].h, 0.8, 1e-5); 85 | 86 | // vector histogram = TreeBuilder().compute_histogram(gradients, splits); 87 | // EXPECT_EQ(histogram.size(), splits.size() + 1); 88 | // EXPECT_NEAR(histogram[0], 0.1, 1e-5); 89 | // EXPECT_NEAR(histogram[1], 0.3, 1e-5); 90 | // EXPECT_NEAR(histogram[2], 0.4, 1e-5); 91 | } 92 | 93 | TEST_F(TreeBuilderTest, merge_histogram_server) { 94 | printf("### Test merge_histogram ###\n"); 95 | 96 | const vector hist1_vec = {GHPair(0.1, 0.2), GHPair(0.3, 0.4), GHPair(0.5, 0.6)}; 97 | const vector hist2_vec = {GHPair(0.11, 0.22), GHPair(0.33, 0.44), GHPair(0.55, 0.66)}; 98 | 99 | MSyncArray hists(2, 3); 100 | hists[0].copy_from(&hist1_vec[0], hist1_vec.size()); 101 | hists[1].copy_from(&hist2_vec[0], hist2_vec.size()); 102 | 103 | SyncArray merged_hist(3); 104 | HistTreeBuilder htb; 105 | htb.parties_hist_init(2); 106 | htb.append_hist(hists[0]); 107 | htb.append_hist(hists[1]); 108 | htb.merge_histograms_server_propose(merged_hist, merged_hist); 109 | merged_hist.copy_from(htb.get_hist()); 110 | auto hist_data = merged_hist.host_data(); 111 | EXPECT_NEAR(hist_data[0].g, 0.21, 1e-5); 112 | EXPECT_NEAR(hist_data[0].h, 0.42, 1e-5); 113 | EXPECT_NEAR(hist_data[1].g, 0.63, 1e-5); 114 | EXPECT_NEAR(hist_data[1].h, 0.84, 1e-5); 115 | EXPECT_NEAR(hist_data[2].g, 1.05, 1e-5); 116 | EXPECT_NEAR(hist_data[2].h, 1.26, 1e-5); 117 | } 118 | 119 | /* 120 | TEST_F(TreeBuilderTest, merge_histogram_clients) { 121 | printf("### Test merge_histogram clients###\n"); 122 | 123 | vector hist1_vec; 124 | vector hist2_vec; 125 | for (int i = 0; i < 14; i++) { 126 | hist1_vec.push_back(GHPair(1, 1)); 127 | hist2_vec.push_back(GHPair(1, 1)); 128 | } 129 | 130 | MSyncArray hists(2, 14); 131 | hists[0].copy_from(&hist1_vec[0], hist1_vec.size()); 132 | hists[1].copy_from(&hist2_vec[0], hist2_vec.size()); 133 | 134 | 135 | const vector cut_points_val_vec1 = {0.1, 0.3, 5, 7, 9, 15, 25, 35, 10, 11}; 136 | const vector cut_ptr_vec1 = {0, 2, 5, 8, 10}; 137 | const vector cut_points_val_vec2 = {0.4, 0.5, 0.6, 4, 8, 30, 50, 9, 12, 15}; 138 | const vector cut_ptr_vec2 = {0, 3, 5, 7, 10}; 139 | 140 | vector cuts(2); 141 | cuts[0].cut_col_ptr = SyncArray(5); 142 | cuts[0].cut_col_ptr.copy_from(&cut_ptr_vec1[0], cut_ptr_vec1.size()); 143 | cuts[0].cut_points_val = SyncArray(10); 144 | cuts[0].cut_points_val.copy_from(&cut_points_val_vec1[0], cut_points_val_vec1.size()); 145 | cuts[1].cut_col_ptr = SyncArray(5); 146 | cuts[1].cut_col_ptr.copy_from(&cut_ptr_vec2[0], cut_ptr_vec2.size()); 147 | cuts[1].cut_points_val = SyncArray(10); 148 | cuts[1].cut_points_val.copy_from(&cut_points_val_vec2[0], cut_points_val_vec2.size()); 149 | 150 | // HistTreeBuilder htb; 151 | // EXPECT_FLOAT_EQ(htb.merge_histograms_client_propose(hists, cuts)[0], -0.1); 152 | // EXPECT_FLOAT_EQ(htb.merge_histograms_client_propose(hists, cuts)[7], 0.6); 153 | // EXPECT_FLOAT_EQ(htb.merge_histograms_client_propose(hists, cuts)[8], 0); 154 | // EXPECT_FLOAT_EQ(htb.merge_histograms_client_propose(hists, cuts)[9], 2); 155 | 156 | SyncArray merged_hist(31); 157 | HistTreeBuilder htb; 158 | htb.merge_histograms_client_propose(hists, cuts, false); 159 | merged_hist.copy_from(htb.get_hist()); 160 | auto hist_data = merged_hist.host_data(); 161 | EXPECT_NEAR(hist_data[0].g, 0.5, 1e-5); 162 | EXPECT_NEAR(hist_data[1].g, 0.5, 1e-5); 163 | EXPECT_NEAR(hist_data[2].g, 0.5, 1e-5); 164 | EXPECT_NEAR(hist_data[3].g, 0.5, 1e-5); 165 | EXPECT_NEAR(hist_data[4].g, 1.5, 1e-5); 166 | EXPECT_NEAR(hist_data[5].g, 1.5, 1e-5); 167 | EXPECT_NEAR(hist_data[6].g, 1, 1e-5); 168 | } 169 | */ 170 | -------------------------------------------------------------------------------- /utils/FedTree_hcut.py: -------------------------------------------------------------------------------- 1 | # cut dataset for FedTree horizontal distributed 2 | import argparse 3 | import random 4 | import os 5 | import copy 6 | 7 | def main(args): 8 | total_features = 0 9 | size = 0 10 | label_set = set() 11 | with open(args.input_file, 'r') as fin: 12 | line = fin.readline() 13 | while line: 14 | line = line[:-1] 15 | if line and line[-1] == "\r": 16 | line = line[:-1] 17 | line = line.rstrip(' ') 18 | ele = line.split(' ') 19 | label_set.add(int(ele[0])) 20 | total_features = max(total_features, int(ele[-1].split(':')[0])) 21 | size += 1 22 | line = fin.readline() 23 | label_dict = {} 24 | if len(label_set) == 2: # classification 25 | tmp_label_list = sorted(label_set) 26 | label_dict = {str(item): str(i) for i, item in enumerate(tmp_label_list)} 27 | else: # regression 28 | label_dict = {str(item): str(item) for item in label_set} 29 | 30 | l = [i for i in range(size)] 31 | random.seed(42) 32 | random.shuffle(l) 33 | BLOCK = size // args.num_parties 34 | offset = [i*BLOCK+min(i, size % args.num_parties) for i in range(args.num_parties)]+[size] 35 | 36 | for i in range(args.num_parties): 37 | l[offset[i]:offset[i+1]] = sorted(l[offset[i]:offset[i+1]]) 38 | 39 | with open(args.input_file, 'r') as fin: 40 | fouts = [open(args.prefix+"_h_{}_{}".format(args.num_parties, i), 'w') for i in range(args.num_parties)] 41 | line = fin.readline() 42 | glob_idx = 0 43 | p_off = copy.deepcopy(offset) 44 | while line: 45 | line = line[:-1] 46 | if line and line[-1] == "\r": 47 | line = line[:-1] 48 | line = line.rstrip(' ') 49 | 50 | for i in range(args.num_parties): 51 | if p_off[i] < offset[i+1] and l[p_off[i]] == glob_idx: 52 | fouts[i].write(line+"\n") 53 | p_off[i] += 1 54 | glob_idx += 1 55 | line = fin.readline() 56 | for item in fouts: 57 | item.close() 58 | if not args.fate: 59 | return 60 | 61 | with open(args.input_file, 'r') as fin: 62 | fouts = [open(args.prefix+"_FATE_h_{}_{}".format(args.num_parties, i), 'w') for i in range(args.num_parties)] 63 | glob_idx = 0 64 | p_off = copy.deepcopy(offset) 65 | line = fin.readline() 66 | while line: 67 | line = line[:-1] 68 | if line and line[-1] == "\r": 69 | line = line[:-1] 70 | line = line.rstrip(' ') 71 | for i in range(args.num_parties): 72 | if p_off[i] < offset[i+1] and l[p_off[i]] == glob_idx: 73 | 74 | ele = line.split(' ') 75 | f_ele = [str(glob_idx), label_dict[str(int(ele[0]))]] 76 | f_ele += ["{}:{}".format(-1+int(item.split(':')[0]), item.split(':')[1]) for item in ele[1:]] 77 | 78 | fouts[i].write(','.join(f_ele)+'\n') 79 | p_off[i] += 1 80 | 81 | glob_idx += 1 82 | line = fin.readline() 83 | for item in fouts: 84 | item.close() 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument( 89 | "-n", 90 | "--num_parties", 91 | type=int 92 | ) 93 | parser.add_argument( 94 | "-f", 95 | "--input_file", 96 | type=str 97 | ) 98 | parser.add_argument( 99 | "-p", 100 | "--prefix", 101 | type=str 102 | ) 103 | parser.add_argument( 104 | "-fate", 105 | action='store_true' 106 | ) 107 | 108 | args = parser.parse_args() 109 | 110 | main(args) 111 | -------------------------------------------------------------------------------- /utils/FedTree_vcut.py: -------------------------------------------------------------------------------- 1 | # cut dataset for FedTree vertical distributed and FATE hetero boost 2 | 3 | import argparse 4 | import random 5 | 6 | def main(args): 7 | random.seed(42) 8 | num_features = 0 9 | with open(args.input_file, 'r') as f: 10 | line = f.readline() 11 | while line : 12 | line = line[:-1] 13 | if line and line[-1] == "\r": 14 | line = line[:-1] 15 | line = line.rstrip(' ') 16 | l = line.split(' ') 17 | num_features = max(num_features, int(l[-1].split(':')[0])) 18 | line = f.readline() 19 | 20 | print(num_features) 21 | features = [i+1 for i in range(num_features)] 22 | random.shuffle(features) 23 | BLOCK = num_features // args.num_parties 24 | offset = [i*BLOCK for i in range(args.num_parties)]+[num_features] 25 | 26 | id2party = {} 27 | id2local = {} 28 | for i in range(args.num_parties): 29 | features[offset[i]: offset[i+1]] = sorted(features[offset[i]: offset[i+1]]) 30 | for j in range(offset[i], offset[i+1]): 31 | id2party[features[j]] = i 32 | id2local[features[j]] = j-offset[i] 33 | print(features) 34 | 35 | with open(args.input_file, 'r') as f: 36 | outfs = [open(args.prefix+"_v_"+str(args.num_parties)+"_"+str(i), 'w') for i in range(args.num_parties)] 37 | line = f.readline() 38 | 39 | while line: 40 | line = line[:-1] 41 | if line and line[-1] == "\r": 42 | line = line[:-1] 43 | line = line.rstrip(' ') 44 | l = line.split(' ') 45 | outls = [[] for i in range(args.num_parties)] 46 | for i in range(args.num_parties): 47 | outls[i].append(l[0]) 48 | 49 | for item in l[1:]: 50 | k, v = item.split(":") 51 | outls[id2party[int(k)]].append("{}:{}".format(id2local[int(k)]+1, v)) 52 | 53 | for i in range(args.num_parties): 54 | outfs[i].write(" ".join(outls[i])+'\n') 55 | 56 | 57 | line = f.readline() 58 | 59 | for item in outfs: 60 | item.close() 61 | 62 | 63 | if args.test_file: 64 | with open(args.test_file, 'r') as f: 65 | with open(args.prefix+"_v_"+str(args.num_parties)+"_test", 'w') as out_tfs: 66 | line = f.readline() 67 | while line: 68 | line = line[:-1] 69 | if line and line[-1] == "\r": 70 | line = line[:-1] 71 | line = line.rstrip(' ') 72 | l = line.split(' ') 73 | outls = [[] for i in range(args.num_parties)] 74 | 75 | for item in l[1:]: 76 | k, v = item.split(":") 77 | outls[id2party[int(k)]].append("{}:{}".format(id2local[int(k)]+1+offset[id2party[int(k)]], v)) 78 | 79 | 80 | tmpl = [l[0]] 81 | for i in range(args.num_parties): 82 | tmpl.extend(outls[i]) 83 | out_tfs.write(" ".join(tmpl)+'\n') 84 | line = f.readline() 85 | 86 | if args.fate: 87 | with open(args.input_file, 'r') as f: 88 | outfs = [open(args.prefix+"_FATE_v_"+str(args.num_parties)+"_"+str(i), 'w') for i in range(args.num_parties)] 89 | line = f.readline() 90 | instances = 0 91 | while line: 92 | line = line[:-1] 93 | if line and line[-1] == "\r": 94 | line = line[:-1] 95 | line = line.rstrip(' ') 96 | l = line.split(' ') 97 | outls = [[] for i in range(args.num_parties)] 98 | for i in range(args.num_parties): 99 | outls[i].append(str(instances)) 100 | outls[i].append(l[0]) 101 | 102 | for item in l[1:]: 103 | k, v = item.split(":") 104 | outls[id2party[int(k)]].append("{}:{}".format(id2local[int(k)], v)) 105 | 106 | for i in range(args.num_parties): 107 | outfs[i].write(",".join(outls[i])+'\n') 108 | 109 | 110 | line = f.readline() 111 | instances += 1 112 | 113 | for item in outfs: 114 | item.close() 115 | 116 | 117 | 118 | 119 | 120 | if __name__ == "__main__": 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument( 123 | "-n", 124 | "--num_parties", 125 | type=int 126 | ) 127 | parser.add_argument( 128 | "-f", 129 | "--input_file", 130 | type=str 131 | ) 132 | parser.add_argument( 133 | "-t", 134 | "--test_file", 135 | type=str 136 | ) 137 | parser.add_argument( 138 | "-p", 139 | "--prefix", 140 | type=str 141 | ) 142 | parser.add_argument( 143 | "-fate", 144 | action='store_true' 145 | ) 146 | args = parser.parse_args() 147 | 148 | main(args) 149 | --------------------------------------------------------------------------------