├── .gitignore ├── CHANGES.txt ├── DistML.iml ├── LICENSE ├── README.md ├── data ├── MNISTReader.java └── mnist_prepare.sh ├── doc ├── architect.png ├── build.md ├── lr-implementation.md ├── runtime.png ├── sample_lda.md ├── sample_lr.md └── src_tree.md ├── pom.xml └── src └── main ├── java └── com │ └── intel │ └── distml │ ├── api │ ├── DMatrix.java │ ├── DataBus.java │ ├── Model.java │ └── Session.java │ ├── platform │ ├── DataBusProtocol.java │ ├── MonitorActor.java │ ├── PSActor.java │ ├── PSAgent.java │ ├── PSManager.java │ ├── PSSync.java │ ├── SSP.java │ ├── WorkerActor.java │ └── WorkerAgent.java │ └── util │ ├── AbstractDataReader.java │ ├── AbstractDataWriter.java │ ├── ByteBufferDataReader.java │ ├── ByteBufferDataWriter.java │ ├── Constants.java │ ├── DataDesc.java │ ├── DataStore.java │ ├── DefaultDataReader.java │ ├── DefaultDataWriter.java │ ├── DoubleArray.java │ ├── DoubleMatrix.java │ ├── IntArray.java │ ├── IntArrayWithIntKey.java │ ├── IntMatrix.java │ ├── IntMatrixWithIntKey.java │ ├── KeyCollection.java │ ├── KeyHash.java │ ├── KeyList.java │ ├── KeyRange.java │ ├── Logger.java │ ├── SparseArray.java │ ├── SparseMatrix.java │ ├── Utils.java │ └── store │ ├── DoubleArrayStore.java │ ├── DoubleMatrixStore.java │ ├── FloatArrayStore.java │ ├── FloatMatrixStore.java │ ├── FloatMatrixStoreAdaGrad.java │ ├── IntArrayStore.java │ └── IntMatrixStore.java ├── main.iml └── scala └── com └── intel └── distml ├── Dict.scala ├── clustering ├── AliasTable.scala ├── LDAModel.scala ├── LDAParams.scala └── LightLDA.scala ├── example ├── SimpleALS.scala ├── clustering │ └── LDAExample.scala ├── feature │ ├── MllibWord2Vec.scala │ └── Word2VecExample.scala └── regression │ ├── LargeLRTest.scala │ ├── MelBlanc.scala │ └── Mnist.scala ├── feature └── Word2Vec.scala ├── platform ├── Clock.scala ├── DistML.scala ├── ParamServerDriver.scala └── PipeLine.scala ├── regression ├── LogisticRegression.scala └── MLR.scala └── util └── scala ├── DoubleArray.scala ├── DoubleArrayWithIntKey.scala ├── DoubleMatrixWithIntKey.scala ├── FloatArray.scala ├── FloatMatrix.scala ├── FloatMatrixAdapGradWithIntKey.scala ├── FloatMatrixWithIntKey.scala ├── IntArray.scala ├── IntMatrix.scala ├── SparseArray.scala ├── SparseMatrix.scala └── SparseMatrixAdapGrad.scala /.gitignore: -------------------------------------------------------------------------------- 1 | *.log 2 | .idea 3 | bin/ 4 | target 5 | -------------------------------------------------------------------------------- /CHANGES.txt: -------------------------------------------------------------------------------- 1 | [2016-01-18] 2 | 1. Reimplemented data transfer with pure socket 3 | 4 | [2016-01-08] 5 | 1. Implemented LightLDA and MLR, now LR/LightLDA/MLR all works fine 6 | 2. Add api support for Double Array/Matrix and Int Array/Matrix 7 | 3. Add scala api for DoubleArray 8 | 9 | [2015-12-14] 10 | 1. Simplify APIs, allow users to use naive data for customized algorithms 11 | 2. With API change, only LR and LDA are remained to proove API is working well 12 | 3. Run examples on yarn and can use spark submit 13 | 14 | [2015-10-15] 15 | 1. Add environment variable "PS_NETWORK_PREFIX", which is used by parameter servers 16 | 2. Fix issue in handling messages after iteration done 17 | 18 | [2015-10-12] 19 | 1. Add serialization support for CNN classes 20 | 21 | [2015-08-21] 22 | 1. add draft lda support 23 | 2. support iterative UDF in serverside 24 | 25 | [2015-07-01] 26 | 1. removed feature of worker group to reduce code complexity, a new branch "workergrou" is created 27 | 2. use akka-io for data transport, which allows big data block transferring 28 | 3. start parameter servers in a separated thread, allows iterative training without resparting parameter servers 29 | -------------------------------------------------------------------------------- /DistML.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intel-machine-learning/DistML/10aaa292d8d48158f97b7f2a439711d1aec42279/LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DistML (Distributed Machine Learning platform) 2 | 3 | DistML is a machine learning tool which allows traing very large models on Spark, it's fully compatible with Spark (tested on 1.2 or above). 4 | 5 | 6 | 7 | Reference paper: [Large Scale Distributed Deep Networks](http://research.google.com/archive/large_deep_networks_nips2012.html) 8 | 9 | 10 | Runtime view: 11 | 12 | 13 | 14 | DistML provides several algorithms (LR, LDA, Word2Vec, ALS) to demonstrate its scalabilites, however, you may need to write your own algorithms based on DistML APIs(Model, Session, Matrix, DataStore...), generally, it's simple to extend existed algorithms to DistML, here we take LR as an example: [How to implement logistic regression on DistML](https://github.com/intel-machine-learning/DistML/tree/master/doc/lr-implementation.md). 15 | 16 | ### User Guide 17 | 1. [Download and build DistML](https://github.com/intel-machine-learning/DistML/tree/master/doc/build.md). 18 | 2. [Typical options](https://github.com/intel-machine-learning/DistML/tree/master/doc/options.md). 19 | 3. [Run Sample - LR](https://github.com/intel-machine-learning/DistML/tree/master/doc/sample_lr.md). 20 | 4. [Run Sample - MLR](https://github.com/intel-machine-learning/DistML/tree/master/doc/sample_mlr.md). 21 | 5. [Run Sample - LDA](https://github.com/intel-machine-learning/DistML/tree/master/doc/sample_lda.md). 22 | 6. [Run Sample - Word2Vec](https://github.com/intel-machine-learning/DistML/tree/master/doc/sample_word2vec.md). 23 | 7. [Run Sample - ALS](https://github.com/intel-machine-learning/DistML/tree/master/doc/sample_als.md). 24 | 8. [Benchmarks](https://github.com/intel-machine-learning/DistML/tree/master/doc/benchmarks.md). 25 | 9. [FAQ](https://github.com/intel-machine-learning/DistML/tree/master/doc/faq.md). 26 | 27 | ### API Document 28 | 1. [Source Tree](https://github.com/intel-machine-learning/DistML/tree/master/doc/src_tree.md). 29 | 2. [DistML API](https://github.com/intel-machine-learning/DistML/tree/master/doc/api.md). 30 | 31 | 32 | ## Contributors 33 | He Yunlong (Intel)
34 | Sun Yongjie (Intel)
35 | Liu Lantao (Intern, Graduated)
36 | Hao Ruixiang (Intern, Graduated)
37 | -------------------------------------------------------------------------------- /data/MNISTReader.java: -------------------------------------------------------------------------------- 1 | 2 | import java.io.*; 3 | 4 | /** 5 | * This class implements a reader for the MNIST dataset of handwritten digits. The dataset is found 6 | * at http://yann.lecun.com/exdb/mnist/. 7 | * 8 | * @author Gabe Johnson 9 | */ 10 | public class MNISTReader { 11 | 12 | /** 13 | * @param args 14 | * args[0]: label file; args[1]: data file. 15 | * @throws IOException 16 | */ 17 | public static void main(String[] args) throws IOException { 18 | 19 | DataOutputStream dos = new DataOutputStream(new FileOutputStream(args[2])); 20 | 21 | DataInputStream labels = new DataInputStream(new FileInputStream(args[0])); 22 | DataInputStream images = new DataInputStream(new FileInputStream(args[1])); 23 | 24 | int magicNumber = labels.readInt(); 25 | if (magicNumber != 2049) { 26 | System.err.println("Label file has wrong magic number: " + magicNumber + " (should be 2049)"); 27 | System.exit(0); 28 | } 29 | magicNumber = images.readInt(); 30 | if (magicNumber != 2051) { 31 | System.err.println("Image file has wrong magic number: " + magicNumber + " (should be 2051)"); 32 | System.exit(0); 33 | } 34 | int numLabels = labels.readInt(); 35 | int numImages = images.readInt(); 36 | int numRows = images.readInt(); 37 | int numCols = images.readInt(); 38 | if (numLabels != numImages) { 39 | System.err.println("Image file and label file do not contain the same number of entries."); 40 | System.err.println(" Label file contains: " + numLabels); 41 | System.err.println(" Image file contains: " + numImages); 42 | System.exit(0); 43 | } 44 | 45 | long start = System.currentTimeMillis(); 46 | int numLabelsRead = 0; 47 | int numImagesRead = 0; 48 | while (labels.available() > 0 && numLabelsRead < numLabels) { 49 | StringBuilder buf = new StringBuilder(); 50 | byte label = labels.readByte(); 51 | numLabelsRead++; 52 | int[][] image = new int[numCols][numRows]; 53 | for (int colIdx = 0; colIdx < numCols; colIdx++) { 54 | for (int rowIdx = 0; rowIdx < numRows; rowIdx++) { 55 | int point = images.readUnsignedByte(); 56 | image[colIdx][rowIdx] = point; 57 | buf.append("" + point + " "); 58 | } 59 | } 60 | numImagesRead++; 61 | 62 | buf.append("" + label + "\n"); 63 | dos.write(buf.toString().getBytes()); 64 | 65 | 66 | // At this point, 'label' and 'image' agree and you can do whatever you like with them. 67 | 68 | if (numLabelsRead % 10 == 0) { 69 | System.out.print("."); 70 | } 71 | if ((numLabelsRead % 800) == 0) { 72 | System.out.print(" " + numLabelsRead + " / " + numLabels); 73 | long end = System.currentTimeMillis(); 74 | long elapsed = end - start; 75 | long minutes = elapsed / (1000 * 60); 76 | long seconds = (elapsed / 1000) - (minutes * 60); 77 | System.out.println(" " + minutes + " m " + seconds + " s "); 78 | } 79 | } 80 | System.out.println(); 81 | long end = System.currentTimeMillis(); 82 | long elapsed = end - start; 83 | long minutes = elapsed / (1000 * 60); 84 | long seconds = (elapsed / 1000) - (minutes * 60); 85 | System.out 86 | .println("Read " + numLabelsRead + " samples in " + minutes + " m " + seconds + " s "); 87 | 88 | dos.close(); 89 | } 90 | 91 | } 92 | 93 | -------------------------------------------------------------------------------- /data/mnist_prepare.sh: -------------------------------------------------------------------------------- 1 | 2 | wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz 3 | wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz 4 | 5 | gunzip train-images-idx3-ubyte.gz 6 | gunzip train-labels-idx1-ubyte.gz 7 | 8 | javac MNISTReader.java 9 | java MNISTReader train-labels-idx1-ubyte train-images-idx3-ubyte mnist_train.txt 10 | -------------------------------------------------------------------------------- /doc/architect.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intel-machine-learning/DistML/10aaa292d8d48158f97b7f2a439711d1aec42279/doc/architect.png -------------------------------------------------------------------------------- /doc/build.md: -------------------------------------------------------------------------------- 1 | 2 | ## Guide to download and build DistML 3 | 4 | 5 | ### Download DistML 6 | ```sh 7 | $ git clone https://github.com/intel-machine-learning/DistML.git 8 | ``` 9 | 10 | ### Configure and build 11 | You may use a Spark with different version as DistML, though DistML is compatible with recent Spark releases, we encourage you to build DistML with same version Spark as that you are using. 12 | ```sh 13 | $ cd DistML 14 | $ vi pom.xml 15 | $ mvn package 16 | ``` 17 | 18 | After successful building, you will see target/distm-.jar, refer guides to run samples. 19 | -------------------------------------------------------------------------------- /doc/lr-implementation.md: -------------------------------------------------------------------------------- 1 | 2 | ## A sample shows how write algorithms based on DistML APIs 3 | 4 | ```scala 5 | val m = new Model() { 6 | registerMatrix("weights", new DoubleArrayWithIntKey(dim + 1)) 7 | } 8 | 9 | val dm = DistML.distribute(sc, m, psCount, DistML.defaultF) 10 | val monitorPath = dm.monitorPath 11 | 12 | dm.setTrainSetSize(samples.count()) 13 | 14 | for (iter <- 0 to maxIterations - 1) { 15 | println("============ Iteration: " + iter + " ==============") 16 | 17 | val t = samples.mapPartitionsWithIndex((index, it) => { 18 | println("--- connecting to PS ---") 19 | val session = new Session(m, monitorPath, index) 20 | val wd = m.getMatrix("weights").asInstanceOf[DoubleArrayWithIntKey] 21 | 22 | val batch = new util.LinkedList[(mutable.HashMap[Int, Double], Int)] 23 | while (it.hasNext) { 24 | batch.clear() 25 | var count = 0 26 | while ((count < batchSize) && it.hasNext) { 27 | batch.add(it.next()) 28 | count = count + 1 29 | } 30 | 31 | val keys = new KeyList() 32 | for ((x, label) <- batch) { 33 | for (key <- x.keySet) { 34 | keys.addKey(key) 35 | } 36 | } 37 | 38 | val w = wd.fetch(keys, session) 39 | val w_old = new util.HashMap[Long, Double] 40 | for ((key, value) <- w) { 41 | w_old.put(key, value) 42 | } 43 | 44 | for ((x, label) <- batch) { 45 | var sum = 0.0 46 | for ((k, v) <- x) { 47 | sum += w(k) * v 48 | } 49 | val h = 1.0 / (1.0 + Math.exp(-sum)) 50 | 51 | val err = eta * (h - label) 52 | for ((k, v) <- x) { 53 | w.put(k, w(k) - err * v) 54 | } 55 | 56 | cost = cost + label * Math.log(h) + (1 - label) * Math.log(1 - h) 57 | } 58 | 59 | cost /= batch.size() 60 | for (key <- w.keySet) { 61 | val grad: Double = w(key) - w_old(key) 62 | w.put(key, grad) 63 | } 64 | wd.push(w, session) 65 | } 66 | 67 | session.disconnect() 68 | 69 | val r = new Array[Double](1) 70 | r(0) = -cost 71 | r.iterator 72 | }) 73 | 74 | val totalCost = t.reduce(_+_) 75 | println("============ Iteration done, Total Cost: " + totalCost + " ============") 76 | } 77 | ``` 78 | 79 | ## Instructions 80 | 81 | Firstly define your model with parameter type and dimension, for logistic regression, we need a double vector, DistML provides Array/Matrix for int/long/float/double. 82 | ```scala 83 | val m = new Model() { 84 | registerMatrix("weights", new DoubleArrayWithIntKey(dim + 1)) 85 | } 86 | ``` 87 | 88 | Before training the model, we need to distributed the parameters to several parameter server nodes, the number of parameter servers is specified by psCount. 89 | ```scala 90 | val dm = DistML.distribute(sc, m, psCount, DistML.defaultF) 91 | val monitorPath = dm.monitorPath 92 | 93 | dm.setTrainSetSize(samples.count()) 94 | ``` 95 | 96 | In each worker doing training jobs, we need to setup a session, which helps to setup databuses between workers and parameter servers. 97 | ```scala 98 | val session = new Session(m, monitorPath, index) 99 | val wd = m.getMatrix("weights").asInstanceOf[DoubleArrayWithIntKey] 100 | ``` 101 | 102 | After connected to parameter servers, we can fetch the parameters now. Note that w_old is used to calculate updates after each iteration. 103 | ```scala 104 | val w = wd.fetch(keys, session) 105 | val w_old = new util.HashMap[Long, Double] 106 | for ((key, value) <- w) { 107 | w_old.put(key, value) 108 | } 109 | ··· 110 | 111 | With training, the parameters are updated, we calculate updates here then push to parameter servers. 112 | ```scala 113 | for (key <- w.keySet) { 114 | val grad: Double = w(key) - w_old(key) 115 | w.put(key, grad) 116 | } 117 | wd.push(w, session) 118 | ``` 119 | When worker finishs training of each iteration, disconnect from parameter servers. 120 | ```scala 121 | session.disconnect() 122 | ``` 123 | 124 | -------------------------------------------------------------------------------- /doc/runtime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intel-machine-learning/DistML/10aaa292d8d48158f97b7f2a439711d1aec42279/doc/runtime.png -------------------------------------------------------------------------------- /doc/sample_lda.md: -------------------------------------------------------------------------------- 1 | 2 | == sample command to run lightlda 3 | 4 | ```sh 5 | #./spark-submit --class com.intel.distml.example.LightLDA --master yarn://dl-s1:8088 /home/spark/yunlong/distml-0.2.jar --psCount 1 --maxIterations 100 --k 2 --alpha 0.01 --beta 0.01 --batchSize 100 /data/lda/short.txt 6 | ``` 7 | 8 | == About LDA parameters: 9 | ``` 10 | k: topic number, default value 20 11 | alpha : hyper parameter, default value 0.01 12 | beta : hyper parameter, default value 0.01 13 | showPlexity: whether to show perplexity after training, default value true 14 | ``` 15 | -------------------------------------------------------------------------------- /doc/sample_lr.md: -------------------------------------------------------------------------------- 1 | 2 | == Run sparse logistic regression 3 | 4 | ```sh 5 | #./spark-submit --num-executors 8 --executor-cores 3 --executor-memory 10g --class com.intel.distml.example.regression.MelBlanc --master yarn://dl-s1:8088 /home/spark/yunlong/distml-0.2.jar --psCount 1 --trainType ssp --maxIterations 150 --maxLag 2 --dim 1000000 --partitions 8 --eta 0.00001 /data/lr/Blanc-Mel.txt /models/lr/Blanc 6 | ``` 7 | 8 | === Options 9 | runType: train or test, default value: train 10 | psCount: number of parameter servers, default value 1 11 | psBackup : whether to enable parameter server fault toleranceBoolean = false 12 | trainType : how to train lr model, "ssp" or "asgd", default value "ssp" 13 | maxIterations: how many iterations to train the model, only applicable to runType=train default value 100 14 | batchSize: batch size when using "asgd" as train type, default value 100 15 | maxLag: max iteration difference between fastest and slowest workers, only applicable to trainType=ssp, default value 2 16 | dim: dimension of the lr weights, default value 10000000 17 | eta: learning rate, default value 0.0001 18 | partitions: how many partitions for training data, default value 1 19 | input: dataset for training, mandatory 20 | modelPath: where to save the model, mandatory 21 | -------------------------------------------------------------------------------- /doc/src_tree.md: -------------------------------------------------------------------------------- 1 | 2 | ## DistML source tree reading guide 3 | Though only supports Spark now, DistML is designed to be runnable on Spark or on Yarn directly. so we implemented main engine with Java, Spark related support and sample algorithms are written in scala. 4 | 5 | --src 6 | main 7 | java 8 | com 9 | intel 10 | distml 11 | api 12 | platform 13 | util 14 | scala 15 | com 16 | intel 17 | distml 18 | clustering 19 | example 20 | feature 21 | platform 22 | regression 23 | util 24 | 25 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | intel 8 | distml 9 | 0.2 10 | 11 | 2.10.4 12 | 2.11.1 13 | 14 | 15 | 16 | 17 | com.typesafe.akka 18 | akka-remote_2.10 19 | 2.3.4 20 | 21 | 22 | com.typesafe.akka 23 | akka-actor_2.10 24 | 2.3.4 25 | 26 | 27 | org.apache.hadoop 28 | hadoop-hdfs 29 | 2.2.0 30 | compile 31 | 32 | 33 | org.apache.spark 34 | spark-core_2.10 35 | 1.4.0 36 | compile 37 | 38 | 39 | org.apache.spark 40 | spark-mllib_2.10 41 | 1.4.0 42 | compile 43 | 44 | 45 | com.google.guava 46 | guava 47 | 18.0 48 | compile 49 | 50 | 51 | com.github.fommil.netlib 52 | core 53 | 1.1.2 54 | 55 | 56 | org.jblas 57 | jblas 58 | 1.2.3 59 | 60 | 61 | com.github.scopt 62 | scopt_2.10 63 | 3.2.0 64 | 65 | 66 | 67 | org.scala-lang 68 | scala-library 69 | 2.10.4 70 | compile 71 | 72 | 73 | 74 | org.scala-lang 75 | scala-compiler 76 | 2.10.4 77 | compile 78 | 79 | 80 | org.apache.commons 81 | commons-math3 82 | 3.0 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | net.alchim31.maven 91 | scala-maven-plugin 92 | 3.2.0 93 | 94 | 95 | org.apache.maven.plugins 96 | maven-compiler-plugin 97 | 98 | 99 | 100 | 101 | src 102 | 103 | 104 | net.alchim31.maven 105 | scala-maven-plugin 106 | 107 | 108 | scala-compile-first 109 | process-resources 110 | 111 | add-source 112 | compile 113 | 114 | 115 | 116 | 117 | 118 | org.apache.maven.plugins 119 | maven-compiler-plugin 120 | 121 | 122 | compile 123 | 124 | compile 125 | 126 | 127 | 1.7 128 | 1.7 129 | 1.7 130 | UTF-8 131 | 132 | 133 | 134 | 135 | 136 | org.apache.maven.plugins 137 | maven-shade-plugin 138 | 139 | false 140 | 141 | 142 | org.jblas:jblas 143 | com.github.scopt:scopt_2.10 144 | org.apache.commons:commons-math3 145 | 146 | 147 | 148 | 149 | 150 | package 151 | 152 | shade 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/api/DMatrix.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.api; 2 | 3 | import com.intel.distml.util.DataDesc; 4 | import com.intel.distml.util.DataStore; 5 | import com.intel.distml.util.KeyCollection; 6 | import com.intel.distml.util.KeyRange; 7 | 8 | import java.io.Serializable; 9 | 10 | public class DMatrix implements Serializable { 11 | 12 | public static final int PARTITION_STRATEGY_LINEAR = 0; 13 | public static final int PARTITION_STRATEGY_HASH = 1; 14 | 15 | public KeyCollection[] partitions; 16 | 17 | protected KeyRange rowKeys; 18 | 19 | protected int partitionStrategy; 20 | 21 | protected String name; 22 | 23 | protected DataDesc format; 24 | 25 | protected DataStore store; 26 | 27 | public DMatrix(long rows) { 28 | this.partitionStrategy = PARTITION_STRATEGY_LINEAR; 29 | 30 | rowKeys = new KeyRange(0, rows-1); 31 | } 32 | 33 | public KeyCollection getRowKeys() { 34 | return rowKeys; 35 | } 36 | 37 | public KeyCollection getColKeys() { 38 | return KeyRange.Single; 39 | } 40 | 41 | public DataDesc getFormat() { 42 | return format; 43 | } 44 | 45 | public void setPartitionStrategy(int strategy) { 46 | if ((strategy != PARTITION_STRATEGY_LINEAR) && (strategy != PARTITION_STRATEGY_HASH)) { 47 | throw new IllegalArgumentException("partiton strategy must be SPLIT or HASH."); 48 | } 49 | 50 | partitionStrategy = strategy; 51 | } 52 | 53 | void partition(int serverNum) { 54 | 55 | KeyCollection[] keySets; 56 | if (partitionStrategy == PARTITION_STRATEGY_LINEAR) { 57 | keySets = rowKeys.linearSplit(serverNum); 58 | } 59 | else { 60 | keySets = rowKeys.hashSplit(serverNum); 61 | } 62 | 63 | partitions = keySets; 64 | } 65 | 66 | public DataStore getStore() { 67 | return store; 68 | } 69 | 70 | public DataStore setStore() { 71 | return store; 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/api/DataBus.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.api; 2 | 3 | import com.intel.distml.util.DataDesc; 4 | import com.intel.distml.util.KeyCollection; 5 | 6 | import java.util.HashMap; 7 | 8 | /** 9 | * Created by yunlong on 6/3/15. 10 | */ 11 | public interface DataBus { 12 | 13 | public byte[][] fetch(String matrixName, KeyCollection rowKeys, DataDesc format); 14 | 15 | public void push(String matrixName, DataDesc format, byte[][] data); 16 | 17 | public void disconnect(); 18 | 19 | public void psAvailable(int index, String addr); 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/api/Model.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.api; 2 | 3 | import java.io.Serializable; 4 | import java.util.HashMap; 5 | 6 | /** 7 | * Created by yunlong on 12/13/14. 8 | */ 9 | public class Model implements Serializable { 10 | 11 | public HashMap dataMap; 12 | 13 | public String monitorPath; 14 | 15 | public int psCount; 16 | public boolean psReady; 17 | 18 | public Model() { 19 | dataMap = new HashMap(); 20 | psReady = false; 21 | } 22 | 23 | public void registerMatrix(String name, DMatrix matrix) { 24 | if (dataMap.containsKey(name)) { 25 | throw new IllegalArgumentException("Matrix already exist: " + name); 26 | } 27 | matrix.name = name; 28 | dataMap.put(name, matrix); 29 | } 30 | 31 | public DMatrix getMatrix(String matrixName) { 32 | return dataMap.get(matrixName); 33 | } 34 | 35 | public void autoPartition(int serverNum) { 36 | this.psCount = serverNum; 37 | 38 | for (String matrixName: dataMap.keySet()) { 39 | DMatrix m = dataMap.get(matrixName); 40 | m.partition(serverNum); 41 | } 42 | } 43 | 44 | } 45 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/api/Session.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.api; 2 | 3 | import akka.actor.ActorRef; 4 | import akka.actor.ActorSystem; 5 | import com.intel.distml.platform.MonitorActor; 6 | import com.intel.distml.platform.WorkerActor; 7 | import com.typesafe.config.Config; 8 | import com.typesafe.config.ConfigFactory; 9 | 10 | /** 11 | * Created by yunlong on 12/9/15. 12 | */ 13 | public class Session { 14 | 15 | static final String ACTOR_SYSTEM_CONFIG = 16 | "akka.actor.provider=\"akka.remote.RemoteActorRefProvider\"\n" + 17 | "akka.remote.netty.tcp.port=0\n" + 18 | "akka.remote.log-remote-lifecycle-events=off\n" + 19 | "akka.log-dead-letters=off\n" + 20 | "akka.io.tcp.direct-buffer-size = 2 MB\n" + 21 | "akka.io.tcp.trace-logging=off\n" + 22 | "akka.remote.netty.tcp.maximum-frame-size=4126935"; 23 | 24 | 25 | public ActorSystem workerActorSystem; 26 | public DataBus dataBus; 27 | public ActorRef monitor; 28 | public ActorRef worker; 29 | public Model model; 30 | 31 | public Session(Model model, String monitorPath, int workerIndex) { 32 | this.model = model; 33 | connect(monitorPath, workerIndex); 34 | } 35 | 36 | public void connect(String monitorPath, int workerIndex) { 37 | dataBus = null; 38 | 39 | String WORKER_ACTOR_SYSTEM_NAME = "worker"; 40 | String WORKER_ACTOR_NAME = "worker"; 41 | 42 | Config cfg = ConfigFactory.parseString(ACTOR_SYSTEM_CONFIG); 43 | workerActorSystem = ActorSystem.create(WORKER_ACTOR_SYSTEM_NAME, ConfigFactory.load(cfg)); 44 | worker = workerActorSystem.actorOf(WorkerActor.props(this, model, monitorPath, workerIndex), WORKER_ACTOR_NAME); 45 | 46 | while(dataBus == null) { 47 | try {Thread.sleep(10); } catch (InterruptedException e) {} 48 | } 49 | } 50 | 51 | public void progress(int sampleCount) { 52 | worker.tell(new WorkerActor.Progress(sampleCount), null); 53 | } 54 | 55 | public void iterationDone(int iteration) { 56 | iterationDone(iteration, 0.0); 57 | } 58 | 59 | public void iterationDone(int iteration, double cost) { 60 | WorkerActor.IterationDone req = new WorkerActor.IterationDone(iteration, cost); 61 | worker.tell(req, null); 62 | while(!req.done) { 63 | try { Thread.sleep(100); } catch (Exception e) {} 64 | } 65 | } 66 | 67 | public void disconnect() { 68 | dataBus.disconnect(); 69 | workerActorSystem.stop(worker); 70 | workerActorSystem.shutdown(); 71 | dataBus = null; 72 | } 73 | 74 | public void discard() { 75 | worker.tell(new WorkerActor.Command(WorkerActor.CMD_DISCONNECT), null); 76 | } 77 | 78 | @Override 79 | public void finalize() { 80 | if (dataBus != null) 81 | disconnect(); 82 | } 83 | 84 | } 85 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/platform/PSActor.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.platform; 2 | 3 | import akka.actor.ActorSelection; 4 | import akka.actor.Props; 5 | import akka.japi.Creator; 6 | import com.intel.distml.api.Model; 7 | 8 | import akka.actor.UntypedActor; 9 | import com.intel.distml.util.*; 10 | import com.intel.distml.util.store.FloatMatrixStoreAdaGrad; 11 | import org.apache.hadoop.conf.Configuration; 12 | import org.apache.hadoop.fs.FileSystem; 13 | import org.apache.hadoop.fs.Path; 14 | 15 | import java.io.DataInputStream; 16 | import java.io.DataOutputStream; 17 | import java.io.IOException; 18 | import java.io.Serializable; 19 | import java.net.Socket; 20 | import java.net.SocketAddress; 21 | import java.net.URI; 22 | import java.util.HashMap; 23 | 24 | public class PSActor extends UntypedActor { 25 | 26 | public static int MIN_REPORT_INTERVAL = 1000; 27 | 28 | public static final int OP_LOAD = 0; 29 | public static final int OP_SAVE = 1; 30 | public static final int OP_ZERO = 2; 31 | public static final int OP_RAND = 3; 32 | public static final int OP_SET = 4; 33 | public static final int OP_SET_ALPHA = 5; 34 | 35 | public static class RegisterRequest implements Serializable { 36 | private static final long serialVersionUID = 1L; 37 | 38 | final public int index; 39 | final public String executorId; 40 | final public String addr; 41 | final public String hostName; 42 | final public long freeMemory; 43 | final long totalMemory; 44 | 45 | public RegisterRequest(int index, String executorId, String hostName, String addr) { 46 | this.index = index; 47 | this.executorId = executorId; 48 | this.addr = addr; 49 | this.hostName = hostName; 50 | totalMemory = Runtime.getRuntime().totalMemory(); 51 | freeMemory = Runtime.getRuntime().freeMemory(); 52 | } 53 | } 54 | 55 | public static class SyncServerInfo implements Serializable { 56 | private static final long serialVersionUID = 1L; 57 | 58 | final public String addr; 59 | public SyncServerInfo(String addr) { 60 | this.addr = addr; 61 | } 62 | 63 | @Override 64 | public String toString() { 65 | return "SyncServerInfo[" + addr + "]"; 66 | } 67 | } 68 | 69 | public static class ModelSetup implements Serializable { 70 | private static final long serialVersionUID = 1L; 71 | 72 | int op; 73 | String path; 74 | String value; 75 | public ModelSetup(int op, String path) { 76 | this.op = op; 77 | this.path = path; 78 | this.value = null; 79 | } 80 | public ModelSetup(int op, String path, String value) { 81 | this.op = op; 82 | this.path = path; 83 | this.value = value; 84 | } 85 | } 86 | 87 | public static class ModelSetupDone implements Serializable { 88 | private static final long serialVersionUID = 1L; 89 | 90 | public ModelSetupDone() { 91 | } 92 | } 93 | 94 | public static class Stop implements Serializable { 95 | private static final long serialVersionUID = 1L; 96 | 97 | public Stop() { 98 | } 99 | } 100 | 101 | public static class AgentMessage implements Serializable { 102 | private static final long serialVersionUID = 1L; 103 | 104 | final public long freeMemory; 105 | final public long totalMemory; 106 | 107 | public AgentMessage(long freeMemory, long totalMemory) { 108 | this.freeMemory = freeMemory; 109 | this.totalMemory = totalMemory; 110 | } 111 | } 112 | 113 | public static class Report implements Serializable { 114 | private static final long serialVersionUID = 1L; 115 | 116 | final public long freeMemory; 117 | final public long totalMemory; 118 | 119 | public Report(long freeMemory, long totalMemory) { 120 | this.freeMemory = freeMemory; 121 | this.totalMemory = totalMemory; 122 | } 123 | } 124 | 125 | private Model model; 126 | private HashMap stores; 127 | 128 | private ActorSelection monitor; 129 | private int serverIndex; 130 | private String executorId; 131 | 132 | private PSAgent agent; 133 | private PSSync syncThread; 134 | 135 | private long lastReportTime; 136 | 137 | public static Props props(final Model model, final HashMap stores, final String monitorPath, 138 | final int parameterServerIndex, final String executorId, final String psNetwordPrefix) { 139 | return Props.create(new Creator() { 140 | private static final long serialVersionUID = 1L; 141 | public PSActor create() throws Exception { 142 | return new PSActor(model, stores, monitorPath, parameterServerIndex, executorId, psNetwordPrefix); 143 | } 144 | }); 145 | } 146 | 147 | PSActor(Model model, HashMap stores, String monitorPath, int serverIndex, String executorId, String psNetwordPrefix) { 148 | this.monitor = getContext().actorSelection(monitorPath); 149 | this.serverIndex = serverIndex; 150 | this.executorId = executorId; 151 | this.model = model; 152 | this.stores = stores; 153 | this.lastReportTime = 0; 154 | 155 | agent = new PSAgent(getSelf(), model, stores, psNetwordPrefix); 156 | agent.start(); 157 | this.monitor.tell(new RegisterRequest(serverIndex, executorId, agent.hostName(), agent.addr()), getSelf()); 158 | } 159 | 160 | 161 | @Override 162 | public void onReceive(Object msg) throws Exception { 163 | log("onReceive: " + msg); 164 | if (msg instanceof SyncServerInfo) { 165 | SyncServerInfo info = (SyncServerInfo) msg; 166 | String[] s = info.addr.split(":"); 167 | Socket sck = new Socket(s[0], Integer.parseInt(s[1])); 168 | syncThread = new PSSync(getSelf(), model, stores); 169 | syncThread.asStandBy(sck); 170 | } 171 | else if (msg instanceof ModelSetup) { 172 | ModelSetup req = (ModelSetup) msg; 173 | String path = req.path; 174 | switch (req.op) { 175 | case OP_LOAD: 176 | load(path); 177 | break; 178 | case OP_SAVE: 179 | save(path); 180 | break; 181 | case OP_RAND: 182 | DataStore store = stores.get(path); 183 | store.rand(); 184 | break; 185 | case OP_ZERO: 186 | store = stores.get(path); 187 | store.zero(); 188 | break; 189 | case OP_SET: 190 | store = stores.get(path); 191 | store.set(req.value); 192 | break; 193 | } 194 | monitor.tell(new ModelSetupDone(), getSelf()); 195 | } 196 | else if (msg instanceof MonitorActor.SetAlpha) { 197 | MonitorActor.SetAlpha req = (MonitorActor.SetAlpha) msg; 198 | FloatMatrixStoreAdaGrad store = (FloatMatrixStoreAdaGrad) stores.get(((MonitorActor.SetAlpha) msg).matrixName); 199 | store.setAlpha(req.initialAlpha, req.minAlpha, req.factor); 200 | monitor.tell(new ModelSetupDone(), getSelf()); 201 | } 202 | else if (msg instanceof AgentMessage) { 203 | AgentMessage m = (AgentMessage) msg; 204 | long now = System.currentTimeMillis(); 205 | if ((now - lastReportTime) > MIN_REPORT_INTERVAL) { 206 | monitor.tell(new Report(m.freeMemory, m.totalMemory), getSelf()); 207 | lastReportTime = now; 208 | } 209 | } 210 | else if (msg instanceof MonitorActor.IterationDone) { 211 | agent.closeClients(); 212 | monitor.tell(new ModelSetupDone(), getSelf()); 213 | } 214 | else if (msg instanceof Stop) { 215 | agent.disconnect(); 216 | getContext().stop(self()); 217 | } 218 | else unhandled(msg); 219 | } 220 | 221 | private void load(String path) throws IOException { 222 | Configuration conf = new Configuration(); 223 | FileSystem fs = FileSystem.get(URI.create(path), conf); 224 | 225 | for (String name : stores.keySet()) { 226 | DataStore store = stores.get(name); 227 | Path dst = new Path(path + "/" + name + "." + serverIndex); 228 | DataInputStream in = fs.open(dst); 229 | 230 | store.readAll(in); 231 | in.close(); 232 | } 233 | } 234 | 235 | private void save(String path) throws IOException { 236 | log("saving model: " + path); 237 | 238 | Configuration conf = new Configuration(); 239 | FileSystem fs = FileSystem.get(URI.create(path), conf); 240 | 241 | for (String name : stores.keySet()) { 242 | DataStore store = stores.get(name); 243 | Path dst = new Path(path + "/" + name + "." + serverIndex); 244 | DataOutputStream out = fs.create(dst); 245 | 246 | log("saving to: " + dst.getName()); 247 | store.writeAll(out); 248 | out.flush(); 249 | out.close(); 250 | } 251 | } 252 | 253 | @Override 254 | public void postStop() { 255 | if (syncThread != null) { 256 | syncThread.disconnect(); 257 | try { 258 | syncThread.join(); 259 | } 260 | catch (Exception e) {} 261 | } 262 | getContext().system().shutdown(); 263 | log("Parameter server stopped"); 264 | } 265 | 266 | private void log(String msg) { 267 | Logger.info(msg, "PS-" + serverIndex); 268 | } 269 | } 270 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/platform/PSManager.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.platform; 2 | 3 | import akka.actor.ActorRef; 4 | import com.intel.distml.util.Logger; 5 | 6 | import java.util.LinkedList; 7 | 8 | /** 9 | * Created by yunlong on 4/28/16. 10 | */ 11 | public class PSManager { 12 | 13 | static class PSNode { 14 | String addr; 15 | String executorId; 16 | ActorRef actor; 17 | 18 | PSNode(String addr, String executorId, ActorRef actor) { 19 | this.addr = addr; 20 | this.executorId = executorId; 21 | this.actor = actor; 22 | } 23 | 24 | public String toString() { 25 | return "[ps " + addr + "]"; 26 | } 27 | } 28 | 29 | static interface PSMonitor { 30 | void switchServer(int index, String addr, ActorRef actor); 31 | void psFail(); 32 | } 33 | 34 | LinkedList[] servers; 35 | PSMonitor monitor; 36 | 37 | public PSManager(PSMonitor monitor, int psCount) { 38 | this.monitor = monitor; 39 | this.servers = new LinkedList[psCount]; 40 | for (int i = 0; i < servers.length; i++) { 41 | servers[i] = new LinkedList(); 42 | } 43 | } 44 | 45 | public LinkedList getAllActors() { 46 | LinkedList actors = new LinkedList(); 47 | for (int i = 0; i < servers.length; i++) { 48 | if (servers[i] != null) { 49 | for (PSNode node : servers[i]) { 50 | if (node != null) { 51 | actors.add(node.actor); 52 | } 53 | } 54 | } 55 | } 56 | return actors; 57 | } 58 | 59 | public ActorRef[] getActors() { 60 | ActorRef[] actors = new ActorRef[servers.length]; 61 | for (int i = 0; i < servers.length; i++) { 62 | if (servers[i] != null) { 63 | PSNode node = servers[i].getFirst(); 64 | if (node != null) { 65 | actors[i] = node.actor; 66 | } 67 | } 68 | } 69 | return actors; 70 | } 71 | 72 | public String[] getAddrs() { 73 | String[] addrs = new String[servers.length]; 74 | for (int i = 0; i < servers.length; i++) { 75 | if (servers[i] != null) { 76 | PSNode node = servers[i].getFirst(); 77 | if (node != null) { 78 | addrs[i] = node.addr; 79 | } 80 | } 81 | } 82 | return addrs; 83 | } 84 | 85 | public ActorRef getActor(int index) { 86 | if (servers[index] != null) { 87 | PSNode node = servers[index].getFirst(); 88 | if (node != null) { 89 | return node.actor; 90 | } 91 | } 92 | return null; 93 | } 94 | 95 | public String getAddr(int index) { 96 | if (servers[index] != null) { 97 | PSNode node = servers[index].getFirst(); 98 | if (node != null) { 99 | return node.addr; 100 | } 101 | } 102 | return null; 103 | } 104 | 105 | public void serverTerminated(ActorRef actor) { 106 | PSNode node = null; 107 | int index = -1; 108 | for (int i = 0; i < servers.length; i++) { 109 | for (PSNode n : servers[i]) { 110 | if (n.actor.equals(actor)) { 111 | node = n; 112 | index = i; 113 | break; 114 | } 115 | } 116 | if (node != null) break; 117 | } 118 | 119 | serverTerminated(index, node); 120 | } 121 | 122 | public void serverTerminated(String executorId) { 123 | PSNode node = null; 124 | int index = -1; 125 | for (int i = 0; i < servers.length; i++) { 126 | for (PSNode n : servers[i]) { 127 | if (n.executorId.equals(executorId)) { 128 | node = n; 129 | index = i; 130 | break; 131 | } 132 | } 133 | if (node != null) break; 134 | } 135 | 136 | serverTerminated(index, node); 137 | } 138 | 139 | public void serverTerminated(int index, PSNode node) { 140 | log("parameter server terminated: " + index + ", " + node); 141 | boolean isPrimary = (servers[index].getFirst() == node); 142 | 143 | servers[index].remove(node); 144 | if (servers[index].size() == 0) { 145 | monitor.psFail(); 146 | } 147 | else if (isPrimary) { 148 | node = servers[index].getFirst(); 149 | monitor.switchServer(index, node.addr, node.actor); 150 | } 151 | } 152 | 153 | 154 | /** 155 | * add new parameter server to the list. 156 | * 157 | * 158 | * @param index 159 | * @param addr 160 | * @param ref 161 | * @return whether the server be primary 162 | */ 163 | public boolean serverAvailable(int index, String executorId, String addr, ActorRef ref) { 164 | servers[index].add(new PSNode(addr, executorId, ref)); 165 | return (servers[index].size() == 1); 166 | } 167 | /* 168 | private PSNode getNode(int executorId) { 169 | for (int i = 0; i < servers.length; i++) { 170 | for (PSNode node : servers[i]) { 171 | if (node.executorId == executorId) { 172 | return node; 173 | } 174 | } 175 | } 176 | 177 | return null; 178 | } 179 | 180 | private PSNode getNode(ActorRef a) { 181 | for (int i = 0; i < servers.length; i++) { 182 | for (PSNode node : servers[i]) { 183 | if (node.actor.equals(a)) { 184 | return node; 185 | } 186 | } 187 | } 188 | 189 | return null; 190 | } 191 | 192 | private int getIndex(ActorRef a) { 193 | for (int i = 0; i < servers.length; i++) { 194 | for (PSNode node : servers[i]) { 195 | if (node.actor.equals(a)) { 196 | return i; 197 | } 198 | } 199 | } 200 | 201 | return -1; 202 | } 203 | 204 | private int getIndex(int executorId) { 205 | for (int i = 0; i < servers.length; i++) { 206 | for (PSNode node : servers[i]) { 207 | if (node.executorId == executorId) { 208 | return i; 209 | } 210 | } 211 | } 212 | 213 | return -1; 214 | } 215 | 216 | private int getIndexAsPrimary(ActorRef a) { 217 | for (int i = 0; i < servers.length; i++) { 218 | if (servers[i].size() == 0) continue; 219 | 220 | PSNode node = servers[i].getFirst(); 221 | if (node.actor.equals(a)) { 222 | return i; 223 | } 224 | } 225 | 226 | return -1; 227 | } 228 | 229 | private int getIndexAsPrimary(int executorId) { 230 | for (int i = 0; i < servers.length; i++) { 231 | if (servers[i].size() == 0) continue; 232 | 233 | PSNode node = servers[i].getFirst(); 234 | if (node.executorId == executorId) { 235 | return i; 236 | } 237 | } 238 | 239 | return -1; 240 | } 241 | */ 242 | 243 | private void log(String msg) { 244 | Logger.info(msg, "PSManager"); 245 | } 246 | private void warn(String msg) { 247 | Logger.warn(msg, "PSManager"); 248 | } 249 | } 250 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/platform/PSSync.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.platform; 2 | 3 | import akka.actor.ActorRef; 4 | import com.intel.distml.api.DMatrix; 5 | import com.intel.distml.api.Model; 6 | import com.intel.distml.util.*; 7 | 8 | import java.io.*; 9 | import java.net.InetSocketAddress; 10 | import java.net.Socket; 11 | import java.nio.ByteBuffer; 12 | import java.nio.channels.SelectionKey; 13 | import java.nio.channels.Selector; 14 | import java.nio.channels.ServerSocketChannel; 15 | import java.nio.channels.SocketChannel; 16 | import java.util.HashMap; 17 | import java.util.Iterator; 18 | 19 | /** 20 | * Created by yunlong on 1/18/16. 21 | */ 22 | public class PSSync extends Thread { 23 | 24 | static final int ROLE_PRIMARY = 0; 25 | static final int ROLE_STANDBY = 1; 26 | 27 | static final int MAX_SYNC_BLOCK_SIZE = 1000000; 28 | static final int SYNC_INTERVAL = 10000; 29 | 30 | ActorRef owner; 31 | int role; 32 | 33 | boolean running = false; 34 | Socket socket; 35 | DataOutputStream dos; 36 | DataInputStream dis; 37 | 38 | Model model; 39 | HashMap stores; 40 | 41 | public PSSync(ActorRef owner, Model model, HashMap stores) { 42 | this.owner = owner; 43 | this.model = model; 44 | this.stores = stores; 45 | } 46 | 47 | public void asPrimary(Socket socket) { 48 | log("start PS sync as primary"); 49 | this.socket = socket; 50 | role = ROLE_PRIMARY; 51 | try { 52 | dos = new DataOutputStream(socket.getOutputStream()); 53 | dis = new DataInputStream(socket.getInputStream()); 54 | } 55 | catch (Exception e) { 56 | e.printStackTrace(); 57 | } 58 | start(); 59 | } 60 | 61 | public void asStandBy(Socket socket) { 62 | log("start PS sync as standby"); 63 | this.socket = socket; 64 | role = ROLE_STANDBY; 65 | 66 | log("connecting to primvary server"); 67 | DataBusProtocol.SyncRequest req = new DataBusProtocol.SyncRequest(); 68 | try { 69 | dos = new DataOutputStream(socket.getOutputStream()); 70 | dis = new DataInputStream(socket.getInputStream()); 71 | DefaultDataWriter out = new DefaultDataWriter(dos); 72 | out.writeInt(req.sizeAsBytes(model)); 73 | req.write(out, model); 74 | dos.flush(); 75 | log("request to sync"); 76 | } 77 | catch (Exception e) { 78 | e.printStackTrace(); 79 | } 80 | 81 | start(); 82 | } 83 | 84 | public void disconnect() { 85 | running = false; 86 | try { 87 | socket.close(); 88 | } 89 | catch (Exception e) { 90 | } 91 | } 92 | 93 | @Override 94 | public void run() { 95 | running = true; 96 | 97 | try { 98 | if (role == ROLE_PRIMARY) { 99 | syncService(); 100 | } else { 101 | syncClient(); 102 | } 103 | } catch (Exception e) { 104 | if (running) 105 | e.printStackTrace(); 106 | } 107 | } 108 | 109 | private void syncService() throws Exception { 110 | while(running) { 111 | log("new round of sync"); 112 | for (String name : model.dataMap.keySet()) { 113 | byte[] nameBytes = name.getBytes(); 114 | 115 | DataStore store = stores.get(name); 116 | int rows = (int) store.rows().size(); 117 | int maxRowsToSync = MAX_SYNC_BLOCK_SIZE / store.rowSize(); 118 | 119 | int startRow = 0; 120 | int leftRows = rows; 121 | while (leftRows > 0) { 122 | int rowsToSync = (leftRows > maxRowsToSync) ? maxRowsToSync : leftRows; 123 | int endRow = rowsToSync + startRow - 1; 124 | 125 | log("sync: " + name + ", " + startRow + ", " + endRow); 126 | dos.writeInt(nameBytes.length); 127 | dos.write(nameBytes); 128 | dos.writeInt(startRow); 129 | dos.writeInt(endRow); 130 | 131 | store.syncTo(dos, startRow, endRow); 132 | startRow += rowsToSync; 133 | leftRows -= rowsToSync; 134 | 135 | try { 136 | Thread.sleep(SYNC_INTERVAL); 137 | } catch (InterruptedException e) { 138 | } 139 | } 140 | } 141 | } 142 | } 143 | 144 | private void syncClient() throws Exception { 145 | while (running) { 146 | log("handling sync ..."); 147 | Utils.waitUntil(dis, 4); 148 | int nameLen = dis.readInt(); 149 | log("name len: " + nameLen); 150 | Utils.waitUntil(dis, nameLen); 151 | byte[] nameBytes = new byte[nameLen]; 152 | dis.read(nameBytes); 153 | String name = new String(nameBytes); 154 | log("name: " + name); 155 | 156 | DataStore store = stores.get(name); 157 | int startRow = dis.readInt(); 158 | int endRow = dis.readInt(); 159 | log("rows: " + startRow + ", " + endRow); 160 | store.syncFrom(dis, startRow, endRow); 161 | } 162 | } 163 | 164 | private void log(String msg) { 165 | Logger.info(msg, "PSAgent"); 166 | } 167 | private void warn(String msg) { 168 | Logger.warn(msg, "PSAgent"); 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/platform/SSP.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.platform; 2 | 3 | import java.util.*; 4 | 5 | /** 6 | * Created by yunlong on 16-4-7. 7 | */ 8 | public class SSP { 9 | 10 | class Worker { 11 | int index; 12 | int progress; 13 | boolean waiting; 14 | 15 | public Worker(int index, int progress) { 16 | this.index = index; 17 | this.progress = progress; 18 | waiting = false; 19 | } 20 | } 21 | 22 | class SortHelper implements Comparator { 23 | 24 | public int compare(Worker o1, Worker o2) { 25 | return o2.progress - o1.progress; 26 | } 27 | } 28 | 29 | 30 | public class CheckResult { 31 | public int index; 32 | public int progress; 33 | public boolean waiting; 34 | public LinkedList workersToNotify; 35 | 36 | public CheckResult(int index, int progress) { 37 | this.index = index; 38 | this.progress = progress; 39 | waiting = false; 40 | 41 | } 42 | 43 | } 44 | 45 | LinkedList workers; 46 | int maxIterations; 47 | int maxLag; 48 | SortHelper sorter = new SortHelper(); 49 | 50 | public SSP(int maxIterations, int maxLag) { 51 | workers = new LinkedList(); 52 | this.maxIterations = maxIterations; 53 | this.maxLag = maxLag; 54 | } 55 | 56 | /** 57 | * set progress and check whether to wait 58 | * 59 | * 60 | * @param worker 61 | * @param iter 62 | * @return 63 | */ 64 | public CheckResult progress(int worker, int iter) { 65 | 66 | CheckResult result = new CheckResult(worker, iter); 67 | Worker w = null; 68 | 69 | boolean found = false; 70 | for (int i = 0; i < workers.size(); i++) { 71 | w = workers.get(i); 72 | if (w.index == worker) { 73 | assert (w.progress == iter - 1); 74 | w.progress++; 75 | found = true; 76 | break; 77 | } 78 | } 79 | if (!found) 80 | workers.add(new Worker(worker, iter)); 81 | 82 | Collections.sort(workers, sorter); 83 | 84 | Worker last = workers.getLast(); 85 | if (worker != last.index) { 86 | if (iter - last.progress >= maxLag) { 87 | result.waiting = true; 88 | w.waiting = true; 89 | //System.out.println("worker " + worker + " is waiting for " + last.index); 90 | } 91 | } 92 | result.workersToNotify = workersToNotify(w); 93 | 94 | //show(); 95 | 96 | return result; 97 | } 98 | 99 | public LinkedList workersToNotify(Worker p) { 100 | Iterator it = workers.iterator(); 101 | Worker last = workers.getLast(); 102 | 103 | LinkedList workersToNotify = new LinkedList(); 104 | while(it.hasNext()) { 105 | Worker w = it.next(); 106 | if (w == p) break; 107 | 108 | //System.out.println("check notify: " + w.progress + ", " + last.progress); 109 | if (w.progress - last.progress < maxLag) { 110 | if (w.waiting) { 111 | workersToNotify.add(w.index); 112 | w.waiting = false; 113 | } 114 | } 115 | } 116 | 117 | return workersToNotify; 118 | } 119 | 120 | private void show() { 121 | for (int i = 0;i < workers.size(); i++) { 122 | Worker w = workers.get(i); 123 | System.out.print("w[" + i + "]=" + w.progress + "\t"); 124 | } 125 | System.out.println(); 126 | } 127 | 128 | } 129 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/platform/WorkerActor.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.platform; 2 | 3 | import akka.actor.*; 4 | import akka.japi.Creator; 5 | import com.intel.distml.api.Session; 6 | import com.intel.distml.api.Model; 7 | import com.intel.distml.util.Logger; 8 | 9 | import java.io.Serializable; 10 | 11 | /** 12 | * Created by yunlong on 12/13/14. 13 | */ 14 | public class WorkerActor extends UntypedActor { 15 | 16 | public static final int CMD_DISCONNECT = 1; 17 | public static final int CMD_STOP = 2; 18 | public static final int CMD_PS_TERMINATED = 3; 19 | public static final int CMD_PS_AVAILABLE = 4; 20 | 21 | public static class Progress implements Serializable { 22 | private static final long serialVersionUID = 1L; 23 | 24 | final public int sampleCount; 25 | public Progress(int sampleCount) { 26 | this.sampleCount = sampleCount; 27 | } 28 | } 29 | 30 | public static class Command implements Serializable { 31 | private static final long serialVersionUID = 1L; 32 | 33 | final public int cmd; 34 | public Command(int cmd) { 35 | this.cmd = cmd; 36 | } 37 | } 38 | 39 | public static class PsTerminated extends Command { 40 | private static final long serialVersionUID = 1L; 41 | 42 | final public int index; 43 | public PsTerminated(int index) { 44 | super(CMD_PS_TERMINATED); 45 | this.index = index; 46 | } 47 | } 48 | 49 | public static class PsAvailable extends Command { 50 | private static final long serialVersionUID = 1L; 51 | 52 | final public int index; 53 | final public String addr; 54 | public PsAvailable(int index, String addr) { 55 | super(CMD_PS_AVAILABLE); 56 | this.index = index; 57 | this.addr = addr; 58 | } 59 | } 60 | 61 | public static class RegisterRequest implements Serializable { 62 | private static final long serialVersionUID = 1L; 63 | 64 | final public int workerIndex; 65 | public RegisterRequest(int workerIndex) { 66 | this.workerIndex = workerIndex; 67 | } 68 | } 69 | 70 | public static class AppRequest implements Serializable { 71 | private static final long serialVersionUID = 1L; 72 | 73 | public boolean done; 74 | public AppRequest() { 75 | done = false; 76 | } 77 | 78 | public String toString() { 79 | return "DriverRequest"; 80 | } 81 | } 82 | 83 | public static class IterationDone extends AppRequest { 84 | private static final long serialVersionUID = 1L; 85 | 86 | int iteration; 87 | double cost; 88 | public IterationDone(int iteration, double cost) { 89 | this.iteration = iteration; 90 | this.cost = cost; 91 | } 92 | 93 | public String toString() { 94 | return "IterationDone"; 95 | } 96 | } 97 | 98 | private ActorSelection monitor; 99 | private Model model; 100 | private int psCount; 101 | private Session de; 102 | 103 | private String[] psAddrs; 104 | private int workerIndex; 105 | 106 | private AppRequest pendingRequest; 107 | 108 | public WorkerActor(final Session de, Model model, String monitorPath, int workerIndex) { 109 | this.monitor = getContext().actorSelection(monitorPath); 110 | this.workerIndex = workerIndex; 111 | this.model = model; 112 | this.de = de; 113 | 114 | this.monitor.tell(new RegisterRequest(this.workerIndex), getSelf()); 115 | log("Worker " + workerIndex + " register to monitor: " + monitorPath); 116 | } 117 | 118 | public static Props props(final Session de,final Model model, final String monitorPath, final int index) { 119 | return Props.create(new Creator() { 120 | private static final long serialVersionUID = 1L; 121 | public WorkerActor create() throws Exception { 122 | return new WorkerActor(de, model, monitorPath, index); 123 | } 124 | }); 125 | } 126 | 127 | @Override 128 | public void onReceive(Object msg) { 129 | log("onReceive: " + msg); 130 | 131 | if (msg instanceof MonitorActor.RegisterResponse) { 132 | MonitorActor.RegisterResponse res = (MonitorActor.RegisterResponse) msg; 133 | this.psAddrs = res.psAddrs; 134 | psCount = psAddrs.length; 135 | 136 | de.dataBus = new WorkerAgent(model, psAddrs); 137 | } 138 | else if (msg instanceof Progress) { 139 | this.monitor.tell(msg, getSelf()); 140 | } 141 | else if (msg instanceof IterationDone) { 142 | IterationDone req = (IterationDone) msg; 143 | pendingRequest = (AppRequest)msg; 144 | monitor.tell(new MonitorActor.SSP_IterationDone(workerIndex, req.iteration, req.cost), getSelf()); 145 | } 146 | else if (msg instanceof MonitorActor.SSP_IterationNext) { 147 | assert(pendingRequest != null); 148 | pendingRequest.done = true; 149 | } 150 | else if (msg instanceof PsAvailable) { 151 | PsAvailable ps = (PsAvailable) msg; 152 | de.dataBus.psAvailable(ps.index, ps.addr); 153 | } 154 | else unhandled(msg); 155 | } 156 | 157 | private void log(String msg) { 158 | Logger.info(msg, "Worker-" + workerIndex); 159 | } 160 | } 161 | 162 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/AbstractDataReader.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import java.io.DataInputStream; 4 | 5 | /** 6 | * Created by yunlong on 4/28/16. 7 | */ 8 | public interface AbstractDataReader { 9 | 10 | int readInt() throws Exception; 11 | 12 | long readLong() throws Exception; 13 | 14 | float readFloat() throws Exception; 15 | 16 | double readDouble() throws Exception; 17 | 18 | void waitUntil(int size) throws Exception; 19 | 20 | int readBytes(byte[] bytes) throws Exception; 21 | 22 | void readFully(byte[] bytes) throws Exception; 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/AbstractDataWriter.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | /** 4 | * Created by yunlong on 4/28/16. 5 | */ 6 | public interface AbstractDataWriter { 7 | 8 | void writeInt(int value) throws Exception; 9 | 10 | void writeLong(long value) throws Exception; 11 | 12 | void writeFloat(float value) throws Exception; 13 | 14 | void writeDouble(double value) throws Exception; 15 | 16 | void writeBytes(byte[] bytes) throws Exception; 17 | 18 | void flush() throws Exception; 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/ByteBufferDataReader.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import java.nio.ByteBuffer; 4 | 5 | /** 6 | * Created by yunlong on 4/28/16. 7 | */ 8 | public class ByteBufferDataReader implements AbstractDataReader { 9 | 10 | ByteBuffer buf; 11 | 12 | public ByteBufferDataReader(ByteBuffer buf) { 13 | this.buf = buf; 14 | } 15 | 16 | public int readInt() { 17 | return buf.getInt(); 18 | } 19 | 20 | public long readLong() { 21 | return buf.getLong(); 22 | } 23 | 24 | public float readFloat() { 25 | return buf.getFloat(); 26 | } 27 | 28 | public double readDouble() { 29 | return buf.getDouble(); 30 | } 31 | 32 | public int readBytes(byte[] bytes) throws Exception { 33 | buf.get(bytes); 34 | 35 | return bytes.length; 36 | } 37 | 38 | public void readFully(byte[] bytes) throws Exception { 39 | buf.get(bytes); 40 | } 41 | 42 | 43 | public void waitUntil(int size) throws Exception { 44 | 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/ByteBufferDataWriter.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import java.nio.ByteBuffer; 4 | 5 | /** 6 | * Created by yunlong on 4/28/16. 7 | */ 8 | public class ByteBufferDataWriter implements AbstractDataWriter { 9 | 10 | 11 | ByteBuffer buf; 12 | 13 | public ByteBufferDataWriter(ByteBuffer buf) { 14 | this.buf = buf; 15 | } 16 | 17 | public void writeInt(int value) throws Exception { 18 | buf.putInt(value); 19 | } 20 | 21 | public void writeLong(long value) throws Exception { 22 | buf.putLong(value); 23 | } 24 | 25 | public void writeFloat(float value) throws Exception { 26 | buf.putFloat(value); 27 | 28 | } 29 | 30 | public void writeDouble(double value) throws Exception { 31 | buf.putDouble(value); 32 | } 33 | 34 | public void writeBytes(byte[] bytes) throws Exception { 35 | buf.put(bytes); 36 | } 37 | 38 | public void flush() throws Exception { 39 | 40 | } 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/Constants.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import akka.util.Timeout; 4 | import scala.concurrent.duration.FiniteDuration; 5 | 6 | import java.util.concurrent.TimeUnit; 7 | 8 | /** 9 | * Created by taotao on 15-1-30. 10 | */ 11 | public class Constants { 12 | 13 | // Data may be transferred for a long time, and we must ensure that data is successfully transferred, 14 | // so we set a much longer timeout for data future. 15 | public static final Long DATA_FUTURE_TIME = 1000000L; // Unit: ms, 1000s 16 | public static final Timeout DATA_FUTURE_TIMEOUT = new Timeout(DATA_FUTURE_TIME, TimeUnit.MILLISECONDS); 17 | public static final FiniteDuration DATA_FUTURE_TIMEOUT_DURATION = DATA_FUTURE_TIMEOUT.duration(); 18 | 19 | // Stop should be done in a short period, so we set a shorter timeout for stop future. 20 | public static final Long STOP_FUTURE_TIME = 30000L; // Unit: ms, 30s 21 | public static final Timeout STOP_FUTURE_TIMEOUT = new Timeout(STOP_FUTURE_TIME, TimeUnit.MILLISECONDS); 22 | public static final FiniteDuration STOP_FUTURE_TIMEOUT_DURATION = STOP_FUTURE_TIMEOUT.duration(); 23 | 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/DataDesc.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import java.io.*; 4 | 5 | /** 6 | * Created by yunlong on 12/30/15. 7 | */ 8 | public final class DataDesc implements Serializable { 9 | 10 | public static final int DATA_TYPE_ARRAY = 0; 11 | public static final int DATA_TYPE_MATRIX = 1; 12 | 13 | public static final int KEY_TYPE_INT = 0; 14 | public static final int KEY_TYPE_LONG = 1; 15 | 16 | public static final int ELEMENT_TYPE_INT = 0; 17 | public static final int ELEMENT_TYPE_FLOAT = 1; 18 | public static final int ELEMENT_TYPE_LONG = 2; 19 | public static final int ELEMENT_TYPE_DOUBLE = 3; 20 | 21 | public int dataType; 22 | public int keyType; 23 | public int valueType; 24 | 25 | public boolean denseRow; 26 | public boolean denseColumn; 27 | public boolean adaGrad; 28 | 29 | public int keySize; 30 | public int valueSize; 31 | 32 | public DataDesc() { 33 | } 34 | 35 | public DataDesc(int dataType, int keyType, int valueType) { 36 | this(dataType, keyType, valueType, false, true); 37 | } 38 | 39 | public DataDesc(int dataType, int keyType, int valueType, boolean denseRow, boolean denseColumn) { 40 | this(dataType, keyType, valueType, denseRow, denseColumn, false); 41 | } 42 | public DataDesc(int dataType, int keyType, int valueType, boolean denseRow, boolean denseColumn, boolean adaGrade) { 43 | this.dataType = dataType; 44 | this.valueType = valueType; 45 | this.keyType = keyType; 46 | this.denseRow = denseRow; 47 | this.denseColumn = denseColumn; 48 | this.adaGrad = adaGrade; 49 | 50 | this.keySize = (keyType == KEY_TYPE_INT)? 4 : 8; 51 | this.valueSize = ((valueType == ELEMENT_TYPE_INT) || (valueType == ELEMENT_TYPE_FLOAT))? 4 : 8; 52 | } 53 | 54 | public String toString() { 55 | return "" + dataType + ", " + keyType + ", " + keySize + ", " + valueType + ", " + valueSize; 56 | } 57 | 58 | public int sizeAsBytes() { 59 | return 24; // keySize and valueSize are calculated in fly 60 | } 61 | 62 | public void write(AbstractDataWriter out) throws Exception { 63 | out.writeInt(dataType); 64 | out.writeInt(keyType); 65 | out.writeInt(valueType); 66 | out.writeInt(denseRow ? 1 : 0); 67 | out.writeInt(denseColumn ? 1 : 0); 68 | out.writeInt(adaGrad ? 1 : 0); 69 | } 70 | 71 | public void read(AbstractDataReader in) throws Exception { 72 | dataType = in.readInt(); 73 | keyType = in.readInt(); 74 | valueType = in.readInt(); 75 | denseRow = in.readInt() == 1; 76 | denseColumn = in.readInt() == 1; 77 | adaGrad = in.readInt() == 1; 78 | 79 | this.keySize = (keyType == KEY_TYPE_INT)? 4 : 8; 80 | this.valueSize = ((valueType == ELEMENT_TYPE_INT) || (valueType == ELEMENT_TYPE_FLOAT))? 4 : 8; 81 | } 82 | 83 | public Number readKey(AbstractDataReader is) throws Exception { 84 | if (keyType == KEY_TYPE_INT) { 85 | return is.readInt(); 86 | } 87 | else { 88 | return is.readLong(); 89 | } 90 | } 91 | 92 | 93 | public void writeKey(Number v, AbstractDataWriter os) throws Exception { 94 | if (keyType == KEY_TYPE_INT) { 95 | os.writeInt(v.intValue()); 96 | } 97 | else { 98 | os.writeLong(v.longValue()); 99 | } 100 | } 101 | 102 | public Object readValue(AbstractDataReader is) throws Exception { 103 | switch(valueType) { 104 | case ELEMENT_TYPE_INT: 105 | return is.readInt(); 106 | case ELEMENT_TYPE_FLOAT: 107 | return is.readFloat(); 108 | case ELEMENT_TYPE_LONG: 109 | return is.readLong(); 110 | case ELEMENT_TYPE_DOUBLE: 111 | return is.readDouble(); 112 | } 113 | 114 | throw new IllegalStateException("invalid value type: " + valueType); 115 | } 116 | 117 | public void writeValue(Object value, AbstractDataWriter os) throws Exception { 118 | switch(valueType) { 119 | case ELEMENT_TYPE_INT: 120 | os.writeInt((Integer) value); 121 | case ELEMENT_TYPE_FLOAT: 122 | os.writeFloat((Float) value); 123 | case ELEMENT_TYPE_LONG: 124 | os.writeLong((Long) value); 125 | case ELEMENT_TYPE_DOUBLE: 126 | os.writeDouble((Double)value); 127 | } 128 | throw new IllegalStateException("invalid value type: " + valueType); 129 | } 130 | 131 | public Number readKey(byte[] data, int offset) { 132 | if (keyType == KEY_TYPE_INT) { 133 | return readInt(data, offset); 134 | } 135 | else { 136 | return readLong(data, offset); 137 | } 138 | } 139 | 140 | public int writeKey(Number v, byte[] data, int offset) { 141 | if (keyType == KEY_TYPE_INT) { 142 | write(v.intValue(), data, offset); 143 | return offset + 4; 144 | } 145 | else { 146 | write(v.longValue(), data, offset); 147 | return offset + 8; 148 | } 149 | } 150 | 151 | public Object readValue(byte[] buf, int offset) { 152 | switch(valueType) { 153 | case ELEMENT_TYPE_INT: 154 | return readInt(buf, offset); 155 | case ELEMENT_TYPE_FLOAT: 156 | return readFloat(buf, offset); 157 | case ELEMENT_TYPE_LONG: 158 | return readLong(buf, offset); 159 | case ELEMENT_TYPE_DOUBLE: 160 | return readDouble(buf, offset); 161 | } 162 | 163 | throw new IllegalStateException("invalid value type: " + valueType); 164 | } 165 | 166 | public int writeValue(Object value, byte[] buf, int offset) { 167 | switch(valueType) { 168 | case ELEMENT_TYPE_INT: 169 | return write((Integer)value, buf, offset); 170 | case ELEMENT_TYPE_FLOAT: 171 | return write((Float)value, buf, offset); 172 | case ELEMENT_TYPE_LONG: 173 | return write((Long)value, buf, offset); 174 | case ELEMENT_TYPE_DOUBLE: 175 | return write((Double)value, buf, offset); 176 | } 177 | throw new IllegalStateException("invalid value type: " + valueType); 178 | } 179 | 180 | public int readInt(byte[] data, int offset) { 181 | int targets = 182 | (data[offset ] & 0x000000ff) 183 | | ((data[offset+1] << 8) & 0x0000ff00) 184 | | ((data[offset+2] << 16) & 0x00ff0000) 185 | | ((data[offset+3] << 24) & 0xff000000); 186 | 187 | return targets; 188 | } 189 | public float readFloat(byte[] data, int offset) { 190 | int targets = readInt(data, offset); 191 | return Float.intBitsToFloat(targets); 192 | } 193 | 194 | public long readLong(byte[] data, int offset) { 195 | long targets = 196 | (data[offset ] & 0x00000000000000ffL) 197 | | ((data[offset+1] << 8) & 0x000000000000ff00L) 198 | | ((data[offset+2] << 16) & 0x0000000000ff0000L) 199 | | ((data[offset+3] << 24) & 0x00000000ff000000L) 200 | | ((((long)data[offset+4]) << 32) & 0x000000ff00000000L) 201 | | ((((long)data[offset+5]) << 40) & 0x0000ff0000000000L) 202 | | ((((long)data[offset+6] << 48)) & 0x00ff000000000000L) 203 | | ((((long)data[offset+7] << 56)) & 0xff00000000000000L); 204 | 205 | return targets; 206 | } 207 | public double readDouble(byte[] data, int offset) { 208 | long targets = readLong(data, offset); 209 | double value = Double.longBitsToDouble(targets); 210 | //System.out.println("read double: " + value); 211 | return value; 212 | } 213 | 214 | public int write(double v, byte[] data, int offset) { 215 | long value = Double.doubleToLongBits(v); 216 | //System.out.println("write double: " + v); 217 | return write(value, data, offset); 218 | } 219 | 220 | public int write(long value, byte[] data, int offset) { 221 | data[offset] = (byte) (value & 0xff); 222 | data[offset+1] = (byte) ((value >> 8) & 0xff); 223 | data[offset+2] = (byte) ((value >> 16) & 0xff); 224 | data[offset+3] = (byte) ((value >> 24) & 0xff); 225 | data[offset+4] = (byte) ((value >> 32) & 0xff); 226 | data[offset+5] = (byte) ((value >> 40) & 0xff); 227 | data[offset+6] = (byte) ((value >> 48) & 0xff); 228 | data[offset+7] = (byte) ((value >> 56) & 0xff); 229 | return offset + 8; 230 | } 231 | 232 | public int write(float v, byte[] data, int offset) { 233 | int value = Float.floatToIntBits(v); 234 | return write(value, data, offset); 235 | } 236 | 237 | public int write(int value, byte[] data, int offset) { 238 | data[offset] = (byte) (value & 0xff); 239 | data[offset+1] = (byte) ((value >> 8) & 0xff); 240 | data[offset+2] = (byte) ((value >> 16) & 0xff); 241 | data[offset+3] = (byte) (value >>> 24); 242 | return offset + 4; 243 | } 244 | 245 | } 246 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/DataStore.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import com.intel.distml.api.DMatrix; 4 | import com.intel.distml.api.Model; 5 | import com.intel.distml.util.store.*; 6 | 7 | import java.io.DataInputStream; 8 | import java.io.DataOutputStream; 9 | import java.io.IOException; 10 | import java.io.OutputStream; 11 | import java.util.HashMap; 12 | import java.util.Map; 13 | 14 | /** 15 | * Created by yunlong on 12/8/15. 16 | */ 17 | public abstract class DataStore { 18 | 19 | public abstract KeyCollection rows(); 20 | public abstract int rowSize(); 21 | 22 | public void rand() {}; 23 | 24 | public void zero() {}; 25 | 26 | public void set(String value) {}; 27 | 28 | public abstract byte[] handleFetch(DataDesc format, KeyCollection rows); 29 | 30 | public abstract void handlePush(DataDesc format, byte[] data); 31 | 32 | public abstract void writeAll(DataOutputStream os) throws IOException; 33 | 34 | public abstract void readAll(DataInputStream is) throws IOException; 35 | 36 | public abstract void syncTo(DataOutputStream os, int fromRow, int toRow) throws IOException; 37 | 38 | public abstract void syncFrom(DataInputStream is, int fromRow, int toRow) throws IOException; 39 | 40 | 41 | public static HashMap createStores(Model model, int serverIndex) { 42 | HashMap stores = new HashMap(); 43 | for (Map.Entry m : model.dataMap.entrySet()) { 44 | stores.put(m.getKey(), DataStore.createStore(serverIndex, m.getValue())); 45 | } 46 | 47 | return stores; 48 | } 49 | 50 | public static DataStore createStore(int serverIndex, DMatrix matrix) { 51 | System.out.println("create store: " + serverIndex + ", cols: " + matrix.getColKeys().size()); 52 | DataDesc format = matrix.getFormat(); 53 | if (format.dataType == DataDesc.DATA_TYPE_ARRAY) { 54 | if (format.valueType == DataDesc.ELEMENT_TYPE_INT) { 55 | IntArrayStore store = new IntArrayStore(); 56 | store.init(matrix.partitions[serverIndex]); 57 | return store; 58 | } else if (format.valueType == DataDesc.ELEMENT_TYPE_DOUBLE) { 59 | DoubleArrayStore store = new DoubleArrayStore(); 60 | store.init(matrix.partitions[serverIndex]); 61 | return store; 62 | } else if (format.valueType == DataDesc.ELEMENT_TYPE_FLOAT) { 63 | FloatArrayStore store = new FloatArrayStore(); 64 | store.init(matrix.partitions[serverIndex]); 65 | return store; 66 | } 67 | } 68 | else { 69 | if (format.valueType == DataDesc.ELEMENT_TYPE_INT) { 70 | IntMatrixStore store = new IntMatrixStore(); 71 | store.init(matrix.partitions[serverIndex], (int) matrix.getColKeys().size()); 72 | return store; 73 | } else if (format.valueType == DataDesc.ELEMENT_TYPE_DOUBLE) { 74 | DoubleMatrixStore store = new DoubleMatrixStore(); 75 | store.init(matrix.partitions[serverIndex], (int) matrix.getColKeys().size()); 76 | return store; 77 | } else if (format.valueType == DataDesc.ELEMENT_TYPE_FLOAT) { 78 | if (format.adaGrad) { 79 | FloatMatrixStoreAdaGrad store = new FloatMatrixStoreAdaGrad(); 80 | store.init(matrix.partitions[serverIndex], (int) matrix.getColKeys().size()); 81 | return store; 82 | } 83 | else { 84 | FloatMatrixStore store = new FloatMatrixStore(); 85 | store.init(matrix.partitions[serverIndex], (int) matrix.getColKeys().size()); 86 | return store; 87 | } 88 | } 89 | } 90 | 91 | throw new IllegalArgumentException("Unrecognized matrix type: " + matrix.getClass().getName()); 92 | } 93 | 94 | } 95 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/DefaultDataReader.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import java.io.DataInputStream; 4 | import java.io.IOException; 5 | import java.nio.ByteBuffer; 6 | 7 | /** 8 | * Created by yunlong on 4/28/16. 9 | */ 10 | public class DefaultDataReader implements AbstractDataReader { 11 | 12 | DataInputStream dis; 13 | 14 | public DefaultDataReader(DataInputStream dis) { 15 | this.dis = dis; 16 | } 17 | 18 | public int readInt() throws Exception { 19 | return dis.readInt(); 20 | } 21 | 22 | public long readLong() throws Exception { 23 | return dis.readLong(); 24 | } 25 | 26 | public float readFloat() throws Exception { 27 | return dis.readFloat(); 28 | } 29 | 30 | public double readDouble() throws Exception { 31 | return dis.readDouble(); 32 | } 33 | 34 | public int readBytes(byte[] bytes) throws Exception { 35 | return dis.read(bytes); 36 | } 37 | 38 | public void readFully(byte[] bytes) throws Exception { 39 | dis.readFully(bytes); 40 | } 41 | 42 | public void waitUntil(int size) throws Exception { 43 | while(dis.available() < size) { try { Thread.sleep(1); } catch (Exception e){}} 44 | } 45 | 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/DefaultDataWriter.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import java.io.DataOutputStream; 4 | import java.nio.ByteBuffer; 5 | 6 | /** 7 | * Created by yunlong on 4/28/16. 8 | */ 9 | public class DefaultDataWriter implements AbstractDataWriter { 10 | 11 | 12 | DataOutputStream dos; 13 | 14 | public DefaultDataWriter(DataOutputStream dos) { 15 | this.dos = dos; 16 | } 17 | 18 | public void writeInt(int value) throws Exception { 19 | dos.writeInt(value); 20 | } 21 | 22 | public void writeLong(long value) throws Exception { 23 | dos.writeLong(value); 24 | } 25 | 26 | public void writeFloat(float value) throws Exception { 27 | dos.writeFloat(value); 28 | 29 | } 30 | 31 | public void writeDouble(double value) throws Exception { 32 | dos.writeDouble(value); 33 | } 34 | 35 | public void writeBytes(byte[] bytes) throws Exception { 36 | dos.write(bytes); 37 | } 38 | 39 | public void flush() throws Exception { 40 | dos.flush(); 41 | } 42 | 43 | 44 | } 45 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/DoubleArray.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import com.intel.distml.api.DMatrix; 4 | import com.intel.distml.api.Session; 5 | import com.intel.distml.util.KeyCollection; 6 | 7 | import java.util.HashMap; 8 | 9 | /** 10 | * Created by yunlong on 12/8/15. 11 | */ 12 | 13 | public class DoubleArray extends SparseArray { 14 | 15 | public DoubleArray(long dim) { 16 | super(dim, DataDesc.KEY_TYPE_LONG, DataDesc.ELEMENT_TYPE_DOUBLE); 17 | } 18 | 19 | } -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/DoubleMatrix.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import com.intel.distml.api.DMatrix; 4 | import com.intel.distml.api.Session; 5 | 6 | import java.util.HashMap; 7 | 8 | /** 9 | * Created by jimmy on 15-12-29. 10 | */ 11 | public class DoubleMatrix extends SparseMatrix { 12 | public DoubleMatrix(long rows, int cols) { 13 | super(rows, cols, DataDesc.KEY_TYPE_LONG, DataDesc.ELEMENT_TYPE_DOUBLE); 14 | } 15 | 16 | protected boolean isZero(Double value) { 17 | return value == 0.0; 18 | } 19 | protected Double[] createValueArray(int size) { 20 | return new Double [size]; 21 | } 22 | public HashMap fetch(KeyCollection rows, Session session) { 23 | HashMap result; 24 | result = super.fetch(rows, session); 25 | //System.out.println(result.get(1).length); 26 | return result; 27 | } 28 | } 29 | 30 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/IntArray.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import akka.pattern.Patterns; 4 | import com.intel.distml.api.DMatrix; 5 | import com.intel.distml.api.Session; 6 | import com.intel.distml.platform.DataBusProtocol; 7 | import com.intel.distml.util.KeyCollection; 8 | import scala.concurrent.Future; 9 | 10 | import java.util.HashMap; 11 | import java.util.Iterator; 12 | import java.util.LinkedList; 13 | import java.util.Map; 14 | 15 | /** 16 | * Created by yunlong on 12/8/15. 17 | */ 18 | public class IntArray extends SparseArray { 19 | 20 | public IntArray(long dim) { 21 | super(dim, DataDesc.KEY_TYPE_LONG, DataDesc.ELEMENT_TYPE_INT); 22 | } 23 | 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/IntArrayWithIntKey.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | /** 4 | * Created by yunlong on 1/2/16. 5 | */ 6 | public class IntArrayWithIntKey extends SparseArray { 7 | 8 | public IntArrayWithIntKey(long dim) { 9 | super(dim, DataDesc.KEY_TYPE_INT, DataDesc.ELEMENT_TYPE_INT); 10 | } 11 | 12 | } -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/IntMatrix.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import com.intel.distml.api.DMatrix; 4 | import com.intel.distml.api.Session; 5 | import com.intel.distml.util.KeyCollection; 6 | 7 | import java.util.HashMap; 8 | import java.util.Map; 9 | 10 | /** 11 | * Created by yunlong on 12/8/15. 12 | */ 13 | public class IntMatrix extends SparseMatrix { 14 | 15 | public IntMatrix(long rows, int cols) { 16 | super(rows, cols, DataDesc.KEY_TYPE_LONG, DataDesc.ELEMENT_TYPE_INT); 17 | } 18 | 19 | protected boolean isZero(Integer value) { 20 | return value == 0; 21 | } 22 | 23 | protected Integer[] createValueArray(int size) { 24 | return new Integer[size]; 25 | } 26 | 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/IntMatrixWithIntKey.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | /** 4 | * Created by yunlong on 12/8/15. 5 | */ 6 | public class IntMatrixWithIntKey extends SparseMatrix { 7 | 8 | public IntMatrixWithIntKey(long rows, int cols) { 9 | super(rows, cols, DataDesc.KEY_TYPE_INT, DataDesc.ELEMENT_TYPE_INT); 10 | } 11 | 12 | protected boolean isZero(Integer value) { 13 | return value == 0; 14 | } 15 | 16 | protected Integer[] createValueArray(int size) { 17 | return new Integer[size]; 18 | } 19 | 20 | 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/KeyCollection.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import java.io.DataInputStream; 4 | import java.io.DataOutputStream; 5 | import java.io.IOException; 6 | import java.io.Serializable; 7 | import java.util.Iterator; 8 | 9 | /** 10 | * Created by yunlong on 12/11/14. 11 | */ 12 | public abstract class KeyCollection implements Serializable { 13 | 14 | public static final int TYPE_ALL = 0; 15 | public static final int TYPE_EMPTY = 1; 16 | public static final int TYPE_RANGE = 2; 17 | public static final int TYPE_LIST = 3; 18 | public static final int TYPE_HASH = 4; 19 | 20 | public static final KeyCollection EMPTY = new EmptyKeys(); 21 | public static final KeyCollection SINGLE = new KeyRange(0, 0); 22 | public static final KeyCollection ALL = new AllKeys(); 23 | 24 | public int type; 25 | 26 | public KeyCollection(int type) { 27 | this.type = type; 28 | } 29 | 30 | public int sizeAsBytes(DataDesc format) { 31 | return 4; 32 | } 33 | 34 | public void write(AbstractDataWriter out, DataDesc format) throws Exception { 35 | out.writeInt(type); 36 | } 37 | 38 | public void read(AbstractDataReader in, DataDesc format) throws Exception { 39 | } 40 | 41 | public static KeyCollection readKeyCollection(AbstractDataReader in, DataDesc format) throws Exception { 42 | in.waitUntil(4); 43 | int type = in.readInt(); 44 | KeyCollection ks; 45 | switch(type) { 46 | case TYPE_ALL: 47 | return ALL; 48 | case TYPE_EMPTY: 49 | return EMPTY; 50 | 51 | case TYPE_LIST: 52 | ks = new KeyList(); 53 | break; 54 | 55 | case TYPE_RANGE: 56 | ks = new KeyRange(); 57 | break; 58 | 59 | default: 60 | ks = new KeyHash(); 61 | break; 62 | 63 | } 64 | 65 | in.waitUntil(ks.sizeAsBytes(format)); 66 | ks.read(in, format); 67 | 68 | return ks; 69 | } 70 | 71 | public abstract boolean contains(long key); 72 | 73 | public abstract Iterator iterator(); 74 | 75 | public abstract boolean isEmpty(); 76 | 77 | public abstract long size(); 78 | 79 | public KeyCollection intersect(KeyCollection keys) { 80 | 81 | if (keys.equals(KeyCollection.ALL)) { 82 | return this; 83 | } 84 | 85 | if (keys.equals(KeyCollection.EMPTY)) { 86 | return keys; 87 | } 88 | 89 | KeyList result = new KeyList(); 90 | 91 | Iterator it = keys.iterator(); 92 | while(it.hasNext()) { 93 | long key = it.next(); 94 | if (contains(key)) { 95 | result.addKey(key); 96 | } 97 | } 98 | 99 | return result; 100 | } 101 | 102 | public static class EmptyKeys extends KeyCollection { 103 | 104 | public EmptyKeys() { 105 | super(TYPE_EMPTY); 106 | } 107 | 108 | @Override 109 | public boolean equals(Object obj) { 110 | return (obj instanceof EmptyKeys); 111 | } 112 | 113 | @Override 114 | public boolean contains(long key) { 115 | return false; 116 | } 117 | 118 | @Override 119 | public KeyCollection intersect(KeyCollection keys) { 120 | return this; 121 | } 122 | 123 | @Override 124 | public Iterator iterator() { 125 | return new Iterator() { 126 | public boolean hasNext() { return false; } 127 | public Long next() { return -1L; } 128 | public void remove() { } 129 | }; 130 | } 131 | 132 | @Override 133 | public boolean isEmpty() { 134 | return true; 135 | } 136 | /* 137 | @Override 138 | public PartitionInfo partitionEqually(int hostNum) { 139 | throw new UnsupportedOperationException("This is an EMPTY_KEYS instance, not partitionable"); 140 | } 141 | 142 | @Override 143 | public KeyCollection[] split(int hostNum) { 144 | throw new UnsupportedOperationException("This is an EMPTY_KEYS instance, not partitionable"); 145 | } 146 | */ 147 | @Override 148 | public long size() { 149 | return 0; 150 | } 151 | }; 152 | 153 | public static class AllKeys extends KeyCollection { 154 | 155 | public AllKeys() { 156 | super(TYPE_ALL); 157 | } 158 | 159 | 160 | @Override 161 | public boolean equals(Object obj) { 162 | return (obj instanceof AllKeys); 163 | } 164 | 165 | @Override 166 | public boolean contains(long key) { 167 | return true; 168 | } 169 | 170 | @Override 171 | public KeyCollection intersect(KeyCollection keys) { 172 | return keys; 173 | } 174 | 175 | @Override 176 | public Iterator iterator() { 177 | throw new UnsupportedOperationException("This is an ALL_KEYS instance, not iterable."); 178 | } 179 | 180 | @Override 181 | public boolean isEmpty() { 182 | return false; 183 | } 184 | /* 185 | @Override 186 | public PartitionInfo partitionEqually(int hostNum) { 187 | throw new UnsupportedOperationException("This is an ALL_KEYS instance, not partitionable"); 188 | } 189 | 190 | @Override 191 | public KeyCollection[] split(int hostNum) { 192 | throw new UnsupportedOperationException("This is an ALL_KEYS instance, not partitionable"); 193 | } 194 | */ 195 | @Override 196 | public long size() { 197 | throw new UnsupportedOperationException("This is an ALL_KEYS instance, size unknown"); 198 | } 199 | }; 200 | } 201 | 202 | 203 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/KeyHash.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import java.io.DataInputStream; 4 | import java.io.DataOutputStream; 5 | import java.io.IOException; 6 | import java.util.Iterator; 7 | 8 | /** 9 | * Created by yunlong on 12/11/14. 10 | */ 11 | public class KeyHash extends KeyCollection { 12 | 13 | public int hashQuato; 14 | public int hashIndex; 15 | 16 | public long minKey; 17 | public long maxKey; 18 | //public long totalKeyNum; 19 | 20 | private long first, last; 21 | 22 | KeyHash() { 23 | super(KeyCollection.TYPE_HASH); 24 | } 25 | 26 | public KeyHash(int hashQuato, int hashIndex, long minKey, long maxKey) { 27 | super(KeyCollection.TYPE_HASH); 28 | 29 | this.hashQuato = hashQuato; 30 | this.hashIndex = hashIndex; 31 | this.minKey = minKey; 32 | this.maxKey = maxKey; 33 | 34 | first = minKey - minKey%hashQuato + hashIndex; 35 | if (first < minKey) { 36 | first += hashQuato; 37 | } 38 | last = maxKey - maxKey%hashQuato + hashIndex; 39 | if (last > maxKey) { 40 | last -= hashQuato; 41 | } 42 | } 43 | 44 | @Override 45 | public boolean equals(Object obj) { 46 | if (!(obj instanceof KeyHash)) { 47 | return false; 48 | } 49 | 50 | KeyHash o = (KeyHash)obj; 51 | return (hashQuato == o.hashQuato) && (hashIndex == o.hashIndex) && (minKey == o.minKey) && (maxKey == o.maxKey); 52 | } 53 | 54 | @Override 55 | public int sizeAsBytes(DataDesc format) { 56 | return super.sizeAsBytes(format) + 8 + 4 * format.keySize; 57 | } 58 | 59 | @Override 60 | public void write(AbstractDataWriter out, DataDesc format) throws Exception { 61 | super.write(out, format); 62 | 63 | out.writeInt(hashQuato); 64 | out.writeInt(hashIndex); 65 | 66 | if (format.keyType == DataDesc.KEY_TYPE_INT) { 67 | out.writeInt((int)minKey); 68 | out.writeInt((int)maxKey); 69 | out.writeInt((int)first); 70 | out.writeInt((int)last); 71 | } 72 | else { 73 | out.writeLong(minKey); 74 | out.writeLong(maxKey); 75 | out.writeLong(first); 76 | out.writeLong(last); 77 | } 78 | } 79 | 80 | @Override 81 | public void read(AbstractDataReader in, DataDesc format) throws Exception { 82 | //super.read(in); 83 | hashQuato = in.readInt(); 84 | hashIndex = in.readInt(); 85 | if (format.keyType == DataDesc.KEY_TYPE_INT) { 86 | minKey = in.readInt(); 87 | maxKey = in.readInt(); 88 | first = in.readInt(); 89 | last = in.readInt(); 90 | } 91 | else { 92 | minKey = in.readLong(); 93 | maxKey = in.readLong(); 94 | first = in.readLong(); 95 | last = in.readLong(); 96 | } 97 | } 98 | 99 | 100 | public long size() { 101 | return (last - first) / hashQuato + 1; 102 | //return (totalKeyNum /hashQuato) + (((totalKeyNum % hashQuato) > hashIndex)? 1 : 0); 103 | } 104 | 105 | @Override 106 | public boolean contains(long key) { 107 | if ((key > last) || (key < first)) { 108 | //throw new RuntimeException("unexpected key: " + key + " >= " + totalKeyNum); 109 | return false; 110 | } 111 | 112 | //System.out.println("check contains: " + key); 113 | return key % hashQuato == hashIndex; 114 | } 115 | 116 | @Override 117 | public boolean isEmpty() { 118 | return size() == 0; 119 | } 120 | 121 | @Override 122 | public String toString() { 123 | return "[KeyHash: quato=" + hashQuato + ", index=" + hashIndex + ", min=" + minKey + ", max=" + maxKey + "]"; 124 | } 125 | 126 | @Override 127 | public KeyCollection intersect(KeyCollection keys) { 128 | 129 | if (keys.equals(KeyCollection.ALL)) { 130 | return this; 131 | } 132 | 133 | if (keys.equals(KeyCollection.EMPTY)) { 134 | return keys; 135 | } 136 | 137 | if (keys instanceof KeyRange) { 138 | KeyRange r = (KeyRange) keys; 139 | if ((r.firstKey <= first) && (r.lastKey >= last)) { 140 | return this; 141 | } 142 | 143 | KeyHash newKeys = new KeyHash(hashQuato, hashIndex, Math.max(minKey, r.firstKey), Math.min(maxKey, r.lastKey)); 144 | return newKeys; 145 | } 146 | 147 | KeyList list = new KeyList(); 148 | Iterator it = keys.iterator(); 149 | while(it.hasNext()) { 150 | long key = it.next(); 151 | if (contains(key)) { 152 | list.addKey(key); 153 | } 154 | } 155 | 156 | if (list.isEmpty()) { 157 | return KeyCollection.EMPTY; 158 | } 159 | 160 | return list; 161 | } 162 | 163 | 164 | @Override 165 | public Iterator iterator() { 166 | return new _Iterator(this); 167 | } 168 | 169 | static class _Iterator implements Iterator { 170 | 171 | long currentKey; 172 | KeyHash keys; 173 | 174 | public _Iterator(KeyHash keys) { 175 | this.keys = keys; 176 | this.currentKey = keys.first; 177 | } 178 | 179 | public boolean hasNext() { 180 | return currentKey <= keys.last; 181 | } 182 | 183 | public Long next() { 184 | long k = currentKey; 185 | currentKey += keys.hashQuato; 186 | return k; 187 | } 188 | 189 | public void remove() { 190 | throw new RuntimeException("Not supported."); 191 | } 192 | 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/KeyList.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import java.io.DataInputStream; 4 | import java.io.DataOutputStream; 5 | import java.io.IOException; 6 | import java.util.HashSet; 7 | import java.util.Iterator; 8 | 9 | /** 10 | * Created by yunlong on 2/1/15. 11 | */ 12 | public class KeyList extends KeyCollection { 13 | 14 | public HashSet keys; 15 | 16 | public KeyList() { 17 | super(KeyCollection.TYPE_LIST); 18 | keys = new HashSet(); 19 | } 20 | 21 | public long size() { 22 | return keys.size(); 23 | } 24 | 25 | @Override 26 | public int sizeAsBytes(DataDesc format) { 27 | return super.sizeAsBytes(format) + 4 + keys.size() * format.keySize; 28 | } 29 | 30 | @Override 31 | public void write(AbstractDataWriter out, DataDesc format) throws Exception { 32 | super.write(out, format); 33 | 34 | out.writeInt(keys.size()); 35 | if (format.keyType == DataDesc.KEY_TYPE_INT) { 36 | for (long key : keys) 37 | out.writeInt((int)key); 38 | } 39 | else { 40 | for (long key : keys) 41 | out.writeLong(key); 42 | } 43 | } 44 | 45 | @Override 46 | public void read(AbstractDataReader in, DataDesc format) throws Exception { 47 | 48 | int count = in.readInt(); 49 | if (format.keyType == DataDesc.KEY_TYPE_INT) { 50 | for (int i = 0; i < count; i++) 51 | keys.add(new Long(in.readInt())); 52 | } 53 | else { 54 | for (int i = 0; i < count; i++) 55 | keys.add(in.readLong()); 56 | } 57 | } 58 | 59 | public void addKey(long k) { 60 | keys.add(k); 61 | } 62 | 63 | @Override 64 | public boolean contains(long key) { 65 | return keys.contains(key); 66 | } 67 | 68 | @Override 69 | public boolean isEmpty() { 70 | return keys.isEmpty(); 71 | } 72 | 73 | @Override 74 | public KeyCollection intersect(KeyCollection keys) { 75 | KeyList list = new KeyList(); 76 | 77 | Iterator it = keys.iterator(); 78 | while(it.hasNext()) { 79 | long k = it.next(); 80 | if (contains(k)) { 81 | list.keys.add(k); 82 | } 83 | } 84 | 85 | return list; 86 | } 87 | 88 | @Override 89 | public Iterator iterator() { 90 | return keys.iterator(); 91 | } 92 | 93 | @Override 94 | public String toString() { 95 | String s = "[KeyList: size=" + keys.size(); 96 | if (keys.size() > 0) { 97 | s += " first=" + keys.iterator().next(); 98 | } 99 | s += "]"; 100 | 101 | return s; 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/KeyRange.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import java.io.DataInputStream; 4 | import java.io.DataOutputStream; 5 | import java.io.IOException; 6 | import java.util.Iterator; 7 | 8 | /** 9 | * Created by yunlong on 12/11/14. 10 | */ 11 | public class KeyRange extends KeyCollection { 12 | 13 | public static final KeyRange Single = new KeyRange(0, 0); 14 | 15 | public long firstKey, lastKey; 16 | 17 | KeyRange() { 18 | super(KeyCollection.TYPE_RANGE); 19 | } 20 | 21 | public KeyRange(long f, long l) { 22 | super(KeyCollection.TYPE_RANGE); 23 | firstKey = f; 24 | lastKey = l; 25 | } 26 | 27 | @Override 28 | public boolean equals(Object obj) { 29 | if (!(obj instanceof KeyRange)) { 30 | return false; 31 | } 32 | 33 | KeyRange range = (KeyRange)obj; 34 | return (firstKey == range.firstKey) && (lastKey == range.lastKey); 35 | } 36 | 37 | @Override 38 | public int sizeAsBytes(DataDesc format) { 39 | return super.sizeAsBytes(format) + 2 * format.keySize; 40 | } 41 | 42 | @Override 43 | public void write(AbstractDataWriter out, DataDesc format) throws Exception { 44 | super.write(out, format); 45 | 46 | if (format.keyType == DataDesc.KEY_TYPE_INT) { 47 | out.writeInt((int)firstKey); 48 | out.writeInt((int)lastKey); 49 | } 50 | else { 51 | out.writeLong(firstKey); 52 | out.writeLong(lastKey); 53 | } 54 | } 55 | 56 | @Override 57 | public void read(AbstractDataReader in, DataDesc format) throws Exception { 58 | if (format.keyType == DataDesc.KEY_TYPE_INT) { 59 | firstKey = in.readInt(); 60 | lastKey = in.readInt(); 61 | } 62 | else { 63 | firstKey = in.readLong(); 64 | lastKey = in.readLong(); 65 | } 66 | } 67 | 68 | public KeyCollection[] linearSplit(int hostNum) { 69 | KeyCollection[] sets = new KeyRange[hostNum]; 70 | 71 | long start = firstKey; 72 | long step = (lastKey - firstKey + hostNum) / hostNum; 73 | for (int i = 0; i < hostNum; i++) { 74 | long end = Math.min(start + step - 1, lastKey); 75 | sets[i] = new KeyRange(start, end); 76 | start += step; 77 | } 78 | 79 | return sets; 80 | } 81 | 82 | public KeyCollection[] hashSplit(int hostNum) { 83 | KeyCollection[] sets = new KeyHash[hostNum]; 84 | 85 | for (int i = 0; i < hostNum; i++) { 86 | sets[i] = new KeyHash(hostNum, i, firstKey, lastKey); 87 | } 88 | 89 | return sets; 90 | } 91 | 92 | public long size() { 93 | return lastKey - firstKey + 1; 94 | } 95 | 96 | @Override 97 | public boolean contains(long key) { 98 | return (key >= firstKey) && (key <= lastKey); 99 | } 100 | 101 | @Override 102 | public boolean isEmpty() { 103 | return firstKey > lastKey; 104 | } 105 | 106 | public boolean containsAll(KeyCollection keys) { 107 | KeyRange keyRange = (KeyRange)keys; 108 | return (keyRange.lastKey >= firstKey && keyRange.firstKey <= lastKey); 109 | } 110 | 111 | @Override 112 | public String toString() { 113 | return "[" + firstKey + ", " + lastKey + "]"; 114 | } 115 | 116 | public KeyRange FetchSame(KeyRange kr) { 117 | long NewFirst = kr.firstKey > this.firstKey ? kr.firstKey : this.firstKey; 118 | long NewLast = kr.lastKey < this.lastKey ? kr.lastKey : this.lastKey; 119 | if (NewFirst > NewLast) return null; 120 | return new KeyRange(NewFirst, NewLast); 121 | } 122 | 123 | @Override 124 | public KeyCollection intersect(KeyCollection keys) { 125 | 126 | if (keys instanceof KeyRange) { 127 | KeyRange r = (KeyRange)keys; 128 | long min = Math.max(r.firstKey, firstKey); 129 | long max = Math.min(r.lastKey, lastKey); 130 | 131 | if (min > max) { 132 | return KeyCollection.EMPTY; 133 | } 134 | 135 | return new KeyRange(min, max); 136 | } 137 | 138 | if (keys instanceof KeyHash) { 139 | // warning: hope KeyHash don't call back 140 | return keys.intersect(this); 141 | } 142 | 143 | return super.intersect(keys); 144 | } 145 | 146 | public boolean mergeFrom(KeyRange keys) { 147 | if ((keys.firstKey > lastKey) && (keys.lastKey < firstKey)) { 148 | return false; 149 | } 150 | 151 | firstKey = Math.min(keys.firstKey, firstKey); 152 | lastKey = Math.max(keys.lastKey, lastKey); 153 | return true; 154 | } 155 | 156 | 157 | @Override 158 | public Iterator iterator() { 159 | return new _Iterator(this); 160 | } 161 | 162 | static class _Iterator implements Iterator { 163 | 164 | long index; 165 | KeyRange range; 166 | 167 | public _Iterator(KeyRange range) { 168 | this.range = range; 169 | this.index = range.firstKey; 170 | } 171 | 172 | public boolean hasNext() { 173 | return index <= range.lastKey; 174 | } 175 | 176 | public Long next() { 177 | return index++; 178 | } 179 | 180 | public void remove() { 181 | throw new RuntimeException("Not supported."); 182 | } 183 | 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/Logger.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import java.text.SimpleDateFormat; 4 | import java.util.Date; 5 | 6 | /** 7 | * Created by taotao on 15-1-14. 8 | * 9 | * Simple log system. 10 | */ 11 | public class Logger { 12 | 13 | private final static String[] levels = {"NOTUSED", "DEBUG", "INFO", "WARN", "ERROR", "CRITICAL"}; 14 | 15 | private final static int DEBUG = 1; 16 | private final static int INFO = 2; 17 | private final static int WARN = 3; 18 | private final static int ERROR = 4; 19 | private final static int CRITICAL = 5; 20 | 21 | private static SimpleDateFormat DATE_FORMAT = new SimpleDateFormat("dd/MM/yyyy HH:mm:ss.SSS"); 22 | 23 | private static int outputLevel = INFO; 24 | 25 | public static void Log(int level, String msg, String module) { 26 | if (level < outputLevel) return; 27 | 28 | String typeStr = levels[level]; 29 | System.out.println("==== [" + typeStr + "] [" + DATE_FORMAT.format(new Date()) + "] " + 30 | "[" + module + "] " + msg); 31 | } 32 | 33 | public static void debug(String msg, String module) { 34 | Log(DEBUG, msg, module); 35 | } 36 | 37 | public static void info(String msg, String module) { 38 | Log(INFO, msg, module); 39 | } 40 | 41 | public static void warn(String msg, String module) { 42 | Log(WARN, msg, module); 43 | } 44 | 45 | public static void error(String msg, String module) { 46 | Log(ERROR, msg, module); 47 | } 48 | 49 | public static void critical(String msg, String module) { 50 | Log(CRITICAL, msg, module); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/SparseArray.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import com.intel.distml.api.DMatrix; 4 | import com.intel.distml.api.Session; 5 | 6 | import java.util.HashMap; 7 | import java.util.Map; 8 | 9 | /** 10 | * Created by yunlong on 12/31/15. 11 | */ 12 | public abstract class SparseArray extends DMatrix { 13 | 14 | public SparseArray(long dim, int keyType, int valueType) { 15 | super(dim); 16 | format = new DataDesc(DataDesc.DATA_TYPE_ARRAY, keyType, valueType); 17 | } 18 | 19 | public HashMap fetch(KeyCollection rows, Session session) { 20 | HashMap result = new HashMap(); 21 | byte[][] data = session.dataBus.fetch(name, rows, format); 22 | for (byte[] obj : data) { 23 | HashMap m = readMap(obj); 24 | result.putAll(m); 25 | } 26 | 27 | return result; 28 | } 29 | 30 | public void push(HashMap data, Session session) { 31 | byte[][] bufs = new byte[partitions.length][]; 32 | for (int i = 0; i < partitions.length; ++i) { 33 | KeyCollection p = partitions[i]; 34 | HashMap m = new HashMap(); 35 | for (K key : data.keySet()) { 36 | long k = (key instanceof Integer)? ((Integer)key).longValue() : ((Long)key).longValue(); 37 | if (p.contains(k)) 38 | m.put(key, data.get(key)); 39 | } 40 | bufs[i] = writeMap(m); 41 | } 42 | 43 | session.dataBus.push(name, format, bufs); 44 | } 45 | 46 | private HashMap readMap(byte[] buf) { 47 | HashMap data = new HashMap(); 48 | 49 | int offset = 0; 50 | while (offset < buf.length) { 51 | K key = (K) format.readKey(buf, offset); 52 | offset += format.keySize; 53 | 54 | V value = (V) format.readValue(buf, offset); 55 | offset += format.valueSize; 56 | 57 | data.put(key, value); 58 | } 59 | 60 | return data; 61 | } 62 | 63 | private byte[] writeMap(HashMap data) { 64 | int recordLen = format.keySize + format.valueSize; 65 | byte[] buf = new byte[recordLen * data.size()]; 66 | 67 | int offset = 0; 68 | for (Map.Entry entry : data.entrySet()) { 69 | format.writeKey((Number) entry.getKey(), buf, offset); 70 | offset += format.keySize; 71 | 72 | offset = format.writeValue(entry.getValue(), buf, offset); 73 | } 74 | 75 | return buf; 76 | } 77 | 78 | } 79 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/SparseMatrix.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import com.intel.distml.api.DMatrix; 4 | import com.intel.distml.api.Session; 5 | 6 | import java.lang.reflect.Array; 7 | import java.util.HashMap; 8 | import java.util.Map; 9 | 10 | /** 11 | * Created by yunlong on 12/31/15. 12 | */ 13 | public abstract class SparseMatrix extends DMatrix { 14 | 15 | protected KeyRange colKeys; 16 | 17 | public SparseMatrix(long dim, int cols, int keyType, int valueType) { 18 | super(dim); 19 | format = new DataDesc(DataDesc.DATA_TYPE_MATRIX, keyType, valueType); 20 | colKeys = new KeyRange(0, cols-1); 21 | } 22 | 23 | public SparseMatrix(long dim, int cols, int keyType, int valueType, 24 | boolean denseColumn) { 25 | super(dim); 26 | format = new DataDesc(DataDesc.DATA_TYPE_MATRIX, keyType, valueType, false, denseColumn); 27 | colKeys = new KeyRange(0, cols-1); 28 | } 29 | 30 | public KeyCollection getColKeys() { 31 | return colKeys; 32 | } 33 | 34 | 35 | public HashMap fetch(KeyCollection rows, Session session) { 36 | HashMap result = new HashMap(); 37 | byte[][] data = session.dataBus.fetch(name, rows, format); 38 | for (byte[] obj : data) { 39 | HashMap m = readMap(obj); 40 | result.putAll(m); 41 | } 42 | 43 | return result; 44 | } 45 | 46 | public void push(HashMap data, Session session) { 47 | byte[][] bufs = new byte[partitions.length][]; 48 | for (int i = 0; i < partitions.length; ++i) { 49 | KeyCollection p = partitions[i]; 50 | HashMap m = new HashMap(); 51 | for (K key : data.keySet()) { 52 | long k = (key instanceof Integer)? ((Integer)key).longValue() : ((Long)key).longValue(); 53 | if (p.contains(k)) 54 | m.put(key, data.get(key)); 55 | } 56 | bufs[i] = writeMap(m); 57 | } 58 | 59 | session.dataBus.push(name, format, bufs); 60 | } 61 | 62 | private HashMap readMap(byte[] buf) { 63 | HashMap data = new HashMap(); 64 | 65 | int offset = 0; 66 | while (offset < buf.length) { 67 | K key = (K) format.readKey(buf, offset); 68 | offset += format.keySize; 69 | 70 | V[] values = createValueArray((int) getColKeys().size()); 71 | if (format.denseColumn) { 72 | for (int i = 0; i < getColKeys().size(); i++) { 73 | values[i] = (V) format.readValue(buf, offset); 74 | offset += format.valueSize; 75 | } 76 | } 77 | else { 78 | int count = format.readInt(buf, offset); 79 | offset += 4; 80 | 81 | for (int i = 0; i < count; i++) { 82 | int index = format.readInt(buf, offset); 83 | offset += 4; 84 | V value = (V) format.readValue(buf, offset); 85 | offset += format.valueSize; 86 | 87 | values[index] = value; 88 | } 89 | } 90 | data.put(key, values); 91 | } 92 | 93 | return data; 94 | } 95 | 96 | private byte[] writeMap(HashMap data) { 97 | 98 | byte[] buf; 99 | if (format.denseColumn) { 100 | int len = (int) (format.valueSize * data.size() * colKeys.size()); 101 | buf = new byte[format.keySize * data.size() + len]; 102 | } 103 | else { 104 | int nzcount = 0; 105 | for (V[] values : data.values()) { 106 | for (V value : values) { 107 | if (!isZero(value)) { 108 | nzcount++; 109 | } 110 | } 111 | } 112 | int len = (int) ((format.valueSize + 4) * nzcount); 113 | buf = new byte[format.keySize * data.size() + len]; 114 | } 115 | 116 | int offset = 0; 117 | for (Map.Entry entry : data.entrySet()) { 118 | format.writeKey((Number)entry.getKey(), buf, offset); 119 | offset += format.keySize; 120 | V[] values = entry.getValue(); 121 | if (format.denseColumn) { 122 | for (int i = 0; i < getColKeys().size(); i++) { 123 | format.writeValue(values[i], buf, offset); 124 | offset += format.valueSize; 125 | } 126 | } 127 | else { 128 | int counterIndex = offset; 129 | offset += 4; 130 | 131 | int counter = 0; 132 | for (int i = 0; i < values.length; i++) { 133 | V value = values[i]; 134 | if (!isZero(value)) { 135 | format.write(i, buf, offset); 136 | offset += 4; 137 | format.writeValue(value, buf, offset); 138 | offset += format.valueSize; 139 | } 140 | 141 | counter++; 142 | } 143 | format.write(counter, buf, counterIndex); 144 | } 145 | } 146 | 147 | return buf; 148 | } 149 | 150 | abstract protected boolean isZero(V value); 151 | abstract protected V[] createValueArray(int size); 152 | } 153 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/Utils.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util; 2 | 3 | import java.io.DataInputStream; 4 | import java.io.IOException; 5 | import java.net.Inet4Address; 6 | import java.net.InetAddress; 7 | import java.net.NetworkInterface; 8 | import java.net.SocketException; 9 | import java.util.Date; 10 | import java.util.Enumeration; 11 | 12 | /** 13 | * Created by spark on 7/2/14. 14 | */ 15 | public class Utils { 16 | 17 | public static String[] getNetworkAddress(String networkPrefix) throws SocketException { 18 | String[] addr = new String[2]; 19 | 20 | Enumeration allNetInterfaces = NetworkInterface.getNetworkInterfaces(); 21 | InetAddress ip = null; 22 | while (allNetInterfaces.hasMoreElements()) 23 | { 24 | NetworkInterface netInterface = (NetworkInterface) allNetInterfaces.nextElement(); 25 | Enumeration addresses = netInterface.getInetAddresses(); 26 | while (addresses.hasMoreElements()) 27 | { 28 | ip = (InetAddress) addresses.nextElement(); 29 | if (ip != null && ip instanceof Inet4Address) 30 | { 31 | if (networkPrefix != null) { 32 | if (ip.getHostAddress().startsWith(networkPrefix)) { 33 | addr[0] = ip.getHostAddress(); 34 | addr[1] = ip.getHostName(); 35 | System.out.println("Server address = " + addr[0] + ", " + addr[1]); 36 | } 37 | } 38 | else { 39 | if (!ip.getHostAddress().startsWith("127")) { 40 | addr[0] = ip.getHostAddress(); 41 | addr[1] = ip.getHostName(); 42 | System.out.println("Server address = " + addr[0] + ", " + addr[1]); 43 | } 44 | } 45 | } 46 | } 47 | } 48 | 49 | return addr; 50 | } 51 | 52 | public static void waitUntil(DataInputStream is, int size) throws IOException { 53 | while(is.available() < size) { try { Thread.sleep(1); } catch (Exception e){}} 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/store/DoubleArrayStore.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.store; 2 | 3 | import com.intel.distml.util.*; 4 | 5 | import java.io.DataInputStream; 6 | import java.io.DataOutputStream; 7 | import java.io.IOException; 8 | import java.util.HashMap; 9 | import java.util.Iterator; 10 | import java.util.Map; 11 | 12 | /** 13 | * Created by yunlong on 12/8/15. 14 | */ 15 | public class DoubleArrayStore extends DataStore { 16 | 17 | public static final int VALUE_SIZE = 8; 18 | 19 | transient KeyCollection localRows; 20 | transient double[] localData; 21 | 22 | public KeyCollection rows() { 23 | return localRows; 24 | } 25 | 26 | public int rowSize() { 27 | return 1; 28 | } 29 | 30 | public void init(KeyCollection keys) { 31 | this.localRows = keys; 32 | localData = new double[(int)keys.size()]; 33 | for (int i = 0; i < keys.size(); i++) 34 | localData[i] = 0.0; 35 | } 36 | 37 | public int indexOf(long key) { 38 | if (localRows instanceof KeyRange) { 39 | return (int) (key - ((KeyRange)localRows).firstKey); 40 | } 41 | else if (localRows instanceof KeyHash) { 42 | KeyHash hash = (KeyHash) localRows; 43 | return (int) ((key - hash.minKey) % hash.hashQuato); 44 | } 45 | 46 | throw new RuntimeException("Only KeyRange or KeyHash is allowed in server storage"); 47 | } 48 | 49 | public long keyOf(int index) { 50 | if (localRows instanceof KeyRange) { 51 | return ((KeyRange)localRows).firstKey + index; 52 | } 53 | else if (localRows instanceof KeyHash) { 54 | KeyHash hash = (KeyHash) localRows; 55 | return hash.minKey + index * hash.hashQuato; 56 | } 57 | 58 | throw new RuntimeException("Only KeyRange or KeyHash is allowed in server storage"); 59 | } 60 | 61 | @Override 62 | public void writeAll(DataOutputStream os) throws IOException { 63 | System.out.println("saving to: " + localData.length); 64 | for (int i = 0; i < localData.length; i++) { 65 | os.writeDouble(localData[i]); 66 | } 67 | } 68 | 69 | @Override 70 | public void readAll(DataInputStream is) throws IOException { 71 | for (int i = 0; i < localData.length; i++) { 72 | localData[i] = is.readDouble(); 73 | } 74 | } 75 | 76 | @Override 77 | public void syncTo(DataOutputStream os, int fromRow, int toRow) throws IOException { 78 | for (int i = fromRow; i <= toRow; i++) { 79 | os.writeDouble(localData[i]); 80 | } 81 | System.out.println("sync done"); 82 | } 83 | 84 | @Override 85 | public void syncFrom(DataInputStream is, int fromRow, int toRow) throws IOException { 86 | for (int i = fromRow; i <= toRow; i++) { 87 | localData[i] = is.readDouble(); 88 | } 89 | System.out.println("sync done"); 90 | } 91 | 92 | @Override 93 | public byte[] handleFetch(DataDesc format, KeyCollection rows) { 94 | 95 | KeyCollection keys = localRows.intersect(rows); 96 | 97 | int len = (int) ((format.keySize + VALUE_SIZE) * keys.size()); 98 | byte[] buf = new byte[len]; 99 | 100 | Iterator it = keys.iterator(); 101 | int offset = 0; 102 | while(it.hasNext()) { 103 | long k = it.next(); 104 | 105 | format.writeKey((Number) k, buf, offset); 106 | offset += format.keySize; 107 | 108 | double value = localData[indexOf(k)]; 109 | format.writeValue(value, buf, offset); 110 | offset += VALUE_SIZE; 111 | } 112 | return buf; 113 | } 114 | 115 | public void handlePush(DataDesc format, byte[] data) { 116 | 117 | int offset = 0; 118 | while (offset < data.length) { 119 | long key = format.readKey(data, offset).longValue(); 120 | offset += format.keySize; 121 | 122 | double update = format.readDouble(data, offset); 123 | offset += VALUE_SIZE; 124 | 125 | localData[indexOf(key)] += update; 126 | } 127 | } 128 | 129 | public Iter iter() { 130 | return new Iter(); 131 | } 132 | 133 | public class Iter { 134 | 135 | int p; 136 | 137 | public Iter() { 138 | p = -1; 139 | } 140 | 141 | public boolean hasNext() { 142 | return p < localData.length - 1; 143 | } 144 | 145 | public long key() { 146 | return keyOf(p); 147 | } 148 | 149 | public double value() { 150 | return localData[p]; 151 | } 152 | 153 | public boolean next() { 154 | p++; 155 | return p < localData.length; 156 | } 157 | } 158 | 159 | } 160 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/store/DoubleMatrixStore.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.store; 2 | 3 | import com.intel.distml.util.*; 4 | 5 | import java.io.DataInputStream; 6 | import java.io.DataOutputStream; 7 | import java.io.IOException; 8 | import java.util.Iterator; 9 | import java.util.Random; 10 | 11 | /** 12 | * Created by jimmy on 16-1-4. 13 | */ 14 | public class DoubleMatrixStore extends DataStore { 15 | public static final int VALUE_SIZE = 8; 16 | 17 | transient KeyCollection localRows; 18 | transient int rowSize; 19 | transient double[][] localData; 20 | 21 | public KeyCollection rows() { 22 | return localRows; 23 | } 24 | public int rowSize() { 25 | return rowSize; 26 | } 27 | 28 | public void init(KeyCollection rows, int cols) { 29 | this.localRows = rows; 30 | this.rowSize = cols; 31 | 32 | localData = new double[(int)localRows.size()][rowSize]; 33 | 34 | Runtime r = Runtime.getRuntime(); 35 | System.out.println("memory: " + r.freeMemory() + ", " + r.totalMemory() + ", needed: " + localRows.size() * cols); 36 | for (int i = 0; i < localRows.size(); i++) 37 | for (int j = 0; j < rowSize; j++) 38 | localData[i][j] = 0.0; 39 | 40 | } 41 | 42 | @Override 43 | public void writeAll(DataOutputStream os) throws IOException { 44 | for (int i = 0; i < localData.length; i++) { 45 | for (int j = 0; j < rowSize; j++) { 46 | os.writeDouble(localData[i][j]); 47 | } 48 | } 49 | } 50 | 51 | @Override 52 | public void readAll(DataInputStream is) throws IOException { 53 | for (int i = 0; i < localData.length; i++) { 54 | for (int j = 0; j < rowSize; j++) { 55 | localData[i][j] = is.readDouble(); 56 | } 57 | } 58 | } 59 | 60 | @Override 61 | public void syncTo(DataOutputStream os, int fromRow, int toRow) throws IOException { 62 | for (int i = fromRow; i <= toRow; i++) { 63 | for (int j = 0; j < rowSize; j++) { 64 | os.writeDouble(localData[i][j]); 65 | } 66 | } 67 | } 68 | 69 | @Override 70 | public void syncFrom(DataInputStream is, int fromRow, int toRow) throws IOException { 71 | int rowSize = (int) localRows.size(); 72 | for (int i = fromRow; i <= toRow; i++) { 73 | for (int j = 0; j < rowSize; j++) { 74 | localData[i][j] = is.readDouble(); 75 | } 76 | } 77 | } 78 | 79 | @Override 80 | public byte[] handleFetch(DataDesc format, KeyCollection rows) { 81 | 82 | KeyCollection keys = localRows.intersect(rows); 83 | byte[] buf; 84 | if (format.denseColumn) { 85 | int keySpace = (int) (format.keySize * keys.size()); 86 | int valueSpace = (int) (VALUE_SIZE * keys.size() * localData[0].length); 87 | buf = new byte[keySpace + valueSpace]; 88 | } 89 | else { 90 | int nzcount = 0; 91 | Iterator it = keys.iterator(); 92 | while (it.hasNext()) { 93 | long k = it.next(); 94 | double[] values = localData[indexOf(k)]; 95 | for (int i = 0; i < values.length; i++) { 96 | if (values[i] != 0.0) { 97 | nzcount++; 98 | } 99 | } 100 | } 101 | int len = (VALUE_SIZE + 4) * nzcount; 102 | buf = new byte[format.keySize * (int)keys.size() + len]; 103 | } 104 | 105 | Iterator it = keys.iterator(); 106 | int offset = 0; 107 | while(it.hasNext()) { 108 | long k = it.next(); 109 | format.writeKey((Number)k, buf, offset); 110 | offset += format.keySize; 111 | 112 | double[] values = localData[indexOf(k)]; 113 | if (format.denseColumn) { 114 | for (int i = 0; i < values.length; i++) { 115 | format.writeValue(values[i], buf, offset); 116 | offset += VALUE_SIZE; 117 | } 118 | } 119 | else { 120 | int counterIndex = offset; 121 | offset += 4; 122 | 123 | int counter = 0; 124 | for (int i = 0; i < values.length; i++) { 125 | if (values[i] != 0) { 126 | format.write(i, buf, offset); 127 | offset += 4; 128 | format.write(values[i], buf, offset); 129 | offset += VALUE_SIZE; 130 | } 131 | 132 | counter++; 133 | } 134 | format.write(counter, buf, counterIndex); 135 | } 136 | } 137 | 138 | return buf; 139 | } 140 | 141 | public int indexOf(long key) { 142 | if (localRows instanceof KeyRange) { 143 | return (int) (key - ((KeyRange)localRows).firstKey); 144 | } 145 | else if (localRows instanceof KeyHash) { 146 | KeyHash hash = (KeyHash) localRows; 147 | return (int) ((key - hash.minKey) % hash.hashQuato); 148 | } 149 | 150 | throw new RuntimeException("Only KeyRange or KeyHash is allowed in server storage"); 151 | } 152 | 153 | public void handlePush(DataDesc format, byte[] data) { 154 | 155 | int offset = 0; 156 | while (offset < data.length) { 157 | long key = format.readKey(data, offset).longValue(); 158 | offset += format.keySize; 159 | offset = updateRow(key, data, offset, format); 160 | } 161 | } 162 | 163 | private int updateRow(long key, byte[] data, int start, DataDesc format) { 164 | assert(localRows.contains(key)); 165 | 166 | int index = indexOf(key); 167 | double[] row = localData[index]; 168 | int offset = start; 169 | if (format.denseColumn) { 170 | for (int i = 0; i < row.length; i++) { 171 | double update = format.readDouble(data, offset); 172 | row[i] += update; 173 | offset += VALUE_SIZE; 174 | } 175 | } 176 | else { 177 | int count = format.readInt(data, offset); 178 | offset += 4; 179 | for (int i = 0; i < count; i++) { 180 | int col = format.readInt(data, offset); 181 | offset += 4; 182 | assert(col < row.length); 183 | 184 | double update = format.readDouble(data, offset); 185 | row[col] += update; 186 | offset += VALUE_SIZE; 187 | } 188 | } 189 | 190 | return offset; 191 | } 192 | public void rand() { 193 | long seed = 1L; 194 | Random random = new Random(seed); 195 | int cols = this.localData[0].length; 196 | for (int i = 0; i < this.localRows.size(); i++) { 197 | double sum = 0.0; 198 | for (int j = 0; j < cols; j++) { 199 | localData[i][j] = Math.abs(random.nextGaussian()); 200 | sum += localData[i][j] * localData[i][j]; 201 | } 202 | sum = Math.sqrt(sum); 203 | for (int j=0; j it = keys.iterator(); 96 | int offset = 0; 97 | while(it.hasNext()) { 98 | long k = it.next(); 99 | 100 | format.writeKey((Number) k, buf, offset); 101 | offset += format.keySize; 102 | 103 | float value = localData[indexOf(k)]; 104 | format.writeValue(value, buf, offset); 105 | offset += VALUE_SIZE; 106 | } 107 | return buf; 108 | } 109 | 110 | public void handlePush(DataDesc format, byte[] data) { 111 | 112 | int offset = 0; 113 | while (offset < data.length) { 114 | long key = format.readKey(data, offset).longValue(); 115 | offset += format.keySize; 116 | 117 | float update = format.readFloat(data, offset); 118 | offset += VALUE_SIZE; 119 | 120 | localData[indexOf(key)] += update; 121 | } 122 | } 123 | 124 | public Iter iter() { 125 | return new Iter(); 126 | } 127 | 128 | public class Iter { 129 | 130 | int p; 131 | 132 | public Iter() { 133 | p = -1; 134 | } 135 | 136 | public boolean hasNext() { 137 | return p < localData.length - 1; 138 | } 139 | 140 | public long key() { 141 | return keyOf(p); 142 | } 143 | 144 | public float value() { 145 | return localData[p]; 146 | } 147 | 148 | public boolean next() { 149 | p++; 150 | return p < localData.length; 151 | } 152 | } 153 | 154 | } 155 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/store/FloatMatrixStore.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.store; 2 | 3 | import com.intel.distml.util.*; 4 | 5 | import java.io.DataInputStream; 6 | import java.io.DataOutputStream; 7 | import java.io.IOException; 8 | import java.util.Iterator; 9 | import java.util.Random; 10 | 11 | /** 12 | * Created by yunlong on 1/3/16. 13 | */ 14 | public class FloatMatrixStore extends DataStore { 15 | public static final int VALUE_SIZE = 4; 16 | 17 | transient KeyCollection localRows; 18 | transient int rowSize; 19 | transient float[][] localData; 20 | 21 | public KeyCollection rows() { 22 | return localRows; 23 | } 24 | public int rowSize() { 25 | return rowSize; 26 | } 27 | 28 | public void init(KeyCollection keys, int cols) { 29 | this.localRows = keys; 30 | this.rowSize = cols; 31 | 32 | localData = new float[(int)localRows.size()][rowSize]; 33 | 34 | for (int i = 0; i < localRows.size(); i++) 35 | for (int j = 0; j < rowSize; j++) 36 | localData[i][j] = 0.0f; 37 | } 38 | 39 | public void rand() { 40 | System.out.println("init with random values"); 41 | 42 | int rows = (int) localRows.size(); 43 | 44 | Random r = new Random(); 45 | for (int i = 0; i < rows; i++) { 46 | for (int j = 0; j < rowSize; j++) { 47 | int a = r.nextInt(100); 48 | localData[i][j] = (a / 100.0f - 0.5f) / rowSize; 49 | } 50 | } 51 | } 52 | 53 | public void set(String value) { 54 | setValue(Float.parseFloat(value)); 55 | } 56 | 57 | public void zero(String value) { 58 | setValue(0f); 59 | } 60 | 61 | private void setValue(float v) { 62 | System.out.println("init with value: " + v); 63 | 64 | int rows = (int) localRows.size(); 65 | 66 | for (int i = 0; i < rows; i++) { 67 | for (int j = 0; j < rowSize; j++) { 68 | localData[i][j] = v; 69 | } 70 | } 71 | } 72 | 73 | @Override 74 | public void writeAll(DataOutputStream os) throws IOException { 75 | 76 | for (int i = 0; i < localData.length; i++) { 77 | for (int j = 0; j < rowSize; j++) { 78 | os.writeFloat(localData[i][j]); 79 | } 80 | } 81 | } 82 | 83 | @Override 84 | public void readAll(DataInputStream is) throws IOException { 85 | 86 | for (int i = 0; i < localData.length; i++) { 87 | for (int j = 0; j < rowSize; j++) { 88 | localData[i][j] = is.readFloat(); 89 | } 90 | } 91 | } 92 | 93 | @Override 94 | public void syncTo(DataOutputStream os, int fromRow, int toRow) throws IOException { 95 | for (int i = fromRow; i <= toRow; i++) { 96 | for (int j = 0; j < rowSize; j++) { 97 | os.writeFloat(localData[i][j]); 98 | } 99 | } 100 | } 101 | 102 | @Override 103 | public void syncFrom(DataInputStream is, int fromRow, int toRow) throws IOException { 104 | int rowSize = (int) localRows.size(); 105 | for (int i = fromRow; i <= toRow; i++) { 106 | for (int j = 0; j < rowSize; j++) { 107 | localData[i][j] = is.readFloat(); 108 | } 109 | } 110 | } 111 | 112 | @Override 113 | public byte[] handleFetch(DataDesc format, KeyCollection rows) { 114 | 115 | System.out.println("handle fetch request: " + rows); 116 | KeyCollection keys = localRows.intersect(rows); 117 | byte[] buf; 118 | if (format.denseColumn) { 119 | int keySpace = (int) (format.keySize * keys.size()); 120 | int valueSpace = (int) (VALUE_SIZE * keys.size() * localData[0].length); 121 | buf = new byte[keySpace + valueSpace]; 122 | System.out.println("buf size: " + buf.length); 123 | } 124 | else { 125 | int nzcount = 0; 126 | Iterator it = keys.iterator(); 127 | while (it.hasNext()) { 128 | long k = it.next(); 129 | float[] values = localData[indexOf(k)]; 130 | for (int i = 0; i < values.length; i++) { 131 | if (values[i] != 0.0) { 132 | nzcount++; 133 | } 134 | } 135 | } 136 | int len = (VALUE_SIZE + 4) * nzcount; 137 | buf = new byte[format.keySize * (int)keys.size() + len]; 138 | } 139 | 140 | Iterator it = keys.iterator(); 141 | int offset = 0; 142 | while(it.hasNext()) { 143 | long k = it.next(); 144 | format.writeKey((Number)k, buf, offset); 145 | offset += format.keySize; 146 | 147 | float[] values = localData[indexOf(k)]; 148 | if (format.denseColumn) { 149 | for (int i = 0; i < values.length; i++) { 150 | format.writeValue(values[i], buf, offset); 151 | offset += VALUE_SIZE; 152 | } 153 | } 154 | else { 155 | int counterIndex = offset; 156 | offset += 4; 157 | 158 | int counter = 0; 159 | for (int i = 0; i < values.length; i++) { 160 | if (values[i] != 0) { 161 | format.write(i, buf, offset); 162 | offset += 4; 163 | format.write(values[i], buf, offset); 164 | offset += VALUE_SIZE; 165 | } 166 | 167 | counter++; 168 | } 169 | format.write(counter, buf, counterIndex); 170 | } 171 | } 172 | 173 | return buf; 174 | } 175 | 176 | public int indexOf(long key) { 177 | if (localRows instanceof KeyRange) { 178 | return (int) (key - ((KeyRange)localRows).firstKey); 179 | } 180 | else if (localRows instanceof KeyHash) { 181 | KeyHash hash = (KeyHash) localRows; 182 | return (int) ((key - hash.minKey) % hash.hashQuato); 183 | } 184 | 185 | throw new RuntimeException("Only KeyRange or KeyHash is allowed in server storage"); 186 | } 187 | 188 | public long keyOf(int index) { 189 | if (localRows instanceof KeyRange) { 190 | return ((KeyRange)localRows).firstKey + index; 191 | } 192 | else if (localRows instanceof KeyHash) { 193 | KeyHash hash = (KeyHash) localRows; 194 | return hash.minKey + index * hash.hashQuato; 195 | } 196 | 197 | throw new RuntimeException("Only KeyRange or KeyHash is allowed in server storage"); 198 | } 199 | 200 | public void handlePush(DataDesc format, byte[] data) { 201 | 202 | int offset = 0; 203 | while (offset < data.length) { 204 | long key = format.readKey(data, offset).longValue(); 205 | offset += format.keySize; 206 | offset = updateRow(key, data, offset, format); 207 | } 208 | } 209 | 210 | private int updateRow(long key, byte[] data, int start, DataDesc format) { 211 | assert(localRows.contains(key)); 212 | 213 | int index = indexOf(key); 214 | float[] row = localData[index]; 215 | int offset = start; 216 | if (format.denseColumn) { 217 | for (int i = 0; i < row.length; i++) { 218 | float update = format.readFloat(data, offset); 219 | row[i] += update; 220 | offset += VALUE_SIZE; 221 | } 222 | } 223 | else { 224 | int count = format.readInt(data, offset); 225 | offset += 4; 226 | for (int i = 0; i < count; i++) { 227 | int col = format.readInt(data, offset); 228 | offset += 4; 229 | assert(col < row.length); 230 | 231 | float update = format.readFloat(data, offset); 232 | row[col] += update; 233 | offset += VALUE_SIZE; 234 | } 235 | } 236 | 237 | return offset; 238 | } 239 | 240 | 241 | public Iter iter() { 242 | return new Iter(); 243 | } 244 | 245 | public class Iter { 246 | 247 | int p; 248 | 249 | public Iter() { 250 | p = -1; 251 | } 252 | 253 | public boolean hasNext() { 254 | return p < localData.length - 1; 255 | } 256 | 257 | public long key() { 258 | return keyOf(p); 259 | } 260 | 261 | public float[] value() { 262 | return localData[p]; 263 | } 264 | 265 | public boolean next() { 266 | p++; 267 | return p < localData.length; 268 | } 269 | } 270 | } 271 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/store/IntArrayStore.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.store; 2 | 3 | import com.intel.distml.util.*; 4 | 5 | import java.io.DataInputStream; 6 | import java.io.DataOutputStream; 7 | import java.io.IOException; 8 | import java.util.HashMap; 9 | import java.util.Iterator; 10 | import java.util.Map; 11 | 12 | /** 13 | * Created by yunlong on 12/8/15. 14 | */ 15 | public class IntArrayStore extends DataStore { 16 | 17 | public static final int VALUE_SIZE = 4; 18 | 19 | transient KeyCollection localRows; 20 | transient int[] localData; 21 | 22 | public KeyCollection rows() { 23 | return localRows; 24 | } 25 | public int rowSize() { 26 | return 1; 27 | } 28 | 29 | public void init(KeyCollection keys) { 30 | this.localRows = keys; 31 | localData = new int[(int)keys.size()]; 32 | 33 | for (int i = 0; i < keys.size(); i++) 34 | localData[i] = 0; 35 | } 36 | 37 | public int indexOf(long key) { 38 | if (localRows instanceof KeyRange) { 39 | return (int) (key - ((KeyRange)localRows).firstKey); 40 | } 41 | else if (localRows instanceof KeyHash) { 42 | KeyHash hash = (KeyHash) localRows; 43 | return (int) ((key - hash.minKey) % hash.hashQuato); 44 | } 45 | 46 | throw new RuntimeException("Only KeyRange or KeyHash is allowed in server storage"); 47 | } 48 | 49 | @Override 50 | public void writeAll(DataOutputStream os) throws IOException { 51 | for (int i = 0; i < localData.length; i++) { 52 | os.writeInt(localData[i]); 53 | } 54 | } 55 | 56 | @Override 57 | public void readAll(DataInputStream is) throws IOException { 58 | for (int i = 0; i < localData.length; i++) { 59 | localData[i] = is.readInt(); 60 | } 61 | } 62 | 63 | @Override 64 | public void syncTo(DataOutputStream os, int fromRow, int toRow) throws IOException { 65 | for (int i = fromRow; i <= toRow; i++) { 66 | os.writeInt(localData[i]); 67 | } 68 | } 69 | 70 | @Override 71 | public void syncFrom(DataInputStream is, int fromRow, int toRow) throws IOException { 72 | for (int i = fromRow; i <= toRow; i++) { 73 | localData[i] = is.readInt(); 74 | } 75 | } 76 | 77 | @Override 78 | public byte[] handleFetch(DataDesc format, KeyCollection rows) { 79 | 80 | KeyCollection keys = localRows.intersect(rows); 81 | int len = (int) ((format.keySize + VALUE_SIZE) * keys.size()); 82 | byte[] buf = new byte[len]; 83 | 84 | Iterator it = keys.iterator(); 85 | int offset = 0; 86 | while(it.hasNext()) { 87 | long k = it.next(); 88 | format.writeKey((Number) k, buf, offset); 89 | offset += format.keySize; 90 | format.writeValue(localData[indexOf(k)], buf, offset); 91 | offset += VALUE_SIZE; 92 | } 93 | 94 | return buf; 95 | } 96 | 97 | public void handlePush(DataDesc format, byte[] data) { 98 | 99 | int offset = 0; 100 | while (offset < data.length) { 101 | long key = format.readKey(data, offset).longValue(); 102 | offset += format.keySize; 103 | 104 | int update = format.readInt(data, offset); 105 | offset += VALUE_SIZE; 106 | 107 | localData[indexOf(key)] += update; 108 | if (localData[indexOf(key)] < 0) { 109 | throw new IllegalStateException("invalid k counter: " + key + ", " + localData[indexOf(key)]); 110 | } 111 | 112 | } 113 | } 114 | 115 | } 116 | -------------------------------------------------------------------------------- /src/main/java/com/intel/distml/util/store/IntMatrixStore.java: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.store; 2 | 3 | import com.intel.distml.util.*; 4 | 5 | import java.io.DataInputStream; 6 | import java.io.DataOutputStream; 7 | import java.io.IOException; 8 | import java.util.HashMap; 9 | import java.util.Iterator; 10 | import java.util.Map; 11 | 12 | /** 13 | * Created by yunlong on 12/8/15. 14 | */ 15 | public class IntMatrixStore extends DataStore { 16 | 17 | public static final int VALUE_SIZE = 4; 18 | 19 | transient KeyCollection localRows; 20 | transient int rowSize; 21 | transient int[][] localData; 22 | 23 | public KeyCollection rows() { 24 | return localRows; 25 | } 26 | public int rowSize() { 27 | return rowSize; 28 | } 29 | 30 | public void init(KeyCollection keys, int cols) { 31 | this.localRows = keys; 32 | this.rowSize = cols; 33 | localData = new int[(int)keys.size()][cols]; 34 | 35 | Runtime r = Runtime.getRuntime(); 36 | System.out.println("memory: " + r.freeMemory() + ", " + r.totalMemory() + ", needed: " + localRows.size() * rowSize); 37 | for (int i = 0; i < localRows.size(); i++) 38 | for (int j = 0; j < rowSize; j++) 39 | localData[i][j] = 0; 40 | 41 | } 42 | 43 | @Override 44 | public void writeAll(DataOutputStream os) throws IOException { 45 | for (int i = 0; i < localData.length; i++) { 46 | for (int j = 0; j < rowSize; j++) { 47 | os.writeInt(localData[i][j]); 48 | } 49 | } 50 | } 51 | 52 | @Override 53 | public void readAll(DataInputStream is) throws IOException { 54 | for (int i = 0; i < localData.length; i++) { 55 | for (int j = 0; j < rowSize; j++) { 56 | localData[i][j] = is.readInt(); 57 | } 58 | } 59 | } 60 | 61 | @Override 62 | public void syncTo(DataOutputStream os, int fromRow, int toRow) throws IOException { 63 | for (int i = fromRow; i <= toRow; i++) { 64 | for (int j = 0; j < rowSize; j++) { 65 | os.writeInt(localData[i][j]); 66 | } 67 | } 68 | } 69 | 70 | @Override 71 | public void syncFrom(DataInputStream is, int fromRow, int toRow) throws IOException { 72 | int rowSize = (int) localRows.size(); 73 | for (int i = fromRow; i <= toRow; i++) { 74 | for (int j = 0; j < rowSize; j++) { 75 | localData[i][j] = is.readInt(); 76 | } 77 | } 78 | } 79 | 80 | @Override 81 | public byte[] handleFetch(DataDesc format, KeyCollection rows) { 82 | 83 | KeyCollection keys = localRows.intersect(rows); 84 | byte[] buf; 85 | if (format.denseColumn) { 86 | int keySpace = (int) (format.keySize * keys.size()); 87 | int valueSpace = (int) (VALUE_SIZE * keys.size() * localData[0].length); 88 | buf = new byte[keySpace + valueSpace]; 89 | } 90 | else { 91 | int nzcount = 0; 92 | Iterator it = keys.iterator(); 93 | while (it.hasNext()) { 94 | long k = it.next(); 95 | int[] values = localData[indexOf(k)]; 96 | for (int i = 0; i < values.length; i++) { 97 | if (values[i] != 0) { 98 | nzcount++; 99 | } 100 | } 101 | } 102 | int len = (VALUE_SIZE + 4) * nzcount; 103 | buf = new byte[format.keySize * (int)keys.size() + len]; 104 | } 105 | 106 | Iterator it = keys.iterator(); 107 | int offset = 0; 108 | while(it.hasNext()) { 109 | long k = it.next(); 110 | format.writeKey((Number)k, buf, offset); 111 | offset += format.keySize; 112 | 113 | int[] values = localData[indexOf(k)]; 114 | if (format.denseColumn) { 115 | for (int i = 0; i < values.length; i++) { 116 | format.writeValue(values[i], buf, offset); 117 | offset += VALUE_SIZE; 118 | } 119 | } 120 | else { 121 | int counterIndex = offset; 122 | offset += 4; 123 | 124 | int counter = 0; 125 | for (int i = 0; i < values.length; i++) { 126 | if (values[i] != 0) { 127 | format.write(i, buf, offset); 128 | offset += 4; 129 | format.write(values[i], buf, offset); 130 | offset += VALUE_SIZE; 131 | } 132 | 133 | counter++; 134 | } 135 | format.write(counter, buf, counterIndex); 136 | } 137 | } 138 | 139 | return buf; 140 | } 141 | 142 | public int indexOf(long key) { 143 | if (localRows instanceof KeyRange) { 144 | return (int) (key - ((KeyRange)localRows).firstKey); 145 | } 146 | else if (localRows instanceof KeyHash) { 147 | KeyHash hash = (KeyHash) localRows; 148 | return (int) ((key - hash.minKey) % hash.hashQuato); 149 | } 150 | 151 | throw new RuntimeException("Only KeyRange or KeyHash is allowed in server storage"); 152 | } 153 | 154 | public void handlePush(DataDesc format, byte[] data) { 155 | 156 | int offset = 0; 157 | while (offset < data.length) { 158 | long key = format.readKey(data, offset).longValue(); 159 | offset += format.keySize; 160 | offset = updateRow(key, data, offset, format); 161 | } 162 | } 163 | 164 | private int updateRow(long key, byte[] data, int start, DataDesc format) { 165 | assert(localRows.contains(key)); 166 | 167 | int index = indexOf(key); 168 | int[] row = localData[index]; 169 | int offset = start; 170 | if (format.denseColumn) { 171 | for (int i = 0; i < row.length; i++) { 172 | int update = format.readInt(data, offset); 173 | row[i] += update; 174 | if (row[i] < 0) { 175 | throw new IllegalStateException("invalid counter: " + key + ", " + i + ", " + row[i]); 176 | } 177 | offset += 4; 178 | } 179 | } 180 | else { 181 | int count = format.readInt(data, offset); 182 | offset += 4; 183 | for (int i = 0; i < count; i++) { 184 | int col = format.readInt(data, offset); 185 | offset += 4; 186 | assert(col < row.length); 187 | 188 | int update = format.readInt(data, offset); 189 | row[col] += update; 190 | offset += 4; 191 | } 192 | } 193 | 194 | return offset; 195 | } 196 | 197 | } 198 | -------------------------------------------------------------------------------- /src/main/main.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/Dict.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml 2 | 3 | import java.io.Serializable 4 | 5 | import scala.collection.immutable 6 | 7 | /** 8 | * Created by yunlong on 12/23/15. 9 | */ 10 | class Dict extends Serializable { 11 | 12 | var word2id = new immutable.HashMap[String, Int] 13 | var id2word = new immutable.HashMap[Int, String] 14 | 15 | def getSize: Int = { 16 | return word2id.size 17 | } 18 | 19 | def getWord(id: Int): String = { 20 | return id2word.get(id).get 21 | } 22 | 23 | def getID(word: String): Integer = { 24 | return word2id.get(word).get 25 | } 26 | 27 | /** 28 | * check if this dictionary contains a specified word 29 | */ 30 | def contains(word: String): Boolean = { 31 | return word2id.contains(word) 32 | } 33 | 34 | def contains(id: Int): Boolean = { 35 | return id2word.contains(id) 36 | } 37 | 38 | def put(word : String, id : Int): Unit = { 39 | word2id += (word -> id) 40 | id2word += (id -> word) 41 | } 42 | 43 | /** 44 | * add a word into this dictionary 45 | * return the corresponding id 46 | */ 47 | def addWord(word: String): Int = { 48 | if (!contains(word)) { 49 | val id: Int = word2id.size 50 | word2id += (word -> id) 51 | id2word += (id -> word) 52 | return id 53 | } 54 | else return getID(word) 55 | } 56 | } -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/clustering/AliasTable.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.clustering 2 | 3 | import java.util 4 | import java.util.Random 5 | 6 | /** 7 | * Created by yunlong on 12/19/15. 8 | */ 9 | class AliasTable (K: Int) extends Serializable { 10 | 11 | var L: Array[Int] = new Array[Int](K) 12 | var H: Array[Int] = new Array[Int](K) 13 | var P: Array[Float] = new Array[Float](K) 14 | 15 | def sampleAlias(gen: Random): Int = { 16 | val s = gen.nextFloat() * K 17 | var t = s.toInt 18 | if (t >= K) t = K - 1 19 | if (K * P(t) < (s-t)) L(t) else H(t) 20 | } 21 | 22 | def init(probs : Array[(Int, Float)], mass : Float): Unit = { 23 | 24 | val m = 1.0 / K 25 | val lq = new util.LinkedList[(Int, Float)]() 26 | val hq = new util.LinkedList[(Int, Float)]() 27 | 28 | for ((k, _p) <- probs) { 29 | val p = _p / mass 30 | if (p < m) lq.add((k, p)) 31 | else hq.add((k, p)) 32 | 33 | } 34 | 35 | var index = 0 36 | while (!lq.isEmpty & !hq.isEmpty) { 37 | val (l, pl) = lq.removeFirst() 38 | val (h, ph) = hq.removeFirst() 39 | L(index) = l 40 | H(index) = h 41 | P(index) = pl 42 | val pd = ph - (m - pl) 43 | if (pd >= m) hq.add((h, pd.toFloat)) 44 | else lq.add((h, pd.toFloat)) 45 | index += 1 46 | } 47 | 48 | while (!hq.isEmpty) { 49 | val (h, ph) = hq.removeFirst() 50 | L(index) = h 51 | H(index) = h 52 | P(index) = ph 53 | index += 1 54 | } 55 | 56 | while (!lq.isEmpty) { 57 | val (l, pl) = lq.removeLast() 58 | L(index) = l 59 | H(index) = l 60 | P(index) = pl 61 | index += 1 62 | } 63 | } 64 | } -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/clustering/LDAModel.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.clustering 2 | 3 | import com.intel.distml.api.Model 4 | import com.intel.distml.util.{IntMatrixWithIntKey, IntArrayWithIntKey} 5 | 6 | /** 7 | * Created by yunlong on 16-3-29. 8 | */ 9 | class LDAModel( 10 | val V : Int = 0, 11 | val K: Int = 20, 12 | val alpha : Double = 0.01, 13 | val beta : Double = 0.01 14 | ) extends Model { 15 | 16 | val alpha_sum = alpha * K 17 | val beta_sum = beta * V 18 | 19 | registerMatrix("doc-topics", new IntArrayWithIntKey(K)) 20 | registerMatrix("word-topics", new IntMatrixWithIntKey(V, K)) 21 | 22 | 23 | } 24 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/clustering/LDAParams.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.clustering 2 | 3 | /** 4 | * Created by yunlong on 12/23/15. 5 | */ 6 | case class LDAParams( 7 | var psCount : Int = 2, 8 | var batchSize : Int = 100, 9 | var input: String = null, 10 | var k: Int = 20, 11 | var alpha : Double = 0.01, 12 | var beta : Double = 0.01, 13 | var maxIterations: Int = 10, 14 | val partitions : Int = 2, 15 | var showPlexity: Boolean = true) 16 | { 17 | 18 | var V : Int = 0 19 | var alpha_sum : Double = 0.0 20 | var beta_sum : Double = 0.0 21 | 22 | def init(V : Int): Unit = { 23 | this.V = V 24 | alpha_sum = alpha * k 25 | beta_sum = beta * V 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/example/clustering/LDAExample.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.example.clustering 2 | 3 | import com.intel.distml.Dict 4 | import com.intel.distml.api.Model 5 | import com.intel.distml.clustering.{LightLDA, LDAParams} 6 | import com.intel.distml.platform.DistML 7 | import com.intel.distml.util.{IntMatrixWithIntKey, IntArrayWithIntKey} 8 | import org.apache.spark.broadcast.Broadcast 9 | import org.apache.spark.storage.StorageLevel 10 | import org.apache.spark.{SparkContext, SparkConf} 11 | import scopt.OptionParser 12 | 13 | import scala.collection.mutable.ListBuffer 14 | 15 | /** 16 | * Created by yunlong on 16-3-29. 17 | */ 18 | object LDAExample { 19 | 20 | def normalizeString(src : String) : String = { 21 | src.replaceAll("[^A-Z^a-z]", " ").trim().toLowerCase(); 22 | } 23 | 24 | def fromWordsToIds(bdic : Broadcast[Dict])(line : String) : Array[Int] = { 25 | 26 | val dic = bdic.value 27 | 28 | val words = line.split(" ") 29 | 30 | var wordIDs = new ListBuffer[Int](); 31 | 32 | for (w <- words) { 33 | val wn = normalizeString(w) 34 | if (dic.contains(wn)) { 35 | wordIDs.append(dic.getID(wn)) 36 | } 37 | } 38 | 39 | wordIDs.toArray 40 | } 41 | 42 | def main(args: Array[String]) { 43 | 44 | val defaultParams = new LDAParams() 45 | 46 | val parser = new OptionParser[LDAParams]("LDAExample") { 47 | head("LDAExample: an example LDA app for plain text data.") 48 | opt[Int]("k") 49 | .text(s"number of topics. default: ${defaultParams.k}") 50 | .action((x, c) => c.copy(k = x)) 51 | opt[Int]("batchSize") 52 | .text(s"number of samples used in one computing. default: ${defaultParams.batchSize}") 53 | .action((x, c) => c.copy(batchSize = x)) 54 | opt[Int]("psCount") 55 | .text(s"number of parameter servers. default: ${defaultParams.psCount}") 56 | .action((x, c) => c.copy(psCount = x)) 57 | opt[Double]("alpha") 58 | .text(s"super parameter for sampling. default: ${defaultParams.alpha}") 59 | .action((x, c) => c.copy(alpha = x)) 60 | opt[Double]("beta") 61 | .text(s"super parameter for sampling. default: ${defaultParams.beta}") 62 | .action((x, c) => c.copy(beta = x)) 63 | opt[Int]("maxIterations") 64 | .text(s"number of iterations of learning. default: ${defaultParams.maxIterations}") 65 | .action((x, c) => c.copy(maxIterations = x)) 66 | opt[Int]("partitions") 67 | .text(s"number of partitions to train the model. default: ${defaultParams.partitions}") 68 | .action((x, c) => c.copy(partitions = x)) 69 | opt[Boolean]("showPlexity") 70 | .text(s"Show plexity after each iteration." + 71 | s" default: ${defaultParams.showPlexity}") 72 | .action((x, c) => c.copy(showPlexity = x)) 73 | arg[String]("...") 74 | .text("input paths (directories) to plain text corpora.") 75 | .unbounded() 76 | .required() 77 | .action((x, c) => c.copy(input = x)) 78 | } 79 | parser.parse(args, defaultParams).map { params => 80 | run(params) 81 | }.getOrElse { 82 | parser.showUsageAsError 83 | sys.exit(1) 84 | } 85 | } 86 | 87 | private def run(p: LDAParams) { 88 | val conf = new SparkConf().setAppName(s"DistML.Example.LDA") 89 | conf.set("spark.driver.maxResultSize", "5g") 90 | 91 | val sc = new SparkContext(conf) 92 | Thread.sleep(3000) 93 | 94 | var rawLines = sc.textFile(p.input).filter(s => s.trim().length > 0) 95 | 96 | var dic = new Dict() 97 | 98 | val words = rawLines.flatMap(line => line.split(" ")).map(normalizeString).filter(s => s.trim().length > 0).distinct().collect() 99 | words.foreach(x => dic.addWord(x)) 100 | 101 | p.init(dic.getSize) 102 | 103 | val bdic = sc.broadcast(dic) 104 | var data = rawLines.map(fromWordsToIds(bdic)).map(ids => { 105 | // println("=============== random initializing start ==================") 106 | val topics = new ListBuffer[(Int, Int)] 107 | ids.foreach(x => topics.append((x, Math.floor(Math.random() * p.k).toInt))) 108 | 109 | val doctopic = new Array[Int](p.k) 110 | topics.foreach(x => doctopic(x._2) = doctopic(x._2) + 1) 111 | 112 | // println("=============== random initializing done ==================") 113 | (doctopic, topics.toArray) 114 | }).repartition(p.partitions).persist(StorageLevel.MEMORY_AND_DISK) 115 | 116 | var statistics = data.map(d => (1, d._2.length)).reduce((a, b) => (a._1 + b._1, a._2 + b._2)) 117 | 118 | println("=============== Corpus Info Begin ================") 119 | println("Vocaulary: " + dic.getSize) 120 | println("Docs: " + statistics._1) 121 | println("Tokens: " + statistics._2) 122 | println("Topics: " + p.k) 123 | println("=============== Corpus Info End ================") 124 | 125 | 126 | val dm = LightLDA.train(sc, data, dic.getSize, p) 127 | 128 | dm.recycle() 129 | sc.stop 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/example/feature/MllibWord2Vec.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.example.feature 2 | 3 | import org.apache.spark._ 4 | import org.apache.spark.mllib.feature.Word2Vec 5 | 6 | /** 7 | * Created by yunlong on 1/29/16. 8 | */ 9 | object MllibWord2Vec { 10 | 11 | def main(args: Array[String]) { 12 | 13 | val p = Integer.parseInt(args(1)) 14 | println("partitions: " + p) 15 | 16 | val conf = new SparkConf().setAppName("Word2Vec") 17 | val sc = new SparkContext(conf) 18 | 19 | val input = sc.textFile(args(0)).map(line => line.split(" ").toSeq) 20 | 21 | val word2vec = new Word2Vec() 22 | word2vec.setNumPartitions(p) 23 | 24 | val model = word2vec.fit(input) 25 | println("vocab size: " + model.getVectors.size) 26 | val synonyms = model.findSynonyms("black", 40) 27 | 28 | for ((synonym, cosineSimilarity) <- synonyms) { 29 | println(s"$synonym $cosineSimilarity") 30 | } 31 | 32 | sc.stop() 33 | 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/example/feature/Word2VecExample.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.example.feature 2 | 3 | import com.intel.distml.Dict 4 | import com.intel.distml.api.{Session, Model} 5 | import com.intel.distml.feature.Word2Vec 6 | import org.apache.spark.storage.StorageLevel 7 | import org.apache.spark.{SparkContext, SparkConf} 8 | import org.apache.spark.broadcast.Broadcast 9 | import scopt.OptionParser 10 | 11 | import scala.collection.mutable 12 | import scala.collection.mutable.ListBuffer 13 | 14 | /** 15 | * Created by yunlong on 16-3-23. 16 | */ 17 | object Word2VecExample { 18 | private case class Params( 19 | psCount: Int = 1, 20 | numPartitions : Int = 1, 21 | input: String = null, 22 | cbow: Boolean = true, 23 | alpha: Double = 0.0f, 24 | alphaFactor: Double = 10.0f, 25 | window : Int = 7, 26 | batchSize : Int = 100, 27 | vectorSize : Int = 10, 28 | minFreq : Int = 50, 29 | maxIterations: Int = 100 ) { 30 | 31 | def show(): Unit = { 32 | println("=========== params =============") 33 | println("psCount: " + psCount) 34 | println("numPartitions: " + psCount) 35 | println("input: " + input) 36 | println("cbow: " + cbow) 37 | println("alpha: " + alpha) 38 | println("alphaFactor: " + alphaFactor) 39 | println("window: " + window) 40 | println("batchSize: " + batchSize) 41 | println("vectorSize: " + vectorSize) 42 | println("minFreq: " + minFreq) 43 | println("maxIterations: " + maxIterations) 44 | println("=========== params =============") 45 | } 46 | } 47 | 48 | def main(args: Array[String]) { 49 | 50 | val defaultParams = Params() 51 | 52 | val parser = new OptionParser[Params]("LDAExample") { 53 | head("LDAExample: an example LDA app for plain text data.") 54 | opt[Int]("psCount") 55 | .text(s"number of parameter servers. default: ${defaultParams.psCount}") 56 | .action((x, c) => c.copy(psCount = x)) 57 | opt[Int]("numPartitions") 58 | .text(s"number of partitions for the training data. default: ${defaultParams.numPartitions}") 59 | .action((x, c) => c.copy(numPartitions = x)) 60 | opt[Double]("alpha") 61 | .text(s"initiali learning rate. default: ${defaultParams.alpha}") 62 | .action((x, c) => c.copy(alpha = x)) 63 | opt[Double]("alphaFactor") 64 | .text(s"factor to decrease adaptive learning rate. default: ${defaultParams.alphaFactor}") 65 | .action((x, c) => c.copy(alphaFactor = x)) 66 | opt[Int]("batchSize") 67 | .text(s"number of samples computed in a round. default: ${defaultParams.batchSize}") 68 | .action((x, c) => c.copy(batchSize = x)) 69 | opt[Int]("vectorSize") 70 | .text(s"vector size for a single word: ${defaultParams.vectorSize}") 71 | .action((x, c) => c.copy(vectorSize = x)) 72 | opt[Int]("maxIterations") 73 | .text(s"number of iterations of learning. default: ${defaultParams.maxIterations}") 74 | .action((x, c) => c.copy(maxIterations = x)) 75 | opt[Int]("minFreq") 76 | .text(s"minimum word frequency. default: ${defaultParams.minFreq}") 77 | .action((x, c) => c.copy(minFreq = x)) 78 | opt[Boolean]("cbow") 79 | .text(s"true if use cbow, false if use skipgram. default: ${defaultParams.cbow}") 80 | .action((x, c) => c.copy(cbow = x)) 81 | arg[String]("...") 82 | .text("input paths (directories) to plain text corpora." + 83 | " Each text file line should hold 1 document.") 84 | .unbounded() 85 | .required() 86 | .action((x, c) => c.copy(input = x)) 87 | } 88 | parser.parse(args, defaultParams).map { params => 89 | run(params) 90 | }.getOrElse { 91 | parser.showUsageAsError 92 | sys.exit(1) 93 | } 94 | } 95 | 96 | def normalizeString(src : String) : String = { 97 | //src.filter( c => ((c >= '0') && (c <= '9')) ) 98 | src.filter( c => (((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z')) || (c == ' '))).toLowerCase 99 | } 100 | 101 | def toWords(line : String) : Array[String] = { 102 | val words = line.split(" ").map(normalizeString) 103 | val list = new mutable.MutableList[String]() 104 | for (w <- words) { 105 | if (w.length > 0) 106 | list += w 107 | } 108 | 109 | list.toArray 110 | } 111 | 112 | def fromWordsToIds(bdic : Broadcast[Dict])(words : Array[String]) : Array[Int] = { 113 | 114 | val dic = bdic.value 115 | 116 | var wordIDs = new ListBuffer[Int]() 117 | 118 | for (w <- words) { 119 | val wn = normalizeString(w) 120 | if (dic.contains(wn)) { 121 | wordIDs.append(dic.getID(wn)) 122 | } 123 | } 124 | 125 | wordIDs.toArray 126 | } 127 | 128 | def run(p : Params): Unit = { 129 | 130 | p.show 131 | 132 | val conf = new SparkConf().setAppName("Word2Vec") 133 | val sc = new SparkContext(conf) 134 | 135 | val rawLines = sc.textFile(p.input) 136 | val lines = rawLines.filter(s => s.length > 0).map(toWords).persist(StorageLevel.MEMORY_AND_DISK) 137 | 138 | val words = lines.flatMap(line => line.iterator).map((_, 1L)) 139 | 140 | val countedWords = words.reduceByKey(_ + _).filter(f => f._2 > p.minFreq).sortBy(f => f._2).collect 141 | println("========== countedWords=" + countedWords.length + " ==================") 142 | 143 | var wordMap = new Dict 144 | 145 | var totalWords = 0L 146 | for (i <- 0 to countedWords.length - 1) { 147 | var item = countedWords(i) 148 | wordMap.put(item._1, i) 149 | totalWords += item._2 150 | } 151 | var wordTree = Word2Vec.createBinaryTree(countedWords) 152 | wordTree.tokens = totalWords 153 | 154 | var initialAlpha = 0.0f 155 | if (p.alpha < 10e-6) { 156 | if (p.cbow) { 157 | initialAlpha = 0.05f 158 | } 159 | else { 160 | initialAlpha = 0.0025f 161 | } 162 | } 163 | else { 164 | initialAlpha = p.alpha.toFloat 165 | } 166 | 167 | println("=============== Corpus Info Begin ================") 168 | println("Vocaulary: " + wordTree.vocabSize) 169 | println("Tokens: " + totalWords) 170 | println("Vector size: " + p.vectorSize) 171 | println("ps count: " + p.psCount) 172 | println("=============== Corpus Info End ================") 173 | 174 | val bdic = sc.broadcast(wordMap) 175 | //var data = lines.map(fromWordsToIds(bdic)).repartition(1).persist(StorageLevel.MEMORY_AND_DISK) 176 | var data = lines.map(fromWordsToIds(bdic)).repartition(p.numPartitions).persist(StorageLevel.MEMORY_AND_DISK) 177 | 178 | var alpha = initialAlpha 179 | val dm = Word2Vec.train(sc, p.psCount, data, p.vectorSize, wordTree, initialAlpha, p.alphaFactor, p.cbow, p.maxIterations, p.window, p.batchSize) 180 | 181 | dm.recycle() 182 | val result = Word2Vec.collect(dm) 183 | 184 | val vectors = new Array[Array[Float]](result.size) 185 | for (w <- result) { 186 | vectors(w._1) = w._2 187 | } 188 | 189 | val blackIndex = wordMap.getID("black") 190 | val sims = Word2Vec.findSynonyms(vectors, vectors(blackIndex), 10) 191 | for (s <- sims) { 192 | println(wordMap.getWord(s._1) + ", " + s._2) 193 | } 194 | 195 | 196 | sc.stop() 197 | 198 | System.out.println("===== Finished ====") 199 | } 200 | 201 | } 202 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/example/regression/LargeLRTest.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.example.regression 2 | 3 | import com.intel.distml.platform.DistML 4 | import com.intel.distml.regression.{LogisticRegression => LR} 5 | import com.intel.distml.util.DataStore 6 | import org.apache.spark.rdd.RDD 7 | import org.apache.spark.{SparkConf, SparkContext} 8 | import scopt.OptionParser 9 | 10 | import scala.collection.mutable.HashMap 11 | 12 | /** 13 | * Created by yunlong on 3/11/16. 14 | */ 15 | object LargeLRTest { 16 | 17 | private case class Params( 18 | runType: String = "train", 19 | psCount: Int = 1, 20 | trainType : String = "ssp", 21 | maxIterations: Int = 100, 22 | batchSize: Int = 100, // for asgd only 23 | maxLag : Int = 2, // for ssp only 24 | dim: Int = 10000000, 25 | eta: Double = 0.0001, 26 | partitions : Int = 1, 27 | input: String = null, 28 | modelPath : String = null 29 | ) 30 | 31 | def parseBlanc(line: String): (HashMap[Int, Double], Int) = { 32 | val s = line.split(" ") 33 | 34 | var a = java.lang.Float.parseFloat("1.0") 35 | var b = a.toInt 36 | var c = java.lang.Float.parseFloat(s(0)) 37 | var label = java.lang.Float.parseFloat(s(0)).toInt 38 | 39 | val x = new HashMap[Int, Double](); 40 | for (i <- 1 to s.length - 1) { 41 | val f = s(i).split(":") 42 | val v = java.lang.Double.parseDouble(f(1)) 43 | x.put(Integer.parseInt(f(0)), v) 44 | } 45 | 46 | (x, label) 47 | } 48 | 49 | def main(args: Array[String]) { 50 | 51 | val defaultParams = Params() 52 | 53 | val parser = new OptionParser[Params]("LargeLRExample") { 54 | head("LargeLRExample: an example for logistic regression with big dimension.") 55 | opt[String]("runType") 56 | .text(s"train for training model, test for testing existed model. default: ${defaultParams.runType}") 57 | .action((x, c) => c.copy(runType = x)) 58 | opt[Int]("psCount") 59 | .text(s"number of parameter servers. default: ${defaultParams.psCount}") 60 | .action((x, c) => c.copy(psCount = x)) 61 | opt[String]("trainType") 62 | .text(s"how to train your model, asgd or ssg. default: ${defaultParams.trainType}") 63 | .action((x, c) => c.copy(trainType = x)) 64 | opt[Int]("maxIterations") 65 | .text(s"number of iterations of learning. default: ${defaultParams.maxIterations}") 66 | .action((x, c) => c.copy(maxIterations = x)) 67 | opt[Int]("batchSize") 68 | .text(s"number of samples computed in a round. default: ${defaultParams.batchSize}") 69 | .action((x, c) => c.copy(batchSize = x)) 70 | opt[Int]("maxLag") 71 | .text(s"maximum number of iterations between the fast worker and the slowest worker. default: ${defaultParams.maxLag}") 72 | .action((x, c) => c.copy(maxLag = x)) 73 | opt[Int]("dim") 74 | .text(s"dimension of features. default: ${defaultParams.dim}") 75 | .action((x, c) => c.copy(dim = x)) 76 | opt[Double]("eta") 77 | .text(s"learning rate. default: ${defaultParams.eta}") 78 | .action((x, c) => c.copy(eta = x)) 79 | opt[Int]("partitions") 80 | .text(s"number of partitions for training data. default: ${defaultParams.partitions}") 81 | .action((x, c) => c.copy(partitions = x)) 82 | arg[String]("...") 83 | .text("path to train the model") 84 | .required() 85 | .action((x, c) => c.copy(input = x)) 86 | arg[String]("...") 87 | .text("path to save the model.") 88 | .required() 89 | .action((x, c) => c.copy(modelPath = x)) 90 | } 91 | parser.parse(args, defaultParams).map { params => 92 | run(params) 93 | }.getOrElse { 94 | parser.showUsageAsError 95 | sys.exit(1) 96 | } 97 | } 98 | 99 | def run(p: Params): Unit = { 100 | 101 | println("batchSize: " + p.batchSize) 102 | println("input: " + p.input) 103 | println("maxIterations: " + p.maxIterations) 104 | 105 | val conf = new SparkConf().setAppName("SparseLR") 106 | val sc = new SparkContext(conf) 107 | 108 | val samples = sc.textFile(p.input).map(parseBlanc) 109 | 110 | // val ratio = new Array[Double](2) 111 | // ratio(0) = 0.9 112 | // ratio(1) = 0.1 113 | // val t = samples.randomSplit(ratio) 114 | // val trainSet = t(0).repartition(p.partitions) 115 | // val testSet = t(1) 116 | 117 | if (p.runType.equals("train")) { 118 | train(sc, samples, p) 119 | } 120 | else { 121 | var auc = verify(sc, samples, p.modelPath) 122 | println("auc: " + auc) 123 | } 124 | 125 | // trainAgain(sc, trainSet, p.maxIterations, p.batchSize, p.modelPath) 126 | // auc = verify(sc, testSet, p.modelPath) 127 | // println("auc: " + auc) 128 | 129 | sc.stop() 130 | } 131 | 132 | def train(sc : SparkContext, samples : RDD[(HashMap[Int, Double], Int)], p : Params): Unit = { 133 | var dm : DistML[Iterator[(Int, String, DataStore)]] = null 134 | if (p.trainType.equals("bsp")) { 135 | dm = LR.trainSSP(sc, samples, p.psCount, p.dim, p.eta, p.maxIterations, 0) 136 | } 137 | else if (p.trainType.equals("ssp")) { 138 | dm = LR.trainSSP(sc, samples, p.psCount, p.dim, p.eta, p.maxIterations, p.maxLag) 139 | } 140 | else if (p.trainType.equals("asgd")) { 141 | dm = LR.trainASGD(sc, samples, p.psCount, p.dim, p.eta, p.maxIterations, p.batchSize) 142 | } 143 | LR.save(dm, p.modelPath, "") 144 | dm.recycle() 145 | } 146 | 147 | def verify(sc : SparkContext, samples : RDD[(HashMap[Int, Double], Int)], modelPath : String): Double = { 148 | 149 | val dm = LR.load(sc, modelPath) 150 | 151 | val auc = LR.auc(samples, dm) 152 | 153 | dm.recycle() 154 | 155 | auc 156 | } 157 | 158 | def trainAgain(sc : SparkContext, samples : RDD[(HashMap[Int, Double], Int)], eta : Double, maxIterations : Int, batchSize : Int, modelPath : String): Unit = { 159 | val dm = LR.load(sc, modelPath) 160 | LR.trainASGD(samples, dm, eta, maxIterations, batchSize) 161 | LR.save(dm, modelPath, "") 162 | dm.recycle() 163 | } 164 | 165 | } 166 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/example/regression/MelBlanc.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.example.regression 2 | 3 | import java.util 4 | 5 | import com.intel.distml.api.Model 6 | import com.intel.distml.platform.DistML 7 | import com.intel.distml.util.DataStore 8 | import com.intel.distml.util.scala.DoubleArrayWithIntKey 9 | import com.intel.distml.util.store.DoubleArrayStore 10 | import org.apache.spark.rdd.RDD 11 | import org.apache.spark.{SparkContext, SparkConf} 12 | import scopt.OptionParser 13 | 14 | import scala.collection.mutable.HashMap 15 | 16 | import com.intel.distml.regression.{LogisticRegression => LR} 17 | 18 | /** 19 | * Created by yunlong on 3/11/16. 20 | */ 21 | object MelBlanc { 22 | 23 | private case class Params( 24 | runType: String = "train", 25 | psCount: Int = 1, 26 | psBackup : Boolean = false, 27 | trainType : String = "ssp", 28 | maxIterations: Int = 100, 29 | batchSize: Int = 100, // for asgd only 30 | maxLag : Int = 2, // for ssp only 31 | dim: Int = 10000000, 32 | eta: Double = 0.0001, 33 | partitions : Int = 1, 34 | input: String = null, 35 | modelPath : String = null 36 | ) 37 | 38 | def parseBlanc(line: String): (HashMap[Int, Double], Int) = { 39 | val s = line.split(" ") 40 | 41 | var label = Integer.parseInt(s(0)) 42 | 43 | val x = new HashMap[Int, Double](); 44 | for (i <- 1 to s.length - 1) { 45 | val f = s(i).split(":") 46 | val v = java.lang.Double.parseDouble(f(1)) 47 | x.put(Integer.parseInt(f(0)), v) 48 | } 49 | 50 | (x, label) 51 | } 52 | 53 | def main(args: Array[String]) { 54 | 55 | val defaultParams = Params() 56 | 57 | val parser = new OptionParser[Params]("MelBlanc") { 58 | head("MelBlanc: an example of small logistic regression.") 59 | opt[String]("runType") 60 | .text(s"train for training model, test for testing existed model. default: ${defaultParams.runType}") 61 | .action((x, c) => c.copy(runType = x)) 62 | opt[Int]("psCount") 63 | .text(s"number of parameter servers. default: ${defaultParams.psCount}") 64 | .action((x, c) => c.copy(psCount = x)) 65 | opt[Boolean]("psBackup") 66 | .text(s"whether to run with parameter server fault tolerance. default: ${defaultParams.psBackup}") 67 | .action((x, c) => c.copy(psBackup = x)) 68 | opt[String]("trainType") 69 | .text(s"how to train your model, asgd or ssg. default: ${defaultParams.trainType}") 70 | .action((x, c) => c.copy(trainType = x)) 71 | opt[Int]("maxIterations") 72 | .text(s"number of iterations of learning. default: ${defaultParams.maxIterations}") 73 | .action((x, c) => c.copy(maxIterations = x)) 74 | opt[Int]("batchSize") 75 | .text(s"number of samples computed in a round. default: ${defaultParams.batchSize}") 76 | .action((x, c) => c.copy(batchSize = x)) 77 | opt[Int]("maxLag") 78 | .text(s"maximum number of iterations between the fast worker and the slowest worker. default: ${defaultParams.maxLag}") 79 | .action((x, c) => c.copy(maxLag = x)) 80 | opt[Int]("dim") 81 | .text(s"dimension of features. default: ${defaultParams.dim}") 82 | .action((x, c) => c.copy(dim = x)) 83 | opt[Double]("eta") 84 | .text(s"learning rate. default: ${defaultParams.eta}") 85 | .action((x, c) => c.copy(eta = x)) 86 | opt[Int]("partitions") 87 | .text(s"number of partitions for training data. default: ${defaultParams.partitions}") 88 | .action((x, c) => c.copy(partitions = x)) 89 | arg[String]("...") 90 | .text("path to train the model") 91 | .required() 92 | .action((x, c) => c.copy(input = x)) 93 | arg[String]("...") 94 | .text("path to save the model.") 95 | .required() 96 | .action((x, c) => c.copy(modelPath = x)) 97 | } 98 | parser.parse(args, defaultParams).map { params => 99 | run(params) 100 | }.getOrElse { 101 | parser.showUsageAsError 102 | sys.exit(1) 103 | } 104 | } 105 | 106 | def run(p: Params): Unit = { 107 | 108 | println("batchSize: " + p.batchSize) 109 | println("input: " + p.input) 110 | println("maxIterations: " + p.maxIterations) 111 | 112 | val conf = new SparkConf().setAppName("SparseLR") 113 | val sc = new SparkContext(conf) 114 | 115 | val samples = sc.textFile(p.input).map(parseBlanc) 116 | 117 | // val ratio = new Array[Double](2) 118 | // ratio(0) = 0.9 119 | // ratio(1) = 0.1 120 | // val t = samples.randomSplit(ratio) 121 | // val trainSet = t(0).repartition(p.partitions) 122 | // val testSet = t(1) 123 | 124 | if (p.runType.equals("train")) { 125 | train(sc, samples, p) 126 | } 127 | else { 128 | var auc = verify(sc, samples, p.modelPath) 129 | println("auc: " + auc) 130 | } 131 | 132 | // trainAgain(sc, trainSet, p.maxIterations, p.batchSize, p.modelPath) 133 | // auc = verify(sc, testSet, p.modelPath) 134 | // println("auc: " + auc) 135 | 136 | sc.stop() 137 | } 138 | 139 | def train(sc : SparkContext, samples : RDD[(HashMap[Int, Double], Int)], p : Params): Unit = { 140 | var dm : DistML[Iterator[(Int, String, DataStore)]] = null 141 | if (p.trainType.equals("bsp")) { 142 | dm = LR.trainSSP(sc, samples, p.psCount, p.dim, p.eta, p.maxIterations, 0) 143 | } 144 | else if (p.trainType.equals("ssp")) { 145 | dm = LR.trainSSP(sc, samples, p.psCount, p.dim, p.eta, p.maxIterations, p.maxLag) 146 | } 147 | else if (p.trainType.equals("asgd")) { 148 | dm = LR.trainASGD(sc, samples, p.psCount, p.psBackup, p.dim, p.eta, p.maxIterations, p.batchSize) 149 | } 150 | LR.save(dm, p.modelPath, "") 151 | dm.recycle() 152 | } 153 | 154 | def verify(sc : SparkContext, samples : RDD[(HashMap[Int, Double], Int)], modelPath : String): Double = { 155 | 156 | val dm = LR.load(sc, modelPath) 157 | 158 | val auc = LR.auc(samples, dm) 159 | 160 | dm.recycle() 161 | 162 | auc 163 | } 164 | 165 | def trainAgain(sc : SparkContext, samples : RDD[(HashMap[Int, Double], Int)], eta : Double, maxIterations : Int, batchSize : Int, modelPath : String): Unit = { 166 | val dm = LR.load(sc, modelPath) 167 | LR.trainASGD(samples, dm, eta, maxIterations, batchSize) 168 | LR.save(dm, modelPath, "") 169 | dm.recycle() 170 | } 171 | 172 | } 173 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/example/regression/Mnist.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.example.regression 2 | 3 | import java.util 4 | 5 | import com.intel.distml.api.{Model, Session} 6 | import com.intel.distml.platform.DistML 7 | import com.intel.distml.regression.MLR 8 | import com.intel.distml.util.{DoubleMatrix, KeyCollection, KeyList} 9 | import org.apache.spark.{SparkConf, SparkContext} 10 | import scopt.OptionParser 11 | 12 | import scala.collection.mutable 13 | import scala.collection.JavaConversions._ 14 | 15 | /** 16 | * Created by jimmy on 15-12-14. 17 | */ 18 | object Mnist { 19 | val BATCH_SIZE = 100 20 | private case class Params( 21 | input: String = null, 22 | inputDim: Long = 0, 23 | outputDim: Int = 0, 24 | partitions: Int = 1, 25 | psCount: Int = 1, 26 | batchSize: Int = 100, 27 | maxIterations: Int = 500 28 | ) 29 | 30 | def main(args: Array[String]): Unit = { 31 | val mlrParams = Params() 32 | val parser = new OptionParser[Params]("MLRExample") { 33 | head("MLRExample: an example MLR(softmax) app for plain text data.") 34 | opt[Int]("inputDim") 35 | .text(s"dimensions of features. default: ${mlrParams.inputDim}") 36 | .action((x, c) => c.copy(inputDim = x)) 37 | opt[Int]("outputDim") 38 | .text(s"dimensions of classification. default: ${mlrParams.outputDim}") 39 | .action((x, c) => c.copy(outputDim = x)) 40 | opt[Int]("batchSize") 41 | .text(s"number of samples computed in a round. default: ${mlrParams.batchSize}") 42 | .action((x, c) => c.copy(batchSize = x)) 43 | opt[Int]("psCount") 44 | .text(s"number of parameter servers. default: ${mlrParams.psCount}") 45 | .action((x, c) => c.copy(psCount = x)) 46 | opt[Int]("partitions") 47 | .text(s"number of partitions for training data. default: ${mlrParams.partitions}") 48 | .action((x, c) => c.copy(partitions = x)) 49 | opt[Int]("maxIterations") 50 | .text(s"number of iterations of training. default: ${mlrParams.maxIterations}") 51 | .action((x, c) => c.copy(maxIterations = x)) 52 | arg[String]("...") 53 | .text(s"input paths (directories) to plain text corpora." + 54 | s"Each text file line is one sample. default: ${mlrParams.input}") 55 | .unbounded() 56 | .required() 57 | .action((x, c) => c.copy(input = x)) 58 | 59 | } 60 | 61 | parser.parse(args, mlrParams).map { 62 | params => run(params) 63 | }.getOrElse { 64 | parser.showUsageAsError 65 | sys.exit(1) 66 | } 67 | } 68 | 69 | def parseLine(line: String, dim: Long): Unit = { 70 | val items = line.split(" ") 71 | val labels = new Array[Double](dim.toInt) 72 | val len = items.length 73 | //val data = new Array[Double](len-1) 74 | val data = new util.HashMap[Long, Double]() 75 | for (i <- 0 to len-2) { 76 | val tmp = items(i).toDouble 77 | if (tmp != 0.0) { 78 | data(i) = tmp 79 | } 80 | } 81 | labels(items(len-1).toInt) = 1.0 82 | (data, labels) 83 | } 84 | 85 | 86 | 87 | def dumpweights(weights: util.HashMap[java.lang.Long, Array[java.lang.Double]], keyList: KeyList): Unit = { 88 | for (i<-0 until 2) { 89 | println("for line " + i) 90 | for (key <- keyList.iterator()) { 91 | print(" " + weights.get(key)(i)) 92 | } 93 | } 94 | } 95 | 96 | def dumpweights(weights: util.HashMap[java.lang.Long, Array[java.lang.Double]], keySet: util.Set[Long]): Unit = { 97 | for (i<-0 until 2) { 98 | println("for line " + i) 99 | for (key <- keySet.iterator()) { 100 | print(" " + weights.get(key)(i)) 101 | } 102 | } 103 | } 104 | def run(p: Params): Unit = { 105 | println("batchSize: " + p.batchSize) 106 | println("input: " + p.input) 107 | println("maxIterations: " + p.maxIterations) 108 | 109 | val conf = new SparkConf().setAppName("SparkMLR") 110 | 111 | val sc = new SparkContext(conf) 112 | val dim = p.outputDim 113 | val samples = sc.textFile(p.input).map(line => { 114 | val items = line.split(" ") 115 | val labels = new Array[Double](p.outputDim.toInt) 116 | val len = items.length 117 | //val data = new Array[Double](len-1) 118 | val data = new mutable.HashMap[Long, Double]() 119 | for (i <- 0 to len-2) { 120 | val tmp = items(i).toDouble 121 | if (tmp != 0.0) { 122 | data(i) = tmp 123 | } 124 | } 125 | labels(items(len-1).toInt) = 1.0 126 | (data, labels) 127 | }).repartition(p.partitions) 128 | 129 | val dm = MLR.train(sc, samples, p.psCount, p.inputDim, p.outputDim, p.maxIterations, p.batchSize) 130 | 131 | val correct = MLR.validate(samples, dm) 132 | println("Total Correct " + correct) 133 | 134 | dm.recycle() 135 | sc.stop() 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/platform/Clock.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.platform 2 | 3 | /** 4 | * Created by yunlong on 2/3/16. 5 | */ 6 | class Clock(name : String) { 7 | 8 | var startTime = 0L 9 | var total = 0L 10 | 11 | def reset(): Unit = { 12 | total = 0L 13 | startTime = 0L 14 | } 15 | 16 | def start(): Unit = { 17 | total = 0L 18 | startTime = System.currentTimeMillis() 19 | } 20 | 21 | def stop(): Unit = { 22 | total += System.currentTimeMillis() - startTime 23 | println("[" + name + "] " + total + " ms") 24 | reset 25 | } 26 | 27 | def pause(): Unit = { 28 | total += System.currentTimeMillis() - startTime 29 | } 30 | 31 | def resume(): Unit = { 32 | startTime = System.currentTimeMillis() 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/platform/ParamServerDriver.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.platform 2 | 3 | import akka.actor.ActorSystem 4 | import com.intel.distml.api.Model 5 | import com.intel.distml.util.DataStore 6 | import com.typesafe.config.ConfigFactory 7 | import org.apache.spark.broadcast.Broadcast 8 | import org.apache.spark.rdd.RDD 9 | import org.apache.spark.storage.StorageLevel 10 | import org.apache.spark.{SparkEnv, SparkConf, SparkContext} 11 | 12 | import scala.collection.mutable 13 | import scala.reflect.ClassTag 14 | 15 | 16 | /** 17 | * Created by yunlong on 6/3/15. 18 | */ 19 | class ParamServerDriver[T : ClassTag] (@transient spark : SparkContext, 20 | modelBroadcast : Broadcast[Model], 21 | actorSystemConfig : String, 22 | monitor : String, 23 | psCount : Int, 24 | backup : Boolean, 25 | f : Function3[Model, Int, java.util.HashMap[String, DataStore], T]) extends Thread with Serializable { 26 | 27 | var finalResult : RDD[T] = null 28 | 29 | def paramServerTask(modelBroadcast : Broadcast[Model], prefix : String, f : Function3[Model, Int, java.util.HashMap[String, DataStore], T])(index: Int) : T = { 30 | 31 | val eId = SparkEnv.get.executorId 32 | println("starting server task: " + index + ", " + eId) 33 | 34 | val PARAMETER_SERVER_ACTOR_SYSTEM_NAME = "parameter-server-system" 35 | val PARAMETER_SERVER_ACTOR_NAME = "parameter-server" 36 | 37 | // Start actor system 38 | val parameterServerRemoteConfig = ConfigFactory.parseString(actorSystemConfig) 39 | val parameterServerActorSystem = ActorSystem(PARAMETER_SERVER_ACTOR_SYSTEM_NAME + index, 40 | ConfigFactory.load(parameterServerRemoteConfig)) 41 | 42 | // Start parameter server 43 | val model = modelBroadcast.value 44 | val stores = DataStore.createStores(model, index) 45 | 46 | val parameterServer = parameterServerActorSystem.actorOf(PSActor.props(model, stores, monitor, index, eId, prefix), 47 | PARAMETER_SERVER_ACTOR_NAME) 48 | 49 | parameterServerActorSystem.awaitTermination() 50 | 51 | println("stopping server task") 52 | 53 | f(model, index, stores) 54 | } 55 | 56 | override def run() { 57 | var prefix = System.getenv("PS_NETWORK_PREFIX") 58 | 59 | var da : Array[Int] = null 60 | if (backup) { 61 | println("start parameter servers with backup") 62 | da = new Array[Int](psCount * 2) 63 | for (i <- 0 to psCount*2 - 1) 64 | da(i) = i % psCount 65 | } 66 | else { 67 | println("start parameter servers") 68 | da = new Array[Int](psCount) 69 | for (i <- 0 to psCount - 1) 70 | da(i) = i 71 | } 72 | 73 | val data = spark.parallelize(da, psCount*2) 74 | println("prepare to start parameter servers: " + data.partitions.length) 75 | finalResult = data.map(paramServerTask(modelBroadcast, prefix, f)).cache 76 | finalResult.count() 77 | println("parameter servers finish their work.") 78 | } 79 | 80 | } 81 | 82 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/platform/PipeLine.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.platform 2 | 3 | /** 4 | * Created by yunlong on 2/4/16. 5 | */ 6 | class PipeLine(tasks : Task, scheduler : Scheduler) extends Thread { 7 | 8 | override def run(): Unit = { 9 | var finished = false 10 | while(!finished) { 11 | var t = tasks 12 | 13 | println("run piple line now: " + t.resource) 14 | while((!finished) && (t != null)) { 15 | while(!scheduler.getResource(t)) { 16 | Thread.sleep(1000) 17 | } 18 | println("starting task: " + t.name) 19 | finished = t.run() 20 | println("task: " + t.name + "done with result: " + finished) 21 | scheduler.releaseResource(t) 22 | t = t.nextTask 23 | } 24 | } 25 | } 26 | } 27 | 28 | class Scheduler(var cpu : Int = 1, var network : Int = 1) { 29 | 30 | def getResource(t : Task) : Boolean = { 31 | //println("get resource: " + cpu + ", " + network + ", " + t.resource + ", " + Thread.currentThread().getId) 32 | if (t.resource == Task.CPU) 33 | getCPU() 34 | else { 35 | getNetwork() 36 | } 37 | } 38 | 39 | def releaseResource(t : Task) : Unit = { 40 | if (t.resource == Task.CPU) 41 | releaseCPU() 42 | else 43 | releaseNetwork() 44 | } 45 | 46 | private[Scheduler] def getCPU() : Boolean = { 47 | synchronized { 48 | if (cpu > 0) { 49 | cpu -= 1 50 | true 51 | } 52 | else false 53 | } 54 | } 55 | 56 | private[Scheduler] def releaseCPU() : Unit = { 57 | synchronized { 58 | cpu += 1 59 | } 60 | } 61 | 62 | private[Scheduler] def getNetwork() : Boolean = { 63 | synchronized { 64 | //println("get network: " + network) 65 | if (network > 0) { 66 | network -= 1 67 | true 68 | } 69 | else false 70 | } 71 | } 72 | 73 | private[Scheduler] def releaseNetwork() : Unit = { 74 | synchronized { 75 | network += 1 76 | } 77 | } 78 | } 79 | 80 | abstract class Task(val name : String, val resource : Int) extends Serializable { 81 | 82 | var nextTask : Task = null 83 | 84 | def run() : Boolean 85 | } 86 | 87 | object Task { 88 | final val NETWORK = 1 89 | final val CPU = 0 90 | } -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/regression/MLR.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.regression 2 | 3 | import java.util 4 | 5 | import com.intel.distml.api.{Model, Session} 6 | import com.intel.distml.platform.DistML 7 | import com.intel.distml.util.{KeyCollection, DoubleMatrix, KeyList, DataStore} 8 | import com.intel.distml.util.scala.DoubleArrayWithIntKey 9 | import org.apache.spark.SparkContext 10 | import org.apache.spark.rdd.RDD 11 | 12 | import scala.collection.JavaConversions._ 13 | import scala.collection.mutable 14 | 15 | /** 16 | * Created by yunlong on 16-3-30. 17 | */ 18 | 19 | class MLRModel( 20 | val inputDim : Long, 21 | val outputDim : Int 22 | ) extends Model { 23 | 24 | registerMatrix("weights", new DoubleMatrix(inputDim, outputDim)) 25 | 26 | } 27 | 28 | object MLR { 29 | 30 | private[MLR] def softmax(x: Array[Double]): Unit = { 31 | val max = x.max 32 | var sum = 0.0 33 | 34 | for (i <- 0 until x.length) { 35 | x(i) = math.exp(x(i) - max) 36 | sum += x(i) 37 | } 38 | 39 | for (i <- 0 until x.length) x(i) /= sum 40 | } 41 | 42 | def train(data: RDD[(mutable.HashMap[Long, Double], Array[Double])], dm : DistML[Iterator[(Int, String, DataStore)]], 43 | maxIterations : Int, batchSize : Int): Unit = { 44 | 45 | val m = dm.model.asInstanceOf[MLRModel] 46 | val monitorPath = dm.monitorPath 47 | 48 | var lr = 0.1 49 | 50 | for (iter <- 0 to maxIterations) { 51 | val t = data.mapPartitionsWithIndex((index, it) => { 52 | 53 | println("Worker: tid=" + Thread.currentThread.getId + ", " + index) 54 | println("--- connecting to PS ---") 55 | val session = new Session(m, monitorPath, index) 56 | 57 | val weights = m.getMatrix("weights").asInstanceOf[DoubleMatrix] 58 | val samples = new util.LinkedList[(mutable.HashMap[Long, Double], Array[Double])] 59 | 60 | var cost = 0.0 61 | 62 | while(it.hasNext) { 63 | samples.clear() 64 | var count = 0 65 | while((count < batchSize) && it.hasNext) { 66 | samples.add(it.next()) 67 | count += 1 68 | } 69 | val keys = new KeyList() 70 | for ((x, label) <- samples) { 71 | for (key <- x.keySet) { 72 | keys.addKey(key) 73 | } 74 | } 75 | val w = weights.fetch(keys, session) 76 | 77 | val w_old = new util.HashMap[Long, Array[Double]] 78 | for ((key, value) <- w) { 79 | var tmp:Array[Double] = new Array[Double](value.length) 80 | for (i <- 0 until value.length) { 81 | tmp(i) = value(i) 82 | } 83 | w_old.put(key, tmp) 84 | } 85 | 86 | for ((x, label) <- samples) { 87 | val p_y_given_x: Array[Double] = new Array[Double](m.outputDim) 88 | val dy: Array[Double] = new Array[Double](m.outputDim) 89 | 90 | val i:Long = 0 91 | for (i<-0 until m.outputDim) { 92 | p_y_given_x(i) = 0.0 93 | for (key <- x.keySet) { 94 | p_y_given_x(i) += (w.get(key))(i) * x.get(key).get 95 | } 96 | } 97 | 98 | softmax(p_y_given_x) 99 | 100 | for (i <- 0 until m.outputDim) { 101 | dy(i) = label(i) - p_y_given_x(i) 102 | for (key <- x.keySet) { 103 | (w.get(key)) (i.toInt) += lr * dy(i) * x.get(key).get 104 | } 105 | if (label(i) > 0.0) { 106 | cost = cost + label(i) * Math.log(p_y_given_x(i)) 107 | } 108 | } 109 | } 110 | 111 | cost /= samples.size() 112 | 113 | // tuning in future 114 | for (key <- w.keySet()) { 115 | val grad: Array[java.lang.Double] = new Array[java.lang.Double](m.outputDim) 116 | for (i <- 0 until m.outputDim) { 117 | grad(i) = w.get(key)(i) - w_old.get(key)(i) 118 | } 119 | w.put(key, grad) 120 | } 121 | 122 | weights.push(w, session) 123 | 124 | //session.progress(samples.size()) 125 | } 126 | 127 | println("--- disconnect ---") 128 | session.disconnect() 129 | 130 | val r = new Array[Double](1) 131 | r(0) = cost 132 | r.iterator 133 | }) 134 | 135 | dm.iterationDone() 136 | 137 | 138 | val totalCost = t.reduce(_+_) 139 | println("Total Cost: " + totalCost) 140 | } 141 | } 142 | 143 | def train(sc : SparkContext, samples: RDD[(mutable.HashMap[Long, Double], Array[Double])], psCount : Int, inputDim : Long, outputDim : Int, 144 | maxIterations : Int, batchSize : Int): DistML[Iterator[(Int, String, DataStore)]] = { 145 | 146 | val m = new MLRModel(inputDim, outputDim) 147 | 148 | val dm = DistML.distribute(sc, m, psCount, DistML.defaultF) 149 | val monitorPath = dm.monitorPath 150 | 151 | train(samples, dm, maxIterations, batchSize) 152 | 153 | dm 154 | } 155 | 156 | def validate(data : RDD[(mutable.HashMap[Long, Double], Array[Double])], dm : DistML[Iterator[(Int, String, DataStore)]]): Int = { 157 | val m = dm.model.asInstanceOf[MLRModel] 158 | val monitorPath = dm.monitorPath 159 | 160 | val t1 = data.mapPartitionsWithIndex( (index, it) => { 161 | val session = new Session(m, monitorPath, index) 162 | val weights = m.getMatrix("weights").asInstanceOf[DoubleMatrix] 163 | val w = weights.fetch(KeyCollection.ALL, session) 164 | 165 | var correct = 0 166 | var error = 0 167 | for((x, label) <- it) { 168 | val p_y_given_x: Array[Double] = new Array[Double](m.outputDim) 169 | val i:Long = 0 170 | for (i<-0 until m.outputDim) { 171 | p_y_given_x(i) = 0.0 172 | for (key <- x.keySet) { 173 | p_y_given_x(i) += (w.get(key))(i) * x.get(key).get 174 | } 175 | //add bias 176 | } 177 | 178 | softmax(p_y_given_x) 179 | val max = p_y_given_x.max 180 | var c_tmp = 0 181 | for (i <- 0 until m.outputDim) { 182 | if(label(i) > 0.0 && p_y_given_x(i) == max) { 183 | c_tmp = 1; 184 | } 185 | } 186 | if(c_tmp == 1) { 187 | correct += 1 188 | } else { 189 | error += 1 190 | } 191 | } 192 | 193 | session.disconnect() 194 | 195 | val r = new Array[Int](1) 196 | r(0) = correct 197 | //r(1) = error 198 | r.iterator 199 | }) 200 | 201 | t1.reduce(_+_) 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/util/scala/DoubleArray.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.scala 2 | 3 | import com.intel.distml.util.DataDesc 4 | 5 | /** 6 | * Created by yunlong on 1/3/16. 7 | */ 8 | class DoubleArray ( 9 | dim : Long) extends SparseArray[Long, Double](dim, DataDesc.KEY_TYPE_LONG, DataDesc.ELEMENT_TYPE_DOUBLE) { 10 | 11 | protected def toLong(k : Long) : Long = k 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/util/scala/DoubleArrayWithIntKey.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.scala 2 | 3 | import com.intel.distml.util.DataDesc 4 | 5 | /** 6 | * Created by yunlong on 1/3/16. 7 | */ 8 | class DoubleArrayWithIntKey( 9 | dim : Long) extends SparseArray[Int, Double](dim, DataDesc.KEY_TYPE_INT, DataDesc.ELEMENT_TYPE_DOUBLE) { 10 | 11 | protected def toLong(k : Int) : Long = k 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/util/scala/DoubleMatrixWithIntKey.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.scala 2 | 3 | import com.intel.distml.util.DataDesc 4 | 5 | /** 6 | * Created by jimmy on 16-4-7. 7 | */ 8 | class DoubleMatrixWithIntKey 9 | ( 10 | dim: Long, 11 | cols: Int 12 | ) extends SparseMatrix [Int,Double](dim, cols, DataDesc.KEY_TYPE_INT, DataDesc.ELEMENT_TYPE_DOUBLE){ 13 | 14 | override protected def isZero (value: Double): Boolean = { 15 | return Math.abs(value) < 10e-6 16 | } 17 | 18 | override protected def subtract(value: Double, delta : Double): Double = { value - delta } 19 | 20 | override protected def createValueArray (size: Int): Array[Double] = { 21 | return new Array[Double] (size) 22 | } 23 | 24 | override protected def toLong(k : Int) : Long = { k } 25 | } 26 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/util/scala/FloatArray.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.scala 2 | 3 | import com.intel.distml.util.DataDesc 4 | 5 | /** 6 | * Created by yunlong on 1/3/16. 7 | */ 8 | class FloatArray ( 9 | dim : Long) extends SparseArray[Long, Float](dim, DataDesc.KEY_TYPE_LONG, DataDesc.ELEMENT_TYPE_FLOAT) { 10 | 11 | protected def toLong(k : Long) : Long = k 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/util/scala/FloatMatrix.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.scala 2 | 3 | import com.intel.distml.util.DataDesc 4 | 5 | /** 6 | * Created by yunlong on 1/26/16. 7 | */ 8 | class FloatMatrix ( 9 | dim : Long, 10 | cols: Int 11 | ) extends SparseMatrix[Long, Float](dim, cols, DataDesc.KEY_TYPE_LONG, DataDesc.ELEMENT_TYPE_FLOAT) { 12 | 13 | override protected def isZero (value: Float): Boolean = { 14 | return Math.abs(value) < 10e-6 15 | } 16 | 17 | override protected def subtract(value: Float, delta : Float): Float = { value - delta } 18 | 19 | override protected def createValueArray (size: Int): Array[Float] = { 20 | return new Array[Float] (size) 21 | } 22 | 23 | override protected def toLong(k : Long) : Long = { k } 24 | 25 | } -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/util/scala/FloatMatrixAdapGradWithIntKey.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.scala 2 | 3 | import com.intel.distml.util.DataDesc 4 | 5 | /** 6 | * Created by yunlong on 1/26/16. 7 | */ 8 | class FloatMatrixAdapGradWithIntKey ( 9 | dim : Long, 10 | cols: Int 11 | ) extends SparseMatrixAdapGrad[Int, Float](dim, cols, DataDesc.KEY_TYPE_INT, DataDesc.ELEMENT_TYPE_FLOAT) { 12 | 13 | override protected def isZero (value: Float): Boolean = { 14 | return Math.abs(value) < 10e-8 15 | } 16 | 17 | override protected def subtract(value: Float, delta : Float): Float = { value - delta } 18 | 19 | override protected def createValueArray (size: Int): Array[Float] = { 20 | return new Array[Float] (size) 21 | } 22 | 23 | override protected def createValueArrayWithAlpha (size: Int): Array[(Float, Float)] = { 24 | return new Array[(Float, Float)] (size) 25 | } 26 | 27 | override protected def toLong(k : Int) : Long = { k } 28 | 29 | } -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/util/scala/FloatMatrixWithIntKey.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.scala 2 | 3 | import com.intel.distml.util.DataDesc 4 | 5 | /** 6 | * Created by yunlong on 1/26/16. 7 | */ 8 | class FloatMatrixWithIntKey ( 9 | dim : Long, 10 | cols: Int 11 | ) extends SparseMatrix[Int, Float](dim, cols, DataDesc.KEY_TYPE_INT, DataDesc.ELEMENT_TYPE_FLOAT) { 12 | 13 | override protected def isZero (value: Float): Boolean = { 14 | return Math.abs(value) < 10e-6 15 | } 16 | 17 | override protected def subtract(value: Float, delta : Float): Float = { value - delta } 18 | 19 | override protected def createValueArray (size: Int): Array[Float] = { 20 | return new Array[Float] (size) 21 | } 22 | 23 | override protected def toLong(k : Int) : Long = { k } 24 | 25 | } -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/util/scala/IntArray.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.scala 2 | 3 | import com.intel.distml.util.DataDesc 4 | 5 | /** 6 | * Created by yunlong on 1/3/16. 7 | */ 8 | class IntArray ( 9 | dim : Long) extends SparseArray[Long, Int](dim, DataDesc.KEY_TYPE_LONG, DataDesc.ELEMENT_TYPE_INT) { 10 | 11 | protected def toLong(k : Long) : Long = k 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/util/scala/IntMatrix.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.scala 2 | 3 | import com.intel.distml.util.DataDesc 4 | 5 | /** 6 | * Created by yunlong on 1/26/16. 7 | */ 8 | class IntMatrix ( 9 | dim : Long, 10 | cols: Int 11 | ) extends SparseMatrix[Long, Int](dim, cols, DataDesc.KEY_TYPE_LONG, DataDesc.ELEMENT_TYPE_INT) { 12 | 13 | protected def isZero (value: Int): Boolean = { 14 | return value == 0 15 | } 16 | 17 | override protected def subtract(value: Int, delta : Int): Int = { value - delta } 18 | 19 | protected def createValueArray (size: Int): Array[Int] = { 20 | return new Array[Int] (size) 21 | } 22 | 23 | protected def toLong(k : Long) : Long = { k } 24 | 25 | } -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/util/scala/SparseArray.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.scala 2 | 3 | import com.intel.distml.api.{Session, DMatrix} 4 | import com.intel.distml.util.{KeyCollection, DataDesc} 5 | 6 | import scala.collection.mutable 7 | 8 | /** 9 | * Created by yunlong on 1/3/16. 10 | */ 11 | abstract class SparseArray[K, V] ( 12 | dim : Long, 13 | keyType : Int, 14 | valueType : Int) extends DMatrix(dim) { 15 | 16 | format = new DataDesc(DataDesc.DATA_TYPE_ARRAY, keyType, valueType) 17 | 18 | def fetch(rows: KeyCollection, session: Session): mutable.HashMap[K, V] = { 19 | val result = new mutable.HashMap[K, V] 20 | val data: Array[Array[Byte]] = session.dataBus.fetch(name, rows, format) 21 | for (obj <- data) { 22 | val m: mutable.HashMap[K, V] = readMap(obj) 23 | m.foreach( f => result += f ) 24 | } 25 | return result 26 | } 27 | 28 | def push(data: mutable.HashMap[K, V], session: Session) { 29 | val bufs = new Array[Array[Byte]](partitions.length) 30 | 31 | for (i <- 0 to partitions.length - 1) { 32 | val p: KeyCollection = partitions(i) 33 | val m: mutable.HashMap[K, V] = new mutable.HashMap[K, V] 34 | for (key <- data.keySet) { 35 | if (p.contains(toLong(key))) m.put(key, data(key)) 36 | } 37 | bufs(i) = writeMap(m) 38 | } 39 | session.dataBus.push(name, format, bufs) 40 | } 41 | 42 | private def readMap(buf: Array[Byte]): mutable.HashMap[K, V] = { 43 | val data = new mutable.HashMap[K, V] 44 | var offset: Int = 0 45 | while (offset < buf.length) { 46 | val key: K = format.readKey(buf, offset).asInstanceOf[K] 47 | offset += format.keySize 48 | val value: V = format.readValue(buf, offset).asInstanceOf[V] 49 | offset += format.valueSize 50 | data.put(key, value) 51 | } 52 | return data 53 | } 54 | 55 | private def writeMap(data: mutable.HashMap[K, V]): Array[Byte] = { 56 | val recordLen: Int = format.keySize + format.valueSize 57 | val buf: Array[Byte] = new Array[Byte](recordLen * data.size) 58 | var offset: Int = 0 59 | for ((k, v) <- data) { 60 | format.writeKey(k.asInstanceOf[Number], buf, offset) 61 | offset += format.keySize 62 | offset = format.writeValue(v, buf, offset) 63 | } 64 | return buf 65 | } 66 | 67 | protected def toLong(k : K) : Long 68 | 69 | } 70 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/util/scala/SparseMatrix.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.scala 2 | 3 | import java.util 4 | 5 | import com.intel.distml.api.{DMatrix, Session} 6 | import com.intel.distml.util.{KeyRange, DataDesc, KeyCollection} 7 | 8 | import scala.collection.mutable 9 | 10 | /** 11 | * Created by yunlong on 1/3/16. 12 | */ 13 | abstract class SparseMatrix[K, V] ( 14 | dim : Long, 15 | cols: Int, 16 | keyType : Int, 17 | valueType : Int) extends DMatrix(dim) { 18 | 19 | format = new DataDesc(DataDesc.DATA_TYPE_MATRIX, keyType, valueType) 20 | val colKeys = new KeyRange(0, cols - 1) 21 | 22 | override def getColKeys() = colKeys 23 | 24 | def fetch(rows: KeyCollection, session: Session): mutable.HashMap[K, Array[V]] = { 25 | val result = new mutable.HashMap[K, Array[V]] 26 | val data = session.dataBus.fetch(name, rows, format) 27 | for (obj <- data) { 28 | val m: mutable.HashMap[K, Array[V]] = readMap(obj) 29 | m.foreach( f => result += f ) 30 | } 31 | return result 32 | } 33 | 34 | def push(data: mutable.HashMap[K, Array[V]], session: Session) { 35 | val bufs = new Array[Array[Byte]](partitions.length) 36 | 37 | for (i <- 0 to partitions.length - 1) { 38 | val p = partitions(i) 39 | val m = new mutable.HashMap[K, Array[V]] 40 | for (key <- data.keySet) { 41 | if (p.contains(toLong(key))) m.put(key, data(key)) 42 | } 43 | bufs(i) = writeMap(m) 44 | } 45 | session.dataBus.push(name, format, bufs) 46 | } 47 | 48 | def cloneData(data: mutable.HashMap[K, Array[V]]): mutable.HashMap[K, Array[V]] = { 49 | val t = new mutable.HashMap[K, Array[V]] 50 | for ((k, v) <- data) { 51 | val tv = createValueArray(v.length) 52 | for (i <- 0 to v.length-1) { 53 | tv(i) = v(i) 54 | } 55 | t.put(k, tv) 56 | } 57 | t 58 | } 59 | 60 | def getUpdate(data: mutable.HashMap[K, Array[V]], delta: mutable.HashMap[K, Array[V]]) { 61 | for ((k, v) <- data) { 62 | val tv = delta.get(k).get 63 | for (i <- 0 to v.length-1) { 64 | tv(i) = subtract(v(i), tv(i)) 65 | } 66 | } 67 | } 68 | 69 | private def readMap(buf: Array[Byte]): mutable.HashMap[K, Array[V]] = { 70 | //println("read map: " + buf.length) 71 | val data = new mutable.HashMap[K, Array[V]] 72 | var offset: Int = 0 73 | while (offset < buf.length) { 74 | //println("read offset: " + offset) 75 | val key: K = format.readKey(buf, offset).asInstanceOf[K] 76 | offset += format.keySize 77 | val values = createValueArray(colKeys.size.toInt) 78 | if (format.denseColumn) { 79 | for (i <- 0 to colKeys.size.toInt - 1) { 80 | //println("read offset: " + offset + ", " + i) 81 | values(i) = format.readValue(buf, offset).asInstanceOf[V] 82 | offset += format.valueSize 83 | } 84 | } 85 | else { 86 | val count: Int = format.readInt(buf, offset) 87 | offset += 4 88 | for (i <- 0 to count -1 ) { 89 | val index: Int = format.readInt(buf, offset) 90 | offset += 4 91 | val value: V = format.readValue(buf, offset).asInstanceOf[V] 92 | offset += format.valueSize 93 | values(index) = value 94 | } 95 | } 96 | data.put(key, values) 97 | } 98 | return data 99 | } 100 | 101 | private def writeMap(data: mutable.HashMap[K, Array[V]]): Array[Byte] = { 102 | 103 | var buf: Array[Byte] = null 104 | 105 | if (format.denseColumn) { 106 | val len: Int = (format.valueSize * data.size * colKeys.size).toInt 107 | buf = new Array[Byte](format.keySize * data.size + len) 108 | } 109 | else { 110 | var nzcount: Int = 0 111 | for (values <- data.values) { 112 | for (value <- values) { 113 | if (!isZero(value)) { 114 | nzcount += 1 115 | } 116 | } 117 | } 118 | val len: Int = ((format.valueSize + 4) * nzcount).toInt 119 | buf = new Array[Byte](format.keySize * data.size + len) 120 | } 121 | 122 | var offset: Int = 0 123 | for ((k, v) <- data) { 124 | format.writeKey(k.asInstanceOf[Number], buf, offset) 125 | offset += format.keySize 126 | val values: Array[V] = v 127 | if (format.denseColumn) { 128 | for ( i <- 0 to colKeys.size.toInt - 1) { 129 | format.writeValue(values(i), buf, offset) 130 | offset += format.valueSize 131 | } 132 | } 133 | else { 134 | val counterIndex: Int = offset 135 | offset += 4 136 | var counter: Int = 0 137 | for (i <- 0 to values.length - 1) { 138 | val value: V = values(i) 139 | if (!isZero(value)) { 140 | format.write(i, buf, offset) 141 | offset += 4 142 | format.writeValue(value, buf, offset) 143 | offset += format.valueSize 144 | } 145 | counter += 1 146 | } 147 | format.write(counter, buf, counterIndex) 148 | } 149 | } 150 | return buf 151 | } 152 | 153 | protected def toLong(k : K) : Long 154 | 155 | protected def isZero(value: V): Boolean 156 | 157 | protected def subtract(value: V, delta : V): V 158 | 159 | protected def createValueArray(size: Int): Array[V] 160 | } 161 | -------------------------------------------------------------------------------- /src/main/scala/com/intel/distml/util/scala/SparseMatrixAdapGrad.scala: -------------------------------------------------------------------------------- 1 | package com.intel.distml.util.scala 2 | 3 | import com.intel.distml.api.{DMatrix, Session} 4 | import com.intel.distml.util.{DataDesc, KeyCollection, KeyRange} 5 | 6 | import scala.collection.mutable 7 | 8 | /** 9 | * Created by yunlong on 1/3/16. 10 | */ 11 | abstract class SparseMatrixAdapGrad[K, V] ( 12 | dim : Long, 13 | cols: Int, 14 | keyType : Int, 15 | valueType : Int) extends DMatrix(dim) { 16 | 17 | format = new DataDesc(DataDesc.DATA_TYPE_MATRIX, keyType, valueType, false, true, true) 18 | val colKeys = new KeyRange(0, cols - 1) 19 | 20 | override def getColKeys() = colKeys 21 | 22 | def fetch(rows: KeyCollection, session: Session): mutable.HashMap[K, (Array[V], Array[Float])] = { 23 | val result = new mutable.HashMap[K, (Array[V], Array[Float])] 24 | val data = session.dataBus.fetch(name, rows, format) 25 | for (obj <- data) { 26 | val m: mutable.HashMap[K, (Array[V], Array[Float])] = readMap(obj) 27 | m.foreach( f => result += f ) 28 | } 29 | return result 30 | } 31 | 32 | def push(data: mutable.HashMap[K, Array[V]], session: Session) { 33 | val bufs = new Array[Array[Byte]](partitions.length) 34 | 35 | for (i <- 0 to partitions.length - 1) { 36 | val p = partitions(i) 37 | val m = new mutable.HashMap[K, Array[V]] 38 | for (key <- data.keySet) { 39 | if (p.contains(toLong(key))) m.put(key, data(key)) 40 | } 41 | bufs(i) = writeMap(m) 42 | } 43 | session.dataBus.push(name, format, bufs) 44 | } 45 | 46 | def cloneData(data: mutable.HashMap[K, (Array[V], Array[Float])]): mutable.HashMap[K, Array[V]] = { 47 | val t = new mutable.HashMap[K, Array[V]] 48 | for ((k, v) <- data) { 49 | 50 | val tv = createValueArray(v._1.length) 51 | for (i <- 0 to v._1.length-1) 52 | tv(i) = v._1(i) 53 | // System.arraycopy(tv, 0, v._1, 0, v._1.length) 54 | t.put(k, tv) 55 | } 56 | t 57 | } 58 | 59 | def getUpdate(newData: mutable.HashMap[K, (Array[V], Array[Float])], oldData: mutable.HashMap[K, Array[V]]) { 60 | //println("get update: " + oldData.size) 61 | for ((k, v) <- newData) { 62 | val tv = oldData.get(k).get 63 | for (i <- 0 to v._1.length-1) { 64 | tv(i) = subtract(v._1(i), tv(i)) 65 | } 66 | } 67 | 68 | var keys = oldData.keySet 69 | for (k <- keys) { 70 | var v = oldData.get(k).get 71 | 72 | var i = 0 73 | while((i < v.length) && isZero(v(i))) { 74 | i += 1 75 | } 76 | if (i == v.length) { 77 | oldData.remove(k) 78 | } 79 | // else { 80 | // println("update(" + k + ")(0) = " + v(0)) 81 | // } 82 | } 83 | //println("get update done: " + oldData.size) 84 | } 85 | 86 | private def readMap(buf: Array[Byte]): mutable.HashMap[K, (Array[V], Array[Float])] = { 87 | //println("read map: " + buf.length) 88 | val data = new mutable.HashMap[K, (Array[V], Array[Float])] 89 | var offset: Int = 0 90 | while (offset < buf.length) { 91 | //println("read offset: " + offset) 92 | val key: K = format.readKey(buf, offset).asInstanceOf[K] 93 | offset += format.keySize 94 | val values = createValueArray(colKeys.size.toInt) 95 | val alphas = new Array[Float](colKeys.size.toInt) 96 | if (format.denseColumn) { 97 | for (i <- 0 to colKeys.size.toInt - 1) { 98 | //println("read offset: " + offset + ", " + i) 99 | values(i) = format.readValue(buf, offset).asInstanceOf[V] 100 | offset += format.valueSize 101 | alphas(i) = format.readFloat(buf, offset) 102 | offset += 4 103 | } 104 | } 105 | else { 106 | val count: Int = format.readInt(buf, offset) 107 | offset += 4 108 | for (i <- 0 to count -1 ) { 109 | val index: Int = format.readInt(buf, offset) 110 | offset += 4 111 | values(i) = format.readValue(buf, offset).asInstanceOf[V] 112 | offset += format.valueSize 113 | alphas(i) = format.readFloat(buf, offset) 114 | offset += 4 115 | } 116 | } 117 | 118 | data.put(key, (values, alphas)) 119 | } 120 | return data 121 | } 122 | 123 | private def writeMap(data: mutable.HashMap[K, Array[V]]): Array[Byte] = { 124 | 125 | var buf: Array[Byte] = null 126 | 127 | if (format.denseColumn) { 128 | val len: Int = (format.valueSize * data.size * colKeys.size).toInt 129 | buf = new Array[Byte](format.keySize * data.size + len) 130 | } 131 | else { 132 | var nzcount: Int = 0 133 | for (values <- data.values) { 134 | for (value <- values) { 135 | if (!isZero(value)) { 136 | nzcount += 1 137 | } 138 | } 139 | } 140 | val len: Int = ((format.valueSize + 4) * nzcount).toInt 141 | buf = new Array[Byte](format.keySize * data.size + len) 142 | } 143 | 144 | var offset: Int = 0 145 | for ((k, v) <- data) { 146 | format.writeKey(k.asInstanceOf[Number], buf, offset) 147 | offset += format.keySize 148 | val values: Array[V] = v 149 | if (format.denseColumn) { 150 | for ( i <- 0 to colKeys.size.toInt - 1) { 151 | format.writeValue(values(i), buf, offset) 152 | offset += format.valueSize 153 | } 154 | } 155 | else { 156 | val counterIndex: Int = offset 157 | offset += 4 158 | var counter: Int = 0 159 | for (i <- 0 to values.length - 1) { 160 | val value: V = values(i) 161 | if (!isZero(value)) { 162 | format.write(i, buf, offset) 163 | offset += 4 164 | format.writeValue(value, buf, offset) 165 | offset += format.valueSize 166 | } 167 | counter += 1 168 | } 169 | format.write(counter, buf, counterIndex) 170 | } 171 | } 172 | return buf 173 | } 174 | 175 | protected def toLong(k : K) : Long 176 | 177 | protected def isZero(value: V): Boolean 178 | 179 | protected def subtract(value: V, delta : V): V 180 | 181 | protected def createValueArrayWithAlpha(size: Int): Array[(V, Float)] 182 | 183 | protected def createValueArray(size: Int): Array[(V)] 184 | } 185 | --------------------------------------------------------------------------------