├── .gitignore
├── CHANGES.txt
├── DistML.iml
├── LICENSE
├── README.md
├── data
├── MNISTReader.java
└── mnist_prepare.sh
├── doc
├── architect.png
├── build.md
├── lr-implementation.md
├── runtime.png
├── sample_lda.md
├── sample_lr.md
└── src_tree.md
├── pom.xml
└── src
└── main
├── java
└── com
│ └── intel
│ └── distml
│ ├── api
│ ├── DMatrix.java
│ ├── DataBus.java
│ ├── Model.java
│ └── Session.java
│ ├── platform
│ ├── DataBusProtocol.java
│ ├── MonitorActor.java
│ ├── PSActor.java
│ ├── PSAgent.java
│ ├── PSManager.java
│ ├── PSSync.java
│ ├── SSP.java
│ ├── WorkerActor.java
│ └── WorkerAgent.java
│ └── util
│ ├── AbstractDataReader.java
│ ├── AbstractDataWriter.java
│ ├── ByteBufferDataReader.java
│ ├── ByteBufferDataWriter.java
│ ├── Constants.java
│ ├── DataDesc.java
│ ├── DataStore.java
│ ├── DefaultDataReader.java
│ ├── DefaultDataWriter.java
│ ├── DoubleArray.java
│ ├── DoubleMatrix.java
│ ├── IntArray.java
│ ├── IntArrayWithIntKey.java
│ ├── IntMatrix.java
│ ├── IntMatrixWithIntKey.java
│ ├── KeyCollection.java
│ ├── KeyHash.java
│ ├── KeyList.java
│ ├── KeyRange.java
│ ├── Logger.java
│ ├── SparseArray.java
│ ├── SparseMatrix.java
│ ├── Utils.java
│ └── store
│ ├── DoubleArrayStore.java
│ ├── DoubleMatrixStore.java
│ ├── FloatArrayStore.java
│ ├── FloatMatrixStore.java
│ ├── FloatMatrixStoreAdaGrad.java
│ ├── IntArrayStore.java
│ └── IntMatrixStore.java
├── main.iml
└── scala
└── com
└── intel
└── distml
├── Dict.scala
├── clustering
├── AliasTable.scala
├── LDAModel.scala
├── LDAParams.scala
└── LightLDA.scala
├── example
├── SimpleALS.scala
├── clustering
│ └── LDAExample.scala
├── feature
│ ├── MllibWord2Vec.scala
│ └── Word2VecExample.scala
└── regression
│ ├── LargeLRTest.scala
│ ├── MelBlanc.scala
│ └── Mnist.scala
├── feature
└── Word2Vec.scala
├── platform
├── Clock.scala
├── DistML.scala
├── ParamServerDriver.scala
└── PipeLine.scala
├── regression
├── LogisticRegression.scala
└── MLR.scala
└── util
└── scala
├── DoubleArray.scala
├── DoubleArrayWithIntKey.scala
├── DoubleMatrixWithIntKey.scala
├── FloatArray.scala
├── FloatMatrix.scala
├── FloatMatrixAdapGradWithIntKey.scala
├── FloatMatrixWithIntKey.scala
├── IntArray.scala
├── IntMatrix.scala
├── SparseArray.scala
├── SparseMatrix.scala
└── SparseMatrixAdapGrad.scala
/.gitignore:
--------------------------------------------------------------------------------
1 | *.log
2 | .idea
3 | bin/
4 | target
5 |
--------------------------------------------------------------------------------
/CHANGES.txt:
--------------------------------------------------------------------------------
1 | [2016-01-18]
2 | 1. Reimplemented data transfer with pure socket
3 |
4 | [2016-01-08]
5 | 1. Implemented LightLDA and MLR, now LR/LightLDA/MLR all works fine
6 | 2. Add api support for Double Array/Matrix and Int Array/Matrix
7 | 3. Add scala api for DoubleArray
8 |
9 | [2015-12-14]
10 | 1. Simplify APIs, allow users to use naive data for customized algorithms
11 | 2. With API change, only LR and LDA are remained to proove API is working well
12 | 3. Run examples on yarn and can use spark submit
13 |
14 | [2015-10-15]
15 | 1. Add environment variable "PS_NETWORK_PREFIX", which is used by parameter servers
16 | 2. Fix issue in handling messages after iteration done
17 |
18 | [2015-10-12]
19 | 1. Add serialization support for CNN classes
20 |
21 | [2015-08-21]
22 | 1. add draft lda support
23 | 2. support iterative UDF in serverside
24 |
25 | [2015-07-01]
26 | 1. removed feature of worker group to reduce code complexity, a new branch "workergrou" is created
27 | 2. use akka-io for data transport, which allows big data block transferring
28 | 3. start parameter servers in a separated thread, allows iterative training without resparting parameter servers
29 |
--------------------------------------------------------------------------------
/DistML.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/intel-machine-learning/DistML/10aaa292d8d48158f97b7f2a439711d1aec42279/LICENSE
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DistML (Distributed Machine Learning platform)
2 |
3 | DistML is a machine learning tool which allows traing very large models on Spark, it's fully compatible with Spark (tested on 1.2 or above).
4 |
5 |
6 |
7 | Reference paper: [Large Scale Distributed Deep Networks](http://research.google.com/archive/large_deep_networks_nips2012.html)
8 |
9 |
10 | Runtime view:
11 |
12 |
13 |
14 | DistML provides several algorithms (LR, LDA, Word2Vec, ALS) to demonstrate its scalabilites, however, you may need to write your own algorithms based on DistML APIs(Model, Session, Matrix, DataStore...), generally, it's simple to extend existed algorithms to DistML, here we take LR as an example: [How to implement logistic regression on DistML](https://github.com/intel-machine-learning/DistML/tree/master/doc/lr-implementation.md).
15 |
16 | ### User Guide
17 | 1. [Download and build DistML](https://github.com/intel-machine-learning/DistML/tree/master/doc/build.md).
18 | 2. [Typical options](https://github.com/intel-machine-learning/DistML/tree/master/doc/options.md).
19 | 3. [Run Sample - LR](https://github.com/intel-machine-learning/DistML/tree/master/doc/sample_lr.md).
20 | 4. [Run Sample - MLR](https://github.com/intel-machine-learning/DistML/tree/master/doc/sample_mlr.md).
21 | 5. [Run Sample - LDA](https://github.com/intel-machine-learning/DistML/tree/master/doc/sample_lda.md).
22 | 6. [Run Sample - Word2Vec](https://github.com/intel-machine-learning/DistML/tree/master/doc/sample_word2vec.md).
23 | 7. [Run Sample - ALS](https://github.com/intel-machine-learning/DistML/tree/master/doc/sample_als.md).
24 | 8. [Benchmarks](https://github.com/intel-machine-learning/DistML/tree/master/doc/benchmarks.md).
25 | 9. [FAQ](https://github.com/intel-machine-learning/DistML/tree/master/doc/faq.md).
26 |
27 | ### API Document
28 | 1. [Source Tree](https://github.com/intel-machine-learning/DistML/tree/master/doc/src_tree.md).
29 | 2. [DistML API](https://github.com/intel-machine-learning/DistML/tree/master/doc/api.md).
30 |
31 |
32 | ## Contributors
33 | He Yunlong (Intel)
34 | Sun Yongjie (Intel)
35 | Liu Lantao (Intern, Graduated)
36 | Hao Ruixiang (Intern, Graduated)
37 |
--------------------------------------------------------------------------------
/data/MNISTReader.java:
--------------------------------------------------------------------------------
1 |
2 | import java.io.*;
3 |
4 | /**
5 | * This class implements a reader for the MNIST dataset of handwritten digits. The dataset is found
6 | * at http://yann.lecun.com/exdb/mnist/.
7 | *
8 | * @author Gabe Johnson
9 | */
10 | public class MNISTReader {
11 |
12 | /**
13 | * @param args
14 | * args[0]: label file; args[1]: data file.
15 | * @throws IOException
16 | */
17 | public static void main(String[] args) throws IOException {
18 |
19 | DataOutputStream dos = new DataOutputStream(new FileOutputStream(args[2]));
20 |
21 | DataInputStream labels = new DataInputStream(new FileInputStream(args[0]));
22 | DataInputStream images = new DataInputStream(new FileInputStream(args[1]));
23 |
24 | int magicNumber = labels.readInt();
25 | if (magicNumber != 2049) {
26 | System.err.println("Label file has wrong magic number: " + magicNumber + " (should be 2049)");
27 | System.exit(0);
28 | }
29 | magicNumber = images.readInt();
30 | if (magicNumber != 2051) {
31 | System.err.println("Image file has wrong magic number: " + magicNumber + " (should be 2051)");
32 | System.exit(0);
33 | }
34 | int numLabels = labels.readInt();
35 | int numImages = images.readInt();
36 | int numRows = images.readInt();
37 | int numCols = images.readInt();
38 | if (numLabels != numImages) {
39 | System.err.println("Image file and label file do not contain the same number of entries.");
40 | System.err.println(" Label file contains: " + numLabels);
41 | System.err.println(" Image file contains: " + numImages);
42 | System.exit(0);
43 | }
44 |
45 | long start = System.currentTimeMillis();
46 | int numLabelsRead = 0;
47 | int numImagesRead = 0;
48 | while (labels.available() > 0 && numLabelsRead < numLabels) {
49 | StringBuilder buf = new StringBuilder();
50 | byte label = labels.readByte();
51 | numLabelsRead++;
52 | int[][] image = new int[numCols][numRows];
53 | for (int colIdx = 0; colIdx < numCols; colIdx++) {
54 | for (int rowIdx = 0; rowIdx < numRows; rowIdx++) {
55 | int point = images.readUnsignedByte();
56 | image[colIdx][rowIdx] = point;
57 | buf.append("" + point + " ");
58 | }
59 | }
60 | numImagesRead++;
61 |
62 | buf.append("" + label + "\n");
63 | dos.write(buf.toString().getBytes());
64 |
65 |
66 | // At this point, 'label' and 'image' agree and you can do whatever you like with them.
67 |
68 | if (numLabelsRead % 10 == 0) {
69 | System.out.print(".");
70 | }
71 | if ((numLabelsRead % 800) == 0) {
72 | System.out.print(" " + numLabelsRead + " / " + numLabels);
73 | long end = System.currentTimeMillis();
74 | long elapsed = end - start;
75 | long minutes = elapsed / (1000 * 60);
76 | long seconds = (elapsed / 1000) - (minutes * 60);
77 | System.out.println(" " + minutes + " m " + seconds + " s ");
78 | }
79 | }
80 | System.out.println();
81 | long end = System.currentTimeMillis();
82 | long elapsed = end - start;
83 | long minutes = elapsed / (1000 * 60);
84 | long seconds = (elapsed / 1000) - (minutes * 60);
85 | System.out
86 | .println("Read " + numLabelsRead + " samples in " + minutes + " m " + seconds + " s ");
87 |
88 | dos.close();
89 | }
90 |
91 | }
92 |
93 |
--------------------------------------------------------------------------------
/data/mnist_prepare.sh:
--------------------------------------------------------------------------------
1 |
2 | wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
3 | wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
4 |
5 | gunzip train-images-idx3-ubyte.gz
6 | gunzip train-labels-idx1-ubyte.gz
7 |
8 | javac MNISTReader.java
9 | java MNISTReader train-labels-idx1-ubyte train-images-idx3-ubyte mnist_train.txt
10 |
--------------------------------------------------------------------------------
/doc/architect.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/intel-machine-learning/DistML/10aaa292d8d48158f97b7f2a439711d1aec42279/doc/architect.png
--------------------------------------------------------------------------------
/doc/build.md:
--------------------------------------------------------------------------------
1 |
2 | ## Guide to download and build DistML
3 |
4 |
5 | ### Download DistML
6 | ```sh
7 | $ git clone https://github.com/intel-machine-learning/DistML.git
8 | ```
9 |
10 | ### Configure and build
11 | You may use a Spark with different version as DistML, though DistML is compatible with recent Spark releases, we encourage you to build DistML with same version Spark as that you are using.
12 | ```sh
13 | $ cd DistML
14 | $ vi pom.xml
15 | $ mvn package
16 | ```
17 |
18 | After successful building, you will see target/distm-.jar, refer guides to run samples.
19 |
--------------------------------------------------------------------------------
/doc/lr-implementation.md:
--------------------------------------------------------------------------------
1 |
2 | ## A sample shows how write algorithms based on DistML APIs
3 |
4 | ```scala
5 | val m = new Model() {
6 | registerMatrix("weights", new DoubleArrayWithIntKey(dim + 1))
7 | }
8 |
9 | val dm = DistML.distribute(sc, m, psCount, DistML.defaultF)
10 | val monitorPath = dm.monitorPath
11 |
12 | dm.setTrainSetSize(samples.count())
13 |
14 | for (iter <- 0 to maxIterations - 1) {
15 | println("============ Iteration: " + iter + " ==============")
16 |
17 | val t = samples.mapPartitionsWithIndex((index, it) => {
18 | println("--- connecting to PS ---")
19 | val session = new Session(m, monitorPath, index)
20 | val wd = m.getMatrix("weights").asInstanceOf[DoubleArrayWithIntKey]
21 |
22 | val batch = new util.LinkedList[(mutable.HashMap[Int, Double], Int)]
23 | while (it.hasNext) {
24 | batch.clear()
25 | var count = 0
26 | while ((count < batchSize) && it.hasNext) {
27 | batch.add(it.next())
28 | count = count + 1
29 | }
30 |
31 | val keys = new KeyList()
32 | for ((x, label) <- batch) {
33 | for (key <- x.keySet) {
34 | keys.addKey(key)
35 | }
36 | }
37 |
38 | val w = wd.fetch(keys, session)
39 | val w_old = new util.HashMap[Long, Double]
40 | for ((key, value) <- w) {
41 | w_old.put(key, value)
42 | }
43 |
44 | for ((x, label) <- batch) {
45 | var sum = 0.0
46 | for ((k, v) <- x) {
47 | sum += w(k) * v
48 | }
49 | val h = 1.0 / (1.0 + Math.exp(-sum))
50 |
51 | val err = eta * (h - label)
52 | for ((k, v) <- x) {
53 | w.put(k, w(k) - err * v)
54 | }
55 |
56 | cost = cost + label * Math.log(h) + (1 - label) * Math.log(1 - h)
57 | }
58 |
59 | cost /= batch.size()
60 | for (key <- w.keySet) {
61 | val grad: Double = w(key) - w_old(key)
62 | w.put(key, grad)
63 | }
64 | wd.push(w, session)
65 | }
66 |
67 | session.disconnect()
68 |
69 | val r = new Array[Double](1)
70 | r(0) = -cost
71 | r.iterator
72 | })
73 |
74 | val totalCost = t.reduce(_+_)
75 | println("============ Iteration done, Total Cost: " + totalCost + " ============")
76 | }
77 | ```
78 |
79 | ## Instructions
80 |
81 | Firstly define your model with parameter type and dimension, for logistic regression, we need a double vector, DistML provides Array/Matrix for int/long/float/double.
82 | ```scala
83 | val m = new Model() {
84 | registerMatrix("weights", new DoubleArrayWithIntKey(dim + 1))
85 | }
86 | ```
87 |
88 | Before training the model, we need to distributed the parameters to several parameter server nodes, the number of parameter servers is specified by psCount.
89 | ```scala
90 | val dm = DistML.distribute(sc, m, psCount, DistML.defaultF)
91 | val monitorPath = dm.monitorPath
92 |
93 | dm.setTrainSetSize(samples.count())
94 | ```
95 |
96 | In each worker doing training jobs, we need to setup a session, which helps to setup databuses between workers and parameter servers.
97 | ```scala
98 | val session = new Session(m, monitorPath, index)
99 | val wd = m.getMatrix("weights").asInstanceOf[DoubleArrayWithIntKey]
100 | ```
101 |
102 | After connected to parameter servers, we can fetch the parameters now. Note that w_old is used to calculate updates after each iteration.
103 | ```scala
104 | val w = wd.fetch(keys, session)
105 | val w_old = new util.HashMap[Long, Double]
106 | for ((key, value) <- w) {
107 | w_old.put(key, value)
108 | }
109 | ···
110 |
111 | With training, the parameters are updated, we calculate updates here then push to parameter servers.
112 | ```scala
113 | for (key <- w.keySet) {
114 | val grad: Double = w(key) - w_old(key)
115 | w.put(key, grad)
116 | }
117 | wd.push(w, session)
118 | ```
119 | When worker finishs training of each iteration, disconnect from parameter servers.
120 | ```scala
121 | session.disconnect()
122 | ```
123 |
124 |
--------------------------------------------------------------------------------
/doc/runtime.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/intel-machine-learning/DistML/10aaa292d8d48158f97b7f2a439711d1aec42279/doc/runtime.png
--------------------------------------------------------------------------------
/doc/sample_lda.md:
--------------------------------------------------------------------------------
1 |
2 | == sample command to run lightlda
3 |
4 | ```sh
5 | #./spark-submit --class com.intel.distml.example.LightLDA --master yarn://dl-s1:8088 /home/spark/yunlong/distml-0.2.jar --psCount 1 --maxIterations 100 --k 2 --alpha 0.01 --beta 0.01 --batchSize 100 /data/lda/short.txt
6 | ```
7 |
8 | == About LDA parameters:
9 | ```
10 | k: topic number, default value 20
11 | alpha : hyper parameter, default value 0.01
12 | beta : hyper parameter, default value 0.01
13 | showPlexity: whether to show perplexity after training, default value true
14 | ```
15 |
--------------------------------------------------------------------------------
/doc/sample_lr.md:
--------------------------------------------------------------------------------
1 |
2 | == Run sparse logistic regression
3 |
4 | ```sh
5 | #./spark-submit --num-executors 8 --executor-cores 3 --executor-memory 10g --class com.intel.distml.example.regression.MelBlanc --master yarn://dl-s1:8088 /home/spark/yunlong/distml-0.2.jar --psCount 1 --trainType ssp --maxIterations 150 --maxLag 2 --dim 1000000 --partitions 8 --eta 0.00001 /data/lr/Blanc-Mel.txt /models/lr/Blanc
6 | ```
7 |
8 | === Options
9 | runType: train or test, default value: train
10 | psCount: number of parameter servers, default value 1
11 | psBackup : whether to enable parameter server fault toleranceBoolean = false
12 | trainType : how to train lr model, "ssp" or "asgd", default value "ssp"
13 | maxIterations: how many iterations to train the model, only applicable to runType=train default value 100
14 | batchSize: batch size when using "asgd" as train type, default value 100
15 | maxLag: max iteration difference between fastest and slowest workers, only applicable to trainType=ssp, default value 2
16 | dim: dimension of the lr weights, default value 10000000
17 | eta: learning rate, default value 0.0001
18 | partitions: how many partitions for training data, default value 1
19 | input: dataset for training, mandatory
20 | modelPath: where to save the model, mandatory
21 |
--------------------------------------------------------------------------------
/doc/src_tree.md:
--------------------------------------------------------------------------------
1 |
2 | ## DistML source tree reading guide
3 | Though only supports Spark now, DistML is designed to be runnable on Spark or on Yarn directly. so we implemented main engine with Java, Spark related support and sample algorithms are written in scala.
4 |
5 | --src
6 | main
7 | java
8 | com
9 | intel
10 | distml
11 | api
12 | platform
13 | util
14 | scala
15 | com
16 | intel
17 | distml
18 | clustering
19 | example
20 | feature
21 | platform
22 | regression
23 | util
24 |
25 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | intel
8 | distml
9 | 0.2
10 |
11 | 2.10.4
12 | 2.11.1
13 |
14 |
15 |
16 |
17 | com.typesafe.akka
18 | akka-remote_2.10
19 | 2.3.4
20 |
21 |
22 | com.typesafe.akka
23 | akka-actor_2.10
24 | 2.3.4
25 |
26 |
27 | org.apache.hadoop
28 | hadoop-hdfs
29 | 2.2.0
30 | compile
31 |
32 |
33 | org.apache.spark
34 | spark-core_2.10
35 | 1.4.0
36 | compile
37 |
38 |
39 | org.apache.spark
40 | spark-mllib_2.10
41 | 1.4.0
42 | compile
43 |
44 |
45 | com.google.guava
46 | guava
47 | 18.0
48 | compile
49 |
50 |
51 | com.github.fommil.netlib
52 | core
53 | 1.1.2
54 |
55 |
56 | org.jblas
57 | jblas
58 | 1.2.3
59 |
60 |
61 | com.github.scopt
62 | scopt_2.10
63 | 3.2.0
64 |
65 |
66 |
67 | org.scala-lang
68 | scala-library
69 | 2.10.4
70 | compile
71 |
72 |
73 |
74 | org.scala-lang
75 | scala-compiler
76 | 2.10.4
77 | compile
78 |
79 |
80 | org.apache.commons
81 | commons-math3
82 | 3.0
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 | net.alchim31.maven
91 | scala-maven-plugin
92 | 3.2.0
93 |
94 |
95 | org.apache.maven.plugins
96 | maven-compiler-plugin
97 |
98 |
99 |
100 |
101 | src
102 |
103 |
104 | net.alchim31.maven
105 | scala-maven-plugin
106 |
107 |
108 | scala-compile-first
109 | process-resources
110 |
111 | add-source
112 | compile
113 |
114 |
115 |
116 |
117 |
118 | org.apache.maven.plugins
119 | maven-compiler-plugin
120 |
121 |
122 | compile
123 |
124 | compile
125 |
126 |
127 | 1.7
128 | 1.7
129 | 1.7
130 | UTF-8
131 |
132 |
133 |
134 |
135 |
136 | org.apache.maven.plugins
137 | maven-shade-plugin
138 |
139 | false
140 |
141 |
142 | org.jblas:jblas
143 | com.github.scopt:scopt_2.10
144 | org.apache.commons:commons-math3
145 |
146 |
147 |
148 |
149 |
150 | package
151 |
152 | shade
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/api/DMatrix.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.api;
2 |
3 | import com.intel.distml.util.DataDesc;
4 | import com.intel.distml.util.DataStore;
5 | import com.intel.distml.util.KeyCollection;
6 | import com.intel.distml.util.KeyRange;
7 |
8 | import java.io.Serializable;
9 |
10 | public class DMatrix implements Serializable {
11 |
12 | public static final int PARTITION_STRATEGY_LINEAR = 0;
13 | public static final int PARTITION_STRATEGY_HASH = 1;
14 |
15 | public KeyCollection[] partitions;
16 |
17 | protected KeyRange rowKeys;
18 |
19 | protected int partitionStrategy;
20 |
21 | protected String name;
22 |
23 | protected DataDesc format;
24 |
25 | protected DataStore store;
26 |
27 | public DMatrix(long rows) {
28 | this.partitionStrategy = PARTITION_STRATEGY_LINEAR;
29 |
30 | rowKeys = new KeyRange(0, rows-1);
31 | }
32 |
33 | public KeyCollection getRowKeys() {
34 | return rowKeys;
35 | }
36 |
37 | public KeyCollection getColKeys() {
38 | return KeyRange.Single;
39 | }
40 |
41 | public DataDesc getFormat() {
42 | return format;
43 | }
44 |
45 | public void setPartitionStrategy(int strategy) {
46 | if ((strategy != PARTITION_STRATEGY_LINEAR) && (strategy != PARTITION_STRATEGY_HASH)) {
47 | throw new IllegalArgumentException("partiton strategy must be SPLIT or HASH.");
48 | }
49 |
50 | partitionStrategy = strategy;
51 | }
52 |
53 | void partition(int serverNum) {
54 |
55 | KeyCollection[] keySets;
56 | if (partitionStrategy == PARTITION_STRATEGY_LINEAR) {
57 | keySets = rowKeys.linearSplit(serverNum);
58 | }
59 | else {
60 | keySets = rowKeys.hashSplit(serverNum);
61 | }
62 |
63 | partitions = keySets;
64 | }
65 |
66 | public DataStore getStore() {
67 | return store;
68 | }
69 |
70 | public DataStore setStore() {
71 | return store;
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/api/DataBus.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.api;
2 |
3 | import com.intel.distml.util.DataDesc;
4 | import com.intel.distml.util.KeyCollection;
5 |
6 | import java.util.HashMap;
7 |
8 | /**
9 | * Created by yunlong on 6/3/15.
10 | */
11 | public interface DataBus {
12 |
13 | public byte[][] fetch(String matrixName, KeyCollection rowKeys, DataDesc format);
14 |
15 | public void push(String matrixName, DataDesc format, byte[][] data);
16 |
17 | public void disconnect();
18 |
19 | public void psAvailable(int index, String addr);
20 | }
21 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/api/Model.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.api;
2 |
3 | import java.io.Serializable;
4 | import java.util.HashMap;
5 |
6 | /**
7 | * Created by yunlong on 12/13/14.
8 | */
9 | public class Model implements Serializable {
10 |
11 | public HashMap dataMap;
12 |
13 | public String monitorPath;
14 |
15 | public int psCount;
16 | public boolean psReady;
17 |
18 | public Model() {
19 | dataMap = new HashMap();
20 | psReady = false;
21 | }
22 |
23 | public void registerMatrix(String name, DMatrix matrix) {
24 | if (dataMap.containsKey(name)) {
25 | throw new IllegalArgumentException("Matrix already exist: " + name);
26 | }
27 | matrix.name = name;
28 | dataMap.put(name, matrix);
29 | }
30 |
31 | public DMatrix getMatrix(String matrixName) {
32 | return dataMap.get(matrixName);
33 | }
34 |
35 | public void autoPartition(int serverNum) {
36 | this.psCount = serverNum;
37 |
38 | for (String matrixName: dataMap.keySet()) {
39 | DMatrix m = dataMap.get(matrixName);
40 | m.partition(serverNum);
41 | }
42 | }
43 |
44 | }
45 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/api/Session.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.api;
2 |
3 | import akka.actor.ActorRef;
4 | import akka.actor.ActorSystem;
5 | import com.intel.distml.platform.MonitorActor;
6 | import com.intel.distml.platform.WorkerActor;
7 | import com.typesafe.config.Config;
8 | import com.typesafe.config.ConfigFactory;
9 |
10 | /**
11 | * Created by yunlong on 12/9/15.
12 | */
13 | public class Session {
14 |
15 | static final String ACTOR_SYSTEM_CONFIG =
16 | "akka.actor.provider=\"akka.remote.RemoteActorRefProvider\"\n" +
17 | "akka.remote.netty.tcp.port=0\n" +
18 | "akka.remote.log-remote-lifecycle-events=off\n" +
19 | "akka.log-dead-letters=off\n" +
20 | "akka.io.tcp.direct-buffer-size = 2 MB\n" +
21 | "akka.io.tcp.trace-logging=off\n" +
22 | "akka.remote.netty.tcp.maximum-frame-size=4126935";
23 |
24 |
25 | public ActorSystem workerActorSystem;
26 | public DataBus dataBus;
27 | public ActorRef monitor;
28 | public ActorRef worker;
29 | public Model model;
30 |
31 | public Session(Model model, String monitorPath, int workerIndex) {
32 | this.model = model;
33 | connect(monitorPath, workerIndex);
34 | }
35 |
36 | public void connect(String monitorPath, int workerIndex) {
37 | dataBus = null;
38 |
39 | String WORKER_ACTOR_SYSTEM_NAME = "worker";
40 | String WORKER_ACTOR_NAME = "worker";
41 |
42 | Config cfg = ConfigFactory.parseString(ACTOR_SYSTEM_CONFIG);
43 | workerActorSystem = ActorSystem.create(WORKER_ACTOR_SYSTEM_NAME, ConfigFactory.load(cfg));
44 | worker = workerActorSystem.actorOf(WorkerActor.props(this, model, monitorPath, workerIndex), WORKER_ACTOR_NAME);
45 |
46 | while(dataBus == null) {
47 | try {Thread.sleep(10); } catch (InterruptedException e) {}
48 | }
49 | }
50 |
51 | public void progress(int sampleCount) {
52 | worker.tell(new WorkerActor.Progress(sampleCount), null);
53 | }
54 |
55 | public void iterationDone(int iteration) {
56 | iterationDone(iteration, 0.0);
57 | }
58 |
59 | public void iterationDone(int iteration, double cost) {
60 | WorkerActor.IterationDone req = new WorkerActor.IterationDone(iteration, cost);
61 | worker.tell(req, null);
62 | while(!req.done) {
63 | try { Thread.sleep(100); } catch (Exception e) {}
64 | }
65 | }
66 |
67 | public void disconnect() {
68 | dataBus.disconnect();
69 | workerActorSystem.stop(worker);
70 | workerActorSystem.shutdown();
71 | dataBus = null;
72 | }
73 |
74 | public void discard() {
75 | worker.tell(new WorkerActor.Command(WorkerActor.CMD_DISCONNECT), null);
76 | }
77 |
78 | @Override
79 | public void finalize() {
80 | if (dataBus != null)
81 | disconnect();
82 | }
83 |
84 | }
85 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/platform/PSActor.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.platform;
2 |
3 | import akka.actor.ActorSelection;
4 | import akka.actor.Props;
5 | import akka.japi.Creator;
6 | import com.intel.distml.api.Model;
7 |
8 | import akka.actor.UntypedActor;
9 | import com.intel.distml.util.*;
10 | import com.intel.distml.util.store.FloatMatrixStoreAdaGrad;
11 | import org.apache.hadoop.conf.Configuration;
12 | import org.apache.hadoop.fs.FileSystem;
13 | import org.apache.hadoop.fs.Path;
14 |
15 | import java.io.DataInputStream;
16 | import java.io.DataOutputStream;
17 | import java.io.IOException;
18 | import java.io.Serializable;
19 | import java.net.Socket;
20 | import java.net.SocketAddress;
21 | import java.net.URI;
22 | import java.util.HashMap;
23 |
24 | public class PSActor extends UntypedActor {
25 |
26 | public static int MIN_REPORT_INTERVAL = 1000;
27 |
28 | public static final int OP_LOAD = 0;
29 | public static final int OP_SAVE = 1;
30 | public static final int OP_ZERO = 2;
31 | public static final int OP_RAND = 3;
32 | public static final int OP_SET = 4;
33 | public static final int OP_SET_ALPHA = 5;
34 |
35 | public static class RegisterRequest implements Serializable {
36 | private static final long serialVersionUID = 1L;
37 |
38 | final public int index;
39 | final public String executorId;
40 | final public String addr;
41 | final public String hostName;
42 | final public long freeMemory;
43 | final long totalMemory;
44 |
45 | public RegisterRequest(int index, String executorId, String hostName, String addr) {
46 | this.index = index;
47 | this.executorId = executorId;
48 | this.addr = addr;
49 | this.hostName = hostName;
50 | totalMemory = Runtime.getRuntime().totalMemory();
51 | freeMemory = Runtime.getRuntime().freeMemory();
52 | }
53 | }
54 |
55 | public static class SyncServerInfo implements Serializable {
56 | private static final long serialVersionUID = 1L;
57 |
58 | final public String addr;
59 | public SyncServerInfo(String addr) {
60 | this.addr = addr;
61 | }
62 |
63 | @Override
64 | public String toString() {
65 | return "SyncServerInfo[" + addr + "]";
66 | }
67 | }
68 |
69 | public static class ModelSetup implements Serializable {
70 | private static final long serialVersionUID = 1L;
71 |
72 | int op;
73 | String path;
74 | String value;
75 | public ModelSetup(int op, String path) {
76 | this.op = op;
77 | this.path = path;
78 | this.value = null;
79 | }
80 | public ModelSetup(int op, String path, String value) {
81 | this.op = op;
82 | this.path = path;
83 | this.value = value;
84 | }
85 | }
86 |
87 | public static class ModelSetupDone implements Serializable {
88 | private static final long serialVersionUID = 1L;
89 |
90 | public ModelSetupDone() {
91 | }
92 | }
93 |
94 | public static class Stop implements Serializable {
95 | private static final long serialVersionUID = 1L;
96 |
97 | public Stop() {
98 | }
99 | }
100 |
101 | public static class AgentMessage implements Serializable {
102 | private static final long serialVersionUID = 1L;
103 |
104 | final public long freeMemory;
105 | final public long totalMemory;
106 |
107 | public AgentMessage(long freeMemory, long totalMemory) {
108 | this.freeMemory = freeMemory;
109 | this.totalMemory = totalMemory;
110 | }
111 | }
112 |
113 | public static class Report implements Serializable {
114 | private static final long serialVersionUID = 1L;
115 |
116 | final public long freeMemory;
117 | final public long totalMemory;
118 |
119 | public Report(long freeMemory, long totalMemory) {
120 | this.freeMemory = freeMemory;
121 | this.totalMemory = totalMemory;
122 | }
123 | }
124 |
125 | private Model model;
126 | private HashMap stores;
127 |
128 | private ActorSelection monitor;
129 | private int serverIndex;
130 | private String executorId;
131 |
132 | private PSAgent agent;
133 | private PSSync syncThread;
134 |
135 | private long lastReportTime;
136 |
137 | public static Props props(final Model model, final HashMap stores, final String monitorPath,
138 | final int parameterServerIndex, final String executorId, final String psNetwordPrefix) {
139 | return Props.create(new Creator() {
140 | private static final long serialVersionUID = 1L;
141 | public PSActor create() throws Exception {
142 | return new PSActor(model, stores, monitorPath, parameterServerIndex, executorId, psNetwordPrefix);
143 | }
144 | });
145 | }
146 |
147 | PSActor(Model model, HashMap stores, String monitorPath, int serverIndex, String executorId, String psNetwordPrefix) {
148 | this.monitor = getContext().actorSelection(monitorPath);
149 | this.serverIndex = serverIndex;
150 | this.executorId = executorId;
151 | this.model = model;
152 | this.stores = stores;
153 | this.lastReportTime = 0;
154 |
155 | agent = new PSAgent(getSelf(), model, stores, psNetwordPrefix);
156 | agent.start();
157 | this.monitor.tell(new RegisterRequest(serverIndex, executorId, agent.hostName(), agent.addr()), getSelf());
158 | }
159 |
160 |
161 | @Override
162 | public void onReceive(Object msg) throws Exception {
163 | log("onReceive: " + msg);
164 | if (msg instanceof SyncServerInfo) {
165 | SyncServerInfo info = (SyncServerInfo) msg;
166 | String[] s = info.addr.split(":");
167 | Socket sck = new Socket(s[0], Integer.parseInt(s[1]));
168 | syncThread = new PSSync(getSelf(), model, stores);
169 | syncThread.asStandBy(sck);
170 | }
171 | else if (msg instanceof ModelSetup) {
172 | ModelSetup req = (ModelSetup) msg;
173 | String path = req.path;
174 | switch (req.op) {
175 | case OP_LOAD:
176 | load(path);
177 | break;
178 | case OP_SAVE:
179 | save(path);
180 | break;
181 | case OP_RAND:
182 | DataStore store = stores.get(path);
183 | store.rand();
184 | break;
185 | case OP_ZERO:
186 | store = stores.get(path);
187 | store.zero();
188 | break;
189 | case OP_SET:
190 | store = stores.get(path);
191 | store.set(req.value);
192 | break;
193 | }
194 | monitor.tell(new ModelSetupDone(), getSelf());
195 | }
196 | else if (msg instanceof MonitorActor.SetAlpha) {
197 | MonitorActor.SetAlpha req = (MonitorActor.SetAlpha) msg;
198 | FloatMatrixStoreAdaGrad store = (FloatMatrixStoreAdaGrad) stores.get(((MonitorActor.SetAlpha) msg).matrixName);
199 | store.setAlpha(req.initialAlpha, req.minAlpha, req.factor);
200 | monitor.tell(new ModelSetupDone(), getSelf());
201 | }
202 | else if (msg instanceof AgentMessage) {
203 | AgentMessage m = (AgentMessage) msg;
204 | long now = System.currentTimeMillis();
205 | if ((now - lastReportTime) > MIN_REPORT_INTERVAL) {
206 | monitor.tell(new Report(m.freeMemory, m.totalMemory), getSelf());
207 | lastReportTime = now;
208 | }
209 | }
210 | else if (msg instanceof MonitorActor.IterationDone) {
211 | agent.closeClients();
212 | monitor.tell(new ModelSetupDone(), getSelf());
213 | }
214 | else if (msg instanceof Stop) {
215 | agent.disconnect();
216 | getContext().stop(self());
217 | }
218 | else unhandled(msg);
219 | }
220 |
221 | private void load(String path) throws IOException {
222 | Configuration conf = new Configuration();
223 | FileSystem fs = FileSystem.get(URI.create(path), conf);
224 |
225 | for (String name : stores.keySet()) {
226 | DataStore store = stores.get(name);
227 | Path dst = new Path(path + "/" + name + "." + serverIndex);
228 | DataInputStream in = fs.open(dst);
229 |
230 | store.readAll(in);
231 | in.close();
232 | }
233 | }
234 |
235 | private void save(String path) throws IOException {
236 | log("saving model: " + path);
237 |
238 | Configuration conf = new Configuration();
239 | FileSystem fs = FileSystem.get(URI.create(path), conf);
240 |
241 | for (String name : stores.keySet()) {
242 | DataStore store = stores.get(name);
243 | Path dst = new Path(path + "/" + name + "." + serverIndex);
244 | DataOutputStream out = fs.create(dst);
245 |
246 | log("saving to: " + dst.getName());
247 | store.writeAll(out);
248 | out.flush();
249 | out.close();
250 | }
251 | }
252 |
253 | @Override
254 | public void postStop() {
255 | if (syncThread != null) {
256 | syncThread.disconnect();
257 | try {
258 | syncThread.join();
259 | }
260 | catch (Exception e) {}
261 | }
262 | getContext().system().shutdown();
263 | log("Parameter server stopped");
264 | }
265 |
266 | private void log(String msg) {
267 | Logger.info(msg, "PS-" + serverIndex);
268 | }
269 | }
270 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/platform/PSManager.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.platform;
2 |
3 | import akka.actor.ActorRef;
4 | import com.intel.distml.util.Logger;
5 |
6 | import java.util.LinkedList;
7 |
8 | /**
9 | * Created by yunlong on 4/28/16.
10 | */
11 | public class PSManager {
12 |
13 | static class PSNode {
14 | String addr;
15 | String executorId;
16 | ActorRef actor;
17 |
18 | PSNode(String addr, String executorId, ActorRef actor) {
19 | this.addr = addr;
20 | this.executorId = executorId;
21 | this.actor = actor;
22 | }
23 |
24 | public String toString() {
25 | return "[ps " + addr + "]";
26 | }
27 | }
28 |
29 | static interface PSMonitor {
30 | void switchServer(int index, String addr, ActorRef actor);
31 | void psFail();
32 | }
33 |
34 | LinkedList[] servers;
35 | PSMonitor monitor;
36 |
37 | public PSManager(PSMonitor monitor, int psCount) {
38 | this.monitor = monitor;
39 | this.servers = new LinkedList[psCount];
40 | for (int i = 0; i < servers.length; i++) {
41 | servers[i] = new LinkedList();
42 | }
43 | }
44 |
45 | public LinkedList getAllActors() {
46 | LinkedList actors = new LinkedList();
47 | for (int i = 0; i < servers.length; i++) {
48 | if (servers[i] != null) {
49 | for (PSNode node : servers[i]) {
50 | if (node != null) {
51 | actors.add(node.actor);
52 | }
53 | }
54 | }
55 | }
56 | return actors;
57 | }
58 |
59 | public ActorRef[] getActors() {
60 | ActorRef[] actors = new ActorRef[servers.length];
61 | for (int i = 0; i < servers.length; i++) {
62 | if (servers[i] != null) {
63 | PSNode node = servers[i].getFirst();
64 | if (node != null) {
65 | actors[i] = node.actor;
66 | }
67 | }
68 | }
69 | return actors;
70 | }
71 |
72 | public String[] getAddrs() {
73 | String[] addrs = new String[servers.length];
74 | for (int i = 0; i < servers.length; i++) {
75 | if (servers[i] != null) {
76 | PSNode node = servers[i].getFirst();
77 | if (node != null) {
78 | addrs[i] = node.addr;
79 | }
80 | }
81 | }
82 | return addrs;
83 | }
84 |
85 | public ActorRef getActor(int index) {
86 | if (servers[index] != null) {
87 | PSNode node = servers[index].getFirst();
88 | if (node != null) {
89 | return node.actor;
90 | }
91 | }
92 | return null;
93 | }
94 |
95 | public String getAddr(int index) {
96 | if (servers[index] != null) {
97 | PSNode node = servers[index].getFirst();
98 | if (node != null) {
99 | return node.addr;
100 | }
101 | }
102 | return null;
103 | }
104 |
105 | public void serverTerminated(ActorRef actor) {
106 | PSNode node = null;
107 | int index = -1;
108 | for (int i = 0; i < servers.length; i++) {
109 | for (PSNode n : servers[i]) {
110 | if (n.actor.equals(actor)) {
111 | node = n;
112 | index = i;
113 | break;
114 | }
115 | }
116 | if (node != null) break;
117 | }
118 |
119 | serverTerminated(index, node);
120 | }
121 |
122 | public void serverTerminated(String executorId) {
123 | PSNode node = null;
124 | int index = -1;
125 | for (int i = 0; i < servers.length; i++) {
126 | for (PSNode n : servers[i]) {
127 | if (n.executorId.equals(executorId)) {
128 | node = n;
129 | index = i;
130 | break;
131 | }
132 | }
133 | if (node != null) break;
134 | }
135 |
136 | serverTerminated(index, node);
137 | }
138 |
139 | public void serverTerminated(int index, PSNode node) {
140 | log("parameter server terminated: " + index + ", " + node);
141 | boolean isPrimary = (servers[index].getFirst() == node);
142 |
143 | servers[index].remove(node);
144 | if (servers[index].size() == 0) {
145 | monitor.psFail();
146 | }
147 | else if (isPrimary) {
148 | node = servers[index].getFirst();
149 | monitor.switchServer(index, node.addr, node.actor);
150 | }
151 | }
152 |
153 |
154 | /**
155 | * add new parameter server to the list.
156 | *
157 | *
158 | * @param index
159 | * @param addr
160 | * @param ref
161 | * @return whether the server be primary
162 | */
163 | public boolean serverAvailable(int index, String executorId, String addr, ActorRef ref) {
164 | servers[index].add(new PSNode(addr, executorId, ref));
165 | return (servers[index].size() == 1);
166 | }
167 | /*
168 | private PSNode getNode(int executorId) {
169 | for (int i = 0; i < servers.length; i++) {
170 | for (PSNode node : servers[i]) {
171 | if (node.executorId == executorId) {
172 | return node;
173 | }
174 | }
175 | }
176 |
177 | return null;
178 | }
179 |
180 | private PSNode getNode(ActorRef a) {
181 | for (int i = 0; i < servers.length; i++) {
182 | for (PSNode node : servers[i]) {
183 | if (node.actor.equals(a)) {
184 | return node;
185 | }
186 | }
187 | }
188 |
189 | return null;
190 | }
191 |
192 | private int getIndex(ActorRef a) {
193 | for (int i = 0; i < servers.length; i++) {
194 | for (PSNode node : servers[i]) {
195 | if (node.actor.equals(a)) {
196 | return i;
197 | }
198 | }
199 | }
200 |
201 | return -1;
202 | }
203 |
204 | private int getIndex(int executorId) {
205 | for (int i = 0; i < servers.length; i++) {
206 | for (PSNode node : servers[i]) {
207 | if (node.executorId == executorId) {
208 | return i;
209 | }
210 | }
211 | }
212 |
213 | return -1;
214 | }
215 |
216 | private int getIndexAsPrimary(ActorRef a) {
217 | for (int i = 0; i < servers.length; i++) {
218 | if (servers[i].size() == 0) continue;
219 |
220 | PSNode node = servers[i].getFirst();
221 | if (node.actor.equals(a)) {
222 | return i;
223 | }
224 | }
225 |
226 | return -1;
227 | }
228 |
229 | private int getIndexAsPrimary(int executorId) {
230 | for (int i = 0; i < servers.length; i++) {
231 | if (servers[i].size() == 0) continue;
232 |
233 | PSNode node = servers[i].getFirst();
234 | if (node.executorId == executorId) {
235 | return i;
236 | }
237 | }
238 |
239 | return -1;
240 | }
241 | */
242 |
243 | private void log(String msg) {
244 | Logger.info(msg, "PSManager");
245 | }
246 | private void warn(String msg) {
247 | Logger.warn(msg, "PSManager");
248 | }
249 | }
250 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/platform/PSSync.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.platform;
2 |
3 | import akka.actor.ActorRef;
4 | import com.intel.distml.api.DMatrix;
5 | import com.intel.distml.api.Model;
6 | import com.intel.distml.util.*;
7 |
8 | import java.io.*;
9 | import java.net.InetSocketAddress;
10 | import java.net.Socket;
11 | import java.nio.ByteBuffer;
12 | import java.nio.channels.SelectionKey;
13 | import java.nio.channels.Selector;
14 | import java.nio.channels.ServerSocketChannel;
15 | import java.nio.channels.SocketChannel;
16 | import java.util.HashMap;
17 | import java.util.Iterator;
18 |
19 | /**
20 | * Created by yunlong on 1/18/16.
21 | */
22 | public class PSSync extends Thread {
23 |
24 | static final int ROLE_PRIMARY = 0;
25 | static final int ROLE_STANDBY = 1;
26 |
27 | static final int MAX_SYNC_BLOCK_SIZE = 1000000;
28 | static final int SYNC_INTERVAL = 10000;
29 |
30 | ActorRef owner;
31 | int role;
32 |
33 | boolean running = false;
34 | Socket socket;
35 | DataOutputStream dos;
36 | DataInputStream dis;
37 |
38 | Model model;
39 | HashMap stores;
40 |
41 | public PSSync(ActorRef owner, Model model, HashMap stores) {
42 | this.owner = owner;
43 | this.model = model;
44 | this.stores = stores;
45 | }
46 |
47 | public void asPrimary(Socket socket) {
48 | log("start PS sync as primary");
49 | this.socket = socket;
50 | role = ROLE_PRIMARY;
51 | try {
52 | dos = new DataOutputStream(socket.getOutputStream());
53 | dis = new DataInputStream(socket.getInputStream());
54 | }
55 | catch (Exception e) {
56 | e.printStackTrace();
57 | }
58 | start();
59 | }
60 |
61 | public void asStandBy(Socket socket) {
62 | log("start PS sync as standby");
63 | this.socket = socket;
64 | role = ROLE_STANDBY;
65 |
66 | log("connecting to primvary server");
67 | DataBusProtocol.SyncRequest req = new DataBusProtocol.SyncRequest();
68 | try {
69 | dos = new DataOutputStream(socket.getOutputStream());
70 | dis = new DataInputStream(socket.getInputStream());
71 | DefaultDataWriter out = new DefaultDataWriter(dos);
72 | out.writeInt(req.sizeAsBytes(model));
73 | req.write(out, model);
74 | dos.flush();
75 | log("request to sync");
76 | }
77 | catch (Exception e) {
78 | e.printStackTrace();
79 | }
80 |
81 | start();
82 | }
83 |
84 | public void disconnect() {
85 | running = false;
86 | try {
87 | socket.close();
88 | }
89 | catch (Exception e) {
90 | }
91 | }
92 |
93 | @Override
94 | public void run() {
95 | running = true;
96 |
97 | try {
98 | if (role == ROLE_PRIMARY) {
99 | syncService();
100 | } else {
101 | syncClient();
102 | }
103 | } catch (Exception e) {
104 | if (running)
105 | e.printStackTrace();
106 | }
107 | }
108 |
109 | private void syncService() throws Exception {
110 | while(running) {
111 | log("new round of sync");
112 | for (String name : model.dataMap.keySet()) {
113 | byte[] nameBytes = name.getBytes();
114 |
115 | DataStore store = stores.get(name);
116 | int rows = (int) store.rows().size();
117 | int maxRowsToSync = MAX_SYNC_BLOCK_SIZE / store.rowSize();
118 |
119 | int startRow = 0;
120 | int leftRows = rows;
121 | while (leftRows > 0) {
122 | int rowsToSync = (leftRows > maxRowsToSync) ? maxRowsToSync : leftRows;
123 | int endRow = rowsToSync + startRow - 1;
124 |
125 | log("sync: " + name + ", " + startRow + ", " + endRow);
126 | dos.writeInt(nameBytes.length);
127 | dos.write(nameBytes);
128 | dos.writeInt(startRow);
129 | dos.writeInt(endRow);
130 |
131 | store.syncTo(dos, startRow, endRow);
132 | startRow += rowsToSync;
133 | leftRows -= rowsToSync;
134 |
135 | try {
136 | Thread.sleep(SYNC_INTERVAL);
137 | } catch (InterruptedException e) {
138 | }
139 | }
140 | }
141 | }
142 | }
143 |
144 | private void syncClient() throws Exception {
145 | while (running) {
146 | log("handling sync ...");
147 | Utils.waitUntil(dis, 4);
148 | int nameLen = dis.readInt();
149 | log("name len: " + nameLen);
150 | Utils.waitUntil(dis, nameLen);
151 | byte[] nameBytes = new byte[nameLen];
152 | dis.read(nameBytes);
153 | String name = new String(nameBytes);
154 | log("name: " + name);
155 |
156 | DataStore store = stores.get(name);
157 | int startRow = dis.readInt();
158 | int endRow = dis.readInt();
159 | log("rows: " + startRow + ", " + endRow);
160 | store.syncFrom(dis, startRow, endRow);
161 | }
162 | }
163 |
164 | private void log(String msg) {
165 | Logger.info(msg, "PSAgent");
166 | }
167 | private void warn(String msg) {
168 | Logger.warn(msg, "PSAgent");
169 | }
170 | }
171 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/platform/SSP.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.platform;
2 |
3 | import java.util.*;
4 |
5 | /**
6 | * Created by yunlong on 16-4-7.
7 | */
8 | public class SSP {
9 |
10 | class Worker {
11 | int index;
12 | int progress;
13 | boolean waiting;
14 |
15 | public Worker(int index, int progress) {
16 | this.index = index;
17 | this.progress = progress;
18 | waiting = false;
19 | }
20 | }
21 |
22 | class SortHelper implements Comparator {
23 |
24 | public int compare(Worker o1, Worker o2) {
25 | return o2.progress - o1.progress;
26 | }
27 | }
28 |
29 |
30 | public class CheckResult {
31 | public int index;
32 | public int progress;
33 | public boolean waiting;
34 | public LinkedList workersToNotify;
35 |
36 | public CheckResult(int index, int progress) {
37 | this.index = index;
38 | this.progress = progress;
39 | waiting = false;
40 |
41 | }
42 |
43 | }
44 |
45 | LinkedList workers;
46 | int maxIterations;
47 | int maxLag;
48 | SortHelper sorter = new SortHelper();
49 |
50 | public SSP(int maxIterations, int maxLag) {
51 | workers = new LinkedList();
52 | this.maxIterations = maxIterations;
53 | this.maxLag = maxLag;
54 | }
55 |
56 | /**
57 | * set progress and check whether to wait
58 | *
59 | *
60 | * @param worker
61 | * @param iter
62 | * @return
63 | */
64 | public CheckResult progress(int worker, int iter) {
65 |
66 | CheckResult result = new CheckResult(worker, iter);
67 | Worker w = null;
68 |
69 | boolean found = false;
70 | for (int i = 0; i < workers.size(); i++) {
71 | w = workers.get(i);
72 | if (w.index == worker) {
73 | assert (w.progress == iter - 1);
74 | w.progress++;
75 | found = true;
76 | break;
77 | }
78 | }
79 | if (!found)
80 | workers.add(new Worker(worker, iter));
81 |
82 | Collections.sort(workers, sorter);
83 |
84 | Worker last = workers.getLast();
85 | if (worker != last.index) {
86 | if (iter - last.progress >= maxLag) {
87 | result.waiting = true;
88 | w.waiting = true;
89 | //System.out.println("worker " + worker + " is waiting for " + last.index);
90 | }
91 | }
92 | result.workersToNotify = workersToNotify(w);
93 |
94 | //show();
95 |
96 | return result;
97 | }
98 |
99 | public LinkedList workersToNotify(Worker p) {
100 | Iterator it = workers.iterator();
101 | Worker last = workers.getLast();
102 |
103 | LinkedList workersToNotify = new LinkedList();
104 | while(it.hasNext()) {
105 | Worker w = it.next();
106 | if (w == p) break;
107 |
108 | //System.out.println("check notify: " + w.progress + ", " + last.progress);
109 | if (w.progress - last.progress < maxLag) {
110 | if (w.waiting) {
111 | workersToNotify.add(w.index);
112 | w.waiting = false;
113 | }
114 | }
115 | }
116 |
117 | return workersToNotify;
118 | }
119 |
120 | private void show() {
121 | for (int i = 0;i < workers.size(); i++) {
122 | Worker w = workers.get(i);
123 | System.out.print("w[" + i + "]=" + w.progress + "\t");
124 | }
125 | System.out.println();
126 | }
127 |
128 | }
129 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/platform/WorkerActor.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.platform;
2 |
3 | import akka.actor.*;
4 | import akka.japi.Creator;
5 | import com.intel.distml.api.Session;
6 | import com.intel.distml.api.Model;
7 | import com.intel.distml.util.Logger;
8 |
9 | import java.io.Serializable;
10 |
11 | /**
12 | * Created by yunlong on 12/13/14.
13 | */
14 | public class WorkerActor extends UntypedActor {
15 |
16 | public static final int CMD_DISCONNECT = 1;
17 | public static final int CMD_STOP = 2;
18 | public static final int CMD_PS_TERMINATED = 3;
19 | public static final int CMD_PS_AVAILABLE = 4;
20 |
21 | public static class Progress implements Serializable {
22 | private static final long serialVersionUID = 1L;
23 |
24 | final public int sampleCount;
25 | public Progress(int sampleCount) {
26 | this.sampleCount = sampleCount;
27 | }
28 | }
29 |
30 | public static class Command implements Serializable {
31 | private static final long serialVersionUID = 1L;
32 |
33 | final public int cmd;
34 | public Command(int cmd) {
35 | this.cmd = cmd;
36 | }
37 | }
38 |
39 | public static class PsTerminated extends Command {
40 | private static final long serialVersionUID = 1L;
41 |
42 | final public int index;
43 | public PsTerminated(int index) {
44 | super(CMD_PS_TERMINATED);
45 | this.index = index;
46 | }
47 | }
48 |
49 | public static class PsAvailable extends Command {
50 | private static final long serialVersionUID = 1L;
51 |
52 | final public int index;
53 | final public String addr;
54 | public PsAvailable(int index, String addr) {
55 | super(CMD_PS_AVAILABLE);
56 | this.index = index;
57 | this.addr = addr;
58 | }
59 | }
60 |
61 | public static class RegisterRequest implements Serializable {
62 | private static final long serialVersionUID = 1L;
63 |
64 | final public int workerIndex;
65 | public RegisterRequest(int workerIndex) {
66 | this.workerIndex = workerIndex;
67 | }
68 | }
69 |
70 | public static class AppRequest implements Serializable {
71 | private static final long serialVersionUID = 1L;
72 |
73 | public boolean done;
74 | public AppRequest() {
75 | done = false;
76 | }
77 |
78 | public String toString() {
79 | return "DriverRequest";
80 | }
81 | }
82 |
83 | public static class IterationDone extends AppRequest {
84 | private static final long serialVersionUID = 1L;
85 |
86 | int iteration;
87 | double cost;
88 | public IterationDone(int iteration, double cost) {
89 | this.iteration = iteration;
90 | this.cost = cost;
91 | }
92 |
93 | public String toString() {
94 | return "IterationDone";
95 | }
96 | }
97 |
98 | private ActorSelection monitor;
99 | private Model model;
100 | private int psCount;
101 | private Session de;
102 |
103 | private String[] psAddrs;
104 | private int workerIndex;
105 |
106 | private AppRequest pendingRequest;
107 |
108 | public WorkerActor(final Session de, Model model, String monitorPath, int workerIndex) {
109 | this.monitor = getContext().actorSelection(monitorPath);
110 | this.workerIndex = workerIndex;
111 | this.model = model;
112 | this.de = de;
113 |
114 | this.monitor.tell(new RegisterRequest(this.workerIndex), getSelf());
115 | log("Worker " + workerIndex + " register to monitor: " + monitorPath);
116 | }
117 |
118 | public static Props props(final Session de,final Model model, final String monitorPath, final int index) {
119 | return Props.create(new Creator() {
120 | private static final long serialVersionUID = 1L;
121 | public WorkerActor create() throws Exception {
122 | return new WorkerActor(de, model, monitorPath, index);
123 | }
124 | });
125 | }
126 |
127 | @Override
128 | public void onReceive(Object msg) {
129 | log("onReceive: " + msg);
130 |
131 | if (msg instanceof MonitorActor.RegisterResponse) {
132 | MonitorActor.RegisterResponse res = (MonitorActor.RegisterResponse) msg;
133 | this.psAddrs = res.psAddrs;
134 | psCount = psAddrs.length;
135 |
136 | de.dataBus = new WorkerAgent(model, psAddrs);
137 | }
138 | else if (msg instanceof Progress) {
139 | this.monitor.tell(msg, getSelf());
140 | }
141 | else if (msg instanceof IterationDone) {
142 | IterationDone req = (IterationDone) msg;
143 | pendingRequest = (AppRequest)msg;
144 | monitor.tell(new MonitorActor.SSP_IterationDone(workerIndex, req.iteration, req.cost), getSelf());
145 | }
146 | else if (msg instanceof MonitorActor.SSP_IterationNext) {
147 | assert(pendingRequest != null);
148 | pendingRequest.done = true;
149 | }
150 | else if (msg instanceof PsAvailable) {
151 | PsAvailable ps = (PsAvailable) msg;
152 | de.dataBus.psAvailable(ps.index, ps.addr);
153 | }
154 | else unhandled(msg);
155 | }
156 |
157 | private void log(String msg) {
158 | Logger.info(msg, "Worker-" + workerIndex);
159 | }
160 | }
161 |
162 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/AbstractDataReader.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import java.io.DataInputStream;
4 |
5 | /**
6 | * Created by yunlong on 4/28/16.
7 | */
8 | public interface AbstractDataReader {
9 |
10 | int readInt() throws Exception;
11 |
12 | long readLong() throws Exception;
13 |
14 | float readFloat() throws Exception;
15 |
16 | double readDouble() throws Exception;
17 |
18 | void waitUntil(int size) throws Exception;
19 |
20 | int readBytes(byte[] bytes) throws Exception;
21 |
22 | void readFully(byte[] bytes) throws Exception;
23 | }
24 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/AbstractDataWriter.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | /**
4 | * Created by yunlong on 4/28/16.
5 | */
6 | public interface AbstractDataWriter {
7 |
8 | void writeInt(int value) throws Exception;
9 |
10 | void writeLong(long value) throws Exception;
11 |
12 | void writeFloat(float value) throws Exception;
13 |
14 | void writeDouble(double value) throws Exception;
15 |
16 | void writeBytes(byte[] bytes) throws Exception;
17 |
18 | void flush() throws Exception;
19 | }
20 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/ByteBufferDataReader.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import java.nio.ByteBuffer;
4 |
5 | /**
6 | * Created by yunlong on 4/28/16.
7 | */
8 | public class ByteBufferDataReader implements AbstractDataReader {
9 |
10 | ByteBuffer buf;
11 |
12 | public ByteBufferDataReader(ByteBuffer buf) {
13 | this.buf = buf;
14 | }
15 |
16 | public int readInt() {
17 | return buf.getInt();
18 | }
19 |
20 | public long readLong() {
21 | return buf.getLong();
22 | }
23 |
24 | public float readFloat() {
25 | return buf.getFloat();
26 | }
27 |
28 | public double readDouble() {
29 | return buf.getDouble();
30 | }
31 |
32 | public int readBytes(byte[] bytes) throws Exception {
33 | buf.get(bytes);
34 |
35 | return bytes.length;
36 | }
37 |
38 | public void readFully(byte[] bytes) throws Exception {
39 | buf.get(bytes);
40 | }
41 |
42 |
43 | public void waitUntil(int size) throws Exception {
44 |
45 | }
46 |
47 | }
48 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/ByteBufferDataWriter.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import java.nio.ByteBuffer;
4 |
5 | /**
6 | * Created by yunlong on 4/28/16.
7 | */
8 | public class ByteBufferDataWriter implements AbstractDataWriter {
9 |
10 |
11 | ByteBuffer buf;
12 |
13 | public ByteBufferDataWriter(ByteBuffer buf) {
14 | this.buf = buf;
15 | }
16 |
17 | public void writeInt(int value) throws Exception {
18 | buf.putInt(value);
19 | }
20 |
21 | public void writeLong(long value) throws Exception {
22 | buf.putLong(value);
23 | }
24 |
25 | public void writeFloat(float value) throws Exception {
26 | buf.putFloat(value);
27 |
28 | }
29 |
30 | public void writeDouble(double value) throws Exception {
31 | buf.putDouble(value);
32 | }
33 |
34 | public void writeBytes(byte[] bytes) throws Exception {
35 | buf.put(bytes);
36 | }
37 |
38 | public void flush() throws Exception {
39 |
40 | }
41 |
42 | }
43 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/Constants.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import akka.util.Timeout;
4 | import scala.concurrent.duration.FiniteDuration;
5 |
6 | import java.util.concurrent.TimeUnit;
7 |
8 | /**
9 | * Created by taotao on 15-1-30.
10 | */
11 | public class Constants {
12 |
13 | // Data may be transferred for a long time, and we must ensure that data is successfully transferred,
14 | // so we set a much longer timeout for data future.
15 | public static final Long DATA_FUTURE_TIME = 1000000L; // Unit: ms, 1000s
16 | public static final Timeout DATA_FUTURE_TIMEOUT = new Timeout(DATA_FUTURE_TIME, TimeUnit.MILLISECONDS);
17 | public static final FiniteDuration DATA_FUTURE_TIMEOUT_DURATION = DATA_FUTURE_TIMEOUT.duration();
18 |
19 | // Stop should be done in a short period, so we set a shorter timeout for stop future.
20 | public static final Long STOP_FUTURE_TIME = 30000L; // Unit: ms, 30s
21 | public static final Timeout STOP_FUTURE_TIMEOUT = new Timeout(STOP_FUTURE_TIME, TimeUnit.MILLISECONDS);
22 | public static final FiniteDuration STOP_FUTURE_TIMEOUT_DURATION = STOP_FUTURE_TIMEOUT.duration();
23 |
24 | }
25 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/DataDesc.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import java.io.*;
4 |
5 | /**
6 | * Created by yunlong on 12/30/15.
7 | */
8 | public final class DataDesc implements Serializable {
9 |
10 | public static final int DATA_TYPE_ARRAY = 0;
11 | public static final int DATA_TYPE_MATRIX = 1;
12 |
13 | public static final int KEY_TYPE_INT = 0;
14 | public static final int KEY_TYPE_LONG = 1;
15 |
16 | public static final int ELEMENT_TYPE_INT = 0;
17 | public static final int ELEMENT_TYPE_FLOAT = 1;
18 | public static final int ELEMENT_TYPE_LONG = 2;
19 | public static final int ELEMENT_TYPE_DOUBLE = 3;
20 |
21 | public int dataType;
22 | public int keyType;
23 | public int valueType;
24 |
25 | public boolean denseRow;
26 | public boolean denseColumn;
27 | public boolean adaGrad;
28 |
29 | public int keySize;
30 | public int valueSize;
31 |
32 | public DataDesc() {
33 | }
34 |
35 | public DataDesc(int dataType, int keyType, int valueType) {
36 | this(dataType, keyType, valueType, false, true);
37 | }
38 |
39 | public DataDesc(int dataType, int keyType, int valueType, boolean denseRow, boolean denseColumn) {
40 | this(dataType, keyType, valueType, denseRow, denseColumn, false);
41 | }
42 | public DataDesc(int dataType, int keyType, int valueType, boolean denseRow, boolean denseColumn, boolean adaGrade) {
43 | this.dataType = dataType;
44 | this.valueType = valueType;
45 | this.keyType = keyType;
46 | this.denseRow = denseRow;
47 | this.denseColumn = denseColumn;
48 | this.adaGrad = adaGrade;
49 |
50 | this.keySize = (keyType == KEY_TYPE_INT)? 4 : 8;
51 | this.valueSize = ((valueType == ELEMENT_TYPE_INT) || (valueType == ELEMENT_TYPE_FLOAT))? 4 : 8;
52 | }
53 |
54 | public String toString() {
55 | return "" + dataType + ", " + keyType + ", " + keySize + ", " + valueType + ", " + valueSize;
56 | }
57 |
58 | public int sizeAsBytes() {
59 | return 24; // keySize and valueSize are calculated in fly
60 | }
61 |
62 | public void write(AbstractDataWriter out) throws Exception {
63 | out.writeInt(dataType);
64 | out.writeInt(keyType);
65 | out.writeInt(valueType);
66 | out.writeInt(denseRow ? 1 : 0);
67 | out.writeInt(denseColumn ? 1 : 0);
68 | out.writeInt(adaGrad ? 1 : 0);
69 | }
70 |
71 | public void read(AbstractDataReader in) throws Exception {
72 | dataType = in.readInt();
73 | keyType = in.readInt();
74 | valueType = in.readInt();
75 | denseRow = in.readInt() == 1;
76 | denseColumn = in.readInt() == 1;
77 | adaGrad = in.readInt() == 1;
78 |
79 | this.keySize = (keyType == KEY_TYPE_INT)? 4 : 8;
80 | this.valueSize = ((valueType == ELEMENT_TYPE_INT) || (valueType == ELEMENT_TYPE_FLOAT))? 4 : 8;
81 | }
82 |
83 | public Number readKey(AbstractDataReader is) throws Exception {
84 | if (keyType == KEY_TYPE_INT) {
85 | return is.readInt();
86 | }
87 | else {
88 | return is.readLong();
89 | }
90 | }
91 |
92 |
93 | public void writeKey(Number v, AbstractDataWriter os) throws Exception {
94 | if (keyType == KEY_TYPE_INT) {
95 | os.writeInt(v.intValue());
96 | }
97 | else {
98 | os.writeLong(v.longValue());
99 | }
100 | }
101 |
102 | public Object readValue(AbstractDataReader is) throws Exception {
103 | switch(valueType) {
104 | case ELEMENT_TYPE_INT:
105 | return is.readInt();
106 | case ELEMENT_TYPE_FLOAT:
107 | return is.readFloat();
108 | case ELEMENT_TYPE_LONG:
109 | return is.readLong();
110 | case ELEMENT_TYPE_DOUBLE:
111 | return is.readDouble();
112 | }
113 |
114 | throw new IllegalStateException("invalid value type: " + valueType);
115 | }
116 |
117 | public void writeValue(Object value, AbstractDataWriter os) throws Exception {
118 | switch(valueType) {
119 | case ELEMENT_TYPE_INT:
120 | os.writeInt((Integer) value);
121 | case ELEMENT_TYPE_FLOAT:
122 | os.writeFloat((Float) value);
123 | case ELEMENT_TYPE_LONG:
124 | os.writeLong((Long) value);
125 | case ELEMENT_TYPE_DOUBLE:
126 | os.writeDouble((Double)value);
127 | }
128 | throw new IllegalStateException("invalid value type: " + valueType);
129 | }
130 |
131 | public Number readKey(byte[] data, int offset) {
132 | if (keyType == KEY_TYPE_INT) {
133 | return readInt(data, offset);
134 | }
135 | else {
136 | return readLong(data, offset);
137 | }
138 | }
139 |
140 | public int writeKey(Number v, byte[] data, int offset) {
141 | if (keyType == KEY_TYPE_INT) {
142 | write(v.intValue(), data, offset);
143 | return offset + 4;
144 | }
145 | else {
146 | write(v.longValue(), data, offset);
147 | return offset + 8;
148 | }
149 | }
150 |
151 | public Object readValue(byte[] buf, int offset) {
152 | switch(valueType) {
153 | case ELEMENT_TYPE_INT:
154 | return readInt(buf, offset);
155 | case ELEMENT_TYPE_FLOAT:
156 | return readFloat(buf, offset);
157 | case ELEMENT_TYPE_LONG:
158 | return readLong(buf, offset);
159 | case ELEMENT_TYPE_DOUBLE:
160 | return readDouble(buf, offset);
161 | }
162 |
163 | throw new IllegalStateException("invalid value type: " + valueType);
164 | }
165 |
166 | public int writeValue(Object value, byte[] buf, int offset) {
167 | switch(valueType) {
168 | case ELEMENT_TYPE_INT:
169 | return write((Integer)value, buf, offset);
170 | case ELEMENT_TYPE_FLOAT:
171 | return write((Float)value, buf, offset);
172 | case ELEMENT_TYPE_LONG:
173 | return write((Long)value, buf, offset);
174 | case ELEMENT_TYPE_DOUBLE:
175 | return write((Double)value, buf, offset);
176 | }
177 | throw new IllegalStateException("invalid value type: " + valueType);
178 | }
179 |
180 | public int readInt(byte[] data, int offset) {
181 | int targets =
182 | (data[offset ] & 0x000000ff)
183 | | ((data[offset+1] << 8) & 0x0000ff00)
184 | | ((data[offset+2] << 16) & 0x00ff0000)
185 | | ((data[offset+3] << 24) & 0xff000000);
186 |
187 | return targets;
188 | }
189 | public float readFloat(byte[] data, int offset) {
190 | int targets = readInt(data, offset);
191 | return Float.intBitsToFloat(targets);
192 | }
193 |
194 | public long readLong(byte[] data, int offset) {
195 | long targets =
196 | (data[offset ] & 0x00000000000000ffL)
197 | | ((data[offset+1] << 8) & 0x000000000000ff00L)
198 | | ((data[offset+2] << 16) & 0x0000000000ff0000L)
199 | | ((data[offset+3] << 24) & 0x00000000ff000000L)
200 | | ((((long)data[offset+4]) << 32) & 0x000000ff00000000L)
201 | | ((((long)data[offset+5]) << 40) & 0x0000ff0000000000L)
202 | | ((((long)data[offset+6] << 48)) & 0x00ff000000000000L)
203 | | ((((long)data[offset+7] << 56)) & 0xff00000000000000L);
204 |
205 | return targets;
206 | }
207 | public double readDouble(byte[] data, int offset) {
208 | long targets = readLong(data, offset);
209 | double value = Double.longBitsToDouble(targets);
210 | //System.out.println("read double: " + value);
211 | return value;
212 | }
213 |
214 | public int write(double v, byte[] data, int offset) {
215 | long value = Double.doubleToLongBits(v);
216 | //System.out.println("write double: " + v);
217 | return write(value, data, offset);
218 | }
219 |
220 | public int write(long value, byte[] data, int offset) {
221 | data[offset] = (byte) (value & 0xff);
222 | data[offset+1] = (byte) ((value >> 8) & 0xff);
223 | data[offset+2] = (byte) ((value >> 16) & 0xff);
224 | data[offset+3] = (byte) ((value >> 24) & 0xff);
225 | data[offset+4] = (byte) ((value >> 32) & 0xff);
226 | data[offset+5] = (byte) ((value >> 40) & 0xff);
227 | data[offset+6] = (byte) ((value >> 48) & 0xff);
228 | data[offset+7] = (byte) ((value >> 56) & 0xff);
229 | return offset + 8;
230 | }
231 |
232 | public int write(float v, byte[] data, int offset) {
233 | int value = Float.floatToIntBits(v);
234 | return write(value, data, offset);
235 | }
236 |
237 | public int write(int value, byte[] data, int offset) {
238 | data[offset] = (byte) (value & 0xff);
239 | data[offset+1] = (byte) ((value >> 8) & 0xff);
240 | data[offset+2] = (byte) ((value >> 16) & 0xff);
241 | data[offset+3] = (byte) (value >>> 24);
242 | return offset + 4;
243 | }
244 |
245 | }
246 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/DataStore.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import com.intel.distml.api.DMatrix;
4 | import com.intel.distml.api.Model;
5 | import com.intel.distml.util.store.*;
6 |
7 | import java.io.DataInputStream;
8 | import java.io.DataOutputStream;
9 | import java.io.IOException;
10 | import java.io.OutputStream;
11 | import java.util.HashMap;
12 | import java.util.Map;
13 |
14 | /**
15 | * Created by yunlong on 12/8/15.
16 | */
17 | public abstract class DataStore {
18 |
19 | public abstract KeyCollection rows();
20 | public abstract int rowSize();
21 |
22 | public void rand() {};
23 |
24 | public void zero() {};
25 |
26 | public void set(String value) {};
27 |
28 | public abstract byte[] handleFetch(DataDesc format, KeyCollection rows);
29 |
30 | public abstract void handlePush(DataDesc format, byte[] data);
31 |
32 | public abstract void writeAll(DataOutputStream os) throws IOException;
33 |
34 | public abstract void readAll(DataInputStream is) throws IOException;
35 |
36 | public abstract void syncTo(DataOutputStream os, int fromRow, int toRow) throws IOException;
37 |
38 | public abstract void syncFrom(DataInputStream is, int fromRow, int toRow) throws IOException;
39 |
40 |
41 | public static HashMap createStores(Model model, int serverIndex) {
42 | HashMap stores = new HashMap();
43 | for (Map.Entry m : model.dataMap.entrySet()) {
44 | stores.put(m.getKey(), DataStore.createStore(serverIndex, m.getValue()));
45 | }
46 |
47 | return stores;
48 | }
49 |
50 | public static DataStore createStore(int serverIndex, DMatrix matrix) {
51 | System.out.println("create store: " + serverIndex + ", cols: " + matrix.getColKeys().size());
52 | DataDesc format = matrix.getFormat();
53 | if (format.dataType == DataDesc.DATA_TYPE_ARRAY) {
54 | if (format.valueType == DataDesc.ELEMENT_TYPE_INT) {
55 | IntArrayStore store = new IntArrayStore();
56 | store.init(matrix.partitions[serverIndex]);
57 | return store;
58 | } else if (format.valueType == DataDesc.ELEMENT_TYPE_DOUBLE) {
59 | DoubleArrayStore store = new DoubleArrayStore();
60 | store.init(matrix.partitions[serverIndex]);
61 | return store;
62 | } else if (format.valueType == DataDesc.ELEMENT_TYPE_FLOAT) {
63 | FloatArrayStore store = new FloatArrayStore();
64 | store.init(matrix.partitions[serverIndex]);
65 | return store;
66 | }
67 | }
68 | else {
69 | if (format.valueType == DataDesc.ELEMENT_TYPE_INT) {
70 | IntMatrixStore store = new IntMatrixStore();
71 | store.init(matrix.partitions[serverIndex], (int) matrix.getColKeys().size());
72 | return store;
73 | } else if (format.valueType == DataDesc.ELEMENT_TYPE_DOUBLE) {
74 | DoubleMatrixStore store = new DoubleMatrixStore();
75 | store.init(matrix.partitions[serverIndex], (int) matrix.getColKeys().size());
76 | return store;
77 | } else if (format.valueType == DataDesc.ELEMENT_TYPE_FLOAT) {
78 | if (format.adaGrad) {
79 | FloatMatrixStoreAdaGrad store = new FloatMatrixStoreAdaGrad();
80 | store.init(matrix.partitions[serverIndex], (int) matrix.getColKeys().size());
81 | return store;
82 | }
83 | else {
84 | FloatMatrixStore store = new FloatMatrixStore();
85 | store.init(matrix.partitions[serverIndex], (int) matrix.getColKeys().size());
86 | return store;
87 | }
88 | }
89 | }
90 |
91 | throw new IllegalArgumentException("Unrecognized matrix type: " + matrix.getClass().getName());
92 | }
93 |
94 | }
95 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/DefaultDataReader.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import java.io.DataInputStream;
4 | import java.io.IOException;
5 | import java.nio.ByteBuffer;
6 |
7 | /**
8 | * Created by yunlong on 4/28/16.
9 | */
10 | public class DefaultDataReader implements AbstractDataReader {
11 |
12 | DataInputStream dis;
13 |
14 | public DefaultDataReader(DataInputStream dis) {
15 | this.dis = dis;
16 | }
17 |
18 | public int readInt() throws Exception {
19 | return dis.readInt();
20 | }
21 |
22 | public long readLong() throws Exception {
23 | return dis.readLong();
24 | }
25 |
26 | public float readFloat() throws Exception {
27 | return dis.readFloat();
28 | }
29 |
30 | public double readDouble() throws Exception {
31 | return dis.readDouble();
32 | }
33 |
34 | public int readBytes(byte[] bytes) throws Exception {
35 | return dis.read(bytes);
36 | }
37 |
38 | public void readFully(byte[] bytes) throws Exception {
39 | dis.readFully(bytes);
40 | }
41 |
42 | public void waitUntil(int size) throws Exception {
43 | while(dis.available() < size) { try { Thread.sleep(1); } catch (Exception e){}}
44 | }
45 |
46 | }
47 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/DefaultDataWriter.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import java.io.DataOutputStream;
4 | import java.nio.ByteBuffer;
5 |
6 | /**
7 | * Created by yunlong on 4/28/16.
8 | */
9 | public class DefaultDataWriter implements AbstractDataWriter {
10 |
11 |
12 | DataOutputStream dos;
13 |
14 | public DefaultDataWriter(DataOutputStream dos) {
15 | this.dos = dos;
16 | }
17 |
18 | public void writeInt(int value) throws Exception {
19 | dos.writeInt(value);
20 | }
21 |
22 | public void writeLong(long value) throws Exception {
23 | dos.writeLong(value);
24 | }
25 |
26 | public void writeFloat(float value) throws Exception {
27 | dos.writeFloat(value);
28 |
29 | }
30 |
31 | public void writeDouble(double value) throws Exception {
32 | dos.writeDouble(value);
33 | }
34 |
35 | public void writeBytes(byte[] bytes) throws Exception {
36 | dos.write(bytes);
37 | }
38 |
39 | public void flush() throws Exception {
40 | dos.flush();
41 | }
42 |
43 |
44 | }
45 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/DoubleArray.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import com.intel.distml.api.DMatrix;
4 | import com.intel.distml.api.Session;
5 | import com.intel.distml.util.KeyCollection;
6 |
7 | import java.util.HashMap;
8 |
9 | /**
10 | * Created by yunlong on 12/8/15.
11 | */
12 |
13 | public class DoubleArray extends SparseArray {
14 |
15 | public DoubleArray(long dim) {
16 | super(dim, DataDesc.KEY_TYPE_LONG, DataDesc.ELEMENT_TYPE_DOUBLE);
17 | }
18 |
19 | }
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/DoubleMatrix.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import com.intel.distml.api.DMatrix;
4 | import com.intel.distml.api.Session;
5 |
6 | import java.util.HashMap;
7 |
8 | /**
9 | * Created by jimmy on 15-12-29.
10 | */
11 | public class DoubleMatrix extends SparseMatrix {
12 | public DoubleMatrix(long rows, int cols) {
13 | super(rows, cols, DataDesc.KEY_TYPE_LONG, DataDesc.ELEMENT_TYPE_DOUBLE);
14 | }
15 |
16 | protected boolean isZero(Double value) {
17 | return value == 0.0;
18 | }
19 | protected Double[] createValueArray(int size) {
20 | return new Double [size];
21 | }
22 | public HashMap fetch(KeyCollection rows, Session session) {
23 | HashMap result;
24 | result = super.fetch(rows, session);
25 | //System.out.println(result.get(1).length);
26 | return result;
27 | }
28 | }
29 |
30 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/IntArray.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import akka.pattern.Patterns;
4 | import com.intel.distml.api.DMatrix;
5 | import com.intel.distml.api.Session;
6 | import com.intel.distml.platform.DataBusProtocol;
7 | import com.intel.distml.util.KeyCollection;
8 | import scala.concurrent.Future;
9 |
10 | import java.util.HashMap;
11 | import java.util.Iterator;
12 | import java.util.LinkedList;
13 | import java.util.Map;
14 |
15 | /**
16 | * Created by yunlong on 12/8/15.
17 | */
18 | public class IntArray extends SparseArray {
19 |
20 | public IntArray(long dim) {
21 | super(dim, DataDesc.KEY_TYPE_LONG, DataDesc.ELEMENT_TYPE_INT);
22 | }
23 |
24 | }
25 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/IntArrayWithIntKey.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | /**
4 | * Created by yunlong on 1/2/16.
5 | */
6 | public class IntArrayWithIntKey extends SparseArray {
7 |
8 | public IntArrayWithIntKey(long dim) {
9 | super(dim, DataDesc.KEY_TYPE_INT, DataDesc.ELEMENT_TYPE_INT);
10 | }
11 |
12 | }
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/IntMatrix.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import com.intel.distml.api.DMatrix;
4 | import com.intel.distml.api.Session;
5 | import com.intel.distml.util.KeyCollection;
6 |
7 | import java.util.HashMap;
8 | import java.util.Map;
9 |
10 | /**
11 | * Created by yunlong on 12/8/15.
12 | */
13 | public class IntMatrix extends SparseMatrix {
14 |
15 | public IntMatrix(long rows, int cols) {
16 | super(rows, cols, DataDesc.KEY_TYPE_LONG, DataDesc.ELEMENT_TYPE_INT);
17 | }
18 |
19 | protected boolean isZero(Integer value) {
20 | return value == 0;
21 | }
22 |
23 | protected Integer[] createValueArray(int size) {
24 | return new Integer[size];
25 | }
26 |
27 | }
28 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/IntMatrixWithIntKey.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | /**
4 | * Created by yunlong on 12/8/15.
5 | */
6 | public class IntMatrixWithIntKey extends SparseMatrix {
7 |
8 | public IntMatrixWithIntKey(long rows, int cols) {
9 | super(rows, cols, DataDesc.KEY_TYPE_INT, DataDesc.ELEMENT_TYPE_INT);
10 | }
11 |
12 | protected boolean isZero(Integer value) {
13 | return value == 0;
14 | }
15 |
16 | protected Integer[] createValueArray(int size) {
17 | return new Integer[size];
18 | }
19 |
20 |
21 | }
22 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/KeyCollection.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import java.io.DataInputStream;
4 | import java.io.DataOutputStream;
5 | import java.io.IOException;
6 | import java.io.Serializable;
7 | import java.util.Iterator;
8 |
9 | /**
10 | * Created by yunlong on 12/11/14.
11 | */
12 | public abstract class KeyCollection implements Serializable {
13 |
14 | public static final int TYPE_ALL = 0;
15 | public static final int TYPE_EMPTY = 1;
16 | public static final int TYPE_RANGE = 2;
17 | public static final int TYPE_LIST = 3;
18 | public static final int TYPE_HASH = 4;
19 |
20 | public static final KeyCollection EMPTY = new EmptyKeys();
21 | public static final KeyCollection SINGLE = new KeyRange(0, 0);
22 | public static final KeyCollection ALL = new AllKeys();
23 |
24 | public int type;
25 |
26 | public KeyCollection(int type) {
27 | this.type = type;
28 | }
29 |
30 | public int sizeAsBytes(DataDesc format) {
31 | return 4;
32 | }
33 |
34 | public void write(AbstractDataWriter out, DataDesc format) throws Exception {
35 | out.writeInt(type);
36 | }
37 |
38 | public void read(AbstractDataReader in, DataDesc format) throws Exception {
39 | }
40 |
41 | public static KeyCollection readKeyCollection(AbstractDataReader in, DataDesc format) throws Exception {
42 | in.waitUntil(4);
43 | int type = in.readInt();
44 | KeyCollection ks;
45 | switch(type) {
46 | case TYPE_ALL:
47 | return ALL;
48 | case TYPE_EMPTY:
49 | return EMPTY;
50 |
51 | case TYPE_LIST:
52 | ks = new KeyList();
53 | break;
54 |
55 | case TYPE_RANGE:
56 | ks = new KeyRange();
57 | break;
58 |
59 | default:
60 | ks = new KeyHash();
61 | break;
62 |
63 | }
64 |
65 | in.waitUntil(ks.sizeAsBytes(format));
66 | ks.read(in, format);
67 |
68 | return ks;
69 | }
70 |
71 | public abstract boolean contains(long key);
72 |
73 | public abstract Iterator iterator();
74 |
75 | public abstract boolean isEmpty();
76 |
77 | public abstract long size();
78 |
79 | public KeyCollection intersect(KeyCollection keys) {
80 |
81 | if (keys.equals(KeyCollection.ALL)) {
82 | return this;
83 | }
84 |
85 | if (keys.equals(KeyCollection.EMPTY)) {
86 | return keys;
87 | }
88 |
89 | KeyList result = new KeyList();
90 |
91 | Iterator it = keys.iterator();
92 | while(it.hasNext()) {
93 | long key = it.next();
94 | if (contains(key)) {
95 | result.addKey(key);
96 | }
97 | }
98 |
99 | return result;
100 | }
101 |
102 | public static class EmptyKeys extends KeyCollection {
103 |
104 | public EmptyKeys() {
105 | super(TYPE_EMPTY);
106 | }
107 |
108 | @Override
109 | public boolean equals(Object obj) {
110 | return (obj instanceof EmptyKeys);
111 | }
112 |
113 | @Override
114 | public boolean contains(long key) {
115 | return false;
116 | }
117 |
118 | @Override
119 | public KeyCollection intersect(KeyCollection keys) {
120 | return this;
121 | }
122 |
123 | @Override
124 | public Iterator iterator() {
125 | return new Iterator() {
126 | public boolean hasNext() { return false; }
127 | public Long next() { return -1L; }
128 | public void remove() { }
129 | };
130 | }
131 |
132 | @Override
133 | public boolean isEmpty() {
134 | return true;
135 | }
136 | /*
137 | @Override
138 | public PartitionInfo partitionEqually(int hostNum) {
139 | throw new UnsupportedOperationException("This is an EMPTY_KEYS instance, not partitionable");
140 | }
141 |
142 | @Override
143 | public KeyCollection[] split(int hostNum) {
144 | throw new UnsupportedOperationException("This is an EMPTY_KEYS instance, not partitionable");
145 | }
146 | */
147 | @Override
148 | public long size() {
149 | return 0;
150 | }
151 | };
152 |
153 | public static class AllKeys extends KeyCollection {
154 |
155 | public AllKeys() {
156 | super(TYPE_ALL);
157 | }
158 |
159 |
160 | @Override
161 | public boolean equals(Object obj) {
162 | return (obj instanceof AllKeys);
163 | }
164 |
165 | @Override
166 | public boolean contains(long key) {
167 | return true;
168 | }
169 |
170 | @Override
171 | public KeyCollection intersect(KeyCollection keys) {
172 | return keys;
173 | }
174 |
175 | @Override
176 | public Iterator iterator() {
177 | throw new UnsupportedOperationException("This is an ALL_KEYS instance, not iterable.");
178 | }
179 |
180 | @Override
181 | public boolean isEmpty() {
182 | return false;
183 | }
184 | /*
185 | @Override
186 | public PartitionInfo partitionEqually(int hostNum) {
187 | throw new UnsupportedOperationException("This is an ALL_KEYS instance, not partitionable");
188 | }
189 |
190 | @Override
191 | public KeyCollection[] split(int hostNum) {
192 | throw new UnsupportedOperationException("This is an ALL_KEYS instance, not partitionable");
193 | }
194 | */
195 | @Override
196 | public long size() {
197 | throw new UnsupportedOperationException("This is an ALL_KEYS instance, size unknown");
198 | }
199 | };
200 | }
201 |
202 |
203 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/KeyHash.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import java.io.DataInputStream;
4 | import java.io.DataOutputStream;
5 | import java.io.IOException;
6 | import java.util.Iterator;
7 |
8 | /**
9 | * Created by yunlong on 12/11/14.
10 | */
11 | public class KeyHash extends KeyCollection {
12 |
13 | public int hashQuato;
14 | public int hashIndex;
15 |
16 | public long minKey;
17 | public long maxKey;
18 | //public long totalKeyNum;
19 |
20 | private long first, last;
21 |
22 | KeyHash() {
23 | super(KeyCollection.TYPE_HASH);
24 | }
25 |
26 | public KeyHash(int hashQuato, int hashIndex, long minKey, long maxKey) {
27 | super(KeyCollection.TYPE_HASH);
28 |
29 | this.hashQuato = hashQuato;
30 | this.hashIndex = hashIndex;
31 | this.minKey = minKey;
32 | this.maxKey = maxKey;
33 |
34 | first = minKey - minKey%hashQuato + hashIndex;
35 | if (first < minKey) {
36 | first += hashQuato;
37 | }
38 | last = maxKey - maxKey%hashQuato + hashIndex;
39 | if (last > maxKey) {
40 | last -= hashQuato;
41 | }
42 | }
43 |
44 | @Override
45 | public boolean equals(Object obj) {
46 | if (!(obj instanceof KeyHash)) {
47 | return false;
48 | }
49 |
50 | KeyHash o = (KeyHash)obj;
51 | return (hashQuato == o.hashQuato) && (hashIndex == o.hashIndex) && (minKey == o.minKey) && (maxKey == o.maxKey);
52 | }
53 |
54 | @Override
55 | public int sizeAsBytes(DataDesc format) {
56 | return super.sizeAsBytes(format) + 8 + 4 * format.keySize;
57 | }
58 |
59 | @Override
60 | public void write(AbstractDataWriter out, DataDesc format) throws Exception {
61 | super.write(out, format);
62 |
63 | out.writeInt(hashQuato);
64 | out.writeInt(hashIndex);
65 |
66 | if (format.keyType == DataDesc.KEY_TYPE_INT) {
67 | out.writeInt((int)minKey);
68 | out.writeInt((int)maxKey);
69 | out.writeInt((int)first);
70 | out.writeInt((int)last);
71 | }
72 | else {
73 | out.writeLong(minKey);
74 | out.writeLong(maxKey);
75 | out.writeLong(first);
76 | out.writeLong(last);
77 | }
78 | }
79 |
80 | @Override
81 | public void read(AbstractDataReader in, DataDesc format) throws Exception {
82 | //super.read(in);
83 | hashQuato = in.readInt();
84 | hashIndex = in.readInt();
85 | if (format.keyType == DataDesc.KEY_TYPE_INT) {
86 | minKey = in.readInt();
87 | maxKey = in.readInt();
88 | first = in.readInt();
89 | last = in.readInt();
90 | }
91 | else {
92 | minKey = in.readLong();
93 | maxKey = in.readLong();
94 | first = in.readLong();
95 | last = in.readLong();
96 | }
97 | }
98 |
99 |
100 | public long size() {
101 | return (last - first) / hashQuato + 1;
102 | //return (totalKeyNum /hashQuato) + (((totalKeyNum % hashQuato) > hashIndex)? 1 : 0);
103 | }
104 |
105 | @Override
106 | public boolean contains(long key) {
107 | if ((key > last) || (key < first)) {
108 | //throw new RuntimeException("unexpected key: " + key + " >= " + totalKeyNum);
109 | return false;
110 | }
111 |
112 | //System.out.println("check contains: " + key);
113 | return key % hashQuato == hashIndex;
114 | }
115 |
116 | @Override
117 | public boolean isEmpty() {
118 | return size() == 0;
119 | }
120 |
121 | @Override
122 | public String toString() {
123 | return "[KeyHash: quato=" + hashQuato + ", index=" + hashIndex + ", min=" + minKey + ", max=" + maxKey + "]";
124 | }
125 |
126 | @Override
127 | public KeyCollection intersect(KeyCollection keys) {
128 |
129 | if (keys.equals(KeyCollection.ALL)) {
130 | return this;
131 | }
132 |
133 | if (keys.equals(KeyCollection.EMPTY)) {
134 | return keys;
135 | }
136 |
137 | if (keys instanceof KeyRange) {
138 | KeyRange r = (KeyRange) keys;
139 | if ((r.firstKey <= first) && (r.lastKey >= last)) {
140 | return this;
141 | }
142 |
143 | KeyHash newKeys = new KeyHash(hashQuato, hashIndex, Math.max(minKey, r.firstKey), Math.min(maxKey, r.lastKey));
144 | return newKeys;
145 | }
146 |
147 | KeyList list = new KeyList();
148 | Iterator it = keys.iterator();
149 | while(it.hasNext()) {
150 | long key = it.next();
151 | if (contains(key)) {
152 | list.addKey(key);
153 | }
154 | }
155 |
156 | if (list.isEmpty()) {
157 | return KeyCollection.EMPTY;
158 | }
159 |
160 | return list;
161 | }
162 |
163 |
164 | @Override
165 | public Iterator iterator() {
166 | return new _Iterator(this);
167 | }
168 |
169 | static class _Iterator implements Iterator {
170 |
171 | long currentKey;
172 | KeyHash keys;
173 |
174 | public _Iterator(KeyHash keys) {
175 | this.keys = keys;
176 | this.currentKey = keys.first;
177 | }
178 |
179 | public boolean hasNext() {
180 | return currentKey <= keys.last;
181 | }
182 |
183 | public Long next() {
184 | long k = currentKey;
185 | currentKey += keys.hashQuato;
186 | return k;
187 | }
188 |
189 | public void remove() {
190 | throw new RuntimeException("Not supported.");
191 | }
192 |
193 | }
194 | }
195 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/KeyList.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import java.io.DataInputStream;
4 | import java.io.DataOutputStream;
5 | import java.io.IOException;
6 | import java.util.HashSet;
7 | import java.util.Iterator;
8 |
9 | /**
10 | * Created by yunlong on 2/1/15.
11 | */
12 | public class KeyList extends KeyCollection {
13 |
14 | public HashSet keys;
15 |
16 | public KeyList() {
17 | super(KeyCollection.TYPE_LIST);
18 | keys = new HashSet();
19 | }
20 |
21 | public long size() {
22 | return keys.size();
23 | }
24 |
25 | @Override
26 | public int sizeAsBytes(DataDesc format) {
27 | return super.sizeAsBytes(format) + 4 + keys.size() * format.keySize;
28 | }
29 |
30 | @Override
31 | public void write(AbstractDataWriter out, DataDesc format) throws Exception {
32 | super.write(out, format);
33 |
34 | out.writeInt(keys.size());
35 | if (format.keyType == DataDesc.KEY_TYPE_INT) {
36 | for (long key : keys)
37 | out.writeInt((int)key);
38 | }
39 | else {
40 | for (long key : keys)
41 | out.writeLong(key);
42 | }
43 | }
44 |
45 | @Override
46 | public void read(AbstractDataReader in, DataDesc format) throws Exception {
47 |
48 | int count = in.readInt();
49 | if (format.keyType == DataDesc.KEY_TYPE_INT) {
50 | for (int i = 0; i < count; i++)
51 | keys.add(new Long(in.readInt()));
52 | }
53 | else {
54 | for (int i = 0; i < count; i++)
55 | keys.add(in.readLong());
56 | }
57 | }
58 |
59 | public void addKey(long k) {
60 | keys.add(k);
61 | }
62 |
63 | @Override
64 | public boolean contains(long key) {
65 | return keys.contains(key);
66 | }
67 |
68 | @Override
69 | public boolean isEmpty() {
70 | return keys.isEmpty();
71 | }
72 |
73 | @Override
74 | public KeyCollection intersect(KeyCollection keys) {
75 | KeyList list = new KeyList();
76 |
77 | Iterator it = keys.iterator();
78 | while(it.hasNext()) {
79 | long k = it.next();
80 | if (contains(k)) {
81 | list.keys.add(k);
82 | }
83 | }
84 |
85 | return list;
86 | }
87 |
88 | @Override
89 | public Iterator iterator() {
90 | return keys.iterator();
91 | }
92 |
93 | @Override
94 | public String toString() {
95 | String s = "[KeyList: size=" + keys.size();
96 | if (keys.size() > 0) {
97 | s += " first=" + keys.iterator().next();
98 | }
99 | s += "]";
100 |
101 | return s;
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/KeyRange.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import java.io.DataInputStream;
4 | import java.io.DataOutputStream;
5 | import java.io.IOException;
6 | import java.util.Iterator;
7 |
8 | /**
9 | * Created by yunlong on 12/11/14.
10 | */
11 | public class KeyRange extends KeyCollection {
12 |
13 | public static final KeyRange Single = new KeyRange(0, 0);
14 |
15 | public long firstKey, lastKey;
16 |
17 | KeyRange() {
18 | super(KeyCollection.TYPE_RANGE);
19 | }
20 |
21 | public KeyRange(long f, long l) {
22 | super(KeyCollection.TYPE_RANGE);
23 | firstKey = f;
24 | lastKey = l;
25 | }
26 |
27 | @Override
28 | public boolean equals(Object obj) {
29 | if (!(obj instanceof KeyRange)) {
30 | return false;
31 | }
32 |
33 | KeyRange range = (KeyRange)obj;
34 | return (firstKey == range.firstKey) && (lastKey == range.lastKey);
35 | }
36 |
37 | @Override
38 | public int sizeAsBytes(DataDesc format) {
39 | return super.sizeAsBytes(format) + 2 * format.keySize;
40 | }
41 |
42 | @Override
43 | public void write(AbstractDataWriter out, DataDesc format) throws Exception {
44 | super.write(out, format);
45 |
46 | if (format.keyType == DataDesc.KEY_TYPE_INT) {
47 | out.writeInt((int)firstKey);
48 | out.writeInt((int)lastKey);
49 | }
50 | else {
51 | out.writeLong(firstKey);
52 | out.writeLong(lastKey);
53 | }
54 | }
55 |
56 | @Override
57 | public void read(AbstractDataReader in, DataDesc format) throws Exception {
58 | if (format.keyType == DataDesc.KEY_TYPE_INT) {
59 | firstKey = in.readInt();
60 | lastKey = in.readInt();
61 | }
62 | else {
63 | firstKey = in.readLong();
64 | lastKey = in.readLong();
65 | }
66 | }
67 |
68 | public KeyCollection[] linearSplit(int hostNum) {
69 | KeyCollection[] sets = new KeyRange[hostNum];
70 |
71 | long start = firstKey;
72 | long step = (lastKey - firstKey + hostNum) / hostNum;
73 | for (int i = 0; i < hostNum; i++) {
74 | long end = Math.min(start + step - 1, lastKey);
75 | sets[i] = new KeyRange(start, end);
76 | start += step;
77 | }
78 |
79 | return sets;
80 | }
81 |
82 | public KeyCollection[] hashSplit(int hostNum) {
83 | KeyCollection[] sets = new KeyHash[hostNum];
84 |
85 | for (int i = 0; i < hostNum; i++) {
86 | sets[i] = new KeyHash(hostNum, i, firstKey, lastKey);
87 | }
88 |
89 | return sets;
90 | }
91 |
92 | public long size() {
93 | return lastKey - firstKey + 1;
94 | }
95 |
96 | @Override
97 | public boolean contains(long key) {
98 | return (key >= firstKey) && (key <= lastKey);
99 | }
100 |
101 | @Override
102 | public boolean isEmpty() {
103 | return firstKey > lastKey;
104 | }
105 |
106 | public boolean containsAll(KeyCollection keys) {
107 | KeyRange keyRange = (KeyRange)keys;
108 | return (keyRange.lastKey >= firstKey && keyRange.firstKey <= lastKey);
109 | }
110 |
111 | @Override
112 | public String toString() {
113 | return "[" + firstKey + ", " + lastKey + "]";
114 | }
115 |
116 | public KeyRange FetchSame(KeyRange kr) {
117 | long NewFirst = kr.firstKey > this.firstKey ? kr.firstKey : this.firstKey;
118 | long NewLast = kr.lastKey < this.lastKey ? kr.lastKey : this.lastKey;
119 | if (NewFirst > NewLast) return null;
120 | return new KeyRange(NewFirst, NewLast);
121 | }
122 |
123 | @Override
124 | public KeyCollection intersect(KeyCollection keys) {
125 |
126 | if (keys instanceof KeyRange) {
127 | KeyRange r = (KeyRange)keys;
128 | long min = Math.max(r.firstKey, firstKey);
129 | long max = Math.min(r.lastKey, lastKey);
130 |
131 | if (min > max) {
132 | return KeyCollection.EMPTY;
133 | }
134 |
135 | return new KeyRange(min, max);
136 | }
137 |
138 | if (keys instanceof KeyHash) {
139 | // warning: hope KeyHash don't call back
140 | return keys.intersect(this);
141 | }
142 |
143 | return super.intersect(keys);
144 | }
145 |
146 | public boolean mergeFrom(KeyRange keys) {
147 | if ((keys.firstKey > lastKey) && (keys.lastKey < firstKey)) {
148 | return false;
149 | }
150 |
151 | firstKey = Math.min(keys.firstKey, firstKey);
152 | lastKey = Math.max(keys.lastKey, lastKey);
153 | return true;
154 | }
155 |
156 |
157 | @Override
158 | public Iterator iterator() {
159 | return new _Iterator(this);
160 | }
161 |
162 | static class _Iterator implements Iterator {
163 |
164 | long index;
165 | KeyRange range;
166 |
167 | public _Iterator(KeyRange range) {
168 | this.range = range;
169 | this.index = range.firstKey;
170 | }
171 |
172 | public boolean hasNext() {
173 | return index <= range.lastKey;
174 | }
175 |
176 | public Long next() {
177 | return index++;
178 | }
179 |
180 | public void remove() {
181 | throw new RuntimeException("Not supported.");
182 | }
183 |
184 | }
185 | }
186 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/Logger.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import java.text.SimpleDateFormat;
4 | import java.util.Date;
5 |
6 | /**
7 | * Created by taotao on 15-1-14.
8 | *
9 | * Simple log system.
10 | */
11 | public class Logger {
12 |
13 | private final static String[] levels = {"NOTUSED", "DEBUG", "INFO", "WARN", "ERROR", "CRITICAL"};
14 |
15 | private final static int DEBUG = 1;
16 | private final static int INFO = 2;
17 | private final static int WARN = 3;
18 | private final static int ERROR = 4;
19 | private final static int CRITICAL = 5;
20 |
21 | private static SimpleDateFormat DATE_FORMAT = new SimpleDateFormat("dd/MM/yyyy HH:mm:ss.SSS");
22 |
23 | private static int outputLevel = INFO;
24 |
25 | public static void Log(int level, String msg, String module) {
26 | if (level < outputLevel) return;
27 |
28 | String typeStr = levels[level];
29 | System.out.println("==== [" + typeStr + "] [" + DATE_FORMAT.format(new Date()) + "] " +
30 | "[" + module + "] " + msg);
31 | }
32 |
33 | public static void debug(String msg, String module) {
34 | Log(DEBUG, msg, module);
35 | }
36 |
37 | public static void info(String msg, String module) {
38 | Log(INFO, msg, module);
39 | }
40 |
41 | public static void warn(String msg, String module) {
42 | Log(WARN, msg, module);
43 | }
44 |
45 | public static void error(String msg, String module) {
46 | Log(ERROR, msg, module);
47 | }
48 |
49 | public static void critical(String msg, String module) {
50 | Log(CRITICAL, msg, module);
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/SparseArray.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import com.intel.distml.api.DMatrix;
4 | import com.intel.distml.api.Session;
5 |
6 | import java.util.HashMap;
7 | import java.util.Map;
8 |
9 | /**
10 | * Created by yunlong on 12/31/15.
11 | */
12 | public abstract class SparseArray extends DMatrix {
13 |
14 | public SparseArray(long dim, int keyType, int valueType) {
15 | super(dim);
16 | format = new DataDesc(DataDesc.DATA_TYPE_ARRAY, keyType, valueType);
17 | }
18 |
19 | public HashMap fetch(KeyCollection rows, Session session) {
20 | HashMap result = new HashMap();
21 | byte[][] data = session.dataBus.fetch(name, rows, format);
22 | for (byte[] obj : data) {
23 | HashMap m = readMap(obj);
24 | result.putAll(m);
25 | }
26 |
27 | return result;
28 | }
29 |
30 | public void push(HashMap data, Session session) {
31 | byte[][] bufs = new byte[partitions.length][];
32 | for (int i = 0; i < partitions.length; ++i) {
33 | KeyCollection p = partitions[i];
34 | HashMap m = new HashMap();
35 | for (K key : data.keySet()) {
36 | long k = (key instanceof Integer)? ((Integer)key).longValue() : ((Long)key).longValue();
37 | if (p.contains(k))
38 | m.put(key, data.get(key));
39 | }
40 | bufs[i] = writeMap(m);
41 | }
42 |
43 | session.dataBus.push(name, format, bufs);
44 | }
45 |
46 | private HashMap readMap(byte[] buf) {
47 | HashMap data = new HashMap();
48 |
49 | int offset = 0;
50 | while (offset < buf.length) {
51 | K key = (K) format.readKey(buf, offset);
52 | offset += format.keySize;
53 |
54 | V value = (V) format.readValue(buf, offset);
55 | offset += format.valueSize;
56 |
57 | data.put(key, value);
58 | }
59 |
60 | return data;
61 | }
62 |
63 | private byte[] writeMap(HashMap data) {
64 | int recordLen = format.keySize + format.valueSize;
65 | byte[] buf = new byte[recordLen * data.size()];
66 |
67 | int offset = 0;
68 | for (Map.Entry entry : data.entrySet()) {
69 | format.writeKey((Number) entry.getKey(), buf, offset);
70 | offset += format.keySize;
71 |
72 | offset = format.writeValue(entry.getValue(), buf, offset);
73 | }
74 |
75 | return buf;
76 | }
77 |
78 | }
79 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/SparseMatrix.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import com.intel.distml.api.DMatrix;
4 | import com.intel.distml.api.Session;
5 |
6 | import java.lang.reflect.Array;
7 | import java.util.HashMap;
8 | import java.util.Map;
9 |
10 | /**
11 | * Created by yunlong on 12/31/15.
12 | */
13 | public abstract class SparseMatrix extends DMatrix {
14 |
15 | protected KeyRange colKeys;
16 |
17 | public SparseMatrix(long dim, int cols, int keyType, int valueType) {
18 | super(dim);
19 | format = new DataDesc(DataDesc.DATA_TYPE_MATRIX, keyType, valueType);
20 | colKeys = new KeyRange(0, cols-1);
21 | }
22 |
23 | public SparseMatrix(long dim, int cols, int keyType, int valueType,
24 | boolean denseColumn) {
25 | super(dim);
26 | format = new DataDesc(DataDesc.DATA_TYPE_MATRIX, keyType, valueType, false, denseColumn);
27 | colKeys = new KeyRange(0, cols-1);
28 | }
29 |
30 | public KeyCollection getColKeys() {
31 | return colKeys;
32 | }
33 |
34 |
35 | public HashMap fetch(KeyCollection rows, Session session) {
36 | HashMap result = new HashMap();
37 | byte[][] data = session.dataBus.fetch(name, rows, format);
38 | for (byte[] obj : data) {
39 | HashMap m = readMap(obj);
40 | result.putAll(m);
41 | }
42 |
43 | return result;
44 | }
45 |
46 | public void push(HashMap data, Session session) {
47 | byte[][] bufs = new byte[partitions.length][];
48 | for (int i = 0; i < partitions.length; ++i) {
49 | KeyCollection p = partitions[i];
50 | HashMap m = new HashMap();
51 | for (K key : data.keySet()) {
52 | long k = (key instanceof Integer)? ((Integer)key).longValue() : ((Long)key).longValue();
53 | if (p.contains(k))
54 | m.put(key, data.get(key));
55 | }
56 | bufs[i] = writeMap(m);
57 | }
58 |
59 | session.dataBus.push(name, format, bufs);
60 | }
61 |
62 | private HashMap readMap(byte[] buf) {
63 | HashMap data = new HashMap();
64 |
65 | int offset = 0;
66 | while (offset < buf.length) {
67 | K key = (K) format.readKey(buf, offset);
68 | offset += format.keySize;
69 |
70 | V[] values = createValueArray((int) getColKeys().size());
71 | if (format.denseColumn) {
72 | for (int i = 0; i < getColKeys().size(); i++) {
73 | values[i] = (V) format.readValue(buf, offset);
74 | offset += format.valueSize;
75 | }
76 | }
77 | else {
78 | int count = format.readInt(buf, offset);
79 | offset += 4;
80 |
81 | for (int i = 0; i < count; i++) {
82 | int index = format.readInt(buf, offset);
83 | offset += 4;
84 | V value = (V) format.readValue(buf, offset);
85 | offset += format.valueSize;
86 |
87 | values[index] = value;
88 | }
89 | }
90 | data.put(key, values);
91 | }
92 |
93 | return data;
94 | }
95 |
96 | private byte[] writeMap(HashMap data) {
97 |
98 | byte[] buf;
99 | if (format.denseColumn) {
100 | int len = (int) (format.valueSize * data.size() * colKeys.size());
101 | buf = new byte[format.keySize * data.size() + len];
102 | }
103 | else {
104 | int nzcount = 0;
105 | for (V[] values : data.values()) {
106 | for (V value : values) {
107 | if (!isZero(value)) {
108 | nzcount++;
109 | }
110 | }
111 | }
112 | int len = (int) ((format.valueSize + 4) * nzcount);
113 | buf = new byte[format.keySize * data.size() + len];
114 | }
115 |
116 | int offset = 0;
117 | for (Map.Entry entry : data.entrySet()) {
118 | format.writeKey((Number)entry.getKey(), buf, offset);
119 | offset += format.keySize;
120 | V[] values = entry.getValue();
121 | if (format.denseColumn) {
122 | for (int i = 0; i < getColKeys().size(); i++) {
123 | format.writeValue(values[i], buf, offset);
124 | offset += format.valueSize;
125 | }
126 | }
127 | else {
128 | int counterIndex = offset;
129 | offset += 4;
130 |
131 | int counter = 0;
132 | for (int i = 0; i < values.length; i++) {
133 | V value = values[i];
134 | if (!isZero(value)) {
135 | format.write(i, buf, offset);
136 | offset += 4;
137 | format.writeValue(value, buf, offset);
138 | offset += format.valueSize;
139 | }
140 |
141 | counter++;
142 | }
143 | format.write(counter, buf, counterIndex);
144 | }
145 | }
146 |
147 | return buf;
148 | }
149 |
150 | abstract protected boolean isZero(V value);
151 | abstract protected V[] createValueArray(int size);
152 | }
153 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/Utils.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util;
2 |
3 | import java.io.DataInputStream;
4 | import java.io.IOException;
5 | import java.net.Inet4Address;
6 | import java.net.InetAddress;
7 | import java.net.NetworkInterface;
8 | import java.net.SocketException;
9 | import java.util.Date;
10 | import java.util.Enumeration;
11 |
12 | /**
13 | * Created by spark on 7/2/14.
14 | */
15 | public class Utils {
16 |
17 | public static String[] getNetworkAddress(String networkPrefix) throws SocketException {
18 | String[] addr = new String[2];
19 |
20 | Enumeration allNetInterfaces = NetworkInterface.getNetworkInterfaces();
21 | InetAddress ip = null;
22 | while (allNetInterfaces.hasMoreElements())
23 | {
24 | NetworkInterface netInterface = (NetworkInterface) allNetInterfaces.nextElement();
25 | Enumeration addresses = netInterface.getInetAddresses();
26 | while (addresses.hasMoreElements())
27 | {
28 | ip = (InetAddress) addresses.nextElement();
29 | if (ip != null && ip instanceof Inet4Address)
30 | {
31 | if (networkPrefix != null) {
32 | if (ip.getHostAddress().startsWith(networkPrefix)) {
33 | addr[0] = ip.getHostAddress();
34 | addr[1] = ip.getHostName();
35 | System.out.println("Server address = " + addr[0] + ", " + addr[1]);
36 | }
37 | }
38 | else {
39 | if (!ip.getHostAddress().startsWith("127")) {
40 | addr[0] = ip.getHostAddress();
41 | addr[1] = ip.getHostName();
42 | System.out.println("Server address = " + addr[0] + ", " + addr[1]);
43 | }
44 | }
45 | }
46 | }
47 | }
48 |
49 | return addr;
50 | }
51 |
52 | public static void waitUntil(DataInputStream is, int size) throws IOException {
53 | while(is.available() < size) { try { Thread.sleep(1); } catch (Exception e){}}
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/store/DoubleArrayStore.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util.store;
2 |
3 | import com.intel.distml.util.*;
4 |
5 | import java.io.DataInputStream;
6 | import java.io.DataOutputStream;
7 | import java.io.IOException;
8 | import java.util.HashMap;
9 | import java.util.Iterator;
10 | import java.util.Map;
11 |
12 | /**
13 | * Created by yunlong on 12/8/15.
14 | */
15 | public class DoubleArrayStore extends DataStore {
16 |
17 | public static final int VALUE_SIZE = 8;
18 |
19 | transient KeyCollection localRows;
20 | transient double[] localData;
21 |
22 | public KeyCollection rows() {
23 | return localRows;
24 | }
25 |
26 | public int rowSize() {
27 | return 1;
28 | }
29 |
30 | public void init(KeyCollection keys) {
31 | this.localRows = keys;
32 | localData = new double[(int)keys.size()];
33 | for (int i = 0; i < keys.size(); i++)
34 | localData[i] = 0.0;
35 | }
36 |
37 | public int indexOf(long key) {
38 | if (localRows instanceof KeyRange) {
39 | return (int) (key - ((KeyRange)localRows).firstKey);
40 | }
41 | else if (localRows instanceof KeyHash) {
42 | KeyHash hash = (KeyHash) localRows;
43 | return (int) ((key - hash.minKey) % hash.hashQuato);
44 | }
45 |
46 | throw new RuntimeException("Only KeyRange or KeyHash is allowed in server storage");
47 | }
48 |
49 | public long keyOf(int index) {
50 | if (localRows instanceof KeyRange) {
51 | return ((KeyRange)localRows).firstKey + index;
52 | }
53 | else if (localRows instanceof KeyHash) {
54 | KeyHash hash = (KeyHash) localRows;
55 | return hash.minKey + index * hash.hashQuato;
56 | }
57 |
58 | throw new RuntimeException("Only KeyRange or KeyHash is allowed in server storage");
59 | }
60 |
61 | @Override
62 | public void writeAll(DataOutputStream os) throws IOException {
63 | System.out.println("saving to: " + localData.length);
64 | for (int i = 0; i < localData.length; i++) {
65 | os.writeDouble(localData[i]);
66 | }
67 | }
68 |
69 | @Override
70 | public void readAll(DataInputStream is) throws IOException {
71 | for (int i = 0; i < localData.length; i++) {
72 | localData[i] = is.readDouble();
73 | }
74 | }
75 |
76 | @Override
77 | public void syncTo(DataOutputStream os, int fromRow, int toRow) throws IOException {
78 | for (int i = fromRow; i <= toRow; i++) {
79 | os.writeDouble(localData[i]);
80 | }
81 | System.out.println("sync done");
82 | }
83 |
84 | @Override
85 | public void syncFrom(DataInputStream is, int fromRow, int toRow) throws IOException {
86 | for (int i = fromRow; i <= toRow; i++) {
87 | localData[i] = is.readDouble();
88 | }
89 | System.out.println("sync done");
90 | }
91 |
92 | @Override
93 | public byte[] handleFetch(DataDesc format, KeyCollection rows) {
94 |
95 | KeyCollection keys = localRows.intersect(rows);
96 |
97 | int len = (int) ((format.keySize + VALUE_SIZE) * keys.size());
98 | byte[] buf = new byte[len];
99 |
100 | Iterator it = keys.iterator();
101 | int offset = 0;
102 | while(it.hasNext()) {
103 | long k = it.next();
104 |
105 | format.writeKey((Number) k, buf, offset);
106 | offset += format.keySize;
107 |
108 | double value = localData[indexOf(k)];
109 | format.writeValue(value, buf, offset);
110 | offset += VALUE_SIZE;
111 | }
112 | return buf;
113 | }
114 |
115 | public void handlePush(DataDesc format, byte[] data) {
116 |
117 | int offset = 0;
118 | while (offset < data.length) {
119 | long key = format.readKey(data, offset).longValue();
120 | offset += format.keySize;
121 |
122 | double update = format.readDouble(data, offset);
123 | offset += VALUE_SIZE;
124 |
125 | localData[indexOf(key)] += update;
126 | }
127 | }
128 |
129 | public Iter iter() {
130 | return new Iter();
131 | }
132 |
133 | public class Iter {
134 |
135 | int p;
136 |
137 | public Iter() {
138 | p = -1;
139 | }
140 |
141 | public boolean hasNext() {
142 | return p < localData.length - 1;
143 | }
144 |
145 | public long key() {
146 | return keyOf(p);
147 | }
148 |
149 | public double value() {
150 | return localData[p];
151 | }
152 |
153 | public boolean next() {
154 | p++;
155 | return p < localData.length;
156 | }
157 | }
158 |
159 | }
160 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/store/DoubleMatrixStore.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util.store;
2 |
3 | import com.intel.distml.util.*;
4 |
5 | import java.io.DataInputStream;
6 | import java.io.DataOutputStream;
7 | import java.io.IOException;
8 | import java.util.Iterator;
9 | import java.util.Random;
10 |
11 | /**
12 | * Created by jimmy on 16-1-4.
13 | */
14 | public class DoubleMatrixStore extends DataStore {
15 | public static final int VALUE_SIZE = 8;
16 |
17 | transient KeyCollection localRows;
18 | transient int rowSize;
19 | transient double[][] localData;
20 |
21 | public KeyCollection rows() {
22 | return localRows;
23 | }
24 | public int rowSize() {
25 | return rowSize;
26 | }
27 |
28 | public void init(KeyCollection rows, int cols) {
29 | this.localRows = rows;
30 | this.rowSize = cols;
31 |
32 | localData = new double[(int)localRows.size()][rowSize];
33 |
34 | Runtime r = Runtime.getRuntime();
35 | System.out.println("memory: " + r.freeMemory() + ", " + r.totalMemory() + ", needed: " + localRows.size() * cols);
36 | for (int i = 0; i < localRows.size(); i++)
37 | for (int j = 0; j < rowSize; j++)
38 | localData[i][j] = 0.0;
39 |
40 | }
41 |
42 | @Override
43 | public void writeAll(DataOutputStream os) throws IOException {
44 | for (int i = 0; i < localData.length; i++) {
45 | for (int j = 0; j < rowSize; j++) {
46 | os.writeDouble(localData[i][j]);
47 | }
48 | }
49 | }
50 |
51 | @Override
52 | public void readAll(DataInputStream is) throws IOException {
53 | for (int i = 0; i < localData.length; i++) {
54 | for (int j = 0; j < rowSize; j++) {
55 | localData[i][j] = is.readDouble();
56 | }
57 | }
58 | }
59 |
60 | @Override
61 | public void syncTo(DataOutputStream os, int fromRow, int toRow) throws IOException {
62 | for (int i = fromRow; i <= toRow; i++) {
63 | for (int j = 0; j < rowSize; j++) {
64 | os.writeDouble(localData[i][j]);
65 | }
66 | }
67 | }
68 |
69 | @Override
70 | public void syncFrom(DataInputStream is, int fromRow, int toRow) throws IOException {
71 | int rowSize = (int) localRows.size();
72 | for (int i = fromRow; i <= toRow; i++) {
73 | for (int j = 0; j < rowSize; j++) {
74 | localData[i][j] = is.readDouble();
75 | }
76 | }
77 | }
78 |
79 | @Override
80 | public byte[] handleFetch(DataDesc format, KeyCollection rows) {
81 |
82 | KeyCollection keys = localRows.intersect(rows);
83 | byte[] buf;
84 | if (format.denseColumn) {
85 | int keySpace = (int) (format.keySize * keys.size());
86 | int valueSpace = (int) (VALUE_SIZE * keys.size() * localData[0].length);
87 | buf = new byte[keySpace + valueSpace];
88 | }
89 | else {
90 | int nzcount = 0;
91 | Iterator it = keys.iterator();
92 | while (it.hasNext()) {
93 | long k = it.next();
94 | double[] values = localData[indexOf(k)];
95 | for (int i = 0; i < values.length; i++) {
96 | if (values[i] != 0.0) {
97 | nzcount++;
98 | }
99 | }
100 | }
101 | int len = (VALUE_SIZE + 4) * nzcount;
102 | buf = new byte[format.keySize * (int)keys.size() + len];
103 | }
104 |
105 | Iterator it = keys.iterator();
106 | int offset = 0;
107 | while(it.hasNext()) {
108 | long k = it.next();
109 | format.writeKey((Number)k, buf, offset);
110 | offset += format.keySize;
111 |
112 | double[] values = localData[indexOf(k)];
113 | if (format.denseColumn) {
114 | for (int i = 0; i < values.length; i++) {
115 | format.writeValue(values[i], buf, offset);
116 | offset += VALUE_SIZE;
117 | }
118 | }
119 | else {
120 | int counterIndex = offset;
121 | offset += 4;
122 |
123 | int counter = 0;
124 | for (int i = 0; i < values.length; i++) {
125 | if (values[i] != 0) {
126 | format.write(i, buf, offset);
127 | offset += 4;
128 | format.write(values[i], buf, offset);
129 | offset += VALUE_SIZE;
130 | }
131 |
132 | counter++;
133 | }
134 | format.write(counter, buf, counterIndex);
135 | }
136 | }
137 |
138 | return buf;
139 | }
140 |
141 | public int indexOf(long key) {
142 | if (localRows instanceof KeyRange) {
143 | return (int) (key - ((KeyRange)localRows).firstKey);
144 | }
145 | else if (localRows instanceof KeyHash) {
146 | KeyHash hash = (KeyHash) localRows;
147 | return (int) ((key - hash.minKey) % hash.hashQuato);
148 | }
149 |
150 | throw new RuntimeException("Only KeyRange or KeyHash is allowed in server storage");
151 | }
152 |
153 | public void handlePush(DataDesc format, byte[] data) {
154 |
155 | int offset = 0;
156 | while (offset < data.length) {
157 | long key = format.readKey(data, offset).longValue();
158 | offset += format.keySize;
159 | offset = updateRow(key, data, offset, format);
160 | }
161 | }
162 |
163 | private int updateRow(long key, byte[] data, int start, DataDesc format) {
164 | assert(localRows.contains(key));
165 |
166 | int index = indexOf(key);
167 | double[] row = localData[index];
168 | int offset = start;
169 | if (format.denseColumn) {
170 | for (int i = 0; i < row.length; i++) {
171 | double update = format.readDouble(data, offset);
172 | row[i] += update;
173 | offset += VALUE_SIZE;
174 | }
175 | }
176 | else {
177 | int count = format.readInt(data, offset);
178 | offset += 4;
179 | for (int i = 0; i < count; i++) {
180 | int col = format.readInt(data, offset);
181 | offset += 4;
182 | assert(col < row.length);
183 |
184 | double update = format.readDouble(data, offset);
185 | row[col] += update;
186 | offset += VALUE_SIZE;
187 | }
188 | }
189 |
190 | return offset;
191 | }
192 | public void rand() {
193 | long seed = 1L;
194 | Random random = new Random(seed);
195 | int cols = this.localData[0].length;
196 | for (int i = 0; i < this.localRows.size(); i++) {
197 | double sum = 0.0;
198 | for (int j = 0; j < cols; j++) {
199 | localData[i][j] = Math.abs(random.nextGaussian());
200 | sum += localData[i][j] * localData[i][j];
201 | }
202 | sum = Math.sqrt(sum);
203 | for (int j=0; j it = keys.iterator();
96 | int offset = 0;
97 | while(it.hasNext()) {
98 | long k = it.next();
99 |
100 | format.writeKey((Number) k, buf, offset);
101 | offset += format.keySize;
102 |
103 | float value = localData[indexOf(k)];
104 | format.writeValue(value, buf, offset);
105 | offset += VALUE_SIZE;
106 | }
107 | return buf;
108 | }
109 |
110 | public void handlePush(DataDesc format, byte[] data) {
111 |
112 | int offset = 0;
113 | while (offset < data.length) {
114 | long key = format.readKey(data, offset).longValue();
115 | offset += format.keySize;
116 |
117 | float update = format.readFloat(data, offset);
118 | offset += VALUE_SIZE;
119 |
120 | localData[indexOf(key)] += update;
121 | }
122 | }
123 |
124 | public Iter iter() {
125 | return new Iter();
126 | }
127 |
128 | public class Iter {
129 |
130 | int p;
131 |
132 | public Iter() {
133 | p = -1;
134 | }
135 |
136 | public boolean hasNext() {
137 | return p < localData.length - 1;
138 | }
139 |
140 | public long key() {
141 | return keyOf(p);
142 | }
143 |
144 | public float value() {
145 | return localData[p];
146 | }
147 |
148 | public boolean next() {
149 | p++;
150 | return p < localData.length;
151 | }
152 | }
153 |
154 | }
155 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/store/FloatMatrixStore.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util.store;
2 |
3 | import com.intel.distml.util.*;
4 |
5 | import java.io.DataInputStream;
6 | import java.io.DataOutputStream;
7 | import java.io.IOException;
8 | import java.util.Iterator;
9 | import java.util.Random;
10 |
11 | /**
12 | * Created by yunlong on 1/3/16.
13 | */
14 | public class FloatMatrixStore extends DataStore {
15 | public static final int VALUE_SIZE = 4;
16 |
17 | transient KeyCollection localRows;
18 | transient int rowSize;
19 | transient float[][] localData;
20 |
21 | public KeyCollection rows() {
22 | return localRows;
23 | }
24 | public int rowSize() {
25 | return rowSize;
26 | }
27 |
28 | public void init(KeyCollection keys, int cols) {
29 | this.localRows = keys;
30 | this.rowSize = cols;
31 |
32 | localData = new float[(int)localRows.size()][rowSize];
33 |
34 | for (int i = 0; i < localRows.size(); i++)
35 | for (int j = 0; j < rowSize; j++)
36 | localData[i][j] = 0.0f;
37 | }
38 |
39 | public void rand() {
40 | System.out.println("init with random values");
41 |
42 | int rows = (int) localRows.size();
43 |
44 | Random r = new Random();
45 | for (int i = 0; i < rows; i++) {
46 | for (int j = 0; j < rowSize; j++) {
47 | int a = r.nextInt(100);
48 | localData[i][j] = (a / 100.0f - 0.5f) / rowSize;
49 | }
50 | }
51 | }
52 |
53 | public void set(String value) {
54 | setValue(Float.parseFloat(value));
55 | }
56 |
57 | public void zero(String value) {
58 | setValue(0f);
59 | }
60 |
61 | private void setValue(float v) {
62 | System.out.println("init with value: " + v);
63 |
64 | int rows = (int) localRows.size();
65 |
66 | for (int i = 0; i < rows; i++) {
67 | for (int j = 0; j < rowSize; j++) {
68 | localData[i][j] = v;
69 | }
70 | }
71 | }
72 |
73 | @Override
74 | public void writeAll(DataOutputStream os) throws IOException {
75 |
76 | for (int i = 0; i < localData.length; i++) {
77 | for (int j = 0; j < rowSize; j++) {
78 | os.writeFloat(localData[i][j]);
79 | }
80 | }
81 | }
82 |
83 | @Override
84 | public void readAll(DataInputStream is) throws IOException {
85 |
86 | for (int i = 0; i < localData.length; i++) {
87 | for (int j = 0; j < rowSize; j++) {
88 | localData[i][j] = is.readFloat();
89 | }
90 | }
91 | }
92 |
93 | @Override
94 | public void syncTo(DataOutputStream os, int fromRow, int toRow) throws IOException {
95 | for (int i = fromRow; i <= toRow; i++) {
96 | for (int j = 0; j < rowSize; j++) {
97 | os.writeFloat(localData[i][j]);
98 | }
99 | }
100 | }
101 |
102 | @Override
103 | public void syncFrom(DataInputStream is, int fromRow, int toRow) throws IOException {
104 | int rowSize = (int) localRows.size();
105 | for (int i = fromRow; i <= toRow; i++) {
106 | for (int j = 0; j < rowSize; j++) {
107 | localData[i][j] = is.readFloat();
108 | }
109 | }
110 | }
111 |
112 | @Override
113 | public byte[] handleFetch(DataDesc format, KeyCollection rows) {
114 |
115 | System.out.println("handle fetch request: " + rows);
116 | KeyCollection keys = localRows.intersect(rows);
117 | byte[] buf;
118 | if (format.denseColumn) {
119 | int keySpace = (int) (format.keySize * keys.size());
120 | int valueSpace = (int) (VALUE_SIZE * keys.size() * localData[0].length);
121 | buf = new byte[keySpace + valueSpace];
122 | System.out.println("buf size: " + buf.length);
123 | }
124 | else {
125 | int nzcount = 0;
126 | Iterator it = keys.iterator();
127 | while (it.hasNext()) {
128 | long k = it.next();
129 | float[] values = localData[indexOf(k)];
130 | for (int i = 0; i < values.length; i++) {
131 | if (values[i] != 0.0) {
132 | nzcount++;
133 | }
134 | }
135 | }
136 | int len = (VALUE_SIZE + 4) * nzcount;
137 | buf = new byte[format.keySize * (int)keys.size() + len];
138 | }
139 |
140 | Iterator it = keys.iterator();
141 | int offset = 0;
142 | while(it.hasNext()) {
143 | long k = it.next();
144 | format.writeKey((Number)k, buf, offset);
145 | offset += format.keySize;
146 |
147 | float[] values = localData[indexOf(k)];
148 | if (format.denseColumn) {
149 | for (int i = 0; i < values.length; i++) {
150 | format.writeValue(values[i], buf, offset);
151 | offset += VALUE_SIZE;
152 | }
153 | }
154 | else {
155 | int counterIndex = offset;
156 | offset += 4;
157 |
158 | int counter = 0;
159 | for (int i = 0; i < values.length; i++) {
160 | if (values[i] != 0) {
161 | format.write(i, buf, offset);
162 | offset += 4;
163 | format.write(values[i], buf, offset);
164 | offset += VALUE_SIZE;
165 | }
166 |
167 | counter++;
168 | }
169 | format.write(counter, buf, counterIndex);
170 | }
171 | }
172 |
173 | return buf;
174 | }
175 |
176 | public int indexOf(long key) {
177 | if (localRows instanceof KeyRange) {
178 | return (int) (key - ((KeyRange)localRows).firstKey);
179 | }
180 | else if (localRows instanceof KeyHash) {
181 | KeyHash hash = (KeyHash) localRows;
182 | return (int) ((key - hash.minKey) % hash.hashQuato);
183 | }
184 |
185 | throw new RuntimeException("Only KeyRange or KeyHash is allowed in server storage");
186 | }
187 |
188 | public long keyOf(int index) {
189 | if (localRows instanceof KeyRange) {
190 | return ((KeyRange)localRows).firstKey + index;
191 | }
192 | else if (localRows instanceof KeyHash) {
193 | KeyHash hash = (KeyHash) localRows;
194 | return hash.minKey + index * hash.hashQuato;
195 | }
196 |
197 | throw new RuntimeException("Only KeyRange or KeyHash is allowed in server storage");
198 | }
199 |
200 | public void handlePush(DataDesc format, byte[] data) {
201 |
202 | int offset = 0;
203 | while (offset < data.length) {
204 | long key = format.readKey(data, offset).longValue();
205 | offset += format.keySize;
206 | offset = updateRow(key, data, offset, format);
207 | }
208 | }
209 |
210 | private int updateRow(long key, byte[] data, int start, DataDesc format) {
211 | assert(localRows.contains(key));
212 |
213 | int index = indexOf(key);
214 | float[] row = localData[index];
215 | int offset = start;
216 | if (format.denseColumn) {
217 | for (int i = 0; i < row.length; i++) {
218 | float update = format.readFloat(data, offset);
219 | row[i] += update;
220 | offset += VALUE_SIZE;
221 | }
222 | }
223 | else {
224 | int count = format.readInt(data, offset);
225 | offset += 4;
226 | for (int i = 0; i < count; i++) {
227 | int col = format.readInt(data, offset);
228 | offset += 4;
229 | assert(col < row.length);
230 |
231 | float update = format.readFloat(data, offset);
232 | row[col] += update;
233 | offset += VALUE_SIZE;
234 | }
235 | }
236 |
237 | return offset;
238 | }
239 |
240 |
241 | public Iter iter() {
242 | return new Iter();
243 | }
244 |
245 | public class Iter {
246 |
247 | int p;
248 |
249 | public Iter() {
250 | p = -1;
251 | }
252 |
253 | public boolean hasNext() {
254 | return p < localData.length - 1;
255 | }
256 |
257 | public long key() {
258 | return keyOf(p);
259 | }
260 |
261 | public float[] value() {
262 | return localData[p];
263 | }
264 |
265 | public boolean next() {
266 | p++;
267 | return p < localData.length;
268 | }
269 | }
270 | }
271 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/store/IntArrayStore.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util.store;
2 |
3 | import com.intel.distml.util.*;
4 |
5 | import java.io.DataInputStream;
6 | import java.io.DataOutputStream;
7 | import java.io.IOException;
8 | import java.util.HashMap;
9 | import java.util.Iterator;
10 | import java.util.Map;
11 |
12 | /**
13 | * Created by yunlong on 12/8/15.
14 | */
15 | public class IntArrayStore extends DataStore {
16 |
17 | public static final int VALUE_SIZE = 4;
18 |
19 | transient KeyCollection localRows;
20 | transient int[] localData;
21 |
22 | public KeyCollection rows() {
23 | return localRows;
24 | }
25 | public int rowSize() {
26 | return 1;
27 | }
28 |
29 | public void init(KeyCollection keys) {
30 | this.localRows = keys;
31 | localData = new int[(int)keys.size()];
32 |
33 | for (int i = 0; i < keys.size(); i++)
34 | localData[i] = 0;
35 | }
36 |
37 | public int indexOf(long key) {
38 | if (localRows instanceof KeyRange) {
39 | return (int) (key - ((KeyRange)localRows).firstKey);
40 | }
41 | else if (localRows instanceof KeyHash) {
42 | KeyHash hash = (KeyHash) localRows;
43 | return (int) ((key - hash.minKey) % hash.hashQuato);
44 | }
45 |
46 | throw new RuntimeException("Only KeyRange or KeyHash is allowed in server storage");
47 | }
48 |
49 | @Override
50 | public void writeAll(DataOutputStream os) throws IOException {
51 | for (int i = 0; i < localData.length; i++) {
52 | os.writeInt(localData[i]);
53 | }
54 | }
55 |
56 | @Override
57 | public void readAll(DataInputStream is) throws IOException {
58 | for (int i = 0; i < localData.length; i++) {
59 | localData[i] = is.readInt();
60 | }
61 | }
62 |
63 | @Override
64 | public void syncTo(DataOutputStream os, int fromRow, int toRow) throws IOException {
65 | for (int i = fromRow; i <= toRow; i++) {
66 | os.writeInt(localData[i]);
67 | }
68 | }
69 |
70 | @Override
71 | public void syncFrom(DataInputStream is, int fromRow, int toRow) throws IOException {
72 | for (int i = fromRow; i <= toRow; i++) {
73 | localData[i] = is.readInt();
74 | }
75 | }
76 |
77 | @Override
78 | public byte[] handleFetch(DataDesc format, KeyCollection rows) {
79 |
80 | KeyCollection keys = localRows.intersect(rows);
81 | int len = (int) ((format.keySize + VALUE_SIZE) * keys.size());
82 | byte[] buf = new byte[len];
83 |
84 | Iterator it = keys.iterator();
85 | int offset = 0;
86 | while(it.hasNext()) {
87 | long k = it.next();
88 | format.writeKey((Number) k, buf, offset);
89 | offset += format.keySize;
90 | format.writeValue(localData[indexOf(k)], buf, offset);
91 | offset += VALUE_SIZE;
92 | }
93 |
94 | return buf;
95 | }
96 |
97 | public void handlePush(DataDesc format, byte[] data) {
98 |
99 | int offset = 0;
100 | while (offset < data.length) {
101 | long key = format.readKey(data, offset).longValue();
102 | offset += format.keySize;
103 |
104 | int update = format.readInt(data, offset);
105 | offset += VALUE_SIZE;
106 |
107 | localData[indexOf(key)] += update;
108 | if (localData[indexOf(key)] < 0) {
109 | throw new IllegalStateException("invalid k counter: " + key + ", " + localData[indexOf(key)]);
110 | }
111 |
112 | }
113 | }
114 |
115 | }
116 |
--------------------------------------------------------------------------------
/src/main/java/com/intel/distml/util/store/IntMatrixStore.java:
--------------------------------------------------------------------------------
1 | package com.intel.distml.util.store;
2 |
3 | import com.intel.distml.util.*;
4 |
5 | import java.io.DataInputStream;
6 | import java.io.DataOutputStream;
7 | import java.io.IOException;
8 | import java.util.HashMap;
9 | import java.util.Iterator;
10 | import java.util.Map;
11 |
12 | /**
13 | * Created by yunlong on 12/8/15.
14 | */
15 | public class IntMatrixStore extends DataStore {
16 |
17 | public static final int VALUE_SIZE = 4;
18 |
19 | transient KeyCollection localRows;
20 | transient int rowSize;
21 | transient int[][] localData;
22 |
23 | public KeyCollection rows() {
24 | return localRows;
25 | }
26 | public int rowSize() {
27 | return rowSize;
28 | }
29 |
30 | public void init(KeyCollection keys, int cols) {
31 | this.localRows = keys;
32 | this.rowSize = cols;
33 | localData = new int[(int)keys.size()][cols];
34 |
35 | Runtime r = Runtime.getRuntime();
36 | System.out.println("memory: " + r.freeMemory() + ", " + r.totalMemory() + ", needed: " + localRows.size() * rowSize);
37 | for (int i = 0; i < localRows.size(); i++)
38 | for (int j = 0; j < rowSize; j++)
39 | localData[i][j] = 0;
40 |
41 | }
42 |
43 | @Override
44 | public void writeAll(DataOutputStream os) throws IOException {
45 | for (int i = 0; i < localData.length; i++) {
46 | for (int j = 0; j < rowSize; j++) {
47 | os.writeInt(localData[i][j]);
48 | }
49 | }
50 | }
51 |
52 | @Override
53 | public void readAll(DataInputStream is) throws IOException {
54 | for (int i = 0; i < localData.length; i++) {
55 | for (int j = 0; j < rowSize; j++) {
56 | localData[i][j] = is.readInt();
57 | }
58 | }
59 | }
60 |
61 | @Override
62 | public void syncTo(DataOutputStream os, int fromRow, int toRow) throws IOException {
63 | for (int i = fromRow; i <= toRow; i++) {
64 | for (int j = 0; j < rowSize; j++) {
65 | os.writeInt(localData[i][j]);
66 | }
67 | }
68 | }
69 |
70 | @Override
71 | public void syncFrom(DataInputStream is, int fromRow, int toRow) throws IOException {
72 | int rowSize = (int) localRows.size();
73 | for (int i = fromRow; i <= toRow; i++) {
74 | for (int j = 0; j < rowSize; j++) {
75 | localData[i][j] = is.readInt();
76 | }
77 | }
78 | }
79 |
80 | @Override
81 | public byte[] handleFetch(DataDesc format, KeyCollection rows) {
82 |
83 | KeyCollection keys = localRows.intersect(rows);
84 | byte[] buf;
85 | if (format.denseColumn) {
86 | int keySpace = (int) (format.keySize * keys.size());
87 | int valueSpace = (int) (VALUE_SIZE * keys.size() * localData[0].length);
88 | buf = new byte[keySpace + valueSpace];
89 | }
90 | else {
91 | int nzcount = 0;
92 | Iterator it = keys.iterator();
93 | while (it.hasNext()) {
94 | long k = it.next();
95 | int[] values = localData[indexOf(k)];
96 | for (int i = 0; i < values.length; i++) {
97 | if (values[i] != 0) {
98 | nzcount++;
99 | }
100 | }
101 | }
102 | int len = (VALUE_SIZE + 4) * nzcount;
103 | buf = new byte[format.keySize * (int)keys.size() + len];
104 | }
105 |
106 | Iterator it = keys.iterator();
107 | int offset = 0;
108 | while(it.hasNext()) {
109 | long k = it.next();
110 | format.writeKey((Number)k, buf, offset);
111 | offset += format.keySize;
112 |
113 | int[] values = localData[indexOf(k)];
114 | if (format.denseColumn) {
115 | for (int i = 0; i < values.length; i++) {
116 | format.writeValue(values[i], buf, offset);
117 | offset += VALUE_SIZE;
118 | }
119 | }
120 | else {
121 | int counterIndex = offset;
122 | offset += 4;
123 |
124 | int counter = 0;
125 | for (int i = 0; i < values.length; i++) {
126 | if (values[i] != 0) {
127 | format.write(i, buf, offset);
128 | offset += 4;
129 | format.write(values[i], buf, offset);
130 | offset += VALUE_SIZE;
131 | }
132 |
133 | counter++;
134 | }
135 | format.write(counter, buf, counterIndex);
136 | }
137 | }
138 |
139 | return buf;
140 | }
141 |
142 | public int indexOf(long key) {
143 | if (localRows instanceof KeyRange) {
144 | return (int) (key - ((KeyRange)localRows).firstKey);
145 | }
146 | else if (localRows instanceof KeyHash) {
147 | KeyHash hash = (KeyHash) localRows;
148 | return (int) ((key - hash.minKey) % hash.hashQuato);
149 | }
150 |
151 | throw new RuntimeException("Only KeyRange or KeyHash is allowed in server storage");
152 | }
153 |
154 | public void handlePush(DataDesc format, byte[] data) {
155 |
156 | int offset = 0;
157 | while (offset < data.length) {
158 | long key = format.readKey(data, offset).longValue();
159 | offset += format.keySize;
160 | offset = updateRow(key, data, offset, format);
161 | }
162 | }
163 |
164 | private int updateRow(long key, byte[] data, int start, DataDesc format) {
165 | assert(localRows.contains(key));
166 |
167 | int index = indexOf(key);
168 | int[] row = localData[index];
169 | int offset = start;
170 | if (format.denseColumn) {
171 | for (int i = 0; i < row.length; i++) {
172 | int update = format.readInt(data, offset);
173 | row[i] += update;
174 | if (row[i] < 0) {
175 | throw new IllegalStateException("invalid counter: " + key + ", " + i + ", " + row[i]);
176 | }
177 | offset += 4;
178 | }
179 | }
180 | else {
181 | int count = format.readInt(data, offset);
182 | offset += 4;
183 | for (int i = 0; i < count; i++) {
184 | int col = format.readInt(data, offset);
185 | offset += 4;
186 | assert(col < row.length);
187 |
188 | int update = format.readInt(data, offset);
189 | row[col] += update;
190 | offset += 4;
191 | }
192 | }
193 |
194 | return offset;
195 | }
196 |
197 | }
198 |
--------------------------------------------------------------------------------
/src/main/main.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/src/main/scala/com/intel/distml/Dict.scala:
--------------------------------------------------------------------------------
1 | package com.intel.distml
2 |
3 | import java.io.Serializable
4 |
5 | import scala.collection.immutable
6 |
7 | /**
8 | * Created by yunlong on 12/23/15.
9 | */
10 | class Dict extends Serializable {
11 |
12 | var word2id = new immutable.HashMap[String, Int]
13 | var id2word = new immutable.HashMap[Int, String]
14 |
15 | def getSize: Int = {
16 | return word2id.size
17 | }
18 |
19 | def getWord(id: Int): String = {
20 | return id2word.get(id).get
21 | }
22 |
23 | def getID(word: String): Integer = {
24 | return word2id.get(word).get
25 | }
26 |
27 | /**
28 | * check if this dictionary contains a specified word
29 | */
30 | def contains(word: String): Boolean = {
31 | return word2id.contains(word)
32 | }
33 |
34 | def contains(id: Int): Boolean = {
35 | return id2word.contains(id)
36 | }
37 |
38 | def put(word : String, id : Int): Unit = {
39 | word2id += (word -> id)
40 | id2word += (id -> word)
41 | }
42 |
43 | /**
44 | * add a word into this dictionary
45 | * return the corresponding id
46 | */
47 | def addWord(word: String): Int = {
48 | if (!contains(word)) {
49 | val id: Int = word2id.size
50 | word2id += (word -> id)
51 | id2word += (id -> word)
52 | return id
53 | }
54 | else return getID(word)
55 | }
56 | }
--------------------------------------------------------------------------------
/src/main/scala/com/intel/distml/clustering/AliasTable.scala:
--------------------------------------------------------------------------------
1 | package com.intel.distml.clustering
2 |
3 | import java.util
4 | import java.util.Random
5 |
6 | /**
7 | * Created by yunlong on 12/19/15.
8 | */
9 | class AliasTable (K: Int) extends Serializable {
10 |
11 | var L: Array[Int] = new Array[Int](K)
12 | var H: Array[Int] = new Array[Int](K)
13 | var P: Array[Float] = new Array[Float](K)
14 |
15 | def sampleAlias(gen: Random): Int = {
16 | val s = gen.nextFloat() * K
17 | var t = s.toInt
18 | if (t >= K) t = K - 1
19 | if (K * P(t) < (s-t)) L(t) else H(t)
20 | }
21 |
22 | def init(probs : Array[(Int, Float)], mass : Float): Unit = {
23 |
24 | val m = 1.0 / K
25 | val lq = new util.LinkedList[(Int, Float)]()
26 | val hq = new util.LinkedList[(Int, Float)]()
27 |
28 | for ((k, _p) <- probs) {
29 | val p = _p / mass
30 | if (p < m) lq.add((k, p))
31 | else hq.add((k, p))
32 |
33 | }
34 |
35 | var index = 0
36 | while (!lq.isEmpty & !hq.isEmpty) {
37 | val (l, pl) = lq.removeFirst()
38 | val (h, ph) = hq.removeFirst()
39 | L(index) = l
40 | H(index) = h
41 | P(index) = pl
42 | val pd = ph - (m - pl)
43 | if (pd >= m) hq.add((h, pd.toFloat))
44 | else lq.add((h, pd.toFloat))
45 | index += 1
46 | }
47 |
48 | while (!hq.isEmpty) {
49 | val (h, ph) = hq.removeFirst()
50 | L(index) = h
51 | H(index) = h
52 | P(index) = ph
53 | index += 1
54 | }
55 |
56 | while (!lq.isEmpty) {
57 | val (l, pl) = lq.removeLast()
58 | L(index) = l
59 | H(index) = l
60 | P(index) = pl
61 | index += 1
62 | }
63 | }
64 | }
--------------------------------------------------------------------------------
/src/main/scala/com/intel/distml/clustering/LDAModel.scala:
--------------------------------------------------------------------------------
1 | package com.intel.distml.clustering
2 |
3 | import com.intel.distml.api.Model
4 | import com.intel.distml.util.{IntMatrixWithIntKey, IntArrayWithIntKey}
5 |
6 | /**
7 | * Created by yunlong on 16-3-29.
8 | */
9 | class LDAModel(
10 | val V : Int = 0,
11 | val K: Int = 20,
12 | val alpha : Double = 0.01,
13 | val beta : Double = 0.01
14 | ) extends Model {
15 |
16 | val alpha_sum = alpha * K
17 | val beta_sum = beta * V
18 |
19 | registerMatrix("doc-topics", new IntArrayWithIntKey(K))
20 | registerMatrix("word-topics", new IntMatrixWithIntKey(V, K))
21 |
22 |
23 | }
24 |
--------------------------------------------------------------------------------
/src/main/scala/com/intel/distml/clustering/LDAParams.scala:
--------------------------------------------------------------------------------
1 | package com.intel.distml.clustering
2 |
3 | /**
4 | * Created by yunlong on 12/23/15.
5 | */
6 | case class LDAParams(
7 | var psCount : Int = 2,
8 | var batchSize : Int = 100,
9 | var input: String = null,
10 | var k: Int = 20,
11 | var alpha : Double = 0.01,
12 | var beta : Double = 0.01,
13 | var maxIterations: Int = 10,
14 | val partitions : Int = 2,
15 | var showPlexity: Boolean = true)
16 | {
17 |
18 | var V : Int = 0
19 | var alpha_sum : Double = 0.0
20 | var beta_sum : Double = 0.0
21 |
22 | def init(V : Int): Unit = {
23 | this.V = V
24 | alpha_sum = alpha * k
25 | beta_sum = beta * V
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/src/main/scala/com/intel/distml/example/clustering/LDAExample.scala:
--------------------------------------------------------------------------------
1 | package com.intel.distml.example.clustering
2 |
3 | import com.intel.distml.Dict
4 | import com.intel.distml.api.Model
5 | import com.intel.distml.clustering.{LightLDA, LDAParams}
6 | import com.intel.distml.platform.DistML
7 | import com.intel.distml.util.{IntMatrixWithIntKey, IntArrayWithIntKey}
8 | import org.apache.spark.broadcast.Broadcast
9 | import org.apache.spark.storage.StorageLevel
10 | import org.apache.spark.{SparkContext, SparkConf}
11 | import scopt.OptionParser
12 |
13 | import scala.collection.mutable.ListBuffer
14 |
15 | /**
16 | * Created by yunlong on 16-3-29.
17 | */
18 | object LDAExample {
19 |
20 | def normalizeString(src : String) : String = {
21 | src.replaceAll("[^A-Z^a-z]", " ").trim().toLowerCase();
22 | }
23 |
24 | def fromWordsToIds(bdic : Broadcast[Dict])(line : String) : Array[Int] = {
25 |
26 | val dic = bdic.value
27 |
28 | val words = line.split(" ")
29 |
30 | var wordIDs = new ListBuffer[Int]();
31 |
32 | for (w <- words) {
33 | val wn = normalizeString(w)
34 | if (dic.contains(wn)) {
35 | wordIDs.append(dic.getID(wn))
36 | }
37 | }
38 |
39 | wordIDs.toArray
40 | }
41 |
42 | def main(args: Array[String]) {
43 |
44 | val defaultParams = new LDAParams()
45 |
46 | val parser = new OptionParser[LDAParams]("LDAExample") {
47 | head("LDAExample: an example LDA app for plain text data.")
48 | opt[Int]("k")
49 | .text(s"number of topics. default: ${defaultParams.k}")
50 | .action((x, c) => c.copy(k = x))
51 | opt[Int]("batchSize")
52 | .text(s"number of samples used in one computing. default: ${defaultParams.batchSize}")
53 | .action((x, c) => c.copy(batchSize = x))
54 | opt[Int]("psCount")
55 | .text(s"number of parameter servers. default: ${defaultParams.psCount}")
56 | .action((x, c) => c.copy(psCount = x))
57 | opt[Double]("alpha")
58 | .text(s"super parameter for sampling. default: ${defaultParams.alpha}")
59 | .action((x, c) => c.copy(alpha = x))
60 | opt[Double]("beta")
61 | .text(s"super parameter for sampling. default: ${defaultParams.beta}")
62 | .action((x, c) => c.copy(beta = x))
63 | opt[Int]("maxIterations")
64 | .text(s"number of iterations of learning. default: ${defaultParams.maxIterations}")
65 | .action((x, c) => c.copy(maxIterations = x))
66 | opt[Int]("partitions")
67 | .text(s"number of partitions to train the model. default: ${defaultParams.partitions}")
68 | .action((x, c) => c.copy(partitions = x))
69 | opt[Boolean]("showPlexity")
70 | .text(s"Show plexity after each iteration." +
71 | s" default: ${defaultParams.showPlexity}")
72 | .action((x, c) => c.copy(showPlexity = x))
73 | arg[String]("...")
74 | .text("input paths (directories) to plain text corpora.")
75 | .unbounded()
76 | .required()
77 | .action((x, c) => c.copy(input = x))
78 | }
79 | parser.parse(args, defaultParams).map { params =>
80 | run(params)
81 | }.getOrElse {
82 | parser.showUsageAsError
83 | sys.exit(1)
84 | }
85 | }
86 |
87 | private def run(p: LDAParams) {
88 | val conf = new SparkConf().setAppName(s"DistML.Example.LDA")
89 | conf.set("spark.driver.maxResultSize", "5g")
90 |
91 | val sc = new SparkContext(conf)
92 | Thread.sleep(3000)
93 |
94 | var rawLines = sc.textFile(p.input).filter(s => s.trim().length > 0)
95 |
96 | var dic = new Dict()
97 |
98 | val words = rawLines.flatMap(line => line.split(" ")).map(normalizeString).filter(s => s.trim().length > 0).distinct().collect()
99 | words.foreach(x => dic.addWord(x))
100 |
101 | p.init(dic.getSize)
102 |
103 | val bdic = sc.broadcast(dic)
104 | var data = rawLines.map(fromWordsToIds(bdic)).map(ids => {
105 | // println("=============== random initializing start ==================")
106 | val topics = new ListBuffer[(Int, Int)]
107 | ids.foreach(x => topics.append((x, Math.floor(Math.random() * p.k).toInt)))
108 |
109 | val doctopic = new Array[Int](p.k)
110 | topics.foreach(x => doctopic(x._2) = doctopic(x._2) + 1)
111 |
112 | // println("=============== random initializing done ==================")
113 | (doctopic, topics.toArray)
114 | }).repartition(p.partitions).persist(StorageLevel.MEMORY_AND_DISK)
115 |
116 | var statistics = data.map(d => (1, d._2.length)).reduce((a, b) => (a._1 + b._1, a._2 + b._2))
117 |
118 | println("=============== Corpus Info Begin ================")
119 | println("Vocaulary: " + dic.getSize)
120 | println("Docs: " + statistics._1)
121 | println("Tokens: " + statistics._2)
122 | println("Topics: " + p.k)
123 | println("=============== Corpus Info End ================")
124 |
125 |
126 | val dm = LightLDA.train(sc, data, dic.getSize, p)
127 |
128 | dm.recycle()
129 | sc.stop
130 | }
131 | }
132 |
--------------------------------------------------------------------------------
/src/main/scala/com/intel/distml/example/feature/MllibWord2Vec.scala:
--------------------------------------------------------------------------------
1 | package com.intel.distml.example.feature
2 |
3 | import org.apache.spark._
4 | import org.apache.spark.mllib.feature.Word2Vec
5 |
6 | /**
7 | * Created by yunlong on 1/29/16.
8 | */
9 | object MllibWord2Vec {
10 |
11 | def main(args: Array[String]) {
12 |
13 | val p = Integer.parseInt(args(1))
14 | println("partitions: " + p)
15 |
16 | val conf = new SparkConf().setAppName("Word2Vec")
17 | val sc = new SparkContext(conf)
18 |
19 | val input = sc.textFile(args(0)).map(line => line.split(" ").toSeq)
20 |
21 | val word2vec = new Word2Vec()
22 | word2vec.setNumPartitions(p)
23 |
24 | val model = word2vec.fit(input)
25 | println("vocab size: " + model.getVectors.size)
26 | val synonyms = model.findSynonyms("black", 40)
27 |
28 | for ((synonym, cosineSimilarity) <- synonyms) {
29 | println(s"$synonym $cosineSimilarity")
30 | }
31 |
32 | sc.stop()
33 |
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/src/main/scala/com/intel/distml/example/feature/Word2VecExample.scala:
--------------------------------------------------------------------------------
1 | package com.intel.distml.example.feature
2 |
3 | import com.intel.distml.Dict
4 | import com.intel.distml.api.{Session, Model}
5 | import com.intel.distml.feature.Word2Vec
6 | import org.apache.spark.storage.StorageLevel
7 | import org.apache.spark.{SparkContext, SparkConf}
8 | import org.apache.spark.broadcast.Broadcast
9 | import scopt.OptionParser
10 |
11 | import scala.collection.mutable
12 | import scala.collection.mutable.ListBuffer
13 |
14 | /**
15 | * Created by yunlong on 16-3-23.
16 | */
17 | object Word2VecExample {
18 | private case class Params(
19 | psCount: Int = 1,
20 | numPartitions : Int = 1,
21 | input: String = null,
22 | cbow: Boolean = true,
23 | alpha: Double = 0.0f,
24 | alphaFactor: Double = 10.0f,
25 | window : Int = 7,
26 | batchSize : Int = 100,
27 | vectorSize : Int = 10,
28 | minFreq : Int = 50,
29 | maxIterations: Int = 100 ) {
30 |
31 | def show(): Unit = {
32 | println("=========== params =============")
33 | println("psCount: " + psCount)
34 | println("numPartitions: " + psCount)
35 | println("input: " + input)
36 | println("cbow: " + cbow)
37 | println("alpha: " + alpha)
38 | println("alphaFactor: " + alphaFactor)
39 | println("window: " + window)
40 | println("batchSize: " + batchSize)
41 | println("vectorSize: " + vectorSize)
42 | println("minFreq: " + minFreq)
43 | println("maxIterations: " + maxIterations)
44 | println("=========== params =============")
45 | }
46 | }
47 |
48 | def main(args: Array[String]) {
49 |
50 | val defaultParams = Params()
51 |
52 | val parser = new OptionParser[Params]("LDAExample") {
53 | head("LDAExample: an example LDA app for plain text data.")
54 | opt[Int]("psCount")
55 | .text(s"number of parameter servers. default: ${defaultParams.psCount}")
56 | .action((x, c) => c.copy(psCount = x))
57 | opt[Int]("numPartitions")
58 | .text(s"number of partitions for the training data. default: ${defaultParams.numPartitions}")
59 | .action((x, c) => c.copy(numPartitions = x))
60 | opt[Double]("alpha")
61 | .text(s"initiali learning rate. default: ${defaultParams.alpha}")
62 | .action((x, c) => c.copy(alpha = x))
63 | opt[Double]("alphaFactor")
64 | .text(s"factor to decrease adaptive learning rate. default: ${defaultParams.alphaFactor}")
65 | .action((x, c) => c.copy(alphaFactor = x))
66 | opt[Int]("batchSize")
67 | .text(s"number of samples computed in a round. default: ${defaultParams.batchSize}")
68 | .action((x, c) => c.copy(batchSize = x))
69 | opt[Int]("vectorSize")
70 | .text(s"vector size for a single word: ${defaultParams.vectorSize}")
71 | .action((x, c) => c.copy(vectorSize = x))
72 | opt[Int]("maxIterations")
73 | .text(s"number of iterations of learning. default: ${defaultParams.maxIterations}")
74 | .action((x, c) => c.copy(maxIterations = x))
75 | opt[Int]("minFreq")
76 | .text(s"minimum word frequency. default: ${defaultParams.minFreq}")
77 | .action((x, c) => c.copy(minFreq = x))
78 | opt[Boolean]("cbow")
79 | .text(s"true if use cbow, false if use skipgram. default: ${defaultParams.cbow}")
80 | .action((x, c) => c.copy(cbow = x))
81 | arg[String]("...")
82 | .text("input paths (directories) to plain text corpora." +
83 | " Each text file line should hold 1 document.")
84 | .unbounded()
85 | .required()
86 | .action((x, c) => c.copy(input = x))
87 | }
88 | parser.parse(args, defaultParams).map { params =>
89 | run(params)
90 | }.getOrElse {
91 | parser.showUsageAsError
92 | sys.exit(1)
93 | }
94 | }
95 |
96 | def normalizeString(src : String) : String = {
97 | //src.filter( c => ((c >= '0') && (c <= '9')) )
98 | src.filter( c => (((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z')) || (c == ' '))).toLowerCase
99 | }
100 |
101 | def toWords(line : String) : Array[String] = {
102 | val words = line.split(" ").map(normalizeString)
103 | val list = new mutable.MutableList[String]()
104 | for (w <- words) {
105 | if (w.length > 0)
106 | list += w
107 | }
108 |
109 | list.toArray
110 | }
111 |
112 | def fromWordsToIds(bdic : Broadcast[Dict])(words : Array[String]) : Array[Int] = {
113 |
114 | val dic = bdic.value
115 |
116 | var wordIDs = new ListBuffer[Int]()
117 |
118 | for (w <- words) {
119 | val wn = normalizeString(w)
120 | if (dic.contains(wn)) {
121 | wordIDs.append(dic.getID(wn))
122 | }
123 | }
124 |
125 | wordIDs.toArray
126 | }
127 |
128 | def run(p : Params): Unit = {
129 |
130 | p.show
131 |
132 | val conf = new SparkConf().setAppName("Word2Vec")
133 | val sc = new SparkContext(conf)
134 |
135 | val rawLines = sc.textFile(p.input)
136 | val lines = rawLines.filter(s => s.length > 0).map(toWords).persist(StorageLevel.MEMORY_AND_DISK)
137 |
138 | val words = lines.flatMap(line => line.iterator).map((_, 1L))
139 |
140 | val countedWords = words.reduceByKey(_ + _).filter(f => f._2 > p.minFreq).sortBy(f => f._2).collect
141 | println("========== countedWords=" + countedWords.length + " ==================")
142 |
143 | var wordMap = new Dict
144 |
145 | var totalWords = 0L
146 | for (i <- 0 to countedWords.length - 1) {
147 | var item = countedWords(i)
148 | wordMap.put(item._1, i)
149 | totalWords += item._2
150 | }
151 | var wordTree = Word2Vec.createBinaryTree(countedWords)
152 | wordTree.tokens = totalWords
153 |
154 | var initialAlpha = 0.0f
155 | if (p.alpha < 10e-6) {
156 | if (p.cbow) {
157 | initialAlpha = 0.05f
158 | }
159 | else {
160 | initialAlpha = 0.0025f
161 | }
162 | }
163 | else {
164 | initialAlpha = p.alpha.toFloat
165 | }
166 |
167 | println("=============== Corpus Info Begin ================")
168 | println("Vocaulary: " + wordTree.vocabSize)
169 | println("Tokens: " + totalWords)
170 | println("Vector size: " + p.vectorSize)
171 | println("ps count: " + p.psCount)
172 | println("=============== Corpus Info End ================")
173 |
174 | val bdic = sc.broadcast(wordMap)
175 | //var data = lines.map(fromWordsToIds(bdic)).repartition(1).persist(StorageLevel.MEMORY_AND_DISK)
176 | var data = lines.map(fromWordsToIds(bdic)).repartition(p.numPartitions).persist(StorageLevel.MEMORY_AND_DISK)
177 |
178 | var alpha = initialAlpha
179 | val dm = Word2Vec.train(sc, p.psCount, data, p.vectorSize, wordTree, initialAlpha, p.alphaFactor, p.cbow, p.maxIterations, p.window, p.batchSize)
180 |
181 | dm.recycle()
182 | val result = Word2Vec.collect(dm)
183 |
184 | val vectors = new Array[Array[Float]](result.size)
185 | for (w <- result) {
186 | vectors(w._1) = w._2
187 | }
188 |
189 | val blackIndex = wordMap.getID("black")
190 | val sims = Word2Vec.findSynonyms(vectors, vectors(blackIndex), 10)
191 | for (s <- sims) {
192 | println(wordMap.getWord(s._1) + ", " + s._2)
193 | }
194 |
195 |
196 | sc.stop()
197 |
198 | System.out.println("===== Finished ====")
199 | }
200 |
201 | }
202 |
--------------------------------------------------------------------------------
/src/main/scala/com/intel/distml/example/regression/LargeLRTest.scala:
--------------------------------------------------------------------------------
1 | package com.intel.distml.example.regression
2 |
3 | import com.intel.distml.platform.DistML
4 | import com.intel.distml.regression.{LogisticRegression => LR}
5 | import com.intel.distml.util.DataStore
6 | import org.apache.spark.rdd.RDD
7 | import org.apache.spark.{SparkConf, SparkContext}
8 | import scopt.OptionParser
9 |
10 | import scala.collection.mutable.HashMap
11 |
12 | /**
13 | * Created by yunlong on 3/11/16.
14 | */
15 | object LargeLRTest {
16 |
17 | private case class Params(
18 | runType: String = "train",
19 | psCount: Int = 1,
20 | trainType : String = "ssp",
21 | maxIterations: Int = 100,
22 | batchSize: Int = 100, // for asgd only
23 | maxLag : Int = 2, // for ssp only
24 | dim: Int = 10000000,
25 | eta: Double = 0.0001,
26 | partitions : Int = 1,
27 | input: String = null,
28 | modelPath : String = null
29 | )
30 |
31 | def parseBlanc(line: String): (HashMap[Int, Double], Int) = {
32 | val s = line.split(" ")
33 |
34 | var a = java.lang.Float.parseFloat("1.0")
35 | var b = a.toInt
36 | var c = java.lang.Float.parseFloat(s(0))
37 | var label = java.lang.Float.parseFloat(s(0)).toInt
38 |
39 | val x = new HashMap[Int, Double]();
40 | for (i <- 1 to s.length - 1) {
41 | val f = s(i).split(":")
42 | val v = java.lang.Double.parseDouble(f(1))
43 | x.put(Integer.parseInt(f(0)), v)
44 | }
45 |
46 | (x, label)
47 | }
48 |
49 | def main(args: Array[String]) {
50 |
51 | val defaultParams = Params()
52 |
53 | val parser = new OptionParser[Params]("LargeLRExample") {
54 | head("LargeLRExample: an example for logistic regression with big dimension.")
55 | opt[String]("runType")
56 | .text(s"train for training model, test for testing existed model. default: ${defaultParams.runType}")
57 | .action((x, c) => c.copy(runType = x))
58 | opt[Int]("psCount")
59 | .text(s"number of parameter servers. default: ${defaultParams.psCount}")
60 | .action((x, c) => c.copy(psCount = x))
61 | opt[String]("trainType")
62 | .text(s"how to train your model, asgd or ssg. default: ${defaultParams.trainType}")
63 | .action((x, c) => c.copy(trainType = x))
64 | opt[Int]("maxIterations")
65 | .text(s"number of iterations of learning. default: ${defaultParams.maxIterations}")
66 | .action((x, c) => c.copy(maxIterations = x))
67 | opt[Int]("batchSize")
68 | .text(s"number of samples computed in a round. default: ${defaultParams.batchSize}")
69 | .action((x, c) => c.copy(batchSize = x))
70 | opt[Int]("maxLag")
71 | .text(s"maximum number of iterations between the fast worker and the slowest worker. default: ${defaultParams.maxLag}")
72 | .action((x, c) => c.copy(maxLag = x))
73 | opt[Int]("dim")
74 | .text(s"dimension of features. default: ${defaultParams.dim}")
75 | .action((x, c) => c.copy(dim = x))
76 | opt[Double]("eta")
77 | .text(s"learning rate. default: ${defaultParams.eta}")
78 | .action((x, c) => c.copy(eta = x))
79 | opt[Int]("partitions")
80 | .text(s"number of partitions for training data. default: ${defaultParams.partitions}")
81 | .action((x, c) => c.copy(partitions = x))
82 | arg[String]("...")
83 | .text("path to train the model")
84 | .required()
85 | .action((x, c) => c.copy(input = x))
86 | arg[String]("