├── .dockerignore
├── .gitignore
├── Dockerfile
├── LICENSE
├── Makefile
├── README.md
├── doc
├── Doxyfile
├── README.md
└── gendoc.sh
├── docker
├── README.md
├── build.sh
├── cloud.sh
├── fire_amazonec2.sh
├── get_host_within_container.sh
├── local.sh
├── rm_local.sh
├── shut.sh
└── upload_s3.sh
├── example
└── linear
│ ├── README.md
│ ├── criteo
│ ├── README.md
│ ├── batch_l1lr.conf
│ ├── download.sh
│ └── eval_batch.conf
│ ├── ctr
│ ├── batch_l1lr.conf
│ ├── download.sh
│ ├── eval_batch.conf
│ ├── eval_online.conf
│ └── online_l1lr.conf
│ └── rcv1
│ ├── batch_l1lr.conf
│ ├── download.sh
│ ├── eval_batch.conf
│ ├── eval_online.conf
│ └── online_l1lr.conf
├── make
├── README.md
└── config.mk
├── script
├── get_root_node.sh
├── install_third.sh
├── kill_node.sh
├── local.sh
├── mpi_node.sh
├── mpi_root.sh
└── ps.sh
└── src
├── README.mk
├── app
└── linear_method
│ ├── async_sgd.h
│ ├── darlin.h
│ ├── learning_rate.h
│ ├── loss.h
│ ├── main.cc
│ ├── model_evaluation.h
│ ├── penalty.h
│ └── proto
│ └── linear.proto
├── data
├── common.cc
├── common.h
├── info_parser.cc
├── info_parser.h
├── matlab
│ ├── bin2mat.m
│ ├── filter_fea.m
│ ├── load_bin.m
│ ├── mat2bin.m
│ ├── save_bin.m
│ └── saveas_pserver.m
├── proto
│ ├── data.proto
│ └── example.proto
├── show_example.h
├── slot_reader.cc
├── slot_reader.h
├── stream_reader.h
├── text2proto.h
├── text_parser.cc
└── text_parser.h
├── filter
├── add_noise.h
├── compressing.h
├── filter.cc
├── filter.h
├── fixing_float.h
├── frequency_filter.h
├── key_caching.h
├── proto
│ └── filter.proto
└── sparse_filter.h
├── learner
├── bcd.cc
├── bcd.h
├── proto
│ ├── bcd.proto
│ ├── sgd.proto
│ └── workload.proto
├── sgd.cc
├── sgd.h
├── workload_pool.cc
└── workload_pool.h
├── parameter
├── README.org
├── kv_layer.h
├── kv_map.h
├── kv_store.h
├── kv_vector.h
├── parameter.cc
├── parameter.h
└── proto
│ └── param.proto
├── ps.h
├── ps_main.cc
├── system
├── assigner.cc
├── assigner.h
├── customer.h
├── dashboard.cc
├── dashboard.h
├── env.cc
├── env.h
├── executor.cc
├── executor.h
├── heartbeat_info.cc
├── heartbeat_info.h
├── manager.cc
├── manager.h
├── message.cc
├── message.h
├── monitor.h
├── postoffice.cc
├── postoffice.h
├── proto
│ ├── heartbeat.proto
│ ├── node.proto
│ └── task.proto
├── remote_node.cc
├── remote_node.h
├── task_tracker.h
├── van.cc
└── van.h
├── test
├── aggregation_ps.cc
├── assign_op_test.cc
├── bloom_filter_test.cc
├── build.mk
├── common_test.cc
├── countmin_test.cc
├── fixing_float_test.cc
├── hello_ps.cc
├── kv_layer_perf_ps.cc
├── kv_layer_ps.cc
├── kv_map_perf_ps.cc
├── kv_map_ps.cc
├── kv_vector_buffer_ps.cc
├── kv_vector_perf_ps.cc
├── kv_vector_ps.cc
├── localizer_test.cc
├── network_perf_ps.cc
├── parallel_ordered_match_test.cc
├── reassign_server_key_range_ps.cc
├── slot_reader_test.cc
├── sparse_matrix_perf.cc
├── sparse_matrix_test.cc
└── stream_reader_test.cc
├── test_main.cc
└── util
├── assign_op.h
├── auc.h
├── barrier.h
├── bitmap.h
├── block_bloom_filter.h
├── bloom_filter.h
├── common.h
├── countmin.h
├── crc32c.cc
├── crc32c.h
├── dense_matrix.h
├── evaluation.h
├── file.cc
├── file.h
├── filelinereader.cc
├── filelinereader.h
├── hdfs.h
├── integral_types.h
├── local_machine.h
├── localizer.h
├── macros.h
├── matrix.h
├── murmurhash3.cc
├── murmurhash3.h
├── parallel_ordered_match.h
├── parallel_sort.h
├── producer_consumer.h
├── proto
├── assign_op.proto
├── auc.proto
├── matrix.proto
└── range.proto
├── range.h
├── recordio.cc
├── recordio.h
├── resource_usage.h
├── shared_array.h
├── shared_array_inl.h
├── sketch.h
├── sparse_matrix.h
├── split.h
├── strtonum.h
├── threadpool.cc
├── threadpool.h
├── threadsafe_limited_queue.h
└── threadsafe_queue.h
/.dockerignore:
--------------------------------------------------------------------------------
1 | LICENSE
2 | README.md
3 | build
4 | example
5 | doc
6 | script
7 | third_party
8 |
9 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | example/**/data
2 | example/**/model
3 | build
4 | third_party
5 | **/log/**
6 | **/cache/**
7 | **/output/**
8 | *.pb.cc
9 | *.pb.h
10 | .*
11 | *core*
12 | config.mk
13 | **/bak/**
14 | doc/html
15 | doc/latex
16 | test/
17 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM qicongc/ps-baseimage
2 | WORKDIR /home/parameter_server
3 | ADD src src
4 | ADD make make
5 | ADD Makefile Makefile
6 | # compile ps
7 | RUN make -j8
8 | # add script used by container to get the ip of its docker host provided by its cloud_provider
9 | ADD docker/get_host_within_container.sh get_host_within_container.sh
10 | # TODO: install dependency to /usr
11 | ENV LD_LIBRARY_PATH /usr/local/lib
12 | # run ps according to args passed in
13 | CMD build/linear -bind_to $my_port -my_node "role:$my_role,hostname:'`./get_host_within_container.sh`',port:$my_port,id:'$my_id'" -scheduler "role:SCHEDULER,hostname:'$scheduler_host',port:$scheduler_port,id:'H'" -app_file $app_file -num_servers $num_servers -num_workers $num_workers $args
14 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | ifneq ("$(wildcard ./config.mk)","")
2 | include ./config.mk
3 | else
4 | include make/config.mk
5 | endif
6 |
7 | ifeq ($(STATIC_THIRD_LIB), 1)
8 | THIRD_LIB=$(addprefix $(THIRD_PATH)/lib/, libgflags.a libzmq.a libprotobuf.a libglog.a libz.a libsnappy.a)
9 | ifeq ($(USE_S3),1)
10 | THIRD_LIB+=$(addprefix $(THIRD_PATH)/lib/, libxml2.a)
11 | endif
12 | else
13 | THIRD_LIB=-L$(THIRD_PATH)/lib -lgflags -lzmq -lprotobuf -lglog -lz -lsnappy
14 | ifeq ($(USE_S3),1)
15 | THIRD_LIB+=-lxml2
16 | endif
17 | endif
18 |
19 | WARN = -Wall -Wno-unused-function -finline-functions -Wno-sign-compare #-Wconversion
20 | INCPATH = -I./src -I$(THIRD_PATH)/include
21 | CFLAGS = -std=c++0x $(WARN) $(OPT) $(INCPATH) $(EXTRA_CFLAGS)
22 | ifeq ($(USE_S3), 1)
23 | CFLAGS += -DUSE_S3=1
24 | endif
25 | LDFLAGS = $(EXTRA_LDFLAGS) $(THIRD_LIB) -lpthread # -lrt
26 |
27 | PS_LIB = build/libps.a
28 | PS_MAIN = build/libpsmain.a
29 | TEST_MAIN = build/test_main.o
30 |
31 | clean:
32 | rm -rf build
33 | find src -name "*.pb.[ch]*" -delete
34 |
35 | ps: $(PS_LIB) $(PS_MAIN) $(TEST_MAIN)
36 |
37 | # PS system
38 | sys_dir = $(addprefix src/, util data system filter learner parameter)
39 | sys_srcs = $(wildcard $(patsubst %, %/*.cc, $(sys_dir)))
40 | sys_protos = $(wildcard $(patsubst %, %/proto/*.proto, $(sys_dir)))
41 | sys_objs = $(patsubst src/%.proto, build/%.pb.o, $(sys_protos)) \
42 | $(patsubst src/%.cc, build/%.o, $(sys_srcs))
43 |
44 | build/libps.a: $(patsubst %.proto, %.pb.h, $(sys_protos)) $(sys_objs)
45 | ar crv $@ $(filter %.o, $?)
46 |
47 | build/libpsmain.a: build/ps_main.o
48 | ar crv $@ $?
49 |
50 | # applications
51 | build/linear: $(addprefix build/app/linear_method/, proto/linear.pb.o main.o) $(PS_LIB)
52 | $(CC) $(CFLAGS) $^ $(LDFLAGS) -o $@
53 |
54 | # general rules
55 | build/%.o: src/%.cc
56 | @mkdir -p $(@D)
57 | $(CC) $(INCPATH) -std=c++0x -MM -MT build/$*.o $< >build/$*.d
58 | $(CC) $(CFLAGS) -c $< -o $@
59 |
60 | %.pb.cc %.pb.h : %.proto
61 | ${THIRD_PATH}/bin/protoc --cpp_out=./src --proto_path=./src $<
62 |
63 | -include build/*/*.d
64 | -include build/*/*/*.d
65 | -include src/test/build.mk
66 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | The parameter server is a distributed system scaling to industry size machine
4 | learning problems. It provides asynchronous and zero-copy key-value pair
5 | communications between worker machines and server machines. It also supports
6 | flexible data consistency model, data filters, and flexible server machine
7 | programming.
8 |
9 | **NOTE: We stop maitaining this repo. Please check the newer version called [ps-lite](https://github.com/dmlc/ps-lite)**
10 |
11 | - [Document](doc/)
12 | - [Wiki](https://github.com/dmlc/parameter_server/wiki/)
13 | - How to [build](make/)
14 | - Examples
15 | - [Linear method](example/linear), [Linear method with Cloud](docker)
16 | - Deep neural network, see [CXXNET](https://github.com/dmlc/cxxnet) and [Minverva](https://github.com/minerva-developers/minerva)
17 |
--------------------------------------------------------------------------------
/doc/README.md:
--------------------------------------------------------------------------------
1 | # Documents
2 |
3 | Use `gendoc.sh` to generate Doxygen documents.
4 |
--------------------------------------------------------------------------------
/doc/gendoc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | dir=`dirname "$0"`
4 | cd $dir/..
5 | doxygen doc/Doxyfile
6 |
--------------------------------------------------------------------------------
/docker/build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [[ $# -lt 2 ]]; then
4 | echo "usage: $0 manager image"
5 | exit -1
6 | fi
7 |
8 | if [[ ! -e ../Dockerfile ]]; then
9 | echo "cannot find Dockerfile! this command must be excuted under docker/"
10 | exit -1
11 | fi
12 |
13 | manager=$1
14 | shift
15 |
16 | image=$1
17 | shift
18 |
19 | eval "`docker-machine env $manager`"
20 | docker build -t $image ..
21 | docker push $image
--------------------------------------------------------------------------------
/docker/cloud.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | if [[ $# -lt 6 ]]; then
4 | echo "usage: $0 cloud_provider manager image num_servers num_workers app_conf_file [args...]"
5 | exit -1
6 | fi
7 |
8 | cloud_provider=$1
9 | shift
10 |
11 | manager=$1
12 | shift
13 |
14 | mount="-v /tmp/parameter_server/data/cache:/home/parameter_server/data/cache"
15 | case $cloud_provider in
16 | amazonec2)
17 | mount="$mount -v /var/log/cloud-init.log:/var/log/cloud-init.log"
18 | scheduler_host=`docker-machine ssh $manager cat /var/log/cloud-init.log | awk '/update hostname/ {print $10}'`
19 | ;;
20 | *)
21 | echo "Currently only support amazonec2!"
22 | exit -1
23 | ;;
24 | esac
25 |
26 |
27 | image=$1
28 | shift
29 |
30 | num_servers=$1
31 | shift
32 |
33 | num_workers=$1
34 | shift
35 |
36 | app_file=$1
37 | shift
38 |
39 | args=$@
40 |
41 | # stop all running containers
42 | echo "cleaning previous apps ..."
43 | eval "`docker-machine env --swarm $manager`"
44 | clean_list=`docker ps -q`
45 | docker stop $clean_list > /dev/null 2>&1
46 | echo "update parameter server image cluster wide ..."
47 | docker pull $image
48 |
49 |
50 | # launch scheduler
51 | echo "launching scheduler ..."
52 | eval "`docker-machine env $manager`"
53 | scheduler_port=8000
54 | env="\
55 | -e cloud_provider=$cloud_provider \
56 | -e scheduler_host=$scheduler_host \
57 | -e scheduler_port=$scheduler_port \
58 | -e my_role=SCHEDULER \
59 | -e my_host=$scheduler_host \
60 | -e my_port=$scheduler_port \
61 | -e my_id=H \
62 | -e num_servers=$num_servers \
63 | -e num_workers=$num_workers \
64 | -e app_file=$app_file \
65 | "
66 | docker rm n0 > /dev/null 2>&1
67 | docker run -d -p $scheduler_port:$scheduler_port $env -e "args=$args" --name n0 $image
68 |
69 | #launch servers and workers
70 | echo "launching workers and servers ..."
71 | eval "`docker-machine env --swarm $manager`"
72 | for (( i = 1; i < $num_servers + $num_workers + 1; ++i )); do
73 | my_port=$(($scheduler_port + $i))
74 | if (( $i <= $num_servers )); then
75 | my_role="SERVER"
76 | my_id="S$i"
77 | else
78 | my_role="WORKER"
79 | my_id="W$i"
80 | fi
81 | env="\
82 | -e cloud_provider=$cloud_provider \
83 | -e scheduler_host=$scheduler_host \
84 | -e scheduler_port=$scheduler_port \
85 | -e my_role=$my_role \
86 | -e my_port=$my_port \
87 | -e my_id=$my_id \
88 | -e num_servers=$num_servers \
89 | -e num_workers=$num_workers \
90 | -e app_file=$app_file \
91 | "
92 | docker rm n$i > /dev/null 2>&1
93 | docker run -d -p $my_port:$my_port $env -e "args=$args" $mount --name n$i $image &
94 | done
95 |
96 | wait
97 |
--------------------------------------------------------------------------------
/docker/fire_amazonec2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | if [[ $# -lt 10 ]]; then
3 | echo "usage: $0 n-instances amazonec2-instance-type spot|regular price-to-bid amazonec2-access-key amazonec2-secret-key amazonec2-region amazonec2-vpc-id amazonec2-security-group amazonec2-zone"
4 | exit -1
5 | fi
6 |
7 | echo "check whether docker is ready ..."
8 | docker version
9 | if [[ $? = 1 ]]; then
10 | echo "please point your docker client to the right docker server!"
11 | exit -1
12 | fi
13 |
14 | if [[ ! `which docker-machine` ]]; then
15 | rm -rf machine
16 | git clone https://github.com/docker/machine
17 | cd machine
18 | # check platform
19 | platform=`uname`
20 | if [[ $platform = 'Darwin' ]];then
21 | platform='darwin'
22 | else
23 | platform='linux'
24 | fi
25 | script/build -osarch="$platform/amd64"
26 | mv docker-machine_$platform* /usr/local/bin/docker-machine
27 | cd ..
28 | rm -rf machine
29 | fi
30 |
31 | n=$1
32 | shift
33 |
34 | amazonec2_instance_type=$1
35 | shift
36 |
37 | if [[ $1 = 'spot' ]]; then
38 | shift
39 | amazonec2_request_spot_instance="--amazonec2-request-spot-instance --amazonec2-spot-price $1"
40 | shift
41 | elif [[ $1 = 'regular' ]]; then
42 | shift
43 | shift
44 | else
45 | echo "Currently only support regular and spot, but get $1."
46 | exit -1
47 | fi
48 |
49 |
50 |
51 | amazonec2_access_key=$1
52 | shift
53 |
54 | amazonec2_secret_key=$1
55 | shift
56 |
57 | amazonec2_region=$1
58 | shift
59 |
60 | amazonec2_vpc_id=$1
61 | shift
62 |
63 | amazonec2_security_group=$1
64 | shift
65 |
66 | amazonec2_zone=$1
67 | shift
68 |
69 | discovery="token://`docker run swarm create`"
70 |
71 | docker-machine create -d amazonec2 --amazonec2-access-key $amazonec2_access_key --amazonec2-secret-key $amazonec2_secret_key --amazonec2-region $amazonec2_region --amazonec2-vpc-id $amazonec2_vpc_id --amazonec2-security-group $amazonec2_security_group --amazonec2-zone $amazonec2_zone --amazonec2-instance-type $amazonec2_instance_type $amazonec2_request_spot_instance --swarm --swarm-master --swarm-discovery $discovery swarm-master &
72 | for (( i = 0; i < n-1; i++ )); do
73 | docker-machine create -d amazonec2 --amazonec2-access-key $amazonec2_access_key --amazonec2-secret-key $amazonec2_secret_key --amazonec2-region $amazonec2_region --amazonec2-vpc-id $amazonec2_vpc_id --amazonec2-security-group $amazonec2_security_group --amazonec2-zone $amazonec2_zone --amazonec2-instance-type $amazonec2_instance_type $amazonec2_request_spot_instance --swarm --swarm-discovery $discovery swarm-node-$i &
74 | done
75 | wait
76 |
77 |
78 |
--------------------------------------------------------------------------------
/docker/get_host_within_container.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | if [[ $my_host ]]; then
3 | echo "$my_host"
4 | else
5 | case $cloud_provider in
6 | amazonec2)
7 | cat /var/log/cloud-init.log | awk '/update hostname/ {print $10}'
8 | ;;
9 | *)
10 | echo "Currently only support amazonec2!"
11 | exit -1
12 | ;;
13 | esac
14 | fi
15 |
16 |
17 |
--------------------------------------------------------------------------------
/docker/local.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ $# -lt 5 ]; then
4 | echo "usage: $0 num_servers num_workers app_conf_file data_dir model_dir [args...]"
5 | exit -1;
6 | fi
7 |
8 | ip=`/sbin/ifconfig docker0 | grep inet | grep -v inet6 | awk '{print $2}' | sed -e 's/[a-z]*://'`
9 | if [ -z ${ip} ]; then
10 | echo "failed to get the ip address for docker"
11 | exit -1
12 | fi
13 |
14 | num_servers=$1
15 | shift
16 |
17 | num_workers=$1
18 | shift
19 |
20 | app=$1
21 | if [[ "$app" != /* ]]; then
22 | app=`pwd`/$app
23 | fi
24 | shift
25 |
26 | data=$1
27 | if [[ "$data" != /* ]]; then
28 | data=`pwd`/$data
29 | fi
30 | shift
31 |
32 | model=$1
33 | if [[ "$model" != /* ]]; then
34 | model=`pwd`/$model
35 | fi
36 | shift
37 |
38 | port=8000
39 | bin="muli/parameter-server /build/ps"
40 | # bin_v="-v /home/muli/work/ps/build:/build"
41 | app_v="-v $app:/app.conf"
42 | data_v="-v $data:/data -v $model:/model"
43 | mount="$bin_v $app_v $data_v"
44 |
45 | arg="-app_file /app.conf -num_servers $num_servers -num_workers $num_workers $@"
46 |
47 | sch="role:SCHEDULER,hostname:'$ip',port:$port,id:'H'"
48 | for (( i = 0; i < ${num_servers} + ${num_workers} + 1; ++i )); do
49 | myport=$(($port + ${i}))
50 | if (( $i == 0)); then
51 | node=$sch
52 | elif (( $i <= ${num_servers} )); then
53 | node="role:SERVER,hostname:'$ip',port:${myport},id:'S${i}'"
54 | else
55 | node="role:WORKER,hostname:'$ip',port:${myport},id:'W${i}'"
56 | fi
57 | docker run --rm -p $myport:$myport --name n${i} $mount $bin \
58 | -my_node ${node} -scheduler ${sch} ${arg} &
59 | done
60 |
61 | wait
62 |
--------------------------------------------------------------------------------
/docker/rm_local.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | docker rm -f $(docker ps -a -q)
3 |
--------------------------------------------------------------------------------
/docker/shut.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | if [[ $# -lt 1 ]]; then
3 | echo "usage: $0 n-instances"
4 | exit -1
5 | fi
6 |
7 | n=$1
8 | docker-machine stop swarm-master > /dev/null 2>&1 &
9 | for (( i = 0; i < n-1; i++ )); do
10 | docker-machine stop swarm-node-$i > /dev/null 2>&1 &
11 | done
12 | wait
13 | docker-machine rm -f swarm-master &
14 | for (( i = 0; i < n-1; i++ )); do
15 | docker-machine rm -f swarm-node-$i &
16 | done
17 | wait
18 |
--------------------------------------------------------------------------------
/docker/upload_s3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | if [[ $# -lt 2 ]]; then
3 | echo "usage: $0 local_file s3_path"
4 | exit -1
5 | fi
6 |
7 | local_file=$1
8 | shift
9 |
10 | if [[ ! -e $local_file ]]; then
11 | echo "$local_file does not exist!"
12 | exit -1
13 | fi
14 |
15 | s3_path=$1
16 | shift
17 |
18 | bucket=`echo $s3_path | cut -d'/' -f3`
19 | path=${s3_path#s3://$bucket}
20 |
21 | length=`stat $local_file | awk '{print $8}'`
22 | curl "http://$bucket.s3.amazonaws.com$path?Content-Length=$length&x-amz-acl=public-read" --upload-file $local_file
23 |
24 |
25 |
--------------------------------------------------------------------------------
/example/linear/README.md:
--------------------------------------------------------------------------------
1 | # Tutorial to run linear method
2 |
3 | **Prepearing Data**
4 |
5 | Use the script such as `rcv1/download.sh` and `ctr/download.sh` to prepare
6 | sample data.
7 |
8 |
9 | **Run L1 Logistic Regression on the CTR dataset**
10 |
11 | Let's first start one worker and one server in the local machine to run the online
12 | solver.
13 | ```bash
14 | ../../script/ps.sh start ../../build/linear -app_file ctr/online_l1lr.conf
15 | ```
16 |
17 | Next we use 3 workers and 4 servers:
18 |
19 | ```bash
20 | ../../script/ps.sh start -nw 3 -ns 4 ../../build/linear -app_file ctr/online_l1lr.conf
21 | ```
22 |
23 | We can also start the jobs in multiple machines. Assume there is a `your_hostfile`
24 | containing all machines IPs line by line. Then start the job via
25 | the `-hostfile` option.
26 | ```bash
27 | ../../script/ps.sh start -hostfile your_hostfile -nw 3 -ns 4 ../../build/linear -app_file ctr/online_l1lr.conf
28 | ```
29 |
30 | Finally, we can change the application by the `-app_file` option. For example, Evaluate the
31 | trained model
32 | ```bash
33 | ../../script/ps.sh start ../../build/linear -app_file ctr/eval_online.conf
34 | ```
35 | or train by the batch solver
36 | ```bash
37 | ../../script/ps.sh start ../../build/linear -app_file ctr/batch_l1lr.conf
38 | ```
39 |
40 | See more information about the configuration file in [linear.proto](../../src/app/linear_method/proto/linear.proto)
41 |
42 |
43 |
44 |
45 |
46 |
--------------------------------------------------------------------------------
/example/linear/criteo/batch_l1lr.conf:
--------------------------------------------------------------------------------
1 | training_data {
2 | format: TEXT
3 | text: CRITEO
4 | file: "data/criteo/train/part.*"
5 |
6 | # If the data is placed on hdfs and HADOOP_HOME="/usr"
7 | # hdfs {
8 | # home: "/usr"
9 | # }
10 | }
11 |
12 | model_output {
13 | format: TEXT
14 | file: "model/criteo_batch"
15 | }
16 |
17 | loss {
18 | type: LOGIT
19 | }
20 |
21 | # lambda * \| w \|_1
22 | penalty {
23 | type: L1
24 | lambda: 4
25 | lambda: 1
26 | }
27 |
28 | learning_rate {
29 | type: CONSTANT
30 | alpha: .9
31 | }
32 |
33 | darlin {
34 | # the max number of data passes
35 | max_pass_of_data : 50
36 | # convergance critiria. stop if the relative objective <= epsilon
37 | epsilon : 2e-5
38 |
39 | save_model_every_n_iter: 20
40 |
41 | # The maximal number of blocks can be updating in parallel (bounded-delay
42 | # consistency). A larger delay may slow down the convergence rate, but improves
43 | # the system performance.
44 | max_block_delay: 2
45 |
46 | # features which occurs <= *tail_feature_freq* will be filtered before
47 | # training. it save both memory and bandwidth.
48 | tail_feature_freq: 4
49 |
50 | # It controls the countmin size. We filter the tail features by countmin, which
51 | # is more efficient than hash, but still is the memory bottleneck for servers. A
52 | # smaller ratio reduces the memory footprint, but may increase the size of
53 | # filtered feature.
54 |
55 | countmin_n_ratio: .66
56 |
57 | # During preprocessing, each (text) file is parsed and then write into the local
58 | # cache in binary format to save the memory. These data are then used by the
59 | # preprocessing stage, and also can be re-used when running next time.
60 | local_cache {
61 | format: BIN
62 | file: "data/cache/criteo_train_"
63 | }
64 |
65 | comm_filter {
66 | type: KEY_CACHING
67 | }
68 |
69 | # load_local_data: true
70 |
71 | # Parameters used by the trust region method. The change of w_i (the i-th
72 | # parameter) is bouned by [-delta_i, delta_i], where delta_i is an adaptive
73 | # value according to the convergence. The initial value of delta_i is
74 | # *delta_init_value* and maximal value is *delta_max_value*. You can increase
75 | # these parameters for easy datasets.
76 |
77 | # [PS.LM.delta_init_value] : 1
78 | # [PS.LM.delta_max_value] : 5
79 |
80 | # This parameter controls the aggressiveness of the KKT filter. Increasing this
81 | # number will decrease the effect of KKT filter. a very large number, such as
82 | # 1e20 will turn off the KKT filter.
83 |
84 | # [PS.LM.kkt_filter_threshold_ratio] : 10
85 | }
86 |
--------------------------------------------------------------------------------
/example/linear/criteo/download.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | dir=`dirname "$0"`
4 | cd $dir/../data
5 |
6 |
7 | if [ ! -f train.txt ]; then
8 | if [ ! -f dac.tar.gz ]; then
9 | wget https://s3-eu-west-1.amazonaws.com/criteo-labs/dac.tar.gz
10 | fi
11 | tar -zxvf dac.tar.gz
12 | fi
13 |
14 | echo "split train.txt..."
15 | mkdir -p criteo/train
16 | split -n l/18 --numeric-suffixes=1 --suffix-length=3 train.txt criteo/train/part-
17 |
18 | echo "make a test set"
19 | mkdir -p criteo/test
20 | mv criteo/train/part-01[7-8] criteo/test
21 |
--------------------------------------------------------------------------------
/example/linear/criteo/eval_batch.conf:
--------------------------------------------------------------------------------
1 | validation_data {
2 | format: TEXT
3 | text: CRITEO
4 | file: "data/criteo/test/part.*"
5 | }
6 |
7 | model_input {
8 | format: TEXT
9 | file: "model/criteo_batch_S.*"
10 | }
11 |
--------------------------------------------------------------------------------
/example/linear/ctr/batch_l1lr.conf:
--------------------------------------------------------------------------------
1 | training_data {
2 | format: TEXT
3 | text: SPARSE_BINARY
4 | file: "data/ctr/train/part.*"
5 |
6 | # If the data is placed on hdfs and HADOOP_HOME="/usr"
7 | # hdfs {
8 | # home: "/usr"
9 | # }
10 | }
11 |
12 | model_output {
13 | format: TEXT
14 | file: "model/ctr_batch"
15 | }
16 |
17 | loss {
18 | type: LOGIT
19 | }
20 |
21 | # lambda * \| w \|_1
22 | penalty {
23 | type: L1
24 | lambda: 10
25 | }
26 |
27 | learning_rate {
28 | type: CONSTANT
29 | alpha: 1
30 | }
31 |
32 | darlin {
33 | # the max number of data passes
34 | max_pass_of_data : 20
35 | # convergance critiria. stop if the relative objective <= epsilon
36 | epsilon : 2e-5
37 |
38 | # Features are partitioned into groups and each time only one group is
39 | # updated. We divide a feature group into
40 | # feature_block_ratio * nnz_feature_per_example
41 | # blocks. A larger ratio often accelerate the convergence, however, it may slow
42 | # down the system performance because of the increased number of global barriers.
43 | feature_block_ratio : 4
44 |
45 | # The maximal number of blocks can be updating in parallel (bounded-delay
46 | # consistency). A larger delay may slow down the convergence rate, but improves
47 | # the system performance.
48 | max_block_delay: 0
49 |
50 | # important feature groups, update them earlier to get a better model
51 | # initialization.
52 | prior_fea_group: 127
53 | prior_fea_group: 120
54 |
55 | # features which occurs <= *tail_feature_freq* will be filtered before
56 | # training. it save both memory and bandwidth.
57 | tail_feature_freq: 4
58 |
59 | # It controls the countmin size. We filter the tail features by countmin, which
60 | # is more efficient than hash, but still is the memory bottleneck for servers. A
61 | # smaller ratio reduces the memory footprint, but may increase the size of
62 | # filtered feature.
63 |
64 | countmin_n_ratio: .66
65 |
66 | # In preprocessing, feature group is processed one by one. It is the main memory
67 | # bottleneck for workers. This number control how many feature groups can be in
68 | # memory at the same time. A smaller number reduce the workers' memory
69 | # footprint, but may slow down the preprocessing speed.
70 |
71 | # max_num_parallel_groups_in_preprocessing: 1000
72 |
73 | # During preprocessing, each (text) file is parsed and then write into the local
74 | # cache in binary format to save the memory. These data are then used by the
75 | # preprocessing stage, and also can be re-used when running next time.
76 | local_cache {
77 | format: BIN
78 | file: "data/cache/ctr_train_"
79 | }
80 |
81 | # Parameters used by the trust region method. The change of w_i (the i-th
82 | # parameter) is bouned by [-delta_i, delta_i], where delta_i is an adaptive
83 | # value according to the convergence. The initial value of delta_i is
84 | # *delta_init_value* and maximal value is *delta_max_value*. You can increase
85 | # these parameters for easy datasets.
86 |
87 | # [PS.LM.delta_init_value] : 1
88 | # [PS.LM.delta_max_value] : 5
89 |
90 | # This parameter controls the aggressiveness of the KKT filter. Increasing this
91 | # number will decrease the effect of KKT filter. a very large number, such as
92 | # 1e20 will turn off the KKT filter.
93 |
94 | # [PS.LM.kkt_filter_threshold_ratio] : 10
95 | }
96 |
--------------------------------------------------------------------------------
/example/linear/ctr/download.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | dir=`dirname "$0"`
4 | git clone https://github.com/mli/ctr-data $dir/../data/ctr
5 |
--------------------------------------------------------------------------------
/example/linear/ctr/eval_batch.conf:
--------------------------------------------------------------------------------
1 | validation_data {
2 | format: TEXT
3 | text: SPARSE_BINARY
4 | file: "data/ctr/test/part.*"
5 | }
6 |
7 | model_input {
8 | format: TEXT
9 | file: "model/ctr_batch.*"
10 | }
11 |
--------------------------------------------------------------------------------
/example/linear/ctr/eval_online.conf:
--------------------------------------------------------------------------------
1 | validation_data {
2 | format: TEXT
3 | # text: SPARSE_BINARY
4 | # file: "data/ctr/test/part.*"
5 |
6 | text: ADFEA
7 | file: "/home/muli/work/data/ctrc-rand/part-05.*"
8 | }
9 |
10 | model_input {
11 | format: TEXT
12 | file: "model/ctr_online.*"
13 | }
14 |
--------------------------------------------------------------------------------
/example/linear/ctr/online_l1lr.conf:
--------------------------------------------------------------------------------
1 | training_data {
2 | format: TEXT
3 | text: SPARSE_BINARY
4 | file: "data/ctr/train/part.*"
5 | }
6 |
7 | model_output {
8 | format: TEXT
9 | file: "model/ctr_online"
10 | }
11 |
12 | loss { type: LOGIT }
13 |
14 | # lambda_0 * |w|_1 + lambda_1 * |w|^2_2
15 | penalty {
16 | type: L1
17 | lambda: 10
18 | lambda: 1
19 | }
20 |
21 | # lr = alpha: .1
22 | learning_rate {
23 | type: DECAY
24 | alpha: .01
25 | beta: 10
26 | }
27 |
28 | # see more config options in linear.proto
29 | async_sgd {
30 | algo: FTRL
31 | # The size of minibatch
32 | minibatch: 10000
33 | # The number of data passes
34 | num_data_pass: 10
35 |
36 | push_filter {
37 | type: KEY_CACHING
38 | clear_cache_if_done: true
39 | }
40 |
41 | push_filter {
42 | type: FIXING_FLOAT
43 | num_bytes: 1
44 | }
45 |
46 | pull_filter {
47 | type: KEY_CACHING
48 | }
49 |
50 | pull_filter {
51 | type: FIXING_FLOAT
52 | num_bytes: 1
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/example/linear/rcv1/batch_l1lr.conf:
--------------------------------------------------------------------------------
1 | training_data {
2 | format: TEXT
3 | text: LIBSVM
4 | file: "data/rcv1/train/part.*"
5 | }
6 |
7 | model_output {
8 | format: TEXT
9 | file: "model/rcv1_batch_l1lr"
10 | }
11 |
12 | loss {
13 | type: LOGIT
14 | }
15 |
16 | # lambda * |w|_1
17 | penalty {
18 | type: L1
19 | lambda: 1
20 | }
21 |
22 | learning_rate {
23 | type: CONSTANT
24 | alpha: 1
25 | }
26 |
27 | darlin {
28 | # max number pass of traing data
29 | max_pass_of_data : 20
30 | # convergance critiria. stop if the relative objective <= epsilon
31 | epsilon : 1e-4
32 |
33 | # temp data
34 | local_cache {
35 | format: BIN
36 | file: "data/cache/rcv1_train"
37 | }
38 |
39 | # debug
40 | # feature_block_ratio : .0001
41 | # tail_feature_freq: 1
42 | # [PS.LM.kkt_filter_threshold_ratio] : 1e20
43 | }
44 |
--------------------------------------------------------------------------------
/example/linear/rcv1/download.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | dir=`dirname "$0"`
4 | cd $dir/../data
5 |
6 | # download
7 | if ! [ -e rcv1_train.binary ]; then
8 | wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/rcv1_train.binary.bz2
9 | bunzip2 rcv1_train.binary.bz2
10 | fi
11 |
12 | if ! [ -e rcv1_test.binary ]; then
13 | wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/rcv1_test.binary.bz2
14 | bunzip2 rcv1_test.binary.bz2
15 | fi
16 |
17 |
18 | name=(train test)
19 | for t in "${name[@]}"
20 | do
21 | echo $t
22 | # shuffle
23 | rnd=rcv1_${t}_rand
24 | shuf rcv1_${t}.binary >$rnd
25 |
26 | # split
27 | mkdir -p rcv1/${t}
28 | rm -f rcv1/${t}/*
29 | split -n l/8 --numeric-suffixes=1 --suffix-length=3 $rnd rcv1/${t}/part-
30 | rm $rnd
31 | done
32 |
33 | # swap train and test
34 | mv rcv1/train tmp
35 | mv rcv1/test rcv1/train
36 | mv tmp rcv1/test
37 |
--------------------------------------------------------------------------------
/example/linear/rcv1/eval_batch.conf:
--------------------------------------------------------------------------------
1 | validation_data {
2 | format: TEXT
3 | text: LIBSVM
4 | file: "data/rcv1/test/part.*"
5 | }
6 |
7 | model_input {
8 | format: TEXT
9 | file: "model/rcv1_batch.*"
10 | }
11 |
--------------------------------------------------------------------------------
/example/linear/rcv1/eval_online.conf:
--------------------------------------------------------------------------------
1 | validation_data {
2 | format: TEXT
3 | text: LIBSVM
4 | file: "data/rcv1/test/part.*"
5 | }
6 |
7 | model_input {
8 | format: TEXT
9 | file: "model/rcv1_online.*"
10 | }
11 |
--------------------------------------------------------------------------------
/example/linear/rcv1/online_l1lr.conf:
--------------------------------------------------------------------------------
1 | training_data {
2 | format: TEXT
3 | text: LIBSVM
4 | file: "data/rcv1/train/part.*"
5 | }
6 |
7 | model_output {
8 | format: TEXT
9 | file: "model/rcv1_online"
10 | }
11 |
12 | loss {
13 | type: LOGIT
14 | }
15 |
16 | # lambda * |w|_1
17 | penalty {
18 | type: L1
19 | lambda: 1
20 | }
21 |
22 | learning_rate {
23 | type: DECAY
24 | alpha: 1
25 | beta: 1
26 | }
27 |
28 | async_sgd {
29 | algo: FTRL
30 | minibatch : 1000
31 | }
32 |
--------------------------------------------------------------------------------
/make/README.md:
--------------------------------------------------------------------------------
1 | # Build the Parameter Server
2 |
3 | **Requirement**
4 |
5 | The parameter server needs a C++ compiler supporting c++11, such as `gcc` >=
6 | 4.7.2 (prefer >= 4.8) or `llvm` >= 3.4.
7 | You can update `gcc` via either downloading
8 | packages, e.g. [centos](http://linux.web.cern.ch/linux/devtoolset/),
9 | [ubuntu](http://ubuntuhandbook.org/index.php/2013/08/install-gcc-4-8-via-ppa-in-ubuntu-12-04-13-04/),
10 | [mac os x](http://hpc.sourceforge.net/), or building from source, such as for
11 | [centos](http://www.codersvoice.com/a/webbase/install/08/202014/131.html).
12 |
13 | **Build the Parameter Server**
14 |
15 | Assume `git` is installed:
16 |
17 | ```bash
18 | git clone https://github.com/dmlc/parameter_server -b dev
19 | cd parameter_server
20 | ./script/install_third.sh
21 | make -j8
22 | ```
23 |
24 | **Customized Building**
25 |
26 | You can modify [config.mk](config.mk) to customize the building. You can copy
27 | this file to the upper directory so that the changes will be ignored by git.
28 |
--------------------------------------------------------------------------------
/make/config.mk:
--------------------------------------------------------------------------------
1 | # default configuration of make
2 | #
3 | # you can copy it to the parent directory and modify it as you want. then
4 | # compile by `make -j 8` using 8 threads
5 |
6 | # compiler
7 | CC = g++
8 |
9 | # optimization flag. -O0 -ggdb for debug
10 | OPT = -O3 -ggdb
11 |
12 | # statically link all dependent libraries, such as gflags, zeromq, if
13 | # 1. otherwise use dynamic linking
14 | STATIC_THIRD_LIB = 0
15 |
16 | # the installed path of third party libraries
17 | THIRD_PATH = $(shell pwd)/third_party
18 |
19 | # additional link flags, such as -ltcmalloc_and_profiler
20 | EXTRA_LDFLAGS =
21 |
22 | # additional compile flags
23 | EXTRA_CFLAGS =
24 |
25 | # io option
26 | USE_S3 = 0
27 |
28 | all: ps build/linear
29 |
--------------------------------------------------------------------------------
/script/get_root_node.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | if [ $# -ne 2 ]; then
3 | echo "usage: ./self interface port"
4 | exit -1;
5 | fi
6 |
7 | ip=`/sbin/ifconfig ${1} | grep inet | grep -v inet6 | awk '{print $2}' | sed -e 's/[a-z]*:/''/'`
8 | if [ -z ${ip} ]; then
9 | echo "failed to get the ip address"
10 | exit -1
11 | fi
12 |
13 | echo "role:SCHEDULER,hostname:'${ip}',port:${2},id:'H'"
14 |
--------------------------------------------------------------------------------
/script/install_third.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | dir=`dirname "$0"`
3 | cd $dir/..
4 | git clone https://github.com/mli/third_party
5 | cd third_party
6 | ./install.sh
7 |
--------------------------------------------------------------------------------
/script/kill_node.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | if [ $# -ne 1 ]; then
3 | echo "usage: $0 node_id"
4 | exit -1;
5 | fi
6 |
7 | out=/tmp/aux
8 | ps aux >$out
9 | pid=`grep \'$1\'.*scheduler $out | awk '{print $2}'`
10 | if [ ! -z "$pid" ]; then
11 | kill -9 $pid
12 | fi
13 | rm -f $out
14 |
--------------------------------------------------------------------------------
/script/local.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # set -x
3 | dir=`dirname "$0"`
4 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$dir/../third_party/lib
5 | if [ $# -lt 3 ]; then
6 | echo "usage: $0 num_servers num_workers bin [args..]"
7 | exit -1;
8 | fi
9 |
10 | num_servers=$1
11 | shift
12 | num_workers=$1
13 | shift
14 | bin=$1
15 | shift
16 | arg="-num_servers ${num_servers} -num_workers ${num_workers} -log_dir log $@" #" -app ${dir}/$@"
17 |
18 |
19 | # killall -q $(basename ${bin})
20 | # killall -q ${bin}
21 | HEAPCHECK=draconian
22 | # start the scheduler
23 | Sch="role:SCHEDULER,hostname:'127.0.0.1',port:8001,id:'H'"
24 | ${bin} -my_node ${Sch} -scheduler ${Sch} ${arg} &
25 |
26 | # start servers
27 | for ((i=0; i<${num_servers}; ++i)); do
28 | port=$((9600 + ${i}))
29 | N="role:SERVER,hostname:'127.0.0.1',port:${port},id:'S${i}'"
30 | # HEAPPROFILE=/tmp/S${i} \
31 | # CPUPROFILE=/tmp/S${i} \
32 | ${bin} -my_node ${N} -scheduler ${Sch} ${arg} &
33 | done
34 |
35 | # start workers
36 | for ((i=0; i<${num_workers}; ++i)); do
37 | port=$((9500 + ${i}))
38 | N="role:WORKER,hostname:'127.0.0.1',port:${port},id:'W${i}'"
39 | # HEAPPROFILE=/tmp/W${i} \
40 | # CPUPROFILE=/tmp/W${i} \
41 | ${bin} -my_node ${N} -scheduler ${Sch} ${arg} &
42 | done
43 |
44 | wait
45 |
--------------------------------------------------------------------------------
/script/mpi_node.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # set -x
3 | if [ $# -ne 2 ]; then
4 | echo $#
5 | echo "usage: ./self scheduler_node mpi.conf"
6 | exit -1;
7 | fi
8 |
9 | # support mpich and openmpi
10 | # try mpirun -n 1 env to get all available environment
11 | if [ ! -z ${PMI_RANK} ]; then
12 | my_rank=${PMI_RANK}
13 | elif [ ! -z ${OMPI_COMM_WORLD_RANK} ]; then
14 | my_rank=${OMPI_COMM_WORLD_RANK}
15 | else
16 | echo "failed to get my rank id"
17 | exit -1
18 | fi
19 |
20 | if [ ! -z ${PMI_SIZE} ]; then
21 | rank_size=${PMI_SIZE}
22 | elif [ ! -z ${OMPI_COMM_WORLD_SIZE} ]; then
23 | rank_size=${OMPI_COMM_WORLD_SIZE}
24 | else
25 | echo "failed to get the rank size"
26 | exit -1
27 | fi
28 |
29 | source ${2}
30 |
31 | if (( ${rank_size} == 0 )); then
32 | root_node=`${dir}/get_root_node.sh ${network_interface} ${network_port}`
33 | if [${root_node} -ne ${my_node}]; then
34 | echo "start ./mpi_root.sh on the first machine in your hostfile"
35 | exit -1
36 | fi
37 | fi
38 |
39 | if (( ${rank_size} < ${num_workers} + ${num_servers} + 1 )); then
40 | echo "too small rank size ${rank_size}"
41 | exit -1
42 | fi
43 |
44 | # mkdir -p ${3}/../output
45 | # -num_threads ${num_threads} \
46 | ${bin} \
47 | -num_servers ${num_servers} \
48 | -num_workers ${num_workers} \
49 | -scheduler ${1} \
50 | -my_rank ${my_rank} \
51 | -interface ${network_interface} \
52 | ${arg}
53 |
54 | exit $?
55 | # echo "rank:${my_rank} launch failed"; exit -1;
56 |
--------------------------------------------------------------------------------
/script/mpi_root.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # set -x
3 |
4 | if [ $# -ne 1 ]; then
5 | echo "usage: ./self mpi.conf"
6 | exit -1;
7 | fi
8 |
9 | conf=${1}
10 | source ${conf}
11 |
12 | # mpirun=${dir}/mpirun
13 |
14 | dir=`dirname "$0"`
15 | root_node=`${dir}/get_root_node.sh ${network_interface} ${network_port}`
16 | np=$((${num_workers} + ${num_servers} + 1))
17 |
18 | if [ ! -z ${hostfile} ]; then
19 | hostfile="-hostfile ${hostfile}"
20 | fi
21 |
22 | # mpirun ${hostfile} killall -q ps
23 | # mpirun ${hostfile} md5sum ../bin/ps
24 |
25 | mpirun ${hostfile} -np ${np} ${dir}/mpi_node.sh ${root_node} ${conf}
26 |
--------------------------------------------------------------------------------
/src/README.mk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dmlc/parameter_server/2feb6f4224188e11f63b0843ea76ade4003a60f3/src/README.mk
--------------------------------------------------------------------------------
/src/app/linear_method/learning_rate.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "app/linear_method/proto/linear.pb.h"
3 | namespace PS {
4 | namespace LM {
5 |
6 | template
7 | class LearningRate {
8 | public:
9 | LearningRate(const LearningRateConfig& conf) : conf_(conf) {
10 | CHECK_GT(alpha(), 0);
11 | CHECK_GE(beta(), 0);
12 | }
13 | ~LearningRate() { }
14 |
15 | V eval(V x = 0) const {
16 | if (conf_.type() == LearningRateConfig::CONSTANT) {
17 | // if (x == 0 && beta() == 0) {
18 | return alpha();
19 | } else {
20 | return alpha() / ( x + beta() );
21 | }
22 | }
23 |
24 | V alpha() const { return conf_.alpha(); }
25 | V beta() const { return conf_.beta(); }
26 | private:
27 | // const LearningRateConfig& conf() { return conf_; }
28 | LearningRateConfig conf_;
29 | };
30 |
31 |
32 | } // namespace LM
33 | } // namespace PS
34 |
--------------------------------------------------------------------------------
/src/app/linear_method/main.cc:
--------------------------------------------------------------------------------
1 | #include "ps.h"
2 | #include "app/linear_method/async_sgd.h"
3 | #include "app/linear_method/darlin.h"
4 | #include "app/linear_method/model_evaluation.h"
5 |
6 | namespace PS {
7 | App* App::Create(const string& conf_str) {
8 | using namespace LM;
9 | // parse config
10 | Config conf;
11 | CHECK(google::protobuf::TextFormat::ParseFromString(conf_str, &conf))
12 | << " failed to parse conf: " << conf.ShortDebugString();
13 |
14 | // create app
15 | auto my_role = MyNode().role();
16 | App* app = nullptr;
17 | if (conf.has_darlin()) {
18 | if (my_role == Node::SCHEDULER) {
19 | app = new DarlinScheduler(conf);
20 | } else if (my_role == Node::WORKER) {
21 | app = new DarlinWorker(conf);
22 | } else if (my_role == Node::SERVER) {
23 | app = new DarlinServer(conf);
24 | }
25 | } else if (conf.has_async_sgd()) {
26 | typedef float Real;
27 | if (my_role == Node::SCHEDULER) {
28 | app = new AsyncSGDScheduler(conf);
29 | } else if (my_role == Node::WORKER) {
30 | app = new AsyncSGDWorker(conf);
31 | } else if (my_role == Node::SERVER) {
32 | app = new AsyncSGDServer(conf);
33 | }
34 | } else if (conf.has_validation_data()) {
35 | app = new ModelEvaluation(conf);
36 | }
37 | CHECK(app) << "fail to create " << conf.ShortDebugString()
38 | << " at " << MyNode().ShortDebugString();
39 | return app;
40 | }
41 | } // namespace PS
42 |
43 | int main(int argc, char *argv[]) {
44 | PS::RunSystem(argc, argv);
45 | return 0;
46 | }
47 |
--------------------------------------------------------------------------------
/src/app/linear_method/model_evaluation.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "system/customer.h"
3 | #include "data/stream_reader.h"
4 | #include "util/evaluation.h"
5 | namespace PS {
6 |
7 | #if USE_S3
8 | bool s3file(const std::string& name);
9 | std::string s3Prefix(const std::string& path);
10 | std::string s3Bucket(const std::string& path);
11 | std::string s3FileUrl(const std::string& path);
12 | #endif // USE_S3
13 |
14 |
15 | namespace LM {
16 |
17 | class ModelEvaluation : public App {
18 | public:
19 | ModelEvaluation(const Config& conf) : App(), conf_(conf) { }
20 | virtual ~ModelEvaluation() { }
21 | virtual void Run();
22 | private:
23 | typedef float Real;
24 | Config conf_;
25 | };
26 |
27 | void ModelEvaluation::Run() {
28 | if (!IsScheduler()) return;
29 | // load model
30 | std::unordered_map weight;
31 | auto model = searchFiles(conf_.model_input());
32 | NOTICE("find %d model files", model.file_size());
33 | for (int i = 0; i < model.file_size(); ++i) {
34 | #if USE_S3
35 | std::ifstream in;
36 | if (s3file(model.file(i))) {
37 | // download file from s3
38 | std::string cmd="curl -s -o model_file "+s3FileUrl(model.file(i));
39 | LOG(INFO)<> k >> v;
52 | weight[k] = v;
53 | }
54 | }
55 | #if USE_S3
56 | // remove local model after read done
57 | std::string cmd="rm -rf model_file";
58 | system(cmd.c_str());
59 | #endif // USE_S3
60 |
61 | NOTICE("load %lu model entries", weight.size());
62 |
63 | // load evaluation data and compute the predicted value
64 | auto data = searchFiles(conf_.validation_data());
65 | data.set_ignore_feature_group(true);
66 | NOTICE("find %d data files", data.file_size());
67 |
68 | SArray label;
69 | SArray predict;
70 | MatrixPtrList mat;
71 | StreamReader reader(data);
72 | // TODO read in an another thread
73 | bool good = false;
74 | do {
75 | good = reader.readMatrices(100000, &mat);
76 | CHECK_EQ(mat.size(), 2);
77 | label.append(mat[0]->value());
78 |
79 | SArray Xw(mat[1]->rows()); Xw.SetZero();
80 | auto X = std::static_pointer_cast>(mat[1]);
81 | for (int i = 0; i < X->rows(); ++i) {
82 | Real re = 0;
83 | for (size_t j = X->offset()[i]; j < X->offset()[i+1]; ++j) {
84 | // TODO build a bloom filter
85 | auto it = weight.find(X->index()[j]);
86 | if (it != weight.end()) {
87 | re += it->second * (X->binary() ? 1 : X->value()[j]);
88 | }
89 | }
90 | Xw[i] = re;
91 | }
92 | predict.append(Xw);
93 | printf("\r \r");
94 | printf(" load %lu examples", label.size());
95 | fflush(stdout);
96 | } while (good);
97 | printf("\n");
98 |
99 | // label.writeToFile("label");
100 | // predict.writeToFile("predict");
101 |
102 | // evaluation
103 |
104 | NOTICE("auc: %f", Evaluation::auc(label, predict));
105 | NOTICE("accuracy: %f", Evaluation::accuracy(label, predict));
106 | NOTICE("logloss: %f", Evaluation::logloss(label, predict));
107 | }
108 |
109 | } // namespace LM
110 | } // namespace PS
111 |
--------------------------------------------------------------------------------
/src/app/linear_method/penalty.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "util/common.h"
3 | #include "util/matrix.h"
4 | #include "app/linear_method/proto/linear.pb.h"
5 | namespace PS {
6 | namespace LM {
7 |
8 | /**
9 | * @brief Interface for the penalty
10 | */
11 | template class Penalty {
12 | public:
13 | Penalty() { }
14 | virtual ~Penalty() { }
15 | /**
16 | * @brief evaluate the objective
17 | *
18 | * @param model
19 | *
20 | * @return objective value
21 | */
22 | virtual T eval(const MatrixPtr& model) = 0;
23 |
24 | /**
25 | * @brief Solve the proximal operator
26 | *
27 | * \f$ \argmin_x 0.5/\eta (x - z)^2 + h(x)\f$, where h denote this penatly, and in
28 | * proximal gradient descent, z = w - eta * grad
29 | *
30 | * @param z
31 | * @param eta
32 | * @return
33 | */
34 | virtual T proximal(T z, T eta) = 0;
35 | };
36 |
37 | /**
38 | * @brief \f$ \lambda_1 * \|x\|_1 + \lambda_2 * \|x\|_2^2 \f$
39 | */
40 | template
41 | class ElasticNet : public Penalty {
42 | public:
43 | ElasticNet(T lambda1, T lambda2) : lambda1_(lambda1), lambda2_(lambda2) {
44 | CHECK_GE(lambda1, 0);
45 | CHECK_GE(lambda2, 0);
46 | }
47 | ~ElasticNet() { }
48 |
49 | T eval(const MatrixPtr& model) { return 0; } // TODO
50 |
51 | T proximal(T z, T eta) {
52 | CHECK_GT(eta, 0);
53 | T leta = lambda1_ * eta;
54 | if (z <= leta && z >= -leta) return 0;
55 | return z > 0 ? (z - leta) / ( 1 + lambda2_ * eta) : (z + leta) / ( 1 + lambda2_ * eta);
56 | }
57 | private:
58 | T lambda1_, lambda2_;
59 | };
60 |
61 |
62 | // template
63 | // class L2 : public Penalty {
64 | // public:
65 | // L2(T lambda) : lambda_(lambda) { }
66 | // CHECK_GE(lambda, 0);
67 | // }
68 | // ~L2() { }
69 | // private:
70 | // T evaluate(const MatrixPtr& model) { return 0; } // TODO
71 | // T proximal(T z, T eta) {
72 | // }
73 | // T lambda_;
74 | // };
75 |
76 | template
77 | Penalty* createPenalty(const PenaltyConfig& conf) {
78 | CHECK_GE(conf.lambda_size(), 1);
79 | switch (conf.type()) {
80 | case PenaltyConfig::L1: {
81 | T l1 = conf.lambda(0);
82 | T l2 = conf.lambda_size() > 1 ? conf.lambda(1) : 0;
83 | return new ElasticNet(l1, l2);
84 | }
85 | case PenaltyConfig::L2:
86 | return new ElasticNet(0, conf.lambda(0));
87 | default:
88 | CHECK(false) << "unknown type: " << conf.DebugString();
89 | }
90 | return nullptr;
91 | }
92 |
93 | } // namespace LM
94 | } // namespace PS
95 |
96 | // // lambda * ||w||_p^P = lambda * \sum_i w_i^p
97 | // // TODO infinity
98 | // template
99 | // class PNormPenalty : public Penalty {
100 | // public:
101 | // PNormPenalty(T p, T lambda) : p_(p), lambda_(lambda) {
102 | // CHECK_GE(p_, 0);
103 | // CHECK_GE(lambda_, 0);
104 | // }
105 | // bool smooth() { return p_ > 1; }
106 |
107 | // T evaluate(const MatrixPtr& model) {
108 | // auto w = model->value().EigenArray();
109 | // return lambda_ * pow(w.abs(), p_).sum();
110 | // }
111 |
112 | // T lambda() { return lambda_; }
113 | // T p() { return p_; }
114 | // private:
115 | // T p_;
116 | // T lambda_;
117 | // };
118 |
--------------------------------------------------------------------------------
/src/app/linear_method/proto/linear.proto:
--------------------------------------------------------------------------------
1 | // configuration of linear methods
2 | package PS.LM;
3 | import "data/proto/data.proto";
4 | import "learner/proto/bcd.proto";
5 | import "filter/proto/filter.proto";
6 |
7 | message Config {
8 | optional DataConfig training_data = 1;
9 | optional DataConfig validation_data = 2;
10 |
11 | optional DataConfig model_output = 4;
12 | optional DataConfig model_input = 5;
13 |
14 | optional LossConfig loss = 10;
15 | optional PenaltyConfig penalty = 11;
16 |
17 | optional LearningRateConfig learning_rate = 12;
18 |
19 | optional SGDConfig async_sgd = 17;
20 | optional BCDConfig darlin = 15;
21 |
22 | }
23 |
24 | extend BCDConfig {
25 | // Used by the trust region method. All changes of parameters will be bounded
26 | // by *delta*. *delta* is updated according to the convergence, whose intial
27 | // value is *delta_init_value* and maximal value is *delta_max_value*
28 | optional double delta_init_value = 101 [default = 1];
29 | optional double delta_max_value = 102 [default = 5];
30 | // kkt_filter_threshold = max_gradient_violation / num_examples *
31 | // kkt_filter_threshold_ratio. increasing this number reduces the effect of
32 | // kkt filter.
33 | optional double kkt_filter_threshold_ratio = 103 [default = 10];
34 | }
35 |
36 | message SGDConfig {
37 | enum Algo {
38 | STANDARD = 1;
39 | FTRL = 2;
40 | }
41 | required Algo algo = 1;
42 |
43 | // The size of minibatch
44 | optional int32 minibatch = 2 [default = 1000];
45 |
46 | optional int32 data_buf = 12 [default = 1000]; // in mb
47 |
48 | optional bool ada_grad = 5 [default = true];
49 |
50 | optional int32 max_delay = 4 [default = 0];
51 |
52 | // The number of data passes
53 | optional int32 num_data_pass = 11 [default = 1];
54 |
55 | // in sec
56 | optional int32 report_interval = 3 [default = 1];
57 |
58 | // features which occurs <= *tail_feature_freq* will be filtered before
59 | // training. it save both memory and bandwidth.
60 | optional int32 tail_feature_freq = 6 [default = 0];
61 |
62 | // It controls the countmin size. We filter the tail features by countmin, which
63 | // is more efficient than hash, but still is the memory bottleneck for servers. A
64 | // smaller ratio reduces the memory footprint, but may increase the size of
65 | // filtered feature.
66 | optional float countmin_n = 8 [default = 1e8];
67 | optional int32 countmin_k = 7 [default = 2];
68 |
69 | repeated FilterConfig push_filter = 13;
70 | repeated FilterConfig pull_filter = 14;
71 | }
72 |
73 | message LossConfig {
74 | enum Type {
75 | SQUARE = 1;
76 | LOGIT = 2;
77 | HINGE = 3;
78 | SQUARE_HINGE = 4;
79 | }
80 | required Type type = 1;
81 | }
82 |
83 | message PenaltyConfig {
84 | enum Type {
85 | L1 = 1; // lambda(0) * ||w||_1 + lambda(1) * ||w||_F^2
86 | L2 = 2; // lambda(0) * ||w||_F^2
87 | }
88 | required Type type = 1;
89 | repeated double lambda = 2;
90 | }
91 |
92 |
93 | message LearningRateConfig {
94 | enum Type {
95 | CONSTANT = 1; // = alpha
96 | DECAY = 2; // = alpha / (beta + x), where x is user-defined, such as sqrt(iter)
97 | }
98 | optional Type type = 1;
99 | optional double alpha = 2;
100 | optional double beta = 3;
101 | }
102 |
--------------------------------------------------------------------------------
/src/data/common.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "util/common.h"
3 | #include "data/proto/example.pb.h"
4 | #include "data/proto/data.pb.h"
5 | #include "util/proto/matrix.pb.h"
6 |
7 | namespace PS {
8 |
9 | DECLARE_string(input);
10 | DECLARE_string(output);
11 | DECLARE_string(format);
12 |
13 | // return files matches the regex in *config*
14 | DataConfig searchFiles(const DataConfig& config);
15 |
16 | // evenly parttion the files into *num* parts
17 | std::vector divideFiles(const DataConfig& data, int num);
18 |
19 | // locate the i-th file in *conf*, append it with suffix, and keep the rest metadata
20 | DataConfig ithFile(const DataConfig& conf, int i, const string& suffix = "");
21 |
22 | // return A + B
23 | DataConfig appendFiles(const DataConfig& A, const DataConfig& B);
24 |
25 | ExampleInfo mergeExampleInfo(const ExampleInfo& A, const ExampleInfo& B);
26 |
27 | MatrixInfo readMatrixInfo(
28 | const ExampleInfo& info, int slot_id, int sizeof_idx, int sizeof_val);
29 |
30 |
31 | DataConfig shuffleFiles(const DataConfig& data);
32 |
33 |
34 | } // namespace PS
35 |
--------------------------------------------------------------------------------
/src/data/info_parser.cc:
--------------------------------------------------------------------------------
1 | #include "data/info_parser.h"
2 | namespace PS {
3 |
4 | void InfoParser::clear() {
5 | for (int i = 0; i < kSlotIDmax; ++i) slot_info_[i].Clear();
6 | info_.Clear();
7 | num_ex_ = 0;
8 | }
9 |
10 | bool InfoParser::add(const Example& ex) {
11 | for (int i = 0; i < ex.slot_size(); ++i) {
12 | const auto& slot = ex.slot(i);
13 | if (slot.id() >= kSlotIDmax) return false;
14 | auto& sinfo = slot_info_[slot.id()];
15 | for (int j = 0; j < slot.key_size(); ++j) {
16 | uint64 key = slot.key(j);
17 | sinfo.set_min_key(std::min((uint64)sinfo.min_key(), key));
18 | sinfo.set_max_key(std::max((uint64)sinfo.max_key(), key + 1));
19 | }
20 | if (slot.key_size() > 0) {
21 | if (slot.val_size() == slot.key_size()) {
22 | sinfo.set_format(SlotInfo::SPARSE);
23 | } else {
24 | sinfo.set_format(SlotInfo::SPARSE_BINARY);
25 | }
26 | } else if (slot.val_size() > 0) {
27 | sinfo.set_format(SlotInfo::DENSE);
28 | }
29 | sinfo.set_nnz_ex(sinfo.nnz_ex() + 1);
30 | sinfo.set_nnz_ele(sinfo.nnz_ele() + std::max(slot.key_size(), slot.val_size()));
31 | }
32 | ++ num_ex_;
33 | return true;
34 | }
35 |
36 | ExampleInfo InfoParser::info() {
37 | info_.set_num_ex(num_ex_);
38 | info_.clear_slot();
39 | for (int i = 0; i < kSlotIDmax; ++i) {
40 | auto &sinfo = slot_info_[i];
41 | if (!sinfo.nnz_ele()) continue;
42 | sinfo.set_id(i);
43 | if (i == 0) { // the label
44 | sinfo.set_min_key(0);
45 | sinfo.set_max_key(1);
46 | }
47 | *info_.add_slot() = sinfo;
48 | }
49 | return info_;
50 | }
51 | } // namespace PS
52 |
--------------------------------------------------------------------------------
/src/data/info_parser.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "util/common.h"
3 | #include "data/proto/example.pb.h"
4 | #include "data/proto/data.pb.h"
5 | namespace PS {
6 |
7 | // the maximal allowed slot id
8 | static const int kSlotIDmax = 4096;
9 |
10 | class InfoParser {
11 | public:
12 | // void init(const DataConfig& conf) { conf_ = conf; }
13 | bool add(const Example& ex);
14 | void clear();
15 | ExampleInfo info();
16 | // int maxSlotID() { return conf_.ignore_fea_slot() ? 2 : kSlotIDmax; }
17 | private:
18 | // DataConfig conf_;
19 | size_t num_ex_ = 0;
20 | ExampleInfo info_;
21 | SlotInfo slot_info_[kSlotIDmax];
22 | };
23 |
24 | } // namespace PS
25 |
--------------------------------------------------------------------------------
/src/data/matlab/bin2mat.m:
--------------------------------------------------------------------------------
1 | function [D, group_id, group_offset] = bin2mat(name)
2 | % load a binary format data, which can be generated by ./bin/recordio2bin
3 | function result = grep(name, pattern)
4 | [ret, result] = system(...
5 | ['grep ', pattern, ' ', name, '.info | awk ''{print $2}''']);
6 | assert(ret == 0);
7 | result(end) = [];
8 | end
9 |
10 | type = grep(name, 'type');
11 | n = str2num(grep(name, 'end'));
12 | row_major = all(grep(name, 'row_major') == 'true');
13 | sizeof_v = str2num(grep(name, 'sizeof_value'));
14 |
15 | if strfind(type, 'DENSE')
16 | assert(row_major);
17 | assert(sizeof_v == 8);
18 | D = load_bin([name, '.value'], 'double');
19 | D = reshape(D, n(2), n(1))';
20 | elseif strfind(type, 'SPARSE')
21 |
22 | if exist([name, '.key'], 'file')
23 | group_id = str2num(grep(name, 'group_id'));
24 | begin = sscanf(grep(name, 'fea_begin'), '%lu');
25 | [begin, ix] = sort(begin);
26 | group_id = group_id(ix);
27 | % begin = str2num(; %, 'uint64');
28 | % begin = uint64(begin)'
29 | % return
30 | % e = str2num(grep(name, 'feature_end'), 'uint64');
31 | key = load_bin([name, '.key'], 'uint64');
32 | group_offset = zeros(length(begin)+1, 1);
33 | group_offset(end) = length(key) + 1;
34 | for i = 1 : length(begin)
35 | group_offset(i) = find(key == begin(i), 1, 'first');
36 | end
37 | clear key
38 | % group_offset
39 | end
40 |
41 | sizeof_i = str2num(grep(name, 'sizeof_index'));
42 | assert(sizeof_i == 4);
43 |
44 | j = load_bin([name '.index'], 'uint32');
45 | J = double(j) + 1;
46 | clear j
47 |
48 | i = load_bin([name '.offset'], 'uint64');
49 | I = zeros(length(J), 1);
50 | for k = 1 : length(i) - 1
51 | I(i(k)+1 : i(k+1)) = k;
52 | end
53 | clear i
54 |
55 | if strfind(type, 'SPARSE_BINARY')
56 | D = sparse(I, J, 1);
57 | else
58 | V = load_bin([name '.value'], 'double');
59 | D = sparse(I, J, V);
60 | end
61 |
62 | end
63 |
64 | end
65 |
--------------------------------------------------------------------------------
/src/data/matlab/filter_fea.m:
--------------------------------------------------------------------------------
1 | % Y = bin2mat('CTRb.Y');
2 | % [X, group_id, group_os] = bin2mat('CTRb.X');
3 |
4 | pv = 4;
5 |
6 | ix = sum(X) > pv;
7 | X = X(:,ix);
8 |
9 | os = group_os;
10 | for i = 2 : length(os)
11 | os(i) = os(i-1) + nnz(ix(group_os(i-1):group_os(i)-1));
12 | end
13 | group_os = os;
14 |
15 |
16 | % save CTRb_pv4 X Y group_id group_os
17 |
--------------------------------------------------------------------------------
/src/data/matlab/load_bin.m:
--------------------------------------------------------------------------------
1 | function data = load_bin(filename, format, offset, length)
2 | %load a vector from a binary file
3 | %load_bin(filename, format, offset, length)
4 |
5 | if nargin < 2, format = 'double'; end
6 | if nargin < 3, offset = 0; end
7 | if nargin < 4, length = inf; end
8 |
9 | fid=fopen(filename,'r');
10 | if (fid < 0)
11 | data = [];
12 | disp([filename ' doesn''t exist']);
13 | return;
14 | end
15 |
16 | eval(['v=', format, '(1);']);
17 | w = whos('v');
18 | bsize = w.bytes;
19 |
20 | fseek(fid, bsize*offset, -1);
21 | data = fread(fid, length, [format,'=>',format]);
22 | fclose(fid);
23 |
24 | end
25 |
--------------------------------------------------------------------------------
/src/data/matlab/mat2bin.m:
--------------------------------------------------------------------------------
1 | function mat2bin(name, Y, X)
2 | % new version
3 | % save to row-majored binary format
4 | % if nargin < 4, index_fmt = 'uint32'; end
5 |
6 | [a,b,x] = find(X');
7 |
8 | save_bin([name '.index'], uint32(a-1), 'uint32');
9 |
10 | bool = 'false';
11 | if nnz(x==0) + nnz(x==1) == length(x)
12 | disp('binary data');
13 | bool = 'true';
14 | save_bin([name '.offset'], uint64(cumsum([0; full(sum(X,2))])), 'uint64');
15 | else
16 | save_bin([name '.value'], x, 'double');
17 | clear x
18 | E = sparse(a,b,1);
19 | save_bin([name '.offset'], uint64(cumsum([0 full(sum(E))])), 'uint64');
20 | end
21 |
22 | save_bin([name '.label'], Y, 'double');
23 |
24 | fid = fopen([name, '.info'], 'w');
25 |
26 |
27 | fprintf(fid, strcat('row_major : true\n',...
28 | 'sparse : true\n',...
29 | 'bool_value : %s\n',...
30 | 'row_begin : 0\n',...
31 | 'row_end : %lu\n',...
32 | 'col_begin : 0\n',...
33 | 'col_end : %lu\n',...
34 | 'entries : %lu\n',...
35 | 'sizeof_index : 4\n',...
36 | 'sizeof_value : 8\n'), bool, size(X,1), size(X,2), nnz(X));
37 | fclose(fid);
38 | % fprintf(fid, '0 %lu\n0 %lu\n%lu\n', size(X,1), size(X,2), nnz(X));
39 |
40 | end
41 |
--------------------------------------------------------------------------------
/src/data/matlab/save_bin.m:
--------------------------------------------------------------------------------
1 | function save_bin(name, X, format)
2 | %save a vector into a binary file
3 | %save_bin(name, X, format)
4 |
5 | if nargin < 3
6 | format = 'double';
7 | end
8 |
9 | fid = fopen(name, 'w');
10 | n = fwrite(fid, X, format);
11 | fclose(fid);
12 |
13 | if n ~= numel(X)
14 | disp(sprintf('only write %d element, the actual one is %d', n, numel(X)));
15 | end
16 |
--------------------------------------------------------------------------------
/src/data/matlab/saveas_pserver.m:
--------------------------------------------------------------------------------
1 | function saveas_pserver(file_name, Y, X, group_id, binary)
2 | % debug use only, not efficient for large data
3 |
4 | if nargin < 4, group_id = zeros(size(X,2),1); end
5 | if nargin < 5, binary = false; end
6 |
7 | assert(issorted(group_id));
8 |
9 | tX = X';
10 | fd = fopen(file_name, 'w');
11 | for i = 1 : length(Y)
12 | fprintf(fd, '%d', Y(i));
13 | [a,b,c] = find(tX(:,i));
14 | pre_gid = -1;
15 | for j = 1 : length(a)
16 | id = group_id(a(j));
17 | if id ~= pre_gid
18 | fprintf(fd, '; %d', id);
19 | pre_gid = id;
20 | end
21 |
22 | if binary
23 | fprintf(fd, ' %d', a(j)-1);
24 | else
25 | fprintf(fd, ' %d:%g', a(j)-1, c(j));
26 | end
27 | end
28 | fprintf(fd, ';\n');
29 | end
30 | fclose(fd);
31 |
--------------------------------------------------------------------------------
/src/data/proto/data.proto:
--------------------------------------------------------------------------------
1 | package PS;
2 | import "util/proto/range.proto";
3 |
4 | message DataConfig {
5 | enum DataFormat {
6 | BIN = 1;
7 | PROTO = 2;
8 | TEXT = 3;
9 | }
10 | required DataFormat format = 1;
11 |
12 | // see https://github.com/mli/parameter_server/wiki/Data
13 | enum TextFormat {
14 | DENSE = 1;
15 | SPARSE = 2;
16 | SPARSE_BINARY = 3;
17 | ADFEA = 4;
18 | LIBSVM = 5;
19 | TERAFEA = 6;
20 | VW = 7;
21 | CRITEO = 9;
22 | }
23 | optional TextFormat text = 2;
24 |
25 | // filenames, supports regular expressions
26 | repeated string file = 3;
27 | // files stored in hdfs
28 | optional HDFSConfig hdfs = 5;
29 | // ignore the feature group information
30 | optional bool ignore_feature_group = 6;
31 | // the maximal number of files will be assigned to a worker, -1 means no limit
32 | optional int32 max_num_files_per_worker = 7 [default = -1];
33 |
34 | // the maximal number of lines will be read from a file, -1 means no limit
35 | optional int32 max_num_lines_per_file = 8 [default = -1];
36 |
37 | // randomly shuffle the file order
38 | optional bool shuffle = 9 [default = false];
39 |
40 | // only valid for the binary format
41 | optional PbRange range = 4;
42 | // duplicate the file several times
43 | optional int32 replica = 10 [default = 1];
44 | }
45 |
46 | message HDFSConfig {
47 | optional string home = 1; // HADOOP_HOME
48 | optional string ugi = 2; // hadoop.job.ugi, format: user,passwd
49 | optional string namenode = 4; // fs.default.name
50 | }
51 |
--------------------------------------------------------------------------------
/src/data/proto/example.proto:
--------------------------------------------------------------------------------
1 | package PS;
2 | message SlotInfo {
3 | enum Format {
4 | DENSE = 1;
5 | SPARSE = 2;
6 | SPARSE_BINARY = 3;
7 | }
8 | optional Format format = 1;
9 | optional int32 id = 2;
10 | optional uint64 min_key = 3 [default = 0xFFFFFFFFFFFFFFFF];
11 | optional uint64 max_key = 4;
12 | // total number of non-zero elements
13 | optional uint64 nnz_ele = 5;
14 | // total number of non-zero instance, (non-empty rows)
15 | optional uint64 nnz_ex = 6;
16 | }
17 |
18 | message ExampleInfo {
19 | repeated SlotInfo slot = 1;
20 | // total number of instances
21 | optional uint64 num_ex = 2;
22 |
23 | }
24 |
25 | message Slot {
26 | optional int32 id = 1;
27 | repeated uint64 key = 2 [packed=true];
28 | repeated float val = 3 [packed=true];
29 | }
30 |
31 | message Example {
32 | repeated Slot slot = 1;
33 | }
34 |
--------------------------------------------------------------------------------
/src/data/show_example.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "util/common.h"
3 | #include "data/common.h"
4 | #include "util/recordio.h"
5 | #include "proto/example.pb.h"
6 | namespace PS {
7 |
8 | DEFINE_int32(n, 3, "show the first *n* instances in text format");
9 |
10 | static void showExample() {
11 | File* in = File::openOrDie(FLAGS_input, "r");
12 | RecordReader reader(in);
13 | for (int i = 0; i < FLAGS_n; ++i) {
14 | Example ex;
15 | CHECK(reader.ReadProtocolMessage(&ex));
16 | std::cout << ex.ShortDebugString() << std::endl;
17 | }
18 | }
19 |
20 | } // namespace PS
21 |
--------------------------------------------------------------------------------
/src/data/slot_reader.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "util/shared_array_inl.h"
3 | #include "proto/example.pb.h"
4 | #include "data/common.h"
5 | namespace PS {
6 |
7 | // read all slots in *data* with multithreadd, save them into *cache*.
8 | class SlotReader {
9 | public:
10 | SlotReader() { }
11 | SlotReader(const DataConfig& data, const DataConfig& cache) {
12 | Init(data, cache);
13 | }
14 |
15 | void Init(const DataConfig& data, const DataConfig& cache);
16 |
17 | // first read, then save
18 | int Read(ExampleInfo* info = nullptr);
19 |
20 | template MatrixInfo info(int slot_id) const {
21 | return readMatrixInfo(info_, slot_id, sizeof(uint64), sizeof(V));
22 | }
23 |
24 | // load a slot from cache
25 | SArray offset(int slot_id);
26 | SArray index(int slot_id);
27 | template SArray value(int slot_id) const;
28 |
29 | void clear(int slot_id) {
30 | offset_cache_.erase(slot_id);
31 | index_cache_.erase(slot_id);
32 | }
33 |
34 | private:
35 | string cacheName(const DataConfig& data, int slot_id) const;
36 | size_t nnzEle(int slot_id) const;
37 | bool readOneFile(const DataConfig& data, int ith_file);
38 | string cache_;
39 | DataConfig data_;
40 | // bool dump_to_disk_;
41 | ExampleInfo info_;
42 | std::unordered_map slot_info_;
43 | std::mutex mu_;
44 | size_t loaded_file_count_;
45 | std::vector num_ex_;
46 | std::unordered_map> offset_cache_;
47 | std::unordered_map> index_cache_;
48 | };
49 |
50 | template SArray SlotReader::value(int slot_id) const {
51 | SArray val;
52 | if (nnzEle(slot_id) == 0) return val;
53 | for (int i = 0; i < data_.file_size(); ++i) {
54 | string file = cacheName(ithFile(data_, i), slot_id) + ".value";
55 | SArray comp; CHECK(comp.ReadFromFile(file));
56 | SArray uncomp; uncomp.UncompressFrom(comp);
57 | size_t n = val.size();
58 | val.resize(n+uncomp.size());
59 | for (size_t i = 0; i < uncomp.size(); ++i) val[n+i] = uncomp[i];
60 | }
61 | CHECK_EQ(val.size(), nnzEle(slot_id)) << slot_id;
62 | return val;
63 | }
64 |
65 | } // namespace PS
66 |
--------------------------------------------------------------------------------
/src/data/text2proto.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | namespace PS {
3 |
4 | void textToProto() {
5 | // TODO
6 | // string format = FLAGS_format;
7 | // std::transform(format.begin(), format.end(), format.begin(), ::tolower);
8 |
9 | // TextParser parser;
10 | // if (format == "libsvm") {
11 | // parser.setFormat(DataConfig::LIBSVM);
12 | // } else if (format == "adfea") {
13 | // parser.setFormat(DataConfig::ADFEA);
14 | // }
15 |
16 |
17 | // auto record_file = File::openOrDie(FLAGS_output+".recordio", "w");
18 | // RecordWriter writer(record_file);
19 |
20 | // Instance ins;
21 | // int ignored = 0;
22 | // FileLineReader reader(FLAGS_input.c_str());
23 | // reader.set_line_callback([&parser, &ins, &writer, &ignored] (char *line) {
24 | // ignored += !parser.toProtobuf(line, &ins);
25 | // writer.WriteProtocolMessage(ins);
26 | // });
27 |
28 | // Timer t; t.start();
29 | // reader.Reload();
30 | // auto info = parser.info();
31 | // writeProtoToASCIIFileOrDie(info, FLAGS_output+".info");
32 | // t.stop();
33 |
34 | // std::cerr << "written " << info.num_ins()
35 | // << " instances in " << t.get() << " sec." << std::endl;
36 | // if (ignored) {
37 | // std::cerr << ignored << " bad instances are skipped" << std::endl;
38 | // }
39 | }
40 |
41 | } // namespace PS
42 |
--------------------------------------------------------------------------------
/src/data/text_parser.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "util/common.h"
3 | #include "data/proto/example.pb.h"
4 | #include "data/proto/data.pb.h"
5 | namespace PS {
6 |
7 | // parse an example from various text formats into the protobuf format,
8 | // e.g. proto/example.proto
9 | class ExampleParser {
10 | public:
11 | typedef DataConfig::TextFormat TextFormat;
12 | void Init(TextFormat format, bool ignore_fea_slot = false);
13 | bool ToProto(char*, Example*);
14 | private:
15 | bool ParseLibsvm(char*, Example*);
16 | bool ParseAdfea(char*, Example*);
17 | bool ParseTerafea(char*, Example*);
18 | bool ParsePS(char*, Example*, TextFormat);
19 | bool ParseCriteo(char*, Example*);
20 | std::function parser_;
21 | bool ignore_fea_slot_;
22 | };
23 |
24 | }
25 |
--------------------------------------------------------------------------------
/src/filter/add_noise.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "filter/filter.h"
3 | namespace PS {
4 |
5 | /**
6 | * @brief Add noise
7 | *
8 | */
9 | class AddNoiseFilter : public Filter {
10 | public:
11 | void encode(Message* msg) {
12 | auto filter_conf = CHECK_NOTNULL(find(FilterConfig::NOISE, msg));
13 | int n = msg->value.size();
14 | CHECK_EQ(n, msg->task.value_type_size());
15 | for (int i = 0; i < n; ++i) {
16 | if (msg->value[i].size() == 0) continue;
17 | auto type = msg->task.value_type(i);
18 | if (type == DataType::FLOAT) {
19 | AddNoise(msg->value[i], filter_conf);
20 | }
21 | if (type == DataType::DOUBLE) {
22 | AddNoise(msg->value[i], filter_conf);
23 | }
24 | }
25 | }
26 |
27 | private:
28 |
29 | template
30 | void AddNoise(const SArray& array, FilterConfig* cf) {
31 | std::default_random_engine generator;
32 | std::normal_distribution distribution((V)cf->mean(), (V)cf->std());
33 | SArray data(array);
34 | // SArray noise(data.size());
35 | for (size_t i = 0; i < data.size(); ++i) {
36 | data[i] += distribution(generator);
37 | }
38 | // LL << noise.Std() << " " << noise;
39 | }
40 |
41 | };
42 |
43 | } // namespace PS
44 |
--------------------------------------------------------------------------------
/src/filter/compressing.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "filter/filter.h"
3 |
4 | namespace PS {
5 |
6 | class CompressingFilter : public Filter {
7 | public:
8 | void encode(Message* msg) {
9 | auto conf = find(FilterConfig::COMPRESSING, msg);
10 | if (!conf) return;
11 | conf->clear_uncompressed_size();
12 | if (msg->has_key()) {
13 | conf->add_uncompressed_size(msg->key.size());
14 | msg->key = msg->key.CompressTo();
15 | }
16 | for (auto& v : msg->value) {
17 | conf->add_uncompressed_size(v.size());
18 | v = v.CompressTo();
19 | }
20 | }
21 | void decode(Message* msg) {
22 | auto conf = find(FilterConfig::COMPRESSING, msg);
23 | if (!conf) return;
24 | int has_key = msg->has_key();
25 | CHECK_EQ(conf->uncompressed_size_size(), msg->value.size() + has_key);
26 |
27 | if (has_key) {
28 | SArray raw(conf->uncompressed_size(0));
29 | raw.UncompressFrom(msg->key);
30 | msg->key = raw;
31 | }
32 | for (int i = 0; i < msg->value.size(); ++i) {
33 | SArray raw(conf->uncompressed_size(i+has_key));
34 | raw.UncompressFrom(msg->value[i]);
35 | msg->value[i] = raw;
36 | }
37 | }
38 | };
39 |
40 | } // namespace PS
41 |
--------------------------------------------------------------------------------
/src/filter/filter.cc:
--------------------------------------------------------------------------------
1 | #include "filter/filter.h"
2 | #include "filter/compressing.h"
3 | #include "filter/key_caching.h"
4 | #include "filter/fixing_float.h"
5 | #include "filter/add_noise.h"
6 |
7 | namespace PS {
8 |
9 | Filter* Filter::create(const FilterConfig& conf) {
10 | switch (conf.type()) {
11 | case FilterConfig::KEY_CACHING:
12 | return new KeyCachingFilter();
13 | case FilterConfig::COMPRESSING:
14 | return new CompressingFilter();
15 | case FilterConfig::FIXING_FLOAT:
16 | return new FixingFloatFilter();
17 | case FilterConfig::NOISE:
18 | return new AddNoiseFilter();
19 | default:
20 | CHECK(false) << "unknow filter type";
21 | }
22 | return nullptr;
23 | }
24 |
25 |
26 | FilterConfig* Filter::find(FilterConfig::Type type, Task* task) {
27 | for (int i = 0; i < task->filter_size(); ++i) {
28 | if (task->filter(i).type() == type) return task->mutable_filter(i);
29 | }
30 | return nullptr;
31 | }
32 |
33 | } // namespace PS
34 |
--------------------------------------------------------------------------------
/src/filter/filter.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "system/message.h"
3 | #include "filter/proto/filter.pb.h"
4 | #include "util/shared_array_inl.h"
5 |
6 | namespace PS {
7 |
8 | // A filter should be thread safe
9 | class Filter {
10 | public:
11 | Filter() { }
12 | virtual ~Filter() { }
13 |
14 | static Filter* create(const FilterConfig& conf);
15 |
16 |
17 | virtual void encode(Message* msg) { }
18 | virtual void decode(Message* msg) { }
19 |
20 | static FilterConfig* find(FilterConfig::Type type, Message* msg) {
21 | return find(type, &(msg->task));
22 | }
23 | static FilterConfig* find(FilterConfig::Type type, Task* task);
24 | };
25 |
26 | } // namespace
27 |
--------------------------------------------------------------------------------
/src/filter/fixing_float.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "filter/filter.h"
3 | #include
4 | namespace PS {
5 |
6 | class FixingFloatFilter : public Filter {
7 | public:
8 | void encode(Message* msg) {
9 | convert(msg, true);
10 | }
11 |
12 | void decode(Message* msg) {
13 | convert(msg, false);
14 | }
15 |
16 | private:
17 | // a fast random function
18 | static bool boolrand(int* seed) {
19 | *seed = (214013 * *seed + 2531011);
20 | return ((*seed >> 16) & 0x1) == 0;
21 | }
22 |
23 | // decode / encode a message
24 | void convert(Message* msg, bool encode) {
25 | auto filter_conf = CHECK_NOTNULL(find(FilterConfig::FIXING_FLOAT, msg));
26 | if (filter_conf->num_bytes() == 0) return;
27 | int n = msg->value.size();
28 | CHECK_EQ(n, msg->task.value_type_size());
29 | int k = 0;
30 | for (int i = 0; i < n; ++i) {
31 | if (msg->value[i].size() == 0) continue;
32 | auto type = msg->task.value_type(i);
33 | if (filter_conf->fixed_point_size() <= k) {
34 | filter_conf->add_fixed_point();
35 | }
36 | if (type == DataType::FLOAT) {
37 | msg->value[i] = convert(
38 | msg->value[i], encode, filter_conf->num_bytes(),
39 | filter_conf->mutable_fixed_point(k++));
40 | }
41 | if (type == DataType::DOUBLE) {
42 | msg->value[i] = convert(
43 | msg->value[i], encode, filter_conf->num_bytes(),
44 | filter_conf->mutable_fixed_point(k++));
45 | }
46 | }
47 | }
48 |
49 | // decode / encode an array
50 | template
51 | SArray convert(const SArray& array, bool encode, int nbytes,
52 | FilterConfig::FixedFloatConfig* conf) {
53 | CHECK_GT(nbytes, 0);
54 | CHECK_LT(nbytes, 8);
55 | double ratio = static_cast(1 << (nbytes*8)) - 2;
56 |
57 | if (encode) {
58 | if (!conf->has_min_value()) {
59 | conf->set_min_value(SArray(array).EigenArray().minCoeff());
60 | }
61 | if (!conf->has_max_value()) {
62 | conf->set_max_value(SArray(array).EigenArray().maxCoeff() + 1e-6); // to avoid max_v == min_v
63 | }
64 | }
65 |
66 | CHECK(conf->has_min_value());
67 | double min_v = static_cast(conf->min_value());
68 | CHECK(conf->has_max_value());
69 | double max_v = static_cast(conf->max_value());
70 | double bin = max_v - min_v;
71 | CHECK_GT(bin, 0);
72 |
73 | if (encode) {
74 | // float/double to nbytes*8 int
75 | SArray orig(array);
76 | SArray code(orig.size() * nbytes);
77 | uint8* code_ptr = code.data();
78 | int seed = time(NULL);
79 | for (int i = 0; i < orig.size(); ++i) {
80 | double proj = orig[i] > max_v ? max_v : orig[i] < min_v ? min_v : orig[i];
81 | double tmp = (proj - min_v) / bin * ratio;
82 | uint64 r = static_cast(floor(tmp)) + boolrand(&seed);
83 | for (int j = 0; j < nbytes; ++j) {
84 | *(code_ptr++) = static_cast(r & 0xFF);
85 | r = r >> 8;
86 | }
87 | }
88 | return SArray(code);
89 | } else {
90 | // nbytes*8 int to float/double
91 | uint8* code_ptr = SArray(array).data();
92 | SArray orig(array.size() / nbytes);
93 | for (int i = 0; i < orig.size(); ++i) {
94 | double r = 0;
95 | for (int j = 0; j < nbytes; ++j) {
96 | r += static_cast(*(code_ptr++)) << 8 * j;
97 | }
98 | orig[i] = static_cast(r / ratio * bin + min_v);
99 | }
100 | return SArray(orig);
101 | }
102 | }
103 | };
104 |
105 | } // namespace PS
106 |
--------------------------------------------------------------------------------
/src/filter/frequency_filter.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "util/countmin.h"
3 | #include "util/shared_array_inl.h"
4 | namespace PS {
5 |
6 | /**
7 | * @brief Filters infrequent keys via the countmin sketch
8 | * @tparam K key type
9 | * @tparam V counter type
10 | */
11 | template
12 | class FreqencyFilter {
13 | public:
14 | /**
15 | * @brief Add keys with their key count
16 | *
17 | * @param key the list of keys
18 | * @param count the according frequency count
19 | */
20 | void InsertKeys(const SArray& key, const SArray& count);
21 |
22 | /**
23 | * @brief Filters infrequency keys
24 | *
25 | * @param key the list of keys
26 | * @param freq_thr the frequency threshold
27 | *
28 | * @return the keys whose frequency is greater than freq_thr
29 | */
30 | SArray QueryKeys(const SArray& key, int freq_thr);
31 |
32 | bool Empty() { return count_.empty(); }
33 |
34 | /**
35 | * @brief resize the countmin sketch
36 | *
37 | */
38 | void Resize(int n, int k) { count_.resize(n, k, 254); }
39 |
40 | void Clear() { count_.clear(); }
41 |
42 | private:
43 | CountMin count_;
44 | };
45 |
46 | // countmin implementation
47 | template
48 | SArray FreqencyFilter::QueryKeys(const SArray& key, int freqency) {
49 | CHECK_LT(freqency, kuint8max) << "change to uint16 or uint32...";
50 | SArray filtered_key;
51 | for (auto k : key) {
52 | if ((int)count_.query(k) > freqency) {
53 | filtered_key.push_back(k);
54 | }
55 | }
56 | return filtered_key;
57 | }
58 |
59 | template
60 | void FreqencyFilter::InsertKeys(const SArray& key, const SArray& count) {
61 | CHECK_EQ(key.size(), count.size());
62 | for (size_t i = 0; i < key.size(); ++i) {
63 | count_.insert(key[i], count[i]);
64 | }
65 | }
66 |
67 | // DEPRECATED hash implementation
68 | // std::unordered_map map_;
69 |
70 | // template
71 | // SArray FreqencyFilter::QueryKeys(const SArray& key, int freqency) {
72 | // SArray filtered_key;
73 | // for (K k : key) {
74 | // if (map_[k] > freqency) filtered_key.push_back(k);
75 | // }
76 | // return filtered_key;
77 | // }
78 |
79 | // template
80 | // void FreqencyFilter::InsertKeys(const SArray& key, const SArray& count) {
81 | // CHECK_EQ(key.size(), count.size());
82 | // for (size_t i = 0; i < key.size(); ++i) {
83 | // map_[key[i]] += count[i];
84 | // }
85 | // }
86 |
87 | }
88 |
--------------------------------------------------------------------------------
/src/filter/key_caching.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "filter/filter.h"
3 | #include "util/crc32c.h"
4 | namespace PS {
5 |
6 | class KeyCachingFilter : public Filter {
7 | public:
8 | // thread safe
9 | void encode(Message* msg) {
10 | // if (!msg->task.has_key_range()) return;
11 | auto conf = find(FilterConfig::KEY_CACHING, msg);
12 | if (!conf) return;
13 | if (!msg->has_key()) {
14 | conf->clear_signature();
15 | return;
16 | }
17 | const auto& key = msg->key;
18 | auto sig = crc32c::Value(key.data(), std::min(key.size(), max_sig_len_));
19 | conf->set_signature(sig);
20 | auto cache_k = std::make_pair(
21 | msg->task.key_channel(), Range(msg->task.key_range()));
22 | Lock l(mu_);
23 | auto& cache = cache_[cache_k];
24 | bool hit_cache = cache.first == sig && cache.second.size() == key.size();
25 | if (hit_cache) {
26 | msg->clear_key();
27 | } else {
28 | cache.first = sig;
29 | cache.second = key;
30 | }
31 | if (conf->clear_cache_if_done() && isDone(msg->task)) {
32 | cache_.erase(cache_k);
33 | }
34 | }
35 |
36 | void decode(Message* msg) {
37 | // if (!msg->task.has_key_range()) return;
38 | auto conf = find(FilterConfig::KEY_CACHING, msg);
39 | if (!conf || !conf->has_signature()) return;
40 | auto sig = conf->signature();
41 | // do a double check
42 | if (msg->has_key()) {
43 | CHECK_EQ(crc32c::Value(msg->key.data(), std::min(msg->key.size(), max_sig_len_)), sig);
44 | }
45 | auto cache_k = std::make_pair(
46 | msg->task.key_channel(), Range(msg->task.key_range()));
47 | Lock l(mu_);
48 | auto& cache = cache_[cache_k];
49 | if (msg->has_key()) {
50 | cache.first = sig;
51 | cache.second = msg->key;
52 | } else {
53 | // the cache is invalid... may ask the sender to resend this task
54 | CHECK_EQ(sig, cache.first) << msg->DebugString();
55 | msg->set_key(cache.second);
56 | }
57 | if (conf->clear_cache_if_done() && isDone(msg->task)) {
58 | cache_.erase(cache_k);
59 | }
60 | }
61 |
62 | private:
63 | bool isDone(const Task& task) {
64 | return (!task.request() ||
65 | (task.has_param()
66 | && task.param().push()));
67 | }
68 |
69 | std::unordered_map<
70 | std::pair>, std::pair>> cache_;
71 |
72 | // calculate the signature using the first max_sig_len_*4 bytes to accelerate
73 | // the computation
74 | const size_t max_sig_len_ = 2048;
75 | std::mutex mu_;
76 | };
77 |
78 | } // namespace
79 |
--------------------------------------------------------------------------------
/src/filter/proto/filter.proto:
--------------------------------------------------------------------------------
1 | package PS;
2 |
3 | message FilterConfig {
4 | enum Type {
5 | // cache the keys at both sender and receiver
6 | KEY_CACHING = 1;
7 | // compress data by snappy
8 | COMPRESSING = 2;
9 | // convert a float/double into a fixed-point integer
10 | FIXING_FLOAT = 3;
11 | // add noise to data
12 | NOISE = 4;
13 | }
14 | required Type type = 1;
15 |
16 | // -- key caching --
17 | // if the task is done, then clear the cache (to save memory)
18 | optional bool clear_cache_if_done = 20 [default = false];
19 |
20 | // -- fixing float filter --
21 | optional int32 num_bytes = 5 [default = 3];
22 | message FixedFloatConfig {
23 | optional float min_value = 1 [default = -1];
24 | optional float max_value = 2 [default = 1];
25 | }
26 | repeated FixedFloatConfig fixed_point = 4;
27 |
28 | // -- nosie --
29 | optional float mean = 6;
30 | optional float std = 7;
31 |
32 | // -- runtime parameters used by the system --
33 | optional uint32 signature = 2;
34 | repeated uint64 uncompressed_size = 3;
35 | }
36 |
--------------------------------------------------------------------------------
/src/filter/sparse_filter.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "filter/filter.h"
3 | namespace PS {
4 |
5 | class SparseFilter : public Filter {
6 | public:
7 | SparseFilter() {
8 | // use 0xffff..ff as the mark when a value is filtered, it is nan for float
9 | // and double.
10 | memcpy(&double_v_, &kuint64max, sizeof(double));
11 | memcpy(&float_v_, &kuint32max, sizeof(float));
12 | }
13 |
14 | // mark an entry as filtered
15 | void mark(float* v) { *v = float_v_; }
16 | void mark(double* v) { *v = double_v_; }
17 |
18 | // test whether or not an entry is filtered
19 | bool marked(double v) { return v != v; }
20 | bool marked(float v) { return v != v; }
21 | private:
22 | float float_v_;
23 | double double_v_;
24 | };
25 |
26 | } // namespace PS
27 |
--------------------------------------------------------------------------------
/src/learner/proto/bcd.proto:
--------------------------------------------------------------------------------
1 | package PS;
2 | import "util/proto/range.proto";
3 | import "data/proto/data.proto";
4 | import "data/proto/example.proto";
5 | import "parameter/proto/param.proto";
6 | import "filter/proto/filter.proto";
7 |
8 | message BCDConfig {
9 | // Divide a feature group into feature_block_ratio x nnz_feature_per_example
10 | // blocks
11 | optional float feature_block_ratio = 1 [default = 4];
12 | // Use a random order to updating feature block. Turn it off to elimiate the
13 | // randomness (often for debuging), but may affects the convergence rate.
14 | optional bool random_feature_block_order = 2 [default = true];
15 | // Updating important feature group at the beginning to get a good initial
16 | // start point.
17 | repeated int32 prior_fea_group = 14;
18 | optional int32 num_iter_for_prior_fea_group = 13 [default = 5];
19 |
20 | // Bounded-delay consistency
21 | optional int32 max_block_delay = 3 [default = 0];
22 | // max number pass of traing data
23 | optional int32 max_pass_of_data = 4 [default = 10];
24 | // convergance critiria. stop if the relative objective <= epsilon
25 | optional double epsilon = 5 [default = 1e-4];
26 |
27 | // features which occurs <= *tail_feature_freq* will be filtered before training
28 | optional int32 tail_feature_freq = 6 [default = 0];
29 | // countmin sketch is used to filter the tail features. It has two
30 | // parameters, k and n.
31 | optional int32 countmin_k = 7 [default = 2];
32 | // n = the_first_arrive_key_length * num_workers * countmin_n_ratio
33 | optional double countmin_n_ratio = 8 [default = 2.0];
34 |
35 |
36 | optional int32 max_num_parallel_groups_in_preprocessing = 9 [default = 1000];
37 | optional int32 max_data_buf_size_in_mb = 10 [default = 1000];
38 | optional DataConfig local_cache = 11;
39 |
40 | optional ParamInitConfig init_w = 12;
41 |
42 | repeated FilterConfig comm_filter = 15;
43 |
44 | optional int32 save_model_every_n_iter = 16 [default = 0];
45 |
46 | optional bool load_local_data = 17 [default = false];
47 | extensions 100 to 199;
48 | }
49 |
50 | message BCDProgress {
51 | optional double objective = 1;
52 | optional double relative_obj = 2;
53 | optional uint64 nnz_w = 5;
54 | optional double violation = 6;
55 | optional uint64 nnz_active_set = 7;
56 |
57 | // performance
58 | optional double total_time = 10;
59 | repeated double busy_time = 11;
60 |
61 | extensions 100 to 199;
62 | }
63 |
64 |
65 | message BCDCall {
66 | enum Command {
67 | LOAD_DATA = 1;
68 | PREPROCESS_DATA = 2;
69 | UPDATE_MODEL = 3;
70 | EVALUATE_PROGRESS = 4;
71 | SAVE_MODEL = 5; // save w
72 | RECOVER = 6;
73 | COMPUTE_VALIDATION_AUC = 7;
74 | REQUEST_WORKLOAD = 8;
75 | // SAVE_AS_DENSE = 7; // save X * w in a given key range
76 | }
77 | required Command cmd = 1;
78 | optional PbRange key = 2;
79 | // optional int32 feature_group_id = 3;
80 |
81 | optional double kkt_filter_threshold = 4;
82 | optional bool reset_kkt_filter = 5;
83 |
84 | optional int32 iter = 11;
85 | repeated int32 fea_grp = 8;
86 | optional bool hit_cache = 9;
87 | optional DataConfig data = 10;
88 | optional int32 time = 12;
89 | }
90 |
91 |
92 | message LoadDataResponse {
93 | optional ExampleInfo example_info = 1;
94 | optional int32 hit_cache = 2;
95 | }
96 |
--------------------------------------------------------------------------------
/src/learner/proto/sgd.proto:
--------------------------------------------------------------------------------
1 | package PS;
2 | import "data/proto/data.proto";
3 | import "learner/proto/workload.proto";
4 |
5 | message SGDProgress {
6 | repeated double objective = 1;
7 | optional uint64 num_examples_processed = 2;
8 | repeated double accuracy = 3;
9 | repeated double auc = 4;
10 | optional uint64 nnz = 5;
11 | optional double weight_sum = 6;
12 | optional double delta_sum = 7;
13 | }
14 |
15 | message SGDCall {
16 | enum Command {
17 | REQUEST_WORKLOAD = 6;
18 | UPDATE_MODEL = 1;
19 | REPORT_PROGRESS = 2;
20 | SAVE_MODEL = 3;
21 | RECOVER = 4;
22 | COMPUTE_VALIDATION_AUC = 5;
23 | }
24 | required Command cmd = 1;
25 | optional Workload load = 2;
26 | }
27 |
--------------------------------------------------------------------------------
/src/learner/proto/workload.proto:
--------------------------------------------------------------------------------
1 | package PS;
2 | import "data/proto/data.proto";
3 |
4 | message Workload {
5 | // the workload id
6 | optional int32 id = 1;
7 |
8 | // the data associated with this workload
9 | optional DataConfig data = 2;
10 |
11 | // randomly shuffle the file order
12 | optional bool shuffle = 3 [default = false];
13 |
14 | // duplicate the data several times
15 | optional int32 replica = 4 [default = 1];
16 |
17 | // all finished workload ids
18 | repeated int32 finished = 6;
19 |
20 | // all workload is been done
21 | optional bool all_is_done = 5 [default = false];
22 | }
23 |
--------------------------------------------------------------------------------
/src/learner/sgd.cc:
--------------------------------------------------------------------------------
1 | #include "learner/sgd.h"
2 | namespace PS {
3 |
4 | ISGDScheduler::~ISGDScheduler() {
5 | // core dump when delete workload_pool_;
6 | }
7 |
8 | void ISGDScheduler::Run() {
9 | // init monitor
10 | using namespace std::placeholders;
11 | monitor_.set_merger(std::bind(&ISGDScheduler::MergeProgress, this, _1, _2));
12 | monitor_.set_printer(1, std::bind(&ISGDScheduler::ShowProgress, this, _1, _2));
13 |
14 | // wait all jobs are finished
15 | sys_.manager().AddNodeFailureHandler([this](const NodeID& id) {
16 | CHECK_NOTNULL(workload_pool_)->restore(id);
17 | });
18 | CHECK_NOTNULL(workload_pool_)->waitUtilDone();
19 |
20 | // save model
21 | Task task;
22 | task.mutable_sgd()->set_cmd(SGDCall::SAVE_MODEL);
23 | int ts = Submit(task, kServerGroup);
24 | Wait(ts);
25 | }
26 |
27 | void ISGDScheduler::ProcessResponse(Message* response) {
28 | const auto& sgd = response->task.sgd();
29 | if (sgd.cmd() == SGDCall::UPDATE_MODEL) {
30 | for (int i = 0; i < sgd.load().finished_size(); ++i) {
31 | workload_pool_->finish(sgd.load().finished(i));
32 | }
33 | SendWorkload(response->sender);
34 | }
35 | }
36 |
37 | void ISGDScheduler::ProcessRequest(Message* request) {
38 | if (request->task.sgd().cmd() == SGDCall::REQUEST_WORKLOAD) {
39 | SendWorkload(request->sender);
40 | }
41 | }
42 |
43 | void ISGDScheduler::SendWorkload(const NodeID& recver) {
44 | Task task;
45 | task.mutable_sgd()->set_cmd(SGDCall::UPDATE_MODEL);
46 | if (workload_pool_->assign(recver, task.mutable_sgd()->mutable_load())) {
47 | Submit(task, recver);
48 | }
49 | }
50 |
51 | void ISGDScheduler::ShowProgress(
52 | double time, std::unordered_map* progress) {
53 | uint64 num_ex = 0, nnz_w = 0;
54 | SArray objv, auc, acc;
55 | double weight_sum = 0, delta_sum = 1e-20;
56 | for (const auto& it : *progress) {
57 | auto& prog = it.second;
58 | num_ex += prog.num_examples_processed();
59 | nnz_w += prog.nnz();
60 | for (int i = 0; i < prog.objective_size(); ++i) {
61 | objv.push_back(prog.objective(i));
62 | }
63 | for (int i = 0; i < prog.auc_size(); ++i) {
64 | auc.push_back(prog.auc(i));
65 | }
66 | for (int i = 0; i < prog.accuracy_size(); ++i) {
67 | acc.push_back(prog.accuracy(i));
68 | }
69 | weight_sum += prog.weight_sum();
70 | delta_sum += prog.delta_sum();
71 | }
72 | progress->clear();
73 | num_ex_processed_ += num_ex;
74 | if (show_prog_head_) {
75 | NOTICE(" sec examples loss auc accuracy |w|_0 updt ratio");
76 | show_prog_head_ = false;
77 | }
78 | NOTICE("%4d %.2e %.3e %.4f %.4f %.2e %.2e",
79 | (int)time,
80 | (double)num_ex_processed_ ,
81 | objv.Sum()/(double)num_ex,
82 | auc.Mean(),
83 | acc.Mean(),
84 | (double)nnz_w,
85 | sqrt(delta_sum) / sqrt(weight_sum));
86 | }
87 |
88 |
89 | void ISGDScheduler::MergeProgress(const SGDProgress& src, SGDProgress* dst) {
90 | auto old = *dst; *dst = src;
91 | // TODO also append objv
92 | dst->set_num_examples_processed(
93 | dst->num_examples_processed() + old.num_examples_processed());
94 | }
95 |
96 | } // namespace PS
97 |
--------------------------------------------------------------------------------
/src/learner/workload_pool.cc:
--------------------------------------------------------------------------------
1 | #include "learner/workload_pool.h"
2 | #include "data/common.h"
3 | namespace PS {
4 |
5 | void WorkloadPool::set(const Workload& load) {
6 | VLOG(1) << "init workload " << load.ShortDebugString();
7 | Lock l(mu_);
8 | CHECK_GT(load.replica(), 0);
9 | DataConfig files = searchFiles(load.data());
10 | VLOG(1) << "find " << files.file_size() << " files: " << files.ShortDebugString();
11 |
12 | loads_.resize(files.file_size() * load.replica());
13 | int k = 0;
14 | for (int r = 0; r < load.replica(); ++r) {
15 | if (load.shuffle()) files = shuffleFiles(files);
16 | for (int i = 0; i < files.file_size(); ++i) {
17 | // be careful here, do not use loads_[k] = xxx; it will copy the pointer
18 | // of load.data_ rather than the value.
19 | *loads_[k].load.mutable_data() = ithFile(files, i);
20 | loads_[k].load.set_id(k);
21 | ++ k;
22 | }
23 | }
24 | CHECK_EQ(k, loads_.size());
25 | }
26 |
27 |
28 | bool WorkloadPool::assign(const NodeID& node_id, Workload* load) {
29 | Lock l(mu_);
30 | for (auto& info : loads_) {
31 | if (!info.assigned) {
32 | load->CopyFrom(info.load);
33 | info.node = node_id;
34 | info.assigned = true;
35 | VLOG(1) << "assign to [" << node_id << "] " << load->ShortDebugString();
36 | return true;
37 | }
38 | }
39 | return false;
40 | }
41 |
42 |
43 | void WorkloadPool::restore(const NodeID& node_id) {
44 | Lock l(mu_);
45 | for (auto& info : loads_) {
46 | if (info.assigned && !info.finished && info.node == node_id) {
47 | info.assigned = false;
48 | LOG(INFO) << "restore workload " << info.load.id() << " from " << node_id;
49 | }
50 | }
51 | }
52 |
53 |
54 | void WorkloadPool::finish(int id) {
55 | Lock l(mu_);
56 | CHECK_GE(id, 0); CHECK_LT(id, loads_.size());
57 | loads_[id].finished = true;
58 | ++ num_finished_;
59 | VLOG(1) << "workload " << id << " is finished";
60 | }
61 |
62 |
63 | void WorkloadPool::waitUtilDone() {
64 | while (true) {
65 | mu_.lock();
66 | bool done = num_finished_ >= loads_.size();
67 | mu_.unlock();
68 | if (done) {
69 | break;
70 | } else {
71 | usleep(1000);
72 | }
73 | }
74 | VLOG(1) << "all workloads are done";
75 | }
76 |
77 | } // namespace PS
78 |
--------------------------------------------------------------------------------
/src/learner/workload_pool.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "util/common.h"
3 | #include "system/message.h"
4 | #include "learner/proto/workload.pb.h"
5 | namespace PS {
6 |
7 | // the base class of a workload pool. thread safe
8 | class WorkloadPool {
9 | public:
10 | WorkloadPool() { }
11 | WorkloadPool(const Workload& load) { set(load); }
12 | virtual ~WorkloadPool() { }
13 |
14 | // set all workloads
15 | void set(const Workload& load);
16 |
17 | // assign a piece of *workload* to *node_id*. return false if all is done.
18 | bool assign(const NodeID& node_id, Workload* load);
19 |
20 | // restored unfinished workloads have been assigned to *node_id*
21 | void restore(const NodeID& node_id);
22 |
23 | // mark the workload with *id* as finished
24 | void finish(int id);
25 |
26 | // block until all workloads are finished
27 | void waitUtilDone();
28 |
29 | protected:
30 | struct WorkloadInfo {
31 | NodeID node;
32 | Workload load;
33 | bool assigned = false;
34 | bool finished = false;
35 | };
36 | std::vector loads_;
37 | int num_finished_ = 0;
38 | std::mutex mu_;
39 | };
40 |
41 | } // namespace PS
42 |
--------------------------------------------------------------------------------
/src/parameter/README.org:
--------------------------------------------------------------------------------
1 | #+TITLE: Shared Parameters
2 |
3 |
4 | | name | KV pairs op | sorted key | value size | storage | usage |
5 | |----------------+-------------+------------+------------+----------+-------------------------------------------------------------------------------------|
6 | | KVVector | batched | yes | fixed | array | sparse data with billions keys, used in worker node, or server node for bath method |
7 | | KVMap | individual | no | fixed | hash map | |
8 | | KVLayer | individual | yes | vary | hash map | neural network |
9 | | KVStore (TODO) | individual | no | vary | hash map | |
10 |
--------------------------------------------------------------------------------
/src/parameter/kv_map.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "ps.h"
3 | #include "parameter/parameter.h"
4 | namespace PS {
5 |
6 | /**
7 | * @brief Default entry type for KVMap
8 | */
9 | template
10 | struct KVMapEntry {
11 | void Get(V* data, void* state) { *data = value; }
12 | void Set(const V* data, void* state) { value = *data; }
13 | V value;
14 | };
15 |
16 | /**
17 | * @brief Default state type for KVMap
18 | */
19 | struct KVMapState {
20 | void Update() { }
21 | };
22 |
23 | /**
24 | * @brief A key-value store with fixed length value.
25 | *
26 | *
27 | * @tparam K the key type
28 | * @tparam V the value type
29 | * @tparam E the entry type
30 | * @tparam S the state type
31 | */
32 | template ,
34 | typename S = KVMapState>
35 | class KVMap : public Parameter {
36 | public:
37 | /**
38 | * @brief Constructor
39 | *
40 | * @param k the length of a value entry
41 | * @param id customer id
42 | */
43 | KVMap(int k = 1, int id = NextCustomerID()) :
44 | Parameter(id), k_(k) {
45 | CHECK_GT(k, 0);
46 | }
47 | virtual ~KVMap() { }
48 |
49 | void set_state(const S& s) { state_ = s; }
50 |
51 | virtual void Slice(const Message& request, const std::vector>& krs,
52 | std::vector* msgs) {
53 | SliceKOFVMessage(request, krs, msgs);
54 | }
55 |
56 | virtual void GetValue(Message* msg);
57 | virtual void SetValue(const Message* msg);
58 |
59 | virtual void WriteToFile(std::string file);
60 |
61 | protected:
62 | int k_;
63 | S state_;
64 | // TODO use multi-thread cuokoo hash
65 | std::unordered_map data_;
66 | };
67 |
68 | template
69 | void KVMap::GetValue(Message* msg) {
70 | SArray key(msg->key);
71 | size_t n = key.size();
72 | SArray val(n * k_);
73 | for (size_t i = 0; i < n; ++i) {
74 | data_[key[i]].Get(val.data() + i * k_, &state_);
75 | }
76 | msg->add_value(val);
77 | }
78 |
79 | template
80 | void KVMap::SetValue(const Message* msg) {
81 | SArray key(msg->key);
82 | size_t n = key.size();
83 | CHECK_EQ(msg->value.size(), 1);
84 | SArray val(msg->value[0]);
85 | CHECK_EQ(n * k_, val.size());
86 |
87 | for (size_t i = 0; i < n; ++i) {
88 | data_[key[i]].Set(val.data() + i * k_, &state_);
89 | }
90 | state_.Update();
91 | }
92 | #if USE_S3
93 | bool s3file(const std::string& name);
94 | std::string s3Prefix(const std::string& path);
95 | std::string s3Bucket(const std::string& path);
96 | std::string s3FileUrl(const std::string& path);
97 | #endif // USE_S3
98 |
99 | template
100 | void KVMap::WriteToFile(std::string file) {
101 | #if USE_S3
102 | std::string s3_file;
103 | if (s3file(file)) {
104 | s3_file=file;
105 | // create a local model dir
106 | file=s3Prefix(s3_file);
107 | }
108 | #endif // USE_S3
109 | if (!dirExists(getPath(file))) {
110 | createDir(getPath(file));
111 | }
112 | std::ofstream out(file); CHECK(out.good());
113 | V v;
114 | for (auto& e : data_) {
115 | e.second.Get(&v, &state_);
116 | if (v != 0) out << e.first << "\t" << v << std::endl;
117 | }
118 | #if USE_S3
119 | if (s3file(s3_file)) {
120 | // upload model
121 | std::string cmd = "curl -s '"+s3FileUrl(s3_file)+"?Content-Length="
122 | +std::to_string(File::size(file))+"&x-amz-acl=public-read' --upload-file "+file;
123 | LOG(INFO)<task.param();
7 | Message* response = nullptr;
8 | bool push = call.push();
9 | if (!push) {
10 | // a pull request, need to reply with the value
11 | response = new Message(*request);
12 | }
13 |
14 | if (call.replica()) {
15 | // a replication request
16 | if (push) {
17 | SetReplica(request);
18 | } else {
19 | GetReplica(response);
20 | }
21 | } else {
22 | // a normal request
23 | if (push) {
24 | SetValue(request);
25 | } else {
26 | GetValue(response);
27 | }
28 | }
29 |
30 | if (response) Reply(request, response);
31 | }
32 |
33 | void Parameter::ProcessResponse(Message* response) {
34 | const auto& call = response->task.param();
35 | bool push = call.push();
36 |
37 | if (call.replica()) {
38 | // a replication response
39 | if (push) return; // an ACK response
40 | if (Range(response->task.key_range()) == MyKeyRange()) {
41 | Recover(response);
42 | } else {
43 | SetReplica(response);
44 | }
45 | } else {
46 | // a normal response
47 | if (!push) SetValue(response);
48 | }
49 | }
50 |
51 | } // namespace PS
52 |
--------------------------------------------------------------------------------
/src/parameter/parameter.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "system/customer.h"
3 | #include "parameter/proto/param.pb.h"
4 | namespace PS {
5 |
6 | /// The base class of shared parameters
7 | class Parameter : public Customer {
8 | public:
9 | Parameter(int id) : Customer(id) { }
10 | virtual ~Parameter() { }
11 |
12 | typedef std::initializer_list Timestamps;
13 | typedef ::google::protobuf::RepeatedPtrField Filters;
14 | /**
15 | * @brief Creats a request task
16 | *
17 | * @param channel communication channel
18 | * @param ts the timestamp of this request
19 | * @param wait a list of timestamp this request should wait
20 | * @param filters a list of filters to compress the request message
21 | * @param key_range the key range of this request
22 | *
23 | * @return A Task
24 | */
25 | static Task Request(int channel,
26 | int ts = Message::kInvalidTime,
27 | const Timestamps& wait = {},
28 | const Filters& filters = Filters(),
29 | const Range& key_range = Range::All()) {
30 | Task req; req.set_request(true);
31 | req.set_key_channel(channel);
32 | if (ts > Message::kInvalidTime) req.set_time(ts);
33 | for (int t : wait) req.add_wait_time(t);
34 | for (const auto& f : filters) req.add_filter()->CopyFrom(f);
35 | key_range.To(req.mutable_key_range());
36 | return req;
37 | }
38 |
39 | /// @brief Submit a push message to msg->recver
40 | inline int Push(Message* msg) {
41 | msg->task.mutable_param()->set_push(true);
42 | return Submit(msg);
43 | }
44 |
45 | /// @brief Submit a pull message to msg->recver
46 | inline int Pull(Message* msg) {
47 | msg->task.mutable_param()->set_push(false);
48 | return Submit(msg);
49 | }
50 |
51 | virtual void WriteToFile(std::string file) { }
52 |
53 | virtual void ProcessRequest(Message* request);
54 | virtual void ProcessResponse(Message* response);
55 | protected:
56 |
57 | /// @brief Fill "msg" with the values it requests, e.g.,
58 | /// msg->value(0)[0] = my_val_[msg->key[0]];
59 | virtual void GetValue(Message* msg) = 0;
60 |
61 | /// @brief Set the values in "msg" into into my data strcuture, e.g..
62 | /// my_val_[msg->key[0]] = msg->value(0)[0];
63 | virtual void SetValue(const Message* msg) = 0;
64 |
65 | /// @brief the message contains the backup KV pairs sent by the master node of the key
66 | /// segment to its replica node. merge these pairs into my replica, say
67 | /// replica_[msg->sender] = ...
68 | virtual void SetReplica(const Message* msg) { }
69 |
70 | /// @brief retrieve the replica. a new server node replacing a dead server will first
71 | /// ask for the dead's replica node for the data
72 | virtual void GetReplica(Message* msg) { }
73 |
74 | /// @brief a new server node fill its own datastructure via the the replica data from
75 | /// the dead's replica node
76 | virtual void Recover(Message* msg) { }
77 | };
78 |
79 | } // namespace PS
80 |
--------------------------------------------------------------------------------
/src/parameter/proto/param.proto:
--------------------------------------------------------------------------------
1 | package PS;
2 | import "util/proto/assign_op.proto";
3 |
4 | message ParamCall {
5 | // push or pull
6 | optional bool push = 1 [default = true];
7 |
8 | // merge operator
9 | optional AssignOpType op = 2;
10 |
11 | optional TailKeyFilter tail_filter = 3;
12 |
13 | // optional bool insert_key = 5;
14 | // optional bool gather = 6;
15 |
16 | // it's a replica request
17 | optional bool replica = 10;
18 | repeated Timestamp backup = 11;
19 | }
20 |
21 | message ParamInitConfig {
22 | enum Type {
23 | ZERO = 1;
24 | CONSTANT = 2;
25 | GAUSSIAN = 3;
26 | FILE = 4;
27 | CLONE = 5;
28 | }
29 | optional Type type = 1 [default = ZERO];
30 | optional double constant = 2 [default = 1];
31 | // gaussian random
32 | optional double mean = 3 [default = 0];
33 | optional double std = 4 [default = 1];
34 | optional string file_name = 5;
35 | }
36 |
37 | message Timestamp {
38 | required string sender = 1;
39 | required int32 time = 2;
40 | }
41 |
42 | message TailKeyFilter {
43 | optional bool insert_count = 1;
44 | optional int32 freq_threshold = 2;
45 | optional bool query_value = 3;
46 | optional int32 countmin_n = 4 [default = 1000000];
47 | optional int32 countmin_k = 5 [default = 2];
48 | }
49 |
--------------------------------------------------------------------------------
/src/ps.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "system/customer.h"
3 |
4 | // A simple interface for writing parameter server (PS) programs. see example in
5 | // src/app/hello_world
6 |
7 | // A typical PS program should define the following two functions:
8 |
9 | // This is the main entrance for a work node. All flags and their arguments
10 | // (e.g. -name value) has been removed from argv, and argc has been changed
11 | // properly. However, commandline arguments are remained.
12 | //
13 | // In example "head -n 100 file", -n is a flag and 100 is this flag's argument,
14 | // but file is a commandline argument
15 | int WorkerNodeMain(int argc, char *argv[]);
16 |
17 | namespace PS {
18 |
19 | // Return an instance of a server node. This node is started with "-app_file
20 | // app.conf -app_conf 'key: value'", then conf has both the content of file "app.conf"
21 | // and 'key:value'
22 | App* CreateServerNode(const std::string& conf);
23 |
24 | // Utility functions:
25 |
26 | // The app this node runs
27 | inline App* MyApp() { return Postoffice::instance().manager().app(); }
28 |
29 | // My node information
30 | inline Node MyNode() { return Postoffice::instance().manager().van().my_node(); }
31 | // Each unique string id of my node
32 | inline std::string MyNodeID() { return MyNode().id(); }
33 | // Query the role of this node
34 | inline int IsWorker() { return MyNode().role() == Node::WORKER; }
35 | inline int IsServer() { return MyNode().role() == Node::SERVER; }
36 | inline int IsScheduler() { return MyNode().role() == Node::SCHEDULER; }
37 |
38 | inline Range MyKeyRange() { return Range(MyNode().key()); }
39 | inline std::string SchedulerID() {
40 | return Postoffice::instance().manager().van().scheduler().id();
41 | }
42 |
43 | inline int NextCustomerID() {
44 | return Postoffice::instance().manager().NextCustomerID();
45 | }
46 |
47 | // The rank ID of this node in its group. Assume this a worker node in a worker
48 | // group with N workers. Then this node will be assigned an unique ID from 0,
49 | // ..., N. Similarly for server and scheduler.
50 | inline int MyRank() { return MyNode().rank(); }
51 | // Total nodes in this node group.
52 | inline int RankSize() {
53 | auto& mng = Postoffice::instance().manager();
54 | return IsWorker() ? mng.num_workers() : (IsServer() ? mng.num_servers() : 1);
55 | }
56 |
57 | // Wait until all FLAGS_num_servers servers are ready.
58 | inline void WaitServersReady() {
59 | PS::Postoffice::instance().manager().WaitServersReady();
60 | }
61 |
62 | // Wait until all FLAGS_num_workers workers are ready.
63 | inline void WaitWorkersReady() {
64 | PS::Postoffice::instance().manager().WaitWorkersReady();
65 | }
66 |
67 | inline void StartSystem(int argc, char *argv[]) {
68 | PS::Postoffice::instance().Run(&argc, &argv);
69 | }
70 |
71 | inline void StopSystem() {
72 | PS::Postoffice::instance().Stop();
73 | }
74 |
75 | inline int RunSystem(int argc, char *argv[]) {
76 | StartSystem(argc, argv); StopSystem();
77 | return 0;
78 | }
79 |
80 | } // namespace PS
81 |
--------------------------------------------------------------------------------
/src/ps_main.cc:
--------------------------------------------------------------------------------
1 | #include "ps.h"
2 | namespace PS {
3 |
4 | App* App::Create(const string& conf) {
5 | auto my_role = MyNode().role();
6 | if (my_role == Node::SERVER) {
7 | return CreateServerNode(conf);
8 | }
9 | return new App();
10 | }
11 | } // namespace PS
12 |
13 | int main(int argc, char *argv[]) {
14 | auto& sys = PS::Postoffice::instance();
15 | sys.Run(&argc, &argv);
16 |
17 | int ret = 0;
18 | if (PS::MyNode().role() == PS::Node::WORKER) {
19 | ret = WorkerNodeMain(argc, argv);
20 | }
21 |
22 | sys.Stop();
23 | return ret;
24 | }
25 |
--------------------------------------------------------------------------------
/src/system/assigner.cc:
--------------------------------------------------------------------------------
1 | #include "system/assigner.h"
2 | #include "data/common.h"
3 | namespace PS {
4 |
5 | void DataAssigner::set(const DataConfig& data, int num, bool local) {
6 | // search all files
7 | CHECK_GT(num, 0);
8 | parts_.resize(num);
9 | if (local) {
10 | for (int i = 0; i < num; ++i) parts_[i].CopyFrom(data);
11 | return;
12 | }
13 |
14 | CHECK_GT(data.replica(), 0);
15 | DataConfig files = searchFiles(data);
16 | VLOG(1) << "find " << files.file_size()
17 | << " files: " << files.ShortDebugString();
18 |
19 | // divide them
20 | for (int r = 0; r < data.replica(); ++r) {
21 | if (data.shuffle()) files = shuffleFiles(files);
22 | auto pts = divideFiles(files, num);
23 | if (r == 0) {
24 | parts_ = pts;
25 | } else {
26 | for (int i = 0; i < num; ++i) {
27 | parts_[i] = appendFiles(parts_[i], pts[i]);
28 | }
29 | }
30 | }
31 | VLOG(1) << "divide into " << num << " jobs";
32 | }
33 |
34 | bool DataAssigner::next(DataConfig *data) {
35 | if (cur_i_ >= parts_.size()) return false;
36 | *data = parts_[cur_i_ ++];
37 | return true;
38 | }
39 |
40 | } // namespace PS
41 |
--------------------------------------------------------------------------------
/src/system/assigner.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "util/common.h"
3 | #include "util/range.h"
4 | #include "system/proto/node.pb.h"
5 | #include "data/proto/data.pb.h"
6 | namespace PS {
7 |
8 | // assign *node* with proper rank_id, key_range, etc..
9 | class NodeAssigner {
10 | public:
11 | NodeAssigner(int num_servers, Range key_range) {
12 | num_servers_ = num_servers;
13 | key_range_ = key_range;
14 | }
15 | ~NodeAssigner() { }
16 |
17 | void assign(Node* node) {
18 | Range kr = key_range_;
19 | int rank = 0;
20 | if (node->role() == Node::SERVER) {
21 | kr = key_range_.EvenDivide(num_servers_, server_rank_);
22 | rank = server_rank_ ++;
23 | } else if (node->role() == Node::WORKER) {
24 | rank = worker_rank_ ++;
25 | }
26 | node->set_rank(rank);
27 | kr.To(node->mutable_key());
28 | }
29 |
30 | void remove(const Node& node) {
31 | // TODO
32 | }
33 | protected:
34 | int num_servers_ = 0;
35 | int server_rank_ = 0;
36 | int worker_rank_ = 0;
37 | Range key_range_;
38 | };
39 |
40 | // divide *data* into *num* parts.
41 | class DataAssigner {
42 | public:
43 | DataAssigner() { }
44 | DataAssigner(const DataConfig& data, int num, bool local) {
45 | set(data, num, local);
46 | }
47 | ~DataAssigner() { }
48 |
49 | void set(const DataConfig& data, int num, bool local);
50 | bool next(DataConfig *data);
51 |
52 | int cur_i() { return cur_i_; }
53 | int size() { return parts_.size(); }
54 | private:
55 | std::vector parts_;
56 | int cur_i_ = 0;
57 | };
58 |
59 | } // namespace PS
60 |
--------------------------------------------------------------------------------
/src/system/dashboard.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "system/message.h"
3 | #include "system/proto/heartbeat.pb.h"
4 |
5 | namespace PS {
6 |
7 | struct NodeIDCmp {
8 | void splitNodeID(const NodeID& in, string& primary, string& secondary);
9 | bool operator()(const NodeID& a, const NodeID& b);
10 | };
11 |
12 | class Dashboard {
13 | public:
14 | void addTask(const NodeID& node, int task_id);
15 | void addReport(const NodeID& node, const string& report);
16 | string report();
17 | private:
18 | string title();
19 | string report(const NodeID& node, const HeartbeatReport& report);
20 | std::mutex mu_;
21 | std::map data_;
22 | };
23 |
24 | } // namespace PS
25 |
--------------------------------------------------------------------------------
/src/system/env.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | namespace PS {
3 |
4 | /**
5 | * @brief Setups environment
6 | */
7 | class Env {
8 | public:
9 | Env() { }
10 | ~Env() { }
11 |
12 | void Init(char* argv0);
13 | private:
14 | void InitGlog(char* argv0);
15 | void InitDMLC();
16 | void AssembleMyNode();
17 |
18 | };
19 |
20 | } // namespace PS
21 |
--------------------------------------------------------------------------------
/src/system/heartbeat_info.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "system/proto/heartbeat.pb.h"
4 | #include "util/resource_usage.h"
5 | #include "util/common.h"
6 |
7 | namespace PS {
8 | class HeartbeatInfo {
9 | public:
10 | enum class TimerType : unsigned char {
11 | BUSY = 0,
12 | NUM
13 | };
14 |
15 | public:
16 | HeartbeatInfo();
17 | ~HeartbeatInfo();
18 | HeartbeatInfo(const HeartbeatInfo& other) = delete;
19 | HeartbeatInfo& operator= (const HeartbeatInfo& rhs) = delete;
20 |
21 | HeartbeatReport get();
22 |
23 | // set network interface which is under use
24 | // such as "eth0"
25 | // set hostname
26 | void init(const string& interface, const string& hostname);
27 |
28 | void startTimer(const HeartbeatInfo::TimerType type);
29 | void stopTimer(const HeartbeatInfo::TimerType type);
30 |
31 | // TODO need lock?
32 | void increaseInBytes(const size_t delta) { Lock l(mu_); in_bytes_ += delta; }
33 | void increaseOutBytes(const size_t delta) { Lock l(mu_); out_bytes_ += delta; }
34 |
35 | private:
36 | std::vector timers_;
37 | MilliTimer total_timer_;
38 |
39 | size_t in_bytes_;
40 | size_t out_bytes_;
41 |
42 | string interface_;
43 | string hostname_;
44 |
45 | // snapshot of performance counters
46 | struct Snapshot {
47 | uint64 process_user;
48 | uint64 process_sys;
49 | uint64 host_user;
50 | uint64 host_sys;
51 | uint64 host_cpu;
52 |
53 | uint64 host_in_bytes;
54 | uint64 host_out_bytes;
55 |
56 | Snapshot() :
57 | process_user(0),
58 | process_sys(0),
59 | host_user(0),
60 | host_sys(0),
61 | host_cpu(0),
62 | host_in_bytes(0),
63 | host_out_bytes(0) {
64 | // do nothing
65 | }
66 |
67 | string shortDebugString() {
68 | std::stringstream ss;
69 | ss << "{";
70 | ss << "process_user: " << process_user << ", ";
71 | ss << "process_sys: " << process_sys << ", ";
72 | ss << "host_user: " << host_user << ", ";
73 | ss << "host_sys: " << host_sys << ", ";
74 | ss << "host_cpu: " << host_cpu << ", ";
75 | ss << "host_in_bytes: " << host_in_bytes << ", ";
76 | ss << "host_out_bytes: " << host_out_bytes;
77 | ss << "}";
78 |
79 | return ss.str();
80 | }
81 | }; // struct Snapshot
82 |
83 | HeartbeatInfo::Snapshot last_;
84 | HeartbeatInfo::Snapshot dump();
85 |
86 | std::mutex mu_;
87 | size_t cpu_core_number_;
88 | }; // class Heartbeatinfo
89 | }; // namespace PS
90 |
--------------------------------------------------------------------------------
/src/system/manager.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "util/common.h"
3 | #include "system/proto/node.pb.h"
4 | #include "system/proto/task.pb.h"
5 | #include "system/van.h"
6 | #include "system/env.h"
7 | #include "system/assigner.h"
8 | namespace PS {
9 |
10 | class App;
11 | class Customer;
12 |
13 | class Manager {
14 | public:
15 | Manager();
16 | ~Manager();
17 |
18 | void Init(char* argv0);
19 | void Run();
20 | void Stop();
21 | bool Process(Message* msg);
22 |
23 | // manage nodes
24 | void AddNode(const Node& node);
25 | void RemoveNode(const NodeID& node_id);
26 | // detect that *node_id* is disconnected
27 | void NodeDisconnected(const NodeID node_id);
28 | // add a function handler which will be called in *nodeDisconnected*
29 | typedef std::function NodeFailureHandler;
30 | void AddNodeFailureHandler(NodeFailureHandler handler) {
31 | node_failure_handlers_.push_back(handler);
32 | }
33 |
34 | // manage customer
35 | Customer* customer(int id);
36 | void AddCustomer(Customer* obj);
37 | void RemoveCustomer(int id);
38 | int NextCustomerID();
39 |
40 | // workers and servers
41 | void WaitServersReady();
42 | void WaitWorkersReady();
43 |
44 | int num_workers() { return num_workers_; }
45 | int num_servers() { return num_servers_; }
46 |
47 | // manage message TODO
48 | void AddRequest(Message* msg) { delete msg; }
49 | void AddResponse(Message* msg) { }
50 |
51 | // accessors
52 | Van& van() { return van_; }
53 | App* app() { return app_; }
54 |
55 | private:
56 | bool IsScheduler() { return van_.my_node().role() == Node::SCHEDULER; }
57 | Task NewControlTask(Control::Command cmd);
58 | void SendTask(const NodeID& recver, const Task& task);
59 | void SendTask(const Node& recver, const Task& task) {
60 | SendTask(recver.id(), task);
61 | }
62 |
63 | // the app
64 | void CreateApp(const string& conf);
65 | App* app_ = nullptr;
66 | string app_conf_;
67 | // std::promise my_node_promise_;
68 |
69 | // nodes
70 | std::map nodes_;
71 | std::mutex nodes_mu_;
72 | int num_workers_ = 0;
73 | int num_servers_ = 0;
74 | int num_active_nodes_ = 0;
75 | std::vector node_failure_handlers_;
76 | bool is_my_node_inited_ = false;
77 |
78 | // only available at the scheduler node
79 | NodeAssigner* node_assigner_ = nullptr;
80 |
81 | // customers
82 | // format: >
83 | std::map> customers_;
84 |
85 | bool done_ = false;
86 | bool in_exit_ = false;
87 | int time_ = 0;
88 |
89 | Van van_;
90 | Env env_;
91 |
92 | DISALLOW_COPY_AND_ASSIGN(Manager);
93 | };
94 |
95 | } // namespace PS
96 |
--------------------------------------------------------------------------------
/src/system/message.cc:
--------------------------------------------------------------------------------
1 | #include "system/message.h"
2 | namespace PS {
3 |
4 | // Message::Message(const NodeID& dest, int time, int wait_time)
5 | // : recver(dest) {
6 | // task.set_time(time);
7 | // if (wait_time != kInvalidTime) task.add_wait_time(wait_time);
8 | // }
9 |
10 | FilterConfig* Message::add_filter(FilterConfig::Type type) {
11 | auto ptr = task.add_filter();
12 | ptr->set_type(type);
13 | return ptr;
14 | }
15 |
16 | size_t Message::mem_size() {
17 | size_t nbytes = task.SpaceUsed() + key.MemSize();
18 | for (const auto& v : value) nbytes += v.MemSize();
19 | return nbytes;
20 | }
21 |
22 | std::string Message::ShortDebugString() const {
23 | std::stringstream ss;
24 | if (key.size()) ss << "key [" << key.size() << "] ";
25 | if (value.size()) {
26 | ss << "value [";
27 | for (int i = 0; i < value.size(); ++i) {
28 | ss << value[i].size();
29 | if (i < value.size() - 1) ss << ",";
30 | }
31 | ss << "] ";
32 | }
33 | auto t = task; t.clear_msg(); ss << t.ShortDebugString();
34 | return ss.str();
35 | }
36 |
37 | std::string Message::DebugString() const {
38 | std::stringstream ss;
39 | ss << "[message]: " << sender << "=>" << recver
40 | << "[task]:" << task.ShortDebugString()
41 | << "\n[key]:" << key.size()
42 | << "\n[" << value.size() << " value]: ";
43 | for (const auto& x: value)
44 | ss << x.size() << " ";
45 | return ss.str();
46 | }
47 |
48 |
49 | } // namespace PS
50 |
--------------------------------------------------------------------------------
/src/system/monitor.h:
--------------------------------------------------------------------------------
1 | /**
2 | * @file monitor.h
3 | * @brief A distributed monitor
4 | *
5 | */
6 | #pragma once
7 | #include "system/customer.h"
8 | namespace PS {
9 |
10 | /**
11 | * @brief The master of the monitor, which collects reports from slavers and
12 | * display the progress
13 | *
14 | * @tparam Progress A proto buffer class
15 | */
16 | template
17 | class MonitorMaster : public Customer {
18 | public:
19 | MonitorMaster(int id = NextCustomerID()) : Customer(id) {}
20 |
21 | typedef std::function*)> Printer;
23 | /**
24 | * @brief set the printer
25 | *
26 | * @param time_interval in sec
27 | * @param printer
28 | */
29 | void set_printer(double time_interval, Printer printer) {
30 | timer_.start();
31 | printer_ = printer;
32 | interval_ = time_interval;
33 | }
34 |
35 | typedef std::function Merger;
36 | /**
37 | * @brief set the merger
38 | *
39 | * @param merger merges two reports
40 | */
41 | void set_merger(Merger merger) {
42 | merger_ = merger;
43 | }
44 |
45 | virtual void ProcessRequest(Message* request) {
46 | NodeID sender = request->sender;
47 | Progress prog;
48 | CHECK(prog.ParseFromString(request->task.msg()));
49 | if (merger_) {
50 | merger_(prog, &progress_[sender]);
51 | } else {
52 | progress_[sender] = prog;
53 | }
54 |
55 | double time = timer_.stop();
56 | if (time > interval_ && printer_) {
57 | total_time_ += time;
58 | printer_(total_time_, &progress_);
59 | timer_.restart();
60 | } else {
61 | timer_.start();
62 | }
63 | }
64 | private:
65 | std::unordered_map progress_;
66 | double interval_;
67 | Timer timer_;
68 | double total_time_ = 0;
69 | Merger merger_;
70 | Printer printer_;
71 | };
72 |
73 | /**
74 | * @brief A slave monitor, which report to the master monitor
75 | *
76 | * @tparam Progress a proto class
77 | */
78 | template
79 | class MonitorSlaver : public Customer {
80 | public:
81 | MonitorSlaver(const NodeID& master, int id = NextCustomerID())
82 | : Customer(id), master_(master) { }
83 | virtual ~MonitorSlaver() { }
84 |
85 | /**
86 | * @brief Sends a report to the master
87 | *
88 | * @param prog
89 | */
90 | void Report(const Progress& prog) {
91 | string str; CHECK(prog.SerializeToString(&str));
92 | Task report; report.set_msg(str);
93 | Submit(report, master_);
94 | }
95 | protected:
96 | NodeID master_;
97 | };
98 |
99 | } // namespace PS
100 |
--------------------------------------------------------------------------------
/src/system/postoffice.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "util/common.h"
3 | #include "system/message.h"
4 | #include "util/threadsafe_queue.h"
5 | #include "system/manager.h"
6 | #include "system/heartbeat_info.h"
7 | namespace PS {
8 |
9 | class Postoffice {
10 | public:
11 | SINGLETON(Postoffice);
12 | ~Postoffice();
13 |
14 | /**
15 | * @brief Starts the system
16 | */
17 | void Run(int* argc, char***);
18 | /**
19 | * @brief Stops the system
20 | */
21 | void Stop() { manager_.Stop(); }
22 |
23 | /**
24 | * @brief Queue a message into the sending buffer, which will be sent by the
25 | * sending thread. It is thread safe.
26 | *
27 | * @param msg it will be DELETE by system after sent successfully. so do NOT
28 | * delete it before
29 | */
30 | void Queue(Message* msg);
31 |
32 | Manager& manager() { return manager_; }
33 | HeartbeatInfo& pm() { return perf_monitor_; }
34 |
35 | private:
36 | Postoffice();
37 | void Send();
38 | void Recv();
39 | bool Process(Message* msg);
40 | std::unique_ptr recv_thread_;
41 | std::unique_ptr send_thread_;
42 | ThreadsafeQueue sending_queue_;
43 |
44 | Manager manager_;
45 | HeartbeatInfo perf_monitor_;
46 |
47 | // key: , value: messages will be packed
48 | std::map, std::vector> pack_;
49 | std::mutex pack_mu_;
50 |
51 | DISALLOW_COPY_AND_ASSIGN(Postoffice);
52 | };
53 |
54 | } // namespace PS
55 |
--------------------------------------------------------------------------------
/src/system/proto/heartbeat.proto:
--------------------------------------------------------------------------------
1 | package PS;
2 |
3 | message HeartbeatReport {
4 | optional int32 task_id = 1 [default = 0];
5 | optional string hostname = 14;
6 |
7 | // time stamp
8 | // latest heartbeat report the scheduler has ever received
9 | // from a specified worker/server
10 | optional uint32 seconds_since_epoch = 2;
11 |
12 | optional uint32 total_time_milli = 13;
13 | optional uint32 busy_time_milli = 3;
14 |
15 | // recv/sent bytes via zmq
16 | optional uint32 net_in_mb = 4;
17 | optional uint32 net_out_mb = 5;
18 |
19 | // user+sys (percentage)
20 | optional uint32 process_cpu_usage = 6;
21 | optional uint32 host_cpu_usage = 7;
22 |
23 | optional uint32 process_rss_mb = 8;
24 | optional uint32 process_virt_mb = 9;
25 | optional uint32 host_in_use_gb = 10;
26 | optional uint32 host_in_use_percentage = 15;
27 |
28 | // host's network in/out bandwidth usage (MB/s)
29 | optional uint32 host_net_in_bw = 11;
30 | optional uint32 host_net_out_bw = 12;
31 | }
32 |
--------------------------------------------------------------------------------
/src/system/proto/node.proto:
--------------------------------------------------------------------------------
1 | package PS;
2 | import "util/proto/range.proto";
3 |
4 | message Node {
5 | enum Role {
6 | SERVER = 0;
7 | WORKER = 1;
8 | SCHEDULER = 3; // each running application has a single scheduler
9 | GROUP = 4; // a virtual node, present a group of node
10 | UNUSED = 5; // a backup node, could turn into another node
11 | }
12 |
13 | required Role role = 1;
14 | optional string id = 2;
15 | optional int32 rank = 5;
16 | // network address
17 | optional string hostname = 3;
18 | optional int32 port = 4;
19 |
20 | optional PbRange key = 6;
21 | }
22 |
--------------------------------------------------------------------------------
/src/system/proto/task.proto:
--------------------------------------------------------------------------------
1 | package PS;
2 | import "util/proto/range.proto";
3 | import "data/proto/data.proto";
4 | import "system/proto/node.proto";
5 | import "parameter/proto/param.proto";
6 | import "filter/proto/filter.proto";
7 | import "learner/proto/sgd.proto";
8 | import "learner/proto/bcd.proto";
9 |
10 | message Task {
11 | // true: system control task, typically *ctrl* should be set
12 | // false: a task for a customer, and *customer_id* should be set
13 | optional bool control = 1 [default = false];
14 | // true: a request task
15 | // false: the response task to the request task with the same *time*
16 | optional bool request = 2 [default = false];
17 | // the unique id of a customer
18 | optional int32 customer_id = 3;
19 |
20 | // the timestamp if this task
21 | optional int32 time = 5;
22 | // the depended tasks of this one. that is, this task is executed only if all
23 | // tasks from the same node with time contained in *wait_time* are finished.
24 | // only valid if *request*=true
25 | repeated int32 wait_time = 6;
26 |
27 | // the key range of this task
28 | optional PbRange key_range = 7;
29 | // namespace of keys
30 | optional int32 key_channel = 8;
31 | // true: the message sent with this task will contain a list of keys
32 | optional bool has_key = 9 [default = false];
33 |
34 | // data type
35 | optional DataType key_type = 13;
36 | repeated DataType value_type = 14;
37 |
38 | // filters applied to the data
39 | repeated FilterConfig filter = 12;
40 |
41 | // if true, the tasks will be packed into a single one during sending
42 | optional bool more = 16 [default = false];
43 | repeated Task task = 15;
44 |
45 | // the place to store a small amount of data
46 | optional bytes msg = 17;
47 |
48 | // system control signals
49 | optional Control ctrl = 18;
50 |
51 | // parameters
52 | optional ParamCall param = 20;
53 | optional SGDCall sgd = 21;
54 | optional BCDCall bcd = 22;
55 |
56 | extensions 100 to 199;
57 | }
58 |
59 | message Control {
60 | enum Command {
61 | // a node => the scheduler
62 | REQUEST_APP = 1;
63 | REGISTER_NODE = 2;
64 | REPORT_PERF = 3;
65 | READY_TO_EXIT = 4;
66 |
67 | // the scheduler => a node
68 | ADD_NODE = 10;
69 | UPDATE_NODE = 11;
70 | REPLACE_NODE = 12;
71 | REMOVE_NODE = 13;
72 | EXIT = 14;
73 | }
74 | required Command cmd = 1;
75 | repeated Node node = 2;
76 | }
77 |
78 | enum DataType {
79 | OTHER = 0;
80 | INT8 = 1;
81 | INT16 = 2;
82 | INT32 = 3;
83 | INT64 = 4;
84 | UINT8 = 5;
85 | UINT16 = 6;
86 | UINT32 = 7;
87 | UINT64 = 8;
88 | FLOAT = 9;
89 | DOUBLE = 10;
90 | CHAR = 11;
91 | }
92 |
--------------------------------------------------------------------------------
/src/system/remote_node.cc:
--------------------------------------------------------------------------------
1 | #include "system/remote_node.h"
2 | #include "system/customer.h"
3 | #include "util/crc32c.h"
4 | #include "util/shared_array_inl.h"
5 | namespace PS {
6 |
7 | Filter* RemoteNode::FindFilterOrCreate(const FilterConfig& conf) {
8 | int id = conf.type();
9 | auto it = filters.find(id);
10 | if (it == filters.end()) {
11 | filters[id] = Filter::create(conf);
12 | it = filters.find(id);
13 | }
14 | return it->second;
15 | }
16 |
17 | void RemoteNode::EncodeMessage(Message* msg) {
18 | const auto& tk = msg->task;
19 | for (int i = 0; i < tk.filter_size(); ++i) {
20 | FindFilterOrCreate(tk.filter(i))->encode(msg);
21 | }
22 | }
23 | void RemoteNode::DecodeMessage(Message* msg) {
24 | const auto& tk = msg->task;
25 | // a reverse order comparing to encode
26 | for (int i = tk.filter_size()-1; i >= 0; --i) {
27 | FindFilterOrCreate(tk.filter(i))->decode(msg);
28 | }
29 | }
30 |
31 | void RemoteNode::AddGroupNode(RemoteNode* rnode) {
32 | CHECK_NOTNULL(rnode);
33 | // insert s into sub_nodes such as sub_nodes is still ordered
34 | int pos = 0;
35 | Range kr(rnode->node.key());
36 | while (pos < group.size()) {
37 | if (kr.InLeft(Range(group[pos]->node.key()))) {
38 | break;
39 | }
40 | ++ pos;
41 | }
42 | group.insert(group.begin() + pos, rnode);
43 | keys.insert(keys.begin() + pos, kr);
44 | }
45 |
46 | void RemoteNode::RemoveGroupNode(RemoteNode* rnode) {
47 | size_t n = group.size();
48 | CHECK_EQ(n, keys.size());
49 | for (int i = 0; i < n; ++i) {
50 | if (group[i] == rnode) {
51 | group.erase(group.begin() + i);
52 | keys.erase(keys.begin() + i);
53 | return;
54 | }
55 | }
56 | }
57 |
58 | } // namespace PS
59 |
--------------------------------------------------------------------------------
/src/system/remote_node.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "util/common.h"
3 | #include "system/proto/task.pb.h"
4 | #include "system/van.h"
5 | #include "system/postoffice.h"
6 | #include "filter/filter.h"
7 | namespace PS {
8 |
9 | // The presentation of a remote node used by Executor. It's not thread
10 | // safe, do not use them directly.
11 |
12 | // Track a request by its timestamp.
13 | class RequestTracker {
14 | public:
15 | RequestTracker() { }
16 | ~RequestTracker() { }
17 |
18 | // Returns true if timestamp "ts" is marked as finished.
19 | bool IsFinished(int ts) {
20 | return ts < 0 || (((int)data_.size() > ts) && data_[ts]);
21 | }
22 |
23 | // Mark timestamp "ts" as finished.
24 | void Finish(int ts) {
25 | CHECK_GE(ts, 0);
26 | CHECK_LT(ts, 1000000);
27 | if ((int)data_.size() <= ts) data_.resize(ts*2+5);
28 | data_[ts] = true;
29 | }
30 | private:
31 | std::vector data_;
32 | };
33 |
34 | // A remote node
35 | struct RemoteNode {
36 | public:
37 | RemoteNode() { }
38 | ~RemoteNode() {
39 | for (auto f : filters) delete f.second;
40 | }
41 |
42 | void EncodeMessage(Message* msg);
43 | void DecodeMessage(Message* msg);
44 |
45 | Node node; // the remote node
46 | bool alive = true; // aliveness
47 |
48 | // timestamp tracker
49 | RequestTracker sent_req_tracker;
50 | RequestTracker recv_req_tracker;
51 |
52 | // node group info. if "node" is a node group, then "group" contains all node
53 | // pointer in this group. otherwise, group contains "this"
54 | void AddGroupNode(RemoteNode* rnode);
55 | void RemoveGroupNode(RemoteNode* rnode);
56 | std::vector group;
57 | // keys[i] is the key range of group[i]
58 | std::vector> keys;
59 |
60 | private:
61 | Filter* FindFilterOrCreate(const FilterConfig& conf);
62 | // key: filter_type
63 | std::unordered_map filters;
64 |
65 | };
66 |
67 |
68 | } // namespace PS
69 |
--------------------------------------------------------------------------------
/src/system/task_tracker.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include "glog/logging.h"
7 |
8 | namespace PS {
9 |
10 | // track a task by using its timestamp, thread safe
11 | class TaskTracker {
12 | public:
13 | TaskTracker() { }
14 | ~TaskTracker() { }
15 |
16 | // wait until task k has been finished
17 | void wait(int k) {
18 | ULock l(mu_);
19 | cond_.wait(l, [this, k]{ return (task_.count(k) && task_[k] == true); });
20 | }
21 |
22 | // non-blocking wait
23 | bool tryWait(int k) {
24 | Lock l(mu_);
25 | return (task_.count(k) && task_[k] == true);
26 | }
27 |
28 | // start task k, do nothing if it has been started or finished
29 | void start(int k) {
30 | Lock l(mu_);
31 | if (!task_.count(k)) task_[k] = false;
32 | }
33 |
34 | // whether or not task k has been finished
35 | bool hasFinished(int k) {
36 | Lock l(mu_);
37 | return (task_.count(k) && task_[k] == true);
38 | }
39 |
40 | // finish k, no warning if it has not been started or has been finished
41 | void finish(int k) {
42 | Lock l(mu_);
43 | task_[k] = true;
44 | cond_.notify_all();
45 | }
46 |
47 | private:
48 | typedef std::lock_guard Lock;
49 | typedef std::unique_lock ULock;
50 |
51 | std::mutex mu_;
52 | std::condition_variable cond_;
53 |
54 | std::map task_;
55 | };
56 |
57 | } // namespace PS
58 |
--------------------------------------------------------------------------------
/src/system/van.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "util/common.h"
3 | #include "system/proto/node.pb.h"
4 | #include "system/message.h"
5 | namespace PS {
6 |
7 | /**
8 | * @brief Van sends (receives) packages to (from) a node The current
9 | * implementation uses ZeroMQ
10 | *
11 | */
12 | class Van {
13 | public:
14 | Van() { }
15 | ~Van();
16 |
17 | void Init();
18 |
19 | void Disconnect(const Node& node);
20 | bool Connect(const Node& node);
21 |
22 | bool Send(Message* msg, size_t* send_bytes);
23 | bool Recv(Message* msg, size_t* recv_bytes);
24 |
25 | static Node ParseNode(const string& node_str);
26 |
27 | Node& my_node() { return my_node_; }
28 | Node& scheduler() { return scheduler_; };
29 | private:
30 | // bind to my port
31 | void Bind();
32 |
33 | static void FreeData(void *data, void *hint) {
34 | if (hint == NULL) {
35 | delete [] (char*)data;
36 | } else {
37 | delete (SArray*)hint;
38 | }
39 | }
40 |
41 | bool IsScheduler() { return my_node_.role() == Node::SCHEDULER; }
42 | // for scheduler: monitor the liveness of all other nodes
43 | // for other nodes: monitor the liveness of the scheduler
44 | void Monitor();
45 |
46 | void *context_ = nullptr;
47 | void *receiver_ = nullptr;
48 | Node my_node_;
49 | Node scheduler_;
50 | std::unordered_map senders_;
51 |
52 | DISALLOW_COPY_AND_ASSIGN(Van);
53 |
54 | // TODO move to postoffice::perf_monitor_
55 | // print statistic info
56 | void Statistic();
57 | std::unordered_map hostnames_;
58 | size_t sent_to_local_ = 0;
59 | size_t sent_to_others_ = 0;
60 | size_t received_from_local_ = 0;
61 | size_t received_from_others_ = 0;
62 |
63 | // for monitor
64 | std::unordered_map fd_to_nodeid_;
65 | std::mutex fd_to_nodeid_mu_;
66 | std::thread* monitor_thread_;
67 |
68 | // debug performance
69 | // double send_time_ = 0;
70 | // double recv_time_ = 0;
71 | // int num_call_ = 0;
72 | };
73 |
74 | } // namespace PS
75 |
--------------------------------------------------------------------------------
/src/test/aggregation_ps.cc:
--------------------------------------------------------------------------------
1 | #include "ps.h"
2 | namespace PS {
3 |
4 | DEFINE_int32(n, 100, "# of aggregation");
5 | DEFINE_int32(interval, 100000, "time (usec) between two aggregation");
6 |
7 | class Server : public App {
8 | public:
9 | virtual void Run() {
10 | WaitWorkersReady();
11 | for (int i = 0; i < FLAGS_n; ++i) {
12 | WaitReceivedRequest(i, kWorkerGroup);
13 | LL << MyNodeID() << " " << i;
14 | }
15 | LL << MyNodeID() << " done";
16 | }
17 | };
18 |
19 | class Worker : public App {
20 | public:
21 | virtual void Run() {
22 | for (int i = 0; i < FLAGS_n; ++i) {
23 | int ts = Submit(Task(), kServerGroup);
24 | usleep(FLAGS_interval);
25 | Wait(ts);
26 | LL << MyNodeID() << " " << i;
27 | }
28 | LL << MyNodeID() << " done";
29 | }
30 | };
31 |
32 | App* App::Create(const std::string& conf) {
33 | if (IsWorker()) return new Worker();
34 | if (IsServer()) return new Server();
35 | return new App();
36 | }
37 |
38 | } // namespace PS
39 |
40 | int main(int argc, char *argv[]) {
41 | return PS::RunSystem(argc, argv);
42 | }
43 |
--------------------------------------------------------------------------------
/src/test/assign_op_test.cc:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "util/assign_op.h"
3 | using namespace PS;
4 | // evaluation the performance of assignop comparing to the plain version
5 |
6 | size_t n = 1000000000;
7 |
8 | TEST(AssignOp, OpPlus) {
9 | double a = 0;
10 | double b = 1;
11 | for (int i = 0; i < n; ++i) {
12 | AssignOp(a, b, AssignOpType::PLUS);
13 | }
14 | EXPECT_EQ(a,(double)n);
15 | }
16 |
17 | TEST(AssignOp, PlusPlain) {
18 | double a = 0;
19 | double b = 1;
20 | for (int i = 0; i < n; ++i) {
21 | a += b;
22 | }
23 | EXPECT_EQ(a,(double)n);
24 | }
25 |
26 | TEST(AssignOp, OpSet) {
27 | double a = 0;
28 | double b = 1;
29 | for (int i = 0; i < n; ++i) {
30 | AssignOp(a, b, AssignOpType::ASSIGN);
31 | }
32 | EXPECT_EQ(a,b);
33 | }
34 |
35 | TEST(AssignOp, SetPlain) {
36 | double a = 0;
37 | double b = 1;
38 | for (int i = 0; i < n; ++i) {
39 | a = b;
40 | }
41 | EXPECT_EQ(a,b);
42 | }
43 |
--------------------------------------------------------------------------------
/src/test/bloom_filter_test.cc:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "util/bloom_filter.h"
3 | #include "util/block_bloom_filter.h"
4 | #include "util/shared_array_inl.h"
5 |
6 | DEFINE_int32(m, 100, "");
7 | DEFINE_int32(k, 2, "");
8 | TEST(BloomFilter, Speed) {
9 |
10 | using namespace PS;
11 | // see src/test/prepare_test_data to get the data
12 | SArray key1; key1.readFromFile("../data/test/key.1");
13 | SArray key2; key2.readFromFile("../data/test/key.3");
14 | BloomFilter bloom(FLAGS_m, FLAGS_k);
15 | auto tv = tic();
16 | for (auto k : key1) bloom.insert(k);
17 | LL << key1.size() / toc(tv) << " insert per sec";
18 |
19 | tv = tic();
20 | int res = 0;
21 | for (auto k : key2) res += bloom[k];
22 | LL << key2.size() / toc(tv) << " query per sec";
23 |
24 | LL << "FPR: " << (double) res / (double) key2.size();
25 |
26 |
27 | BlockBloomFilter blk_bloom(FLAGS_m, FLAGS_k);
28 |
29 | tv = tic();
30 | for (auto k : key1) blk_bloom.insert(k);
31 | LL << key1.size() / toc(tv) << " insert per sec";
32 |
33 | tv = tic();
34 | res = 0;
35 | for (auto k : key2) res += blk_bloom[k];
36 | LL << key2.size() / toc(tv) << " query per sec";
37 |
38 | LL << "FPR: " << (double) res / (double) key2.size();
39 | }
40 |
41 | int main(int argc, char **argv) {
42 | FLAGS_logtostderr = 1;
43 | testing::InitGoogleTest(&argc, argv);
44 | google::ParseCommandLineFlags(&argc, &argv, true);
45 |
46 | return RUN_ALL_TESTS();
47 | }
48 |
--------------------------------------------------------------------------------
/src/test/build.mk:
--------------------------------------------------------------------------------
1 | test: build/hello_ps \
2 | build/aggregation_ps \
3 | build/network_perf_ps \
4 | build/kv_vector_ps \
5 | build/kv_vector_buffer_ps \
6 | build/kv_map_ps \
7 | build/kv_map_perf_ps \
8 | build/kv_layer_ps \
9 | build/kv_layer_perf_ps \
10 | build/assign_op_test \
11 | build/parallel_ordered_match_test \
12 | build/common_test
13 |
14 | build/%_ps: src/test/%_ps.cc $(PS_LIB)
15 | $(CC) $(CFLAGS) $^ $(LDFLAGS) -o $@
16 |
17 | # google test
18 | TESTFLAGS = $(TEST_MAIN) -lgtest $(LDFLAGS)
19 |
20 | build/parallel_ordered_match_test: build/util/file.o build/util/proto/*.o build/data/proto/*.pb.o
21 |
22 | build/%_test: build/test/%_test.o
23 | $(CC) $(CFLAGS) $(filter %.o %.a %.cc, $^) $(TESTFLAGS) -o $@
24 |
25 | # build/reassign_server_key_range: src/test/reassign_server_key_range.cc $(PS_LIB)
26 | # $(CC) $(CFLAGS) $^ $(LDFLAGS) -o $@
27 |
28 | # build/fixing_float_test: src/test/fixing_float_test.cc src/filter/fixing_float.h $(PS_LIB)
29 | # $(CC) $(CFLAGS) $< $(PS_LIB) $(TESTFLAGS) -o $@
30 |
--------------------------------------------------------------------------------
/src/test/common_test.cc:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "util/common.h"
3 | #include "util/resource_usage.h"
4 | #include "util/threadsafe_queue.h"
5 | using namespace PS;
6 |
7 | std::shared_ptr p(new int());
8 | inline void fun1(std::shared_ptr a) {
9 | *a += 1;
10 | }
11 |
12 | inline void fun2(std::shared_ptr& a) {
13 | *a += 1;
14 | }
15 |
16 | void run(int t) {
17 | int n = 100000;
18 | if (t == 1){
19 | for (int i = 0; i < n; ++i) fun1(p);
20 | } else {
21 | for (int i = 0; i < n; ++i) fun2(p);
22 | }
23 | }
24 |
25 | TEST(xx, xx) {
26 | *p = 1;
27 | auto tv = tic();
28 | std::thread p1(run, 1);
29 | std::thread p2(run, 1);
30 | p1.join();
31 | p2.join();
32 | LL << toc(tv) << " " << *p;
33 |
34 | *p = 1;
35 | tv = tic();
36 | std::thread p3(run, 2);
37 | std::thread p4(run, 2);
38 | p3.join();
39 | p4.join();
40 | LL << toc(tv) << " " << *p;
41 |
42 | }
43 |
44 |
45 | TEST(xx,bb) {
46 | ThreadsafeQueue> queue;
47 |
48 | std::unique_ptr a(new int());
49 | *a = 1;
50 | LL << a.get();
51 | queue.push(std::move(a));
52 | LL << a.get();
53 |
54 | std::unique_ptr b;
55 | queue.wait_and_pop(b);
56 | LL << b.get();
57 | }
58 |
--------------------------------------------------------------------------------
/src/test/countmin_test.cc:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "util/countmin.h"
3 | using namespace PS;
4 |
5 | TEST(xx, xx) {
6 | std::shared_ptr p(new int());
7 |
8 |
9 | }
10 |
11 | // class CountMinTest : public ::testing::Test {
12 | // protected:
13 | // virtual void SetUp() {
14 | // int m = 1000;
15 |
16 | // key.resize(n);
17 | // cnt.resize(n);
18 | // for (int i = 0; i < n; ++i) {
19 | // uint64 a = (uint64) rand();
20 | // uint64 b = (uint64) rand();
21 | // key[i] = a | (b << 32);
22 | // cnt[i] = rand() % m;
23 | // }
24 |
25 | // auto tv = tic();
26 | // for (int i = 0; i < n; ++i) map[key[i]] += cnt[i];
27 | // LL << toc(tv);
28 | // }
29 |
30 | // void test(int len, int k) {
31 | // CountMin cm;
32 | // cm.resize(len, k);
33 | // auto tv = tic();
34 | // cm.bulkInsert(key, cnt);
35 | // auto t = toc(tv);
36 |
37 | // double err = 0, tol = 0;
38 | // for (int i = 0; i < n; ++i) {
39 | // int a = cm.query(key[i]);
40 | // int b = map[key[i]];
41 | // EXPECT_GE(a, b);
42 | // err += (a-b);
43 | // tol += b;
44 | // }
45 | // LL << t << " " << err / tol;
46 | // }
47 | // int n = 1000000;
48 | // SArray key;
49 | // SArray cnt;
50 | // std::unordered_map map;
51 | // };
52 |
53 | // TEST_F(CountMinTest, test) {
54 |
55 | // for (int k = 1; k < 10; ++k) {
56 | // test(n*.2, k);
57 | // test(n, k);
58 | // test(n*2, k);
59 | // test(n*5, k);
60 | // test(n*10, k);
61 | // }
62 | // }
63 |
--------------------------------------------------------------------------------
/src/test/fixing_float_test.cc:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "filter/fixing_float.h"
3 |
4 | using namespace PS;
5 |
6 | TEST(FIXING_FLOAT, EncodeDecode) {
7 | MessagePtr msg(new Message());
8 | auto filter_conf = msg->add_filter(FilterConfig::FIXING_FLOAT);
9 | auto conf = filter_conf->add_fixed_point();
10 | conf->set_min_value(-90);
11 | conf->set_max_value(90);
12 | conf->set_num_bytes(3);
13 |
14 | conf = filter_conf->add_fixed_point();
15 | conf->set_num_bytes(3);
16 |
17 | SArray ax = {100.0, .1, -100.0}; msg->addValue(ax);
18 | SArray bx = {100.0, .1, -100.0}; msg->addValue(bx);
19 |
20 | FixingFloatFilter filter;
21 | filter.encode(msg);
22 | filter.decode(msg);
23 |
24 | LL << SArray(msg->value[0]);
25 | LL << SArray(msg->value[1]);
26 | }
27 |
28 |
29 | TEST(FIXING_FLOAT, Error) {
30 |
31 | }
32 |
33 | // TEST(FIXING_FLOAT, BoolRand) {
34 | // int n = 100000;
35 | // SArray res(n);
36 | // int seed = time(NULL);
37 | // for (int i = 0; i < n; ++i) {
38 | // res[i] = FixingFloatFilter::boolrand(&seed);
39 | // }
40 | // LL << res.mean() << " " << res.std();
41 | // }
42 |
--------------------------------------------------------------------------------
/src/test/hello_ps.cc:
--------------------------------------------------------------------------------
1 | #include "ps.h"
2 | namespace PS {
3 |
4 | class Server : public App {
5 | public:
6 | virtual void ProcessRequest(Message* req) {
7 | std::cout << MyNodeID() << ": processing request " << req->task.time() <<
8 | " from " << req->sender << std::endl;
9 | }
10 | };
11 |
12 | class Worker : public App {
13 | public:
14 | virtual void ProcessResponse(Message* res) {
15 | std::cout << MyNodeID() << ": received response " << res->task.time() <<
16 | " from " << res->sender << std::endl;
17 | }
18 |
19 | virtual void Run() {
20 | int ts = Submit(Task(), kServerGroup);
21 | Wait(ts);
22 |
23 | ts = Submit(Task(), kServerGroup);
24 | Wait(ts);
25 |
26 | Message req;
27 | req.recver = kServerGroup;
28 | req.callback = [this]() {
29 | std::cout << MyNodeID() << ": request " << LastResponse()->task.time() <<
30 | " is finished" << std::endl;
31 | };
32 | Wait(Submit(&req));
33 | }
34 | };
35 |
36 | App* App::Create(const std::string& conf) {
37 | if (IsWorker()) return new Worker();
38 | if (IsServer()) return new Server();
39 | return new App();
40 | }
41 |
42 | } // namespace PS
43 |
44 | int main(int argc, char *argv[]) {
45 | return PS::RunSystem(argc, argv);
46 | }
47 |
--------------------------------------------------------------------------------
/src/test/kv_layer_perf_ps.cc:
--------------------------------------------------------------------------------
1 | /**
2 | * @brief Performance test of KVLayer
3 | */
4 | #include "ps.h"
5 | #include "parameter/kv_layer.h"
6 | #include "util/resource_usage.h"
7 | namespace PS {
8 | typedef int V; // value type
9 |
10 | class Updater {
11 | public:
12 | void Init(int id, size_t size, V* data) {
13 | memset(data, 0, sizeof(V)*size);
14 | }
15 |
16 | void Update(int id, size_t size, const V* recv_data, V* data) {
17 | // sum
18 | for (int i = 0; i < size; ++i) {
19 | data[i] += recv_data[i];
20 | }
21 | }
22 | };
23 |
24 | class Server : public App {
25 | public:
26 | Server() {
27 | model_.set_updater(&updt_);
28 | }
29 | private:
30 | KVLayer model_;
31 | Updater updt_;
32 | };
33 |
34 | class Worker : public App {
35 | public:
36 | virtual void Run() {
37 | std::cout << MyNodeID() << ": this is worker " << MyRank() << std::endl;
38 |
39 | // alexnet
40 | SArray layer_size =
41 | {11*11*96, 5*5*256, 3*3*284, 3*3*256, 43264*4096, 4096*4096, 4096*1000};
42 | int n = layer_size.size();
43 |
44 | std::vector> layers(n);
45 | for (int i = 0; i < n; ++i) layers[i].resize(layer_size[i]);
46 |
47 | auto tv = tic();
48 | std::vector pull_time(n);
49 | for (int i = 0; i < n; ++i) {
50 | auto& val = layers[i];
51 | val.SetValue(1);
52 | int ts = model_.Push(
53 | Parameter::Request(i), val.data(), val.size());
54 | pull_time[i] = model_.Pull(
55 | Parameter::Request(i, -1, {ts}), val.data(), val.size());
56 | }
57 |
58 | for (int i = 0; i < n; ++i) {
59 | model_.Wait(pull_time[i]);
60 | const auto& val = model_[i];
61 | // for (int j = 0; j < val.size(); ++j) {
62 | // CHECK_EQ(val[j], val[0]);
63 | // CHECK_LE(val[j], RankSize());
64 | // }
65 | }
66 | LL << (double)layer_size.Sum() * sizeof(V) / toc(tv) / 1e6;
67 | }
68 | private:
69 | KVLayer model_;
70 | };
71 |
72 | App* App::Create(const std::string& conf) {
73 | if (IsWorker()) return new Worker();
74 | if (IsServer()) return new Server();
75 | return new App();
76 | }
77 |
78 | } // namespace PS
79 |
80 | int main(int argc, char *argv[]) {
81 | return PS::RunSystem(argc, argv);
82 | }
83 |
--------------------------------------------------------------------------------
/src/test/kv_layer_ps.cc:
--------------------------------------------------------------------------------
1 | /**
2 | * @brief Simple test of KVLayer
3 | */
4 | #include "ps.h"
5 | #include "parameter/kv_layer.h"
6 | namespace PS {
7 | typedef uint64 K; // key type
8 | typedef int V; // value type
9 |
10 | class Updater {
11 | public:
12 | void Init(int id, size_t size, V* data) {
13 | memset(data, 0, sizeof(V)*size);
14 | }
15 |
16 | void Update(int id, size_t size, const V* recv_data, V* data) {
17 | // sum
18 | for (int i = 0; i < size; ++i) {
19 | data[i] += recv_data[i];
20 | }
21 | }
22 | };
23 |
24 | class Server : public App {
25 | public:
26 | Server() {
27 | model_.set_updater(&updt_);
28 | }
29 | private:
30 | KVLayer model_;
31 | Updater updt_;
32 | };
33 |
34 | class Worker : public App {
35 | public:
36 | virtual void Run() {
37 | std::cout << MyNodeID() << ": this is worker " << MyRank() << std::endl;
38 |
39 | std::vector layer_size = {5, 10, 4};
40 | int n = layer_size.size();
41 |
42 | std::vector> layers(n);
43 | for (int i = 0; i < n; ++i) layers[i].resize(layer_size[i]);
44 |
45 | std::vector pull_time(n);
46 | for (int i = 0; i < n; ++i) {
47 | auto& val = layers[i];
48 | val.SetValue(1);
49 | int ts = model_.Push(
50 | Parameter::Request(i), val.data(), val.size());
51 | pull_time[i] = model_.Pull(
52 | Parameter::Request(i, -1, {ts}), val.data(), val.size(),
53 | [i, this](){
54 | LL << "layer " << i << " " << model_[i];
55 | });
56 | }
57 |
58 | for (int i = 0; i < n; ++i) {
59 | model_.Wait(pull_time[i]);
60 | }
61 | sleep(1);
62 | }
63 | private:
64 | KVLayer model_;
65 | };
66 |
67 | App* App::Create(const std::string& conf) {
68 | if (IsWorker()) return new Worker();
69 | if (IsServer()) return new Server();
70 | return new App();
71 | }
72 |
73 | } // namespace PS
74 |
75 | int main(int argc, char *argv[]) {
76 | return PS::RunSystem(argc, argv);
77 | }
78 |
--------------------------------------------------------------------------------
/src/test/kv_map_perf_ps.cc:
--------------------------------------------------------------------------------
1 | /**
2 | * @brief Performance test of KVMap
3 | */
4 | #include
5 | #include "ps.h"
6 | #include "parameter/kv_map.h"
7 | #include "parameter/kv_vector.h"
8 | #include "util/shared_array_inl.h"
9 | #include "util/resource_usage.h"
10 | namespace PS {
11 | DEFINE_int32(n, 10, "repeat n times");
12 |
13 | typedef uint64 K; // key
14 | typedef float V; // value type
15 |
16 | struct Entry {
17 | void Get(V* data, void* state) { *data = value; }
18 | void Set(const V* data, void* state) { value += *data; }
19 | V value = 0;
20 | };
21 |
22 | class Server : public App {
23 | private:
24 | KVMap vec_;
25 | };
26 |
27 | class Worker : public App {
28 | public:
29 | virtual void Run() {
30 | std::random_device rd;
31 | std::mt19937 gen(rd());
32 | std::uniform_int_distribution<> dis(0, 21);
33 | size_t bytes = 0;
34 | auto tv = tic();
35 | for (int i = 0; i < FLAGS_n; ++i) {
36 |
37 | int k = dis(gen);
38 | SArray key;
39 | key.ReadFromFile("../test/keys/key_" + std::to_string(k));
40 | SArray val(key.size());
41 | ParamInitConfig cf;
42 | cf.set_type(ParamInitConfig::GAUSSIAN);
43 | cf.set_mean(0);
44 | cf.set_std(1);
45 | val.SetValue(cf);
46 |
47 | int ts = vec_.Push(Parameter::Request(i), key, {val});
48 | vec_.Wait(vec_.Pull(Parameter::Request(i, ts+1, {ts}), key));
49 | CHECK_EQ(vec_[i].value.size(), key.size());
50 | vec_.Clear(i);
51 | bytes += key.size() * (sizeof(K) + sizeof(V)) * 2;
52 | }
53 |
54 | double thr = (double) bytes / toc(tv) / 1000000;
55 | printf("%s: %d push/pull, throughput %.3lf MB/sec\n",
56 | MyNodeID().c_str(), FLAGS_n, thr);
57 | }
58 | private:
59 | KVVector vec_;
60 | };
61 |
62 | App* App::Create(const std::string& conf) {
63 | if (IsWorker()) return new Worker();
64 | if (IsServer()) return new Server();
65 | return new App();
66 | }
67 |
68 | } // namespace PS
69 |
70 | int main(int argc, char *argv[]) {
71 | return PS::RunSystem(argc, argv);
72 | }
73 |
--------------------------------------------------------------------------------
/src/test/kv_map_ps.cc:
--------------------------------------------------------------------------------
1 | /**
2 | * @brief Simple test of KVMap
3 | */
4 | #include "ps.h"
5 | #include "parameter/kv_map.h"
6 | #include "parameter/kv_vector.h"
7 | namespace PS {
8 | typedef uint64 K; // key
9 | typedef float V; // value type
10 |
11 | struct Entry {
12 | void Get(V* data, void* state) { *data = value; }
13 | void Set(const V* data, void* state) { value += *data; }
14 | V value = 0;
15 | };
16 |
17 | class Server : public App {
18 | private:
19 | KVMap vec_;
20 | };
21 |
22 | class Worker : public App {
23 | public:
24 | virtual void Run() {
25 | std::cout << MyNodeID() << ": this is worker " << MyRank() << std::endl;
26 |
27 | SArray key;
28 | if (MyRank() == 0) {
29 | key = {0, 2, 4, 5};
30 | } else {
31 | key = {0, 1, 3, 4};
32 | }
33 | SArray val = {1, 1, 1, 1};
34 |
35 | vec_.Wait(vec_.Push(Parameter::Request(0), key, {val}));
36 | vec_.Wait(vec_.Pull(Parameter::Request(0), key));
37 |
38 | std::cout << MyNodeID() << ": pulled value in channel 0 " << vec_[0].value
39 | << std::endl;
40 | }
41 | private:
42 | KVVector vec_;
43 | };
44 |
45 | App* App::Create(const std::string& conf) {
46 | if (IsWorker()) return new Worker();
47 | if (IsServer()) return new Server();
48 | return new App();
49 | }
50 |
51 | } // namespace PS
52 |
53 | int main(int argc, char *argv[]) {
54 | return PS::RunSystem(argc, argv);
55 | }
56 |
--------------------------------------------------------------------------------
/src/test/kv_vector_buffer_ps.cc:
--------------------------------------------------------------------------------
1 | /**
2 | * @brief Simple test of buffered KVVector
3 | */
4 | #include "ps.h"
5 | #include "parameter/kv_vector.h"
6 | namespace PS {
7 | typedef uint64 K; // key
8 | typedef float V; // value type
9 |
10 | class Server : public App {
11 | public:
12 | Server() : vec_(true, 2) {
13 | // channel 4
14 | vec_[4].key = {0, 1, 3, 4, 5};
15 | }
16 |
17 | virtual void Run() {
18 | // aggregate data received from all workers
19 | WaitWorkersReady();
20 | int ts = 0;
21 | vec_.WaitReceivedRequest(ts, kWorkerGroup);
22 | auto recv = vec_.buffer(ts);
23 | vec_[4].value = recv.values[0];
24 | vec_[4].value.vec() += recv.values[1].vec();
25 | vec_.FinishReceivedRequest(ts+1, kWorkerGroup);
26 | }
27 | private:
28 | KVVector vec_;
29 | };
30 |
31 | class Worker : public App {
32 | public:
33 | Worker() : vec_(false, 2) { }
34 |
35 | virtual void Run() {
36 | std::cout << MyNodeID() << ": this is worker " << MyRank() << std::endl;
37 |
38 | SArray key;
39 | if (MyRank() == 0) {
40 | key = {0, 2, 4, 5};
41 | } else {
42 | key = {0, 1, 3, 4};
43 | }
44 | // push [1 1 1 1 and [3 3 3 3 into servers
45 | // 2 2 2 2] 4 4 4 4]
46 | SArray val1 = {1, 2, 1, 2, 1, 2, 1, 2};
47 | SArray val2 = {3, 4, 3, 4, 3, 4, 3, 4};
48 |
49 | int ts = vec_.Push(Parameter::Request(4), key, {val1, val2});
50 | // this pull request will depends on a virtual timestamp ts+1 which is used for
51 | // aggregation
52 | vec_.Wait(vec_.Pull(Parameter::Request(4, ts+2, {ts+1}), key));
53 |
54 | std::cout << MyNodeID() << ": pulled value in channel 4 " << vec_[4].value
55 | << std::endl;
56 | }
57 | private:
58 | KVVector vec_;
59 | };
60 |
61 | App* App::Create(const std::string& conf) {
62 | if (IsWorker()) return new Worker();
63 | if (IsServer()) return new Server();
64 | return new App();
65 | }
66 |
67 | } // namespace PS
68 |
69 | int main(int argc, char *argv[]) {
70 | return PS::RunSystem(argc, argv);
71 | }
72 |
--------------------------------------------------------------------------------
/src/test/kv_vector_perf_ps.cc:
--------------------------------------------------------------------------------
1 | /**
2 | * @file kv_vector_perf_ps.cc
3 | * @author Mu Li
4 | * @date Wed Mar 18 16:34:31 2015
5 | *
6 | * @brief Performance test of KVVector
7 | *
8 | *
9 | */
10 | #include "ps.h"
11 | #include "parameter/kv_vector.h"
12 | namespace PS {
13 | typedef uint64 K; // key
14 | typedef float V; // value type
15 |
16 | class Server : public App {
17 | public:
18 | Server() : vec_(true, 2) {
19 | // channel 4
20 | vec_[4].key = {0, 1, 3, 4, 5};
21 | }
22 |
23 | virtual void Run() {
24 | // aggregate data received from all workers
25 | WaitWorkersReady();
26 | int ts = 0;
27 | vec_.WaitReceivedRequest(ts, kWorkerGroup);
28 | auto recv = vec_.buffer(ts);
29 | vec_[4].value = recv.values[0];
30 | vec_[4].value.vec() += recv.values[1].vec();
31 | vec_.FinishReceivedRequest(ts+1, kWorkerGroup);
32 | }
33 | private:
34 | KVVector vec_;
35 | };
36 |
37 | class Worker : public App {
38 | public:
39 | Worker() : vec_(false, 2) { }
40 |
41 | virtual void Run() {
42 | std::cout << MyNodeID() << ": this is worker " << MyRank() << std::endl;
43 |
44 | SArray key;
45 | if (MyRank() == 0) {
46 | key = {0, 2, 4, 5};
47 | } else {
48 | key = {0, 1, 3, 4};
49 | }
50 | // push [1 1 1 1 and [3 3 3 3 into servers
51 | // 2 2 2 2] 4 4 4 4]
52 | SArray val1 = {1, 2, 1, 2, 1, 2, 1, 2};
53 | SArray val2 = {3, 4, 3, 4, 3, 4, 3, 4};
54 |
55 | int ts = vec_.Push(Parameter::Request(4), key, {val1, val2});
56 | // this pull request will depends on a virtual timestamp ts+1 which is used for
57 | // aggregation
58 | vec_.Wait(vec_.Pull(Parameter::Request(4, ts+2, {ts+1}), key));
59 |
60 | std::cout << MyNodeID() << ": pulled value in channel 4 " << vec_[4].value
61 | << std::endl;
62 | }
63 | private:
64 | KVVector vec_;
65 | };
66 |
67 | App* App::Create(const std::string& conf) {
68 | if (IsWorker()) return new Worker();
69 | if (IsServer()) return new Server();
70 | return new App();
71 | }
72 |
73 | } // namespace PS
74 |
75 | int main(int argc, char *argv[]) {
76 | return PS::RunSystem(argc, argv);
77 | }
78 |
--------------------------------------------------------------------------------
/src/test/kv_vector_ps.cc:
--------------------------------------------------------------------------------
1 | /**
2 | * @file kv_vector_ps.cc
3 | * @author Mu Li
4 | * @date Wed Mar 18 16:34:09 2015
5 | *
6 | * @brief Simple test of KVVector
7 | *
8 | *
9 | */
10 |
11 | #include "ps.h"
12 | #include "parameter/kv_vector.h"
13 | namespace PS {
14 | typedef uint64 K; // key type
15 | typedef int V; // value type
16 |
17 | class Server : public App {
18 | public:
19 | Server() : vec_() {
20 | // channel 0
21 | vec_[0].key = {0, 1, 3, 4, 5};
22 | vec_[0].value = {1, 2, 3, 4, 5};
23 |
24 | // channel 1
25 | vec_[1].key = {0, 1, 3, 4, 5};
26 | vec_[1].value = {2, 3, 4, 5, 6};
27 | }
28 |
29 | private:
30 | KVVector vec_;
31 | };
32 |
33 | class Worker : public App {
34 | public:
35 | Worker() : vec_() { }
36 |
37 | virtual void Run() {
38 | std::cout << MyNodeID() << ": this is worker " << MyRank() << std::endl;
39 |
40 | SArray key;
41 | if (MyRank() == 0) {
42 | key = {0, 2, 4, 5};
43 | } else {
44 | key = {0, 1, 3, 4};
45 | }
46 |
47 | int ts1 = vec_.Pull(Parameter::Request(0), key);
48 | int ts2 = vec_.Pull(Parameter::Request(1), key);
49 |
50 | vec_.Wait(ts1);
51 | std::cout << MyNodeID() << ": pulled value in channel 0 " << vec_[0].value
52 | << std::endl;
53 |
54 | vec_.Wait(ts2);
55 | std::cout << MyNodeID() << ": pulled value in channel 1 " << vec_[1].value
56 | << std::endl;
57 | }
58 | private:
59 | KVVector vec_;
60 | };
61 |
62 | App* App::Create(const std::string& conf) {
63 | if (IsWorker()) return new Worker();
64 | if (IsServer()) return new Server();
65 | return new App();
66 | }
67 |
68 | } // namespace PS
69 |
70 | int main(int argc, char *argv[]) {
71 | return PS::RunSystem(argc, argv);
72 | }
73 |
--------------------------------------------------------------------------------
/src/test/localizer_test.cc:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "util/localizer.h"
3 |
4 | using namespace PS;
5 |
6 | TEST(Localizer, RCV1) {
7 | DataConfig cache, dc;
8 | cache.add_file("/tmp/test/");
9 |
10 | dc.set_format(DataConfig::TEXT);
11 | dc.set_text(DataConfig::LIBSVM);
12 | dc.add_file("../data/rcv1_train.binary");
13 |
14 | SlotReader sr(dc, cache);
15 | ExampleInfo info;
16 | sr.read(&info);
17 |
18 | Localizer lc;
19 |
20 | SArray key;
21 | SArray freq;
22 | lc.countUniqIndex(sr.index(1), &key, &freq);
23 |
24 | EXPECT_EQ(key.eigenArray().sum(), 1051859373);
25 | EXPECT_EQ(freq.eigenArray().sum(), 1498952);
26 | EXPECT_EQ(freq.eigenArray().square().sum(), 1924492682);
27 | EXPECT_EQ(key.size(), freq.size());
28 | EXPECT_EQ(key.size(), 44504);
29 |
30 | int filter = 2;
31 | SArray f_key;
32 | for (int i = 0; i < key.size(); ++i) {
33 | if (freq[i] > filter) f_key.pushBack(key[i]);
34 | }
35 | EXPECT_EQ(f_key.size(), 19959);
36 |
37 | auto X = std::static_pointer_cast>(
38 | lc.remapIndex( 1, f_key, &sr));
39 |
40 | EXPECT_EQ(X->offset().eigenArray().sum(), 14702421805);
41 | SArray idx;
42 | for (auto k : X->index()) idx.pushBack(k);
43 | EXPECT_EQ(idx.size(), 1467683);
44 | EXPECT_EQ(idx.eigenArray().sum(), 14708054959 - idx.size());
45 | EXPECT_LT(X->value().eigenArray().sum(), 132224);
46 | EXPECT_GT(X->value().eigenArray().sum(), 132223);
47 |
48 | // LL << X->debugString();
49 | }
50 |
51 | // TEST(Localizer, ADFEA) {
52 | // DataConfig dc;
53 | // dc.set_format(DataConfig::TEXT);
54 | // dc.set_text(DataConfig::ADFEA);
55 | // dc.add_file("../../data/ctrc/train/part-000[0-1].gz");
56 | // auto data = readMatricesOrDie(searchFiles(dc));
57 |
58 | // for (int i = 1; i < data.size(); ++i) {
59 | // Localizer lc(data[i]);
60 | // SArray key;
61 | // SArray freq;
62 | // lc.countUniqIndex(&key, &freq);
63 |
64 | // int filter = 4;
65 | // SArray f_key;
66 | // for (int i = 0; i < key.size(); ++i) {
67 | // if (freq[i] > filter) f_key.pushBack(key[i]);
68 | // }
69 | // LL << f_key.size();
70 | // auto X = std::static_pointer_cast>(lc.remapIndex(f_key));
71 | // if (X) {
72 | // LL << X->index().eigenArray().maxCoeff();
73 | // }
74 | // // LL << X->debugString();
75 | // }
76 | // }
77 |
--------------------------------------------------------------------------------
/src/test/network_perf_ps.cc:
--------------------------------------------------------------------------------
1 | #include "ps.h"
2 | #include "util/resource_usage.h"
3 |
4 | DEFINE_int32(n, 1000, "repeat n times");
5 | DEFINE_int32(data_size, 1000,
6 | "data in KB sent from a worker to a server");
7 | DEFINE_bool(server_aggregation, false,
8 | "servers will aggregate the data from servs if true");
9 | namespace PS {
10 |
11 | class Server : public App {
12 | public:
13 | virtual void Run() {
14 | // if (FLAGS_server_aggregation) {
15 | // WaitWorkersReady();
16 | // auto data = split(FLAGS_data_size, ',',true);
17 | // for (int i = 0; i < data.size() * FLAGS_n; ++i) {
18 | // }
19 | // }
20 | }
21 | };
22 |
23 | class Worker : public App {
24 | public:
25 |
26 | virtual void Slice(const Message& request, const std::vector>& krs,
27 | std::vector* msgs) {
28 | for (auto m : *msgs) *m = request;
29 | }
30 |
31 | virtual void Run() {
32 | int n = FLAGS_n;
33 | int m = FLAGS_data_size;
34 | auto tv = tic();
35 | for (int j = 0; j < n; ++j) {
36 | SArray val(m*1000/sizeof(int), 1);
37 | Message msg;
38 | msg.add_value(val);
39 | msg.recver = kServerGroup;
40 | int ts = Submit(&msg);
41 | Wait(ts);
42 | }
43 | double thr = (double)m / 1000.0 * n * sys_.manager().num_servers() / toc(tv);
44 | printf("%s: packet size: %d KB, throughput %.3lf MB/sec\n",
45 | MyNodeID().c_str(), m, thr);
46 | }
47 | };
48 |
49 | App* App::Create(const std::string& conf) {
50 | if (IsWorker()) return new Worker();
51 | if (IsServer()) return new Server();
52 | return new App();
53 | }
54 |
55 | } // namespace PS
56 |
57 | int main(int argc, char *argv[]) {
58 | return PS::RunSystem(argc, argv);
59 | }
60 |
--------------------------------------------------------------------------------
/src/test/parallel_ordered_match_test.cc:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "util/parallel_ordered_match.h"
3 | #include "util/shared_array_inl.h"
4 |
5 | using namespace PS;
6 | namespace PS {
7 | DEFINE_int32(num_threads, 2, "");
8 | DEFINE_int32(k, 1, "k");
9 | } // namespace PS
10 |
11 |
12 | class PMatchTest : public ::testing::Test {
13 | protected:
14 | virtual void SetUp() {
15 | key1.ReadFromFile("../test/keys/key_1");
16 | key2.ReadFromFile("../test/keys/key_2");
17 | };
18 |
19 | SArray key1, key2;
20 | };
21 |
22 | TEST_F(PMatchTest, simple) {
23 |
24 | }
25 |
26 | TEST_F(PMatchTest, match) {
27 |
28 | int k = FLAGS_k;
29 | SArray val1(key1.size()*k, 1);
30 | SArray val2;
31 |
32 | size_t n = ParallelOrderedMatch(key1, val1, key2, &val2, k);
33 |
34 | LL << n << " " << val2;
35 | }
36 |
--------------------------------------------------------------------------------
/src/test/reassign_server_key_range_ps.cc:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "system/postoffice.h"
3 | #include "system/postmaster.h"
4 | #include "system/app.h"
5 |
6 | // build: make build/reassign_server_key_range
7 | // test: script/local.sh build/reassign_server_key_range 3 3
8 | namespace PS {
9 | class Root : public App {
10 | public:
11 | Root() : App() { }
12 | virtual ~Root() { }
13 |
14 | void init() {
15 | // repartition the key range
16 | Range range(0, 100);
17 | auto nodes = sys_.yp().nodes();
18 | nodes = Postmaster::partitionServerKeyRange(nodes, range);
19 |
20 | Task task;
21 | task.set_type(Task::MANAGE);
22 | task.mutable_mng_node()->set_cmd(ManageNode::UPDATE);
23 | for (const auto& n : nodes) {
24 | *task.mutable_mng_node()->add_node() = n;
25 | }
26 | port(kLiveGroup)->submitAndWait(task);
27 | }
28 | };
29 |
30 | class Slave : public App {
31 | public:
32 | Slave() : App() { }
33 | virtual ~Slave() { }
34 |
35 | void run() {
36 | LL << exec_.myNode().id() << " key range: "
37 | << exec_.myNode().key().ShortDebugString();
38 | }
39 | };
40 |
41 | App* App::create(const string& name, const string& conf) {
42 | auto my_role = Postoffice::instance().myNode().role();
43 | if (my_role == Node::SCHEDULER) {
44 | return new Root();
45 | } else {
46 | return new Slave();
47 | }
48 | }
49 |
50 | } // namespace PS
51 |
52 | int main(int argc, char *argv[]) {
53 | auto& sys = PS::Postoffice::instance();
54 | sys.start(&argc, &argv);
55 |
56 | sys.Stop();
57 | return 0;
58 | }
59 |
--------------------------------------------------------------------------------
/src/test/slot_reader_test.cc:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "data/slot_reader.h"
3 |
4 | using namespace PS;
5 |
6 | TEST(SlotReader, read) {
7 | DataConfig cache, dc;
8 | cache.add_file("/tmp/test/");
9 | dc.set_format(DataConfig::TEXT);
10 |
11 | // load adfea
12 | dc.set_text(DataConfig::ADFEA);
13 | dc.add_file("../../data/ctrc/train/part-000[0-1].gz");
14 |
15 | // load libsvm
16 | // dc.set_text(DataConfig::LIBSVM);
17 | // dc.add_file("../data/rcv1/train/part-.*");
18 |
19 | DataConfig dc2 = searchFiles(dc);
20 | SlotReader gr; gr.init(dc2, cache); gr.read();
21 |
22 | auto data = readMatricesOrDie(dc2);
23 |
24 | auto label = gr.value(0);
25 | // LL << label;
26 | // LL << data[0]->value();
27 | EXPECT_EQ((label.eigenVector() - data[0]->value().eigenVector()).norm(), 0);
28 |
29 | for (int i = 1; i < data.size(); ++i) {
30 | auto X = std::static_pointer_cast>(data[i]);
31 | int id = X->info().id();
32 | auto index = gr.index(id);
33 | auto offset = gr.offset(id);
34 | // LL << index;
35 | // LL << offset;
36 | EXPECT_EQ((index.eigenVector() - X->index().eigenVector()).norm(), 0);
37 | EXPECT_EQ((offset.eigenVector() - X->offset().eigenVector()).norm(), 0);
38 |
39 | if (!X->value().empty()) {
40 | auto value = gr.value(id);
41 | EXPECT_EQ((value.eigenVector() - X->value().eigenVector()).norm(), 0);
42 | }
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/src/test/stream_reader_test.cc:
--------------------------------------------------------------------------------
1 | #include "gtest/gtest.h"
2 | #include "data/stream_reader.h"
3 |
4 | using namespace PS;
5 |
6 | // TEST(StreamReader, read_proto) {
7 | // DataConfig dc;
8 | // // load adfea
9 | // dc.set_format(DataConfig::PROTO);
10 | // dc.add_file("../output/parsa_.*");
11 |
12 | // DataConfig dc2 = searchFiles(dc);
13 | // StreamReader reader; reader.init(dc2);
14 |
15 | // MatrixPtrList X;
16 | // while (reader.readMatrices(10000, &X)) {
17 | // CHECK_EQ(X.size(), 2);
18 | // }
19 | // }
20 |
21 | // TEST(StreamReader, read) {
22 | // DataConfig dc;
23 | // // load adfea
24 | // dc.set_format(DataConfig::TEXT);
25 | // dc.set_text(DataConfig::ADFEA);
26 | // dc.add_file("../../data/ctrc/train/part-000[0-1].gz");
27 | // dc.set_ignore_feature_group(true);
28 |
29 | // // load libsvm
30 | // // dc.set_text(DataConfig::LIBSVM);
31 | // // dc.add_file("../data/rcv1/train/part-.*");
32 |
33 | // DataConfig dc2 = searchFiles(dc);
34 | // StreamReader reader; reader.init(dc2);
35 |
36 | // MatrixPtrList X;
37 | // reader.readMatrices(100, &X);
38 | // }
39 |
40 |
41 | // TEST(StreamReader, convert) {
42 | // DataConfig dc;
43 | // dc.set_format(DataConfig::TEXT);
44 | // dc.set_text(DataConfig::TERAFEA);
45 | // dc.add_file("../data/toutiao/data.txt");
46 | // dc.set_ignore_feature_group(true);
47 |
48 | // DataConfig dc2 = searchFiles(dc);
49 | // StreamReader reader; reader.init(dc2);
50 |
51 | // MatrixPtrList X;
52 | // reader.readMatrices(1000000, &X);
53 | // // CHECK_EQ(X.size(), 2);
54 |
55 | // X[0]->writeToBinFile("toutiao.Y");
56 |
57 | // for (int i = 1; i < X.size(); ++i) {
58 | // Localizer localizer;
59 | // SArray