├── .dockerignore ├── .gitignore ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── doc ├── Doxyfile ├── README.md └── gendoc.sh ├── docker ├── README.md ├── build.sh ├── cloud.sh ├── fire_amazonec2.sh ├── get_host_within_container.sh ├── local.sh ├── rm_local.sh ├── shut.sh └── upload_s3.sh ├── example └── linear │ ├── README.md │ ├── criteo │ ├── README.md │ ├── batch_l1lr.conf │ ├── download.sh │ └── eval_batch.conf │ ├── ctr │ ├── batch_l1lr.conf │ ├── download.sh │ ├── eval_batch.conf │ ├── eval_online.conf │ └── online_l1lr.conf │ └── rcv1 │ ├── batch_l1lr.conf │ ├── download.sh │ ├── eval_batch.conf │ ├── eval_online.conf │ └── online_l1lr.conf ├── make ├── README.md └── config.mk ├── script ├── get_root_node.sh ├── install_third.sh ├── kill_node.sh ├── local.sh ├── mpi_node.sh ├── mpi_root.sh └── ps.sh └── src ├── README.mk ├── app └── linear_method │ ├── async_sgd.h │ ├── darlin.h │ ├── learning_rate.h │ ├── loss.h │ ├── main.cc │ ├── model_evaluation.h │ ├── penalty.h │ └── proto │ └── linear.proto ├── data ├── common.cc ├── common.h ├── info_parser.cc ├── info_parser.h ├── matlab │ ├── bin2mat.m │ ├── filter_fea.m │ ├── load_bin.m │ ├── mat2bin.m │ ├── save_bin.m │ └── saveas_pserver.m ├── proto │ ├── data.proto │ └── example.proto ├── show_example.h ├── slot_reader.cc ├── slot_reader.h ├── stream_reader.h ├── text2proto.h ├── text_parser.cc └── text_parser.h ├── filter ├── add_noise.h ├── compressing.h ├── filter.cc ├── filter.h ├── fixing_float.h ├── frequency_filter.h ├── key_caching.h ├── proto │ └── filter.proto └── sparse_filter.h ├── learner ├── bcd.cc ├── bcd.h ├── proto │ ├── bcd.proto │ ├── sgd.proto │ └── workload.proto ├── sgd.cc ├── sgd.h ├── workload_pool.cc └── workload_pool.h ├── parameter ├── README.org ├── kv_layer.h ├── kv_map.h ├── kv_store.h ├── kv_vector.h ├── parameter.cc ├── parameter.h └── proto │ └── param.proto ├── ps.h ├── ps_main.cc ├── system ├── assigner.cc ├── assigner.h ├── customer.h ├── dashboard.cc ├── dashboard.h ├── env.cc ├── env.h ├── executor.cc ├── executor.h ├── heartbeat_info.cc ├── heartbeat_info.h ├── manager.cc ├── manager.h ├── message.cc ├── message.h ├── monitor.h ├── postoffice.cc ├── postoffice.h ├── proto │ ├── heartbeat.proto │ ├── node.proto │ └── task.proto ├── remote_node.cc ├── remote_node.h ├── task_tracker.h ├── van.cc └── van.h ├── test ├── aggregation_ps.cc ├── assign_op_test.cc ├── bloom_filter_test.cc ├── build.mk ├── common_test.cc ├── countmin_test.cc ├── fixing_float_test.cc ├── hello_ps.cc ├── kv_layer_perf_ps.cc ├── kv_layer_ps.cc ├── kv_map_perf_ps.cc ├── kv_map_ps.cc ├── kv_vector_buffer_ps.cc ├── kv_vector_perf_ps.cc ├── kv_vector_ps.cc ├── localizer_test.cc ├── network_perf_ps.cc ├── parallel_ordered_match_test.cc ├── reassign_server_key_range_ps.cc ├── slot_reader_test.cc ├── sparse_matrix_perf.cc ├── sparse_matrix_test.cc └── stream_reader_test.cc ├── test_main.cc └── util ├── assign_op.h ├── auc.h ├── barrier.h ├── bitmap.h ├── block_bloom_filter.h ├── bloom_filter.h ├── common.h ├── countmin.h ├── crc32c.cc ├── crc32c.h ├── dense_matrix.h ├── evaluation.h ├── file.cc ├── file.h ├── filelinereader.cc ├── filelinereader.h ├── hdfs.h ├── integral_types.h ├── local_machine.h ├── localizer.h ├── macros.h ├── matrix.h ├── murmurhash3.cc ├── murmurhash3.h ├── parallel_ordered_match.h ├── parallel_sort.h ├── producer_consumer.h ├── proto ├── assign_op.proto ├── auc.proto ├── matrix.proto └── range.proto ├── range.h ├── recordio.cc ├── recordio.h ├── resource_usage.h ├── shared_array.h ├── shared_array_inl.h ├── sketch.h ├── sparse_matrix.h ├── split.h ├── strtonum.h ├── threadpool.cc ├── threadpool.h ├── threadsafe_limited_queue.h └── threadsafe_queue.h /.dockerignore: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | build 4 | example 5 | doc 6 | script 7 | third_party 8 | 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | example/**/data 2 | example/**/model 3 | build 4 | third_party 5 | **/log/** 6 | **/cache/** 7 | **/output/** 8 | *.pb.cc 9 | *.pb.h 10 | .* 11 | *core* 12 | config.mk 13 | **/bak/** 14 | doc/html 15 | doc/latex 16 | test/ 17 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM qicongc/ps-baseimage 2 | WORKDIR /home/parameter_server 3 | ADD src src 4 | ADD make make 5 | ADD Makefile Makefile 6 | # compile ps 7 | RUN make -j8 8 | # add script used by container to get the ip of its docker host provided by its cloud_provider 9 | ADD docker/get_host_within_container.sh get_host_within_container.sh 10 | # TODO: install dependency to /usr 11 | ENV LD_LIBRARY_PATH /usr/local/lib 12 | # run ps according to args passed in 13 | CMD build/linear -bind_to $my_port -my_node "role:$my_role,hostname:'`./get_host_within_container.sh`',port:$my_port,id:'$my_id'" -scheduler "role:SCHEDULER,hostname:'$scheduler_host',port:$scheduler_port,id:'H'" -app_file $app_file -num_servers $num_servers -num_workers $num_workers $args 14 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ifneq ("$(wildcard ./config.mk)","") 2 | include ./config.mk 3 | else 4 | include make/config.mk 5 | endif 6 | 7 | ifeq ($(STATIC_THIRD_LIB), 1) 8 | THIRD_LIB=$(addprefix $(THIRD_PATH)/lib/, libgflags.a libzmq.a libprotobuf.a libglog.a libz.a libsnappy.a) 9 | ifeq ($(USE_S3),1) 10 | THIRD_LIB+=$(addprefix $(THIRD_PATH)/lib/, libxml2.a) 11 | endif 12 | else 13 | THIRD_LIB=-L$(THIRD_PATH)/lib -lgflags -lzmq -lprotobuf -lglog -lz -lsnappy 14 | ifeq ($(USE_S3),1) 15 | THIRD_LIB+=-lxml2 16 | endif 17 | endif 18 | 19 | WARN = -Wall -Wno-unused-function -finline-functions -Wno-sign-compare #-Wconversion 20 | INCPATH = -I./src -I$(THIRD_PATH)/include 21 | CFLAGS = -std=c++0x $(WARN) $(OPT) $(INCPATH) $(EXTRA_CFLAGS) 22 | ifeq ($(USE_S3), 1) 23 | CFLAGS += -DUSE_S3=1 24 | endif 25 | LDFLAGS = $(EXTRA_LDFLAGS) $(THIRD_LIB) -lpthread # -lrt 26 | 27 | PS_LIB = build/libps.a 28 | PS_MAIN = build/libpsmain.a 29 | TEST_MAIN = build/test_main.o 30 | 31 | clean: 32 | rm -rf build 33 | find src -name "*.pb.[ch]*" -delete 34 | 35 | ps: $(PS_LIB) $(PS_MAIN) $(TEST_MAIN) 36 | 37 | # PS system 38 | sys_dir = $(addprefix src/, util data system filter learner parameter) 39 | sys_srcs = $(wildcard $(patsubst %, %/*.cc, $(sys_dir))) 40 | sys_protos = $(wildcard $(patsubst %, %/proto/*.proto, $(sys_dir))) 41 | sys_objs = $(patsubst src/%.proto, build/%.pb.o, $(sys_protos)) \ 42 | $(patsubst src/%.cc, build/%.o, $(sys_srcs)) 43 | 44 | build/libps.a: $(patsubst %.proto, %.pb.h, $(sys_protos)) $(sys_objs) 45 | ar crv $@ $(filter %.o, $?) 46 | 47 | build/libpsmain.a: build/ps_main.o 48 | ar crv $@ $? 49 | 50 | # applications 51 | build/linear: $(addprefix build/app/linear_method/, proto/linear.pb.o main.o) $(PS_LIB) 52 | $(CC) $(CFLAGS) $^ $(LDFLAGS) -o $@ 53 | 54 | # general rules 55 | build/%.o: src/%.cc 56 | @mkdir -p $(@D) 57 | $(CC) $(INCPATH) -std=c++0x -MM -MT build/$*.o $< >build/$*.d 58 | $(CC) $(CFLAGS) -c $< -o $@ 59 | 60 | %.pb.cc %.pb.h : %.proto 61 | ${THIRD_PATH}/bin/protoc --cpp_out=./src --proto_path=./src $< 62 | 63 | -include build/*/*.d 64 | -include build/*/*/*.d 65 | -include src/test/build.mk 66 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Parameter Server 2 | 3 | The parameter server is a distributed system scaling to industry size machine 4 | learning problems. It provides asynchronous and zero-copy key-value pair 5 | communications between worker machines and server machines. It also supports 6 | flexible data consistency model, data filters, and flexible server machine 7 | programming. 8 | 9 | **NOTE: We stop maitaining this repo. Please check the newer version called [ps-lite](https://github.com/dmlc/ps-lite)** 10 | 11 | - [Document](doc/) 12 | - [Wiki](https://github.com/dmlc/parameter_server/wiki/) 13 | - How to [build](make/) 14 | - Examples 15 | - [Linear method](example/linear), [Linear method with Cloud](docker) 16 | - Deep neural network, see [CXXNET](https://github.com/dmlc/cxxnet) and [Minverva](https://github.com/minerva-developers/minerva) 17 | -------------------------------------------------------------------------------- /doc/README.md: -------------------------------------------------------------------------------- 1 | # Documents 2 | 3 | Use `gendoc.sh` to generate Doxygen documents. 4 | -------------------------------------------------------------------------------- /doc/gendoc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dir=`dirname "$0"` 4 | cd $dir/.. 5 | doxygen doc/Doxyfile 6 | -------------------------------------------------------------------------------- /docker/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $# -lt 2 ]]; then 4 | echo "usage: $0 manager image" 5 | exit -1 6 | fi 7 | 8 | if [[ ! -e ../Dockerfile ]]; then 9 | echo "cannot find Dockerfile! this command must be excuted under docker/" 10 | exit -1 11 | fi 12 | 13 | manager=$1 14 | shift 15 | 16 | image=$1 17 | shift 18 | 19 | eval "`docker-machine env $manager`" 20 | docker build -t $image .. 21 | docker push $image -------------------------------------------------------------------------------- /docker/cloud.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [[ $# -lt 6 ]]; then 4 | echo "usage: $0 cloud_provider manager image num_servers num_workers app_conf_file [args...]" 5 | exit -1 6 | fi 7 | 8 | cloud_provider=$1 9 | shift 10 | 11 | manager=$1 12 | shift 13 | 14 | mount="-v /tmp/parameter_server/data/cache:/home/parameter_server/data/cache" 15 | case $cloud_provider in 16 | amazonec2) 17 | mount="$mount -v /var/log/cloud-init.log:/var/log/cloud-init.log" 18 | scheduler_host=`docker-machine ssh $manager cat /var/log/cloud-init.log | awk '/update hostname/ {print $10}'` 19 | ;; 20 | *) 21 | echo "Currently only support amazonec2!" 22 | exit -1 23 | ;; 24 | esac 25 | 26 | 27 | image=$1 28 | shift 29 | 30 | num_servers=$1 31 | shift 32 | 33 | num_workers=$1 34 | shift 35 | 36 | app_file=$1 37 | shift 38 | 39 | args=$@ 40 | 41 | # stop all running containers 42 | echo "cleaning previous apps ..." 43 | eval "`docker-machine env --swarm $manager`" 44 | clean_list=`docker ps -q` 45 | docker stop $clean_list > /dev/null 2>&1 46 | echo "update parameter server image cluster wide ..." 47 | docker pull $image 48 | 49 | 50 | # launch scheduler 51 | echo "launching scheduler ..." 52 | eval "`docker-machine env $manager`" 53 | scheduler_port=8000 54 | env="\ 55 | -e cloud_provider=$cloud_provider \ 56 | -e scheduler_host=$scheduler_host \ 57 | -e scheduler_port=$scheduler_port \ 58 | -e my_role=SCHEDULER \ 59 | -e my_host=$scheduler_host \ 60 | -e my_port=$scheduler_port \ 61 | -e my_id=H \ 62 | -e num_servers=$num_servers \ 63 | -e num_workers=$num_workers \ 64 | -e app_file=$app_file \ 65 | " 66 | docker rm n0 > /dev/null 2>&1 67 | docker run -d -p $scheduler_port:$scheduler_port $env -e "args=$args" --name n0 $image 68 | 69 | #launch servers and workers 70 | echo "launching workers and servers ..." 71 | eval "`docker-machine env --swarm $manager`" 72 | for (( i = 1; i < $num_servers + $num_workers + 1; ++i )); do 73 | my_port=$(($scheduler_port + $i)) 74 | if (( $i <= $num_servers )); then 75 | my_role="SERVER" 76 | my_id="S$i" 77 | else 78 | my_role="WORKER" 79 | my_id="W$i" 80 | fi 81 | env="\ 82 | -e cloud_provider=$cloud_provider \ 83 | -e scheduler_host=$scheduler_host \ 84 | -e scheduler_port=$scheduler_port \ 85 | -e my_role=$my_role \ 86 | -e my_port=$my_port \ 87 | -e my_id=$my_id \ 88 | -e num_servers=$num_servers \ 89 | -e num_workers=$num_workers \ 90 | -e app_file=$app_file \ 91 | " 92 | docker rm n$i > /dev/null 2>&1 93 | docker run -d -p $my_port:$my_port $env -e "args=$args" $mount --name n$i $image & 94 | done 95 | 96 | wait 97 | -------------------------------------------------------------------------------- /docker/fire_amazonec2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [[ $# -lt 10 ]]; then 3 | echo "usage: $0 n-instances amazonec2-instance-type spot|regular price-to-bid amazonec2-access-key amazonec2-secret-key amazonec2-region amazonec2-vpc-id amazonec2-security-group amazonec2-zone" 4 | exit -1 5 | fi 6 | 7 | echo "check whether docker is ready ..." 8 | docker version 9 | if [[ $? = 1 ]]; then 10 | echo "please point your docker client to the right docker server!" 11 | exit -1 12 | fi 13 | 14 | if [[ ! `which docker-machine` ]]; then 15 | rm -rf machine 16 | git clone https://github.com/docker/machine 17 | cd machine 18 | # check platform 19 | platform=`uname` 20 | if [[ $platform = 'Darwin' ]];then 21 | platform='darwin' 22 | else 23 | platform='linux' 24 | fi 25 | script/build -osarch="$platform/amd64" 26 | mv docker-machine_$platform* /usr/local/bin/docker-machine 27 | cd .. 28 | rm -rf machine 29 | fi 30 | 31 | n=$1 32 | shift 33 | 34 | amazonec2_instance_type=$1 35 | shift 36 | 37 | if [[ $1 = 'spot' ]]; then 38 | shift 39 | amazonec2_request_spot_instance="--amazonec2-request-spot-instance --amazonec2-spot-price $1" 40 | shift 41 | elif [[ $1 = 'regular' ]]; then 42 | shift 43 | shift 44 | else 45 | echo "Currently only support regular and spot, but get $1." 46 | exit -1 47 | fi 48 | 49 | 50 | 51 | amazonec2_access_key=$1 52 | shift 53 | 54 | amazonec2_secret_key=$1 55 | shift 56 | 57 | amazonec2_region=$1 58 | shift 59 | 60 | amazonec2_vpc_id=$1 61 | shift 62 | 63 | amazonec2_security_group=$1 64 | shift 65 | 66 | amazonec2_zone=$1 67 | shift 68 | 69 | discovery="token://`docker run swarm create`" 70 | 71 | docker-machine create -d amazonec2 --amazonec2-access-key $amazonec2_access_key --amazonec2-secret-key $amazonec2_secret_key --amazonec2-region $amazonec2_region --amazonec2-vpc-id $amazonec2_vpc_id --amazonec2-security-group $amazonec2_security_group --amazonec2-zone $amazonec2_zone --amazonec2-instance-type $amazonec2_instance_type $amazonec2_request_spot_instance --swarm --swarm-master --swarm-discovery $discovery swarm-master & 72 | for (( i = 0; i < n-1; i++ )); do 73 | docker-machine create -d amazonec2 --amazonec2-access-key $amazonec2_access_key --amazonec2-secret-key $amazonec2_secret_key --amazonec2-region $amazonec2_region --amazonec2-vpc-id $amazonec2_vpc_id --amazonec2-security-group $amazonec2_security_group --amazonec2-zone $amazonec2_zone --amazonec2-instance-type $amazonec2_instance_type $amazonec2_request_spot_instance --swarm --swarm-discovery $discovery swarm-node-$i & 74 | done 75 | wait 76 | 77 | 78 | -------------------------------------------------------------------------------- /docker/get_host_within_container.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [[ $my_host ]]; then 3 | echo "$my_host" 4 | else 5 | case $cloud_provider in 6 | amazonec2) 7 | cat /var/log/cloud-init.log | awk '/update hostname/ {print $10}' 8 | ;; 9 | *) 10 | echo "Currently only support amazonec2!" 11 | exit -1 12 | ;; 13 | esac 14 | fi 15 | 16 | 17 | -------------------------------------------------------------------------------- /docker/local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -lt 5 ]; then 4 | echo "usage: $0 num_servers num_workers app_conf_file data_dir model_dir [args...]" 5 | exit -1; 6 | fi 7 | 8 | ip=`/sbin/ifconfig docker0 | grep inet | grep -v inet6 | awk '{print $2}' | sed -e 's/[a-z]*://'` 9 | if [ -z ${ip} ]; then 10 | echo "failed to get the ip address for docker" 11 | exit -1 12 | fi 13 | 14 | num_servers=$1 15 | shift 16 | 17 | num_workers=$1 18 | shift 19 | 20 | app=$1 21 | if [[ "$app" != /* ]]; then 22 | app=`pwd`/$app 23 | fi 24 | shift 25 | 26 | data=$1 27 | if [[ "$data" != /* ]]; then 28 | data=`pwd`/$data 29 | fi 30 | shift 31 | 32 | model=$1 33 | if [[ "$model" != /* ]]; then 34 | model=`pwd`/$model 35 | fi 36 | shift 37 | 38 | port=8000 39 | bin="muli/parameter-server /build/ps" 40 | # bin_v="-v /home/muli/work/ps/build:/build" 41 | app_v="-v $app:/app.conf" 42 | data_v="-v $data:/data -v $model:/model" 43 | mount="$bin_v $app_v $data_v" 44 | 45 | arg="-app_file /app.conf -num_servers $num_servers -num_workers $num_workers $@" 46 | 47 | sch="role:SCHEDULER,hostname:'$ip',port:$port,id:'H'" 48 | for (( i = 0; i < ${num_servers} + ${num_workers} + 1; ++i )); do 49 | myport=$(($port + ${i})) 50 | if (( $i == 0)); then 51 | node=$sch 52 | elif (( $i <= ${num_servers} )); then 53 | node="role:SERVER,hostname:'$ip',port:${myport},id:'S${i}'" 54 | else 55 | node="role:WORKER,hostname:'$ip',port:${myport},id:'W${i}'" 56 | fi 57 | docker run --rm -p $myport:$myport --name n${i} $mount $bin \ 58 | -my_node ${node} -scheduler ${sch} ${arg} & 59 | done 60 | 61 | wait 62 | -------------------------------------------------------------------------------- /docker/rm_local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | docker rm -f $(docker ps -a -q) 3 | -------------------------------------------------------------------------------- /docker/shut.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [[ $# -lt 1 ]]; then 3 | echo "usage: $0 n-instances" 4 | exit -1 5 | fi 6 | 7 | n=$1 8 | docker-machine stop swarm-master > /dev/null 2>&1 & 9 | for (( i = 0; i < n-1; i++ )); do 10 | docker-machine stop swarm-node-$i > /dev/null 2>&1 & 11 | done 12 | wait 13 | docker-machine rm -f swarm-master & 14 | for (( i = 0; i < n-1; i++ )); do 15 | docker-machine rm -f swarm-node-$i & 16 | done 17 | wait 18 | -------------------------------------------------------------------------------- /docker/upload_s3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [[ $# -lt 2 ]]; then 3 | echo "usage: $0 local_file s3_path" 4 | exit -1 5 | fi 6 | 7 | local_file=$1 8 | shift 9 | 10 | if [[ ! -e $local_file ]]; then 11 | echo "$local_file does not exist!" 12 | exit -1 13 | fi 14 | 15 | s3_path=$1 16 | shift 17 | 18 | bucket=`echo $s3_path | cut -d'/' -f3` 19 | path=${s3_path#s3://$bucket} 20 | 21 | length=`stat $local_file | awk '{print $8}'` 22 | curl "http://$bucket.s3.amazonaws.com$path?Content-Length=$length&x-amz-acl=public-read" --upload-file $local_file 23 | 24 | 25 | -------------------------------------------------------------------------------- /example/linear/README.md: -------------------------------------------------------------------------------- 1 | # Tutorial to run linear method 2 | 3 | **Prepearing Data** 4 | 5 | Use the script such as `rcv1/download.sh` and `ctr/download.sh` to prepare 6 | sample data. 7 | 8 | 9 | **Run L1 Logistic Regression on the CTR dataset** 10 | 11 | Let's first start one worker and one server in the local machine to run the online 12 | solver. 13 | ```bash 14 | ../../script/ps.sh start ../../build/linear -app_file ctr/online_l1lr.conf 15 | ``` 16 | 17 | Next we use 3 workers and 4 servers: 18 | 19 | ```bash 20 | ../../script/ps.sh start -nw 3 -ns 4 ../../build/linear -app_file ctr/online_l1lr.conf 21 | ``` 22 | 23 | We can also start the jobs in multiple machines. Assume there is a `your_hostfile` 24 | containing all machines IPs line by line. Then start the job via 25 | the `-hostfile` option. 26 | ```bash 27 | ../../script/ps.sh start -hostfile your_hostfile -nw 3 -ns 4 ../../build/linear -app_file ctr/online_l1lr.conf 28 | ``` 29 | 30 | Finally, we can change the application by the `-app_file` option. For example, Evaluate the 31 | trained model 32 | ```bash 33 | ../../script/ps.sh start ../../build/linear -app_file ctr/eval_online.conf 34 | ``` 35 | or train by the batch solver 36 | ```bash 37 | ../../script/ps.sh start ../../build/linear -app_file ctr/batch_l1lr.conf 38 | ``` 39 | 40 | See more information about the configuration file in [linear.proto](../../src/app/linear_method/proto/linear.proto) 41 | 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /example/linear/criteo/batch_l1lr.conf: -------------------------------------------------------------------------------- 1 | training_data { 2 | format: TEXT 3 | text: CRITEO 4 | file: "data/criteo/train/part.*" 5 | 6 | # If the data is placed on hdfs and HADOOP_HOME="/usr" 7 | # hdfs { 8 | # home: "/usr" 9 | # } 10 | } 11 | 12 | model_output { 13 | format: TEXT 14 | file: "model/criteo_batch" 15 | } 16 | 17 | loss { 18 | type: LOGIT 19 | } 20 | 21 | # lambda * \| w \|_1 22 | penalty { 23 | type: L1 24 | lambda: 4 25 | lambda: 1 26 | } 27 | 28 | learning_rate { 29 | type: CONSTANT 30 | alpha: .9 31 | } 32 | 33 | darlin { 34 | # the max number of data passes 35 | max_pass_of_data : 50 36 | # convergance critiria. stop if the relative objective <= epsilon 37 | epsilon : 2e-5 38 | 39 | save_model_every_n_iter: 20 40 | 41 | # The maximal number of blocks can be updating in parallel (bounded-delay 42 | # consistency). A larger delay may slow down the convergence rate, but improves 43 | # the system performance. 44 | max_block_delay: 2 45 | 46 | # features which occurs <= *tail_feature_freq* will be filtered before 47 | # training. it save both memory and bandwidth. 48 | tail_feature_freq: 4 49 | 50 | # It controls the countmin size. We filter the tail features by countmin, which 51 | # is more efficient than hash, but still is the memory bottleneck for servers. A 52 | # smaller ratio reduces the memory footprint, but may increase the size of 53 | # filtered feature. 54 | 55 | countmin_n_ratio: .66 56 | 57 | # During preprocessing, each (text) file is parsed and then write into the local 58 | # cache in binary format to save the memory. These data are then used by the 59 | # preprocessing stage, and also can be re-used when running next time. 60 | local_cache { 61 | format: BIN 62 | file: "data/cache/criteo_train_" 63 | } 64 | 65 | comm_filter { 66 | type: KEY_CACHING 67 | } 68 | 69 | # load_local_data: true 70 | 71 | # Parameters used by the trust region method. The change of w_i (the i-th 72 | # parameter) is bouned by [-delta_i, delta_i], where delta_i is an adaptive 73 | # value according to the convergence. The initial value of delta_i is 74 | # *delta_init_value* and maximal value is *delta_max_value*. You can increase 75 | # these parameters for easy datasets. 76 | 77 | # [PS.LM.delta_init_value] : 1 78 | # [PS.LM.delta_max_value] : 5 79 | 80 | # This parameter controls the aggressiveness of the KKT filter. Increasing this 81 | # number will decrease the effect of KKT filter. a very large number, such as 82 | # 1e20 will turn off the KKT filter. 83 | 84 | # [PS.LM.kkt_filter_threshold_ratio] : 10 85 | } 86 | -------------------------------------------------------------------------------- /example/linear/criteo/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dir=`dirname "$0"` 4 | cd $dir/../data 5 | 6 | 7 | if [ ! -f train.txt ]; then 8 | if [ ! -f dac.tar.gz ]; then 9 | wget https://s3-eu-west-1.amazonaws.com/criteo-labs/dac.tar.gz 10 | fi 11 | tar -zxvf dac.tar.gz 12 | fi 13 | 14 | echo "split train.txt..." 15 | mkdir -p criteo/train 16 | split -n l/18 --numeric-suffixes=1 --suffix-length=3 train.txt criteo/train/part- 17 | 18 | echo "make a test set" 19 | mkdir -p criteo/test 20 | mv criteo/train/part-01[7-8] criteo/test 21 | -------------------------------------------------------------------------------- /example/linear/criteo/eval_batch.conf: -------------------------------------------------------------------------------- 1 | validation_data { 2 | format: TEXT 3 | text: CRITEO 4 | file: "data/criteo/test/part.*" 5 | } 6 | 7 | model_input { 8 | format: TEXT 9 | file: "model/criteo_batch_S.*" 10 | } 11 | -------------------------------------------------------------------------------- /example/linear/ctr/batch_l1lr.conf: -------------------------------------------------------------------------------- 1 | training_data { 2 | format: TEXT 3 | text: SPARSE_BINARY 4 | file: "data/ctr/train/part.*" 5 | 6 | # If the data is placed on hdfs and HADOOP_HOME="/usr" 7 | # hdfs { 8 | # home: "/usr" 9 | # } 10 | } 11 | 12 | model_output { 13 | format: TEXT 14 | file: "model/ctr_batch" 15 | } 16 | 17 | loss { 18 | type: LOGIT 19 | } 20 | 21 | # lambda * \| w \|_1 22 | penalty { 23 | type: L1 24 | lambda: 10 25 | } 26 | 27 | learning_rate { 28 | type: CONSTANT 29 | alpha: 1 30 | } 31 | 32 | darlin { 33 | # the max number of data passes 34 | max_pass_of_data : 20 35 | # convergance critiria. stop if the relative objective <= epsilon 36 | epsilon : 2e-5 37 | 38 | # Features are partitioned into groups and each time only one group is 39 | # updated. We divide a feature group into 40 | # feature_block_ratio * nnz_feature_per_example 41 | # blocks. A larger ratio often accelerate the convergence, however, it may slow 42 | # down the system performance because of the increased number of global barriers. 43 | feature_block_ratio : 4 44 | 45 | # The maximal number of blocks can be updating in parallel (bounded-delay 46 | # consistency). A larger delay may slow down the convergence rate, but improves 47 | # the system performance. 48 | max_block_delay: 0 49 | 50 | # important feature groups, update them earlier to get a better model 51 | # initialization. 52 | prior_fea_group: 127 53 | prior_fea_group: 120 54 | 55 | # features which occurs <= *tail_feature_freq* will be filtered before 56 | # training. it save both memory and bandwidth. 57 | tail_feature_freq: 4 58 | 59 | # It controls the countmin size. We filter the tail features by countmin, which 60 | # is more efficient than hash, but still is the memory bottleneck for servers. A 61 | # smaller ratio reduces the memory footprint, but may increase the size of 62 | # filtered feature. 63 | 64 | countmin_n_ratio: .66 65 | 66 | # In preprocessing, feature group is processed one by one. It is the main memory 67 | # bottleneck for workers. This number control how many feature groups can be in 68 | # memory at the same time. A smaller number reduce the workers' memory 69 | # footprint, but may slow down the preprocessing speed. 70 | 71 | # max_num_parallel_groups_in_preprocessing: 1000 72 | 73 | # During preprocessing, each (text) file is parsed and then write into the local 74 | # cache in binary format to save the memory. These data are then used by the 75 | # preprocessing stage, and also can be re-used when running next time. 76 | local_cache { 77 | format: BIN 78 | file: "data/cache/ctr_train_" 79 | } 80 | 81 | # Parameters used by the trust region method. The change of w_i (the i-th 82 | # parameter) is bouned by [-delta_i, delta_i], where delta_i is an adaptive 83 | # value according to the convergence. The initial value of delta_i is 84 | # *delta_init_value* and maximal value is *delta_max_value*. You can increase 85 | # these parameters for easy datasets. 86 | 87 | # [PS.LM.delta_init_value] : 1 88 | # [PS.LM.delta_max_value] : 5 89 | 90 | # This parameter controls the aggressiveness of the KKT filter. Increasing this 91 | # number will decrease the effect of KKT filter. a very large number, such as 92 | # 1e20 will turn off the KKT filter. 93 | 94 | # [PS.LM.kkt_filter_threshold_ratio] : 10 95 | } 96 | -------------------------------------------------------------------------------- /example/linear/ctr/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dir=`dirname "$0"` 4 | git clone https://github.com/mli/ctr-data $dir/../data/ctr 5 | -------------------------------------------------------------------------------- /example/linear/ctr/eval_batch.conf: -------------------------------------------------------------------------------- 1 | validation_data { 2 | format: TEXT 3 | text: SPARSE_BINARY 4 | file: "data/ctr/test/part.*" 5 | } 6 | 7 | model_input { 8 | format: TEXT 9 | file: "model/ctr_batch.*" 10 | } 11 | -------------------------------------------------------------------------------- /example/linear/ctr/eval_online.conf: -------------------------------------------------------------------------------- 1 | validation_data { 2 | format: TEXT 3 | # text: SPARSE_BINARY 4 | # file: "data/ctr/test/part.*" 5 | 6 | text: ADFEA 7 | file: "/home/muli/work/data/ctrc-rand/part-05.*" 8 | } 9 | 10 | model_input { 11 | format: TEXT 12 | file: "model/ctr_online.*" 13 | } 14 | -------------------------------------------------------------------------------- /example/linear/ctr/online_l1lr.conf: -------------------------------------------------------------------------------- 1 | training_data { 2 | format: TEXT 3 | text: SPARSE_BINARY 4 | file: "data/ctr/train/part.*" 5 | } 6 | 7 | model_output { 8 | format: TEXT 9 | file: "model/ctr_online" 10 | } 11 | 12 | loss { type: LOGIT } 13 | 14 | # lambda_0 * |w|_1 + lambda_1 * |w|^2_2 15 | penalty { 16 | type: L1 17 | lambda: 10 18 | lambda: 1 19 | } 20 | 21 | # lr = alpha: .1 22 | learning_rate { 23 | type: DECAY 24 | alpha: .01 25 | beta: 10 26 | } 27 | 28 | # see more config options in linear.proto 29 | async_sgd { 30 | algo: FTRL 31 | # The size of minibatch 32 | minibatch: 10000 33 | # The number of data passes 34 | num_data_pass: 10 35 | 36 | push_filter { 37 | type: KEY_CACHING 38 | clear_cache_if_done: true 39 | } 40 | 41 | push_filter { 42 | type: FIXING_FLOAT 43 | num_bytes: 1 44 | } 45 | 46 | pull_filter { 47 | type: KEY_CACHING 48 | } 49 | 50 | pull_filter { 51 | type: FIXING_FLOAT 52 | num_bytes: 1 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /example/linear/rcv1/batch_l1lr.conf: -------------------------------------------------------------------------------- 1 | training_data { 2 | format: TEXT 3 | text: LIBSVM 4 | file: "data/rcv1/train/part.*" 5 | } 6 | 7 | model_output { 8 | format: TEXT 9 | file: "model/rcv1_batch_l1lr" 10 | } 11 | 12 | loss { 13 | type: LOGIT 14 | } 15 | 16 | # lambda * |w|_1 17 | penalty { 18 | type: L1 19 | lambda: 1 20 | } 21 | 22 | learning_rate { 23 | type: CONSTANT 24 | alpha: 1 25 | } 26 | 27 | darlin { 28 | # max number pass of traing data 29 | max_pass_of_data : 20 30 | # convergance critiria. stop if the relative objective <= epsilon 31 | epsilon : 1e-4 32 | 33 | # temp data 34 | local_cache { 35 | format: BIN 36 | file: "data/cache/rcv1_train" 37 | } 38 | 39 | # debug 40 | # feature_block_ratio : .0001 41 | # tail_feature_freq: 1 42 | # [PS.LM.kkt_filter_threshold_ratio] : 1e20 43 | } 44 | -------------------------------------------------------------------------------- /example/linear/rcv1/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dir=`dirname "$0"` 4 | cd $dir/../data 5 | 6 | # download 7 | if ! [ -e rcv1_train.binary ]; then 8 | wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/rcv1_train.binary.bz2 9 | bunzip2 rcv1_train.binary.bz2 10 | fi 11 | 12 | if ! [ -e rcv1_test.binary ]; then 13 | wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/rcv1_test.binary.bz2 14 | bunzip2 rcv1_test.binary.bz2 15 | fi 16 | 17 | 18 | name=(train test) 19 | for t in "${name[@]}" 20 | do 21 | echo $t 22 | # shuffle 23 | rnd=rcv1_${t}_rand 24 | shuf rcv1_${t}.binary >$rnd 25 | 26 | # split 27 | mkdir -p rcv1/${t} 28 | rm -f rcv1/${t}/* 29 | split -n l/8 --numeric-suffixes=1 --suffix-length=3 $rnd rcv1/${t}/part- 30 | rm $rnd 31 | done 32 | 33 | # swap train and test 34 | mv rcv1/train tmp 35 | mv rcv1/test rcv1/train 36 | mv tmp rcv1/test 37 | -------------------------------------------------------------------------------- /example/linear/rcv1/eval_batch.conf: -------------------------------------------------------------------------------- 1 | validation_data { 2 | format: TEXT 3 | text: LIBSVM 4 | file: "data/rcv1/test/part.*" 5 | } 6 | 7 | model_input { 8 | format: TEXT 9 | file: "model/rcv1_batch.*" 10 | } 11 | -------------------------------------------------------------------------------- /example/linear/rcv1/eval_online.conf: -------------------------------------------------------------------------------- 1 | validation_data { 2 | format: TEXT 3 | text: LIBSVM 4 | file: "data/rcv1/test/part.*" 5 | } 6 | 7 | model_input { 8 | format: TEXT 9 | file: "model/rcv1_online.*" 10 | } 11 | -------------------------------------------------------------------------------- /example/linear/rcv1/online_l1lr.conf: -------------------------------------------------------------------------------- 1 | training_data { 2 | format: TEXT 3 | text: LIBSVM 4 | file: "data/rcv1/train/part.*" 5 | } 6 | 7 | model_output { 8 | format: TEXT 9 | file: "model/rcv1_online" 10 | } 11 | 12 | loss { 13 | type: LOGIT 14 | } 15 | 16 | # lambda * |w|_1 17 | penalty { 18 | type: L1 19 | lambda: 1 20 | } 21 | 22 | learning_rate { 23 | type: DECAY 24 | alpha: 1 25 | beta: 1 26 | } 27 | 28 | async_sgd { 29 | algo: FTRL 30 | minibatch : 1000 31 | } 32 | -------------------------------------------------------------------------------- /make/README.md: -------------------------------------------------------------------------------- 1 | # Build the Parameter Server 2 | 3 | **Requirement** 4 | 5 | The parameter server needs a C++ compiler supporting c++11, such as `gcc` >= 6 | 4.7.2 (prefer >= 4.8) or `llvm` >= 3.4. 7 | You can update `gcc` via either downloading 8 | packages, e.g. [centos](http://linux.web.cern.ch/linux/devtoolset/), 9 | [ubuntu](http://ubuntuhandbook.org/index.php/2013/08/install-gcc-4-8-via-ppa-in-ubuntu-12-04-13-04/), 10 | [mac os x](http://hpc.sourceforge.net/), or building from source, such as for 11 | [centos](http://www.codersvoice.com/a/webbase/install/08/202014/131.html). 12 | 13 | **Build the Parameter Server** 14 | 15 | Assume `git` is installed: 16 | 17 | ```bash 18 | git clone https://github.com/dmlc/parameter_server -b dev 19 | cd parameter_server 20 | ./script/install_third.sh 21 | make -j8 22 | ``` 23 | 24 | **Customized Building** 25 | 26 | You can modify [config.mk](config.mk) to customize the building. You can copy 27 | this file to the upper directory so that the changes will be ignored by git. 28 | -------------------------------------------------------------------------------- /make/config.mk: -------------------------------------------------------------------------------- 1 | # default configuration of make 2 | # 3 | # you can copy it to the parent directory and modify it as you want. then 4 | # compile by `make -j 8` using 8 threads 5 | 6 | # compiler 7 | CC = g++ 8 | 9 | # optimization flag. -O0 -ggdb for debug 10 | OPT = -O3 -ggdb 11 | 12 | # statically link all dependent libraries, such as gflags, zeromq, if 13 | # 1. otherwise use dynamic linking 14 | STATIC_THIRD_LIB = 0 15 | 16 | # the installed path of third party libraries 17 | THIRD_PATH = $(shell pwd)/third_party 18 | 19 | # additional link flags, such as -ltcmalloc_and_profiler 20 | EXTRA_LDFLAGS = 21 | 22 | # additional compile flags 23 | EXTRA_CFLAGS = 24 | 25 | # io option 26 | USE_S3 = 0 27 | 28 | all: ps build/linear 29 | -------------------------------------------------------------------------------- /script/get_root_node.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ $# -ne 2 ]; then 3 | echo "usage: ./self interface port" 4 | exit -1; 5 | fi 6 | 7 | ip=`/sbin/ifconfig ${1} | grep inet | grep -v inet6 | awk '{print $2}' | sed -e 's/[a-z]*:/''/'` 8 | if [ -z ${ip} ]; then 9 | echo "failed to get the ip address" 10 | exit -1 11 | fi 12 | 13 | echo "role:SCHEDULER,hostname:'${ip}',port:${2},id:'H'" 14 | -------------------------------------------------------------------------------- /script/install_third.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dir=`dirname "$0"` 3 | cd $dir/.. 4 | git clone https://github.com/mli/third_party 5 | cd third_party 6 | ./install.sh 7 | -------------------------------------------------------------------------------- /script/kill_node.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ $# -ne 1 ]; then 3 | echo "usage: $0 node_id" 4 | exit -1; 5 | fi 6 | 7 | out=/tmp/aux 8 | ps aux >$out 9 | pid=`grep \'$1\'.*scheduler $out | awk '{print $2}'` 10 | if [ ! -z "$pid" ]; then 11 | kill -9 $pid 12 | fi 13 | rm -f $out 14 | -------------------------------------------------------------------------------- /script/local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # set -x 3 | dir=`dirname "$0"` 4 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$dir/../third_party/lib 5 | if [ $# -lt 3 ]; then 6 | echo "usage: $0 num_servers num_workers bin [args..]" 7 | exit -1; 8 | fi 9 | 10 | num_servers=$1 11 | shift 12 | num_workers=$1 13 | shift 14 | bin=$1 15 | shift 16 | arg="-num_servers ${num_servers} -num_workers ${num_workers} -log_dir log $@" #" -app ${dir}/$@" 17 | 18 | 19 | # killall -q $(basename ${bin}) 20 | # killall -q ${bin} 21 | HEAPCHECK=draconian 22 | # start the scheduler 23 | Sch="role:SCHEDULER,hostname:'127.0.0.1',port:8001,id:'H'" 24 | ${bin} -my_node ${Sch} -scheduler ${Sch} ${arg} & 25 | 26 | # start servers 27 | for ((i=0; i<${num_servers}; ++i)); do 28 | port=$((9600 + ${i})) 29 | N="role:SERVER,hostname:'127.0.0.1',port:${port},id:'S${i}'" 30 | # HEAPPROFILE=/tmp/S${i} \ 31 | # CPUPROFILE=/tmp/S${i} \ 32 | ${bin} -my_node ${N} -scheduler ${Sch} ${arg} & 33 | done 34 | 35 | # start workers 36 | for ((i=0; i<${num_workers}; ++i)); do 37 | port=$((9500 + ${i})) 38 | N="role:WORKER,hostname:'127.0.0.1',port:${port},id:'W${i}'" 39 | # HEAPPROFILE=/tmp/W${i} \ 40 | # CPUPROFILE=/tmp/W${i} \ 41 | ${bin} -my_node ${N} -scheduler ${Sch} ${arg} & 42 | done 43 | 44 | wait 45 | -------------------------------------------------------------------------------- /script/mpi_node.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # set -x 3 | if [ $# -ne 2 ]; then 4 | echo $# 5 | echo "usage: ./self scheduler_node mpi.conf" 6 | exit -1; 7 | fi 8 | 9 | # support mpich and openmpi 10 | # try mpirun -n 1 env to get all available environment 11 | if [ ! -z ${PMI_RANK} ]; then 12 | my_rank=${PMI_RANK} 13 | elif [ ! -z ${OMPI_COMM_WORLD_RANK} ]; then 14 | my_rank=${OMPI_COMM_WORLD_RANK} 15 | else 16 | echo "failed to get my rank id" 17 | exit -1 18 | fi 19 | 20 | if [ ! -z ${PMI_SIZE} ]; then 21 | rank_size=${PMI_SIZE} 22 | elif [ ! -z ${OMPI_COMM_WORLD_SIZE} ]; then 23 | rank_size=${OMPI_COMM_WORLD_SIZE} 24 | else 25 | echo "failed to get the rank size" 26 | exit -1 27 | fi 28 | 29 | source ${2} 30 | 31 | if (( ${rank_size} == 0 )); then 32 | root_node=`${dir}/get_root_node.sh ${network_interface} ${network_port}` 33 | if [${root_node} -ne ${my_node}]; then 34 | echo "start ./mpi_root.sh on the first machine in your hostfile" 35 | exit -1 36 | fi 37 | fi 38 | 39 | if (( ${rank_size} < ${num_workers} + ${num_servers} + 1 )); then 40 | echo "too small rank size ${rank_size}" 41 | exit -1 42 | fi 43 | 44 | # mkdir -p ${3}/../output 45 | # -num_threads ${num_threads} \ 46 | ${bin} \ 47 | -num_servers ${num_servers} \ 48 | -num_workers ${num_workers} \ 49 | -scheduler ${1} \ 50 | -my_rank ${my_rank} \ 51 | -interface ${network_interface} \ 52 | ${arg} 53 | 54 | exit $? 55 | # echo "rank:${my_rank} launch failed"; exit -1; 56 | -------------------------------------------------------------------------------- /script/mpi_root.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # set -x 3 | 4 | if [ $# -ne 1 ]; then 5 | echo "usage: ./self mpi.conf" 6 | exit -1; 7 | fi 8 | 9 | conf=${1} 10 | source ${conf} 11 | 12 | # mpirun=${dir}/mpirun 13 | 14 | dir=`dirname "$0"` 15 | root_node=`${dir}/get_root_node.sh ${network_interface} ${network_port}` 16 | np=$((${num_workers} + ${num_servers} + 1)) 17 | 18 | if [ ! -z ${hostfile} ]; then 19 | hostfile="-hostfile ${hostfile}" 20 | fi 21 | 22 | # mpirun ${hostfile} killall -q ps 23 | # mpirun ${hostfile} md5sum ../bin/ps 24 | 25 | mpirun ${hostfile} -np ${np} ${dir}/mpi_node.sh ${root_node} ${conf} 26 | -------------------------------------------------------------------------------- /src/README.mk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmlc/parameter_server/2feb6f4224188e11f63b0843ea76ade4003a60f3/src/README.mk -------------------------------------------------------------------------------- /src/app/linear_method/learning_rate.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "app/linear_method/proto/linear.pb.h" 3 | namespace PS { 4 | namespace LM { 5 | 6 | template 7 | class LearningRate { 8 | public: 9 | LearningRate(const LearningRateConfig& conf) : conf_(conf) { 10 | CHECK_GT(alpha(), 0); 11 | CHECK_GE(beta(), 0); 12 | } 13 | ~LearningRate() { } 14 | 15 | V eval(V x = 0) const { 16 | if (conf_.type() == LearningRateConfig::CONSTANT) { 17 | // if (x == 0 && beta() == 0) { 18 | return alpha(); 19 | } else { 20 | return alpha() / ( x + beta() ); 21 | } 22 | } 23 | 24 | V alpha() const { return conf_.alpha(); } 25 | V beta() const { return conf_.beta(); } 26 | private: 27 | // const LearningRateConfig& conf() { return conf_; } 28 | LearningRateConfig conf_; 29 | }; 30 | 31 | 32 | } // namespace LM 33 | } // namespace PS 34 | -------------------------------------------------------------------------------- /src/app/linear_method/main.cc: -------------------------------------------------------------------------------- 1 | #include "ps.h" 2 | #include "app/linear_method/async_sgd.h" 3 | #include "app/linear_method/darlin.h" 4 | #include "app/linear_method/model_evaluation.h" 5 | 6 | namespace PS { 7 | App* App::Create(const string& conf_str) { 8 | using namespace LM; 9 | // parse config 10 | Config conf; 11 | CHECK(google::protobuf::TextFormat::ParseFromString(conf_str, &conf)) 12 | << " failed to parse conf: " << conf.ShortDebugString(); 13 | 14 | // create app 15 | auto my_role = MyNode().role(); 16 | App* app = nullptr; 17 | if (conf.has_darlin()) { 18 | if (my_role == Node::SCHEDULER) { 19 | app = new DarlinScheduler(conf); 20 | } else if (my_role == Node::WORKER) { 21 | app = new DarlinWorker(conf); 22 | } else if (my_role == Node::SERVER) { 23 | app = new DarlinServer(conf); 24 | } 25 | } else if (conf.has_async_sgd()) { 26 | typedef float Real; 27 | if (my_role == Node::SCHEDULER) { 28 | app = new AsyncSGDScheduler(conf); 29 | } else if (my_role == Node::WORKER) { 30 | app = new AsyncSGDWorker(conf); 31 | } else if (my_role == Node::SERVER) { 32 | app = new AsyncSGDServer(conf); 33 | } 34 | } else if (conf.has_validation_data()) { 35 | app = new ModelEvaluation(conf); 36 | } 37 | CHECK(app) << "fail to create " << conf.ShortDebugString() 38 | << " at " << MyNode().ShortDebugString(); 39 | return app; 40 | } 41 | } // namespace PS 42 | 43 | int main(int argc, char *argv[]) { 44 | PS::RunSystem(argc, argv); 45 | return 0; 46 | } 47 | -------------------------------------------------------------------------------- /src/app/linear_method/model_evaluation.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "system/customer.h" 3 | #include "data/stream_reader.h" 4 | #include "util/evaluation.h" 5 | namespace PS { 6 | 7 | #if USE_S3 8 | bool s3file(const std::string& name); 9 | std::string s3Prefix(const std::string& path); 10 | std::string s3Bucket(const std::string& path); 11 | std::string s3FileUrl(const std::string& path); 12 | #endif // USE_S3 13 | 14 | 15 | namespace LM { 16 | 17 | class ModelEvaluation : public App { 18 | public: 19 | ModelEvaluation(const Config& conf) : App(), conf_(conf) { } 20 | virtual ~ModelEvaluation() { } 21 | virtual void Run(); 22 | private: 23 | typedef float Real; 24 | Config conf_; 25 | }; 26 | 27 | void ModelEvaluation::Run() { 28 | if (!IsScheduler()) return; 29 | // load model 30 | std::unordered_map weight; 31 | auto model = searchFiles(conf_.model_input()); 32 | NOTICE("find %d model files", model.file_size()); 33 | for (int i = 0; i < model.file_size(); ++i) { 34 | #if USE_S3 35 | std::ifstream in; 36 | if (s3file(model.file(i))) { 37 | // download file from s3 38 | std::string cmd="curl -s -o model_file "+s3FileUrl(model.file(i)); 39 | LOG(INFO)<> k >> v; 52 | weight[k] = v; 53 | } 54 | } 55 | #if USE_S3 56 | // remove local model after read done 57 | std::string cmd="rm -rf model_file"; 58 | system(cmd.c_str()); 59 | #endif // USE_S3 60 | 61 | NOTICE("load %lu model entries", weight.size()); 62 | 63 | // load evaluation data and compute the predicted value 64 | auto data = searchFiles(conf_.validation_data()); 65 | data.set_ignore_feature_group(true); 66 | NOTICE("find %d data files", data.file_size()); 67 | 68 | SArray label; 69 | SArray predict; 70 | MatrixPtrList mat; 71 | StreamReader reader(data); 72 | // TODO read in an another thread 73 | bool good = false; 74 | do { 75 | good = reader.readMatrices(100000, &mat); 76 | CHECK_EQ(mat.size(), 2); 77 | label.append(mat[0]->value()); 78 | 79 | SArray Xw(mat[1]->rows()); Xw.SetZero(); 80 | auto X = std::static_pointer_cast>(mat[1]); 81 | for (int i = 0; i < X->rows(); ++i) { 82 | Real re = 0; 83 | for (size_t j = X->offset()[i]; j < X->offset()[i+1]; ++j) { 84 | // TODO build a bloom filter 85 | auto it = weight.find(X->index()[j]); 86 | if (it != weight.end()) { 87 | re += it->second * (X->binary() ? 1 : X->value()[j]); 88 | } 89 | } 90 | Xw[i] = re; 91 | } 92 | predict.append(Xw); 93 | printf("\r \r"); 94 | printf(" load %lu examples", label.size()); 95 | fflush(stdout); 96 | } while (good); 97 | printf("\n"); 98 | 99 | // label.writeToFile("label"); 100 | // predict.writeToFile("predict"); 101 | 102 | // evaluation 103 | 104 | NOTICE("auc: %f", Evaluation::auc(label, predict)); 105 | NOTICE("accuracy: %f", Evaluation::accuracy(label, predict)); 106 | NOTICE("logloss: %f", Evaluation::logloss(label, predict)); 107 | } 108 | 109 | } // namespace LM 110 | } // namespace PS 111 | -------------------------------------------------------------------------------- /src/app/linear_method/penalty.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/common.h" 3 | #include "util/matrix.h" 4 | #include "app/linear_method/proto/linear.pb.h" 5 | namespace PS { 6 | namespace LM { 7 | 8 | /** 9 | * @brief Interface for the penalty 10 | */ 11 | template class Penalty { 12 | public: 13 | Penalty() { } 14 | virtual ~Penalty() { } 15 | /** 16 | * @brief evaluate the objective 17 | * 18 | * @param model 19 | * 20 | * @return objective value 21 | */ 22 | virtual T eval(const MatrixPtr& model) = 0; 23 | 24 | /** 25 | * @brief Solve the proximal operator 26 | * 27 | * \f$ \argmin_x 0.5/\eta (x - z)^2 + h(x)\f$, where h denote this penatly, and in 28 | * proximal gradient descent, z = w - eta * grad 29 | * 30 | * @param z 31 | * @param eta 32 | * @return 33 | */ 34 | virtual T proximal(T z, T eta) = 0; 35 | }; 36 | 37 | /** 38 | * @brief \f$ \lambda_1 * \|x\|_1 + \lambda_2 * \|x\|_2^2 \f$ 39 | */ 40 | template 41 | class ElasticNet : public Penalty { 42 | public: 43 | ElasticNet(T lambda1, T lambda2) : lambda1_(lambda1), lambda2_(lambda2) { 44 | CHECK_GE(lambda1, 0); 45 | CHECK_GE(lambda2, 0); 46 | } 47 | ~ElasticNet() { } 48 | 49 | T eval(const MatrixPtr& model) { return 0; } // TODO 50 | 51 | T proximal(T z, T eta) { 52 | CHECK_GT(eta, 0); 53 | T leta = lambda1_ * eta; 54 | if (z <= leta && z >= -leta) return 0; 55 | return z > 0 ? (z - leta) / ( 1 + lambda2_ * eta) : (z + leta) / ( 1 + lambda2_ * eta); 56 | } 57 | private: 58 | T lambda1_, lambda2_; 59 | }; 60 | 61 | 62 | // template 63 | // class L2 : public Penalty { 64 | // public: 65 | // L2(T lambda) : lambda_(lambda) { } 66 | // CHECK_GE(lambda, 0); 67 | // } 68 | // ~L2() { } 69 | // private: 70 | // T evaluate(const MatrixPtr& model) { return 0; } // TODO 71 | // T proximal(T z, T eta) { 72 | // } 73 | // T lambda_; 74 | // }; 75 | 76 | template 77 | Penalty* createPenalty(const PenaltyConfig& conf) { 78 | CHECK_GE(conf.lambda_size(), 1); 79 | switch (conf.type()) { 80 | case PenaltyConfig::L1: { 81 | T l1 = conf.lambda(0); 82 | T l2 = conf.lambda_size() > 1 ? conf.lambda(1) : 0; 83 | return new ElasticNet(l1, l2); 84 | } 85 | case PenaltyConfig::L2: 86 | return new ElasticNet(0, conf.lambda(0)); 87 | default: 88 | CHECK(false) << "unknown type: " << conf.DebugString(); 89 | } 90 | return nullptr; 91 | } 92 | 93 | } // namespace LM 94 | } // namespace PS 95 | 96 | // // lambda * ||w||_p^P = lambda * \sum_i w_i^p 97 | // // TODO infinity 98 | // template 99 | // class PNormPenalty : public Penalty { 100 | // public: 101 | // PNormPenalty(T p, T lambda) : p_(p), lambda_(lambda) { 102 | // CHECK_GE(p_, 0); 103 | // CHECK_GE(lambda_, 0); 104 | // } 105 | // bool smooth() { return p_ > 1; } 106 | 107 | // T evaluate(const MatrixPtr& model) { 108 | // auto w = model->value().EigenArray(); 109 | // return lambda_ * pow(w.abs(), p_).sum(); 110 | // } 111 | 112 | // T lambda() { return lambda_; } 113 | // T p() { return p_; } 114 | // private: 115 | // T p_; 116 | // T lambda_; 117 | // }; 118 | -------------------------------------------------------------------------------- /src/app/linear_method/proto/linear.proto: -------------------------------------------------------------------------------- 1 | // configuration of linear methods 2 | package PS.LM; 3 | import "data/proto/data.proto"; 4 | import "learner/proto/bcd.proto"; 5 | import "filter/proto/filter.proto"; 6 | 7 | message Config { 8 | optional DataConfig training_data = 1; 9 | optional DataConfig validation_data = 2; 10 | 11 | optional DataConfig model_output = 4; 12 | optional DataConfig model_input = 5; 13 | 14 | optional LossConfig loss = 10; 15 | optional PenaltyConfig penalty = 11; 16 | 17 | optional LearningRateConfig learning_rate = 12; 18 | 19 | optional SGDConfig async_sgd = 17; 20 | optional BCDConfig darlin = 15; 21 | 22 | } 23 | 24 | extend BCDConfig { 25 | // Used by the trust region method. All changes of parameters will be bounded 26 | // by *delta*. *delta* is updated according to the convergence, whose intial 27 | // value is *delta_init_value* and maximal value is *delta_max_value* 28 | optional double delta_init_value = 101 [default = 1]; 29 | optional double delta_max_value = 102 [default = 5]; 30 | // kkt_filter_threshold = max_gradient_violation / num_examples * 31 | // kkt_filter_threshold_ratio. increasing this number reduces the effect of 32 | // kkt filter. 33 | optional double kkt_filter_threshold_ratio = 103 [default = 10]; 34 | } 35 | 36 | message SGDConfig { 37 | enum Algo { 38 | STANDARD = 1; 39 | FTRL = 2; 40 | } 41 | required Algo algo = 1; 42 | 43 | // The size of minibatch 44 | optional int32 minibatch = 2 [default = 1000]; 45 | 46 | optional int32 data_buf = 12 [default = 1000]; // in mb 47 | 48 | optional bool ada_grad = 5 [default = true]; 49 | 50 | optional int32 max_delay = 4 [default = 0]; 51 | 52 | // The number of data passes 53 | optional int32 num_data_pass = 11 [default = 1]; 54 | 55 | // in sec 56 | optional int32 report_interval = 3 [default = 1]; 57 | 58 | // features which occurs <= *tail_feature_freq* will be filtered before 59 | // training. it save both memory and bandwidth. 60 | optional int32 tail_feature_freq = 6 [default = 0]; 61 | 62 | // It controls the countmin size. We filter the tail features by countmin, which 63 | // is more efficient than hash, but still is the memory bottleneck for servers. A 64 | // smaller ratio reduces the memory footprint, but may increase the size of 65 | // filtered feature. 66 | optional float countmin_n = 8 [default = 1e8]; 67 | optional int32 countmin_k = 7 [default = 2]; 68 | 69 | repeated FilterConfig push_filter = 13; 70 | repeated FilterConfig pull_filter = 14; 71 | } 72 | 73 | message LossConfig { 74 | enum Type { 75 | SQUARE = 1; 76 | LOGIT = 2; 77 | HINGE = 3; 78 | SQUARE_HINGE = 4; 79 | } 80 | required Type type = 1; 81 | } 82 | 83 | message PenaltyConfig { 84 | enum Type { 85 | L1 = 1; // lambda(0) * ||w||_1 + lambda(1) * ||w||_F^2 86 | L2 = 2; // lambda(0) * ||w||_F^2 87 | } 88 | required Type type = 1; 89 | repeated double lambda = 2; 90 | } 91 | 92 | 93 | message LearningRateConfig { 94 | enum Type { 95 | CONSTANT = 1; // = alpha 96 | DECAY = 2; // = alpha / (beta + x), where x is user-defined, such as sqrt(iter) 97 | } 98 | optional Type type = 1; 99 | optional double alpha = 2; 100 | optional double beta = 3; 101 | } 102 | -------------------------------------------------------------------------------- /src/data/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/common.h" 3 | #include "data/proto/example.pb.h" 4 | #include "data/proto/data.pb.h" 5 | #include "util/proto/matrix.pb.h" 6 | 7 | namespace PS { 8 | 9 | DECLARE_string(input); 10 | DECLARE_string(output); 11 | DECLARE_string(format); 12 | 13 | // return files matches the regex in *config* 14 | DataConfig searchFiles(const DataConfig& config); 15 | 16 | // evenly parttion the files into *num* parts 17 | std::vector divideFiles(const DataConfig& data, int num); 18 | 19 | // locate the i-th file in *conf*, append it with suffix, and keep the rest metadata 20 | DataConfig ithFile(const DataConfig& conf, int i, const string& suffix = ""); 21 | 22 | // return A + B 23 | DataConfig appendFiles(const DataConfig& A, const DataConfig& B); 24 | 25 | ExampleInfo mergeExampleInfo(const ExampleInfo& A, const ExampleInfo& B); 26 | 27 | MatrixInfo readMatrixInfo( 28 | const ExampleInfo& info, int slot_id, int sizeof_idx, int sizeof_val); 29 | 30 | 31 | DataConfig shuffleFiles(const DataConfig& data); 32 | 33 | 34 | } // namespace PS 35 | -------------------------------------------------------------------------------- /src/data/info_parser.cc: -------------------------------------------------------------------------------- 1 | #include "data/info_parser.h" 2 | namespace PS { 3 | 4 | void InfoParser::clear() { 5 | for (int i = 0; i < kSlotIDmax; ++i) slot_info_[i].Clear(); 6 | info_.Clear(); 7 | num_ex_ = 0; 8 | } 9 | 10 | bool InfoParser::add(const Example& ex) { 11 | for (int i = 0; i < ex.slot_size(); ++i) { 12 | const auto& slot = ex.slot(i); 13 | if (slot.id() >= kSlotIDmax) return false; 14 | auto& sinfo = slot_info_[slot.id()]; 15 | for (int j = 0; j < slot.key_size(); ++j) { 16 | uint64 key = slot.key(j); 17 | sinfo.set_min_key(std::min((uint64)sinfo.min_key(), key)); 18 | sinfo.set_max_key(std::max((uint64)sinfo.max_key(), key + 1)); 19 | } 20 | if (slot.key_size() > 0) { 21 | if (slot.val_size() == slot.key_size()) { 22 | sinfo.set_format(SlotInfo::SPARSE); 23 | } else { 24 | sinfo.set_format(SlotInfo::SPARSE_BINARY); 25 | } 26 | } else if (slot.val_size() > 0) { 27 | sinfo.set_format(SlotInfo::DENSE); 28 | } 29 | sinfo.set_nnz_ex(sinfo.nnz_ex() + 1); 30 | sinfo.set_nnz_ele(sinfo.nnz_ele() + std::max(slot.key_size(), slot.val_size())); 31 | } 32 | ++ num_ex_; 33 | return true; 34 | } 35 | 36 | ExampleInfo InfoParser::info() { 37 | info_.set_num_ex(num_ex_); 38 | info_.clear_slot(); 39 | for (int i = 0; i < kSlotIDmax; ++i) { 40 | auto &sinfo = slot_info_[i]; 41 | if (!sinfo.nnz_ele()) continue; 42 | sinfo.set_id(i); 43 | if (i == 0) { // the label 44 | sinfo.set_min_key(0); 45 | sinfo.set_max_key(1); 46 | } 47 | *info_.add_slot() = sinfo; 48 | } 49 | return info_; 50 | } 51 | } // namespace PS 52 | -------------------------------------------------------------------------------- /src/data/info_parser.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/common.h" 3 | #include "data/proto/example.pb.h" 4 | #include "data/proto/data.pb.h" 5 | namespace PS { 6 | 7 | // the maximal allowed slot id 8 | static const int kSlotIDmax = 4096; 9 | 10 | class InfoParser { 11 | public: 12 | // void init(const DataConfig& conf) { conf_ = conf; } 13 | bool add(const Example& ex); 14 | void clear(); 15 | ExampleInfo info(); 16 | // int maxSlotID() { return conf_.ignore_fea_slot() ? 2 : kSlotIDmax; } 17 | private: 18 | // DataConfig conf_; 19 | size_t num_ex_ = 0; 20 | ExampleInfo info_; 21 | SlotInfo slot_info_[kSlotIDmax]; 22 | }; 23 | 24 | } // namespace PS 25 | -------------------------------------------------------------------------------- /src/data/matlab/bin2mat.m: -------------------------------------------------------------------------------- 1 | function [D, group_id, group_offset] = bin2mat(name) 2 | % load a binary format data, which can be generated by ./bin/recordio2bin 3 | function result = grep(name, pattern) 4 | [ret, result] = system(... 5 | ['grep ', pattern, ' ', name, '.info | awk ''{print $2}''']); 6 | assert(ret == 0); 7 | result(end) = []; 8 | end 9 | 10 | type = grep(name, 'type'); 11 | n = str2num(grep(name, 'end')); 12 | row_major = all(grep(name, 'row_major') == 'true'); 13 | sizeof_v = str2num(grep(name, 'sizeof_value')); 14 | 15 | if strfind(type, 'DENSE') 16 | assert(row_major); 17 | assert(sizeof_v == 8); 18 | D = load_bin([name, '.value'], 'double'); 19 | D = reshape(D, n(2), n(1))'; 20 | elseif strfind(type, 'SPARSE') 21 | 22 | if exist([name, '.key'], 'file') 23 | group_id = str2num(grep(name, 'group_id')); 24 | begin = sscanf(grep(name, 'fea_begin'), '%lu'); 25 | [begin, ix] = sort(begin); 26 | group_id = group_id(ix); 27 | % begin = str2num(; %, 'uint64'); 28 | % begin = uint64(begin)' 29 | % return 30 | % e = str2num(grep(name, 'feature_end'), 'uint64'); 31 | key = load_bin([name, '.key'], 'uint64'); 32 | group_offset = zeros(length(begin)+1, 1); 33 | group_offset(end) = length(key) + 1; 34 | for i = 1 : length(begin) 35 | group_offset(i) = find(key == begin(i), 1, 'first'); 36 | end 37 | clear key 38 | % group_offset 39 | end 40 | 41 | sizeof_i = str2num(grep(name, 'sizeof_index')); 42 | assert(sizeof_i == 4); 43 | 44 | j = load_bin([name '.index'], 'uint32'); 45 | J = double(j) + 1; 46 | clear j 47 | 48 | i = load_bin([name '.offset'], 'uint64'); 49 | I = zeros(length(J), 1); 50 | for k = 1 : length(i) - 1 51 | I(i(k)+1 : i(k+1)) = k; 52 | end 53 | clear i 54 | 55 | if strfind(type, 'SPARSE_BINARY') 56 | D = sparse(I, J, 1); 57 | else 58 | V = load_bin([name '.value'], 'double'); 59 | D = sparse(I, J, V); 60 | end 61 | 62 | end 63 | 64 | end 65 | -------------------------------------------------------------------------------- /src/data/matlab/filter_fea.m: -------------------------------------------------------------------------------- 1 | % Y = bin2mat('CTRb.Y'); 2 | % [X, group_id, group_os] = bin2mat('CTRb.X'); 3 | 4 | pv = 4; 5 | 6 | ix = sum(X) > pv; 7 | X = X(:,ix); 8 | 9 | os = group_os; 10 | for i = 2 : length(os) 11 | os(i) = os(i-1) + nnz(ix(group_os(i-1):group_os(i)-1)); 12 | end 13 | group_os = os; 14 | 15 | 16 | % save CTRb_pv4 X Y group_id group_os 17 | -------------------------------------------------------------------------------- /src/data/matlab/load_bin.m: -------------------------------------------------------------------------------- 1 | function data = load_bin(filename, format, offset, length) 2 | %load a vector from a binary file 3 | %load_bin(filename, format, offset, length) 4 | 5 | if nargin < 2, format = 'double'; end 6 | if nargin < 3, offset = 0; end 7 | if nargin < 4, length = inf; end 8 | 9 | fid=fopen(filename,'r'); 10 | if (fid < 0) 11 | data = []; 12 | disp([filename ' doesn''t exist']); 13 | return; 14 | end 15 | 16 | eval(['v=', format, '(1);']); 17 | w = whos('v'); 18 | bsize = w.bytes; 19 | 20 | fseek(fid, bsize*offset, -1); 21 | data = fread(fid, length, [format,'=>',format]); 22 | fclose(fid); 23 | 24 | end 25 | -------------------------------------------------------------------------------- /src/data/matlab/mat2bin.m: -------------------------------------------------------------------------------- 1 | function mat2bin(name, Y, X) 2 | % new version 3 | % save to row-majored binary format 4 | % if nargin < 4, index_fmt = 'uint32'; end 5 | 6 | [a,b,x] = find(X'); 7 | 8 | save_bin([name '.index'], uint32(a-1), 'uint32'); 9 | 10 | bool = 'false'; 11 | if nnz(x==0) + nnz(x==1) == length(x) 12 | disp('binary data'); 13 | bool = 'true'; 14 | save_bin([name '.offset'], uint64(cumsum([0; full(sum(X,2))])), 'uint64'); 15 | else 16 | save_bin([name '.value'], x, 'double'); 17 | clear x 18 | E = sparse(a,b,1); 19 | save_bin([name '.offset'], uint64(cumsum([0 full(sum(E))])), 'uint64'); 20 | end 21 | 22 | save_bin([name '.label'], Y, 'double'); 23 | 24 | fid = fopen([name, '.info'], 'w'); 25 | 26 | 27 | fprintf(fid, strcat('row_major : true\n',... 28 | 'sparse : true\n',... 29 | 'bool_value : %s\n',... 30 | 'row_begin : 0\n',... 31 | 'row_end : %lu\n',... 32 | 'col_begin : 0\n',... 33 | 'col_end : %lu\n',... 34 | 'entries : %lu\n',... 35 | 'sizeof_index : 4\n',... 36 | 'sizeof_value : 8\n'), bool, size(X,1), size(X,2), nnz(X)); 37 | fclose(fid); 38 | % fprintf(fid, '0 %lu\n0 %lu\n%lu\n', size(X,1), size(X,2), nnz(X)); 39 | 40 | end 41 | -------------------------------------------------------------------------------- /src/data/matlab/save_bin.m: -------------------------------------------------------------------------------- 1 | function save_bin(name, X, format) 2 | %save a vector into a binary file 3 | %save_bin(name, X, format) 4 | 5 | if nargin < 3 6 | format = 'double'; 7 | end 8 | 9 | fid = fopen(name, 'w'); 10 | n = fwrite(fid, X, format); 11 | fclose(fid); 12 | 13 | if n ~= numel(X) 14 | disp(sprintf('only write %d element, the actual one is %d', n, numel(X))); 15 | end 16 | -------------------------------------------------------------------------------- /src/data/matlab/saveas_pserver.m: -------------------------------------------------------------------------------- 1 | function saveas_pserver(file_name, Y, X, group_id, binary) 2 | % debug use only, not efficient for large data 3 | 4 | if nargin < 4, group_id = zeros(size(X,2),1); end 5 | if nargin < 5, binary = false; end 6 | 7 | assert(issorted(group_id)); 8 | 9 | tX = X'; 10 | fd = fopen(file_name, 'w'); 11 | for i = 1 : length(Y) 12 | fprintf(fd, '%d', Y(i)); 13 | [a,b,c] = find(tX(:,i)); 14 | pre_gid = -1; 15 | for j = 1 : length(a) 16 | id = group_id(a(j)); 17 | if id ~= pre_gid 18 | fprintf(fd, '; %d', id); 19 | pre_gid = id; 20 | end 21 | 22 | if binary 23 | fprintf(fd, ' %d', a(j)-1); 24 | else 25 | fprintf(fd, ' %d:%g', a(j)-1, c(j)); 26 | end 27 | end 28 | fprintf(fd, ';\n'); 29 | end 30 | fclose(fd); 31 | -------------------------------------------------------------------------------- /src/data/proto/data.proto: -------------------------------------------------------------------------------- 1 | package PS; 2 | import "util/proto/range.proto"; 3 | 4 | message DataConfig { 5 | enum DataFormat { 6 | BIN = 1; 7 | PROTO = 2; 8 | TEXT = 3; 9 | } 10 | required DataFormat format = 1; 11 | 12 | // see https://github.com/mli/parameter_server/wiki/Data 13 | enum TextFormat { 14 | DENSE = 1; 15 | SPARSE = 2; 16 | SPARSE_BINARY = 3; 17 | ADFEA = 4; 18 | LIBSVM = 5; 19 | TERAFEA = 6; 20 | VW = 7; 21 | CRITEO = 9; 22 | } 23 | optional TextFormat text = 2; 24 | 25 | // filenames, supports regular expressions 26 | repeated string file = 3; 27 | // files stored in hdfs 28 | optional HDFSConfig hdfs = 5; 29 | // ignore the feature group information 30 | optional bool ignore_feature_group = 6; 31 | // the maximal number of files will be assigned to a worker, -1 means no limit 32 | optional int32 max_num_files_per_worker = 7 [default = -1]; 33 | 34 | // the maximal number of lines will be read from a file, -1 means no limit 35 | optional int32 max_num_lines_per_file = 8 [default = -1]; 36 | 37 | // randomly shuffle the file order 38 | optional bool shuffle = 9 [default = false]; 39 | 40 | // only valid for the binary format 41 | optional PbRange range = 4; 42 | // duplicate the file several times 43 | optional int32 replica = 10 [default = 1]; 44 | } 45 | 46 | message HDFSConfig { 47 | optional string home = 1; // HADOOP_HOME 48 | optional string ugi = 2; // hadoop.job.ugi, format: user,passwd 49 | optional string namenode = 4; // fs.default.name 50 | } 51 | -------------------------------------------------------------------------------- /src/data/proto/example.proto: -------------------------------------------------------------------------------- 1 | package PS; 2 | message SlotInfo { 3 | enum Format { 4 | DENSE = 1; 5 | SPARSE = 2; 6 | SPARSE_BINARY = 3; 7 | } 8 | optional Format format = 1; 9 | optional int32 id = 2; 10 | optional uint64 min_key = 3 [default = 0xFFFFFFFFFFFFFFFF]; 11 | optional uint64 max_key = 4; 12 | // total number of non-zero elements 13 | optional uint64 nnz_ele = 5; 14 | // total number of non-zero instance, (non-empty rows) 15 | optional uint64 nnz_ex = 6; 16 | } 17 | 18 | message ExampleInfo { 19 | repeated SlotInfo slot = 1; 20 | // total number of instances 21 | optional uint64 num_ex = 2; 22 | 23 | } 24 | 25 | message Slot { 26 | optional int32 id = 1; 27 | repeated uint64 key = 2 [packed=true]; 28 | repeated float val = 3 [packed=true]; 29 | } 30 | 31 | message Example { 32 | repeated Slot slot = 1; 33 | } 34 | -------------------------------------------------------------------------------- /src/data/show_example.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/common.h" 3 | #include "data/common.h" 4 | #include "util/recordio.h" 5 | #include "proto/example.pb.h" 6 | namespace PS { 7 | 8 | DEFINE_int32(n, 3, "show the first *n* instances in text format"); 9 | 10 | static void showExample() { 11 | File* in = File::openOrDie(FLAGS_input, "r"); 12 | RecordReader reader(in); 13 | for (int i = 0; i < FLAGS_n; ++i) { 14 | Example ex; 15 | CHECK(reader.ReadProtocolMessage(&ex)); 16 | std::cout << ex.ShortDebugString() << std::endl; 17 | } 18 | } 19 | 20 | } // namespace PS 21 | -------------------------------------------------------------------------------- /src/data/slot_reader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/shared_array_inl.h" 3 | #include "proto/example.pb.h" 4 | #include "data/common.h" 5 | namespace PS { 6 | 7 | // read all slots in *data* with multithreadd, save them into *cache*. 8 | class SlotReader { 9 | public: 10 | SlotReader() { } 11 | SlotReader(const DataConfig& data, const DataConfig& cache) { 12 | Init(data, cache); 13 | } 14 | 15 | void Init(const DataConfig& data, const DataConfig& cache); 16 | 17 | // first read, then save 18 | int Read(ExampleInfo* info = nullptr); 19 | 20 | template MatrixInfo info(int slot_id) const { 21 | return readMatrixInfo(info_, slot_id, sizeof(uint64), sizeof(V)); 22 | } 23 | 24 | // load a slot from cache 25 | SArray offset(int slot_id); 26 | SArray index(int slot_id); 27 | template SArray value(int slot_id) const; 28 | 29 | void clear(int slot_id) { 30 | offset_cache_.erase(slot_id); 31 | index_cache_.erase(slot_id); 32 | } 33 | 34 | private: 35 | string cacheName(const DataConfig& data, int slot_id) const; 36 | size_t nnzEle(int slot_id) const; 37 | bool readOneFile(const DataConfig& data, int ith_file); 38 | string cache_; 39 | DataConfig data_; 40 | // bool dump_to_disk_; 41 | ExampleInfo info_; 42 | std::unordered_map slot_info_; 43 | std::mutex mu_; 44 | size_t loaded_file_count_; 45 | std::vector num_ex_; 46 | std::unordered_map> offset_cache_; 47 | std::unordered_map> index_cache_; 48 | }; 49 | 50 | template SArray SlotReader::value(int slot_id) const { 51 | SArray val; 52 | if (nnzEle(slot_id) == 0) return val; 53 | for (int i = 0; i < data_.file_size(); ++i) { 54 | string file = cacheName(ithFile(data_, i), slot_id) + ".value"; 55 | SArray comp; CHECK(comp.ReadFromFile(file)); 56 | SArray uncomp; uncomp.UncompressFrom(comp); 57 | size_t n = val.size(); 58 | val.resize(n+uncomp.size()); 59 | for (size_t i = 0; i < uncomp.size(); ++i) val[n+i] = uncomp[i]; 60 | } 61 | CHECK_EQ(val.size(), nnzEle(slot_id)) << slot_id; 62 | return val; 63 | } 64 | 65 | } // namespace PS 66 | -------------------------------------------------------------------------------- /src/data/text2proto.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | namespace PS { 3 | 4 | void textToProto() { 5 | // TODO 6 | // string format = FLAGS_format; 7 | // std::transform(format.begin(), format.end(), format.begin(), ::tolower); 8 | 9 | // TextParser parser; 10 | // if (format == "libsvm") { 11 | // parser.setFormat(DataConfig::LIBSVM); 12 | // } else if (format == "adfea") { 13 | // parser.setFormat(DataConfig::ADFEA); 14 | // } 15 | 16 | 17 | // auto record_file = File::openOrDie(FLAGS_output+".recordio", "w"); 18 | // RecordWriter writer(record_file); 19 | 20 | // Instance ins; 21 | // int ignored = 0; 22 | // FileLineReader reader(FLAGS_input.c_str()); 23 | // reader.set_line_callback([&parser, &ins, &writer, &ignored] (char *line) { 24 | // ignored += !parser.toProtobuf(line, &ins); 25 | // writer.WriteProtocolMessage(ins); 26 | // }); 27 | 28 | // Timer t; t.start(); 29 | // reader.Reload(); 30 | // auto info = parser.info(); 31 | // writeProtoToASCIIFileOrDie(info, FLAGS_output+".info"); 32 | // t.stop(); 33 | 34 | // std::cerr << "written " << info.num_ins() 35 | // << " instances in " << t.get() << " sec." << std::endl; 36 | // if (ignored) { 37 | // std::cerr << ignored << " bad instances are skipped" << std::endl; 38 | // } 39 | } 40 | 41 | } // namespace PS 42 | -------------------------------------------------------------------------------- /src/data/text_parser.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/common.h" 3 | #include "data/proto/example.pb.h" 4 | #include "data/proto/data.pb.h" 5 | namespace PS { 6 | 7 | // parse an example from various text formats into the protobuf format, 8 | // e.g. proto/example.proto 9 | class ExampleParser { 10 | public: 11 | typedef DataConfig::TextFormat TextFormat; 12 | void Init(TextFormat format, bool ignore_fea_slot = false); 13 | bool ToProto(char*, Example*); 14 | private: 15 | bool ParseLibsvm(char*, Example*); 16 | bool ParseAdfea(char*, Example*); 17 | bool ParseTerafea(char*, Example*); 18 | bool ParsePS(char*, Example*, TextFormat); 19 | bool ParseCriteo(char*, Example*); 20 | std::function parser_; 21 | bool ignore_fea_slot_; 22 | }; 23 | 24 | } 25 | -------------------------------------------------------------------------------- /src/filter/add_noise.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "filter/filter.h" 3 | namespace PS { 4 | 5 | /** 6 | * @brief Add noise 7 | * 8 | */ 9 | class AddNoiseFilter : public Filter { 10 | public: 11 | void encode(Message* msg) { 12 | auto filter_conf = CHECK_NOTNULL(find(FilterConfig::NOISE, msg)); 13 | int n = msg->value.size(); 14 | CHECK_EQ(n, msg->task.value_type_size()); 15 | for (int i = 0; i < n; ++i) { 16 | if (msg->value[i].size() == 0) continue; 17 | auto type = msg->task.value_type(i); 18 | if (type == DataType::FLOAT) { 19 | AddNoise(msg->value[i], filter_conf); 20 | } 21 | if (type == DataType::DOUBLE) { 22 | AddNoise(msg->value[i], filter_conf); 23 | } 24 | } 25 | } 26 | 27 | private: 28 | 29 | template 30 | void AddNoise(const SArray& array, FilterConfig* cf) { 31 | std::default_random_engine generator; 32 | std::normal_distribution distribution((V)cf->mean(), (V)cf->std()); 33 | SArray data(array); 34 | // SArray noise(data.size()); 35 | for (size_t i = 0; i < data.size(); ++i) { 36 | data[i] += distribution(generator); 37 | } 38 | // LL << noise.Std() << " " << noise; 39 | } 40 | 41 | }; 42 | 43 | } // namespace PS 44 | -------------------------------------------------------------------------------- /src/filter/compressing.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "filter/filter.h" 3 | 4 | namespace PS { 5 | 6 | class CompressingFilter : public Filter { 7 | public: 8 | void encode(Message* msg) { 9 | auto conf = find(FilterConfig::COMPRESSING, msg); 10 | if (!conf) return; 11 | conf->clear_uncompressed_size(); 12 | if (msg->has_key()) { 13 | conf->add_uncompressed_size(msg->key.size()); 14 | msg->key = msg->key.CompressTo(); 15 | } 16 | for (auto& v : msg->value) { 17 | conf->add_uncompressed_size(v.size()); 18 | v = v.CompressTo(); 19 | } 20 | } 21 | void decode(Message* msg) { 22 | auto conf = find(FilterConfig::COMPRESSING, msg); 23 | if (!conf) return; 24 | int has_key = msg->has_key(); 25 | CHECK_EQ(conf->uncompressed_size_size(), msg->value.size() + has_key); 26 | 27 | if (has_key) { 28 | SArray raw(conf->uncompressed_size(0)); 29 | raw.UncompressFrom(msg->key); 30 | msg->key = raw; 31 | } 32 | for (int i = 0; i < msg->value.size(); ++i) { 33 | SArray raw(conf->uncompressed_size(i+has_key)); 34 | raw.UncompressFrom(msg->value[i]); 35 | msg->value[i] = raw; 36 | } 37 | } 38 | }; 39 | 40 | } // namespace PS 41 | -------------------------------------------------------------------------------- /src/filter/filter.cc: -------------------------------------------------------------------------------- 1 | #include "filter/filter.h" 2 | #include "filter/compressing.h" 3 | #include "filter/key_caching.h" 4 | #include "filter/fixing_float.h" 5 | #include "filter/add_noise.h" 6 | 7 | namespace PS { 8 | 9 | Filter* Filter::create(const FilterConfig& conf) { 10 | switch (conf.type()) { 11 | case FilterConfig::KEY_CACHING: 12 | return new KeyCachingFilter(); 13 | case FilterConfig::COMPRESSING: 14 | return new CompressingFilter(); 15 | case FilterConfig::FIXING_FLOAT: 16 | return new FixingFloatFilter(); 17 | case FilterConfig::NOISE: 18 | return new AddNoiseFilter(); 19 | default: 20 | CHECK(false) << "unknow filter type"; 21 | } 22 | return nullptr; 23 | } 24 | 25 | 26 | FilterConfig* Filter::find(FilterConfig::Type type, Task* task) { 27 | for (int i = 0; i < task->filter_size(); ++i) { 28 | if (task->filter(i).type() == type) return task->mutable_filter(i); 29 | } 30 | return nullptr; 31 | } 32 | 33 | } // namespace PS 34 | -------------------------------------------------------------------------------- /src/filter/filter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "system/message.h" 3 | #include "filter/proto/filter.pb.h" 4 | #include "util/shared_array_inl.h" 5 | 6 | namespace PS { 7 | 8 | // A filter should be thread safe 9 | class Filter { 10 | public: 11 | Filter() { } 12 | virtual ~Filter() { } 13 | 14 | static Filter* create(const FilterConfig& conf); 15 | 16 | 17 | virtual void encode(Message* msg) { } 18 | virtual void decode(Message* msg) { } 19 | 20 | static FilterConfig* find(FilterConfig::Type type, Message* msg) { 21 | return find(type, &(msg->task)); 22 | } 23 | static FilterConfig* find(FilterConfig::Type type, Task* task); 24 | }; 25 | 26 | } // namespace 27 | -------------------------------------------------------------------------------- /src/filter/fixing_float.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "filter/filter.h" 3 | #include 4 | namespace PS { 5 | 6 | class FixingFloatFilter : public Filter { 7 | public: 8 | void encode(Message* msg) { 9 | convert(msg, true); 10 | } 11 | 12 | void decode(Message* msg) { 13 | convert(msg, false); 14 | } 15 | 16 | private: 17 | // a fast random function 18 | static bool boolrand(int* seed) { 19 | *seed = (214013 * *seed + 2531011); 20 | return ((*seed >> 16) & 0x1) == 0; 21 | } 22 | 23 | // decode / encode a message 24 | void convert(Message* msg, bool encode) { 25 | auto filter_conf = CHECK_NOTNULL(find(FilterConfig::FIXING_FLOAT, msg)); 26 | if (filter_conf->num_bytes() == 0) return; 27 | int n = msg->value.size(); 28 | CHECK_EQ(n, msg->task.value_type_size()); 29 | int k = 0; 30 | for (int i = 0; i < n; ++i) { 31 | if (msg->value[i].size() == 0) continue; 32 | auto type = msg->task.value_type(i); 33 | if (filter_conf->fixed_point_size() <= k) { 34 | filter_conf->add_fixed_point(); 35 | } 36 | if (type == DataType::FLOAT) { 37 | msg->value[i] = convert( 38 | msg->value[i], encode, filter_conf->num_bytes(), 39 | filter_conf->mutable_fixed_point(k++)); 40 | } 41 | if (type == DataType::DOUBLE) { 42 | msg->value[i] = convert( 43 | msg->value[i], encode, filter_conf->num_bytes(), 44 | filter_conf->mutable_fixed_point(k++)); 45 | } 46 | } 47 | } 48 | 49 | // decode / encode an array 50 | template 51 | SArray convert(const SArray& array, bool encode, int nbytes, 52 | FilterConfig::FixedFloatConfig* conf) { 53 | CHECK_GT(nbytes, 0); 54 | CHECK_LT(nbytes, 8); 55 | double ratio = static_cast(1 << (nbytes*8)) - 2; 56 | 57 | if (encode) { 58 | if (!conf->has_min_value()) { 59 | conf->set_min_value(SArray(array).EigenArray().minCoeff()); 60 | } 61 | if (!conf->has_max_value()) { 62 | conf->set_max_value(SArray(array).EigenArray().maxCoeff() + 1e-6); // to avoid max_v == min_v 63 | } 64 | } 65 | 66 | CHECK(conf->has_min_value()); 67 | double min_v = static_cast(conf->min_value()); 68 | CHECK(conf->has_max_value()); 69 | double max_v = static_cast(conf->max_value()); 70 | double bin = max_v - min_v; 71 | CHECK_GT(bin, 0); 72 | 73 | if (encode) { 74 | // float/double to nbytes*8 int 75 | SArray orig(array); 76 | SArray code(orig.size() * nbytes); 77 | uint8* code_ptr = code.data(); 78 | int seed = time(NULL); 79 | for (int i = 0; i < orig.size(); ++i) { 80 | double proj = orig[i] > max_v ? max_v : orig[i] < min_v ? min_v : orig[i]; 81 | double tmp = (proj - min_v) / bin * ratio; 82 | uint64 r = static_cast(floor(tmp)) + boolrand(&seed); 83 | for (int j = 0; j < nbytes; ++j) { 84 | *(code_ptr++) = static_cast(r & 0xFF); 85 | r = r >> 8; 86 | } 87 | } 88 | return SArray(code); 89 | } else { 90 | // nbytes*8 int to float/double 91 | uint8* code_ptr = SArray(array).data(); 92 | SArray orig(array.size() / nbytes); 93 | for (int i = 0; i < orig.size(); ++i) { 94 | double r = 0; 95 | for (int j = 0; j < nbytes; ++j) { 96 | r += static_cast(*(code_ptr++)) << 8 * j; 97 | } 98 | orig[i] = static_cast(r / ratio * bin + min_v); 99 | } 100 | return SArray(orig); 101 | } 102 | } 103 | }; 104 | 105 | } // namespace PS 106 | -------------------------------------------------------------------------------- /src/filter/frequency_filter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/countmin.h" 3 | #include "util/shared_array_inl.h" 4 | namespace PS { 5 | 6 | /** 7 | * @brief Filters infrequent keys via the countmin sketch 8 | * @tparam K key type 9 | * @tparam V counter type 10 | */ 11 | template 12 | class FreqencyFilter { 13 | public: 14 | /** 15 | * @brief Add keys with their key count 16 | * 17 | * @param key the list of keys 18 | * @param count the according frequency count 19 | */ 20 | void InsertKeys(const SArray& key, const SArray& count); 21 | 22 | /** 23 | * @brief Filters infrequency keys 24 | * 25 | * @param key the list of keys 26 | * @param freq_thr the frequency threshold 27 | * 28 | * @return the keys whose frequency is greater than freq_thr 29 | */ 30 | SArray QueryKeys(const SArray& key, int freq_thr); 31 | 32 | bool Empty() { return count_.empty(); } 33 | 34 | /** 35 | * @brief resize the countmin sketch 36 | * 37 | */ 38 | void Resize(int n, int k) { count_.resize(n, k, 254); } 39 | 40 | void Clear() { count_.clear(); } 41 | 42 | private: 43 | CountMin count_; 44 | }; 45 | 46 | // countmin implementation 47 | template 48 | SArray FreqencyFilter::QueryKeys(const SArray& key, int freqency) { 49 | CHECK_LT(freqency, kuint8max) << "change to uint16 or uint32..."; 50 | SArray filtered_key; 51 | for (auto k : key) { 52 | if ((int)count_.query(k) > freqency) { 53 | filtered_key.push_back(k); 54 | } 55 | } 56 | return filtered_key; 57 | } 58 | 59 | template 60 | void FreqencyFilter::InsertKeys(const SArray& key, const SArray& count) { 61 | CHECK_EQ(key.size(), count.size()); 62 | for (size_t i = 0; i < key.size(); ++i) { 63 | count_.insert(key[i], count[i]); 64 | } 65 | } 66 | 67 | // DEPRECATED hash implementation 68 | // std::unordered_map map_; 69 | 70 | // template 71 | // SArray FreqencyFilter::QueryKeys(const SArray& key, int freqency) { 72 | // SArray filtered_key; 73 | // for (K k : key) { 74 | // if (map_[k] > freqency) filtered_key.push_back(k); 75 | // } 76 | // return filtered_key; 77 | // } 78 | 79 | // template 80 | // void FreqencyFilter::InsertKeys(const SArray& key, const SArray& count) { 81 | // CHECK_EQ(key.size(), count.size()); 82 | // for (size_t i = 0; i < key.size(); ++i) { 83 | // map_[key[i]] += count[i]; 84 | // } 85 | // } 86 | 87 | } 88 | -------------------------------------------------------------------------------- /src/filter/key_caching.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "filter/filter.h" 3 | #include "util/crc32c.h" 4 | namespace PS { 5 | 6 | class KeyCachingFilter : public Filter { 7 | public: 8 | // thread safe 9 | void encode(Message* msg) { 10 | // if (!msg->task.has_key_range()) return; 11 | auto conf = find(FilterConfig::KEY_CACHING, msg); 12 | if (!conf) return; 13 | if (!msg->has_key()) { 14 | conf->clear_signature(); 15 | return; 16 | } 17 | const auto& key = msg->key; 18 | auto sig = crc32c::Value(key.data(), std::min(key.size(), max_sig_len_)); 19 | conf->set_signature(sig); 20 | auto cache_k = std::make_pair( 21 | msg->task.key_channel(), Range(msg->task.key_range())); 22 | Lock l(mu_); 23 | auto& cache = cache_[cache_k]; 24 | bool hit_cache = cache.first == sig && cache.second.size() == key.size(); 25 | if (hit_cache) { 26 | msg->clear_key(); 27 | } else { 28 | cache.first = sig; 29 | cache.second = key; 30 | } 31 | if (conf->clear_cache_if_done() && isDone(msg->task)) { 32 | cache_.erase(cache_k); 33 | } 34 | } 35 | 36 | void decode(Message* msg) { 37 | // if (!msg->task.has_key_range()) return; 38 | auto conf = find(FilterConfig::KEY_CACHING, msg); 39 | if (!conf || !conf->has_signature()) return; 40 | auto sig = conf->signature(); 41 | // do a double check 42 | if (msg->has_key()) { 43 | CHECK_EQ(crc32c::Value(msg->key.data(), std::min(msg->key.size(), max_sig_len_)), sig); 44 | } 45 | auto cache_k = std::make_pair( 46 | msg->task.key_channel(), Range(msg->task.key_range())); 47 | Lock l(mu_); 48 | auto& cache = cache_[cache_k]; 49 | if (msg->has_key()) { 50 | cache.first = sig; 51 | cache.second = msg->key; 52 | } else { 53 | // the cache is invalid... may ask the sender to resend this task 54 | CHECK_EQ(sig, cache.first) << msg->DebugString(); 55 | msg->set_key(cache.second); 56 | } 57 | if (conf->clear_cache_if_done() && isDone(msg->task)) { 58 | cache_.erase(cache_k); 59 | } 60 | } 61 | 62 | private: 63 | bool isDone(const Task& task) { 64 | return (!task.request() || 65 | (task.has_param() 66 | && task.param().push())); 67 | } 68 | 69 | std::unordered_map< 70 | std::pair>, std::pair>> cache_; 71 | 72 | // calculate the signature using the first max_sig_len_*4 bytes to accelerate 73 | // the computation 74 | const size_t max_sig_len_ = 2048; 75 | std::mutex mu_; 76 | }; 77 | 78 | } // namespace 79 | -------------------------------------------------------------------------------- /src/filter/proto/filter.proto: -------------------------------------------------------------------------------- 1 | package PS; 2 | 3 | message FilterConfig { 4 | enum Type { 5 | // cache the keys at both sender and receiver 6 | KEY_CACHING = 1; 7 | // compress data by snappy 8 | COMPRESSING = 2; 9 | // convert a float/double into a fixed-point integer 10 | FIXING_FLOAT = 3; 11 | // add noise to data 12 | NOISE = 4; 13 | } 14 | required Type type = 1; 15 | 16 | // -- key caching -- 17 | // if the task is done, then clear the cache (to save memory) 18 | optional bool clear_cache_if_done = 20 [default = false]; 19 | 20 | // -- fixing float filter -- 21 | optional int32 num_bytes = 5 [default = 3]; 22 | message FixedFloatConfig { 23 | optional float min_value = 1 [default = -1]; 24 | optional float max_value = 2 [default = 1]; 25 | } 26 | repeated FixedFloatConfig fixed_point = 4; 27 | 28 | // -- nosie -- 29 | optional float mean = 6; 30 | optional float std = 7; 31 | 32 | // -- runtime parameters used by the system -- 33 | optional uint32 signature = 2; 34 | repeated uint64 uncompressed_size = 3; 35 | } 36 | -------------------------------------------------------------------------------- /src/filter/sparse_filter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "filter/filter.h" 3 | namespace PS { 4 | 5 | class SparseFilter : public Filter { 6 | public: 7 | SparseFilter() { 8 | // use 0xffff..ff as the mark when a value is filtered, it is nan for float 9 | // and double. 10 | memcpy(&double_v_, &kuint64max, sizeof(double)); 11 | memcpy(&float_v_, &kuint32max, sizeof(float)); 12 | } 13 | 14 | // mark an entry as filtered 15 | void mark(float* v) { *v = float_v_; } 16 | void mark(double* v) { *v = double_v_; } 17 | 18 | // test whether or not an entry is filtered 19 | bool marked(double v) { return v != v; } 20 | bool marked(float v) { return v != v; } 21 | private: 22 | float float_v_; 23 | double double_v_; 24 | }; 25 | 26 | } // namespace PS 27 | -------------------------------------------------------------------------------- /src/learner/proto/bcd.proto: -------------------------------------------------------------------------------- 1 | package PS; 2 | import "util/proto/range.proto"; 3 | import "data/proto/data.proto"; 4 | import "data/proto/example.proto"; 5 | import "parameter/proto/param.proto"; 6 | import "filter/proto/filter.proto"; 7 | 8 | message BCDConfig { 9 | // Divide a feature group into feature_block_ratio x nnz_feature_per_example 10 | // blocks 11 | optional float feature_block_ratio = 1 [default = 4]; 12 | // Use a random order to updating feature block. Turn it off to elimiate the 13 | // randomness (often for debuging), but may affects the convergence rate. 14 | optional bool random_feature_block_order = 2 [default = true]; 15 | // Updating important feature group at the beginning to get a good initial 16 | // start point. 17 | repeated int32 prior_fea_group = 14; 18 | optional int32 num_iter_for_prior_fea_group = 13 [default = 5]; 19 | 20 | // Bounded-delay consistency 21 | optional int32 max_block_delay = 3 [default = 0]; 22 | // max number pass of traing data 23 | optional int32 max_pass_of_data = 4 [default = 10]; 24 | // convergance critiria. stop if the relative objective <= epsilon 25 | optional double epsilon = 5 [default = 1e-4]; 26 | 27 | // features which occurs <= *tail_feature_freq* will be filtered before training 28 | optional int32 tail_feature_freq = 6 [default = 0]; 29 | // countmin sketch is used to filter the tail features. It has two 30 | // parameters, k and n. 31 | optional int32 countmin_k = 7 [default = 2]; 32 | // n = the_first_arrive_key_length * num_workers * countmin_n_ratio 33 | optional double countmin_n_ratio = 8 [default = 2.0]; 34 | 35 | 36 | optional int32 max_num_parallel_groups_in_preprocessing = 9 [default = 1000]; 37 | optional int32 max_data_buf_size_in_mb = 10 [default = 1000]; 38 | optional DataConfig local_cache = 11; 39 | 40 | optional ParamInitConfig init_w = 12; 41 | 42 | repeated FilterConfig comm_filter = 15; 43 | 44 | optional int32 save_model_every_n_iter = 16 [default = 0]; 45 | 46 | optional bool load_local_data = 17 [default = false]; 47 | extensions 100 to 199; 48 | } 49 | 50 | message BCDProgress { 51 | optional double objective = 1; 52 | optional double relative_obj = 2; 53 | optional uint64 nnz_w = 5; 54 | optional double violation = 6; 55 | optional uint64 nnz_active_set = 7; 56 | 57 | // performance 58 | optional double total_time = 10; 59 | repeated double busy_time = 11; 60 | 61 | extensions 100 to 199; 62 | } 63 | 64 | 65 | message BCDCall { 66 | enum Command { 67 | LOAD_DATA = 1; 68 | PREPROCESS_DATA = 2; 69 | UPDATE_MODEL = 3; 70 | EVALUATE_PROGRESS = 4; 71 | SAVE_MODEL = 5; // save w 72 | RECOVER = 6; 73 | COMPUTE_VALIDATION_AUC = 7; 74 | REQUEST_WORKLOAD = 8; 75 | // SAVE_AS_DENSE = 7; // save X * w in a given key range 76 | } 77 | required Command cmd = 1; 78 | optional PbRange key = 2; 79 | // optional int32 feature_group_id = 3; 80 | 81 | optional double kkt_filter_threshold = 4; 82 | optional bool reset_kkt_filter = 5; 83 | 84 | optional int32 iter = 11; 85 | repeated int32 fea_grp = 8; 86 | optional bool hit_cache = 9; 87 | optional DataConfig data = 10; 88 | optional int32 time = 12; 89 | } 90 | 91 | 92 | message LoadDataResponse { 93 | optional ExampleInfo example_info = 1; 94 | optional int32 hit_cache = 2; 95 | } 96 | -------------------------------------------------------------------------------- /src/learner/proto/sgd.proto: -------------------------------------------------------------------------------- 1 | package PS; 2 | import "data/proto/data.proto"; 3 | import "learner/proto/workload.proto"; 4 | 5 | message SGDProgress { 6 | repeated double objective = 1; 7 | optional uint64 num_examples_processed = 2; 8 | repeated double accuracy = 3; 9 | repeated double auc = 4; 10 | optional uint64 nnz = 5; 11 | optional double weight_sum = 6; 12 | optional double delta_sum = 7; 13 | } 14 | 15 | message SGDCall { 16 | enum Command { 17 | REQUEST_WORKLOAD = 6; 18 | UPDATE_MODEL = 1; 19 | REPORT_PROGRESS = 2; 20 | SAVE_MODEL = 3; 21 | RECOVER = 4; 22 | COMPUTE_VALIDATION_AUC = 5; 23 | } 24 | required Command cmd = 1; 25 | optional Workload load = 2; 26 | } 27 | -------------------------------------------------------------------------------- /src/learner/proto/workload.proto: -------------------------------------------------------------------------------- 1 | package PS; 2 | import "data/proto/data.proto"; 3 | 4 | message Workload { 5 | // the workload id 6 | optional int32 id = 1; 7 | 8 | // the data associated with this workload 9 | optional DataConfig data = 2; 10 | 11 | // randomly shuffle the file order 12 | optional bool shuffle = 3 [default = false]; 13 | 14 | // duplicate the data several times 15 | optional int32 replica = 4 [default = 1]; 16 | 17 | // all finished workload ids 18 | repeated int32 finished = 6; 19 | 20 | // all workload is been done 21 | optional bool all_is_done = 5 [default = false]; 22 | } 23 | -------------------------------------------------------------------------------- /src/learner/sgd.cc: -------------------------------------------------------------------------------- 1 | #include "learner/sgd.h" 2 | namespace PS { 3 | 4 | ISGDScheduler::~ISGDScheduler() { 5 | // core dump when delete workload_pool_; 6 | } 7 | 8 | void ISGDScheduler::Run() { 9 | // init monitor 10 | using namespace std::placeholders; 11 | monitor_.set_merger(std::bind(&ISGDScheduler::MergeProgress, this, _1, _2)); 12 | monitor_.set_printer(1, std::bind(&ISGDScheduler::ShowProgress, this, _1, _2)); 13 | 14 | // wait all jobs are finished 15 | sys_.manager().AddNodeFailureHandler([this](const NodeID& id) { 16 | CHECK_NOTNULL(workload_pool_)->restore(id); 17 | }); 18 | CHECK_NOTNULL(workload_pool_)->waitUtilDone(); 19 | 20 | // save model 21 | Task task; 22 | task.mutable_sgd()->set_cmd(SGDCall::SAVE_MODEL); 23 | int ts = Submit(task, kServerGroup); 24 | Wait(ts); 25 | } 26 | 27 | void ISGDScheduler::ProcessResponse(Message* response) { 28 | const auto& sgd = response->task.sgd(); 29 | if (sgd.cmd() == SGDCall::UPDATE_MODEL) { 30 | for (int i = 0; i < sgd.load().finished_size(); ++i) { 31 | workload_pool_->finish(sgd.load().finished(i)); 32 | } 33 | SendWorkload(response->sender); 34 | } 35 | } 36 | 37 | void ISGDScheduler::ProcessRequest(Message* request) { 38 | if (request->task.sgd().cmd() == SGDCall::REQUEST_WORKLOAD) { 39 | SendWorkload(request->sender); 40 | } 41 | } 42 | 43 | void ISGDScheduler::SendWorkload(const NodeID& recver) { 44 | Task task; 45 | task.mutable_sgd()->set_cmd(SGDCall::UPDATE_MODEL); 46 | if (workload_pool_->assign(recver, task.mutable_sgd()->mutable_load())) { 47 | Submit(task, recver); 48 | } 49 | } 50 | 51 | void ISGDScheduler::ShowProgress( 52 | double time, std::unordered_map* progress) { 53 | uint64 num_ex = 0, nnz_w = 0; 54 | SArray objv, auc, acc; 55 | double weight_sum = 0, delta_sum = 1e-20; 56 | for (const auto& it : *progress) { 57 | auto& prog = it.second; 58 | num_ex += prog.num_examples_processed(); 59 | nnz_w += prog.nnz(); 60 | for (int i = 0; i < prog.objective_size(); ++i) { 61 | objv.push_back(prog.objective(i)); 62 | } 63 | for (int i = 0; i < prog.auc_size(); ++i) { 64 | auc.push_back(prog.auc(i)); 65 | } 66 | for (int i = 0; i < prog.accuracy_size(); ++i) { 67 | acc.push_back(prog.accuracy(i)); 68 | } 69 | weight_sum += prog.weight_sum(); 70 | delta_sum += prog.delta_sum(); 71 | } 72 | progress->clear(); 73 | num_ex_processed_ += num_ex; 74 | if (show_prog_head_) { 75 | NOTICE(" sec examples loss auc accuracy |w|_0 updt ratio"); 76 | show_prog_head_ = false; 77 | } 78 | NOTICE("%4d %.2e %.3e %.4f %.4f %.2e %.2e", 79 | (int)time, 80 | (double)num_ex_processed_ , 81 | objv.Sum()/(double)num_ex, 82 | auc.Mean(), 83 | acc.Mean(), 84 | (double)nnz_w, 85 | sqrt(delta_sum) / sqrt(weight_sum)); 86 | } 87 | 88 | 89 | void ISGDScheduler::MergeProgress(const SGDProgress& src, SGDProgress* dst) { 90 | auto old = *dst; *dst = src; 91 | // TODO also append objv 92 | dst->set_num_examples_processed( 93 | dst->num_examples_processed() + old.num_examples_processed()); 94 | } 95 | 96 | } // namespace PS 97 | -------------------------------------------------------------------------------- /src/learner/workload_pool.cc: -------------------------------------------------------------------------------- 1 | #include "learner/workload_pool.h" 2 | #include "data/common.h" 3 | namespace PS { 4 | 5 | void WorkloadPool::set(const Workload& load) { 6 | VLOG(1) << "init workload " << load.ShortDebugString(); 7 | Lock l(mu_); 8 | CHECK_GT(load.replica(), 0); 9 | DataConfig files = searchFiles(load.data()); 10 | VLOG(1) << "find " << files.file_size() << " files: " << files.ShortDebugString(); 11 | 12 | loads_.resize(files.file_size() * load.replica()); 13 | int k = 0; 14 | for (int r = 0; r < load.replica(); ++r) { 15 | if (load.shuffle()) files = shuffleFiles(files); 16 | for (int i = 0; i < files.file_size(); ++i) { 17 | // be careful here, do not use loads_[k] = xxx; it will copy the pointer 18 | // of load.data_ rather than the value. 19 | *loads_[k].load.mutable_data() = ithFile(files, i); 20 | loads_[k].load.set_id(k); 21 | ++ k; 22 | } 23 | } 24 | CHECK_EQ(k, loads_.size()); 25 | } 26 | 27 | 28 | bool WorkloadPool::assign(const NodeID& node_id, Workload* load) { 29 | Lock l(mu_); 30 | for (auto& info : loads_) { 31 | if (!info.assigned) { 32 | load->CopyFrom(info.load); 33 | info.node = node_id; 34 | info.assigned = true; 35 | VLOG(1) << "assign to [" << node_id << "] " << load->ShortDebugString(); 36 | return true; 37 | } 38 | } 39 | return false; 40 | } 41 | 42 | 43 | void WorkloadPool::restore(const NodeID& node_id) { 44 | Lock l(mu_); 45 | for (auto& info : loads_) { 46 | if (info.assigned && !info.finished && info.node == node_id) { 47 | info.assigned = false; 48 | LOG(INFO) << "restore workload " << info.load.id() << " from " << node_id; 49 | } 50 | } 51 | } 52 | 53 | 54 | void WorkloadPool::finish(int id) { 55 | Lock l(mu_); 56 | CHECK_GE(id, 0); CHECK_LT(id, loads_.size()); 57 | loads_[id].finished = true; 58 | ++ num_finished_; 59 | VLOG(1) << "workload " << id << " is finished"; 60 | } 61 | 62 | 63 | void WorkloadPool::waitUtilDone() { 64 | while (true) { 65 | mu_.lock(); 66 | bool done = num_finished_ >= loads_.size(); 67 | mu_.unlock(); 68 | if (done) { 69 | break; 70 | } else { 71 | usleep(1000); 72 | } 73 | } 74 | VLOG(1) << "all workloads are done"; 75 | } 76 | 77 | } // namespace PS 78 | -------------------------------------------------------------------------------- /src/learner/workload_pool.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/common.h" 3 | #include "system/message.h" 4 | #include "learner/proto/workload.pb.h" 5 | namespace PS { 6 | 7 | // the base class of a workload pool. thread safe 8 | class WorkloadPool { 9 | public: 10 | WorkloadPool() { } 11 | WorkloadPool(const Workload& load) { set(load); } 12 | virtual ~WorkloadPool() { } 13 | 14 | // set all workloads 15 | void set(const Workload& load); 16 | 17 | // assign a piece of *workload* to *node_id*. return false if all is done. 18 | bool assign(const NodeID& node_id, Workload* load); 19 | 20 | // restored unfinished workloads have been assigned to *node_id* 21 | void restore(const NodeID& node_id); 22 | 23 | // mark the workload with *id* as finished 24 | void finish(int id); 25 | 26 | // block until all workloads are finished 27 | void waitUtilDone(); 28 | 29 | protected: 30 | struct WorkloadInfo { 31 | NodeID node; 32 | Workload load; 33 | bool assigned = false; 34 | bool finished = false; 35 | }; 36 | std::vector loads_; 37 | int num_finished_ = 0; 38 | std::mutex mu_; 39 | }; 40 | 41 | } // namespace PS 42 | -------------------------------------------------------------------------------- /src/parameter/README.org: -------------------------------------------------------------------------------- 1 | #+TITLE: Shared Parameters 2 | 3 | 4 | | name | KV pairs op | sorted key | value size | storage | usage | 5 | |----------------+-------------+------------+------------+----------+-------------------------------------------------------------------------------------| 6 | | KVVector | batched | yes | fixed | array | sparse data with billions keys, used in worker node, or server node for bath method | 7 | | KVMap | individual | no | fixed | hash map | | 8 | | KVLayer | individual | yes | vary | hash map | neural network | 9 | | KVStore (TODO) | individual | no | vary | hash map | | 10 | -------------------------------------------------------------------------------- /src/parameter/kv_map.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "ps.h" 3 | #include "parameter/parameter.h" 4 | namespace PS { 5 | 6 | /** 7 | * @brief Default entry type for KVMap 8 | */ 9 | template 10 | struct KVMapEntry { 11 | void Get(V* data, void* state) { *data = value; } 12 | void Set(const V* data, void* state) { value = *data; } 13 | V value; 14 | }; 15 | 16 | /** 17 | * @brief Default state type for KVMap 18 | */ 19 | struct KVMapState { 20 | void Update() { } 21 | }; 22 | 23 | /** 24 | * @brief A key-value store with fixed length value. 25 | * 26 | * 27 | * @tparam K the key type 28 | * @tparam V the value type 29 | * @tparam E the entry type 30 | * @tparam S the state type 31 | */ 32 | template , 34 | typename S = KVMapState> 35 | class KVMap : public Parameter { 36 | public: 37 | /** 38 | * @brief Constructor 39 | * 40 | * @param k the length of a value entry 41 | * @param id customer id 42 | */ 43 | KVMap(int k = 1, int id = NextCustomerID()) : 44 | Parameter(id), k_(k) { 45 | CHECK_GT(k, 0); 46 | } 47 | virtual ~KVMap() { } 48 | 49 | void set_state(const S& s) { state_ = s; } 50 | 51 | virtual void Slice(const Message& request, const std::vector>& krs, 52 | std::vector* msgs) { 53 | SliceKOFVMessage(request, krs, msgs); 54 | } 55 | 56 | virtual void GetValue(Message* msg); 57 | virtual void SetValue(const Message* msg); 58 | 59 | virtual void WriteToFile(std::string file); 60 | 61 | protected: 62 | int k_; 63 | S state_; 64 | // TODO use multi-thread cuokoo hash 65 | std::unordered_map data_; 66 | }; 67 | 68 | template 69 | void KVMap::GetValue(Message* msg) { 70 | SArray key(msg->key); 71 | size_t n = key.size(); 72 | SArray val(n * k_); 73 | for (size_t i = 0; i < n; ++i) { 74 | data_[key[i]].Get(val.data() + i * k_, &state_); 75 | } 76 | msg->add_value(val); 77 | } 78 | 79 | template 80 | void KVMap::SetValue(const Message* msg) { 81 | SArray key(msg->key); 82 | size_t n = key.size(); 83 | CHECK_EQ(msg->value.size(), 1); 84 | SArray val(msg->value[0]); 85 | CHECK_EQ(n * k_, val.size()); 86 | 87 | for (size_t i = 0; i < n; ++i) { 88 | data_[key[i]].Set(val.data() + i * k_, &state_); 89 | } 90 | state_.Update(); 91 | } 92 | #if USE_S3 93 | bool s3file(const std::string& name); 94 | std::string s3Prefix(const std::string& path); 95 | std::string s3Bucket(const std::string& path); 96 | std::string s3FileUrl(const std::string& path); 97 | #endif // USE_S3 98 | 99 | template 100 | void KVMap::WriteToFile(std::string file) { 101 | #if USE_S3 102 | std::string s3_file; 103 | if (s3file(file)) { 104 | s3_file=file; 105 | // create a local model dir 106 | file=s3Prefix(s3_file); 107 | } 108 | #endif // USE_S3 109 | if (!dirExists(getPath(file))) { 110 | createDir(getPath(file)); 111 | } 112 | std::ofstream out(file); CHECK(out.good()); 113 | V v; 114 | for (auto& e : data_) { 115 | e.second.Get(&v, &state_); 116 | if (v != 0) out << e.first << "\t" << v << std::endl; 117 | } 118 | #if USE_S3 119 | if (s3file(s3_file)) { 120 | // upload model 121 | std::string cmd = "curl -s '"+s3FileUrl(s3_file)+"?Content-Length=" 122 | +std::to_string(File::size(file))+"&x-amz-acl=public-read' --upload-file "+file; 123 | LOG(INFO)<task.param(); 7 | Message* response = nullptr; 8 | bool push = call.push(); 9 | if (!push) { 10 | // a pull request, need to reply with the value 11 | response = new Message(*request); 12 | } 13 | 14 | if (call.replica()) { 15 | // a replication request 16 | if (push) { 17 | SetReplica(request); 18 | } else { 19 | GetReplica(response); 20 | } 21 | } else { 22 | // a normal request 23 | if (push) { 24 | SetValue(request); 25 | } else { 26 | GetValue(response); 27 | } 28 | } 29 | 30 | if (response) Reply(request, response); 31 | } 32 | 33 | void Parameter::ProcessResponse(Message* response) { 34 | const auto& call = response->task.param(); 35 | bool push = call.push(); 36 | 37 | if (call.replica()) { 38 | // a replication response 39 | if (push) return; // an ACK response 40 | if (Range(response->task.key_range()) == MyKeyRange()) { 41 | Recover(response); 42 | } else { 43 | SetReplica(response); 44 | } 45 | } else { 46 | // a normal response 47 | if (!push) SetValue(response); 48 | } 49 | } 50 | 51 | } // namespace PS 52 | -------------------------------------------------------------------------------- /src/parameter/parameter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "system/customer.h" 3 | #include "parameter/proto/param.pb.h" 4 | namespace PS { 5 | 6 | /// The base class of shared parameters 7 | class Parameter : public Customer { 8 | public: 9 | Parameter(int id) : Customer(id) { } 10 | virtual ~Parameter() { } 11 | 12 | typedef std::initializer_list Timestamps; 13 | typedef ::google::protobuf::RepeatedPtrField Filters; 14 | /** 15 | * @brief Creats a request task 16 | * 17 | * @param channel communication channel 18 | * @param ts the timestamp of this request 19 | * @param wait a list of timestamp this request should wait 20 | * @param filters a list of filters to compress the request message 21 | * @param key_range the key range of this request 22 | * 23 | * @return A Task 24 | */ 25 | static Task Request(int channel, 26 | int ts = Message::kInvalidTime, 27 | const Timestamps& wait = {}, 28 | const Filters& filters = Filters(), 29 | const Range& key_range = Range::All()) { 30 | Task req; req.set_request(true); 31 | req.set_key_channel(channel); 32 | if (ts > Message::kInvalidTime) req.set_time(ts); 33 | for (int t : wait) req.add_wait_time(t); 34 | for (const auto& f : filters) req.add_filter()->CopyFrom(f); 35 | key_range.To(req.mutable_key_range()); 36 | return req; 37 | } 38 | 39 | /// @brief Submit a push message to msg->recver 40 | inline int Push(Message* msg) { 41 | msg->task.mutable_param()->set_push(true); 42 | return Submit(msg); 43 | } 44 | 45 | /// @brief Submit a pull message to msg->recver 46 | inline int Pull(Message* msg) { 47 | msg->task.mutable_param()->set_push(false); 48 | return Submit(msg); 49 | } 50 | 51 | virtual void WriteToFile(std::string file) { } 52 | 53 | virtual void ProcessRequest(Message* request); 54 | virtual void ProcessResponse(Message* response); 55 | protected: 56 | 57 | /// @brief Fill "msg" with the values it requests, e.g., 58 | /// msg->value(0)[0] = my_val_[msg->key[0]]; 59 | virtual void GetValue(Message* msg) = 0; 60 | 61 | /// @brief Set the values in "msg" into into my data strcuture, e.g.. 62 | /// my_val_[msg->key[0]] = msg->value(0)[0]; 63 | virtual void SetValue(const Message* msg) = 0; 64 | 65 | /// @brief the message contains the backup KV pairs sent by the master node of the key 66 | /// segment to its replica node. merge these pairs into my replica, say 67 | /// replica_[msg->sender] = ... 68 | virtual void SetReplica(const Message* msg) { } 69 | 70 | /// @brief retrieve the replica. a new server node replacing a dead server will first 71 | /// ask for the dead's replica node for the data 72 | virtual void GetReplica(Message* msg) { } 73 | 74 | /// @brief a new server node fill its own datastructure via the the replica data from 75 | /// the dead's replica node 76 | virtual void Recover(Message* msg) { } 77 | }; 78 | 79 | } // namespace PS 80 | -------------------------------------------------------------------------------- /src/parameter/proto/param.proto: -------------------------------------------------------------------------------- 1 | package PS; 2 | import "util/proto/assign_op.proto"; 3 | 4 | message ParamCall { 5 | // push or pull 6 | optional bool push = 1 [default = true]; 7 | 8 | // merge operator 9 | optional AssignOpType op = 2; 10 | 11 | optional TailKeyFilter tail_filter = 3; 12 | 13 | // optional bool insert_key = 5; 14 | // optional bool gather = 6; 15 | 16 | // it's a replica request 17 | optional bool replica = 10; 18 | repeated Timestamp backup = 11; 19 | } 20 | 21 | message ParamInitConfig { 22 | enum Type { 23 | ZERO = 1; 24 | CONSTANT = 2; 25 | GAUSSIAN = 3; 26 | FILE = 4; 27 | CLONE = 5; 28 | } 29 | optional Type type = 1 [default = ZERO]; 30 | optional double constant = 2 [default = 1]; 31 | // gaussian random 32 | optional double mean = 3 [default = 0]; 33 | optional double std = 4 [default = 1]; 34 | optional string file_name = 5; 35 | } 36 | 37 | message Timestamp { 38 | required string sender = 1; 39 | required int32 time = 2; 40 | } 41 | 42 | message TailKeyFilter { 43 | optional bool insert_count = 1; 44 | optional int32 freq_threshold = 2; 45 | optional bool query_value = 3; 46 | optional int32 countmin_n = 4 [default = 1000000]; 47 | optional int32 countmin_k = 5 [default = 2]; 48 | } 49 | -------------------------------------------------------------------------------- /src/ps.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "system/customer.h" 3 | 4 | // A simple interface for writing parameter server (PS) programs. see example in 5 | // src/app/hello_world 6 | 7 | // A typical PS program should define the following two functions: 8 | 9 | // This is the main entrance for a work node. All flags and their arguments 10 | // (e.g. -name value) has been removed from argv, and argc has been changed 11 | // properly. However, commandline arguments are remained. 12 | // 13 | // In example "head -n 100 file", -n is a flag and 100 is this flag's argument, 14 | // but file is a commandline argument 15 | int WorkerNodeMain(int argc, char *argv[]); 16 | 17 | namespace PS { 18 | 19 | // Return an instance of a server node. This node is started with "-app_file 20 | // app.conf -app_conf 'key: value'", then conf has both the content of file "app.conf" 21 | // and 'key:value' 22 | App* CreateServerNode(const std::string& conf); 23 | 24 | // Utility functions: 25 | 26 | // The app this node runs 27 | inline App* MyApp() { return Postoffice::instance().manager().app(); } 28 | 29 | // My node information 30 | inline Node MyNode() { return Postoffice::instance().manager().van().my_node(); } 31 | // Each unique string id of my node 32 | inline std::string MyNodeID() { return MyNode().id(); } 33 | // Query the role of this node 34 | inline int IsWorker() { return MyNode().role() == Node::WORKER; } 35 | inline int IsServer() { return MyNode().role() == Node::SERVER; } 36 | inline int IsScheduler() { return MyNode().role() == Node::SCHEDULER; } 37 | 38 | inline Range MyKeyRange() { return Range(MyNode().key()); } 39 | inline std::string SchedulerID() { 40 | return Postoffice::instance().manager().van().scheduler().id(); 41 | } 42 | 43 | inline int NextCustomerID() { 44 | return Postoffice::instance().manager().NextCustomerID(); 45 | } 46 | 47 | // The rank ID of this node in its group. Assume this a worker node in a worker 48 | // group with N workers. Then this node will be assigned an unique ID from 0, 49 | // ..., N. Similarly for server and scheduler. 50 | inline int MyRank() { return MyNode().rank(); } 51 | // Total nodes in this node group. 52 | inline int RankSize() { 53 | auto& mng = Postoffice::instance().manager(); 54 | return IsWorker() ? mng.num_workers() : (IsServer() ? mng.num_servers() : 1); 55 | } 56 | 57 | // Wait until all FLAGS_num_servers servers are ready. 58 | inline void WaitServersReady() { 59 | PS::Postoffice::instance().manager().WaitServersReady(); 60 | } 61 | 62 | // Wait until all FLAGS_num_workers workers are ready. 63 | inline void WaitWorkersReady() { 64 | PS::Postoffice::instance().manager().WaitWorkersReady(); 65 | } 66 | 67 | inline void StartSystem(int argc, char *argv[]) { 68 | PS::Postoffice::instance().Run(&argc, &argv); 69 | } 70 | 71 | inline void StopSystem() { 72 | PS::Postoffice::instance().Stop(); 73 | } 74 | 75 | inline int RunSystem(int argc, char *argv[]) { 76 | StartSystem(argc, argv); StopSystem(); 77 | return 0; 78 | } 79 | 80 | } // namespace PS 81 | -------------------------------------------------------------------------------- /src/ps_main.cc: -------------------------------------------------------------------------------- 1 | #include "ps.h" 2 | namespace PS { 3 | 4 | App* App::Create(const string& conf) { 5 | auto my_role = MyNode().role(); 6 | if (my_role == Node::SERVER) { 7 | return CreateServerNode(conf); 8 | } 9 | return new App(); 10 | } 11 | } // namespace PS 12 | 13 | int main(int argc, char *argv[]) { 14 | auto& sys = PS::Postoffice::instance(); 15 | sys.Run(&argc, &argv); 16 | 17 | int ret = 0; 18 | if (PS::MyNode().role() == PS::Node::WORKER) { 19 | ret = WorkerNodeMain(argc, argv); 20 | } 21 | 22 | sys.Stop(); 23 | return ret; 24 | } 25 | -------------------------------------------------------------------------------- /src/system/assigner.cc: -------------------------------------------------------------------------------- 1 | #include "system/assigner.h" 2 | #include "data/common.h" 3 | namespace PS { 4 | 5 | void DataAssigner::set(const DataConfig& data, int num, bool local) { 6 | // search all files 7 | CHECK_GT(num, 0); 8 | parts_.resize(num); 9 | if (local) { 10 | for (int i = 0; i < num; ++i) parts_[i].CopyFrom(data); 11 | return; 12 | } 13 | 14 | CHECK_GT(data.replica(), 0); 15 | DataConfig files = searchFiles(data); 16 | VLOG(1) << "find " << files.file_size() 17 | << " files: " << files.ShortDebugString(); 18 | 19 | // divide them 20 | for (int r = 0; r < data.replica(); ++r) { 21 | if (data.shuffle()) files = shuffleFiles(files); 22 | auto pts = divideFiles(files, num); 23 | if (r == 0) { 24 | parts_ = pts; 25 | } else { 26 | for (int i = 0; i < num; ++i) { 27 | parts_[i] = appendFiles(parts_[i], pts[i]); 28 | } 29 | } 30 | } 31 | VLOG(1) << "divide into " << num << " jobs"; 32 | } 33 | 34 | bool DataAssigner::next(DataConfig *data) { 35 | if (cur_i_ >= parts_.size()) return false; 36 | *data = parts_[cur_i_ ++]; 37 | return true; 38 | } 39 | 40 | } // namespace PS 41 | -------------------------------------------------------------------------------- /src/system/assigner.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/common.h" 3 | #include "util/range.h" 4 | #include "system/proto/node.pb.h" 5 | #include "data/proto/data.pb.h" 6 | namespace PS { 7 | 8 | // assign *node* with proper rank_id, key_range, etc.. 9 | class NodeAssigner { 10 | public: 11 | NodeAssigner(int num_servers, Range key_range) { 12 | num_servers_ = num_servers; 13 | key_range_ = key_range; 14 | } 15 | ~NodeAssigner() { } 16 | 17 | void assign(Node* node) { 18 | Range kr = key_range_; 19 | int rank = 0; 20 | if (node->role() == Node::SERVER) { 21 | kr = key_range_.EvenDivide(num_servers_, server_rank_); 22 | rank = server_rank_ ++; 23 | } else if (node->role() == Node::WORKER) { 24 | rank = worker_rank_ ++; 25 | } 26 | node->set_rank(rank); 27 | kr.To(node->mutable_key()); 28 | } 29 | 30 | void remove(const Node& node) { 31 | // TODO 32 | } 33 | protected: 34 | int num_servers_ = 0; 35 | int server_rank_ = 0; 36 | int worker_rank_ = 0; 37 | Range key_range_; 38 | }; 39 | 40 | // divide *data* into *num* parts. 41 | class DataAssigner { 42 | public: 43 | DataAssigner() { } 44 | DataAssigner(const DataConfig& data, int num, bool local) { 45 | set(data, num, local); 46 | } 47 | ~DataAssigner() { } 48 | 49 | void set(const DataConfig& data, int num, bool local); 50 | bool next(DataConfig *data); 51 | 52 | int cur_i() { return cur_i_; } 53 | int size() { return parts_.size(); } 54 | private: 55 | std::vector parts_; 56 | int cur_i_ = 0; 57 | }; 58 | 59 | } // namespace PS 60 | -------------------------------------------------------------------------------- /src/system/dashboard.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "system/message.h" 3 | #include "system/proto/heartbeat.pb.h" 4 | 5 | namespace PS { 6 | 7 | struct NodeIDCmp { 8 | void splitNodeID(const NodeID& in, string& primary, string& secondary); 9 | bool operator()(const NodeID& a, const NodeID& b); 10 | }; 11 | 12 | class Dashboard { 13 | public: 14 | void addTask(const NodeID& node, int task_id); 15 | void addReport(const NodeID& node, const string& report); 16 | string report(); 17 | private: 18 | string title(); 19 | string report(const NodeID& node, const HeartbeatReport& report); 20 | std::mutex mu_; 21 | std::map data_; 22 | }; 23 | 24 | } // namespace PS 25 | -------------------------------------------------------------------------------- /src/system/env.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | namespace PS { 3 | 4 | /** 5 | * @brief Setups environment 6 | */ 7 | class Env { 8 | public: 9 | Env() { } 10 | ~Env() { } 11 | 12 | void Init(char* argv0); 13 | private: 14 | void InitGlog(char* argv0); 15 | void InitDMLC(); 16 | void AssembleMyNode(); 17 | 18 | }; 19 | 20 | } // namespace PS 21 | -------------------------------------------------------------------------------- /src/system/heartbeat_info.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "system/proto/heartbeat.pb.h" 4 | #include "util/resource_usage.h" 5 | #include "util/common.h" 6 | 7 | namespace PS { 8 | class HeartbeatInfo { 9 | public: 10 | enum class TimerType : unsigned char { 11 | BUSY = 0, 12 | NUM 13 | }; 14 | 15 | public: 16 | HeartbeatInfo(); 17 | ~HeartbeatInfo(); 18 | HeartbeatInfo(const HeartbeatInfo& other) = delete; 19 | HeartbeatInfo& operator= (const HeartbeatInfo& rhs) = delete; 20 | 21 | HeartbeatReport get(); 22 | 23 | // set network interface which is under use 24 | // such as "eth0" 25 | // set hostname 26 | void init(const string& interface, const string& hostname); 27 | 28 | void startTimer(const HeartbeatInfo::TimerType type); 29 | void stopTimer(const HeartbeatInfo::TimerType type); 30 | 31 | // TODO need lock? 32 | void increaseInBytes(const size_t delta) { Lock l(mu_); in_bytes_ += delta; } 33 | void increaseOutBytes(const size_t delta) { Lock l(mu_); out_bytes_ += delta; } 34 | 35 | private: 36 | std::vector timers_; 37 | MilliTimer total_timer_; 38 | 39 | size_t in_bytes_; 40 | size_t out_bytes_; 41 | 42 | string interface_; 43 | string hostname_; 44 | 45 | // snapshot of performance counters 46 | struct Snapshot { 47 | uint64 process_user; 48 | uint64 process_sys; 49 | uint64 host_user; 50 | uint64 host_sys; 51 | uint64 host_cpu; 52 | 53 | uint64 host_in_bytes; 54 | uint64 host_out_bytes; 55 | 56 | Snapshot() : 57 | process_user(0), 58 | process_sys(0), 59 | host_user(0), 60 | host_sys(0), 61 | host_cpu(0), 62 | host_in_bytes(0), 63 | host_out_bytes(0) { 64 | // do nothing 65 | } 66 | 67 | string shortDebugString() { 68 | std::stringstream ss; 69 | ss << "{"; 70 | ss << "process_user: " << process_user << ", "; 71 | ss << "process_sys: " << process_sys << ", "; 72 | ss << "host_user: " << host_user << ", "; 73 | ss << "host_sys: " << host_sys << ", "; 74 | ss << "host_cpu: " << host_cpu << ", "; 75 | ss << "host_in_bytes: " << host_in_bytes << ", "; 76 | ss << "host_out_bytes: " << host_out_bytes; 77 | ss << "}"; 78 | 79 | return ss.str(); 80 | } 81 | }; // struct Snapshot 82 | 83 | HeartbeatInfo::Snapshot last_; 84 | HeartbeatInfo::Snapshot dump(); 85 | 86 | std::mutex mu_; 87 | size_t cpu_core_number_; 88 | }; // class Heartbeatinfo 89 | }; // namespace PS 90 | -------------------------------------------------------------------------------- /src/system/manager.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/common.h" 3 | #include "system/proto/node.pb.h" 4 | #include "system/proto/task.pb.h" 5 | #include "system/van.h" 6 | #include "system/env.h" 7 | #include "system/assigner.h" 8 | namespace PS { 9 | 10 | class App; 11 | class Customer; 12 | 13 | class Manager { 14 | public: 15 | Manager(); 16 | ~Manager(); 17 | 18 | void Init(char* argv0); 19 | void Run(); 20 | void Stop(); 21 | bool Process(Message* msg); 22 | 23 | // manage nodes 24 | void AddNode(const Node& node); 25 | void RemoveNode(const NodeID& node_id); 26 | // detect that *node_id* is disconnected 27 | void NodeDisconnected(const NodeID node_id); 28 | // add a function handler which will be called in *nodeDisconnected* 29 | typedef std::function NodeFailureHandler; 30 | void AddNodeFailureHandler(NodeFailureHandler handler) { 31 | node_failure_handlers_.push_back(handler); 32 | } 33 | 34 | // manage customer 35 | Customer* customer(int id); 36 | void AddCustomer(Customer* obj); 37 | void RemoveCustomer(int id); 38 | int NextCustomerID(); 39 | 40 | // workers and servers 41 | void WaitServersReady(); 42 | void WaitWorkersReady(); 43 | 44 | int num_workers() { return num_workers_; } 45 | int num_servers() { return num_servers_; } 46 | 47 | // manage message TODO 48 | void AddRequest(Message* msg) { delete msg; } 49 | void AddResponse(Message* msg) { } 50 | 51 | // accessors 52 | Van& van() { return van_; } 53 | App* app() { return app_; } 54 | 55 | private: 56 | bool IsScheduler() { return van_.my_node().role() == Node::SCHEDULER; } 57 | Task NewControlTask(Control::Command cmd); 58 | void SendTask(const NodeID& recver, const Task& task); 59 | void SendTask(const Node& recver, const Task& task) { 60 | SendTask(recver.id(), task); 61 | } 62 | 63 | // the app 64 | void CreateApp(const string& conf); 65 | App* app_ = nullptr; 66 | string app_conf_; 67 | // std::promise my_node_promise_; 68 | 69 | // nodes 70 | std::map nodes_; 71 | std::mutex nodes_mu_; 72 | int num_workers_ = 0; 73 | int num_servers_ = 0; 74 | int num_active_nodes_ = 0; 75 | std::vector node_failure_handlers_; 76 | bool is_my_node_inited_ = false; 77 | 78 | // only available at the scheduler node 79 | NodeAssigner* node_assigner_ = nullptr; 80 | 81 | // customers 82 | // format: > 83 | std::map> customers_; 84 | 85 | bool done_ = false; 86 | bool in_exit_ = false; 87 | int time_ = 0; 88 | 89 | Van van_; 90 | Env env_; 91 | 92 | DISALLOW_COPY_AND_ASSIGN(Manager); 93 | }; 94 | 95 | } // namespace PS 96 | -------------------------------------------------------------------------------- /src/system/message.cc: -------------------------------------------------------------------------------- 1 | #include "system/message.h" 2 | namespace PS { 3 | 4 | // Message::Message(const NodeID& dest, int time, int wait_time) 5 | // : recver(dest) { 6 | // task.set_time(time); 7 | // if (wait_time != kInvalidTime) task.add_wait_time(wait_time); 8 | // } 9 | 10 | FilterConfig* Message::add_filter(FilterConfig::Type type) { 11 | auto ptr = task.add_filter(); 12 | ptr->set_type(type); 13 | return ptr; 14 | } 15 | 16 | size_t Message::mem_size() { 17 | size_t nbytes = task.SpaceUsed() + key.MemSize(); 18 | for (const auto& v : value) nbytes += v.MemSize(); 19 | return nbytes; 20 | } 21 | 22 | std::string Message::ShortDebugString() const { 23 | std::stringstream ss; 24 | if (key.size()) ss << "key [" << key.size() << "] "; 25 | if (value.size()) { 26 | ss << "value ["; 27 | for (int i = 0; i < value.size(); ++i) { 28 | ss << value[i].size(); 29 | if (i < value.size() - 1) ss << ","; 30 | } 31 | ss << "] "; 32 | } 33 | auto t = task; t.clear_msg(); ss << t.ShortDebugString(); 34 | return ss.str(); 35 | } 36 | 37 | std::string Message::DebugString() const { 38 | std::stringstream ss; 39 | ss << "[message]: " << sender << "=>" << recver 40 | << "[task]:" << task.ShortDebugString() 41 | << "\n[key]:" << key.size() 42 | << "\n[" << value.size() << " value]: "; 43 | for (const auto& x: value) 44 | ss << x.size() << " "; 45 | return ss.str(); 46 | } 47 | 48 | 49 | } // namespace PS 50 | -------------------------------------------------------------------------------- /src/system/monitor.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file monitor.h 3 | * @brief A distributed monitor 4 | * 5 | */ 6 | #pragma once 7 | #include "system/customer.h" 8 | namespace PS { 9 | 10 | /** 11 | * @brief The master of the monitor, which collects reports from slavers and 12 | * display the progress 13 | * 14 | * @tparam Progress A proto buffer class 15 | */ 16 | template 17 | class MonitorMaster : public Customer { 18 | public: 19 | MonitorMaster(int id = NextCustomerID()) : Customer(id) {} 20 | 21 | typedef std::function*)> Printer; 23 | /** 24 | * @brief set the printer 25 | * 26 | * @param time_interval in sec 27 | * @param printer 28 | */ 29 | void set_printer(double time_interval, Printer printer) { 30 | timer_.start(); 31 | printer_ = printer; 32 | interval_ = time_interval; 33 | } 34 | 35 | typedef std::function Merger; 36 | /** 37 | * @brief set the merger 38 | * 39 | * @param merger merges two reports 40 | */ 41 | void set_merger(Merger merger) { 42 | merger_ = merger; 43 | } 44 | 45 | virtual void ProcessRequest(Message* request) { 46 | NodeID sender = request->sender; 47 | Progress prog; 48 | CHECK(prog.ParseFromString(request->task.msg())); 49 | if (merger_) { 50 | merger_(prog, &progress_[sender]); 51 | } else { 52 | progress_[sender] = prog; 53 | } 54 | 55 | double time = timer_.stop(); 56 | if (time > interval_ && printer_) { 57 | total_time_ += time; 58 | printer_(total_time_, &progress_); 59 | timer_.restart(); 60 | } else { 61 | timer_.start(); 62 | } 63 | } 64 | private: 65 | std::unordered_map progress_; 66 | double interval_; 67 | Timer timer_; 68 | double total_time_ = 0; 69 | Merger merger_; 70 | Printer printer_; 71 | }; 72 | 73 | /** 74 | * @brief A slave monitor, which report to the master monitor 75 | * 76 | * @tparam Progress a proto class 77 | */ 78 | template 79 | class MonitorSlaver : public Customer { 80 | public: 81 | MonitorSlaver(const NodeID& master, int id = NextCustomerID()) 82 | : Customer(id), master_(master) { } 83 | virtual ~MonitorSlaver() { } 84 | 85 | /** 86 | * @brief Sends a report to the master 87 | * 88 | * @param prog 89 | */ 90 | void Report(const Progress& prog) { 91 | string str; CHECK(prog.SerializeToString(&str)); 92 | Task report; report.set_msg(str); 93 | Submit(report, master_); 94 | } 95 | protected: 96 | NodeID master_; 97 | }; 98 | 99 | } // namespace PS 100 | -------------------------------------------------------------------------------- /src/system/postoffice.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/common.h" 3 | #include "system/message.h" 4 | #include "util/threadsafe_queue.h" 5 | #include "system/manager.h" 6 | #include "system/heartbeat_info.h" 7 | namespace PS { 8 | 9 | class Postoffice { 10 | public: 11 | SINGLETON(Postoffice); 12 | ~Postoffice(); 13 | 14 | /** 15 | * @brief Starts the system 16 | */ 17 | void Run(int* argc, char***); 18 | /** 19 | * @brief Stops the system 20 | */ 21 | void Stop() { manager_.Stop(); } 22 | 23 | /** 24 | * @brief Queue a message into the sending buffer, which will be sent by the 25 | * sending thread. It is thread safe. 26 | * 27 | * @param msg it will be DELETE by system after sent successfully. so do NOT 28 | * delete it before 29 | */ 30 | void Queue(Message* msg); 31 | 32 | Manager& manager() { return manager_; } 33 | HeartbeatInfo& pm() { return perf_monitor_; } 34 | 35 | private: 36 | Postoffice(); 37 | void Send(); 38 | void Recv(); 39 | bool Process(Message* msg); 40 | std::unique_ptr recv_thread_; 41 | std::unique_ptr send_thread_; 42 | ThreadsafeQueue sending_queue_; 43 | 44 | Manager manager_; 45 | HeartbeatInfo perf_monitor_; 46 | 47 | // key: , value: messages will be packed 48 | std::map, std::vector> pack_; 49 | std::mutex pack_mu_; 50 | 51 | DISALLOW_COPY_AND_ASSIGN(Postoffice); 52 | }; 53 | 54 | } // namespace PS 55 | -------------------------------------------------------------------------------- /src/system/proto/heartbeat.proto: -------------------------------------------------------------------------------- 1 | package PS; 2 | 3 | message HeartbeatReport { 4 | optional int32 task_id = 1 [default = 0]; 5 | optional string hostname = 14; 6 | 7 | // time stamp 8 | // latest heartbeat report the scheduler has ever received 9 | // from a specified worker/server 10 | optional uint32 seconds_since_epoch = 2; 11 | 12 | optional uint32 total_time_milli = 13; 13 | optional uint32 busy_time_milli = 3; 14 | 15 | // recv/sent bytes via zmq 16 | optional uint32 net_in_mb = 4; 17 | optional uint32 net_out_mb = 5; 18 | 19 | // user+sys (percentage) 20 | optional uint32 process_cpu_usage = 6; 21 | optional uint32 host_cpu_usage = 7; 22 | 23 | optional uint32 process_rss_mb = 8; 24 | optional uint32 process_virt_mb = 9; 25 | optional uint32 host_in_use_gb = 10; 26 | optional uint32 host_in_use_percentage = 15; 27 | 28 | // host's network in/out bandwidth usage (MB/s) 29 | optional uint32 host_net_in_bw = 11; 30 | optional uint32 host_net_out_bw = 12; 31 | } 32 | -------------------------------------------------------------------------------- /src/system/proto/node.proto: -------------------------------------------------------------------------------- 1 | package PS; 2 | import "util/proto/range.proto"; 3 | 4 | message Node { 5 | enum Role { 6 | SERVER = 0; 7 | WORKER = 1; 8 | SCHEDULER = 3; // each running application has a single scheduler 9 | GROUP = 4; // a virtual node, present a group of node 10 | UNUSED = 5; // a backup node, could turn into another node 11 | } 12 | 13 | required Role role = 1; 14 | optional string id = 2; 15 | optional int32 rank = 5; 16 | // network address 17 | optional string hostname = 3; 18 | optional int32 port = 4; 19 | 20 | optional PbRange key = 6; 21 | } 22 | -------------------------------------------------------------------------------- /src/system/proto/task.proto: -------------------------------------------------------------------------------- 1 | package PS; 2 | import "util/proto/range.proto"; 3 | import "data/proto/data.proto"; 4 | import "system/proto/node.proto"; 5 | import "parameter/proto/param.proto"; 6 | import "filter/proto/filter.proto"; 7 | import "learner/proto/sgd.proto"; 8 | import "learner/proto/bcd.proto"; 9 | 10 | message Task { 11 | // true: system control task, typically *ctrl* should be set 12 | // false: a task for a customer, and *customer_id* should be set 13 | optional bool control = 1 [default = false]; 14 | // true: a request task 15 | // false: the response task to the request task with the same *time* 16 | optional bool request = 2 [default = false]; 17 | // the unique id of a customer 18 | optional int32 customer_id = 3; 19 | 20 | // the timestamp if this task 21 | optional int32 time = 5; 22 | // the depended tasks of this one. that is, this task is executed only if all 23 | // tasks from the same node with time contained in *wait_time* are finished. 24 | // only valid if *request*=true 25 | repeated int32 wait_time = 6; 26 | 27 | // the key range of this task 28 | optional PbRange key_range = 7; 29 | // namespace of keys 30 | optional int32 key_channel = 8; 31 | // true: the message sent with this task will contain a list of keys 32 | optional bool has_key = 9 [default = false]; 33 | 34 | // data type 35 | optional DataType key_type = 13; 36 | repeated DataType value_type = 14; 37 | 38 | // filters applied to the data 39 | repeated FilterConfig filter = 12; 40 | 41 | // if true, the tasks will be packed into a single one during sending 42 | optional bool more = 16 [default = false]; 43 | repeated Task task = 15; 44 | 45 | // the place to store a small amount of data 46 | optional bytes msg = 17; 47 | 48 | // system control signals 49 | optional Control ctrl = 18; 50 | 51 | // parameters 52 | optional ParamCall param = 20; 53 | optional SGDCall sgd = 21; 54 | optional BCDCall bcd = 22; 55 | 56 | extensions 100 to 199; 57 | } 58 | 59 | message Control { 60 | enum Command { 61 | // a node => the scheduler 62 | REQUEST_APP = 1; 63 | REGISTER_NODE = 2; 64 | REPORT_PERF = 3; 65 | READY_TO_EXIT = 4; 66 | 67 | // the scheduler => a node 68 | ADD_NODE = 10; 69 | UPDATE_NODE = 11; 70 | REPLACE_NODE = 12; 71 | REMOVE_NODE = 13; 72 | EXIT = 14; 73 | } 74 | required Command cmd = 1; 75 | repeated Node node = 2; 76 | } 77 | 78 | enum DataType { 79 | OTHER = 0; 80 | INT8 = 1; 81 | INT16 = 2; 82 | INT32 = 3; 83 | INT64 = 4; 84 | UINT8 = 5; 85 | UINT16 = 6; 86 | UINT32 = 7; 87 | UINT64 = 8; 88 | FLOAT = 9; 89 | DOUBLE = 10; 90 | CHAR = 11; 91 | } 92 | -------------------------------------------------------------------------------- /src/system/remote_node.cc: -------------------------------------------------------------------------------- 1 | #include "system/remote_node.h" 2 | #include "system/customer.h" 3 | #include "util/crc32c.h" 4 | #include "util/shared_array_inl.h" 5 | namespace PS { 6 | 7 | Filter* RemoteNode::FindFilterOrCreate(const FilterConfig& conf) { 8 | int id = conf.type(); 9 | auto it = filters.find(id); 10 | if (it == filters.end()) { 11 | filters[id] = Filter::create(conf); 12 | it = filters.find(id); 13 | } 14 | return it->second; 15 | } 16 | 17 | void RemoteNode::EncodeMessage(Message* msg) { 18 | const auto& tk = msg->task; 19 | for (int i = 0; i < tk.filter_size(); ++i) { 20 | FindFilterOrCreate(tk.filter(i))->encode(msg); 21 | } 22 | } 23 | void RemoteNode::DecodeMessage(Message* msg) { 24 | const auto& tk = msg->task; 25 | // a reverse order comparing to encode 26 | for (int i = tk.filter_size()-1; i >= 0; --i) { 27 | FindFilterOrCreate(tk.filter(i))->decode(msg); 28 | } 29 | } 30 | 31 | void RemoteNode::AddGroupNode(RemoteNode* rnode) { 32 | CHECK_NOTNULL(rnode); 33 | // insert s into sub_nodes such as sub_nodes is still ordered 34 | int pos = 0; 35 | Range kr(rnode->node.key()); 36 | while (pos < group.size()) { 37 | if (kr.InLeft(Range(group[pos]->node.key()))) { 38 | break; 39 | } 40 | ++ pos; 41 | } 42 | group.insert(group.begin() + pos, rnode); 43 | keys.insert(keys.begin() + pos, kr); 44 | } 45 | 46 | void RemoteNode::RemoveGroupNode(RemoteNode* rnode) { 47 | size_t n = group.size(); 48 | CHECK_EQ(n, keys.size()); 49 | for (int i = 0; i < n; ++i) { 50 | if (group[i] == rnode) { 51 | group.erase(group.begin() + i); 52 | keys.erase(keys.begin() + i); 53 | return; 54 | } 55 | } 56 | } 57 | 58 | } // namespace PS 59 | -------------------------------------------------------------------------------- /src/system/remote_node.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/common.h" 3 | #include "system/proto/task.pb.h" 4 | #include "system/van.h" 5 | #include "system/postoffice.h" 6 | #include "filter/filter.h" 7 | namespace PS { 8 | 9 | // The presentation of a remote node used by Executor. It's not thread 10 | // safe, do not use them directly. 11 | 12 | // Track a request by its timestamp. 13 | class RequestTracker { 14 | public: 15 | RequestTracker() { } 16 | ~RequestTracker() { } 17 | 18 | // Returns true if timestamp "ts" is marked as finished. 19 | bool IsFinished(int ts) { 20 | return ts < 0 || (((int)data_.size() > ts) && data_[ts]); 21 | } 22 | 23 | // Mark timestamp "ts" as finished. 24 | void Finish(int ts) { 25 | CHECK_GE(ts, 0); 26 | CHECK_LT(ts, 1000000); 27 | if ((int)data_.size() <= ts) data_.resize(ts*2+5); 28 | data_[ts] = true; 29 | } 30 | private: 31 | std::vector data_; 32 | }; 33 | 34 | // A remote node 35 | struct RemoteNode { 36 | public: 37 | RemoteNode() { } 38 | ~RemoteNode() { 39 | for (auto f : filters) delete f.second; 40 | } 41 | 42 | void EncodeMessage(Message* msg); 43 | void DecodeMessage(Message* msg); 44 | 45 | Node node; // the remote node 46 | bool alive = true; // aliveness 47 | 48 | // timestamp tracker 49 | RequestTracker sent_req_tracker; 50 | RequestTracker recv_req_tracker; 51 | 52 | // node group info. if "node" is a node group, then "group" contains all node 53 | // pointer in this group. otherwise, group contains "this" 54 | void AddGroupNode(RemoteNode* rnode); 55 | void RemoveGroupNode(RemoteNode* rnode); 56 | std::vector group; 57 | // keys[i] is the key range of group[i] 58 | std::vector> keys; 59 | 60 | private: 61 | Filter* FindFilterOrCreate(const FilterConfig& conf); 62 | // key: filter_type 63 | std::unordered_map filters; 64 | 65 | }; 66 | 67 | 68 | } // namespace PS 69 | -------------------------------------------------------------------------------- /src/system/task_tracker.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "glog/logging.h" 7 | 8 | namespace PS { 9 | 10 | // track a task by using its timestamp, thread safe 11 | class TaskTracker { 12 | public: 13 | TaskTracker() { } 14 | ~TaskTracker() { } 15 | 16 | // wait until task k has been finished 17 | void wait(int k) { 18 | ULock l(mu_); 19 | cond_.wait(l, [this, k]{ return (task_.count(k) && task_[k] == true); }); 20 | } 21 | 22 | // non-blocking wait 23 | bool tryWait(int k) { 24 | Lock l(mu_); 25 | return (task_.count(k) && task_[k] == true); 26 | } 27 | 28 | // start task k, do nothing if it has been started or finished 29 | void start(int k) { 30 | Lock l(mu_); 31 | if (!task_.count(k)) task_[k] = false; 32 | } 33 | 34 | // whether or not task k has been finished 35 | bool hasFinished(int k) { 36 | Lock l(mu_); 37 | return (task_.count(k) && task_[k] == true); 38 | } 39 | 40 | // finish k, no warning if it has not been started or has been finished 41 | void finish(int k) { 42 | Lock l(mu_); 43 | task_[k] = true; 44 | cond_.notify_all(); 45 | } 46 | 47 | private: 48 | typedef std::lock_guard Lock; 49 | typedef std::unique_lock ULock; 50 | 51 | std::mutex mu_; 52 | std::condition_variable cond_; 53 | 54 | std::map task_; 55 | }; 56 | 57 | } // namespace PS 58 | -------------------------------------------------------------------------------- /src/system/van.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/common.h" 3 | #include "system/proto/node.pb.h" 4 | #include "system/message.h" 5 | namespace PS { 6 | 7 | /** 8 | * @brief Van sends (receives) packages to (from) a node The current 9 | * implementation uses ZeroMQ 10 | * 11 | */ 12 | class Van { 13 | public: 14 | Van() { } 15 | ~Van(); 16 | 17 | void Init(); 18 | 19 | void Disconnect(const Node& node); 20 | bool Connect(const Node& node); 21 | 22 | bool Send(Message* msg, size_t* send_bytes); 23 | bool Recv(Message* msg, size_t* recv_bytes); 24 | 25 | static Node ParseNode(const string& node_str); 26 | 27 | Node& my_node() { return my_node_; } 28 | Node& scheduler() { return scheduler_; }; 29 | private: 30 | // bind to my port 31 | void Bind(); 32 | 33 | static void FreeData(void *data, void *hint) { 34 | if (hint == NULL) { 35 | delete [] (char*)data; 36 | } else { 37 | delete (SArray*)hint; 38 | } 39 | } 40 | 41 | bool IsScheduler() { return my_node_.role() == Node::SCHEDULER; } 42 | // for scheduler: monitor the liveness of all other nodes 43 | // for other nodes: monitor the liveness of the scheduler 44 | void Monitor(); 45 | 46 | void *context_ = nullptr; 47 | void *receiver_ = nullptr; 48 | Node my_node_; 49 | Node scheduler_; 50 | std::unordered_map senders_; 51 | 52 | DISALLOW_COPY_AND_ASSIGN(Van); 53 | 54 | // TODO move to postoffice::perf_monitor_ 55 | // print statistic info 56 | void Statistic(); 57 | std::unordered_map hostnames_; 58 | size_t sent_to_local_ = 0; 59 | size_t sent_to_others_ = 0; 60 | size_t received_from_local_ = 0; 61 | size_t received_from_others_ = 0; 62 | 63 | // for monitor 64 | std::unordered_map fd_to_nodeid_; 65 | std::mutex fd_to_nodeid_mu_; 66 | std::thread* monitor_thread_; 67 | 68 | // debug performance 69 | // double send_time_ = 0; 70 | // double recv_time_ = 0; 71 | // int num_call_ = 0; 72 | }; 73 | 74 | } // namespace PS 75 | -------------------------------------------------------------------------------- /src/test/aggregation_ps.cc: -------------------------------------------------------------------------------- 1 | #include "ps.h" 2 | namespace PS { 3 | 4 | DEFINE_int32(n, 100, "# of aggregation"); 5 | DEFINE_int32(interval, 100000, "time (usec) between two aggregation"); 6 | 7 | class Server : public App { 8 | public: 9 | virtual void Run() { 10 | WaitWorkersReady(); 11 | for (int i = 0; i < FLAGS_n; ++i) { 12 | WaitReceivedRequest(i, kWorkerGroup); 13 | LL << MyNodeID() << " " << i; 14 | } 15 | LL << MyNodeID() << " done"; 16 | } 17 | }; 18 | 19 | class Worker : public App { 20 | public: 21 | virtual void Run() { 22 | for (int i = 0; i < FLAGS_n; ++i) { 23 | int ts = Submit(Task(), kServerGroup); 24 | usleep(FLAGS_interval); 25 | Wait(ts); 26 | LL << MyNodeID() << " " << i; 27 | } 28 | LL << MyNodeID() << " done"; 29 | } 30 | }; 31 | 32 | App* App::Create(const std::string& conf) { 33 | if (IsWorker()) return new Worker(); 34 | if (IsServer()) return new Server(); 35 | return new App(); 36 | } 37 | 38 | } // namespace PS 39 | 40 | int main(int argc, char *argv[]) { 41 | return PS::RunSystem(argc, argv); 42 | } 43 | -------------------------------------------------------------------------------- /src/test/assign_op_test.cc: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "util/assign_op.h" 3 | using namespace PS; 4 | // evaluation the performance of assignop comparing to the plain version 5 | 6 | size_t n = 1000000000; 7 | 8 | TEST(AssignOp, OpPlus) { 9 | double a = 0; 10 | double b = 1; 11 | for (int i = 0; i < n; ++i) { 12 | AssignOp(a, b, AssignOpType::PLUS); 13 | } 14 | EXPECT_EQ(a,(double)n); 15 | } 16 | 17 | TEST(AssignOp, PlusPlain) { 18 | double a = 0; 19 | double b = 1; 20 | for (int i = 0; i < n; ++i) { 21 | a += b; 22 | } 23 | EXPECT_EQ(a,(double)n); 24 | } 25 | 26 | TEST(AssignOp, OpSet) { 27 | double a = 0; 28 | double b = 1; 29 | for (int i = 0; i < n; ++i) { 30 | AssignOp(a, b, AssignOpType::ASSIGN); 31 | } 32 | EXPECT_EQ(a,b); 33 | } 34 | 35 | TEST(AssignOp, SetPlain) { 36 | double a = 0; 37 | double b = 1; 38 | for (int i = 0; i < n; ++i) { 39 | a = b; 40 | } 41 | EXPECT_EQ(a,b); 42 | } 43 | -------------------------------------------------------------------------------- /src/test/bloom_filter_test.cc: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "util/bloom_filter.h" 3 | #include "util/block_bloom_filter.h" 4 | #include "util/shared_array_inl.h" 5 | 6 | DEFINE_int32(m, 100, ""); 7 | DEFINE_int32(k, 2, ""); 8 | TEST(BloomFilter, Speed) { 9 | 10 | using namespace PS; 11 | // see src/test/prepare_test_data to get the data 12 | SArray key1; key1.readFromFile("../data/test/key.1"); 13 | SArray key2; key2.readFromFile("../data/test/key.3"); 14 | BloomFilter bloom(FLAGS_m, FLAGS_k); 15 | auto tv = tic(); 16 | for (auto k : key1) bloom.insert(k); 17 | LL << key1.size() / toc(tv) << " insert per sec"; 18 | 19 | tv = tic(); 20 | int res = 0; 21 | for (auto k : key2) res += bloom[k]; 22 | LL << key2.size() / toc(tv) << " query per sec"; 23 | 24 | LL << "FPR: " << (double) res / (double) key2.size(); 25 | 26 | 27 | BlockBloomFilter blk_bloom(FLAGS_m, FLAGS_k); 28 | 29 | tv = tic(); 30 | for (auto k : key1) blk_bloom.insert(k); 31 | LL << key1.size() / toc(tv) << " insert per sec"; 32 | 33 | tv = tic(); 34 | res = 0; 35 | for (auto k : key2) res += blk_bloom[k]; 36 | LL << key2.size() / toc(tv) << " query per sec"; 37 | 38 | LL << "FPR: " << (double) res / (double) key2.size(); 39 | } 40 | 41 | int main(int argc, char **argv) { 42 | FLAGS_logtostderr = 1; 43 | testing::InitGoogleTest(&argc, argv); 44 | google::ParseCommandLineFlags(&argc, &argv, true); 45 | 46 | return RUN_ALL_TESTS(); 47 | } 48 | -------------------------------------------------------------------------------- /src/test/build.mk: -------------------------------------------------------------------------------- 1 | test: build/hello_ps \ 2 | build/aggregation_ps \ 3 | build/network_perf_ps \ 4 | build/kv_vector_ps \ 5 | build/kv_vector_buffer_ps \ 6 | build/kv_map_ps \ 7 | build/kv_map_perf_ps \ 8 | build/kv_layer_ps \ 9 | build/kv_layer_perf_ps \ 10 | build/assign_op_test \ 11 | build/parallel_ordered_match_test \ 12 | build/common_test 13 | 14 | build/%_ps: src/test/%_ps.cc $(PS_LIB) 15 | $(CC) $(CFLAGS) $^ $(LDFLAGS) -o $@ 16 | 17 | # google test 18 | TESTFLAGS = $(TEST_MAIN) -lgtest $(LDFLAGS) 19 | 20 | build/parallel_ordered_match_test: build/util/file.o build/util/proto/*.o build/data/proto/*.pb.o 21 | 22 | build/%_test: build/test/%_test.o 23 | $(CC) $(CFLAGS) $(filter %.o %.a %.cc, $^) $(TESTFLAGS) -o $@ 24 | 25 | # build/reassign_server_key_range: src/test/reassign_server_key_range.cc $(PS_LIB) 26 | # $(CC) $(CFLAGS) $^ $(LDFLAGS) -o $@ 27 | 28 | # build/fixing_float_test: src/test/fixing_float_test.cc src/filter/fixing_float.h $(PS_LIB) 29 | # $(CC) $(CFLAGS) $< $(PS_LIB) $(TESTFLAGS) -o $@ 30 | -------------------------------------------------------------------------------- /src/test/common_test.cc: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "util/common.h" 3 | #include "util/resource_usage.h" 4 | #include "util/threadsafe_queue.h" 5 | using namespace PS; 6 | 7 | std::shared_ptr p(new int()); 8 | inline void fun1(std::shared_ptr a) { 9 | *a += 1; 10 | } 11 | 12 | inline void fun2(std::shared_ptr& a) { 13 | *a += 1; 14 | } 15 | 16 | void run(int t) { 17 | int n = 100000; 18 | if (t == 1){ 19 | for (int i = 0; i < n; ++i) fun1(p); 20 | } else { 21 | for (int i = 0; i < n; ++i) fun2(p); 22 | } 23 | } 24 | 25 | TEST(xx, xx) { 26 | *p = 1; 27 | auto tv = tic(); 28 | std::thread p1(run, 1); 29 | std::thread p2(run, 1); 30 | p1.join(); 31 | p2.join(); 32 | LL << toc(tv) << " " << *p; 33 | 34 | *p = 1; 35 | tv = tic(); 36 | std::thread p3(run, 2); 37 | std::thread p4(run, 2); 38 | p3.join(); 39 | p4.join(); 40 | LL << toc(tv) << " " << *p; 41 | 42 | } 43 | 44 | 45 | TEST(xx,bb) { 46 | ThreadsafeQueue> queue; 47 | 48 | std::unique_ptr a(new int()); 49 | *a = 1; 50 | LL << a.get(); 51 | queue.push(std::move(a)); 52 | LL << a.get(); 53 | 54 | std::unique_ptr b; 55 | queue.wait_and_pop(b); 56 | LL << b.get(); 57 | } 58 | -------------------------------------------------------------------------------- /src/test/countmin_test.cc: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "util/countmin.h" 3 | using namespace PS; 4 | 5 | TEST(xx, xx) { 6 | std::shared_ptr p(new int()); 7 | 8 | 9 | } 10 | 11 | // class CountMinTest : public ::testing::Test { 12 | // protected: 13 | // virtual void SetUp() { 14 | // int m = 1000; 15 | 16 | // key.resize(n); 17 | // cnt.resize(n); 18 | // for (int i = 0; i < n; ++i) { 19 | // uint64 a = (uint64) rand(); 20 | // uint64 b = (uint64) rand(); 21 | // key[i] = a | (b << 32); 22 | // cnt[i] = rand() % m; 23 | // } 24 | 25 | // auto tv = tic(); 26 | // for (int i = 0; i < n; ++i) map[key[i]] += cnt[i]; 27 | // LL << toc(tv); 28 | // } 29 | 30 | // void test(int len, int k) { 31 | // CountMin cm; 32 | // cm.resize(len, k); 33 | // auto tv = tic(); 34 | // cm.bulkInsert(key, cnt); 35 | // auto t = toc(tv); 36 | 37 | // double err = 0, tol = 0; 38 | // for (int i = 0; i < n; ++i) { 39 | // int a = cm.query(key[i]); 40 | // int b = map[key[i]]; 41 | // EXPECT_GE(a, b); 42 | // err += (a-b); 43 | // tol += b; 44 | // } 45 | // LL << t << " " << err / tol; 46 | // } 47 | // int n = 1000000; 48 | // SArray key; 49 | // SArray cnt; 50 | // std::unordered_map map; 51 | // }; 52 | 53 | // TEST_F(CountMinTest, test) { 54 | 55 | // for (int k = 1; k < 10; ++k) { 56 | // test(n*.2, k); 57 | // test(n, k); 58 | // test(n*2, k); 59 | // test(n*5, k); 60 | // test(n*10, k); 61 | // } 62 | // } 63 | -------------------------------------------------------------------------------- /src/test/fixing_float_test.cc: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "filter/fixing_float.h" 3 | 4 | using namespace PS; 5 | 6 | TEST(FIXING_FLOAT, EncodeDecode) { 7 | MessagePtr msg(new Message()); 8 | auto filter_conf = msg->add_filter(FilterConfig::FIXING_FLOAT); 9 | auto conf = filter_conf->add_fixed_point(); 10 | conf->set_min_value(-90); 11 | conf->set_max_value(90); 12 | conf->set_num_bytes(3); 13 | 14 | conf = filter_conf->add_fixed_point(); 15 | conf->set_num_bytes(3); 16 | 17 | SArray ax = {100.0, .1, -100.0}; msg->addValue(ax); 18 | SArray bx = {100.0, .1, -100.0}; msg->addValue(bx); 19 | 20 | FixingFloatFilter filter; 21 | filter.encode(msg); 22 | filter.decode(msg); 23 | 24 | LL << SArray(msg->value[0]); 25 | LL << SArray(msg->value[1]); 26 | } 27 | 28 | 29 | TEST(FIXING_FLOAT, Error) { 30 | 31 | } 32 | 33 | // TEST(FIXING_FLOAT, BoolRand) { 34 | // int n = 100000; 35 | // SArray res(n); 36 | // int seed = time(NULL); 37 | // for (int i = 0; i < n; ++i) { 38 | // res[i] = FixingFloatFilter::boolrand(&seed); 39 | // } 40 | // LL << res.mean() << " " << res.std(); 41 | // } 42 | -------------------------------------------------------------------------------- /src/test/hello_ps.cc: -------------------------------------------------------------------------------- 1 | #include "ps.h" 2 | namespace PS { 3 | 4 | class Server : public App { 5 | public: 6 | virtual void ProcessRequest(Message* req) { 7 | std::cout << MyNodeID() << ": processing request " << req->task.time() << 8 | " from " << req->sender << std::endl; 9 | } 10 | }; 11 | 12 | class Worker : public App { 13 | public: 14 | virtual void ProcessResponse(Message* res) { 15 | std::cout << MyNodeID() << ": received response " << res->task.time() << 16 | " from " << res->sender << std::endl; 17 | } 18 | 19 | virtual void Run() { 20 | int ts = Submit(Task(), kServerGroup); 21 | Wait(ts); 22 | 23 | ts = Submit(Task(), kServerGroup); 24 | Wait(ts); 25 | 26 | Message req; 27 | req.recver = kServerGroup; 28 | req.callback = [this]() { 29 | std::cout << MyNodeID() << ": request " << LastResponse()->task.time() << 30 | " is finished" << std::endl; 31 | }; 32 | Wait(Submit(&req)); 33 | } 34 | }; 35 | 36 | App* App::Create(const std::string& conf) { 37 | if (IsWorker()) return new Worker(); 38 | if (IsServer()) return new Server(); 39 | return new App(); 40 | } 41 | 42 | } // namespace PS 43 | 44 | int main(int argc, char *argv[]) { 45 | return PS::RunSystem(argc, argv); 46 | } 47 | -------------------------------------------------------------------------------- /src/test/kv_layer_perf_ps.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * @brief Performance test of KVLayer 3 | */ 4 | #include "ps.h" 5 | #include "parameter/kv_layer.h" 6 | #include "util/resource_usage.h" 7 | namespace PS { 8 | typedef int V; // value type 9 | 10 | class Updater { 11 | public: 12 | void Init(int id, size_t size, V* data) { 13 | memset(data, 0, sizeof(V)*size); 14 | } 15 | 16 | void Update(int id, size_t size, const V* recv_data, V* data) { 17 | // sum 18 | for (int i = 0; i < size; ++i) { 19 | data[i] += recv_data[i]; 20 | } 21 | } 22 | }; 23 | 24 | class Server : public App { 25 | public: 26 | Server() { 27 | model_.set_updater(&updt_); 28 | } 29 | private: 30 | KVLayer model_; 31 | Updater updt_; 32 | }; 33 | 34 | class Worker : public App { 35 | public: 36 | virtual void Run() { 37 | std::cout << MyNodeID() << ": this is worker " << MyRank() << std::endl; 38 | 39 | // alexnet 40 | SArray layer_size = 41 | {11*11*96, 5*5*256, 3*3*284, 3*3*256, 43264*4096, 4096*4096, 4096*1000}; 42 | int n = layer_size.size(); 43 | 44 | std::vector> layers(n); 45 | for (int i = 0; i < n; ++i) layers[i].resize(layer_size[i]); 46 | 47 | auto tv = tic(); 48 | std::vector pull_time(n); 49 | for (int i = 0; i < n; ++i) { 50 | auto& val = layers[i]; 51 | val.SetValue(1); 52 | int ts = model_.Push( 53 | Parameter::Request(i), val.data(), val.size()); 54 | pull_time[i] = model_.Pull( 55 | Parameter::Request(i, -1, {ts}), val.data(), val.size()); 56 | } 57 | 58 | for (int i = 0; i < n; ++i) { 59 | model_.Wait(pull_time[i]); 60 | const auto& val = model_[i]; 61 | // for (int j = 0; j < val.size(); ++j) { 62 | // CHECK_EQ(val[j], val[0]); 63 | // CHECK_LE(val[j], RankSize()); 64 | // } 65 | } 66 | LL << (double)layer_size.Sum() * sizeof(V) / toc(tv) / 1e6; 67 | } 68 | private: 69 | KVLayer model_; 70 | }; 71 | 72 | App* App::Create(const std::string& conf) { 73 | if (IsWorker()) return new Worker(); 74 | if (IsServer()) return new Server(); 75 | return new App(); 76 | } 77 | 78 | } // namespace PS 79 | 80 | int main(int argc, char *argv[]) { 81 | return PS::RunSystem(argc, argv); 82 | } 83 | -------------------------------------------------------------------------------- /src/test/kv_layer_ps.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * @brief Simple test of KVLayer 3 | */ 4 | #include "ps.h" 5 | #include "parameter/kv_layer.h" 6 | namespace PS { 7 | typedef uint64 K; // key type 8 | typedef int V; // value type 9 | 10 | class Updater { 11 | public: 12 | void Init(int id, size_t size, V* data) { 13 | memset(data, 0, sizeof(V)*size); 14 | } 15 | 16 | void Update(int id, size_t size, const V* recv_data, V* data) { 17 | // sum 18 | for (int i = 0; i < size; ++i) { 19 | data[i] += recv_data[i]; 20 | } 21 | } 22 | }; 23 | 24 | class Server : public App { 25 | public: 26 | Server() { 27 | model_.set_updater(&updt_); 28 | } 29 | private: 30 | KVLayer model_; 31 | Updater updt_; 32 | }; 33 | 34 | class Worker : public App { 35 | public: 36 | virtual void Run() { 37 | std::cout << MyNodeID() << ": this is worker " << MyRank() << std::endl; 38 | 39 | std::vector layer_size = {5, 10, 4}; 40 | int n = layer_size.size(); 41 | 42 | std::vector> layers(n); 43 | for (int i = 0; i < n; ++i) layers[i].resize(layer_size[i]); 44 | 45 | std::vector pull_time(n); 46 | for (int i = 0; i < n; ++i) { 47 | auto& val = layers[i]; 48 | val.SetValue(1); 49 | int ts = model_.Push( 50 | Parameter::Request(i), val.data(), val.size()); 51 | pull_time[i] = model_.Pull( 52 | Parameter::Request(i, -1, {ts}), val.data(), val.size(), 53 | [i, this](){ 54 | LL << "layer " << i << " " << model_[i]; 55 | }); 56 | } 57 | 58 | for (int i = 0; i < n; ++i) { 59 | model_.Wait(pull_time[i]); 60 | } 61 | sleep(1); 62 | } 63 | private: 64 | KVLayer model_; 65 | }; 66 | 67 | App* App::Create(const std::string& conf) { 68 | if (IsWorker()) return new Worker(); 69 | if (IsServer()) return new Server(); 70 | return new App(); 71 | } 72 | 73 | } // namespace PS 74 | 75 | int main(int argc, char *argv[]) { 76 | return PS::RunSystem(argc, argv); 77 | } 78 | -------------------------------------------------------------------------------- /src/test/kv_map_perf_ps.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * @brief Performance test of KVMap 3 | */ 4 | #include 5 | #include "ps.h" 6 | #include "parameter/kv_map.h" 7 | #include "parameter/kv_vector.h" 8 | #include "util/shared_array_inl.h" 9 | #include "util/resource_usage.h" 10 | namespace PS { 11 | DEFINE_int32(n, 10, "repeat n times"); 12 | 13 | typedef uint64 K; // key 14 | typedef float V; // value type 15 | 16 | struct Entry { 17 | void Get(V* data, void* state) { *data = value; } 18 | void Set(const V* data, void* state) { value += *data; } 19 | V value = 0; 20 | }; 21 | 22 | class Server : public App { 23 | private: 24 | KVMap vec_; 25 | }; 26 | 27 | class Worker : public App { 28 | public: 29 | virtual void Run() { 30 | std::random_device rd; 31 | std::mt19937 gen(rd()); 32 | std::uniform_int_distribution<> dis(0, 21); 33 | size_t bytes = 0; 34 | auto tv = tic(); 35 | for (int i = 0; i < FLAGS_n; ++i) { 36 | 37 | int k = dis(gen); 38 | SArray key; 39 | key.ReadFromFile("../test/keys/key_" + std::to_string(k)); 40 | SArray val(key.size()); 41 | ParamInitConfig cf; 42 | cf.set_type(ParamInitConfig::GAUSSIAN); 43 | cf.set_mean(0); 44 | cf.set_std(1); 45 | val.SetValue(cf); 46 | 47 | int ts = vec_.Push(Parameter::Request(i), key, {val}); 48 | vec_.Wait(vec_.Pull(Parameter::Request(i, ts+1, {ts}), key)); 49 | CHECK_EQ(vec_[i].value.size(), key.size()); 50 | vec_.Clear(i); 51 | bytes += key.size() * (sizeof(K) + sizeof(V)) * 2; 52 | } 53 | 54 | double thr = (double) bytes / toc(tv) / 1000000; 55 | printf("%s: %d push/pull, throughput %.3lf MB/sec\n", 56 | MyNodeID().c_str(), FLAGS_n, thr); 57 | } 58 | private: 59 | KVVector vec_; 60 | }; 61 | 62 | App* App::Create(const std::string& conf) { 63 | if (IsWorker()) return new Worker(); 64 | if (IsServer()) return new Server(); 65 | return new App(); 66 | } 67 | 68 | } // namespace PS 69 | 70 | int main(int argc, char *argv[]) { 71 | return PS::RunSystem(argc, argv); 72 | } 73 | -------------------------------------------------------------------------------- /src/test/kv_map_ps.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * @brief Simple test of KVMap 3 | */ 4 | #include "ps.h" 5 | #include "parameter/kv_map.h" 6 | #include "parameter/kv_vector.h" 7 | namespace PS { 8 | typedef uint64 K; // key 9 | typedef float V; // value type 10 | 11 | struct Entry { 12 | void Get(V* data, void* state) { *data = value; } 13 | void Set(const V* data, void* state) { value += *data; } 14 | V value = 0; 15 | }; 16 | 17 | class Server : public App { 18 | private: 19 | KVMap vec_; 20 | }; 21 | 22 | class Worker : public App { 23 | public: 24 | virtual void Run() { 25 | std::cout << MyNodeID() << ": this is worker " << MyRank() << std::endl; 26 | 27 | SArray key; 28 | if (MyRank() == 0) { 29 | key = {0, 2, 4, 5}; 30 | } else { 31 | key = {0, 1, 3, 4}; 32 | } 33 | SArray val = {1, 1, 1, 1}; 34 | 35 | vec_.Wait(vec_.Push(Parameter::Request(0), key, {val})); 36 | vec_.Wait(vec_.Pull(Parameter::Request(0), key)); 37 | 38 | std::cout << MyNodeID() << ": pulled value in channel 0 " << vec_[0].value 39 | << std::endl; 40 | } 41 | private: 42 | KVVector vec_; 43 | }; 44 | 45 | App* App::Create(const std::string& conf) { 46 | if (IsWorker()) return new Worker(); 47 | if (IsServer()) return new Server(); 48 | return new App(); 49 | } 50 | 51 | } // namespace PS 52 | 53 | int main(int argc, char *argv[]) { 54 | return PS::RunSystem(argc, argv); 55 | } 56 | -------------------------------------------------------------------------------- /src/test/kv_vector_buffer_ps.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * @brief Simple test of buffered KVVector 3 | */ 4 | #include "ps.h" 5 | #include "parameter/kv_vector.h" 6 | namespace PS { 7 | typedef uint64 K; // key 8 | typedef float V; // value type 9 | 10 | class Server : public App { 11 | public: 12 | Server() : vec_(true, 2) { 13 | // channel 4 14 | vec_[4].key = {0, 1, 3, 4, 5}; 15 | } 16 | 17 | virtual void Run() { 18 | // aggregate data received from all workers 19 | WaitWorkersReady(); 20 | int ts = 0; 21 | vec_.WaitReceivedRequest(ts, kWorkerGroup); 22 | auto recv = vec_.buffer(ts); 23 | vec_[4].value = recv.values[0]; 24 | vec_[4].value.vec() += recv.values[1].vec(); 25 | vec_.FinishReceivedRequest(ts+1, kWorkerGroup); 26 | } 27 | private: 28 | KVVector vec_; 29 | }; 30 | 31 | class Worker : public App { 32 | public: 33 | Worker() : vec_(false, 2) { } 34 | 35 | virtual void Run() { 36 | std::cout << MyNodeID() << ": this is worker " << MyRank() << std::endl; 37 | 38 | SArray key; 39 | if (MyRank() == 0) { 40 | key = {0, 2, 4, 5}; 41 | } else { 42 | key = {0, 1, 3, 4}; 43 | } 44 | // push [1 1 1 1 and [3 3 3 3 into servers 45 | // 2 2 2 2] 4 4 4 4] 46 | SArray val1 = {1, 2, 1, 2, 1, 2, 1, 2}; 47 | SArray val2 = {3, 4, 3, 4, 3, 4, 3, 4}; 48 | 49 | int ts = vec_.Push(Parameter::Request(4), key, {val1, val2}); 50 | // this pull request will depends on a virtual timestamp ts+1 which is used for 51 | // aggregation 52 | vec_.Wait(vec_.Pull(Parameter::Request(4, ts+2, {ts+1}), key)); 53 | 54 | std::cout << MyNodeID() << ": pulled value in channel 4 " << vec_[4].value 55 | << std::endl; 56 | } 57 | private: 58 | KVVector vec_; 59 | }; 60 | 61 | App* App::Create(const std::string& conf) { 62 | if (IsWorker()) return new Worker(); 63 | if (IsServer()) return new Server(); 64 | return new App(); 65 | } 66 | 67 | } // namespace PS 68 | 69 | int main(int argc, char *argv[]) { 70 | return PS::RunSystem(argc, argv); 71 | } 72 | -------------------------------------------------------------------------------- /src/test/kv_vector_perf_ps.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * @file kv_vector_perf_ps.cc 3 | * @author Mu Li 4 | * @date Wed Mar 18 16:34:31 2015 5 | * 6 | * @brief Performance test of KVVector 7 | * 8 | * 9 | */ 10 | #include "ps.h" 11 | #include "parameter/kv_vector.h" 12 | namespace PS { 13 | typedef uint64 K; // key 14 | typedef float V; // value type 15 | 16 | class Server : public App { 17 | public: 18 | Server() : vec_(true, 2) { 19 | // channel 4 20 | vec_[4].key = {0, 1, 3, 4, 5}; 21 | } 22 | 23 | virtual void Run() { 24 | // aggregate data received from all workers 25 | WaitWorkersReady(); 26 | int ts = 0; 27 | vec_.WaitReceivedRequest(ts, kWorkerGroup); 28 | auto recv = vec_.buffer(ts); 29 | vec_[4].value = recv.values[0]; 30 | vec_[4].value.vec() += recv.values[1].vec(); 31 | vec_.FinishReceivedRequest(ts+1, kWorkerGroup); 32 | } 33 | private: 34 | KVVector vec_; 35 | }; 36 | 37 | class Worker : public App { 38 | public: 39 | Worker() : vec_(false, 2) { } 40 | 41 | virtual void Run() { 42 | std::cout << MyNodeID() << ": this is worker " << MyRank() << std::endl; 43 | 44 | SArray key; 45 | if (MyRank() == 0) { 46 | key = {0, 2, 4, 5}; 47 | } else { 48 | key = {0, 1, 3, 4}; 49 | } 50 | // push [1 1 1 1 and [3 3 3 3 into servers 51 | // 2 2 2 2] 4 4 4 4] 52 | SArray val1 = {1, 2, 1, 2, 1, 2, 1, 2}; 53 | SArray val2 = {3, 4, 3, 4, 3, 4, 3, 4}; 54 | 55 | int ts = vec_.Push(Parameter::Request(4), key, {val1, val2}); 56 | // this pull request will depends on a virtual timestamp ts+1 which is used for 57 | // aggregation 58 | vec_.Wait(vec_.Pull(Parameter::Request(4, ts+2, {ts+1}), key)); 59 | 60 | std::cout << MyNodeID() << ": pulled value in channel 4 " << vec_[4].value 61 | << std::endl; 62 | } 63 | private: 64 | KVVector vec_; 65 | }; 66 | 67 | App* App::Create(const std::string& conf) { 68 | if (IsWorker()) return new Worker(); 69 | if (IsServer()) return new Server(); 70 | return new App(); 71 | } 72 | 73 | } // namespace PS 74 | 75 | int main(int argc, char *argv[]) { 76 | return PS::RunSystem(argc, argv); 77 | } 78 | -------------------------------------------------------------------------------- /src/test/kv_vector_ps.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * @file kv_vector_ps.cc 3 | * @author Mu Li 4 | * @date Wed Mar 18 16:34:09 2015 5 | * 6 | * @brief Simple test of KVVector 7 | * 8 | * 9 | */ 10 | 11 | #include "ps.h" 12 | #include "parameter/kv_vector.h" 13 | namespace PS { 14 | typedef uint64 K; // key type 15 | typedef int V; // value type 16 | 17 | class Server : public App { 18 | public: 19 | Server() : vec_() { 20 | // channel 0 21 | vec_[0].key = {0, 1, 3, 4, 5}; 22 | vec_[0].value = {1, 2, 3, 4, 5}; 23 | 24 | // channel 1 25 | vec_[1].key = {0, 1, 3, 4, 5}; 26 | vec_[1].value = {2, 3, 4, 5, 6}; 27 | } 28 | 29 | private: 30 | KVVector vec_; 31 | }; 32 | 33 | class Worker : public App { 34 | public: 35 | Worker() : vec_() { } 36 | 37 | virtual void Run() { 38 | std::cout << MyNodeID() << ": this is worker " << MyRank() << std::endl; 39 | 40 | SArray key; 41 | if (MyRank() == 0) { 42 | key = {0, 2, 4, 5}; 43 | } else { 44 | key = {0, 1, 3, 4}; 45 | } 46 | 47 | int ts1 = vec_.Pull(Parameter::Request(0), key); 48 | int ts2 = vec_.Pull(Parameter::Request(1), key); 49 | 50 | vec_.Wait(ts1); 51 | std::cout << MyNodeID() << ": pulled value in channel 0 " << vec_[0].value 52 | << std::endl; 53 | 54 | vec_.Wait(ts2); 55 | std::cout << MyNodeID() << ": pulled value in channel 1 " << vec_[1].value 56 | << std::endl; 57 | } 58 | private: 59 | KVVector vec_; 60 | }; 61 | 62 | App* App::Create(const std::string& conf) { 63 | if (IsWorker()) return new Worker(); 64 | if (IsServer()) return new Server(); 65 | return new App(); 66 | } 67 | 68 | } // namespace PS 69 | 70 | int main(int argc, char *argv[]) { 71 | return PS::RunSystem(argc, argv); 72 | } 73 | -------------------------------------------------------------------------------- /src/test/localizer_test.cc: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "util/localizer.h" 3 | 4 | using namespace PS; 5 | 6 | TEST(Localizer, RCV1) { 7 | DataConfig cache, dc; 8 | cache.add_file("/tmp/test/"); 9 | 10 | dc.set_format(DataConfig::TEXT); 11 | dc.set_text(DataConfig::LIBSVM); 12 | dc.add_file("../data/rcv1_train.binary"); 13 | 14 | SlotReader sr(dc, cache); 15 | ExampleInfo info; 16 | sr.read(&info); 17 | 18 | Localizer lc; 19 | 20 | SArray key; 21 | SArray freq; 22 | lc.countUniqIndex(sr.index(1), &key, &freq); 23 | 24 | EXPECT_EQ(key.eigenArray().sum(), 1051859373); 25 | EXPECT_EQ(freq.eigenArray().sum(), 1498952); 26 | EXPECT_EQ(freq.eigenArray().square().sum(), 1924492682); 27 | EXPECT_EQ(key.size(), freq.size()); 28 | EXPECT_EQ(key.size(), 44504); 29 | 30 | int filter = 2; 31 | SArray f_key; 32 | for (int i = 0; i < key.size(); ++i) { 33 | if (freq[i] > filter) f_key.pushBack(key[i]); 34 | } 35 | EXPECT_EQ(f_key.size(), 19959); 36 | 37 | auto X = std::static_pointer_cast>( 38 | lc.remapIndex( 1, f_key, &sr)); 39 | 40 | EXPECT_EQ(X->offset().eigenArray().sum(), 14702421805); 41 | SArray idx; 42 | for (auto k : X->index()) idx.pushBack(k); 43 | EXPECT_EQ(idx.size(), 1467683); 44 | EXPECT_EQ(idx.eigenArray().sum(), 14708054959 - idx.size()); 45 | EXPECT_LT(X->value().eigenArray().sum(), 132224); 46 | EXPECT_GT(X->value().eigenArray().sum(), 132223); 47 | 48 | // LL << X->debugString(); 49 | } 50 | 51 | // TEST(Localizer, ADFEA) { 52 | // DataConfig dc; 53 | // dc.set_format(DataConfig::TEXT); 54 | // dc.set_text(DataConfig::ADFEA); 55 | // dc.add_file("../../data/ctrc/train/part-000[0-1].gz"); 56 | // auto data = readMatricesOrDie(searchFiles(dc)); 57 | 58 | // for (int i = 1; i < data.size(); ++i) { 59 | // Localizer lc(data[i]); 60 | // SArray key; 61 | // SArray freq; 62 | // lc.countUniqIndex(&key, &freq); 63 | 64 | // int filter = 4; 65 | // SArray f_key; 66 | // for (int i = 0; i < key.size(); ++i) { 67 | // if (freq[i] > filter) f_key.pushBack(key[i]); 68 | // } 69 | // LL << f_key.size(); 70 | // auto X = std::static_pointer_cast>(lc.remapIndex(f_key)); 71 | // if (X) { 72 | // LL << X->index().eigenArray().maxCoeff(); 73 | // } 74 | // // LL << X->debugString(); 75 | // } 76 | // } 77 | -------------------------------------------------------------------------------- /src/test/network_perf_ps.cc: -------------------------------------------------------------------------------- 1 | #include "ps.h" 2 | #include "util/resource_usage.h" 3 | 4 | DEFINE_int32(n, 1000, "repeat n times"); 5 | DEFINE_int32(data_size, 1000, 6 | "data in KB sent from a worker to a server"); 7 | DEFINE_bool(server_aggregation, false, 8 | "servers will aggregate the data from servs if true"); 9 | namespace PS { 10 | 11 | class Server : public App { 12 | public: 13 | virtual void Run() { 14 | // if (FLAGS_server_aggregation) { 15 | // WaitWorkersReady(); 16 | // auto data = split(FLAGS_data_size, ',',true); 17 | // for (int i = 0; i < data.size() * FLAGS_n; ++i) { 18 | // } 19 | // } 20 | } 21 | }; 22 | 23 | class Worker : public App { 24 | public: 25 | 26 | virtual void Slice(const Message& request, const std::vector>& krs, 27 | std::vector* msgs) { 28 | for (auto m : *msgs) *m = request; 29 | } 30 | 31 | virtual void Run() { 32 | int n = FLAGS_n; 33 | int m = FLAGS_data_size; 34 | auto tv = tic(); 35 | for (int j = 0; j < n; ++j) { 36 | SArray val(m*1000/sizeof(int), 1); 37 | Message msg; 38 | msg.add_value(val); 39 | msg.recver = kServerGroup; 40 | int ts = Submit(&msg); 41 | Wait(ts); 42 | } 43 | double thr = (double)m / 1000.0 * n * sys_.manager().num_servers() / toc(tv); 44 | printf("%s: packet size: %d KB, throughput %.3lf MB/sec\n", 45 | MyNodeID().c_str(), m, thr); 46 | } 47 | }; 48 | 49 | App* App::Create(const std::string& conf) { 50 | if (IsWorker()) return new Worker(); 51 | if (IsServer()) return new Server(); 52 | return new App(); 53 | } 54 | 55 | } // namespace PS 56 | 57 | int main(int argc, char *argv[]) { 58 | return PS::RunSystem(argc, argv); 59 | } 60 | -------------------------------------------------------------------------------- /src/test/parallel_ordered_match_test.cc: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "util/parallel_ordered_match.h" 3 | #include "util/shared_array_inl.h" 4 | 5 | using namespace PS; 6 | namespace PS { 7 | DEFINE_int32(num_threads, 2, ""); 8 | DEFINE_int32(k, 1, "k"); 9 | } // namespace PS 10 | 11 | 12 | class PMatchTest : public ::testing::Test { 13 | protected: 14 | virtual void SetUp() { 15 | key1.ReadFromFile("../test/keys/key_1"); 16 | key2.ReadFromFile("../test/keys/key_2"); 17 | }; 18 | 19 | SArray key1, key2; 20 | }; 21 | 22 | TEST_F(PMatchTest, simple) { 23 | 24 | } 25 | 26 | TEST_F(PMatchTest, match) { 27 | 28 | int k = FLAGS_k; 29 | SArray val1(key1.size()*k, 1); 30 | SArray val2; 31 | 32 | size_t n = ParallelOrderedMatch(key1, val1, key2, &val2, k); 33 | 34 | LL << n << " " << val2; 35 | } 36 | -------------------------------------------------------------------------------- /src/test/reassign_server_key_range_ps.cc: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "system/postoffice.h" 3 | #include "system/postmaster.h" 4 | #include "system/app.h" 5 | 6 | // build: make build/reassign_server_key_range 7 | // test: script/local.sh build/reassign_server_key_range 3 3 8 | namespace PS { 9 | class Root : public App { 10 | public: 11 | Root() : App() { } 12 | virtual ~Root() { } 13 | 14 | void init() { 15 | // repartition the key range 16 | Range range(0, 100); 17 | auto nodes = sys_.yp().nodes(); 18 | nodes = Postmaster::partitionServerKeyRange(nodes, range); 19 | 20 | Task task; 21 | task.set_type(Task::MANAGE); 22 | task.mutable_mng_node()->set_cmd(ManageNode::UPDATE); 23 | for (const auto& n : nodes) { 24 | *task.mutable_mng_node()->add_node() = n; 25 | } 26 | port(kLiveGroup)->submitAndWait(task); 27 | } 28 | }; 29 | 30 | class Slave : public App { 31 | public: 32 | Slave() : App() { } 33 | virtual ~Slave() { } 34 | 35 | void run() { 36 | LL << exec_.myNode().id() << " key range: " 37 | << exec_.myNode().key().ShortDebugString(); 38 | } 39 | }; 40 | 41 | App* App::create(const string& name, const string& conf) { 42 | auto my_role = Postoffice::instance().myNode().role(); 43 | if (my_role == Node::SCHEDULER) { 44 | return new Root(); 45 | } else { 46 | return new Slave(); 47 | } 48 | } 49 | 50 | } // namespace PS 51 | 52 | int main(int argc, char *argv[]) { 53 | auto& sys = PS::Postoffice::instance(); 54 | sys.start(&argc, &argv); 55 | 56 | sys.Stop(); 57 | return 0; 58 | } 59 | -------------------------------------------------------------------------------- /src/test/slot_reader_test.cc: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "data/slot_reader.h" 3 | 4 | using namespace PS; 5 | 6 | TEST(SlotReader, read) { 7 | DataConfig cache, dc; 8 | cache.add_file("/tmp/test/"); 9 | dc.set_format(DataConfig::TEXT); 10 | 11 | // load adfea 12 | dc.set_text(DataConfig::ADFEA); 13 | dc.add_file("../../data/ctrc/train/part-000[0-1].gz"); 14 | 15 | // load libsvm 16 | // dc.set_text(DataConfig::LIBSVM); 17 | // dc.add_file("../data/rcv1/train/part-.*"); 18 | 19 | DataConfig dc2 = searchFiles(dc); 20 | SlotReader gr; gr.init(dc2, cache); gr.read(); 21 | 22 | auto data = readMatricesOrDie(dc2); 23 | 24 | auto label = gr.value(0); 25 | // LL << label; 26 | // LL << data[0]->value(); 27 | EXPECT_EQ((label.eigenVector() - data[0]->value().eigenVector()).norm(), 0); 28 | 29 | for (int i = 1; i < data.size(); ++i) { 30 | auto X = std::static_pointer_cast>(data[i]); 31 | int id = X->info().id(); 32 | auto index = gr.index(id); 33 | auto offset = gr.offset(id); 34 | // LL << index; 35 | // LL << offset; 36 | EXPECT_EQ((index.eigenVector() - X->index().eigenVector()).norm(), 0); 37 | EXPECT_EQ((offset.eigenVector() - X->offset().eigenVector()).norm(), 0); 38 | 39 | if (!X->value().empty()) { 40 | auto value = gr.value(id); 41 | EXPECT_EQ((value.eigenVector() - X->value().eigenVector()).norm(), 0); 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/test/stream_reader_test.cc: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "data/stream_reader.h" 3 | 4 | using namespace PS; 5 | 6 | // TEST(StreamReader, read_proto) { 7 | // DataConfig dc; 8 | // // load adfea 9 | // dc.set_format(DataConfig::PROTO); 10 | // dc.add_file("../output/parsa_.*"); 11 | 12 | // DataConfig dc2 = searchFiles(dc); 13 | // StreamReader reader; reader.init(dc2); 14 | 15 | // MatrixPtrList X; 16 | // while (reader.readMatrices(10000, &X)) { 17 | // CHECK_EQ(X.size(), 2); 18 | // } 19 | // } 20 | 21 | // TEST(StreamReader, read) { 22 | // DataConfig dc; 23 | // // load adfea 24 | // dc.set_format(DataConfig::TEXT); 25 | // dc.set_text(DataConfig::ADFEA); 26 | // dc.add_file("../../data/ctrc/train/part-000[0-1].gz"); 27 | // dc.set_ignore_feature_group(true); 28 | 29 | // // load libsvm 30 | // // dc.set_text(DataConfig::LIBSVM); 31 | // // dc.add_file("../data/rcv1/train/part-.*"); 32 | 33 | // DataConfig dc2 = searchFiles(dc); 34 | // StreamReader reader; reader.init(dc2); 35 | 36 | // MatrixPtrList X; 37 | // reader.readMatrices(100, &X); 38 | // } 39 | 40 | 41 | // TEST(StreamReader, convert) { 42 | // DataConfig dc; 43 | // dc.set_format(DataConfig::TEXT); 44 | // dc.set_text(DataConfig::TERAFEA); 45 | // dc.add_file("../data/toutiao/data.txt"); 46 | // dc.set_ignore_feature_group(true); 47 | 48 | // DataConfig dc2 = searchFiles(dc); 49 | // StreamReader reader; reader.init(dc2); 50 | 51 | // MatrixPtrList X; 52 | // reader.readMatrices(1000000, &X); 53 | // // CHECK_EQ(X.size(), 2); 54 | 55 | // X[0]->writeToBinFile("toutiao.Y"); 56 | 57 | // for (int i = 1; i < X.size(); ++i) { 58 | // Localizer localizer; 59 | // SArray key; 60 | // localizer.countUniqIndex(X[i], &key); 61 | // key.writeToFile("toutiao.key."+std::to_string(X[i]->info().id())); 62 | // auto Z = localizer.remapIndex(key); 63 | // Z->writeToBinFile("toutiao.X."+std::to_string(X[i]->info().id())); 64 | // } 65 | // } 66 | -------------------------------------------------------------------------------- /src/test_main.cc: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "gflags/gflags.h" 3 | #include "glog/logging.h" 4 | 5 | int main(int argc, char **argv) { 6 | FLAGS_logtostderr = 1; 7 | testing::InitGoogleTest(&argc, argv); 8 | google::ParseCommandLineFlags(&argc, &argv, true); 9 | 10 | return RUN_ALL_TESTS(); 11 | } 12 | -------------------------------------------------------------------------------- /src/util/assign_op.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/proto/assign_op.pb.h" 3 | #include "glog/logging.h" 4 | namespace PS { 5 | // The cost of the switch is minimal. Once "op" is a constant, the compiler will 6 | // do optimization. see test/assign_op_test.cc 7 | 8 | // Returns right op= left. bascial version, works for both floast and intergers 9 | template 10 | T& AssignOp(T& right, const T& left, const AssignOpType& op) { 11 | switch (op) { 12 | case AssignOpType::ASSIGN: 13 | right = left; break; 14 | case AssignOpType::PLUS: 15 | right += left; break; 16 | case AssignOpType::MINUS: 17 | right -= left; break; 18 | case AssignOpType::TIMES: 19 | right *= left; break; 20 | case AssignOpType::DIVIDE: 21 | right /= left; break; 22 | default: 23 | CHECK(false) << "use AssignOpI.." ; 24 | } 25 | return right; 26 | } 27 | 28 | // Returns right op= left. for integers 29 | template 30 | T& AssignOpI(T& right, const T& left, const AssignOpType& op) { 31 | switch (op) { 32 | case AssignOpType::ASSIGN: 33 | right = left; break; 34 | case AssignOpType::PLUS: 35 | right += left; break; 36 | case AssignOpType::MINUS: 37 | right -= left; break; 38 | case AssignOpType::TIMES: 39 | right *= left; break; 40 | case AssignOpType::DIVIDE: 41 | right /= left; break; 42 | case AssignOpType::AND: 43 | right &= left; break; 44 | case AssignOpType::OR: 45 | right |= left; break; 46 | case AssignOpType::XOR: 47 | right ^= left; break; 48 | } 49 | return right; 50 | } 51 | 52 | } // namespace PS 53 | -------------------------------------------------------------------------------- /src/util/auc.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/shared_array.h" 3 | // #include "util/proto/evaluation.pb.h" 4 | 5 | namespace PS { 6 | 7 | // distributed auc 8 | class AUC { 9 | public: 10 | void setGoodness(int64 goodness) { goodness_ = goodness; } 11 | 12 | // functions for the scheduler (or the root machine) 13 | // merge local results from a worker 14 | void merge(const AUCData& data) { 15 | CHECK_EQ(data.tp_key_size(), data.tp_count_size()); 16 | for (size_t i = 0; i < data.tp_key_size(); ++i) 17 | tp_count_[data.tp_key(i)] += data.tp_count(i); 18 | 19 | CHECK_EQ(data.fp_key_size(), data.fp_count_size()); 20 | for (size_t i = 0; i < data.fp_key_size(); ++i) 21 | fp_count_[data.fp_key(i)] += data.fp_count(i); 22 | 23 | // LL << tp_count_.size() << " " << fp_count_.size(); 24 | } 25 | 26 | double accuracy(double threshold = 0) { 27 | double total = 0, correct = 0; 28 | double x = threshold * goodness_; 29 | for (auto& it : tp_count_) { 30 | if (it.first >= x) correct += it.second; 31 | total += it.second; 32 | } 33 | 34 | for (auto& it : fp_count_) { 35 | if (it.first < x) correct += it.second; 36 | total += it.second; 37 | } 38 | return (correct / total); 39 | } 40 | 41 | // evaluate the auc after merging all workers' results 42 | double evaluate() { 43 | if (tp_count_.empty() || fp_count_.empty()) return 0.5; 44 | double tp_sum = 0, fp_sum = 0, auc = 0; 45 | auto tp_it = tp_count_.begin(); 46 | 47 | for (auto& fp_it : fp_count_) { 48 | auto fp_v = fp_it.second; 49 | for (; tp_it != tp_count_.end() && tp_it->second <= fp_v; ++ tp_it) 50 | tp_sum += tp_it->second; 51 | fp_sum += fp_v; 52 | auc += tp_sum * fp_v; 53 | } 54 | for (; tp_it != tp_count_.end(); ++tp_it) tp_sum += tp_it->second; 55 | 56 | // LL << tp_sum << " " << fp_sum; 57 | auc = auc / tp_sum / fp_sum; 58 | return (auc < .5 ? 1 - auc : auc); 59 | } 60 | 61 | // clear cached results of workers 62 | void clear() { 63 | tp_count_.clear(); 64 | fp_count_.clear(); 65 | } 66 | 67 | // worker: compute results using local data 68 | template 69 | void compute(const SArray& label, const SArray& predict, AUCData* res) { 70 | CHECK_EQ(label.size(), predict.size()); 71 | CHECK_GT(label.size(), 0); 72 | 73 | clear(); 74 | for (size_t i = 0; i < predict.size(); ++i) { 75 | int64 k = (int64)(predict[i] * goodness_); 76 | if (label[i] > 0) 77 | ++ tp_count_[k]; 78 | else 79 | ++ fp_count_[k]; 80 | } 81 | 82 | 83 | // LL << goodness_ << " " << tp_count_.size() << " " << fp_count_.size(); 84 | 85 | res->clear_tp_key(); 86 | res->clear_tp_count(); 87 | for (auto& it : tp_count_) { 88 | res->add_tp_key(it.first); 89 | res->add_tp_count(it.second); 90 | } 91 | 92 | res->clear_fp_key(); 93 | res->clear_fp_count(); 94 | for (auto& it : fp_count_) { 95 | res->add_fp_key(it.first); 96 | res->add_fp_count(it.second); 97 | } 98 | } 99 | private: 100 | int64 goodness_ = 1000; 101 | std::map fp_count_, tp_count_; 102 | 103 | }; 104 | 105 | } // namespace PS 106 | -------------------------------------------------------------------------------- /src/util/barrier.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "util/macros.h" 6 | #include "glog/logging.h" 7 | 8 | namespace PS { 9 | 10 | class Barrier { 11 | public: 12 | explicit Barrier(int num_threads) 13 | : num_to_block_(num_threads), num_to_exit_(num_threads) {} 14 | 15 | // return true if this is the last thread 16 | bool Block() { 17 | std::unique_lock l(mu_); 18 | num_to_block_--; 19 | CHECK_GE(num_to_block_, 0); 20 | 21 | if (num_to_block_ > 0) { 22 | while (num_to_block_ > 0) cv_.wait(l); 23 | } else { 24 | cv_.notify_all(); 25 | } 26 | 27 | num_to_exit_--; 28 | CHECK_GE(num_to_exit_, 0); 29 | return (num_to_exit_ == 0); 30 | } 31 | 32 | private: 33 | DISALLOW_COPY_AND_ASSIGN(Barrier); 34 | std::mutex mu_; 35 | std::condition_variable cv_; 36 | int num_to_block_; 37 | int num_to_exit_; 38 | }; 39 | 40 | 41 | } // PS 42 | -------------------------------------------------------------------------------- /src/util/bitmap.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "util/common.h" 4 | 5 | namespace PS { 6 | class Bitmap; 7 | typedef std::shared_ptr BitmapPtr; 8 | 9 | #define BITCOUNT_(x) (((BX_(x)+(BX_(x)>>4)) & 0x0F0F0F0F) % 255) 10 | #define BX_(x) ((x) - (((x)>>1)&0x77777777) \ 11 | - (((x)>>2)&0x33333333) \ 12 | - (((x)>>3)&0x11111111)) 13 | class Bitmap { 14 | public: 15 | Bitmap() { } 16 | Bitmap(uint32 size, bool value = false) { resize(size, value); } 17 | ~Bitmap() { clear(); } 18 | 19 | void resize(uint32 size, bool value = false) { 20 | CHECK_EQ(size_, 0) 21 | << "TODO didn't support resize non-empty bitmap... clear() first "; 22 | size_ = size; 23 | map_size_ = (size >> kBitmapShift) + 1; 24 | map_ = new uint16[map_size_]; 25 | fill(value); 26 | } 27 | 28 | void clear() { 29 | delete [] map_; 30 | map_ = nullptr; 31 | map_size_ = 0; 32 | size_ = 0; 33 | } 34 | 35 | void set(uint32 i) { 36 | map_[i>>kBitmapShift] |= (uint16) (1 << (i&kBitmapMask)); 37 | } 38 | void clear(uint32 i) { 39 | map_[i>>kBitmapShift] &= ~((uint16) (1 << (i&kBitmapMask))); 40 | } 41 | 42 | bool test(uint32 i) const { 43 | return static_cast((map_[i>>kBitmapShift] >> (i&kBitmapMask)) & 1); 44 | } 45 | bool operator[] (uint32 i) const { 46 | return test(i); 47 | } 48 | 49 | void fill(bool value) { 50 | if (value) 51 | memset(map_, 0xFF, map_size_*sizeof(uint16)); 52 | else 53 | memset(map_, 0, map_size_*sizeof(uint16)); 54 | } 55 | 56 | // TODO flip all bits 57 | void flip() { } 58 | 59 | uint32 size() const { return size_; } 60 | size_t memSize() const { return map_size_*sizeof(uint16); } 61 | 62 | // number of bit == true 63 | uint32 nnz() { 64 | if (!init_nnz_) { 65 | for(int i=0; i<65536; i++) 66 | LUT_[i] = (unsigned char)BITCOUNT_(i); 67 | init_nnz_ = true; 68 | } 69 | 70 | uint32 bn = size_ >> kBitmapShift; 71 | uint32 v = 0; 72 | for (uint32_t i = 0; i < bn; i++) 73 | v += LUT_[map_[i]]; 74 | return v + nnz(bn << kBitmapShift, size_); 75 | } 76 | 77 | private: 78 | uint32 nnz(uint32 start, uint32 end) { 79 | CHECK_LE(end, size_); 80 | uint32 v = 0; 81 | for (uint32 i = start; i < end; ++i) 82 | v += (*this)[i]; 83 | return v; 84 | } 85 | 86 | private: 87 | uint16* map_ = nullptr; 88 | uint32 map_size_ = 0; 89 | uint32 size_ = 0; 90 | 91 | static const uint32 kBitmapShift = 4; 92 | static const uint32 kBitmapMask = 0x0F; 93 | 94 | unsigned char LUT_[65536]; 95 | bool init_nnz_ = false; 96 | 97 | }; 98 | 99 | } // namespace PS 100 | -------------------------------------------------------------------------------- /src/util/block_bloom_filter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/sketch.h" 3 | namespace PS { 4 | 5 | // a blocked version, see 6 | // Cache-, Hash- and Space-Efficient Bloom Filters, 7 | // http://algo2.iti.kit.edu/documents/cacheefficientbloomfilters-jea.pdf 8 | 9 | // 1.2x - 1.8x faster than BloomFilter, but may give slightly large FPR 10 | template 11 | class BlockBloomFilter : public Sketch { 12 | public: 13 | BlockBloomFilter() { } 14 | BlockBloomFilter(int m, int k) { resize(m, k); } 15 | ~BlockBloomFilter() { delete [] data_; } 16 | void resize(int m, int k) { 17 | m = std::max(m, 1024); 18 | num_bin_ = (m / 8 / bin_size_) + 1; 19 | data_size_ = num_bin_ * bin_size_; 20 | if (m > m_) { 21 | delete [] data_; 22 | data_ = new char[data_size_]; 23 | // CHECK_EQ(posix_memalign((void**)&data_, bin_size_*8, data_size_), 0); 24 | } 25 | k_ = std::min(64, std::max(1, k)); 26 | m_ = m; 27 | reset(); 28 | } 29 | 30 | void reset() { 31 | memset(data_, 0, data_size_ * sizeof(char)); 32 | } 33 | 34 | // make the api be similar to std::set 35 | bool count(K key) const { return query(key); } 36 | bool operator[] (K key) const { return query(key); } 37 | bool query(K key) const { 38 | // auto h = crc32(key); 39 | auto h = hash(key); 40 | auto delta = (h >> 17) | (h << 15); // Rotate right 17 bits 41 | char* data = data_ + (h % num_bin_) * bin_size_; 42 | for (int j = 0; j < k_; ++j) { 43 | uint32 bitpos = h % (bin_size_ * 8); 44 | if ((data[bitpos/8] & (1 << (bitpos % 8))) == 0) return false; 45 | h += delta; 46 | } 47 | return true; 48 | } 49 | 50 | void insert(K key) { 51 | // auto h = crc32(key); 52 | auto h = hash(key); 53 | auto delta = (h >> 17) | (h << 15); // Rotate right 17 bits 54 | char* data = data_ + (h % num_bin_) * bin_size_; 55 | for (int j = 0; j < k_; ++j) { 56 | uint32 bitpos = h % (bin_size_ * 8); 57 | data[bitpos/8] |= (1 << (bitpos % 8)); 58 | h += delta; 59 | } 60 | } 61 | 62 | private: 63 | char* data_ = NULL; 64 | int data_size_ = 0; 65 | uint32 m_ = 0; 66 | int k_ = 0; 67 | const uint32 bin_size_ = 64; // cache line size 68 | uint32 num_bin_ = 0; 69 | }; 70 | 71 | } 72 | -------------------------------------------------------------------------------- /src/util/bloom_filter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/sketch.h" 3 | namespace PS { 4 | 5 | template 6 | class BloomFilter : public Sketch { 7 | public: 8 | BloomFilter() { } 9 | BloomFilter(int m, int k) { resize(m, k); } 10 | ~BloomFilter() { delete [] data_; } 11 | void resize(int m, int k) { 12 | delete [] data_; 13 | k_ = std::min(64, std::max(1, k)); 14 | m_ = m; 15 | data_size_ = (m / 8) + 1; 16 | data_ = new char[data_size_]; 17 | memset(data_, 0, data_size_ * sizeof(char)); 18 | } 19 | 20 | bool operator[] (K key) const { return query(key); } 21 | bool query(K key) const { 22 | uint32 h = hash(key); 23 | const uint32 delta = (h >> 17) | (h << 15); // Rotate right 17 bits 24 | for (int j = 0; j < k_; ++j) { 25 | uint32 bitpos = h % m_; 26 | if ((data_[bitpos/8] & (1 << (bitpos % 8))) == 0) return false; 27 | h += delta; 28 | } 29 | return true; 30 | } 31 | 32 | void insert(K key) { 33 | uint32 h = hash(key); 34 | const uint32 delta = (h >> 17) | (h << 15); // Rotate right 17 bits 35 | for (int j = 0; j < k_; ++j) { 36 | uint32 bitpos = h % m_; 37 | data_[bitpos/8] |= (1 << (bitpos % 8)); 38 | h += delta; 39 | } 40 | } 41 | 42 | private: 43 | char* data_ = NULL; 44 | int data_size_ = 0; 45 | uint32 m_ = 0; 46 | int k_ = 0; 47 | }; 48 | } // namespace PS 49 | -------------------------------------------------------------------------------- /src/util/common.h: -------------------------------------------------------------------------------- 1 | // some utility functions 2 | #pragma once 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | // concurrency 11 | #include 12 | #include 13 | #include 14 | // smart pointers 15 | #include 16 | // stream 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | // containers 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | 32 | #include 33 | 34 | 35 | // google staff 36 | #include "gflags/gflags.h" 37 | #include "glog/logging.h" 38 | 39 | // util 40 | #include "util/macros.h" 41 | #include "util/integral_types.h" 42 | #include "util/resource_usage.h" 43 | 44 | // base 45 | #include 46 | #include "google/protobuf/text_format.h" 47 | 48 | //const int MAX_NUM_LEN = 1000; 49 | 50 | namespace PS { 51 | 52 | // uint64 is the default key size. We can change it into uint32 to reduce the 53 | // spaces for storing the keys. Howerver, if we want a larger key size, say 54 | // uint128, we need to change proto/range.proto to string type, because uint64 55 | // is the largest integer type supported by protobuf 56 | typedef uint64 Key; 57 | static const Key kMaxKey = kuint64max; 58 | 59 | typedef std::string NodeID; 60 | 61 | typedef std::lock_guard Lock; 62 | using std::string; 63 | 64 | #define LL LOG(ERROR) 65 | #define LI LOG(INFO) 66 | 67 | DECLARE_int32(num_threads); 68 | 69 | // print the array's head and tail 70 | template 71 | inline string dbstr(const V* data, int n, int m = 5) { 72 | std::stringstream ss; 73 | ss << "[" << n << "]: "; 74 | if (n < 2 * m) { 75 | for (int i = 0; i < n; ++i) ss << data[i] << " "; 76 | } else { 77 | for (int i = 0; i < m; ++i) ss << data[i] << " "; 78 | ss << "... "; 79 | for (int i = n-m; i < n; ++i) ss << data[i] << " "; 80 | } 81 | return ss.str(); 82 | } 83 | 84 | #define NOTICE(_fmt_, args...) do { \ 85 | struct timeval tv; gettimeofday(&tv, NULL); \ 86 | time_t ts = (time_t)(tv.tv_sec); \ 87 | struct ::tm tm_time; localtime_r(&ts, &tm_time); \ 88 | int n = strlen(__FILE__) - 1; \ 89 | for (; n > -1; --n) { if (n==-1 || __FILE__[n] == '/') break; } \ 90 | fprintf(stdout, "[%02d%02d %02d:%02d:%02d.%03d %s:%d] " _fmt_ "\n", \ 91 | 1+tm_time.tm_mon, tm_time.tm_mday, tm_time.tm_hour, \ 92 | tm_time.tm_min, tm_time.tm_sec, (int)tv.tv_usec/1000, \ 93 | __FILE__+n+1, __LINE__, ##args); \ 94 | } while (0) 95 | 96 | } // namespace PS 97 | 98 | 99 | // basename(__FILE__) 100 | -------------------------------------------------------------------------------- /src/util/countmin.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/sketch.h" 3 | #include 4 | #include "util/shared_array_inl.h" 5 | namespace PS { 6 | 7 | template 8 | class CountMin : public Sketch { 9 | public: 10 | // TODO prefetch to accelerate the memory access 11 | bool empty() { return n_ == 0; } 12 | void clear() { data_.clear(); n_ = 0; } 13 | void resize(int n, int k, V v_max) { 14 | n_ = std::max(n, 64); 15 | data_.resize(n_); 16 | data_.SetZero(); 17 | k_ = std::min(30, std::max(1, k)); 18 | v_max_ = v_max; 19 | } 20 | 21 | void insert(const K& key, const V& count) { 22 | uint32 h = hash(key); 23 | const uint32 delta = (h >> 17) | (h << 15); // Rotate right 17 bits 24 | for (int j = 0; j < k_; ++j) { 25 | V v = data_[h % n_]; 26 | // to avoid overflow 27 | data_[h % n_] = count > v_max_ - v ? v_max_ : v + count; 28 | h += delta; 29 | } 30 | } 31 | 32 | V query(const K& key) const { 33 | V res = v_max_; 34 | uint32 h = hash(key); 35 | const uint32 delta = (h >> 17) | (h << 15); // Rotate right 17 bits 36 | for (int j = 0; j < k_; ++j) { 37 | res = std::min(res, data_[h % n_]); 38 | h += delta; 39 | } 40 | return res; 41 | } 42 | 43 | private: 44 | SArray data_; 45 | int n_ = 0; 46 | int k_ = 1; 47 | V v_max_ = 0; 48 | }; 49 | 50 | } // namespace PS 51 | -------------------------------------------------------------------------------- /src/util/crc32c.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2011 The LevelDB Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style license that can be 3 | // found in the LICENSE file. See the AUTHORS file for names of contributors. 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | namespace PS { 11 | namespace crc32c { 12 | 13 | // Return the crc32c of concat(A, data[0,n-1]) where init_crc is the 14 | // crc32c of some string A. Extend() is often used to maintain the 15 | // crc32c of a stream of data. 16 | extern uint32_t Extend(uint32_t init_crc, const char* data, size_t n); 17 | 18 | // Return the crc32c of data[0,n-1] 19 | inline uint32_t Value(const char* data, size_t n) { 20 | return Extend(0, data, n); 21 | } 22 | 23 | static const uint32_t kMaskDelta = 0xa282ead8ul; 24 | 25 | // Return a masked representation of crc. 26 | // 27 | // Motivation: it is problematic to compute the CRC of a string that 28 | // contains embedded CRCs. Therefore we recommend that CRCs stored 29 | // somewhere (e.g., in files) should be masked before being stored. 30 | inline uint32_t Mask(uint32_t crc) { 31 | // Rotate right by 15 bits and add a constant. 32 | return ((crc >> 15) | (crc << 17)) + kMaskDelta; 33 | } 34 | 35 | // Return the crc whose masked representation is masked_crc. 36 | inline uint32_t Unmask(uint32_t masked_crc) { 37 | uint32_t rot = masked_crc - kMaskDelta; 38 | return ((rot >> 17) | (rot << 15)); 39 | } 40 | 41 | } // namespace crc32c 42 | } 43 | -------------------------------------------------------------------------------- /src/util/dense_matrix.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "util/matrix.h" 4 | 5 | namespace PS { 6 | 7 | template 8 | class DenseMatrix : public Matrix { 9 | public: 10 | USING_MATRIX; 11 | DenseMatrix() { } 12 | DenseMatrix(size_t rows, size_t cols, bool row_major = true) { 13 | resize(rows, cols, rows*cols, row_major); 14 | } 15 | 16 | void resize(size_t rows, size_t cols, size_t nnz, bool row_major); 17 | 18 | DenseMatrix(const MatrixInfo& info, SArray value) 19 | : Matrix(info, value) { } 20 | 21 | // TODO 22 | virtual void times(const V* x, V *y) const {CHECK(false); } 23 | 24 | // C = A .* B 25 | virtual MatrixPtr dotTimes(const MatrixPtr& B) const { CHECK(false); return MatrixPtr(); } 26 | 27 | // (nearly) non-copy matrix transpose 28 | virtual MatrixPtr trans() const {CHECK(false); return MatrixPtr(); } 29 | 30 | // convert global index into local index (0,1,2,3...) and return the key map 31 | // virtual MatrixPtr localize(SArray* key_map) const {CHECK(false); return MatrixPtr(); } 32 | 33 | virtual MatrixPtr alterStorage() const; 34 | 35 | // non-copy matrix block 36 | virtual MatrixPtr rowBlock(SizeR range) const { 37 | if (colMajor()) CHECK_EQ(range, SizeR(0, rows())); 38 | auto info = info_; 39 | range.To(info.mutable_row()); 40 | info.set_nnz(range.size() * cols()); 41 | return MatrixPtr(new DenseMatrix(info, value_.Segment(range*cols()))); 42 | } 43 | 44 | virtual MatrixPtr colBlock(SizeR range) const { 45 | if (rowMajor()) CHECK_EQ(range, SizeR(0, cols())); 46 | auto info = info_; 47 | range.To(info.mutable_col()); 48 | info.set_nnz(range.size() * rows()); 49 | return MatrixPtr(new DenseMatrix(info, value_.Segment(range*rows()))); 50 | } 51 | 52 | virtual bool writeToBinFile(string name) const { 53 | return (writeProtoToASCIIFile(info_, name+".info") 54 | && value_.WriteToFile(name+".value")); 55 | } 56 | 57 | virtual string debugString() const { 58 | std::stringstream ss; 59 | ss << rows() << " x " << cols() << " dense matrix " << std::endl 60 | << dbstr(value_.data(), value_.size(), 8); 61 | return ss.str(); 62 | } 63 | }; 64 | 65 | 66 | template 67 | void DenseMatrix::resize( 68 | size_t rows, size_t cols, size_t nnz, bool row_major) { 69 | info_.set_type(MatrixInfo::DENSE); 70 | info_.set_row_major(row_major); 71 | SizeR(0, rows).To(info_.mutable_row()); 72 | SizeR(0, cols).To(info_.mutable_col()); 73 | nnz = rows * cols; 74 | // CHECK_EQ(nnz, rows*cols); 75 | info_.set_nnz(nnz); 76 | info_.set_sizeof_value(sizeof(V)); 77 | // info_.set_nnz_per_row(cols); 78 | // info_.set_nnz_per_col(rows); 79 | // data 80 | value_.resize(nnz); 81 | value_.SetZero(); 82 | 83 | } 84 | 85 | 86 | template 87 | MatrixPtr DenseMatrix::alterStorage() const { 88 | size_t in = innerSize(); 89 | size_t out = outerSize(); 90 | CHECK_EQ(value_.size(), in*out); 91 | 92 | SArray new_value(value_.size()); 93 | 94 | for (size_t i = 0; i < in; ++i) { 95 | for (size_t j = 0; j < out; ++j) { 96 | new_value[i*out+j] = value_[j*in+i]; 97 | } 98 | } 99 | 100 | auto new_info = info_; 101 | new_info.set_row_major(!info_.row_major()); 102 | 103 | return MatrixPtr(new DenseMatrix(new_info, new_value)); 104 | } 105 | 106 | 107 | 108 | } // namespace PS 109 | -------------------------------------------------------------------------------- /src/util/evaluation.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "util/shared_array.h" 4 | #include "util/parallel_sort.h" 5 | 6 | namespace PS { 7 | 8 | // evaluation in a single machine 9 | template 10 | class Evaluation { 11 | public: 12 | static V auc(const SArray& label, 13 | const SArray& predict); 14 | 15 | static V accuracy(const SArray& label, 16 | const SArray& predict, 17 | V threshold = 0); 18 | 19 | static V logloss(const SArray& label, 20 | const SArray& predict); 21 | }; 22 | 23 | template 24 | V Evaluation::auc(const SArray& label, const SArray& predict) { 25 | int n = label.size(); 26 | CHECK_EQ(n, predict.size()); 27 | struct Entry { 28 | V label; 29 | V predict; 30 | }; 31 | SArray buff(n); 32 | for (int i = 0; i < n; ++i) { 33 | buff[i].label = label[i]; 34 | buff[i].predict = predict[i]; 35 | } 36 | // parallelSort(buff.data(), n, FLAGS_num_threads, []( 37 | // const Entry& a, const Entry&b) { return a.predict < b.predict; }); 38 | std::sort(buff.data(), buff.data()+n, [](const Entry& a, const Entry&b) { 39 | return a.predict < b.predict; }); 40 | V area = 0, cum_tp = 0; 41 | for (int i = 0; i < n; ++i) { 42 | if (buff[i].label > 0) { 43 | cum_tp += 1; 44 | } else { 45 | area += cum_tp; 46 | } 47 | } 48 | area /= cum_tp * (n - cum_tp); 49 | return area < 0.5 ? 1 - area : area; 50 | } 51 | 52 | 53 | template 54 | V Evaluation::accuracy(const SArray& label, const SArray& predict, V threshold) { 55 | int n = label.size(); 56 | CHECK_EQ(n, predict.size()); 57 | V correct = 0; 58 | for (int i = 0; i < n; ++i) { 59 | if ((label[i] > 0 && predict[i] > threshold) || 60 | (label[i] < 0 && predict[i] <= threshold)) 61 | correct += 1; 62 | } 63 | V acc = correct / (V) n; 64 | return acc > 0.5 ? acc : 1 - acc; 65 | } 66 | 67 | 68 | template 69 | V Evaluation::logloss(const SArray& label, const SArray& predict) { 70 | int n = label.size(); 71 | CHECK_EQ(n, predict.size()); 72 | V loss = 0; 73 | for (int i = 0; i < n; ++i) { 74 | V y = label[i] > 0; 75 | V p = 1 / (1 + exp(- predict[i])); 76 | loss += y * log(p) + (1 - y) * log(1 - p); 77 | } 78 | return - loss / n; 79 | } 80 | 81 | } // namespace PS 82 | -------------------------------------------------------------------------------- /src/util/filelinereader.cc: -------------------------------------------------------------------------------- 1 | #include "util/filelinereader.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "util/file.h" 8 | #include "glog/logging.h" 9 | #include "util/common.h" 10 | 11 | namespace PS { 12 | 13 | DEFINE_int32(line_limit, 0, 14 | "line number limit that one data file could read"); 15 | 16 | void FileLineReader::Reload() { 17 | const int kMaxLineLength = 60 * 1024; 18 | File* const data_file = File::open(data_conf_, "r"); 19 | if (data_file == NULL) { 20 | loaded_successfully_ = false; 21 | return; 22 | } 23 | 24 | size_t readed_line_count = 0; 25 | std::unique_ptr line(new char[kMaxLineLength]); 26 | for (;;) { 27 | char* const result = data_file->readLine(line.get(), kMaxLineLength); 28 | if (result == NULL || 29 | (FLAGS_line_limit > 0 && readed_line_count > FLAGS_line_limit)) { 30 | data_file->close(); 31 | loaded_successfully_ = true; 32 | return; 33 | } 34 | // Chop the last linefeed if present. 35 | int len = strlen(result); 36 | if (len > 0 && result[len - 1] == '\n') { // Linefeed. 37 | result[--len] = '\0'; 38 | } 39 | if (len > 0 && result[len - 1] == '\r') { // Carriage return. 40 | result[--len] = '\0'; 41 | } 42 | if (line_callback_) line_callback_(result); 43 | 44 | // increase line counter 45 | readed_line_count++; 46 | } 47 | data_file->close(); 48 | } 49 | 50 | 51 | } // namespace PS 52 | -------------------------------------------------------------------------------- /src/util/filelinereader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "util/file.h" 9 | 10 | namespace PS { 11 | 12 | // The FileLineReader class will read a text file specified by 13 | // 'filename' line by line. Each line will be cleaned with respect to 14 | // termination ('\n' and '\r'). The line callback will be called in 15 | // sequence on each line. 16 | class FileLineReader { 17 | public: 18 | // Creates a file line reader object that will read the file 'filename' 19 | // line by line. 20 | explicit FileLineReader(const DataConfig& data_conf) : 21 | data_conf_(data_conf), loaded_successfully_(false) {}; 22 | 23 | ~FileLineReader() { } 24 | 25 | // Sets the line callback and takes ownership. 26 | void set_line_callback(std::function callback) { 27 | line_callback_ = callback; 28 | } 29 | 30 | // Reloads the file line by line. 31 | void Reload(); 32 | 33 | // Indicates if the file was loaded successfully. 34 | bool loaded_successfully() const { return loaded_successfully_; } 35 | 36 | private: 37 | DataConfig data_conf_; 38 | std::function line_callback_; 39 | bool loaded_successfully_; 40 | }; 41 | 42 | } // namespace PS 43 | -------------------------------------------------------------------------------- /src/util/hdfs.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "util/common.h" 4 | #include "proto/config.pb.h" 5 | 6 | namespace PS { 7 | 8 | } // namespace PS 9 | -------------------------------------------------------------------------------- /src/util/integral_types.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifndef SWIG 4 | // Standard typedefs 5 | typedef signed char schar; 6 | typedef signed char int8; 7 | typedef short int16; // NOLINT 8 | typedef int int32; 9 | #ifdef COMPILER_MSVC 10 | typedef __int64 int64; // NOLINT 11 | #else 12 | typedef long long int64; // NOLINT 13 | #endif /* COMPILER_MSVC */ 14 | 15 | // NOTE: unsigned types are DANGEROUS in loops and other arithmetical 16 | // places. Use the signed types unless your variable represents a bit 17 | // pattern (eg a hash value) or you really need the extra bit. Do NOT 18 | // use 'unsigned' to express "this value should always be positive"; 19 | // use assertions for this. 20 | 21 | typedef unsigned char uint8; 22 | typedef unsigned short uint16; // NOLINT 23 | typedef unsigned int uint32; 24 | #ifdef COMPILER_MSVC 25 | typedef unsigned __int64 uint64; 26 | #else 27 | typedef unsigned long long uint64; // NOLINT 28 | #endif /* COMPILER_MSVC */ 29 | 30 | // A type to represent a Unicode code-point value. As of Unicode 4.0, 31 | // such values require up to 21 bits. 32 | // (For type-checking on pointers, make this explicitly signed, 33 | // and it should always be the signed version of whatever int32 is.) 34 | typedef signed int char32; 35 | 36 | // A type to represent a natural machine word (for e.g. efficiently 37 | // scanning through memory for checksums or index searching). Don't use 38 | // this for storing normal integers. Ideally this would be just 39 | // unsigned int, but our 64-bit architectures use the LP64 model 40 | // (http://www.opengroup.org/public/tech/aspen/lp64_wp.htm), hence 41 | // their ints are only 32 bits. We want to use the same fundamental 42 | // type on all archs if possible to preserve *printf() compatability. 43 | typedef unsigned long uword_t; // NOLINT 44 | 45 | // A signed natural machine word. In general you want to use "int" 46 | // rather than "sword_t" 47 | typedef long sword_t; // NOLINT 48 | 49 | #endif /* SWIG */ 50 | 51 | // long long macros to be used because gcc and vc++ use different suffixes, 52 | // and different size specifiers in format strings 53 | #undef GG_LONGLONG 54 | #undef GG_ULONGLONG 55 | #undef GG_LL_FORMAT 56 | 57 | #ifdef COMPILER_MSVC /* if Visual C++ */ 58 | 59 | // VC++ long long suffixes 60 | #define GG_LONGLONG(x) x##I64 61 | #define GG_ULONGLONG(x) x##UI64 62 | 63 | // Length modifier in printf format string for int64's (e.g. within %d) 64 | #define GG_LL_FORMAT "I64" // As in printf("%I64d", ...) 65 | #define GG_LL_FORMAT_W L"I64" 66 | 67 | #else /* not Visual C++ */ 68 | 69 | #define GG_LONGLONG(x) x##LL 70 | #define GG_ULONGLONG(x) x##ULL 71 | #define GG_LL_FORMAT "ll" // As in "%lld". Note that "q" is poor form also. 72 | #define GG_LL_FORMAT_W L"ll" 73 | 74 | #endif // COMPILER_MSVC 75 | 76 | 77 | static const uint8 kuint8max = static_cast(0xFF); 78 | static const uint16 kuint16max = static_cast(0xFFFF); 79 | static const uint32 kuint32max = static_cast(0xFFFFFFFF); 80 | static const uint64 kuint64max = 81 | static_cast(GG_LONGLONG(0xFFFFFFFFFFFFFFFF)); 82 | static const int8 kint8min = static_cast(0x80); 83 | static const int8 kint8max = static_cast(0x7F); 84 | static const int16 kint16min = static_cast(0x8000); 85 | static const int16 kint16max = static_cast(0x7FFF); 86 | static const int32 kint32min = static_cast(0x80000000); 87 | static const int32 kint32max = static_cast(0x7FFFFFFF); 88 | static const int64 kint64min = 89 | static_cast(GG_LONGLONG(0x8000000000000000)); 90 | static const int64 kint64max = 91 | static_cast(GG_LONGLONG(0x7FFFFFFFFFFFFFFF)); 92 | -------------------------------------------------------------------------------- /src/util/macros.h: -------------------------------------------------------------------------------- 1 | namespace PS { 2 | 3 | // DISALLOW_COPY_AND_ASSIGN disallows the copy and operator= functions. 4 | // It goes in the private: declarations in a class. 5 | #define DISALLOW_COPY_AND_ASSIGN(TypeName) \ 6 | TypeName(const TypeName&); \ 7 | void operator=(const TypeName&) 8 | 9 | #define SINGLETON(Typename) \ 10 | static Typename& instance() { \ 11 | static Typename e; \ 12 | return e; \ 13 | } 14 | 15 | } // namespace PS 16 | -------------------------------------------------------------------------------- /src/util/murmurhash3.h: -------------------------------------------------------------------------------- 1 | //----------------------------------------------------------------------------- 2 | // MurmurHash3 was written by Austin Appleby, and is placed in the public 3 | // domain. The author hereby disclaims copyright to this source code. 4 | 5 | #ifndef _MURMURHASH3_H_ 6 | #define _MURMURHASH3_H_ 7 | 8 | //----------------------------------------------------------------------------- 9 | // Platform-specific functions and macros 10 | 11 | // Microsoft Visual Studio 12 | 13 | #if defined(_MSC_VER) 14 | 15 | typedef unsigned char uint8_t; 16 | typedef unsigned long uint32_t; 17 | typedef unsigned __int64 uint64_t; 18 | 19 | // Other compilers 20 | 21 | #else // defined(_MSC_VER) 22 | 23 | #include 24 | 25 | #endif // !defined(_MSC_VER) 26 | 27 | //----------------------------------------------------------------------------- 28 | 29 | void MurmurHash3_x86_32 ( const void * key, int len, uint32_t seed, void * out ); 30 | 31 | void MurmurHash3_x86_128 ( const void * key, int len, uint32_t seed, void * out ); 32 | 33 | void MurmurHash3_x64_128 ( const void * key, int len, uint32_t seed, void * out ); 34 | 35 | //----------------------------------------------------------------------------- 36 | 37 | #endif // _MURMURHASH3_H_ 38 | -------------------------------------------------------------------------------- /src/util/parallel_sort.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file parallel_sort.h 3 | * @date Tue Mar 31 17:01:58 2015 4 | * 5 | * @brief Parallel sort 6 | */ 7 | #pragma once 8 | #include "util/shared_array.h" 9 | namespace PS { 10 | 11 | namespace { 12 | /// @brief the thread function 13 | template 14 | void ParallelSort(T* data, size_t len, size_t grainsize, const Fn& cmp) { 15 | if (len <= grainsize) { 16 | std::sort(data, data + len, cmp); 17 | } else { 18 | std::thread thr(ParallelSort, data, len/2, grainsize, cmp); 19 | ParallelSort(data + len/2, len - len/2, grainsize, cmp); 20 | thr.join(); 21 | 22 | std::inplace_merge(data, data + len/2, data + len, cmp); 23 | } 24 | } 25 | } // namespace 26 | 27 | 28 | /** 29 | * @brief Parallel Sort 30 | * 31 | * @param arr array 32 | * @param num_threads 33 | * @param cmp the comparision function, such as [](const T& a, const T& b) { 34 | * return a < b; } or an even simplier version: std::less() 35 | */ 36 | template 37 | void ParallelSort(SArray* arr, int num_threads, const Fn& cmp) { 38 | CHECK_GT(num_threads, 0); 39 | size_t grainsize = std::max(arr->size() / num_threads + 5, (size_t)1024*16); 40 | ParallelSort(arr->data(), arr->size(), grainsize, cmp); 41 | } 42 | 43 | } // namespace PS 44 | -------------------------------------------------------------------------------- /src/util/producer_consumer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/threadsafe_limited_queue.h" 3 | #include "util/common.h" 4 | namespace PS { 5 | 6 | template 7 | class ProducerConsumer { 8 | public: 9 | ProducerConsumer() { setCapacity(1000); } 10 | ProducerConsumer(int capacity_in_mb) { setCapacity(capacity_in_mb); } 11 | void setCapacity(int mb) { queue_.setMaxCapacity(mb*1000000); } 12 | 13 | // *func* returns false if finished, true otherwise 14 | void startProducer(const std::function& func) { 15 | producer_thr_ = std::thread([this, func](){ 16 | V entry; 17 | bool done = false; 18 | while (!done) { 19 | size_t size = 0; 20 | done = !func(&entry, &size); 21 | queue_.push(entry, size, done); 22 | } 23 | }); 24 | producer_thr_.detach(); 25 | } 26 | 27 | void startConsumer(const std::function& func) { 28 | consumer_thr_ = std::thread([this, func](){ 29 | V entry; 30 | while (pop(&entry)) { 31 | func(entry); 32 | } 33 | }); 34 | // consumer_thr_.detach(); 35 | } 36 | void waitConsumer() { consumer_thr_.join(); } 37 | 38 | bool pop(V* data) { 39 | return queue_.pop(*data); 40 | } 41 | void push(const V& entry, size_t size = 1, bool finished = false) { 42 | queue_.push(entry, size, finished); 43 | } 44 | void setFinished() { 45 | V empty; 46 | queue_.push(empty, 0, true); 47 | } 48 | private: 49 | DISALLOW_COPY_AND_ASSIGN(ProducerConsumer); 50 | ThreadsafeLimitedQueue queue_; 51 | std::thread producer_thr_; 52 | std::thread consumer_thr_; 53 | }; 54 | } // namespace PS 55 | -------------------------------------------------------------------------------- /src/util/proto/assign_op.proto: -------------------------------------------------------------------------------- 1 | package PS; 2 | 3 | // assignment operators: http://en.cppreference.com/w/cpp/language/operator_assignment 4 | enum AssignOpType { 5 | ASSIGN = 0; // a = b 6 | PLUS = 1; // a += b 7 | MINUS = 2; // a -= b 8 | TIMES = 3; // a *= b 9 | DIVIDE = 4; // a -= b 10 | AND = 5; // a &= b 11 | OR = 6; // a |= b 12 | XOR = 7; // a ^= b 13 | } 14 | -------------------------------------------------------------------------------- /src/util/proto/auc.proto: -------------------------------------------------------------------------------- 1 | package PS; 2 | 3 | message AUCData { 4 | repeated int64 tp_key = 1; 5 | repeated uint64 tp_count = 2; 6 | 7 | repeated int64 fp_key = 3; 8 | repeated uint64 fp_count = 4; 9 | } 10 | -------------------------------------------------------------------------------- /src/util/proto/matrix.proto: -------------------------------------------------------------------------------- 1 | package PS; 2 | import "util/proto/range.proto"; 3 | 4 | message MatrixInfo { 5 | enum Type { 6 | DENSE = 1; 7 | SPARSE = 2; 8 | SPARSE_BINARY = 3; 9 | // TODO gpus 10 | } 11 | required Type type = 1; 12 | required bool row_major = 2; 13 | // e.g. feature group id 14 | optional int32 id = 3; 15 | 16 | // size 17 | required PbRange row = 5; 18 | required PbRange col = 6; 19 | // number of non-zero entries 20 | optional uint64 nnz = 7; 21 | 22 | optional uint32 sizeof_index = 8; 23 | required uint32 sizeof_value = 9; 24 | 25 | // // statisitic 26 | // optional float nnz_per_row = 11; 27 | // optional float nnz_per_col = 12; 28 | } 29 | -------------------------------------------------------------------------------- /src/util/proto/range.proto: -------------------------------------------------------------------------------- 1 | package PS; 2 | 3 | // TODO may change it to string 4 | message PbRange { 5 | required uint64 begin = 1; 6 | required uint64 end = 2; 7 | } 8 | -------------------------------------------------------------------------------- /src/util/recordio.cc: -------------------------------------------------------------------------------- 1 | // #include "util/recordio.h" 2 | 3 | // #include 4 | // #include 5 | // #include "glog/logging.h" 6 | 7 | // namespace PS { 8 | // const int RecordWriter::kMagicNumber = 0x3ed7230a; 9 | 10 | // std::string RecordWriter::Compress(std::string const& s) const { 11 | // const unsigned long source_size = s.size(); // NOLINT 12 | // const char* source = s.c_str(); 13 | 14 | // unsigned long dsize = source_size + (source_size * 0.1f) + 16; // NOLINT 15 | // std::unique_ptr destination(new char[dsize]); 16 | // // Use compress() from zlib.h. 17 | // const int result = 18 | // compress(reinterpret_cast(destination.get()), &dsize, 19 | // reinterpret_cast(source), source_size); 20 | 21 | // if (result != Z_OK) { 22 | // LOG(FATAL) << "Compress error occured! Error code: " << result; 23 | // } 24 | // return std::string(destination.get(), dsize); 25 | // } 26 | 27 | // void RecordReader::Uncompress(const char* const source, uint64 source_size, 28 | // char* const output_buffer, 29 | // uint64 output_size) const { 30 | // unsigned long result_size = output_size; // NOLINT 31 | // // Use uncompress() from zlib.h 32 | // const int result = 33 | // uncompress(reinterpret_cast(output_buffer), &result_size, 34 | // reinterpret_cast(source), source_size); 35 | // if (result != Z_OK) { 36 | // LOG(FATAL) << "Uncompress error occured! Error code: " << result; 37 | // } 38 | // CHECK_LE(result_size, static_cast(output_size)); // NOLINT 39 | // } 40 | 41 | // } // namespace PS 42 | -------------------------------------------------------------------------------- /src/util/recordio.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "util/file.h" 6 | namespace PS { 7 | 8 | namespace { 9 | static const int kMagicNumber = 0x3ed7230a; 10 | } 11 | 12 | // This class appends a protocol buffer to a file in a binary format. 13 | // The data written in the file follows the following format (sequentially): 14 | // - MagicNumber (32 bits) to recognize this format. 15 | // - data payload size (32 bits) 16 | // - Payload 17 | class RecordWriter { 18 | public: 19 | // Magic number when writing and reading protocol buffers. 20 | RecordWriter() { file_ = NULL; } 21 | explicit RecordWriter(File* const file) : file_(file) { } 22 | bool Close() { return file_ && file_->close(); } 23 | 24 | template bool WriteProtocolMessage(const P& proto) { 25 | if (file_ == NULL) return false; 26 | std::string buffer; 27 | proto.SerializeToString(&buffer); 28 | const uint32 buff_size = (uint32) buffer.size(); 29 | if (file_->write(&kMagicNumber, sizeof(kMagicNumber)) != 30 | sizeof(kMagicNumber)) { 31 | return false; 32 | } 33 | if (file_->write(&buff_size, sizeof(buff_size)) != sizeof(buff_size)) { 34 | return false; 35 | } 36 | if (file_->write(buffer.c_str(), buff_size) != buff_size) { 37 | return false; 38 | } 39 | return true; 40 | } 41 | 42 | private: 43 | File* file_; 44 | }; 45 | 46 | // This class reads a protocol buffer from a file. 47 | // The format must be the one described in RecordWriter, above. 48 | class RecordReader { 49 | public: 50 | explicit RecordReader(File* const file) : file_(file) { } 51 | bool Close() { return file_->close(); } 52 | 53 | template bool ReadProtocolMessage(P* const proto) { 54 | uint32 size = 0; 55 | int magic_number = 0; 56 | 57 | if (file_->read(&magic_number, sizeof(magic_number)) != 58 | sizeof(magic_number)) { 59 | return false; 60 | } 61 | if (magic_number != kMagicNumber) { 62 | return false; 63 | } 64 | if (file_->read(&size, sizeof(size)) != sizeof(size)) { 65 | return false; 66 | } 67 | std::unique_ptr buffer(new char[size + 1]); 68 | if (file_->read(buffer.get(), size) != size) { 69 | return false; 70 | } 71 | proto->ParseFromArray(buffer.get(), size); 72 | return true; 73 | } 74 | 75 | private: 76 | File* file_; 77 | }; 78 | 79 | } // namespace PS 80 | -------------------------------------------------------------------------------- /src/util/sketch.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/common.h" 3 | // #include 4 | namespace PS { 5 | 6 | // the basc class for bloom filters, countmin, etc... 7 | class Sketch { 8 | public: 9 | protected: 10 | 11 | // ver 1 is faster than ver 2, but is comparable to the murmurhash version 12 | // need add -msse4.2 in CFLAGS 13 | // uint64 crc32(uint64 key) const { 14 | // return _mm_crc32_u64(0, key); 15 | // } 16 | // uint32 crc32(uint64 key) const { 17 | // return _mm_crc32_u32((uint32)(key<<32), (uint32)key); 18 | // } 19 | 20 | uint32 hash(const uint64& key) const { 21 | // similar to murmurhash 22 | const uint32 seed = 0xbc9f1d34; 23 | const uint32 m = 0xc6a4a793; 24 | const uint32 n = 8; // sizeof uint64 25 | uint32 h = seed ^ (n * m); 26 | 27 | uint32 w = (uint32) key; 28 | h += w; h *= m; h ^= (h >> 16); 29 | 30 | w = (uint32) (key >> 32); 31 | h += w; h *= m; h ^= (h >> 16); 32 | return h; 33 | } 34 | }; 35 | } // namespace PS 36 | -------------------------------------------------------------------------------- /src/util/split.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "util/common.h" 3 | 4 | namespace PS { 5 | 6 | // split a std::string using a character delimiter. if skip_empty == true, 7 | // split("one:two::three", ':'); will return 4 items 8 | 9 | inline std::vector 10 | split(const std::string &s, char delim, bool skip_empty = false) { 11 | std::vector elems; 12 | std::stringstream ss(s); 13 | string item; 14 | while (std::getline(ss, item, delim)) 15 | if (!(skip_empty && item.empty())) 16 | elems.push_back(item); 17 | return elems; 18 | } 19 | 20 | // TODO support bool skip_empty = false 21 | inline std::string join(const std::vector &elems, const string& delim) { 22 | std::string str; 23 | for (int i = 0; i < elems.size() - 1; ++i) { 24 | str += elems[i] + delim; 25 | } 26 | str += elems.back(); 27 | return str; 28 | } 29 | 30 | } // namespace PS 31 | -------------------------------------------------------------------------------- /src/util/strtonum.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "util/integral_types.h" 6 | // #include "util/common.h" 7 | 8 | namespace PS { 9 | 10 | // return true if success 11 | 12 | inline bool strtofloat(const char* str, float* num) { 13 | char* end; 14 | *num = strtof(str, &end); 15 | if (*end == '\0') return true; 16 | return false; 17 | } 18 | 19 | 20 | inline bool strtoi32(const char* str, int32* num) { 21 | char* end; 22 | *num = strtol(str, &end, 10); 23 | if (*end == '\0') return true; 24 | return false; 25 | } 26 | 27 | inline bool strtou64(const char* str, uint64* num) { 28 | char* end; 29 | *num = strtoull(str, &end, 10); 30 | if (*end == '\0') return true; 31 | return false; 32 | } 33 | 34 | // convinient wrapper 35 | inline bool strtofloat(const std::string& str, float* num) { 36 | return strtofloat(str.c_str(), num); 37 | } 38 | 39 | inline bool strtoi32(const std::string& str, int32* num) { 40 | return strtoi32(str.c_str(), num); 41 | } 42 | 43 | inline bool strtou64(const std::string& str, uint64* num) { 44 | return strtou64(str.c_str(), num); 45 | } 46 | 47 | } // namespace PS 48 | -------------------------------------------------------------------------------- /src/util/threadpool.cc: -------------------------------------------------------------------------------- 1 | #include "util/threadpool.h" 2 | 3 | namespace PS { 4 | 5 | ThreadPool::~ThreadPool() { 6 | if (!started_) return; 7 | // if (started_) { 8 | mu_.lock(); 9 | waiting_to_finish_ = true; 10 | cv_.notify_all(); 11 | mu_.unlock(); 12 | 13 | stopOnFinalBarrier(); 14 | 15 | for (int i = 0; i < num_workers_; ++i) 16 | all_workers_[i].join(); 17 | // } 18 | } 19 | 20 | void RunWorker(void* data) { 21 | ThreadPool* const thread_pool = reinterpret_cast(data); 22 | auto task = thread_pool->getNextTask(); 23 | while (task) { 24 | task(); 25 | task = thread_pool->getNextTask(); 26 | } 27 | thread_pool->stopOnFinalBarrier(); 28 | } 29 | 30 | void ThreadPool::startWorkers() { 31 | started_ = true; 32 | for (int i = 0; i < num_workers_; ++i) 33 | all_workers_.push_back(std::move(std::thread(&RunWorker, this))); 34 | } 35 | 36 | typename ThreadPool::Task ThreadPool::getNextTask() { 37 | std::unique_lock l(mu_); 38 | for (;;) { 39 | if (!tasks_.empty()) { 40 | auto task = tasks_.front(); 41 | tasks_.pop_front(); 42 | return task; 43 | } 44 | 45 | if (waiting_to_finish_) 46 | return Task(); 47 | else 48 | cv_.wait(l); 49 | } 50 | return Task(); 51 | } 52 | 53 | void ThreadPool::add(const Task& task) { 54 | std::lock_guard l(mu_); 55 | tasks_.push_back(task); 56 | if (started_) cv_.notify_all(); 57 | } 58 | 59 | } // namespace PS 60 | -------------------------------------------------------------------------------- /src/util/threadpool.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #include "util/macros.h" 13 | #include "util/barrier.h" 14 | 15 | namespace PS { 16 | 17 | class ThreadPool { 18 | public: 19 | explicit ThreadPool(int num_workers) 20 | : num_workers_(num_workers), final_barrier_(num_workers + 1) {} 21 | ~ThreadPool(); 22 | 23 | typedef std::function Task; 24 | void add(const Task& task); 25 | 26 | void startWorkers(); 27 | 28 | // for internal use 29 | Task getNextTask(); 30 | void stopOnFinalBarrier() { final_barrier_.Block(); } 31 | private: 32 | DISALLOW_COPY_AND_ASSIGN(ThreadPool); 33 | 34 | 35 | const int num_workers_; 36 | std::list tasks_; 37 | std::mutex mu_; 38 | std::condition_variable cv_; 39 | 40 | // std::unique_ptr final_barrier_; 41 | Barrier final_barrier_; 42 | std::vector all_workers_; 43 | 44 | bool waiting_to_finish_ = false; 45 | bool started_ = false; 46 | }; 47 | 48 | } // PS 49 | -------------------------------------------------------------------------------- /src/util/threadsafe_limited_queue.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "util/common.h" 7 | 8 | namespace PS { 9 | 10 | template 11 | class ThreadsafeLimitedQueue { 12 | public: 13 | ThreadsafeLimitedQueue() { } 14 | ThreadsafeLimitedQueue(size_t capacity) { setMaxCapacity(capacity); } 15 | void setMaxCapacity(size_t capacity) { max_capacity_ = capacity; } 16 | 17 | void push(const T& value, size_t capacity, bool finished = false) { 18 | CHECK(!done_) << "must not call push again if *finished* is set true"; 19 | if (capacity > max_capacity_) { 20 | LL << "push obj with size " << capacity 21 | << " into queue with capacity " << max_capacity_ 22 | << ". you will be blocked here forever..."; 23 | } 24 | // do not insert 25 | if (finished == false && capacity == 0) return; 26 | std::unique_lock l(mu_); 27 | full_cond_.wait(l, [this, capacity]{ 28 | return (capacity + cur_capacity_ <= max_capacity_); }); 29 | queue_.push(std::move(std::make_pair(value, capacity))); 30 | cur_capacity_ += capacity; 31 | done_ = finished; 32 | empty_cond_.notify_all(); 33 | } 34 | 35 | bool pop(T& value) { 36 | std::unique_lock l(mu_); 37 | // already finished 38 | if (done_ && queue_.empty()) return false; 39 | 40 | empty_cond_.wait(l, [this]{ return !queue_.empty(); }); 41 | std::pair e = std::move(queue_.front()); 42 | 43 | // an empty item, which is inserted only when finished 44 | if (e.second == 0) { 45 | CHECK(done_); 46 | return false; 47 | } 48 | 49 | // get a valid item 50 | value = std::move(e.first); 51 | cur_capacity_ -= e.second; 52 | queue_.pop(); 53 | full_cond_.notify_all(); 54 | return true; 55 | } 56 | 57 | size_t size() const { 58 | std::lock_guard l(mu_); 59 | return queue_.size(); 60 | } 61 | 62 | bool empty() const { 63 | return size() == 0; 64 | } 65 | 66 | private: 67 | mutable std::mutex mu_; 68 | bool done_ = false; 69 | size_t max_capacity_ = 0, cur_capacity_ = 0; 70 | std::queue > queue_; 71 | std::condition_variable empty_cond_, full_cond_; 72 | }; 73 | } // namespace PS 74 | -------------------------------------------------------------------------------- /src/util/threadsafe_queue.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | namespace PS { 6 | 7 | // TODO the code style is inconsistent with others 8 | template class ThreadsafeQueue { 9 | public: 10 | ThreadsafeQueue() {} 11 | 12 | void push(T new_value) { 13 | std::lock_guard lk(mut); 14 | data_queue.push(std::move(new_value)); 15 | data_cond.notify_all(); 16 | } 17 | 18 | void wait_and_pop(T& value) { 19 | std::unique_lock lk(mut); 20 | data_cond.wait(lk, [this]{return !data_queue.empty();}); 21 | value = std::move(data_queue.front()); 22 | data_queue.pop(); 23 | } 24 | 25 | bool try_pop(T& value) { 26 | std::lock_guard lk(mut); 27 | if(data_queue.empty()) 28 | return false; 29 | value=std::move(data_queue.front()); 30 | data_queue.pop(); 31 | return true; 32 | } 33 | 34 | size_t size() const { 35 | std::lock_guard lk(mut); 36 | return data_queue.size(); 37 | } 38 | 39 | bool empty() const { 40 | std::lock_guard lk(mut); 41 | return data_queue.empty(); 42 | } 43 | 44 | private: 45 | mutable std::mutex mut; 46 | std::queue data_queue; 47 | std::condition_variable data_cond; 48 | }; 49 | 50 | } // namespace PS 51 | 52 | // std::shared_ptr wait_and_pop() { 53 | // std::unique_lock lk(mut); 54 | // data_cond.wait(lk, [this]{return !data_queue.empty();}); 55 | // std::shared_ptr res( 56 | // std::make_shared(std::move(data_queue.front()))); 57 | // data_queue.pop(); 58 | // return res; 59 | // } 60 | 61 | 62 | // std::shared_ptr try_pop() { 63 | // std::lock_guard lk(mut); 64 | // if(data_queue.empty()) 65 | // return std::shared_ptr(); 66 | // std::shared_ptr res( 67 | // std::make_shared(std::move(data_queue.front()))); 68 | // data_queue.pop(); 69 | // return res; 70 | // } 71 | --------------------------------------------------------------------------------