├── doc ├── requirements.txt ├── .gitignore ├── README.md └── index.rst ├── tests ├── matlab │ ├── data.mat │ ├── localizer_test.m │ ├── fm_loss_l2.m │ ├── batch_iter_test.m │ ├── fm_loss.m │ ├── fm_loss_test.m │ ├── sgd_test.m │ ├── lr_bcd_test.m │ └── lbfgs.m ├── cpp │ ├── main.cc │ ├── find_position_test.cc │ ├── compressed_row_block_test.cc │ ├── spmt_test.cc │ ├── test.mk │ ├── sgd_learner_test.cc │ ├── batch_reader_test.cc │ ├── bcd_learner_test.cc │ ├── data_store_test.cc │ ├── localizer_test.cc │ ├── spmv_perf.cc │ ├── kv_union_test.cc │ ├── lbfgs_twoloop_test.cc │ ├── fm_loss_test.cc │ ├── spmv_test.cc │ ├── spmm_test.cc │ ├── logit_loss_delta_test.cc │ ├── kv_match_test.cc │ ├── lbfgs_learner_test.cc │ ├── utils.h │ └── spmv_test.h ├── travis │ ├── run_test.sh │ └── setup.sh └── README.md ├── example ├── rcv1_lbfgs.conf ├── ctra_bcd.conf ├── rcv1_bcd.conf ├── rcv1_sgd.conf ├── README.md ├── criteo_sgd.conf ├── ctra_sgd.conf ├── criteo_lbfgs.conf └── ctra_lbfgs.conf ├── .gitmodules ├── src ├── updater.cc ├── store │ ├── store.cc │ └── store_local.h ├── tracker │ ├── dist_tracker.h │ ├── tracker.cc │ ├── local_tracker.h │ └── async_local_tracker.h ├── reporter │ ├── reporter.cc │ └── local_reporter.h ├── loss │ ├── loss.cc │ ├── fm_loss_delta.h │ ├── logit_loss.h │ └── bin_class_metric.h ├── learner.cc ├── common │ ├── parallel_sort.h │ ├── arg_parser.h │ ├── learner_utils.h │ ├── find_position.h │ ├── range.h │ ├── thread_pool.h │ ├── spmt.h │ ├── kv_union.h │ └── kv_match-inl.h ├── reader │ ├── crb_parser.h │ ├── match_file.h │ ├── reader.h │ ├── batch_reader.h │ ├── adfea_parser.h │ ├── batch_reader.cc │ ├── criteo_parser.h │ └── converter.h ├── main.cc ├── bcd │ ├── bcd_param.h │ ├── bcd_learner.h │ └── bcd_utils.h ├── sgd │ ├── sgd_utils.h │ ├── sgd_updater.h │ ├── sgd_learner.h │ ├── sgd_param.h │ └── sgd_updater.cc ├── data │ ├── data_store_impl.h │ ├── localizer.h │ ├── localizer.cc │ ├── shared_row_block_container.h │ └── compressed_row_block.h └── lbfgs │ ├── lbfgs_learner.h │ ├── lbfgs_utils.h │ ├── lbfgs_param.h │ └── lbfgs_twoloop.h ├── .gitignore ├── tools ├── copyright.sh └── download.sh ├── LICENSE ├── .travis.yml ├── include └── difacto │ ├── node_id.h │ ├── README.md │ ├── sarray.h │ ├── reporter.h │ ├── learner.h │ ├── updater.h │ ├── loss.h │ ├── store.h │ ├── tracker.h │ └── base.h ├── README.md └── Makefile /doc/requirements.txt: -------------------------------------------------------------------------------- 1 | breathe 2 | -------------------------------------------------------------------------------- /doc/.gitignore: -------------------------------------------------------------------------------- 1 | _* 2 | html 3 | xml 4 | -------------------------------------------------------------------------------- /tests/matlab/data.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmlc/difacto/HEAD/tests/matlab/data.mat -------------------------------------------------------------------------------- /tests/matlab/localizer_test.m: -------------------------------------------------------------------------------- 1 | load data 2 | 3 | %% 4 | 5 | [a,b,c] = find(x); 6 | sum(unique(b)) 7 | sum(unique(mod(unique(b), 1000))) 8 | -------------------------------------------------------------------------------- /example/rcv1_lbfgs.conf: -------------------------------------------------------------------------------- 1 | task = train 2 | data_in = data/rcv1_train.binary 3 | l2 = 1 4 | learner = lbfgs 5 | max_num_epochs = 5 6 | V_dim = 0 7 | V_l2 = .01 8 | -------------------------------------------------------------------------------- /example/ctra_bcd.conf: -------------------------------------------------------------------------------- 1 | task = train 2 | data_in = ../data/ctra_train 3 | # data_val = data/ctra_test 4 | l1 = 1 5 | lr = .2 6 | learner = bcd 7 | block_ratio = 1 8 | -------------------------------------------------------------------------------- /example/rcv1_bcd.conf: -------------------------------------------------------------------------------- 1 | task = train 2 | data_in = ../data/rcv1_train.binary 3 | l1 = 1 4 | lr = .5 5 | learner = bcd 6 | block_ratio = 1 7 | 8 | max_num_epochs = 10 9 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "dmlc-core"] 2 | path = dmlc-core 3 | url = https://github.com/dmlc/dmlc-core 4 | [submodule "ps-lite"] 5 | path = ps-lite 6 | url = https://github.com/dmlc/ps-lite 7 | -------------------------------------------------------------------------------- /example/rcv1_sgd.conf: -------------------------------------------------------------------------------- 1 | # data 2 | task = train 3 | data_in = ../data/rcv1_train.binary 4 | l1 = 1 5 | lr = .1 6 | learner = sgd 7 | max_num_epochs = 10 8 | batch_size = 100 9 | 10 | # embedding term 11 | V_dim = 0 12 | -------------------------------------------------------------------------------- /tests/cpp/main.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include "gtest/gtest.h" 5 | int main(int argc, char **argv) { 6 | ::testing::InitGoogleTest(&argc, argv); 7 | return RUN_ALL_TESTS(); 8 | } 9 | -------------------------------------------------------------------------------- /doc/README.md: -------------------------------------------------------------------------------- 1 | # How to build the document 2 | 3 | Requirements: 4 | - Doxygen 5 | - Sphinx 6 | - Breathe 7 | 8 | 9 | First `doxygen`, then `make html`. The results will be at [_build/html/index.html](_build/html/index.html). 10 | -------------------------------------------------------------------------------- /example/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | This fold contains example configurations, which can be passed by 4 | `argfile=`. For example, 5 | 6 | ```bash 7 | tools/download.sh criteo 8 | build/difacto argfile=example/criteo_lbfgs.conf 9 | ``` 10 | -------------------------------------------------------------------------------- /tests/matlab/fm_loss_l2.m: -------------------------------------------------------------------------------- 1 | function [objv, gw, gV] = fm_loss_l2(y, X, w, X2, V, l2_w, l2_V) 2 | 3 | [objv, gw, gV] = fm_loss(y, X, w, X2, V); 4 | gw = gw + l2_w * w; 5 | gV = gV + l2_V * V; 6 | objv = objv + .5 * l2_w * sum(w(:).^2) + .5 * l2_V * sum(V(:).^2); 7 | end 8 | -------------------------------------------------------------------------------- /tests/travis/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ${TASK} == "lint" ]; then 4 | make lint 5 | exit $? 6 | fi 7 | 8 | if [ ${TASK} == "cpp-test" ]; then 9 | make -j4 test CXX=g++-4.8 ADD_CFLAGS=-coverage 10 | cd build; ./difacto_tests 11 | exit $? 12 | fi 13 | -------------------------------------------------------------------------------- /src/updater.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include "difacto/updater.h" 5 | #include "./sgd/sgd_updater.h" 6 | #include "./bcd/bcd_updater.h" 7 | namespace difacto { 8 | 9 | DMLC_REGISTER_PARAMETER(SGDUpdaterParam); 10 | 11 | } // namespace difacto 12 | -------------------------------------------------------------------------------- /example/criteo_sgd.conf: -------------------------------------------------------------------------------- 1 | # data 2 | data_in = data/criteo_kaggle/criteo_train.rec 3 | data_val = data/criteo_kaggle/criteo_val.rec 4 | data_format = rec 5 | 6 | # learner 7 | task = train 8 | learner = sgd 9 | max_num_epochs = 10 10 | batch_size = 10000 11 | 12 | # linear term 13 | l1 = 10 14 | l2 = 10 15 | 16 | # embedding term 17 | V_dim = 10 18 | V_threshold = 10 19 | V_l2 = 10 20 | -------------------------------------------------------------------------------- /src/store/store.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include "difacto/store.h" 5 | #include "./store_local.h" 6 | namespace difacto { 7 | 8 | Store* Store::Create() { 9 | if (IsDistributed()) { 10 | LOG(FATAL) << "not implemented"; 11 | return nullptr; 12 | } else { 13 | return new StoreLocal(); 14 | } 15 | } 16 | 17 | } // namespace difacto 18 | -------------------------------------------------------------------------------- /example/ctra_sgd.conf: -------------------------------------------------------------------------------- 1 | # data 2 | task = train 3 | data_in = data/ctra/ctra_train.rec 4 | data_val = data/ctra/ctra_val.rec 5 | data_format = rec 6 | 7 | # learner 8 | learner = sgd 9 | max_num_epochs = 10 10 | batch_size = 10000 11 | lr = .1 12 | 13 | # linear 14 | l1 = 10 15 | l2 = 10 16 | # tail_feature_filter = 4 17 | 18 | # embeding 19 | V_dim = 10 20 | V_l2 = 10 21 | V_threshold = 10 22 | -------------------------------------------------------------------------------- /src/tracker/dist_tracker.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_TRACKER_DIST_TRACKER_H_ 5 | #define DIFACTO_TRACKER_DIST_TRACKER_H_ 6 | namespace difacto { 7 | /** 8 | * \brief a tracker which runs over mutliple machines 9 | */ 10 | class DistTracker : public Tracker { 11 | }; 12 | } // namespace difacto 13 | #endif // DIFACTO_TRACKER_DIST_TRACKER_H_ 14 | -------------------------------------------------------------------------------- /example/criteo_lbfgs.conf: -------------------------------------------------------------------------------- 1 | task = train 2 | 3 | # algo 4 | learner = lbfgs 5 | max_num_epochs = 500 6 | m = 10 7 | 8 | # data 9 | data_in = data/criteo_kaggle/criteo_train.rec 10 | data_val = data/criteo_kaggle/criteo_val.rec 11 | data_format = rec 12 | 13 | # linear term 14 | tail_feature_filter = 4 15 | l2 = 100 16 | 17 | # embedding term 18 | V_dim = 10 19 | V_threshold = 10 20 | V_l2 = 10 21 | -------------------------------------------------------------------------------- /src/tracker/tracker.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include "difacto/tracker.h" 5 | #include "./local_tracker.h" 6 | namespace difacto { 7 | 8 | Tracker* Tracker::Create() { 9 | if (IsDistributed()) { 10 | LOG(FATAL) << "not implemented"; 11 | return nullptr; 12 | } else { 13 | return new LocalTracker(); 14 | } 15 | } 16 | 17 | } // namespace difacto 18 | -------------------------------------------------------------------------------- /src/reporter/reporter.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include "difacto/reporter.h" 5 | #include "./local_reporter.h" 6 | namespace difacto { 7 | 8 | Reporter* Reporter::Create() { 9 | if (IsDistributed()) { 10 | LOG(FATAL) << "not implemented"; 11 | return nullptr; 12 | } else { 13 | return new LocalReporter(); 14 | } 15 | } 16 | 17 | } // namespace difacto 18 | -------------------------------------------------------------------------------- /example/ctra_lbfgs.conf: -------------------------------------------------------------------------------- 1 | # data 2 | task = train 3 | data_in = data/ctra/ctra_train.rec 4 | data_val = data/ctra/ctra_val.rec 5 | data_format = rec 6 | 7 | # learner 8 | learner = lbfgs 9 | m = 10 10 | max_num_epochs = 100 11 | max_num_linesearchs = 20 12 | stop_val_auc = 1e-5 13 | 14 | # linear 15 | l2 = 500 16 | tail_feature_filter = 4 17 | 18 | # embedding 19 | V_dim = 40 20 | V_l2 = 100 21 | V_threshold = 40 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files 2 | *.slo 3 | *.lo 4 | *.o 5 | *.obj 6 | 7 | # Precompiled Headers 8 | *.gch 9 | *.pch 10 | 11 | # Compiled Dynamic libraries 12 | *.so 13 | *.dylib 14 | *.dll 15 | 16 | # Fortran module files 17 | *.mod 18 | 19 | # Compiled Static libraries 20 | *.lai 21 | *.la 22 | *.a 23 | *.lib 24 | 25 | # Executables 26 | *.exe 27 | *.out 28 | *.app 29 | 30 | deps/* 31 | build/* 32 | data/* 33 | *.gc* 34 | -------------------------------------------------------------------------------- /tests/matlab/batch_iter_test.m: -------------------------------------------------------------------------------- 1 | load data 2 | %% 3 | batch = 37 4 | ix = [1 : 37 : 100, 101] 5 | 6 | re = []; 7 | for i = 1 : length(ix) - 1 8 | j = ix(i) : ix(i+1)-1; 9 | v = x(j,:); 10 | [a,b,c] = find(v'); 11 | c(1:3) 12 | w = full(sparse(a,b,1)); 13 | re = [re; [length(j), sum(y(j)), sum(cumsum(sum(w))), sum(abs(a)), sum(abs(b-1)), sum(c.*c)]]; 14 | end 15 | 16 | for i = 1 : size(re, 2) 17 | re(:,i)' 18 | end 19 | -------------------------------------------------------------------------------- /tools/copyright.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # add copyright 3 | dir=`pwd`/`dirname $0`/..; cd $dir 4 | copyright=/tmp/copyright 5 | 6 | cat >$copyright <$file.new && mv $file.new $file 15 | fi 16 | done 17 | -------------------------------------------------------------------------------- /tests/matlab/fm_loss.m: -------------------------------------------------------------------------------- 1 | function [objv, gw, gV] = fm_loss(y, X, w, X2, V) 2 | % X, w for the linear term 3 | % X2, V for the embedding term 4 | 5 | py = X * w; 6 | if ~isempty(X2) 7 | py = py + .5 * sum((X2*V).^2 - (X2.*X2)*(V.*V), 2); 8 | end 9 | 10 | objv = sum(log(1+exp(-y .* py))); 11 | p = - y ./ (1 + exp (y .* py)); 12 | gw = X' * p; 13 | 14 | if ~isempty(X2) 15 | gV = X2' * bsxfun(@times, p, X2*V) - bsxfun(@times, (X2.*X2)'*p, V); 16 | else 17 | gV = []; 18 | end 19 | 20 | 21 | end 22 | -------------------------------------------------------------------------------- /tests/matlab/fm_loss_test.m: -------------------------------------------------------------------------------- 1 | load data 2 | 3 | %% logit 4 | 5 | w = ((1:size(x,2))/5e4); 6 | % w = ones(1,size(x,2)); 7 | 8 | sum(log(1 + exp ( - y .* (x * w')))) 9 | 10 | g = full(x' * (-y ./ (1 + exp( y .* (x * w'))))); 11 | sum(g.*g) 12 | 13 | tau = 1 ./ (1 + exp( y .* (x * w'))); 14 | h = full((x.*x)' * (tau .* (1-tau))); 15 | sum(h.*h) 16 | 17 | 18 | %% fm 19 | 20 | V_dim = 5; 21 | 22 | w = (1:size(x,2))'/5e4; 23 | V = w * (1:V_dim) / 10; 24 | 25 | [objv, gw, gV] = fm_loss(y, x, w, x, V); 26 | 27 | objv 28 | sum(gw.^2) + sum(gV(:).^2) 29 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. difacto documentation master file, created by 2 | sphinx-quickstart on Mon Dec 14 18:03:09 2015. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to difacto's documentation! 7 | =================================== 8 | 9 | Contents: 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | 14 | .. doxygenstruct:: difacto::DiFactoParam 15 | :project: difacto 16 | :members: 17 | 18 | Indices and tables 19 | ================== 20 | 21 | * :ref:`genindex` 22 | * :ref:`modindex` 23 | * :ref:`search` 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 by Contributors 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/travis/setup.sh: -------------------------------------------------------------------------------- 1 | if [ ${TASK} == "lint" ]; then 2 | pip install cpplint pylint --user `whoami` 3 | fi 4 | 5 | # setup cache prefix 6 | export CACHE_PREFIX=${HOME}/.cache/usr 7 | export CPLUS_INCLUDE_PATH=${CPLUS_INCLUDE_PATH}:${CACHE_PREFIX}/include 8 | export C_INCLUDE_PATH=${C_INCLUDE_PATH}:${CACHE_PREFIX}/include 9 | export LIBRARY_PATH=${LIBRARY_PATH}:${CACHE_PREFIX}/lib 10 | export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${CACHE_PREFIX}/lib 11 | export DYLD_LIBRARY_PATH=${DYLD_LIBRARY_PATH}:${CACHE_PREFIX}/lib 12 | 13 | if [ ${TASK} == "cpp-test" ]; then 14 | make -f dmlc-core/scripts/packages.mk gtest 15 | fi 16 | -------------------------------------------------------------------------------- /tests/cpp/find_position_test.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include 5 | #include 6 | #include "./utils.h" 7 | #include "common/find_position.h" 8 | 9 | using namespace difacto; 10 | 11 | TEST(FindPosition, Basic) { 12 | SArray a = {3, 5, 7}; 13 | SArray b = {1, 3, 4, 7, 8}; 14 | SArray pos, pos2 = {-1, 0, -1, 2, -1}; 15 | FindPosition(a, b, &pos); 16 | for (size_t i = 0; i < pos2.size(); ++i) EXPECT_EQ(pos2[i], pos[i]); 17 | 18 | SArray pos3, pos4 = {1, -1, 3}; 19 | FindPosition(b, a, &pos3); 20 | for (size_t i = 0; i < pos4.size(); ++i) EXPECT_EQ(pos4[i], pos3[i]); 21 | } 22 | -------------------------------------------------------------------------------- /tests/cpp/compressed_row_block_test.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include 5 | #include "data/compressed_row_block.h" 6 | #include "reader/batch_reader.h" 7 | #include "./utils.h" 8 | 9 | using namespace difacto; 10 | 11 | TEST(CompressedRowBlock, Basic) { 12 | BatchReader reader("../tests/data", "libsvm", 0, 1, 100); 13 | CHECK(reader.Next()); 14 | auto A = reader.Value(); 15 | 16 | std::string out; 17 | CompressedRowBlock crb; 18 | crb.Compress(A, &out); 19 | 20 | dmlc::data::RowBlockContainer container; 21 | crb.Decompress(out, &container); 22 | auto B = container.GetBlock(); 23 | 24 | check_equal(A, B); 25 | } 26 | -------------------------------------------------------------------------------- /src/loss/loss.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include "difacto/loss.h" 5 | #include "./fm_loss.h" 6 | #include "./logit_loss_delta.h" 7 | #include "./logit_loss.h" 8 | namespace difacto { 9 | 10 | DMLC_REGISTER_PARAMETER(FMLossParam); 11 | DMLC_REGISTER_PARAMETER(LogitLossDeltaParam); 12 | 13 | Loss* Loss::Create(const std::string& type, int nthreads) { 14 | Loss* loss = nullptr; 15 | if (type == "fm") { 16 | loss = new FMLoss(); 17 | } else if (type == "logit") { 18 | loss = new LogitLoss(); 19 | } else if (type == "logit_delta") { 20 | loss = new LogitLossDelta(); 21 | } else { 22 | LOG(FATAL) << "unknown loss type"; 23 | } 24 | loss->set_nthreads(nthreads); 25 | return loss; 26 | } 27 | 28 | } // namespace difacto 29 | -------------------------------------------------------------------------------- /src/reporter/local_reporter.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include 5 | #ifndef DIFACTO_REPORTER_LOCAL_REPORTER_H_ 6 | #define DIFACTO_REPORTER_LOCAL_REPORTER_H_ 7 | namespace difacto { 8 | 9 | class LocalReporter : public Reporter { 10 | public: 11 | LocalReporter() { } 12 | virtual ~LocalReporter() { } 13 | 14 | KWArgs Init(const KWArgs& kwargs) override { return kwargs; } 15 | 16 | void SetMonitor(const Monitor& monitor) override { 17 | monitor_ = monitor; 18 | } 19 | 20 | int Report(const std::string& report) { 21 | monitor_(-1, report); return 0; 22 | } 23 | 24 | void Wait(int timestamp) { } 25 | 26 | private: 27 | Monitor monitor_; 28 | }; 29 | } // namespace difacto 30 | #endif // DIFACTO_REPORTER_LOCAL_REPORTER_H_ 31 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | 3 | language: cpp 4 | 5 | os: 6 | - linux 7 | 8 | env: 9 | - TASK=lint 10 | # - TASK=doc 11 | - TASK=cpp-test 12 | 13 | # dependent apt packages 14 | addons: 15 | apt: 16 | sources: 17 | - ubuntu-toolchain-r-test 18 | packages: 19 | - doxygen 20 | - git 21 | - gcc-4.8 22 | - g++-4.8 23 | 24 | before_install: 25 | - pip install --user codecov 26 | 27 | after_success: 28 | - bash <(curl -s https://codecov.io/bash) 29 | 30 | install: 31 | - source tests/travis/setup.sh 32 | 33 | script: 34 | - tests/travis/run_test.sh 35 | 36 | cache: 37 | directories: 38 | - ${HOME}/.cache/usr 39 | 40 | before_cache: 41 | - dmlc-core/scripts/travis/travis_before_cache.sh 42 | 43 | notifications: 44 | # Emails are sent to the committer's git-configured email address by default, 45 | email: 46 | on_success: change 47 | on_failure: always 48 | -------------------------------------------------------------------------------- /tests/cpp/spmt_test.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include 5 | #include "./utils.h" 6 | #include "difacto/base.h" 7 | #include "common/spmt.h" 8 | #include "data/localizer.h" 9 | 10 | using namespace difacto; 11 | TEST(SpMT, Transpose) { 12 | dmlc::data::RowBlockContainer data; 13 | std::vector uidx; 14 | load_data(&data, &uidx); 15 | 16 | auto X = data.GetBlock(); 17 | dmlc::data::RowBlockContainer Y, X2; 18 | SpMT::Transpose(X, &Y, uidx.size()); 19 | SpMT::Transpose(Y.GetBlock(), &X2); 20 | 21 | auto X3 = X2.GetBlock(); 22 | size_t nnz = X3.offset[X3.size]; 23 | EXPECT_EQ(X3.size, X.size); 24 | EXPECT_EQ(norm1(X.offset, X.size+1), 25 | norm1(X3.offset, X.size+1)); 26 | EXPECT_EQ(norm1(X.index, nnz), 27 | norm1(X3.index, nnz)); 28 | EXPECT_EQ(norm2(X.value, nnz), 29 | norm2(X3.value, nnz)); 30 | } 31 | -------------------------------------------------------------------------------- /tests/cpp/test.mk: -------------------------------------------------------------------------------- 1 | GTEST_PATH = /usr 2 | 3 | CPPTEST_SRC = $(wildcard tests/cpp/*_test.cc) 4 | CPPTEST_OBJ = $(patsubst tests/cpp/%_test.cc, build/tests/%_test.o, $(CPPTEST_SRC)) 5 | 6 | build/tests/%.o : tests/cpp/%.cc ${DEPS} 7 | @mkdir -p $(@D) 8 | $(CXX) $(INCPATH) -std=c++0x -MM -MT build/tests/$*.o $< >build/tests/$*.d 9 | $(CXX) $(CFLAGS) -c $< -o $@ 10 | 11 | build/difacto_tests: $(CPPTEST_OBJ) build/tests/main.o build/libdifacto.a $(DMLC_DEPS) 12 | $(CXX) $(CFLAGS) -I$(GTEST_PATH)/include -o $@ $^ $(LDFLAGS) -L$(GTEST_PATH)/lib -lgtest 13 | 14 | CPPPERF_SRC = $(wildcard tests/cpp/*_perf.cc) 15 | CPPPERF = $(patsubst tests/cpp/%_perf.cc, build/%_perf, $(CPPTEST_SRC)) 16 | 17 | 18 | build/%_perf : tests/cpp/%_perf.cc build/libdifacto.a $(DMLC_DEPS) ${DEPS} 19 | $(CXX) -std=c++0x $(CFLAGS) -MM -MT $@ $< >$@.d 20 | $(CXX) -std=c++0x $(CFLAGS) -I$(GTEST_PATH)/include -o $@ $(filter %.cc %.a, $^) $(LDFLAGS) 21 | 22 | cpp-perf: $(CPPPERF) 23 | -------------------------------------------------------------------------------- /include/difacto/node_id.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_NODE_ID_H_ 5 | #define DIFACTO_NODE_ID_H_ 6 | namespace difacto { 7 | 8 | class NodeID { 9 | public: 10 | /** \brief node ID for the scheduler */ 11 | static const int kScheduler = 1; 12 | /** 13 | * \brief the server node group ID 14 | * 15 | * group id can be combined: 16 | * - kServerGroup + kScheduler means all server nodes and the scheuduler 17 | * - kServerGroup + kWorkerGroup means all server and worker nodes 18 | */ 19 | static const int kServerGroup = 2; 20 | /** \brief the worker node group ID */ 21 | static const int kWorkerGroup = 4; 22 | 23 | static int Encode(int group, int rank) { 24 | return group + (rank+1) * 8; 25 | } 26 | 27 | /* \brief return the node group id */ 28 | static int GetGroup(int id) { 29 | return (id % 8); 30 | } 31 | }; 32 | } // namespace difacto 33 | 34 | #endif // DIFACTO_NODE_ID_H_ 35 | -------------------------------------------------------------------------------- /src/learner.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include "difacto/learner.h" 5 | #include "./sgd/sgd_param.h" 6 | #include "./sgd/sgd_learner.h" 7 | #include "./bcd/bcd_param.h" 8 | #include "./bcd/bcd_learner.h" 9 | #include "./lbfgs/lbfgs_learner.h" 10 | namespace difacto { 11 | 12 | DMLC_REGISTER_PARAMETER(SGDLearnerParam); 13 | DMLC_REGISTER_PARAMETER(BCDLearnerParam); 14 | 15 | Learner* Learner::Create(const std::string& type) { 16 | if (type == "sgd") { 17 | return new SGDLearner(); 18 | } else if (type == "bcd") { 19 | return new BCDLearner(); 20 | } else if (type == "lbfgs") { 21 | return new LBFGSLearner(); 22 | } else { 23 | LOG(FATAL) << "unknown learner type: " << type; 24 | } 25 | return nullptr; 26 | } 27 | 28 | KWArgs Learner::Init(const KWArgs& kwargs) { 29 | // init job tracker 30 | tracker_ = Tracker::Create(); 31 | auto remain = tracker_->Init(kwargs); 32 | using namespace std::placeholders; 33 | tracker_->SetExecutor(std::bind(&Learner::Process, this, _1, _2)); 34 | return remain; 35 | } 36 | 37 | } // namespace difacto 38 | -------------------------------------------------------------------------------- /tests/matlab/sgd_test.m: -------------------------------------------------------------------------------- 1 | %% 2 | load data 3 | [i,j,k]=find(x); 4 | cnt = full(sum(sparse(i,j,ones(size(i))))); 5 | X = x(:,cnt>0); 6 | Y = y; 7 | X2 = x(:,cnt>0); 8 | 9 | %% 10 | lr = 1; 11 | lr_V = .8; 12 | l1 = 1; 13 | l2 = .1; 14 | V_l2 = .1; 15 | 16 | V_dim = 0; 17 | if V_dim == 0 18 | X2 = []; 19 | end 20 | 21 | [n, p] = size(X); 22 | w = zeros(p,1); 23 | V = repmat(((1:V_dim) - V_dim/2)*.01, size(X2,2), 1); 24 | 25 | sq_w = zeros(size(w)); 26 | sq_V = zeros(size(V)); 27 | z = zeros(size(w)); 28 | 29 | for k = 1 : 20 30 | [objv, gw, gV] = fm_loss_l2(Y, X, w, X2, V, lr, lr_V); 31 | objv = objv + l1 * sum(abs(w)); 32 | objv = fm_loss(Y, X, w, X2, V); 33 | sq_w_new = sqrt(sq_w.*sq_w + gw.*gw); 34 | sq_V = sqrt(sq_V.*sq_V + gV.*gV); 35 | 36 | z = z - gw - (sq_w - sq_w_new) ./ lr .* w; 37 | sq_w = sq_w_new; 38 | 39 | ix = (z <= l1) & (z >= -l1); 40 | w(ix) = 0; 41 | eta = (1 + sq_w) / lr; 42 | 43 | ix1 = (~ ix) & (z > 0); 44 | w(ix1) = (z(ix1) - l1) ./ eta(ix1); 45 | 46 | ix2 = (~ ix) & (z < 0); 47 | w(ix2) = (z(ix2) + l1) ./ eta(ix2); 48 | 49 | V = V - lr_V * gV./(sq_V+1); 50 | 51 | fprintf('objv = %f, nnz_w = %f\n', objv, nnz(w)) 52 | end 53 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # test codes 2 | 3 | ## test data 4 | 5 | [data](data) contains the first 100 lines of the 6 | [rcv1](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#rcv1.binary) 7 | dataset in libsvm format 8 | 9 | ## c++ tests 10 | [cpp/](cpp/) contains c++ test codes 11 | 12 | - `gtest` is needed for compiling. One can install `libgtest-dev` using a package 13 | manager, e.g. for ubuntu 14 | 15 | ```bash 16 | sudo apt-get install libgtest-dev 17 | cd /usr/src/gtest 18 | sudo cmake . 19 | sudo make 20 | sudo mv libg* /usr/lib/ 21 | ``` 22 | 23 | - compile by `make test` on the project root directory 24 | 25 | - run all tests by 26 | ```bash 27 | cd build; ./difacto_tests 28 | ``` 29 | 30 | Use `./difacto_tests --gtest_list_tests` to list all tests and 31 | `./difacto_tests --gtest_filter=PATTERN` to run some particular tests 32 | 33 | ### coverage 34 | 35 | ```bash 36 | make clean; make -j8 test ADD_CFLAGS=-coverage; cd build; ./difacto_tests; codecov 37 | ``` 38 | 39 | ### disable feature id reversing 40 | ```bash 41 | make clean; make -j8 NO_REVERSE_ID=1 42 | ``` 43 | ## matlab tests 44 | 45 | Some scripts used to generate the *ground truth*. 46 | -------------------------------------------------------------------------------- /include/difacto/README.md: -------------------------------------------------------------------------------- 1 | This fold contains the abstract classes of difacto. 2 | 3 | 1. [Learner](learner.h) the base class of the learning algorithm. The system 4 | starts by calling `Learner.Run()` 5 | 6 | 2. [Loss](loss.h) the base class of a loss function, such as the logistic loss, 7 | which is able to evaluate the object value and calculate the gradients based 8 | on the weights and data. 9 | 10 | 3. [Updater](updater.h) the base class of a updater, which maintains the 11 | weights and allows to get (`Get()`) the weights and update (`Update()`) the 12 | weights based on the inputs (often the gradients) 13 | 14 | 4. [Store](store.h) the data communication interface for the Updater. So a remote 15 | node (such a worker) can get (`Pull()`) and update (`Push()`) the weights. 16 | 17 | 5. [Tracker](tracker.h) the control interface, which is able to 18 | issue remote procedure calls (RPCs) from the scheduler to any workers and 19 | servers, and monitors the progress. 20 | 21 | The following figure shows the scheduler sends a RPC to a worker, which get and 22 | update the weights on the servers. 23 | 24 | 25 | -------------------------------------------------------------------------------- /tests/matlab/lr_bcd_test.m: -------------------------------------------------------------------------------- 1 | load data 2 | a = sum(abs(x)); 3 | X = x(:,a~=0); 4 | 5 | %% 6 | load rcv1 7 | y = Y; 8 | %% 9 | 10 | l1 = .1; 11 | % l2 = .01; 12 | lr = .5; 13 | nblk = 100; 14 | 15 | [n,p] = size(X); 16 | w = zeros(p,1); 17 | delta = ones(p,1); 18 | 19 | blks = round(linspace(1,p+1,nblk+1)) 20 | 21 | for i = 1 : 1000 22 | objv = sum(log(1 + exp ( - y .* (X * w)))); 23 | fprintf('iter %d, objv %f, nnz w %d\n', i, objv, nnz(w)); 24 | rdp = randperm(nblk); 25 | for b = 1 : nblk 26 | blk = false(p,1); 27 | blk(blks(rdp(b)) : blks(rdp(b)+1)-1) = true; 28 | 29 | tau = 1 ./ (1 + exp(y .* (X * w))); 30 | g = full(X(:,blk)' * (-y .* tau)); 31 | h = full((X(:,blk).^2)' * (tau .* (1-tau))) / lr + 1e-6; 32 | 33 | % soft-threadhold 34 | d = -w(blk); 35 | gp = g + l1; 36 | ix = gp <= h .* w(blk); 37 | d(ix) = - gp(ix) ./ h(ix); 38 | gn = g - l1; 39 | ix = gn >= h .* w(blk); 40 | d(ix) = - gn(ix) ./ h(ix); 41 | 42 | d = max(min(d, delta(blk)), -delta(blk)); 43 | delta(blk) = 2*abs(d) + .1; 44 | w(blk) = w(blk) + d; 45 | % fprintf('%f %f %f %f\n', norm(g)^2, norm(h*lr)^2, norm(w)^2, norm(delta)^2); 46 | end 47 | 48 | delta = max(min(delta, 5), -5); 49 | end 50 | -------------------------------------------------------------------------------- /include/difacto/sarray.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_SARRAY_H_ 5 | #define DIFACTO_SARRAY_H_ 6 | #include "ps/sarray.h" 7 | namespace difacto { 8 | 9 | /** 10 | * \brief Shared array 11 | * 12 | * A smart array that retains shared ownership. It provides similar 13 | * functionalities comparing to std::vector, including data(), size(), 14 | * operator[], resize(), clear(). SArray can be easily constructed from 15 | * std::vector, such as 16 | * 17 | * \code 18 | * std::vector a(10); SArray b(a); // copying 19 | * std::shared_ptr> c(new std::vector(10)); 20 | * SArray d(c); // only pointer copying 21 | * \endcode 22 | * 23 | * SArray is also like a C pointer when copying and assigning, namely 24 | * both copy are assign are passing by pointers. The memory will be release only 25 | * if there is no copy exists. It is also can be cast without memory copy, such as 26 | * 27 | * \code 28 | * SArray a(10); 29 | * SArray b(a); // now b.size() = 10 * sizeof(int); 30 | * \endcode 31 | * 32 | * \tparam T the value type 33 | */ 34 | template 35 | using SArray = ps::SArray; 36 | 37 | } // namespace difacto 38 | #endif // DIFACTO_SARRAY_H_ 39 | -------------------------------------------------------------------------------- /tests/cpp/sgd_learner_test.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include 5 | #include "sgd/sgd_learner.h" 6 | 7 | using namespace difacto; 8 | 9 | TEST(SGDLearner, Basic) { 10 | std::vector objv = { 11 | 69.314718, 12 | 69.314718, 13 | 67.151912, 14 | 61.414778, 15 | 56.244989, 16 | 53.218700, 17 | 51.248737, 18 | 49.846688, 19 | 48.650164, 20 | 47.698351, 21 | 46.924038, 22 | 46.388223, 23 | 45.970721, 24 | 45.499307, 25 | 45.102245, 26 | 44.798413, 27 | 44.565211, 28 | 44.386417, 29 | 44.240657, 30 | 44.109764}; 31 | SGDLearner learner; 32 | KWArgs args = {{"data_in", "../tests/data"}, 33 | {"V_dim", "0"}, 34 | {"l2", "1"}, 35 | {"l1", "1"}, 36 | {"lr", "1"}, 37 | {"num_jobs_per_epoch", "1"}, 38 | {"batch_size", "100"}, 39 | {"max_num_epochs", "20"}}; 40 | auto remain = learner.Init(args); 41 | EXPECT_EQ(remain.size(), 0); 42 | 43 | auto callback = [objv]( 44 | int epoch, const sgd::Progress& train, const sgd::Progress& val) { 45 | EXPECT_LT(fabs(objv[epoch] - train.loss), 5e-5); 46 | }; 47 | learner.AddEpochEndCallback(callback); 48 | learner.Run(); 49 | } 50 | -------------------------------------------------------------------------------- /tools/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ $# -ne 1 ]; then 3 | echo "usage: $0 dataset_name" 4 | echo " dataset_name can be rcv1, criteo, ctra, ..." 5 | echo "sample: $0 ctra" 6 | exit 0 7 | fi 8 | 9 | mkdir -p data && cd data 10 | 11 | # download from http://data.dmlc.ml/difacto/datasets/ 12 | dmlc_download() { 13 | url=http://data.dmlc.ml/difacto/datasets/ 14 | file=$1 15 | dir=`dirname $file` 16 | if [ ! -e $file ]; then 17 | wget ${url}/${file} -P ${dir} 18 | fi 19 | } 20 | 21 | # download from https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/ 22 | libsvm_download() { 23 | url=https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/ 24 | file=$1 25 | if [ ! -e $file ]; then 26 | if [ ! -e ${file}.bz2 ]; then 27 | wget ${url}/${file}.bz2 28 | fi 29 | bzip2 -d ${file}.bz2 30 | fi 31 | } 32 | if [ $1 == "ctra" ]; then 33 | dmlc_download ctra/ctra_train.rec 34 | dmlc_download ctra/ctra_val.rec 35 | elif [ $1 == "criteo" ]; then 36 | dmlc_download criteo_kaggle/criteo_train.rec 37 | dmlc_download criteo_kaggle/criteo_val.rec 38 | elif [ $1 == "gisette" ]; then 39 | libsvm_download gisette_scale 40 | libsvm_download gisette_scale.t 41 | elif [ $1 == "rcv1" ]; then 42 | libsvm_download rcv1_train.binary 43 | else 44 | echo "unknown dataset name : $1" 45 | fi 46 | -------------------------------------------------------------------------------- /src/common/parallel_sort.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_COMMON_PARALLEL_SORT_H_ 5 | #define DIFACTO_COMMON_PARALLEL_SORT_H_ 6 | #include 7 | #include 8 | #include 9 | namespace difacto { 10 | namespace { 11 | /** 12 | * \brief the thread function 13 | */ 14 | template 15 | void ParallelSort_(T* data, size_t len, size_t grainsize, const Fn& cmp) { 16 | if (len <= grainsize) { 17 | std::sort(data, data + len, cmp); 18 | } else { 19 | std::thread thr(ParallelSort_, data, len/2, grainsize, cmp); 20 | ParallelSort_(data + len/2, len - len/2, grainsize, cmp); 21 | thr.join(); 22 | std::inplace_merge(data, data + len/2, data + len, cmp); 23 | } 24 | } 25 | } // namespace 26 | 27 | /** 28 | * @brief Parallel Sort 29 | * 30 | * @param arr the array for sorting 31 | * @param num_threads 32 | * @param cmp the comparision function, such as [](const T& a, const T& b) { 33 | * return a < b; } or an even simplier version: std::less() 34 | */ 35 | template 36 | void ParallelSort(std::vector* arr, int num_threads, const Fn& cmp) { 37 | size_t grainsize = std::max(arr->size() / num_threads + 5, (size_t)1024*16); 38 | ParallelSort_(arr->data(), arr->size(), grainsize, cmp); 39 | } 40 | 41 | } // namespace difacto 42 | #endif // DIFACTO_COMMON_PARALLEL_SORT_H_ 43 | -------------------------------------------------------------------------------- /src/reader/crb_parser.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | * @file crb_parser.h 4 | * @brief parser for compressed row block data format 5 | */ 6 | #ifndef DIFACTO_READER_CRB_PARSER_H_ 7 | #define DIFACTO_READER_CRB_PARSER_H_ 8 | #include 9 | #include "data/parser.h" 10 | #include "dmlc/recordio.h" 11 | #include "data/compressed_row_block.h" 12 | namespace difacto { 13 | /** 14 | * \brief compressed row block parser 15 | */ 16 | class CRBParser : public dmlc::data::ParserImpl { 17 | public: 18 | explicit CRBParser(dmlc::InputSplit *source) 19 | : bytes_read_(0), source_(source) { 20 | } 21 | virtual ~CRBParser() { 22 | delete source_; 23 | } 24 | void BeforeFirst(void) override { 25 | source_->BeforeFirst(); 26 | } 27 | size_t BytesRead(void) const override { 28 | return bytes_read_; 29 | } 30 | bool ParseNext( 31 | std::vector > *data) override { 32 | dmlc::InputSplit::Blob rec; 33 | if (!source_->NextRecord(&rec)) return false; 34 | CHECK_NE(rec.size, 0); 35 | bytes_read_ += rec.size; 36 | data->resize(1); (*data)[0].Clear(); 37 | CompressedRowBlock crb; 38 | crb.Decompress((char const*)rec.dptr, rec.size, &(*data)[0]); 39 | return true; 40 | } 41 | 42 | private: 43 | // number of bytes readed 44 | size_t bytes_read_; 45 | // source split that provides the data 46 | dmlc::InputSplit *source_; 47 | }; 48 | } // namespace difacto 49 | #endif // DIFACTO_READER_CRB_PARSER_H_ 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # *Distributed Factorization Machines* 2 | 3 | [![Build Status](https://travis-ci.org/dmlc/difacto.svg?branch=master)](https://travis-ci.org/dmlc/difacto) 4 | [![codecov.io](https://codecov.io/github/dmlc/difacto/coverage.svg?branch=master)](https://codecov.io/github/dmlc/difacto?branch=master) 5 | [![Documentation Status](https://readthedocs.org/projects/difacto/badge/?version=latest)](http://difacto.readthedocs.org/en/latest/?badge=latest) 6 | [![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE) 7 | 8 | Fast and memory efficient library for factorization machines (FM). 9 | 10 | - Supports both ℓ1 regularized logistic regression and factorization 11 | machines. 12 | - Runs on local machine and distributed clusters. 13 | - Scales to datasets with billions examples and features. 14 | 15 | ### Quick Start 16 | 17 | The following commands clone and build difacto, then download a sample dataset, 18 | and train FM with 2-dimension on it. 19 | 20 | ```bash 21 | git clone --recursive https://github.com/dmlc/difacto 22 | cd difacto; git submodule update --init; make -j8 23 | ./tools/download.sh gisette 24 | build/difacto data_in=data/gisette_scale val_data=data/gisette_scale.t lr=.02 V_dim=2 V_lr=.001 25 | ``` 26 | 27 | ### History 28 | 29 | Origins from 30 | [wormhole/learn/difacto](https://github.com/dmlc/wormhole/tree/master/learn/difacto). 31 | 32 | (NOTE: this project is still under developing) 33 | 34 | ### References 35 | 36 | Mu Li, Ziqi Liu, Alex Smola, and Yu-Xiang Wang. 37 | DiFacto — Distributed Factorization Machines. In WSDM, 2016 38 | -------------------------------------------------------------------------------- /src/common/arg_parser.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_COMMON_ARG_PARSER_H_ 5 | #define DIFACTO_COMMON_ARG_PARSER_H_ 6 | #include 7 | #include 8 | #include "dmlc/io.h" 9 | #include "dmlc/config.h" 10 | #include "difacto/base.h" 11 | namespace difacto { 12 | class ArgParser { 13 | public: 14 | ArgParser() { } 15 | ~ArgParser() { } 16 | 17 | /** 18 | * \brief add an arg 19 | */ 20 | void AddArg(const char* argv) { data_.append(argv); data_.append(" "); } 21 | 22 | /** 23 | * \brief return parsed kwargs 24 | */ 25 | KWArgs GetKWArgs() { 26 | std::stringstream ss(data_); 27 | dmlc::Config* conf = new dmlc::Config(ss); 28 | 29 | for (auto it : *conf) { 30 | if (it.first == "argfile") { 31 | AddArgFile(it.second.c_str()); 32 | delete conf; 33 | std::stringstream ss(data_); 34 | conf = new dmlc::Config(ss); 35 | break; 36 | } 37 | } 38 | KWArgs kwargs; 39 | for (auto it : *conf) { 40 | if (it.first == "argfile") continue; 41 | kwargs.push_back(it); 42 | } 43 | delete conf; 44 | return kwargs; 45 | } 46 | 47 | private: 48 | /** 49 | * \brief read all args in a file 50 | */ 51 | void AddArgFile(const char* const filename) { 52 | dmlc::Stream *fs = dmlc::Stream::Create(filename, "r"); 53 | CHECK(fs != nullptr) << "failed to open " << filename; 54 | char buf[1000]; 55 | while (true) { 56 | size_t r = fs->Read(buf, 1000); 57 | data_.append(buf, r); 58 | if (!r) break; 59 | } 60 | } 61 | std::string data_; 62 | }; 63 | } // namespace difacto 64 | #endif // DIFACTO_COMMON_ARG_PARSER_H_ 65 | -------------------------------------------------------------------------------- /src/store/store_local.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_STORE_STORE_LOCAL_H_ 5 | #define DIFACTO_STORE_STORE_LOCAL_H_ 6 | #include 7 | #include 8 | #include 9 | #include "difacto/store.h" 10 | #include "difacto/updater.h" 11 | #include "dmlc/parameter.h" 12 | namespace difacto { 13 | 14 | /** 15 | * \brief model sync within a machine 16 | */ 17 | class StoreLocal : public Store { 18 | public: 19 | StoreLocal() { } 20 | virtual ~StoreLocal() { } 21 | 22 | KWArgs Init(const KWArgs& kwargs) { return kwargs; } 23 | 24 | int Push(const SArray& fea_ids, 25 | int val_type, 26 | const SArray& vals, 27 | const SArray& lens, 28 | const std::function& on_complete) override { 29 | SArray vals_copy; vals_copy.CopyFrom(vals); 30 | SArray lens_copy; lens_copy.CopyFrom(lens); 31 | updater_->Update(fea_ids, val_type, vals_copy, lens_copy); 32 | if (on_complete) on_complete(); 33 | return time_++; 34 | } 35 | 36 | int Pull(const SArray& fea_ids, 37 | int val_type, 38 | SArray* vals, 39 | SArray* lens, 40 | const std::function& on_complete) override { 41 | updater_->Get(fea_ids, val_type, vals, lens); 42 | if (on_complete) on_complete(); 43 | return time_++; 44 | } 45 | 46 | void Wait(int time) override { } 47 | int Rank() override { return 0; } 48 | int NumWorkers() override { return 1; } 49 | int NumServers() override { return 1; } 50 | 51 | private: 52 | int time_; 53 | }; 54 | } // namespace difacto 55 | #endif // DIFACTO_STORE_STORE_LOCAL_H_ 56 | -------------------------------------------------------------------------------- /include/difacto/reporter.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_REPORTER_H_ 5 | #define DIFACTO_REPORTER_H_ 6 | #include 7 | #include 8 | #include 9 | #include "./base.h" 10 | namespace difacto { 11 | /** 12 | * \brief report to the scheduler 13 | */ 14 | class Reporter { 15 | public: 16 | /** 17 | * \brief factory function 18 | */ 19 | static Reporter* Create(); 20 | /** \brief constructor */ 21 | Reporter() { } 22 | /** \brief deconstructor */ 23 | virtual ~Reporter() { } 24 | /** 25 | * \brief init 26 | * @param kwargs keyword arguments 27 | * @return the unknown kwargs 28 | */ 29 | virtual KWArgs Init(const KWArgs& kwargs) = 0; 30 | 31 | /////////////// functions for the scheduler node ///////////////// 32 | /** 33 | * \brief the function to process the report sent by a node 34 | * @param node_id the node id 35 | * @param report the received report 36 | */ 37 | typedef std::function Monitor; 38 | /** 39 | * \brief set the monitor function 40 | */ 41 | virtual void SetMonitor(const Monitor& monitor) = 0; 42 | 43 | /////////////// functions for a server/worker node ///////////////// 44 | 45 | /** 46 | * \brief report to the scheduler 47 | * \param report the report 48 | * \return the timestamp of the report 49 | */ 50 | virtual int Report(const std::string& report) = 0; 51 | /** 52 | * \brief wait until a particular Report has been finished. 53 | * @param timestamp the timestamp of the report 54 | */ 55 | virtual void Wait(int timestamp) = 0; 56 | }; 57 | 58 | } // namespace difacto 59 | #endif // DIFACTO_REPORTER_H_ 60 | -------------------------------------------------------------------------------- /src/common/learner_utils.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_COMMON_LEARNER_UTILS_H_ 5 | #define DIFACTO_COMMON_LEARNER_UTILS_H_ 6 | #include 7 | #include 8 | #include 9 | #include "difacto/tracker.h" 10 | #include "dmlc/io.h" 11 | #include "dmlc/memory_io.h" 12 | namespace difacto { 13 | /** 14 | * \brief send jobs to a node group and wait them finished. 15 | * 16 | * @param node_group 17 | * @param job_args 18 | * @param tracker 19 | * @param job_rets 20 | */ 21 | inline void SendJobAndWait(int node_group, 22 | const std::string& job_args, 23 | Tracker* tracker, 24 | std::vector* job_rets) { 25 | // set monitor 26 | Tracker::Monitor monitor = nullptr; 27 | if (job_rets != nullptr) { 28 | monitor = [job_rets](int node_id, const std::string& rets) { 29 | auto copy = rets; dmlc::Stream* ss = new dmlc::MemoryStringStream(©); 30 | std::vector vec; ss->Read(&vec); delete ss; 31 | if (job_rets->empty()) { 32 | *job_rets = vec; 33 | } else { 34 | CHECK_EQ(job_rets->size(), vec.size()); 35 | for (size_t i = 0; i < vec.size(); ++i) (*job_rets)[i] += vec[i]; 36 | } 37 | }; 38 | } 39 | tracker->SetMonitor(monitor); 40 | 41 | // sent job 42 | std::pair job; 43 | job.first = node_group; 44 | job.second = job_args; 45 | tracker->Issue({job}); 46 | 47 | // wait until finished 48 | while (tracker->NumRemains() != 0) { 49 | std::this_thread::sleep_for(std::chrono::milliseconds(10)); 50 | } 51 | } 52 | } // namespace difacto 53 | #endif // DIFACTO_COMMON_LEARNER_UTILS_H_ 54 | -------------------------------------------------------------------------------- /src/reader/match_file.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_READER_MATCH_FILE_H_ 5 | #define DIFACTO_READER_MATCH_FILE_H_ 6 | #include 7 | #include 8 | #include 9 | #include "io/filesys.h" 10 | #include "dmlc/logging.h" 11 | namespace difacto { 12 | 13 | /** 14 | * 15 | * \brief match file by regex patterns 16 | * such as s3://my_path/part-.* 17 | * 18 | * @param pattern the regex pattern 19 | * @param matched matched filenames 20 | */ 21 | inline void MatchFile(const std::string& pattern, 22 | std::vector* matched) { 23 | // get the path 24 | size_t pos = pattern.find_last_of("/\\"); 25 | std::string path = "./"; 26 | if (pos != std::string::npos) path = pattern.substr(0, pos); 27 | 28 | // find all files 29 | dmlc::io::URI path_uri(path.c_str()); 30 | dmlc::io::FileSystem *fs = 31 | dmlc::io::FileSystem::GetInstance(path_uri.protocol); 32 | std::vector info; 33 | fs->ListDirectory(path_uri, &info); 34 | 35 | // store all matached files 36 | regex_t pat; 37 | std::string file = 38 | pos == std::string::npos ? pattern : pattern.substr(pos+1); 39 | file = ".*" + file; 40 | int status = regcomp(&pat, file.c_str(), REG_EXTENDED|REG_NEWLINE); 41 | if (status != 0) { 42 | char error_message[1000]; 43 | regerror(status, &pat, error_message, 1000); 44 | LOG(FATAL) << "error regex '" << pattern << "' : " << error_message; 45 | } 46 | 47 | regmatch_t m[1]; 48 | CHECK_NOTNULL(matched); 49 | for (size_t i = 0; i < info.size(); ++i) { 50 | std::string file = info[i].path.str(); 51 | if (regexec(&pat, file.c_str(), 1, m, 0)) continue; 52 | matched->push_back(file); 53 | } 54 | } 55 | 56 | } // namespace difacto 57 | #endif // DIFACTO_READER_MATCH_FILE_H_ 58 | -------------------------------------------------------------------------------- /src/reader/reader.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_READER_READER_H_ 5 | #define DIFACTO_READER_READER_H_ 6 | #include 7 | #include "difacto/base.h" 8 | #include "dmlc/data.h" 9 | #include "data/parser.h" 10 | #include "data/libsvm_parser.h" 11 | #include "./adfea_parser.h" 12 | #include "./crb_parser.h" 13 | #include "./criteo_parser.h" 14 | namespace difacto { 15 | /** 16 | * \brief a reader reads a chunk of data with roughly same size a time 17 | */ 18 | class Reader { 19 | public: 20 | Reader() { parser_ = nullptr; } 21 | Reader(const std::string& uri, 22 | const std::string& format, 23 | int part_index, 24 | int num_parts, 25 | int chunk_size_hint) { 26 | char const* c_uri = uri.c_str(); 27 | dmlc::InputSplit* input = dmlc::InputSplit::Create( 28 | c_uri, part_index, num_parts, format == "rec" ? "recordio" : "text"); 29 | input->HintChunkSize(chunk_size_hint); 30 | 31 | if (format == "libsvm") { 32 | parser_ = new dmlc::data::LibSVMParser(input, 1); 33 | } else if (format == "criteo") { 34 | parser_ = new CriteoParser(input, true); 35 | } else if (format == "criteo_test") { 36 | parser_ = new CriteoParser(input, false); 37 | } else if (format == "adfea") { 38 | parser_ = new AdfeaParser(input); 39 | } else if (format == "rec") { 40 | parser_ = new CRBParser(input); 41 | } else { 42 | LOG(FATAL) << "unknown format " << format; 43 | } 44 | parser_ = new dmlc::data::ThreadedParser(parser_); 45 | } 46 | 47 | virtual ~Reader() { delete parser_; } 48 | 49 | virtual bool Next() { return parser_->Next(); } 50 | 51 | virtual const dmlc::RowBlock& Value() const { return parser_->Value(); } 52 | 53 | private: 54 | dmlc::data::ParserImpl* parser_; 55 | }; 56 | 57 | } // namespace difacto 58 | #endif // DIFACTO_READER_READER_H_ 59 | -------------------------------------------------------------------------------- /tests/cpp/batch_reader_test.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include 5 | #include "reader/batch_reader.h" 6 | #include "./utils.h" 7 | 8 | using namespace difacto; 9 | int batch_size = 37; 10 | int label[] = { 11, 15, 10}; 11 | int len[] = { 37, 37, 26}; 12 | size_t os[] = { 85035 , 63968 , 31323 }; 13 | uint32_t idx[] = { 95285478 , 70504854 , 62972349}; 14 | float val[] = { 37.0000 , 37.0000 , 26.0000}; 15 | 16 | TEST(BatchReader, Read) { 17 | BatchReader reader("../tests/data", "libsvm", 0, 1, batch_size); 18 | int i = 0; 19 | while (reader.Next()) { 20 | auto batch = reader.Value(); 21 | int size = batch.size; 22 | EXPECT_EQ(label[i], sum(batch.label, size)); 23 | EXPECT_EQ(len[i], size); 24 | EXPECT_EQ(os[i], norm1(batch.offset, size+1)); 25 | EXPECT_EQ(idx[i], norm1(batch.index, batch.offset[size])); 26 | EXPECT_LE(fabs(val[i] - norm2(batch.value, batch.offset[size])), 1e-5); 27 | ++i; 28 | } 29 | } 30 | 31 | TEST(BatchReader, RandRead) { 32 | BatchReader reader("../tests/data", "libsvm", 0, 1, batch_size, batch_size); 33 | int i = 0; 34 | while (reader.Next()) { 35 | auto batch = reader.Value(); 36 | int size = batch.size; 37 | EXPECT_EQ(label[i], sum(batch.label, size)); 38 | EXPECT_EQ(len[i], size); 39 | EXPECT_NE(os[i], norm1(batch.offset, size+1)); 40 | EXPECT_EQ(idx[i], norm1(batch.index, batch.offset[size])); 41 | EXPECT_LE(fabs(val[i] - norm2(batch.value, batch.offset[size])), 1e-5); 42 | ++i; 43 | } 44 | } 45 | 46 | TEST(BatchReader, PartRead) { 47 | BatchReader reader("../tests/data", "libsvm", 1, 2, batch_size); 48 | int ttl = 0; 49 | while (reader.Next()) { 50 | auto batch = reader.Value(); 51 | int size = batch.size; 52 | EXPECT_LE(fabs(size - norm2(batch.value, batch.offset[size])), 1e-5); 53 | ttl += size; 54 | } 55 | CHECK_LE(ttl, 60); 56 | CHECK_GE(ttl, 40); 57 | } 58 | -------------------------------------------------------------------------------- /tests/cpp/bcd_learner_test.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include 5 | #include "bcd/bcd_learner.h" 6 | 7 | using namespace difacto; 8 | 9 | TEST(BCDLearer, DiagNewton) { 10 | BCDLearner learner; 11 | KWArgs args = {{"data_in", "../tests/data"}, 12 | {"l1", ".1"}, 13 | {"lr", ".05"}, 14 | {"block_ratio", "0.001"}, 15 | {"tail_feature_filter", "0"}, 16 | {"max_num_epochs", "10"}}; 17 | auto remain = learner.Init(args); 18 | EXPECT_EQ(remain.size(), 0); 19 | 20 | std::vector objv = { 21 | 34.877064, 22 | 33.885559, 23 | 29.572740, 24 | 27.458964, 25 | 25.317689, 26 | 23.917098, 27 | 22.855843, 28 | 22.099876, 29 | 21.552682, 30 | 21.137216 31 | }; 32 | 33 | auto callback = [objv](int epoch, const std::vector& prog) { 34 | EXPECT_LT(fabs(prog[1] - objv[epoch])/prog[1], 1e-5); 35 | }; 36 | learner.AddEpochEndCallback(callback); 37 | learner.Run(); 38 | } 39 | 40 | // the optimal solution with ../tests/data and l1 = .1 is objv = 15.884923, nnz 41 | // w = 47 42 | 43 | TEST(BCDLearer, Convergence) { 44 | std::vector ratio = {.4, 1, 10}; 45 | 46 | for (real_t r : ratio) { 47 | real_t objv; 48 | BCDLearner learner; 49 | KWArgs args = {{"data_in", "../tests/data"}, 50 | {"l1", ".1"}, 51 | {"lr", ".8"}, 52 | {"block_ratio", std::to_string(r)}, 53 | {"tail_feature_filter", "0"}, 54 | {"max_num_epochs", "50"}}; 55 | auto remain = learner.Init(args); 56 | EXPECT_EQ(remain.size(), 0); 57 | 58 | auto callback = [&objv](int epoch, const std::vector& prog) { 59 | objv = prog[1]; 60 | }; 61 | learner.AddEpochEndCallback(callback); 62 | learner.Run(); 63 | 64 | EXPECT_LT(fabs(objv - 15.884923)/objv, 1e-3); 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /src/common/find_position.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_COMMON_FIND_POSITION_H_ 5 | #define DIFACTO_COMMON_FIND_POSITION_H_ 6 | #include 7 | #include 8 | #include "dmlc/logging.h" 9 | #include "difacto/base.h" 10 | #include "difacto/sarray.h" 11 | #include "./range.h" 12 | namespace difacto { 13 | 14 | namespace { 15 | template 16 | size_t FindPosition(K const* src_begin, K const* src_end, 17 | K const* dst_begin, K const* dst_end, 18 | int* pos_begin, int* pos_end) { 19 | size_t n = 0; 20 | K const* src = std::lower_bound(src_begin, src_end, *dst_begin); 21 | K const* dst = std::lower_bound(dst_begin, dst_end, *src); 22 | 23 | int *pos = pos_begin + (dst - dst_begin); 24 | while (pos_begin != pos) { *pos_begin = -1; ++pos_begin; } 25 | while (src != src_end && dst != dst_end) { 26 | if (*src < *dst) { 27 | ++src; 28 | } else { 29 | if (!(*dst < *src)) { // equal 30 | *pos = static_cast(src - src_begin); 31 | ++src; ++n; 32 | } else { 33 | *pos = -1; 34 | } 35 | ++dst; ++pos; 36 | } 37 | } 38 | while (pos != pos_end) {*pos = -1; ++pos; } 39 | return n; 40 | } 41 | 42 | } // namespace 43 | 44 | /** 45 | * \brief store the position of dst[i] in src into pos[i], namely src[pos[i]] == dst[i] 46 | * 47 | * @param src unique and sorted vector 48 | * @param dst unique and sorted vector 49 | * @param pos the positions, -1 means no matched 50 | * @return the number of matched 51 | */ 52 | template 53 | size_t FindPosition(const SArray& src, const SArray& dst, SArray* pos) { 54 | CHECK_NOTNULL(pos)->resize(dst.size()); 55 | return FindPosition(src.begin(), src.end(), 56 | dst.begin(), dst.end(), 57 | pos->begin(), pos->end()); 58 | } 59 | } // namespace difacto 60 | #endif // DIFACTO_COMMON_FIND_POSITION_H_ 61 | -------------------------------------------------------------------------------- /tests/matlab/lbfgs.m: -------------------------------------------------------------------------------- 1 | %% 2 | load data 3 | [i,j,k]=find(x); 4 | cnt = full(sum(sparse(i,j,ones(size(i))))); 5 | X = x(:,cnt>0); 6 | Y = y; 7 | X2 = x(:,cnt>0); 8 | 9 | %% 10 | V_dim = 5; 11 | if V_dim == 0 12 | X2 = []; 13 | end 14 | 15 | V = repmat(((1:V_dim) - V_dim/2)*.01, size(X2,2), 1); 16 | % V = randn(size(X2,2),V_dim) * .01; 17 | 18 | max_m = 5; 19 | 20 | lw = .1; 21 | lV = .01; 22 | 23 | c1 = 1e-4; 24 | c2 = .9; 25 | rho = .5; 26 | 27 | [n, p] = size(X); 28 | w = zeros(p,1); 29 | 30 | s = []; 31 | y = []; 32 | 33 | [objv, gw, gV] = fm_loss_l2(Y, X, w, X2, V, lw, lV); 34 | g = [gw(:); gV(:)]; 35 | 36 | g'*g 37 | %% 38 | for k = 1 : 23 39 | % two loop 40 | m = size(y, 2); 41 | p = - g; 42 | alpha = zeros(m,1); 43 | for i = m : -1 : 1 44 | alpha(i) = (s(:,i)' * p ) / (s(:,i)' * y(:,i) + 1e-10); 45 | p = p - alpha(i) * y(:,i); 46 | end 47 | if m > 0 48 | p = (s(:,m)'*y(:,m)) / (y(:,m)'*y(:,m) + 1e-10) * p; 49 | end 50 | for i = 1 : m 51 | beta = (y(:,i)'*p) / (s(:,i)'*y(:,i)); 52 | p = p + (alpha(i) - beta) * s(:,i); 53 | end 54 | p = min(max(p, -5), 5); 55 | 56 | % back tracking 57 | alpha = 1; 58 | gp = g'*p; 59 | fprintf('epoch %d, objv %f, gp %f\n', k, objv, gp); 60 | dw = p(1:length(w)); 61 | dV = reshape(p((length(w)+1):end), [], V_dim); 62 | for j = 1 : 10 63 | [new_o, gw, gV] = fm_loss_l2(Y, X, w+alpha*dw, X2, V+alpha*dV, lw, lV); 64 | new_g = [gw(:); gV(:)]; 65 | 66 | new_gp = new_g' * p; 67 | fprintf('alpha %f, new_objv %f, new_gp %f\n', alpha, new_o, new_gp); 68 | if (new_o <= objv + c1 * alpha * gp) && (new_gp >= c2 * gp) 69 | break; 70 | end 71 | alpha = alpha * rho; 72 | end 73 | 74 | if m == max_m 75 | s = s(:,2:m); 76 | y = y(:,2:m); 77 | end 78 | 79 | w = w + alpha * dw; 80 | V = V + alpha * dV; 81 | old_g = g; 82 | 83 | [objv, gw, gV] = fm_loss_l2(Y, X, w, X2, V, lw, lV); 84 | g = [gw(:); gV(:)]; 85 | 86 | s = [s, alpha*p]; 87 | y = [y, g - old_g]; 88 | end 89 | -------------------------------------------------------------------------------- /include/difacto/learner.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_LEARNER_H_ 5 | #define DIFACTO_LEARNER_H_ 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "dmlc/io.h" 11 | #include "./base.h" 12 | #include "./tracker.h" 13 | namespace difacto { 14 | 15 | /** 16 | * \brief the base class of a learner 17 | * 18 | * a learner runs the learning algorithm to train a model 19 | */ 20 | class Learner { 21 | public: 22 | /** 23 | * \brief the factory function 24 | * \param type the learner type such as "sgd" 25 | */ 26 | static Learner* Create(const std::string& type); 27 | /** \brief construct */ 28 | Learner() { } 29 | /** \brief deconstruct */ 30 | virtual ~Learner() { } 31 | /** 32 | * \brief init learner 33 | * 34 | * @param kwargs keyword arguments 35 | * @return the unknown kwargs 36 | */ 37 | virtual KWArgs Init(const KWArgs& kwargs); 38 | /** 39 | * \brief train 40 | */ 41 | void Run() { 42 | if (!IsDistributed() || !strcmp(getenv("DMLC_ROLE"), "scheduler")) { 43 | RunScheduler(); 44 | } else { 45 | tracker_->Wait(); 46 | } 47 | } 48 | /** 49 | * \brief Stop learner. It is often used to stop the training earlier 50 | */ 51 | void Stop() { 52 | tracker_->Stop(); 53 | } 54 | 55 | protected: 56 | /** 57 | * \brief the function runs on the scheduler, which issues jobs to workers and 58 | * servers 59 | */ 60 | virtual void RunScheduler() = 0; 61 | 62 | /** 63 | * \brief the function runs on the worker/server to process jobs issued by the 64 | * scheduler 65 | * 66 | * \param args the job arguments received from the scheduler 67 | * \param rets the results send back to the scheduler 68 | */ 69 | virtual void Process(const std::string& args, std::string* rets) = 0; 70 | 71 | /** \brief the job tracker */ 72 | Tracker* tracker_; 73 | }; 74 | 75 | } // namespace difacto 76 | #endif // DIFACTO_LEARNER_H_ 77 | -------------------------------------------------------------------------------- /src/common/range.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_COMMON_RANGE_H_ 5 | #define DIFACTO_COMMON_RANGE_H_ 6 | #include "dmlc/logging.h" 7 | namespace difacto { 8 | /** 9 | * \brief a range between [begin, end) 10 | */ 11 | struct Range { 12 | Range(uint64_t _begin, uint64_t _end) : begin(_begin), end(_end) { } 13 | Range() : Range(0, 0) { } 14 | ~Range() { } 15 | /** 16 | * \brief evenly divide this range into npart segments, and return the idx-th 17 | * one 18 | */ 19 | inline Range Segment(uint64_t idx, uint64_t nparts) const { 20 | CHECK_GE(end, begin); 21 | CHECK_GT(nparts, (uint64_t)0); 22 | CHECK_LT(idx, nparts); 23 | double itv = static_cast(end - begin) / 24 | static_cast(nparts); 25 | uint64_t _begin = static_cast(begin + itv * idx); 26 | uint64_t _end = (idx == nparts - 1) ? 27 | end : static_cast(begin + itv * (idx+1)); 28 | return Range(_begin, _end); 29 | } 30 | 31 | /** 32 | * \brief Return true if i contains in this range 33 | */ 34 | inline bool Has(uint64_t i) const { 35 | return (begin <= i && i < end); 36 | } 37 | 38 | /** 39 | * \brief return a range for the whole range 40 | */ 41 | static Range All() { return Range(0, -1); } 42 | 43 | inline bool Valid() const { return end > begin; } 44 | 45 | inline uint64_t Size() const { return end - begin; } 46 | 47 | bool operator== (const Range& rhs) const { 48 | return (begin == rhs.begin && end == rhs.end); 49 | } 50 | bool operator!= (const Range& rhs) const { 51 | return !(*this == rhs); 52 | } 53 | 54 | Range operator+ (const uint64_t v) const { return Range(begin+v, end+v); } 55 | Range operator- (const uint64_t v) const { return Range(begin-v, end-v); } 56 | Range operator* (const uint64_t v) const { return Range(begin*v, end*v); } 57 | 58 | uint64_t begin; 59 | uint64_t end; 60 | }; 61 | 62 | } // namespace difacto 63 | #endif // DIFACTO_COMMON_RANGE_H_ 64 | -------------------------------------------------------------------------------- /src/loss/fm_loss_delta.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_LOSS_FM_LOSS_DELTA_H_ 5 | #define DIFACTO_LOSS_FM_LOSS_DELTA_H_ 6 | #include 7 | #include "difacto/sarray.h" 8 | #include "./fm_loss.h" 9 | namespace difacto { 10 | 11 | /** 12 | * \brief the FM loss, different to \ref FMLoss, \ref FMLossDelta is feeded with 13 | * delta weight, and tranpose of X, each time 14 | */ 15 | class FMLossDelta : public FMLoss { 16 | public: 17 | /** \brief constructor */ 18 | FMLossDelta() { } 19 | /** \brief deconstructor */ 20 | virtual ~FMLossDelta() { } 21 | 22 | KWArgs Init(const KWArgs& kwargs) override { 23 | return kwargs; 24 | } 25 | 26 | /** 27 | * @param data X', the transpose of X 28 | * @param param parameters 29 | * - param[0], real_t, previous predict 30 | * - param[1], int, param[1][i] is the length of the gradient on the i-th feature 31 | * and sum(param[2]) = length(grad) 32 | * - param[2], real_t, the prediction (results of \ref Predict) 33 | * @param grad output gradient 34 | */ 35 | void CalcGrad(const dmlc::RowBlock& data, 36 | const std::vector>& param, 37 | SArray* grad) override { 38 | // TODO(mli) 39 | } 40 | 41 | /** 42 | * @param data X', the transpose of X 43 | * @param param parameters 44 | * - param[0], real_t, previous predict 45 | * - param[1], real_t, delta weight, namely new_w - old_w 46 | * - param[2], int, param[2][i] is the length of delta_w[i]. 47 | * and sum(param[2]) = length(delta_w) 48 | * @param pred output prediction, it may overwrite param[0] 49 | */ 50 | void Predict(const dmlc::RowBlock& data, 51 | const std::vector>& param, 52 | SArray* pred) override { 53 | *CHECK_NOTNULL(pred) = param[0]; 54 | FMLoss::Predict(data, {param[1], param[2]}, pred); 55 | } 56 | }; 57 | } // namespace difacto 58 | 59 | #endif // DIFACTO_LOSS_FM_LOSS_DELTA_H_ 60 | -------------------------------------------------------------------------------- /tests/cpp/data_store_test.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include 5 | #include "dmlc/memory_io.h" 6 | #include "data/data_store.h" 7 | #include "./utils.h" 8 | 9 | using namespace difacto; 10 | 11 | TEST(DataStore, MemBase) { 12 | DataStore store; 13 | int n = 1000; 14 | SArray val1; 15 | SArray val2; 16 | SArray val3; 17 | 18 | gen_vals(n, -100, 100, &val1); 19 | gen_vals(n, -100, 100, &val2); 20 | gen_vals(n, -100, 100, &val3); 21 | 22 | store.Store("1", val1.data(), val1.size()); 23 | store.Store("2", val2.data(), val2.size()); 24 | 25 | SArray ret1; 26 | SArray ret2; 27 | store.Fetch("1", &ret1); 28 | store.Fetch("2", &ret2, Range(10, 30)); 29 | 30 | // overwrite key 31 | SArray ret3; 32 | store.Store("1", val3.data(), val3.size()); 33 | store.Fetch("1", &ret3); 34 | 35 | // noncopy 36 | { 37 | SArray val4(val2); 38 | store.Store("4", val4); 39 | } 40 | SArray ret4; 41 | store.Fetch("4", &ret4); 42 | 43 | EXPECT_EQ(norm2(val1), norm2(ret1)); 44 | EXPECT_EQ(norm2(SArray(val2).segment(10, 30)), norm2(ret2)); 45 | EXPECT_EQ(norm2(val3), norm2(ret3)); 46 | EXPECT_EQ(norm2(val2), norm2(ret4)); 47 | } 48 | 49 | TEST(DataStore, Meta) { 50 | DataStore store; 51 | int n = 1000; 52 | SArray val1; 53 | SArray val2; 54 | SArray val3; 55 | gen_vals(n, -100, 100, &val1); 56 | gen_vals(n, -100, 100, &val2); 57 | gen_vals(n, -100, 100, &val3); 58 | 59 | 60 | store.Store("1", val1); 61 | store.Store("2", val2); 62 | store.Store("3", val3); 63 | 64 | std::string meta; 65 | dmlc::Stream* os = new dmlc::MemoryStringStream(&meta); 66 | store.Save(os); 67 | delete os; 68 | 69 | LL << meta; 70 | 71 | DataStore store2; 72 | dmlc::Stream* is = new dmlc::MemoryStringStream(&meta); 73 | store2.Load(is); 74 | delete is; 75 | 76 | EXPECT_EQ(store2.size("1"), n); 77 | EXPECT_EQ(store2.size("2"), n); 78 | EXPECT_EQ(store2.size("3"), n); 79 | } 80 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # default configures, one can change it by passing new value to make. 2 | # e.g. `make CXX=g++-4.9` 3 | CXX = g++ 4 | DEPS_PATH = $(shell pwd)/deps 5 | USE_CITY=0 6 | USE_LZ4=1 7 | NO_REVERSE_ID=0 8 | 9 | all: build/difacto 10 | 11 | INCPATH = -I./src -I./include -I./dmlc-core/include -I./ps-lite/include -I./dmlc-core/src -I$(DEPS_PATH)/include 12 | PROTOC = ${DEPS_PATH}/bin/protoc 13 | CFLAGS = -std=c++11 -fopenmp -fPIC -O3 -ggdb -Wall -finline-functions $(INCPATH) -DDMLC_LOG_FATAL_THROW=0 $(ADD_CFLAGS) 14 | 15 | ifeq ($(NO_REVERSE_ID), 1) 16 | CFLAGS += -DREVERSE_FEATURE_ID=0 17 | endif 18 | 19 | include ps-lite/make/deps.mk 20 | 21 | ifeq ($(USE_CITY), 1) 22 | DEPS += ${CITYHASH} 23 | CFLAGS += -DDIFACTO_USE_CITY=1 24 | LDFLAGS += ${DEPS_PATH}/lib/libcityhash.a 25 | endif 26 | 27 | ifeq ($(USE_LZ4), 1) 28 | DEPS += ${LZ4} 29 | CFLAGS += -DDIFACTO_USE_LZ4=1 30 | LDFLAGS += ${DEPS_PATH}/lib/liblz4.a 31 | endif 32 | 33 | 34 | 35 | # LDFLAGS += $(addprefix $(DEPS_PATH)/lib/, libprotobuf.a libzmq.a) 36 | 37 | OBJS = $(addprefix build/, loss/loss.o \ 38 | updater.o \ 39 | sgd/sgd_updater.o sgd/sgd_learner.o \ 40 | learner.o \ 41 | bcd/bcd_learner.o \ 42 | lbfgs/lbfgs_learner.o \ 43 | store/store.o \ 44 | tracker/tracker.o \ 45 | reporter/reporter.o \ 46 | data/localizer.o reader/batch_reader.o ) 47 | 48 | DMLC_DEPS = dmlc-core/libdmlc.a 49 | 50 | clean: 51 | rm -rf build/* 52 | make -C dmlc-core clean 53 | make -C ps-lite clean 54 | 55 | lint: 56 | python2 dmlc-core/scripts/lint.py difacto all include src tests/cpp 57 | 58 | 59 | build/%.o: src/%.cc ${DEPS} 60 | @mkdir -p $(@D) 61 | $(CXX) $(INCPATH) -std=c++0x -MM -MT build/$*.o $< >build/$*.d 62 | $(CXX) $(CFLAGS) -c $< -o $@ 63 | 64 | build/libdifacto.a: $(OBJS) 65 | ar crv $@ $(filter %.o, $?) 66 | 67 | build/difacto: build/main.o build/libdifacto.a $(DMLC_DEPS) 68 | $(CXX) $(CFLAGS) -o $@ $^ $(LDFLAGS) 69 | 70 | dmlc-core/libdmlc.a: 71 | $(MAKE) -C dmlc-core libdmlc.a DEPS_PATH=$(DEPS_PATH) CXX=$(CXX) 72 | 73 | include tests/cpp/test.mk 74 | 75 | 76 | test: build/difacto_tests 77 | 78 | -include build/*.d 79 | -include build/*/*.d 80 | -------------------------------------------------------------------------------- /tests/cpp/localizer_test.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include 5 | #include "./utils.h" 6 | #include "reader/batch_reader.h" 7 | #include "difacto/base.h" 8 | #include "data/localizer.h" 9 | 10 | using namespace difacto; 11 | 12 | TEST(Localizer, Base) { 13 | BatchReader reader("../tests/data", "libsvm", 0, 1, 100); 14 | CHECK(reader.Next()); 15 | dmlc::data::RowBlockContainer compact; 16 | std::vector uidx; 17 | std::vector freq; 18 | 19 | Localizer lc; 20 | lc.Compact(reader.Value(), &compact, &uidx, &freq); 21 | auto batch = compact.GetBlock(); 22 | int size = batch.size; 23 | 24 | for (auto& i : uidx) i = ReverseBytes(i); 25 | 26 | EXPECT_EQ(norm1(uidx.data(), uidx.size()), (uint32_t)65111856); 27 | EXPECT_EQ(norm1(freq.data(), freq.size()), (real_t)9648); 28 | EXPECT_EQ(norm1(reader.Value().offset, size+1), 29 | norm1(batch.offset, size+1)); 30 | EXPECT_EQ(norm2(reader.Value().value, batch.offset[size]), 31 | norm2(batch.value, batch.offset[size])); 32 | } 33 | 34 | TEST(Localizer, BaseHash) { 35 | BatchReader reader("../tests/data", "libsvm", 0, 1, 100); 36 | CHECK(reader.Next()); 37 | dmlc::data::RowBlockContainer compact; 38 | std::vector uidx; 39 | std::vector freq; 40 | 41 | Localizer lc(1000); 42 | lc.Compact(reader.Value(), &compact, &uidx, &freq); 43 | auto batch = compact.GetBlock(); 44 | int size = batch.size; 45 | 46 | for (auto& i : uidx) i = ReverseBytes(i); 47 | 48 | EXPECT_EQ(norm1(uidx.data(), uidx.size()), (uint32_t)478817); 49 | EXPECT_EQ(norm1(freq.data(), freq.size()), 9648); 50 | EXPECT_EQ(norm1(reader.Value().offset, size+1), 51 | norm1(batch.offset, size+1)); 52 | EXPECT_EQ(norm2(reader.Value().value, batch.offset[size]), 53 | norm2(batch.value, batch.offset[size])); 54 | } 55 | 56 | TEST(Localizer, ReverseBytes) { 57 | feaid_t max = -1; 58 | int n = 1000000; 59 | for (int i = 0; i < n; ++i) { 60 | feaid_t j = (max / n) * i; 61 | EXPECT_EQ(j, ReverseBytes(ReverseBytes(j))); 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /tests/cpp/spmv_perf.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include "./utils.h" 5 | #include "common/arg_parser.h" 6 | #include "common/spmv.h" 7 | #include "dmlc/config.h" 8 | #include "dmlc/timer.h" 9 | #include "reader/reader.h" 10 | 11 | using namespace difacto; 12 | using namespace dmlc; 13 | 14 | struct Param : public Parameter { 15 | std::string data; 16 | std::string format; 17 | int nthreads; 18 | DMLC_DECLARE_PARAMETER(Param) { 19 | DMLC_DECLARE_FIELD(format).set_default("libsvm").describe("data format");; 20 | DMLC_DECLARE_FIELD(data).describe("input data filename");; 21 | DMLC_DECLARE_FIELD(nthreads).set_default(2).describe("number of threads");; 22 | } 23 | }; 24 | 25 | DMLC_REGISTER_PARAMETER(Param); 26 | 27 | int main(int argc, char *argv[]) { 28 | Param param; 29 | if (argc < 2) { 30 | LOG(ERROR) << "not enough input.. \n\nusage: ./difacto key1=val1 key2=val2 ...\n\n" 31 | << param.__DOC__(); 32 | return 0; 33 | } 34 | ArgParser parser; 35 | for (int i = 1; i < argc; ++i) parser.AddArg(argv[i]); 36 | param.Init(parser.GetKWArgs()); 37 | 38 | Reader reader(param.data, param.format, 0, 1, 512<<20); 39 | CHECK(reader.Next()); 40 | dmlc::data::RowBlockContainer data; 41 | std::vector uidx; 42 | Localizer lc; lc.Compact(reader.Value(), &data, &uidx); 43 | auto D = data.GetBlock(); 44 | size_t n = D.size; 45 | size_t p = uidx.size(); 46 | LOG(INFO) << "load " << n << " x " << p << " matrix"; 47 | 48 | double start; 49 | int repeat = 20; 50 | SArray x(p), y(n); 51 | 52 | for (int i = 0; i < repeat+1; ++i) { 53 | if (i == 1) start = GetTime(); // warmup when i == 0 54 | SpMV::Times(D, x, &y, param.nthreads); 55 | } 56 | double t1 = (GetTime() - start) / repeat; 57 | 58 | for (int i = 0; i < repeat+1; ++i) { 59 | if (i == 1) start = GetTime(); // warmup when i == 0 60 | SpMV::TransTimes(D, y, &x, param.nthreads); 61 | } 62 | double t2 = (GetTime() - start) / repeat; 63 | 64 | 65 | LOG(INFO) << "Times: " << t1 << ",\t TransTimes: " << t2; 66 | 67 | return 0; 68 | } 69 | -------------------------------------------------------------------------------- /include/difacto/updater.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_UPDATER_H_ 5 | #define DIFACTO_UPDATER_H_ 6 | #include 7 | #include 8 | #include "./base.h" 9 | #include "./sarray.h" 10 | #include "dmlc/io.h" 11 | namespace difacto { 12 | /** 13 | * \brief the base class of an updater 14 | * 15 | * the main job of a updater is to update model based 16 | * on the received data (often gradient) 17 | */ 18 | class Updater { 19 | public: 20 | /** 21 | * \brief default constructor 22 | */ 23 | Updater() { } 24 | /** 25 | * \brief default deconstructor 26 | */ 27 | virtual ~Updater() { } 28 | /** 29 | * \brief init the updater 30 | * @param kwargs keyword arguments 31 | * @return the unknown kwargs 32 | */ 33 | virtual KWArgs Init(const KWArgs& kwargs) = 0; 34 | 35 | /** 36 | * \brief load the updater 37 | * \param fi input stream 38 | * \param has_aux whether the loaded updater has aux data 39 | */ 40 | virtual void Load(dmlc::Stream* fi, bool* has_aux) = 0; 41 | 42 | /** 43 | * \brief save the updater 44 | * \param save_aux whether or not save aux data 45 | * \param fo output stream 46 | */ 47 | virtual void Save(bool save_aux, dmlc::Stream *fo) const = 0; 48 | /** 49 | * \brief get the weights on the given features 50 | * 51 | * @param fea_ids the list of feature ids 52 | * @param model 53 | * @param model_offset could be empty 54 | */ 55 | virtual void Get(const SArray& fea_ids, 56 | int data_type, 57 | SArray* data, 58 | SArray* data_offset) = 0; 59 | /** 60 | * \brief update the model given a list of key-value pairs 61 | * 62 | * @param fea_ids the list of feature ids 63 | * @param recv_data 64 | * @param recv_data_offset 65 | */ 66 | virtual void Update(const SArray& fea_ids, 67 | int data_type, 68 | const SArray& data, 69 | const SArray& data_offset) = 0; 70 | }; 71 | 72 | } // namespace difacto 73 | #endif // DIFACTO_UPDATER_H_ 74 | -------------------------------------------------------------------------------- /src/tracker/local_tracker.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_TRACKER_LOCAL_TRACKER_H_ 5 | #define DIFACTO_TRACKER_LOCAL_TRACKER_H_ 6 | #include 7 | #include 8 | #include 9 | #include "difacto/tracker.h" 10 | #include "./async_local_tracker.h" 11 | namespace difacto { 12 | /** 13 | * \brief an implementation of the tracker which only runs within a local 14 | * process 15 | */ 16 | class LocalTracker : public Tracker { 17 | public: 18 | typedef std::pair Job; 19 | 20 | LocalTracker() { 21 | tracker_ = new AsyncLocalTracker(); 22 | } 23 | virtual ~LocalTracker() { delete tracker_; } 24 | 25 | KWArgs Init(const KWArgs& kwargs) override { return kwargs; } 26 | 27 | 28 | void Issue(const std::vector& jobs) override { 29 | if (!tracker_) tracker_ = new AsyncLocalTracker(); 30 | tracker_->Issue(jobs); 31 | } 32 | 33 | int NumRemains() override { 34 | return CHECK_NOTNULL(tracker_)->NumRemains(); 35 | } 36 | 37 | void Clear() override { 38 | CHECK_NOTNULL(tracker_)->Clear(); 39 | } 40 | 41 | void Stop() override { 42 | if (tracker_) { 43 | delete tracker_; 44 | tracker_ = nullptr; 45 | } 46 | } 47 | 48 | void Wait() override { 49 | CHECK_NOTNULL(tracker_)->Wait(); 50 | } 51 | 52 | void SetMonitor(const Monitor& monitor) override { 53 | CHECK_NOTNULL(tracker_)->SetMonitor( 54 | [monitor](const Job& rets) { 55 | if (monitor) monitor(rets.first, rets.second); 56 | }); 57 | } 58 | 59 | void SetExecutor(const Executor& executor) override { 60 | CHECK_NOTNULL(executor); 61 | CHECK_NOTNULL(tracker_)->SetExecutor( 62 | [executor](const Job& args, 63 | const std::function& on_complete, 64 | Job* rets) { 65 | rets->first = args.first; 66 | executor(args.second, &(rets->second)); 67 | on_complete(); 68 | }); 69 | } 70 | 71 | private: 72 | AsyncLocalTracker* tracker_ = nullptr; 73 | }; 74 | 75 | } // namespace difacto 76 | #endif // DIFACTO_TRACKER_LOCAL_TRACKER_H_ 77 | -------------------------------------------------------------------------------- /src/main.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include "difacto/learner.h" 5 | #include "common/arg_parser.h" 6 | #include "dmlc/parameter.h" 7 | #include "reader/converter.h" 8 | namespace difacto { 9 | struct DifactoParam : public dmlc::Parameter { 10 | /** 11 | * \brief the type of task, 12 | * - train: train a model, which is the default 13 | * - predict: predict by using a trained model 14 | * - convert: convert data from one format into another 15 | */ 16 | std::string task; 17 | /** \brief the learner's type, required for a training task */ 18 | std::string learner; 19 | DMLC_DECLARE_PARAMETER(DifactoParam) { 20 | DMLC_DECLARE_FIELD(learner).set_default("sgd"); 21 | DMLC_DECLARE_FIELD(task).set_default("train"); 22 | } 23 | }; 24 | 25 | void WarnUnknownKWArgs(const DifactoParam& param, const KWArgs& remain) { 26 | if (remain.empty()) return; 27 | LOG(WARNING) << "Unrecognized keyword argument for task = " << param.task; 28 | for (auto kw : remain) { 29 | LOG(WARNING) << " - " << kw.first << " = " << kw.second; 30 | } 31 | } 32 | 33 | DMLC_REGISTER_PARAMETER(DifactoParam); 34 | DMLC_REGISTER_PARAMETER(ConverterParam); 35 | 36 | } // namespace difacto 37 | 38 | int main(int argc, char *argv[]) { 39 | if (argc < 2) { 40 | LOG(ERROR) << "usage: difacto key1=val1 key2=val2 ..."; 41 | return 0; 42 | } 43 | using namespace difacto; 44 | 45 | // parse configuure 46 | ArgParser parser; 47 | for (int i = 1; i < argc; ++i) parser.AddArg(argv[i]); 48 | DifactoParam param; 49 | auto kwargs_remain = param.InitAllowUnknown(parser.GetKWArgs()); 50 | 51 | // run 52 | if (param.task == "train") { 53 | Learner* learner = Learner::Create(param.learner); 54 | WarnUnknownKWArgs(param, learner->Init(kwargs_remain)); 55 | learner->Run(); 56 | delete learner; 57 | } else if (param.task == "convert") { 58 | Converter converter; 59 | WarnUnknownKWArgs(param, converter.Init(kwargs_remain)); 60 | converter.Run(); 61 | } else if (param.task == "predict") { 62 | LOG(FATAL) << "TODO"; 63 | } else { 64 | LOG(FATAL) << "unknown task: " << param.task; 65 | } 66 | return 0; 67 | } 68 | -------------------------------------------------------------------------------- /src/reader/batch_reader.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_READER_BATCH_READER_H_ 5 | #define DIFACTO_READER_BATCH_READER_H_ 6 | #include 7 | #include 8 | #include "difacto/base.h" 9 | #include "dmlc/data.h" 10 | #include "./reader.h" 11 | namespace difacto { 12 | 13 | /** 14 | * \brief a reader reads a batch with a given number of examples 15 | * each time. 16 | */ 17 | class BatchReader : public Reader { 18 | public: 19 | /** 20 | * \brief create a batch iterator 21 | * 22 | * @param uri filename 23 | * @param format the data format, support libsvm, rec, ... 24 | * @param part_index the i-th part to read 25 | * @param num_parts partition the file into serveral parts 26 | * @param batch_size the batch size. 27 | * @param shuffle_size if nonzero, then the batch is randomly picked from a buffer with 28 | * shuffle_buf_size examples 29 | * @param neg_sampling the probability to pickup a negative sample (label <= 0) 30 | */ 31 | BatchReader(const std::string& uri, 32 | const std::string& format, 33 | unsigned part_index, 34 | unsigned num_parts, 35 | unsigned batch_size, 36 | unsigned shuffle_buf_size = 0, 37 | float neg_sampling = 1.0); 38 | 39 | virtual ~BatchReader() { 40 | delete reader_; 41 | delete buf_reader_; 42 | } 43 | 44 | /** 45 | * \brief read the next batch 46 | */ 47 | bool Next() override; 48 | 49 | /** 50 | * \brief get the current batch 51 | * 52 | */ 53 | const dmlc::RowBlock& Value() const override { 54 | return out_blk_; 55 | } 56 | 57 | private: 58 | /** 59 | * \brief batch_.push(in_blk_(pos:pos+len)) 60 | */ 61 | void Push(size_t pos, size_t len); 62 | 63 | unsigned batch_size_, shuf_buf_; 64 | 65 | Reader *reader_; 66 | BatchReader* buf_reader_; 67 | 68 | float neg_sampling_; 69 | size_t start_, end_; 70 | dmlc::RowBlock in_blk_, out_blk_; 71 | dmlc::data::RowBlockContainer batch_; 72 | 73 | // random pertubation 74 | std::vector rdp_; 75 | unsigned int seed_; 76 | }; 77 | 78 | } // namespace difacto 79 | #endif // DIFACTO_READER_BATCH_READER_H_ 80 | -------------------------------------------------------------------------------- /src/bcd/bcd_param.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_BCD_BCD_PARAM_H_ 5 | #define DIFACTO_BCD_BCD_PARAM_H_ 6 | #include 7 | #include "dmlc/parameter.h" 8 | namespace difacto { 9 | 10 | struct BCDLearnerParam : public dmlc::Parameter { 11 | /** \brief The input data, either a filename or a directory. */ 12 | std::string data_in; 13 | /** \brief The optional validation dataset, either a filename or a directory */ 14 | std::string data_val; 15 | /** \brief the data format. default is libsvm */ 16 | std::string data_format; 17 | /** \brief the directory for the data chache */ 18 | std::string data_cache; 19 | /** \brief the model output for a training task */ 20 | std::string model_out; 21 | /** \brief the model input for warm start */ 22 | std::string model_in; 23 | /** \brief type of loss, defaut is fm*/ 24 | std::string loss; 25 | /** \brief the maximal number of data passes, defaut is 20 */ 26 | int max_num_epochs; 27 | /** \brief controls the number of feature blocks, default is 4 */ 28 | float block_ratio; 29 | /** \brief if or not process feature blocks in a random order, default is true */ 30 | int random_block; 31 | /** \brief the number of heading bits used to encode the feature group, default is 12 */ 32 | int num_feature_group_bits; 33 | float neg_sampling; 34 | /** \brief the size of data in MB read each time for processing, in default 256 MB */ 35 | int data_chunk_size; 36 | 37 | DMLC_DECLARE_PARAMETER(BCDLearnerParam) { 38 | DMLC_DECLARE_FIELD(data_format).set_default("libsvm"); 39 | DMLC_DECLARE_FIELD(data_in); 40 | DMLC_DECLARE_FIELD(data_val).set_default(""); 41 | DMLC_DECLARE_FIELD(data_cache).set_default("/tmp/difacto_bcd_"); 42 | DMLC_DECLARE_FIELD(data_chunk_size).set_default(1<<28); 43 | DMLC_DECLARE_FIELD(model_out).set_default(""); 44 | DMLC_DECLARE_FIELD(model_in).set_default(""); 45 | DMLC_DECLARE_FIELD(loss).set_default("fm"); 46 | DMLC_DECLARE_FIELD(max_num_epochs).set_default(20); 47 | DMLC_DECLARE_FIELD(random_block).set_default(1); 48 | DMLC_DECLARE_FIELD(num_feature_group_bits).set_default(0); 49 | DMLC_DECLARE_FIELD(block_ratio).set_default(4); 50 | } 51 | }; 52 | } // namespace difacto 53 | #endif // DIFACTO_BCD_BCD_PARAM_H_ 54 | -------------------------------------------------------------------------------- /src/sgd/sgd_utils.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_SGD_SGD_UTILS_H_ 5 | #define DIFACTO_SGD_SGD_UTILS_H_ 6 | #include 7 | #include 8 | #include 9 | #include "dmlc/memory_io.h" 10 | namespace difacto { 11 | namespace sgd { 12 | 13 | /** 14 | * \brief a sgd job 15 | */ 16 | struct Job { 17 | static const int kLoadModel = 1; 18 | static const int kSaveModel = 2; 19 | static const int kTraining = 3; 20 | static const int kValidation = 4; 21 | static const int kEvaluation = 5; 22 | int type; 23 | /** \brief number of partitions of this file */ 24 | int num_parts; 25 | /** \brief the part will be processed, -1 means all */ 26 | int part_idx; 27 | /** \brief the current epoch */ 28 | int epoch; 29 | Job() { } 30 | void SerializeToString(std::string* str) const { 31 | *str = std::string(reinterpret_cast(this), sizeof(Job)); 32 | } 33 | 34 | void ParseFromString(const std::string& str) { 35 | CHECK_EQ(str.size(), sizeof(Job)); 36 | memcpy(this, str.data(), sizeof(Job)); 37 | } 38 | }; 39 | 40 | struct Progress { 41 | real_t loss = 0; // 42 | real_t penalty = 0; // 43 | real_t auc = 0; // auc 44 | real_t nnz_w = 0; // |w|_0 45 | real_t nrows = 0; // number of examples 46 | 47 | std::string TextString() { 48 | std::stringstream ss; 49 | ss << "loss = " << loss << ", AUC = " << auc / nrows; 50 | return ss.str(); 51 | } 52 | 53 | void SerializeToString(std::string* str) const { 54 | *str = std::string(reinterpret_cast(this), sizeof(Progress)); 55 | } 56 | 57 | void ParseFrom(char const* data, size_t size) { 58 | if (size == 0) return; 59 | CHECK_EQ(size, sizeof(Progress)); 60 | memcpy(this, data, sizeof(Progress)); 61 | } 62 | 63 | void Merge(const std::string& str) { 64 | Progress other; 65 | other.ParseFrom(str.data(), str.size()); 66 | Merge(other); 67 | } 68 | 69 | void Merge(const Progress& other) { 70 | size_t n = sizeof(Progress) / sizeof(real_t); 71 | auto a = reinterpret_cast(this); 72 | auto b = reinterpret_cast(&other); 73 | for (size_t i = 0; i < n; ++i) a[i] += b[i]; 74 | } 75 | }; 76 | 77 | } // namespace sgd 78 | } // namespace difacto 79 | #endif // DIFACTO_SGD_SGD_UTILS_H_ 80 | -------------------------------------------------------------------------------- /src/sgd/sgd_updater.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | * @file sgd.h 4 | * @brief the stochastic gradient descent solver 5 | */ 6 | #ifndef DIFACTO_SGD_SGD_UPDATER_H_ 7 | #define DIFACTO_SGD_SGD_UPDATER_H_ 8 | #include 9 | #include 10 | #include 11 | #include "difacto/updater.h" 12 | #include "./sgd_param.h" 13 | #include "./sgd_utils.h" 14 | #include "dmlc/io.h" 15 | namespace difacto { 16 | /** 17 | * \brief the weight entry for one feature 18 | */ 19 | struct SGDEntry { 20 | public: 21 | SGDEntry() { } 22 | ~SGDEntry() { delete [] V; } 23 | /** \brief the number of appearence of this feature in the data so far */ 24 | real_t fea_cnt = 0; 25 | /** \brief w and its aux data */ 26 | real_t w = 0, sqrt_g = 0, z = 0; 27 | /** \brief V and its aux data */ 28 | real_t *V = nullptr; 29 | }; 30 | /** 31 | * \brief sgd updater 32 | * 33 | * - w is updated by FTRL, which is a smooth version of adagrad works well with 34 | * the l1 regularizer 35 | * - V is updated by adagrad 36 | */ 37 | class SGDUpdater : public Updater { 38 | public: 39 | SGDUpdater() {} 40 | virtual ~SGDUpdater() {} 41 | 42 | KWArgs Init(const KWArgs& kwargs) override; 43 | 44 | void Load(dmlc::Stream* fi, bool* has_aux) override { 45 | // TODO(mli) 46 | } 47 | 48 | void Save(bool save_aux, dmlc::Stream *fo) const override { 49 | // TODO(mli) 50 | } 51 | 52 | void Get(const SArray& fea_ids, 53 | int value_type, 54 | SArray* weights, 55 | SArray* val_lens) override; 56 | 57 | 58 | void Update(const SArray& fea_ids, 59 | int value_type, 60 | const SArray& values, 61 | const SArray& val_lens) override; 62 | 63 | void Evaluate(sgd::Progress* prog) const; 64 | 65 | const SGDUpdaterParam& param() const { return param_; } 66 | 67 | private: 68 | /** \brief update w by FTRL */ 69 | void UpdateW(real_t gw, SGDEntry* e); 70 | 71 | /** \brief update V by adagrad */ 72 | void UpdateV(real_t const* gV, SGDEntry* e); 73 | 74 | /** \brief init V */ 75 | void InitV(SGDEntry* e); 76 | 77 | SGDUpdaterParam param_; 78 | std::unordered_map model_; 79 | mutable std::mutex mu_; 80 | bool has_aux_ = true; 81 | }; 82 | 83 | 84 | } // namespace difacto 85 | #endif // DIFACTO_SGD_SGD_UPDATER_H_ 86 | -------------------------------------------------------------------------------- /tests/cpp/kv_union_test.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include 5 | #include 6 | #include "./utils.h" 7 | #include "common/kv_union.h" 8 | 9 | using namespace difacto; 10 | 11 | // a referance impl based std::map 12 | template 13 | void KVUnionRefer( 14 | const SArray& keys_a, 15 | const SArray& vals_a, 16 | const SArray& keys_b, 17 | const SArray& vals_b, 18 | SArray* joined_keys, 19 | SArray* joined_vals, 20 | int val_len = 1) { 21 | std::map> data; 22 | for (size_t i = 0; i < keys_a.size(); ++i) { 23 | auto& v = data[keys_a[i]]; 24 | v.resize(val_len); 25 | for (int j = 0; j < val_len; ++j) { 26 | v[j] = vals_a[i*val_len +j]; 27 | } 28 | } 29 | 30 | for (size_t i = 0; i < keys_b.size(); ++i) { 31 | auto it = data.find(keys_b[i]); 32 | if (it == data.end()) { 33 | auto& v = data[keys_b[i]]; 34 | v.resize(val_len); 35 | for (int j = 0; j < val_len; ++j) { 36 | v[j] = vals_b[i*val_len +j]; 37 | } 38 | } else { 39 | auto& v = it->second; 40 | for (int j = 0; j < val_len; ++j) { 41 | v[j] += vals_b[i*val_len +j]; 42 | } 43 | } 44 | } 45 | 46 | for (auto it : data) { 47 | joined_keys->push_back(it.first); 48 | for (V v : it.second) joined_vals->push_back(v); 49 | } 50 | } 51 | 52 | namespace { 53 | void test(int n, int k) { 54 | SArray key1, key2, jkey1, jkey2; 55 | SArray val1, val2, jval1, jval2; 56 | gen_keys(n, n*10, &key1); 57 | gen_keys(n, n*10, &key2); 58 | gen_vals(key1.size()*k, -100, 100, &val1); 59 | gen_vals(key2.size()*k, -100, 100, &val2); 60 | 61 | KVUnion(key1, val1, key2, val2, &jkey1, &jval1, PLUS, 4); 62 | KVUnionRefer(key1, val1, key2, val2, &jkey2, &jval2, k); 63 | 64 | EXPECT_EQ(jval1.size(), jval2.size()); 65 | EXPECT_EQ(jkey1.size(), jkey2.size()); 66 | 67 | EXPECT_EQ(norm2(jkey1.data(), jkey1.size()), 68 | norm2(jkey2.data(), jkey2.size())); 69 | EXPECT_EQ(norm1(jval1.data(), jval1.size()), 70 | norm1(jval2.data(), jval2.size())); 71 | } 72 | } // namespace 73 | 74 | 75 | TEST(KVUnion, Union) { 76 | for (int i = 0; i < 10; ++i) { 77 | test(1000, 1); 78 | } 79 | } 80 | 81 | TEST(KVUnion, Val3) { 82 | for (int i = 0; i < 10; ++i) { 83 | test(1000, 4); 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /tests/cpp/lbfgs_twoloop_test.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include 5 | #include "lbfgs/lbfgs_twoloop.h" 6 | #include "./utils.h" 7 | 8 | using namespace difacto; 9 | using namespace difacto::lbfgs; 10 | 11 | /** 12 | * \brief a reference implementation 13 | */ 14 | void TwoloopRefer(const std::vector>& s, 15 | const std::vector>& y, 16 | const SArray& g, 17 | SArray* p) { 18 | int m = s.size(); 19 | std::vector alpha(m); 20 | size_t n = g.size(); 21 | p->resize(n); memset(p->data(), 0, n*sizeof(real_t)); 22 | 23 | Add(-1, g, p); 24 | for (int i = m-1; i >= 0; i--) { 25 | alpha[i] = Inner(s[i], *p) / Inner(s[i], y[i]); 26 | Add(-alpha[i], y[i], p); 27 | } 28 | 29 | if (m > 0) { 30 | double x = Inner(s[m-1], y[m-1])/Inner(y[m-1], y[m-1])-1; 31 | Add(x, *p, p); 32 | } 33 | 34 | for (int i = 0; i < m; ++i) { 35 | double beta = Inner(y[i], *p) / Inner(s[i], y[i]); 36 | Add(alpha[i]-beta, s[i], p); 37 | } 38 | } 39 | 40 | TEST(Twoloop, naive) { 41 | SArray g = {1, 2}; 42 | std::vector> s = {{2, 3}}, y = {{3, 4}}; 43 | SArray p0, p1; 44 | 45 | TwoloopRefer(s, y, g, &p0); 46 | 47 | Twoloop two; 48 | std::vector B; 49 | two.CalcIncreB(s, y, g, &B); 50 | two.ApplyIncreB(B); 51 | two.CalcDirection(s, y, g, &p1); 52 | 53 | real_t a = (54.0-202)/25/9; 54 | real_t b = (-36.0-303)/25/9; 55 | 56 | EXPECT_LE(fabs(p0[0] - a), 1e-5); 57 | EXPECT_LE(fabs(p0[1] - b), 1e-5); 58 | EXPECT_LE(fabs(p1[0] - a), 1e-5); 59 | EXPECT_LE(fabs(p1[1] - b), 1e-5); 60 | } 61 | 62 | TEST(Twoloop, basic) { 63 | int m = 4; 64 | int n = 100; 65 | std::vector> s, y; 66 | SArray g; 67 | SArray p0, p1; 68 | Twoloop two; 69 | for (int k = 0; k < 10; ++k) { 70 | gen_vals(n, -1, 1, &g); 71 | 72 | if (static_cast(s.size()) == m-1) { 73 | s.erase(s.begin()); 74 | y.erase(y.begin()); 75 | } 76 | SArray a, b; 77 | gen_vals(n, -1, 1, &a); 78 | gen_vals(n, -1, 1, &b); 79 | s.push_back(a); 80 | y.push_back(b); 81 | 82 | TwoloopRefer(s, y, g, &p0); 83 | 84 | std::vector B; 85 | two.CalcIncreB(s, y, g, &B); 86 | two.ApplyIncreB(B); 87 | two.CalcDirection(s, y, g, &p1); 88 | 89 | EXPECT_LE(fabs(norm2(p0) - norm2(p1)) / norm2(p1), 1e-5); 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /tests/cpp/fm_loss_test.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include 5 | #include "./utils.h" 6 | #include "loss/fm_loss.h" 7 | #include "data/localizer.h" 8 | #include "loss/bin_class_metric.h" 9 | 10 | using namespace difacto; 11 | 12 | TEST(FMLoss, NoV) { 13 | SArray weight(47149); 14 | for (size_t i = 0; i < weight.size(); ++i) { 15 | weight[i] = i / 5e4; 16 | } 17 | 18 | dmlc::data::RowBlockContainer rowblk; 19 | std::vector uidx; 20 | load_data(&rowblk, &uidx); 21 | SArray w(uidx.size()); 22 | for (size_t i = 0; i < uidx.size(); ++i) { 23 | w[i] = weight[uidx[i]]; 24 | } 25 | 26 | KWArgs args = {{"V_dim", "0"}}; 27 | FMLoss loss; loss.Init(args); 28 | auto data = rowblk.GetBlock(); 29 | SArray pred(data.size); 30 | loss.Predict(data, w, {}, {}, &pred); 31 | 32 | BinClassMetric eval(data.label, pred.data(), data.size); 33 | 34 | // Progress prog; 35 | EXPECT_LT(fabs(eval.LogitObjv() - 147.4672), 1e-3); 36 | 37 | SArray grad(w.size()); 38 | loss.CalcGrad(data, w, {}, {}, pred, &grad); 39 | EXPECT_LT(fabs(norm2(grad) - 90.5817), 1e-3); 40 | } 41 | 42 | TEST(FMLoss, HasV) { 43 | int V_dim = 5; 44 | int n = 47149; 45 | std::vector weight(n*(V_dim+1)); 46 | for (int i = 0; i < n; ++i) { 47 | weight[i*(V_dim+1)] = i / 5e4; 48 | for (int j = 1; j <= V_dim; ++j) { 49 | weight[i*(V_dim+1)+j] = i * j / 5e5; 50 | } 51 | } 52 | 53 | dmlc::data::RowBlockContainer rowblk; 54 | std::vector uidx; 55 | load_data(&rowblk, &uidx); 56 | 57 | SArray w_pos(uidx.size()); 58 | SArray V_pos(uidx.size()); 59 | int p = 0; 60 | SArray w(uidx.size()*(V_dim+1)); 61 | for (size_t i = 0; i < uidx.size(); ++i) { 62 | for (int j = 0; j < V_dim+1; ++j) { 63 | w[i*(V_dim+1)+j] = weight[uidx[i]*(V_dim+1)+j]; 64 | } 65 | w_pos[i] = p; 66 | V_pos[i] = p+1; 67 | p += V_dim + 1; 68 | } 69 | 70 | KWArgs args = {{"V_dim", std::to_string(V_dim)}}; 71 | FMLoss loss; loss.Init(args); 72 | auto data = rowblk.GetBlock(); 73 | SArray pred(data.size); 74 | loss.Predict(data, w, w_pos, V_pos, &pred); 75 | 76 | // Progress prog; 77 | BinClassMetric eval(data.label, pred.data(), data.size); 78 | EXPECT_LT(fabs(eval.LogitObjv() - 330.628), 1e-3); 79 | 80 | SArray grad(w.size()); 81 | loss.CalcGrad(data, w, w_pos, V_pos, pred, &grad); 82 | EXPECT_LT(fabs(norm2(grad) - 1.2378e+03), 1e-1); 83 | } 84 | -------------------------------------------------------------------------------- /tests/cpp/spmv_test.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include 5 | #include "./utils.h" 6 | #include "common/spmv.h" 7 | #include "dmlc/timer.h" 8 | #include "./spmv_test.h" 9 | 10 | using namespace difacto; 11 | 12 | namespace { 13 | dmlc::data::RowBlockContainer data; 14 | std::vector uidx; 15 | } // namespace 16 | 17 | 18 | TEST(SpMV, Times) { 19 | load_data(&data, &uidx); 20 | auto D = data.GetBlock(); 21 | SArray x; 22 | gen_vals(uidx.size(), -10, 10, &x); 23 | 24 | SArray y1(D.size); 25 | SArray y2(D.size); 26 | 27 | test::SpMV::Times(D, x, &y1); 28 | SpMV::Times(D, x, &y2); 29 | EXPECT_EQ(norm2(y1), norm2(y2)); 30 | } 31 | 32 | TEST(SpMV, TransTimes) { 33 | load_data(&data, &uidx); 34 | auto D = data.GetBlock(); 35 | SArray x; 36 | gen_vals(D.size, -10, 10, &x); 37 | 38 | SArray y1(uidx.size()); 39 | SArray y2(uidx.size()); 40 | 41 | test::SpMV::TransTimes(D, x, &y1); 42 | SpMV::TransTimes(D, x, &y2); 43 | EXPECT_EQ(norm2(y1), norm2(y2)); 44 | } 45 | 46 | TEST(SpMV, TimesPos) { 47 | load_data(&data, &uidx); 48 | auto D = data.GetBlock(); 49 | SArray x; 50 | gen_vals(uidx.size(), -10, 10, &x); 51 | SArray x_pos; 52 | SArray x_val; 53 | test::gen_sliced_vec(x, &x_val, &x_pos); 54 | 55 | SArray y(D.size); 56 | SArray y_pos; 57 | SArray y_val; 58 | test::gen_sliced_vec(y, &y_val, &y_pos); 59 | memset(y_val.data(), 0, y_val.size()*sizeof(real_t)); 60 | 61 | test::SpMV::Times(D, x, &y); 62 | SpMV::Times(D, x_val, &y_val, DEFAULT_NTHREADS, x_pos, y_pos); 63 | 64 | EXPECT_EQ(norm2(y), norm2(y_val)); 65 | 66 | SArray y2; 67 | test::slice_vec(y_val, y_pos, &y2); 68 | EXPECT_EQ(norm2(y), norm2(y2)); 69 | } 70 | 71 | TEST(SpMV, TransTimesPos) { 72 | load_data(&data, &uidx); 73 | auto D = data.GetBlock(); 74 | SArray x; 75 | gen_vals(D.size, -10, 10, &x); 76 | SArray x_pos; 77 | SArray x_val; 78 | test::gen_sliced_vec(x, &x_val, &x_pos); 79 | 80 | SArray y(uidx.size()); 81 | SArray y_pos; 82 | SArray y_val; 83 | test::gen_sliced_vec(y, &y_val, &y_pos); 84 | memset(y_val.data(), 0, y_val.size()*sizeof(real_t)); 85 | 86 | test::SpMV::TransTimes(D, x, &y); 87 | SpMV::TransTimes(D, x_val, &y_val, DEFAULT_NTHREADS, x_pos, y_pos); 88 | EXPECT_EQ(norm2(y), norm2(y_val)); 89 | 90 | SArray y2; 91 | test::slice_vec(y_val, y_pos, &y2); 92 | EXPECT_EQ(norm2(y), norm2(y2)); 93 | } 94 | -------------------------------------------------------------------------------- /tests/cpp/spmm_test.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include 5 | #include "./utils.h" 6 | #include "common/spmv.h" 7 | #include "common/spmm.h" 8 | #include "dmlc/timer.h" 9 | #include "./spmv_test.h" 10 | #include "./spmm_test.h" 11 | 12 | using namespace difacto; 13 | 14 | dmlc::data::RowBlockContainer data; 15 | std::vector uidx; 16 | 17 | TEST(SpMM, Times) { 18 | load_data(&data, &uidx); 19 | auto D = data.GetBlock(); 20 | int k = 10; 21 | SArray x; 22 | gen_vals(uidx.size()*k, -10, 10, &x); 23 | 24 | SArray y1(D.size*k); 25 | SArray y2(D.size*k); 26 | 27 | test::SpMM::Times(D, x, &y1); 28 | SpMM::Times(D, x, k, &y2); 29 | EXPECT_EQ(norm2(y1), norm2(y2)); 30 | } 31 | 32 | TEST(SpMM, TransTimes) { 33 | load_data(&data, &uidx); 34 | auto D = data.GetBlock(); 35 | SArray x; 36 | int k = 10; 37 | gen_vals(D.size*k, -10, 10, &x); 38 | 39 | SArray y1(uidx.size()*k); 40 | SArray y2(uidx.size()*k); 41 | 42 | test::SpMM::TransTimes(D, x, &y1); 43 | SpMM::TransTimes(D, x, k, &y2); 44 | EXPECT_EQ(norm2(y1), norm2(y2)); 45 | } 46 | 47 | TEST(SpMM, TimesPosVec) { 48 | load_data(&data, &uidx); 49 | auto D = data.GetBlock(); 50 | SArray x; 51 | gen_vals(uidx.size(), -10, 10, &x); 52 | SArray x_pos; 53 | SArray x_val; 54 | test::gen_sliced_vec(x, &x_val, &x_pos); 55 | 56 | SArray y(D.size); 57 | SArray y_pos; 58 | SArray y_val; 59 | test::gen_sliced_vec(y, &y_val, &y_pos); 60 | memset(y_val.data(), 0, y_val.size()*sizeof(real_t)); 61 | SArray y_val2(y_val.size()); 62 | 63 | SpMM::Times(D, x_val, 1, &y_val2, DEFAULT_NTHREADS, x_pos, y_pos); 64 | SpMV::Times(D, x_val, &y_val, DEFAULT_NTHREADS, x_pos, y_pos); 65 | 66 | EXPECT_EQ(norm2(y_val), norm2(y_val2)); 67 | } 68 | 69 | 70 | TEST(SpMM, TransTimesPosVec) { 71 | load_data(&data, &uidx); 72 | auto D = data.GetBlock(); 73 | SArray x; 74 | gen_vals(D.size, -10, 10, &x); 75 | SArray x_pos; 76 | SArray x_val; 77 | test::gen_sliced_vec(x, &x_val, &x_pos); 78 | 79 | SArray y(uidx.size()); 80 | SArray y_pos; 81 | SArray y_val; 82 | test::gen_sliced_vec(y, &y_val, &y_pos); 83 | memset(y_val.data(), 0, y_val.size()*sizeof(real_t)); 84 | SArray y_val2(y_val.size()); 85 | 86 | SpMM::TransTimes(D, x_val, 1, &y_val2, DEFAULT_NTHREADS, x_pos, y_pos); 87 | SpMV::TransTimes(D, x_val, &y_val, DEFAULT_NTHREADS, x_pos, y_pos); 88 | EXPECT_EQ(norm2(y_val), norm2(y_val2)); 89 | } 90 | -------------------------------------------------------------------------------- /src/data/data_store_impl.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_DATA_DATA_STORE_IMPL_H_ 5 | #define DIFACTO_DATA_DATA_STORE_IMPL_H_ 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include "common/range.h" 13 | #include "difacto/sarray.h" 14 | namespace difacto { 15 | 16 | class DataStoreImpl { 17 | public: 18 | DataStoreImpl() { } 19 | virtual ~DataStoreImpl() { } 20 | /** 21 | * \brief push a data into the store 22 | * 23 | * @param key the unique key 24 | * @param data the data buff 25 | */ 26 | virtual void Store(const std::string& key, const SArray& data) = 0; 27 | /** 28 | * \brief pull data from the store 29 | * 30 | * @param key the unique key 31 | * @param range only pull a range of the data. If it is Range::All(), then pul 32 | * the whole data 33 | * @param data the pulled data 34 | */ 35 | virtual void Fetch(const std::string& key, Range range, SArray* data) = 0; 36 | 37 | /** 38 | * \brief pretech a data 39 | * 40 | * @param key 41 | * @param range 42 | */ 43 | virtual void Prefetch(const std::string& key, Range range) = 0; 44 | 45 | /** 46 | * \brief remove data from the store 47 | * \param key the unique key of the data 48 | */ 49 | virtual void Remove(const std::string& key) = 0; 50 | }; 51 | 52 | /** 53 | * \brief a naive implementation which puts all things in memory 54 | */ 55 | class DataStoreMemory : public DataStoreImpl { 56 | public: 57 | DataStoreMemory() { } 58 | virtual ~DataStoreMemory() { } 59 | void Store(const std::string& key, const SArray& data) override { 60 | store_[key] = data; 61 | } 62 | void Fetch(const std::string& key, Range range, SArray* data) override { 63 | auto it = store_.find(key); 64 | CHECK(it != store_.end()); 65 | *CHECK_NOTNULL(data) = it->second.segment(range.begin, range.end); 66 | } 67 | void Prefetch(const std::string& key, Range range) override { } 68 | void Remove(const std::string& key) override { store_.erase(key); } 69 | 70 | private: 71 | std::unordered_map> store_; 72 | }; 73 | 74 | /** 75 | * \brief write data back to disk if exeeds the maximal memory capacity 76 | */ 77 | class DataStoreDisk : public DataStoreImpl { 78 | public: 79 | DataStoreDisk(const std::string& cache_prefix, 80 | size_t max_mem_capacity) { 81 | } 82 | virtual ~DataStoreDisk() { } 83 | }; 84 | 85 | } // namespace difacto 86 | #endif // DIFACTO_DATA_DATA_STORE_IMPL_H_ 87 | -------------------------------------------------------------------------------- /src/loss/logit_loss.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_LOSS_LOGIT_LOSS_H_ 5 | #define DIFACTO_LOSS_LOGIT_LOSS_H_ 6 | #include 7 | #include 8 | #include "difacto/base.h" 9 | #include "difacto/loss.h" 10 | #include "dmlc/data.h" 11 | #include "dmlc/omp.h" 12 | #include "common/spmv.h" 13 | namespace difacto { 14 | 15 | /** 16 | * \brief the logistic loss 17 | * 18 | * :math:`\ell(x,y,w) = log(1 + exp(- y ))` 19 | * 20 | */ 21 | class LogitLoss : public Loss { 22 | public: 23 | LogitLoss() {} 24 | virtual ~LogitLoss() {} 25 | 26 | KWArgs Init(const KWArgs& kwargs) override { 27 | return kwargs; 28 | } 29 | 30 | /** 31 | * \brief perform prediction 32 | * 33 | * pred += X * w 34 | * 35 | * @param data the data X 36 | * @param param input parameters 37 | * - param[0], real_t vector, the weights 38 | * - param[1], optional int vector, the weight positions 39 | * @param pred predict output, should be pre-allocated 40 | */ 41 | void Predict(const dmlc::RowBlock& data, 42 | const std::vector>& param, 43 | SArray* pred) override { 44 | int psize = param.size(); 45 | CHECK_GE(psize, 1); CHECK_LE(psize, 2); 46 | SArray w(param[0]); 47 | SArray w_pos = psize == 2 ? SArray(param[1]) : SArray(); 48 | SpMV::Times(data, w, pred, nthreads_, w_pos, {}); 49 | } 50 | 51 | /*! 52 | * \brief compute the gradients 53 | * 54 | * p = - y ./ (1 + exp (y .* pred)); 55 | * grad += X' * p; 56 | * 57 | * @param data the data X 58 | * @param param input parameters 59 | * - param[0], real_t vector, the predict output 60 | * - param[1], optional int vector, the gradient positions 61 | * @param grad the results, should be pre-allocated 62 | */ 63 | void CalcGrad(const dmlc::RowBlock& data, 64 | const std::vector>& param, 65 | SArray* grad) override { 66 | int psize = param.size(); 67 | CHECK_GE(psize, 1); 68 | CHECK_LE(psize, 2); 69 | SArray p; p.CopyFrom(SArray(param[0])); 70 | SArray grad_pos = psize == 2 ? SArray(param[1]) : SArray(); 71 | // p = ... 72 | CHECK_NOTNULL(data.label); 73 | #pragma omp parallel for num_threads(nthreads_) 74 | for (size_t i = 0; i < p.size(); ++i) { 75 | real_t y = data.label[i] > 0 ? 1 : -1; 76 | p[i] = - y / (1 + std::exp(y * p[i])); 77 | } 78 | 79 | // grad += ... 80 | SpMV::TransTimes(data, p, grad, nthreads_, {}, grad_pos); 81 | } 82 | }; 83 | 84 | } // namespace difacto 85 | #endif // DIFACTO_LOSS_LOGIT_LOSS_H_ 86 | -------------------------------------------------------------------------------- /include/difacto/loss.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | * @file loss.h 4 | * @brief the basic class of a loss function 5 | */ 6 | #ifndef DIFACTO_LOSS_H_ 7 | #define DIFACTO_LOSS_H_ 8 | #include 9 | #include 10 | #include "./base.h" 11 | #include "dmlc/data.h" 12 | #include "dmlc/omp.h" 13 | #include "./sarray.h" 14 | namespace difacto { 15 | /** 16 | * \brief the basic class of a loss function 17 | */ 18 | class Loss { 19 | public: 20 | /** 21 | * \brief the factory function 22 | * \param type the loss type such as "fm" 23 | * \param num_threads number of threads 24 | */ 25 | static Loss* Create(const std::string& type, int nthreads = DEFAULT_NTHREADS); 26 | /** \brief constructor */ 27 | Loss() : nthreads_(DEFAULT_NTHREADS) { } 28 | /** \brief deconstructor */ 29 | virtual ~Loss() { } 30 | /** 31 | * \brief init the loss function 32 | * 33 | * @param kwargs keyword arguments 34 | * @return the unknown kwargs 35 | */ 36 | virtual KWArgs Init(const KWArgs& kwargs) = 0; 37 | /** 38 | * \brief predict given the data and model weights. often known as "forward" 39 | * 40 | * @param data the data 41 | * @param param model weights 42 | * @return pred the predict results 43 | */ 44 | virtual void Predict(const dmlc::RowBlock& data, 45 | const std::vector>& param, 46 | SArray* pred) = 0; 47 | /** 48 | * \brief evaluate the loss 49 | * 50 | * return logit loss in default 51 | * 52 | * @param label label 53 | * @param pred prediction 54 | * 55 | * @return the objective value 56 | */ 57 | virtual real_t Evaluate(dmlc::real_t const* label, 58 | const SArray& pred) const { 59 | real_t objv = 0; 60 | #pragma omp parallel for reduction(+:objv) num_threads(nthreads_) 61 | for (size_t i = 0; i < pred.size(); ++i) { 62 | real_t y = label[i] > 0 ? 1 : -1; 63 | objv += log(1 + exp(- y * pred[i])); 64 | } 65 | return objv; 66 | } 67 | 68 | /** 69 | * \brief calculate gradient given the data and model weights. often known as "backward" 70 | * @param data the data 71 | * @param param model weights 72 | * @return grad the gradients 73 | */ 74 | virtual void CalcGrad(const dmlc::RowBlock& data, 75 | const std::vector>& param, 76 | SArray* grad) = 0; 77 | /** 78 | * \brief set the number of threads 79 | */ 80 | void set_nthreads(int nthreads) { 81 | CHECK_GT(nthreads, 1); CHECK_LT(nthreads, 50); 82 | nthreads_ = nthreads; 83 | } 84 | 85 | 86 | int nthreads_; 87 | }; 88 | } // namespace difacto 89 | #endif // DIFACTO_LOSS_H_ 90 | -------------------------------------------------------------------------------- /src/reader/adfea_parser.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | * @file adfea_parser.h 4 | * @brief parse adfea ctr data format 5 | */ 6 | #ifndef DIFACTO_READER_ADFEA_PARSER_H_ 7 | #define DIFACTO_READER_ADFEA_PARSER_H_ 8 | #include 9 | #include 10 | #include "difacto/base.h" 11 | #include "data/row_block.h" 12 | #include "data/parser.h" 13 | #include "data/strtonum.h" 14 | namespace difacto { 15 | 16 | /** 17 | * \brief adfea ctr dataset 18 | * the top 10 bits store the feature group id 19 | */ 20 | class AdfeaParser : public dmlc::data::ParserImpl { 21 | public: 22 | explicit AdfeaParser(dmlc::InputSplit *source) 23 | : bytes_read_(0), source_(source) { } 24 | virtual ~AdfeaParser() { 25 | delete source_; 26 | } 27 | 28 | void BeforeFirst(void) override { 29 | source_->BeforeFirst(); 30 | } 31 | size_t BytesRead(void) const override { 32 | return bytes_read_; 33 | } 34 | bool ParseNext( 35 | std::vector > *data) override { 36 | using dmlc::data::isspace; 37 | using dmlc::data::isdigit; 38 | using dmlc::data::strtoull; 39 | 40 | dmlc::InputSplit::Blob chunk; 41 | 42 | if (!source_->NextChunk(&chunk)) return false; 43 | 44 | CHECK_NE(chunk.size, 0); 45 | bytes_read_ += chunk.size; 46 | data->resize(1); 47 | dmlc::data::RowBlockContainer& blk = (*data)[0]; 48 | blk.Clear(); 49 | int i = 0; 50 | char *p = reinterpret_cast(chunk.dptr); 51 | char *end = p + chunk.size; 52 | 53 | while (isspace(*p) && p != end) ++p; 54 | while (p != end) { 55 | char *head = p; 56 | while (isdigit(*p) && p != end) ++p; 57 | CHECK_NE(head, p); 58 | 59 | if (*p == ':') { 60 | ++p; 61 | feaid_t idx = strtoull(head, NULL, 10); 62 | feaid_t gid = strtoull(p, NULL, 10); 63 | blk.index.push_back(EncodeFeaGrpID(idx, gid, 12)); 64 | while (isdigit(*p) && p != end) ++p; 65 | } else { 66 | // skip the lineid and the first count 67 | if (i == 2) { 68 | i = 0; 69 | if (blk.label.size() != 0) { 70 | blk.offset.push_back(blk.index.size()); 71 | } 72 | blk.label.push_back(*head == '1'); 73 | } else { 74 | ++i; 75 | } 76 | } 77 | 78 | while (isspace(*p) && p != end) ++p; 79 | } 80 | if (blk.label.size() != 0) { 81 | blk.offset.push_back(blk.index.size()); 82 | } 83 | return true; 84 | } 85 | 86 | private: 87 | // number of bytes readed 88 | size_t bytes_read_; 89 | // source split that provides the data 90 | dmlc::InputSplit *source_; 91 | }; 92 | 93 | } // namespace difacto 94 | #endif // DIFACTO_READER_ADFEA_PARSER_H_ 95 | -------------------------------------------------------------------------------- /src/common/thread_pool.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_COMMON_THREAD_POOL_H_ 5 | #define DIFACTO_COMMON_THREAD_POOL_H_ 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | namespace difacto { 14 | /** 15 | * \brief a pool with multiple threads 16 | */ 17 | class ThreadPool { 18 | public: 19 | /** 20 | * \brief create a threadpool 21 | * 22 | * @param num_workers number of threads 23 | * @param max_capacity the maximal jobs can be added to the pool 24 | */ 25 | explicit ThreadPool(int num_workers, int max_capacity = 1000000) { 26 | CHECK_GT(max_capacity, 0); 27 | CHECK_GT(num_workers, 0); 28 | CHECK_LT(num_workers, 100); 29 | capacity_ = max_capacity; 30 | for (int i = 0; i < num_workers; ++i) { 31 | workers_.push_back(std::thread([this, i](){ 32 | RunWorker(i); 33 | })); 34 | } 35 | } 36 | 37 | /** 38 | * \brief will wait all jobs are done before deconstruction 39 | */ 40 | ~ThreadPool() { 41 | Wait(); 42 | done_ = true; 43 | add_cond_.notify_all(); 44 | for (size_t i = 0; i < workers_.size(); ++i) { 45 | workers_[i].join(); 46 | } 47 | } 48 | 49 | /** 50 | * \brief add a job to the pool 51 | * return immmediatly if the current number of unfinished jobs is less than the 52 | * max_capacity. otherwise wait until the pool is available 53 | * @param job 54 | */ 55 | void Add(const std::function& job) { 56 | std::unique_lock lk(mu_); 57 | fin_cond_.wait(lk, [this]{ return tasks_.size() < capacity_; }); 58 | tasks_.push_back(job); 59 | add_cond_.notify_one(); 60 | } 61 | 62 | /** 63 | * \brief wait untill all jobs are finished 64 | */ 65 | void Wait() { 66 | std::unique_lock lk(mu_); 67 | fin_cond_.wait(lk, [this]{ return num_running_ == 0 && tasks_.empty(); }); 68 | } 69 | 70 | private: 71 | void RunWorker(int tid) { 72 | std::unique_lock lk(mu_); 73 | while (true) { 74 | add_cond_.wait(lk, [this]{ return done_ || !tasks_.empty(); }); 75 | if (done_) break; 76 | // run a job 77 | auto task = std::move(tasks_.front()); 78 | tasks_.pop_front(); 79 | ++num_running_; 80 | lk.unlock(); 81 | CHECK(task); task(tid); 82 | --num_running_; 83 | fin_cond_.notify_all(); 84 | lk.lock(); 85 | } 86 | } 87 | std::atomic num_running_{0}; 88 | std::atomic done_{false}; 89 | size_t capacity_; 90 | std::mutex mu_; 91 | std::condition_variable fin_cond_, add_cond_; 92 | std::vector workers_; 93 | std::list> tasks_; 94 | }; 95 | } // namespace difacto 96 | #endif // DIFACTO_COMMON_THREAD_POOL_H_ 97 | -------------------------------------------------------------------------------- /src/common/spmt.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_COMMON_SPMT_H_ 5 | #define DIFACTO_COMMON_SPMT_H_ 6 | #include 7 | #include 8 | #include "dmlc/data.h" 9 | #include "dmlc/omp.h" 10 | #include "./range.h" 11 | #include "data/row_block.h" 12 | #include "difacto/base.h" 13 | namespace difacto { 14 | 15 | /** 16 | * \brief multi-thread sparse matrix transpose 17 | */ 18 | class SpMT { 19 | public: 20 | /** 21 | * \brief transpose matrix Y = X' 22 | * \param X sparse matrix in row major 23 | * \param Y sparse matrix in row major 24 | * \param X_ncols optional, number of columns in X 25 | * \param nt optional, number of threads 26 | */ 27 | static void Transpose(const dmlc::RowBlock& X, 28 | dmlc::data::RowBlockContainer* Y, 29 | unsigned X_ncols = 0, 30 | int nt = DEFAULT_NTHREADS) { 31 | // find number of columns in X 32 | size_t nrows = X.size; 33 | size_t nnz = X.offset[nrows] - X.offset[0]; 34 | if (X_ncols == 0) { 35 | for (size_t i = 0; i < nnz; ++i) { 36 | if (X_ncols < X.index[i]) X_ncols = X.index[i]; 37 | } 38 | ++X_ncols; 39 | } 40 | 41 | // allocate Y 42 | CHECK_NOTNULL(Y); 43 | Y->offset.clear(); 44 | Y->offset.resize(X_ncols+1, 0); 45 | Y->index.resize(nnz); 46 | if (X.value) Y->value.resize(nnz); 47 | 48 | // fill Y->offset 49 | #pragma omp parallel num_threads(nt) 50 | { 51 | Range range = Range(0, X_ncols).Segment( 52 | omp_get_thread_num(), omp_get_num_threads()); 53 | for (size_t i = 0; i < nnz; ++i) { 54 | unsigned k = X.index[i]; 55 | if (!range.Has(k)) continue; 56 | ++Y->offset[k+1]; 57 | } 58 | } 59 | for (size_t i = 0; i < X_ncols; ++i) { 60 | Y->offset[i+1] += Y->offset[i]; 61 | } 62 | 63 | // fill Y->index and Y->value 64 | #pragma omp parallel num_threads(nt) 65 | { 66 | Range range = Range(0, X_ncols).Segment( 67 | omp_get_thread_num(), omp_get_num_threads()); 68 | 69 | for (size_t i = 0; i < nrows; ++i) { 70 | if (X.offset[i] == X.offset[i+1]) continue; 71 | for (size_t j = X.offset[i]; j < X.offset[i+1]; ++j) { 72 | unsigned k = X.index[j]; 73 | if (!range.Has(k)) continue; 74 | if (X.value) { 75 | Y->value[Y->offset[k]] = X.value[j]; 76 | } 77 | Y->index[Y->offset[k]] = static_cast(i); 78 | ++Y->offset[k]; 79 | } 80 | } 81 | } 82 | 83 | // restore Y->offset 84 | if (X_ncols > 0) { 85 | for (size_t i = X_ncols -1; i > 0; --i) { 86 | Y->offset[i] = Y->offset[i-1]; 87 | } 88 | Y->offset[0] = 0; 89 | } 90 | } 91 | }; 92 | } // namespace difacto 93 | #endif // DIFACTO_COMMON_SPMT_H_ 94 | -------------------------------------------------------------------------------- /include/difacto/store.h: -------------------------------------------------------------------------------- 1 | /*! 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_STORE_H_ 5 | #define DIFACTO_STORE_H_ 6 | #include 7 | #include 8 | #include 9 | #include "./base.h" 10 | #include "dmlc/io.h" 11 | #include "dmlc/parameter.h" 12 | #include "./sarray.h" 13 | #include "./updater.h" 14 | namespace difacto { 15 | 16 | /** 17 | * \brief the store allows workers to get and set and model 18 | */ 19 | class Store { 20 | public: 21 | /** 22 | * \brief the factory function 23 | */ 24 | static Store* Create(); 25 | 26 | /** \brief default constructor */ 27 | Store() { } 28 | /** \brief default deconstructor */ 29 | virtual ~Store() { } 30 | 31 | static const int kFeaCount = 1; 32 | static const int kWeight = 2; 33 | static const int kGradient = 3; 34 | /** 35 | * \brief init 36 | * 37 | * @param kwargs keyword arguments 38 | * @return the unknown kwargs 39 | */ 40 | virtual KWArgs Init(const KWArgs& kwargs) = 0; 41 | 42 | /** 43 | * \brief push a list of (feature id, value) into the store 44 | * 45 | * @param sync_type 46 | * @param fea_ids 47 | * @param vals 48 | * @param lens 49 | * @param on_complete 50 | * 51 | * @return 52 | */ 53 | virtual int Push(const SArray& fea_ids, 54 | int val_type, 55 | const SArray& vals, 56 | const SArray& lens, 57 | const std::function& on_complete = nullptr) = 0; 58 | /** 59 | * \brief pull the values for a list of feature ids 60 | * 61 | * @param sync_type 62 | * @param fea_ids 63 | * @param vals 64 | * @param lens 65 | * @param on_complete 66 | * 67 | * @return 68 | */ 69 | virtual int Pull(const SArray& fea_ids, 70 | int val_type, 71 | SArray* vals, 72 | SArray* lens, 73 | const std::function& on_complete = nullptr) = 0; 74 | 75 | /** 76 | * \brief wait until a push or a pull is actually finished 77 | * 78 | * @param time 79 | */ 80 | virtual void Wait(int time) = 0; 81 | 82 | /** 83 | * \brief return number of workers 84 | */ 85 | virtual int NumWorkers() = 0; 86 | /** 87 | * \brief return number of servers 88 | */ 89 | virtual int NumServers() = 0; 90 | /** 91 | * \brief return the rank of this node 92 | */ 93 | virtual int Rank() = 0; 94 | 95 | /** \brief set an updater for the store, only required for a server node */ 96 | void SetUpdater(const std::shared_ptr& updater) { 97 | updater_ = updater; 98 | } 99 | /** \brief get the updater */ 100 | std::shared_ptr updater() { return updater_; } 101 | 102 | protected: 103 | std::shared_ptr updater_; 104 | }; 105 | 106 | } // namespace difacto 107 | 108 | #endif // DIFACTO_STORE_H_ 109 | -------------------------------------------------------------------------------- /src/reader/batch_reader.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include "./batch_reader.h" 5 | namespace difacto { 6 | 7 | BatchReader::BatchReader( 8 | const std::string& uri, const std::string& format, 9 | unsigned part_index, unsigned num_parts, 10 | unsigned batch_size, unsigned shuffle_buf_size, 11 | float neg_sampling) { 12 | batch_size_ = batch_size; 13 | shuf_buf_ = shuffle_buf_size; 14 | neg_sampling_ = neg_sampling; 15 | start_ = 0; 16 | end_ = 0; 17 | seed_ = 0; 18 | if (shuf_buf_) { 19 | CHECK_GE(shuf_buf_, batch_size_); 20 | buf_reader_ = new BatchReader( 21 | uri, format, part_index, num_parts, shuf_buf_); 22 | reader_ = NULL; 23 | } else { 24 | buf_reader_ = NULL; 25 | reader_ = new Reader(uri, format, part_index, num_parts, 1<<26); 26 | } 27 | } 28 | 29 | bool BatchReader::Next() { 30 | batch_.Clear(); 31 | while (batch_.offset.size() < batch_size_ + 1) { 32 | if (start_ == end_) { 33 | if (shuf_buf_ == 0) { 34 | // no random shuffle 35 | if (!reader_->Next()) break; 36 | in_blk_ = reader_->Value(); 37 | } else { 38 | // do random shuffle 39 | if (!buf_reader_->Next()) break; 40 | in_blk_ = buf_reader_->Value(); 41 | if (rdp_.size() != in_blk_.size) { 42 | rdp_.resize(in_blk_.size); 43 | for (size_t i = 0; i < in_blk_.size; ++i) rdp_[i] = i; 44 | } 45 | std::random_shuffle(rdp_.begin(), rdp_.end()); 46 | } 47 | start_ = 0; 48 | end_ = in_blk_.size; 49 | } 50 | 51 | size_t len = std::min(end_ - start_, batch_size_ + 1 - batch_.offset.size()); 52 | if (shuf_buf_ == 0 && neg_sampling_ == 1.0) { 53 | Push(start_, len); 54 | } else { 55 | for (size_t i = start_; i < start_ + len; ++i) { 56 | int j = rdp_[i]; 57 | // downsampling 58 | float p = static_cast(rand_r(&seed_)) / 59 | static_cast(RAND_MAX); 60 | if (neg_sampling_ < 1.0 && 61 | in_blk_.label[j] <= 0 && 62 | p > 1 - neg_sampling_) { 63 | continue; 64 | } 65 | batch_.Push(in_blk_[j]); 66 | } 67 | } 68 | start_ += len; 69 | } 70 | 71 | bool binary = true; 72 | for (auto f : batch_.value) if (f != 1) { binary = false; break; } 73 | if (binary) batch_.value.clear(); 74 | 75 | out_blk_ = batch_.GetBlock(); 76 | 77 | return out_blk_.size > 0; 78 | } 79 | 80 | void BatchReader::Push(size_t pos, size_t len) { 81 | if (!len) return; 82 | CHECK_LE(pos + len, in_blk_.size); 83 | dmlc::RowBlock slice; 84 | slice.weight = NULL; 85 | slice.size = len; 86 | slice.offset = in_blk_.offset + pos; 87 | slice.label = in_blk_.label + pos; 88 | slice.index = in_blk_.index + in_blk_.offset[pos]; 89 | if (in_blk_.value) { 90 | slice.value = in_blk_.value + in_blk_.offset[pos]; 91 | } else { 92 | slice.value = NULL; 93 | } 94 | batch_.Push(slice); 95 | } 96 | 97 | } // namespace difacto 98 | -------------------------------------------------------------------------------- /src/common/kv_union.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_COMMON_KV_UNION_H_ 5 | #define DIFACTO_COMMON_KV_UNION_H_ 6 | #include 7 | #include "./kv_match.h" 8 | namespace difacto { 9 | 10 | 11 | /** 12 | * \brief Join two key-value lists 13 | * 14 | * \code 15 | * key_a = {1,2,3}; 16 | * val_a = {2,3,4}; 17 | * key_b = {1,3,5}; 18 | * val_b = {3,4,5}; 19 | * KVUnion(key_a, val_a, key_b, val_b, &joined_key, &joined_val); 20 | * // then joined_key = {1,2,3,5}; and joined_val = {5,3,8,5}; 21 | * \endcode 22 | * 23 | * \tparam K type of key 24 | * \tparam V type of value 25 | * @param keys_a keys from list a 26 | * @param vals_a values from list a 27 | * @param keys_b keys from list b 28 | * @param vals_b values from list b 29 | * @param joined_key the union of key1 and key2 30 | * @param joined_val the union of val1 and val2 31 | * @param op the assignment operator (default is PLUS) 32 | * @param num_threads number of thread (default is 2) 33 | */ 34 | template 35 | void KVUnion( 36 | const SArray& keys_a, 37 | const SArray& vals_a, 38 | const SArray& keys_b, 39 | const SArray& vals_b, 40 | SArray* joined_keys, 41 | SArray* joined_vals, 42 | AssignOp op = PLUS, 43 | int num_threads = DEFAULT_NTHREADS) { 44 | if (keys_a.empty()) { 45 | joined_keys->CopyFrom(keys_b); 46 | joined_vals->CopyFrom(vals_b); 47 | return; 48 | } 49 | if (keys_b.empty()) { 50 | joined_keys->CopyFrom(keys_a); 51 | joined_vals->CopyFrom(vals_a); 52 | return; 53 | } 54 | 55 | // merge keys 56 | CHECK_NOTNULL(joined_keys)->resize(keys_a.size() + keys_b.size()); 57 | auto end = std::set_union(keys_a.begin(), keys_a.end(), keys_b.begin(), keys_b.end(), 58 | joined_keys->begin()); 59 | joined_keys->resize(end - joined_keys->begin()); 60 | 61 | // merge value of list a 62 | size_t val_len = vals_a.size() / keys_a.size(); 63 | CHECK_NOTNULL(joined_vals)->clear(); 64 | size_t n1 = KVMatch( 65 | keys_a, vals_a, *joined_keys, joined_vals, ASSIGN, num_threads); 66 | CHECK_EQ(n1, keys_a.size() * val_len); 67 | 68 | // merge value list b 69 | auto n2 = KVMatch( 70 | keys_b, vals_b, *joined_keys, joined_vals, op, num_threads); 71 | CHECK_EQ(n2, keys_b.size() * val_len); 72 | } 73 | 74 | /** 75 | * \brief Join two key-value lists 76 | * 77 | * joined_keys = joined_keys \cup keys 78 | * joined_vals = joined_vals \cup vals 79 | */ 80 | template 81 | void KVUnion( 82 | const SArray& keys, 83 | const SArray& vals, 84 | SArray* joined_keys, 85 | SArray* joined_vals, 86 | AssignOp op = PLUS, 87 | int num_threads = DEFAULT_NTHREADS) { 88 | SArray new_keys; 89 | SArray new_vals; 90 | KVUnion(keys, vals, *joined_keys, *joined_vals, 91 | &new_keys, &new_vals, op, num_threads); 92 | *joined_keys = new_keys; 93 | *joined_vals = new_vals; 94 | } 95 | 96 | 97 | } // namespace difacto 98 | #endif // DIFACTO_COMMON_KV_UNION_H_ 99 | -------------------------------------------------------------------------------- /src/loss/bin_class_metric.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_LOSS_BIN_CLASS_METRIC_H_ 5 | #define DIFACTO_LOSS_BIN_CLASS_METRIC_H_ 6 | #include 7 | #include 8 | #include "difacto/base.h" 9 | #include "dmlc/logging.h" 10 | #include "dmlc/omp.h" 11 | #include "difacto/sarray.h" 12 | namespace difacto { 13 | 14 | /** 15 | * \brief binary classificatoin metrics 16 | * all metrics are not divided by num_examples 17 | */ 18 | class BinClassMetric { 19 | public: 20 | /** 21 | * \brief constructor 22 | * 23 | * @param label label vector 24 | * @param predict predict vector 25 | * @param n length 26 | * @param nthreads num threads 27 | */ 28 | BinClassMetric(const dmlc::real_t* const label, 29 | const real_t* const predict, 30 | size_t n, int nthreads = DEFAULT_NTHREADS) 31 | : label_(label), predict_(predict), size_(n), nt_(nthreads) { } 32 | 33 | ~BinClassMetric() { } 34 | 35 | real_t AUC() { 36 | size_t n = size_; 37 | struct Entry { dmlc::real_t label; real_t predict; }; 38 | std::vector buff(n); 39 | for (size_t i = 0; i < n; ++i) { 40 | buff[i].label = label_[i]; 41 | buff[i].predict = predict_[i]; 42 | } 43 | std::sort(buff.data(), buff.data()+n, [](const Entry& a, const Entry&b) { 44 | return a.predict < b.predict; }); 45 | real_t area = 0, cum_tp = 0; 46 | for (size_t i = 0; i < n; ++i) { 47 | if (buff[i].label > 0) { 48 | cum_tp += 1; 49 | } else { 50 | area += cum_tp; 51 | } 52 | } 53 | if (cum_tp == 0 || cum_tp == n) return 1; 54 | area /= cum_tp * (n - cum_tp); 55 | return (area < 0.5 ? 1 - area : area) * n; 56 | } 57 | 58 | real_t Accuracy(real_t threshold) { 59 | real_t correct = 0; 60 | size_t n = size_; 61 | #pragma omp parallel for reduction(+:correct) num_threads(nt_) 62 | for (size_t i = 0; i < n; ++i) { 63 | if ((label_[i] > 0 && predict_[i] > threshold) || 64 | (label_[i] <= 0 && predict_[i] <= threshold)) 65 | correct += 1; 66 | } 67 | return correct > 0.5 * n ? correct : n - correct; 68 | } 69 | 70 | real_t LogLoss() { 71 | real_t loss = 0; 72 | size_t n = size_; 73 | #pragma omp parallel for reduction(+:loss) num_threads(nt_) 74 | for (size_t i = 0; i < n; ++i) { 75 | real_t y = label_[i] > 0; 76 | real_t p = 1 / (1 + exp(- predict_[i])); 77 | p = p < 1e-10 ? 1e-10 : p; 78 | loss += y * log(p) + (1 - y) * log(1 - p); 79 | } 80 | return - loss; 81 | } 82 | 83 | real_t LogitObjv() { 84 | real_t objv = 0; 85 | #pragma omp parallel for reduction(+:objv) num_threads(nt_) 86 | for (size_t i = 0; i < size_; ++i) { 87 | real_t y = label_[i] > 0 ? 1 : -1; 88 | objv += log(1 + exp(- y * predict_[i])); 89 | } 90 | return objv; 91 | } 92 | 93 | private: 94 | dmlc::real_t const* label_; 95 | real_t const* predict_; 96 | size_t size_; 97 | int nt_; 98 | }; 99 | 100 | } // namespace difacto 101 | #endif // DIFACTO_LOSS_BIN_CLASS_METRIC_H_ 102 | -------------------------------------------------------------------------------- /src/sgd/sgd_learner.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_SGD_SGD_LEARNER_H_ 5 | #define DIFACTO_SGD_SGD_LEARNER_H_ 6 | #include 7 | #include 8 | #include "difacto/learner.h" 9 | #include "./sgd_utils.h" 10 | #include "./sgd_updater.h" 11 | #include "./sgd_param.h" 12 | #include "difacto/loss.h" 13 | #include "difacto/store.h" 14 | namespace difacto { 15 | 16 | class SGDLearner : public Learner { 17 | public: 18 | SGDLearner() { 19 | store_ = nullptr; 20 | loss_ = nullptr; 21 | } 22 | 23 | virtual ~SGDLearner() { 24 | delete loss_; 25 | delete store_; 26 | } 27 | KWArgs Init(const KWArgs& kwargs) override; 28 | 29 | void AddEpochEndCallback(const std::function& callback) { 31 | epoch_end_callback_.push_back(callback); 32 | } 33 | 34 | SGDUpdater* GetUpdater() { 35 | return CHECK_NOTNULL(std::static_pointer_cast( 36 | CHECK_NOTNULL(store_)->updater()).get()); 37 | } 38 | 39 | protected: 40 | void RunScheduler() override; 41 | 42 | void Process(const std::string& args, std::string* rets) { 43 | using sgd::Job; 44 | sgd::Progress prog; 45 | Job job; job.ParseFromString(args); 46 | if (job.type == Job::kTraining || 47 | job.type == Job::kValidation) { 48 | IterateData(job, &prog); 49 | } else if (job.type == Job::kEvaluation) { 50 | GetUpdater()->Evaluate(&prog); 51 | } 52 | prog.SerializeToString(rets); 53 | } 54 | 55 | private: 56 | void RunEpoch(int epoch, int job_type, sgd::Progress* prog); 57 | 58 | /** 59 | * \brief iterate on a part of a data 60 | * 61 | * it repeats the following steps 62 | * 63 | * 1. read batch_size examples 64 | * 2. preprogress data (map from uint64 feature index into continous ones) 65 | * 3. pull the newest model for this batch from the servers 66 | * 4. compute the gradients on this batch 67 | * 5. push the gradients to the servers to update the model 68 | * 69 | * to maximize the parallelization of i/o and computation, we uses three 70 | * threads here. they are asynchronized by callbacks 71 | * 72 | * a. main thread does 1 and 2 73 | * b. batch_tracker's thread does 3 once a batch is preprocessed 74 | * c. store_'s threads does 4 and 5 when the weight is pulled back 75 | */ 76 | void IterateData(const sgd::Job& job, sgd::Progress* prog); 77 | 78 | real_t EvaluatePenalty(const SArray& weight, 79 | const SArray& w_pos, 80 | const SArray& V_pos); 81 | void GetPos(const SArray& len, 82 | SArray* w_pos, SArray* V_pos); 83 | /** \brief the model store*/ 84 | Store* store_; 85 | /** \brief the loss*/ 86 | Loss* loss_; 87 | /** \brief parameters */ 88 | SGDLearnerParam param_; 89 | // ProgressPrinter pprinter_; 90 | int blk_nthreads_ = DEFAULT_NTHREADS; 91 | 92 | std::vector> epoch_end_callback_; 94 | }; 95 | 96 | } // namespace difacto 97 | #endif // DIFACTO_SGD_SGD_LEARNER_H_ 98 | -------------------------------------------------------------------------------- /src/data/localizer.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #ifndef DIFACTO_DATA_LOCALIZER_H_ 5 | #define DIFACTO_DATA_LOCALIZER_H_ 6 | #include 7 | #include 8 | #include "difacto/base.h" 9 | #include "dmlc/io.h" 10 | #include "data/row_block.h" 11 | namespace difacto { 12 | 13 | /** 14 | * @brief Compact a rowblock's feature indices 15 | */ 16 | class Localizer { 17 | public: 18 | /** 19 | * \brief constructor 20 | * 21 | * \param max_index feature index will be projected into [0, max_index) by mod 22 | * \param nthreads number of threads 23 | */ 24 | 25 | Localizer(feaid_t max_index = std::numeric_limits::max(), 26 | int nthreads = DEFAULT_NTHREADS) 27 | : max_index_(max_index), nt_(nthreads) { } 28 | ~Localizer() { } 29 | 30 | /** 31 | * \brief compact blk's feature indices 32 | * 33 | * This function maps a RowBlock from arbitrary feature index into continuous 34 | * feature indices starting from 0 35 | * 36 | * @param blk the data block 37 | * @param compacted the new block with feature index remapped 38 | * @param uniq_idx if not null, then return the original unique feature indices 39 | * @param idx_frq if not null, then return the according feature occurance 40 | */ 41 | void Compact(const dmlc::RowBlock& blk, 42 | dmlc::data::RowBlockContainer *compacted, 43 | std::vector* uniq_idx = NULL, 44 | std::vector* idx_frq = NULL) { 45 | std::vector* uidx = 46 | uniq_idx == NULL ? new std::vector() : uniq_idx; 47 | CountUniqIndex(blk, uidx, idx_frq); 48 | RemapIndex(blk, *uidx, compacted); 49 | if (uniq_idx == NULL) delete uidx; 50 | Clear(); 51 | } 52 | 53 | /** 54 | * @brief find the unique indices and count the occurance 55 | * 56 | * This function stores temporal results to accelerate \ref RemapIndex. 57 | * 58 | * @param idx the item list in any order 59 | * @param uniq_idx returns the sorted unique items 60 | * @param idx_frq if not NULL then returns the according occurrence counts 61 | */ 62 | void CountUniqIndex(const dmlc::RowBlock& blk, 63 | std::vector* uniq_idx, 64 | std::vector* idx_frq); 65 | 66 | /** 67 | * @brief Remaps the index. 68 | * 69 | * @param idx_dict the index dictionary, which should be ordered. Any index 70 | * does not exists in this dictionary is dropped. 71 | * 72 | * @param compacted a rowblock with index mapped: idx_dict[i] -> i. 73 | */ 74 | void RemapIndex(const dmlc::RowBlock& blk, 75 | const std::vector& idx_dict, 76 | dmlc::data::RowBlockContainer *compacted); 77 | 78 | /** 79 | * @brief Clears the temporal results 80 | */ 81 | void Clear() { pair_.clear(); } 82 | 83 | private: 84 | feaid_t max_index_; 85 | /** \brief number of threads */ 86 | int nt_; 87 | 88 | #pragma pack(push) 89 | #pragma pack(4) 90 | struct Pair { 91 | feaid_t k; unsigned i; 92 | }; 93 | #pragma pack(pop) 94 | std::vector pair_; 95 | }; 96 | } // namespace difacto 97 | 98 | #endif // DIFACTO_DATA_LOCALIZER_H_ 99 | -------------------------------------------------------------------------------- /src/data/localizer.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | */ 4 | #include "./localizer.h" 5 | #include "dmlc/omp.h" 6 | #include "dmlc/logging.h" 7 | #include "common/parallel_sort.h" 8 | #include "difacto/sarray.h" 9 | namespace difacto { 10 | 11 | void Localizer::CountUniqIndex( 12 | const dmlc::RowBlock& blk, 13 | std::vector* uniq_idx, 14 | std::vector* idx_frq) { 15 | // sort 16 | if (blk.size == 0) return; 17 | size_t idx_size = blk.offset[blk.size]; 18 | CHECK_LT(idx_size, static_cast(std::numeric_limits::max())) 19 | << "you need to change Pair.i from unsigned to uint64"; 20 | pair_.resize(idx_size); 21 | 22 | #pragma omp parallel for num_threads(nt_) 23 | for (size_t i = 0; i < idx_size; ++i) { 24 | pair_[i].k = ReverseBytes(blk.index[i] % max_index_); 25 | pair_[i].i = i; 26 | } 27 | 28 | ParallelSort(&pair_, nt_, 29 | [](const Pair& a, const Pair& b) {return a.k < b.k; }); 30 | 31 | // save data 32 | CHECK_NOTNULL(uniq_idx); 33 | uniq_idx->clear(); 34 | if (idx_frq) idx_frq->clear(); 35 | 36 | feaid_t curr = pair_[0].k; 37 | real_t cnt = 0; 38 | for (size_t i = 0; i < pair_.size(); ++i) { 39 | const Pair& v = pair_[i]; 40 | if (v.k != curr) { 41 | uniq_idx->push_back(curr); 42 | curr = v.k; 43 | if (idx_frq) idx_frq->push_back(cnt); 44 | cnt = 0; 45 | } 46 | ++cnt; 47 | } 48 | uniq_idx->push_back(curr); 49 | if (idx_frq) idx_frq->push_back(cnt); 50 | } 51 | 52 | 53 | void Localizer::RemapIndex( 54 | const dmlc::RowBlock& blk, 55 | const std::vector& idx_dict, 56 | dmlc::data::RowBlockContainer *compacted) { 57 | if (blk.size == 0 || idx_dict.empty()) return; 58 | CHECK_LT(idx_dict.size(), 59 | static_cast(std::numeric_limits::max())); 60 | CHECK_EQ(blk.offset[blk.size], pair_.size()); 61 | 62 | // build the index mapping 63 | unsigned matched = 0; 64 | std::vector remapped_idx(pair_.size(), 0); 65 | auto cur_dict = idx_dict.cbegin(); 66 | auto cur_pair = pair_.cbegin(); 67 | while (cur_dict != idx_dict.cend() && cur_pair != pair_.cend()) { 68 | if (*cur_dict < cur_pair->k) { 69 | ++cur_dict; 70 | } else { 71 | if (*cur_dict == cur_pair->k) { 72 | remapped_idx[cur_pair->i] 73 | = static_cast((cur_dict-idx_dict.cbegin()) + 1); 74 | ++matched; 75 | } 76 | ++cur_pair; 77 | } 78 | } 79 | 80 | // construct the new rowblock 81 | auto o = compacted; 82 | CHECK_NOTNULL(o); 83 | o->offset.resize(blk.size+1); o->offset[0] = 0; 84 | o->index.resize(matched); 85 | if (blk.value) o->value.resize(matched); 86 | 87 | size_t k = 0; 88 | for (size_t i = 0; i < blk.size; ++i) { 89 | for (size_t j = blk.offset[i]; j < blk.offset[i+1]; ++j) { 90 | if (remapped_idx[j] == 0) continue; 91 | if (blk.value) o->value[k] = blk.value[j]; 92 | o->index[k++] = remapped_idx[j] - 1; 93 | } 94 | o->offset[i+1] = k; 95 | } 96 | CHECK_EQ(k, matched); 97 | 98 | if (blk.label) { 99 | o->label.resize(blk.size); 100 | memcpy(o->label.data(), blk.label, blk.size*sizeof(*blk.label)); 101 | } 102 | o->max_index = idx_dict.size() - 1; 103 | } 104 | 105 | } // namespace difacto 106 | -------------------------------------------------------------------------------- /src/reader/criteo_parser.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015 by Contributors 3 | * @file criteo_parser.h 4 | * @brief parse criteo ctr data format 5 | */ 6 | #ifndef DIFACTO_READER_CRITEO_PARSER_H_ 7 | #define DIFACTO_READER_CRITEO_PARSER_H_ 8 | #include 9 | #if DIFACTO_USE_CITY 10 | #include 11 | #endif // DIFACTO_USE_CITY 12 | #include 13 | #include "difacto/base.h" 14 | #include "data/row_block.h" 15 | #include "data/parser.h" 16 | #include "data/strtonum.h" 17 | namespace difacto { 18 | 19 | /** 20 | * \brief criteo ctr dataset: 21 | * The columns are tab separeted with the following schema: 22 | *