├── README.md ├── app ├── lr │ ├── Makefile │ ├── run_lr.py │ └── lr.cpp └── svm │ ├── Makefile │ ├── run_svm.py │ └── svm.cpp ├── partitiondata.sh ├── src ├── comm │ ├── include_comm.hpp │ ├── protocol.hpp │ └── Comm.hpp ├── util │ ├── include_util.hpp │ ├── assist_func.hpp │ └── vector_operation.hpp ├── model │ ├── include_model.hpp │ ├── Model.hpp │ ├── LRModel.hpp │ └── SVMModel.hpp ├── trainer │ ├── include_trainer.hpp │ ├── Trainer.hpp │ ├── Coordinator.hpp │ ├── Server.hpp │ └── Worker.hpp ├── storage │ ├── include_storage.hpp │ ├── DataPoint.hpp │ ├── Gradient.hpp │ ├── Parameter.hpp │ └── DataSet.hpp └── include_ps.hpp └── partitiondata.c /README.md: -------------------------------------------------------------------------------- 1 | # LIBBLE-PS 2 | LIBBLE-PS is a library for big data machine learning. 3 | 4 | Please visit http://www.libble.ml for more details. 5 | -------------------------------------------------------------------------------- /app/lr/Makefile: -------------------------------------------------------------------------------- 1 | COMPILE_ARG = -std=c++11 -lpthread -D_GLIBCXX_USE_NANOSLEEP -O3 -march=native 2 | lr: lr.cpp 3 | mpic++ lr.cpp -o lr $(COMPILE_ARG) 4 | 5 | clean: 6 | rm lr -------------------------------------------------------------------------------- /app/svm/Makefile: -------------------------------------------------------------------------------- 1 | COMPILE_ARG = -std=c++11 -lpthread -D_GLIBCXX_USE_NANOSLEEP -O3 -march=native 2 | lr: svm.cpp 3 | mpic++ svm.cpp -o svm $(COMPILE_ARG) 4 | 5 | clean: 6 | rm svm -------------------------------------------------------------------------------- /partitiondata.sh: -------------------------------------------------------------------------------- 1 | #Usage: ./partitiondata.sh [file] [num] [host] 2 | #file path must be an absolute path 3 | #Warning: it will cover the file named "data"+"_" 4 | #Warning: host file should not contain itself 5 | #[num] is the number of partition data, it must be more than the Worker number of LIBBLE-PS 6 | #!/bin/bash 7 | rm -r $1"_" 8 | mkdir $1"_" 9 | g++ partitiondata.c -o partitiondata -std=c++11 10 | echo "partition data ..." 11 | ./partitiondata $1 $2 12 | pssh -h $3 rm -r $1"_" 13 | str=${1%/*} 14 | echo "distribute data ..." 15 | pscp -h $3 -r $1"_" $str -------------------------------------------------------------------------------- /src/comm/include_comm.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _INCLUDE_COMM_HPP_ 17 | #define _INCLUDE_COMM_HPP_ 18 | 19 | #include "Comm.hpp" 20 | 21 | #endif -------------------------------------------------------------------------------- /src/util/include_util.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _INCLUDE_UTIL_HPP_ 17 | #define _INCLUDE_UTIL_HPP_ 18 | 19 | #include "assist_func.hpp" 20 | #include "vector_operation.hpp" 21 | 22 | #endif -------------------------------------------------------------------------------- /src/model/include_model.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _INCLUDE_MODEL_HPP_ 17 | #define _INCLUDE_MODEL_HPP_ 18 | 19 | #include "LRModel.hpp" 20 | #include "SVMModel.hpp" 21 | #include "Model.hpp" 22 | 23 | #endif -------------------------------------------------------------------------------- /src/trainer/include_trainer.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _INCLUDE_TRAINER_HPP_ 17 | #define _INCLUDE_TRAINER_HPP_ 18 | 19 | #include "Coordinator.hpp" 20 | #include "Server.hpp" 21 | #include "Trainer.hpp" 22 | #include "Worker.hpp" 23 | 24 | #endif -------------------------------------------------------------------------------- /src/storage/include_storage.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _INCLUDE_STORAGE_HPP_ 17 | #define _INCLUDE_STORAGE_HPP_ 18 | 19 | #include "DataPoint.hpp" 20 | #include "DataSet.hpp" 21 | #include "Gradient.hpp" 22 | #include "Parameter.hpp" 23 | 24 | #endif -------------------------------------------------------------------------------- /src/include_ps.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _INCLUDE_PS_HPP_ 17 | #define _INCLUDE_PS_HPP_ 18 | 19 | #include "comm/include_comm.hpp" 20 | #include "model/include_model.hpp" 21 | #include "storage/include_storage.hpp" 22 | #include "trainer/include_trainer.hpp" 23 | #include "util/include_util.hpp" 24 | 25 | #endif -------------------------------------------------------------------------------- /src/storage/DataPoint.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _DATA_POINT_HPP_ 17 | #define _DATA_POINT_HPP_ 18 | 19 | #include 20 | #include "../util/include_util.hpp" 21 | 22 | class DataPoint { 23 | public: 24 | double label; 25 | std::vector key; 26 | std::vector value; 27 | 28 | DataPoint() {} 29 | 30 | ~DataPoint() {} 31 | 32 | DataPoint(const DataPoint &d) = delete; 33 | 34 | DataPoint &operator=(const DataPoint &d) = delete; 35 | }; 36 | 37 | #endif -------------------------------------------------------------------------------- /app/lr/run_lr.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/python 2 | 3 | import os 4 | import sys 5 | 6 | #-------------------------------------- 7 | # modify these arguments to run the program 8 | machine_file = ' mach ' # hosts file 9 | test_data_file = ' null ' #test data path, 'null': no test_file 10 | train_data_file = ' /data/webspam_wc_normalized_trigram.svm ' #train data path. It will read the directory 'train_data_file_/' for partition data. 11 | n_cols = ' 16609143 ' #train data feature number 12 | n_rows = ' 350000 ' #train data instance number 13 | n_servers = ' 2 ' 14 | n_workers = ' 16 ' 15 | n_epoches = ' 1 ' 16 | n_iters = ' 10 ' 17 | rate = ' 1 ' #step size 18 | lam = ' 0.0001 ' #regularization hyperparameter 19 | param_init = ' 0 ' # parameter initialization. 0--all zero 1--randomize to [0,1] 20 | #-------------------------------------- 21 | 22 | n_trainers = str(1 + int(n_servers) + int(n_workers)) 23 | 24 | os.system('mpirun -n ' + n_trainers + ' -f ' + machine_file + ' ./lr ' + ' -n_servers ' + n_servers 25 | + '-n_workers ' + n_workers + ' -n_epoches' + n_epoches + ' -n_iters' + n_iters + ' -n_cols' + n_cols 26 | + ' -n_rows' + n_rows + ' -test_data_file ' + test_data_file + ' -train_data_file ' + train_data_file 27 | + ' -rate ' + rate + ' -lambda ' + lam + ' -param_init ' + param_init ) 28 | 29 | -------------------------------------------------------------------------------- /app/svm/run_svm.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/python 2 | 3 | import os 4 | import sys 5 | 6 | #-------------------------------------- 7 | # modify these arguments to run the program 8 | machine_file = ' mach ' # hosts file 9 | test_data_file = ' /data/rcv1_test.binary ' #test data path, 'null': no test_file 10 | train_data_file = ' /data/rcv1_train.binary ' #train data path. It will read the directory 'train_data_file_/' for partition data. 11 | n_cols = ' 47236 ' #train data feature number 12 | n_rows = ' 20242 ' #train data instance number 13 | n_servers = ' 2 ' 14 | n_workers = ' 16 ' 15 | n_epoches = ' 1 ' 16 | n_iters = ' 100 ' 17 | rate = ' 0.1 ' #step size 18 | lam = ' 0.0001 ' #regularization hyperparameter 19 | param_init = ' 0 ' # parameter initialization. 0--all zero 1--randomize to [0,1] 20 | #-------------------------------------- 21 | 22 | n_trainers = str(1 + int(n_servers) + int(n_workers)) 23 | 24 | os.system('mpirun -n ' + n_trainers + ' -f ' + machine_file + ' ./svm ' + ' -n_servers ' + n_servers 25 | + '-n_workers ' + n_workers + ' -n_epoches' + n_epoches + ' -n_iters' + n_iters + ' -n_cols' + n_cols 26 | + ' -n_rows' + n_rows + ' -test_data_file ' + test_data_file + ' -train_data_file ' + train_data_file 27 | + ' -rate ' + rate + ' -lambda ' + lam + ' -param_init ' + param_init ) 28 | 29 | -------------------------------------------------------------------------------- /partitiondata.c: -------------------------------------------------------------------------------- 1 | /********************* 2 | g++ **.c -o ** -std=c++11 3 | ./partitiondata [filename] [number] 4 | attention: it will make a new file named "data"+"_" 5 | **********************/ 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | int main(int argc, char **argv) { 14 | if(argc!=3){ 15 | std::cout<<"argv error!\n"; 16 | return -1; 17 | } 18 | 19 | std::string data_files =argv[1]; 20 | int number = atoi(argv[2]); 21 | 22 | std::ifstream fin(data_files.c_str()); 23 | if(!fin){ 24 | std::cout<<"file open error!\n"; 25 | return -1; 26 | } 27 | 28 | std::ofstream fout[number]; 29 | for(int i=0;i 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include "../util/include_util.hpp" 25 | 26 | class Gradient_Dense { 27 | public: 28 | std::vector gradient; 29 | 30 | Gradient_Dense() {} 31 | 32 | Gradient_Dense(const Gradient_Dense &g) = delete; 33 | 34 | Gradient_Dense &operator=(const Gradient_Dense &g) = delete; 35 | 36 | void resize(int s) { gradient.resize(s); } 37 | 38 | void reset() { 39 | for (auto &x : gradient) x = 0; 40 | } 41 | }; 42 | 43 | // sparse gradient, to do 44 | class Gradient_Sparse { 45 | public: 46 | std::vector key; 47 | std::vector value; 48 | 49 | void resize(int s) { 50 | key.resize(s); 51 | value.resize(s); 52 | } 53 | 54 | // to do 55 | }; 56 | 57 | #endif -------------------------------------------------------------------------------- /src/trainer/Trainer.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _TRAINER_HPP_ 17 | #define _TRAINER_HPP_ 18 | 19 | #include 20 | #include "../comm/include_comm.hpp" 21 | #include "../model/include_model.hpp" 22 | #define PRINT_ITER 10 // for SGD 23 | 24 | class Trainer { 25 | protected: 26 | int num_servers, num_workers; // number of servers and workers in this system 27 | int num_cols; // number of features 28 | int num_of_all_data; // number of data 29 | int num_epoches; // number of epoches in the training process 30 | int num_iters; // number of iterations in scope 31 | std::string data_file; // file name of the dataset 32 | Model *model_ptr; 33 | Comm *comm_ptr; 34 | int mode; 35 | 36 | public: 37 | Trainer(int n_ser, int n_wor, int n_c, int n_r, int n_e, int n_i, int mode_, std::string f, 38 | Model *model_p, Comm *comm_p) 39 | : num_servers(n_ser), 40 | num_workers(n_wor), 41 | num_cols(n_c), 42 | num_of_all_data(n_r), 43 | num_epoches(n_e), 44 | num_iters(n_i), 45 | mode(mode_), 46 | data_file(f), 47 | model_ptr(model_p), 48 | comm_ptr(comm_p) {} 49 | 50 | // this function include the whole process of working for each participants 51 | virtual void work() = 0; 52 | }; 53 | 54 | #endif -------------------------------------------------------------------------------- /src/model/Model.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _MODEL_HPP_ 17 | #define _MODEL_HPP_ 18 | 19 | /* 20 | This class Model is the base class for every 21 | machine learning applications implemented in 22 | our PS. 23 | */ 24 | 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include "../storage/include_storage.hpp" 32 | 33 | class Model { 34 | private: 35 | public: 36 | Model() {} 37 | virtual double compute_loss(const DataSet &ds, const Parameter ¶ms, const int num_of_all_data, 38 | const int num_workers, const double lambda) = 0; 39 | 40 | virtual void compute_full_gradient(const DataSet &ds, const Parameter ¶ms, 41 | Gradient_Dense &g, const int num_of_all_data) = 0; 42 | 43 | virtual void update(const DataSet &ds, std::uniform_int_distribution<> &u, 44 | std::default_random_engine &e, Parameter ¶ms, 45 | const Gradient_Dense &full_grad, const double lambda, 46 | const int num_epoches, const double rate, const int recover_index, 47 | const int num_of_all_data, const int num_workers) = 0; 48 | 49 | 50 | }; 51 | 52 | #endif -------------------------------------------------------------------------------- /src/util/assist_func.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _ASSIST_FUNC_HPP_ 17 | #define _ASSIST_FUNC_HPP_ 18 | 19 | #define STRINGIFY(x) #x 20 | #define TOSTRING(x) STRINGIFY(x) 21 | #define AT __FILE__ ":" TOSTRING(__LINE__) 22 | 23 | #include 24 | #include 25 | 26 | void error(const char *loc, const char *msg) { 27 | std::cout << loc << ": " << msg << std::endl; 28 | exit(-1); 29 | } 30 | 31 | /* parse the argument */ 32 | int arg_parser(std::string str, int argc, char **argv) { 33 | int pos; 34 | for (pos = 0; pos < argc; pos++) { 35 | if (str.compare(argv[pos]) == 0) { 36 | return pos; 37 | } 38 | } 39 | return -1; 40 | } 41 | 42 | /* helps to find how many parameter a certain server holds */ 43 | int get_local_params_size(const int &n_cols, const int &n_servers, const int &server_id) { 44 | int x = n_cols / n_servers, y = n_cols % n_servers; 45 | if (y > 0) { 46 | if (server_id <= y) 47 | return x + 1; 48 | else 49 | return x; 50 | } else { 51 | return x; 52 | } 53 | } 54 | 55 | void write_file(std::string data_file, std::string info, double loss, double accuracy) { 56 | std::string output_file = data_file; 57 | std::ofstream fout(output_file.c_str(), std::ios::out | std::ios::app); 58 | fout.precision(15); 59 | fout << info << loss << " " << accuracy << std::endl; 60 | fout.close(); 61 | } 62 | 63 | #endif -------------------------------------------------------------------------------- /src/util/vector_operation.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _VECTOR_OPERATION_HPP_ 17 | #define _VECTOR_OPERATION_HPP_ 18 | 19 | #include 20 | 21 | // add vec2 onto vec1 22 | void vector_add(std::vector &vec1, const std::vector &vec2) { 23 | for (int i = 0; i < vec1.size(); i++) vec1[i] += vec2[i]; 24 | } 25 | 26 | void vector_add(std::vector &vec1, const std::vector &vec2) { 27 | for (int i = 0; i < vec1.size(); i++) vec1[i] += vec2[i]; 28 | } 29 | 30 | void vector_divi(std::vector &vec, const double &x) { 31 | for (int i = 0; i < vec.size(); i++) vec[i] /= x; 32 | } 33 | 34 | void vector_multi_add(std::vector &vec1, const double &x, const std::vector &vec2, 35 | const double &y) { 36 | for (int i = 0; i < vec1.size(); i++) vec1[i] = vec1[i] * x + vec2[i] * y; 37 | } 38 | 39 | void vector_divi_add(std::vector &vec1, const double &x, const std::vector &vec2, 40 | const double &y) { 41 | for (int i = 0; i < vec1.size(); i++) vec1[i] = vec1[i] / x + vec2[i] * y; 42 | } 43 | 44 | void vector_sub(std::vector &vec1, const std::vector &vec2) { 45 | for (int i = 0; i < vec1.size(); i++) vec1[i] -= vec2[i]; 46 | } 47 | 48 | double vector_multi(const std::vector &vec1, const std::vector &vec2) { 49 | double result = 0; 50 | for (int i = 0; i < vec1.size(); i++) result += vec1[i] * vec2[i]; 51 | return result; 52 | } 53 | 54 | #endif -------------------------------------------------------------------------------- /src/comm/protocol.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _PROTOCOL_HPP_ 17 | #define _PROTOCOL_HPP_ 18 | 19 | /* this file define the tag number used for MPI for sending different messages */ 20 | 21 | // from coordinator 22 | #define CW_PARAMS 100 // coordinator sends parameters to worker 23 | #define CW_INFO 101 // coordinator sends info to worker 24 | #define CW_GRAD 102 // coordinator sends full grad to worker 25 | #define CS_INFO 103 // coordinator sends info to server 26 | #define CSWPULL_INFO 104 // coordinator sends pull w_id info to server 27 | #define CSWPUSH_INFO 105 // coordinator sends push w_id info to server 28 | 29 | // from server 30 | #define SW_PARAMS 200 // server sends parameters to worker 31 | #define SC_EPOCH 201 // server sends epoch to coordinator to count time 32 | #define SC_PARAMS 202 // server sends parameters to coordinator in the end 33 | #define SW_C 203 // server sends c to worker 34 | #define SW_GRAD 204 // server sends full grad to worker 35 | 36 | // from worker 37 | #define WS_GRADS 300 // worker sends gradients to server 38 | #define WC_LOSS 301 // worker sends loss to coordinator 39 | #define WC_GRAD 302 // worker sends part full grad to coordinator 40 | #define WC_PARAMS 303 // worker sends parameters to coordinator 41 | #define WS_PARAMS 304 // worker sends parameters to coordinator 42 | #define WCP_INFO 305 // worker sends pull info to coordinator 43 | #define WCG_INFO 306 // worker sends push info to coordinator 44 | #define WS_C 307 // worker sends c to server 45 | #define WC_ACCU 308 // worker sends accuracy to coordinator 46 | 47 | #endif -------------------------------------------------------------------------------- /src/storage/Parameter.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _PARAMETER_HPP_ 17 | #define _PARAMETER_HPP_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include "../util/include_util.hpp" 25 | #include "Gradient.hpp" 26 | 27 | class Parameter { 28 | public: 29 | std::vector parameter; 30 | Parameter() {} 31 | Parameter(const std::vector &other_params) { parameter = other_params; } 32 | 33 | Parameter(const Parameter &p) = delete; 34 | 35 | Parameter &operator=(const Parameter &g) = delete; 36 | 37 | void resize(int s) { parameter.resize(s); } 38 | 39 | void reset() { 40 | for (auto &x : parameter) x = 0; 41 | } 42 | 43 | void parameter_random_init() { 44 | std::random_device rd; 45 | std::default_random_engine e(rd()); 46 | std::uniform_real_distribution<> u(0, 1); 47 | for (auto &x : parameter) x = u(e); 48 | } 49 | 50 | void subs_gradient(const Gradient_Dense &g, const double &rate) { 51 | for (int i = 0; i < parameter.size(); i++) { 52 | parameter[i] -= g.gradient[i] * rate; 53 | } 54 | } 55 | 56 | void soft_threshold(double z) { 57 | for (auto &x : parameter) { 58 | if (x > z) 59 | x -= z; 60 | else if (x < -z) 61 | x += z; 62 | else 63 | x = 0; 64 | } 65 | } 66 | 67 | // get a slice of parameters 68 | std::vector slice(int s, int e) { 69 | std::vector slice_parameter; 70 | for (int i = s; i < e; i++) { 71 | slice_parameter.push_back(parameter[i]); 72 | } 73 | 74 | return slice_parameter; 75 | } 76 | 77 | std::vector get_parameter() { return parameter; } 78 | 79 | void save_into_file(std::string data_file) { 80 | std::string output_file = data_file + "_output"; 81 | std::ofstream fout(output_file.c_str()); 82 | for (int i = 0; i < parameter.size(); i++) { 83 | fout << parameter[i] << " "; 84 | } 85 | fout << std::endl; 86 | fout.close(); 87 | } 88 | }; 89 | 90 | #endif -------------------------------------------------------------------------------- /src/trainer/Coordinator.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _COORDINATOR_HPP_ 17 | #define _COORDINATOR_HPP_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include "../util/include_util.hpp" 24 | #include "Trainer.hpp" 25 | 26 | class Coordinator : public Trainer { 27 | private: 28 | char info; 29 | Parameter params; 30 | 31 | public: 32 | Coordinator(int n_ser, int n_wor, int n_c, int n_r, int n_e, int n_i, int mode_, std::string f, 33 | Model *model_p, Comm *comm_p) 34 | : Trainer(n_ser, n_wor, n_c, n_r, n_e, n_i, mode_, f, model_p, comm_p) { 35 | params.resize(num_cols); 36 | } 37 | 38 | void work() override { 39 | 40 | std::chrono::duration total_time = (std::chrono::duration)0; 41 | double i_loss = gather_loss(); 42 | double accuracy = receive_accuracy(); 43 | std::cout.precision(15); 44 | if(accuracy != -1) { 45 | std::cout << "[0.000000s] iter 0 's loss is " << i_loss 46 | << ", accuracy is "<< accuracy << std::endl; 47 | } 48 | else 49 | std::cout << "[0.000000s] iter 0 's loss is " << i_loss << std::endl; 50 | for (int i = 0; i < num_iters; i++) { 51 | MPI_Barrier(MPI_COMM_WORLD); // start 52 | auto start = std::chrono::steady_clock::now(); 53 | MPI_Barrier(MPI_COMM_WORLD); // end 54 | auto end = std::chrono::steady_clock::now(); 55 | std::chrono::duration time = end - start; 56 | total_time += time; 57 | double loss = gather_loss(); 58 | double accuracy = receive_accuracy(); 59 | /************print***************/ 60 | if(accuracy != -1) { 61 | std::cout << "[" << total_time.count() << "s] iter " << i + 1 << " 's loss is " << loss 62 | << ", accuracy is "<< accuracy << std::endl; 63 | } 64 | else{ 65 | std::cout << "[" << total_time.count() << "s] iter " << i + 1 << " 's loss is " << loss 66 | << std::endl; 67 | } 68 | std::string file = data_file + "_info"; 69 | std::string info = std::to_string(i) + " " + std::to_string(total_time.count()) + " "; 70 | write_file(file, info, loss, accuracy); 71 | } 72 | 73 | recv_params_from_servers_and_save(); 74 | // std::cout << "coordinator done" << std::endl; 75 | } 76 | 77 | // gather loss from workers 78 | double gather_loss() { 79 | double loss = comm_ptr->C_recv_loss_from_all_W(); 80 | return loss; 81 | } 82 | 83 | double receive_accuracy() { 84 | double accuracy = comm_ptr->C_recv_accuracy_from_W(); 85 | return accuracy; 86 | } 87 | 88 | // receive parameters from servers and save to file 89 | void recv_params_from_servers_and_save() { 90 | comm_ptr->C_recv_params_from_all_S(params); 91 | // save params to file 92 | params.save_into_file(data_file); 93 | std::cout << "Already saved parameters into file." << std::endl; 94 | } 95 | }; 96 | 97 | #endif -------------------------------------------------------------------------------- /src/trainer/Server.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _SERVER_HPP_ 17 | #define _SERVER_HPP_ 18 | 19 | #include 20 | #include "../storage/include_storage.hpp" 21 | #include "../util/include_util.hpp" 22 | #include "Trainer.hpp" 23 | 24 | class Server : public Trainer { 25 | private: 26 | char info; 27 | int server_id; 28 | Parameter params; 29 | Gradient_Dense grad; 30 | double rate; 31 | double lambda; 32 | int param_init; 33 | double MIN = pow(0.1,290); 34 | int recover_index = 0; 35 | 36 | public: 37 | Server(int n_ser, int n_wor, int n_c, int n_r, int n_e, int n_i, int mode_, std::string f, 38 | Model *model_p, Comm *comm_p, int proc_id, double lambda_, double r, int param_i) 39 | : Trainer(n_ser, n_wor, n_c, n_r, n_e, n_i, mode_, f, model_p, comm_p), 40 | server_id(proc_id), 41 | lambda(lambda_), 42 | rate(r), 43 | param_init(param_i) { 44 | int s = get_local_params_size(num_cols, num_servers, server_id); 45 | params.resize(s); 46 | grad.resize(s); 47 | 48 | // paramter initialization 49 | if (param_init == 0) 50 | params.reset(); 51 | else 52 | params.parameter_random_init(); 53 | } 54 | 55 | void work() override { 56 | push(); 57 | // check if a exceed 58 | double check_a = 1; 59 | for (int i = 0; i < num_epoches * (num_of_all_data / num_workers); i++) { 60 | check_a *= (1 - rate * lambda); 61 | if(check_a < MIN){ 62 | recover_index = i; 63 | break; 64 | } 65 | } 66 | 67 | for (int i = 0; i < num_iters; i++) { 68 | MPI_Barrier(MPI_COMM_WORLD); // start 69 | 70 | pull_part_full_grad(); 71 | 72 | push_full_grad(); 73 | 74 | pull_params(); 75 | 76 | push(); 77 | MPI_Barrier(MPI_COMM_WORLD); // end 78 | } 79 | 80 | send_params_to_coordinator(); 81 | // std::cout << "server " << server_id << " done" << std::endl; 82 | } 83 | 84 | void push() { comm_ptr->S_send_params_to_all_W(params); } 85 | 86 | void push_full_grad() { comm_ptr->S_send_grads_to_all_W(grad); } 87 | 88 | void pull_part_full_grad() { 89 | comm_ptr->S_recv_grads_from_all_W(grad); 90 | } 91 | 92 | void pull_params() { 93 | comm_ptr->S_recv_params_from_all_W(params); 94 | double a = 1, b = 0; 95 | int r_i = recover_index==0?(num_epoches*(num_of_all_data/num_workers)):((num_epoches*(num_of_all_data/num_workers))%recover_index); 96 | for (int i = 0; i < r_i; i++) { 97 | a = (1 - lambda * rate) * a; 98 | b = (1 - lambda * rate) * b - rate; 99 | } 100 | vector_multi_add(params.parameter, a / num_workers, grad.gradient, b); 101 | } 102 | 103 | // send parameters to coordinator for saving 104 | void send_params_to_coordinator() { comm_ptr->S_send_params_to_C(params); } 105 | }; 106 | 107 | #endif -------------------------------------------------------------------------------- /src/model/LRModel.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _LRMODEL_HPP_ 17 | #define _LRMODEL_HPP_ 18 | 19 | #include 20 | #include "../storage/include_storage.hpp" 21 | #include "Model.hpp" 22 | 23 | class LRModel : public Model { 24 | public: 25 | LRModel() {} 26 | double compute_loss(const DataSet &ds, const Parameter ¶ms, const int num_of_all_data, 27 | const int num_workers, const double lambda) override { 28 | double loss = 0; 29 | for (int i = 0; i < ds.num_rows; i++) { 30 | DataPoint &d = ds.data[i]; 31 | double z = 0; 32 | for (int j = 0; j < d.key.size(); j++) { 33 | z += params.parameter[d.key[j]] * d.value[j]; 34 | } 35 | loss += log(1 + exp(-d.label * z)); 36 | } 37 | loss /= (double)num_of_all_data; 38 | double index = 0.5 * lambda / ((double)num_workers); 39 | for (int i = 0; i < params.parameter.size(); i++) { 40 | loss += index * pow(params.parameter[i], 2) ; 41 | } 42 | return loss; 43 | } 44 | 45 | void compute_full_gradient(const DataSet &ds, const Parameter ¶ms, Gradient_Dense &g, 46 | const int num_of_all_data) override { 47 | g.reset(); 48 | for (int i = 0; i < ds.num_rows; i++) { 49 | DataPoint &d = ds.data[i]; 50 | double z = 0; 51 | for (int j = 0; j < d.key.size(); j++) { 52 | z += params.parameter[d.key[j]] * d.value[j]; 53 | } 54 | z = -d.label * (1 - 1 / (1 + exp(-d.label * z)))/num_of_all_data; 55 | for (int j = 0; j < d.key.size(); j++) { 56 | g.gradient[d.key[j]] += z * d.value[j]; 57 | } 58 | } 59 | } 60 | 61 | void update(const DataSet &ds, std::uniform_int_distribution<> &u, 62 | std::default_random_engine &e, Parameter ¶ms, 63 | const Gradient_Dense &full_grad, const double lambda, 64 | const int num_epoches, const double rate, const int recover_index, 65 | const int num_of_all_data, const int num_workers) override { 66 | const std::vector old_params = params.parameter; 67 | double a = 1, b = 0; 68 | for (int i = 0; i < num_epoches * (num_of_all_data/num_workers); i++) { 69 | if(recover_index !=0 && i%recover_index == 0){ 70 | vector_multi_add(params.parameter, a, full_grad.gradient, b); 71 | a = 1; 72 | b = 0; 73 | } 74 | double z, z1 = 0, z2 = 0; 75 | const DataPoint &d = ds.data[u(e)]; 76 | for (int j = 0; j < d.key.size(); j++) { 77 | z1 += (a * params.parameter[d.key[j]] + b * full_grad.gradient[d.key[j]]) * 78 | d.value[j]; 79 | z2 += old_params[d.key[j]] * d.value[j]; 80 | } 81 | b = (1 - lambda * rate) * b - rate; 82 | a = (1 - lambda * rate) * a; 83 | z = rate * d.label * (1 / (1 + exp(-d.label * z1)) - 1 / (1 + exp(-d.label * z2))) / a; 84 | for (int j = 0; j < d.key.size(); j++) { 85 | params.parameter[d.key[j]] -= z * d.value[j]; 86 | } 87 | } 88 | } 89 | }; 90 | 91 | #endif -------------------------------------------------------------------------------- /app/lr/lr.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include "mpi.h" 21 | 22 | #include "../../src/include_ps.hpp" 23 | 24 | using namespace std; 25 | 26 | int n_servers = 2, n_workers = 2; 27 | int n_epoches = 100; 28 | int n_iters = 10; 29 | int n_cols = 50; // num of features for a single data point 30 | int n_rows = 1000; 31 | string test_data_file = ""; 32 | string data_file = ""; 33 | int batch_size = 10; // the size of sub-dataset a worker sampled for an epoch 34 | double rate = 0.01; 35 | double lambda = 0.0001; 36 | int mode = 1; 37 | int param_init = 0; 38 | 39 | int main(int argc, char **argv) { 40 | int proc_id, num_procs, provided; 41 | 42 | MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided); 43 | if (provided != MPI_THREAD_MULTIPLE) { 44 | printf("MPI do not Support Multiple thread\n"); 45 | exit(0); 46 | } 47 | MPI_Comm_rank(MPI_COMM_WORLD, &proc_id); 48 | MPI_Comm_size(MPI_COMM_WORLD, &num_procs); 49 | 50 | int pos; 51 | if ((pos = arg_parser("-n_servers", argc, argv)) > 0) n_servers = atoi(argv[pos + 1]); 52 | if ((pos = arg_parser("-n_workers", argc, argv)) > 0) n_workers = atoi(argv[pos + 1]); 53 | if ((pos = arg_parser("-n_epoches", argc, argv)) > 0) n_epoches = atoi(argv[pos + 1]); 54 | if ((pos = arg_parser("-n_iters", argc, argv)) > 0) n_iters = atoi(argv[pos + 1]); 55 | if ((pos = arg_parser("-n_cols", argc, argv)) > 0) n_cols = atoi(argv[pos + 1]); 56 | if ((pos = arg_parser("-n_rows", argc, argv)) > 0) n_rows = atoi(argv[pos + 1]); 57 | if ((pos = arg_parser("-test_data_file", argc, argv)) > 0) test_data_file = argv[pos + 1]; 58 | if ((pos = arg_parser("-train_data_file", argc, argv)) > 0) data_file = argv[pos + 1]; 59 | if ((pos = arg_parser("-batch_size", argc, argv)) > 0) batch_size = atoi(argv[pos + 1]); 60 | if ((pos = arg_parser("-rate", argc, argv)) > 0) rate = atof(argv[pos + 1]); 61 | if ((pos = arg_parser("-lambda", argc, argv)) > 0) lambda = atof(argv[pos + 1]); 62 | if ((pos = arg_parser("-mode", argc, argv)) > 0) mode = atoi(argv[pos + 1]); 63 | if ((pos = arg_parser("-param_init", argc, argv)) > 0) param_init = atoi(argv[pos + 1]); 64 | 65 | n_cols += 1; // add the bias term 66 | assert(n_servers + n_workers + 1 == num_procs); 67 | Model *model_ptr = new LRModel(); 68 | Comm *comm_ptr = new Comm(n_servers, n_workers, n_cols); 69 | Trainer *trainer_ptr = nullptr; 70 | if (proc_id == 0) { 71 | trainer_ptr = new Coordinator(n_servers, n_workers, n_cols, n_rows, n_epoches, n_iters, 72 | mode, data_file, model_ptr, comm_ptr); 73 | } else if (proc_id <= n_servers) { 74 | trainer_ptr = new Server(n_servers, n_workers, n_cols, n_rows, n_epoches, n_iters, mode, 75 | data_file, model_ptr, comm_ptr, proc_id, lambda, rate, param_init); 76 | } else { 77 | trainer_ptr = 78 | new Worker(n_servers, n_workers, n_cols, n_rows, n_epoches, n_iters, mode, data_file, 79 | model_ptr, comm_ptr, proc_id - n_servers, batch_size, lambda, rate, test_data_file); 80 | } 81 | 82 | trainer_ptr->work(); // start working 83 | 84 | MPI_Finalize(); 85 | 86 | delete model_ptr; 87 | delete trainer_ptr; 88 | delete comm_ptr; 89 | return 0; 90 | } -------------------------------------------------------------------------------- /app/svm/svm.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include "mpi.h" 21 | 22 | #include "../../src/include_ps.hpp" 23 | 24 | using namespace std; 25 | 26 | int n_servers = 2, n_workers = 2; 27 | int n_epoches = 100; 28 | int n_iters = 10; 29 | int n_cols = 50; // num of features for a single data point 30 | int n_rows = 1000; 31 | string test_data_file = ""; 32 | string data_file = ""; 33 | int batch_size = 10; // the size of sub-dataset a worker sampled for an epoch 34 | double rate = 0.01; 35 | double lambda = 0.0001; 36 | int mode = 1; 37 | int param_init = 0; 38 | 39 | int main(int argc, char **argv) { 40 | int proc_id, num_procs, provided; 41 | 42 | MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided); 43 | if (provided != MPI_THREAD_MULTIPLE) { 44 | printf("MPI do not Support Multiple thread\n"); 45 | exit(0); 46 | } 47 | MPI_Comm_rank(MPI_COMM_WORLD, &proc_id); 48 | MPI_Comm_size(MPI_COMM_WORLD, &num_procs); 49 | 50 | int pos; 51 | if ((pos = arg_parser("-n_servers", argc, argv)) > 0) n_servers = atoi(argv[pos + 1]); 52 | if ((pos = arg_parser("-n_workers", argc, argv)) > 0) n_workers = atoi(argv[pos + 1]); 53 | if ((pos = arg_parser("-n_epoches", argc, argv)) > 0) n_epoches = atoi(argv[pos + 1]); 54 | if ((pos = arg_parser("-n_iters", argc, argv)) > 0) n_iters = atoi(argv[pos + 1]); 55 | if ((pos = arg_parser("-n_cols", argc, argv)) > 0) n_cols = atoi(argv[pos + 1]); 56 | if ((pos = arg_parser("-n_rows", argc, argv)) > 0) n_rows = atoi(argv[pos + 1]); 57 | if ((pos = arg_parser("-test_data_file", argc, argv)) > 0) test_data_file = argv[pos + 1]; 58 | if ((pos = arg_parser("-train_data_file", argc, argv)) > 0) data_file = argv[pos + 1]; 59 | if ((pos = arg_parser("-batch_size", argc, argv)) > 0) batch_size = atoi(argv[pos + 1]); 60 | if ((pos = arg_parser("-rate", argc, argv)) > 0) rate = atof(argv[pos + 1]); 61 | if ((pos = arg_parser("-lambda", argc, argv)) > 0) lambda = atof(argv[pos + 1]); 62 | if ((pos = arg_parser("-mode", argc, argv)) > 0) mode = atoi(argv[pos + 1]); 63 | if ((pos = arg_parser("-param_init", argc, argv)) > 0) param_init = atoi(argv[pos + 1]); 64 | 65 | n_cols += 1; // add the bias term 66 | 67 | assert(n_servers + n_workers + 1 == num_procs); 68 | 69 | Model *model_ptr = new SVMModel(); 70 | Comm *comm_ptr = new Comm(n_servers, n_workers, n_cols); 71 | 72 | Trainer *trainer_ptr = nullptr; 73 | if (proc_id == 0) { 74 | trainer_ptr = new Coordinator(n_servers, n_workers, n_cols, n_rows, n_epoches, n_iters, 75 | mode, data_file, model_ptr, comm_ptr); 76 | } else if (proc_id <= n_servers) { 77 | trainer_ptr = new Server(n_servers, n_workers, n_cols, n_rows, n_epoches, n_iters, mode, 78 | data_file, model_ptr, comm_ptr, proc_id, lambda, rate, param_init); 79 | } else { 80 | trainer_ptr = 81 | new Worker(n_servers, n_workers, n_cols, n_rows, n_epoches, n_iters, mode, data_file, 82 | model_ptr, comm_ptr, proc_id - n_servers, batch_size, lambda, rate, test_data_file); 83 | } 84 | 85 | trainer_ptr->work(); // start working 86 | 87 | MPI_Finalize(); 88 | 89 | delete model_ptr; 90 | delete trainer_ptr; 91 | delete comm_ptr; 92 | return 0; 93 | } -------------------------------------------------------------------------------- /src/model/SVMModel.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _SVMMODEL_HPP_ 17 | #define _SVMMODEL_HPP_ 18 | 19 | #include 20 | #include "../storage/include_storage.hpp" 21 | #include "Model.hpp" 22 | 23 | class SVMModel : public Model { 24 | public: 25 | SVMModel() {} 26 | double compute_loss(const DataSet &ds, const Parameter ¶ms, const int num_of_all_data, 27 | const int num_workers, const double lambda) override { 28 | double loss = 0; 29 | for (int i = 0; i < ds.num_rows; i++) { 30 | DataPoint &d = ds.data[i]; 31 | double z = 0; 32 | for (int j = 0; j < d.key.size(); j++) { 33 | z += params.parameter[d.key[j]] * d.value[j]; 34 | } 35 | if(d.label*z <1){ 36 | loss += (1 - d.label*z); 37 | } 38 | } 39 | loss /= (double)num_of_all_data; 40 | double index = 0.5 * lambda / ((double)num_workers); 41 | for (int i = 0; i < params.parameter.size(); i++) { 42 | loss += index * pow(params.parameter[i], 2) ; 43 | } 44 | return loss; 45 | } 46 | 47 | void compute_full_gradient(const DataSet &ds, const Parameter ¶ms, Gradient_Dense &g, 48 | const int num_of_all_data) override { 49 | g.reset(); 50 | for (int i = 0; i < ds.num_rows; i++) { 51 | DataPoint &d = ds.data[i]; 52 | double z = 0; 53 | for (int j = 0; j < d.key.size(); j++) { 54 | z += params.parameter[d.key[j]] * d.value[j]; 55 | } 56 | if(d.label*z <1){ 57 | for (int j = 0; j < d.key.size(); j++) { 58 | g.gradient[d.key[j]] -= d.label * d.value[j]; 59 | } 60 | } 61 | } 62 | vector_divi(g.gradient, num_of_all_data); 63 | } 64 | 65 | void update(const DataSet &ds, std::uniform_int_distribution<> &u, 66 | std::default_random_engine &e, Parameter ¶ms, 67 | const Gradient_Dense &full_grad, const double lambda, 68 | const int num_epoches, const double rate, const int recover_index, 69 | const int num_of_all_data, const int num_workers) override { 70 | const std::vector old_params = params.parameter; 71 | double a = 1, b = 0; 72 | for (int i = 0; i < num_epoches * (num_of_all_data/num_workers); i++) { 73 | if(recover_index !=0 && i%recover_index == 0){ 74 | vector_multi_add(params.parameter, a, full_grad.gradient, b); 75 | a = 1; 76 | b = 0; 77 | } 78 | double z, z1 = 0, z2 = 0; 79 | const DataPoint &d = ds.data[u(e)]; 80 | for (int j = 0; j < d.key.size(); j++) { 81 | z1 += (a * params.parameter[d.key[j]] + b * full_grad.gradient[d.key[j]]) * 82 | d.value[j]; 83 | z2 += old_params[d.key[j]] * d.value[j]; 84 | } 85 | b = (1 - lambda * rate) * b - rate; 86 | a = (1 - lambda * rate) * a; 87 | 88 | if(d.label * z1 >1 && d.label * z2 <1){ 89 | for (int j = 0; j < d.key.size(); j++) { 90 | params.parameter[d.key[j]] -= rate * d.label * d.value[j]/a; 91 | } 92 | } 93 | else if(d.label * z1 <1 && d.label * z2 >1){ 94 | for (int j = 0; j < d.key.size(); j++) { 95 | params.parameter[d.key[j]] += rate * d.label * d.value[j]/a; 96 | } 97 | } 98 | else; 99 | } 100 | } 101 | }; 102 | 103 | #endif -------------------------------------------------------------------------------- /src/trainer/Worker.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _WORKER_HPP_ 17 | #define _WORKER_HPP_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include "../storage/include_storage.hpp" 24 | #include "../util/include_util.hpp" 25 | #include "Trainer.hpp" 26 | 27 | class Worker : public Trainer { 28 | private: 29 | int worker_id; 30 | int batch_size; 31 | double lambda; 32 | char info; 33 | double rate; 34 | double MIN = pow(0.1,290); 35 | int recover_index = 0; 36 | DataSet dataset; 37 | Parameter params; 38 | Gradient_Dense grad; 39 | Gradient_Dense full_grad; 40 | std::string test_data_file; 41 | DataSet test_dataset; 42 | 43 | public: 44 | Worker(int n_ser, int n_wor, int n_c, int n_r, int n_e, int n_i, int mode_, std::string f, 45 | Model *model_p, Comm *comm_p, int proc_id, int b_s, double lambda_, double r, std::string t_f) 46 | : Trainer(n_ser, n_wor, n_c, n_r, n_e, n_i, mode_, f, model_p, comm_p), 47 | worker_id(proc_id), 48 | batch_size(b_s), 49 | lambda(lambda_), 50 | rate(r), 51 | test_data_file(t_f) { 52 | params.resize(num_cols); 53 | grad.resize(num_cols); 54 | full_grad.resize(num_cols); 55 | } 56 | 57 | void work() override { 58 | read_data(); 59 | double check_a = 1; 60 | for (int i = 0; i < num_epoches * (num_of_all_data / num_workers ); i++) { 61 | check_a *= (1 - rate * lambda); 62 | if(check_a < MIN){ 63 | recover_index = i; 64 | break; 65 | } 66 | } 67 | if(test_data_file != "null" && worker_id == 1) {read_test_data();} 68 | std::random_device rd; 69 | std::default_random_engine e(rd()); 70 | std::uniform_int_distribution<> u(0, dataset.get_num_rows() - 1); 71 | 72 | pull(); 73 | double i_loss = calculate_loss(); 74 | report_loss(i_loss); 75 | if(worker_id == 1) { 76 | report_accuracy(); 77 | } 78 | for (int i = 0; i < num_iters; i++) { 79 | MPI_Barrier(MPI_COMM_WORLD); // start 80 | 81 | calculate_part_full_gradient(); 82 | 83 | push(); 84 | 85 | pull_full_grad(); 86 | 87 | local_update_sparse(u, e); 88 | 89 | scope_push(); 90 | 91 | pull(); 92 | MPI_Barrier(MPI_COMM_WORLD); // end 93 | double loss = calculate_loss(); 94 | report_loss(loss); 95 | if(worker_id == 1) { 96 | report_accuracy(); 97 | } 98 | } 99 | 100 | // std::cout << "worker " << worker_id << " done" << std::endl; 101 | } 102 | 103 | void sample_data(std::vector &sample_ids) { 104 | // assert(sample_ids.size() == 0); 105 | int num_rows = dataset.get_num_rows(), left = batch_size; 106 | std::default_random_engine e(std::random_device{}()); 107 | for (int i = 0; i < num_rows; i++) { 108 | int x = e() % (num_rows - i); 109 | if (x < left) { 110 | sample_ids.push_back(i); 111 | left--; 112 | if (left == 0) break; 113 | } 114 | } 115 | } 116 | 117 | void pull() { comm_ptr->W_recv_params_from_all_S(params); } 118 | 119 | void pull_full_grad() { comm_ptr->W_recv_full_grad_from_all_S(full_grad); } 120 | 121 | void push() { comm_ptr->W_send_grads_to_all_S(grad); } 122 | 123 | void scope_push() { comm_ptr->W_send_params_to_all_S(params); } 124 | 125 | // read data from files 126 | void read_data() { dataset.read_from_file(data_file, worker_id, num_workers, num_cols); } 127 | 128 | // calculate loss for all data 129 | double calculate_loss() { 130 | return model_ptr->compute_loss(dataset, params, num_of_all_data, num_workers, lambda); 131 | } 132 | 133 | // calculate local full gradient 134 | void calculate_part_full_gradient() { 135 | model_ptr->compute_full_gradient(dataset, params, grad, num_of_all_data); 136 | } 137 | 138 | void local_update_sparse(std::uniform_int_distribution<> &u, std::default_random_engine &e) { 139 | model_ptr->update(dataset, u, e, params, full_grad, lambda, num_epoches, rate, recover_index, num_of_all_data, num_workers); 140 | } 141 | 142 | // report loss to coordinator 143 | void report_loss(double loss) { comm_ptr->W_send_loss_to_C(loss); } 144 | 145 | void read_test_data() { test_dataset.read_from_test_file(test_data_file, num_cols); } 146 | 147 | void report_accuracy(){ 148 | if(test_data_file != "null"){ 149 | double accuracy = 0; 150 | for(int i = 0; i< test_dataset.num_rows; i++){ 151 | double result = 0; 152 | DataPoint &d = test_dataset.data[i]; 153 | for(int j = 0; j < d.key.size(); j++){ 154 | result += params.parameter[d.key[j]] * d.value[j]; 155 | } 156 | if(result * d.label > 0) accuracy++; 157 | else if (result * d.label == 0) {accuracy += 0.5;} 158 | else; 159 | } 160 | accuracy /= test_dataset.num_rows; 161 | comm_ptr->W_send_accuracy_to_C(accuracy); 162 | } 163 | else{ 164 | comm_ptr->W_send_accuracy_to_C(-1); 165 | } 166 | } 167 | }; 168 | 169 | #endif -------------------------------------------------------------------------------- /src/storage/DataSet.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _DATASET_HPP_ 17 | #define _DATASET_HPP_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include "../util/include_util.hpp" 27 | #include "DataPoint.hpp" 28 | 29 | class DataSet { 30 | public: 31 | DataPoint *data; 32 | int num_rows; 33 | int num_cols; 34 | 35 | DataSet() { 36 | data = NULL; 37 | num_rows = -1; 38 | num_cols = -1; 39 | } 40 | 41 | ~DataSet() { 42 | if (data != NULL) delete[] data; 43 | } 44 | 45 | DataSet(const DataSet &d) = delete; 46 | 47 | DataSet &operator=(const DataSet &d) = delete; 48 | 49 | int get_num_rows() { return num_rows; } 50 | 51 | int get_num_cols() { return num_cols; } 52 | 53 | int read_from_file(std::string data_file, int id, int worker_num, int real_num_cols) { 54 | num_rows = 0; 55 | std::string buf; 56 | int file_count = 0; 57 | while (++file_count) { 58 | std::string data_files = 59 | data_file + "_/part" + std::to_string(id + ((file_count - 1) * worker_num)); 60 | std::ifstream fin(data_files.c_str()); 61 | if (!fin) break; 62 | while (getline(fin, buf)) { 63 | num_rows++; 64 | } 65 | fin.close(); 66 | } 67 | 68 | data = new DataPoint[num_rows]; 69 | file_count = 0; 70 | int row_count = 0; 71 | num_cols = 0; 72 | 73 | while (++file_count) { 74 | std::string data_files = 75 | data_file + "_/part" + std::to_string(id + ((file_count - 1) * worker_num)); 76 | std::ifstream fin(data_files.c_str()); 77 | if (!fin) break; 78 | while (getline(fin, buf)) { 79 | char str0[] = " :"; 80 | 81 | char *result = strtok((char *)buf.c_str(), str0); 82 | 83 | if ((strcmp(result, "1") == 0) || (strcmp(result, "+1") == 0) || 84 | (strcmp(result, "1.0") == 0) || (strcmp(result, "+1.0") == 0)) 85 | data[row_count].label = 1.0; 86 | else 87 | data[row_count].label = -1.0; 88 | 89 | while (result = strtok(NULL, str0)) { 90 | // key start from 0 91 | data[row_count].key.push_back(atoi(result) - 1); 92 | result = strtok(NULL, str0); 93 | data[row_count].value.push_back(atof(result)); 94 | } 95 | if (data[row_count].key[data[row_count].key.size() - 1] > num_cols) { 96 | num_cols = data[row_count].key[data[row_count].key.size() - 1]; 97 | } 98 | 99 | row_count++; 100 | } 101 | fin.close(); 102 | } 103 | // start from 0, so add 1 104 | num_cols++; 105 | std::cout << "Worker " << id << ": examples:" << num_rows << ",features:" << num_cols << "(" 106 | << real_num_cols - 1 << ")" << std::endl; 107 | num_cols = real_num_cols; 108 | for (int i = 0; i < num_rows; i++) { 109 | data[i].key.push_back(num_cols - 1); 110 | data[i].value.push_back(1.0); 111 | } 112 | 113 | return 0; 114 | } 115 | 116 | int read_from_test_file(std::string data_file, int real_num_cols) { 117 | num_rows = 0; 118 | std::string buf; 119 | std::ifstream fin(data_file.c_str()); 120 | if (!fin) {error(AT, "error");} 121 | while (getline(fin, buf)) { 122 | num_rows++; 123 | } 124 | fin.close(); 125 | 126 | data = new DataPoint[num_rows]; 127 | int row_count = 0; 128 | num_cols = 0; 129 | fin.open(data_file.c_str()); 130 | if (!fin) error(AT, "error");; 131 | while (getline(fin, buf)) { 132 | char str0[] = " :"; 133 | char *result = strtok((char *)buf.c_str(), str0); 134 | if ((strcmp(result, "1") == 0) || (strcmp(result, "+1") == 0) || 135 | (strcmp(result, "1.0") == 0) || (strcmp(result, "+1.0") == 0)) 136 | data[row_count].label = 1.0; 137 | else 138 | data[row_count].label = -1.0; 139 | 140 | while (result = strtok(NULL, str0)) { 141 | // key start from 0 142 | data[row_count].key.push_back(atoi(result) - 1); 143 | result = strtok(NULL, str0); 144 | data[row_count].value.push_back(atof(result)); 145 | } 146 | if (data[row_count].key[data[row_count].key.size() - 1] > num_cols) { 147 | num_cols = data[row_count].key[data[row_count].key.size() - 1]; 148 | } 149 | row_count++; 150 | } 151 | fin.close(); 152 | 153 | // start from 0, so add 1 154 | num_cols++; 155 | std::cout << "test_data " << ": examples:" << num_rows << ",features:" << num_cols << "(" 156 | << real_num_cols - 1 << ")" << std::endl; 157 | num_cols = real_num_cols; 158 | for (int i = 0; i < num_rows; i++) { 159 | data[i].key.push_back(num_cols - 1); 160 | data[i].value.push_back(1.0); 161 | } 162 | return 0; 163 | } 164 | 165 | void count_c_num(std::vector &c) { 166 | for (int i = 0; i < num_rows; i++) { 167 | for (auto &x : data[i].key) c[x]++; 168 | } 169 | } 170 | }; 171 | 172 | #endif -------------------------------------------------------------------------------- /src/comm/Comm.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017 LIBBLE team supervised by Dr. Wu-Jun LI at Nanjing University. 3 | * All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. */ 15 | 16 | #ifndef _COMM_HPP_ 17 | #define _COMM_HPP_ 18 | 19 | #include 20 | #include "mpi.h" 21 | 22 | #include "../storage/include_storage.hpp" 23 | #include "../util/include_util.hpp" 24 | #include "protocol.hpp" 25 | 26 | /* the only class for doing communications */ 27 | 28 | class Comm { 29 | private: 30 | char info; 31 | int num_servers, num_workers, num_cols; 32 | std::vector server_list, worker_list; 33 | std::vector buffer; 34 | std::vector buffer_int; 35 | 36 | public: 37 | Comm(int n_servers, int n_workers, int n_cols) 38 | : num_servers(n_servers), num_workers(n_workers), num_cols(n_cols) { 39 | for (int i = 1; i <= num_servers; i++) server_list.push_back(i); 40 | for (int i = 1; i <= num_workers; i++) worker_list.push_back(num_servers + i); 41 | buffer.resize(num_cols); 42 | buffer_int.resize(num_cols); 43 | } 44 | 45 | std::vector get_server_list() { return server_list; } 46 | 47 | //--------------------coordinator-send 48 | 49 | //--------------------coordinator-receive 50 | double C_recv_loss_from_all_W() { 51 | double total_loss = 0, partial_loss = 0; 52 | for (int w_id : worker_list) { 53 | MPI_Recv(&partial_loss, 1, MPI_DOUBLE, w_id, WC_LOSS, MPI_COMM_WORLD, 54 | MPI_STATUS_IGNORE); 55 | total_loss += partial_loss; 56 | } 57 | return total_loss; 58 | } 59 | 60 | double C_recv_accuracy_from_W() { 61 | double accuracy; 62 | MPI_Recv(&accuracy, 1, MPI_DOUBLE, worker_list[0], WC_ACCU, MPI_COMM_WORLD, MPI_STATUS_IGNORE); 63 | return accuracy; 64 | } 65 | 66 | void C_recv_params_from_all_S(Parameter ¶ms) { 67 | MPI_Status status; 68 | int pos = 0, recv_num = 0; 69 | for (int s_id : server_list) { 70 | /* recv params from each server and concatenate */ 71 | MPI_Recv(¶ms.parameter[pos], params.parameter.size() - pos, MPI_DOUBLE, s_id, 72 | SC_PARAMS, MPI_COMM_WORLD, &status); 73 | MPI_Get_count(&status, MPI_DOUBLE, &recv_num); 74 | pos += recv_num; 75 | } 76 | } 77 | 78 | //--------------------server-send 79 | void S_send_grads_to_all_W(const Gradient_Dense &g) { 80 | for (int w_id : worker_list) { 81 | MPI_Send(&g.gradient[0], g.gradient.size(), MPI_DOUBLE, w_id, SW_GRAD, MPI_COMM_WORLD); 82 | } 83 | } 84 | 85 | void S_send_params_to_all_W(Parameter ¶ms) { 86 | const std::vector &v = params.parameter; 87 | for (int w_id : worker_list) { 88 | MPI_Send(&v[0], v.size(), MPI_DOUBLE, w_id, SW_PARAMS, MPI_COMM_WORLD); 89 | } 90 | } 91 | 92 | void S_send_params_to_C(Parameter ¶ms) { 93 | const std::vector &v = params.parameter; 94 | MPI_Send(&v[0], v.size(), MPI_DOUBLE, 0, SC_PARAMS, MPI_COMM_WORLD); 95 | } 96 | 97 | //--------------------server-receive 98 | void S_recv_grads_from_all_W(Gradient_Dense &g) { 99 | g.reset(); 100 | for (int i = 0; i < num_workers; i++) { 101 | // recv each gradient and add to g->gradient 102 | MPI_Recv(&buffer[0], buffer.size(), MPI_DOUBLE, MPI_ANY_SOURCE, WS_GRADS, 103 | MPI_COMM_WORLD, MPI_STATUS_IGNORE); 104 | vector_add(g.gradient, buffer); 105 | } 106 | } 107 | 108 | void S_recv_params_from_all_W(Parameter ¶ms) { 109 | params.reset(); 110 | for (int i = 0; i < num_workers; i++) { 111 | MPI_Recv(&buffer[0], buffer.size(), MPI_DOUBLE, MPI_ANY_SOURCE, WS_PARAMS, 112 | MPI_COMM_WORLD, MPI_STATUS_IGNORE); 113 | vector_add(params.parameter, buffer); 114 | } 115 | } 116 | 117 | //--------------------worker-send 118 | 119 | void W_send_loss_to_C(double loss) { 120 | MPI_Send(&loss, 1, MPI_DOUBLE, 0, WC_LOSS, MPI_COMM_WORLD); 121 | } 122 | 123 | void W_send_accuracy_to_C(double accuracy) { 124 | MPI_Send(&accuracy, 1, MPI_DOUBLE, 0, WC_ACCU, MPI_COMM_WORLD); 125 | } 126 | 127 | void W_send_grads_to_all_S(const Gradient_Dense &grad) { 128 | /* need to split gradient according to each server's possessions */ 129 | int pos = 0; 130 | for (int s_id : server_list) { 131 | int len = get_local_params_size(num_cols, num_servers, s_id); 132 | MPI_Send(&grad.gradient[pos], len, MPI_DOUBLE, s_id, WS_GRADS, MPI_COMM_WORLD); 133 | pos += len; 134 | } 135 | } 136 | 137 | void W_send_params_to_all_S(const Parameter ¶ms) { 138 | int pos = 0; 139 | for (int s_id : server_list) { 140 | int len = get_local_params_size(num_cols, num_servers, s_id); 141 | MPI_Send(¶ms.parameter[pos], len, MPI_DOUBLE, s_id, WS_PARAMS, MPI_COMM_WORLD); 142 | pos += len; 143 | } 144 | } 145 | 146 | //--------------------worker-receive 147 | void W_recv_params_from_all_S(Parameter ¶ms) { 148 | int pos = 0; 149 | for (int s_id : server_list) { // may optimize to recv unordered 150 | int len = get_local_params_size(num_cols, num_servers, s_id); 151 | MPI_Recv(¶ms.parameter[pos], len, MPI_DOUBLE, s_id, SW_PARAMS, MPI_COMM_WORLD, 152 | MPI_STATUS_IGNORE); 153 | pos += len; 154 | } 155 | } 156 | 157 | void W_recv_full_grad_from_all_S(Gradient_Dense &grad) { 158 | int pos = 0; 159 | for (int s_id : server_list) { 160 | int len = get_local_params_size(num_cols, num_servers, s_id); 161 | MPI_Recv(&grad.gradient[pos], len, MPI_DOUBLE, s_id, SW_GRAD, MPI_COMM_WORLD, 162 | MPI_STATUS_IGNORE); 163 | pos += len; 164 | } 165 | } 166 | }; 167 | 168 | #endif --------------------------------------------------------------------------------