├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── cmake ├── get_project_version.cmake ├── get_python_wheel_tag.cmake ├── mindalpha_shared.cmake └── python_wheel.cmake ├── compile.sh ├── cpp └── mindalpha │ ├── actor_config.cpp │ ├── actor_config.h │ ├── actor_process.cpp │ ├── actor_process.h │ ├── array_hash_map.h │ ├── array_hash_map_reader.h │ ├── array_hash_map_writer.h │ ├── combine_schema.cpp │ ├── combine_schema.h │ ├── data_type.cpp │ ├── data_type.h │ ├── debug.h │ ├── dense_tensor.cpp │ ├── dense_tensor.h │ ├── dense_tensor_meta.cpp │ ├── dense_tensor_meta.h │ ├── dense_tensor_partition.cpp │ ├── dense_tensor_partition.h │ ├── feature_extraction_python_bindings.cpp │ ├── feature_extraction_python_bindings.h │ ├── file_utils.h │ ├── filesys.cpp │ ├── filesys.h │ ├── hash_uniquifier.cpp │ ├── hash_uniquifier.h │ ├── hashtable_helpers.h │ ├── index_batch.cpp │ ├── index_batch.h │ ├── io.cpp │ ├── io.h │ ├── local_filesys.cpp │ ├── local_filesys.h │ ├── logging.h │ ├── map_file_header.cpp │ ├── map_file_header.h │ ├── memory_buffer.h │ ├── message.cpp │ ├── message.h │ ├── message_meta.cpp │ ├── message_meta.h │ ├── message_transport.cpp │ ├── message_transport.h │ ├── ml_ps_python_bindings.cpp │ ├── ml_ps_python_bindings.h │ ├── model_metric_buffer.cpp │ ├── model_metric_buffer.h │ ├── network_utils.cpp │ ├── network_utils.h │ ├── node_control.cpp │ ├── node_control.h │ ├── node_control_command.cpp │ ├── node_control_command.h │ ├── node_encoding.cpp │ ├── node_encoding.h │ ├── node_info.cpp │ ├── node_info.h │ ├── node_manager.cpp │ ├── node_manager.h │ ├── node_role.cpp │ ├── node_role.h │ ├── ps_agent.cpp │ ├── ps_agent.h │ ├── ps_default_agent.cpp │ ├── ps_default_agent.h │ ├── ps_helper.cpp │ ├── ps_helper.h │ ├── ps_runner.cpp │ ├── ps_runner.h │ ├── pybind_utils.cpp │ ├── pybind_utils.h │ ├── s3_sdk_filesys.cpp │ ├── s3_sdk_filesys.h │ ├── smart_array.h │ ├── sparse_tensor.cpp │ ├── sparse_tensor.h │ ├── sparse_tensor_meta.cpp │ ├── sparse_tensor_meta.h │ ├── sparse_tensor_partition.cpp │ ├── sparse_tensor_partition.h │ ├── stack_trace_utils.cpp │ ├── stack_trace_utils.h │ ├── string_utils.h │ ├── tensor_partition_store.cpp │ ├── tensor_partition_store.h │ ├── tensor_store_python_bindings.cpp │ ├── tensor_store_python_bindings.h │ ├── tensor_utils.cpp │ ├── tensor_utils.h │ ├── thread_utils.cpp │ ├── thread_utils.h │ ├── vector_utils.h │ ├── zeromq_transport.cpp │ └── zeromq_transport.h ├── docker ├── centos7 │ ├── Dockerfile │ ├── compile.sh │ └── package.sh └── ubuntu20.04 │ ├── Dockerfile │ ├── compile.sh │ └── package.sh ├── examples ├── deep_fm_example.py ├── swing_estimator_example.py └── wide_and_deep_example.py ├── package.sh ├── python ├── mindalpha │ ├── __init__.py │ ├── agent.py │ ├── cast.py │ ├── compat │ │ ├── __init__.py │ │ └── ps │ │ │ └── __init__.py │ ├── demo.py │ ├── distributed_tensor.py │ ├── distributed_trainer.py │ ├── embedding.py │ ├── estimator.py │ ├── experiment.py │ ├── file_utils.py │ ├── initializer.py │ ├── input.py │ ├── job_utils.py │ ├── loss_utils.py │ ├── metric.py │ ├── model.py │ ├── name_utils.py │ ├── network_utils.py │ ├── nn │ │ ├── __init__.py │ │ ├── deep_fm.py │ │ ├── fm.py │ │ ├── normalization.py │ │ └── wide_and_deep.py │ ├── output.py │ ├── patching_pickle.py │ ├── ps_launcher.py │ ├── s3_utils.py │ ├── shell_utils.py │ ├── spark.py │ ├── stack_trace_utils.py │ ├── swing_retrieval.py │ ├── two_tower_ranking.py │ ├── two_tower_retrieval.py │ ├── updater.py │ └── url_utils.py ├── ps │ ├── __init__.py │ └── job.py └── setup.py ├── run_build.sh ├── thrift └── mindalpha │ └── message_meta.thrift └── tutorials ├── mindalpha-getting-started.ipynb ├── mindalpha-tutorial.ipynb └── schema ├── column_name_demo.txt └── combine_schema_demo.txt /.gitignore: -------------------------------------------------------------------------------- 1 | /build/ 2 | /built/ 3 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | cmake_minimum_required(VERSION 3.14 FATAL_ERROR) 18 | project(mindalpha VERSION 2.0.0.0 LANGUAGES CXX) 19 | 20 | find_package(Git REQUIRED) 21 | find_package(Python REQUIRED COMPONENTS Interpreter Development) 22 | find_package(Boost REQUIRED COMPONENTS) 23 | find_package(PkgConfig REQUIRED) 24 | 25 | find_package(spdlog REQUIRED CONFIG) 26 | find_package(pybind11 REQUIRED CONFIG) 27 | find_package(AWSSDK REQUIRED CONFIG COMPONENTS s3) 28 | 29 | find_package(json11 CONFIG) 30 | if(NOT TARGET json11_static) 31 | pkg_search_module(JSON11 REQUIRED IMPORTED_TARGET GLOBAL json11) 32 | add_library(json11_static ALIAS PkgConfig::JSON11) 33 | endif() 34 | 35 | find_package(Thrift CONFIG) 36 | if(NOT TARGET thrift::thrift) 37 | pkg_search_module(THRIFT REQUIRED IMPORTED_TARGET GLOBAL thrift) 38 | add_library(thrift::thrift ALIAS PkgConfig::THRIFT) 39 | endif() 40 | 41 | find_package(ZeroMQ CONFIG) 42 | if(NOT TARGET libzmq-static) 43 | find_library(ZMQ_LIB zmq) 44 | if("${ZMQ_LIB}" STREQUAL "ZMQ_LIB-NOTFOUND") 45 | message(FATAL_ERROR "libzmq not found") 46 | endif() 47 | find_path(ZMQ_HEADER zmq.h) 48 | if("${ZMQ_HEADER}" STREQUAL "ZMQ_HEADER-NOTFOUND") 49 | message(FATAL_ERROR "zmq.h not found") 50 | endif() 51 | add_library(zmq::libzmq STATIC IMPORTED GLOBAL) 52 | set_target_properties(zmq::libzmq PROPERTIES 53 | IMPORTED_LOCATION "${ZMQ_LIB}" 54 | INTERFACE_INCLUDE_DIRECTORIES "${ZMQ_HEADER}") 55 | else() 56 | add_library(zmq::libzmq ALIAS libzmq-static) 57 | endif() 58 | 59 | include(cmake/get_project_version.cmake) 60 | include(cmake/get_python_wheel_tag.cmake) 61 | include(cmake/mindalpha_shared.cmake) 62 | include(cmake/python_wheel.cmake) 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MindAlpha 2 | 3 | MindAlpha is a machine learning platform integrating PySpark, PyTorch 4 | and a parameter server implementation. The platform contains native 5 | support for sparse parameters, making it easy for users to develop 6 | large-scale models. Together with MindAlpha Serving, the platform 7 | provides a one-stop solution for data preprocessing, model training and 8 | online prediction. 9 | 10 | ## Features 11 | 12 | * Efficient IO with PySpark. Minibatches read by PySpark as pandas DataFrames 13 | can be feed directly to models. 14 | 15 | * Similar API with PyTorch and Spark MLlib, users familar with PyTorch and 16 | PySpark can get started quickly. 17 | 18 | * Wrap custom sparse layers as PyTorch modules, making them easy to use. 19 | Those sparse layers can contain billions of parameters. 20 | 21 | * Models can be developed in Jupyter Notebook interactively and periodical 22 | model training can be scheduled by Airflow. 23 | 24 | * The trained model can be exported via one method call and loaded by MindAlpha 25 | Serving for online prediction. 26 | 27 | ## Build 28 | 29 | Firstly, run script to build a docker image 30 | 31 | ```shell 32 | sh run_build.sh -i 33 | ``` 34 | 35 | For more details, please refer to [docker/ubuntu20.04/Dockerfile](docker/ubuntu20.04/Dockerfile) 36 | and [docker/centos7/Dockerfile](docker/centos7/Dockerfile). 37 | 38 | and run script to compile sources(\*cpp && py) to get dynamic-link library (\*.so) and 39 | python install packages (\*.whl) which will generate at directory **build** by default. 40 | 41 | ```shell 42 | sh run_build.sh -m 43 | ``` 44 | 45 | ## Tutorials 46 | 47 | Two tutorials are given: 48 | 49 | * [MindAlpha Getting Started](tutorials/mindalpha-getting-started.ipynb) introduces the basic API of MindAlpha briefly. 50 | * [MindAlpha Tutorial](tutorials/mindalpha-tutorial.ipynb) shows how to use MindAlpha in the production environment. 51 | -------------------------------------------------------------------------------- /cmake/get_project_version.cmake: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | function(get_project_version var) 18 | execute_process( 19 | COMMAND ${GIT_EXECUTABLE} rev-parse --short HEAD 20 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} 21 | RESULT_VARIABLE rc 22 | OUTPUT_VARIABLE commit_id) 23 | if(NOT "${rc}" STREQUAL "0" OR "${commit_id}" STREQUAL "") 24 | message(FATAL_ERROR "Can not find commit id of the repository.") 25 | endif() 26 | string(STRIP "${commit_id}" commit_id) 27 | execute_process( 28 | COMMAND ${GIT_EXECUTABLE} status --short 29 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} 30 | RESULT_VARIABLE rc 31 | OUTPUT_VARIABLE out) 32 | if(NOT "${rc}" STREQUAL "0") 33 | message(FATAL_ERROR "Can not check cleanness the repository.") 34 | endif() 35 | string(STRIP "${out}" out) 36 | if(NOT "${out}" STREQUAL "") 37 | set(commit_id "${commit_id}.dirty") 38 | endif() 39 | set(version) 40 | string(APPEND version ${PROJECT_VERSION_MAJOR}.) 41 | string(APPEND version ${PROJECT_VERSION_MINOR}.) 42 | string(APPEND version ${PROJECT_VERSION_PATCH}+) 43 | string(APPEND version ${commit_id}) 44 | set("${var}" "${version}" PARENT_SCOPE) 45 | endfunction() 46 | -------------------------------------------------------------------------------- /cmake/get_python_wheel_tag.cmake: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | function(get_python_wheel_tag var) 18 | set(src) 19 | string(APPEND src "import sys; ") 20 | string(APPEND src "ver = '%d%d' % ") 21 | string(APPEND src "(sys.version_info.major, sys.version_info.minor); ") 22 | string(APPEND src "flag = 'u' if ver == '27' and ") 23 | string(APPEND src "sys.maxunicode == 0x10ffff else ''; ") 24 | string(APPEND src "flag += 'm' if ver == '37' else ''; ") 25 | string(APPEND src "print('cp%s-cp%s%s-linux_x86_64' % ") 26 | string(APPEND src "(ver, ver, flag))") 27 | execute_process( 28 | COMMAND ${Python_EXECUTABLE} -c "${src}" 29 | RESULT_VARIABLE rc 30 | OUTPUT_VARIABLE wheel_tag) 31 | if(NOT "${rc}" STREQUAL "0" OR "${wheel_tag}" STREQUAL "") 32 | message(FATAL_ERROR "Can not get Python wheel tag.") 33 | endif() 34 | string(STRIP "${wheel_tag}" wheel_tag) 35 | set("${var}" "${wheel_tag}" PARENT_SCOPE) 36 | endfunction() 37 | -------------------------------------------------------------------------------- /cmake/python_wheel.cmake: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | get_python_wheel_tag(python_wheel_tag) 18 | set(wheel_file_name mindalpha-${project_version}-${python_wheel_tag}.whl) 19 | message(STATUS "python_wheel_tag: ${python_wheel_tag}") 20 | message(STATUS "wheel_file_name: ${wheel_file_name}") 21 | 22 | set(python_files 23 | python/setup.py 24 | python/mindalpha/__init__.py 25 | python/mindalpha/initializer.py 26 | python/mindalpha/updater.py 27 | python/mindalpha/model.py 28 | python/mindalpha/distributed_trainer.py 29 | python/mindalpha/distributed_tensor.py 30 | python/mindalpha/agent.py 31 | python/mindalpha/metric.py 32 | python/mindalpha/loss_utils.py 33 | python/mindalpha/embedding.py 34 | python/mindalpha/cast.py 35 | python/mindalpha/input.py 36 | python/mindalpha/output.py 37 | python/mindalpha/url_utils.py 38 | python/mindalpha/s3_utils.py 39 | python/mindalpha/file_utils.py 40 | python/mindalpha/name_utils.py 41 | python/mindalpha/network_utils.py 42 | python/mindalpha/shell_utils.py 43 | python/mindalpha/stack_trace_utils.py 44 | python/mindalpha/ps_launcher.py 45 | python/mindalpha/job_utils.py 46 | python/mindalpha/estimator.py 47 | python/mindalpha/two_tower_ranking.py 48 | python/mindalpha/two_tower_retrieval.py 49 | python/mindalpha/swing_retrieval.py 50 | python/mindalpha/experiment.py 51 | python/mindalpha/spark.py 52 | python/mindalpha/patching_pickle.py 53 | python/mindalpha/nn/__init__.py 54 | python/mindalpha/nn/normalization.py 55 | python/mindalpha/nn/fm.py 56 | python/mindalpha/nn/wide_and_deep.py 57 | python/mindalpha/nn/deep_fm.py 58 | python/mindalpha/compat/__init__.py 59 | python/mindalpha/compat/ps/__init__.py 60 | python/ps/__init__.py 61 | python/ps/job.py 62 | ) 63 | add_custom_command(OUTPUT ${wheel_file_name} 64 | COMMAND env _MINDALPHA_SO=${PROJECT_BINARY_DIR}/_mindalpha.so 65 | _MINDALPHA_VERSION=${project_version} 66 | ${Python_EXECUTABLE} -m pip wheel ${PROJECT_SOURCE_DIR}/python 67 | MAIN_DEPENDENCY python/setup.py 68 | DEPENDS mindalpha_shared ${python_files}) 69 | add_custom_target(python_wheel ALL DEPENDS ${wheel_file_name}) 70 | -------------------------------------------------------------------------------- /compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 4 | # Copyright 2021 Mobvista 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # 18 | 19 | set -e 20 | pushd $(dirname ${BASH_SOURCE[0]}) 21 | tag=$(source /etc/os-release; echo ${ID}${VERSION_ID}) 22 | ./docker/${tag}/compile.sh 23 | popd 24 | -------------------------------------------------------------------------------- /cpp/mindalpha/actor_config.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | 19 | namespace mindalpha 20 | { 21 | 22 | 23 | 24 | } 25 | -------------------------------------------------------------------------------- /cpp/mindalpha/actor_process.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | 35 | namespace mindalpha 36 | { 37 | 38 | class ActorProcess 39 | { 40 | friend class PSAgent; 41 | 42 | public: 43 | explicit ActorProcess(std::shared_ptr config); 44 | 45 | std::shared_ptr GetConfig() const { return config_; } 46 | void SetConfig(std::shared_ptr value) { config_ = std::move(value); } 47 | 48 | int64_t GetMessageId() { return message_counter_++; } 49 | void Barrier(int group); 50 | 51 | int64_t Send(const Message& msg); 52 | void Receiving(); 53 | 54 | void Start(); 55 | void Run(); 56 | void Stop(); 57 | 58 | private: 59 | bool IsReady(); 60 | void SetIsReady(bool value); 61 | void WaitReady(); 62 | 63 | bool HandleDataMessage(Message&& msg); 64 | 65 | #undef MINDALPHA_NODE_CONTROL_COMMAND_DEF 66 | #define MINDALPHA_NODE_CONTROL_COMMAND_DEF(n) bool Handle##n##Message(const Message& msg); 67 | MINDALPHA_NODE_CONTROL_COMMANDS(MINDALPHA_NODE_CONTROL_COMMAND_DEF) 68 | 69 | std::unordered_set GetDeadNodes(); 70 | void UpdateLocalId(const Message& msg); 71 | void CoordinatorHandleAddNode(const Message& msg); 72 | 73 | std::shared_ptr config_; 74 | std::unique_ptr transport_; 75 | std::unique_ptr manager_; 76 | std::unordered_map connected_nodes_; 77 | std::unordered_map shared_node_mapping_; 78 | std::vector barrier_counter_; 79 | bool ready_{false}; 80 | std::mutex ready_mutex_; 81 | std::condition_variable ready_cv_; 82 | std::atomic message_counter_{0}; 83 | std::atomic send_bytes_{0}; 84 | std::atomic receive_bytes_{0}; 85 | MessageMeta nodes_; 86 | MessageMeta recovery_nodes_; 87 | int num_servers_ = 0; 88 | int num_workers_ = 0; 89 | std::mutex start_mutex_; 90 | std::unique_ptr> receiver_exit_; 91 | int init_stage_ = 0; 92 | NodeInfo coordinator_; 93 | std::shared_ptr agent_; 94 | }; 95 | 96 | } 97 | -------------------------------------------------------------------------------- /cpp/mindalpha/combine_schema.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | namespace mindalpha 28 | { 29 | 30 | class CombineSchema 31 | { 32 | public: 33 | void Clear(); 34 | 35 | void LoadColumnNameFromStream(std::istream& stream); 36 | void LoadColumnNameFromSource(const std::string& source); 37 | void LoadColumnNameFromFile(const std::string& uri); 38 | 39 | void LoadCombineSchemaFromStream(std::istream& stream); 40 | void LoadCombineSchemaFromSource(const std::string& source); 41 | void LoadCombineSchemaFromFile(const std::string& uri); 42 | 43 | size_t GetFeatureCount() const 44 | { 45 | return combine_columns_.size(); 46 | } 47 | 48 | const std::string& GetColumnNameSource() const { return column_name_source_; } 49 | const std::string& GetCombineSchemaSource() const { return combine_schema_source_; } 50 | const std::unordered_map& GetColumnNameMap() const { return column_name_map_; } 51 | 52 | std::tuple, std::vector> 53 | CombineToIndicesAndOffsets(const IndexBatch& batch, bool feature_offset) const; 54 | 55 | static uint64_t ComputeFeatureHash(const std::vector>& feature); 56 | 57 | private: 58 | static constexpr uint64_t CombineOneField(uint64_t name, uint64_t value) 59 | { 60 | return CombineHashCodes(name, value); 61 | } 62 | 63 | static constexpr uint64_t ConcatOneField(uint64_t h, uint64_t name, uint64_t value) 64 | { 65 | constexpr uint64_t sep = '\001'; 66 | h = CombineHashCodes(h, sep); 67 | h = CombineHashCodes(h, name); 68 | h = CombineHashCodes(h, value); 69 | return h; 70 | } 71 | 72 | static void CombineOneFeature(const std::vector& splits, 73 | const std::vector& names, 74 | const std::vector& name_hashes, 75 | std::vector& combine_hashes, 76 | size_t total_results); 77 | 78 | const StringViewHashVector* 79 | GetCell(const IndexBatch& batch, size_t i, const std::string& column_name) const; 80 | 81 | std::unordered_map column_name_map_; 82 | std::vector> combine_columns_; 83 | std::vector> combine_columns_aliases_; 84 | std::vector> combine_columns_aliases_hashes_; 85 | std::vector column_names_; 86 | std::string column_name_source_; 87 | std::string combine_schema_source_; 88 | }; 89 | 90 | } 91 | -------------------------------------------------------------------------------- /cpp/mindalpha/data_type.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | namespace mindalpha 23 | { 24 | 25 | std::string DataTypeToString(DataType type) 26 | { 27 | switch (type) 28 | { 29 | #undef MINDALPHA_DATA_TYPE_DEF 30 | #define MINDALPHA_DATA_TYPE_DEF(t, l, u) case DataType::u: return #l; 31 | MINDALPHA_DATA_TYPES(MINDALPHA_DATA_TYPE_DEF) 32 | default: 33 | std::string serr; 34 | serr.append("Invalid DataType enum value: "); 35 | serr.append(std::to_string(static_cast(type))); 36 | serr.append(".\n\n"); 37 | serr.append(GetStackTrace()); 38 | spdlog::error(serr); 39 | throw std::runtime_error(serr); 40 | } 41 | } 42 | 43 | DataType DataTypeFromString(const std::string& str) 44 | { 45 | #undef MINDALPHA_DATA_TYPE_DEF 46 | #define MINDALPHA_DATA_TYPE_DEF(t, l, u) if (str == #l) return DataType::u; 47 | MINDALPHA_DATA_TYPES(MINDALPHA_DATA_TYPE_DEF) 48 | std::string serr; 49 | serr.append("Invalid DataType enum value: "); 50 | serr.append(str); 51 | serr.append(".\n\n"); 52 | serr.append(GetStackTrace()); 53 | spdlog::error(serr); 54 | throw std::runtime_error(serr); 55 | } 56 | 57 | std::string NullableDataTypeToString(DataType type) 58 | { 59 | if (type == NullDataType) 60 | return NullDataTypeString; 61 | return DataTypeToString(type); 62 | } 63 | 64 | DataType NullableDataTypeFromString(const std::string& str) 65 | { 66 | if (str == NullDataTypeString) 67 | return NullDataType; 68 | return DataTypeFromString(str); 69 | } 70 | 71 | size_t DataTypeToSize(DataType type) 72 | { 73 | switch (type) 74 | { 75 | #undef MINDALPHA_DATA_TYPE_DEF 76 | #define MINDALPHA_DATA_TYPE_DEF(t, l, u) case DataType::u: return sizeof(t); 77 | MINDALPHA_DATA_TYPES(MINDALPHA_DATA_TYPE_DEF) 78 | default: 79 | std::string serr; 80 | serr.append("Invalid DataType enum value: "); 81 | serr.append(std::to_string(static_cast(type))); 82 | serr.append(".\n\n"); 83 | serr.append(GetStackTrace()); 84 | spdlog::error(serr); 85 | throw std::runtime_error(serr); 86 | } 87 | } 88 | 89 | } 90 | -------------------------------------------------------------------------------- /cpp/mindalpha/data_type.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | 22 | // 23 | // ``data_type.h`` defines enum ``DataType`` to represent numeric data types 24 | // and some helper functions to convert ``DataType`` values. 25 | // 26 | 27 | namespace mindalpha 28 | { 29 | 30 | // 31 | // Use the X Macro technique to simplify code. See the following page 32 | // for more information about X Macros: 33 | // 34 | // https://en.wikipedia.org/wiki/X_Macro 35 | // 36 | 37 | #define MINDALPHA_INTEGRAL_DATA_TYPES(X) \ 38 | X(int8_t, int8, Int8) \ 39 | X(int16_t, int16, Int16) \ 40 | X(int32_t, int32, Int32) \ 41 | X(int64_t, int64, Int64) \ 42 | X(uint8_t, uint8, UInt8) \ 43 | X(uint16_t, uint16, UInt16) \ 44 | X(uint32_t, uint32, UInt32) \ 45 | X(uint64_t, uint64, UInt64) \ 46 | /**/ 47 | 48 | #define MINDALPHA_FLOATING_DATA_TYPES(X) \ 49 | X(float, float32, Float32) \ 50 | X(double, float64, Float64) \ 51 | /**/ 52 | 53 | #define MINDALPHA_DATA_TYPES(X) \ 54 | MINDALPHA_INTEGRAL_DATA_TYPES(X) \ 55 | MINDALPHA_FLOATING_DATA_TYPES(X) \ 56 | /**/ 57 | 58 | enum class DataType 59 | { 60 | #undef MINDALPHA_DATA_TYPE_DEF 61 | #define MINDALPHA_DATA_TYPE_DEF(t, l, u) u, 62 | MINDALPHA_DATA_TYPES(MINDALPHA_DATA_TYPE_DEF) 63 | }; 64 | 65 | // A missing ``DataType`` is represented by ``DataType(-1)``. 66 | constexpr DataType NullDataType = static_cast(-1); 67 | constexpr const char* NullDataTypeString = "null"; 68 | 69 | // Functions to convert ``DataType`` to and from strings. 70 | std::string DataTypeToString(DataType type); 71 | DataType DataTypeFromString(const std::string& str); 72 | 73 | std::string NullableDataTypeToString(DataType type); 74 | DataType NullableDataTypeFromString(const std::string& str); 75 | 76 | // This class template computes the ``DataType`` code of a numeric type. 77 | template 78 | struct DataTypeToCode; 79 | 80 | #undef MINDALPHA_DATA_TYPE_DEF 81 | #define MINDALPHA_DATA_TYPE_DEF(t, l, u) \ 82 | template<> \ 83 | struct DataTypeToCode \ 84 | { \ 85 | static constexpr DataType value = DataType::u; \ 86 | }; \ 87 | /**/ 88 | MINDALPHA_DATA_TYPES(MINDALPHA_DATA_TYPE_DEF) 89 | 90 | // Compute the size in bytes of a value of ``type``. 91 | size_t DataTypeToSize(DataType type); 92 | 93 | // This function template and two function overloads ensure ``value`` 94 | // can be output as numbers. Output ``int8_t``/``uint8_t`` directly to 95 | // ``std::ostream`` will cause problems as they are character types 96 | // actually. 97 | template 98 | inline T AsNumber(T value) { return value; } 99 | 100 | inline int32_t AsNumber(int8_t value) { return static_cast(value); } 101 | inline uint32_t AsNumber(uint8_t value) { return static_cast(value); } 102 | 103 | } 104 | -------------------------------------------------------------------------------- /cpp/mindalpha/debug.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | #include 19 | #include 20 | -------------------------------------------------------------------------------- /cpp/mindalpha/dense_tensor.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | namespace mindalpha 26 | { 27 | 28 | class DenseTensor 29 | { 30 | public: 31 | DenseTensorMeta& GetMeta() { return meta_; } 32 | const DenseTensorMeta& GetMeta() const { return meta_; } 33 | void SetMeta(DenseTensorMeta value) { meta_ = std::move(value); } 34 | 35 | std::shared_ptr GetAgent() const { return agent_; } 36 | void SetAgent(std::shared_ptr value) { agent_ = std::move(value); } 37 | 38 | void Init(std::function cb); 39 | void Dispose(std::function cb); 40 | void Push(SmartArray in, std::function cb, bool is_value = false, bool is_state = false); 41 | void Pull(std::function out)> cb, bool is_state = false); 42 | void PushMeta(const DenseTensorMeta& meta, std::function cb); 43 | void PullMeta(std::function cb); 44 | void Load(const std::string& dir_path, std::function cb, bool keep_meta = false); 45 | void Save(const std::string& dir_path, std::function cb); 46 | 47 | private: 48 | std::string GetDenseMetaPath(const std::string& dir_path) const; 49 | std::string GetDenseDataPath(const std::string& dir_path) const; 50 | std::string GetDenseStatePath(const std::string& dir_path) const; 51 | 52 | DenseTensorMeta meta_; 53 | std::shared_ptr agent_; 54 | }; 55 | 56 | } 57 | -------------------------------------------------------------------------------- /cpp/mindalpha/dense_tensor_meta.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | namespace mindalpha 30 | { 31 | 32 | using DenseInitializer = std::function data, 34 | const class DenseTensorMeta& meta)>; 35 | 36 | using DenseUpdater = std::function param, 38 | SmartArray grad, 39 | SmartArray state, 40 | const class DenseTensorMeta& meta)>; 41 | 42 | class DenseTensorMeta 43 | { 44 | public: 45 | const std::string& GetName() const { return name_; } 46 | void SetName(std::string value) { name_ = std::move(value); } 47 | 48 | DataType GetDataType() const { return data_type_; } 49 | void SetDataType(DataType value) { data_type_ = value; } 50 | 51 | const std::vector& GetDataShape() const { return data_shape_; } 52 | void SetDataShape(std::vector value) { data_shape_ = std::move(value); } 53 | 54 | const std::vector& GetStateShape() const { return state_shape_; } 55 | void SetStateShape(std::vector value) { state_shape_ = std::move(value); } 56 | 57 | DenseInitializer GetInitializer() const { return initializer_; } 58 | void SetInitializer(DenseInitializer value) { initializer_ = std::move(value); } 59 | 60 | DenseUpdater GetUpdater() const { return updater_; } 61 | void SetUpdater(DenseUpdater value) { updater_ = std::move(value); } 62 | 63 | int GetPartitionCount() const { return partition_count_; } 64 | void SetPartitionCount(int value) { partition_count_ = value; } 65 | 66 | void CheckDenseTensorMeta(int index) const; 67 | 68 | void ComputePartitionShapesWithHash(size_t hash, int index, size_t& begin, size_t& end, 69 | std::vector* partition_data_shape, 70 | std::vector* partition_state_shape) const; 71 | 72 | void ComputePartitionShapes(int index, size_t& begin, size_t& end, 73 | std::vector* partition_data_shape, 74 | std::vector* partition_state_shape) const; 75 | 76 | size_t GetNameHash() const { return BKDRHash(name_); } 77 | 78 | void SetInitializerByData(std::string data); 79 | void SetUpdaterByData(std::string data); 80 | 81 | std::string GetInitializerAsData() const; 82 | std::string GetUpdaterAsData() const; 83 | 84 | std::string ToString() const; 85 | std::string ToJsonString() const; 86 | json11::Json to_json() const; 87 | 88 | static DenseTensorMeta FromJsonString(const std::string& str); 89 | static DenseTensorMeta FromJson(json11::Json json); 90 | 91 | bool IsCompatible(const DenseTensorMeta& rhs) const; 92 | 93 | bool operator==(const DenseTensorMeta& rhs) const; 94 | bool operator!=(const DenseTensorMeta& rhs) const { return !(*this == rhs); } 95 | 96 | private: 97 | std::string name_; 98 | DataType data_type_ = NullDataType; 99 | std::vector data_shape_; 100 | std::vector state_shape_; 101 | DenseInitializer initializer_; 102 | DenseUpdater updater_; 103 | std::any initializer_object_; 104 | std::any updater_object_; 105 | int partition_count_ = -1; 106 | }; 107 | 108 | } 109 | -------------------------------------------------------------------------------- /cpp/mindalpha/dense_tensor_partition.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | namespace mindalpha 25 | { 26 | 27 | void DenseTensorPartition::AllocateDataBlock(bool init) 28 | { 29 | size_t begin = 0; 30 | size_t end = 0; 31 | const int index = GetPartitionIndex(); 32 | GetMeta().CheckDenseTensorMeta(index); 33 | GetMeta().ComputePartitionShapes(index, begin, end, 34 | &partition_data_shape_, &partition_state_shape_); 35 | const size_t item_size = DataTypeToSize(GetMeta().GetDataType()); 36 | const size_t data_size = item_size * TotalElements(GetPartitionDataShape()); 37 | data_.Reset(data_size); 38 | if (init) 39 | { 40 | DenseInitializer initializer = GetMeta().GetInitializer(); 41 | if (!initializer) 42 | memset(data_.data(), 0, data_size); 43 | else 44 | initializer(GetMeta().GetName(), data_, GetMeta()); 45 | } 46 | if (!GetPartitionStateShape().empty()) 47 | { 48 | const size_t state_size = item_size * TotalElements(GetPartitionStateShape()); 49 | state_.Reset(state_size); 50 | if (init) 51 | memset(state_.data(), 0, state_size); 52 | } 53 | } 54 | 55 | void DenseTensorPartition::HandlePush(SmartArray in, bool is_value, bool is_state) 56 | { 57 | DenseUpdater updater = GetMeta().GetUpdater(); 58 | if (is_state) 59 | state_.CopyFrom(in); 60 | else if (!updater || is_value) 61 | data_.CopyFrom(in); 62 | else 63 | updater(GetMeta().GetName(), data_, in, state_, GetMeta()); 64 | } 65 | 66 | SmartArray DenseTensorPartition::HandlePull(bool is_state) 67 | { 68 | return is_state ? state_ : data_; 69 | } 70 | 71 | void DenseTensorPartition::HandlePushMeta(const DenseTensorMeta& meta) 72 | { 73 | if (!meta.IsCompatible(meta_)) 74 | { 75 | std::string serr; 76 | serr.append("Incompatible meta detected, can not update initializer and updater"); 77 | serr.append(" of dense tensor '"); 78 | serr.append(GetMeta().GetName()); 79 | serr.append("'.\n\n"); 80 | serr.append(GetStackTrace()); 81 | spdlog::error(serr); 82 | throw std::runtime_error(serr); 83 | } 84 | meta_.SetInitializerByData(meta.GetInitializerAsData()); 85 | meta_.SetUpdaterByData(meta.GetUpdaterAsData()); 86 | } 87 | 88 | const DenseTensorMeta& DenseTensorPartition::HandlePullMeta() 89 | { 90 | return meta_; 91 | } 92 | 93 | } 94 | -------------------------------------------------------------------------------- /cpp/mindalpha/dense_tensor_partition.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | 21 | namespace mindalpha 22 | { 23 | 24 | class DenseTensorPartition 25 | { 26 | public: 27 | DenseTensorMeta& GetMeta() { return meta_; } 28 | const DenseTensorMeta& GetMeta() const { return meta_; } 29 | void SetMeta(DenseTensorMeta value) { meta_ = std::move(value); } 30 | 31 | std::vector& GetPartitionDataShape() { return partition_data_shape_; } 32 | const std::vector& GetPartitionDataShape() const { return partition_data_shape_; } 33 | void SetPartitionDataShape(std::vector value) { partition_data_shape_ = std::move(value); } 34 | 35 | std::vector& GetPartitionStateShape() { return partition_state_shape_; } 36 | const std::vector& GetPartitionStateShape() const { return partition_state_shape_; } 37 | void SetPartitionStateShape(std::vector value) { partition_state_shape_ = std::move(value); } 38 | 39 | int GetPartitionIndex() const { return partition_index_; } 40 | void SetPartitionIndex(int value) { partition_index_ = value; } 41 | 42 | size_t GetOffset() const { return offset_; } 43 | void SetOffset(size_t value) { offset_ = value; } 44 | 45 | void AllocateDataBlock(bool init); 46 | void HandlePush(SmartArray in, bool is_value, bool is_state); 47 | SmartArray HandlePull(bool is_state); 48 | void HandlePushMeta(const DenseTensorMeta& meta); 49 | const DenseTensorMeta& HandlePullMeta(); 50 | 51 | private: 52 | DenseTensorMeta meta_; 53 | std::vector partition_data_shape_; 54 | std::vector partition_state_shape_; 55 | int partition_index_ = -1; 56 | size_t offset_ = size_t(-1); 57 | SmartArray data_; 58 | SmartArray state_; 59 | }; 60 | 61 | } 62 | -------------------------------------------------------------------------------- /cpp/mindalpha/feature_extraction_python_bindings.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | 21 | namespace mindalpha 22 | { 23 | 24 | void DefineFeatureExtractionBindings(pybind11::module& m); 25 | 26 | } 27 | -------------------------------------------------------------------------------- /cpp/mindalpha/file_utils.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include "mindalpha/smart_array.h" 20 | #include "mindalpha/io.h" 21 | #include "mindalpha/debug.h" 22 | #include "mindalpha/logging.h" 23 | 24 | namespace mindalpha 25 | { 26 | 27 | struct FileHeader { 28 | uint32_t magic; 29 | uint32_t patch; 30 | uint64_t size; 31 | static const uint32_t kMagicNum = 0xffffeeee; 32 | bool Check() { 33 | return magic == kMagicNum && 34 | size > 0; 35 | } 36 | inline uint64_t Size() { 37 | return size; 38 | } 39 | }; 40 | 41 | template 42 | int LoadAsSArray(const std::string& path, SmartArray* array) { 43 | if (array->empty()) { 44 | // ignore empty range on this server 45 | LOG(INFO) << "Ignoring empty range for " << path; 46 | return 0; 47 | } 48 | std::unique_ptr stream(Stream::Create(path.c_str(), "r", true)); 49 | if (!stream) { 50 | return -1; 51 | } 52 | const size_t nread = stream->Read(array->data(), array->size()); 53 | if (nread != array->size()) 54 | return -1; 55 | return 0; 56 | } 57 | 58 | template 59 | int SaveAsSArray(const std::string&path, const SmartArray& array) { 60 | if (array.empty()) { 61 | // ignore empty range on this server 62 | LOG(INFO) << "Ignoring empty range for " << path; 63 | return 0; 64 | } 65 | std::unique_ptr stream(Stream::Create(path.c_str(), "w", true)); 66 | if (!stream) { 67 | return -1; 68 | } 69 | stream->Write(array.data(), array.size()); 70 | return 0; 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /cpp/mindalpha/filesys.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | #include 19 | 20 | namespace mindalpha 21 | { 22 | 23 | void FileSystem::ListDirectoryRecursive(const URI &path, std::vector *out_list) { 24 | std::queue queue; 25 | queue.push(path); 26 | while (!queue.empty()) { 27 | std::vector dfiles; 28 | ListDirectory(queue.front(), &dfiles); 29 | queue.pop(); 30 | for (auto dfile : dfiles) { 31 | if (dfile.type == kDirectory) { 32 | queue.push(dfile.path); 33 | } 34 | else { 35 | out_list->push_back(dfile); 36 | } 37 | } 38 | } 39 | } 40 | 41 | } 42 | -------------------------------------------------------------------------------- /cpp/mindalpha/hash_uniquifier.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | namespace mindalpha 22 | { 23 | 24 | std::vector HashUniquifier::Uniquify(uint64_t* items, size_t count) 25 | { 26 | const uint64_t capacity = GetHashCapacity(count); 27 | const uint64_t size = capacity * 2 / 3; 28 | std::vector entries; 29 | std::vector buckets; 30 | entries.reserve(size); 31 | buckets.assign(capacity, -1); 32 | for (size_t i = 0; i < count; i++) 33 | { 34 | uint64_t offset; 35 | const uint64_t key = items[i]; 36 | InsertHashEntry(key, offset, entries, buckets); 37 | items[i] = offset; 38 | } 39 | return std::move(entries); 40 | } 41 | 42 | std::vector HashUniquifier::Uniquify(std::vector& items) 43 | { 44 | return Uniquify(items.data(), items.size()); 45 | } 46 | 47 | int32_t HashUniquifier::FindEntryAndBucket(uint64_t key, uint64_t hashCode, 48 | const std::vector& entries, 49 | const std::vector& buckets, 50 | uint64_t& bucket) 51 | { 52 | const uint64_t mask = buckets.size() - 1; 53 | uint64_t perturb = hashCode; 54 | bucket = hashCode & mask; 55 | for (;;) 56 | { 57 | const int32_t i = buckets.at(bucket); 58 | if (i == -1) 59 | return -1; 60 | if (i >= 0 && entries.at(i) == key) 61 | return i; 62 | perturb >>= 5; 63 | bucket = (bucket * 5 + 1 + perturb) & mask; 64 | } 65 | } 66 | 67 | uint64_t HashUniquifier::GetHashCapacity(uint64_t minSize) 68 | { 69 | const uint64_t cap = (3 * minSize + 1) / 2; 70 | return HashtableHelpers::GetPowerBucketCount(cap); 71 | } 72 | 73 | bool HashUniquifier::InsertHashEntry(uint64_t key, uint64_t& offset, 74 | std::vector& entries, 75 | std::vector& buckets) 76 | { 77 | uint64_t bucket; 78 | const uint64_t hashCode = HashtableHelpers::FastModulo(key); 79 | const int32_t n = FindEntryAndBucket(key, hashCode, entries, buckets, bucket); 80 | if (n == -1) 81 | { 82 | offset = static_cast(entries.size()); 83 | buckets.at(bucket) = static_cast(offset); 84 | entries.push_back(key); 85 | return true; 86 | } 87 | offset = static_cast(n); 88 | return false; 89 | } 90 | 91 | } 92 | -------------------------------------------------------------------------------- /cpp/mindalpha/hash_uniquifier.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | // 24 | // ``hash_uniquifier.h`` defines class ``HashUniquifier``. Its static 25 | // method ``Uniquify`` can be used to uniquify feature hash codes produced 26 | // by ``EmbeddingOperator`` so that communication can be reduced. 27 | // 28 | 29 | namespace mindalpha 30 | { 31 | 32 | class HashUniquifier 33 | { 34 | public: 35 | static std::vector Uniquify(uint64_t* items, size_t count); 36 | static std::vector Uniquify(std::vector& items); 37 | 38 | private: 39 | static int32_t FindEntryAndBucket(uint64_t key, uint64_t hashCode, 40 | const std::vector& entries, 41 | const std::vector& buckets, 42 | uint64_t& bucket); 43 | 44 | static uint64_t GetHashCapacity(uint64_t minSize); 45 | 46 | static bool InsertHashEntry(uint64_t key, uint64_t& offset, 47 | std::vector& entries, 48 | std::vector& buckets); 49 | }; 50 | 51 | } 52 | -------------------------------------------------------------------------------- /cpp/mindalpha/hashtable_helpers.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | 21 | // 22 | // ``hashtable_helpers.h`` contains helper functions useful for 23 | // implementing hashtables. 24 | // 25 | 26 | namespace mindalpha 27 | { 28 | 29 | class HashtableHelpers 30 | { 31 | public: 32 | static constexpr uint64_t GetPowerBucketCount(uint64_t v) 33 | { 34 | v--; 35 | v |= v >> 1; 36 | v |= v >> 2; 37 | v |= v >> 4; 38 | v |= v >> 8; 39 | v |= v >> 16; 40 | v |= v >> 32; 41 | v++; 42 | return v; 43 | } 44 | 45 | static constexpr uint64_t FastModulo(uint64_t a) 46 | { 47 | constexpr int q = 31; 48 | constexpr uint64_t prime = (UINT64_C(1) << q) - 1; 49 | uint64_t r = (a & prime) + (a >> q); 50 | if (r >= prime) 51 | r -= prime; 52 | return r; 53 | } 54 | }; 55 | 56 | } 57 | -------------------------------------------------------------------------------- /cpp/mindalpha/index_batch.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | namespace mindalpha 24 | { 25 | 26 | class __attribute__((visibility("hidden"))) IndexBatch 27 | { 28 | public: 29 | IndexBatch(pybind11::list columns, const std::string& delimiters); 30 | 31 | const StringViewHashVector& GetCell(size_t i, size_t j, const std::string& column_name) const; 32 | 33 | pybind11::list ToList() const; 34 | 35 | size_t GetRows() const { return rows_; } 36 | size_t GetColumns() const { return split_columns_.size(); } 37 | 38 | std::string ToString() const; 39 | 40 | private: 41 | struct __attribute__((visibility("hidden"))) string_view_cell 42 | { 43 | StringViewHashVector items_; 44 | pybind11::object obj_; 45 | }; 46 | 47 | using StringViewColumn = std::vector; 48 | 49 | static StringViewColumn SplitColumn(const pybind11::array& column, std::string_view delims); 50 | 51 | std::vector split_columns_; 52 | size_t rows_; 53 | }; 54 | 55 | } 56 | -------------------------------------------------------------------------------- /cpp/mindalpha/io.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | namespace mindalpha 24 | { 25 | 26 | class Stream { // NOLINT(*) 27 | public: 28 | /*! 29 | * \brief reads data from a stream 30 | * \param ptr pointer to a memory buffer 31 | * \param size block size 32 | * \return the size of data read 33 | */ 34 | virtual size_t Read(void *ptr, size_t size) = 0; 35 | /*! 36 | * \brief writes data to a stream 37 | * \param ptr pointer to a memory buffer 38 | * \param size block size 39 | */ 40 | virtual void Write(const void *ptr, size_t size) = 0; 41 | /*! \brief virtual destructor */ 42 | virtual ~Stream(void) { 43 | } 44 | /*! 45 | * \brief generic factory function 46 | * create an stream, the stream will close the underlying files upon deletion 47 | * 48 | * \param uri the uri of the input currently we support 49 | * hdfs://, s3://, and file:// by default file:// will be used 50 | * \param flag can be "w", "r", "a" 51 | * \param allow_null whether NULL can be returned, or directly report error 52 | * \return the created stream, can be NULL when allow_null == true and file do not exist 53 | */ 54 | static Stream *Create(const char *uri, const char *const flag, bool allow_null = false); 55 | }; 56 | 57 | /*! \brief interface of i/o stream that support seek */ 58 | class SeekStream : public Stream { 59 | public: 60 | // virtual destructor 61 | virtual ~SeekStream(void) { 62 | } 63 | /*! \brief seek to certain position of the file */ 64 | virtual void Seek(size_t pos) = 0; 65 | /*! \brief tell the position of the stream */ 66 | virtual size_t Tell(void) = 0; 67 | /*! 68 | * \brief generic factory function 69 | * create an SeekStream for read only, 70 | * the stream will close the underlying files upon deletion 71 | * error will be reported and the system will exit when create failed 72 | * \param uri the uri of the input currently we support 73 | * hdfs://, s3://, and file:// by default file:// will be used 74 | * \param allow_null whether NULL can be returned, or directly report error 75 | * \return the created stream, can be NULL when allow_null == true and file do not exist 76 | */ 77 | static SeekStream *CreateForRead(const char *uri, bool allow_null = false); 78 | }; 79 | 80 | class InputStream 81 | { 82 | public: 83 | explicit InputStream(const std::string& url); 84 | size_t Read(void* buffer, size_t size); 85 | 86 | private: 87 | std::unique_ptr stream_; 88 | }; 89 | 90 | class OutputStream 91 | { 92 | public: 93 | explicit OutputStream(const std::string& url); 94 | void Write(const void* buffer, size_t size); 95 | 96 | private: 97 | std::unique_ptr stream_; 98 | }; 99 | 100 | void StreamWriteAll(const std::string& url, const char* data, size_t size); 101 | void StreamWriteAll(const std::string& url, const std::string& data); 102 | void StreamReadAll(const std::string&url, char* data, size_t size); 103 | std::string StreamReadAll(const std::string& url); 104 | 105 | void MakeLocalDirectories(const std::string& path, mode_t mode); 106 | void EnsureLocalDirectory(const std::string& dir_path); 107 | std::string DirName(const std::string& path); 108 | std::string JoinPath(const std::string& dir_path, const std::string& file_name); 109 | 110 | } 111 | -------------------------------------------------------------------------------- /cpp/mindalpha/local_filesys.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | 22 | namespace mindalpha 23 | { 24 | 25 | /*! \brief local file system */ 26 | class LocalFileSystem : public FileSystem { 27 | public: 28 | /*! \brief destructor */ 29 | virtual ~LocalFileSystem() { 30 | } 31 | /*! 32 | * \brief get information about a path 33 | * \param path the path to the file 34 | * \return the information about the file 35 | */ 36 | virtual FileInfo GetPathInfo(const URI &path); 37 | /*! 38 | * \brief list files in a directory 39 | * \param path to the file 40 | * \param out_list the output information about the files 41 | */ 42 | virtual void ListDirectory(const URI &path, std::vector *out_list); 43 | /*! 44 | * \brief open a stream, will report error and exit if bad thing happens 45 | * NOTE: the IStream can continue to work even when filesystem was destructed 46 | * \param path path to file 47 | * \param uri the uri of the input 48 | * \param allow_null whether NULL can be returned, or directly report error 49 | * \return the created stream, can be NULL when allow_null == true and file do not exist 50 | */ 51 | virtual SeekStream *Open(const URI &path, const char *const flag, bool allow_null); 52 | /*! 53 | * \brief open a seekable stream for read 54 | * \param path the path to the file 55 | * \param allow_null whether NULL can be returned, or directly report error 56 | * \return the created stream, can be NULL when allow_null == true and file do not exist 57 | */ 58 | virtual SeekStream *OpenForRead(const URI &path, bool allow_null); 59 | /*! 60 | * \brief get a singleton of LocalFileSystem when needed 61 | * \return a singleton instance 62 | */ 63 | inline static LocalFileSystem *GetInstance(void) { 64 | static LocalFileSystem instance; 65 | return &instance; 66 | } 67 | LocalFileSystem() { 68 | } 69 | 70 | private: 71 | }; 72 | 73 | } 74 | -------------------------------------------------------------------------------- /cpp/mindalpha/logging.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include "spdlog/spdlog.h" 20 | #include 21 | #include 22 | #include 23 | 24 | namespace mindalpha 25 | { 26 | 27 | class LogMessage { 28 | public: 29 | LogMessage(const char *file, int line, bool is_fatal, spdlog::level::level_enum log_level) 30 | : is_fatal(is_fatal) 31 | , log_level(log_level) { 32 | buf.reserve(1024); 33 | using namespace std::string_view_literals; 34 | fmt::format_to(buf, "{}:{} -- "sv, file, line); 35 | } 36 | ~LogMessage() { 37 | spdlog::log(log_level, "{}", str()); 38 | if (is_fatal) { 39 | abort(); 40 | } 41 | } 42 | template 43 | inline LogMessage &operator<<(const T &val) { 44 | using namespace std::string_view_literals; 45 | fmt::format_to(buf, "{}"sv, val); 46 | return *this; 47 | } 48 | // for std::endl 49 | inline LogMessage &operator<<(std::ostream &(*f)(std::ostream &)) { 50 | buf.push_back('\n'); 51 | return *this; 52 | } 53 | inline std::string_view str() const { 54 | return std::string_view(buf.data(), buf.size()); 55 | } 56 | inline LogMessage &stream() { 57 | return *this; 58 | } 59 | 60 | private: 61 | fmt::memory_buffer buf; 62 | bool is_fatal = false; 63 | spdlog::level::level_enum log_level = spdlog::level::info; 64 | }; 65 | // Always-on checking 66 | #define CHECK(x) \ 67 | if (!(x)) \ 68 | LogMessage(__FILE__, __LINE__, true, spdlog::level::err).stream() << "Check " \ 69 | "failed: " #x \ 70 | << ' ' 71 | #define CHECK_LT(x, y) CHECK((x) < (y)) 72 | #define CHECK_GT(x, y) CHECK((x) > (y)) 73 | #define CHECK_LE(x, y) CHECK((x) <= (y)) 74 | #define CHECK_GE(x, y) CHECK((x) >= (y)) 75 | #define CHECK_EQ(x, y) CHECK((x) == (y)) 76 | #define CHECK_NE(x, y) CHECK((x) != (y)) 77 | #define CHECK_NOTNULL(x) \ 78 | ((x) == NULL ? LogMessage(__FILE__, __LINE__, true, spdlog::level::err).stream() \ 79 | << "notnull: " #x << ' ', (x) \ 80 | : (x)) // NOLINT(*) 81 | 82 | #define LOG_INFO LogMessage(__FILE__, __LINE__, false, spdlog::level::info) 83 | #define LOG_ERROR LogMessage(__FILE__, __LINE__, false, spdlog::level::err) 84 | #define LOG_WARNING LogMessage(__FILE__, __LINE__, false, spdlog::level::warn) 85 | #define LOG_FATAL LogMessage(__FILE__, __LINE__, true, spdlog::level::critical) 86 | #define LOG_QFATAL LOG_FATAL 87 | 88 | // Poor man version of VLOG 89 | #define VLOG(x) LOG_INFO.stream() 90 | 91 | #define LOG(severity) LOG_##severity.stream() 92 | #define LG LOG_INFO.stream() 93 | #define LOG_IF(severity, condition) !(condition) ? (void)0 : LOG(severity) 94 | 95 | } 96 | -------------------------------------------------------------------------------- /cpp/mindalpha/map_file_header.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | // 24 | // ``map_file_header.h`` defines the binary structure of ArrayHashMap 25 | // file header. For simplicity and efficiency, we assume little endian 26 | // and do not consider portability. 27 | // 28 | 29 | namespace mindalpha 30 | { 31 | 32 | const uint64_t map_file_signature_size = 32; 33 | 34 | struct MapFileHeader 35 | { 36 | char signature[map_file_signature_size]; 37 | uint64_t version; 38 | uint64_t reserved_; // for backward compatibility 39 | uint64_t key_type; 40 | uint64_t value_type; 41 | uint64_t key_count; 42 | uint64_t bucket_count; 43 | uint64_t value_count; 44 | uint64_t value_count_per_key; 45 | 46 | void FillBasicFields(); 47 | bool IsSignatureValid() const; 48 | void Validate(const std::string& hint) const; 49 | }; 50 | 51 | const uint64_t map_file_header_size = sizeof(MapFileHeader); 52 | 53 | extern const char map_file_signature[map_file_signature_size]; 54 | extern const uint64_t map_file_version; 55 | 56 | } 57 | -------------------------------------------------------------------------------- /cpp/mindalpha/memory_buffer.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | // 25 | // ``memory_buffer.h`` defines class ``MemoryBuffer`` which simplifies 26 | // the calling of C memory management functions in ``ArrayHashMap``. 27 | // 28 | 29 | namespace mindalpha 30 | { 31 | 32 | class MemoryBuffer 33 | { 34 | public: 35 | MemoryBuffer() 36 | { 37 | ptr_ = nullptr; 38 | size_ = 0; 39 | } 40 | 41 | explicit MemoryBuffer(uint64_t size) 42 | { 43 | if (size == 0) 44 | { 45 | ptr_ = nullptr; 46 | size_ = 0; 47 | } 48 | else 49 | { 50 | ptr_ = malloc(size); 51 | if (!ptr_) 52 | throw std::bad_alloc(); 53 | size_ = size; 54 | } 55 | } 56 | 57 | MemoryBuffer(MemoryBuffer&& rhs) 58 | { 59 | ptr_ = rhs.ptr_; 60 | size_ = rhs.size_; 61 | rhs.ptr_ = nullptr; 62 | rhs.size_ = 0; 63 | } 64 | 65 | ~MemoryBuffer() 66 | { 67 | if (ptr_) 68 | { 69 | free(ptr_); 70 | ptr_ = nullptr; 71 | } 72 | size_ = 0; 73 | } 74 | 75 | void Swap(MemoryBuffer& other) 76 | { 77 | std::swap(ptr_, other.ptr_); 78 | std::swap(size_, other.size_); 79 | } 80 | 81 | void* GetPointer() const 82 | { 83 | return ptr_; 84 | } 85 | 86 | uint64_t GetSize() const 87 | { 88 | return size_; 89 | } 90 | 91 | void Deallocate() 92 | { 93 | MemoryBuffer buf; 94 | Swap(buf); 95 | } 96 | 97 | void Reallocate(uint64_t size) 98 | { 99 | if (size == 0) 100 | Deallocate(); 101 | else 102 | { 103 | void* new_ptr = realloc(ptr_, size); 104 | if (!new_ptr) 105 | throw std::bad_alloc(); 106 | ptr_ = new_ptr; 107 | size_ = size; 108 | } 109 | } 110 | 111 | private: 112 | void* ptr_; 113 | uint64_t size_; 114 | }; 115 | 116 | } 117 | -------------------------------------------------------------------------------- /cpp/mindalpha/message.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | namespace mindalpha 23 | { 24 | 25 | void Message::AddTypedSlice(SmartArray slice, DataType dataType) 26 | { 27 | AddSlice(std::move(slice)); 28 | message_meta_.AddSliceDataType(dataType); 29 | } 30 | 31 | SmartArray Message::GetSlice(size_t i) const 32 | { 33 | if (i >= slices_.size()) 34 | { 35 | std::string serr; 36 | serr.append("GetSlice failed as slice index is out of range. i: "); 37 | serr.append(std::to_string(i)); 38 | serr.append(", slices_.size(): "); 39 | serr.append(std::to_string(slices_.size())); 40 | serr.append(".\n\n"); 41 | serr.append(GetStackTrace()); 42 | spdlog::error(serr); 43 | throw std::runtime_error(serr); 44 | } 45 | return slices_.at(i); 46 | } 47 | 48 | SmartArray Message::GetTypedSlice(size_t i, DataType dataType) const 49 | { 50 | const std::vector& sliceDataTypes = message_meta_.GetSliceDataTypes(); 51 | if (i >= sliceDataTypes.size()) 52 | { 53 | std::string serr; 54 | serr.append("GetTypedSlice failed as slice index is out of range. i: "); 55 | serr.append(std::to_string(i)); 56 | serr.append(", sliceDataTypes.size(): "); 57 | serr.append(std::to_string(sliceDataTypes.size())); 58 | serr.append(".\n\n"); 59 | serr.append(GetStackTrace()); 60 | spdlog::error(serr); 61 | throw std::runtime_error(serr); 62 | } 63 | if (i >= slices_.size()) 64 | { 65 | std::string serr; 66 | serr.append("GetTypedSlice failed as slice index is out of range. i: "); 67 | serr.append(std::to_string(i)); 68 | serr.append(", slices_.size(): "); 69 | serr.append(std::to_string(slices_.size())); 70 | serr.append(".\n\n"); 71 | serr.append(GetStackTrace()); 72 | spdlog::error(serr); 73 | throw std::runtime_error(serr); 74 | } 75 | if (dataType != sliceDataTypes.at(i)) 76 | { 77 | std::string serr; 78 | serr.append("GetTypedSlice failed as data types mismatch. i: "); 79 | serr.append(std::to_string(i)); 80 | serr.append(", dataType: "); 81 | serr.append(DataTypeToString(dataType)); 82 | serr.append(", sliceDataTypes.at(i): "); 83 | serr.append(DataTypeToString(sliceDataTypes.at(i))); 84 | serr.append(".\n\n"); 85 | serr.append(GetStackTrace()); 86 | spdlog::error(serr); 87 | throw std::runtime_error(serr); 88 | } 89 | return slices_.at(i); 90 | } 91 | 92 | std::string Message::ToString() const 93 | { 94 | return message_meta_.ToString(); 95 | } 96 | 97 | std::string Message::ToJsonString() const 98 | { 99 | return to_json().dump(); 100 | } 101 | 102 | json11::Json Message::to_json() const 103 | { 104 | std::vector slices; 105 | slices.reserve(slices.size()); 106 | for (auto&& slice: slices_) 107 | slices.push_back(slice.to_json()); 108 | return json11::Json::object 109 | { 110 | { "message_meta", message_meta_ }, 111 | { "slices", std::move(slices) }, 112 | }; 113 | } 114 | 115 | } 116 | -------------------------------------------------------------------------------- /cpp/mindalpha/message.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | 22 | // 23 | // ``message.h`` defines class ``Message`` which represents 24 | // messages sent between Parameter Server nodes. 25 | // 26 | // ``Message`` consists of a metadata part and zero or more 27 | // typed data slices. 28 | // 29 | 30 | namespace mindalpha 31 | { 32 | 33 | class Message 34 | { 35 | public: 36 | MessageMeta& GetMessageMeta() { return message_meta_; } 37 | const MessageMeta& GetMessageMeta() const { return message_meta_; } 38 | void SetMessageMeta(MessageMeta value) { message_meta_ = std::move(value); } 39 | 40 | const std::vector>& GetSlices() const { return slices_; } 41 | void SetSlices(std::vector> value) { slices_ = std::move(value); } 42 | void ClearSlices() { slices_.clear(); } 43 | void ClearSlicesAndDataTypes() { ClearSlices(); message_meta_.ClearSliceDataTypes(); } 44 | void AddSlice(SmartArray value) { slices_.push_back(std::move(value)); } 45 | 46 | void AddTypedSlice(SmartArray slice, DataType dataType); 47 | 48 | template 49 | void AddTypedSlice(SmartArray slice) 50 | { 51 | auto sa = slice.template Cast(); 52 | AddTypedSlice(std::move(sa), DataTypeToCode::value); 53 | } 54 | 55 | SmartArray GetSlice(size_t i) const; 56 | SmartArray GetTypedSlice(size_t i, DataType dataType) const; 57 | 58 | template 59 | SmartArray GetTypedSlice(size_t i) const 60 | { 61 | SmartArray slice = GetTypedSlice(i, DataTypeToCode::value); 62 | return slice.Cast(); 63 | } 64 | 65 | std::shared_ptr Copy() const { return std::make_shared(*this); } 66 | 67 | std::string ToString() const; 68 | std::string ToJsonString() const; 69 | json11::Json to_json() const; 70 | 71 | private: 72 | MessageMeta message_meta_; 73 | std::vector> slices_; 74 | }; 75 | 76 | } 77 | -------------------------------------------------------------------------------- /cpp/mindalpha/message_meta.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | // 25 | // ``message_meta.h`` defines class ``MessageMeta`` which stores 26 | // metadata of messages sent between Parameter Server nodes. 27 | // 28 | 29 | namespace mindalpha 30 | { 31 | 32 | class MessageMeta 33 | { 34 | public: 35 | int GetMessageId() const { return message_id_; } 36 | void SetMessageId(int value) { message_id_ = value; } 37 | 38 | int GetSender() const { return sender_; } 39 | void SetSender(int value) { sender_ = value; } 40 | 41 | int GetReceiver() const { return receiver_; } 42 | void SetReceiver(int value) { receiver_ = value; } 43 | 44 | bool IsRequest() const { return is_request_; } 45 | void SetIsRequest(bool value) { is_request_ = value; } 46 | 47 | bool IsException() const { return is_exception_; } 48 | void SetIsException(bool value) { is_exception_ = value; } 49 | 50 | const std::string& GetBody() const { return body_; } 51 | void SetBody(std::string value) { body_ = std::move(value); } 52 | 53 | const std::vector& GetSliceDataTypes() const { return slice_data_types_; } 54 | void SetSliceDataTypes(std::vector value) { slice_data_types_ = std::move(value); } 55 | void ClearSliceDataTypes() { slice_data_types_.clear(); } 56 | void AddSliceDataType(DataType value) { slice_data_types_.push_back(value); } 57 | 58 | NodeControl& GetNodeControl() { return node_control_; } 59 | const NodeControl& GetNodeControl() const { return node_control_ ; } 60 | void SetNodeControl(NodeControl value) { node_control_ = std::move(value); } 61 | 62 | std::string ToString() const; 63 | std::string ToJsonString() const; 64 | json11::Json to_json() const; 65 | 66 | TMessageMeta PackAsThriftObject() const; 67 | std::string PackAsThriftJson() const; 68 | SmartArray PackAsThriftBuffer() const; 69 | 70 | void UnpackFromThriftObject(TMessageMeta&& meta); 71 | void UnpackFromThriftJson(const std::string& str); 72 | void UnpackFromThriftBuffer(const uint8_t* ptr, size_t size); 73 | void UnpackFromThriftBuffer(const SmartArray& buf); 74 | void UnpackFromThriftBuffer(const std::string_view& buf); 75 | 76 | private: 77 | int message_id_ = -1; 78 | int sender_ = -1; 79 | int receiver_ = -1; 80 | bool is_request_ = false; 81 | bool is_exception_ = false; 82 | std::string body_; 83 | std::vector slice_data_types_; 84 | NodeControl node_control_; 85 | }; 86 | 87 | } 88 | -------------------------------------------------------------------------------- /cpp/mindalpha/message_transport.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | namespace mindalpha 24 | { 25 | 26 | MessageTransport::MessageTransport(std::shared_ptr config) 27 | : config_(std::move(config)) 28 | { 29 | } 30 | 31 | std::unique_ptr MessageTransport::Create(std::shared_ptr config) 32 | { 33 | const std::string& type = config->GetTransportType(); 34 | if (type == "ZeroMQ") 35 | return std::make_unique(std::move(config)); 36 | else 37 | { 38 | std::string serr; 39 | serr.append("MessageTransport type '"); 40 | serr.append(type); 41 | serr.append("' is not supported.\n\n"); 42 | serr.append(GetStackTrace()); 43 | spdlog::error(serr); 44 | throw std::runtime_error(serr); 45 | } 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /cpp/mindalpha/message_transport.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | namespace mindalpha 26 | { 27 | 28 | class MessageTransport 29 | { 30 | public: 31 | explicit MessageTransport(std::shared_ptr config); 32 | virtual ~MessageTransport() { } 33 | 34 | std::shared_ptr GetConfig() const { return config_; } 35 | void SetConfig(std::shared_ptr value) { config_ = std::move(value); } 36 | 37 | virtual void Start() = 0; 38 | virtual void Stop() = 0; 39 | virtual int Bind(const NodeInfo& node, int maxRetry) = 0; 40 | virtual void Connect(const NodeInfo& node) = 0; 41 | virtual int64_t SendMessage(const Message& msg) = 0; 42 | virtual int64_t ReceiveMessage(Message& msg) = 0; 43 | 44 | static std::unique_ptr Create(std::shared_ptr config); 45 | 46 | private: 47 | std::shared_ptr config_; 48 | }; 49 | 50 | } 51 | -------------------------------------------------------------------------------- /cpp/mindalpha/ml_ps_python_bindings.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | 21 | namespace mindalpha 22 | { 23 | 24 | template 25 | class PyPSAgent : public PSAgentBase 26 | { 27 | public: 28 | using PSAgentBase::PSAgentBase; 29 | 30 | void Run() override 31 | { 32 | PYBIND11_OVERLOAD_NAME( 33 | void, 34 | PSAgentBase, 35 | "run", 36 | Run, 37 | ); 38 | } 39 | 40 | void HandleRequest(mindalpha::PSMessage req) override 41 | { 42 | PYBIND11_OVERLOAD_NAME( 43 | void, 44 | PSAgentBase, 45 | "handle_request", 46 | HandleRequest, 47 | req 48 | ); 49 | } 50 | 51 | void Finalize() override 52 | { 53 | PYBIND11_OVERLOAD_NAME( 54 | void, 55 | PSAgentBase, 56 | "finalize", 57 | Finalize, 58 | ); 59 | } 60 | }; 61 | 62 | } 63 | -------------------------------------------------------------------------------- /cpp/mindalpha/model_metric_buffer.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | 19 | namespace mindalpha 20 | { 21 | 22 | void ModelMetricBuffer::UpdateBuffer(pybind11::array_t positive_buffer, 23 | pybind11::array_t negative_buffer, 24 | pybind11::array_t predictions, 25 | pybind11::array_t labels) 26 | { 27 | const size_t buffer_size = positive_buffer.size(); 28 | const size_t instance_count = labels.size(); 29 | double* const pos_buf = positive_buffer.mutable_data(); 30 | double* const neg_buf = negative_buffer.mutable_data(); 31 | const float* const preds = predictions.data(); 32 | const float* const labs = labels.data(); 33 | for (size_t i = 0; i < instance_count; i++) 34 | { 35 | const float pred = preds[i]; 36 | const float lab = labs[i]; 37 | const int64_t bucket = static_cast(pred * (buffer_size - 1)); 38 | if (lab > 0.0) 39 | pos_buf[bucket] += lab; 40 | else if (lab < 0.0) 41 | neg_buf[bucket] += -lab; 42 | else 43 | neg_buf[bucket] += 1.0; 44 | } 45 | } 46 | 47 | double ModelMetricBuffer::ComputeAUC(pybind11::array_t positive_buffer, 48 | pybind11::array_t negative_buffer) 49 | { 50 | const size_t buffer_size = positive_buffer.size(); 51 | const double* const pos_buf = positive_buffer.mutable_data(); 52 | const double* const neg_buf = negative_buffer.mutable_data(); 53 | double auc = 0.0; 54 | double prev_pos_sum = 0; 55 | double pos_sum = 0; 56 | double neg_sum = 0; 57 | for (size_t i = 0; i < buffer_size; i++) { 58 | const double pos = pos_buf[i]; 59 | const double neg = neg_buf[i]; 60 | prev_pos_sum = pos_sum; 61 | pos_sum += pos; 62 | neg_sum += neg; 63 | auc += 0.5 * (prev_pos_sum + pos_sum) * neg; 64 | } 65 | if (pos_sum == 0) 66 | auc = (neg_sum > 0) ? 0.0 : 1.0; 67 | else if (neg_sum == 0) 68 | auc = 1.0; 69 | else 70 | auc = 1.0 - auc / pos_sum / neg_sum; 71 | return auc; 72 | } 73 | 74 | } 75 | -------------------------------------------------------------------------------- /cpp/mindalpha/model_metric_buffer.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | 22 | namespace mindalpha 23 | { 24 | 25 | class ModelMetricBuffer 26 | { 27 | public: 28 | static void UpdateBuffer(pybind11::array_t positive_buffer, 29 | pybind11::array_t negative_buffer, 30 | pybind11::array_t predictions, 31 | pybind11::array_t labels); 32 | 33 | static double ComputeAUC(pybind11::array_t positive_buffer, 34 | pybind11::array_t negative_buffer); 35 | }; 36 | 37 | } 38 | -------------------------------------------------------------------------------- /cpp/mindalpha/network_utils.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | namespace mindalpha::network_utils 28 | { 29 | 30 | std::string get_ip(const std::string& interface) 31 | { 32 | struct ifaddrs* ifas = nullptr; 33 | getifaddrs(&ifas); 34 | std::unique_ptr ifas_guard(ifas, &freeifaddrs); 35 | for (auto ifa = ifas; ifa != nullptr; ifa = ifa->ifa_next) 36 | { 37 | if (ifa->ifa_addr == nullptr) 38 | continue; 39 | if (ifa->ifa_addr->sa_family != AF_INET) 40 | continue; 41 | if (interface != ifa->ifa_name) 42 | continue; 43 | auto addr = reinterpret_cast(ifa->ifa_addr); 44 | void* temp_addr_ptr = &addr->sin_addr; 45 | char address_buffer[INET_ADDRSTRLEN]; 46 | inet_ntop(AF_INET, temp_addr_ptr, address_buffer, INET_ADDRSTRLEN); 47 | return address_buffer; 48 | } 49 | return {}; 50 | } 51 | 52 | std::string get_interface_and_ip(std::string& interface) 53 | { 54 | struct ifaddrs* ifas = nullptr; 55 | getifaddrs(&ifas); 56 | std::unique_ptr ifas_guard(ifas, &freeifaddrs); 57 | for (auto ifa = ifas; ifa != nullptr; ifa = ifa->ifa_next) 58 | { 59 | if (ifa->ifa_addr == nullptr) 60 | continue; 61 | if (ifa->ifa_addr->sa_family != AF_INET) 62 | continue; 63 | if (ifa->ifa_flags & IFF_LOOPBACK) 64 | continue; 65 | auto addr = reinterpret_cast(ifa->ifa_addr); 66 | void* temp_addr_ptr = &addr->sin_addr; 67 | char address_buffer[INET_ADDRSTRLEN]; 68 | inet_ntop(AF_INET, temp_addr_ptr, address_buffer, INET_ADDRSTRLEN); 69 | interface = ifa->ifa_name; 70 | return address_buffer; 71 | } 72 | return {}; 73 | } 74 | 75 | int get_available_port() 76 | { 77 | struct sockaddr_in addr; 78 | addr.sin_port = htons(0); 79 | addr.sin_family = AF_INET; 80 | addr.sin_addr.s_addr = htonl(INADDR_ANY); 81 | const int sock = socket(AF_INET, SOCK_STREAM, 0); 82 | if (bind(sock, (struct sockaddr*)&addr, sizeof(struct sockaddr_in)) != 0) 83 | { 84 | perror("bind():"); 85 | return 0; 86 | } 87 | socklen_t addr_len = sizeof(struct sockaddr_in); 88 | if (getsockname(sock, (struct sockaddr*)&addr, &addr_len) != 0) 89 | { 90 | perror("getsockname():"); 91 | return 0; 92 | } 93 | const int port = ntohs(addr.sin_port); 94 | close(sock); 95 | return port; 96 | } 97 | 98 | } 99 | -------------------------------------------------------------------------------- /cpp/mindalpha/network_utils.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | 21 | namespace mindalpha::network_utils 22 | { 23 | 24 | std::string get_ip(const std::string& interface); 25 | std::string get_interface_and_ip(std::string& interface); 26 | int get_available_port(); 27 | 28 | } 29 | -------------------------------------------------------------------------------- /cpp/mindalpha/node_control.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | 19 | namespace mindalpha 20 | { 21 | 22 | std::string NodeControl::ToString() const 23 | { 24 | return ToJsonString(); 25 | } 26 | 27 | std::string NodeControl::ToJsonString() const 28 | { 29 | return to_json().dump(); 30 | } 31 | 32 | json11::Json NodeControl::to_json() const 33 | { 34 | std::vector group; 35 | if (BarrierGroupContainsCoordinator()) 36 | group.push_back("Coordinator"); 37 | if (BarrierGroupContainsServers()) 38 | group.push_back("Servers"); 39 | if (BarrierGroupContainsWorkers()) 40 | group.push_back("Workers"); 41 | return json11::Json::object 42 | { 43 | { "command", NullableNodeControlCommandToString(command_) }, 44 | { "nodes", nodes_ }, 45 | { "barrier_group", std::move(group) }, 46 | }; 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /cpp/mindalpha/node_control.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | // 27 | // ``node_control.h`` defines class ``NodeControl`` which contains information 28 | // about node control messages. 29 | // 30 | 31 | namespace mindalpha 32 | { 33 | 34 | class NodeControl 35 | { 36 | public: 37 | bool IsEmpty() const { return command_ == NullNodeControlCommand; } 38 | 39 | NodeControlCommand GetCommand() const { return command_; } 40 | void SetCommand(NodeControlCommand value) { command_ = value; } 41 | 42 | // 43 | // Methods related to node info. 44 | // 45 | std::vector& GetNodes() { return nodes_; } 46 | const std::vector& GetNodes() const { return nodes_; } 47 | void SetNodes(std::vector value) { nodes_ = std::move(value); } 48 | 49 | void ClearNodes() { nodes_.clear(); } 50 | void AddNode(NodeInfo value) { nodes_.push_back(std::move(value)); } 51 | 52 | // 53 | // Methods related to barrier group. 54 | // 55 | int GetBarrierGroup() const { return barrier_group_; } 56 | void SetBarrierGroup(int value) { barrier_group_ = value; } 57 | 58 | void ClearBarrierGroup() { barrier_group_ = 0; } 59 | void AddCoordinatorToBarrierGroup() { barrier_group_ |= CoordinatorGroup; } 60 | void AddServersToBarrierGroup() { barrier_group_ |= ServerGroup; } 61 | void AddWorkersToBarrierGroup() { barrier_group_ |= WorkerGroup; } 62 | void RemoveCoordinatorFromBarrierGroup() { barrier_group_ &= ~CoordinatorGroup; } 63 | void RemoveServersFromBarrierGroup() { barrier_group_ &= ~ServerGroup; } 64 | void RemoveWorkersFromBarrierGroup() { barrier_group_ &= ~WorkerGroup; } 65 | bool BarrierGroupContainsCoordinator() const { return (barrier_group_ & CoordinatorGroup) != 0; } 66 | bool BarrierGroupContainsServers() const { return (barrier_group_ & ServerGroup) != 0; } 67 | bool BarrierGroupContainsWorkers() const { return (barrier_group_ & WorkerGroup) != 0; } 68 | 69 | std::string ToString() const; 70 | std::string ToJsonString() const; 71 | json11::Json to_json() const; 72 | 73 | private: 74 | NodeControlCommand command_ = NullNodeControlCommand; 75 | std::vector nodes_; 76 | int barrier_group_ = 0; 77 | }; 78 | 79 | } 80 | -------------------------------------------------------------------------------- /cpp/mindalpha/node_control_command.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | namespace mindalpha 23 | { 24 | 25 | std::string NodeControlCommandToString(NodeControlCommand command) 26 | { 27 | switch (command) 28 | { 29 | #undef MINDALPHA_NODE_CONTROL_COMMAND_DEF 30 | #define MINDALPHA_NODE_CONTROL_COMMAND_DEF(n) case NodeControlCommand::n: return #n; 31 | MINDALPHA_NODE_CONTROL_COMMANDS(MINDALPHA_NODE_CONTROL_COMMAND_DEF) 32 | default: 33 | std::string serr; 34 | serr.append("Invalid NodeControlCommand enum value: "); 35 | serr.append(std::to_string(static_cast(command))); 36 | serr.append(".\n\n"); 37 | serr.append(GetStackTrace()); 38 | spdlog::error(serr); 39 | throw std::runtime_error(serr); 40 | } 41 | } 42 | 43 | NodeControlCommand NodeControlCommandFromString(const std::string& str) 44 | { 45 | #undef MINDALPHA_NODE_CONTROL_COMMAND_DEF 46 | #define MINDALPHA_NODE_CONTROL_COMMAND_DEF(n) if (str == #n) return NodeControlCommand::n; 47 | MINDALPHA_NODE_CONTROL_COMMANDS(MINDALPHA_NODE_CONTROL_COMMAND_DEF) 48 | std::string serr; 49 | serr.append("Invalid NodeControlCommand enum value: "); 50 | serr.append(str); 51 | serr.append(".\n\n"); 52 | serr.append(GetStackTrace()); 53 | spdlog::error(serr); 54 | throw std::runtime_error(serr); 55 | } 56 | 57 | std::string NullableNodeControlCommandToString(NodeControlCommand command) 58 | { 59 | if (command == NullNodeControlCommand) 60 | return NullNodeControlCommandString; 61 | return NodeControlCommandToString(command); 62 | } 63 | 64 | NodeControlCommand NullableNodeControlCommandFromString(const std::string& str) 65 | { 66 | if (str == NullNodeControlCommandString) 67 | return NullNodeControlCommand; 68 | return NodeControlCommandFromString(str); 69 | } 70 | 71 | } 72 | -------------------------------------------------------------------------------- /cpp/mindalpha/node_control_command.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | 21 | // 22 | // ``node_contrl_command.h`` defines enum ``NodeControlCommand`` to represent 23 | // control commands sent between Parameter Server nodes. 24 | // 25 | 26 | namespace mindalpha 27 | { 28 | 29 | // 30 | // Use the X Macro technique to simplify code. See the following page 31 | // for more information about X Macros: 32 | // 33 | // https://en.wikipedia.org/wiki/X_Macro 34 | // 35 | 36 | #define MINDALPHA_NODE_CONTROL_COMMANDS(X) \ 37 | X(Terminate) \ 38 | X(AddNode) \ 39 | X(Barrier) \ 40 | /**/ 41 | 42 | enum class NodeControlCommand 43 | { 44 | #undef MINDALPHA_NODE_CONTROL_COMMAND_DEF 45 | #define MINDALPHA_NODE_CONTROL_COMMAND_DEF(n) n, 46 | MINDALPHA_NODE_CONTROL_COMMANDS(MINDALPHA_NODE_CONTROL_COMMAND_DEF) 47 | }; 48 | 49 | // A missing ``NodeControlCommand`` is represented by ``NodeControlCommand(-1)``. 50 | constexpr NodeControlCommand NullNodeControlCommand = static_cast(-1); 51 | constexpr const char* NullNodeControlCommandString = "null"; 52 | 53 | // Functions to convert ``NodeControlCommand`` to and from strings. 54 | std::string NodeControlCommandToString(NodeControlCommand command); 55 | NodeControlCommand NodeControlCommandFromString(const std::string& str); 56 | 57 | std::string NullableNodeControlCommandToString(NodeControlCommand command); 58 | NodeControlCommand NullableNodeControlCommandFromString(const std::string& str); 59 | 60 | } 61 | -------------------------------------------------------------------------------- /cpp/mindalpha/node_encoding.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | #include 19 | 20 | namespace mindalpha 21 | { 22 | 23 | std::string NodeIdToString(int node_id) 24 | { 25 | std::ostringstream sout; 26 | if (node_id & CoordinatorGroup) 27 | sout << "C"; 28 | else if (node_id & ServerGroup) 29 | sout << "S"; 30 | else if (node_id & WorkerGroup) 31 | sout << "W"; 32 | else 33 | sout << "?"; 34 | if (node_id & SingleNodeIdTag) 35 | sout << "[" << NodeIdToRank(node_id) << "]"; 36 | else 37 | sout << "*"; 38 | sout << ":" << node_id; 39 | return sout.str(); 40 | } 41 | 42 | } 43 | -------------------------------------------------------------------------------- /cpp/mindalpha/node_encoding.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | 21 | // 22 | // ``node_encoding.h`` defines integer encodings of nodes and 23 | // node groups and related functions. 24 | // 25 | 26 | namespace mindalpha 27 | { 28 | 29 | // Integer encodings of node group of the same role. 30 | #undef MINDALPHA_NODE_ROLE_DEF 31 | #define MINDALPHA_NODE_ROLE_DEF(n) constexpr int n##Group = 1 << static_cast(NodeRole::n); 32 | MINDALPHA_NODE_ROLES(MINDALPHA_NODE_ROLE_DEF) 33 | 34 | // Tag value specify that this integer encoding identify a single node. 35 | constexpr int SingleNodeIdTag = 1 << (static_cast(NodeRole::Worker) + 1); 36 | 37 | // Tag value specify that this integer encoding identify a node of the specific role. 38 | #undef MINDALPHA_NODE_ROLE_DEF 39 | #define MINDALPHA_NODE_ROLE_DEF(n) constexpr int n##NodeIdTag = 1 << static_cast(NodeRole::n) | SingleNodeIdTag; 40 | MINDALPHA_NODE_ROLES(MINDALPHA_NODE_ROLE_DEF) 41 | 42 | // Encode node numbered ``rank`` as a node id, which is an integer. ``rank`` is zero-based. 43 | #undef MINDALPHA_NODE_ROLE_DEF 44 | #define MINDALPHA_NODE_ROLE_DEF(n) constexpr int n##RankToNodeId(int rank) { return rank << 4 | n##NodeIdTag; } 45 | MINDALPHA_NODE_ROLES(MINDALPHA_NODE_ROLE_DEF) 46 | 47 | // Get the zero-based ``rank`` from node id ``id``. 48 | constexpr int NodeIdToRank(int id) { return id >> 4; } 49 | 50 | // Node id of the coordinator node. Since there is a single coordinator node 51 | // in the Parameter Server system, its node id can be pre-computed. 52 | constexpr int CoordinatorNodeId = CoordinatorRankToNodeId(0); 53 | 54 | // Convert integer node id to a descriptive string. 55 | std::string NodeIdToString(int node_id); 56 | 57 | } 58 | -------------------------------------------------------------------------------- /cpp/mindalpha/node_info.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | namespace mindalpha 22 | { 23 | 24 | std::string NodeInfo::ToString() const 25 | { 26 | return ToJsonString(); 27 | } 28 | 29 | std::string NodeInfo::ToShortString() const 30 | { 31 | std::ostringstream sout; 32 | switch (role_) 33 | { 34 | case NodeRole::Coordinator: 35 | sout << "C"; 36 | break; 37 | case NodeRole::Server: 38 | sout << "S"; 39 | break; 40 | case NodeRole::Worker: 41 | sout << "W"; 42 | break; 43 | default: 44 | sout << "?"; 45 | break; 46 | } 47 | if (node_id_ != -1) 48 | sout << "[" << NodeIdToRank(node_id_) << "]"; 49 | sout << ":" << node_id_; 50 | return sout.str(); 51 | } 52 | 53 | std::string NodeInfo::ToJsonString() const 54 | { 55 | return to_json().dump(); 56 | } 57 | 58 | json11::Json NodeInfo::to_json() const 59 | { 60 | return json11::Json::object 61 | { 62 | { "role", NullableNodeRoleToString(role_) }, 63 | { "node_id", node_id_ }, 64 | { "host_name", host_name_ }, 65 | { "port", port_ }, 66 | }; 67 | } 68 | 69 | } 70 | -------------------------------------------------------------------------------- /cpp/mindalpha/node_info.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | // 25 | // ``node_info.h`` defines class ``NodeInfo`` which stores 26 | // information about nodes in the Parameter Server system. 27 | // 28 | 29 | namespace mindalpha 30 | { 31 | 32 | class NodeInfo 33 | { 34 | public: 35 | NodeRole GetRole() const { return role_; } 36 | void SetRole(NodeRole value) { role_ = value; } 37 | 38 | int GetNodeId() const { return node_id_; } 39 | void SetNodeId(int value) { node_id_ = value; } 40 | 41 | const std::string& GetHostName() const { return host_name_; } 42 | void SetHostName(std::string value) { host_name_ = std::move(value); } 43 | 44 | int GetPort() const { return port_; } 45 | void SetPort(int value) { port_ = value; } 46 | 47 | std::string GetAddress() const { return host_name_ + ":" + std::to_string(port_); } 48 | 49 | std::string ToString() const; 50 | std::string ToShortString() const; 51 | std::string ToJsonString() const; 52 | json11::Json to_json() const; 53 | 54 | private: 55 | NodeRole role_ = NullNodeRole; 56 | int node_id_ = -1; 57 | std::string host_name_; 58 | int port_ = -1; 59 | }; 60 | 61 | } 62 | -------------------------------------------------------------------------------- /cpp/mindalpha/node_manager.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | namespace mindalpha 30 | { 31 | 32 | class ActorProcess; 33 | 34 | class NodeManager 35 | { 36 | public: 37 | explicit NodeManager(std::shared_ptr config); 38 | 39 | std::shared_ptr GetConfig() const { return config_; } 40 | void SetConfig(std::shared_ptr value) { config_ = std::move(value); } 41 | 42 | const std::vector& GetNodeIds(int group) const; 43 | void Barrier(int group, ActorProcess& process); 44 | void NotifyBarrierDone(const Message& msg); 45 | void UpdateHeartbeat(int nodeId, time_t t); 46 | std::vector GetDeadNodes(int timeout); 47 | 48 | private: 49 | void InitNodeIds(); 50 | 51 | std::shared_ptr config_; 52 | std::mutex start_mutex_; 53 | time_t start_time_ = 0; 54 | std::unordered_map> node_ids_; 55 | std::mutex barrier_mutex_; 56 | std::condition_variable barrier_cv_; 57 | bool barrier_done_; 58 | std::mutex heartbeat_mutex_; 59 | std::unordered_map heartbeats_; 60 | }; 61 | 62 | } 63 | -------------------------------------------------------------------------------- /cpp/mindalpha/node_role.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | namespace mindalpha 23 | { 24 | 25 | std::string NodeRoleToString(NodeRole role) 26 | { 27 | switch (role) 28 | { 29 | #undef MINDALPHA_NODE_ROLE_DEF 30 | #define MINDALPHA_NODE_ROLE_DEF(n) case NodeRole::n: return #n; 31 | MINDALPHA_NODE_ROLES(MINDALPHA_NODE_ROLE_DEF) 32 | default: 33 | std::string serr; 34 | serr.append("Invalid NodeRole enum value: "); 35 | serr.append(std::to_string(static_cast(role))); 36 | serr.append(".\n\n"); 37 | serr.append(GetStackTrace()); 38 | spdlog::error(serr); 39 | throw std::runtime_error(serr); 40 | } 41 | } 42 | 43 | NodeRole NodeRoleFromString(const std::string& str) 44 | { 45 | #undef MINDALPHA_NODE_ROLE_DEF 46 | #define MINDALPHA_NODE_ROLE_DEF(n) if (str == #n) return NodeRole::n; 47 | MINDALPHA_NODE_ROLES(MINDALPHA_NODE_ROLE_DEF) 48 | std::string serr; 49 | serr.append("Invalid NodeRole enum value: "); 50 | serr.append(str); 51 | serr.append(".\n\n"); 52 | serr.append(GetStackTrace()); 53 | spdlog::error(serr); 54 | throw std::runtime_error(serr); 55 | } 56 | 57 | std::string NullableNodeRoleToString(NodeRole role) 58 | { 59 | if (role == NullNodeRole) 60 | return NullNodeRoleString; 61 | return NodeRoleToString(role); 62 | } 63 | 64 | NodeRole NullableNodeRoleFromString(const std::string& str) 65 | { 66 | if (str == NullNodeRoleString) 67 | return NullNodeRole; 68 | return NodeRoleFromString(str); 69 | } 70 | 71 | } 72 | -------------------------------------------------------------------------------- /cpp/mindalpha/node_role.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | 21 | // 22 | // ``node_role.h`` defines enum ``NodeRole`` to represent Parameter Server 23 | // node roles and some helper functions to convert ``NodeRole`` values. 24 | // 25 | // 26 | 27 | namespace mindalpha 28 | { 29 | 30 | // 31 | // Use the X Macro technique to simplify code. See the following page 32 | // for more information about X Macros: 33 | // 34 | // https://en.wikipedia.org/wiki/X_Macro 35 | // 36 | 37 | #define MINDALPHA_NODE_ROLES(X) \ 38 | X(Coordinator) \ 39 | X(Server) \ 40 | X(Worker) \ 41 | /**/ 42 | 43 | enum class NodeRole 44 | { 45 | #undef MINDALPHA_NODE_ROLE_DEF 46 | #define MINDALPHA_NODE_ROLE_DEF(n) n, 47 | MINDALPHA_NODE_ROLES(MINDALPHA_NODE_ROLE_DEF) 48 | }; 49 | 50 | // A missing ``NodeRole`` is represented by ``NodeRole(-1)``. 51 | constexpr NodeRole NullNodeRole = static_cast(-1); 52 | constexpr const char* NullNodeRoleString = "null"; 53 | 54 | // Functions to convert ``NodeRole`` to and from strings. 55 | std::string NodeRoleToString(NodeRole role); 56 | NodeRole NodeRoleFromString(const std::string& str); 57 | 58 | std::string NullableNodeRoleToString(NodeRole role); 59 | NodeRole NullableNodeRoleFromString(const std::string& str); 60 | 61 | } 62 | -------------------------------------------------------------------------------- /cpp/mindalpha/ps_agent.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | namespace mindalpha 30 | { 31 | 32 | using PSMessage = std::shared_ptr; 33 | using SingleCallback = std::function; 34 | using MultipleCallback = std::function reqs, std::vector ress)>; 35 | using BroadcastCallback = std::function ress)>; 36 | using PSAgentCreator = std::function()>; 37 | 38 | class PSAgent : public std::enable_shared_from_this 39 | { 40 | friend class ActorProcess; 41 | 42 | public: 43 | virtual ~PSAgent() { } 44 | 45 | virtual void Run() { } 46 | virtual void HandleRequest(PSMessage req); 47 | virtual void Finalize() { } 48 | 49 | bool IsCoordinator() const { return is_coordinator_; } 50 | bool IsServer() const { return is_server_; } 51 | bool IsWorker() const { return is_worker_; } 52 | 53 | int GetServerCount() const { return server_count_; } 54 | int GetWorkerCount() const { return worker_count_; } 55 | int GetAgentRank() const; 56 | 57 | void Barrier(int group); 58 | void Shutdown(); 59 | 60 | void SendRequest(PSMessage req, SingleCallback cb); 61 | void SendAllRequests(std::vector reqs, MultipleCallback cb); 62 | void BroadcastRequest(PSMessage req, BroadcastCallback cb); 63 | void SendResponse(PSMessage req, PSMessage res); 64 | void HandleMessage(PSMessage msg); 65 | 66 | std::string ToString() const; 67 | 68 | private: 69 | class ActorProcess* actor_process_ = nullptr; 70 | 71 | struct TrackerEntry 72 | { 73 | int total = 0; 74 | std::vector responses; 75 | 76 | void Clear() 77 | { 78 | responses.clear(); 79 | } 80 | }; 81 | 82 | std::mutex tracker_mutex_; 83 | std::condition_variable tracker_cv_; 84 | std::unordered_map tracker_; 85 | 86 | bool is_coordinator_ = false; 87 | bool is_server_ = false; 88 | bool is_worker_ = false; 89 | 90 | int server_count_ = 0; 91 | int worker_count_ = 0; 92 | }; 93 | 94 | } 95 | -------------------------------------------------------------------------------- /cpp/mindalpha/ps_default_agent.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | namespace mindalpha 24 | { 25 | 26 | class __attribute__((visibility("hidden"))) PSDefaultAgent : public PSAgent 27 | { 28 | public: 29 | pybind11::object GetPyAgent() const { return py_agent_; } 30 | void SetPyAgent(pybind11::object value) { py_agent_ = std::move(value); } 31 | 32 | void Run() override; 33 | void HandleRequest(PSMessage req) override; 34 | void Finalize() override; 35 | 36 | private: 37 | pybind11::object py_agent_; 38 | std::unique_ptr store_; 39 | }; 40 | 41 | } 42 | -------------------------------------------------------------------------------- /cpp/mindalpha/ps_helper.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | namespace mindalpha { 24 | 25 | std::shared_ptr 26 | GetLocalConfig(const std::string &role, PSAgentCreator agent_creator) 27 | { 28 | auto config = std::make_shared(); 29 | config->SetRootUri("localhost"); 30 | config->SetRootPort(network_utils::get_available_port()); 31 | config->SetServerCount(2); 32 | config->SetWorkerCount(2); 33 | //config->SetIsMessageDumpingEnabled(true); 34 | config->SetAgentCreator(std::move(agent_creator)); 35 | if (role.empty()) { 36 | config->SetIsLocalMode(true); 37 | } 38 | else { 39 | if (role == "C") { 40 | config->SetNodeRole(NodeRole::Coordinator); 41 | } 42 | else if (role == "S") { 43 | config->SetNodeRole(NodeRole::Server); 44 | } 45 | else if (role == "W") { 46 | config->SetNodeRole(NodeRole::Worker); 47 | } 48 | else { 49 | std::cerr << "role must be [C | S | W]"; 50 | exit(-1); 51 | } 52 | } 53 | return config; 54 | } 55 | 56 | std::shared_ptr GetLocalConfig(const std::string &role) { 57 | return GetLocalConfig(role); 58 | } 59 | 60 | } 61 | -------------------------------------------------------------------------------- /cpp/mindalpha/ps_helper.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | 22 | namespace mindalpha 23 | { 24 | 25 | std::shared_ptr 26 | GetLocalConfig(const std::string& role, PSAgentCreator agent_creator); 27 | 28 | template 29 | std::shared_ptr 30 | GetLocalConfig(const std::string& role) 31 | { 32 | return GetLocalConfig(role, [] { 33 | std::shared_ptr agent = std::make_shared(); 34 | return agent; 35 | }); 36 | } 37 | 38 | std::shared_ptr 39 | GetLocalConfig(const std::string& role); 40 | 41 | } 42 | -------------------------------------------------------------------------------- /cpp/mindalpha/ps_runner.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | 21 | namespace mindalpha 22 | { 23 | 24 | class PSRunner 25 | { 26 | public: 27 | static void RunPS(std::shared_ptr config); 28 | 29 | private: 30 | static void RunPSCoordinator(std::shared_ptr config); 31 | static void RunPSServer(std::shared_ptr config); 32 | static void RunPSWorker(std::shared_ptr config); 33 | }; 34 | 35 | } 36 | -------------------------------------------------------------------------------- /cpp/mindalpha/pybind_utils.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | namespace mindalpha 23 | { 24 | 25 | std::shared_ptr make_shared_pyobject(pybind11::object obj) 26 | { 27 | std::shared_ptr obj_ptr( 28 | new pybind11::object(std::move(obj)), 29 | [](pybind11::object* ptr) { 30 | pybind11::gil_scoped_acquire gil; 31 | //pybind11::print("Python object", *ptr, "is deleted."); 32 | delete ptr; 33 | }); 34 | return std::move(obj_ptr); 35 | } 36 | 37 | std::string serialize_pyobject(pybind11::object obj) 38 | { 39 | if (obj.is_none()) 40 | return {}; 41 | pybind11::module base64 = pybind11::module::import("base64"); 42 | pybind11::module cloudpickle = pybind11::module::import("cloudpickle"); 43 | pybind11::bytes data = cloudpickle.attr("dumps")(obj); 44 | pybind11::str result = base64.attr("b64encode")(data).attr("decode")("ascii"); 45 | return result; 46 | } 47 | 48 | pybind11::object deserialize_pyobject(const std::string& data) 49 | { 50 | if (data.empty()) 51 | return {}; 52 | pybind11::module base64 = pybind11::module::import("base64"); 53 | pybind11::module pickle = pybind11::module::import("pickle"); 54 | pybind11::bytes buffer = base64.attr("b64decode")(pybind11::bytes(data)); 55 | pybind11::object obj = pickle.attr("loads")(buffer); 56 | return obj; 57 | } 58 | 59 | void fixup_attributes(pybind11::object obj) 60 | { 61 | pybind11::module compat = pybind11::module::import("mindalpha.compat"); 62 | compat.attr("fixup_attributes")(obj); 63 | } 64 | 65 | pybind11::array make_numpy_array(SmartArray data, DataType dtype) 66 | { 67 | namespace py = pybind11; 68 | py::object obj = py::cast(data); 69 | py::buffer buf(obj); 70 | py::array arr(buf); 71 | py::array result; 72 | switch (dtype) 73 | { 74 | #undef MINDALPHA_DATA_TYPE_DEF 75 | #define MINDALPHA_DATA_TYPE_DEF(t, l, u) \ 76 | case mindalpha::DataType::u: \ 77 | result = arr.attr("view")(#l); \ 78 | break; \ 79 | /**/ 80 | MINDALPHA_DATA_TYPES(MINDALPHA_DATA_TYPE_DEF) 81 | } 82 | return result; 83 | } 84 | 85 | std::tuple make_string_object_tuple(pybind11::bytes obj) 86 | { 87 | using namespace std::string_view_literals; 88 | const size_t length = PyBytes_Size(obj.ptr()); 89 | const char* p = PyBytes_AsString(obj.ptr()); 90 | if (p == nullptr || length == 0 || *p == '\0') 91 | return std::make_tuple("none"sv, std::move(obj)); 92 | else 93 | return std::make_tuple(std::string_view{p, length}, std::move(obj)); 94 | } 95 | 96 | std::tuple get_string_object_tuple(pybind11::object obj) 97 | { 98 | using namespace std::string_view_literals; 99 | if (obj.is_none()) 100 | return std::make_tuple("none"sv, std::move(obj)); 101 | else if (pybind11::isinstance(obj)) 102 | return make_string_object_tuple(obj.cast()); 103 | else if (pybind11::isinstance(obj)) 104 | return make_string_object_tuple(obj.attr("encode")("utf-8").cast()); 105 | else 106 | { 107 | std::string serr; 108 | serr.append("None, bytes or str expected\n\n"); 109 | serr.append(GetStackTrace()); 110 | spdlog::error(serr); 111 | throw std::runtime_error(serr); 112 | } 113 | } 114 | 115 | } 116 | -------------------------------------------------------------------------------- /cpp/mindalpha/pybind_utils.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | namespace mindalpha 26 | { 27 | 28 | std::shared_ptr make_shared_pyobject(pybind11::object obj); 29 | 30 | template 31 | std::shared_ptr extract_shared_pyobject(pybind11::object obj) 32 | { 33 | std::shared_ptr obj_ptr = make_shared_pyobject(std::move(obj)); 34 | std::shared_ptr ptr1 = obj_ptr->cast>(); 35 | std::shared_ptr ptr2(obj_ptr, ptr1.get()); 36 | return std::move(ptr2); 37 | } 38 | 39 | std::string serialize_pyobject(pybind11::object obj); 40 | pybind11::object deserialize_pyobject(const std::string& data); 41 | 42 | void fixup_attributes(pybind11::object obj); 43 | 44 | pybind11::array make_numpy_array(SmartArray data, DataType dtype); 45 | 46 | template 47 | inline pybind11::array make_numpy_array(SmartArray data) 48 | { 49 | SmartArray data_u8 = data.template Cast(); 50 | DataType dtype = DataTypeToCode::value; 51 | return make_numpy_array(data_u8, dtype); 52 | } 53 | 54 | template 55 | inline pybind11::array to_numpy_array(std::vector data) 56 | { 57 | auto data_arr = SmartArray::Wrap(std::move(data)); 58 | return make_numpy_array(data_arr); 59 | } 60 | 61 | template 62 | pybind11::tuple make_python_tuple(const std::vector& vec) 63 | { 64 | pybind11::list result(vec.size()); 65 | for (size_t i = 0; i < vec.size(); i++) 66 | result[i] = vec.at(i); 67 | return result; 68 | } 69 | 70 | template 71 | pybind11::list make_python_list(const std::vector& vec) 72 | { 73 | pybind11::list result(vec.size()); 74 | for (size_t i = 0; i < vec.size(); i++) 75 | result[i] = vec.at(i); 76 | return result; 77 | } 78 | 79 | template 80 | std::vector make_cpp_vector(pybind11::object obj) 81 | { 82 | std::vector result; 83 | for (pybind11::handle item : obj) 84 | { 85 | T t = item.cast(); 86 | result.push_back(std::move(t)); 87 | } 88 | return result; 89 | } 90 | 91 | std::tuple make_string_object_tuple(pybind11::bytes obj); 92 | std::tuple get_string_object_tuple(pybind11::object obj); 93 | 94 | } 95 | -------------------------------------------------------------------------------- /cpp/mindalpha/s3_sdk_filesys.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | /*! 18 | * Copyright (c) 2015 by Contributors 19 | * \file s3_sdk_filesys.h 20 | * \brief S3 access module 21 | * \author Tianqi Chen 22 | */ 23 | #ifndef DMLC_IO_S3_FILESYS_H_ 24 | #define DMLC_IO_S3_FILESYS_H_ 25 | 26 | #include 27 | #include "mindalpha/filesys.h" 28 | 29 | namespace mindalpha { 30 | /*! \brief AWS S3 filesystem */ 31 | class S3FileSystem : public FileSystem { 32 | public: 33 | /*! \brief destructor */ 34 | virtual ~S3FileSystem() {} 35 | /*! 36 | * \brief get information about a path 37 | * \param path the path to the file 38 | * \return the information about the file 39 | */ 40 | virtual FileInfo GetPathInfo(const URI &path) override; 41 | /*! 42 | * \brief list files in a directory 43 | * \param path to the file 44 | * \param out_list the output information about the files 45 | */ 46 | virtual void ListDirectory(const URI &path, std::vector *out_list) override; 47 | /*! 48 | * \brief open a stream, will report error and exit if bad thing happens 49 | * NOTE: the Stream can continue to work even when filesystem was destructed 50 | * \param path path to file 51 | * \param uri the uri of the input 52 | * \param flag can be "w", "r", "a" 53 | * \param allow_null whether NULL can be returned, or directly report error 54 | * \return the created stream, can be NULL when allow_null == true and file do not exist 55 | */ 56 | virtual Stream *Open(const URI &path, const char* const flag, bool allow_null) override; 57 | /*! 58 | * \brief open a seekable stream for read 59 | * \param path the path to the file 60 | * \param allow_null whether NULL can be returned, or directly report error 61 | * \return the created stream, can be NULL when allow_null == true and file do not exist 62 | */ 63 | virtual SeekStream *OpenForRead(const URI &path, bool allow_null) override; 64 | /*! 65 | * \brief get a singleton of S3FileSystem when needed 66 | * \return a singleton instance 67 | */ 68 | static S3FileSystem *GetInstance(void); 69 | 70 | private: 71 | /*! \brief constructor */ 72 | S3FileSystem(); 73 | }; 74 | } // namespace mindalpha 75 | #endif // DMLC_IO_S3_FILESYS_H_ 76 | -------------------------------------------------------------------------------- /cpp/mindalpha/sparse_tensor.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | namespace mindalpha 27 | { 28 | 29 | class SparseTensor 30 | { 31 | public: 32 | SparseTensorMeta& GetMeta() { return meta_; } 33 | const SparseTensorMeta& GetMeta() const { return meta_; } 34 | void SetMeta(SparseTensorMeta value) { meta_ = std::move(value); } 35 | 36 | std::shared_ptr GetAgent() const { return agent_; } 37 | void SetAgent(std::shared_ptr value) { agent_ = std::move(value); } 38 | 39 | void Init(std::function cb); 40 | void Dispose(std::function cb); 41 | void Clear(std::function cb); 42 | void Push(SmartArray keys, SmartArray in, std::function cb, 43 | bool is_value = false); 44 | void Pull(SmartArray keys, std::function out)> cb, 45 | bool read_only = false, bool nan_fill = false); 46 | void PushPartition(ArrayHashMap& data, std::function cb, 47 | bool data_only = false, bool skip_existing = false); 48 | void PullPartition(ArrayHashMap& data, std::function cb, 49 | bool data_only = false, int index = -1, int count = -1); 50 | void PushMeta(const SparseTensorMeta& meta, std::function cb); 51 | void PullMeta(std::function cb); 52 | void Load(const std::string& dir_path, std::function cb, bool keep_meta = false); 53 | void Save(const std::string& dir_path, std::function cb, bool text_mode = false); 54 | void Export(const std::string& dir_path, std::function cb); 55 | void ImportFrom(const std::string& meta_file_path, std::function cb, 56 | bool data_only = false, bool skip_existing = false, 57 | bool transform_key = false, const std::string& feature_name = ""); 58 | void PruneSmall(double epsilon, std::function cb); 59 | void PruneOld(int max_age, std::function cb); 60 | 61 | private: 62 | std::string GetSparseMetaPath(const std::string& dir_path) const; 63 | static std::string GetSparsePath(const std::string& dir_path, const SparseTensorMeta& meta, int index); 64 | 65 | SparseTensorMeta meta_; 66 | std::shared_ptr agent_; 67 | }; 68 | 69 | } 70 | -------------------------------------------------------------------------------- /cpp/mindalpha/sparse_tensor_partition.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | 22 | namespace mindalpha 23 | { 24 | 25 | class SparseTensorPartition 26 | { 27 | public: 28 | SparseTensorMeta& GetMeta() { return meta_; } 29 | const SparseTensorMeta& GetMeta() const { return meta_; } 30 | void SetMeta(SparseTensorMeta value) { meta_ = std::move(value); } 31 | 32 | int GetPartitionIndex() const { return partition_index_; } 33 | void SetPartitionIndex(int value) { partition_index_ = value; } 34 | 35 | void AllocateHashMap(); 36 | void Clear(); 37 | void HandlePush(SmartArray keys, SmartArray in, bool is_value); 38 | SmartArray HandlePull(SmartArray keys, bool read_only, bool nan_fill); 39 | void HandlePushPartition(SmartArray keys, SmartArray in, bool data_only, bool skip_existing); 40 | SmartArray HandlePullPartition(bool data_only, int index, int count, SmartArray& keys); 41 | void HandlePushMeta(const SparseTensorMeta& meta); 42 | const SparseTensorMeta& HandlePullMeta(); 43 | void Load(const std::string& dir_path); 44 | void Save(const std::string& dir_path, bool text_mode); 45 | void Export(const std::string& dir_path); 46 | void PruneSmall(double epsilon); 47 | void PruneOld(int max_age); 48 | 49 | private: 50 | template 51 | void DoPruneSmall(double epsilon); 52 | 53 | void TransformIndices(SmartArray keys, bool pull, bool read_only); 54 | std::string GetSparsePath(const std::string& dir_path) const; 55 | std::string GetSparseExportPath(const std::string& dir_path) const; 56 | 57 | static constexpr uint64_t kPaddingKey = 0; 58 | static constexpr uint64_t kPaddingIndex = uint64_t(-2); 59 | static constexpr uint64_t kNotFoundIndex = uint64_t(-1); 60 | SparseTensorMeta meta_; 61 | int partition_index_ = -1; 62 | ArrayHashMap data_; 63 | }; 64 | 65 | } 66 | -------------------------------------------------------------------------------- /cpp/mindalpha/stack_trace_utils.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | namespace mindalpha 26 | { 27 | 28 | namespace 29 | { 30 | 31 | std::vector GetStackTraceAddresses() 32 | { 33 | std::vector vec(1024); 34 | for (;;) 35 | { 36 | const int size = backtrace(&vec[0], (int)vec.size()); 37 | if (size == vec.size()) 38 | vec.resize(vec.size() * 2); 39 | else 40 | { 41 | vec.resize(size); 42 | return vec; 43 | } 44 | } 45 | } 46 | 47 | std::vector GetStackTraceSymbols() 48 | { 49 | std::vector addresses = GetStackTraceAddresses(); 50 | char** symbols = backtrace_symbols(&addresses[0], (int)addresses.size()); 51 | std::unique_ptr symbols_guard(symbols, &free); 52 | std::vector result; 53 | if (symbols) 54 | result.insert(result.end(), symbols, symbols + addresses.size()); 55 | return result; 56 | } 57 | 58 | bool DecodeStackTraceSymbol(const std::string& symbol, 59 | std::string& file_name, 60 | std::string& function_name, 61 | uintptr_t& offset, 62 | uintptr_t& address) 63 | { 64 | static const std::regex re("(.+)\\(([^+]+)\\+0x([0-9a-f]+)\\) \\[0x([0-9a-f]+)\\]"); 65 | std::smatch m; 66 | if (!std::regex_match(symbol, m, re)) 67 | return false; 68 | file_name = m[1].str(); 69 | function_name = m[2].str(); 70 | offset = std::stoull(m[3].str(), nullptr, 16); 71 | address = std::stoull(m[4].str(), nullptr, 16); 72 | int status = 0; 73 | char* demangled = abi::__cxa_demangle(function_name.c_str(), NULL, NULL, &status); 74 | std::unique_ptr demangled_guard(demangled, &free); 75 | if (!demangled) 76 | return false; 77 | function_name = demangled; 78 | return true; 79 | } 80 | 81 | std::string DemangleStackTraceSymbol(const std::string& symbol) 82 | { 83 | std::string file_name; 84 | std::string function_name; 85 | uintptr_t offset; 86 | uintptr_t address; 87 | if (!DecodeStackTraceSymbol(symbol, file_name, function_name, offset, address)) 88 | return symbol; 89 | std::ostringstream sout; 90 | sout << file_name << "(" << function_name << "+0x" << std::hex << offset << ")"; 91 | sout << " [0x" << address << "]"; 92 | return sout.str(); 93 | } 94 | 95 | } 96 | 97 | std::string GetStackTrace() 98 | { 99 | const int offset = 3; 100 | std::ostringstream sout; 101 | std::vector symbols = GetStackTraceSymbols(); 102 | sout << "Stack trace returned " << (symbols.size() - offset) << " entries:"; 103 | for (size_t i = offset; i < symbols.size(); i++) 104 | sout << "\n[bt] (" << (i - offset) << ") " << DemangleStackTraceSymbol(symbols[i]); 105 | return sout.str(); 106 | } 107 | 108 | } 109 | -------------------------------------------------------------------------------- /cpp/mindalpha/stack_trace_utils.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | 21 | // 22 | // ``stack_trace_utils.h`` defines utility functions for stack traces. 23 | // 24 | 25 | namespace mindalpha 26 | { 27 | 28 | // Return function call stack trace as a string, which can be included 29 | // in exception and logging messages for debug purpose. 30 | std::string GetStackTrace(); 31 | 32 | } 33 | -------------------------------------------------------------------------------- /cpp/mindalpha/tensor_partition_store.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | namespace mindalpha 26 | { 27 | 28 | class TensorPartitionStore 29 | { 30 | public: 31 | int GetPartitionCount() const { return partition_count_; } 32 | void SetPartitionCount(int value) { partition_count_ = value; } 33 | 34 | int GetPartitionIndex() const { return partition_index_; } 35 | void SetPartitionIndex(int value) { partition_index_ = value; } 36 | 37 | void DenseInit(const DenseTensorMeta& meta); 38 | void DenseDispose(const std::string& name); 39 | void DensePush(const std::string& name, PSMessage req, bool is_value, bool is_state); 40 | PSMessage DensePull(const std::string& name, bool is_state); 41 | void DensePushMeta(const std::string& name, const DenseTensorMeta& meta); 42 | PSMessage DensePullMeta(const std::string& name); 43 | 44 | void SparseInit(const SparseTensorMeta& meta); 45 | void SparseDispose(const std::string& name); 46 | void SparseClear(const std::string& name); 47 | void SparsePush(const std::string& name, PSMessage req, bool is_value); 48 | PSMessage SparsePull(const std::string& name, PSMessage req, bool read_only, bool nan_fill); 49 | void SparsePushPartition(const std::string& name, PSMessage req, bool data_only, bool skip_existing); 50 | PSMessage SparsePullPartition(const std::string& name, bool data_only, int index, int count); 51 | void SparsePushMeta(const std::string& name, const SparseTensorMeta& meta); 52 | PSMessage SparsePullMeta(const std::string& name); 53 | void SparseLoad(const std::string& name, const std::string& dir_path); 54 | void SparseSave(const std::string& name, const std::string& dir_path, bool text_mode); 55 | void SparseExport(const std::string& name, const std::string& dir_path); 56 | void SparsePruneSmall(const std::string& name, double epsilon); 57 | void SparsePruneOld(const std::string& name, int max_age); 58 | 59 | private: 60 | int partition_count_ = -1; 61 | int partition_index_ = -1; 62 | std::unordered_map dense_store_; 63 | std::unordered_map sparse_store_; 64 | }; 65 | 66 | } 67 | -------------------------------------------------------------------------------- /cpp/mindalpha/tensor_store_python_bindings.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | namespace mindalpha 24 | { 25 | 26 | template 27 | class PyPSDefaultAgent : public PyPSAgent { }; 28 | 29 | void DefineTensorStoreBindings(pybind11::module& m); 30 | 31 | } 32 | -------------------------------------------------------------------------------- /cpp/mindalpha/tensor_utils.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | namespace mindalpha 26 | { 27 | 28 | size_t SliceElements(const std::vector& shape) 29 | { 30 | if (shape.empty()) 31 | return 0; 32 | size_t n = 1; 33 | for (size_t i = 1; i < shape.size(); i++) 34 | n *= shape[i]; 35 | return n; 36 | } 37 | 38 | size_t TotalElements(const std::vector& shape) 39 | { 40 | if (shape.empty()) 41 | return 0; 42 | size_t n = 1; 43 | for (size_t i = 0; i < shape.size(); i++) 44 | n *= shape[i]; 45 | return n; 46 | } 47 | 48 | std::string ShapeToString(const std::vector& shape) 49 | { 50 | std::ostringstream sout; 51 | for (size_t i = 0; i < shape.size(); i++) 52 | sout << (i ? " " : "") << shape.at(i); 53 | return sout.str(); 54 | } 55 | 56 | std::vector ShapeFromString(const std::string& str) 57 | { 58 | std::vector shape; 59 | std::istringstream sin(str); 60 | size_t dim; 61 | while (sin >> dim) 62 | shape.push_back(dim); 63 | return shape; 64 | } 65 | 66 | template 67 | void FillNaNValues(uint8_t* buffer, size_t size) 68 | { 69 | if (size % sizeof(T) != 0) 70 | { 71 | std::string serr; 72 | serr.append("Buffer size "); 73 | serr.append(std::to_string(size)); 74 | serr.append(" is not a multiple of sizeof("); 75 | serr.append(DataTypeToString(DataTypeToCode::value)); 76 | serr.append(".\n\n"); 77 | serr.append(GetStackTrace()); 78 | spdlog::error(serr); 79 | throw std::runtime_error(serr); 80 | } 81 | T* buf = reinterpret_cast(buffer); 82 | const size_t n = size / sizeof(T); 83 | for (size_t i = 0; i < n; i++) 84 | buf[i] = std::numeric_limits::quiet_NaN(); 85 | } 86 | 87 | void FillNaN(uint8_t* buffer, size_t size, DataType type) 88 | { 89 | switch (type) 90 | { 91 | case DataType::Float32: 92 | FillNaNValues(buffer, size); 93 | break; 94 | case DataType::Float64: 95 | FillNaNValues(buffer, size); 96 | break; 97 | default: 98 | std::string serr; 99 | serr.append("DataType must be float32 or float64 to fill NaN values; "); 100 | serr.append(DataTypeToString(type)); 101 | serr.append(" is invalid.\n\n"); 102 | serr.append(GetStackTrace()); 103 | spdlog::error(serr); 104 | throw std::runtime_error(serr); 105 | } 106 | } 107 | 108 | void MakeInitializerReady(pybind11::object initializer) 109 | { 110 | fixup_attributes(initializer); 111 | } 112 | 113 | void MakeUpdaterReady(pybind11::object updater) 114 | { 115 | fixup_attributes(updater); 116 | } 117 | 118 | } 119 | -------------------------------------------------------------------------------- /cpp/mindalpha/tensor_utils.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | namespace mindalpha 26 | { 27 | 28 | size_t SliceElements(const std::vector& shape); 29 | size_t TotalElements(const std::vector& shape); 30 | std::string ShapeToString(const std::vector& shape); 31 | std::vector ShapeFromString(const std::string& str); 32 | void FillNaN(uint8_t* buffer, size_t size, DataType type); 33 | void MakeInitializerReady(pybind11::object initializer); 34 | void MakeUpdaterReady(pybind11::object udpater); 35 | 36 | } 37 | -------------------------------------------------------------------------------- /cpp/mindalpha/thread_utils.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | namespace mindalpha 26 | { 27 | 28 | std::string GetThreadIdentifier() 29 | { 30 | std::ostringstream sout; 31 | sout << "pid: " << getpid() << ", "; 32 | sout << "tid: " << syscall(SYS_gettid) << ", "; 33 | sout << "thread: 0x" << std::hex << static_cast(pthread_self()); 34 | return sout.str(); 35 | } 36 | 37 | } 38 | -------------------------------------------------------------------------------- /cpp/mindalpha/thread_utils.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | 21 | // 22 | // ``thread_utils.h`` defines utility functions for threads. 23 | // 24 | 25 | namespace mindalpha 26 | { 27 | 28 | // Return identifier of current thread as a string, which can be 29 | // included in exception and logging messages for debug purpose. 30 | std::string GetThreadIdentifier(); 31 | 32 | } 33 | -------------------------------------------------------------------------------- /cpp/mindalpha/vector_utils.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | 22 | // 23 | // ``vector_utils.h`` defines utility functions for ``std::vector``. 24 | // 25 | 26 | namespace mindalpha 27 | { 28 | 29 | template 30 | inline void VectorAppend(std::vector& v, InputIterator first, InputIterator last) 31 | { 32 | v.insert(v.end(), first, last); 33 | } 34 | 35 | template 36 | inline void VectorAppend(std::vector& v, RandomAccessIterator first, size_t count) 37 | { 38 | VectorAppend(v, first, first + count); 39 | } 40 | 41 | template 42 | struct VectorBase 43 | { 44 | T* start; 45 | T* finish; 46 | T* end_of_storage; 47 | 48 | VectorBase() 49 | : start() 50 | , finish() 51 | , end_of_storage() 52 | { 53 | } 54 | }; 55 | 56 | template 57 | inline void VectorDetachBuffer(std::vector& v, T*& data, size_t& size, size_t& capacity) 58 | { 59 | VectorBase base; 60 | std::vector& fake = reinterpret_cast&>(base); 61 | fake.swap(v); 62 | data = base.start; 63 | size = base.finish - base.start; 64 | capacity = base.end_of_storage - base.start; 65 | } 66 | 67 | template 68 | inline void VectorAttachBuffer(std::vector& v, T* data, size_t size, size_t capacity) 69 | { 70 | VectorBase base; 71 | base.start = data; 72 | base.finish = data + size; 73 | base.end_of_storage = data + capacity; 74 | std::vector& fake = reinterpret_cast&>(base); 75 | std::vector t; 76 | fake.swap(t); 77 | t.swap(v); 78 | } 79 | 80 | } 81 | -------------------------------------------------------------------------------- /cpp/mindalpha/zeromq_transport.h: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | namespace mindalpha 24 | { 25 | 26 | class ZeroMQTransport : public MessageTransport 27 | { 28 | public: 29 | explicit ZeroMQTransport(std::shared_ptr config); 30 | 31 | void Start() override; 32 | void Stop() override; 33 | int Bind(const NodeInfo& node, int maxRetry) override; 34 | void Connect(const NodeInfo& node) override; 35 | int64_t SendMessage(const Message& msg) override; 36 | int64_t ReceiveMessage(Message& msg) override; 37 | 38 | private: 39 | std::string FormatActorAddress(const NodeInfo& node, int port, bool forServer) const; 40 | std::string FormatActorIdentity(const NodeInfo& node) const; 41 | int ParseActorIdentity(const char* buf, size_t size) const; 42 | 43 | std::mutex mutex_; 44 | void* context_ = nullptr; 45 | void* receiver_ = nullptr; 46 | std::unordered_map senders_; 47 | }; 48 | 49 | } 50 | -------------------------------------------------------------------------------- /docker/centos7/compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 4 | # Copyright 2021 Mobvista 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # 18 | 19 | set -e 20 | pushd $(dirname ${BASH_SOURCE[0]})/../.. 21 | rm -rf build built 22 | path=/usr/local/gcc-7.3.0/bin:${PATH} 23 | path=/usr/local/cmake-3.20.3/bin:${path} 24 | path=/usr/local/ninja-1.10.2/bin:${path} 25 | path=/usr/local/thrift-0.14.1/bin:${path} 26 | prefix="/usr/local/spdlog-1.8.5" 27 | prefix="${prefix};/usr/local/pybind11-2.6.2" 28 | prefix="${prefix};/usr/local/aws-sdk-cpp-1.7.108" 29 | prefix="${prefix};/usr/local/boost-1.76.0" 30 | prefix="${prefix};/usr/local/thrift-0.14.1" 31 | prefix="${prefix};/usr/local/zeromq-4.3.4" 32 | env PATH=${path} \ 33 | PKG_CONFIG_PATH=/usr/local/json11-1.0.0/lib/pkgconfig \ 34 | LD_LIBRARY_PATH=/usr/local/gcc-7.3.0/lib64 \ 35 | /usr/local/cmake-3.20.3/bin/cmake \ 36 | -Wno-dev \ 37 | -G Ninja \ 38 | -DCMAKE_BUILD_TYPE=RelWithDebInfo \ 39 | -DCMAKE_INSTALL_PREFIX=built \ 40 | -DBUILD_SHARED_LIBS=OFF \ 41 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ 42 | -DCMAKE_CXX_COMPILER=/usr/local/gcc-7.3.0/bin/g++ \ 43 | -DCMAKE_MAKE_PROGRAM=/usr/local/ninja-1.10.2/bin/ninja \ 44 | -DCMAKE_PREFIX_PATH="${prefix}" \ 45 | -DCMAKE_CXX_FLAGS=-I/usr/local/dbg-macro-0.4.0/include \ 46 | -DPython_ROOT_DIR=/usr/local/python-3.7.7 \ 47 | -H. \ 48 | -Bbuild 49 | env PATH=${path} \ 50 | PKG_CONFIG_PATH=/usr/local/json11-1.0.0/lib/pkgconfig \ 51 | LD_LIBRARY_PATH=/usr/local/gcc-7.3.0/lib64 \ 52 | /usr/local/cmake-3.20.3/bin/cmake \ 53 | --build build 54 | popd 55 | echo OK 56 | -------------------------------------------------------------------------------- /docker/centos7/package.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 4 | # Copyright 2021 Mobvista 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # 18 | 19 | set -e 20 | pushd $(dirname ${BASH_SOURCE[0]})/../.. 21 | rm -rf build/python-env 22 | rm -rf build/python-env.tgz 23 | mkdir -p build/python-env 24 | tar -xf /usr/local/python-env-3.7.7.tgz -C build/python-env 25 | build/python-env/bin/python3.7 -m pip install --upgrade build/mindalpha-2.0.0+*-cp37-cp37m-linux_x86_64.whl pip 26 | find build/python-env/bin -type f -exec sed -i -e 's@^#!.\+/bin/python\(3\(\.7\)\?\)\?$@#!/usr/bin/env python3.7@' {} \; 27 | tar -czf build/python-env.tgz -C build/python-env $(ls build/python-env) 28 | rm -rf build/python-env 29 | popd 30 | echo OK 31 | -------------------------------------------------------------------------------- /docker/ubuntu20.04/compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 4 | # Copyright 2021 Mobvista 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # 18 | 19 | set -e 20 | pushd $(dirname ${BASH_SOURCE[0]})/../.. 21 | rm -rf build built 22 | mkdir -p build 23 | cd build 24 | cmake -G Ninja -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_INSTALL_PREFIX=../built .. 25 | ninja 26 | popd 27 | echo OK 28 | -------------------------------------------------------------------------------- /docker/ubuntu20.04/package.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 4 | # Copyright 2021 Mobvista 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # 18 | 19 | set -e 20 | pushd $(dirname ${BASH_SOURCE[0]})/../.. 21 | rm -rf build/python-env 22 | rm -rf build/python-env.tgz 23 | mkdir -p build/python-env 24 | tar -xf /usr/local/python-env-3.8.5.tgz -C build/python-env 25 | build/python-env/bin/python3.8 -m pip install --upgrade build/mindalpha-2.0.0+*-cp38-cp38-linux_x86_64.whl pip 26 | find build/python-env/bin -type f -exec sed -i -e 's@^#!.\+/bin/python\(3\(\.8\)\?\)\?$@#!/usr/bin/env python3.8@' {} \; 27 | tar -czf build/python-env.tgz -C build/python-env $(ls build/python-env) 28 | rm -rf build/python-env 29 | popd 30 | echo OK 31 | -------------------------------------------------------------------------------- /examples/deep_fm_example.py: -------------------------------------------------------------------------------- 1 | # 2 | # To run locally, execute: 3 | # 4 | # spark-submit --master local[2] deep_fm_example.py 5 | # 6 | 7 | S3_ROOT_DIR = 's3://{YOUR_S3_BUCKET}/{YOUR_S3_PATH}/' 8 | 9 | batch_size = 100 10 | worker_count = 1 11 | server_count = 1 12 | 13 | import mindalpha as ma 14 | spark = ma.spark.get_session(batch_size=batch_size, 15 | worker_count=worker_count, 16 | server_count=server_count, 17 | ) 18 | sc = spark.sparkContext 19 | 20 | with spark: 21 | module = ma.nn.DeepFMModule( 22 | wide_column_name_path=S3_ROOT_DIR + 'demo/schema/column_name_demo.txt', 23 | wide_combine_schema_path=S3_ROOT_DIR + 'demo/schema/combine_schema_demo.txt', 24 | cross_sparse_column_name_path=S3_ROOT_DIR + 'demo/schema/column_name_demo.txt', 25 | cross_sparse_combine_schema_path=S3_ROOT_DIR + 'demo/schema/combine_schema_demo.txt', 26 | deep_sparse_column_name_path=S3_ROOT_DIR + 'demo/schema/column_name_demo.txt', 27 | deep_sparse_combine_schema_path=S3_ROOT_DIR + 'demo/schema/combine_schema_demo.txt', 28 | ) 29 | 30 | model_out_path = S3_ROOT_DIR + 'demo/output/dev/model_out/' 31 | estimator = ma.PyTorchEstimator(module=module, 32 | worker_count=worker_count, 33 | server_count=server_count, 34 | model_out_path=model_out_path, 35 | input_label_column_index=0) 36 | 37 | train_dataset_path = S3_ROOT_DIR + 'demo/data/train/day_0_0.001_train.csv' 38 | train_dataset = ma.input.read_s3_csv(spark, train_dataset_path, delimiter='\t') 39 | model = estimator.fit(train_dataset) 40 | 41 | test_dataset_path = S3_ROOT_DIR + 'demo/data/test/day_0_0.001_test.csv' 42 | test_dataset = ma.input.read_s3_csv(spark, test_dataset_path, delimiter='\t') 43 | result = model.transform(test_dataset) 44 | result.show(5) 45 | 46 | import pyspark 47 | evaluator = pyspark.ml.evaluation.BinaryClassificationEvaluator() 48 | test_auc = evaluator.evaluate(result) 49 | print('test_auc: %g' % test_auc) 50 | -------------------------------------------------------------------------------- /examples/swing_estimator_example.py: -------------------------------------------------------------------------------- 1 | import mindalpha as ma 2 | spark = ma.spark.get_session() 3 | sc = spark.sparkContext 4 | 5 | with spark: 6 | input_path = "s3://{YOUR_S3_BUCKET}/{YOUR_S3_PATH}/example.csv" 7 | estimator = ma.SwingEstimator(user_id_column_name='_c0', 8 | item_id_column_name='_c1', 9 | behavior_column_name='_c3', 10 | behavior_filter_value='buy', 11 | cassandra_catalog='mycatalog', 12 | cassandra_host_ip='172.17.0.5', 13 | cassandra_port=9042, 14 | cassandra_db_name='testks', 15 | cassandra_table_name='recdb', 16 | ) 17 | dataset = ma.input.read_s3_csv(spark, input_path, delimiter=',') 18 | model = estimator.fit(dataset) 19 | model.transform(dataset).show() 20 | model = model.stringify() 21 | model.df.show() 22 | model.publish() 23 | -------------------------------------------------------------------------------- /examples/wide_and_deep_example.py: -------------------------------------------------------------------------------- 1 | # 2 | # To run locally, execute: 3 | # 4 | # spark-submit --master local[2] wide_and_deep_example.py 5 | # 6 | 7 | S3_ROOT_DIR = 's3://{YOUR_S3_BUCKET}/{YOUR_S3_PATH}/' 8 | 9 | batch_size = 100 10 | worker_count = 1 11 | server_count = 1 12 | 13 | import mindalpha as ma 14 | spark = ma.spark.get_session(batch_size=batch_size, 15 | worker_count=worker_count, 16 | server_count=server_count, 17 | ) 18 | sc = spark.sparkContext 19 | 20 | with spark: 21 | module = ma.nn.WideAndDeepModule( 22 | wide_column_name_path=S3_ROOT_DIR + 'demo/schema/column_name_demo.txt', 23 | wide_combine_schema_path=S3_ROOT_DIR + 'demo/schema/combine_schema_demo.txt', 24 | deep_sparse_column_name_path=S3_ROOT_DIR + 'demo/schema/column_name_demo.txt', 25 | deep_sparse_combine_schema_path=S3_ROOT_DIR + 'demo/schema/combine_schema_demo.txt', 26 | ) 27 | 28 | model_out_path = S3_ROOT_DIR + 'demo/output/dev/model_out/' 29 | estimator = ma.PyTorchEstimator(module=module, 30 | worker_count=worker_count, 31 | server_count=server_count, 32 | model_out_path=model_out_path, 33 | input_label_column_index=0) 34 | 35 | train_dataset_path = S3_ROOT_DIR + 'demo/data/train/day_0_0.001_train.csv' 36 | train_dataset = ma.input.read_s3_csv(spark, train_dataset_path, delimiter='\t') 37 | model = estimator.fit(train_dataset) 38 | 39 | test_dataset_path = S3_ROOT_DIR + 'demo/data/test/day_0_0.001_test.csv' 40 | test_dataset = ma.input.read_s3_csv(spark, test_dataset_path, delimiter='\t') 41 | result = model.transform(test_dataset) 42 | result.show(5) 43 | 44 | import pyspark 45 | evaluator = pyspark.ml.evaluation.BinaryClassificationEvaluator() 46 | test_auc = evaluator.evaluate(result) 47 | print('test_auc: %g' % test_auc) 48 | -------------------------------------------------------------------------------- /package.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 4 | # Copyright 2021 Mobvista 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # 18 | 19 | set -e 20 | pushd $(dirname ${BASH_SOURCE[0]}) 21 | tag=$(source /etc/os-release; echo ${ID}${VERSION_ID}) 22 | ./docker/${tag}/package.sh 23 | popd 24 | -------------------------------------------------------------------------------- /python/mindalpha/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from ._mindalpha import NodeRole 18 | from ._mindalpha import ActorConfig 19 | from ._mindalpha import PSRunner 20 | 21 | from .embedding import EmbeddingSumConcat 22 | from .embedding import EmbeddingRangeSum 23 | from .embedding import EmbeddingLookup 24 | 25 | from .cast import Cast 26 | 27 | from .initializer import TensorInitializer 28 | from .initializer import DefaultTensorInitializer 29 | from .initializer import ZeroTensorInitializer 30 | from .initializer import OneTensorInitializer 31 | from .initializer import NormalTensorInitializer 32 | from .initializer import XavierTensorInitializer 33 | 34 | from .updater import TensorUpdater 35 | from .updater import NoOpUpdater 36 | from .updater import SGDTensorUpdater 37 | from .updater import AdaGradTensorUpdater 38 | from .updater import AdamTensorUpdater 39 | from .updater import FTRLTensorUpdater 40 | from .updater import EMATensorUpdater 41 | 42 | from .agent import Agent 43 | from .model import Model 44 | from .model import SparseModel 45 | from .metric import ModelMetric 46 | from .distributed_trainer import DistributedTrainer 47 | from .experiment import Experiment 48 | 49 | try: 50 | import pyspark 51 | except ImportError: 52 | # Use findspark to simplify running job locally. 53 | try: 54 | import findspark 55 | findspark.init() 56 | except: 57 | pass 58 | 59 | try: 60 | import pyspark 61 | except ImportError: 62 | pass 63 | else: 64 | # PySpark may not be available at this point, 65 | # we import the classes in estimator only when 66 | # PySpark is ready. 67 | 68 | from .estimator import PyTorchAgent 69 | from .estimator import PyTorchLauncher 70 | from .estimator import PyTorchModel 71 | from .estimator import PyTorchEstimator 72 | 73 | from .two_tower_ranking import TwoTowerRankingModule 74 | from .two_tower_ranking import TwoTowerRankingAgent 75 | from .two_tower_ranking import TwoTowerRankingLauncher 76 | from .two_tower_ranking import TwoTowerRankingModel 77 | from .two_tower_ranking import TwoTowerRankingEstimator 78 | 79 | from .swing_retrieval import SwingModel 80 | from .swing_retrieval import SwingEstimator 81 | 82 | try: 83 | import pyspark 84 | import faiss 85 | except ImportError: 86 | pass 87 | else: 88 | from .two_tower_retrieval import TwoTowerRetrievalModule 89 | from .two_tower_retrieval import FaissIndexBuildingAgent 90 | from .two_tower_retrieval import FaissIndexRetrievalAgent 91 | from .two_tower_retrieval import TwoTowerRetrievalModel 92 | from .two_tower_retrieval import TwoTowerRetrievalEstimator 93 | 94 | from ._mindalpha import get_mindalpha_version 95 | __version__ = get_mindalpha_version() 96 | del get_mindalpha_version 97 | 98 | from . import nn 99 | from . import input 100 | from . import output 101 | from . import spark 102 | from . import patching_pickle 103 | from . import demo 104 | 105 | patching_pickle._patch_lookup_module_and_qualname() 106 | patching_pickle._patch_getsourcelines() 107 | -------------------------------------------------------------------------------- /python/mindalpha/compat/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from . import ps 18 | 19 | import sys 20 | sys.modules['ps.initializer'] = ps.initializer 21 | sys.modules['ps.updater'] = ps.updater 22 | 23 | def fixup_attributes(obj): 24 | names = dir(obj) 25 | for name in names: 26 | if not name.startswith('_'): 27 | continue 28 | if name.endswith('__'): 29 | continue 30 | i = name.find('__') 31 | if i == -1: 32 | continue 33 | new_name = name[i + 1:] 34 | value = getattr(obj, name) 35 | setattr(obj, new_name, value) 36 | delattr(obj, name) 37 | 38 | ps.Agent._criterion = ps.Agent._metric 39 | ps.Agent.update_criterion = ps.Agent.update_metric 40 | ps.Agent.push_criterion = ps.Agent.push_metric 41 | ps.Agent.clear_criterion = ps.Agent.clear_metric 42 | -------------------------------------------------------------------------------- /python/mindalpha/compat/ps/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from mindalpha import NodeRole as ActorRole 18 | from mindalpha import ActorConfig 19 | from mindalpha import PSRunner 20 | 21 | from mindalpha import EmbeddingSumConcat 22 | from mindalpha import EmbeddingRangeSum 23 | from mindalpha import EmbeddingLookup 24 | 25 | from mindalpha import TensorInitializer 26 | from mindalpha import DefaultTensorInitializer 27 | from mindalpha import ZeroTensorInitializer 28 | from mindalpha import OneTensorInitializer 29 | from mindalpha import NormalTensorInitializer 30 | from mindalpha import XavierTensorInitializer 31 | 32 | from mindalpha import TensorUpdater 33 | from mindalpha import NoOpUpdater 34 | from mindalpha import SGDTensorUpdater 35 | from mindalpha import AdaGradTensorUpdater 36 | from mindalpha import AdamTensorUpdater 37 | from mindalpha import FTRLTensorUpdater 38 | from mindalpha import EMATensorUpdater 39 | 40 | from mindalpha import Agent 41 | from mindalpha import Model 42 | from mindalpha import SparseModel 43 | from mindalpha import ModelMetric as ModelCriterion 44 | from mindalpha import DistributedTrainer 45 | 46 | try: 47 | import pyspark 48 | except ImportError: 49 | pass 50 | else: 51 | from mindalpha import PyTorchAgent 52 | from mindalpha import PyTorchLauncher 53 | from mindalpha import PyTorchModel 54 | from mindalpha import PyTorchEstimator 55 | 56 | from mindalpha import __version__ 57 | from mindalpha import _mindalpha as _ps 58 | 59 | from mindalpha import nn 60 | from mindalpha import input 61 | from mindalpha import spark 62 | 63 | from mindalpha import initializer 64 | from mindalpha import updater 65 | -------------------------------------------------------------------------------- /python/mindalpha/demo.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | def download_dataset(): 18 | import glob 19 | import subprocess 20 | GLOB_PATTERN = 'data/**/*.csv' 21 | NUM_FILES = 24 + 24 22 | if len(glob.glob(GLOB_PATTERN)) == NUM_FILES: 23 | print('MindAlpha demo dataset already downloaded') 24 | return 25 | string = "rm -rf data && " 26 | string += "mkdir -p data/train && " 27 | string += "cd data/train && " 28 | string += "curl -L -O https://mob-emr-test.s3.amazonaws.com/ml-platform/ml-ranking/data/criteo/0.001/train/day_{$(seq -s ',' 0 23)}_0.001_train.csv && " 29 | string += "cd ../.. && " 30 | string += "mkdir -p data/test && " 31 | string += "cd data/test && " 32 | string += "curl -L -O https://mob-emr-test.s3.amazonaws.com/ml-platform/ml-ranking/data/criteo/0.001/test/day_{$(seq -s ',' 0 23)}_0.001_test.csv && " 33 | string += "cd ../.. && " 34 | string += "echo OK: criteo" 35 | args = string, 36 | subprocess.check_call(args, shell=True, stderr=subprocess.PIPE) 37 | if len(glob.glob(GLOB_PATTERN)) == NUM_FILES: 38 | print('MindAlpha demo dataset downloaded') 39 | else: 40 | message = "fail to download the MindAlpha demo dataset; " 41 | message += "see https://mob-emr-test.s3.amazonaws.com/ml-platform/ml-ranking/data/criteo/0.001/index.html " 42 | message += "for more information" 43 | raise RuntimeError(message) 44 | -------------------------------------------------------------------------------- /python/mindalpha/file_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | def file_exists(url): 18 | import os 19 | from .s3_utils import s3_file_exists 20 | if url.startswith('s3://') or url.startswith('s3a://'): 21 | return s3_file_exists(url) 22 | else: 23 | return os.path.isfile(url) 24 | 25 | def dir_exists(url): 26 | import os 27 | from .s3_utils import get_s3_dir_size 28 | if url.startswith('s3://') or url.startswith('s3a://'): 29 | return get_s3_dir_size(url) > 0 30 | else: 31 | return os.path.isdir(url) 32 | 33 | def delete_dir(url): 34 | import os 35 | import shutil 36 | from .s3_utils import delete_s3_dir 37 | if url.startswith('s3://') or url.startswith('s3a://'): 38 | delete_s3_dir(url) 39 | else: 40 | if os.path.isdir(url): 41 | shutil.rmtree(url) 42 | 43 | def delete_file(url): 44 | import os 45 | from .s3_utils import delete_s3_file 46 | if url.startswith('s3://') or url.startswith('s3a://'): 47 | delete_s3_file(url) 48 | else: 49 | if os.path.isfile(url): 50 | os.remove(url) 51 | 52 | def copy_dir(src_url, dst_url): 53 | import shutil 54 | from .s3_utils import copy_s3_dir 55 | from .s3_utils import download_s3_dir 56 | from .s3_utils import upload_s3_dir 57 | if src_url.startswith('s3://') or src_url.startswith('s3a://'): 58 | if dst_url.startswith('s3://') or dst_url.startswith('s3a://'): 59 | copy_s3_dir(src_url, dst_url) 60 | else: 61 | download_s3_dir(src_url, dst_url) 62 | else: 63 | if dst_url.startswith('s3://') or dst_url.startswith('s3a://'): 64 | upload_s3_dir(src_url, dst_url) 65 | else: 66 | shutil.copytree(src_url, dst_url) 67 | -------------------------------------------------------------------------------- /python/mindalpha/input.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | def shuffle_df(df, num_workers): 18 | from pyspark.sql import functions as F 19 | df = df.withColumn('srand', F.rand()) 20 | df = df.repartition(2 * num_workers, 'srand') 21 | print('shuffle df to partitions {}'.format(df.rdd.getNumPartitions())) 22 | df = df.sortWithinPartitions('srand') 23 | df = df.drop('srand') 24 | return df 25 | 26 | def read_kudu(spark_session, url, column_name, sql=None, condition_select_conf='', shuffle=False, num_workers=1): 27 | from pyspark.sql import SQLContext 28 | from pyspark.sql import DataFrame 29 | 30 | queryCols=[] 31 | has_rand_column = False 32 | with open(column_name, 'r') as f_column_name: 33 | for line in f_column_name: 34 | line = line.split(" ")[-1].strip() 35 | if line == 'rand': 36 | has_rand_column = True 37 | queryCols.append(line) 38 | if not has_rand_column: 39 | print('append rand column by default') 40 | queryCols.append('rand') 41 | print('total cols from kudu: {}'.format(len(queryCols))) 42 | 43 | sc = spark_session.sparkContext 44 | ssqlContext = SQLContext(sc)._ssql_ctx 45 | jsparkSession = spark_session._jsparkSession 46 | if condition_select_conf == '': 47 | queryKudu = sc._jvm.com.mobvista.dataflow.apis.kuduUtils.QueryKudu.readKudu(jsparkSession, url, queryCols) 48 | else: 49 | print('use condition select conf: {}'.format(condition_select_conf)) 50 | queryKudu = sc._jvm.com.mobvista.dataflow.apis.kuduUtils.QueryKudu.readKudu(jsparkSession, url, queryCols, condition_select_conf) 51 | kudu_df_tmp = DataFrame(queryKudu, ssqlContext) 52 | kudu_df_tmp.createOrReplaceTempView("kudu_df_tmp") 53 | if sql is not None: 54 | kudu_df = spark_session.sql(sql) 55 | else: 56 | kudu_df = spark_session.sql('select * from kudu_df_tmp') 57 | 58 | if shuffle and num_workers > 1: 59 | kudu_df = shuffle_df(kudu_df, num_workers) 60 | else: 61 | print("ignore shuffle") 62 | if not has_rand_column: 63 | kudu_df = kudu_df.drop('rand') 64 | return kudu_df 65 | 66 | def read_s3_csv(spark_session, url, shuffle=False, num_workers=1, 67 | header=False, nullable=False, delimiter="\002", encoding="UTF-8"): 68 | from .url_utils import use_s3a 69 | df = (spark_session 70 | .read 71 | .format('csv') 72 | .option("header", str(bool(header)).lower()) 73 | .option("nullable", str(bool(nullable)).lower()) 74 | .option("delimiter", delimiter) 75 | .option("encoding", encoding) 76 | .load(use_s3a(url))) 77 | if shuffle and num_workers > 1: 78 | df = shuffle_df(df, num_workers) 79 | else: 80 | print("ignore shuffle") 81 | return df 82 | 83 | def read_s3_image(spark_session, url): 84 | from .url_utils import use_s3a 85 | df = spark_session.read.format('image').option('dropInvalid', 'true').load(use_s3a(url)) 86 | return df 87 | -------------------------------------------------------------------------------- /python/mindalpha/job_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | def normalize_storage_size(size): 18 | import re 19 | if not isinstance(size, str) or not re.match(r'\d+[MG]', size): 20 | message = "'size' must be a string like 4096M or 4G; " 21 | message += "%r is invalid" % size 22 | raise ValueError(message) 23 | value = int(size[:-1]) 24 | unit = size[-1] 25 | if unit == 'G': 26 | value *= 1024 27 | return value 28 | 29 | def merge_storage_size(worker_memory, server_memory): 30 | mem1 = normalize_storage_size(worker_memory) 31 | mem2 = normalize_storage_size(server_memory) 32 | mem = max(mem1, mem2) 33 | if mem % 1024 == 0: 34 | mem = '%dG' % (mem // 1024) 35 | else: 36 | mem = '%dM' % mem 37 | return mem 38 | -------------------------------------------------------------------------------- /python/mindalpha/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | 19 | def nansum(x): 20 | return torch.where(torch.isnan(x), torch.zeros_like(x), x).sum() 21 | 22 | def log_loss(yhat, y): 23 | return nansum(-(y * (yhat + 1e-12).log() + (1 - y) * (1 - yhat + 1e-12).log())) 24 | -------------------------------------------------------------------------------- /python/mindalpha/name_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import re 18 | 19 | def is_valid_qualified_name(name): 20 | pattern = r'^[A-Za-z_.\-][A-Za-z0-9_.\-]*$' 21 | match = re.match(pattern, name) 22 | return match is not None 23 | -------------------------------------------------------------------------------- /python/mindalpha/network_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | def get_host_ip(): 18 | import socket 19 | host_name = socket.gethostname() 20 | host_ip = socket.gethostbyname(host_name) 21 | return host_ip 22 | 23 | def get_available_endpoint(): 24 | import socket 25 | import random 26 | host_name = socket.gethostname() 27 | host_ip = socket.gethostbyname(host_name) 28 | addr_info = socket.getaddrinfo(host_ip, None) 29 | ip_family = addr_info[0][0] 30 | with socket.socket(ip_family, socket.SOCK_STREAM) as sock: 31 | try: 32 | sock.bind(('', 0)) 33 | _, port = sock.getsockname() 34 | return host_ip, port 35 | except socket.error as e: 36 | message = "can not find bindable port " 37 | message += "on host %s(%s)" % (host_name, host_ip) 38 | raise RuntimeError(message) from e 39 | -------------------------------------------------------------------------------- /python/mindalpha/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .normalization import Normalization 18 | from .fm import FMModule 19 | from .wide_and_deep import WideAndDeepModule 20 | from .deep_fm import DeepFMModule 21 | -------------------------------------------------------------------------------- /python/mindalpha/nn/deep_fm.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | from ..updater import FTRLTensorUpdater 19 | from ..initializer import NormalTensorInitializer 20 | from ..embedding import EmbeddingSumConcat 21 | from .fm import FMModule 22 | from .wide_and_deep import WideAndDeepModule 23 | 24 | class DeepFMModule(WideAndDeepModule): 25 | def __init__(self, 26 | cross_sparse_embedding_size=16, 27 | cross_sparse_column_name_path=None, 28 | cross_sparse_combine_schema_path=None, 29 | cross_sparse_updater=None, 30 | cross_sparse_initializer=None, 31 | **kwargs 32 | ): 33 | super().__init__(**kwargs) 34 | if cross_sparse_column_name_path is None: 35 | raise ValueError("cross_sparse_column_name_path is required") 36 | if cross_sparse_combine_schema_path is None: 37 | raise ValueError("cross_sparse_combine_schema_path is required") 38 | if cross_sparse_updater is None: 39 | cross_sparse_updater = FTRLTensorUpdater() 40 | if cross_sparse_initializer is None: 41 | cross_sparse_initializer = NormalTensorInitializer(var=0.01) 42 | self._cross_sparse_embedding_size = cross_sparse_embedding_size 43 | self._cross_sparse_column_name_path = cross_sparse_column_name_path 44 | self._cross_sparse_combine_schema_path = cross_sparse_combine_schema_path 45 | self._cross_sparse = EmbeddingSumConcat(self._cross_sparse_embedding_size, 46 | self._cross_sparse_column_name_path, 47 | self._cross_sparse_combine_schema_path) 48 | self._cross_sparse.updater = cross_sparse_updater 49 | self._cross_sparse.initializer = cross_sparse_initializer 50 | self._cross_sparse_feature_count = self._cross_sparse.feature_count 51 | self._fm = FMModule() 52 | 53 | def forward(self, inputs): 54 | wide_outputs = self._wide(inputs) 55 | wide_outputs = torch.sum(wide_outputs, dim=1, keepdim=True) 56 | cross_sparse_outputs = self._cross_sparse(inputs) 57 | cross_sparse_outputs = cross_sparse_outputs.reshape( 58 | -1, 59 | self._cross_sparse_feature_count, 60 | self._cross_sparse_embedding_size) 61 | fm_outputs = self._fm(cross_sparse_outputs) 62 | deep_sparse_outputs = self._deep_sparse(inputs) 63 | deep_outputs = self._deep_dense(deep_sparse_outputs) 64 | return torch.sigmoid(wide_outputs + fm_outputs + deep_outputs) 65 | -------------------------------------------------------------------------------- /python/mindalpha/nn/fm.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | 19 | class FMModule(torch.nn.Module): 20 | def __init__(self): 21 | super().__init__() 22 | 23 | def forward(self, inputs): 24 | square_of_sum = torch.pow(torch.sum(inputs, dim=1, keepdim=True), 2) 25 | sum_of_square = torch.sum(inputs * inputs, dim=1, keepdim=True) 26 | cross_term = square_of_sum - sum_of_square 27 | cross_term = 0.5 * torch.sum(cross_term, dim=2, keepdim=False) 28 | return cross_term 29 | -------------------------------------------------------------------------------- /python/mindalpha/nn/normalization.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | 20 | class Normalization(torch.nn.modules.batchnorm._BatchNorm): 21 | def _check_input_dim(self, input): 22 | if input.dim() != 2 and input.dim() != 3: 23 | raise ValueError('expected 3D or 3D input (got {}D input)'.format(input.dim())) 24 | 25 | def forward(self, input): 26 | self._check_input_dim(input) 27 | 28 | if not self.training: 29 | return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False) 30 | 31 | if self.momentum is None: 32 | exponential_average_factor = 0.0 33 | else: 34 | exponential_average_factor = self.momentum 35 | 36 | if self.training and self.track_running_stats: 37 | if self.num_batches_tracked is not None: 38 | self.num_batches_tracked = self.num_batches_tracked + 1 39 | if self.momentum is None: # use cumulative moving average 40 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 41 | else: # use exponential moving average 42 | exponential_average_factor = self.momentum 43 | 44 | if self.training: 45 | bn_training = True 46 | else: 47 | bn_training = (self.running_mean is None) and (self.running_var is None) 48 | 49 | batch_mean = input.mean(dim=0) 50 | batch_var = ((input - self.running_mean) * (input - self.running_mean)).mean(dim=0) 51 | output = (input - self.running_mean) / (self.running_var + self.eps).sqrt() 52 | if self.training: 53 | with torch.no_grad(): 54 | self.running_mean[...] = batch_mean 55 | self.running_var[...] = batch_var 56 | result = output * self.weight + self.bias 57 | return result 58 | -------------------------------------------------------------------------------- /python/mindalpha/output.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | def write_s3_csv(df, url, mode="overwrite", 18 | header=False, delimiter="\002", encoding="UTF-8"): 19 | from .url_utils import use_s3a 20 | df.write.csv(use_s3a(url), mode=mode, header=header, sep=delimiter, encoding=encoding) 21 | 22 | def config_cassandra(spark_session, catalog, host_ip, port=9042, user_name=None, password=None): 23 | catalog_key = f'spark.sql.catalog.{catalog}' 24 | catalog_value = 'com.datastax.spark.connector.datasource.CassandraCatalog' 25 | host_key = f'{catalog_key}.spark.cassandra.connection.host' 26 | port_key = f'{catalog_key}.spark.cassandra.connection.port' 27 | user_name_key = f'{catalog_key}.spark.cassandra.auth.username' 28 | password_key = f'{catalog_key}.spark.cassandra.auth.password' 29 | spark_session.conf.set(catalog_key, catalog_value) 30 | spark_session.conf.set(host_key, host_ip) 31 | spark_session.conf.set(port_key, str(port)) 32 | if user_name is not None: 33 | spark_session.conf.set(user_name_key, user_name) 34 | if password is not None: 35 | spark_session.conf.set(password_key, password) 36 | 37 | def ensure_cassandra_db(spark_session, catalog, db_name, 38 | db_properties="class='SimpleStrategy', replication_factor='1'"): 39 | spark_session.sql(f'CREATE DATABASE IF NOT EXISTS {catalog}.{db_name} ' 40 | f'WITH DBPROPERTIES ({db_properties})') 41 | 42 | def write_cassandra(df, catalog, db_name, table_name, partition_key='key', mode='overwrite'): 43 | table = f'{catalog}.{db_name}.{table_name}' 44 | df.write.partitionBy(partition_key).mode(mode).saveAsTable(table) 45 | -------------------------------------------------------------------------------- /python/mindalpha/s3_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | def parse_s3_url(s3_url): 18 | from urllib.parse import urlparse 19 | r = urlparse(s3_url, allow_fragments=False) 20 | if r.scheme not in ('s3', 's3a'): 21 | message = "invalid s3 url: %r" % (s3_url,) 22 | raise ValueError(message) 23 | path = r.path.lstrip('/') 24 | return r.netloc, path 25 | 26 | def parse_s3_dir_url(s3_url): 27 | bucket, path = parse_s3_url(s3_url) 28 | if not path.endswith('/'): 29 | path += '/' 30 | return bucket, path 31 | 32 | def get_aws_endpoint(): 33 | import os 34 | endpoint = os.environ.get('AWS_ENDPOINT') 35 | if endpoint is not None: 36 | if not endpoint.startswith('http://') and not endpoint.startswith('https://'): 37 | endpoint = 'http://' + endpoint 38 | return endpoint 39 | 40 | def get_s3_client(): 41 | import boto3 42 | endpoint = get_aws_endpoint() 43 | s3 = boto3.client('s3', endpoint_url=endpoint) 44 | return s3 45 | 46 | def get_s3_resource(): 47 | import boto3 48 | endpoint = get_aws_endpoint() 49 | s3 = boto3.resource('s3', endpoint_url=endpoint) 50 | return s3 51 | 52 | def get_s3_dir_size(dir_path): 53 | bucket, path = parse_s3_dir_url(dir_path) 54 | s3 = get_s3_client() 55 | objs = s3.list_objects(Bucket=bucket, Prefix=path) 56 | size = 0 57 | if 'Contents' in objs: 58 | for obj in objs['Contents']: 59 | size += obj['Size'] 60 | return size 61 | 62 | def s3_file_exists(file_path): 63 | bucket, path = parse_s3_url(file_path) 64 | s3 = get_s3_client() 65 | try: 66 | s3.head_object(Bucket=bucket, Key=path) 67 | except: 68 | return False 69 | else: 70 | return True 71 | 72 | def delete_s3_dir(dir_path): 73 | bucket, path = parse_s3_dir_url(dir_path) 74 | s3 = get_s3_resource() 75 | s3.Bucket(bucket).objects.filter(Prefix=path).delete() 76 | 77 | def delete_s3_file(file_path): 78 | bucket, path = parse_s3_url(file_path) 79 | s3 = get_s3_resource() 80 | s3.Object(bucket, path).delete() 81 | 82 | def copy_s3_dir(src_dir_path, dst_dir_path): 83 | src_bucket, src_dir = parse_s3_dir_url(src_dir_path) 84 | dst_bucket, dst_dir = parse_s3_dir_url(dst_dir_path) 85 | s3 = get_s3_resource() 86 | bucket = s3.Bucket(dst_bucket) 87 | for item in s3.Bucket(src_bucket).objects.filter(Prefix=src_dir): 88 | src = { 'Bucket' : item.bucket_name, 'Key' : item.key } 89 | dst = dst_dir + item.key[len(src_dir):] 90 | bucket.copy(src, dst) 91 | 92 | def download_s3_dir(src_dir_path, dst_dir_path): 93 | import os 94 | from . import _mindalpha 95 | src_bucket, src_dir = parse_s3_dir_url(src_dir_path) 96 | s3 = get_s3_resource() 97 | bucket = s3.Bucket(src_bucket) 98 | for item in bucket.objects.filter(Prefix=src_dir): 99 | src = item.key 100 | dst = os.path.join(dst_dir_path, item.key[len(src_dir):]) 101 | _mindalpha.ensure_local_directory(os.path.dirname(dst)) 102 | bucket.download_file(src, dst) 103 | 104 | def upload_s3_dir(src_dir_path, dst_dir_path): 105 | import os 106 | if not src_dir_path.endswith('/'): 107 | src_dir_path += '/' 108 | dst_bucket, dst_dir = parse_s3_dir_url(dst_dir_path) 109 | s3 = get_s3_resource() 110 | bucket = s3.Bucket(dst_bucket) 111 | for dirpath, dirnames, filenames in os.walk(src_dir_path): 112 | for filename in filenames: 113 | src = os.path.join(dirpath, filename) 114 | dst = dst_dir + src[len(src_dir_path):] 115 | bucket.upload_file(src, dst) 116 | -------------------------------------------------------------------------------- /python/mindalpha/shell_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import sys 18 | import string 19 | 20 | BASH_SAFE_CHARS = frozenset(string.ascii_letters + string.digits + '+-_=/%@:,.') 21 | 22 | def check_bash_string(item): 23 | if isinstance(item, str): 24 | return item 25 | elif isinstance(item, bytes): 26 | return item.decode() 27 | elif isinstance(item, (int, float)): 28 | return str(item) 29 | else: 30 | message = "'item' must be string or number; " 31 | message += "%r is not supported" % type(item) 32 | raise TypeError(message) 33 | 34 | def escape_bash_string(item): 35 | item = check_bash_string(item) 36 | if not item: 37 | return "''" 38 | if all(c in BASH_SAFE_CHARS for c in item): 39 | return item 40 | if len(item) == 1: 41 | c = item[0] 42 | return '"\'"' if c == "'" else "'%c'" % c 43 | prev_index = None 44 | string = '' 45 | index = item.find('=') 46 | if index != -1 and all(c in BASH_SAFE_CHARS for c in item[:index]): 47 | index += 1 48 | string = item[:index] 49 | item = item[index:] 50 | for index, c in enumerate(item): 51 | if c == "'": 52 | if prev_index is not None: 53 | prev_index = None 54 | item += "'" 55 | string += '"\'"' 56 | else: 57 | if prev_index is None: 58 | prev_index = index 59 | string += "'" 60 | string += c 61 | if prev_index is not None: 62 | string += "'" 63 | return string 64 | 65 | def escape_bash_command(command): 66 | if not isinstance(command, (list, tuple)): 67 | message = "'command' must be list or tuple; " 68 | message += "%r is not supported" % type(command) 69 | raise TypeError(message) 70 | if not command: 71 | message = "'command' can not be empty" 72 | raise ValueError(message) 73 | if len(command) == 1: 74 | return escape_bash_string(command[0]) 75 | return ' '.join(escape_bash_string(x) for x in command) 76 | 77 | def bash_escape(args): 78 | if not isinstance(args, (list, tuple)): 79 | return escape_bash_string(args) 80 | elif all(not isinstance(x, (list, tuple)) for x in args): 81 | return escape_bash_command(args) 82 | else: 83 | return '; '.join(escape_bash_command(x) for x in args) 84 | 85 | def wrap_message(color, message, *, check_stderr=False): 86 | stream = sys.stderr if check_stderr else sys.stdout 87 | is_atty = getattr(stream, 'isatty', None) 88 | if is_atty and is_atty(): 89 | message = '\033[%sm%s\033[m' % (color, message) 90 | return message 91 | 92 | def log_message(color, message): 93 | message = wrap_message(color, message, check_stderr=True) 94 | print(message, file=sys.stderr) 95 | 96 | def log_error(message): 97 | log_message('38;5;196', message) 98 | 99 | def log_warning(message): 100 | log_message('38;5;051', message) 101 | 102 | def log_info(message): 103 | log_message('38;5;231', message) 104 | 105 | def log_debug(message): 106 | log_message('38;5;240', message) 107 | 108 | def log_trace(message): 109 | log_message('38;5;046', message) 110 | 111 | def log_command(args, color=None): 112 | string = bash_escape(args) 113 | if color is None: 114 | log_debug(string) 115 | else: 116 | log_message(color, string) 117 | -------------------------------------------------------------------------------- /python/mindalpha/stack_trace_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import os 18 | import threading 19 | import ctypes 20 | 21 | def gettid(): 22 | SYS_gettid = 186 23 | libc = ctypes.cdll.LoadLibrary('libc.so.6') 24 | tid = libc.syscall(SYS_gettid) 25 | return tid 26 | 27 | def get_thread_identifier(): 28 | string = 'pid: %d, ' % os.getpid() 29 | string += 'tid: %d, ' % gettid() 30 | string += 'thread: 0x%x' % threading.current_thread().ident 31 | return string 32 | -------------------------------------------------------------------------------- /python/mindalpha/url_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | def use_s3(url): 18 | return url.replace('s3a://', 's3://') 19 | 20 | def use_s3a(url): 21 | return url.replace('s3://', 's3a://') 22 | -------------------------------------------------------------------------------- /python/ps/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from mindalpha.compat.ps import * 18 | from mindalpha.compat.ps import __version__ 19 | from mindalpha.compat.ps import _ps 20 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Mobvista 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | # 18 | # This script derives from the following link: 19 | # 20 | # https://stackoverflow.com/questions/42585210/extending-setuptools-extension-to-use-cmake-in-setup-py 21 | # 22 | 23 | from setuptools import setup 24 | from setuptools import Extension 25 | from setuptools.command.build_ext import build_ext 26 | 27 | class MindAlphaExtension(Extension): 28 | def __init__(self, name): 29 | super().__init__(name, sources=[]) 30 | 31 | class mindalpha_build_ext(build_ext): 32 | def run(self): 33 | for ext in self.extensions: 34 | self.build_mindalpha(ext) 35 | 36 | def get_mindalpha_so_path(self): 37 | import os 38 | key = '_MINDALPHA_SO' 39 | path = os.environ.get(key) 40 | if path is None: 41 | message = "environment variable %r is not set; " % key 42 | message += "can not find path of '_mindalpha.so'" 43 | raise RuntimeError(message) 44 | if not os.path.isfile(path): 45 | message = "'_mindalpha.so' is not found at %r" % path 46 | raise RuntimeError(message) 47 | return path 48 | 49 | def build_mindalpha(self, ext): 50 | import shutil 51 | mindalpha_so_path = self.get_mindalpha_so_path() 52 | ext_so_path = self.get_ext_fullpath(ext.name) 53 | shutil.copy(mindalpha_so_path, ext_so_path) 54 | 55 | def get_mindalpha_version(): 56 | import os 57 | key = '_MINDALPHA_VERSION' 58 | mindalpha_version = os.environ.get(key) 59 | if mindalpha_version is None: 60 | message = "environment variable %r is not set; " % key 61 | message += "can not get MindAlpha wheel version" 62 | raise RuntimeError(message) 63 | return mindalpha_version 64 | 65 | setup(name='mindalpha', 66 | version=get_mindalpha_version(), 67 | description="MindAlpha machine learning platform.", 68 | packages=['mindalpha', 'mindalpha.nn', 'mindalpha.compat', 'mindalpha.compat.ps', 'ps'], 69 | ext_modules=[MindAlphaExtension('mindalpha/_mindalpha')], 70 | cmdclass={ 'build_ext': mindalpha_build_ext }, 71 | install_requires=['numpy>=1.20.1', 72 | 'pandas>=1.2.3', 73 | 'nest_asyncio>=1.5.1', 74 | 'cloudpickle>=1.6.0', 75 | 'pyarrow>=3.0.0', 76 | 'PyYAML>=5.3.1', 77 | 'boto3>=1.17.41', 78 | 'python-consul>=1.1.0', 79 | 'findspark>=1.4.2']) 80 | -------------------------------------------------------------------------------- /run_build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | function build_images() { 4 | local name=$1 5 | local images_no_cache=$2 6 | local build_for_centos=$3 7 | if [[ $build_for_centos -eq 1 ]]; then 8 | build_dir=docker/centos7 9 | else 10 | build_dir=docker/ubuntu20.04 11 | fi 12 | dockerfile=$build_dir/Dockerfile 13 | docker build $images_no_cache -f $dockerfile -t $name $build_dir 14 | } 15 | 16 | function build_mindalpha() { 17 | local image_name=$1 18 | local running_image_name=$2 19 | local package_mindalpha=$3 20 | docker ps | grep $running_image_name 21 | if [[ $? -ne 0 ]]; then 22 | docker run -dt --net=host --name $running_image_name \ 23 | --cap-add=SYS_PTRACE --cap-add=SYS_NICE --security-opt seccomp=unconfined \ 24 | -e TERM=xterm-256color -e COLUMNS="`tput cols`" -e LINES="`tput lines`" \ 25 | -v /home:/home \ 26 | $image_name 27 | else 28 | docker start $running_image_name 29 | fi 30 | l_base_dir=$(pwd) 31 | build_dir="${l_base_dir}/build" 32 | if [[ ! -d $build_dir ]]; then 33 | mkdir $build_dir 34 | fi 35 | if [[ $package_mindalpha -eq 1 ]]; then 36 | docker exec -t -w ${build_dir} ${running_image_name} /bin/bash \ 37 | -c "source ~/.bashrc && cd $l_base_dir && bash compile.sh && bash package.sh" 38 | else 39 | docker exec -t -w ${build_dir} ${running_image_name} /bin/bash \ 40 | -c "source ~/.bashrc && cd $l_base_dir && bash compile.sh" 41 | fi 42 | } 43 | 44 | function print_help() { 45 | echo "usage $0 -n tagname -u usertag -i(build_images) -c(for_centos) -C(no_cache) -m(build_mindalpha) -p(package_mindalpha) -h(help)" 46 | exit -1 47 | } 48 | 49 | function main() { 50 | default_ubuntu_tags_name="mindalpha-build-ubuntu20.04:v1.0" 51 | default_centos_tags_name="mindalpha-build-centos7:v1.0" 52 | tags_name="" 53 | user_tag=$(whoami) 54 | 55 | images=0 56 | build_for_centos=0 57 | images_no_cache="" 58 | build_mindalpha=0 59 | package_mindalpha=0 60 | 61 | while getopts nu:icCmph OPTION 62 | do 63 | case ${OPTION} in 64 | h) 65 | print_help 66 | ;; 67 | i) 68 | images=1 69 | ;; 70 | c) 71 | build_for_centos=1 72 | ;; 73 | C) 74 | images_no_cache="--no-cache" 75 | ;; 76 | m) 77 | build_mindalpha=1 78 | ;; 79 | p) 80 | package_mindalpha=1 81 | ;; 82 | n) 83 | tags_name=${OPTARG} 84 | ;; 85 | u) 86 | user_tag=${OPTARG} 87 | ;; 88 | esac 89 | done 90 | if [[ -z "$tags_name" ]]; then 91 | if [[ $build_for_centos -eq 1 ]]; then 92 | tags_name=$default_centos_tags_name 93 | else 94 | tags_name=$default_ubuntu_tags_name 95 | fi 96 | fi 97 | images_name=$(echo $tags_name | sed 's/:/-/g') 98 | running_image_name=$user_tag-$images_name-env 99 | if [[ $images -eq 1 ]]; then 100 | build_images $tags_name "$images_no_cache" $build_for_centos 101 | fi 102 | if [[ $build_mindalpha -eq 1 ]]; then 103 | build_mindalpha $tags_name $running_image_name $package_mindalpha 104 | fi 105 | } 106 | 107 | main $* 108 | -------------------------------------------------------------------------------- /thrift/mindalpha/message_meta.thrift: -------------------------------------------------------------------------------- 1 | // 2 | // Copyright 2021 Mobvista 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // 16 | 17 | // 18 | // ``message_meta.thrift`` defines the Thrift representation of ``MessageMeta``. 19 | // 20 | // This is used to serialize ``MessageMeta`` to and from byte buffer. Currently, 21 | // only the serialization mechanism of Thrift is used. 22 | // 23 | 24 | namespace cpp mindalpha 25 | 26 | enum TNodeRole 27 | { 28 | Null = -1; 29 | Coordinator = 0; 30 | Server = 1; 31 | Worker = 2; 32 | } 33 | 34 | struct TNodeInfo 35 | { 36 | 1: required TNodeRole role; 37 | 2: required i32 nodeId; 38 | 3: required string hostName; 39 | 4: required i32 port; 40 | } 41 | 42 | enum TNodeControlCommand 43 | { 44 | Null = -1; 45 | Terminate = 0; 46 | AddNode = 1; 47 | Barrier = 2; 48 | } 49 | 50 | struct TNodeControl 51 | { 52 | 1: required TNodeControlCommand command; 53 | 2: required list nodes; 54 | 3: required i32 barrierGroup; 55 | } 56 | 57 | enum TDataType 58 | { 59 | Null = -1; 60 | Int8 = 0; 61 | Int16 = 1; 62 | Int32 = 2; 63 | Int64 = 3; 64 | UInt8 = 4; 65 | UInt16 = 5; 66 | UInt32 = 6; 67 | UInt64 = 7; 68 | Float32 = 8; 69 | Float64 = 9; 70 | } 71 | 72 | struct TMessageMeta 73 | { 74 | 1: required i32 messageId; 75 | 2: required i32 sender; 76 | 3: required i32 receiver; 77 | 4: required bool isRequest; 78 | 5: required bool isException; 79 | 6: required string body; 80 | 7: required list sliceDataTypes; 81 | 8: required TNodeControl control; 82 | } 83 | -------------------------------------------------------------------------------- /tutorials/schema/column_name_demo.txt: -------------------------------------------------------------------------------- 1 | 0 label 2 | 1 integer_feature_1 3 | 2 integer_feature_2 4 | 3 integer_feature_3 5 | 4 integer_feature_4 6 | 5 integer_feature_5 7 | 6 integer_feature_6 8 | 7 integer_feature_7 9 | 8 integer_feature_8 10 | 9 integer_feature_9 11 | 10 integer_feature_10 12 | 11 integer_feature_11 13 | 12 integer_feature_12 14 | 13 integer_feature_13 15 | 14 categorical_feature_1 16 | 15 categorical_feature_2 17 | 16 categorical_feature_3 18 | 17 categorical_feature_4 19 | 18 categorical_feature_5 20 | 19 categorical_feature_6 21 | 20 categorical_feature_7 22 | 21 categorical_feature_8 23 | 22 categorical_feature_9 24 | 23 categorical_feature_10 25 | 24 categorical_feature_11 26 | 25 categorical_feature_12 27 | 26 categorical_feature_13 28 | 27 categorical_feature_14 29 | 28 categorical_feature_15 30 | 29 categorical_feature_16 31 | 30 categorical_feature_17 32 | 31 categorical_feature_18 33 | 32 categorical_feature_19 34 | 33 categorical_feature_20 35 | 34 categorical_feature_21 36 | 35 categorical_feature_22 37 | 36 categorical_feature_23 38 | 37 categorical_feature_24 39 | 38 categorical_feature_25 40 | 39 categorical_feature_26 41 | -------------------------------------------------------------------------------- /tutorials/schema/combine_schema_demo.txt: -------------------------------------------------------------------------------- 1 | integer_feature_1 2 | integer_feature_2 3 | integer_feature_3 4 | integer_feature_4 5 | integer_feature_5 6 | integer_feature_6 7 | integer_feature_7 8 | integer_feature_8 9 | integer_feature_9 10 | integer_feature_10 11 | integer_feature_11 12 | integer_feature_12 13 | integer_feature_13 14 | categorical_feature_1 15 | categorical_feature_2 16 | categorical_feature_3 17 | categorical_feature_4 18 | categorical_feature_5 19 | categorical_feature_6 20 | categorical_feature_7 21 | categorical_feature_8 22 | categorical_feature_9 23 | categorical_feature_10 24 | categorical_feature_11 25 | categorical_feature_12 26 | categorical_feature_13 27 | categorical_feature_14 28 | categorical_feature_15 29 | categorical_feature_16 30 | categorical_feature_17 31 | categorical_feature_18 32 | categorical_feature_19 33 | categorical_feature_20 34 | categorical_feature_21 35 | categorical_feature_22 36 | categorical_feature_23 37 | categorical_feature_24 38 | categorical_feature_25 39 | categorical_feature_26 40 | integer_feature_1#categorical_feature_2 41 | integer_feature_5#categorical_feature_10#categorical_feature_5 42 | --------------------------------------------------------------------------------