├── CMakeLists.txt ├── README.md ├── base ├── CMakeLists.txt ├── class_register.h ├── class_register_test.cc ├── class_register_test_helper.cc ├── class_register_test_helper.h ├── common.h ├── cvector.h ├── cvector_test.cc ├── logging.cc ├── logging.h ├── random.cc ├── random.h ├── scoped_ptr.h ├── stdint_msvc.h ├── stl-util.h ├── stl-util_test.cc ├── stream_wrapper.cc ├── stream_wrapper.h ├── varint32.cc ├── varint32.h └── varint32_test.cc ├── hash ├── CMakeLists.txt ├── md5_hash.cc ├── md5_hash.h ├── md5_hash_test.cc ├── simple_hash.cc └── simple_hash.h ├── mrml-lasso ├── CMakeLists.txt ├── Makefile.cygwin ├── Makefile.rules ├── Makefile.tencent ├── Makefile.ubuntu ├── command_line_options.cc ├── command_line_options.h ├── dense_vector_tmpl.h ├── dense_vector_tmpl_test.cc ├── dump_learner_states.cc ├── learner.cc ├── learner.h ├── learner_dense_impl.h ├── learner_sparse_impl.h ├── learner_states.cc ├── learner_states.h ├── learner_states_test.cc ├── learner_test.cc ├── logistic_regression.proto ├── mr_assign_feature_id.cc ├── mr_assign_feature_id.h ├── mr_convert_data_format.cc ├── mr_convert_data_format.h ├── mrml_mappers_and_reducers.cc ├── mrml_mappers_and_reducers.h ├── prediction_engine.cc ├── prediction_engine.h ├── run.sh ├── sparse_vector_tmpl.h ├── sparse_vector_tmpl_test.cc ├── tags ├── termination_flag.h ├── termination_flag_test.cc ├── test_utils.h ├── testdata │ ├── input-00000-of-00001 │ └── tiny ├── train.cc ├── vector_types.cc ├── vector_types.h └── vector_types_test.cc ├── mrml ├── CMakeLists.txt ├── codex.cc ├── mr.h ├── mrml.cc ├── mrml.h ├── mrml.proto ├── mrml_filesystem.cc ├── mrml_filesystem.h ├── mrml_filesystem_test.cc ├── mrml_main.cc ├── mrml_reader.cc ├── mrml_reader.h ├── mrml_recordio.cc ├── mrml_recordio.h ├── mrml_recordio_test.cc └── testdata │ ├── input-00000-of-00002 │ └── input-00001-of-00002 ├── sorted_buffer ├── CMakeLists.txt ├── memory_allocator.cc ├── memory_allocator.h ├── memory_allocator_test.cc ├── memory_piece.cc ├── memory_piece.h ├── memory_piece_io_test.cc ├── memory_piece_less_than_test.cc ├── memory_piece_test.cc ├── sorted_buffer.cc ├── sorted_buffer.h ├── sorted_buffer_iterator.cc ├── sorted_buffer_iterator.h ├── sorted_buffer_iterator_test.cc ├── sorted_buffer_regression_test.cc └── sorted_buffer_test.cc ├── strutil ├── CMakeLists.txt ├── Makefile ├── join_strings.h ├── join_strings_test.cc ├── split_string.cc ├── split_string.h ├── split_string_test.cc ├── strcodec.cc ├── strcodec.h ├── strcodec_test.cc ├── stringprintf.cc ├── stringprintf.h └── stringprintf_test.cc └── system ├── CMakeLists.txt ├── condition_variable.cc ├── condition_variable.h ├── condition_variable_test.cc ├── filepattern.cc ├── filepattern.h ├── filepattern_test.cc ├── mutex.h ├── mutex_test.cc └── scoped_locker.h /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # To build the whole paralgo project, you can type the following shell 2 | # commands: 3 | # 4 | # $> mkdir build 5 | # $> cd build 6 | # $> cmake .. 7 | # $> make 8 | # 9 | # Thus you check out the paralgo project and build it in a 10 | # subdirectory ``build''. If you want further to install the built 11 | # project, you can modify the default installation directory: 12 | # 13 | # set(CMAKE_INSTALL_PREFIX "/home/public/paralgo") 14 | # 15 | # and type the command 16 | # 17 | # $> make install 18 | # 19 | # If you want to use distcc for a distributed build, substitute above 20 | # command 21 | # cmake ../paralgo 22 | # by 23 | # CC=distcc cmake ../paralgo 24 | # 25 | project ("paralgo") 26 | 27 | cmake_minimum_required(VERSION 2.8) # Requires 2.8 for protobuf support. 28 | 29 | #------------------------------------------------------------------------------ 30 | # Add protobuf compilation support 31 | #------------------------------------------------------------------------------ 32 | include("FindProtobuf") 33 | find_package(Protobuf REQUIRED) 34 | 35 | #------------------------------------------------------------------------------ 36 | # Take almost all warnining; 37 | # Take warnings as errors; 38 | # Do not generate debug symbols; 39 | # Optimization level 2; 40 | #------------------------------------------------------------------------------ 41 | add_definitions(" -Wall -Wno-sign-compare -Werror -O2 ") 42 | 43 | #------------------------------------------------------------------------------ 44 | # Declare where our project will be installed. 45 | #------------------------------------------------------------------------------ 46 | set(CMAKE_INSTALL_PREFIX "/home/public/paralgo") 47 | 48 | #------------------------------------------------------------------------------ 49 | # The following flags ensure that executables are statically linked 50 | # with libraries. This makes it easy to deploy your executable across 51 | # a computer cluster. However, MacOS X does not support these flags, 52 | # so you might want to comment them out if you develop on MacOS X. 53 | # ------------------------------------------------------------------------------ 54 | # set(CMAKE_EXE_LINKER_FLAGS "-static -static-libgcc") 55 | 56 | #------------------------------------------------------------------------------ 57 | # Declare where the third party libraries were installed. 58 | # 59 | # If you are building on and for Linux, it is recommended to install 60 | # all the following dependents by yourself, using --enable-static and 61 | # --disabel-shared flags with the configure script. This ensures that 62 | # your binary links statistically with all dependents and is the only 63 | # file you need to deploy. 64 | # 65 | # MRML-lasso depends on the following thridparty pacakges: 66 | # 67 | # - protobuf 68 | # - boost 69 | # - gflags 70 | # - openssl 71 | # - libssh2 72 | # - mpich2 73 | # 74 | # However, if you are building on MacOS X. It would be much easier to 75 | # install above package using Homebrew, and you will find all headers 76 | # files at /usr/local/include and all libraries in /usr/local/lib. 77 | # ------------------------------------------------------------------------------ 78 | # set(THIRD_PARTY_DIR "${PROJECT_SOURCE_DIR}/../thirdparty") 79 | # set(PROTOBUF_DIR "${THIRD_PARTY_DIR}/protobuf") 80 | # set(BOOST_DIR "${THIRD_PARTY_DIR}/boost") 81 | # set(GFLAGS_DIR "${THIRD_PARTY_DIR}/gflags") 82 | # set(MPICH2_DIR "${THIRD_PARTY_DIR}/mpich") 83 | # set(OPENSSL_DIR "${THIRD_PARTY_DIR}/openssl") 84 | # set(LIBSSH2_DIR "${THIRD_PARTY_DIR}/libssh2") 85 | 86 | #------------------------------------------------------------------------------ 87 | # Set include paths. 88 | # Add new lines below if you installed more 3rd-party libs. 89 | #------------------------------------------------------------------------------ 90 | include_directories( 91 | "${PROJECT_SOURCE_DIR}" 92 | "${PROJECT_BINARY_DIR}" 93 | "${PROJECT_SOURCE_DIR}/gtest/include" 94 | # "${PROTOBUF_DIR}/include"; 95 | # "${MPICH2_DIR}/include"; 96 | # "${OPENSSL_DIR}/include"; 97 | # "${LIBSSH2_DIR}/include"; 98 | # "${BOOST_DIR}"; 99 | # "${GFLAGS_DIR}/include"; 100 | "/usr/local/include" 101 | ) 102 | 103 | #------------------------------------------------------------------------------ 104 | # Set libray paths. 105 | # Add new lines below if you add new packages in paralgo project, or installed 106 | # more 3rd-party libs. 107 | #------------------------------------------------------------------------------ 108 | link_directories( 109 | "${PROJECT_BINARY_DIR}/base" 110 | "${PROJECT_BINARY_DIR}/strutil" 111 | "${PROJECT_BINARY_DIR}/hash" 112 | "${PROJECT_BINARY_DIR}/sorted_buffer" 113 | "${PROJECT_BINARY_DIR}/mrml" 114 | # "${PROTOBUF_DIR}/lib"; 115 | # "${MPICH2_DIR}/lib"; 116 | # "${OPENSSL_DIR}/lib"; 117 | # "${LIBSSH2_DIR}/lib"; 118 | # "${BOOST_DIR}/stage/lib"; 119 | # "${GFLAGS_DIR}/lib"; 120 | "/usr/local/lib" 121 | ) 122 | 123 | #------------------------------------------------------------------------------ 124 | # Declare packages in paralgo project. 125 | # Add new lines below if you add new packages in paralgo project. 126 | #------------------------------------------------------------------------------ 127 | add_subdirectory(gtest) 128 | add_subdirectory(base) 129 | add_subdirectory(strutil) 130 | add_subdirectory(hash) 131 | add_subdirectory(sorted_buffer) 132 | add_subdirectory(mrml) 133 | add_subdirectory(mrml-lasso) 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LASSO 2 | 3 | ## Introduction 4 | 5 | LASSO is a parallel machine learning system that learns a regression 6 | model from large data. It works in either of two modes: 7 | 8 | 1. IPM-mode. In this mode, you start multiple training processes 9 | running the `mrml-lasso/train` program on one or more computers. 10 | Each process learns a model from its local part of data. After all 11 | processes are finished, these models are aggregated into one using 12 | the Iterative Parameter Mixtures (IPM) technology. 13 | 14 | 1. MPI-mode. In this model, you start a process running the 15 | `mrml-lasso/mrml-lasso` program, which will start more processes 16 | using MPI. After every iteration, these processes exchange their 17 | opinions and update the model. Since MPI-mode induces more data 18 | exchanges than IPM-mode, it is less scalable. 19 | 20 | In either mode, LASSO learns a logistic regression model with 21 | L1-regularization using the OWL-QN training algorithm. For more 22 | details about this algorithm, please refer to: 23 | 24 | - http://en.wikipedia.org/wiki/Limited-memory_BFGS, or 25 | - http://research.microsoft.com/en-us/downloads/b1eb1016-1738-4bd5-83a9-370c9d498a03/default.aspx 26 | 27 | ## Motivation 28 | 29 | This project serves as a baseline training system for the grand 30 | challenge in IEEE ICME 2014. To win this challenge, you need to be 31 | able to handle large training corpus generated from real Internet 32 | services. You can develop your own system, or try this one. 33 | 34 | ## Installation 35 | 36 | LASSO was developed and tested on MacOS X and Linux. It should be able to run on FreeBSD. 37 | 38 | ### Dependents 39 | 40 | LASSO depends on the following thirdparty libraries: 41 | 42 | 1. protobuf 43 | 1. boost 44 | 1. gflags 45 | 1. openssl 46 | 1. libssh2 47 | 1. mpich2 48 | 49 | On MacOS X, it is recommended to install these packages using Homebrew. Homebrew makes sure that all header files come to folder `/usr/loca/include` and all libraries come to `/usr/local/lib`. 50 | 51 | On Linux, you can install these packages using package management systems or build your own copy from source code. In this case, you might need to edit the `CMakeLists.txt` file to tell `cmake` where these packages are installed. Please refer to comments in `CMakeLists.txt` as a guide on how to edit it. 52 | 53 | To make it easy to deploy LASSO on many computers, we prefer *static linking* to above libraries and the GCC runtiem library during building. This can be controlled by adding the following line to the `CMakeLists.txt` file: 54 | 55 | set(CMAKE_EXE_LINKER_FLAGS "-static -static-libgcc") 56 | 57 | With `-static-libgcc`, you should not need to worry that all computers in your cluster run the same version of GCC runtime. 58 | 59 | Notice that above linker flags are not supported on MacOS X. It is reasonable anyway as MacOS X is a desktop system, and it is efficient for desktop applications sharing common components as shared libraries. 60 | 61 | ### Checkout and Build 62 | 63 | With above dependents installed, you can simply checkout the code and build it using `cmake`. 64 | 65 | cd ~ 66 | git clone https://github.com/wangkuiyi/lasso 67 | cd /tmp 68 | cmake ~/lasso 69 | make 70 | make install 71 | 72 | The `make install` commmand copies built software to a directory specified in `CMakeLists.txt` by the directive 73 | 74 | set(CMAKE_INSTALL_PREFIX "/home/public/paralgo") 75 | -------------------------------------------------------------------------------- /base/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Build library strutil. 2 | add_library(base logging.cc random.cc varint32.cc stream_wrapper.cc) 3 | 4 | # Build unittests. 5 | set(LIBS base protobuf gflags gtest pthread) 6 | 7 | add_executable(class_register_test class_register_test.cc class_register_test_helper.cc) 8 | target_link_libraries(class_register_test gtest_main ${LIBS}) 9 | 10 | add_executable(cvector_test cvector_test.cc) 11 | target_link_libraries(cvector_test gtest_main ${LIBS}) 12 | 13 | add_executable(stl-util_test stl-util_test.cc) 14 | target_link_libraries(stl-util_test gtest_main ${LIBS}) 15 | 16 | add_executable(varint32_test varint32_test.cc) 17 | target_link_libraries(varint32_test gtest_main ${LIBS}) 18 | 19 | # Install library and header files 20 | install(TARGETS base DESTINATION bin/base) 21 | FILE(GLOB HEADER_FILES "${CMAKE_CURRENT_SOURCE_DIR}/*.h") 22 | install(FILES ${HEADER_FILES} DESTINATION include/base) 23 | -------------------------------------------------------------------------------- /base/class_register_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | #include "base/class_register.h" 6 | #include "base/class_register_test_helper.h" 7 | #include "base/scoped_ptr.h" 8 | #include "gtest/gtest.h" 9 | 10 | 11 | TEST(ClassRegister, CreateMapper) { 12 | scoped_ptr mapper; 13 | mapper.reset(CREATE_MAPPER("")); 14 | EXPECT_TRUE(mapper.get() == NULL); 15 | 16 | mapper.reset(CREATE_MAPPER("HelloMapper ")); 17 | EXPECT_TRUE(mapper.get() == NULL); 18 | 19 | mapper.reset(CREATE_MAPPER("HelloWorldMapper")); 20 | EXPECT_TRUE(mapper.get() == NULL); 21 | 22 | mapper.reset(CREATE_MAPPER("HelloReducer")); 23 | EXPECT_TRUE(mapper.get() == NULL); 24 | 25 | mapper.reset(CREATE_MAPPER("WorldReducer")); 26 | EXPECT_TRUE(mapper.get() == NULL); 27 | 28 | mapper.reset(CREATE_MAPPER("SecondaryMapper")); 29 | EXPECT_TRUE(mapper.get() == NULL); 30 | 31 | mapper.reset(CREATE_MAPPER("HelloMapper")); 32 | ASSERT_TRUE(mapper.get() != NULL); 33 | EXPECT_EQ("HelloMapper", mapper->GetMapperName()); 34 | 35 | mapper.reset(CREATE_MAPPER("WorldMapper")); 36 | ASSERT_TRUE(mapper.get() != NULL); 37 | EXPECT_EQ("WorldMapper", mapper->GetMapperName()); 38 | } 39 | 40 | TEST(ClassRegister, CreateSecondaryMapper) { 41 | scoped_ptr mapper; 42 | mapper.reset(CREATE_SECONDARY_MAPPER("")); 43 | EXPECT_TRUE(mapper.get() == NULL); 44 | 45 | mapper.reset(CREATE_SECONDARY_MAPPER("SecondaryMapper ")); 46 | EXPECT_TRUE(mapper.get() == NULL); 47 | 48 | mapper.reset(CREATE_SECONDARY_MAPPER("HelloWorldMapper")); 49 | EXPECT_TRUE(mapper.get() == NULL); 50 | 51 | mapper.reset(CREATE_SECONDARY_MAPPER("HelloReducer")); 52 | EXPECT_TRUE(mapper.get() == NULL); 53 | 54 | mapper.reset(CREATE_SECONDARY_MAPPER("WorldReducer")); 55 | EXPECT_TRUE(mapper.get() == NULL); 56 | 57 | mapper.reset(CREATE_SECONDARY_MAPPER("HelloMapper")); 58 | EXPECT_TRUE(mapper.get() == NULL); 59 | 60 | mapper.reset(CREATE_SECONDARY_MAPPER("WorldMapper")); 61 | EXPECT_TRUE(mapper.get() == NULL); 62 | 63 | mapper.reset(CREATE_SECONDARY_MAPPER("SecondaryMapper")); 64 | ASSERT_TRUE(mapper.get() != NULL); 65 | EXPECT_EQ("SecondaryMapper", mapper->GetMapperName()); 66 | } 67 | 68 | TEST(ClassRegister, CreateReducer) { 69 | scoped_ptr reducer; 70 | reducer.reset(CREATE_REDUCER("")); 71 | EXPECT_TRUE(reducer.get() == NULL); 72 | 73 | reducer.reset(CREATE_REDUCER("HelloReducer ")); 74 | EXPECT_TRUE(reducer.get() == NULL); 75 | 76 | reducer.reset(CREATE_REDUCER("HelloWorldReducer")); 77 | EXPECT_TRUE(reducer.get() == NULL); 78 | 79 | reducer.reset(CREATE_REDUCER("HelloMapper")); 80 | EXPECT_TRUE(reducer.get() == NULL); 81 | 82 | reducer.reset(CREATE_REDUCER("WorldMapper")); 83 | EXPECT_TRUE(reducer.get() == NULL); 84 | 85 | reducer.reset(CREATE_REDUCER("HelloReducer")); 86 | ASSERT_TRUE(reducer.get() != NULL); 87 | EXPECT_EQ("HelloReducer", reducer->GetReducerName()); 88 | 89 | reducer.reset(CREATE_REDUCER("WorldReducer")); 90 | ASSERT_TRUE(reducer.get() != NULL); 91 | EXPECT_EQ("WorldReducer", reducer->GetReducerName()); 92 | } 93 | 94 | TEST(ClassRegister, CreateFileImpl) { 95 | scoped_ptr file_impl; 96 | file_impl.reset(CREATE_FILE_IMPL("/mem")); 97 | ASSERT_TRUE(file_impl.get() != NULL); 98 | EXPECT_EQ("MemFileImpl", file_impl->GetFileImplName()); 99 | 100 | file_impl.reset(CREATE_FILE_IMPL("/nfs")); 101 | ASSERT_TRUE(file_impl.get() != NULL); 102 | EXPECT_EQ("NetworkFileImpl", file_impl->GetFileImplName()); 103 | 104 | file_impl.reset(CREATE_FILE_IMPL("/local")); 105 | ASSERT_TRUE(file_impl.get() != NULL); 106 | EXPECT_EQ("LocalFileImpl", file_impl->GetFileImplName()); 107 | 108 | file_impl.reset(CREATE_FILE_IMPL("/")); 109 | ASSERT_TRUE(file_impl.get() != NULL); 110 | EXPECT_EQ("LocalFileImpl", file_impl->GetFileImplName()); 111 | 112 | file_impl.reset(CREATE_FILE_IMPL("")); 113 | ASSERT_TRUE(file_impl.get() != NULL); 114 | EXPECT_EQ("LocalFileImpl", file_impl->GetFileImplName()); 115 | 116 | file_impl.reset(CREATE_FILE_IMPL("/mem2")); 117 | ASSERT_TRUE(file_impl.get() != NULL); 118 | EXPECT_EQ("LocalFileImpl", file_impl->GetFileImplName()); 119 | 120 | file_impl.reset(CREATE_FILE_IMPL("/mem/")); 121 | ASSERT_TRUE(file_impl.get() != NULL); 122 | EXPECT_EQ("LocalFileImpl", file_impl->GetFileImplName()); 123 | 124 | file_impl.reset(CREATE_FILE_IMPL("/nfs2/")); 125 | ASSERT_TRUE(file_impl.get() != NULL); 126 | EXPECT_EQ("LocalFileImpl", file_impl->GetFileImplName()); 127 | } 128 | -------------------------------------------------------------------------------- /base/class_register_test_helper.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | #include "base/class_register_test_helper.h" 6 | 7 | CLASS_REGISTER_IMPLEMENT_REGISTRY(mapper_register, Mapper); 8 | CLASS_REGISTER_IMPLEMENT_REGISTRY(second_mapper_register, Mapper); 9 | CLASS_REGISTER_IMPLEMENT_REGISTRY(reducer_register, Reducer); 10 | CLASS_REGISTER_IMPLEMENT_REGISTRY(file_impl_register, FileImpl); 11 | 12 | 13 | class HelloMapper : public Mapper { 14 | virtual std::string GetMapperName() const { 15 | return "HelloMapper"; 16 | } 17 | }; 18 | REGISTER_MAPPER(HelloMapper); 19 | 20 | class WorldMapper : public Mapper { 21 | virtual std::string GetMapperName() const { 22 | return "WorldMapper"; 23 | } 24 | }; 25 | REGISTER_MAPPER(WorldMapper); 26 | 27 | class SecondaryMapper : public Mapper { 28 | virtual std::string GetMapperName() const { 29 | return "SecondaryMapper"; 30 | } 31 | }; 32 | REGISTER_SECONDARY_MAPPER(SecondaryMapper); 33 | 34 | 35 | class HelloReducer : public Reducer { 36 | virtual std::string GetReducerName() const { 37 | return "HelloReducer"; 38 | } 39 | }; 40 | REGISTER_REDUCER(HelloReducer); 41 | 42 | class WorldReducer : public Reducer { 43 | virtual std::string GetReducerName() const { 44 | return "WorldReducer"; 45 | } 46 | }; 47 | REGISTER_REDUCER(WorldReducer); 48 | 49 | 50 | class LocalFileImpl : public FileImpl { 51 | virtual std::string GetFileImplName() const { 52 | return "LocalFileImpl"; 53 | } 54 | }; 55 | REGISTER_DEFAULT_FILE_IMPL(LocalFileImpl); 56 | REGISTER_FILE_IMPL("/local", LocalFileImpl); 57 | 58 | class MemFileImpl : public FileImpl { 59 | virtual std::string GetFileImplName() const { 60 | return "MemFileImpl"; 61 | } 62 | }; 63 | REGISTER_FILE_IMPL("/mem", MemFileImpl); 64 | 65 | class NetworkFileImpl : public FileImpl { 66 | virtual std::string GetFileImplName() const { 67 | return "NetworkFileImpl"; 68 | } 69 | }; 70 | REGISTER_FILE_IMPL("/nfs", NetworkFileImpl); 71 | -------------------------------------------------------------------------------- /base/class_register_test_helper.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | // 5 | // Defines several base class and registers for testing. We intentionally 6 | // define them in a separate file as some compilers don't correctly support to 7 | // define static variable in inline function, they create a separate copy each 8 | // time it's included. We want to make sure it doesn't happen to us. 9 | 10 | #ifndef COMMON_BASE_CLASS_REGISTER_HELPER_H_ 11 | #define COMMON_BASE_CLASS_REGISTER_HELPER_H_ 12 | 13 | #include 14 | #include "base/class_register.h" 15 | 16 | class Mapper { 17 | public: 18 | Mapper() {} 19 | virtual ~Mapper() {} 20 | 21 | virtual std::string GetMapperName() const = 0; 22 | }; 23 | 24 | CLASS_REGISTER_DEFINE_REGISTRY(mapper_register, Mapper); 25 | 26 | #define REGISTER_MAPPER(mapper_name) \ 27 | CLASS_REGISTER_OBJECT_CREATOR( \ 28 | mapper_register, Mapper, #mapper_name, mapper_name) 29 | 30 | #define CREATE_MAPPER(mapper_name_as_string) \ 31 | CLASS_REGISTER_CREATE_OBJECT(mapper_register, mapper_name_as_string) 32 | 33 | CLASS_REGISTER_DEFINE_REGISTRY(second_mapper_register, Mapper); 34 | 35 | #define REGISTER_SECONDARY_MAPPER(mapper_name) \ 36 | CLASS_REGISTER_OBJECT_CREATOR( \ 37 | second_mapper_register, Mapper, #mapper_name, mapper_name) 38 | 39 | #define CREATE_SECONDARY_MAPPER(mapper_name_as_string) \ 40 | CLASS_REGISTER_CREATE_OBJECT(second_mapper_register, \ 41 | mapper_name_as_string) 42 | 43 | 44 | class Reducer { 45 | public: 46 | Reducer() {} 47 | virtual ~Reducer() {} 48 | 49 | virtual std::string GetReducerName() const = 0; 50 | }; 51 | 52 | CLASS_REGISTER_DEFINE_REGISTRY(reducer_register, Reducer); 53 | 54 | #define REGISTER_REDUCER(reducer_name) \ 55 | CLASS_REGISTER_OBJECT_CREATOR( \ 56 | reducer_register, Reducer, #reducer_name, reducer_name) 57 | 58 | #define CREATE_REDUCER(reducer_name_as_string) \ 59 | CLASS_REGISTER_CREATE_OBJECT(reducer_register, reducer_name_as_string) 60 | 61 | 62 | 63 | class FileImpl { 64 | public: 65 | FileImpl() {} 66 | virtual ~FileImpl() {} 67 | 68 | virtual std::string GetFileImplName() const = 0; 69 | }; 70 | 71 | CLASS_REGISTER_DEFINE_REGISTRY(file_impl_register, FileImpl); 72 | 73 | #define REGISTER_DEFAULT_FILE_IMPL(file_impl_name) \ 74 | CLASS_REGISTER_DEFAULT_OBJECT_CREATOR( \ 75 | file_impl_register, FileImpl, file_impl_name) 76 | 77 | #define REGISTER_FILE_IMPL(path_prefix_as_string, file_impl_name) \ 78 | CLASS_REGISTER_OBJECT_CREATOR( \ 79 | file_impl_register, FileImpl, path_prefix_as_string, file_impl_name) 80 | 81 | #define CREATE_FILE_IMPL(path_prefix_as_string) \ 82 | CLASS_REGISTER_CREATE_OBJECT(file_impl_register, path_prefix_as_string) 83 | 84 | #endif // COMMON_BASE_CLASS_REGISTER_HELPER_H_ 85 | -------------------------------------------------------------------------------- /base/cvector.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // Define class templates CVector and CVector 5 | // 6 | // CVector is designed to encapsulate C-style vectors 7 | // (including malloc, free and memcpy-type initialization) with a 8 | // minimum interface. Compared with std::vector, CVector ensures that 9 | // elements are stored successively in memory layout, and provides an 10 | // interface to access the elements in C-style. 11 | // 12 | // CVector is a partial specialization of CVector. 13 | // Whenever you assign a pointer to any element in CVector, 14 | // CVector takes the ownership of your pointer, and will 15 | // free it in CVector's destructure using C++ keyword delete. 16 | // 17 | // Example: 18 | // 19 | // Many C libraries like GNU Scientific Library (GSL) operates C-style 20 | // vectors. The following example shows how to draw a sample 21 | // (parameters of a multinomial distribution) from a symmetric 22 | // Dirichlet distribution: 23 | /* 24 | const int kDim = 10; 25 | const double kParam = 0.1 26 | CVector dirichlet_param(kDim, kParam); 27 | CVector multinomial_param(kDim); 28 | gsl_ran_dirichlet(gsl_rng, 29 | dirichlet_param.size(), dirichlet_param.data(), 30 | multinomial_param.data()); 31 | */ 32 | // 33 | // Using CVector and CVector, it is easy to define 34 | // a matrix (2D array). For example: 35 | /* 36 | template 37 | class Matrix { 38 | public: 39 | void resize(int n, int m) { 40 | rows_.resize(n); 41 | for (int i = 0; i < n; ++i) { 42 | rows_.data()[i] = new CVector(m, 0); 43 | } 44 | } 45 | private: 46 | CVector rows_; 47 | }; 48 | */ 49 | 50 | #ifndef BASE_CVECTOR_H_ 51 | #define BASE_CVECTOR_H_ 52 | 53 | #include 54 | 55 | #include "base/common.h" 56 | 57 | template 58 | class CVector { 59 | public: 60 | CVector(int size, const Element& init) { 61 | Allocate(size); 62 | Initialize(init); 63 | } 64 | 65 | explicit CVector(int size) { Allocate(size); } 66 | 67 | CVector() { 68 | data_ = NULL; 69 | size_ = 0; 70 | } 71 | 72 | ~CVector() { Deallocate(); } 73 | 74 | void resize(int size, const Element& init) { 75 | Deallocate(); 76 | Allocate(size); 77 | Initialize(init); 78 | } 79 | 80 | void resize(int size) { 81 | Deallocate(); 82 | Allocate(size); 83 | } 84 | 85 | int size() const { return size_; } 86 | 87 | const Element* data() const { return data_; } 88 | Element* data() { return data_; } 89 | 90 | protected: 91 | void Allocate(int size) { 92 | size_ = size; 93 | data_ = new Element[size]; 94 | } 95 | 96 | void Deallocate() { 97 | delete [] data_; 98 | data_ = NULL; 99 | } 100 | 101 | void Initialize(const Element& init) { 102 | for (int i = 0; i < size_; ++i) { 103 | data_[i] = init; 104 | } 105 | } 106 | 107 | Element* data_; 108 | int size_; 109 | 110 | private: 111 | DISALLOW_COPY_AND_ASSIGN(CVector); 112 | }; 113 | 114 | 115 | template 116 | class CVector { 117 | public: 118 | explicit CVector(int size) { 119 | Allocate(size); 120 | Initialize(NULL); 121 | } 122 | 123 | CVector() { 124 | data_ = NULL; 125 | size_ = 0; 126 | } 127 | 128 | ~CVector() { Deallocate(); } 129 | 130 | void resize(int size, const Element* init) { 131 | Deallocate(); 132 | Allocate(size); 133 | Initialize(init); 134 | } 135 | 136 | void resize(int size) { 137 | Deallocate(); 138 | Allocate(size); 139 | } 140 | 141 | int size() const { return size_; } 142 | 143 | const Element** data() const { return data_; } 144 | Element** data() { return data_; } 145 | 146 | protected: 147 | void Allocate(int size) { 148 | size_ = size; 149 | data_ = new Element*[size]; 150 | } 151 | 152 | void Deallocate() { 153 | for (int i = 0; i < size_; ++i) { 154 | if (data_[i] != NULL) { 155 | delete data_[i]; 156 | } 157 | } 158 | delete [] data_; 159 | data_ = NULL; 160 | } 161 | 162 | void Initialize(Element* init) { 163 | for (int i = 0; i < size_; ++i) { 164 | data_[i] = init; 165 | } 166 | } 167 | 168 | Element** data_; 169 | int size_; 170 | 171 | private: 172 | DISALLOW_COPY_AND_ASSIGN(CVector); 173 | }; 174 | 175 | #endif // BASE_CVECTOR_H_ 176 | 177 | -------------------------------------------------------------------------------- /base/cvector_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include "gtest/gtest.h" 5 | 6 | #include "base/common.h" 7 | #include "base/cvector.h" 8 | 9 | class DestructDetector { 10 | public: 11 | explicit DestructDetector(bool* flag) { 12 | flag_ = flag; 13 | } 14 | ~DestructDetector() { 15 | *flag_ = true; 16 | } 17 | private: 18 | bool* flag_; 19 | }; 20 | 21 | 22 | TEST(CVectorTest, NumericalValueElements) { 23 | CVector vd(3, 0.1); 24 | EXPECT_EQ(vd.size(), 3); 25 | for (int i = 0; i < vd.size(); ++i) { 26 | EXPECT_EQ(vd.data()[i], 0.1); 27 | } 28 | 29 | // Note: uncomment the following line to test 30 | // DISALLOW_COPY_AND_ASSIGN of CVector. 31 | // CVector vcd = vd; 32 | 33 | vd.resize(4, 0.2); 34 | EXPECT_EQ(vd.size(), 4); 35 | for (int i = 0; i < vd.size(); ++i) { 36 | EXPECT_EQ(vd.data()[i], 0.2); 37 | } 38 | } 39 | 40 | TEST(CVectorTest, PointerValueElements) { 41 | CVector vv(3); 42 | EXPECT_EQ(vv.size(), 3); 43 | for (int i = 0; i < vv.size(); ++i) { 44 | DestructDetector* d = vv.data()[i]; 45 | EXPECT_EQ(static_cast(NULL), d); 46 | } 47 | 48 | // Note: uncomment the following line to test 49 | // DISALLOW_COPY_AND_ASSIGN of CVector. 50 | // CVector vvcd = vv; 51 | 52 | bool destruct_flag = false; 53 | vv.data()[1] = new DestructDetector(&destruct_flag); 54 | 55 | vv.resize(0); 56 | EXPECT_EQ(vv.size(), 0); 57 | EXPECT_EQ(destruct_flag, true); 58 | } 59 | -------------------------------------------------------------------------------- /base/logging.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | #include "base/logging.h" 5 | 6 | #include 7 | #include 8 | 9 | std::ofstream Logger::info_log_file_; 10 | std::ofstream Logger::warn_log_file_; 11 | std::ofstream Logger::erro_log_file_; 12 | 13 | void InitializeLogger(const std::string& info_log_filename, 14 | const std::string& warn_log_filename, 15 | const std::string& erro_log_filename) { 16 | Logger::info_log_file_.open(info_log_filename.c_str()); 17 | Logger::warn_log_file_.open(warn_log_filename.c_str()); 18 | Logger::erro_log_file_.open(erro_log_filename.c_str()); 19 | } 20 | 21 | /*static*/ 22 | std::ostream& Logger::GetStream(LogSeverity severity) { 23 | return (severity == INFO) ? 24 | (info_log_file_.is_open() ? info_log_file_ : std::cout) : 25 | (severity == WARNING ? 26 | (warn_log_file_.is_open() ? warn_log_file_ : std::cerr) : 27 | (erro_log_file_.is_open() ? erro_log_file_ : std::cerr)); 28 | } 29 | 30 | /*static*/ 31 | std::ostream& Logger::Start(LogSeverity severity, 32 | const std::string& file, 33 | int line, 34 | const std::string& function) { 35 | time_t tm; 36 | time(&tm); 37 | char time_string[128]; 38 | ctime_r(&tm, time_string); 39 | return GetStream(severity) << time_string 40 | << " " << file << ":" << line 41 | << " (" << function << ") " << std::flush; 42 | } 43 | 44 | Logger::~Logger() { 45 | GetStream(severity_) << "\n" << std::flush; 46 | 47 | if (severity_ == FATAL) { 48 | info_log_file_.close(); 49 | warn_log_file_.close(); 50 | erro_log_file_.close(); 51 | abort(); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /base/logging.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // Provide logging facilities that treat log messages by their 5 | // severities. If function |InitializeLogger| was invoked and was 6 | // able to open files specified by the parameters, log messages of 7 | // various severity will be written into corresponding files. 8 | // Otherwise, all log messages will be written to stderr. 9 | // 10 | // Example: 11 | /* 12 | int main() { 13 | InitializeLogger("/tmp/info.log", "/tmp/warn.log", "/tmp/erro.log"); 14 | LOG(INFO) << "An info message going into /tmp/info.log"; 15 | LOG(WARNING) << "An warn message going into /tmp/warn.log"; 16 | LOG(ERROR) << "An erro message going into /tmp/erro.log"; 17 | LOG(FATAL) << "An fatal message going into /tmp/erro.log, " 18 | << "and kills current process by a segmentation fault."; 19 | return 0; 20 | } 21 | */ 22 | #ifndef _BASE_LOGGING_H_ 23 | #define _BASE_LOGGING_H_ 24 | 25 | #include 26 | #include 27 | #include 28 | #include 29 | 30 | void InitializeLogger(const std::string& info_log_filename, 31 | const std::string& warn_log_filename, 32 | const std::string& erro_log_filename); 33 | 34 | enum LogSeverity { INFO, WARNING, ERROR, FATAL }; 35 | 36 | class Logger { 37 | friend void InitializeLogger(const std::string& info_log_filename, 38 | const std::string& warn_log_filename, 39 | const std::string& erro_log_filename); 40 | public: 41 | Logger(LogSeverity s) : severity_(s) {} 42 | ~Logger(); 43 | 44 | static std::ostream& GetStream(LogSeverity severity); 45 | static std::ostream& Start(LogSeverity severity, 46 | const std::string& file, 47 | int line, 48 | const std::string& function); 49 | 50 | private: 51 | static std::ofstream info_log_file_; 52 | static std::ofstream warn_log_file_; 53 | static std::ofstream erro_log_file_; 54 | LogSeverity severity_; 55 | }; 56 | 57 | 58 | // The basic mechanism of logging.{h,cc} is as follows: 59 | // - LOG(severity) defines a Logger instance, which records the severity. 60 | // - LOG(severity) then invokes Logger::Start(), which invokes Logger::Stream 61 | // to choose an output stream, outputs a message head into the stream and 62 | // flush. 63 | // - The std::ostream reference returned by LoggerStart() is then passed to 64 | // user-specific output operators (<<), which writes the log message body. 65 | // - When the Logger instance is destructed, the destructor appends flush. 66 | // If severity is FATAL, the destructor causes SEGFAULT and core dump. 67 | // 68 | // It is important to flush in Logger::Start() after outputing message 69 | // head. This is because that the time when the destructor is invoked 70 | // depends on how/where the caller code defines the Logger instance. 71 | // If the caller code crashes before the Logger instance is properly 72 | // destructed, the destructor might not have the chance to append its 73 | // flush flags. Without flush in Logger::Start(), this may cause the 74 | // lose of the last few messages. However, given flush in Start(), 75 | // program crashing between invocations to Logger::Start() and 76 | // destructor only causes the lose of the last message body; while the 77 | // message head will be there. 78 | // 79 | #define LOG(severity) \ 80 | Logger(severity).Start(severity, __FILE__, __LINE__, __FUNCTION__) 81 | 82 | 83 | #endif //_BASE_LOGGING_H_ 84 | -------------------------------------------------------------------------------- /base/random.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | // 5 | #include "base/random.h" 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | /*static*/ 13 | uint32 Random::GetTickCount() { 14 | struct timeval t; 15 | gettimeofday(&t, NULL); 16 | t.tv_sec %= (24 * 60 * 60); // one day ticks 24*60*60 17 | uint32 tick_count = t.tv_sec * 1000 + t.tv_usec / 1000; 18 | return tick_count; 19 | } 20 | 21 | void CRuntimeRandom::SeedRNG(int seed) { 22 | if (seed >= 0) { 23 | seed_ = seed; 24 | } else { 25 | seed_ = getpid() * time(NULL); // BUG(yiwang): should also times thread id 26 | } 27 | } 28 | 29 | void MTRandom::SeedRNG(int seed) { 30 | if (seed >= 0) { 31 | uniform_01_rng_.base().seed(seed); 32 | } else { 33 | uniform_01_rng_.base().seed(GetTickCount()); 34 | } 35 | } 36 | 37 | -------------------------------------------------------------------------------- /base/random.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | // 5 | // Wrappers for several pseudo-random number generators used in LDA, including 6 | // the default unix RNG(random number generator), and RNGs provided in Boost 7 | // library. We suggest not using unix RNG due to its poor randomness in 8 | // application. 9 | // 10 | #ifndef BASE_RANDOM_H_ 11 | #define BASE_RANDOM_H_ 12 | 13 | #include "base/common.h" 14 | #include "boost/random.hpp" 15 | 16 | // The RNG wrapper interface 17 | class Random { 18 | public: 19 | Random() {} 20 | 21 | virtual ~Random() {} 22 | 23 | // Seed the RNG using specified seed or current time(if seed < 0). 24 | // In order to achieve maximum randomness we use current time in 25 | // millisecond as the seed. Note that it is not a good idea to 26 | // seed with current time in second when multiple random number 27 | // sequences are required, which usually produces correlated number 28 | // sequences and results in poor randomness. 29 | virtual void SeedRNG(int seed) = 0; 30 | 31 | // Generate a random float value in the range of [0,1) from the 32 | // uniform distribution. 33 | virtual double RandDouble() = 0; 34 | 35 | // Generate a random integer value in the range of [0,bound) from the 36 | // uniform distribution. 37 | virtual int RandInt(int bound) { 38 | return static_cast(RandDouble() * bound); 39 | } 40 | 41 | // Get tick count of the day, used as random seed 42 | static uint32 GetTickCount(); 43 | }; 44 | 45 | 46 | // Wrapper for default C-runtime random number generator 47 | class CRuntimeRandom : public Random { 48 | public: 49 | CRuntimeRandom() { seed_ = 0; } 50 | 51 | virtual ~CRuntimeRandom() {} 52 | 53 | virtual void SeedRNG(int seed); 54 | 55 | virtual double RandDouble() { 56 | // rand() returns a pseudo-random integral number in the range 57 | // [0,RAND_MAX]. original code will generate a random float in [0,1], not 58 | // [0, 1) WARNING : RAND_MAX is the largest integer, so we should cast it 59 | // to double before we do RADND_MAX+1 60 | return rand_r(&seed_) / (static_cast(RAND_MAX) + 1); 61 | } 62 | 63 | private: 64 | unsigned int seed_; 65 | }; 66 | 67 | 68 | // The RNG(random number generator) wrapper using Boost mt19937. 69 | // Please refer to [http://www.boost.org/doc/libs/1_44_0/doc/html/ 70 | // boost_random/reference.html#boost_random.reference.generators] 71 | // for details about mt19937 generator and uniform_01 distribution. 72 | class MTRandom : public Random { 73 | public: 74 | MTRandom() 75 | : uniform_01_rng_(boost::mt19937()) {} 76 | 77 | virtual ~MTRandom() {} 78 | 79 | virtual void SeedRNG(int seed); 80 | 81 | virtual double RandDouble() { 82 | return uniform_01_rng_(); 83 | } 84 | 85 | private: 86 | boost::uniform_01 uniform_01_rng_; 87 | }; 88 | 89 | #endif // BASE_RANDOM_H_ 90 | -------------------------------------------------------------------------------- /base/stl-util.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // This file contains facilities that enhance the STL. 5 | // 6 | #ifndef BASE_STL_UTIL_H_ 7 | #define BASE_STL_UTIL_H_ 8 | 9 | // Delete elements (in pointer type) in a STL container like vector, 10 | // list, and deque. 11 | template 12 | void STLDeleteElementsAndClear(Container* c) { 13 | for (typename Container::iterator iter = c->begin(); 14 | iter != c->end(); ++iter) { 15 | if (*iter != NULL) { 16 | delete *iter; 17 | } 18 | } 19 | c->clear(); 20 | } 21 | 22 | // Delete elements (in pointer type) in a STL associative container 23 | // like map and hash_map. 24 | template 25 | void STLDeleteValuesAndClear(AssocContainer* c) { 26 | for (typename AssocContainer::iterator iter = c->begin(); 27 | iter != c->end(); ++iter) { 28 | if (iter->second != NULL) { 29 | delete iter->second; 30 | } 31 | } 32 | c->clear(); 33 | } 34 | 35 | #endif // BASE_STL_UTIL_H_ 36 | -------------------------------------------------------------------------------- /base/stl-util_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include 5 | #include 6 | 7 | #include "gtest/gtest.h" 8 | 9 | #include "base/common.h" 10 | #include "base/stl-util.h" 11 | 12 | TEST(STLUtilTest, DeleteElementsInVector) { 13 | std::vector v; 14 | v.push_back(new int(10)); 15 | v.push_back(new int(20)); 16 | 17 | EXPECT_EQ(2, v.size()); 18 | EXPECT_EQ(10, *(v[0])); 19 | EXPECT_EQ(20, *(v[1])); 20 | 21 | STLDeleteElementsAndClear(&v); 22 | 23 | EXPECT_EQ(0, v.size()); 24 | } 25 | 26 | TEST(STLUtilTest, DeleteElementsInMap) { 27 | std::map m; 28 | m[100] = new int(10); 29 | m[200] = new int(20); 30 | 31 | EXPECT_EQ(2, m.size()); 32 | EXPECT_EQ(10, *(m[100])); 33 | EXPECT_EQ(20, *(m[200])); 34 | 35 | STLDeleteValuesAndClear(&m); 36 | 37 | EXPECT_EQ(0, m.size()); 38 | } 39 | -------------------------------------------------------------------------------- /base/stream_wrapper.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | // 5 | #include "base/stream_wrapper.h" 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | namespace stream_wrapper { 12 | 13 | ostream_wrapper::ostream_wrapper(const char* filename) 14 | : output_stream_(0) { 15 | if (std::strcmp(filename, "-") == 0) 16 | output_stream_ = &std::cout; 17 | else 18 | output_stream_ = new std::ofstream(filename); 19 | } 20 | 21 | ostream_wrapper::~ostream_wrapper() { 22 | if (output_stream_ != &std::cout) delete output_stream_; 23 | } 24 | 25 | istream_wrapper::istream_wrapper(const char* filename) 26 | : input_stream_(0) { 27 | if (std::strcmp(filename, "-") == 0) 28 | input_stream_ = &std::cin; 29 | else 30 | input_stream_ = new std::ifstream(filename, std::ios::binary|std::ios::in); 31 | } 32 | 33 | istream_wrapper::~istream_wrapper() { 34 | if (input_stream_ != &std::cin) delete input_stream_; 35 | } 36 | 37 | } 38 | -------------------------------------------------------------------------------- /base/stream_wrapper.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | // 5 | // Wrappers for iostream, use stdin and stdout if filename is "-" 6 | // 7 | #ifndef BASE_STREAM_WRAPPER_H_ 8 | #define BASE_STREAM_WRAPPER_H_ 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | namespace stream_wrapper { 15 | 16 | /** 17 | * @brief 包装输出流, 使得标准输出和普通文件输出接口一致 18 | */ 19 | class ostream_wrapper 20 | { 21 | public: 22 | std::ostream &operator*() const { return *output_stream_; } 23 | std::ostream *operator->() const { return output_stream_; } 24 | 25 | explicit ostream_wrapper(const char* filename); 26 | virtual ~ostream_wrapper(); 27 | 28 | private: 29 | std::ostream* output_stream_; 30 | }; 31 | 32 | /** 33 | * @brief 包装输入流, 使得标准输入和普通文件输入接口一致 34 | */ 35 | class istream_wrapper { 36 | public: 37 | std::istream &operator*() const { return *input_stream_; } 38 | std::istream *operator->() const { return input_stream_; } 39 | 40 | explicit istream_wrapper(const char* filename); 41 | virtual ~istream_wrapper(); 42 | 43 | private: 44 | std::istream* input_stream_; 45 | }; 46 | 47 | } // namespace 48 | #endif // BASE_STREAM_WRAPPER_H_ 49 | -------------------------------------------------------------------------------- /base/varint32.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include "base/varint32.h" 5 | 6 | #include 7 | 8 | #include "base/common.h" 9 | #include "google/protobuf/io/coded_stream.h" 10 | 11 | namespace { 12 | inline uint32 GetByte(FILE* input) { 13 | return static_cast(getc(input)); 14 | } 15 | } // namespace 16 | 17 | // The following varint32 decoding code snippet is copied from an 18 | // internal function of Google protobuf. It is compliant with the 19 | // varint32 codec specification at 20 | // http://code.google.com/apis/protocolbuffers/docs/encoding.html 21 | bool ReadVarint32(FILE* input, uint32* value) { 22 | static const int kMaxVarintBytes = 10; 23 | static const int kMaxVarint32Bytes = 5; 24 | 25 | if (ferror(input) || feof(input)) 26 | return false; 27 | 28 | uint32 b; 29 | uint32 result; 30 | 31 | b = GetByte(input); 32 | result = (b & 0x7F); 33 | if (!(b & 0x80)) goto done; 34 | 35 | b = GetByte(input); 36 | result |= (b & 0x7F) << 7; 37 | if (!(b & 0x80)) goto done; 38 | 39 | b = GetByte(input); 40 | result |= (b & 0x7F) << 14; 41 | if (!(b & 0x80)) goto done; 42 | 43 | b = GetByte(input); 44 | result |= (b & 0x7F) << 21; 45 | if (!(b & 0x80)) goto done; 46 | 47 | b = GetByte(input); 48 | result |= b << 28; 49 | if (!(b & 0x80)) goto done; 50 | 51 | // If the input is larger than 32 bits, we still need to read it all 52 | // and discard the high-order bits. 53 | for (int i = 0; i < kMaxVarintBytes - kMaxVarint32Bytes; i++) { 54 | b = GetByte(input); 55 | if (!(b & 0x80)) goto done; 56 | } 57 | 58 | // We have overrun the maximum size of a varint (10 bytes). Assume 59 | // the data is corrupt. 60 | return false; 61 | 62 | done: 63 | *value = result; 64 | return true; 65 | } 66 | 67 | 68 | bool WriteVarint32(FILE* output, uint32 value) { 69 | using google::protobuf::io::CodedOutputStream; 70 | uint8 buffer[4]; 71 | uint8* end = CodedOutputStream::WriteVarint32ToArray(value, buffer); 72 | return fwrite(buffer, 1, end - buffer, output) == (end - buffer); 73 | } 74 | 75 | 76 | -------------------------------------------------------------------------------- /base/varint32.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // Implementation the codec of varint32 using code from Google Protobuf. 5 | // 6 | #ifndef BASE_VARINT32_H_ 7 | #define BASE_VARINT32_H_ 8 | 9 | #include "base/common.h" 10 | 11 | bool ReadVarint32(FILE* input, uint32* value); 12 | bool WriteVarint32(FILE* output, uint32 value); 13 | 14 | #endif // BASE_VARINT32_H_ 15 | -------------------------------------------------------------------------------- /base/varint32_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | #include "base/varint32.h" 5 | 6 | #include 7 | 8 | #include "gtest/gtest.h" 9 | 10 | #include "base/common.h" 11 | 12 | TEST(Varint32Test, WriteAndReadVarint32) { 13 | static const char* kTmpFile = "/tmp/varint32_test.tmp"; 14 | 15 | uint32 kTestValues[] = { 0, 1, 0xff, 0xffff, 0xffffffff }; 16 | 17 | FILE* output = fopen(kTmpFile, "w+"); 18 | for (int i = 0; i < sizeof(kTestValues)/sizeof(kTestValues[0]); ++i) { 19 | if (!WriteVarint32(output, kTestValues[i])) { 20 | LOG(FATAL) << "Error on WriteVarint32 with value= " << kTestValues[i]; 21 | } 22 | } 23 | fclose(output); 24 | 25 | FILE* input = fopen(kTmpFile, "r"); 26 | for (int i = 0; i < sizeof(kTestValues)/sizeof(kTestValues[0]); ++i) { 27 | uint32 value; 28 | if (!ReadVarint32(input, &value)) { 29 | LOG(FATAL) << "Error on ReadVarint32 with value = " << kTestValues[i]; 30 | } 31 | EXPECT_EQ(kTestValues[i], value); 32 | } 33 | fclose(input); 34 | } 35 | -------------------------------------------------------------------------------- /hash/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Build library strutil. 2 | add_library(hash md5_hash.cc simple_hash.cc) 3 | 4 | # Build unittests. 5 | set(LIBS base hash gtest pthread) 6 | 7 | add_executable(md5_hash_test md5_hash_test.cc) 8 | target_link_libraries(md5_hash_test gtest_main ${LIBS}) 9 | 10 | # Install library and header files 11 | install(TARGETS hash DESTINATION bin/hash) 12 | FILE(GLOB HEADER_FILES "${CMAKE_CURRENT_SOURCE_DIR}/*.h") 13 | install(FILES ${HEADER_FILES} DESTINATION include/hash) 14 | -------------------------------------------------------------------------------- /hash/md5_hash.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // This file exports the MD5 hashing algorithm. 5 | // 6 | #ifndef HASH_MD5_HASH_H_ 7 | #define HASH_MD5_HASH_H_ 8 | 9 | #include 10 | 11 | #include "base/common.h" 12 | 13 | uint64 MD5Hash(const unsigned char *s, const unsigned int len); 14 | uint64 MD5Hash(const std::string& s); 15 | 16 | #endif // HASH_MD5_HASH_H_ 17 | -------------------------------------------------------------------------------- /hash/md5_hash_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // This unittest uses grouth-truth MD5 results from Wikipedia: 5 | // http://en.wikipedia.org/wiki/MD5#MD5_hashes. 6 | 7 | #include 8 | 9 | #include "gtest/gtest.h" 10 | 11 | #include "base/common.h" 12 | #include "hash/md5_hash.h" 13 | 14 | TEST(MD5Test, AsGroundTruthOnWikipedia) { 15 | EXPECT_EQ(MD5Hash("The quick brown fox jumps over the lazy dog"), 16 | 0x82b62b379d7d109eLLU); 17 | EXPECT_EQ(MD5Hash("The quick brown fox jumps over the lazy dog."), 18 | 0x1cfbd090c209d9e4LLU); 19 | EXPECT_EQ(MD5Hash(""), 0x04b2008fd98c1dd4LLU); 20 | } 21 | -------------------------------------------------------------------------------- /hash/simple_hash.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // This source file is mainly copied from http://www.partow.net, with 5 | // the following Copyright information. 6 | /* 7 | ************************************************************************** 8 | * * 9 | * General Purpose Hash Function Algorithms Library * 10 | * * 11 | * Author: Arash Partow - 2002 * 12 | * URL: http://www.partow.net * 13 | * URL: http://www.partow.net/programming/hashfunctions/index.html * 14 | * * 15 | * Copyright notice: * 16 | * Free use of the General Purpose Hash Function Algorithms Library is * 17 | * permitted under the guidelines and in accordance with the most current * 18 | * version of the Common Public License. * 19 | * http://www.opensource.org/licenses/cpl1.0.php * 20 | * * 21 | ************************************************************************** 22 | */ 23 | #include "hash/simple_hash.h" 24 | 25 | #include 26 | 27 | unsigned int RSHash(const std::string& str) { 28 | unsigned int b = 378551; 29 | unsigned int a = 63689; 30 | unsigned int hash = 0; 31 | 32 | for (std::size_t i = 0; i < str.length(); i++) { 33 | hash = hash * a + str[i]; 34 | a = a * b; 35 | } 36 | return hash; 37 | } 38 | /* End Of RS Hash Function */ 39 | 40 | 41 | unsigned int JSHash(const std::string& str) { 42 | unsigned int hash = 1315423911; 43 | 44 | for (std::size_t i = 0; i < str.length(); i++) { 45 | hash ^= ((hash << 5) + str[i] + (hash >> 2)); 46 | } 47 | return hash; 48 | } 49 | /* End Of JS Hash Function */ 50 | 51 | 52 | unsigned int PJWHash(const std::string& str) { 53 | unsigned int BitsInUnsignedInt = (unsigned int)(sizeof(unsigned int) * 8); 54 | unsigned int ThreeQuarters = (unsigned int)((BitsInUnsignedInt * 3) / 4); 55 | unsigned int OneEighth = (unsigned int)(BitsInUnsignedInt / 8); 56 | unsigned int HighBits = 57 | (unsigned int)(0xFFFFFFFF) << (BitsInUnsignedInt - OneEighth); 58 | unsigned int hash = 0; 59 | unsigned int test = 0; 60 | 61 | for (std::size_t i = 0; i < str.length(); i++) { 62 | hash = (hash << OneEighth) + str[i]; 63 | 64 | if ((test = hash & HighBits) != 0) { 65 | hash = ((hash ^ (test >> ThreeQuarters)) & (~HighBits)); 66 | } 67 | } 68 | 69 | return hash; 70 | } 71 | /* End Of P. J. Weinberger Hash Function */ 72 | 73 | 74 | unsigned int ELFHash(const std::string& str) { 75 | unsigned int hash = 0; 76 | unsigned int x = 0; 77 | 78 | for (std::size_t i = 0; i < str.length(); i++) { 79 | hash = (hash << 4) + str[i]; 80 | if ((x = hash & 0xF0000000L) != 0) { 81 | hash ^= (x >> 24); 82 | } 83 | hash &= ~x; 84 | } 85 | 86 | return hash; 87 | } 88 | /* End Of ELF Hash Function */ 89 | 90 | 91 | unsigned int BKDRHash(const std::string& str) { 92 | unsigned int seed = 131; // 31 131 1313 13131 131313 etc.. 93 | unsigned int hash = 0; 94 | 95 | for (std::size_t i = 0; i < str.length(); i++) { 96 | hash = (hash * seed) + str[i]; 97 | } 98 | 99 | return hash; 100 | } 101 | /* End Of BKDR Hash Function */ 102 | 103 | 104 | unsigned int SDBMHash(const std::string& str) { 105 | unsigned int hash = 0; 106 | 107 | for (std::size_t i = 0; i < str.length(); i++) { 108 | hash = str[i] + (hash << 6) + (hash << 16) - hash; 109 | } 110 | 111 | return hash; 112 | } 113 | /* End Of SDBM Hash Function */ 114 | 115 | 116 | unsigned int DJBHash(const std::string& str) { 117 | unsigned int hash = 5381; 118 | 119 | for (std::size_t i = 0; i < str.length(); i++) { 120 | hash = ((hash << 5) + hash) + str[i]; 121 | } 122 | 123 | return hash; 124 | } 125 | /* End Of DJB Hash Function */ 126 | 127 | 128 | unsigned int DEKHash(const std::string& str) { 129 | unsigned int hash = static_cast(str.length()); 130 | 131 | for (std::size_t i = 0; i < str.length(); i++) { 132 | hash = ((hash << 5) ^ (hash >> 27)) ^ str[i]; 133 | } 134 | 135 | return hash; 136 | } 137 | /* End Of DEK Hash Function */ 138 | 139 | 140 | unsigned int BPHash(const std::string& str) { 141 | unsigned int hash = 0; 142 | for (std::size_t i = 0; i < str.length(); i++) { 143 | hash = hash << 7 ^ str[i]; 144 | } 145 | 146 | return hash; 147 | } 148 | /* End Of BP Hash Function */ 149 | 150 | 151 | unsigned int FNVHash(const std::string& str) { 152 | const unsigned int fnv_prime = 0x811C9DC5; 153 | unsigned int hash = 0; 154 | for (std::size_t i = 0; i < str.length(); i++) { 155 | hash *= fnv_prime; 156 | hash ^= str[i]; 157 | } 158 | 159 | return hash; 160 | } 161 | /* End Of FNV Hash Function */ 162 | 163 | 164 | unsigned int APHash(const std::string& str) { 165 | unsigned int hash = 0xAAAAAAAA; 166 | 167 | for (std::size_t i = 0; i < str.length(); i++) { 168 | hash ^= ((i & 1) == 0) ? ((hash << 7) ^ str[i] * (hash >> 3)) : 169 | (~((hash << 11) + (str[i] ^ (hash >> 5)))); 170 | } 171 | 172 | return hash; 173 | } 174 | /* End Of AP Hash Function */ 175 | -------------------------------------------------------------------------------- /hash/simple_hash.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // This source file is mainly copied from http://www.partow.net, with 5 | // the following Copyright information. 6 | /* 7 | ************************************************************************** 8 | * * 9 | * General Purpose Hash Function Algorithms Library * 10 | * * 11 | * Author: Arash Partow - 2002 * 12 | * URL: http://www.partow.net * 13 | * URL: http://www.partow.net/programming/hashfunctions/index.html * 14 | * * 15 | * Copyright notice: * 16 | * Free use of the General Purpose Hash Function Algorithms Library is * 17 | * permitted under the guidelines and in accordance with the most current * 18 | * version of the Common Public License. * 19 | * http://www.opensource.org/licenses/cpl1.0.php * 20 | * * 21 | ************************************************************************** 22 | */ 23 | #ifndef HASH_SIMPLE_HASH_H_ 24 | #define HASH_SIMPLE_HASH_H_ 25 | 26 | #include 27 | 28 | #include "base/common.h" 29 | 30 | typedef unsigned int (*HashFunction)(const std::string&); 31 | 32 | unsigned int RSHash(const std::string& str); 33 | unsigned int JSHash(const std::string& str); 34 | unsigned int PJWHash(const std::string& str); 35 | unsigned int ELFHash(const std::string& str); 36 | unsigned int BKDRHash(const std::string& str); 37 | unsigned int SDBMHash(const std::string& str); 38 | unsigned int DJBHash(const std::string& str); 39 | unsigned int DEKHash(const std::string& str); 40 | unsigned int BPHash(const std::string& str); 41 | unsigned int FNVHash(const std::string& str); 42 | unsigned int APHash(const std::string& str); 43 | 44 | #endif // HASH_SIMPLE_HASH_H_ 45 | -------------------------------------------------------------------------------- /mrml-lasso/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Build protobuf 2 | protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS logistic_regression.proto) 3 | 4 | # Build library mrml. 5 | add_library(lasso ${PROTO_SRCS} vector_types.cc learner_states.cc learner.cc) 6 | add_library(lasso-predict prediction_engine.cc) 7 | 8 | # Build unittests. 9 | set(LIBS lasso mrml sorted_buffer strutil hash base mpichcxx mpich opa ssh2 ssl crypto z dl boost_program_options boost_regex boost_filesystem boost_system protobuf gflags gtest pthread) 10 | 11 | add_executable(sparse_vector_tmpl_test sparse_vector_tmpl_test.cc) 12 | target_link_libraries(sparse_vector_tmpl_test gtest_main ${LIBS}) 13 | 14 | add_executable(dense_vector_tmpl_test dense_vector_tmpl_test.cc) 15 | target_link_libraries(dense_vector_tmpl_test gtest_main ${LIBS}) 16 | 17 | add_executable(vector_types_test vector_types_test.cc) 18 | target_link_libraries(vector_types_test gtest_main ${LIBS}) 19 | 20 | add_executable(learner_states_test learner_states_test.cc) 21 | target_link_libraries(learner_states_test gtest_main ${LIBS}) 22 | 23 | add_executable(learner_test learner_test.cc) 24 | target_link_libraries(learner_test gtest_main ${LIBS}) 25 | 26 | add_executable(termination_flag_test termination_flag_test.cc) 27 | target_link_libraries(termination_flag_test ${LIBS}) 28 | 29 | # Build regression test 30 | add_executable(train train.cc) 31 | target_link_libraries(train ${LIBS}) 32 | 33 | # Build MapReduce binaries 34 | add_executable(mrml-lasso mrml_mappers_and_reducers.cc mr_assign_feature_id.cc mr_convert_data_format.cc command_line_options.cc) 35 | target_link_libraries(mrml-lasso mrml-main ${LIBS}) 36 | 37 | # Build dump_learner_states 38 | add_executable(dump_learner_states dump_learner_states.cc) 39 | target_link_libraries(dump_learner_states ${LIBS}) 40 | 41 | # Install library and header files 42 | install(TARGETS mrml-lasso DESTINATION bin/mrml-lasso) 43 | FILE(GLOB HEADER_FILES "${CMAKE_CURRENT_SOURCE_DIR}/*.h") 44 | install(FILES ${HEADER_FILES} DESTINATION include/mrml-lasso) 45 | -------------------------------------------------------------------------------- /mrml-lasso/Makefile.cygwin: -------------------------------------------------------------------------------- 1 | PROTOBUF_BASE_DIR = /home/yiwang/3rd-party/protobuf-2.3.0 2 | 3 | CXX=g++ -Wall -Wno-sign-compare -O2 4 | AR=ar rcs 5 | PROTOC = $(PROTOBUF_BASE_DIR)/bin/protoc 6 | 7 | PROTOBUF_CXX_FLAGS = -I$(PROTOBUF_BASE_DIR)/include/ 8 | PROTOBUF_LD_FLAGS = -L$(PROTOBUF_BASE_DIR)/lib/ -lprotobuf 9 | 10 | BOOST_BASE_DIR = /home/yiwang/3rd-party/boost-1.43.0 11 | BOOST_CXX_FLAGS=-I$(BOOST_BASE_DIR)/include 12 | BOOST_LD_FLAGS=-L$(BOOST_BASE_DIR)/lib/ -lboost_program_options -lboost_regex -lboost_filesystem -lboost_system 13 | 14 | MPI_HOME_DIR=/home/yiwang/3rd-party/mpich2-1.2.1p1 15 | MPI_CXX_FLAGS=-I$(MPI_HOME_DIR)/include 16 | MPI_LD_FLAGS=-L$(MPI_HOME_DIR)/lib -lmpichcxx -lpmpich -lmpich 17 | 18 | SSL_SSH2_CXX_FLAGS= 19 | SSL_SSH2_LD_FLAGS=-lssh2 -lcrypto -lssl -lz 20 | 21 | CXXFLAGS= $(BOOST_CXX_FLAGS) $(MPI_CXX_FLAGS) $(PROTOBUF_CXX_FLAGS) $(SSL_SSH2_CXX_FLAGS) 22 | LDFLAGS=-static -static-libgcc $(BOOST_LD_FLAGS) $(MPI_LD_FLAGS) $(PROTOBUF_LD_FLAGS) $(SSL_SSH2_LD_FLAGS) 23 | 24 | include Makefile.rules 25 | -------------------------------------------------------------------------------- /mrml-lasso/Makefile.rules: -------------------------------------------------------------------------------- 1 | EXTERNAL_LIBS = ../mrml/libmrml.a ../mrml/libmrml-main.a ../strutil/libstrutil.a ../base/libbase.a 2 | UNIT_TESTS = sparse_vector_tmpl_test dense_vector_tmpl_test vector_types_test learner_states_test learner_test termination_flag_test 3 | BUILD_TARGETS = $(UNIT_TESTS) liblr.a train dump_learner_states mrml-lr liblr-predict.a 4 | DIST_DIR = ../../distribution/plr 5 | 6 | all : $(BUILD_TARGETS) 7 | 8 | #------------------------------------------------------------------------------- 9 | # protocol buffer 10 | #------------------------------------------------------------------------------- 11 | logistic_regression.pb.cc : logistic_regression.pb.h 12 | 13 | logistic_regression.pb.h : logistic_regression.proto 14 | $(PROTOC) --cpp_out=./ logistic_regression.proto 15 | 16 | logistic_regression.pb.o : logistic_regression.pb.cc 17 | $(CXX) $(CXXFLAGS) -c logistic_regression.pb.cc 18 | 19 | learner_states.o : learner_states.cc learner_states.hh logistic_regression.pb.h 20 | $(CXX) $(CXXFLAGS) -c learner_states.cc 21 | 22 | #------------------------------------------------------------------------------- 23 | # basic components 24 | #------------------------------------------------------------------------------- 25 | 26 | vector_types.hh : sparse_vector_tmpl.hh dense_vector_tmpl.hh 27 | 28 | vector_types.o : vector_types.cc vector_types.hh sparse_vector_tmpl.hh 29 | 30 | learner_states.o : learner_states.cc learner_states.hh learner_states.hh 31 | 32 | learner.o : learner.cc learner.hh termination_flag.hh 33 | 34 | prediction_engine.o : prediction_engine.cc prediction_engine.hh 35 | 36 | command_line_options.o : command_line_options.cc command_line_options.hh 37 | 38 | mr_convert_data_format.o : mr_convert_data_format.cc mr_convert_data_format.hh command_line_options.hh 39 | 40 | mr_assign_feature_id.o : mr_assign_feature_id.cc mr_assign_feature_id.hh command_line_options.hh 41 | 42 | #------------------------------------------------------------------------------- 43 | # core librar 44 | #------------------------------------------------------------------------------- 45 | liblr.a : logistic_regression.pb.o vector_types.o learner_states.o learner.o 46 | $(AR) $@ $^ 47 | 48 | liblr-predict.a : prediction_engine.o 49 | $(AR) $@ $^ 50 | #------------------------------------------------------------------------------- 51 | # unit tests 52 | #------------------------------------------------------------------------------- 53 | sparse_vector_tmpl_test : sparse_vector_tmpl.hh sparse_vector_tmpl_test.cc ../base/libbase.a test_utils.hh 54 | $(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS) 55 | 56 | dense_vector_tmpl_test : dense_vector_tmpl.hh dense_vector_tmpl_test.cc ../base/libbase.a test_utils.hh 57 | $(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS) 58 | 59 | vector_types_test : vector_types_test.cc vector_types.hh dense_vector_tmpl.hh sparse_vector_tmpl.hh logistic_regression.pb.h liblr.a test_utils.hh ../base/libbase.a 60 | $(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS) 61 | 62 | learner_states_test : learner_states_test.cc learner_states.hh liblr.a ../mrml/libmrml.a ../strutil/libstrutil.a ../base/libbase.a test_utils.hh 63 | $(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS) 64 | 65 | learner_test : learner_test.cc learner.hh learner_states.hh sparse_vector_tmpl.hh liblr.a ../mrml/libmrml.a ../base/libbase.a test_utils.hh 66 | $(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS) 67 | 68 | termination_flag_test : termination_flag_test.cc termination_flag.hh ../base/libbase.a test_utils.hh 69 | $(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS) 70 | 71 | #------------------------------------------------------------------------------- 72 | # regression test 73 | #------------------------------------------------------------------------------- 74 | dump_learner_states.o : dump_learner_states.cc logistic_regression.pb.h learner_states.hh vector_types.hh 75 | 76 | train.o : train.cc learner_sparse_impl.hh learner_dense_impl.hh learner.hh vector_types.hh learner_states.hh sparse_vector_tmpl.hh 77 | 78 | train : train.o ./liblr.a ../mrml/libmrml.a ../base/libbase.a 79 | $(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS) 80 | 81 | dump_learner_states : dump_learner_states.o liblr.a ../mrml/libmrml.a ../strutil/libstrutil.a ../base/libbase.a 82 | $(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS) 83 | 84 | #------------------------------------------------------------------------------- 85 | # mrml-lr binary 86 | #------------------------------------------------------------------------------- 87 | 88 | mrml_mappers_and_reducers.o : mrml_mappers_and_reducers.hh mrml_mappers_and_reducers.cc ../base/common.hh ../strutil/split_string.hh ../base/common.hh mrml_mappers_and_reducers.hh sparse_vector_tmpl.hh command_line_options.hh 89 | 90 | mrml-lr : mrml_mappers_and_reducers.o mr_assign_feature_id.o mr_convert_data_format.o command_line_options.o liblr.a $(EXTERNAL_LIBS) 91 | $(CXX) $(CXXFLAGS) -o $@ $+ $(LDFLAGS) 92 | 93 | #------------------------------------------------------------------------------- 94 | # make distribution 95 | #------------------------------------------------------------------------------- 96 | 97 | dist : $(BUILD_TARGETS) 98 | rm -rf $(DIST_DIR) 99 | mkdir -p $(DIST_DIR) 100 | cp $(BUILD_TARGETS) $(DIST_DIR) 101 | cp *.hh $(DIST_DIR) 102 | cp *.h $(DIST_DIR) 103 | 104 | clean : 105 | rm -rf $(BUILD_TARGETS) *.o *.a *.exe *~ *.stackdump *.core *.dSYM 106 | -------------------------------------------------------------------------------- /mrml-lasso/Makefile.tencent: -------------------------------------------------------------------------------- 1 | PROTOBUF_BASE_DIR = ../../3rd-party/protobuf-2.3.0 2 | 3 | CXX = g++ -Wall -Wno-sign-compare -O3 4 | AR=ar rcs 5 | PROTOC = $(PROTOBUF_BASE_DIR)/bin/protoc 6 | 7 | PROTOBUF_CXX_FLAGS = -I$(PROTOBUF_BASE_DIR)/include/ 8 | PROTOBUF_LD_FLAGS = -L$(PROTOBUF_BASE_DIR)/lib/ -lprotobuf -pthread 9 | 10 | BOOST_BASE_DIR = ../../3rd-party/boost-1.43.0 11 | BOOST_CXX_FLAGS=-I$(BOOST_BASE_DIR)/include 12 | BOOST_LD_FLAGS=-L$(BOOST_BASE_DIR)/lib/ -lboost_program_options -lboost_regex -lboost_filesystem -lboost_system 13 | 14 | MPI_HOME_DIR=../../3rd-party/mpich2-1.2.1p1 15 | MPI_CXX_FLAGS=-I$(MPI_HOME_DIR)/include 16 | MPI_LD_FLAGS=-L$(MPI_HOME_DIR)/lib -lmpichcxx -lmpich 17 | 18 | SSL_BASE_DIR = ../../3rd-party/openssl-0.9.8o 19 | SSL_CXX_FLAGS = -I$(SSL_BASE_DIR)/include 20 | SSL_LD_FLAGS = -L$(SSL_BASE_DIR)/lib -lssl -lcrypto -lz -ldl 21 | 22 | SSH2_BASE_DIR = ../../3rd-party/libssh2-1.2.6 23 | SSH2_CXX_FLAGS = -I$(SSH2_BASE_DIR)/include 24 | SSH2_LD_FLAGS = -L$(SSH2_BASE_DIR)/lib -lssh2 25 | 26 | CXXFLAGS= -Wall -Wno-sign-compare -O2 $(BOOST_CXX_FLAGS) $(MPI_CXX_FLAGS) $(SSH2_CXX_FLAGS) $(SSL_CXX_FLAGS) $(PROTOBUF_CXX_FLAGS) 27 | LDFLAGS=-static -static-libgcc $(BOOST_LD_FLAGS) $(MPI_LD_FLAGS) $(SSH2_LD_FLAGS) $(SSL_LD_FLAGS) $(PROTOBUF_LD_FLAGS) 28 | 29 | include Makefile.rules 30 | -------------------------------------------------------------------------------- /mrml-lasso/Makefile.ubuntu: -------------------------------------------------------------------------------- 1 | PROTOBUF_BASE_DIR = /home/yiwang/3rd-party/protobuf-2.3.0 2 | 3 | CXX=g++ -Wall -Wno-sign-compare -O2 -DDEBUG_PRINT_TRACE -DDEBUG_PRINT_VARS 4 | AR=ar rcs 5 | PROTOC = $(PROTOBUF_BASE_DIR)/bin/protoc 6 | 7 | PROTOBUF_CXX_FLAGS = -I$(PROTOBUF_BASE_DIR)/include/ 8 | PROTOBUF_LD_FLAGS = -L$(PROTOBUF_BASE_DIR)/lib/ -lprotobuf 9 | 10 | BOOST_BASE_DIR = /home/yiwang/3rd-party/boost-1.43.0 11 | BOOST_CXX_FLAGS=-I$(BOOST_BASE_DIR)/include 12 | BOOST_LD_FLAGS=-L$(BOOST_BASE_DIR)/lib/ -lboost_program_options -lboost_regex -lboost_filesystem -lboost_system 13 | 14 | MPI_HOME_DIR=/home/yiwang/3rd-party/mpich2-1.2.1p1 15 | MPI_CXX_FLAGS=-I$(MPI_HOME_DIR)/include 16 | MPI_LD_FLAGS=-L$(MPI_HOME_DIR)/lib -lmpichcxx -lmpich -lopa 17 | 18 | SSL_SSH2_CXX_FLAGS= 19 | SSL_SSH2_LD_FLAGS=-L/usr/lib -lssh2 -lssl -lgcrypt -lgpg-error -lz 20 | 21 | CXXFLAGS= $(BOOST_CXX_FLAGS) $(MPI_CXX_FLAGS) $(PROTOBUF_CXX_FLAGS) $(SSL_SSH2_CXX_FLAGS) 22 | LDFLAGS=-static -static-libgcc $(BOOST_LD_FLAGS) $(MPI_LD_FLAGS) $(PROTOBUF_LD_FLAGS) $(SSL_SSH2_LD_FLAGS) -L/usr/lib -lpthread 23 | 24 | include Makefile.rules 25 | -------------------------------------------------------------------------------- /mrml-lasso/command_line_options.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "base/common.h" 11 | #include "mrml/mrml_filesystem.h" 12 | #include "mrml/mrml_recordio.h" 13 | 14 | #include "mrml-lasso/command_line_options.h" 15 | 16 | namespace logistic_regression { 17 | 18 | using std::string; 19 | 20 | void CommandLineOptions::Parse(const std::vector& cmdline) { 21 | namespace po = boost::program_options; 22 | po::options_description desc("LR-MPI initialization options"); 23 | desc.add_options() 24 | ("feature_id_file", 25 | po::value(&feature_id_file), 26 | "# feature to id mapping file") 27 | ("states_file_dir", 28 | po::value(&base_dir), 29 | "# where the states files are located") 30 | ("states_filebase", 31 | po::value(&states_filebase), 32 | "# states_filebase") 33 | ("flag_file", 34 | po::value(&flag_file), 35 | "flag_file") 36 | ("memory_size", 37 | po::value(&memory_size)->default_value(10), 38 | "# lr.memory_size") 39 | ("l1weight", 40 | po::value(&l1weight)->default_value(1), 41 | "# lr.l1weight") 42 | ("max_line_search_steps", 43 | po::value(&max_line_search_steps)->default_value(20), 44 | "# lr.max_line_search_steps") 45 | ("max_iterations", 46 | po::value(&max_iterations)->default_value(120), 47 | "# lr.max_iterations") 48 | ("convergence_tolerance", 49 | po::value(&convergence_tolerance)->default_value(1e-4), 50 | "# lr.convergence_tolerance") 51 | ("max_feature_number", 52 | po::value(&max_feature_number)->default_value(0), 53 | "# features, required when learning a dense model"); 54 | po::parsed_options parsed = 55 | po::command_line_parser(cmdline).options(desc).allow_unregistered(). 56 | run(); 57 | po::variables_map vm; 58 | po::store(parsed, vm); 59 | po::notify(vm); 60 | 61 | LOG(INFO) << "command line:" \ 62 | << "\tfeature_id_file:" << feature_id_file \ 63 | << "\tstates_file_dir:" << base_dir \ 64 | << "\tstates_filebase:" << states_filebase \ 65 | << "\tflag_file:" << flag_file \ 66 | << "\tmemory_size:" << memory_size \ 67 | << "\tl1weight:" << l1weight \ 68 | << "\tmax_line_search_steps:" << max_line_search_steps \ 69 | << "\tmax_iterations:" << max_iterations \ 70 | << "\tconvergence_tolerance:" << convergence_tolerance \ 71 | << "\tmax_feature_number:" << max_feature_number; 72 | 73 | CHECK_LT(0, memory_size); 74 | CHECK_LE(0, l1weight); 75 | CHECK_LT(1, max_line_search_steps); 76 | CHECK_LT(1, max_iterations); 77 | CHECK_LT(0, convergence_tolerance); 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /mrml-lasso/command_line_options.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // Encapsulate command line options used by all the MR classes. 5 | // 6 | #ifndef MRML_LASSO_COMMAND_LINE_OPTIONS_H_ 7 | #define MRML_LASSO_COMMAND_LINE_OPTIONS_H_ 8 | 9 | #include 10 | #include 11 | 12 | namespace logistic_regression { 13 | 14 | class CommandLineOptions { 15 | public: 16 | std::string feature_id_file; 17 | std::string base_dir; 18 | std::string states_filebase; 19 | std::string flag_file; 20 | 21 | int memory_size; 22 | double l1weight; 23 | int max_line_search_steps; 24 | int max_iterations; 25 | double convergence_tolerance; 26 | int max_feature_number; // only valid in "leanrer==dense" situation 27 | 28 | void Parse(const std::vector& cmdline); 29 | }; 30 | 31 | } // namespace logistic_regression 32 | 33 | #endif // MRML_LASSO_COMMAND_LINE_OPTIONS_H_ 34 | -------------------------------------------------------------------------------- /mrml-lasso/dense_vector_tmpl.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // Define the class template DenseVectorImpl and operations required 5 | // by class template Learner. 6 | // 7 | #ifndef MRML_LASSO_DENSE_VECTOR_TMPL_H_ 8 | #define MRML_LASSO_DENSE_VECTOR_TMPL_H_ 9 | 10 | #include 11 | 12 | namespace logistic_regression { 13 | using std::ostream; 14 | using std::vector; 15 | 16 | template 17 | class DenseVectorTmpl : public vector { 18 | public: 19 | typedef typename vector::const_iterator const_iterator; 20 | typedef typename vector::iterator iterator; 21 | 22 | DenseVectorTmpl(size_t size, const ValueType& init) 23 | : vector(size, init) {} 24 | 25 | DenseVectorTmpl() 26 | : vector() {} 27 | }; 28 | 29 | // Scale(v,c) : v <- v * c 30 | template 31 | void Scale(DenseVectorTmpl* v, 32 | const ScaleType& c) { 33 | for (size_t i = 0; i < v->size(); ++i) { 34 | (*v)[i] *= c; 35 | } 36 | } 37 | 38 | // ScaleInto(u,v,c) : u <- v * c 39 | template 40 | void ScaleInto(DenseVectorTmpl* u, 41 | const DenseVectorTmpl& v, 42 | const ScaleType& c) { 43 | CHECK_EQ(v.size(), u->size()); 44 | CHECK_LT(0, v.size()); 45 | for (size_t i = 0; i < v.size(); ++i) { 46 | (*u)[i] = v[i] * c; 47 | } 48 | } 49 | 50 | // AddScaled(u,v,c) : u <- u + v * c 51 | template 52 | void AddScaled(DenseVectorTmpl* u, 53 | const DenseVectorTmpl& v, 54 | const ScaleType& c) { 55 | CHECK_EQ(v.size(), u->size()); 56 | CHECK_LT(0, v.size()); 57 | for (size_t i = 0; i < v.size(); ++i) { 58 | (*u)[i] += v[i] * c; 59 | } 60 | } 61 | 62 | // AddScaledInto(w,u,v,c) : w <- u + v * c 63 | template 64 | void AddScaledInto(DenseVectorTmpl* w, 65 | const DenseVectorTmpl& u, 66 | const DenseVectorTmpl& v, 67 | const ScaleType& c) { 68 | CHECK_EQ(u.size(), v.size()); 69 | CHECK_EQ(u.size(), w->size()); 70 | CHECK_LT(0, u.size()); 71 | for (size_t i = 0; i < u.size(); ++i) { 72 | (*w)[i] = u[i] + v[i] * c; 73 | } 74 | } 75 | 76 | // DotProduct(u,v) : r <- dot(u, v) 77 | template 78 | ValueType DotProduct(const DenseVectorTmpl& v1, 79 | const DenseVectorTmpl& v2) { 80 | CHECK_EQ(v1.size(), v2.size()); 81 | ValueType ret = 0; 82 | for (size_t i = 0; i < v1.size(); ++i) { 83 | ret += v1[i] * v2[i]; 84 | } 85 | return ret; 86 | } 87 | 88 | // Output a sparse vector in human readable format. 89 | template 90 | ostream& operator<< (ostream& output, 91 | const DenseVectorTmpl& vec) { 92 | output << "[ "; 93 | for (size_t i = 0; i < vec.size(); ++i) { 94 | if (vec[i] != 0) // to keep the format the same with sparse 95 | output << i << ":" < 5 | 6 | #include "base/common.h" 7 | #include "gtest/gtest.h" 8 | #include "mrml-lasso/dense_vector_tmpl.h" 9 | 10 | using logistic_regression::DenseVectorTmpl; 11 | 12 | typedef DenseVectorTmpl RealVector; 13 | 14 | TEST(DenseVectorTmplTest, Scale) { 15 | RealVector v; 16 | v.push_back(2); 17 | v.push_back(4); 18 | Scale(&v, 0.5); 19 | EXPECT_EQ(v.size(), 2); 20 | EXPECT_EQ(v[0], 1); 21 | EXPECT_EQ(v[1], 2); 22 | } 23 | 24 | TEST(DenseVectorTmplTest, ScaleInto) { 25 | RealVector u, v; 26 | u.push_back(2); 27 | u.push_back(2); 28 | v.push_back(2); 29 | v.push_back(4); 30 | ScaleInto(&u, v, 0.5); 31 | EXPECT_EQ(u.size(), 2); 32 | EXPECT_EQ(u[0], 1); 33 | EXPECT_EQ(u[1], 2); 34 | } 35 | 36 | TEST(DenseVectorTmplTest, AddScaled) { 37 | RealVector u, v; 38 | u.push_back(2); 39 | u.push_back(0); 40 | u.push_back(0); 41 | v.push_back(0); 42 | v.push_back(2); 43 | v.push_back(4); 44 | AddScaled(&u, v, 0.5); 45 | EXPECT_EQ(u.size(), 3); 46 | EXPECT_EQ(u[0], 2); 47 | EXPECT_EQ(u[1], 1); 48 | EXPECT_EQ(u[2], 2); 49 | } 50 | 51 | TEST(DenseVectorTmplTest, AddScaledInto) { 52 | RealVector w, u, v; 53 | w.resize(3, 1); 54 | u.push_back(2); 55 | u.push_back(4); 56 | u.push_back(6); 57 | v.push_back(2); 58 | v.push_back(4); 59 | v.push_back(0); 60 | AddScaledInto(&w, u, v, 0.5); 61 | EXPECT_EQ(w.size(), 3); 62 | EXPECT_EQ(w[0], 3); 63 | EXPECT_EQ(w[1], 6); 64 | EXPECT_EQ(w[2], 6); 65 | } 66 | 67 | TEST(DenseVectorTmplTest, DotProduct) { 68 | RealVector v, u, w; 69 | v.push_back(1); 70 | v.push_back(0); 71 | u.push_back(0); 72 | u.push_back(1); 73 | w.push_back(1); 74 | w.push_back(1); 75 | EXPECT_EQ(DotProduct(v, u), 0); 76 | EXPECT_EQ(DotProduct(u, v), 0); 77 | EXPECT_EQ(DotProduct(v, w), 1); 78 | EXPECT_EQ(DotProduct(u, w), 1); 79 | } 80 | -------------------------------------------------------------------------------- /mrml-lasso/dump_learner_states.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // A utility program that dump a learner states saved in a RecordIO 5 | // file. 6 | // 7 | #include // NOLINT: TODO(yiwang): add C-style print facility 8 | // for sparse vectors. 9 | 10 | #include "boost/program_options/option.hpp" 11 | #include "boost/program_options/options_description.hpp" 12 | #include "boost/program_options/variables_map.hpp" 13 | #include "boost/program_options/parsers.hpp" 14 | 15 | #include "base/common.h" 16 | #include "mrml/mrml_filesystem.h" 17 | #include "mrml/mrml_recordio.h" 18 | 19 | #include "mrml-lasso/logistic_regression.pb.h" 20 | #include "mrml-lasso/learner_states.h" 21 | #include "mrml-lasso/vector_types.h" 22 | 23 | int main(int argc, char* argv[]) { 24 | using std::cout; 25 | using std::string; 26 | using logistic_regression::LearnerStates; 27 | using logistic_regression::SparseRealVector; 28 | using logistic_regression::DenseRealVector; 29 | namespace po = boost::program_options; 30 | 31 | po::options_description desc("Supported options"); 32 | desc.add_options() 33 | ("learner_states_file", po::value(), "the LearnerStatesPB file name") 34 | ("model_only", po::value(), "to dump model only or not"); 35 | po::parsed_options parsed = 36 | po::command_line_parser(argc, argv).options(desc).allow_unregistered(). 37 | run(); 38 | po::variables_map vm; 39 | po::store(parsed, vm); 40 | po::notify(vm); 41 | CHECK(vm.count("learner_states_file")); 42 | CHECK(vm.count("model_only")); 43 | 44 | MRMLFS_File file(vm["learner_states_file"].as(), true); 45 | LearnerStates states; 46 | states.LoadFromRecordFile(&file); 47 | if (vm["model_only"].as()) 48 | cout << states.new_x(); 49 | else 50 | cout << states; 51 | return 0; 52 | } 53 | -------------------------------------------------------------------------------- /mrml-lasso/learner.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | 5 | namespace logistic_regression { 6 | 7 | const char* kErrorNonDescentDirection = 8 | "ERROR: UpdateDir chose a non-descent direction, " 9 | "the line search will break, so we stop here. The " 10 | "likely reason is bug in gradient computation."; 11 | 12 | const char* kEnoughLongLineSearch = 13 | "WARNING: We have done enough number of steps in " 14 | "line search, and have to stop."; 15 | 16 | const char* kEnoughNumberOfIterations = 17 | "WARNING: We have done enough number of iterations."; 18 | 19 | } // namespace logistic_regression 20 | -------------------------------------------------------------------------------- /mrml-lasso/learner_states.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include 5 | 6 | #include 7 | 8 | #include "base/common.h" 9 | #include "mrml/mrml.h" 10 | #include "mrml/mrml_filesystem.h" 11 | #include "mrml/mrml_recordio.h" 12 | 13 | #include "mrml-lasso/learner_states.h" 14 | 15 | namespace logistic_regression { 16 | 17 | void SerializeDoubleToProtoBuf(const double& v, DoublePB* pb) { 18 | pb->set_value(v); 19 | } 20 | 21 | void ParseDoubleFromProtoBuf(const DoublePB& pb, double* v) { 22 | *v = pb.value(); 23 | } 24 | 25 | void SerializeInt32ToProtoBuf(int32 v, Int32PB* pb) { 26 | pb->set_value(v); 27 | } 28 | 29 | void ParseInt32FromProtoBuf(const Int32PB& pb, int32* v) { 30 | *v = pb.value(); 31 | } 32 | 33 | void SerializeDequeToProtoBuf(const std::deque& deque, 34 | DoubleSequencePB* pb) { 35 | pb->Clear(); 36 | for (int i = 0; i < deque.size(); ++i) { 37 | pb->add_value(deque[i]); 38 | } 39 | } 40 | 41 | void ParseDequeFromProtoBuf(const DoubleSequencePB& pb, 42 | std::deque* deque) { 43 | deque->clear(); 44 | for (int i = 0; i < pb.value_size(); ++i) { 45 | deque->push_back(pb.value(i)); 46 | } 47 | } 48 | 49 | ostream& operator<<(ostream& out, const deque& double_deque) { 50 | for (size_t s = 0; s < double_deque.size(); ++s) { 51 | out << s << ":" << double_deque[s] << " "; 52 | } 53 | return out; 54 | } 55 | 56 | ostream& operator<<(ostream& out, const ImprovementFilter& filter) { 57 | out << filter.value_history_; 58 | return out; 59 | } 60 | 61 | void ImprovementFilter::SerializeToProtoBuf(DoubleSequencePB* pb) const { 62 | SerializeDequeToProtoBuf(value_history_, pb); 63 | } 64 | 65 | void ImprovementFilter::ParseFromProtoBuf(const DoubleSequencePB& pb) { 66 | ParseDequeFromProtoBuf(pb, &value_history_); 67 | } 68 | 69 | double ImprovementFilter::GetImprovement(double new_value) { 70 | double ret = std::numeric_limits::infinity(); 71 | 72 | if (value_history_.size() > kNumIterationsToAverage) { 73 | double previous_value = value_history_.front(); 74 | if (value_history_.size() == 2 * kNumIterationsToAverage) { 75 | value_history_.pop_front(); 76 | } 77 | double average_improvement = 78 | (previous_value - new_value) / value_history_.size(); 79 | double relative_average_improvement = 80 | average_improvement / fabs(new_value); 81 | ret = relative_average_improvement; 82 | } 83 | 84 | value_history_.push_back(new_value); 85 | 86 | return ret; 87 | } 88 | 89 | } // namespace logsitic_regression 90 | -------------------------------------------------------------------------------- /mrml-lasso/learner_states_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include 5 | 6 | #include "base/common.h" 7 | #include "gtest/gtest.h" 8 | #include "mrml/mrml.h" 9 | #include "mrml/mrml_filesystem.h" 10 | #include "mrml/mrml_recordio.h" 11 | 12 | #include "mrml-lasso/learner_states.h" 13 | #include "mrml-lasso/test_utils.h" 14 | 15 | namespace logistic_regression { 16 | 17 | using std::istringstream; 18 | using std::ostringstream; 19 | 20 | template 21 | void testLearnerStatesSerialization() { 22 | std::cout << "Run " << __FUNCTION__ << "\n"; 23 | 24 | LearnerStatesTestUtil u; 25 | static const char* kTempFile = "/tmp/testLearnerStatesSerialization"; 26 | { 27 | LearnerStates states; 28 | u.Construct(&states); 29 | MRMLFS_File out(kTempFile, false); 30 | CHECK(out.IsOpen()); 31 | states.SaveIntoRecordFile(&out); 32 | } 33 | { 34 | LearnerStates states; 35 | MRMLFS_File in(kTempFile, true); 36 | states.LoadFromRecordFile(&in); 37 | u.Check(states); 38 | } 39 | } 40 | 41 | } // namespace logsitic_regression 42 | 43 | TEST(LearnerStatesSerializationTest, DenseRealVector) { 44 | using logistic_regression::DenseRealVector; 45 | logistic_regression::testLearnerStatesSerialization(); 46 | } 47 | 48 | TEST(LearnerStatesSerializationTest, SparseRealVector) { 49 | using logistic_regression::SparseRealVector; 50 | logistic_regression::testLearnerStatesSerialization(); 51 | } 52 | -------------------------------------------------------------------------------- /mrml-lasso/logistic_regression.proto: -------------------------------------------------------------------------------- 1 | package logistic_regression; 2 | 3 | // RealVectorPB is the external data structure of both DenseRealVector 4 | // and SparseRealVector (defined in vector_types.hh). 5 | message RealVectorPB { 6 | repeated group Element = 1 { 7 | required uint64 index = 2; 8 | required double value = 3; 9 | } 10 | optional uint32 dim = 4; // Used for DenseRealVector. 11 | } 12 | 13 | message RealVectorPtrDequePB { 14 | required int32 real_size = 1; 15 | repeated group Element = 2 { 16 | required int32 index = 3; 17 | required RealVectorPB vector = 4; 18 | } 19 | } 20 | 21 | message InstancePB { 22 | message Feature { 23 | optional string name = 1; // name := group_name + ':' + feature_name 24 | optional int32 id = 2; // a zero-based successive integer 25 | required double value = 3; 26 | } 27 | required float num_positive = 1; // negative value means 'unlabeled' 28 | required float num_appearance = 2; // float-type allows instance weighting 29 | repeated Feature feature = 3; // the feature vector 30 | } 31 | 32 | message DoublePB { 33 | required double value = 1; 34 | } 35 | 36 | message Int32PB { 37 | required int32 value = 1; 38 | } 39 | 40 | message DoubleSequencePB { 41 | repeated double value = 1; 42 | } 43 | 44 | message ComputeGradientMapperOutputPB { 45 | required double partial_loss = 1; 46 | required RealVectorPB partial_gradient = 2; 47 | } 48 | -------------------------------------------------------------------------------- /mrml-lasso/mr_assign_feature_id.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | #include 6 | #include 7 | 8 | #include // NOLINT. TODO(leostarzhou): Use SplitString in /strutil 9 | // to substitute the usage of istringstream. 10 | #include 11 | 12 | #include "boost/filesystem.hpp" 13 | #include "boost/program_options/option.hpp" 14 | #include "boost/program_options/options_description.hpp" 15 | #include "boost/program_options/variables_map.hpp" 16 | #include "boost/program_options/parsers.hpp" 17 | 18 | #include "base/common.h" 19 | #include "mrml/mrml_filesystem.h" 20 | #include "mrml/mrml_recordio.h" 21 | 22 | #include "mrml-lasso/logistic_regression.pb.h" 23 | #include "mrml-lasso/mr_assign_feature_id.h" 24 | 25 | using std::map; 26 | using std::string; 27 | using std::istringstream; 28 | using std::ostringstream; 29 | using std::stringstream; 30 | using std::setw; 31 | using std::setfill; 32 | 33 | namespace logistic_regression { 34 | 35 | REGISTER_MAPPER(AssignFeatureIDMapper); 36 | REGISTER_REDUCER(AssignFeatureIDReducer); 37 | 38 | AssignFeatureIDMapper::AssignFeatureIDMapper() { 39 | options_.Parse(GetConfig()); 40 | feature_dict_.clear(); 41 | } 42 | 43 | void AssignFeatureIDMapper::Map(const std::string& key, 44 | const std::string& value) { 45 | float num_positive; 46 | float num_appearance; 47 | string feature_name; 48 | double feature_value; 49 | int feature_num = 0; 50 | InstancePB instance; 51 | 52 | if (GetInputFormat() == Text) { 53 | istringstream line_parser(value); 54 | line_parser >> num_positive >> num_appearance; 55 | while (line_parser >> feature_name >> feature_value) { 56 | if (!feature_dict_.count(feature_name)) { 57 | ++feature_num; 58 | feature_dict_[feature_name] = feature_num; 59 | } 60 | } 61 | } else if (GetInputFormat() == RecordIO) { 62 | CHECK(instance.ParseFromString(value)); 63 | for (int i = 0; i < instance.feature_size(); ++i) { 64 | feature_name = instance.feature(i).name(); 65 | if (!feature_dict_.count(feature_name)) { 66 | ++feature_num; 67 | feature_dict_[feature_name] = feature_num; 68 | } 69 | } 70 | } 71 | } 72 | 73 | void AssignFeatureIDMapper::Flush() { 74 | map::const_iterator map_it = feature_dict_.begin(); 75 | while (map_it != feature_dict_.end()) { 76 | Output(kUniqueKey, map_it->first); 77 | ++map_it; 78 | } 79 | } 80 | 81 | AssignFeatureIDReducer::AssignFeatureIDReducer() { 82 | options_.Parse(GetConfig()); 83 | feature_dict_.clear(); 84 | feature_num_ = 0; 85 | } 86 | 87 | void* AssignFeatureIDReducer::BeginReduce(const std::string& key, 88 | const std::string& value) { 89 | if (!feature_dict_.count(value)) { 90 | ++feature_num_; 91 | feature_dict_[value] = feature_num_; 92 | } 93 | return NULL; // No intermediate result. 94 | } 95 | 96 | void AssignFeatureIDReducer::PartialReduce(const std::string& key, 97 | const std::string& value, 98 | void* partial_result) { 99 | if (!feature_dict_.count(value)) { 100 | ++feature_num_; 101 | feature_dict_[value] = feature_num_; 102 | } 103 | } 104 | 105 | void AssignFeatureIDReducer::EndReduce(const std::string& key, 106 | void* partial_result) { 107 | map::const_iterator map_it = feature_dict_.begin(); 108 | while (map_it != feature_dict_.end()) { 109 | ostringstream oss; 110 | if (GetOutputFormat() == RecordIO) { 111 | oss << map_it->second; 112 | Output(map_it->first,oss.str()); 113 | } else { 114 | oss << map_it->first << "\t" << map_it->second; 115 | Output("", oss.str()); 116 | } 117 | ++map_it; 118 | } 119 | } 120 | 121 | } // namespace logistic_regression 122 | -------------------------------------------------------------------------------- /mrml-lasso/mr_assign_feature_id.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // Define the mappers and reducers to assign id for each feature. 5 | // Mappers input: the training data. (data format: text or recordio) 6 | // Reducers output: the feature_id mapping file. (data format: recordio) 7 | 8 | #ifndef MRML_LASSO_MR_ASSIGN_FEATURE_ID_H_ 9 | #define MRML_LASSO_MR_ASSIGN_FEATURE_ID_H_ 10 | 11 | #include 12 | #include 13 | 14 | #include "base/common.h" 15 | #include "strutil/split_string.h" 16 | #include "mrml/mrml.h" 17 | #include "mrml-lasso/command_line_options.h" 18 | 19 | namespace logistic_regression { 20 | 21 | extern const char* kUniqueKey; 22 | 23 | // AssignFeatureIDMapper generates the "feature to id" mapping file. 24 | // It is necessary to run AssignFeatureID and ConvertDataFormat 25 | // MR tasks before mrml-lr starts running. 26 | class AssignFeatureIDMapper : public MRML_Mapper { 27 | public: 28 | AssignFeatureIDMapper(); 29 | void Map(const std::string& key, const std::string& value); 30 | void Flush(); 31 | 32 | private: 33 | std::map feature_dict_; 34 | CommandLineOptions options_; 35 | }; 36 | 37 | class AssignFeatureIDReducer : public MRML_Reducer { 38 | public: 39 | AssignFeatureIDReducer(); 40 | void* BeginReduce(const std::string& key, const std::string& value); 41 | void PartialReduce(const std::string& key, const std::string& value, 42 | void* partial_result); 43 | void EndReduce(const std::string& key, void* partial_result); 44 | 45 | private: 46 | std::map feature_dict_; 47 | CommandLineOptions options_; 48 | int feature_num_; 49 | }; 50 | 51 | } // namespace logistic_regression 52 | 53 | #endif // MRML_LASSO_MR_ASSIGN_FEATURE_ID_H_ 54 | -------------------------------------------------------------------------------- /mrml-lasso/mr_convert_data_format.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | #include 5 | #include 6 | 7 | #include // NOLINT. TODO(leostarzhou): Use SplitString in /strutil 8 | // to substitute the use of istringstream. 9 | #include 10 | 11 | #include "boost/filesystem.hpp" 12 | #include "boost/program_options/option.hpp" 13 | #include "boost/program_options/options_description.hpp" 14 | #include "boost/program_options/variables_map.hpp" 15 | #include "boost/program_options/parsers.hpp" 16 | 17 | #include "base/common.h" 18 | #include "mrml/mrml_filesystem.h" 19 | #include "mrml/mrml_recordio.h" 20 | 21 | #include "mrml-lasso/logistic_regression.pb.h" 22 | #include "mrml-lasso/mr_convert_data_format.h" 23 | 24 | using std::map; 25 | using std::string; 26 | using std::istringstream; 27 | using std::ostringstream; 28 | using std::stringstream; 29 | using std::setw; 30 | using std::setfill; 31 | 32 | namespace logistic_regression { 33 | 34 | REGISTER_MAPPER(ConvertDataFormatMapper); 35 | 36 | ConvertDataFormatMapper::ConvertDataFormatMapper() { 37 | options_.Parse(GetConfig()); 38 | feature_dict_.clear(); 39 | MRMLFS_File file(options_.feature_id_file, true); 40 | string feature_name; 41 | string feature_id; 42 | while ( true == MRML_ReadRecord(&file, &feature_name, &feature_id) ) { 43 | feature_dict_[feature_name] = atoi(feature_id.c_str()); 44 | } 45 | } 46 | 47 | void ConvertDataFormatMapper::Map(const std::string& key, 48 | const std::string& value) { 49 | float num_positive; 50 | float num_appearance; 51 | string feature_name; 52 | double feature_value; 53 | InstancePB instance; 54 | ostringstream oss; 55 | InstancePB pb; 56 | 57 | if (GetInputFormat() == Text) { 58 | istringstream line_parser(value); 59 | line_parser >> num_positive >> num_appearance; 60 | pb.set_num_positive(num_positive); 61 | pb.set_num_appearance(num_appearance); 62 | while (line_parser >> feature_name >> feature_value) { 63 | if (!feature_dict_.count(feature_name)) { 64 | LOG(ERROR) << "could not find feature[" << feature_name << "]'s id."; 65 | return; 66 | } 67 | InstancePB::Feature* e = pb.add_feature(); 68 | e->set_name(feature_name); 69 | e->set_id(feature_dict_[feature_name]); 70 | e->set_value(feature_value); 71 | } 72 | string output_buffer; 73 | pb.SerializeToString(&output_buffer); 74 | Output(kUniqueKey, output_buffer); 75 | } else if (GetInputFormat() == RecordIO) { 76 | CHECK(instance.ParseFromString(value)); 77 | pb.set_num_positive(instance.num_positive()); 78 | pb.set_num_appearance(instance.num_appearance()); 79 | for (int i = 0; i < instance.feature_size(); ++i) { 80 | feature_name = instance.feature(i).name(); 81 | if (!feature_dict_.count(feature_name)) { 82 | LOG(ERROR) << "could not find feature[" << feature_name << "]'s id."; 83 | return; 84 | } 85 | InstancePB::Feature* e = pb.add_feature(); 86 | e->set_name(feature_name); 87 | e->set_id(feature_dict_[feature_name]); 88 | e->set_value(instance.feature(i).value()); 89 | } 90 | string output_buffer; 91 | pb.SerializeToString(&output_buffer); 92 | Output(kUniqueKey, output_buffer); 93 | } 94 | } 95 | 96 | } // namespace logistic_regression 97 | -------------------------------------------------------------------------------- /mrml-lasso/mr_convert_data_format.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // Define the mappers only to convert features to id in traing data. 5 | // Mappers input: the training data. (data format: text or recordio) 6 | // Mappers output: the new training data in which features has been 7 | // converted from string to id. (data format: recordio) 8 | // External input: the feature_id mapping file. (data format: recordio) 9 | 10 | #ifndef MRML_LASSO_MR_CONVERT_DATA_FORMAT_H_ 11 | #define MRML_LASSO_MR_CONVERT_DATA_FORMAT_H_ 12 | 13 | #include 14 | #include 15 | 16 | #include "base/common.h" 17 | #include "strutil/split_string.h" 18 | #include "mrml/mrml.h" 19 | #include "mrml-lasso/command_line_options.h" 20 | 21 | namespace logistic_regression { 22 | 23 | extern const char* kUniqueKey; 24 | 25 | class ConvertDataFormatMapper : public MRML_Mapper { 26 | public: 27 | ConvertDataFormatMapper(); 28 | void Map(const std::string& key, const std::string& value); 29 | void Flush() {} 30 | 31 | private: 32 | std::map feature_dict_; 33 | CommandLineOptions options_; 34 | }; 35 | 36 | } // namespace logistic_regression 37 | 38 | #endif // MRML_LASSO_MR_CONVERT_DATA_FORMAT_H_ 39 | -------------------------------------------------------------------------------- /mrml-lasso/mrml_mappers_and_reducers.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // Define the mappers and reducers using class template Learner. 5 | // 6 | #ifndef MRML_LASSO_MRML_MAPPERS_AND_REDUCERS_H_ 7 | #define MRML_LASSO_MRML_MAPPERS_AND_REDUCERS_H_ 8 | 9 | #include 10 | 11 | #include "base/common.h" 12 | #include "strutil/split_string.h" 13 | #include "mrml/mrml.h" 14 | #include "mrml-lasso/learner.h" 15 | #include "mrml-lasso/learner_sparse_impl.h" 16 | #include "mrml-lasso/learner_dense_impl.h" 17 | #include "mrml-lasso/command_line_options.h" 18 | 19 | namespace logistic_regression { 20 | 21 | // All map outputs have the same key, and will be reduced by a 22 | // unique reduce worker. This ensures the gradient and value is 23 | // computed by summation over all training instances. 24 | extern const char* kUniqueKey; 25 | 26 | // ComputeGradientMapper computes the value and the gradient of the 27 | // logistic loss function (without the regularization term). 28 | template 29 | class ComputeGradientMapper : public MRML_Mapper { 30 | public: 31 | ComputeGradientMapper(); 32 | void Map(const std::string& key, const std::string& value); 33 | void Flush(); 34 | 35 | private: 36 | RealVector feature_weights_; // The model parameters. 37 | DenseRealVector combined_gradient_; 38 | double combined_loss_; 39 | 40 | SparseRealVector partial_gradient_; 41 | LearnerStates states_; 42 | 43 | CommandLineOptions options_; 44 | }; 45 | 46 | 47 | // Given the value and gradient of the logistic loss function computed 48 | // by ComputeGradientMapper, UpdateModelReducer either (i) initializes 49 | // the model, (ii) determine a gradient descent direction, or (iii) 50 | // does a line search prob step. 51 | template 52 | class UpdateModelReducer : public MRML_Reducer { 53 | public: 54 | UpdateModelReducer() { options_.Parse(GetConfig()); } 55 | void* BeginReduce(const std::string& key, const std::string& value); 56 | void PartialReduce(const std::string& key, const std::string& value, 57 | void* partial_result); 58 | void EndReduce(const std::string& key, void* partial_result); 59 | 60 | private: 61 | struct PartialReduceInfo { 62 | std::string word_; 63 | Learner *learner; 64 | double value; 65 | RealVector gradient; 66 | }; 67 | 68 | RealVector initial_x_; 69 | RealVector partial_gradient_; 70 | CommandLineOptions options_; 71 | }; 72 | 73 | class ComputeDenseGradientMapper 74 | : public ComputeGradientMapper { 75 | }; 76 | 77 | class UpdateDenseModelReducer 78 | : public UpdateModelReducer { 79 | }; 80 | 81 | class ComputeSparseGradientMapper 82 | : public ComputeGradientMapper { 83 | }; 84 | 85 | class UpdateSparseModelReducer 86 | : public UpdateModelReducer { 87 | }; 88 | 89 | } // namespace logistic_regression 90 | 91 | #endif // MRML_LASSO_MRML_MAPPERS_AND_REDUCERS_H_ 92 | -------------------------------------------------------------------------------- /mrml-lasso/prediction_engine.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include // NOLINT. TODO(leostarzhou): Use fopen API. 9 | #include 10 | #include 11 | #include 12 | 13 | #include "mrml-lasso/prediction_engine.h" 14 | 15 | using std::map; 16 | using std::vector; 17 | using std::string; 18 | using std::ifstream; 19 | using std::istringstream; 20 | 21 | namespace logistic_regression { 22 | 23 | PredictionEngine::PredictionEngine(const string& model_file_name) { 24 | ifstream input(model_file_name.c_str()); 25 | if (!input.is_open()) { 26 | printf("Fatal:model file name is valid!\n"); 27 | return; 28 | }; 29 | 30 | feature_weights_.clear(); 31 | string line; 32 | string value; 33 | getline(input, line); 34 | istringstream line_parser(line); 35 | size_t pos = 0; 36 | while (line_parser >> value) { 37 | if ((pos = value.find(":", 0)) != string::npos) { 38 | string feature_name = value.substr(0, pos); 39 | string feature_value = value.substr(pos+1, value.size()+1-pos); 40 | int index = atoi(feature_name.c_str()); 41 | double val = atof(feature_value.c_str()); 42 | feature_weights_[index] = val; 43 | } 44 | } 45 | } 46 | 47 | PredictionEngine::~PredictionEngine() { 48 | feature_weights_.clear(); 49 | } 50 | 51 | int PredictionEngine::Predict(const vector& feature_names, 52 | const vector& feature_values, 53 | double& ret) { 54 | if (feature_names.size() != feature_values.size()) { 55 | printf("Fatal: input parameters are invalid!\n"); 56 | return -1; 57 | } 58 | ret = 0; 59 | vector::const_iterator iter1 = feature_names.begin(); 60 | vector::const_iterator iter2 = feature_values.begin(); 61 | while (iter1 != feature_names.end()) { 62 | ret += *iter2 * feature_weights_[*iter1]; 63 | ++iter1; 64 | ++iter2; 65 | } 66 | ret = 1.0/(1 + exp(-ret)); 67 | return 0; 68 | } 69 | 70 | } // namespace logistic_regression 71 | -------------------------------------------------------------------------------- /mrml-lasso/prediction_engine.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangkuiyi/lasso/7cb55b55e34643e9ffbe868f503711be86ef1b26/mrml-lasso/prediction_engine.h -------------------------------------------------------------------------------- /mrml-lasso/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #mpd --listenport=55555 & 5 | #cd PROCALL 6 | #./proc_cmd_run.sh < tlist.log "mv /home/relmlr/leostarzhou/data/output/output-*-of-00005 /home/relmlr/leostarzhou/data/input" 7 | #cd .. 8 | #exit 9 | 10 | #scp input-00000-of-00005 relmlr@172.26.3.109#36000:/home/relmlr/leostarzhou/data/input 11 | #scp input-00001-of-00005 relmlr@172.26.3.107#36000:/home/relmlr/leostarzhou/data/input 12 | #scp input-00002-of-00005 relmlr@172.26.3.108#36000:/home/relmlr/leostarzhou/data/input 13 | #scp input-00003-of-00005 relmlr@172.26.3.110#36000:/home/relmlr/leostarzhou/data/input 14 | #scp input-00004-of-00005 relmlr@172.26.3.111#36000:/home/relmlr/leostarzhou/data/input 15 | #exit 16 | 17 | rm /home/relmlr/leostarzhou/data/base_dir/* 18 | cd PROCALL 19 | ./proc_cmd_run.sh < tlist.log "rm /home/relmlr/leostarzhou/data/base_dir/*" 20 | cd .. 21 | date_start=$(date +%s) 22 | 23 | rm flag_file 24 | [[ -e flag_file ]] &&{ 25 | rm flag_file 26 | } || { 27 | 28 | while [ ! -e flag_file ] 29 | do 30 | 31 | file=`ls -th /home/relmlr/leostarzhou/data/base_dir/ | head -1 | awk '{print $NF}'` 32 | echo $file 33 | ip_set="172.26.3.110 172.26.3.111 172.26.3.108 172.26.3.107" 34 | for ip in $ip_set 35 | do 36 | scp /home/relmlr/leostarzhou/data/base_dir/${file} relmlr@${ip}#36000:/home/relmlr/leostarzhou/data/base_dir 37 | done 38 | 39 | mpiexec -machinefile ./machine-list -np 6 mrml_mappers_and_reducers \ 40 | --mrml_num_map_workers=5 \ 41 | --mrml_num_reduce_workers=1 \ 42 | --mrml_input_format=text \ 43 | --mrml_output_format=recordio \ 44 | --mrml_input_filebase=/home/relmlr/leostarzhou/data/input/input \ 45 | --mrml_output_filebase=/home/relmlr/leostarzhou/data/output/output \ 46 | --mrml_mapper_class=ComputeGradientMapper_dense \ 47 | --mrml_reducer_class=UpdateModelReducer_dense \ 48 | --mrml_log_filebase=/home/relmlr/leostarzhou/data/lr-log \ 49 | --states_file_dir=/home/relmlr/leostarzhou/data/base_dir \ 50 | --flag_file=flag_file \ 51 | --states_filebase=states_filebase \ 52 | --max_fea_num=100001 53 | done 54 | } 55 | 56 | date_end=$(date +%s) 57 | echo "using time :" 58 | echo $(($date_end-$date_start)) 59 | -------------------------------------------------------------------------------- /mrml-lasso/sparse_vector_tmpl.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // Define the class template DenseVectorImpl and operations required 5 | // by class template Learner. 6 | // 7 | #ifndef MRML_LASSO_SPARSE_VECTOR_TMPL_H_ 8 | #define MRML_LASSO_SPARSE_VECTOR_TMPL_H_ 9 | 10 | #include 11 | 12 | namespace logistic_regression { 13 | 14 | using std::map; 15 | using std::pair; 16 | using std::ostream; 17 | 18 | // We use std::map as the container of a sparse vector. Because 19 | // std::map implements a sorting tree (RB-tree), so its iterator 20 | // accesses elements in known order of keys. This is important for 21 | // sparse vector operations like dot-product and add-multi-into. 22 | // The ValueType must be a numerical type supporting const 0. 23 | template 24 | class SparseVectorTmpl : public map { 25 | public: 26 | typedef typename map::const_iterator const_iterator; 27 | typedef typename map::iterator iterator; 28 | 29 | // We constrain operator[] a read-only operation to prevent 30 | // accidential insert of elements. 31 | const ValueType& operator[](const KeyType& key) const { 32 | const_iterator iter = this->find(key); 33 | if (iter == this->end()) { 34 | return zero_; 35 | } 36 | return iter->second; 37 | } 38 | 39 | // Set a value at given key. If value==0, an exisiting key-value 40 | // pair is removed. If value!=0, the value is set or inserted. 41 | // This function also serves as a convenient form of insert(), no 42 | // need to use std::pair. 43 | void set(const KeyType& key, const ValueType& value) { 44 | iterator iter = this->find(key); 45 | if (iter != this->end()) { 46 | if (IsZero(value)) { 47 | this->erase(iter); 48 | } else { 49 | iter->second = value; 50 | } 51 | } else { 52 | if (!IsZero(value)) { 53 | this->insert(pair(key, value)); 54 | } 55 | } 56 | } 57 | 58 | bool has(const KeyType& key) const { 59 | return this->find(key) != this->end(); 60 | } 61 | 62 | protected: 63 | static const ValueType zero_; 64 | 65 | static bool IsZero(const ValueType& value) { 66 | // Once, we used zero-judgement like: 67 | // static double kEpsilon = 1e-12; 68 | // return (value - zero_) * (value - zero_) < kEpsilon; 69 | // however, this does not work as exactly as (value==0). 70 | return value == 0; 71 | } 72 | }; 73 | 74 | 75 | template 76 | const ValueType SparseVectorTmpl::zero_(0); 77 | 78 | 79 | // Scale(v,c) : v <- v * c 80 | template 81 | void Scale(SparseVectorTmpl* v, 82 | const ScaleType& c) { 83 | typedef SparseVectorTmpl SV; 84 | for (typename SV::iterator i = v->begin(); i != v->end(); ++i) { 85 | i->second *= c; 86 | } 87 | } 88 | 89 | // ScaleInto(u,v,c) : u <- v * c 90 | template 91 | void ScaleInto(SparseVectorTmpl* u, 92 | const SparseVectorTmpl& v, 93 | const ScaleType& c) { 94 | typedef SparseVectorTmpl SV; 95 | u->clear(); 96 | for (typename SV::const_iterator i = v.begin(); i != v.end(); ++i) { 97 | u->set(i->first, i->second * c); 98 | } 99 | } 100 | 101 | // AddScaled(u,v,c) : u <- u + v * c 102 | template 103 | void AddScaled(SparseVectorTmpl* u, 104 | const SparseVectorTmpl& v, 105 | const ScaleType& c) { 106 | typedef SparseVectorTmpl SV; 107 | for (typename SV::const_iterator i = v.begin(); i != v.end(); ++i) { 108 | u->set(i->first, (*u)[i->first] + i->second * c); 109 | } 110 | } 111 | 112 | // AddScaledInto(w,u,v,c) : w <- u + v * c 113 | template 114 | void AddScaledInto(SparseVectorTmpl* w, 115 | const SparseVectorTmpl& u, 116 | const SparseVectorTmpl& v, 117 | const ScaleType& c) { 118 | typedef SparseVectorTmpl SV; 119 | w->clear(); 120 | typename SV::const_iterator i = u.begin(); 121 | typename SV::const_iterator j = v.begin(); 122 | while (i != u.end() && j != v.end()) { 123 | if (i->first == j->first) { 124 | w->set(i->first, i->second + j->second * c); 125 | ++i; 126 | ++j; 127 | } else if (i->first < j->first) { 128 | w->set(i->first, i->second); 129 | ++i; 130 | } else { 131 | w->set(j->first, j->second * c); 132 | ++j; 133 | } 134 | } 135 | while (i != u.end()) { 136 | w->set(i->first, i->second); 137 | ++i; 138 | } 139 | while (j != v.end()) { 140 | w->set(j->first, j->second * c); 141 | ++j; 142 | } 143 | } 144 | 145 | // DotProduct(u,v) : r <- dot(u, v) 146 | template 147 | ValueType DotProduct(const SparseVectorTmpl& v1, 148 | const SparseVectorTmpl& v2) { 149 | typedef SparseVectorTmpl SV; 150 | typename SV::const_iterator i = v1.begin(); 151 | typename SV::const_iterator j = v2.begin(); 152 | ValueType ret = 0; 153 | while (i != v1.end() && j != v2.end()) { 154 | if (i->first == j->first) { 155 | ret += i->second * j->second; 156 | ++i; 157 | ++j; 158 | } else if (i->first < j->first) { 159 | ++i; 160 | } else { 161 | ++j; 162 | } 163 | } 164 | return ret; 165 | } 166 | 167 | // Output a sparse vector in human readable format. 168 | template 169 | ostream& operator<<(ostream& output, 170 | const SparseVectorTmpl& vec) { 171 | typedef SparseVectorTmpl SV; 172 | output << "[ "; 173 | for (typename SV::const_iterator i = vec.begin(); i != vec.end(); ++i) { 174 | output << i->first << ":" << i->second << " "; 175 | } 176 | output << "]"; 177 | return output; 178 | } 179 | 180 | } // namespace logistic_regression 181 | 182 | #endif // MRML_LASSO_SPARSE_VECTOR_TMPL_H_ 183 | -------------------------------------------------------------------------------- /mrml-lasso/sparse_vector_tmpl_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include 5 | 6 | #include "base/common.h" 7 | #include "gtest/gtest.h" 8 | #include "mrml-lasso/sparse_vector_tmpl.h" 9 | 10 | using logistic_regression::SparseVectorTmpl; 11 | 12 | typedef SparseVectorTmpl RealVector; 13 | 14 | TEST(SparseVectorTmplTest, SquareBrackets) { 15 | RealVector v; 16 | v.set(101, 1); 17 | EXPECT_EQ(v[101], 1); 18 | EXPECT_EQ(v[102], 0); 19 | EXPECT_EQ(v.has(101), true); 20 | EXPECT_EQ(v.has(102), false); 21 | } 22 | 23 | TEST(SparseVectorTmplTest, Set) { 24 | RealVector v; 25 | EXPECT_EQ(v.size(), 0); 26 | v.set(101, 0); 27 | EXPECT_EQ(v.size(), 0); 28 | EXPECT_EQ(v.has(101), false); 29 | v.set(101, 1); 30 | EXPECT_EQ(v.size(), 1); 31 | EXPECT_EQ(v.has(101), true); 32 | EXPECT_EQ(v[101], 1); 33 | v.set(101, 2); 34 | EXPECT_EQ(v.size(), 1); 35 | EXPECT_EQ(v.has(101), true); 36 | EXPECT_EQ(v[101], 2); 37 | v.set(101, 0); 38 | EXPECT_EQ(v.size(), 0); 39 | EXPECT_EQ(v.has(101), false); 40 | } 41 | 42 | TEST(SparseVectorTmplTest, Scale) { 43 | RealVector v; 44 | v.set(101, 2); 45 | v.set(102, 4); 46 | Scale(&v, 0.5); 47 | EXPECT_EQ(v.size(), 2); 48 | EXPECT_EQ(v[101], 1); 49 | EXPECT_EQ(v[102], 2); 50 | } 51 | 52 | TEST(SparseVectorTmplTest, ScaleInto) { 53 | RealVector u, v; 54 | u.set(200, 2); 55 | v.set(101, 2); 56 | v.set(102, 4); 57 | ScaleInto(&u, v, 0.5); 58 | EXPECT_EQ(u.size(), 2); 59 | EXPECT_EQ(u[101], 1); 60 | EXPECT_EQ(u[102], 2); 61 | } 62 | 63 | TEST(SparseVectorTmplTest, AddScaled) { 64 | RealVector u, v; 65 | u.set(200, 2); 66 | v.set(101, 2); 67 | v.set(102, 4); 68 | AddScaled(&u, v, 0.5); 69 | EXPECT_EQ(u.size(), 3); 70 | EXPECT_EQ(u[200], 2); 71 | EXPECT_EQ(u[101], 1); 72 | EXPECT_EQ(u[102], 2); 73 | } 74 | 75 | TEST(SparseVectorTmplTest, AddScaledInto) { 76 | RealVector w, u, v; 77 | w.set(200, 100); 78 | u.set(101, 2); 79 | u.set(102, 4); 80 | u.set(301, 8); 81 | u.set(302, 100); 82 | v.set(101, 2); 83 | v.set(103, 6); 84 | v.set(301, 8); 85 | AddScaledInto(&w, u, v, 0.5); 86 | EXPECT_EQ(w.size(), 5); 87 | EXPECT_EQ(w[101], 3); 88 | EXPECT_EQ(w[102], 4); 89 | EXPECT_EQ(w[103], 3); 90 | EXPECT_EQ(w[301], 12); 91 | EXPECT_EQ(w[302], 100); 92 | } 93 | 94 | TEST(SparseVectorTmplTest, DotProduct) { 95 | RealVector v, u, w; 96 | v.set(101, 2); 97 | v.set(102, 4); 98 | v.set(301, 9); 99 | v.set(302, 100); 100 | u.set(101, 2); 101 | u.set(103, 6); 102 | u.set(301, 9); 103 | w.set(200, 10); 104 | EXPECT_EQ(DotProduct(v, u), 85); 105 | EXPECT_EQ(DotProduct(u, v), 85); 106 | EXPECT_EQ(DotProduct(v, w), 0); 107 | EXPECT_EQ(DotProduct(u, w), 0); 108 | } 109 | 110 | 111 | -------------------------------------------------------------------------------- /mrml-lasso/tags: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangkuiyi/lasso/7cb55b55e34643e9ffbe868f503711be86ef1b26/mrml-lasso/tags -------------------------------------------------------------------------------- /mrml-lasso/termination_flag.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #ifndef MRML_LASSO_TERMINATION_FLAG_H_ 5 | #define MRML_LASSO_TERMINATION_FLAG_H_ 6 | 7 | #include // NOLINT: TODO(yiwang) use fopen to create the flag file. 8 | #include "mrml-lasso/learner_states.h" 9 | 10 | namespace logistic_regression { 11 | 12 | // TerminationFlag creates a text file on local filesystem to 13 | // indicate the termination of a training procedure. The content of 14 | // the file is the model parameters. 15 | template 16 | class TerminationFlag { 17 | public: 18 | static bool SetLocally(const char* termination_filename, 19 | const char* termination_reason, 20 | const LearnerStates* states) { 21 | std::ofstream output(termination_filename); 22 | if (!output.is_open()) { 23 | LOG(ERROR) << "Cannot create termination flag file: " 24 | << termination_filename; 25 | return false; 26 | } 27 | 28 | output << termination_reason << "\n"; 29 | if (states != NULL) { 30 | output << "x = " << states->x() << "\n" 31 | << "new_x = " << states->new_x() << "\n"; 32 | } 33 | return true; 34 | } 35 | }; 36 | 37 | } // namespace logistic_regression 38 | 39 | #endif // MRML_LASSO_TERMINATION_FLAG_H_ 40 | -------------------------------------------------------------------------------- /mrml-lasso/termination_flag_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "base/common.h" 10 | #include "gtest/gtest.h" 11 | #include "mrml-lasso/learner_states.h" 12 | #include "mrml-lasso/termination_flag.h" 13 | #include "mrml-lasso/vector_types.h" 14 | 15 | using std::string; 16 | using std::cout; 17 | 18 | namespace logistic_regression { 19 | 20 | template 21 | void testTerminationFlag() { 22 | typedef TerminationFlag > TermFlag; 23 | 24 | static const char* kTempFilename = "/tmp/tmp_term_flag"; 25 | 26 | cout << "Run " << __FUNCTION__ << "\n"; 27 | 28 | // Remove the flag file to ensure that it does not exist. 29 | string cmd_remove_file = string("rm ") + kTempFilename; 30 | system(cmd_remove_file.c_str()); 31 | 32 | // Create the flag file. 33 | EXPECT_EQ(true, 34 | TermFlag::SetLocally(kTempFilename, 35 | "Somewhat reasons.\n", 36 | NULL)); 37 | 38 | // Check the existence of the flag file. 39 | string cmd_test_file = string("test -e ") + kTempFilename; 40 | int result = system(cmd_test_file.c_str()); 41 | EXPECT_EQ(WEXITSTATUS(result), 0); 42 | 43 | // Remove the temp file 44 | system(cmd_remove_file.c_str()); 45 | } 46 | 47 | } // namespace logstic_regression 48 | 49 | int main(int argc, char** argv) { 50 | using logistic_regression::testTerminationFlag; 51 | using logistic_regression::SparseRealVector; 52 | using logistic_regression::DenseRealVector; 53 | 54 | testTerminationFlag(); 55 | testTerminationFlag(); 56 | return 0; 57 | } 58 | -------------------------------------------------------------------------------- /mrml-lasso/testdata/tiny: -------------------------------------------------------------------------------- 1 | 10 10 2 1 5 0 2 | 0 10 2 0 5 1 3 | -------------------------------------------------------------------------------- /mrml-lasso/vector_types.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // Here defines classes DenseRealVector and SparseRealVector, which 5 | // will be used in class templates: RealVectorPtrDeque and 6 | // LearnerStates. 7 | // 8 | // Note that both classes serialize to the same protocol message 9 | // RealVectorPB; this allows LearnerStates and 10 | // LearnerStates share the same protocol message 11 | // LearnerStatesPB. 12 | // 13 | #ifndef MRML_LASSO_VECTOR_TYPES_H_ 14 | #define MRML_LASSO_VECTOR_TYPES_H_ 15 | 16 | #include 17 | 18 | #include "base/common.h" 19 | #include "mrml/mrml_filesystem.h" 20 | #include "mrml/mrml_recordio.h" 21 | #include "mrml-lasso/sparse_vector_tmpl.h" 22 | #include "mrml-lasso/dense_vector_tmpl.h" 23 | 24 | namespace logistic_regression { 25 | 26 | const int kMessageSize = 4000000; 27 | 28 | class RealVectorPB; 29 | 30 | typedef uint32 IndexType; // The index type in SparseRealVector 31 | 32 | //--------------------------------------------------------------------------- 33 | // DenseRealVector, a vector realization of DenseVectorTmpl. 34 | //--------------------------------------------------------------------------- 35 | class DenseRealVector : public DenseVectorTmpl { 36 | public: 37 | typedef std::vector::const_iterator const_iterator; 38 | typedef std::vector::iterator iterator; 39 | 40 | DenseRealVector(size_t size, const double& init) 41 | : DenseVectorTmpl(size, init) {} 42 | DenseRealVector() 43 | : DenseVectorTmpl() {} 44 | 45 | void SerializeToProtoBuf(RealVectorPB* pb) const; 46 | void SerializeToRecordIO(MRMLFS_File* file, 47 | const std::string& key_base) const; 48 | void ParseFromProtoBuf(const RealVectorPB& pb); 49 | void ParseFromRecordIO(MRMLFS_File* file, 50 | const std::string& key_base, int32& vec_size); 51 | }; 52 | 53 | //--------------------------------------------------------------------------- 54 | // SparseRealVector, a map realization of SparseVectorTmpl. 55 | //--------------------------------------------------------------------------- 56 | class SparseRealVector : public SparseVectorTmpl { 57 | public: 58 | typedef SparseVectorTmpl::const_iterator const_iterator; 59 | typedef SparseVectorTmpl::iterator iterator; 60 | void SerializeToProtoBuf(RealVectorPB* pb) const; 61 | void SerializeToRecordIO(MRMLFS_File* file, 62 | const std::string& key_base) const; 63 | void ParseFromProtoBuf(const RealVectorPB& pb); 64 | void ParseFromRecordIO(MRMLFS_File* file, 65 | const std::string& key_base, int32& vec_size); 66 | }; 67 | 68 | 69 | //--------------------------------------------------------------------------- 70 | // Vector operations that accepts a dense vector and a sparse 71 | // vector. Note that operations defined in 72 | // sparse/dense-vector-impl.h accept either sparse or dense operands. 73 | //--------------------------------------------------------------------------- 74 | double DotProduct(const SparseRealVector& sv, const DenseRealVector& dv); 75 | void AddScaled(DenseRealVector* dv, const SparseRealVector& sv, double f); 76 | 77 | } // namespace logistic_regression 78 | 79 | #endif // MRML_LASSO_VECTOR_TYPES_H_ 80 | -------------------------------------------------------------------------------- /mrml-lasso/vector_types_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include "mrml-lasso/logistic_regression.pb.h" 5 | #include "mrml-lasso/test_utils.h" 6 | #include "mrml-lasso/vector_types.h" 7 | 8 | using logistic_regression::RealVectorTestUtil; 9 | using logistic_regression::DenseRealVector; 10 | using logistic_regression::SparseRealVector; 11 | using logistic_regression::RealVectorPB; 12 | 13 | TEST(VectorTypesTest, DenseVectorSerialization) { 14 | RealVectorTestUtil u; 15 | DenseRealVector v; 16 | u.Construct(&v); 17 | 18 | RealVectorPB pb; 19 | v.SerializeToProtoBuf(&pb); 20 | u.CheckProtoBuf(pb); 21 | 22 | v.clear(); 23 | v.ParseFromProtoBuf(pb); 24 | u.Check(v); 25 | } 26 | 27 | TEST(VectorTypesTest, SparseVectorSerialization) { 28 | RealVectorTestUtil u; 29 | SparseRealVector v; 30 | u.Construct(&v); 31 | 32 | RealVectorPB pb; 33 | v.SerializeToProtoBuf(&pb); 34 | u.CheckProtoBuf(pb); 35 | 36 | v.ParseFromProtoBuf(pb); 37 | u.Check(v); 38 | } 39 | 40 | TEST(VectorTypesTest, HybridDotProduct) { 41 | RealVectorTestUtil su; 42 | SparseRealVector sv; 43 | su.Construct(&sv); 44 | 45 | RealVectorTestUtil du; 46 | DenseRealVector dv; 47 | du.Construct(&dv); 48 | 49 | EXPECT_EQ(0, DotProduct(sv, dv)); 50 | } 51 | 52 | TEST(VectorTypesTest, HybridAddScaled) { 53 | RealVectorTestUtil su; 54 | SparseRealVector sv; 55 | su.Construct(&sv); 56 | 57 | RealVectorTestUtil du; 58 | DenseRealVector dv; 59 | du.Construct(&dv); 60 | 61 | AddScaled(&dv, sv, 2); 62 | 63 | EXPECT_EQ(4, dv.size()); 64 | EXPECT_EQ(10, dv[0]); 65 | EXPECT_EQ(20, dv[1]); 66 | EXPECT_EQ(30, dv[2]); 67 | EXPECT_EQ(60, dv[3]); 68 | } 69 | -------------------------------------------------------------------------------- /mrml/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Build protobuf 2 | protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS mrml.proto) 3 | 4 | # Build library mrml. 5 | add_library(mrml mrml_filesystem.cc mrml_reader.cc mrml.cc ${PROTO_SRCS} mrml_recordio.cc) 6 | add_library(mrml-main mrml_main.cc) 7 | 8 | # Build unittests. 9 | set(LIBS mrml sorted_buffer strutil hash base mpichcxx mpich opa ssh2 ssl crypto z dl boost_program_options boost_regex boost_filesystem boost_system protobuf gflags gtest pthread) 10 | 11 | add_executable(mrml_recordio_test mrml_recordio_test.cc) 12 | target_link_libraries(mrml_recordio_test gtest_main ${LIBS}) 13 | 14 | add_executable(mrml_filesystem_test mrml_filesystem_test.cc) 15 | target_link_libraries(mrml_filesystem_test gtest_main ${LIBS}) 16 | 17 | # Build utility codex 18 | add_executable(codex codex.cc) 19 | target_link_libraries(codex ${LIBS}) 20 | 21 | # Install library and header files 22 | install(TARGETS mrml DESTINATION bin/mrml) 23 | FILE(GLOB HEADER_FILES "${CMAKE_CURRENT_SOURCE_DIR}/*.h") 24 | install(FILES ${HEADER_FILES} DESTINATION include/mrml) 25 | -------------------------------------------------------------------------------- /mrml/mr.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #ifndef MRML_MR_H_ 5 | #define MRML_MR_H_ 6 | 7 | #include 8 | 9 | #include "mrml/mrml.h" 10 | #include "sorted_buffer/sorted_buffer_iterator.h" 11 | 12 | //----------------------------------------------------------------------------- 13 | // 14 | // In addition to the MRML_Mapper, MRML_Reducer API, MR_Mapper and 15 | // MR_Reducer API provides the ability of large scale processing; in 16 | // particular, without the limitation on the number of unique map 17 | // output keys. 18 | // 19 | // From the perspective of API, the ability of incremental reduce is 20 | // removed, and the API is therefore identical to the standard Google 21 | // MapReduce API). 22 | // 23 | // *** Batch Reduction *** 24 | // 25 | // The initial design of MRML is to provide a way which makes it 26 | // possible for map workers and reduce workers work simultaneously. 27 | // The design was realized by MRML_Reducer class. However, this 28 | // design has a limit on the number of unique map output keys, which 29 | // would becomes a servere problem in applications like parallel 30 | // training of language models. Therefore, we add MR_Redcuer, a 31 | // reducer API which is identical to that published in Google papers. 32 | // 33 | // If you derive your reducer class from MR_Reducer, instead of 34 | // MRML_Reducer, please remember to register it using 35 | // REGISTER_MR_REDUCER instead of REGISTER_REDUCER. 36 | // 37 | //----------------------------------------------------------------------------- 38 | 39 | // MR_Mapper is exactly MRML_Mapper. 40 | typedef MRML_Mapper MR_Mapper; 41 | 42 | typedef sorted_buffer::SortedBufferIterator ReduceInputIterator; 43 | 44 | class MR_Reducer : public MRML_Reducer { 45 | public: 46 | virtual ~MR_Reducer() {} 47 | 48 | // The new API: 49 | virtual void Start() {} 50 | virtual void Reduce(const string& key, ReduceInputIterator* values) = 0; 51 | virtual void Flush() {} 52 | 53 | private: 54 | // Forbids the old API: 55 | virtual void* BeginReduce(const string&, const string&) { return NULL; } 56 | virtual void PartialReduce(const string&, const string&, void*) {} 57 | virtual void EndReduce(const string&, void*) {} 58 | }; 59 | 60 | //----------------------------------------------------------------------------- 61 | // Mapper/reducer registering and creating mechanism. 62 | //----------------------------------------------------------------------------- 63 | 64 | typedef MR_Mapper* (*MR_MapperCreator)(); 65 | typedef MR_Reducer* (*MR_ReducerCreator)(); 66 | 67 | class MR_MapperRegisterer { 68 | public: 69 | MR_MapperRegisterer(const string& class_name, MR_MapperCreator p); 70 | }; 71 | 72 | class MR_ReducerRegisterer { 73 | public: 74 | MR_ReducerRegisterer(const string& class_name, MR_ReducerCreator p); 75 | }; 76 | 77 | #define REGISTER_MR_MAPPER(mapper_name) \ 78 | MR_Mapper* mapper_name##_creator() { return new mapper_name; } \ 79 | MR_MapperRegisterer g_mapper_reg##mapper_name(#mapper_name, \ 80 | mapper_name##_creator) 81 | 82 | #define REGISTER_MR_REDUCER(reducer_name) \ 83 | MR_Reducer* reducer_name##_creator() { return new reducer_name; } \ 84 | MR_ReducerRegisterer g_reducer_reg##reducer_name(#reducer_name, \ 85 | reducer_name##_creator) 86 | 87 | #endif // MRML_MR_H_ 88 | -------------------------------------------------------------------------------- /mrml/mrml.proto: -------------------------------------------------------------------------------- 1 | message KeyValuePair { 2 | optional bytes key = 1; 3 | optional bytes value = 2; 4 | } 5 | 6 | // A map worker either 7 | // 1. send a map output to a reduce worker, or 8 | // 2. send a finished flag to all reduce workers. 9 | // If map_worker is set, it is case 2; otherwise, it is case 1. 10 | message MapOutput { 11 | optional int32 map_worker = 1; 12 | optional bytes key = 2; 13 | optional bytes value = 3; 14 | } 15 | -------------------------------------------------------------------------------- /mrml/mrml_filesystem.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // This file provides MRMLFS_File, an interface to remote file access 5 | // through SFTP protocol. 6 | // 7 | #ifndef MRML_MRML_FILESYSTEM_H_ 8 | #define MRML_MRML_FILESYSTEM_H_ 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | 17 | 18 | class MRMLFS_File { 19 | friend class MRMLFS_FileTest; 20 | 21 | public: 22 | enum Type { Local, SFTP }; 23 | 24 | MRMLFS_File(); 25 | MRMLFS_File(const std::string& filename, bool for_read); 26 | virtual ~MRMLFS_File(); 27 | 28 | // This function opens a local file or a remote file (through SFTP). 29 | // The filename must be in either one of the following formats: 30 | // 31 | // 1. sftp://:@: 32 | // 33 | // This format denotes a remote SFTP file protocol. If 34 | // is "localhost" or "127.0.0.1", the filename will be deprecated 35 | // into a local file, i.e., only local-path is reserved. 36 | // 37 | // 2. 38 | // 39 | // This format denotes a local file. 40 | // 41 | // If |for_read| is ture, the file will be opened in read-only 42 | // mode. Otherwise, it will be truncated and overwritten. 43 | // 44 | bool Open(const std::string& filename, bool for_read); 45 | 46 | bool IsOpen() const; 47 | 48 | // Generally, this function returns the total number of bytes 49 | // successfully read. Particularly, for SFTP file, it invokes 50 | // libssh2_sftp_read directly, and for local file, it invokes fread 51 | // directly. 52 | size_t Read(char* buffer, size_t size); 53 | 54 | // Generally, this function returns the actual number of bytes 55 | // written or negative on failure. If this number differs from the 56 | // count parameter, it indicates an error. Particularly, for SFTP 57 | // file, it invokes libssh2_sftp_write directly, and for local file, 58 | // it invokes fwrite directly. 59 | size_t Write(const char* buffer, size_t size); 60 | 61 | void Close(); 62 | 63 | protected: 64 | // Fields consisting of a filename. 65 | struct FilenameFields { 66 | std::string protocol; 67 | std::string username; 68 | std::string password; 69 | std::string hostname; 70 | int port; 71 | std::string path; 72 | 73 | FilenameFields() : port(0) {} 74 | ~FilenameFields() { Clear(); } 75 | void Clear(); 76 | }; 77 | 78 | // Fields for accessing a local file: 79 | FILE* local_file_; 80 | 81 | // Fields for accessing a SFTP file: 82 | int socket_fd_; 83 | LIBSSH2_SESSION* ssh2_session_; 84 | LIBSSH2_SFTP* sftp_session_; 85 | LIBSSH2_SFTP_HANDLE* sftp_handle_; 86 | 87 | static bool ParseFilename(const std::string& filename, FilenameFields* f); 88 | bool OpenLocalFile(const FilenameFields& f, bool for_read); 89 | bool OpenSFTPFile(const FilenameFields& f, bool for_read); 90 | }; 91 | 92 | #endif // MRML_MRML_FILESYSTEM_H_ 93 | -------------------------------------------------------------------------------- /mrml/mrml_filesystem_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include 5 | 6 | #include "gtest/gtest.h" 7 | 8 | #include "base/common.h" 9 | #include "mrml/mrml_filesystem.h" 10 | 11 | using std::string; 12 | 13 | DEFINE_string(remote_path, "", 14 | "The path specified for testing CreateAndReadSFTPFile"); 15 | 16 | class MRMLFS_FileTest : public ::testing::Test { 17 | public: 18 | static void ParseLocalFile(); 19 | static void ParseSFTPFile(); 20 | static void ParseLocalSFTPFile(); 21 | static void CreateAndReadLocalFile(); 22 | static void CreateAndReadSFTPFile(); 23 | }; 24 | 25 | void MRMLFS_FileTest::ParseLocalFile() { 26 | static const char* kFilename = "/tmp/a.h"; 27 | static const char* kFilenameWithProtocol = "file:///tmp/a.h"; 28 | 29 | MRMLFS_File::FilenameFields fields; 30 | 31 | MRMLFS_File::ParseFilename(kFilename, &fields); 32 | EXPECT_EQ(fields.protocol, "file://"); 33 | EXPECT_EQ(fields.path, kFilename); 34 | 35 | MRMLFS_File::ParseFilename(kFilenameWithProtocol, &fields); 36 | EXPECT_EQ(fields.protocol, "file://"); 37 | EXPECT_EQ(fields.path, kFilename); 38 | } 39 | 40 | void MRMLFS_FileTest::ParseSFTPFile() { 41 | static const char* kFilename1 = "sftp://user:password@host:/tmp/a.h"; 42 | static const char* kFilename2 = "sftp://user:password@host#36000:/tmp/a.h"; 43 | 44 | MRMLFS_File::FilenameFields fields; 45 | MRMLFS_File::ParseFilename(kFilename1, &fields); 46 | EXPECT_EQ(fields.protocol, "sftp://"); 47 | EXPECT_EQ(fields.username, "user"); 48 | EXPECT_EQ(fields.password, "password"); 49 | EXPECT_EQ(fields.hostname, "host"); 50 | EXPECT_EQ(fields.port, 22); 51 | EXPECT_EQ(fields.path, "/tmp/a.h"); 52 | 53 | MRMLFS_File::ParseFilename(kFilename2, &fields); 54 | EXPECT_EQ(fields.protocol, "sftp://"); 55 | EXPECT_EQ(fields.username, "user"); 56 | EXPECT_EQ(fields.password, "password"); 57 | EXPECT_EQ(fields.hostname, "host"); 58 | EXPECT_EQ(fields.port, 36000); 59 | EXPECT_EQ(fields.path, "/tmp/a.h"); 60 | } 61 | 62 | 63 | void MRMLFS_FileTest::ParseLocalSFTPFile() { 64 | static const char* kFilename1 = "sftp://user:password@127.0.0.1:/tmp/a.h"; 65 | static const char* kFilename2 = "sftp://user:password@localhost:/tmp/a.h"; 66 | 67 | MRMLFS_File::FilenameFields fields; 68 | 69 | MRMLFS_File::ParseFilename(kFilename1, &fields); 70 | EXPECT_EQ(fields.protocol, "file://"); 71 | EXPECT_EQ(fields.path, "/tmp/a.h"); 72 | 73 | MRMLFS_File::ParseFilename(kFilename2, &fields); 74 | EXPECT_EQ(fields.protocol, "file://"); 75 | EXPECT_EQ(fields.path, "/tmp/a.h"); 76 | } 77 | 78 | void MRMLFS_FileTest::CreateAndReadLocalFile() { 79 | static const char* kFilename = "/tmp/testCreateAndReadLocalFile"; 80 | static const char kContent[] = "apple"; 81 | 82 | MRMLFS_File file; 83 | CHECK(file.Open(kFilename, false)); 84 | file.Write(kContent, sizeof(kContent)); 85 | file.Close(); 86 | 87 | CHECK(file.Open(kFilename, true)); 88 | char buffer[sizeof(kContent) + 1]; 89 | EXPECT_EQ(sizeof(kContent), file.Read(buffer, sizeof(kContent))); 90 | EXPECT_EQ(string(kContent), buffer); 91 | file.Close(); 92 | } 93 | 94 | void MRMLFS_FileTest::CreateAndReadSFTPFile() { 95 | CHECK(!FLAGS_remote_path.empty()); 96 | 97 | static const char kContent[] = "apple"; 98 | 99 | MRMLFS_File file; 100 | CHECK(file.Open(FLAGS_remote_path, false)); 101 | file.Write(kContent, sizeof(kContent)); 102 | file.Close(); 103 | 104 | CHECK(file.Open(FLAGS_remote_path, true)); 105 | char buffer[sizeof(kContent) + 1]; 106 | EXPECT_EQ(sizeof(kContent), file.Read(buffer, sizeof(kContent))); 107 | EXPECT_EQ(string(kContent), buffer); 108 | file.Close(); 109 | } 110 | 111 | TEST_F(MRMLFS_FileTest, ParseLocalFile) { 112 | MRMLFS_FileTest::ParseLocalFile(); 113 | } 114 | 115 | TEST_F(MRMLFS_FileTest, ParseSFTPFile) { 116 | MRMLFS_FileTest::ParseSFTPFile(); 117 | } 118 | 119 | TEST_F(MRMLFS_FileTest, ParseLocalSFTPFile) { 120 | MRMLFS_FileTest::ParseLocalSFTPFile(); 121 | } 122 | 123 | TEST_F(MRMLFS_FileTest, CreateAndReadLocalFile) { 124 | MRMLFS_FileTest::CreateAndReadLocalFile(); 125 | } 126 | 127 | TEST_F(MRMLFS_FileTest, CreateAndReadSFTPFile) { 128 | MRMLFS_FileTest::CreateAndReadSFTPFile(); 129 | } 130 | -------------------------------------------------------------------------------- /mrml/mrml_main.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include "base/common.h" 5 | #include "mrml/mrml.h" 6 | 7 | extern bool MRML_Initialize(int argc, char** argv); 8 | extern bool MRML_AmIMapWorker(); 9 | extern void MRML_MapWork(); 10 | extern void MRML_ReduceWork(); 11 | extern void MRML_Finalize(); 12 | 13 | //----------------------------------------------------------------------------- 14 | // The pre-defined main function 15 | //----------------------------------------------------------------------------- 16 | 17 | int main(int argc, char** argv) { 18 | if (!MRML_Initialize(argc, argv)) { 19 | return -1; 20 | } 21 | 22 | LOG(INFO) << "I am a " 23 | << (MRML_AmIMapWorker() ? 24 | string("map worker") : string("reduce worker")); 25 | 26 | if (MRML_AmIMapWorker()) { 27 | MRML_MapWork(); 28 | } else { 29 | MRML_ReduceWork(); 30 | } 31 | 32 | MRML_Finalize(); 33 | return 0; 34 | } 35 | -------------------------------------------------------------------------------- /mrml/mrml_reader.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include "mrml/mrml_reader.h" 5 | 6 | #include 7 | 8 | #include "base/common.h" 9 | #include "base/logging.h" 10 | #include "mrml/mrml_recordio.h" 11 | #include "strutil/stringprintf.h" 12 | 13 | static void OpenFileOrDie(const std::string& filename, FILE** input_stream) { 14 | *input_stream = fopen(filename.c_str(), "r"); 15 | if (*input_stream == NULL) { 16 | LOG(FATAL) << "Cannot open file: " << filename; 17 | } 18 | } 19 | 20 | //----------------------------------------------------------------------------- 21 | // Implementation of MRML_TextReader 22 | //----------------------------------------------------------------------------- 23 | 24 | MRML_TextReader::MRML_TextReader(const std::string& filename, 25 | int max_line_length) 26 | : max_line_length_(max_line_length), 27 | line_num_(0), 28 | reading_a_long_line_(false), 29 | input_filename_(filename) { 30 | OpenFileOrDie(filename, &input_stream_); 31 | try { 32 | CHECK_LT(1, max_line_length_); // At least 1 for '\0' appended by fgets. 33 | line_ = new char[max_line_length_]; 34 | } catch(std::bad_alloc&) { 35 | line_ = NULL; 36 | LOG(FATAL) << "Cannot allocate line input buffer."; 37 | } 38 | } 39 | 40 | MRML_TextReader::~MRML_TextReader() { 41 | if (input_stream_ != NULL) { 42 | fclose(input_stream_); 43 | input_stream_ = NULL; 44 | } 45 | if (line_ != NULL) { 46 | delete [] line_; 47 | line_ = NULL; 48 | } 49 | } 50 | 51 | bool MRML_TextReader::Read(std::string* key, std::string* value) { 52 | SStringPrintf(key, "%s-%010lld", 53 | input_filename_.c_str(), ftell(input_stream_)); 54 | value->clear(); 55 | 56 | if (fgets(line_, max_line_length_, input_stream_) == NULL) { 57 | return false; // Either ferror or feof. Anyway, returns false to 58 | // notify the caller no further reading operations. 59 | } 60 | 61 | int read_size = strlen(line_); 62 | if (line_[read_size - 1] != '\n') { 63 | LOG(ERROR) << "Encountered a too-long line (line_num = " << line_num_ 64 | << "). May return one or more empty values while skipping " 65 | << " this long line."; 66 | reading_a_long_line_ = true; 67 | return true; // Skip the current part of a long line. 68 | } else { 69 | ++line_num_; 70 | if (reading_a_long_line_) { 71 | reading_a_long_line_ = false; 72 | return true; // Skip the last part of a long line. 73 | } 74 | } 75 | 76 | if (line_[read_size - 1] == '\n') { 77 | line_[read_size - 1] = '\0'; 78 | if (line_[read_size - 2] == '\r') { // Handle DOS text format. 79 | line_[read_size - 2] = '\0'; 80 | } 81 | } 82 | value->assign(line_); 83 | return true; 84 | } 85 | 86 | //----------------------------------------------------------------------------- 87 | // Implementation of MRML_RecordReader 88 | //----------------------------------------------------------------------------- 89 | 90 | MRML_RecordReader::MRML_RecordReader(const std::string& filename) 91 | : input_filename_(filename) { 92 | OpenFileOrDie(filename, &input_stream_); 93 | } 94 | 95 | MRML_RecordReader::~MRML_RecordReader() { 96 | if (input_stream_ != NULL) { 97 | fclose(input_stream_); 98 | input_stream_ = NULL; 99 | } 100 | } 101 | 102 | bool MRML_RecordReader::Read(std::string* key, std::string* value) { 103 | return MRML_ReadRecord(input_stream_, key, value); 104 | } 105 | -------------------------------------------------------------------------------- /mrml/mrml_reader.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // Define the interface of Reader and two standard readers: TextReader 5 | // and RecordReader. 6 | // 7 | #ifndef MRML_MRML_READER_H_ 8 | #define MRML_MRML_READER_H_ 9 | 10 | #include 11 | 12 | #include 13 | 14 | // The interface implemented by ``real'' readers. 15 | class MRML_Reader { 16 | public: 17 | virtual ~MRML_Reader() {} 18 | 19 | // Returns false to indicate that the current read failed and no 20 | // further reading operations should be performed. 21 | virtual bool Read(std::string* key, std::string* value) = 0; 22 | 23 | protected: 24 | FILE* input_stream_; 25 | }; 26 | 27 | // Read from a text file, using stdio.h API. 28 | // - The key returned by Read() is "filename-offset", the value 29 | // returned by Read is the content of a line. 30 | // - The value might be empty if it is reading a too long line. 31 | // - The '\r' (if there is any) and '\n' at the end of a line are 32 | // removed. 33 | class MRML_TextReader : public MRML_Reader { 34 | public: 35 | explicit MRML_TextReader(const std::string& filename, 36 | int max_line_length); 37 | virtual ~MRML_TextReader(); 38 | virtual bool Read(std::string* key, std::string* value); 39 | 40 | private: 41 | int max_line_length_; 42 | char* line_; // input line buffer 43 | int line_num_; // count line number 44 | bool reading_a_long_line_; // is reading a lone line 45 | std::string input_filename_; 46 | FILE* input_stream_; 47 | }; 48 | 49 | // Read from a MRML RecordIO file, using MRML_RecordIO API. 50 | class MRML_RecordReader : public MRML_Reader { 51 | public: 52 | explicit MRML_RecordReader(const std::string& filename); 53 | virtual ~MRML_RecordReader(); 54 | virtual bool Read(std::string* key, std::string* value); 55 | 56 | private: 57 | std::string input_filename_; 58 | FILE* input_stream_; 59 | }; 60 | 61 | #endif // MRML_MRML_READER_H_ 62 | -------------------------------------------------------------------------------- /mrml/mrml_recordio.h: -------------------------------------------------------------------------------- 1 | 2 | // 3 | // The interface to accessing MRML RecordIO files. 4 | // 5 | #ifndef MRML_MRML_RECORDIO_H_ 6 | #define MRML_MRML_RECORDIO_H_ 7 | 8 | #include 9 | 10 | #include 11 | 12 | namespace google { 13 | namespace protobuf { 14 | class Message; 15 | } 16 | } 17 | 18 | class MRMLFS_File; 19 | 20 | bool MRML_ReadRecord(FILE* input, 21 | std::string* key, 22 | std::string* value); 23 | bool MRML_ReadRecord(FILE* input, 24 | std::string* key, 25 | ::google::protobuf::Message* value); 26 | bool MRML_ReadRecord(MRMLFS_File* input, 27 | std::string* key, 28 | std::string* value_pb); 29 | bool MRML_ReadRecord(MRMLFS_File* input, 30 | std::string* key, 31 | ::google::protobuf::Message* value_pb); 32 | 33 | bool MRML_WriteRecord(FILE* output, 34 | const std::string& key, 35 | const std::string& value); 36 | bool MRML_WriteRecord(FILE* output, 37 | const std::string& key, 38 | const ::google::protobuf::Message& value); 39 | bool MRML_WriteRecord(MRMLFS_File* output, 40 | const std::string& key, 41 | const std::string& value); 42 | bool MRML_WriteRecord(MRMLFS_File* output, 43 | const std::string& key, 44 | const ::google::protobuf::Message& value); 45 | 46 | #endif // MRML_MRML_RECORDIO_H_ 47 | -------------------------------------------------------------------------------- /mrml/mrml_recordio_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include 5 | #include 6 | 7 | #include "gtest/gtest.h" 8 | 9 | #include "base/common.h" 10 | #include "mrml/mrml_filesystem.h" 11 | #include "mrml/mrml_recordio.h" 12 | #include "mrml/mrml.pb.h" 13 | 14 | using std::string; 15 | 16 | DEFINE_string(username, "", "Username for test IO of a remote SFTP file."); 17 | DEFINE_string(password, "", "Password for test IO of a remote SFTP file."); 18 | DEFINE_string(hostname, "", "Hostname for test IO of a remote SFTP file."); 19 | DEFINE_string(path, "", "Path for test IO of a remote SFTP file."); 20 | 21 | static const char* kTestKey = "a key"; 22 | static const char* kTestValue = "a value"; 23 | 24 | static void CheckWriteReadConsistency(const string& filename) { 25 | KeyValuePair pair; 26 | pair.set_key(kTestKey); 27 | pair.set_value(kTestValue); 28 | 29 | MRMLFS_File file(filename, false); 30 | CHECK(file.IsOpen()); 31 | MRML_WriteRecord(&file, kTestKey, kTestValue); 32 | MRML_WriteRecord(&file, kTestKey, pair); 33 | file.Close(); 34 | 35 | CHECK(file.Open(filename, true)); 36 | string key, value; 37 | CHECK(MRML_ReadRecord(&file, &key, &value)); 38 | EXPECT_EQ(key, kTestKey); 39 | EXPECT_EQ(value, kTestValue); 40 | 41 | pair.Clear(); 42 | CHECK(MRML_ReadRecord(&file, &key, &pair)); 43 | EXPECT_EQ(pair.key(), kTestKey); 44 | EXPECT_EQ(pair.value(), kTestValue); 45 | 46 | EXPECT_TRUE(!MRML_ReadRecord(&file, &key, &value)); 47 | } 48 | 49 | TEST(MRMLRecordIOTest, LocalRecordIO) { 50 | static const char* kFilename = "/tmp/testLocalRecordIO"; 51 | CheckWriteReadConsistency(kFilename); 52 | } 53 | 54 | TEST(MRMLRecordIOTest, SFTPRecordIO) { 55 | CHECK(!FLAGS_username.empty()); 56 | CHECK(!FLAGS_password.empty()); 57 | CHECK(!FLAGS_hostname.empty()); 58 | CHECK(!FLAGS_path.empty()); 59 | string filename = "sftp://" + FLAGS_username + ":" + FLAGS_password + "@" + 60 | FLAGS_hostname + ":" + FLAGS_path; 61 | CheckWriteReadConsistency(filename); 62 | } 63 | -------------------------------------------------------------------------------- /mrml/testdata/input-00000-of-00002: -------------------------------------------------------------------------------- 1 | apple banana orange 2 | 3 | banana orange apple 4 | -------------------------------------------------------------------------------- /mrml/testdata/input-00001-of-00002: -------------------------------------------------------------------------------- 1 | apple orange 2 | orange banana 3 | banana apple 4 | -------------------------------------------------------------------------------- /sorted_buffer/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Build library strutil. 2 | add_library(sorted_buffer memory_allocator.cc memory_piece.cc sorted_buffer.cc sorted_buffer_iterator.cc) 3 | 4 | # Build unittests. 5 | set(LIBS sorted_buffer strutil base protobuf boost_program_options boost_regex boost_filesystem boost_system gtest pthread) 6 | 7 | add_executable(memory_allocator_test memory_allocator_test.cc) 8 | target_link_libraries(memory_allocator_test gtest_main ${LIBS}) 9 | 10 | add_executable(memory_piece_less_than_test memory_piece_less_than_test.cc) 11 | target_link_libraries(memory_piece_less_than_test gtest_main ${LIBS}) 12 | 13 | add_executable(sorted_buffer_test sorted_buffer_test.cc) 14 | target_link_libraries(sorted_buffer_test gtest_main ${LIBS}) 15 | 16 | add_executable(memory_piece_io_test memory_piece_io_test.cc) 17 | target_link_libraries(memory_piece_io_test gtest_main ${LIBS}) 18 | 19 | add_executable(memory_piece_test memory_piece_test.cc) 20 | target_link_libraries(memory_piece_test gtest_main ${LIBS}) 21 | 22 | add_executable(sorted_buffer_iterator_test sorted_buffer_iterator_test.cc) 23 | target_link_libraries(sorted_buffer_iterator_test gtest_main ${LIBS}) 24 | 25 | add_executable(sorted_buffer_regression_test sorted_buffer_regression_test.cc) 26 | target_link_libraries(sorted_buffer_regression_test gtest_main ${LIBS}) 27 | 28 | # Install library and header files 29 | install(TARGETS sorted_buffer DESTINATION bin/sorted_buffer) 30 | FILE(GLOB HEADER_FILES "${CMAKE_CURRENT_SOURCE_DIR}/*.h") 31 | install(FILES ${HEADER_FILES} DESTINATION include/sorted_buffer) 32 | -------------------------------------------------------------------------------- /sorted_buffer/memory_allocator.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | #include "sorted_buffer/memory_allocator.h" 5 | 6 | #include "base/common.h" 7 | #include "sorted_buffer/memory_piece.h" 8 | 9 | namespace sorted_buffer { 10 | 11 | //----------------------------------------------------------------------------- 12 | // Implementation of NaiveMemoryAllocator 13 | //----------------------------------------------------------------------------- 14 | 15 | NaiveMemoryAllocator::NaiveMemoryAllocator( 16 | const int pool_size) 17 | : pool_size_(pool_size), 18 | allocated_size_(0) { 19 | CHECK_LT(0, pool_size); 20 | try { 21 | pool_ = new char[pool_size]; 22 | } catch(std::bad_alloc&) { 23 | pool_ = NULL; 24 | LOG(FATAL) << "Insufficient memory to initialize NaiveMemoryAlloctor with " 25 | << "pool size = " << pool_size; 26 | } 27 | } 28 | 29 | NaiveMemoryAllocator::~NaiveMemoryAllocator() { 30 | if (pool_ != NULL) { 31 | delete [] pool_; 32 | } 33 | pool_ = NULL; 34 | pool_size_ = 0; 35 | allocated_size_ = 0; 36 | } 37 | 38 | bool NaiveMemoryAllocator::Allocate(PieceSize size, 39 | MemoryPiece* piece) { 40 | CHECK(IsInitialized()); 41 | if (Have(size)) { 42 | piece->Set(pool_ + allocated_size_, size); 43 | allocated_size_ += size + sizeof(PieceSize); 44 | return true; 45 | } 46 | piece->Clear(); 47 | return false; 48 | } 49 | 50 | bool NaiveMemoryAllocator::Have(PieceSize size) const { 51 | return size + sizeof(PieceSize) + allocated_size_ <= pool_size_; 52 | } 53 | 54 | bool NaiveMemoryAllocator::Have(PieceSize key_length, PieceSize value_length) { 55 | return allocated_size_ + key_length + value_length + 2 * sizeof(PieceSize) 56 | <= pool_size_; 57 | } 58 | 59 | void NaiveMemoryAllocator::Reset() { 60 | allocated_size_ = 0; 61 | } 62 | 63 | std::ostream& operator<< (std::ostream& output, const MemoryPiece& p) { 64 | output << "(" << p.Size() << ") "; 65 | if (p.IsSet()) { 66 | for (int i = 0; i < p.Size(); ++i) { 67 | output << p.Data()[i]; 68 | } 69 | } else { 70 | output << "[not set]"; 71 | } 72 | return output; 73 | } 74 | 75 | } // namespace sorted_buffer 76 | -------------------------------------------------------------------------------- /sorted_buffer/memory_allocator.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // NaiveMemoryAllocator implements a simple memroy allocator, which 5 | // allocates a big memory block during initialization, and 6 | // successively allocates variable-length pieces in the pool to 7 | // applicaitons. It does not support ``free'' nor ``reallocate''; but 8 | // supports ``reset'', which reclaims all allocated pieces for a new 9 | // round of allocations. This allocator is designed for using in 10 | // InMemoryBuffer. 11 | // 12 | // NOTE: the max size of each piece is 4G, so the size of each piece 13 | // can be represented by 4 bytes. 14 | // 15 | #ifndef SORTED_BUFFER_MEMORY_ALLOCATOR_H_ 16 | #define SORTED_BUFFER_MEMORY_ALLOCATOR_H_ 17 | 18 | #include "sorted_buffer/memory_piece.h" 19 | 20 | namespace sorted_buffer { 21 | 22 | class MemoryPiece; 23 | 24 | class NaiveMemoryAllocator { 25 | public: 26 | explicit NaiveMemoryAllocator(const int pool_size); 27 | ~NaiveMemoryAllocator(); 28 | 29 | // Returns false for insufficiency memory. 30 | bool Allocate(PieceSize size, MemoryPiece* piece); 31 | // Check if there is sufficient memory to hold a string. 32 | bool Have(PieceSize length) const; 33 | // Check if there is sufficient memory to hold two strings. 34 | bool Have(PieceSize key_length, PieceSize value_length); 35 | // Reclaims all allocated blocks for the next round of allocations. 36 | void Reset(); 37 | 38 | const char* Pool() { return pool_; } // For test only. 39 | size_t PoolSize() const { return pool_size_; } 40 | size_t AllocatedSize() const { return allocated_size_; } 41 | bool IsInitialized() const { return pool_ != NULL; } 42 | 43 | private: 44 | char* pool_; 45 | size_t pool_size_; 46 | size_t allocated_size_; 47 | 48 | DISALLOW_COPY_AND_ASSIGN(NaiveMemoryAllocator); 49 | }; 50 | 51 | } // namespace sorted_buffer 52 | 53 | #endif // SORTED_BUFFER_MEMORY_ALLOCATOR_H_ 54 | -------------------------------------------------------------------------------- /sorted_buffer/memory_allocator_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include "sorted_buffer/memory_allocator.h" 5 | 6 | #include "gtest/gtest.h" 7 | 8 | namespace sorted_buffer { 9 | 10 | class NaiveMemoryAllocatorTest : public ::testing::Test {}; 11 | 12 | TEST_F(NaiveMemoryAllocatorTest, NaiveMemoryAllocator) { 13 | NaiveMemoryAllocator a(100); 14 | CHECK(a.IsInitialized()); 15 | CHECK_EQ(a.PoolSize(), 100); 16 | CHECK_EQ(a.AllocatedSize(), 0); 17 | 18 | MemoryPiece p; 19 | CHECK(a.Allocate(50, &p)); 20 | CHECK_EQ(p.Piece(), a.Pool()); 21 | CHECK_EQ(p.Data(), static_cast(a.Pool()) + sizeof(PieceSize)); 22 | CHECK_EQ(p.Size(), 50); 23 | CHECK_EQ(a.AllocatedSize(), 50 + sizeof(PieceSize)); 24 | CHECK_EQ(a.PoolSize(), 100); // not changed due to allocation 25 | 26 | CHECK(!a.Allocate(50, &p)); 27 | CHECK(p.Piece() == NULL); 28 | CHECK(p.Data() == NULL); 29 | CHECK_EQ(p.Size(), 0); 30 | CHECK_EQ(a.AllocatedSize(), 50 + sizeof(PieceSize)); 31 | CHECK_EQ(a.PoolSize(), 100); // not changed due to allocation 32 | 33 | CHECK(a.Allocate(50 - 2 * sizeof(PieceSize), &p)); 34 | CHECK_EQ(p.Piece(), a.Pool() + 50 + sizeof(PieceSize)); 35 | CHECK_EQ(p.Data(), a.Pool() + 50 + 2 * sizeof(PieceSize)); 36 | CHECK_EQ(p.Size(), 50 - 2 * sizeof(PieceSize)); 37 | CHECK_EQ(a.AllocatedSize(), a.PoolSize()); 38 | CHECK_EQ(a.PoolSize(), 100); // not changed due to allocation 39 | } 40 | 41 | } // namespace sorted_buffer 42 | -------------------------------------------------------------------------------- /sorted_buffer/memory_piece.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include 5 | 6 | #include "sorted_buffer/memory_piece.h" 7 | 8 | #include "base/common.h" 9 | #include "base/varint32.h" 10 | 11 | namespace sorted_buffer { 12 | 13 | bool MemoryPieceLessThan::operator() (const MemoryPiece& x, 14 | const MemoryPiece& y) const { 15 | typedef unsigned char byte; 16 | const byte* xdata = reinterpret_cast(x.Data()); 17 | const byte* ydata = reinterpret_cast(y.Data()); 18 | for (int i = 0; i < std::min(x.Size(), y.Size()); ++i) { 19 | if (xdata[i] < ydata[i]) { 20 | return true; 21 | } else if (xdata[i] > ydata[i]) { 22 | return false; 23 | } 24 | } 25 | return x.Size() < y.Size(); 26 | } 27 | 28 | bool MemoryPieceEqual(const MemoryPiece& x, const MemoryPiece& y) { 29 | typedef unsigned char byte; 30 | const byte* xdata = reinterpret_cast(x.Data()); 31 | const byte* ydata = reinterpret_cast(y.Data()); 32 | for (int i = 0; i < std::min(x.Size(), y.Size()); ++i) { 33 | if (xdata[i] != ydata[i]) { 34 | return false; 35 | } 36 | } 37 | return x.Size() == y.Size(); 38 | } 39 | 40 | bool WriteMemoryPiece(FILE* output, const MemoryPiece& piece) { 41 | CHECK(piece.IsSet()); 42 | return WriteVarint32(output, piece.Size()) && 43 | ((piece.Size() > 0) ? 44 | (fwrite(piece.Data(), 1, piece.Size(), output) == piece.Size()) : 45 | true); 46 | } 47 | 48 | bool ReadMemoryPiece(FILE* input, std::string* piece) { 49 | PieceSize size; 50 | if (!ReadVarint32(input, &size)) { 51 | return false; 52 | } 53 | // TODO(rickjin): make this re-entrant 54 | static const int kMaxBufferSize = 32 * 1024 * 1024; 55 | static char buffer[kMaxBufferSize]; 56 | if (size >= kMaxBufferSize) { 57 | LOG(FATAL) << "The size of string exceeds kMaxBufferSize " 58 | << kMaxBufferSize; 59 | } 60 | piece->resize(size); 61 | if (size > 0) { 62 | if (fread(buffer, 1, size, input) < size) { 63 | return false; 64 | } 65 | piece->assign(buffer, size); 66 | } 67 | return true; 68 | } 69 | 70 | } // namespace sorted_buffer 71 | 72 | 73 | -------------------------------------------------------------------------------- /sorted_buffer/memory_piece.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // MemoryPiece encapsulate either a pointer to a std::string, or a 5 | // pointer to a memory block with size = sizeof(PieceSize) + 6 | // block_size, where the first sizeof(PieceSize) bytes saves the value 7 | // of block_size. The main purpose of MemoryPiece is to save string 8 | // in raw memory allocated from user-defined memory pool. The reason 9 | // of encapsulating std::string is to make it possible to compare a 10 | // std::string with a MemoryPiece that encapsulate a memory-block. 11 | // 12 | // MemoryPieceLessThan is a binary comparator for sorting MemoryPieces 13 | // in lexical order. 14 | // 15 | // ReadMemoryPiece and WriteMemoryPiece supports (local) file IO of 16 | // MemoryPieces. 17 | // 18 | #ifndef SORTED_BUFFER_MEMORY_PIECE_H_ 19 | #define SORTED_BUFFER_MEMORY_PIECE_H_ 20 | 21 | #include 22 | #include 23 | #include 24 | 25 | #include "base/common.h" 26 | 27 | namespace sorted_buffer { 28 | 29 | typedef uint32 PieceSize; 30 | 31 | 32 | // Represent either a piece of memory, which is prepended by a 33 | // PieceSize, or a std::string object. 34 | class MemoryPiece { 35 | friend std::ostream& operator<< (std::ostream&, const MemoryPiece& p); 36 | 37 | public: 38 | MemoryPiece() : piece_(NULL), string_(NULL) {} 39 | MemoryPiece(char* piece, PieceSize size) { Set(piece, size); } 40 | explicit MemoryPiece(std::string* string) { Set(string); } 41 | 42 | void Set(char* piece, PieceSize size) { 43 | CHECK_LE(0, size); 44 | CHECK_NOTNULL(piece); 45 | // TODO(charlieyan): fix bug here 46 | piece_ = piece; 47 | *reinterpret_cast(piece_) = size; 48 | string_ = NULL; 49 | } 50 | 51 | void Set(std::string* string) { 52 | CHECK_NOTNULL(string); 53 | string_ = string; 54 | piece_ = NULL; 55 | } 56 | 57 | void Clear() { 58 | piece_ = NULL; 59 | string_ = NULL; 60 | } 61 | 62 | bool IsSet() const { return IsString() || IsPiece(); } 63 | bool IsString() const { return string_ != NULL; } 64 | bool IsPiece() const { return piece_ != NULL; } 65 | 66 | const char* Piece() const { return piece_; } 67 | 68 | char* Data() { 69 | return IsPiece() ? piece_ + sizeof(PieceSize) : 70 | (IsString() ? const_cast(string_->data()) : NULL); 71 | } 72 | 73 | const char* Data() const { 74 | return IsPiece() ? piece_ + sizeof(PieceSize) : 75 | (IsString() ? string_->data() : NULL); 76 | } 77 | 78 | size_t Size() const { 79 | return IsPiece() ? *reinterpret_cast(piece_) : 80 | (IsString() ? string_->size() : 0); 81 | } 82 | 83 | private: 84 | char* piece_; 85 | std::string* string_; 86 | }; 87 | 88 | 89 | // Compare two MemoryPiece objects in lexical order. 90 | struct MemoryPieceLessThan : public std::binary_function { 93 | bool operator() (const MemoryPiece& x, const MemoryPiece& y) const; 94 | }; 95 | 96 | bool MemoryPieceEqual(const MemoryPiece& x, const MemoryPiece& y); 97 | 98 | bool WriteMemoryPiece(FILE* output, const MemoryPiece& piece); 99 | bool ReadMemoryPiece(FILE* input, std::string* piece); 100 | 101 | } // namespace sorted_buffer 102 | 103 | #endif // SORTED_BUFFER_MEMORY_PIECE_H_ 104 | -------------------------------------------------------------------------------- /sorted_buffer/memory_piece_io_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "gtest/gtest.h" 10 | 11 | #include "sorted_buffer/memory_piece.h" 12 | 13 | namespace sorted_buffer { 14 | 15 | TEST(MemoryPieceIOTest, MemoryPieceIO) { 16 | static const char* kTmpFile = "/tmp/testMemoryPieceIO"; 17 | 18 | std::string s("apple"); 19 | std::string empty(""); 20 | char buffer[] = "1234orange"; 21 | char buffer2[sizeof(PieceSize)]; 22 | 23 | MemoryPiece p0(&empty); 24 | MemoryPiece p1(&s); 25 | MemoryPiece p2(buffer2, 0); 26 | MemoryPiece p3(buffer, strlen("orange")); 27 | 28 | FILE* output = fopen(kTmpFile, "w+"); 29 | CHECK(output != NULL); 30 | EXPECT_TRUE(WriteMemoryPiece(output, p0)); 31 | EXPECT_TRUE(WriteMemoryPiece(output, p1)); 32 | EXPECT_TRUE(WriteMemoryPiece(output, p2)); 33 | EXPECT_TRUE(WriteMemoryPiece(output, p3)); 34 | fclose(output); 35 | 36 | FILE* input = fopen(kTmpFile, "r"); 37 | CHECK(input != NULL); 38 | std::string p; 39 | EXPECT_TRUE(ReadMemoryPiece(input, &p)); 40 | EXPECT_TRUE(p.empty()); 41 | EXPECT_TRUE(ReadMemoryPiece(input, &p)); 42 | EXPECT_EQ(p, "apple"); 43 | EXPECT_TRUE(ReadMemoryPiece(input, &p)); 44 | EXPECT_TRUE(p.empty()); 45 | EXPECT_TRUE(ReadMemoryPiece(input, &p)); 46 | EXPECT_EQ(p, "orange"); 47 | EXPECT_TRUE(!ReadMemoryPiece(input, &p)); 48 | fclose(input); 49 | } 50 | 51 | TEST(MemoryPieceIOTest, ReadFromEmptyFile) { 52 | static const char* kTmpFile = "/tmp/testReadFromEmptyFile"; 53 | 54 | // Create an empty file. 55 | FILE* output = fopen(kTmpFile, "w+"); 56 | CHECK(output != NULL); 57 | fclose(output); 58 | 59 | FILE* input = fopen(kTmpFile, "r"); 60 | CHECK(input != NULL); 61 | std::string p; 62 | EXPECT_TRUE(!ReadMemoryPiece(input, &p)); 63 | fclose(input); 64 | } 65 | 66 | } // namespace sorted_buffer 67 | -------------------------------------------------------------------------------- /sorted_buffer/memory_piece_less_than_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include 5 | 6 | #include 7 | 8 | #include "gtest/gtest.h" 9 | 10 | #include "sorted_buffer/sorted_buffer.h" 11 | 12 | namespace sorted_buffer { 13 | 14 | TEST(MemoryPieceLessThanTest, LessThan) { 15 | MemoryPieceLessThan lt; 16 | 17 | char buffer1[] = "1234apple"; 18 | MemoryPiece p1(buffer1, strlen(buffer1) - sizeof(PieceSize)); 19 | 20 | char buffer2[] = "1234applee"; 21 | MemoryPiece p2(buffer2, strlen(buffer2) - sizeof(PieceSize)); 22 | 23 | EXPECT_TRUE(lt(p1, p2)); 24 | EXPECT_TRUE(!lt(p2, p1)); 25 | } 26 | 27 | TEST(MemoryPieceLessThanTest, LessThanNULLPiece) { 28 | MemoryPieceLessThan lt; 29 | 30 | char buffer1[sizeof(PieceSize)]; 31 | MemoryPiece p1(buffer1, 0); 32 | 33 | char buffer2[sizeof(PieceSize)]; 34 | MemoryPiece p2(buffer2, 0); 35 | 36 | char buffer3[] = "1234applee"; 37 | MemoryPiece p3(buffer3, strlen(buffer3) - sizeof(PieceSize)); 38 | 39 | EXPECT_TRUE(!lt(p1, p2)); 40 | EXPECT_TRUE(!lt(p2, p1)); 41 | EXPECT_TRUE(lt(p1, p3)); 42 | EXPECT_TRUE(lt(p2, p3)); 43 | EXPECT_TRUE(!lt(p3, p3)); 44 | } 45 | 46 | TEST(MemoryPieceLessThanTest, LessThanInSTLContainer) { 47 | typedef std::map MemoryPieceMap; 50 | MemoryPieceMap m; 51 | 52 | char buffer1[] = "1234apple"; 53 | MemoryPiece p(buffer1, strlen(buffer1) - sizeof(PieceSize)); 54 | 55 | m[p] = 2; 56 | CHECK(m.find(p) != m.end()); 57 | CHECK_EQ(m.find(p)->second, 2); 58 | 59 | char buffer2[] = "1234apple"; 60 | p.Set(buffer2, strlen(buffer2) - sizeof(PieceSize)); 61 | CHECK(m.find(p) != m.end()); 62 | CHECK_EQ(m.find(p)->second, 2); 63 | 64 | std::string s("apple"); 65 | p.Set(&s); 66 | CHECK(m.find(p) != m.end()); 67 | CHECK_EQ(m.find(p)->second, 2); 68 | 69 | MemoryPiece p1(&s); 70 | CHECK(m.find(p1) != m.end()); 71 | CHECK_EQ(m.find(p1)->second, 2); 72 | } 73 | 74 | } // namespace sorted_buffer 75 | -------------------------------------------------------------------------------- /sorted_buffer/memory_piece_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include 5 | 6 | #include "sorted_buffer/memory_piece.h" 7 | 8 | using sorted_buffer::MemoryPiece; 9 | using sorted_buffer::PieceSize; 10 | 11 | TEST(MemoryPieceTest, SetMemoryPiece) { 12 | MemoryPiece p; 13 | CHECK(!p.IsSet()); 14 | 15 | char buffer[1024]; 16 | p.Set(buffer, 100); 17 | CHECK(p.IsSet()); 18 | CHECK_EQ(p.Piece(), buffer); 19 | CHECK_EQ(p.Size(), 100); 20 | CHECK_EQ(p.Data(), buffer + sizeof(PieceSize)); 21 | } 22 | 23 | TEST(MemoryPieceTest, ConstructNULLPiece) { 24 | char buffer[sizeof(PieceSize)]; 25 | MemoryPiece p(buffer, 0); 26 | CHECK(p.IsSet()); 27 | CHECK_EQ(p.Piece(), buffer); 28 | CHECK_EQ(p.Size(), 0); 29 | CHECK_EQ(p.Data(), buffer + sizeof(PieceSize)); 30 | } 31 | 32 | TEST(MemoryPieceTest, ConstructMemoryPiece) { 33 | char buffer[1024]; 34 | MemoryPiece p(buffer, 100); 35 | CHECK(p.IsSet()); 36 | CHECK_EQ(p.Piece(), buffer); 37 | CHECK_EQ(p.Size(), 100); 38 | CHECK_EQ(p.Data(), buffer + sizeof(PieceSize)); 39 | } 40 | -------------------------------------------------------------------------------- /sorted_buffer/sorted_buffer.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include "sorted_buffer/sorted_buffer.h" 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "base/common.h" 11 | #include "base/varint32.h" 12 | #include "strutil/stringprintf.h" 13 | #include "sorted_buffer/sorted_buffer_iterator.h" 14 | 15 | namespace sorted_buffer { 16 | 17 | /*static*/ 18 | std::string SortedBuffer::SortedFilename(const std::string filebase, 19 | int index) { 20 | return StringPrintf("%s-%010d", filebase.c_str(), index); 21 | } 22 | 23 | SortedBuffer::SortedBuffer(const std::string& filebase, 24 | int in_memory_buffer_size) 25 | : filebase_(filebase), 26 | allocator_(new NaiveMemoryAllocator(in_memory_buffer_size)), 27 | count_files_(0) { 28 | CHECK(allocator_->IsInitialized()); // Ensure the memory pool is allocated. 29 | } 30 | 31 | SortedBuffer::~SortedBuffer() { 32 | Flush(); 33 | } 34 | 35 | void SortedBuffer::Insert(const std::string& key, 36 | const std::string& value) { 37 | CHECK_LE(0, key.size()); 38 | CHECK_LE(0, value.size()); 39 | 40 | if (!allocator_->Have(key.size(), value.size())) { 41 | Flush(); 42 | if (!allocator_->Have(key.size(), value.size())) { 43 | LOG(FATAL) << "The memory pool has insufficient space to hold incoming " 44 | << "key-value pair: " << key << " : " << value; 45 | } 46 | } 47 | MemoryPiece key_piece; 48 | CHECK(allocator_->Allocate(key.size(), &key_piece)); 49 | memcpy(key_piece.Data(), key.data(), key.size()); 50 | 51 | MemoryPiece value_piece; 52 | CHECK(allocator_->Allocate(value.size(), &value_piece)); 53 | memcpy(value_piece.Data(), value.data(), value.size()); 54 | 55 | key_value_list_.push_back(KeyValuePair(key_piece, value_piece)); 56 | } 57 | 58 | bool SortedBuffer::KeyValuePairLessThan(const KeyValuePair& x, 59 | const KeyValuePair& y) { 60 | return MemoryPieceLessThan()(x.key, y.key); 61 | } 62 | 63 | bool SortedBuffer::KeyValuePairEqual(const KeyValuePair& x, 64 | const KeyValuePair& y) { 65 | return MemoryPieceEqual(x.key, y.key); 66 | } 67 | 68 | void SortedBuffer::Flush() { 69 | if (!allocator_->IsInitialized() || allocator_->AllocatedSize() == 0) 70 | return; 71 | 72 | FILE* output = fopen(SortedFilename(filebase_, count_files_).c_str(), "w+"); 73 | if (output == NULL) { 74 | LOG(FATAL) << "Cannot open disk swap file: " 75 | << SortedFilename(filebase_, count_files_); 76 | } 77 | 78 | ++count_files_; 79 | 80 | std::sort(key_value_list_.begin(), key_value_list_.end(), 81 | KeyValuePairLessThan); 82 | 83 | uint32 current_index = 0; 84 | while (current_index < key_value_list_.size()) { 85 | uint32 next_index = current_index + 1; 86 | while (next_index < key_value_list_.size() && 87 | KeyValuePairEqual(key_value_list_[current_index], 88 | key_value_list_[next_index])) { 89 | ++next_index; 90 | } 91 | 92 | WriteMemoryPiece(output, key_value_list_[current_index].key); // key 93 | CHECK_LT(next_index - current_index, kInt32Max); 94 | WriteVarint32(output, next_index - current_index); 95 | while (current_index < next_index) { // values 96 | WriteMemoryPiece(output, key_value_list_[current_index].value); 97 | ++current_index; 98 | } 99 | } 100 | 101 | fclose(output); 102 | key_value_list_.clear(); 103 | allocator_->Reset(); 104 | } 105 | 106 | SortedBufferIterator* SortedBuffer::CreateIterator() const { 107 | if (allocator_->AllocatedSize() > 0) { 108 | LOG(FATAL) << "You must invoke Flush before CreateIterator."; 109 | } 110 | return new SortedBufferIteratorImpl(filebase_, count_files_); 111 | } 112 | 113 | void SortedBuffer::RemoveBufferFiles() const { 114 | if (allocator_->AllocatedSize() > 0) { 115 | LOG(FATAL) << "You must invoke Flush before RemoveBufferFiles."; 116 | } 117 | for (int i = 0; i < count_files_; ++i) { 118 | std::string filename = SortedFilename(filebase_, i); 119 | LOG(INFO) << "Removing : " << filename; 120 | if (remove(filename.c_str()) < 0) { 121 | LOG(ERROR) << "Cannot remove file: " << filename; 122 | } 123 | } 124 | } 125 | 126 | } // namespace sorted_buffer 127 | -------------------------------------------------------------------------------- /sorted_buffer/sorted_buffer.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #ifndef SORTED_BUFFER_SORTED_BUFFER_H_ 5 | #define SORTED_BUFFER_SORTED_BUFFER_H_ 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "boost/scoped_ptr.hpp" 13 | 14 | #include "base/common.h" 15 | #include "sorted_buffer/memory_piece.h" 16 | #include "sorted_buffer/memory_allocator.h" 17 | 18 | namespace sorted_buffer { 19 | 20 | class SortedBufferIterator; 21 | 22 | // To buffer a massive set of map outputs (key-value pairs) sorted by 23 | // key. Once the buffer is close to full, the content is output into 24 | // a disk file and the buffer is cleared. This ensures that key-value 25 | // pairs in each file are sorted. This gives SortedBufferIterator the 26 | // chance to traverse all files for sorted map outputs. 27 | class SortedBuffer { 28 | public: 29 | SortedBuffer(const std::string& disk_file_base, 30 | int in_memory_buffer_size); 31 | ~SortedBuffer(); 32 | 33 | void Insert(const std::string& key, const std::string& value); 34 | 35 | void Flush(); 36 | 37 | // The caller is responsible to delete the iterator. 38 | SortedBufferIterator* CreateIterator() const; 39 | 40 | // Remove buffer files generated by Flush(). 41 | void RemoveBufferFiles() const; 42 | 43 | static std::string SortedFilename(const std::string filebase, int index); 44 | 45 | NaiveMemoryAllocator* Allocator() { return allocator_.get(); } 46 | int NumFiles() { return count_files_; } 47 | 48 | private: 49 | struct KeyValuePair { 50 | MemoryPiece key; 51 | MemoryPiece value; 52 | KeyValuePair(const MemoryPiece& k, 53 | const MemoryPiece& v) 54 | : key(k), value(v) {} 55 | }; 56 | typedef std::vector KeyValueList; 57 | 58 | static bool KeyValuePairLessThan(const KeyValuePair& x, 59 | const KeyValuePair& y); 60 | static bool KeyValuePairEqual(const KeyValuePair& x, 61 | const KeyValuePair& y); 62 | 63 | KeyValueList key_value_list_; 64 | std::string filebase_; 65 | boost::scoped_ptr allocator_; 66 | int count_files_; 67 | 68 | DISALLOW_COPY_AND_ASSIGN(SortedBuffer); 69 | }; 70 | 71 | } // namespace sorted_buffer 72 | 73 | #endif // SORTED_BUFFER_SORTED_BUFFER_H_ 74 | -------------------------------------------------------------------------------- /sorted_buffer/sorted_buffer_iterator.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include "sorted_buffer/sorted_buffer_iterator.h" 5 | 6 | #include "base/varint32.h" 7 | #include "sorted_buffer/memory_piece.h" 8 | #include "sorted_buffer/sorted_buffer.h" 9 | 10 | namespace sorted_buffer { 11 | 12 | SortedBufferIteratorImpl::SortedBufferIteratorImpl(const std::string& filebase, 13 | int num_files) { 14 | Initialize(filebase, num_files); 15 | } 16 | 17 | SortedBufferIteratorImpl::~SortedBufferIteratorImpl() { 18 | Clear(); 19 | } 20 | 21 | void SortedBufferIteratorImpl::Initialize(const std::string& filebase, 22 | int num_files) { 23 | CHECK_LE(0, num_files); 24 | filebase_ = filebase; 25 | 26 | for (int i = 0; i < num_files; ++i) { 27 | SortedStringFile* file = new SortedStringFile; 28 | files_.push_back(file); 29 | 30 | file->index = i; 31 | file->input = 32 | fopen(SortedBuffer::SortedFilename(filebase, i).c_str(), "r"); 33 | if (file->input == NULL) { 34 | LOG(FATAL) << "Cannot open file: " 35 | << SortedBuffer::SortedFilename(filebase, i); 36 | } 37 | CHECK(LoadKey(file)); 38 | CHECK(LoadValue(file)); 39 | } 40 | 41 | RelocateMergeSource(); 42 | } 43 | 44 | const std::string& SortedBufferIteratorImpl::key() const { 45 | return current_key_; 46 | } 47 | 48 | const std::string& SortedBufferIteratorImpl::value() const { 49 | return merge_source_->top_value; 50 | } 51 | 52 | void SortedBufferIteratorImpl::Next() { 53 | if (!LoadValue(merge_source_)) { 54 | SortedStringFile* equal = FindNextMergeSourceWithEqualKey(); 55 | if (equal != NULL) { 56 | if (LoadKey(merge_source_)) { 57 | LoadValue(merge_source_); 58 | } 59 | merge_source_ = equal; 60 | } else { 61 | --(merge_source_->num_rest_values); 62 | } 63 | } 64 | } 65 | 66 | void SortedBufferIteratorImpl::DiscardRestValues() { 67 | while (merge_source_->num_rest_values >= 0) { 68 | Next(); 69 | } 70 | } 71 | 72 | bool SortedBufferIteratorImpl::Done() const { 73 | return merge_source_->num_rest_values < 0; 74 | } 75 | 76 | void SortedBufferIteratorImpl::NextKey() { 77 | DiscardRestValues(); 78 | CHECK(Done()); 79 | if (LoadKey(merge_source_)) { 80 | LoadValue(merge_source_); 81 | } 82 | RelocateMergeSource(); 83 | } 84 | 85 | bool SortedBufferIteratorImpl::FinishedAll() const { 86 | return merge_source_ == NULL; 87 | } 88 | 89 | bool SortedBufferIteratorImpl::LoadValue(SortedStringFile* file) { 90 | if (file->num_rest_values > 0) { 91 | --(file->num_rest_values); 92 | if (!ReadMemoryPiece(file->input, &(file->top_value))) { 93 | LOG(FATAL) << "Error loading value for " 94 | << "key = " << file->top_key << " file = " 95 | << SortedBuffer::SortedFilename(filebase_, file->index); 96 | } 97 | return true; 98 | } 99 | return false; 100 | } 101 | 102 | bool SortedBufferIteratorImpl::LoadKey(SortedStringFile* file) { 103 | if (!ReadMemoryPiece(file->input, &(file->top_key))) { 104 | --(file->num_rest_values); // Negative value means "end-of-sorted_buffer". 105 | return false; 106 | } 107 | if (!ReadVarint32(file->input, 108 | reinterpret_cast(&(file->num_rest_values)))) { 109 | LOG(FATAL) << "Error load num_rest_values from: " 110 | << SortedBuffer::SortedFilename(filebase_, file->index); 111 | } 112 | if (file->num_rest_values <= 0) { 113 | LOG(FATAL) << "Zero num_rest_values loaded from " 114 | << SortedBuffer::SortedFilename(filebase_, file->index); 115 | } 116 | return true; 117 | } 118 | 119 | void SortedBufferIteratorImpl::RelocateMergeSource() { 120 | // If two files have the same top_key, the one with smaller index 121 | // value is located. 122 | merge_source_ = NULL; 123 | for (SSFileList::iterator i = files_.begin(); i != files_.end(); ++i) { 124 | if ((merge_source_ == NULL || (*i)->top_key < merge_source_->top_key) && 125 | (*i)->num_rest_values >= 0) { 126 | merge_source_ = *i; 127 | } 128 | } 129 | if (merge_source_ != NULL) { 130 | current_key_ = merge_source_->top_key; 131 | } 132 | } 133 | 134 | SortedBufferIteratorImpl::SortedStringFile* 135 | SortedBufferIteratorImpl::FindNextMergeSourceWithEqualKey() { 136 | for (SSFileList::iterator i = files_.begin(); i != files_.end(); ++i) { 137 | if (*i != merge_source_ && 138 | (*i)->num_rest_values >= 0 && 139 | (*i)->top_key == merge_source_->top_key) { 140 | return *i; 141 | } 142 | } 143 | return NULL; 144 | } 145 | 146 | void SortedBufferIteratorImpl::Clear() { 147 | for (SSFileList::iterator i = files_.begin(); i != files_.end(); ++i) { 148 | fclose((*i)->input); 149 | delete *i; 150 | } 151 | files_.clear(); 152 | } 153 | 154 | } // namespace sorted_buffer 155 | -------------------------------------------------------------------------------- /sorted_buffer/sorted_buffer_iterator.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #ifndef SORTED_BUFFER_SORTED_BUFFER_ITERATOR_H_ 5 | #define SORTED_BUFFER_SORTED_BUFFER_ITERATOR_H_ 6 | 7 | #include 8 | #include 9 | 10 | #include "base/common.h" 11 | #include "sorted_buffer/memory_piece.h" 12 | 13 | namespace sorted_buffer { 14 | 15 | // The interface of iterator. 16 | class SortedBufferIterator { 17 | public: 18 | virtual ~SortedBufferIterator() {} 19 | virtual const std::string& key() const = 0; 20 | virtual const std::string& value() const = 0; 21 | virtual bool Done() const = 0; // Done with values of current key. 22 | virtual void Next() = 0; // Jump to the next value 23 | virtual void DiscardRestValues() = 0; // Jump until all values are skipped. 24 | }; 25 | 26 | 27 | // Traverse disk files generated by SortedBuffer for sorted map outputs. 28 | class SortedBufferIteratorImpl : public SortedBufferIterator { 29 | public: 30 | SortedBufferIteratorImpl(const std::string& filebase, 31 | int num_files); 32 | virtual ~SortedBufferIteratorImpl(); 33 | 34 | virtual const std::string& key() const; 35 | virtual const std::string& value() const; 36 | virtual bool Done() const; // Done with values of current key. 37 | virtual void Next(); // Jump to the next value of current key 38 | virtual void DiscardRestValues(); // Jump skip all values of current key. 39 | 40 | void NextKey(); // Jump to the next reduce input (key). 41 | bool FinishedAll() const; // Done with all keys and values. 42 | 43 | private: 44 | struct SortedStringFile { 45 | FILE* input; 46 | int index; 47 | std::string top_key; 48 | std::string top_value; 49 | int32 num_rest_values; // number of values of top_key left in current 50 | // file. 0 means no value for the key on disk 51 | // but might be one in top_key. Negative 52 | // value means "end-of-sorted_buffer". 53 | }; 54 | 55 | typedef std::list SSFileList; 56 | 57 | std::string current_key_; 58 | std::string filebase_; 59 | SSFileList files_; 60 | SortedStringFile* merge_source_; // The file with the minimum top_key. 61 | // merge_source_==NULL means no file is 62 | // valid and Finished() should be true. 63 | 64 | // Invoked by ctor. Open all block files (specified by filebase and 65 | // num_files). Requires that each file contains at least one key-value pair. 66 | void Initialize(const std::string& filebase, int num_files); 67 | 68 | // Invoked by dtor. 69 | void Clear(); 70 | 71 | // Returns false if file->num_rest_values <= 0 72 | bool LoadValue(SortedStringFile* file); 73 | 74 | // If no more keys exist in the file, mark it as end-of-sorted_buffer 75 | // and returns false. 76 | bool LoadKey(SortedStringFile* file); 77 | 78 | // Find among valid files (not end-of-sorted_buffer) for the one with 79 | // the minimum top_key, and point it by merge_source_. 80 | void RelocateMergeSource(); 81 | 82 | // Find a valid file (not end-of-sorted_buffer) whose top_key is LESS 83 | // than that of merge_source_. 84 | SortedStringFile* FindNextMergeSourceWithEqualKey(); 85 | }; 86 | 87 | } // namespace sorted_buffer 88 | 89 | #endif // SORTED_BUFFER_SORTED_BUFFER_ITERATOR_H_ 90 | -------------------------------------------------------------------------------- /sorted_buffer/sorted_buffer_iterator_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include "sorted_buffer/sorted_buffer_iterator.h" 5 | 6 | #include "gtest/gtest.h" 7 | 8 | #include "base/common.h" 9 | #include "base/varint32.h" 10 | #include "sorted_buffer/sorted_buffer.h" 11 | 12 | namespace sorted_buffer { 13 | 14 | TEST(SortedBufferIteratorTest, SortedBufferIterator) { 15 | // The following code snippet that generates a series of two disk 16 | // block files are copied from sorted_buffer_test.cc 17 | // 18 | static const std::string kTmpFilebase("/tmp/testSortedBufferIterator"); 19 | static const int kInMemBufferSize = 40; // Can hold two key-value pairs 20 | static const std::string kSomeStrings[] = { 21 | "applee", "applee", "applee", "papaya" }; 22 | static const std::string kValue("123456"); 23 | { 24 | SortedBuffer buffer(kTmpFilebase, kInMemBufferSize); 25 | for (int k = 0; k < sizeof(kSomeStrings)/sizeof(kSomeStrings[0]); ++k) { 26 | buffer.Insert(kSomeStrings[k], kValue); 27 | } 28 | buffer.Flush(); 29 | } 30 | 31 | int i = 0; 32 | for (SortedBufferIteratorImpl iter(kTmpFilebase, 2); !iter.FinishedAll(); 33 | iter.NextKey()) { 34 | for (; !iter.Done(); iter.Next()) { 35 | EXPECT_EQ(iter.key(), kSomeStrings[i]); 36 | EXPECT_EQ(iter.value(), kValue); 37 | ++i; 38 | } 39 | } 40 | } 41 | 42 | } // namespace sorted_buffer 43 | -------------------------------------------------------------------------------- /sorted_buffer/sorted_buffer_regression_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include 5 | #include 6 | 7 | #include "base/random.h" 8 | #include "gtest/gtest.h" 9 | #include "sorted_buffer/sorted_buffer_iterator.h" 10 | #include "sorted_buffer/sorted_buffer.h" 11 | 12 | using std::map; 13 | using std::string; 14 | using sorted_buffer::SortedBufferIteratorImpl; 15 | using sorted_buffer::SortedBuffer; 16 | 17 | TEST(SortedBufferTest, SortedBuffer) { 18 | LOG(INFO) << "Running ..."; 19 | 20 | static const char* kTmpFile = "/tmp/sorted_buffer_regtext_file"; 21 | static const int kMaxFileSize = 1024 * 1024; 22 | static const int kNumMaxOutputs = 5 * 1024 * 1024; 23 | 24 | static const string kVocabulary[] = 25 | { "BBM:", "Bayesian", "Browsing", "Model", "from", 26 | "Petabyte", "scale", "Data"}; 27 | static const int kVocabularySize = 28 | sizeof(kVocabulary) / sizeof(kVocabulary[0]); 29 | 30 | MTRandom rng; 31 | 32 | std::map ground_truth; 33 | SortedBuffer buffer(kTmpFile, kMaxFileSize); 34 | for (int i = 0; i < kNumMaxOutputs; ++i) { 35 | const string& word = kVocabulary[rng.RandInt(kVocabularySize)]; 36 | buffer.Insert(word, "1"); 37 | ++ground_truth[word]; 38 | } 39 | buffer.Flush(); 40 | 41 | map::const_iterator i = ground_truth.begin(); 42 | SortedBufferIteratorImpl* iter = 43 | reinterpret_cast(buffer.CreateIterator()); 44 | for (; !(iter->FinishedAll()); iter->NextKey()) { 45 | int count = 0; 46 | for (; !(iter->Done()); iter->Next()) { 47 | ++count; 48 | } 49 | EXPECT_EQ(iter->key(), i->first); 50 | EXPECT_EQ(count, i->second); 51 | ++i; 52 | } 53 | delete iter; 54 | } 55 | -------------------------------------------------------------------------------- /sorted_buffer/sorted_buffer_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include "sorted_buffer/sorted_buffer.h" 5 | 6 | #include "base/common.h" 7 | #include "base/varint32.h" 8 | #include "gtest/gtest.h" 9 | 10 | namespace sorted_buffer { 11 | 12 | class SortedBufferTest : public ::testing::Test {}; 13 | 14 | TEST_F(SortedBufferTest, OneFlushFile) { 15 | static const std::string kTmpFilebase("/tmp/testOneFlushFile"); 16 | static const int kInMemBufferSize = 1024; 17 | static const std::string kSomeStrings[] = { 18 | "applee", "banana", "orange", "papaya" }; 19 | 20 | { 21 | SortedBuffer buffer(kTmpFilebase, kInMemBufferSize); 22 | for (int k = 0; k < sizeof(kSomeStrings)/sizeof(kSomeStrings[0]); ++k) { 23 | for (int v = 0; v <= k; ++v) { 24 | buffer.Insert(kSomeStrings[k], kSomeStrings[v]); 25 | } 26 | } 27 | EXPECT_EQ(200, buffer.Allocator()->AllocatedSize()); 28 | buffer.Flush(); 29 | EXPECT_EQ(buffer.NumFiles(), 1); 30 | } 31 | 32 | std::string filename = kTmpFilebase + "-0000000000"; 33 | FILE* input = fopen(filename.c_str(), "r"); 34 | CHECK(input != NULL); 35 | 36 | std::string piece; 37 | uint32 num_values; 38 | for (int k = 0; k < sizeof(kSomeStrings)/sizeof(kSomeStrings[0]); ++k) { 39 | EXPECT_TRUE(ReadMemoryPiece(input, &piece)) << "k = " << k; 40 | EXPECT_EQ(piece, kSomeStrings[k]); 41 | EXPECT_TRUE(ReadVarint32(input, &num_values)); 42 | EXPECT_EQ(num_values, k + 1); 43 | 44 | for (int v = 0; v <= k; ++v) { 45 | EXPECT_TRUE(ReadMemoryPiece(input, &piece)); 46 | EXPECT_EQ(piece, kSomeStrings[v]); 47 | } 48 | } 49 | 50 | fclose(input); 51 | } 52 | 53 | TEST_F(SortedBufferTest, MultipleFlushFiles) { 54 | static const std::string kTmpFilebase("/tmp/testMultipleFlushFiles"); 55 | static const int kInMemBufferSize = 40; // Can hold two key-value pairs 56 | static const std::string kSomeStrings[] = { 57 | "applee", "banana", "applee", "papaya" }; 58 | static const std::string kValue("123456"); 59 | 60 | { 61 | SortedBuffer buffer(kTmpFilebase, kInMemBufferSize); 62 | for (int k = 0; k < sizeof(kSomeStrings)/sizeof(kSomeStrings[0]); ++k) { 63 | buffer.Insert(kSomeStrings[k], kValue); 64 | } 65 | buffer.Flush(); 66 | EXPECT_EQ(buffer.NumFiles(), 2); 67 | } 68 | } 69 | 70 | } // namespace sorted_buffer 71 | -------------------------------------------------------------------------------- /strutil/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Build library strutil. 2 | add_library(strutil strcodec.cc split_string.cc stringprintf.cc) 3 | 4 | # Build unittests. 5 | set(LIBS strutil base gflags gtest pthread) 6 | 7 | add_executable(strcodec_test strcodec_test.cc) 8 | target_link_libraries(strcodec_test base gtest_main ${LIBS}) 9 | 10 | add_executable(split_string_test split_string_test.cc) 11 | target_link_libraries(split_string_test gtest_main ${LIBS}) 12 | 13 | add_executable(stringprintf_test stringprintf_test.cc) 14 | target_link_libraries(stringprintf_test gtest_main ${LIBS}) 15 | 16 | add_executable(join_strings_test join_strings_test.cc) 17 | target_link_libraries(join_strings_test gtest_main ${LIBS}) 18 | 19 | # Install library and header files 20 | install(TARGETS strutil DESTINATION bin/strutil) 21 | FILE(GLOB HEADER_FILES "${CMAKE_CURRENT_SOURCE_DIR}/*.h") 22 | install(FILES ${HEADER_FILES} DESTINATION include/strutil) 23 | -------------------------------------------------------------------------------- /strutil/Makefile: -------------------------------------------------------------------------------- 1 | CXX=g++ -Wall -Wno-sign-compare -Werror -O2 2 | AR=ar rcs 3 | 4 | 5 | BUILD_TARGETS = libstrutil.a split_string_test strcodec_test stringprintf_test 6 | DIST_DIR = ../../distribution/strutil 7 | 8 | 9 | all : $(BUILD_TARGETS) 10 | 11 | libstrutil.a : split_string.o strcodec.o stringprintf.o 12 | $(AR) libstrutil.a $+ 13 | 14 | split_string_test : split_string_test.o libstrutil.a ../base/libbase.a 15 | $(CXX) -o split_string_test split_string_test.o -L. -lstrutil -L../base -lbase 16 | 17 | strcodec_test : strcodec_test.o libstrutil.a ../base/libbase.a 18 | $(CXX) -o strcodec_test strcodec_test.o -L. -lstrutil -L../base -lbase 19 | 20 | stringprintf_test : stringprintf_test.o ../base/libbase.a 21 | $(CXX) -o stringprintf_test stringprintf_test.o -L. -lstrutil -L../base -lbase 22 | 23 | dist : $(BUILD_TARGETS) 24 | rm -rf $(DIST_DIR) 25 | mkdir -p $(DIST_DIR) 26 | cp $(BUILD_TARGETS) $(DIST_DIR) 27 | cp *.hh $(DIST_DIR) 28 | 29 | clean : 30 | rm -rf $(BUILD_TARGETS) *.o *.a *.exe *~ *.stackdump *.dSYM 31 | -------------------------------------------------------------------------------- /strutil/join_strings.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #ifndef STRUTIL_JOIN_STRINGS_H_ 5 | #define STRUTIL_JOIN_STRINGS_H_ 6 | 7 | #include 8 | 9 | template 10 | void JoinStrings(const ConstForwardIterator& begin, 11 | const ConstForwardIterator& end, 12 | const std::string& delimiter, 13 | std::string* output) { 14 | output->clear(); 15 | for (ConstForwardIterator iter = begin; iter != end; ++iter) { 16 | if (iter != begin) { 17 | output->append(delimiter); 18 | } 19 | output->append(*iter); 20 | } 21 | } 22 | 23 | template 24 | std::string JoinStrings(const ConstForwardIterator& begin, 25 | const ConstForwardIterator& end, 26 | const std::string& delimiter) { 27 | std::string output; 28 | JoinStrings(begin, end, delimiter, &output); 29 | return output; 30 | } 31 | 32 | template 33 | std::string JoinStrings(const Container& container, 34 | const std::string& delimiter = " ") { 35 | return JoinStrings(container.begin(), container.end(), delimiter); 36 | } 37 | 38 | #endif // STRUTIL_JOIN_STRINGS_H_ 39 | -------------------------------------------------------------------------------- /strutil/join_strings_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include "strutil/join_strings.h" 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "gtest/gtest.h" 12 | 13 | TEST(JoinStringsTest, JoinStringsInVector) { 14 | std::vector vector; 15 | vector.push_back("apple"); 16 | vector.push_back("banana"); 17 | vector.push_back("orange"); 18 | 19 | std::string output; 20 | JoinStrings(vector.begin(), vector.end(), ",", &output); 21 | EXPECT_EQ("apple,banana,orange", output); 22 | EXPECT_EQ("apple,banana,orange", 23 | JoinStrings(vector.begin(), vector.end(), ",")); 24 | } 25 | 26 | TEST(JoinStringsTest, JoinStringsInList) { 27 | std::list list; 28 | list.push_back("apple"); 29 | list.push_back("banana"); 30 | list.push_back("orange"); 31 | 32 | std::string output; 33 | JoinStrings(list.begin(), list.end(), ",", &output); 34 | EXPECT_EQ("apple,banana,orange", output); 35 | EXPECT_EQ("apple,banana,orange", 36 | JoinStrings(list.begin(), list.end(), ",")); 37 | } 38 | 39 | TEST(JoinStringsTest, JoinStringsInSet) { 40 | std::set set; 41 | set.insert("apple"); 42 | set.insert("banana"); 43 | set.insert("orange"); 44 | 45 | std::string output; 46 | JoinStrings(set.begin(), set.end(), ",", &output); 47 | EXPECT_EQ("apple,banana,orange", output); 48 | EXPECT_EQ("apple,banana,orange", 49 | JoinStrings(set.begin(), set.end(), ",")); 50 | } 51 | -------------------------------------------------------------------------------- /strutil/split_string.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include "strutil/split_string.h" 5 | 6 | #include "base/common.h" 7 | 8 | using std::string; 9 | using std::vector; 10 | using std::set; 11 | 12 | // In most cases, delim contains only one character. In this case, we 13 | // use CalculateReserveForVector to count the number of elements 14 | // should be reserved in result vector, and thus optimize SplitStringUsing. 15 | static int CalculateReserveForVector(const string& full, const char* delim) { 16 | int count = 0; 17 | if (delim[0] != '\0' && delim[1] == '\0') { 18 | // Optimize the common case where delim is a single character. 19 | char c = delim[0]; 20 | const char* p = full.data(); 21 | const char* end = p + full.size(); 22 | while (p != end) { 23 | if (*p == c) { // This could be optimized with hasless(v,1) trick. 24 | ++p; 25 | } else { 26 | while (++p != end && *p != c) { 27 | // Skip to the next occurence of the delimiter. 28 | } 29 | ++count; 30 | } 31 | } 32 | } 33 | return count; 34 | } 35 | 36 | void SplitStringUsing(const string& full, 37 | const char* delim, 38 | vector* result) { 39 | CHECK(delim != NULL); 40 | CHECK(result != NULL); 41 | result->reserve(CalculateReserveForVector(full, delim)); 42 | back_insert_iterator< vector > it(*result); 43 | SplitStringToIteratorUsing(full, delim, it); 44 | } 45 | 46 | void SplitStringToSetUsing(const string& full, 47 | const char* delim, 48 | set* result) { 49 | CHECK(delim != NULL); 50 | CHECK(result != NULL); 51 | simple_insert_iterator > it(result); 52 | SplitStringToIteratorUsing(full, delim, it); 53 | } 54 | -------------------------------------------------------------------------------- /strutil/split_string.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // This file declares string splitting utilities. 5 | // 6 | #ifndef STRUTIL_SPLIT_STRING_H_ 7 | #define STRUTIL_SPLIT_STRING_H_ 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | // Subdivide string |full| into substrings according to delimitors 14 | // given in |delim|. |delim| should pointing to a string including 15 | // one or more characters. Each character is considerred a possible 16 | // delimitor. For example, 17 | // vector substrings; 18 | // SplitStringUsing("apple orange\tbanana", "\t ", &substrings); 19 | // results in three substrings: 20 | // substrings.size() == 3 21 | // substrings[0] == "apple" 22 | // substrings[1] == "orange" 23 | // substrings[2] == "banana" 24 | void SplitStringUsing(const std::string& full, 25 | const char* delim, 26 | std::vector* result); 27 | 28 | // This function has the same semnatic as SplitStringUsing. Results 29 | // are saved in an STL set container. 30 | void SplitStringToSetUsing(const std::string& full, 31 | const char* delim, 32 | std::set* result); 33 | 34 | 35 | template 36 | struct simple_insert_iterator { 37 | explicit simple_insert_iterator(T* t) : t_(t) { } 38 | 39 | simple_insert_iterator& operator=(const typename T::value_type& value) { 40 | t_->insert(value); 41 | return *this; 42 | } 43 | 44 | simple_insert_iterator& operator*() { return *this; } 45 | simple_insert_iterator& operator++() { return *this; } 46 | simple_insert_iterator& operator++(int placeholder) { return *this; } 47 | 48 | T* t_; 49 | }; 50 | 51 | template 52 | struct back_insert_iterator { 53 | explicit back_insert_iterator(T& t) : t_(t) {} 54 | 55 | back_insert_iterator& operator=(const typename T::value_type& value) { 56 | t_.push_back(value); 57 | return *this; 58 | } 59 | 60 | back_insert_iterator& operator*() { return *this; } 61 | back_insert_iterator& operator++() { return *this; } 62 | back_insert_iterator operator++(int placeholder) { return *this; } 63 | 64 | T& t_; 65 | }; 66 | 67 | 68 | template 69 | static inline 70 | void SplitStringToIteratorUsing(const StringType& full, 71 | const char* delim, 72 | ITR& result) { 73 | // Optimize the common case where delim is a single character. 74 | if (delim[0] != '\0' && delim[1] == '\0') { 75 | char c = delim[0]; 76 | const char* p = full.data(); 77 | const char* end = p + full.size(); 78 | while (p != end) { 79 | if (*p == c) { 80 | ++p; 81 | } else { 82 | const char* start = p; 83 | while (++p != end && *p != c) { 84 | // Skip to the next occurence of the delimiter. 85 | } 86 | *result++ = StringType(start, p - start); 87 | } 88 | } 89 | return; 90 | } 91 | 92 | std::string::size_type begin_index, end_index; 93 | begin_index = full.find_first_not_of(delim); 94 | while (begin_index != std::string::npos) { 95 | end_index = full.find_first_of(delim, begin_index); 96 | if (end_index == std::string::npos) { 97 | *result++ = full.substr(begin_index); 98 | return; 99 | } 100 | *result++ = full.substr(begin_index, (end_index - begin_index)); 101 | begin_index = full.find_first_not_of(delim, end_index); 102 | } 103 | } 104 | 105 | #endif // STRUTIL_SPLIT_STRING_H_ 106 | -------------------------------------------------------------------------------- /strutil/split_string_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include 5 | #include 6 | 7 | #include "gtest/gtest.h" 8 | 9 | #include "base/common.h" 10 | #include "strutil/split_string.h" 11 | 12 | TEST(SplitStringTest, SplitStringUsingCompoundDelim) { 13 | std::string full(" apple \torange "); 14 | std::vector subs; 15 | SplitStringUsing(full, " \t", &subs); 16 | EXPECT_EQ(subs.size(), 2); 17 | EXPECT_EQ(subs[0], std::string("apple")); 18 | EXPECT_EQ(subs[1], std::string("orange")); 19 | } 20 | 21 | TEST(SplitStringTest, testSplitStringUsingSingleDelim) { 22 | std::string full(" apple orange "); 23 | std::vector subs; 24 | SplitStringUsing(full, " ", &subs); 25 | EXPECT_EQ(subs.size(), 2); 26 | EXPECT_EQ(subs[0], std::string("apple")); 27 | EXPECT_EQ(subs[1], std::string("orange")); 28 | } 29 | 30 | TEST(SplitStringTest, testSplitingNoDelimString) { 31 | std::string full("apple"); 32 | std::vector subs; 33 | SplitStringUsing(full, " ", &subs); 34 | EXPECT_EQ(subs.size(), 1); 35 | EXPECT_EQ(subs[0], std::string("apple")); 36 | } 37 | -------------------------------------------------------------------------------- /strutil/strcodec.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include // for memcpy 5 | 6 | #include 7 | #include 8 | 9 | #include "base/common.h" 10 | #include "strutil/strcodec.h" 11 | #include "strutil/stringprintf.h" 12 | 13 | // |XXXToKey| functions will generate a string with length 14 | // kNumericValueFillSize containing human readable format of a 15 | // numerical value, prefixed by '0's. 16 | const int kNumericValueFillSize = 10; 17 | 18 | //----------------------------------------------------------------------------- 19 | // Template implementation of human-readable encoding. 20 | //----------------------------------------------------------------------------- 21 | template 22 | void NumericValueToKey(const ValueType& value, std::string* key) { 23 | using std::ostringstream; 24 | CHECK_GE(value, 0); 25 | ostringstream o; 26 | o << std::setfill('0') << std::setw(kNumericValueFillSize) << value; 27 | *key = o.str(); 28 | } 29 | 30 | template 31 | std::string NumericValueToKey(const ValueType& value) { 32 | using std::ostringstream; 33 | CHECK_GE(value, 0); 34 | ostringstream o; 35 | o << std::setfill('0') << std::setw(kNumericValueFillSize) << value; 36 | return o.str(); 37 | } 38 | 39 | template 40 | ValueType KeyToNumericValue(const std::string& key) { 41 | using std::istringstream; 42 | istringstream i(key); 43 | ValueType v; 44 | i >> v; 45 | return v; 46 | } 47 | 48 | //----------------------------------------------------------------------------- 49 | // Realizations of fast human-readable encoding. 50 | //----------------------------------------------------------------------------- 51 | 52 | void Int32ToKey(int32 value, std::string* str) { 53 | NumericValueToKey(value, str); 54 | } 55 | 56 | std::string Int32ToKey(int32 value) { 57 | return NumericValueToKey(value); 58 | } 59 | 60 | void Uint32ToKey(uint32 value, std::string* str) { 61 | NumericValueToKey(value, str); 62 | } 63 | 64 | std::string Uint32ToKey(uint32 value) { 65 | return NumericValueToKey(value); 66 | } 67 | 68 | void Int64ToKey(int64 value, std::string* str) { 69 | NumericValueToKey(value, str); 70 | } 71 | 72 | std::string Int64ToKey(int64 value) { 73 | return NumericValueToKey(value); 74 | } 75 | 76 | void Uint64ToKey(uint64 value, std::string* str) { 77 | NumericValueToKey(value, str); 78 | } 79 | 80 | std::string Uint64ToKey(uint64 value) { 81 | return NumericValueToKey(value); 82 | } 83 | 84 | int32 KeyToInt32(const std::string& key) { 85 | return KeyToNumericValue(key); 86 | } 87 | 88 | uint32 KeyToUint32(const std::string& key) { 89 | return KeyToNumericValue(key); 90 | } 91 | 92 | int64 KeyToInt64(const std::string& key) { 93 | return KeyToNumericValue(key); 94 | } 95 | 96 | uint64 KeyToUint64(const std::string& key) { 97 | return KeyToNumericValue(key); 98 | } 99 | 100 | //----------------------------------------------------------------------------- 101 | // Template implementation of fast encoding/decoding. 102 | //----------------------------------------------------------------------------- 103 | template 104 | void EncodeNumericValue(const ValueType& value, std::string* str) { 105 | str->resize(sizeof(ValueType)); 106 | memcpy(const_cast(str->data()), &value, sizeof(ValueType)); 107 | } 108 | 109 | template 110 | std::string EncodeNumericValue(const ValueType& value) { 111 | std::string str; 112 | str.resize(sizeof(ValueType)); 113 | memcpy(const_cast(str.data()), &value, sizeof(ValueType)); 114 | return str; 115 | } 116 | 117 | template 118 | ValueType DecodeNumericValue(const std::string& str) { 119 | ValueType ret; 120 | CHECK_EQ(str.size(), sizeof(ValueType)); 121 | memcpy(&ret, str.data(), sizeof(ValueType)); 122 | return ret; 123 | } 124 | 125 | //----------------------------------------------------------------------------- 126 | // Realizations of fast encoding/decoding. 127 | //----------------------------------------------------------------------------- 128 | void EncodeInt32(int32 value, std::string* str) { 129 | EncodeNumericValue(value, str); 130 | } 131 | 132 | std::string EncodeInt32(int32 value) { 133 | return EncodeNumericValue(value); 134 | } 135 | 136 | int32 DecodeInt32(const std::string& str) { 137 | return DecodeNumericValue(str); 138 | } 139 | 140 | void EncodeUint32(uint32 value, std::string* str) { 141 | EncodeNumericValue(value, str); 142 | } 143 | 144 | std::string EncodeUint32(uint32 value) { 145 | return EncodeNumericValue(value); 146 | } 147 | 148 | uint32 DecodeUint32(const std::string& str) { 149 | return DecodeNumericValue(str); 150 | } 151 | 152 | void EncodeInt64(int64 value, std::string* str) { 153 | EncodeNumericValue(value, str); 154 | } 155 | 156 | std::string EncodeInt64(int64 value) { 157 | return EncodeNumericValue(value); 158 | } 159 | 160 | int64 DecodeInt64(const std::string& str) { 161 | return DecodeNumericValue(str); 162 | } 163 | 164 | void EncodeUint64(uint64 value, std::string* str) { 165 | EncodeNumericValue(value, str); 166 | } 167 | 168 | std::string EncodeUint64(uint64 value) { 169 | return EncodeNumericValue(value); 170 | } 171 | 172 | uint64 DecodeUint64(const std::string& str) { 173 | return DecodeNumericValue(str); 174 | } 175 | -------------------------------------------------------------------------------- /strutil/strcodec.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // This file declares functions that converts numerical types 5 | // from/into binary forms. An example usage of these functions is 6 | // generating/parsing MapReduce keys from/as numerical types. 7 | // 8 | #ifndef STRUTIL_STRCODEC_H_ 9 | #define STRUTIL_STRCODEC_H_ 10 | 11 | #include 12 | #include "base/common.h" 13 | 14 | // Following functions encode numerical values in human-readable format. 15 | 16 | void Int32ToKey(int32 value, std::string* str); 17 | std::string Int32ToKey(int32 value); 18 | 19 | void Uint32ToKey(uint32 value, std::string* str); 20 | std::string Uint32ToKey(uint32 value); 21 | 22 | void Int64ToKey(int64 value, std::string* str); 23 | std::string Int64ToKey(int64 value); 24 | 25 | void Uint64ToKey(uint64 value, std::string* str); 26 | std::string Uint64ToKey(uint64 value); 27 | 28 | int32 KeyToInt32(const std::string& key); 29 | uint32 KeyToUint32(const std::string& key); 30 | int64 KeyToInt64(const std::string& key); 31 | uint64 KeyToUint64(const std::string& key); 32 | 33 | // Following functions does fast encoding/decoding of numerical types. 34 | // NOTE: The fast encoding of a numerical value is just its memory 35 | // mirro, and MUST NOT be passed cross machines using different endian 36 | // styles. 37 | 38 | void EncodeInt32(int32 value, std::string* str); 39 | std::string EncodeInt32(int32 value); 40 | 41 | int32 DecodeInt32(const std::string& str); 42 | 43 | void EncodeUint32(uint32 value, std::string* str); 44 | std::string EncodeUint32(uint32 value); 45 | 46 | uint32 DecodeUint32(const std::string& str); 47 | 48 | void EncodeInt64(int64 value, std::string* str); 49 | std::string EncodeInt64(int64 value); 50 | 51 | int64 DecodeInt64(const std::string& str); 52 | 53 | void EncodeUint64(uint64 value, std::string* str); 54 | std::string EncodeUint64(uint64 value); 55 | 56 | uint64 DecodeUint64(const std::string& str); 57 | 58 | #endif // STRUTIL_STRCODEC_H_ 59 | -------------------------------------------------------------------------------- /strutil/strcodec_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include "gtest/gtest.h" 5 | 6 | #include "strutil/strcodec.h" 7 | 8 | TEST(StrCodecTest, testFastCodecInt32) { 9 | std::string str; 10 | EncodeInt32(0, &str); 11 | EXPECT_EQ(0, DecodeInt32(str)); 12 | 13 | EncodeInt32(1, &str); 14 | EXPECT_EQ(1, DecodeInt32(str)); 15 | 16 | EncodeInt32(-1, &str); 17 | EXPECT_EQ(-1, DecodeInt32(str)); 18 | 19 | EncodeInt32(0xffffffff, &str); 20 | EXPECT_EQ(0xffffffff, DecodeInt32(str)); 21 | } 22 | 23 | TEST(StrCodecTest, testFastCodecUint64) { 24 | std::string str; 25 | EncodeUint64(0, &str); 26 | EXPECT_EQ(0, DecodeUint64(str)); 27 | 28 | EncodeUint64(1, &str); 29 | EXPECT_EQ(1, DecodeUint64(str)); 30 | 31 | EncodeUint64(0xffffffffffffffffLLU, &str); 32 | EXPECT_EQ(0xffffffffffffffffLLU, DecodeUint64(str)); 33 | } 34 | 35 | TEST(StrCodecTest, testInt32ToKey) { 36 | std::string key; 37 | Int32ToKey(0, &key); 38 | EXPECT_EQ(key, "0000000000"); 39 | 40 | Int32ToKey(1, &key); 41 | EXPECT_EQ(key, "0000000001"); 42 | } 43 | 44 | TEST(StrCodecTest, testKeyToInt32) { 45 | std::string key; 46 | Int32ToKey(0, &key); 47 | EXPECT_EQ(KeyToInt32(key), 0); 48 | 49 | Int32ToKey(1, &key); 50 | EXPECT_EQ(KeyToInt32(key), 1); 51 | } 52 | -------------------------------------------------------------------------------- /strutil/stringprintf.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // This code comes from the re2 project host on Google Code 5 | // (http://code.google.com/p/re2/), in particular, the following source file 6 | // http://code.google.com/p/re2/source/browse/util/stringprintf.cc 7 | 8 | #include "strutil/stringprintf.h" 9 | 10 | #include 11 | #include 12 | 13 | using std::string; 14 | 15 | static void StringAppendV(string* dst, const char* format, va_list ap) { 16 | // First try with a small fixed size buffer 17 | char space[1024]; 18 | 19 | // It's possible for methods that use a va_list to invalidate 20 | // the data in it upon use. The fix is to make a copy 21 | // of the structure before using it and use that copy instead. 22 | va_list backup_ap; 23 | va_copy(backup_ap, ap); 24 | int result = vsnprintf(space, sizeof(space), format, backup_ap); 25 | va_end(backup_ap); 26 | 27 | if ((result >= 0) && (result < sizeof(space))) { 28 | // It fit 29 | dst->append(space, result); 30 | return; 31 | } 32 | 33 | // Repeatedly increase buffer size until it fits 34 | int length = sizeof(space); 35 | while (true) { 36 | if (result < 0) { 37 | // Older behavior: just try doubling the buffer size 38 | length *= 2; 39 | } else { 40 | // We need exactly "result+1" characters 41 | length = result + 1; 42 | } 43 | char* buf = new char[length]; 44 | 45 | // Restore the va_list before we use it again 46 | va_copy(backup_ap, ap); 47 | result = vsnprintf(buf, length, format, backup_ap); 48 | va_end(backup_ap); 49 | 50 | if ((result >= 0) && (result < length)) { 51 | // It fit 52 | dst->append(buf, result); 53 | delete[] buf; 54 | return; 55 | } 56 | delete[] buf; 57 | } 58 | } 59 | 60 | string StringPrintf(const char* format, ...) { 61 | va_list ap; 62 | va_start(ap, format); 63 | string result; 64 | StringAppendV(&result, format, ap); 65 | va_end(ap); 66 | return result; 67 | } 68 | 69 | void SStringPrintf(string* dst, const char* format, ...) { 70 | va_list ap; 71 | va_start(ap, format); 72 | dst->clear(); 73 | StringAppendV(dst, format, ap); 74 | va_end(ap); 75 | } 76 | 77 | void StringAppendF(string* dst, const char* format, ...) { 78 | va_list ap; 79 | va_start(ap, format); 80 | StringAppendV(dst, format, ap); 81 | va_end(ap); 82 | } 83 | -------------------------------------------------------------------------------- /strutil/stringprintf.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // This code comes from the re2 project host on Google Code 5 | // (http://code.google.com/p/re2/), in particular, the following source file 6 | // http://code.google.com/p/re2/source/browse/util/stringprintf.cc 7 | // 8 | #ifndef STRUTIL_STRINGPRINTF_H_ 9 | #define STRUTIL_STRINGPRINTF_H_ 10 | 11 | #include 12 | 13 | std::string StringPrintf(const char* format, ...); 14 | void SStringPrintf(std::string* dst, const char* format, ...); 15 | void StringAppendF(std::string* dst, const char* format, ...); 16 | 17 | #endif // STRUTIL_STRINGPRINTF_H_ 18 | -------------------------------------------------------------------------------- /strutil/stringprintf_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | #include "strutil/stringprintf.h" 5 | 6 | #include "gtest/gtest.h" 7 | 8 | TEST(StringPrintf, normal) { 9 | using std::string; 10 | EXPECT_EQ(StringPrintf("%d", 1), string("1")); 11 | string target; 12 | SStringPrintf(&target, "%d", 1); 13 | EXPECT_EQ(target, string("1")); 14 | StringAppendF(&target, "%d", 2); 15 | EXPECT_EQ(target, string("12")); 16 | } 17 | -------------------------------------------------------------------------------- /system/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Build library strutil. 2 | add_library(system condition_variable.cc filepattern.cc) 3 | 4 | # Build unittests. 5 | set(LIBS system base strutil gtest pthread) 6 | 7 | add_executable(condition_variable_test condition_variable_test.cc) 8 | target_link_libraries(condition_variable_test gtest_main ${LIBS}) 9 | 10 | add_executable(mutex_test mutex_test.cc) 11 | target_link_libraries(mutex_test gtest_main ${LIBS}) 12 | 13 | add_executable(filepattern_test filepattern_test.cc) 14 | target_link_libraries(filepattern_test gtest_main ${LIBS}) 15 | 16 | # Install library and header files 17 | install(TARGETS system DESTINATION bin/system) 18 | FILE(GLOB HEADER_FILES "${CMAKE_CURRENT_SOURCE_DIR}/*.h") 19 | install(FILES ${HEADER_FILES} DESTINATION include/system) 20 | -------------------------------------------------------------------------------- /system/condition_variable.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // This is a copy from 5 | 6 | // trunk/src/common/system/concurrency/condition_variable.cpp 7 | // 8 | #include "system/condition_variable.h" 9 | 10 | #include 11 | #if __unix__ 12 | #include 13 | #endif 14 | 15 | #include 16 | #include 17 | 18 | 19 | #ifdef _WIN32 20 | 21 | ConditionVariable::ConditionVariable() { 22 | m_hCondition = ::CreateEvent(NULL, FALSE, FALSE, NULL); 23 | m_nWaitCount = 0; 24 | assert(m_hCondition != NULL); 25 | } 26 | 27 | ConditionVariable::~ConditionVariable() { 28 | ::CloseHandle(m_hCondition); 29 | } 30 | 31 | void ConditionVariable::Wait(Mutex* inMutex) { 32 | inMutex->Unlock(); 33 | m_nWaitCount++; 34 | DWORD theErr = ::WaitForSingleObject(m_hCondition, INFINITE); 35 | m_nWaitCount--; 36 | assert((theErr == WAIT_OBJECT_0) || (theErr == WAIT_TIMEOUT)); 37 | inMutex->Lock(); 38 | 39 | if (theErr != WAIT_OBJECT_0) 40 | throw std::runtime_error("ConditionVariable::Wait"); 41 | } 42 | 43 | bool ConditionVariable::Wait(Mutex* inMutex, int inTimeoutInMilSecs) { 44 | inMutex->Unlock(); 45 | m_nWaitCount++; 46 | DWORD theErr = ::WaitForSingleObject(m_hCondition, inTimeoutInMilSecs); 47 | m_nWaitCount--; 48 | assert((theErr == WAIT_OBJECT_0) || (theErr == WAIT_TIMEOUT)); 49 | inMutex->Lock(); 50 | 51 | if (theErr == WAIT_OBJECT_0) 52 | return true; 53 | else if (theErr == WAIT_TIMEOUT) 54 | return false; 55 | else 56 | throw std::runtime_error("ConditionVariable::Wait"); 57 | } 58 | 59 | void ConditionVariable::Signal() { 60 | if (!::SetEvent(m_hCondition)) 61 | throw std::runtime_error("ConditionVariable::Signal"); 62 | } 63 | 64 | void ConditionVariable::Broadcast() { 65 | // There doesn't seem like any more elegant way to 66 | // implement Broadcast using events in Win32. 67 | // This will work, it may generate spurious wakeups, 68 | // but condition variables are allowed to generate 69 | // spurious wakeups 70 | unsigned int waitCount = m_nWaitCount; 71 | for (unsigned int x = 0; x < waitCount; x++) { 72 | if (!::SetEvent(m_hCondition)) 73 | throw std::runtime_error("ConditionVariable::Broadcast"); 74 | } 75 | } 76 | 77 | #elif defined __unix__ 78 | 79 | void ConditionVariable::CheckError(const char* context, int error) { 80 | if (error != 0) { 81 | std::string msg = context; 82 | msg += " error: "; 83 | msg += strerror(error); 84 | throw std::runtime_error(msg); 85 | } 86 | } 87 | 88 | ConditionVariable::ConditionVariable() { 89 | pthread_condattr_t cond_attr; 90 | pthread_condattr_init(&cond_attr); 91 | int ret = pthread_cond_init(&m_hCondition, &cond_attr); 92 | pthread_condattr_destroy(&cond_attr); 93 | CheckError("ConditionVariable::ConditionVariable", ret); 94 | } 95 | 96 | ConditionVariable::~ConditionVariable() { 97 | pthread_cond_destroy(&m_hCondition); 98 | } 99 | 100 | void ConditionVariable::Signal() { 101 | CheckError("ConditionVariable::Signal", 102 | pthread_cond_signal(&m_hCondition)); 103 | } 104 | 105 | void ConditionVariable::Broadcast() { 106 | CheckError("ConditionVariable::Broadcast", 107 | pthread_cond_broadcast(&m_hCondition)); 108 | } 109 | 110 | void ConditionVariable::Wait(Mutex* inMutex) { 111 | CheckError("ConditionVariable::Wait", 112 | pthread_cond_wait(&m_hCondition, &inMutex->m_Mutex)); 113 | } 114 | 115 | bool ConditionVariable::Wait(Mutex* inMutex, int inTimeoutInMilSecs) { 116 | if (inTimeoutInMilSecs < 0) { 117 | Wait(inMutex); // wait forever 118 | return true; 119 | } 120 | 121 | // get current absolate time 122 | struct timeval tv; 123 | gettimeofday(&tv, NULL); 124 | 125 | // add timeout 126 | tv.tv_sec += inTimeoutInMilSecs / 1000; 127 | tv.tv_usec += (inTimeoutInMilSecs % 1000) * 1000; 128 | 129 | int million = 1000000; 130 | if (tv.tv_usec >= million) { 131 | tv.tv_sec += tv.tv_usec / million; 132 | tv.tv_usec %= million; 133 | } 134 | 135 | // convert timeval to timespec 136 | struct timespec ts; 137 | ts.tv_sec = tv.tv_sec; 138 | ts.tv_nsec = tv.tv_usec * 1000; 139 | int error = pthread_cond_timedwait(&m_hCondition, &inMutex->m_Mutex, &ts); 140 | 141 | if (error == ETIMEDOUT) 142 | return false; 143 | else 144 | CheckError("ConditionVariable::Wait", error); 145 | return true; 146 | } 147 | 148 | #endif 149 | 150 | -------------------------------------------------------------------------------- /system/condition_variable.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // This is a copy from 5 | 6 | // trunk/src/common/system/concurrency/condition_variable.hpp 7 | // 8 | #ifndef SYSTEM_CONDITION_VARIABLE_H_ 9 | #define SYSTEM_CONDITION_VARIABLE_H_ 10 | 11 | #ifndef _WIN32 12 | #if __unix__ 13 | #include 14 | #endif 15 | #endif 16 | 17 | #include 18 | #include "system/mutex.h" 19 | 20 | class ConditionVariable { 21 | public: 22 | ConditionVariable(); 23 | ~ConditionVariable(); 24 | 25 | void Signal(); 26 | void Broadcast(); 27 | 28 | bool Wait(Mutex* inMutex, int inTimeoutInMilSecs); 29 | void Wait(Mutex* inMutex); 30 | 31 | private: 32 | #ifdef _WIN32 33 | HANDLE m_hCondition; 34 | unsigned int m_nWaitCount; 35 | #elif __unix__ 36 | pthread_cond_t m_hCondition; 37 | #endif 38 | static void CheckError(const char* context, int error); 39 | }; 40 | 41 | #endif // SYSTEM_CONDITION_VARIABLE_H_ 42 | 43 | -------------------------------------------------------------------------------- /system/condition_variable_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // This is a copy from 5 | 6 | // trunk/src/common/system/concurrency/condition_variable_test.cpp 7 | // 8 | #include "gtest/gtest.h" 9 | #include "system/condition_variable.h" 10 | 11 | TEST(ConditionVariable, Init) { 12 | ConditionVariable cond; 13 | } 14 | 15 | TEST(ConditionVariable, Wait) { 16 | ConditionVariable event; 17 | event.Signal(); 18 | } 19 | 20 | TEST(ConditionVariable, Release) { 21 | } 22 | 23 | -------------------------------------------------------------------------------- /system/filepattern.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | #include "system/filepattern.h" 5 | 6 | #include "base/common.h" 7 | 8 | FilepatternMatcher::FilepatternMatcher(const std::string& filepattern) { 9 | glob_return_ = glob(filepattern.c_str(), 10 | GLOB_MARK | // Append a slash to each path which 11 | // corresponds to a directory. 12 | GLOB_TILDE_CHECK | // Carry out tilde expansion. If 13 | // the username is invalid, or the 14 | // home directory cannot be 15 | // determined, glob() returns 16 | // GLOB_NOMATCH to indicate an error. 17 | GLOB_BRACE, // Enable brace expressions. 18 | FilepatternMatcher::ErrorFunc, 19 | &glob_result_); 20 | 21 | if (glob_return_ != 0) { 22 | switch (glob_return_) { 23 | case GLOB_NOSPACE: 24 | LOG(ERROR) << "Run out of memory in file pattern matching."; 25 | break; 26 | case GLOB_ABORTED: 27 | LOG(ERROR) << "Encouterred read error in file pattern matching."; 28 | break; 29 | case GLOB_NOMATCH: 30 | LOG(ERROR) << "Filepattern " << filepattern << " matches no file."; 31 | break; 32 | } 33 | } 34 | } 35 | 36 | bool FilepatternMatcher::NoError() const { 37 | return glob_return_ == 0; 38 | } 39 | 40 | int FilepatternMatcher::NumMatched() const { 41 | return glob_result_.gl_pathc; 42 | } 43 | 44 | const char* FilepatternMatcher::Matched(int i) const { 45 | CHECK_LE(0, i); 46 | CHECK_LT(i, NumMatched()); 47 | return glob_result_.gl_pathv[i]; 48 | } 49 | 50 | FilepatternMatcher::~FilepatternMatcher() { 51 | globfree(&glob_result_); 52 | } 53 | 54 | /*static*/ 55 | int FilepatternMatcher::ErrorFunc(const char* epath, int eerrno) { 56 | LOG(ERROR) << "Failed at matching: " << epath << " due to error: " << eerrno; 57 | return 0; // Retuns 0 to let glob() attempts carry on despite errors. 58 | } 59 | -------------------------------------------------------------------------------- /system/filepattern.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // Operations with file patterns. 5 | // 6 | #ifndef SYSTEM_FILEPATTERN_H_ 7 | #define SYSTEM_FILEPATTERN_H_ 8 | 9 | #include 10 | #include 11 | 12 | // Given a filepattern as a string, this class find mached files. 13 | // Usage: 14 | /* 15 | const string filepattern = "/usr/include/g*.h"; 16 | FilepatternMatcher m(filepattern); 17 | for (int i = 0; i < m.NumMatched(); ++i) { 18 | printf("Found %s\n", m.Matched()); 19 | } 20 | */ 21 | class FilepatternMatcher { 22 | public: 23 | explicit FilepatternMatcher(const std::string& filepattern); 24 | ~FilepatternMatcher(); 25 | 26 | bool NoError() const; 27 | int NumMatched() const; 28 | const char* Matched(int i) const; 29 | 30 | private: 31 | static int ErrorFunc(const char* epath, int eerrno); // The error handler. 32 | int glob_return_; 33 | glob_t glob_result_; 34 | }; 35 | 36 | #endif // SYSTEM_FILEPATTERN_H_ 37 | -------------------------------------------------------------------------------- /system/filepattern_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | #include // For fopen() and remove() 5 | 6 | #include 7 | 8 | #include "gtest/gtest.h" 9 | #include "strutil/stringprintf.h" 10 | #include "system/filepattern.h" 11 | 12 | static const int kNumTestFiles = 5; 13 | static const char* kTestFilebase = "/tmp/filepattern-test"; 14 | 15 | class FilepatternMatcherTest : public ::testing::Test { 16 | protected: 17 | virtual void SetUp() { 18 | FILE* file = NULL; 19 | for (int i = 0; i < kNumTestFiles; ++i) { 20 | std::string filename; 21 | SStringPrintf(&filename, "%s-%05d-of-%05d", kTestFilebase, 22 | i, kNumTestFiles); 23 | ASSERT_TRUE((file = fopen(filename.c_str(), "w+")) != NULL); 24 | } 25 | } 26 | 27 | virtual void TearDown() { 28 | for (int i = 0; i < kNumTestFiles; ++i) { 29 | std::string filename; 30 | SStringPrintf(&filename, "%s-%05d-of-%05d", kTestFilebase, 31 | i, kNumTestFiles); 32 | ASSERT_EQ(remove(filename.c_str()), 0); 33 | } 34 | } 35 | }; 36 | 37 | TEST_F(FilepatternMatcherTest, MatchUsingAsterisk) { 38 | FilepatternMatcher m( 39 | StringPrintf("%s-*-of-%05d", kTestFilebase, kNumTestFiles)); 40 | EXPECT_EQ(m.NumMatched(), 5); 41 | EXPECT_TRUE(m.NoError()); 42 | } 43 | 44 | TEST_F(FilepatternMatcherTest, MatchUsingQuestionMark) { 45 | FilepatternMatcher m( 46 | StringPrintf("%s-0000?-of-%05d", kTestFilebase, kNumTestFiles)); 47 | EXPECT_EQ(m.NumMatched(), 5); 48 | EXPECT_TRUE(m.NoError()); 49 | } 50 | 51 | TEST_F(FilepatternMatcherTest, MatchUsingSpecifiedRange) { 52 | FilepatternMatcher m( 53 | StringPrintf("%s-0000[0-2]-of-%05d", kTestFilebase, kNumTestFiles)); 54 | EXPECT_EQ(m.NumMatched(), 3); 55 | EXPECT_TRUE(m.NoError()); 56 | } 57 | 58 | TEST_F(FilepatternMatcherTest, MatchUsingBrace) { 59 | FilepatternMatcher m( 60 | StringPrintf("%s-000{00,01,02}-of-%05d", kTestFilebase, kNumTestFiles)); 61 | EXPECT_EQ(m.NumMatched(), 3); 62 | EXPECT_TRUE(m.NoError()); 63 | } 64 | 65 | TEST_F(FilepatternMatcherTest, MatchASpecificFile) { 66 | FilepatternMatcher m( 67 | StringPrintf("%s-00000-of-%05d", kTestFilebase, kNumTestFiles)); 68 | EXPECT_EQ(m.NumMatched(), 1); 69 | EXPECT_TRUE(m.NoError()); 70 | } 71 | 72 | TEST_F(FilepatternMatcherTest, MatchNotExisting) { 73 | FilepatternMatcher m("/tmp/somthing-that-does-not-seem-possible-to-exist"); 74 | EXPECT_EQ(m.NumMatched(), 0); 75 | EXPECT_FALSE(m.NoError()); 76 | } 77 | 78 | -------------------------------------------------------------------------------- /system/mutex.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // This is a copy from 5 | 6 | // trunk/src/common/system/concurrency/mutex.hpp 7 | // 8 | #ifndef SYSTEM_MUTEX_H_ 9 | #define SYSTEM_MUTEX_H_ 10 | 11 | #if defined _WIN32 12 | # ifndef WIN32_LEAN_AND_MEAN 13 | # define WIN32_LEAN_AND_MEAN 14 | # endif 15 | # ifndef _WIN32_WINNT 16 | # define _WIN32_WINNT 0x0501 // windows xp 17 | # endif 18 | # define NOMINMAX 1 19 | # include 20 | #elif defined __unix__ 21 | # include 22 | #else 23 | # error Unknown platform 24 | #endif 25 | 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | 32 | #include "system/scoped_locker.h" 33 | 34 | class ConditionVariable; 35 | 36 | #if defined _WIN32 37 | 38 | // if _WIN32_WINNT not defined, TryEnterCriticalSection will not be declared 39 | // in windows.h 40 | extern "C" WINBASEAPI 41 | BOOL WINAPI TryEnterCriticalSection( 42 | __inout LPCRITICAL_SECTION lpCriticalSection); 43 | 44 | class Mutex { 45 | public: 46 | typedef ScopedLocker Locker; 47 | public: 48 | explicit Mutex(bool recursive = true) { 49 | ::InitializeCriticalSection(&m_Mutex); 50 | } 51 | 52 | ~Mutex() { 53 | ::DeleteCriticalSection(&m_Mutex); 54 | } 55 | 56 | void Lock() { 57 | ::EnterCriticalSection(&m_Mutex); 58 | assert(IsLocked()); 59 | } 60 | 61 | bool TryLock() { 62 | return ::TryEnterCriticalSection(&m_Mutex) != FALSE; 63 | } 64 | 65 | void Unlock() { 66 | assert(IsLocked()); 67 | ::LeaveCriticalSection(&m_Mutex); 68 | } 69 | 70 | bool IsLocked() const { 71 | if (IsNewBehavior()) // after win2k3 sp1 72 | return (m_Mutex.LockCount & 1) == 0; 73 | else 74 | return m_Mutex.LockCount >= 0; 75 | } 76 | 77 | private: 78 | // In Microsoft Windows Server 2003 Service Pack 1 and later versions of 79 | // Windows, the LockCount field is parsed as follows: 80 | // * The lowest bit shows the lock status. If this bit is 0, the critical 81 | // section is locked; if it is 1, the critical section is not locked. 82 | // * The next bit shows whether a thread has been woken for this lock. 83 | // If this bit is 0, then a thread has been woken for this lock; if it 84 | // is 1, no thread has been woken. 85 | // * The remaining bits are the ones-complement of the number of threads 86 | // waiting for the lock. 87 | static bool IsNewBehavior() { 88 | static bool result = DoIsNewBehavior(); 89 | return result; 90 | } 91 | static bool DoIsNewBehavior() { 92 | Mutex mutex; 93 | int old_lock_count = mutex.m_Mutex.LockCount; 94 | mutex.Lock(); 95 | int new_lock_count = mutex.m_Mutex.LockCount; 96 | return new_lock_count < old_lock_count; 97 | } 98 | private: 99 | Mutex(const Mutex& right); 100 | Mutex& operator = (const Mutex& right); 101 | CRITICAL_SECTION m_Mutex; 102 | friend class Cond; 103 | }; 104 | 105 | #elif defined __unix__ 106 | 107 | class Mutex { 108 | public: 109 | typedef ScopedLocker Locker; 110 | public: 111 | explicit Mutex(bool recursive = true) { 112 | int n; 113 | if (recursive) { 114 | pthread_mutexattr_t attr; 115 | n = pthread_mutexattr_init(&attr); 116 | if (n == 0) { 117 | n = pthread_mutexattr_settype(&attr, PTHREAD_MUTEX_RECURSIVE); 118 | if (n == 0) 119 | n = pthread_mutex_init(&m_Mutex, &attr); 120 | n = pthread_mutexattr_destroy(&attr); 121 | } 122 | } else { 123 | n = pthread_mutex_init(&m_Mutex, NULL); 124 | } 125 | CheckError("Mutex::Mutex", n); 126 | } 127 | ~Mutex() { 128 | pthread_mutex_destroy(&m_Mutex); 129 | } 130 | 131 | void Lock() { 132 | CheckError("Mutex::Lock", pthread_mutex_lock(&m_Mutex)); 133 | assert(IsLocked()); 134 | } 135 | 136 | bool TryLock() { 137 | int n = pthread_mutex_trylock(&m_Mutex); 138 | if (n == EBUSY) { 139 | return false; 140 | } else { 141 | CheckError("Mutex::Lock", n); 142 | return true; 143 | } 144 | } 145 | 146 | // by inspect internal data 147 | bool IsLocked() const { 148 | return m_Mutex.__data.__lock > 0; 149 | } 150 | 151 | void Unlock() { 152 | assert(IsLocked()); 153 | CheckError("Mutex::Unlock", pthread_mutex_unlock(&m_Mutex)); 154 | // NOTE: can't check unlocked here, maybe already locked by 155 | // other thread 156 | } 157 | 158 | private: 159 | static void CheckError(const char* context, int error) { 160 | if (error != 0) { 161 | std::string msg = context; 162 | msg += " error: "; 163 | msg += strerror(error); 164 | throw std::runtime_error(msg); 165 | } 166 | } 167 | private: 168 | Mutex(const Mutex& right); 169 | Mutex& operator = (const Mutex& right); 170 | private: 171 | pthread_mutex_t m_Mutex; 172 | friend class ConditionVariable; 173 | }; 174 | 175 | #endif 176 | 177 | typedef Mutex::Locker MutexLocker; 178 | 179 | /// null mutex for template mutex param placeholder 180 | /// NOTE: don't make this class uncopyable 181 | class NullMutex { 182 | public: 183 | typedef ScopedLocker Locker; 184 | public: 185 | NullMutex() : m_locked(false) { 186 | } 187 | 188 | void Lock() { 189 | m_locked = true; 190 | } 191 | 192 | bool TryLock() { 193 | m_locked = true; 194 | return true; 195 | } 196 | 197 | // by inspect internal data 198 | bool IsLocked() const { 199 | return m_locked; 200 | } 201 | 202 | void Unlock() { 203 | m_locked = false; 204 | } 205 | private: 206 | bool m_locked; 207 | }; 208 | 209 | #endif // SYSTEM_MUTEX_H_ 210 | 211 | -------------------------------------------------------------------------------- /system/mutex_test.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // This is a copy from 5 | 6 | // trunk/src/common/system/concurrency/mutex_test.cpp 7 | // 8 | #include "gtest/gtest.h" 9 | #include "system/mutex.h" 10 | 11 | TEST(Mutex, Lock) { 12 | Mutex mutex; 13 | ASSERT_FALSE(mutex.IsLocked()); 14 | mutex.Lock(); 15 | ASSERT_TRUE(mutex.IsLocked()); 16 | mutex.Unlock(); 17 | ASSERT_FALSE(mutex.IsLocked()); 18 | } 19 | 20 | TEST(Mutex, Locker) { 21 | Mutex mutex; 22 | { 23 | ASSERT_FALSE(mutex.IsLocked()); 24 | MutexLocker locker(&mutex); 25 | ASSERT_TRUE(mutex.IsLocked()); 26 | } 27 | ASSERT_FALSE(mutex.IsLocked()); 28 | } 29 | 30 | TEST(Mutex, LockerWithException) { 31 | Mutex mutex; 32 | try { 33 | ASSERT_FALSE(mutex.IsLocked()); 34 | MutexLocker locker(&mutex); 35 | ASSERT_TRUE(mutex.IsLocked()) << "after locked constructed"; 36 | throw 0; 37 | } catch(...) { 38 | // ignore 39 | } 40 | ASSERT_FALSE(mutex.IsLocked()) << "after exception thrown"; 41 | } 42 | -------------------------------------------------------------------------------- /system/scoped_locker.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // 4 | // This implementation is a copy from 5 | 6 | // trunk/src/common/system/concurrency/scoped_locker.hpp 7 | 8 | #ifndef SYSTEM_SCOPED_LOCKER_H_ 9 | #define SYSTEM_SCOPED_LOCKER_H_ 10 | 11 | #include "base/common.h" 12 | 13 | template 14 | class ScopedLocker { 15 | public: 16 | explicit ScopedLocker(LockType* lock) : m_lock(lock) { 17 | m_lock->Lock(); 18 | } 19 | ~ScopedLocker() { 20 | m_lock->Unlock(); 21 | } 22 | private: 23 | LockType* m_lock; 24 | }; 25 | 26 | template 27 | class ScopedReaderLocker { 28 | public: 29 | explicit ScopedReaderLocker(LockType* lock) : m_lock(lock) { 30 | m_lock->ReaderLock(); 31 | } 32 | ~ScopedReaderLocker() { 33 | m_lock->ReaderUnlock(); 34 | } 35 | private: 36 | LockType* m_lock; 37 | }; 38 | 39 | template 40 | class ScopedWriterLocker { 41 | public: 42 | explicit ScopedWriterLocker(LockType* lock) : m_lock(*lock) { 43 | m_lock.WriterLock(); 44 | } 45 | ~ScopedWriterLocker() { 46 | m_lock.WriterUnlock(); 47 | } 48 | private: 49 | LockType& m_lock; 50 | }; 51 | 52 | #endif // SYSTEM_SCOPED_LOCKER_H_ 53 | --------------------------------------------------------------------------------