├── .gitignore ├── LICENSE ├── README.md ├── assembly └── pom.xml ├── dev ├── change-version-to-2.10.sh └── change-version-to-2.11.sh ├── examples ├── pom.xml └── src │ └── main │ └── scala │ └── com │ └── github │ └── cloudml │ └── zen │ └── examples │ └── ml │ ├── AbstractParams.scala │ ├── BinaryClassification.scala │ ├── LDADriver.scala │ ├── LambdaMARTRunner.scala │ ├── MovieLensBSFM.scala │ ├── MovieLensFM.scala │ ├── MovieLensMVM.scala │ ├── MovieLensUtils.scala │ ├── NetflixPrizeFM.scala │ ├── NetflixPrizeMVM.scala │ └── NetflixPrizeUtils.scala ├── ml ├── pom.xml └── src │ ├── main │ └── scala │ │ ├── com │ │ └── github │ │ │ └── cloudml │ │ │ └── zen │ │ │ └── ml │ │ │ ├── clustering │ │ │ ├── LDA.scala │ │ │ ├── LDADefines.scala │ │ │ ├── LDAMetrics.scala │ │ │ ├── LDAModel.scala │ │ │ ├── README.md │ │ │ └── algorithm │ │ │ │ ├── AliasLDA.scala │ │ │ │ ├── FPlusLDA.scala │ │ │ │ ├── LDAAlgorithm.scala │ │ │ │ ├── LDAInferrer.scala │ │ │ │ ├── LDATrainer.scala │ │ │ │ ├── LDATrainerByDoc.scala │ │ │ │ ├── LDATrainerByWord.scala │ │ │ │ ├── LightLDA.scala │ │ │ │ ├── SparseLDA.scala │ │ │ │ └── ZenLDA.scala │ │ │ ├── linalg │ │ │ └── BLAS.scala │ │ │ ├── neuralNetwork │ │ │ ├── AdaDeltaUpdater.scala │ │ │ ├── AdaGradUpdater.scala │ │ │ ├── DBN.scala │ │ │ ├── EquilibratedUpdater.scala │ │ │ ├── Layer.scala │ │ │ ├── MLP.scala │ │ │ ├── MLPModel.scala │ │ │ ├── MomentumUpdater.scala │ │ │ ├── NNUtil.scala │ │ │ ├── RBM.scala │ │ │ ├── RBMModel.scala │ │ │ ├── README.md │ │ │ └── StackedRBM.scala │ │ │ ├── optimization │ │ │ ├── Gradient.scala │ │ │ ├── GradientDescent.scala │ │ │ ├── LBFGS.scala │ │ │ ├── Optimizer.scala │ │ │ └── Updater.scala │ │ │ ├── partitioner │ │ │ ├── BBRPartitioner.scala │ │ │ ├── DBHPartitioner.scala │ │ │ ├── EdgeDstPartitioner.scala │ │ │ ├── LBVertexRDDBuilder.scala │ │ │ └── VSDLPPartitioner.scala │ │ │ ├── recommendation │ │ │ ├── BSFM.scala │ │ │ ├── BSFMModel.scala │ │ │ ├── FM.scala │ │ │ ├── FMModel.scala │ │ │ ├── MVM.scala │ │ │ ├── MVMModel.scala │ │ │ └── README.md │ │ │ ├── regression │ │ │ ├── LogisticRegression.scala │ │ │ └── README.md │ │ │ ├── sampler │ │ │ ├── AliasTable.scala │ │ │ ├── CompositeSampler.scala │ │ │ ├── CumulativeDist.scala │ │ │ ├── DiscreteSampler.scala │ │ │ ├── FTree.scala │ │ │ ├── FlatDist.scala │ │ │ ├── MetropolisHastings.scala │ │ │ └── Sampler.scala │ │ │ ├── tree │ │ │ ├── DerivativeCalculator.scala │ │ │ ├── Histogram.scala │ │ │ ├── LambdaMART.scala │ │ │ ├── LambdaMARTDecisionTree.scala │ │ │ ├── Node.scala │ │ │ ├── ProbabilityFunctions.scala │ │ │ ├── SplitInfo.scala │ │ │ ├── TreeUtils.scala │ │ │ └── treeAggregatorFormat.scala │ │ │ └── util │ │ │ ├── CompressedVector.scala │ │ │ ├── Concurrent.scala │ │ │ ├── Logging.scala │ │ │ ├── SparkUtils.scala │ │ │ ├── TimeTracker.scala │ │ │ ├── Utils.scala │ │ │ ├── XORShiftRandom.scala │ │ │ └── modelSaveLoad.scala │ │ └── org │ │ └── apache │ │ └── spark │ │ └── graphx2 │ │ ├── Edge.scala │ │ ├── EdgeContext.scala │ │ ├── EdgeDirection.scala │ │ ├── EdgeRDD.scala │ │ ├── EdgeTriplet.scala │ │ ├── Graph.scala │ │ ├── GraphLoader.scala │ │ ├── GraphOps.scala │ │ ├── GraphXUtils.scala │ │ ├── PartitionStrategy.scala │ │ ├── TripletFields.java │ │ ├── VertexRDD.scala │ │ ├── impl │ │ ├── EdgeActiveness.java │ │ ├── EdgePartition.scala │ │ ├── EdgePartitionBuilder.scala │ │ ├── EdgeRDDImpl.scala │ │ ├── GraphImpl.scala │ │ ├── ReplicatedVertexView.scala │ │ ├── RoutingTablePartition.scala │ │ ├── ShippableVertexPartition.scala │ │ ├── VertexPartition.scala │ │ ├── VertexPartitionBase.scala │ │ ├── VertexPartitionBaseOps.scala │ │ ├── VertexRDDImpl.scala │ │ └── package.scala │ │ ├── package-info.java │ │ ├── package.scala │ │ └── util │ │ ├── BytecodeUtils.scala │ │ ├── GraphGenerators.scala │ │ ├── collection │ │ └── GraphXPrimitiveKeyOpenHashMap.scala │ │ ├── package-info.java │ │ └── package.scala │ └── test │ ├── resources │ ├── binary_classification_data.txt │ ├── log4j.properties │ └── regression_data.txt │ └── scala │ └── com │ └── github │ └── cloudml │ └── zen │ └── ml │ ├── clustering │ └── LDASuite.scala │ ├── neuralNetwork │ ├── DBNSuite.scala │ ├── MLPSuite.scala │ ├── RBMSuite.scala │ └── StackedRBMSuite.scala │ ├── recommendation │ ├── FMSuite.scala │ └── MVMSuite.scala │ ├── regression │ └── LogisticRegressionSuite.scala │ └── util │ ├── LocalSparkContext.scala │ ├── MinstDatasetReader.scala │ ├── MnistDatasetSuite.scala │ └── SharedSparkContext.scala ├── pom.xml └── scalastyle-config.xml /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.#* 3 | *#*# 4 | *.swp 5 | *.ipr 6 | *.iml 7 | *.iws 8 | *.pyc 9 | .idea/ 10 | .idea_modules/ 11 | build/*.jar 12 | .settings 13 | .cache 14 | cache 15 | .generated-mima* 16 | work/ 17 | out/ 18 | .DS_Store 19 | third_party/libmesos.so 20 | third_party/libmesos.dylib 21 | build/apache-maven* 22 | build/zinc* 23 | build/scala* 24 | conf/java-opts 25 | conf/*.sh 26 | conf/*.cmd 27 | conf/*.properties 28 | conf/*.conf 29 | conf/*.xml 30 | conf/slaves 31 | docs/_site 32 | docs/api 33 | target/ 34 | reports/ 35 | .project 36 | .classpath 37 | .scala_dependencies 38 | lib_managed/ 39 | src_managed/ 40 | project/boot/ 41 | project/plugins/project/build.properties 42 | project/build/target/ 43 | project/plugins/target/ 44 | project/plugins/lib_managed/ 45 | project/plugins/src_managed/ 46 | logs/ 47 | log/ 48 | spark-tests.log 49 | streaming-tests.log 50 | dependency-reduced-pom.xml 51 | .ensime 52 | .ensime_lucene 53 | checkpoint 54 | derby.log 55 | dist/ 56 | dev/create-release/*txt 57 | dev/create-release/*final 58 | spark-*-bin-*.tgz 59 | unit-tests.log 60 | /lib/ 61 | ec2/lib/ 62 | rat-results.txt 63 | scalastyle.txt 64 | scalastyle-output.xml 65 | tmp/ 66 | .cache-main 67 | .cache-tests 68 | 69 | # For Hive 70 | metastore_db/ 71 | metastore/ 72 | warehouse/ 73 | TempStatsStore/ 74 | sql/hive-thriftserver/test_warehouses 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Zen 2 | 3 | Zen aims to provide the largest scale and the most efficient machine learning platform on top of Spark, including but not limited to logistic regression, latent dirichilet allocation, factorization machines, and DNN. 4 | 5 | Zen is based on Apache Spark, MLlib and GraphX, but with sophisticated optimizations and newly-added features to optimize and scale up the machine learning training. Zen is developed with the mind that a successful machine learning platform should and must combine both data insight, ml algorithm and system experience together. 6 | 7 | ## Contributors 8 | 9 | * Bo Zhao ([@bhoppi](https://github.com/bhoppi)) 10 | 11 | * Guoqiang Li ([@witgo](https://github.com/witgo)) 12 | 13 | * Hucheng Zhou ([@hucheng](https://github.com/hucheng)) 14 | 15 | * Sendong Li ([@lisendong](https://github.com/lisendong)) 16 | -------------------------------------------------------------------------------- /dev/change-version-to-2.10.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | # Note that this will not necessarily work as intended with non-GNU sed (e.g. OS X) 21 | # Copy from spark 22 | BASEDIR=$(dirname $0)/.. 23 | find $BASEDIR -name 'pom.xml' | grep -v target \ 24 | | xargs -I {} sed -i -e 's/\(artifactId.*\)_2.11/\1_2.10/g' {} 25 | 26 | # Also update in parent POM 27 | sed -i -e '0,/2.112.10 in parent POM 27 | sed -i -e '0,/2.102.11 m 44 | } 45 | val mirror = runtimeMirror(getClass.getClassLoader) 46 | val instanceMirror = mirror.reflect(this) 47 | allAccessors.map { f => 48 | val paramName = f.name.toString 49 | val fieldMirror = instanceMirror.reflectField(f) 50 | val paramValue = fieldMirror.get 51 | s" $paramName:\t$paramValue" 52 | }.mkString("{\n", ",\n", "\n}") 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /examples/src/main/scala/com/github/cloudml/zen/examples/ml/BinaryClassification.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.examples.ml 19 | 20 | import com.github.cloudml.zen.ml.regression.LogisticRegression 21 | import org.apache.spark.graphx2.GraphXUtils 22 | import org.apache.spark.mllib.classification.LogisticRegressionModel 23 | import org.apache.spark.mllib.util.MLUtils 24 | import org.apache.spark.{SparkConf, SparkContext} 25 | import scopt.OptionParser 26 | 27 | object BinaryClassification { 28 | 29 | case class Params( 30 | input: String = null, 31 | out: String = null, 32 | numIterations: Int = 200, 33 | stepSize: Double = 1.0, 34 | l1: Double = 1e-2, 35 | epsilon: Double = 1e-4, 36 | useAdaGrad: Boolean = false, 37 | kryo: Boolean = false) extends AbstractParams[Params] 38 | 39 | def main(args: Array[String]) { 40 | val defaultParams = Params() 41 | val parser = new OptionParser[Params]("BinaryClassification") { 42 | head("BinaryClassification: an example app for LogisticRegression.") 43 | opt[Int]("numIterations") 44 | .text(s"number of iterations, default: ${defaultParams.numIterations}") 45 | .action((x, c) => c.copy(numIterations = x)) 46 | opt[Double]("epsilon") 47 | .text(s"epsilon (smoothing constant) for MIS, default: ${defaultParams.epsilon}") 48 | .action((x, c) => c.copy(epsilon = x)) 49 | opt[Unit]("kryo") 50 | .text("use Kryo serialization") 51 | .action((_, c) => c.copy(kryo = true)) 52 | opt[Double]("stepSize") 53 | .text(s"stepSize, default: ${defaultParams.stepSize}") 54 | .action((x, c) => c.copy(stepSize = x)) 55 | opt[Double]("l1") 56 | .text(s"L1 Regularization, default: ${defaultParams.l1} (auto)") 57 | .action((x, c) => c.copy(l1 = x)) 58 | opt[Unit]("adagrad") 59 | .text("use AdaGrad") 60 | .action((_, c) => c.copy(useAdaGrad = true)) 61 | arg[String]("") 62 | .required() 63 | .text("input paths (binary labeled data in the LIBSVM format)") 64 | .action((x, c) => c.copy(input = x)) 65 | arg[String]("") 66 | .required() 67 | .text("out paths (model)") 68 | .action((x, c) => c.copy(out = x)) 69 | note( 70 | """ 71 | |For example, the following command runs this app on a synthetic dataset: 72 | | 73 | | bin/spark-submit --class com.github.cloudml.zen.examples.ml.LogisticRegression \ 74 | | examples/target/scala-*/zen-examples-*.jar \ 75 | | --numIterations 200 --lambda 1.0 --kryo \ 76 | | data/mllib/kdda.txt 77 | | data/mllib/lr_model.txt 78 | """.stripMargin) 79 | } 80 | 81 | parser.parse(args, defaultParams).map { params => 82 | run(params) 83 | } getOrElse { 84 | System.exit(1) 85 | } 86 | } 87 | 88 | def run(params: Params): Unit = { 89 | val Params(input, out, numIterations, stepSize, l1, epsilon, useAdaGrad, useKryo) = params 90 | val conf = new SparkConf().setAppName(s"LogisticRegression with $params") 91 | if (useKryo) { 92 | GraphXUtils.registerKryoClasses(conf) 93 | // conf.set("spark.kryoserializer.buffer.mb", "8") 94 | } 95 | val sc = new SparkContext(conf) 96 | val dataSet = MLUtils.loadLibSVMFile(sc, input).zipWithUniqueId().map(_.swap).cache() 97 | val model = LogisticRegression.trainMIS(dataSet, numIterations, stepSize, l1, epsilon, useAdaGrad) 98 | val lm = new LogisticRegressionModel(model.weights, model.intercept, model.weights.size, 2) 99 | lm.save(sc, out) 100 | sc.stop() 101 | } 102 | 103 | } 104 | -------------------------------------------------------------------------------- /examples/src/main/scala/com/github/cloudml/zen/examples/ml/NetflixPrizeFM.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package com.github.cloudml.zen.examples.ml 18 | 19 | import com.github.cloudml.zen.ml.recommendation._ 20 | import com.github.cloudml.zen.ml.util.Logging 21 | import org.apache.spark.graphx2.GraphXUtils 22 | import org.apache.spark.{SparkConf, SparkContext} 23 | import scopt.OptionParser 24 | 25 | object NetflixPrizeFM extends Logging { 26 | 27 | case class Params( 28 | input: String = null, 29 | out: String = null, 30 | numIterations: Int = 200, 31 | numPartitions: Int = -1, 32 | stepSize: Double = 0.05, 33 | regular: String = "0.01,0.01,0.01", 34 | rank: Int = 64, 35 | useAdaGrad: Boolean = false, 36 | kryo: Boolean = false) extends AbstractParams[Params] 37 | 38 | def main(args: Array[String]) { 39 | val defaultParams = Params() 40 | val parser = new OptionParser[Params]("NetflixPrizeFM") { 41 | head("NetflixPrizeFM: an example app for FM.") 42 | opt[Int]("numIterations") 43 | .text(s"number of iterations, default: ${defaultParams.numIterations}") 44 | .action((x, c) => c.copy(numIterations = x)) 45 | opt[Int]("numPartitions") 46 | .text(s"number of partitions, default: ${defaultParams.numPartitions}") 47 | .action((x, c) => c.copy(numPartitions = x)) 48 | opt[Int]("rank") 49 | .text(s"dim of 2,3-way interactions, default: ${defaultParams.rank}") 50 | .action((x, c) => c.copy(rank = x)) 51 | opt[Unit]("kryo") 52 | .text("use Kryo serialization") 53 | .action((_, c) => c.copy(kryo = true)) 54 | opt[Double]("stepSize") 55 | .text(s"stepSize, default: ${defaultParams.stepSize}") 56 | .action((x, c) => c.copy(stepSize = x)) 57 | opt[String]("regular") 58 | .text( 59 | s""" 60 | |'r0,r1,r2' for SGD: r0=bias regularization, 61 | |r1=1-way regularization, r2=2-way and 3-way regularization, default: ${defaultParams.regular} (auto) 62 | """.stripMargin) 63 | .action((x, c) => c.copy(regular = x)) 64 | opt[Unit]("adagrad") 65 | .text("use AdaGrad") 66 | .action((_, c) => c.copy(useAdaGrad = true)) 67 | arg[String]("") 68 | .required() 69 | .text("input paths") 70 | .action((x, c) => c.copy(input = x)) 71 | arg[String]("") 72 | .required() 73 | .text("out paths (model)") 74 | .action((x, c) => c.copy(out = x)) 75 | note( 76 | """ 77 | |For example, the following command runs this app on a synthetic dataset: 78 | | 79 | | bin/spark-submit --class com.github.cloudml.zen.examples.ml.NetflixPrizeFM \ 80 | | examples/target/scala-*/zen-examples-*.jar \ 81 | | --rank 20 --numIterations 200 --regular 0.01,0.01,0.01 --kryo \ 82 | | data/mllib/nf_prize_dataset 83 | | data/mllib/MVM_model 84 | """.stripMargin) 85 | } 86 | 87 | parser.parse(args, defaultParams).map { params => 88 | run(params) 89 | } getOrElse { 90 | System.exit(1) 91 | } 92 | } 93 | 94 | def run(params: Params): Unit = { 95 | val Params(input, out, numIterations, numPartitions, stepSize, regular, 96 | rank, useAdaGrad, kryo) = params 97 | val regs = regular.split(",").map(_.toDouble) 98 | val l2 = (regs(0), regs(1), regs(2)) 99 | val checkpointDir = s"$out/checkpoint" 100 | val conf = new SparkConf().setAppName(s"FM with $params") 101 | if (kryo) { 102 | GraphXUtils.registerKryoClasses(conf) 103 | // conf.set("spark.kryoserializer.buffer.mb", "8") 104 | } 105 | val sc = new SparkContext(conf) 106 | sc.setCheckpointDir(checkpointDir) 107 | val (trainSet, testSet, _) = NetflixPrizeUtils.genSamplesWithTime(sc, input, numPartitions) 108 | val model = FM.trainRegression(trainSet, numIterations, stepSize, l2, rank, useAdaGrad, 1.0) 109 | model.save(sc, out) 110 | val rmse = model.loss(testSet) 111 | logInfo(f"Test RMSE: $rmse%1.4f") 112 | println(f"Test RMSE: $rmse%1.4f") 113 | sc.stop() 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /examples/src/main/scala/com/github/cloudml/zen/examples/ml/NetflixPrizeMVM.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package com.github.cloudml.zen.examples.ml 18 | 19 | import com.github.cloudml.zen.ml.recommendation._ 20 | import com.github.cloudml.zen.ml.util.Logging 21 | import org.apache.spark.graphx2.GraphXUtils 22 | import org.apache.spark.storage.StorageLevel 23 | import org.apache.spark.{SparkConf, SparkContext} 24 | import scopt.OptionParser 25 | 26 | object NetflixPrizeMVM extends Logging { 27 | 28 | case class Params( 29 | input: String = null, 30 | out: String = null, 31 | numIterations: Int = 200, 32 | numPartitions: Int = -1, 33 | stepSize: Double = 0.05, 34 | regular: Double = 0.01, 35 | rank: Int = 64, 36 | useAdaGrad: Boolean = false, 37 | useWeightedLambda: Boolean = false, 38 | kryo: Boolean = false) extends AbstractParams[Params] 39 | 40 | def main(args: Array[String]) { 41 | val defaultParams = Params() 42 | val parser = new OptionParser[Params]("MVM") { 43 | head("NetflixPrizeMVM: an example app for MVM.") 44 | opt[Int]("numIterations") 45 | .text(s"number of iterations, default: ${defaultParams.numIterations}") 46 | .action((x, c) => c.copy(numIterations = x)) 47 | opt[Int]("numPartitions") 48 | .text(s"number of partitions, default: ${defaultParams.numPartitions}") 49 | .action((x, c) => c.copy(numPartitions = x)) 50 | opt[Int]("rank") 51 | .text(s"dim of 2-way interactions, default: ${defaultParams.rank}") 52 | .action((x, c) => c.copy(rank = x)) 53 | opt[Unit]("kryo") 54 | .text("use Kryo serialization") 55 | .action((_, c) => c.copy(kryo = true)) 56 | opt[Double]("stepSize") 57 | .text(s"stepSize, default: ${defaultParams.stepSize}") 58 | .action((x, c) => c.copy(stepSize = x)) 59 | opt[Double]("regular") 60 | .text( 61 | s"L2 regularization, default: ${defaultParams.regular}".stripMargin) 62 | .action((x, c) => c.copy(regular = x)) 63 | opt[Unit]("adagrad") 64 | .text("use AdaGrad") 65 | .action((_, c) => c.copy(useAdaGrad = true)) 66 | opt[Unit]("weightedLambda") 67 | .text("use weighted lambda regularization") 68 | .action((_, c) => c.copy(useWeightedLambda = true)) 69 | arg[String]("") 70 | .required() 71 | .text("input paths") 72 | .action((x, c) => c.copy(input = x)) 73 | arg[String]("") 74 | .required() 75 | .text("out paths (model)") 76 | .action((x, c) => c.copy(out = x)) 77 | note( 78 | """ 79 | |For example, the following command runs this app on a synthetic dataset: 80 | | 81 | | bin/spark-submit --class com.github.cloudml.zen.examples.ml.NetflixPrizeMVM \ 82 | | examples/target/scala-*/zen-examples-*.jar \ 83 | | --rank 20 --numIterations 200 --regular 0.01 --kryo \ 84 | | data/mllib/nf_prize_dataset 85 | | data/mllib/MVM_model 86 | """.stripMargin) 87 | } 88 | 89 | parser.parse(args, defaultParams).map { params => 90 | run(params) 91 | } getOrElse { 92 | System.exit(1) 93 | } 94 | } 95 | 96 | def run(params: Params): Unit = { 97 | val Params(input, out, numIterations, numPartitions, stepSize, regular, 98 | rank, useAdaGrad, useWeightedLambda, kryo) = params 99 | val checkpointDir = s"$out/checkpoint" 100 | val conf = new SparkConf().setAppName(s"MVM with $params") 101 | if (kryo) { 102 | GraphXUtils.registerKryoClasses(conf) 103 | // conf.set("spark.kryoserializer.buffer.mb", "8") 104 | } 105 | val sc = new SparkContext(conf) 106 | sc.setCheckpointDir(checkpointDir) 107 | val (trainSet, testSet, views) = NetflixPrizeUtils.genSamplesWithTime(sc, input, numPartitions) 108 | val fm = new MVMRegression(trainSet, stepSize, views, regular, 0.0, rank, 109 | useAdaGrad, useWeightedLambda, 1.0, StorageLevel.MEMORY_AND_DISK) 110 | fm.run(numIterations) 111 | val model = fm.saveModel() 112 | model.save(sc, out) 113 | val rmse = model.loss(testSet) 114 | logInfo(f"Test RMSE: $rmse%1.4f") 115 | println(f"Test RMSE: $rmse%1.4f") 116 | sc.stop() 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /examples/src/main/scala/com/github/cloudml/zen/examples/ml/NetflixPrizeUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.examples.ml 19 | 20 | import java.text.SimpleDateFormat 21 | import java.util.{Locale, TimeZone} 22 | 23 | import breeze.linalg.{SparseVector => BSV} 24 | import org.apache.spark.SparkContext 25 | import org.apache.spark.mllib.linalg.{SparseVector => SSV} 26 | import org.apache.spark.mllib.regression.LabeledPoint 27 | import org.apache.spark.rdd.RDD 28 | import org.apache.spark.storage.StorageLevel 29 | 30 | import scala.collection.mutable.ArrayBuffer 31 | 32 | object NetflixPrizeUtils { 33 | 34 | def genSamplesWithTime( 35 | sc: SparkContext, 36 | input: String, 37 | numPartitions: Int = -1, 38 | newLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK): 39 | (RDD[(Long, LabeledPoint)], RDD[(Long, LabeledPoint)], Array[Long]) = { 40 | 41 | val probeFile = s"$input/probe.txt" 42 | val dataSetFile = s"$input/training_set/*" 43 | val probe = sc.wholeTextFiles(probeFile).flatMap { case (fileName, txt) => 44 | val ab = new ArrayBuffer[(Int, Int)] 45 | var lastMovieId = -1 46 | var lastUserId = -1 47 | txt.split("\n").filter(_.nonEmpty).foreach { line => 48 | if (line.endsWith(":")) { 49 | lastMovieId = line.split(":").head.toInt 50 | } else { 51 | lastUserId = line.toInt 52 | val pair = (lastUserId, lastMovieId) 53 | ab += pair 54 | } 55 | } 56 | ab.toSeq 57 | }.collect().toSet 58 | 59 | val simpleDateFormat = new SimpleDateFormat("yyyy-MM-dd", Locale.ROOT) 60 | simpleDateFormat.setTimeZone(TimeZone.getTimeZone("GMT+08:00")) 61 | var nfPrize = sc.wholeTextFiles(dataSetFile, sc.defaultParallelism).flatMap { case (fileName, txt) => 62 | val Array(movieId, csv) = txt.split(":") 63 | csv.split("\n").filter(_.nonEmpty).map { line => 64 | val Array(userId, rating, timestamp) = line.split(",") 65 | val day = simpleDateFormat.parse(timestamp).getTime / (1000L * 60 * 60 * 24) 66 | ((userId.toInt, movieId.toInt), rating.toDouble, day.toInt) 67 | } 68 | } 69 | nfPrize = if (numPartitions > 0) { 70 | nfPrize.repartition(numPartitions) 71 | } else { 72 | nfPrize.repartition(sc.defaultParallelism) 73 | } 74 | nfPrize.persist(newLevel).count() 75 | 76 | val maxUserId = nfPrize.map(_._1._1).max + 1 77 | val maxMovieId = nfPrize.map(_._1._2).max + 1 78 | val maxTime = nfPrize.map(_._3).max() 79 | val minTime = nfPrize.map(_._3).min() 80 | val maxDay = maxTime - minTime + 1 81 | val numFeatures = maxUserId + maxMovieId + maxDay 82 | 83 | val testSet = nfPrize.mapPartitions { iter => 84 | iter.filter(t => probe.contains(t._1)).map { 85 | case ((userId, movieId), rating, timestamp) => 86 | val sv = BSV.zeros[Double](numFeatures) 87 | sv(userId) = 1.0 88 | sv(movieId + maxUserId) = 1.0 89 | sv(timestamp - minTime + maxUserId + maxMovieId) = 1.0 90 | new LabeledPoint(rating, new SSV(sv.length, sv.index.slice(0, sv.used), sv.data.slice(0, sv.used))) 91 | } 92 | }.zipWithIndex().map(_.swap).persist(newLevel) 93 | testSet.count() 94 | 95 | val trainSet = nfPrize.mapPartitions { iter => 96 | iter.filter(t => !probe.contains(t._1)).map { 97 | case ((userId, movieId), rating, timestamp) => 98 | val sv = BSV.zeros[Double](numFeatures) 99 | sv(userId) = 1.0 100 | sv(movieId + maxUserId) = 1.0 101 | sv(timestamp - minTime + maxUserId + maxMovieId) = 1.0 102 | new LabeledPoint(rating, new SSV(sv.length, sv.index.slice(0, sv.used), sv.data.slice(0, sv.used))) 103 | } 104 | }.zipWithIndex().map(_.swap).persist(newLevel) 105 | trainSet.count() 106 | nfPrize.unpersist() 107 | /** 108 | * The first view contains [0,maxUserId),The second view contains [maxUserId, maxMovieId + maxUserId)... 109 | * The third contains [maxMovieId + maxUserId, numFeatures) The last id equals the number of features 110 | */ 111 | val views = Array(maxUserId, maxMovieId + maxUserId, numFeatures).map(_.toLong) 112 | 113 | (trainSet, testSet, views) 114 | 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /ml/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 18 | 20 | 4.0.0 21 | 22 | com.github.cloudml.zen 23 | zen-parent_2.11 24 | 0.4-SNAPSHOT 25 | ../pom.xml 26 | 27 | zen-ml_2.11 28 | Zen Project ML Library 29 | https://github.com/cloudml/zen/ 30 | 31 | UTF-8 32 | 33 | 34 | 35 | org.apache.spark 36 | spark-mllib_${scala.binary.version} 37 | 38 | 39 | org.scalanlp 40 | breeze_${scala.binary.version} 41 | 42 | 43 | com.github.fommil.netlib 44 | native_system-java 45 | 1.1 46 | 47 | 48 | com.github.fommil.netlib 49 | native_ref-java 50 | 1.1 51 | 52 | 53 | me.lemire.integercompression 54 | JavaFastPFOR 55 | 0.1.6 56 | 57 | 58 | org.apache.commons 59 | commons-math3 60 | 61 | 62 | 63 | junit 64 | junit 65 | test 66 | 67 | 68 | com.novocode 69 | junit-interface 70 | test 71 | 72 | 73 | org.scalatest 74 | scalatest_${scala.binary.version} 75 | test 76 | 77 | 78 | 79 | target/scala-${scala.binary.version}/classes 80 | target/scala-${scala.binary.version}/test-classes 81 | 82 | 83 | org.scalatest 84 | scalatest-maven-plugin 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/clustering/LDADefines.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.clustering 19 | 20 | import java.util.Random 21 | 22 | import breeze.collection.mutable.SparseArray 23 | import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} 24 | import com.github.cloudml.zen.ml.sampler._ 25 | import com.github.cloudml.zen.ml.util.{BVCompressor, BVDecompressor, CompressedVector} 26 | import org.apache.spark.SparkConf 27 | import org.apache.spark.graphx2._ 28 | import org.apache.spark.rdd.RDD 29 | 30 | 31 | object LDADefines { 32 | type DocId = VertexId 33 | type WordId = VertexId 34 | type Count = Int 35 | type TC = CompressedVector 36 | type TA = Int 37 | type BOW = (Long, BSV[Count]) 38 | type Nwk = BV[Count] 39 | type Ndk = BSV[Count] 40 | type Nvk = BV[Count] 41 | type NwkPair = (VertexId, Nwk) 42 | type NvkPair = (VertexId, Nvk) 43 | 44 | val sv_formatVersionV2_0 = "2.0" 45 | val sv_classNameV2_0 = "com.github.cloudml.zen.ml.clustering.DistributedLDAModel" 46 | val cs_numTopics = "zen.lda.numTopics" 47 | val cs_numPartitions = "zen.lda.numPartitions" 48 | val cs_sampleRate = "zen.lda.sampleRate" 49 | val cs_LDAAlgorithm = "zen.lda.LDAAlgorithm" 50 | val cs_storageLevel = "zen.lda.storageLevel" 51 | val cs_partStrategy = "zen.lda.partStrategy" 52 | val cs_initStrategy = "zen.lda.initStrategy" 53 | val cs_chkptInterval = "zen.lda.chkptInterval" 54 | val cs_evalMetric = "zen.lda.evalMetric" 55 | val cs_saveInterval = "zen.lda.saveInterval" 56 | val cs_inputPath = "zen.lda.inputPath" 57 | val cs_outputpath = "zen.lda.outputPath" 58 | val cs_saveAsSolid = "zen.lda.saveAsSolid" 59 | val cs_numThreads = "zen.lda.numThreads" 60 | val cs_ignoreDocId = "zen.lda.ignoreDocId" 61 | val cs_saveTransposed = "zen.lda.saveTransposed" 62 | 63 | // make docId always be negative, so that the doc vertex always be the dest vertex 64 | @inline def genNewDocId(docId: Long): VertexId = { 65 | assert(docId >= 0) 66 | -(docId + 1L) 67 | } 68 | 69 | @inline def isDocId(vid: VertexId): Boolean = vid < 0L 70 | 71 | @inline def isTermId(vid: VertexId): Boolean = vid >= 0L 72 | 73 | def uniformDistSampler(gen: Random, 74 | tokens: Array[Int], 75 | topics: Array[Int], 76 | numTopics: Int): BSV[Count] = { 77 | val docTopics = BSV.zeros[Count](numTopics) 78 | var i = 0 79 | while (i < tokens.length) { 80 | val topic = gen.nextInt(numTopics) 81 | topics(i) = topic 82 | docTopics(topic) += 1 83 | i += 1 84 | } 85 | docTopics 86 | } 87 | 88 | def compressCounterRDD(model: RDD[NvkPair], numTopics: Int): RDD[(VertexId, TC)] = { 89 | model.mapPartitions(iter => { 90 | val comp = new BVCompressor(numTopics) 91 | iter.map(Function.tupled((vid, counter) => 92 | (vid, comp.BV2CV(counter)) 93 | )) 94 | }, preservesPartitioning = true) 95 | } 96 | 97 | def decompressVertexRDD(verts: RDD[(VertexId, TC)], numTopics: Int): RDD[NvkPair] = { 98 | verts.mapPartitions(iter => { 99 | val decomp = new BVDecompressor(numTopics) 100 | iter.map(Function.tupled((vid, cv) => 101 | (vid, decomp.CV2BV(cv)) 102 | )) 103 | }, preservesPartitioning = true) 104 | } 105 | 106 | def registerKryoClasses(conf: SparkConf): Unit = { 107 | conf.registerKryoClasses(Array( 108 | classOf[TC], 109 | classOf[BOW], 110 | classOf[NwkPair], 111 | classOf[AliasTable[Object]], classOf[FTree[Object]], // for some partitioners 112 | classOf[BSV[Object]], classOf[BDV[Object]], 113 | classOf[SparseArray[Object]], // member of BSV 114 | classOf[Array[Int]] 115 | )) 116 | } 117 | 118 | def toBDV(bv: BV[Count]): BDV[Count] = bv match { 119 | case v: BDV[Count] => v 120 | case v: BSV[Count] => 121 | val arr = new Array[Count](bv.length) 122 | val used = v.used 123 | val index = v.index 124 | val data = v.data 125 | var i = 0 126 | while (i < used) { 127 | arr(index(i)) = data(i) 128 | i += 1 129 | } 130 | new BDV(arr) 131 | } 132 | 133 | def toBSV(bv: BV[Count], used: Int): BSV[Count] = bv match { 134 | case v: BSV[Count] => v 135 | case v: BDV[Count] => 136 | val index = new Array[Int](used) 137 | val data = new Array[Count](used) 138 | val arr = v.data 139 | var i = 0 140 | var j = 0 141 | while (i < used) { 142 | val cnt = arr(j) 143 | if (cnt > 0) { 144 | index(i) = j 145 | data(i) = cnt 146 | i += 1 147 | } 148 | j += 1 149 | } 150 | new BSV(index, data, used, bv.length) 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/clustering/LDAMetrics.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.clustering 19 | 20 | trait LDAMetrics { 21 | def getTotal: Double 22 | def getWord: Double 23 | def getDoc: Double 24 | def output(writer: String => Unit): Unit 25 | } 26 | 27 | class LDAPerplexity(val pplx: Double, val wpplx: Double, val dpplx: Double) extends LDAMetrics { 28 | override def getTotal: Double = pplx 29 | 30 | override def getWord: Double = wpplx 31 | 32 | override def getDoc: Double = dpplx 33 | 34 | override def output(writer: String => Unit): Unit = { 35 | val o = s"perplexity=$getTotal, word pplx=$getWord, doc pplx=$getDoc" 36 | writer(o) 37 | } 38 | } 39 | 40 | class LDALogLikelihood(val wllh: Double, val dllh: Double) extends LDAMetrics { 41 | override def getTotal: Double = wllh + dllh 42 | 43 | override def getWord: Double = wllh 44 | 45 | override def getDoc: Double = dllh 46 | 47 | override def output(writer: String => Unit): Unit = { 48 | val o = s"total llh=$getTotal, word llh=$getWord, doc llh=$getDoc" 49 | writer(o) 50 | } 51 | } 52 | 53 | object LDAMetrics { 54 | def apply(evalMetric: String, lda: LDA): LDAMetrics = { 55 | val verts = lda.verts 56 | val topicCounters = lda.topicCounters 57 | val numTokens = lda.numTokens 58 | val numTerms = lda.numTerms 59 | val alpha = lda.alpha 60 | val alphaAS = lda.alphaAS 61 | val beta = lda.beta 62 | evalMetric match { 63 | case "pplx" => 64 | lda.algo.calcPerplexity(lda.edges, verts, topicCounters, numTokens, numTerms, alpha, alphaAS, beta) 65 | case "llh" => 66 | lda.algo.calcLogLikelihood(verts, topicCounters, numTokens, lda.numDocs, numTerms, alpha, alphaAS, beta) 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/clustering/README.md: -------------------------------------------------------------------------------- 1 | # LDA Highlights: 2 | * Industry strength scalability and efficiency 3 | 4 | * Support billions documents, 100s millions words, with tens of millions of topics 5 | 6 | * Support asymmetric Dirichlet prior over the document topic distributions ("Rethinking LDA: Why Priors Matters") 7 | 8 | * Support multiple gibbs sampling algorithms, SparseLDA, AliasLDA and LightLDA, while with a little bit difference in sampling formula decomposition 9 | 10 | * Support multiple graph partitioning strategies beyond that provided by GraphX, such as Degree-Based Hashing, Hybrid-Cut, Greedy strategy 11 | 12 | * Support duplicate topic merge 13 | 14 | ## Road map 15 | 16 | * Propose new proposal distribution in Hasting sampling that gets the sweat point between system performance and model convergence 17 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/neuralNetwork/AdaDeltaUpdater.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.neuralNetwork 19 | 20 | import com.github.cloudml.zen.ml.linalg.BLAS 21 | import com.github.cloudml.zen.ml.util.SparkUtils 22 | import org.apache.spark.annotation.Experimental 23 | import org.apache.spark.mllib.linalg.{Vector => SV, DenseVector => SDV} 24 | import com.github.cloudml.zen.ml.optimization._ 25 | 26 | @Experimental 27 | private[ml] class AdaDeltaUpdater( 28 | val rho: Double, 29 | val epsilon: Double, 30 | val momentum: Double) extends Updater { 31 | require(rho > 0 && rho < 1) 32 | require(momentum >= 0 && momentum < 1) 33 | @transient private var gradientSum: SDV = null 34 | @transient private var deltaSum: SDV = null 35 | @transient private var momentumSum: SDV = null 36 | 37 | protected def l2( 38 | weightsOld: SV, 39 | gradient: SV, 40 | stepSize: Double, 41 | iter: Int, 42 | regParam: Double): Double = { 43 | 0D 44 | } 45 | 46 | override def compute( 47 | weightsOld: SV, 48 | gradient: SV, 49 | stepSize: Double, 50 | iter: Int, 51 | regParam: Double): (SV, Double) = { 52 | if (momentum > 0 && momentumSum == null) { 53 | momentumSum = new SDV(new Array[Double](weightsOld.size)) 54 | } 55 | if (deltaSum == null) { 56 | deltaSum = new SDV(new Array[Double](weightsOld.size)) 57 | gradientSum = new SDV(new Array[Double](weightsOld.size)) 58 | } 59 | 60 | val reg = l2(weightsOld, gradient, stepSize, iter, regParam) 61 | if (momentum > 0) { 62 | BLAS.axpy(momentum, momentumSum, gradient) 63 | this.synchronized { 64 | BLAS.copy(gradient, momentumSum) 65 | } 66 | } 67 | 68 | val grad = SparkUtils.toBreeze(gradient) 69 | val g2 = grad :* grad 70 | this.synchronized { 71 | BLAS.scal(rho, gradientSum) 72 | BLAS.axpy(1 - rho, SparkUtils.fromBreeze(g2), gradientSum) 73 | } 74 | 75 | for (i <- 0 until grad.length) { 76 | val rmsDelta = math.sqrt(epsilon + deltaSum(i)) 77 | val rmsGrad = math.sqrt(epsilon + gradientSum(i)) 78 | grad(i) *= rmsDelta / rmsGrad 79 | } 80 | 81 | val d2 = grad :* grad 82 | this.synchronized { 83 | BLAS.scal(rho, deltaSum) 84 | BLAS.axpy(1 - rho, SparkUtils.fromBreeze(d2), deltaSum) 85 | } 86 | 87 | BLAS.axpy(-stepSize, gradient, weightsOld) 88 | (weightsOld, reg) 89 | } 90 | 91 | } 92 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/neuralNetwork/AdaGradUpdater.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.neuralNetwork 19 | 20 | import com.github.cloudml.zen.ml.linalg.BLAS 21 | import com.github.cloudml.zen.ml.util.SparkUtils 22 | import org.apache.spark.annotation.Experimental 23 | import org.apache.spark.mllib.linalg.{Vector => SV, DenseVector => SDV} 24 | import com.github.cloudml.zen.ml.optimization._ 25 | 26 | @Experimental 27 | class AdaGradUpdater( 28 | val rho: Double, 29 | val epsilon: Double, 30 | val gamma: Double, 31 | val momentum: Double) extends Updater { 32 | require(rho >= 0 && rho < 1) 33 | require(momentum >= 0 && momentum < 1) 34 | @transient private var etaSum: SDV = null 35 | @transient private var momentumSum: SDV = null 36 | 37 | protected def l2( 38 | weightsOld: SV, 39 | gradient: SV, 40 | stepSize: Double, 41 | iter: Int, 42 | regParam: Double): Double = { 43 | 0D 44 | } 45 | 46 | override def compute( 47 | weightsOld: SV, 48 | gradient: SV, 49 | stepSize: Double, 50 | iter: Int, 51 | regParam: Double): (SV, Double) = { 52 | if (momentum > 0 && momentumSum == null) { 53 | momentumSum = new SDV(new Array[Double](weightsOld.size)) 54 | } 55 | if (etaSum == null) { 56 | etaSum = new SDV(new Array[Double](weightsOld.size)) 57 | } 58 | val reg = l2(weightsOld, gradient, stepSize, iter, regParam) 59 | if (momentum > 0) { 60 | BLAS.axpy(momentum, momentumSum, gradient) 61 | this.synchronized { 62 | BLAS.copy(gradient, momentumSum) 63 | } 64 | } 65 | 66 | val grad = SparkUtils.toBreeze(gradient) 67 | val g2 = grad :* grad 68 | this.synchronized { 69 | if (rho > 0D && rho < 1D) { 70 | BLAS.scal(rho, etaSum) 71 | } 72 | BLAS.axpy(1D, SparkUtils.fromBreeze(g2), etaSum) 73 | } 74 | 75 | for (i <- 0 until grad.length) { 76 | grad(i) *= gamma / (epsilon + math.sqrt(etaSum(i))) 77 | } 78 | BLAS.axpy(-stepSize, SparkUtils.fromBreeze(grad), weightsOld) 79 | (weightsOld, reg) 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/neuralNetwork/DBN.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.neuralNetwork 19 | 20 | import com.github.cloudml.zen.ml.util.Logging 21 | import org.apache.spark.annotation.Experimental 22 | import org.apache.spark.mllib.linalg.{Vector => SV} 23 | import org.apache.spark.rdd.RDD 24 | 25 | @Experimental 26 | class DBN(val stackedRBM: StackedRBM) 27 | extends Logging with Serializable { 28 | lazy val mlp: MLPModel = { 29 | val nn = stackedRBM.toMLP() 30 | val lastLayer = nn.innerLayers(nn.numLayer - 1) 31 | NNUtil.initUniformDistWeight(lastLayer.weight, 0.01) 32 | nn.innerLayers(nn.numLayer - 1) = new SoftMaxLayer(lastLayer.weight, lastLayer.bias) 33 | nn 34 | } 35 | 36 | def this(topology: Array[Int]) { 37 | this(new StackedRBM(topology)) 38 | } 39 | } 40 | 41 | @Experimental 42 | object DBN extends Logging { 43 | def train( 44 | data: RDD[(SV, SV)], 45 | batchSize: Int, 46 | numIteration: Int, 47 | topology: Array[Int], 48 | fraction: Double, 49 | learningRate: Double, 50 | weightCost: Double): DBN = { 51 | val dbn = new DBN(topology) 52 | pretrain(data, batchSize, numIteration, dbn, fraction, learningRate, weightCost) 53 | finetune(data, batchSize, numIteration, dbn, fraction, learningRate, weightCost) 54 | dbn 55 | } 56 | 57 | def pretrain( 58 | data: RDD[(SV, SV)], 59 | batchSize: Int, 60 | numIteration: Int, 61 | dbn: DBN, 62 | fraction: Double, 63 | learningRate: Double, 64 | weightCost: Double): DBN = { 65 | val stackedRBM = dbn.stackedRBM 66 | val numLayer = stackedRBM.innerRBMs.length 67 | StackedRBM.train(data.map(_._1), batchSize, numIteration, stackedRBM, 68 | fraction, learningRate, weightCost, numLayer - 1) 69 | dbn 70 | } 71 | 72 | def finetune(data: RDD[(SV, SV)], 73 | batchSize: Int, 74 | numIteration: Int, 75 | dbn: DBN, 76 | fraction: Double, 77 | learningRate: Double, 78 | weightCost: Double): DBN = { 79 | MLP.train(data, batchSize, numIteration, dbn.mlp, 80 | fraction, learningRate, weightCost) 81 | dbn 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/neuralNetwork/EquilibratedUpdater.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.neuralNetwork 19 | 20 | import com.github.cloudml.zen.ml.linalg.BLAS 21 | import com.github.cloudml.zen.ml.util.SparkUtils._ 22 | import com.github.cloudml.zen.ml.util.Utils 23 | import org.apache.spark.annotation.Experimental 24 | import org.apache.spark.mllib.linalg.{Vector => SV, DenseVector => SDV, Vectors} 25 | import com.github.cloudml.zen.ml.optimization._ 26 | 27 | /** 28 | * Equilibrated Gradient Descent the paper: 29 | * RMSProp and equilibrated adaptive learning rates for non-convex optimization 30 | * @param epsilon 31 | * @param momentum 32 | */ 33 | @Experimental 34 | class EquilibratedUpdater( 35 | val epsilon: Double, 36 | val gamma: Double, 37 | val momentum: Double) extends Updater { 38 | require(momentum >= 0 && momentum < 1) 39 | @transient private var etaSum: SDV = null 40 | @transient private var momentumSum: SDV = null 41 | 42 | protected def l2( 43 | weightsOld: SV, 44 | gradient: SV, 45 | stepSize: Double, 46 | iter: Int, 47 | regParam: Double): Double = { 48 | 0D 49 | } 50 | 51 | override def compute( 52 | weightsOld: SV, 53 | gradient: SV, 54 | stepSize: Double, 55 | iter: Int, 56 | regParam: Double): (SV, Double) = { 57 | if (etaSum == null) etaSum = new SDV(new Array[Double](weightsOld.size)) 58 | val reg = l2(weightsOld, gradient, stepSize, iter, regParam) 59 | 60 | val grad = toBreeze(gradient) 61 | val e = toBreeze(etaSum) 62 | for (i <- 0 until grad.length) { 63 | e(i) += math.pow(grad(i) * Utils.random.nextGaussian(), 2) 64 | } 65 | 66 | etaSum.synchronized { 67 | for (i <- 0 until grad.length) { 68 | grad(i) = gamma * grad(i) / (epsilon + math.sqrt(etaSum(i) / iter)) 69 | } 70 | } 71 | 72 | if (momentum > 0) { 73 | if (momentumSum == null) momentumSum = new SDV(new Array[Double](weightsOld.size)) 74 | momentumSum.synchronized { 75 | BLAS.axpy(momentum, momentumSum, gradient) 76 | BLAS.copy(gradient, momentumSum) 77 | } 78 | } 79 | 80 | BLAS.axpy(-stepSize, gradient, weightsOld) 81 | (weightsOld, reg) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/neuralNetwork/MomentumUpdater.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.neuralNetwork 19 | 20 | import com.github.cloudml.zen.ml.linalg.BLAS 21 | import com.github.cloudml.zen.ml.util.SparkUtils 22 | import org.apache.spark.annotation.Experimental 23 | 24 | import org.apache.spark.mllib.linalg.{Vector => SV, DenseVector => SDV} 25 | import com.github.cloudml.zen.ml.optimization._ 26 | 27 | @Experimental 28 | class MomentumUpdater(val momentum: Double) extends Updater { 29 | 30 | assert(momentum > 0 && momentum < 1) 31 | 32 | @transient private var momentumSum: SDV = null 33 | 34 | protected def l2( 35 | weightsOld: SV, 36 | gradient: SV, 37 | stepSize: Double, 38 | iter: Int, 39 | regParam: Double): Double = { 40 | 0D 41 | } 42 | 43 | override def compute( 44 | weightsOld: SV, 45 | gradient: SV, 46 | stepSize: Double, 47 | iter: Int, 48 | regParam: Double): (SV, Double) = { 49 | if (momentumSum == null) { 50 | momentumSum = new SDV(new Array[Double](weightsOld.size)) 51 | } 52 | val reg = l2(weightsOld, gradient, stepSize, iter, regParam) 53 | if (momentum > 0) { 54 | BLAS.axpy(momentum, momentumSum, gradient) 55 | this.synchronized { 56 | BLAS.copy(gradient, momentumSum) 57 | } 58 | } 59 | BLAS.axpy(-stepSize, gradient, weightsOld) 60 | (weightsOld, reg) 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/neuralNetwork/README.md: -------------------------------------------------------------------------------- 1 | # ANN 2 | 3 | ## Road map 4 | 5 | * Support Parameter Server(Asynchronous Parallel) 6 | 7 | * Support CNN 8 | 9 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/neuralNetwork/StackedRBM.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.neuralNetwork 19 | 20 | import breeze.linalg.{DenseMatrix => BDM} 21 | import com.github.cloudml.zen.ml.util.Logging 22 | import com.github.cloudml.zen.ml.util.SparkUtils._ 23 | import org.apache.spark.annotation.Experimental 24 | import org.apache.spark.broadcast.Broadcast 25 | import org.apache.spark.mllib.linalg.{Vector => SV} 26 | import org.apache.spark.rdd.RDD 27 | 28 | @Experimental 29 | class StackedRBM(val innerRBMs: Array[RBMModel]) 30 | extends Logging with Serializable { 31 | def this(topology: Array[Int]) { 32 | this(StackedRBM.initializeRBMs(topology)) 33 | } 34 | 35 | def numLayer: Int = innerRBMs.length 36 | 37 | def numInput: Int = innerRBMs.head.numIn 38 | 39 | def numOut: Int = innerRBMs.last.numOut 40 | 41 | def forward(visible: BDM[Double], toLayer: Int): BDM[Double] = { 42 | var x = visible 43 | for (layer <- 0 until toLayer) { 44 | x = innerRBMs(layer).forward(x) 45 | } 46 | x 47 | } 48 | 49 | def forward(visible: BDM[Double]): BDM[Double] = { 50 | forward(visible, numLayer) 51 | } 52 | 53 | def topology: Array[Int] = { 54 | val topology = new Array[Int](numLayer + 1) 55 | topology(0) = numInput 56 | for (i <- 1 to numLayer) { 57 | topology(i) = innerRBMs(i - 1).numOut 58 | } 59 | topology 60 | } 61 | 62 | def toMLP(): MLPModel = { 63 | val layers = new Array[Layer](numLayer) 64 | for (layer <- 0 until numLayer) { 65 | layers(layer) = innerRBMs(layer).hiddenLayer 66 | } 67 | new MLPModel(layers, innerRBMs.map(_.dropoutRate)) 68 | } 69 | } 70 | 71 | object StackedRBM extends Logging { 72 | def train( 73 | data: RDD[SV], 74 | batchSize: Int, 75 | numIteration: Int, 76 | topology: Array[Int], 77 | fraction: Double, 78 | learningRate: Double, 79 | weightCost: Double): StackedRBM = { 80 | train(data, batchSize, numIteration, new StackedRBM(topology), fraction, learningRate, weightCost) 81 | } 82 | 83 | def train( 84 | data: RDD[SV], 85 | batchSize: Int, 86 | numIteration: Int, 87 | stackedRBM: StackedRBM, 88 | fraction: Double, 89 | learningRate: Double, 90 | weightCost: Double, 91 | maxLayer: Int = -1): StackedRBM = { 92 | val trainLayer = if (maxLayer > -1D) { 93 | maxLayer 94 | } else { 95 | stackedRBM.numLayer 96 | } 97 | 98 | for (layer <- 0 until trainLayer) { 99 | logInfo(s"Train ($layer/$trainLayer)") 100 | val broadcast = data.context.broadcast(stackedRBM) 101 | val dataBatch = forward(data, broadcast, layer) 102 | val rbm = stackedRBM.innerRBMs(layer) 103 | RBM.train(dataBatch, batchSize, numIteration, rbm, 104 | fraction, learningRate, weightCost) 105 | // broadcast.destroy(blocking = false) 106 | } 107 | stackedRBM 108 | } 109 | 110 | private def forward( 111 | data: RDD[SV], 112 | broadcast: Broadcast[StackedRBM], 113 | toLayer: Int): RDD[SV] = { 114 | if (toLayer > 0) { 115 | data.mapPartitions { itr => 116 | val stackedRBM = broadcast.value 117 | itr.map { data => 118 | val input = new BDM(data.size, 1, data.toArray) 119 | val x = stackedRBM.forward(input, toLayer) 120 | fromBreeze(x(::, 0)) 121 | } 122 | } 123 | } else { 124 | data 125 | } 126 | } 127 | 128 | def initializeRBMs(topology: Array[Int]): Array[RBMModel] = { 129 | val numLayer = topology.length - 1 130 | val innerRBMs = new Array[RBMModel](numLayer) 131 | for (layer <- 0 until numLayer) { 132 | val dropout = if (layer == 0) { 133 | 0.2 134 | } else if (layer < numLayer - 1) { 135 | 0.5 136 | } else { 137 | 0.0 138 | } 139 | innerRBMs(layer) = new RBMModel(topology(layer), topology(layer + 1), dropout) 140 | println(s"innerRBMs($layer) = ${innerRBMs(layer).numIn} * ${innerRBMs(layer).numOut}") 141 | } 142 | innerRBMs 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/optimization/Gradient.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.optimization 19 | 20 | import org.apache.spark.annotation.DeveloperApi 21 | import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors} 22 | 23 | /** 24 | * :: DeveloperApi :: 25 | * Class used to compute the gradient for a loss function, given a single data point. 26 | */ 27 | @DeveloperApi 28 | abstract class Gradient extends Serializable { 29 | /** 30 | * Compute the gradient and loss given the features of a single data point. 31 | * 32 | * @param data features for one data point 33 | * @param label label for this data point 34 | * @param weights weights/coefficients corresponding to features 35 | * 36 | * @return (gradient: Vector, loss: Double) 37 | */ 38 | def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { 39 | val gradient = Vectors.zeros(weights.size) 40 | val loss = compute(data, label, weights, gradient) 41 | (gradient, loss) 42 | } 43 | 44 | /** 45 | * Compute the gradient and loss given the features of a single data point, 46 | * add the gradient to a provided vector to avoid creating new objects, and return loss. 47 | * 48 | * @param data features for one data point 49 | * @param label label for this data point 50 | * @param weights weights/coefficients corresponding to features 51 | * @param cumGradient the computed gradient will be added to this vector 52 | * 53 | * @return loss 54 | */ 55 | def compute(data: Vector, label: Double, weights: Vector, cumGradient: Vector): Double 56 | 57 | /** 58 | * Compute the gradient and loss given the iterator. 59 | * 60 | * @param iter Iterator for (label, data) pair 61 | * @param weights weights/coefficients corresponding to features 62 | * @param cumGradient the computed gradient will be added to this vector 63 | * 64 | * @return (count: Long, loss: Double) 65 | */ 66 | def compute( 67 | iter: Iterator[(Double, Vector)], 68 | weights: Vector, 69 | cumGradient: Vector): (Long, Double) = { 70 | var loss = 0D 71 | var count = 0L 72 | iter.foreach { t => 73 | loss += compute(t._2, t._1, weights, cumGradient) 74 | count += 1 75 | } 76 | (count, loss) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/optimization/Optimizer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.optimization 19 | 20 | 21 | import org.apache.spark.rdd.RDD 22 | 23 | import org.apache.spark.annotation.DeveloperApi 24 | import org.apache.spark.mllib.linalg.Vector 25 | 26 | /** 27 | * :: DeveloperApi :: 28 | * Trait for optimization problem solvers. 29 | */ 30 | @DeveloperApi 31 | trait Optimizer extends Serializable { 32 | 33 | /** 34 | * Solve the provided convex optimization problem. 35 | */ 36 | def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector 37 | } 38 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/optimization/Updater.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.optimization 19 | 20 | 21 | import scala.math._ 22 | 23 | import breeze.linalg.{norm => brzNorm, axpy => brzAxpy, Vector => BV} 24 | 25 | import org.apache.spark.annotation.DeveloperApi 26 | import org.apache.spark.mllib.linalg.{Vectors, Vector} 27 | 28 | /** 29 | * :: DeveloperApi :: 30 | * Class used to perform steps (weight update) using Gradient Descent methods. 31 | * 32 | * For general minimization problems, or for regularized problems of the form 33 | * min L(w) + regParam * R(w), 34 | * the compute function performs the actual update step, when given some 35 | * (e.g. stochastic) gradient direction for the loss L(w), 36 | * and a desired step-size (learning rate). 37 | * 38 | * The updater is responsible to also perform the update coming from the 39 | * regularization term R(w) (if any regularization is used). 40 | */ 41 | @DeveloperApi 42 | abstract class Updater extends Serializable { 43 | /** 44 | * Compute an updated value for weights given the gradient, stepSize, iteration number and 45 | * regularization parameter. Also returns the regularization value regParam * R(w) 46 | * computed using the *updated* weights. 47 | * 48 | * @param weightsOld - Column matrix of size dx1 where d is the number of features. 49 | * @param gradient - Column matrix of size dx1 where d is the number of features. 50 | * @param stepSize - step size across iterations 51 | * @param iter - Iteration number 52 | * @param regParam - Regularization parameter 53 | * 54 | * @return A tuple of 2 elements. The first element is a column matrix containing updated weights, 55 | * and the second element is the regularization value computed using updated weights. 56 | */ 57 | def compute( 58 | weightsOld: Vector, 59 | gradient: Vector, 60 | stepSize: Double, 61 | iter: Int, 62 | regParam: Double): (Vector, Double) 63 | } 64 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/partitioner/BBRPartitioner.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.partitioner 19 | 20 | import scala.reflect.ClassTag 21 | 22 | import com.github.cloudml.zen.ml.clustering.LDADefines._ 23 | import com.github.cloudml.zen.ml.sampler.AliasTable 24 | import com.github.cloudml.zen.ml.util.XORShiftRandom 25 | import breeze.linalg.{SparseVector => BSV} 26 | import org.apache.spark.Partitioner 27 | import org.apache.spark.graphx2._ 28 | import org.apache.spark.graphx2.impl.GraphImpl 29 | import org.apache.spark.storage.StorageLevel 30 | 31 | 32 | private[ml] class BBRPartitioner(val partitions: Int) extends Partitioner { 33 | 34 | override def numPartitions: Int = partitions 35 | 36 | def getKey(et: EdgeTriplet[Int, _]): VertexId = { 37 | if (et.srcAttr >= et.dstAttr) et.srcId else et.dstId 38 | } 39 | 40 | def getPartition(key: Any): PartitionID = { 41 | key.asInstanceOf[PartitionID] % numPartitions 42 | } 43 | 44 | override def equals(other: Any): Boolean = other match { 45 | case bbr: BBRPartitioner => 46 | bbr.numPartitions == numPartitions 47 | case _ => 48 | false 49 | } 50 | 51 | override def hashCode: Int = numPartitions 52 | } 53 | 54 | /** 55 | * Bounded & Balanced Rearranger Partitioner 56 | */ 57 | object BBRPartitioner { 58 | private[zen] def partitionByBBR[VD: ClassTag, ED: ClassTag]( 59 | input: Graph[VD, ED], 60 | storageLevel: StorageLevel): Graph[VD, ED] = { 61 | val edges = input.edges 62 | val conf = edges.context.getConf 63 | val numPartitions = conf.getInt(cs_numPartitions, edges.partitions.length) 64 | val bbr = new BBRPartitioner(numPartitions) 65 | val degGraph = GraphImpl(input.degrees, edges) 66 | val assnGraph = degGraph.mapTriplets((pid, iter) => 67 | iter.map(et => (bbr.getKey(et), Edge(et.srcId, et.dstId, et.attr))), TripletFields.All) 68 | assnGraph.persist(storageLevel) 69 | 70 | val assnVerts = assnGraph.aggregateMessages[Long](ect => { 71 | if (ect.attr._1 == ect.srcId) { 72 | ect.sendToSrc(1L) 73 | } else { 74 | ect.sendToDst(1L) 75 | } 76 | }, _ + _, TripletFields.EdgeOnly) 77 | val (kids, koccurs) = assnVerts.filter(_._2 > 0L).collect().unzip 78 | val partRdd = edges.context.parallelize(kids.zip(rearrage(koccurs, numPartitions))) 79 | val rearrGraph = assnGraph.mapVertices((_, _) => null.asInstanceOf[AliasTable[Long]]) 80 | .joinVertices(partRdd)((_, _, arr) => AliasTable.generateAlias(arr)) 81 | 82 | val newEdges = rearrGraph.triplets.mapPartitions(iter => { 83 | val gen = new XORShiftRandom() 84 | iter.map(et => { 85 | val (kid, edge) = et.attr 86 | val table = if (kid == et.srcId) et.srcAttr else et.dstAttr 87 | (table.sampleRandom(gen), edge) 88 | }) 89 | }).partitionBy(bbr).map(_._2) 90 | GraphImpl(input.vertices, newEdges, null.asInstanceOf[VD], storageLevel, storageLevel) 91 | } 92 | 93 | private def rearrage(koccurs: IndexedSeq[Long], numPartitions: Int): IndexedSeq[BSV[Long]] = { 94 | val numKeys = koccurs.length 95 | val numEdges = koccurs.sum 96 | val npp = numEdges / numPartitions 97 | val rpn = numEdges - npp * numPartitions 98 | @inline def nrpp(pi: Int): Long = npp + (if (pi < rpn) 1L else 0L) 99 | @inline def kbn(ki: Int): Long = if (ki < numKeys) koccurs(ki) else 0L 100 | val keyPartCount = koccurs.map(t => BSV.zeros[Long](numPartitions)) 101 | def put(ki: Int, krest: Long, pi: Int, prest: Long): Unit = { 102 | if (ki < numKeys) { 103 | if (krest == prest) { 104 | keyPartCount(ki)(pi) = krest 105 | put(ki + 1, kbn(ki + 1), pi + 1, nrpp(pi + 1)) 106 | } else if (krest < prest) { 107 | keyPartCount(ki)(pi) = krest 108 | put(ki + 1, kbn(ki + 1), pi, prest - krest) 109 | } else { 110 | keyPartCount(ki)(pi) = prest 111 | put(ki, krest - prest, pi + 1, nrpp(pi + 1)) 112 | } 113 | } 114 | } 115 | put(0, kbn(0), 0, nrpp(0)) 116 | keyPartCount 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/partitioner/DBHPartitioner.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.partitioner 19 | 20 | import scala.reflect.ClassTag 21 | 22 | import com.github.cloudml.zen.ml.clustering.LDADefines._ 23 | 24 | import org.apache.spark.HashPartitioner 25 | import org.apache.spark.graphx2._ 26 | import org.apache.spark.graphx2.impl.GraphImpl 27 | import org.apache.spark.storage.StorageLevel 28 | 29 | /** 30 | * Degree-Based Hashing, the paper: 31 | * Distributed Power-law Graph Computing: Theoretical and Empirical Analysis 32 | */ 33 | class DBHPartitioner(val partitions: Int, val threshold: Int = 0) 34 | extends HashPartitioner(partitions) { 35 | /** 36 | * Default DBH doesn't consider the situation where both the degree of src and 37 | * dst vertices are both small than a given threshold value 38 | */ 39 | def getKey(et: EdgeTriplet[Int, _]): Long = { 40 | val srcId = et.srcId 41 | val dstId = et.dstId 42 | val srcDeg = et.srcAttr 43 | val dstDeg = et.dstAttr 44 | val maxDeg = math.max(srcDeg, dstDeg) 45 | val minDegId = if (maxDeg == srcDeg) dstId else srcId 46 | val maxDegId = if (maxDeg == srcDeg) srcId else dstId 47 | if (maxDeg < threshold) { 48 | maxDegId 49 | } else { 50 | minDegId 51 | } 52 | } 53 | 54 | override def equals(other: Any): Boolean = other match { 55 | case dbh: DBHPartitioner => 56 | dbh.numPartitions == numPartitions 57 | case _ => 58 | false 59 | } 60 | } 61 | 62 | object DBHPartitioner { 63 | def partitionByDBH[VD: ClassTag, ED: ClassTag](input: Graph[VD, ED], 64 | storageLevel: StorageLevel): Graph[VD, ED] = { 65 | val edges = input.edges 66 | val conf = edges.context.getConf 67 | val numPartitions = conf.getInt(cs_numPartitions, edges.partitions.length) 68 | val dbh = new DBHPartitioner(numPartitions, 0) 69 | val degGraph = GraphImpl(input.degrees, edges) 70 | val newEdges = degGraph.triplets.mapPartitions(_.map(et => 71 | (dbh.getKey(et), Edge(et.srcId, et.dstId, et.attr)) 72 | )).partitionBy(dbh).map(_._2) 73 | GraphImpl(input.vertices, newEdges, null.asInstanceOf[VD], storageLevel, storageLevel) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/partitioner/EdgeDstPartitioner.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.partitioner 19 | 20 | import scala.reflect.ClassTag 21 | 22 | import com.github.cloudml.zen.ml.clustering.LDADefines._ 23 | 24 | import org.apache.spark.HashPartitioner 25 | import org.apache.spark.graphx2._ 26 | import org.apache.spark.graphx2.impl.GraphImpl 27 | import org.apache.spark.storage.StorageLevel 28 | 29 | 30 | class EdgeDstPartitioner(val partitions: Int) extends HashPartitioner(partitions) { 31 | 32 | @inline def getKey(et: EdgeTriplet[_, _]): Long = et.dstId 33 | 34 | override def equals(other: Any): Boolean = other match { 35 | case edp: EdgeDstPartitioner => 36 | edp.numPartitions == numPartitions 37 | case _ => 38 | false 39 | } 40 | } 41 | 42 | object EdgeDstPartitioner { 43 | def partitionByEDP[VD: ClassTag, ED: ClassTag](input: Graph[VD, ED], 44 | storageLevel: StorageLevel): Graph[VD, ED] = { 45 | val edges = input.edges 46 | val conf = edges.context.getConf 47 | val numPartitions = conf.getInt(cs_numPartitions, edges.partitions.length) 48 | val edp = new EdgeDstPartitioner(numPartitions) 49 | val newEdges = input.triplets.mapPartitions(_.map(et => 50 | (edp.getKey(et), Edge(et.srcId, et.dstId, et.attr)) 51 | )).partitionBy(edp).map(_._2) 52 | GraphImpl(input.vertices, newEdges, null.asInstanceOf[VD], storageLevel, storageLevel) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/partitioner/LBVertexRDDBuilder.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.partitioner 19 | 20 | import scala.reflect.ClassTag 21 | 22 | import org.apache.spark.graphx2._ 23 | import org.apache.spark.graphx2.impl._ 24 | import org.apache.spark.rdd.RDD 25 | import org.apache.spark.storage.StorageLevel 26 | 27 | 28 | object LBVertexRDDBuilder { 29 | def fromEdgeRDD[VD: ClassTag, ED: ClassTag](edges: EdgeRDD[ED], 30 | storageLevel: StorageLevel): GraphImpl[VD, ED] = { 31 | val eimpl = edges.asInstanceOf[EdgeRDDImpl[ED, VD]] 32 | GraphImpl.fromEdgeRDD(eimpl, null.asInstanceOf[VD], storageLevel, storageLevel) 33 | } 34 | 35 | def fromEdges[VD: ClassTag, ED: ClassTag](edges: RDD[Edge[ED]], 36 | storageLevel: StorageLevel): GraphImpl[VD, ED] = { 37 | fromEdgeRDD[VD, ED](EdgeRDD.fromEdges[ED, VD](edges), storageLevel) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/recommendation/README.md: -------------------------------------------------------------------------------- 1 | # Factorization Machines 2 | 3 | ## Road map 4 | * Support hundereds billions of features 5 | 6 | * Support online learning 7 | 8 | * Support TBs training data 9 | 10 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/regression/README.md: -------------------------------------------------------------------------------- 1 | # LogisticRegression 2 | 3 | ## Road map 4 | 5 | * Support hundereds billions of features 6 | 7 | * Support online learning 8 | 9 | * Support hundereds billions of samples 10 | 11 | ## TODO 12 | 13 | * Add LogisticRegressionModel 14 | 15 | * Add examples 16 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/sampler/CompositeSampler.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.sampler 19 | 20 | import java.util.Random 21 | 22 | import spire.math.{Numeric => spNum} 23 | 24 | 25 | class CompositeSampler(implicit ev: spNum[Double]) 26 | extends Sampler[Double] { 27 | private var samplers: Seq[Sampler[_]] = _ 28 | 29 | protected def numer: spNum[Double] = ev 30 | 31 | def apply(state: Int): Double = samplers.iterator.map(_.applyDouble(state)).sum 32 | 33 | def norm: Double = samplers.iterator.map(_.normDouble).sum 34 | 35 | def sampleFrom(base: Double, gen: Random): Int = { 36 | val sampIter = samplers.iterator 37 | var curSampler = sampIter.next() 38 | var subNorm = curSampler.normDouble 39 | var remain = base 40 | while (remain >= subNorm) { 41 | remain -= subNorm 42 | curSampler = sampIter.next() 43 | subNorm = curSampler.normDouble 44 | } 45 | curSampler.sampleFromDouble(remain, gen) 46 | } 47 | 48 | def resetComponents(samplers: Sampler[_]*): CompositeSampler = { 49 | this.samplers = samplers 50 | this 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/sampler/CumulativeDist.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.sampler 19 | 20 | import java.util.Random 21 | import scala.reflect.ClassTag 22 | 23 | import CumulativeDist._ 24 | 25 | import breeze.linalg.StorageVector 26 | import spire.math.{Numeric => spNum} 27 | 28 | 29 | class CumulativeDist[@specialized(Double, Int, Float, Long) T: ClassTag](implicit ev: spNum[T]) 30 | extends DiscreteSampler[T] with Serializable { 31 | var _cdf: Array[T] = _ 32 | var _space: Array[Int] = _ 33 | var _used: Int = _ 34 | 35 | protected def numer: spNum[T] = ev 36 | 37 | def length: Int = _cdf.length 38 | 39 | def size: Int = length 40 | 41 | def used: Int = _used 42 | 43 | def norm: T = { 44 | if (_used == 0) { 45 | ev.zero 46 | } else { 47 | _cdf(_used - 1) 48 | } 49 | } 50 | 51 | def sampleFrom(base: T, gen: Random): Int = { 52 | // assert(ev.lt(base, _cdf(_used - 1))) 53 | if (_used == 1) { 54 | _space(0) 55 | } else { 56 | val i = binarySelect(_cdf, base, 0, _used, greater=true) 57 | _space(i) 58 | } 59 | } 60 | 61 | def apply(state: Int): T = { 62 | val i = binarySelect(_space, state, 0, _used, greater=true) 63 | if (_space(i) == state) { 64 | if (i == 0) _cdf(0) else ev.minus(_cdf(i), _cdf(i - 1)) 65 | } else { 66 | ev.zero 67 | } 68 | } 69 | 70 | def update(state: Int, value: => T): Unit = {} 71 | 72 | def deltaUpdate(state: Int, delta: => T): Unit = {} 73 | 74 | def resetDist(probs: Array[T], space: Array[Int], psize: Int): CumulativeDist[T] = { 75 | resetDist(space.iterator.zip(probs.iterator), psize) 76 | } 77 | 78 | def resetDist(distIter: Iterator[(Int, T)], psize: Int): CumulativeDist[T] = { 79 | reset(psize) 80 | var sum = ev.zero 81 | var i = 0 82 | while (i < psize) { 83 | val (state, prob) = distIter.next() 84 | sum = ev.plus(sum, prob) 85 | _cdf(i) = sum 86 | _space(i) = state 87 | i += 1 88 | } 89 | this 90 | } 91 | 92 | def reset(newSize: Int): CumulativeDist[T] = { 93 | if (_cdf == null || _cdf.length < newSize) { 94 | _cdf = new Array[T](newSize) 95 | _space = new Array[Int](newSize) 96 | } 97 | _used = newSize 98 | this 99 | } 100 | 101 | def data: Array[T] = _cdf 102 | } 103 | 104 | object CumulativeDist { 105 | def generateCdf[@specialized(Double, Int, Float, Long) T: ClassTag: spNum] 106 | (sv: StorageVector[T]): CumulativeDist[T] = { 107 | val used = sv.activeSize 108 | val cdf = new CumulativeDist[T] 109 | cdf.resetDist(sv.activeIterator, used) 110 | } 111 | 112 | def binarySelect[@specialized(Double, Int, Float, Long) T](arr: Array[T], key: T, 113 | begin: Int, end: Int, greater: Boolean)(implicit ev: spNum[T]): Int = { 114 | if (begin == end) { 115 | return if (greater) end else begin - 1 116 | } 117 | var b = begin 118 | var e = end - 1 119 | 120 | var mid: Int = (e + b) >> 1 121 | while (b <= e) { 122 | mid = (e + b) >> 1 123 | val v = arr(mid) 124 | if (ev.lt(v, key)) { 125 | b = mid + 1 126 | } else if (ev.gt(v, key)) { 127 | e = mid - 1 128 | } else { 129 | return mid 130 | } 131 | } 132 | val v = arr(mid) 133 | mid = if ((greater && ev.gteqv(v, key)) || (!greater && ev.lteqv(v, key))) { 134 | mid 135 | } else if (greater) { 136 | mid + 1 137 | } else { 138 | mid - 1 139 | } 140 | 141 | // if (greater) { 142 | // if (mid < end) assert(ev.gteqv(arr(mid), key)) 143 | // if (mid > 0) assert(ev.lteqv(arr(mid - 1), key)) 144 | // } else { 145 | // if (mid > 0) assert(ev.lteqv(arr(mid), key)) 146 | // if (mid < end - 1) assert(ev.gteqv(arr(mid + 1), key)) 147 | // } 148 | mid 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/sampler/DiscreteSampler.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.sampler 19 | 20 | import java.util.Random 21 | import scala.annotation.tailrec 22 | 23 | import spire.math.{Numeric => spNum} 24 | 25 | 26 | trait DiscreteSampler[@specialized(Double, Int, Float, Long) T] extends Sampler[T] { 27 | def length: Int 28 | def used: Int 29 | def update(state: Int, value: => T): Unit 30 | def deltaUpdate(state: Int, delta: => T): Unit 31 | def resetDist(probs: Array[T], space: Array[Int], psize: Int): DiscreteSampler[T] 32 | def resetDist(distIter: Iterator[(Int, T)], psize: Int): DiscreteSampler[T] 33 | def reset(newSize: Int): DiscreteSampler[T] 34 | 35 | @tailrec final def resampleRandom(gen: Random, 36 | state: Int, 37 | residualRate: Double, 38 | numResampling: Int = 2)(implicit ev: spNum[T]): Int = { 39 | val newState = sampleRandom(gen) 40 | if (newState == state && numResampling >= 0 && used > 1 && 41 | (residualRate >= 1.0 || gen.nextDouble() < residualRate)) { 42 | resampleRandom(gen, state, residualRate, numResampling - 1) 43 | } else { 44 | newState 45 | } 46 | } 47 | 48 | @tailrec final def resampleFrom(base: T, 49 | gen: Random, 50 | state: Int, 51 | residualRate: Double, 52 | numResampling: Int = 2)(implicit ev: spNum[T]): Int = { 53 | val newState = sampleFrom(base, gen) 54 | if (newState == state && numResampling >= 0 && used > 1 && 55 | (residualRate >= 1.0 || gen.nextDouble() < residualRate)) { 56 | val newBase = ev.fromDouble(gen.nextDouble() * ev.toDouble(norm)) 57 | resampleFrom(newBase, gen, state, residualRate, numResampling - 1) 58 | } else { 59 | newState 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/sampler/FlatDist.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.sampler 19 | 20 | import java.util.Random 21 | import scala.reflect.ClassTag 22 | 23 | import breeze.linalg.{SparseVector => brSV, DenseVector => brDV, StorageVector, Vector => brV} 24 | import breeze.storage.Zero 25 | import spire.math.{Numeric => spNum} 26 | 27 | 28 | class FlatDist[@specialized(Double, Int, Float, Long) T: ClassTag](val isSparse: Boolean) 29 | (implicit ev: spNum[T]) extends DiscreteSampler[T] with Serializable { 30 | private var _dist: StorageVector[T] = _ 31 | private var _norm: T = _ 32 | 33 | protected def numer: spNum[T] = ev 34 | 35 | def length: Int = _dist.length 36 | 37 | def used: Int = _dist.activeSize 38 | 39 | def norm: T = _norm 40 | 41 | def sampleFrom(base: T, gen: Random): Int = { 42 | assert(ev.lt(base, _norm)) 43 | val idx = if (used == 1) { 44 | 0 45 | } else { 46 | var i = 0 47 | var cdf = ev.zero 48 | var found = false 49 | do { 50 | cdf = ev.plus(cdf, _dist.valueAt(i)) 51 | if (ev.lt(base, cdf)) { 52 | found = true 53 | } else { 54 | i += 1 55 | } 56 | } while (!found && i < used - 1) 57 | i 58 | } 59 | _dist.indexAt(idx) 60 | } 61 | 62 | def apply(state: Int): T = _dist(state) 63 | 64 | def update(state: Int, value: => T): Unit = { 65 | val prev = _dist(state) 66 | _dist(state) = value 67 | val newNorm = ev.plus(_norm, ev.minus(value, prev)) 68 | setNorm(newNorm) 69 | } 70 | 71 | def deltaUpdate(state: Int, delta: => T): Unit = { 72 | _dist(state) = ev.plus(_dist(state), delta) 73 | val newNorm = ev.plus(_norm, delta) 74 | setNorm(newNorm) 75 | } 76 | 77 | def resetDist(probs: Array[T], space: Array[Int], psize: Int): FlatDist[T] = { 78 | reset(psize) 79 | implicit val zero = Zero(ev.zero) 80 | _dist = if (isSparse) { 81 | new brSV[T](space, probs, psize, _dist.length) 82 | } else { 83 | new brDV[T](probs) 84 | } 85 | var sum = ev.zero 86 | var i = 0 87 | while (i < psize) { 88 | sum = ev.plus(sum, probs(i)) 89 | i += 1 90 | } 91 | setNorm(sum) 92 | } 93 | 94 | def resetDist(distIter: Iterator[(Int, T)], psize: Int): FlatDist[T] = { 95 | reset(psize) 96 | var sum = ev.zero 97 | while (distIter.hasNext) { 98 | val (state, prob) = distIter.next() 99 | _dist(state) = prob 100 | sum = ev.plus(sum, prob) 101 | } 102 | setNorm(sum) 103 | this 104 | } 105 | 106 | def reset(newSize: Int): FlatDist[T] = { 107 | if (_dist == null || _dist.length < newSize) { 108 | implicit val zero = Zero(ev.zero) 109 | _dist = if (isSparse) brSV.zeros[T](newSize) else brDV.zeros[T](newSize) 110 | } 111 | _norm = ev.zero 112 | this 113 | } 114 | 115 | private def setNorm(norm: T): FlatDist[T] = { 116 | _norm = norm 117 | this 118 | } 119 | } 120 | 121 | object FlatDist { 122 | def generateFlat[@specialized(Double, Int, Float, Long) T: ClassTag: spNum] 123 | (sv: brV[T]): FlatDist[T] = { 124 | val used = sv.activeSize 125 | val flat = sv match { 126 | case v: brDV[T] => new FlatDist[T](isSparse=false) 127 | case v: brSV[T] => new FlatDist[T](isSparse=true) 128 | } 129 | flat.resetDist(sv.activeIterator, used) 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/sampler/MetropolisHastings.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.sampler 19 | 20 | import java.util.Random 21 | 22 | import spire.math.{Numeric => spNum} 23 | 24 | 25 | class MetropolisHastings(implicit ev: spNum[Double]) 26 | extends Sampler[Double] { 27 | type TransProb = Int => Double 28 | 29 | private var origFunc: TransProb = _ 30 | private var proposal: Sampler[Double] = _ 31 | private var state: Int = _ 32 | 33 | protected def numer: spNum[Double] = ev 34 | 35 | def apply(state: Int): Double = origFunc(state) 36 | 37 | def norm: Double = proposal.norm 38 | 39 | def sampleFrom(base: Double, gen: Random): Int = { 40 | val newState = proposal.sampleFrom(base, gen) 41 | if (newState != state) { 42 | val ar = acceptRate(newState) 43 | if (ar >= 1.0 || gen.nextDouble() < ar) { 44 | state = newState 45 | } 46 | } 47 | state 48 | } 49 | 50 | private def acceptRate(newState:Int): Double = { 51 | origFunc(newState) * proposal(state) / 52 | (origFunc(state) * proposal(newState)) 53 | } 54 | 55 | def resetProb(origFunc: TransProb, 56 | proposal: Sampler[Double], 57 | initState: Int): MetropolisHastings = { 58 | this.origFunc = origFunc 59 | this.proposal = proposal 60 | this.state = initState 61 | this 62 | } 63 | 64 | def resetProb(origFunc: TransProb, 65 | proposal: Sampler[Double], 66 | gen: Random): MetropolisHastings = { 67 | this.origFunc = origFunc 68 | this.proposal = proposal 69 | this.state = proposal.sampleRandom(gen) 70 | this 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/sampler/Sampler.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.sampler 19 | 20 | import java.util.Random 21 | 22 | import spire.math.{Numeric => spNum} 23 | 24 | 25 | trait Sampler[@specialized(Double, Int, Float, Long) T] { 26 | protected def numer: spNum[T] 27 | def apply(state: Int): T 28 | def norm: T 29 | def sampleFrom(base: T, gen: Random): Int 30 | 31 | def applyDouble(state: Int): Double = numer.toDouble(apply(state)) 32 | 33 | def normDouble: Double = numer.toDouble(norm) 34 | 35 | def sampleFromDouble(base: Double, gen: Random): Int = sampleFrom(numer.fromDouble(base), gen) 36 | 37 | def sampleRandom(gen: Random): Int = { 38 | val u = gen.nextDouble() * normDouble 39 | sampleFromDouble(u, gen) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/tree/Histogram.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.tree 19 | 20 | class Histogram(val numBins: Int) { 21 | private val _counts = new Array[Double](numBins) 22 | private val _scores = new Array[Double](numBins) 23 | private val _squares = new Array[Double](numBins) 24 | private val _scoreWeights = new Array[Double](numBins) 25 | 26 | @inline def counts: Array[Double] = _counts 27 | 28 | @inline def scores: Array[Double] = _scores 29 | 30 | @inline def squares: Array[Double] = _squares 31 | 32 | @inline def scoreWeights: Array[Double] = _scoreWeights 33 | 34 | def weightedUpdate(bin: Int, score: Double, scoreWeight: Double, weight: Double = 1.0): Unit = { 35 | _counts(bin) += weight 36 | _scores(bin) += score * weight 37 | _squares(bin) += score * score * weight 38 | _scoreWeights(bin) += scoreWeight 39 | } 40 | 41 | def update(bin: Int, score: Double, scoreWeight: Double): Unit = { 42 | _counts(bin) += 1 43 | _scores(bin) += score 44 | _squares(bin) += score * score 45 | _scoreWeights(bin) += scoreWeight 46 | } 47 | 48 | def cumulateLeft(): Histogram = { 49 | var bin = 1 50 | while (bin < numBins) { 51 | _counts(bin) += _counts(bin-1) 52 | _scores(bin) += _scores(bin-1) 53 | _squares(bin) += _squares(bin-1) 54 | _scoreWeights(bin) += _scoreWeights(bin-1) 55 | bin += 1 56 | } 57 | this 58 | } 59 | 60 | def cumulate(info: NodeInfoStats): Histogram = { 61 | // cumulate from right to left 62 | var bin = numBins-2 63 | while (bin >0) { 64 | val binRight = bin + 1 65 | _counts(bin) += _counts(binRight) 66 | _scores(bin) += _scores(binRight) 67 | _squares(bin) += _squares(binRight) 68 | _scoreWeights(bin) += _scoreWeights(binRight) 69 | bin -= 1 70 | } 71 | 72 | // fill in Entry(0) with node sum information 73 | _counts(0)=info.sumCount 74 | _scores(0)=info.sumScores 75 | _squares(0)=info.sumSquares 76 | _scoreWeights(0)=info.sumScoreWeights 77 | 78 | this 79 | } 80 | } 81 | 82 | class NodeInfoStats(var sumCount: Int, 83 | var sumScores: Double, 84 | var sumSquares: Double, 85 | var sumScoreWeights: Double)extends Serializable { 86 | 87 | override def toString: String = s"NodeInfoStats($sumCount, $sumScores, $sumSquares, $sumScoreWeights)" 88 | 89 | def canEqual(other: Any): Boolean = other.isInstanceOf[NodeInfoStats] 90 | 91 | override def equals(other: Any): Boolean = other match { 92 | case that: NodeInfoStats => 93 | (that canEqual this) && 94 | sumCount == that.sumCount && 95 | sumScores == that.sumScores && 96 | sumSquares == that.sumSquares && 97 | sumScoreWeights == that.sumScoreWeights 98 | case _ => false 99 | } 100 | 101 | override def hashCode(): Int = { 102 | val state = Seq(sumCount, sumScores, sumSquares, sumScoreWeights) 103 | state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/tree/Node.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.tree 19 | 20 | import org.apache.spark.mllib.tree.model.{Node, Predict} 21 | 22 | object Node { 23 | 24 | /** 25 | * Return a node with the given node id (but nothing else set). 26 | */ 27 | def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0, 28 | false, None, None, None, None) 29 | 30 | /** 31 | * Construct a node with nodeIndex, predict, impurity and isLeaf parameters. 32 | * This is used in `DecisionTree.findBestSplits` to construct child nodes 33 | * after finding the best splits for parent nodes. 34 | * Other fields are set at next level. 35 | * 36 | * @param nodeIndex integer node id, from 1 37 | * @param predict predicted value at the node 38 | * @param impurity current node impurity 39 | * @param isLeaf whether the node is a leaf 40 | * @return new node instance 41 | */ 42 | def apply( 43 | nodeIndex: Int, 44 | predict: Predict, 45 | impurity: Double, 46 | isLeaf: Boolean): Node = { 47 | new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None) 48 | } 49 | 50 | /** 51 | * Return the index of the left child of this node. 52 | */ 53 | def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1 54 | 55 | /** 56 | * Return the index of the right child of this node. 57 | */ 58 | def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1 59 | 60 | /** 61 | * Get the parent index of the given node, or 0 if it is the root. 62 | */ 63 | def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1 64 | 65 | /** 66 | * Return the level of a tree which the given node is in. 67 | */ 68 | def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) { 69 | throw new IllegalArgumentException(s"0 is not a valid node index.") 70 | } else { 71 | java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex)) 72 | } 73 | 74 | /** 75 | * Returns true if this is a left child. 76 | * Note: Returns false for the root. 77 | */ 78 | def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0 79 | 80 | /** 81 | * Return the maximum number of nodes which can be in the given level of the tree. 82 | * 83 | * @param level Level of tree (0 = root). 84 | */ 85 | def maxNodesInLevel(level: Int): Int = 1 << level 86 | 87 | /** 88 | * Return the index of the first node in the given level. 89 | * 90 | * @param level Level of tree (0 = root). 91 | */ 92 | def startIndexInLevel(level: Int): Int = 1 << level 93 | 94 | /** 95 | * Traces down from a root node to get the node with the given node index. 96 | * This assumes the node exists. 97 | */ 98 | def getNode(nodeIndex: Int, rootNode: Node): Node = { 99 | var tmpNode: Node = rootNode 100 | var levelsToGo = indexToLevel(nodeIndex) 101 | while (levelsToGo > 0) { 102 | if ((nodeIndex & (1 << levelsToGo - 1)) == 0) { 103 | tmpNode = tmpNode.leftNode.get 104 | } else { 105 | tmpNode = tmpNode.rightNode.get 106 | } 107 | levelsToGo -= 1 108 | } 109 | tmpNode 110 | } 111 | 112 | } 113 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/tree/ProbabilityFunctions.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.tree 19 | 20 | 21 | object ProbabilityFunctions{ 22 | // probit function 23 | val ProbA = Array(3.3871328727963666080e0, 1.3314166789178437745e+2, 1.9715909503065514427e+3, 24 | 1.3731693765509461125e+4, 4.5921953931549871457e+4, 6.7265770927008700853e+4, 3.3430575583588128105e+4, 25 | 2.5090809287301226727e+3) 26 | val ProbB = Array(4.2313330701600911252e+1, 6.8718700749205790830e+2, 5.3941960214247511077e+3, 27 | 2.1213794301586595867e+4, 3.9307895800092710610e+4, 2.8729085735721942674e+4, 5.2264952788528545610e+3) 28 | 29 | val ProbC = Array(1.42343711074968357734e0, 4.63033784615654529590e0, 5.76949722146069140550e0, 30 | 3.64784832476320460504e0, 1.27045825245236838258e0, 2.41780725177450611770e-1, 2.27238449892691845833e-2, 31 | 7.74545014278341407640e-4) 32 | val ProbD = Array(2.05319162663775882187e0, 1.67638483018380384940e0, 6.89767334985100004550e-1, 33 | 1.48103976427480074590e-1, 1.51986665636164571966e-2, 5.47593808499534494600e-4, 1.05075007164441684324e-9) 34 | 35 | val ProbE = Array(6.65790464350110377720e0, 5.46378491116411436990e0, 1.78482653991729133580e0, 36 | 2.96560571828504891230e-1, 2.65321895265761230930e-2, 1.24266094738807843860e-3, 2.71155556874348757815e-5, 37 | 2.01033439929228813265e-7) 38 | val ProbF = Array(5.99832206555887937690e-1, 1.36929880922735805310e-1, 1.48753612908506148525e-2, 39 | 7.86869131145613259100e-4, 1.84631831751005468180e-5, 1.42151175831644588870e-7, 2.04426310338993978564e-15) 40 | 41 | def Probit(p: Double): Double ={ 42 | val q = p - 0.5 43 | var r = 0.0 44 | if (math.abs(q) < 0.425) { 45 | r = 0.180625 - q * q 46 | q * coeff(ProbA, ProbB, r) 47 | } else { 48 | r = if (q < 0) p else 1 - p 49 | r = math.sqrt(-math.log(r)) 50 | var retval = 0.0 51 | if(r < 5) { 52 | r = r - 1.6 53 | retval = coeff(ProbC, ProbD, r) 54 | } else { 55 | r = r - 5 56 | retval = coeff(ProbE, ProbF, r) 57 | } 58 | if (q >= 0) retval else -retval 59 | } 60 | } 61 | 62 | def coeff(p1: Array[Double], p2: Array[Double], r: Double): Double = { 63 | (((((((p1(7) * r + p1(6)) * r + p1(5)) * r + p1(4)) * r + p1(3)) * r + p1(2)) * r + p1(1)) * r + p1(0)) / 64 | (((((((p2(6) * r + p2(5)) * r + p2(4)) * r + p2(3)) * r + p2(2)) * r + p2(1)) * r + p2(0)) * r + 1.0) 65 | } 66 | 67 | // The approximate complimentary error function (i.e., 1-erf). 68 | def erfc(x: Double): Double = { 69 | if (x.isInfinity) { 70 | if(x.isPosInfinity) 1.0 else -1.0 71 | } else { 72 | val p = 0.3275911 73 | val a1 = 0.254829592 74 | val a2 = -0.284496736 75 | val a3 = 1.421413741 76 | val a4 = -1.453152027 77 | val a5 = 1.061405429 78 | 79 | val t = 1.0 / (1.0 + p * math.abs(x)) 80 | val ev = ((((((((a5 * t) + a4) * t) + a3) * t) + a2) * t + a1) * t) * scala.math.exp(-(x * x)) 81 | if (x >= 0) ev else 2 - ev 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/tree/SplitInfo.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.tree 19 | 20 | import org.apache.spark.mllib.tree.configuration.FeatureType 21 | import org.apache.spark.mllib.tree.model.Split 22 | 23 | class SplitInfo(feature: Int, threshold: Double) 24 | extends Split(feature, threshold, FeatureType.Continuous, List()) 25 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/tree/TreeUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.tree 19 | 20 | import org.apache.hadoop.fs.{FileSystem, Path} 21 | import org.apache.spark.SparkConf 22 | import org.apache.spark.deploy.SparkHadoopUtil 23 | 24 | object TreeUtils { 25 | def getFileSystem(conf: SparkConf, path: Path): FileSystem = { 26 | val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) 27 | if (sys.env.contains("HADOOP_CONF_DIR") || sys.env.contains("YARN_CONF_DIR")) { 28 | val hdfsConfPath = if (sys.env.get("HADOOP_CONF_DIR").isDefined) { 29 | sys.env.get("HADOOP_CONF_DIR").get + "/core-site.xml" 30 | } else { 31 | sys.env.get("YARN_CONF_DIR").get + "/core-site.xml" 32 | } 33 | hadoopConf.addResource(new Path(hdfsConfPath)) 34 | } 35 | path.getFileSystem(hadoopConf) 36 | } 37 | 38 | def getPartitionOffsets(upper: Int, numPartitions: Int): (Array[Int], Array[Int]) = { 39 | val npp = upper / numPartitions 40 | val nppp = npp + 1 41 | val residual = upper - npp * numPartitions 42 | val boundary = residual * nppp 43 | val startPP = new Array[Int](numPartitions) 44 | val lcLenPP = new Array[Int](numPartitions) 45 | var i = 0 46 | while(i < numPartitions) { 47 | if (i < residual) { 48 | startPP(i) = nppp * i 49 | lcLenPP(i) = nppp 50 | } 51 | else{ 52 | startPP(i) = boundary + (i - residual) * npp 53 | lcLenPP(i) = npp 54 | } 55 | i += 1 56 | } 57 | (startPP, lcLenPP) 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/tree/treeAggregatorFormat.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.tree 19 | 20 | 21 | import java.io.{File, PrintWriter, FileOutputStream} 22 | 23 | import org.apache.spark.mllib.tree.model._ 24 | import scala.collection.mutable 25 | 26 | object treeAggregatorFormat{ 27 | type Lists = (List[String], List[Double], List[Double], List[Int], List[Int], List[Double], List[Double]) 28 | 29 | def reformatted(topNode: Node): Lists = { 30 | val splitFeatures = new mutable.MutableList[String] 31 | val splitGains = new mutable.MutableList[Double] 32 | val gainPValues = new mutable.MutableList[Double] 33 | val lteChildren = new mutable.MutableList[Int] 34 | val gtChildren = new mutable.MutableList[Int] 35 | val thresholds = new mutable.MutableList[Double] 36 | val outputs = new mutable.MutableList[Double] 37 | 38 | var curNonLeafIdx = 0 39 | var curLeafIdx = 0 40 | val childIdx = (child: Node) => if (child.isLeaf) { 41 | curLeafIdx -= 1 42 | curLeafIdx 43 | } else { 44 | curNonLeafIdx += 1 45 | curNonLeafIdx 46 | } 47 | 48 | val q = new mutable.Queue[Node] 49 | q.enqueue(topNode) 50 | while (q.nonEmpty) { 51 | val node = q.dequeue() 52 | if (!node.isLeaf) { 53 | val split = node.split.get 54 | val stats = node.stats.get 55 | splitFeatures += s"I:${split.feature}" 56 | splitGains += stats.gain 57 | gainPValues += 0.0 58 | thresholds += split.threshold 59 | val left = node.leftNode.get 60 | val right = node.rightNode.get 61 | lteChildren += childIdx(left) 62 | gtChildren += childIdx(right) 63 | q.enqueue(left) 64 | q.enqueue(right) 65 | } else { 66 | outputs += node.predict.predict 67 | } 68 | } 69 | (splitFeatures.toList, splitGains.toList, gainPValues.toList, lteChildren.toList, gtChildren.toList, 70 | thresholds.toList, outputs.toList) 71 | } 72 | 73 | def sequence(path: String, model: DecisionTreeModel, modelId: Int): Unit = { 74 | val topNode = model.topNode 75 | val (splitFeatures, splitGains, gainPValues, lteChildren, gtChildren, thresholds, outputs) = reformatted(topNode) 76 | val numInternalNodes = splitFeatures.length 77 | 78 | val pw = new PrintWriter(new FileOutputStream(new File(path), true)) 79 | pw.write(s"[Evaluator:$modelId]\n") 80 | pw.write("EvaluatorType=DecisionTree\n") 81 | pw.write(s"NumInternalNodes=$numInternalNodes\n") 82 | 83 | var str = splitFeatures.mkString("\t") 84 | pw.write(s"SplitFeatures=$str\n") 85 | str = splitGains.mkString("\t") 86 | pw.write(s"SplitGain=$str\n") 87 | str = gainPValues.mkString("\t") 88 | pw.write(s"GainPValue=$str\n") 89 | str = lteChildren.mkString("\t") 90 | pw.write(s"LTEChild=$str\n") 91 | str = gtChildren.mkString("\t") 92 | pw.write(s"GTChild=$str\n") 93 | str = thresholds.mkString("\t") 94 | pw.write(s"Threshold=$str\n") 95 | str = outputs.mkString("\t") 96 | pw.write(s"Output=$str\n") 97 | 98 | pw.write("\n") 99 | pw.close() 100 | println(s"save succeed") 101 | } 102 | 103 | def appendTreeAggregator(filePath: String, 104 | index: Int, 105 | evalNodes: Array[Int], 106 | evalWeights: Array[Double] = null, 107 | bias: Double = 0.0, 108 | Type: String = "Linear"): Unit = { 109 | val pw = new PrintWriter(new FileOutputStream(new File(filePath), true)) 110 | 111 | pw.append(s"[Evaluator:$index]").write("\r\n") 112 | pw.append(s"EvaluatorType=Aggregator").write("\r\n") 113 | 114 | val numNodes = evalNodes.length 115 | val defaultWeight = 1.0 116 | if (evalNodes == null) { 117 | throw new IllegalArgumentException("there is no evaluators to be aggregated") 118 | } else { 119 | pw.append(s"NumNodes=$numNodes").write("\r\n") 120 | pw.append(s"Nodes=").write("") 121 | for (eval <- evalNodes) { 122 | pw.append(s"E:$eval").write("\t") 123 | } 124 | pw.write("\r\n") 125 | } 126 | 127 | var weights = new Array[Double](numNodes) 128 | if (evalWeights == null) { 129 | for (i <- 0 until numNodes) { 130 | weights(i) = defaultWeight 131 | } 132 | } else { 133 | weights = evalWeights 134 | } 135 | 136 | pw.append(s"Weights=").write("") 137 | for (weight <- weights) { 138 | pw.append(s"$weight").write("\t") 139 | } 140 | 141 | pw.write("\r\n") 142 | 143 | pw.append(s"Bias=$bias").write("\r\n") 144 | pw.append(s"Type=$Type").write("\r\n") 145 | 146 | pw.close() 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/util/CompressedVector.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.util 19 | 20 | import java.util 21 | 22 | import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} 23 | import me.lemire.integercompression._ 24 | import me.lemire.integercompression.differential._ 25 | 26 | 27 | class CompressedVector(val used: Int, 28 | val cdata: Array[Int], 29 | val cindex: Array[Int]) extends Serializable 30 | 31 | // optimized for performance, not thread-safe 32 | class BVCompressor(numTopics: Int) { 33 | val dataCodec = new SkippableComposition(new BinaryPacking, new VariableByte) 34 | val indexCodec = new SkippableIntegratedComposition(new IntegratedBinaryPacking, new IntegratedVariableByte) 35 | val buf = new Array[Int](numTopics + 1024) 36 | val inPos = new IntWrapper 37 | val outPos = new IntWrapper 38 | val initValue = new IntWrapper 39 | 40 | def BV2CV(bv: BV[Int]): CompressedVector = bv match { 41 | case v: BDV[Int] => 42 | val cdata = compressData(v.data, numTopics) 43 | new CompressedVector(numTopics, cdata, null) 44 | case v: BSV[Int] => 45 | val used = v.used 46 | val index = v.index 47 | val data = v.data 48 | if (used <= 4) { 49 | new CompressedVector(used, data, index) 50 | } else { 51 | val cdata = compressData(data, used) 52 | val cindex = compressIndex(index, used) 53 | new CompressedVector(used, cdata, cindex) 54 | } 55 | } 56 | 57 | def compressData(data: Array[Int], len: Int): Array[Int] = { 58 | inPos.set(0) 59 | outPos.set(0) 60 | dataCodec.headlessCompress(data, inPos, len, buf, outPos) 61 | util.Arrays.copyOf(buf, outPos.get) 62 | } 63 | 64 | def compressIndex(index: Array[Int], len: Int): Array[Int] = { 65 | buf(0) = index.length 66 | inPos.set(0) 67 | outPos.set(1) 68 | initValue.set(0) 69 | indexCodec.headlessCompress(index, inPos, len, buf, outPos, initValue) 70 | util.Arrays.copyOf(buf, outPos.get) 71 | } 72 | } 73 | 74 | // optimized for performance, not thread-safe 75 | class BVDecompressor(numTopics: Int) { 76 | val dataCodec = new SkippableComposition(new BinaryPacking, new VariableByte) 77 | val indexCodec = new SkippableIntegratedComposition(new IntegratedBinaryPacking, new IntegratedVariableByte) 78 | val inPos = new IntWrapper 79 | val outPos = new IntWrapper 80 | val initValue = new IntWrapper 81 | 82 | def CV2BV(cv: CompressedVector): BV[Int] = { 83 | val cdata = cv.cdata 84 | val cindex = cv.cindex 85 | if (cindex == null) { 86 | val data = decompressData(cdata, numTopics) 87 | new BDV(data) 88 | } else { 89 | val used = cv.used 90 | if (used <= 4) { 91 | new BSV(cindex, cdata, used, numTopics) 92 | } else { 93 | val data = decompressData(cdata, used) 94 | val index= decompressIndex(cindex, used) 95 | new BSV(index, data, used, numTopics) 96 | } 97 | } 98 | } 99 | 100 | def decompressData(cdata: Array[Int], rawLen: Int): Array[Int] = { 101 | val data = new Array[Int](rawLen) 102 | inPos.set(0) 103 | outPos.set(0) 104 | dataCodec.headlessUncompress(cdata, inPos, cdata.length, data, outPos, rawLen) 105 | data 106 | } 107 | 108 | def decompressIndex(cindex: Array[Int], rawLen: Int): Array[Int] = { 109 | val index = new Array[Int](rawLen) 110 | inPos.set(1) 111 | outPos.set(0) 112 | initValue.set(0) 113 | indexCodec.headlessUncompress(cindex, inPos, cindex.length - 1, index, outPos, rawLen, initValue) 114 | index 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/util/Concurrent.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.util 19 | 20 | import java.util.concurrent.{Executors, LinkedBlockingQueue, ThreadPoolExecutor} 21 | 22 | import scala.concurrent._ 23 | import scala.concurrent.duration._ 24 | 25 | 26 | object Concurrent extends Serializable { 27 | @inline def withFuture[T](body: => T)(implicit es: ExecutionContextExecutorService): Future[T] = { 28 | Future(body)(es) 29 | } 30 | 31 | @inline def withAwaitReady[T](future: Future[T]): Unit = { 32 | Await.ready(future, 1.hour) 33 | } 34 | 35 | def withAwaitReadyAndClose[T](future: Future[T])(implicit es: ExecutionContextExecutorService): Unit = { 36 | Await.ready(future, 1.hour) 37 | closeExecutionContext(es) 38 | } 39 | 40 | @inline def withAwaitResult[T](future: Future[T]): T = { 41 | Await.result(future, 1.hour) 42 | } 43 | 44 | def withAwaitResultAndClose[T](future: Future[T])(implicit es: ExecutionContextExecutorService): T = { 45 | val res = Await.result(future, 1.hour) 46 | closeExecutionContext(es) 47 | res 48 | } 49 | 50 | @inline def initExecutionContext(numThreads: Int): ExecutionContextExecutorService = { 51 | ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(numThreads)) 52 | } 53 | 54 | @inline def closeExecutionContext(es: ExecutionContextExecutorService): Unit = { 55 | es.shutdown() 56 | } 57 | } 58 | 59 | object DebugConcurrent extends Serializable { 60 | def withFuture[T](body: => T)(implicit es: ExecutionContextExecutorService): Future[T] = { 61 | val future = Future(body)(es) 62 | future.onFailure { case e => 63 | e.printStackTrace() 64 | }(scala.concurrent.ExecutionContext.Implicits.global) 65 | future 66 | } 67 | 68 | def withAwaitReady[T](future: Future[T]): Unit = { 69 | Await.ready(future, 1.hour) 70 | } 71 | 72 | def withAwaitReadyAndClose[T](future: Future[T])(implicit es: ExecutionContextExecutorService): Unit = { 73 | future.onComplete { _ => 74 | closeExecutionContext(es) 75 | }(scala.concurrent.ExecutionContext.Implicits.global) 76 | Await.ready(future, 1.hour) 77 | } 78 | 79 | def withAwaitResult[T](future: Future[T]): T = { 80 | Await.result(future, 1.hour) 81 | } 82 | 83 | def withAwaitResultAndClose[T](future: Future[T])(implicit es: ExecutionContextExecutorService): T = { 84 | future.onComplete { _ => 85 | closeExecutionContext(es) 86 | }(scala.concurrent.ExecutionContext.Implicits.global) 87 | Await.result(future, 1.hour) 88 | } 89 | 90 | def initExecutionContext(numThreads: Int): ExecutionContextExecutorService = { 91 | val es = new ThreadPoolExecutor(numThreads, numThreads, 0L, MILLISECONDS, new LinkedBlockingQueue[Runnable], 92 | Executors.defaultThreadFactory, new ThreadPoolExecutor.AbortPolicy) 93 | ExecutionContext.fromExecutorService(es) 94 | } 95 | 96 | def closeExecutionContext(es: ExecutionContextExecutorService): Unit = { 97 | es.shutdown() 98 | if (!es.awaitTermination(1L, SECONDS)) { 99 | System.err.println("Error: ExecutorService does not exit itself, force to terminate.") 100 | } 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/util/SparkUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.util 19 | 20 | import breeze.linalg.{Vector => BV, SparseVector => BSV, DenseVector => BDV} 21 | import breeze.storage.Zero 22 | import org.apache.hadoop.fs.{FileSystem, Path} 23 | import org.apache.spark.SparkConf 24 | import org.apache.spark.deploy.SparkHadoopUtil 25 | import org.apache.spark.mllib.linalg.{DenseVector => SDV, Vector => SV, SparseVector => SSV} 26 | import scala.language.implicitConversions 27 | import scala.reflect.ClassTag 28 | 29 | 30 | private[zen] object SparkUtils { 31 | implicit def toBreeze(sv: SV): BV[Double] = { 32 | sv match { 33 | case SDV(data) => 34 | new BDV(data) 35 | case SSV(size, indices, values) => 36 | new BSV(indices, values, size) 37 | } 38 | } 39 | 40 | implicit def fromBreeze(breezeVector: BV[Double]): SV = { 41 | breezeVector match { 42 | case v: BDV[Double] => 43 | if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) { 44 | new SDV(v.data) 45 | } else { 46 | new SDV(v.toArray) // Can't use underlying array directly, so make a new one 47 | } 48 | case v: BSV[Double] => 49 | if (v.index.length == v.used) { 50 | new SSV(v.length, v.index, v.data) 51 | } else { 52 | new SSV(v.length, v.index.slice(0, v.used), v.data.slice(0, v.used)) 53 | } 54 | case v: BV[_] => 55 | sys.error("Unsupported Breeze vector type: " + v.getClass.getName) 56 | } 57 | } 58 | 59 | def toBreezeConv[T: ClassTag](sv: SV)(implicit num: Numeric[T]): BV[T] = { 60 | val zero = num.zero 61 | implicit val conv: Array[Double] => Array[T] = (data) => { 62 | data.map(ele => (zero match { 63 | case zero: Double => ele 64 | case zero: Float => ele.toFloat 65 | case zero: Int => ele.toInt 66 | case zero: Long => ele.toLong 67 | }).asInstanceOf[T]).array 68 | } 69 | sv match { 70 | case SDV(data) => 71 | new BDV[T](data) 72 | case SSV(size, indices, values) => 73 | new BSV[T](indices, values, size)(Zero[T](zero)) 74 | } 75 | } 76 | 77 | def fromBreezeConv[T: ClassTag](breezeVector: BV[T])(implicit num: Numeric[T]): SV = { 78 | implicit val conv: Array[T] => Array[Double] = (data) => { 79 | data.map(num.toDouble).array 80 | } 81 | breezeVector match { 82 | case v: BDV[T] => 83 | if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) { 84 | new SDV(v.data) 85 | } else { 86 | new SDV(v.toArray) // Can't use underlying array directly, so make a new one 87 | } 88 | case v: BSV[T] => 89 | if (v.index.length == v.used) { 90 | new SSV(v.length, v.index, v.data) 91 | } else { 92 | new SSV(v.length, v.index.slice(0, v.used), v.data.slice(0, v.used)) 93 | } 94 | case v: BV[T] => 95 | sys.error("Unsupported Breeze vector type: " + v.getClass.getName) 96 | } 97 | } 98 | 99 | def getFileSystem(conf: SparkConf, path: Path): FileSystem = { 100 | val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) 101 | if (sys.env.contains("HADOOP_CONF_DIR") || sys.env.contains("YARN_CONF_DIR")) { 102 | val hdfsConfPath = if (sys.env.get("HADOOP_CONF_DIR").isDefined) { 103 | sys.env.get("HADOOP_CONF_DIR").get + "/core-site.xml" 104 | } else { 105 | sys.env.get("YARN_CONF_DIR").get + "/core-site.xml" 106 | } 107 | hadoopConf.addResource(new Path(hdfsConfPath)) 108 | } 109 | path.getFileSystem(hadoopConf) 110 | } 111 | 112 | def deleteChkptDirs(conf: SparkConf, dirs: Array[String]): Unit = { 113 | val fs = getFileSystem(conf, new Path(dirs(0))) 114 | dirs.foreach(dir => { 115 | fs.delete(new Path(dir), true) 116 | }) 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/util/TimeTracker.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.util 19 | 20 | import scala.collection.mutable.{HashMap => MutableHashMap} 21 | 22 | /** 23 | * Time tracker implementation which holds labeled timers. 24 | */ 25 | class TimeTracker extends Serializable { 26 | 27 | private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() 28 | 29 | private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() 30 | 31 | /** 32 | * Starts a new timer, or re-starts a stopped timer. 33 | */ 34 | def start(timerLabel: String): Unit = { 35 | val currentTime = System.nanoTime() 36 | if (starts.contains(timerLabel)) { 37 | throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" + 38 | s" timerLabel = $timerLabel before that timer was stopped.") 39 | } 40 | starts(timerLabel) = currentTime 41 | } 42 | 43 | /** 44 | * Stops a timer and returns the elapsed time in seconds. 45 | */ 46 | def stop(timerLabel: String): Double = { 47 | val currentTime = System.nanoTime() 48 | if (!starts.contains(timerLabel)) { 49 | throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" + 50 | s" timerLabel = $timerLabel, but that timer was not started.") 51 | } 52 | val elapsed = currentTime - starts(timerLabel) 53 | starts.remove(timerLabel) 54 | if (totals.contains(timerLabel)) { 55 | totals(timerLabel) += elapsed 56 | } else { 57 | totals(timerLabel) = elapsed 58 | } 59 | elapsed / 1e9 60 | } 61 | 62 | /** 63 | * Print all timing results in seconds. 64 | */ 65 | override def toString: String = { 66 | totals.map { case (label, elapsed) => 67 | s" $label: ${elapsed / 1e9}" 68 | }.mkString("\n") 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/util/Utils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.util 19 | 20 | import java.util.Random 21 | 22 | object Utils { 23 | val random = new Random() 24 | def log1pExp(x: Double): Double = { 25 | if (x > 0) { 26 | x + math.log1p(math.exp(-x)) 27 | } else { 28 | math.log1p(math.exp(x)) 29 | } 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/util/XORShiftRandom.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | // Copy from spark 19 | 20 | package com.github.cloudml.zen.ml.util 21 | 22 | import java.nio.ByteBuffer 23 | import java.util.{Random => JavaRandom} 24 | 25 | import scala.util.hashing.MurmurHash3 26 | 27 | /** 28 | * This class implements a XORShift random number generator algorithm 29 | * Source: 30 | * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14. 31 | * @see Paper 32 | * This implementation is approximately 3.5 times faster than 33 | * { @link java.util.Random java.util.Random}, partly because of the algorithm, but also due 34 | * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class 35 | * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG 36 | * for each thread. 37 | */ 38 | class XORShiftRandom(init: Long) extends JavaRandom(init) { 39 | 40 | def this() = this(System.nanoTime) 41 | 42 | private var seed = XORShiftRandom.hashSeed(init) 43 | 44 | // we need to just override next - this will be called by nextInt, nextDouble, 45 | // nextGaussian, nextLong, etc. 46 | override protected def next(bits: Int): Int = { 47 | var nextSeed = seed ^ (seed << 21) 48 | nextSeed ^= (nextSeed >>> 35) 49 | nextSeed ^= (nextSeed << 4) 50 | seed = nextSeed 51 | (nextSeed & ((1L << bits) - 1)).asInstanceOf[Int] 52 | } 53 | 54 | override def setSeed(s: Long) { 55 | seed = XORShiftRandom.hashSeed(s) 56 | } 57 | } 58 | 59 | /** Contains benchmark method and main method to run benchmark of the RNG */ 60 | object XORShiftRandom { 61 | 62 | /** Hash seeds to have 0/1 bits throughout. */ 63 | private def hashSeed(seed: Long): Long = { 64 | val bytes = ByteBuffer.allocate(java.lang.Long.SIZE).putLong(seed).array() 65 | MurmurHash3.bytesHash(bytes) 66 | } 67 | 68 | } 69 | -------------------------------------------------------------------------------- /ml/src/main/scala/com/github/cloudml/zen/ml/util/modelSaveLoad.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.util 19 | 20 | import org.apache.hadoop.fs._ 21 | import org.apache.spark.SparkContext 22 | import org.apache.spark.rdd.RDD 23 | import org.apache.spark.sql.catalyst.ScalaReflection 24 | import org.apache.spark.sql.types.{DataType, StructField, StructType} 25 | import org.json4s._ 26 | import org.json4s.jackson.JsonMethods._ 27 | 28 | import scala.reflect.ClassTag 29 | import scala.reflect.runtime.universe.TypeTag 30 | 31 | // copy form Spark MLlib 32 | /** 33 | * Helper methods for loading models from files. 34 | */ 35 | private[ml] object LoaderUtils { 36 | 37 | /** Returns URI for path/data using the Hadoop filesystem */ 38 | def dataPath(path: String): String = new Path(path, "data").toUri.toString 39 | 40 | /** Returns URI for path/metadata using the Hadoop filesystem */ 41 | def metadataPath(path: String): String = new Path(path, "metadata").toUri.toString 42 | 43 | /** 44 | * Check the schema of loaded model data. 45 | * 46 | * This checks every field in the expected schema to make sure that a field with the same 47 | * name and DataType appears in the loaded schema. Note that this does NOT check metadata 48 | * or containsNull. 49 | * 50 | * @param loadedSchema Schema for model data loaded from file. 51 | * @tparam Data Expected data type from which an expected schema can be derived. 52 | */ 53 | def checkSchema[Data: TypeTag](loadedSchema: StructType): Unit = { 54 | // Check schema explicitly since erasure makes it hard to use match-case for checking. 55 | val expectedFields: Array[StructField] = 56 | ScalaReflection.schemaFor[Data].dataType.asInstanceOf[StructType].fields 57 | val loadedFields: Map[String, DataType] = 58 | loadedSchema.map(field => field.name -> field.dataType).toMap 59 | expectedFields.foreach { field => 60 | assert(loadedFields.contains(field.name), s"Unable to parse model data." + 61 | s" Expected field with name ${field.name} was missing in loaded schema:" + 62 | s" ${loadedFields.mkString(", ")}") 63 | } 64 | } 65 | 66 | /** 67 | * Load metadata from the given path. 68 | * @return (class name, version, metadata) 69 | */ 70 | def loadMetadata(sc: SparkContext, path: String): (String, String, JValue) = { 71 | implicit val formats = DefaultFormats 72 | val metadata = parse(sc.textFile(metadataPath(path)).first()) 73 | val clazz = (metadata \ "class").extract[String] 74 | val version = (metadata \ "version").extract[String] 75 | (clazz, version, metadata) 76 | } 77 | 78 | /** 79 | * Save an RDD to one HDFS file 80 | * @param sc SparkContext 81 | * @param rdd The RDD to save 82 | * @param outPathStr The HDFS file path of String 83 | * @param header Header line of HDFS file, used for storing some metadata 84 | * @param mapEle The function mapping each element of RDD to a line of String 85 | */ 86 | def RDD2HDFSFile[T: ClassTag](sc: SparkContext, 87 | rdd: RDD[T], 88 | outPathStr: String, 89 | header: => String, 90 | mapEle: T => String): Unit = { 91 | val hdpconf = sc.hadoopConfiguration 92 | val fs = FileSystem.get(hdpconf) 93 | val outPath = new Path(outPathStr) 94 | if (fs.exists(outPath)) { 95 | throw new InvalidPathException(s"Output path $outPathStr already exists.") 96 | } 97 | val fout = fs.create(outPath) 98 | fout.write(header.getBytes) 99 | fout.write("\n".getBytes) 100 | rdd.toLocalIterator.foreach(e => { 101 | fout.write(mapEle(e).getBytes) 102 | fout.write("\n".getBytes) 103 | }) 104 | fout.close() 105 | } 106 | 107 | /** 108 | * Load an RDD from one HDFS file 109 | * @param sc SparkContext 110 | * @param inPathStr The HDFS file path of String 111 | * @param init_f The function used for initialization after reading header 112 | * @param lineParser The function parses each line in HDFS file to an element of RDD 113 | */ 114 | def HDFSFile2RDD[T: ClassTag, M: ClassTag](sc: SparkContext, 115 | inPathStr: String, 116 | init_f: String => M, 117 | lineParser: (M, String) => T): (M, RDD[T]) = { 118 | val rawrdd = sc.textFile(inPathStr) 119 | val header = rawrdd.first() 120 | val meta = init_f(header) 121 | val rdd: RDD[T] = rawrdd.mapPartitions(iter => { 122 | val first = iter.next() 123 | if (first == header) { 124 | iter 125 | } else { 126 | Iterator.single(first) ++ iter 127 | } 128 | }.map(lineParser(meta, _))) 129 | (meta, rdd) 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/Edge.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.graphx2 19 | 20 | import org.apache.spark.util.collection.SortDataFormat 21 | 22 | /** 23 | * A single directed edge consisting of a source id, target id, 24 | * and the data associated with the edge. 25 | * 26 | * @tparam ED type of the edge attribute 27 | * 28 | * @param srcId The vertex id of the source vertex 29 | * @param dstId The vertex id of the target vertex 30 | * @param attr The attribute associated with the edge 31 | */ 32 | case class Edge[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED] ( 33 | var srcId: VertexId = 0, 34 | var dstId: VertexId = 0, 35 | var attr: ED = null.asInstanceOf[ED]) 36 | extends Serializable { 37 | 38 | /** 39 | * Given one vertex in the edge return the other vertex. 40 | * 41 | * @param vid the id one of the two vertices on the edge. 42 | * @return the id of the other vertex on the edge. 43 | */ 44 | def otherVertexId(vid: VertexId): VertexId = 45 | if (srcId == vid) dstId else { assert(dstId == vid); srcId } 46 | 47 | /** 48 | * Return the relative direction of the edge to the corresponding 49 | * vertex. 50 | * 51 | * @param vid the id of one of the two vertices in the edge. 52 | * @return the relative direction of the edge to the corresponding 53 | * vertex. 54 | */ 55 | def relativeDirection(vid: VertexId): EdgeDirection = 56 | if (vid == srcId) EdgeDirection.Out else { assert(vid == dstId); EdgeDirection.In } 57 | } 58 | 59 | object Edge { 60 | // scalastyle:off 61 | def lexicographicOrdering[ED] = new Ordering[Edge[ED]] { 62 | override def compare(a: Edge[ED], b: Edge[ED]): Int = { 63 | if (a.srcId == b.srcId) { 64 | if (a.dstId == b.dstId) 0 65 | else if (a.dstId < b.dstId) -1 66 | else 1 67 | } else if (a.srcId < b.srcId) -1 68 | else 1 69 | } 70 | } 71 | 72 | def edgeArraySortDataFormat[ED] = new SortDataFormat[Edge[ED], Array[Edge[ED]]] { 73 | override def getKey(data: Array[Edge[ED]], pos: Int): Edge[ED] = { 74 | data(pos) 75 | } 76 | 77 | override def swap(data: Array[Edge[ED]], pos0: Int, pos1: Int): Unit = { 78 | val tmp = data(pos0) 79 | data(pos0) = data(pos1) 80 | data(pos1) = tmp 81 | } 82 | 83 | override def copyElement( 84 | src: Array[Edge[ED]], srcPos: Int, 85 | dst: Array[Edge[ED]], dstPos: Int) { 86 | dst(dstPos) = src(srcPos) 87 | } 88 | 89 | override def copyRange( 90 | src: Array[Edge[ED]], srcPos: Int, 91 | dst: Array[Edge[ED]], dstPos: Int, length: Int) { 92 | System.arraycopy(src, srcPos, dst, dstPos, length) 93 | } 94 | 95 | override def allocate(length: Int): Array[Edge[ED]] = { 96 | new Array[Edge[ED]](length) 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/EdgeContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.graphx2 19 | 20 | /** 21 | * Represents an edge along with its neighboring vertices and allows sending messages along the 22 | * edge. Used in [[Graph#aggregateMessages]]. 23 | */ 24 | abstract class EdgeContext[VD, ED, A] { 25 | /** The vertex id of the edge's source vertex. */ 26 | def srcId: VertexId 27 | /** The vertex id of the edge's destination vertex. */ 28 | def dstId: VertexId 29 | /** The vertex attribute of the edge's source vertex. */ 30 | def srcAttr: VD 31 | /** The vertex attribute of the edge's destination vertex. */ 32 | def dstAttr: VD 33 | /** The attribute associated with the edge. */ 34 | def attr: ED 35 | 36 | /** Sends a message to the source vertex. */ 37 | def sendToSrc(msg: A): Unit 38 | /** Sends a message to the destination vertex. */ 39 | def sendToDst(msg: A): Unit 40 | 41 | /** Converts the edge and vertex properties into an [[EdgeTriplet]] for convenience. */ 42 | def toEdgeTriplet: EdgeTriplet[VD, ED] = { 43 | val et = new EdgeTriplet[VD, ED] 44 | et.srcId = srcId 45 | et.srcAttr = srcAttr 46 | et.dstId = dstId 47 | et.dstAttr = dstAttr 48 | et.attr = attr 49 | et 50 | } 51 | } 52 | 53 | object EdgeContext { 54 | 55 | /** 56 | * Extractor mainly used for Graph#aggregateMessages*. 57 | * Example: 58 | * {{{ 59 | * val messages = graph.aggregateMessages( 60 | * case ctx @ EdgeContext(_, _, _, _, attr) => 61 | * ctx.sendToDst(attr) 62 | * , _ + _) 63 | * }}} 64 | */ 65 | def unapply[VD, ED, A](edge: EdgeContext[VD, ED, A]): Some[(VertexId, VertexId, VD, VD, ED)] = 66 | Some(edge.srcId, edge.dstId, edge.srcAttr, edge.dstAttr, edge.attr) 67 | } 68 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/EdgeDirection.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.graphx2 19 | 20 | /** 21 | * The direction of a directed edge relative to a vertex. 22 | */ 23 | class EdgeDirection private (private val name: String) extends Serializable { 24 | /** 25 | * Reverse the direction of an edge. An in becomes out, 26 | * out becomes in and both and either remain the same. 27 | */ 28 | def reverse: EdgeDirection = this match { 29 | case EdgeDirection.In => EdgeDirection.Out 30 | case EdgeDirection.Out => EdgeDirection.In 31 | case EdgeDirection.Either => EdgeDirection.Either 32 | case EdgeDirection.Both => EdgeDirection.Both 33 | } 34 | 35 | override def toString: String = "EdgeDirection." + name 36 | 37 | override def equals(o: Any): Boolean = o match { 38 | case other: EdgeDirection => other.name == name 39 | case _ => false 40 | } 41 | 42 | override def hashCode: Int = name.hashCode 43 | } 44 | 45 | 46 | /** 47 | * A set of [[EdgeDirection]]s. 48 | */ 49 | object EdgeDirection { 50 | /** Edges arriving at a vertex. */ 51 | final val In: EdgeDirection = new EdgeDirection("In") 52 | 53 | /** Edges originating from a vertex. */ 54 | final val Out: EdgeDirection = new EdgeDirection("Out") 55 | 56 | /** Edges originating from *or* arriving at a vertex of interest. */ 57 | final val Either: EdgeDirection = new EdgeDirection("Either") 58 | 59 | /** Edges originating from *and* arriving at a vertex of interest. */ 60 | final val Both: EdgeDirection = new EdgeDirection("Both") 61 | } 62 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/EdgeRDD.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.graphx2 19 | 20 | import scala.language.existentials 21 | import scala.reflect.ClassTag 22 | 23 | import org.apache.spark.Dependency 24 | import org.apache.spark.Partition 25 | import org.apache.spark.SparkContext 26 | import org.apache.spark.TaskContext 27 | import org.apache.spark.rdd.RDD 28 | import org.apache.spark.storage.StorageLevel 29 | 30 | import org.apache.spark.graphx2.impl.EdgePartition 31 | import org.apache.spark.graphx2.impl.EdgePartitionBuilder 32 | import org.apache.spark.graphx2.impl.EdgeRDDImpl 33 | 34 | /** 35 | * `EdgeRDD[ED, VD]` extends `RDD[Edge[ED]]` by storing the edges in columnar format on each 36 | * partition for performance. It may additionally store the vertex attributes associated with each 37 | * edge to provide the triplet view. Shipping of the vertex attributes is managed by 38 | * `impl.ReplicatedVertexView`. 39 | */ 40 | abstract class EdgeRDD[ED]( 41 | @transient sc: SparkContext, 42 | @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { 43 | 44 | // scalastyle:off structural.type 45 | def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD } 46 | // scalastyle:on structural.type 47 | 48 | override protected def getPartitions: Array[Partition] = partitionsRDD.partitions 49 | 50 | override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = { 51 | val p = firstParent[(PartitionID, EdgePartition[ED, _])].iterator(part, context) 52 | if (p.hasNext) { 53 | p.next()._2.iterator.map(_.copy()) 54 | } else { 55 | Iterator.empty 56 | } 57 | } 58 | 59 | /** 60 | * Map the values in an edge partitioning preserving the structure but changing the values. 61 | * 62 | * @tparam ED2 the new edge value type 63 | * @param f the function from an edge to a new edge value 64 | * @return a new EdgeRDD containing the new edge values 65 | */ 66 | def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2] 67 | 68 | /** 69 | * Reverse all the edges in this RDD. 70 | * 71 | * @return a new EdgeRDD containing all the edges reversed 72 | */ 73 | def reverse: EdgeRDD[ED] 74 | 75 | /** 76 | * Inner joins this EdgeRDD with another EdgeRDD, assuming both are partitioned using the same 77 | * [[PartitionStrategy]]. 78 | * 79 | * @param other the EdgeRDD to join with 80 | * @param f the join function applied to corresponding values of `this` and `other` 81 | * @return a new EdgeRDD containing only edges that appear in both `this` and `other`, 82 | * with values supplied by `f` 83 | */ 84 | def innerJoin[ED2: ClassTag, ED3: ClassTag] 85 | (other: EdgeRDD[ED2]) 86 | (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3] 87 | 88 | /** 89 | * Changes the target storage level while preserving all other properties of the 90 | * EdgeRDD. Operations on the returned EdgeRDD will preserve this storage level. 91 | * 92 | * This does not actually trigger a cache; to do this, call 93 | * [[org.apache.spark.graphx2.EdgeRDD#cache]] on the returned EdgeRDD. 94 | */ 95 | def withTargetStorageLevel(targetStorageLevel: StorageLevel): EdgeRDD[ED] 96 | } 97 | 98 | object EdgeRDD { 99 | /** 100 | * Creates an EdgeRDD from a set of edges. 101 | * 102 | * @tparam ED the edge attribute type 103 | * @tparam VD the type of the vertex attributes that may be joined with the returned EdgeRDD 104 | */ 105 | def fromEdges[ED: ClassTag, VD: ClassTag](edges: RDD[Edge[ED]]): EdgeRDDImpl[ED, VD] = { 106 | val edgePartitions = edges.mapPartitionsWithIndex { (pid, iter) => 107 | val builder = new EdgePartitionBuilder[ED, VD] 108 | iter.foreach { e => 109 | builder.add(e.srcId, e.dstId, e.attr) 110 | } 111 | Iterator((pid, builder.toEdgePartition)) 112 | } 113 | EdgeRDD.fromEdgePartitions(edgePartitions) 114 | } 115 | 116 | /** 117 | * Creates an EdgeRDD from already-constructed edge partitions. 118 | * 119 | * @tparam ED the edge attribute type 120 | * @tparam VD the type of the vertex attributes that may be joined with the returned EdgeRDD 121 | */ 122 | def fromEdgePartitions[ED: ClassTag, VD: ClassTag]( 123 | edgePartitions: RDD[(Int, EdgePartition[ED, VD])]): EdgeRDDImpl[ED, VD] = { 124 | new EdgeRDDImpl(edgePartitions) 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/EdgeTriplet.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.graphx2 19 | 20 | /** 21 | * An edge triplet represents an edge along with the vertex attributes of its neighboring vertices. 22 | * 23 | * @tparam VD the type of the vertex attribute. 24 | * @tparam ED the type of the edge attribute 25 | */ 26 | class EdgeTriplet[VD, ED] extends Edge[ED] { 27 | /** 28 | * The source vertex attribute 29 | */ 30 | var srcAttr: VD = _ // nullValue[VD] 31 | 32 | /** 33 | * The destination vertex attribute 34 | */ 35 | var dstAttr: VD = _ // nullValue[VD] 36 | 37 | /** 38 | * Set the edge properties of this triplet. 39 | */ 40 | protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD, ED] = { 41 | srcId = other.srcId 42 | dstId = other.dstId 43 | attr = other.attr 44 | this 45 | } 46 | 47 | /** 48 | * Given one vertex in the edge return the other vertex. 49 | * 50 | * @param vid the id one of the two vertices on the edge 51 | * @return the attribute for the other vertex on the edge 52 | */ 53 | def otherVertexAttr(vid: VertexId): VD = 54 | if (srcId == vid) dstAttr else { assert(dstId == vid); srcAttr } 55 | 56 | /** 57 | * Get the vertex object for the given vertex in the edge. 58 | * 59 | * @param vid the id of one of the two vertices on the edge 60 | * @return the attr for the vertex with that id 61 | */ 62 | def vertexAttr(vid: VertexId): VD = 63 | if (srcId == vid) srcAttr else { assert(dstId == vid); dstAttr } 64 | 65 | override def toString: String = ((srcId, srcAttr), (dstId, dstAttr), attr).toString() 66 | 67 | def toTuple: ((VertexId, VD), (VertexId, VD), ED) = ((srcId, srcAttr), (dstId, dstAttr), attr) 68 | } 69 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/GraphLoader.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.graphx2 19 | 20 | import com.github.cloudml.zen.ml.util.Logging 21 | import org.apache.spark.storage.StorageLevel 22 | import org.apache.spark.SparkContext 23 | import org.apache.spark.graphx2.impl.{EdgePartitionBuilder, GraphImpl} 24 | 25 | /** 26 | * Provides utilities for loading [[Graph]]s from files. 27 | */ 28 | object GraphLoader extends Logging { 29 | 30 | /** 31 | * Loads a graph from an edge list formatted file where each line contains two integers: a source 32 | * id and a target id. Skips lines that begin with `#`. 33 | * 34 | * If desired the edges can be automatically oriented in the positive 35 | * direction (source Id < target Id) by setting `canonicalOrientation` to 36 | * true. 37 | * 38 | * @example Loads a file in the following format: 39 | * {{{ 40 | * # Comment Line 41 | * # Source Id <\t> Target Id 42 | * 1 -5 43 | * 1 2 44 | * 2 7 45 | * 1 8 46 | * }}} 47 | * 48 | * @param sc SparkContext 49 | * @param path the path to the file (e.g., /home/data/file or hdfs://file) 50 | * @param canonicalOrientation whether to orient edges in the positive 51 | * direction 52 | * @param numEdgePartitions the number of partitions for the edge RDD 53 | * Setting this value to -1 will use the default parallelism. 54 | * @param edgeStorageLevel the desired storage level for the edge partitions 55 | * @param vertexStorageLevel the desired storage level for the vertex partitions 56 | */ 57 | def edgeListFile( 58 | sc: SparkContext, 59 | path: String, 60 | canonicalOrientation: Boolean = false, 61 | numEdgePartitions: Int = -1, 62 | edgeStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY, 63 | vertexStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) 64 | : Graph[Int, Int] = 65 | { 66 | val startTime = System.currentTimeMillis 67 | 68 | // Parse the edge data table directly into edge partitions 69 | val lines = 70 | if (numEdgePartitions > 0) { 71 | sc.textFile(path, numEdgePartitions).coalesce(numEdgePartitions) 72 | } else { 73 | sc.textFile(path) 74 | } 75 | val edges = lines.mapPartitionsWithIndex { (pid, iter) => 76 | val builder = new EdgePartitionBuilder[Int, Int] 77 | iter.foreach { line => 78 | if (!line.isEmpty && line(0) != '#') { 79 | val lineArray = line.split("\\s+") 80 | if (lineArray.length < 2) { 81 | throw new IllegalArgumentException("Invalid line: " + line) 82 | } 83 | val srcId = lineArray(0).toLong 84 | val dstId = lineArray(1).toLong 85 | if (canonicalOrientation && srcId > dstId) { 86 | builder.add(dstId, srcId, 1) 87 | } else { 88 | builder.add(srcId, dstId, 1) 89 | } 90 | } 91 | } 92 | Iterator((pid, builder.toEdgePartition)) 93 | }.persist(edgeStorageLevel).setName("GraphLoader.edgeListFile - edges (%s)".format(path)) 94 | edges.count() 95 | 96 | logInfo("It took %d ms to load the edges".format(System.currentTimeMillis - startTime)) 97 | 98 | GraphImpl.fromEdgePartitions(edges, defaultVertexAttr = 1, edgeStorageLevel = edgeStorageLevel, 99 | vertexStorageLevel = vertexStorageLevel) 100 | } // end of edgeListFile 101 | 102 | } 103 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/GraphXUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.graphx2 19 | 20 | import org.apache.spark.SparkConf 21 | 22 | import org.apache.spark.graphx2.impl._ 23 | import org.apache.spark.graphx2.util.collection.GraphXPrimitiveKeyOpenHashMap 24 | 25 | import org.apache.spark.util.collection.{OpenHashSet, BitSet} 26 | import org.apache.spark.util.BoundedPriorityQueue 27 | 28 | object GraphXUtils { 29 | /** 30 | * Registers classes that GraphX uses with Kryo. 31 | */ 32 | def registerKryoClasses(conf: SparkConf) { 33 | conf.registerKryoClasses(Array( 34 | classOf[Edge[Object]], 35 | classOf[(VertexId, Object)], 36 | classOf[EdgePartition[Object, Object]], 37 | classOf[BitSet], 38 | classOf[VertexIdToIndexMap], 39 | classOf[VertexAttributeBlock[Object]], 40 | classOf[PartitionStrategy], 41 | classOf[BoundedPriorityQueue[Object]], 42 | classOf[EdgeDirection], 43 | classOf[GraphXPrimitiveKeyOpenHashMap[VertexId, Int]], 44 | classOf[OpenHashSet[Int]], 45 | classOf[OpenHashSet[Long]])) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/TripletFields.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.graphx2; 19 | 20 | import java.io.Serializable; 21 | 22 | /** 23 | * Represents a subset of the fields of an [[EdgeTriplet]] or [[EdgeContext]]. This allows the 24 | * system to populate only those fields for efficiency. 25 | */ 26 | public class TripletFields implements Serializable { 27 | 28 | /** Indicates whether the source vertex attribute is included. */ 29 | public final boolean useSrc; 30 | 31 | /** Indicates whether the destination vertex attribute is included. */ 32 | public final boolean useDst; 33 | 34 | /** Indicates whether the edge attribute is included. */ 35 | public final boolean useEdge; 36 | 37 | /** Constructs a default TripletFields in which all fields are included. */ 38 | public TripletFields() { 39 | this(true, true, true); 40 | } 41 | 42 | public TripletFields(boolean useSrc, boolean useDst, boolean useEdge) { 43 | this.useSrc = useSrc; 44 | this.useDst = useDst; 45 | this.useEdge = useEdge; 46 | } 47 | 48 | /** 49 | * None of the triplet fields are exposed. 50 | */ 51 | public static final TripletFields None = new TripletFields(false, false, false); 52 | 53 | /** 54 | * Expose only the edge field and not the source or destination field. 55 | */ 56 | public static final TripletFields EdgeOnly = new TripletFields(false, false, true); 57 | 58 | /** 59 | * Expose the source and edge fields but not the destination field. (Same as Src) 60 | */ 61 | public static final TripletFields Src = new TripletFields(true, false, true); 62 | 63 | /** 64 | * Expose the destination and edge fields but not the source field. (Same as Dst) 65 | */ 66 | public static final TripletFields Dst = new TripletFields(false, true, true); 67 | 68 | /** 69 | * Expose all the fields (source, edge, and destination). 70 | */ 71 | public static final TripletFields All = new TripletFields(true, true, true); 72 | } 73 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/impl/EdgeActiveness.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.graphx2.impl; 19 | 20 | /** 21 | * Criteria for filtering edges based on activeness. For internal use only. 22 | */ 23 | public enum EdgeActiveness { 24 | /** Neither the source vertex nor the destination vertex need be active. */ 25 | Neither, 26 | /** The source vertex must be active. */ 27 | SrcOnly, 28 | /** The destination vertex must be active. */ 29 | DstOnly, 30 | /** Both vertices must be active. */ 31 | Both, 32 | /** At least one vertex must be active. */ 33 | Either 34 | } 35 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/impl/EdgeRDDImpl.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.graphx2.impl 19 | 20 | import scala.reflect.{classTag, ClassTag} 21 | 22 | import org.apache.spark.{OneToOneDependency, HashPartitioner} 23 | import org.apache.spark.rdd.RDD 24 | import org.apache.spark.storage.StorageLevel 25 | 26 | import org.apache.spark.graphx2._ 27 | 28 | class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] ( 29 | @transient override val partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])], 30 | val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) 31 | extends EdgeRDD[ED](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { 32 | 33 | override def setName(_name: String): this.type = { 34 | if (partitionsRDD.name != null) { 35 | partitionsRDD.setName(partitionsRDD.name + ", " + _name) 36 | } else { 37 | partitionsRDD.setName(_name) 38 | } 39 | this 40 | } 41 | setName("EdgeRDD") 42 | 43 | /** 44 | * If `partitionsRDD` already has a partitioner, use it. Otherwise assume that the 45 | * [[PartitionID]]s in `partitionsRDD` correspond to the actual partitions and create a new 46 | * partitioner that allows co-partitioning with `partitionsRDD`. 47 | */ 48 | override val partitioner = 49 | partitionsRDD.partitioner.orElse(Some(new HashPartitioner(partitions.size))) 50 | 51 | override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() 52 | 53 | /** 54 | * Persists the edge partitions at the specified storage level, ignoring any existing target 55 | * storage level. 56 | */ 57 | override def persist(newLevel: StorageLevel): this.type = { 58 | partitionsRDD.persist(newLevel) 59 | this 60 | } 61 | 62 | override def unpersist(blocking: Boolean = true): this.type = { 63 | partitionsRDD.unpersist(blocking) 64 | this 65 | } 66 | 67 | /** Persists the edge partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */ 68 | override def cache(): this.type = { 69 | partitionsRDD.persist(targetStorageLevel) 70 | this 71 | } 72 | 73 | override def getStorageLevel: StorageLevel = partitionsRDD.getStorageLevel 74 | 75 | override def checkpoint(): Unit = { 76 | partitionsRDD.checkpoint() 77 | } 78 | 79 | override def isCheckpointed: Boolean = { 80 | firstParent[(PartitionID, EdgePartition[ED, VD])].isCheckpointed 81 | } 82 | 83 | override def getCheckpointFile: Option[String] = { 84 | partitionsRDD.getCheckpointFile 85 | } 86 | 87 | /** The number of edges in the RDD. */ 88 | override def count(): Long = { 89 | partitionsRDD.map(_._2.size.toLong).reduce(_ + _) 90 | } 91 | 92 | override def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDDImpl[ED2, VD] = 93 | mapEdgePartitions((pid, part) => part.map(f)) 94 | 95 | override def reverse: EdgeRDDImpl[ED, VD] = mapEdgePartitions((pid, part) => part.reverse) 96 | 97 | def filter( 98 | epred: EdgeTriplet[VD, ED] => Boolean, 99 | vpred: (VertexId, VD) => Boolean): EdgeRDDImpl[ED, VD] = { 100 | mapEdgePartitions((pid, part) => part.filter(epred, vpred)) 101 | } 102 | 103 | override def innerJoin[ED2: ClassTag, ED3: ClassTag] 104 | (other: EdgeRDD[ED2]) 105 | (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDDImpl[ED3, VD] = { 106 | val ed2Tag = classTag[ED2] 107 | val ed3Tag = classTag[ED3] 108 | this.withPartitionsRDD[ED3, VD](partitionsRDD.zipPartitions(other.partitionsRDD, true) { 109 | (thisIter, otherIter) => 110 | val (pid, thisEPart) = thisIter.next() 111 | val (_, otherEPart) = otherIter.next() 112 | Iterator(Tuple2(pid, thisEPart.innerJoin(otherEPart)(f)(ed2Tag, ed3Tag))) 113 | }) 114 | } 115 | 116 | def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag]( 117 | f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDDImpl[ED2, VD2] = { 118 | this.withPartitionsRDD[ED2, VD2](partitionsRDD.mapPartitions({ iter => 119 | if (iter.hasNext) { 120 | val (pid, ep) = iter.next() 121 | Iterator(Tuple2(pid, f(pid, ep))) 122 | } else { 123 | Iterator.empty 124 | } 125 | }, preservesPartitioning = true)) 126 | } 127 | 128 | def withPartitionsRDD[ED2: ClassTag, VD2: ClassTag]( 129 | partitionsRDD: RDD[(PartitionID, EdgePartition[ED2, VD2])]): EdgeRDDImpl[ED2, VD2] = { 130 | new EdgeRDDImpl(partitionsRDD, this.targetStorageLevel) 131 | } 132 | 133 | override def withTargetStorageLevel( 134 | targetStorageLevel: StorageLevel): EdgeRDDImpl[ED, VD] = { 135 | new EdgeRDDImpl(this.partitionsRDD, targetStorageLevel) 136 | } 137 | 138 | } 139 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/impl/VertexPartition.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.graphx2.impl 19 | 20 | import scala.reflect.ClassTag 21 | 22 | import org.apache.spark.util.collection.BitSet 23 | 24 | import org.apache.spark.graphx2._ 25 | import org.apache.spark.graphx2.util.collection.GraphXPrimitiveKeyOpenHashMap 26 | 27 | object VertexPartition { 28 | /** Construct a `VertexPartition` from the given vertices. */ 29 | def apply[VD: ClassTag](iter: Iterator[(VertexId, VD)]) 30 | : VertexPartition[VD] = { 31 | val (index, values, mask) = VertexPartitionBase.initFrom(iter) 32 | new VertexPartition(index, values, mask) 33 | } 34 | 35 | import scala.language.implicitConversions 36 | 37 | /** 38 | * Implicit conversion to allow invoking `VertexPartitionBase` operations directly on a 39 | * `VertexPartition`. 40 | */ 41 | implicit def partitionToOps[VD: ClassTag](partition: VertexPartition[VD]) 42 | : VertexPartitionOps[VD] = new VertexPartitionOps(partition) 43 | 44 | /** 45 | * Implicit evidence that `VertexPartition` is a member of the `VertexPartitionBaseOpsConstructor` 46 | * typeclass. This enables invoking `VertexPartitionBase` operations on a `VertexPartition` via an 47 | * evidence parameter, as in [[VertexPartitionBaseOps]]. 48 | */ 49 | implicit object VertexPartitionOpsConstructor 50 | extends VertexPartitionBaseOpsConstructor[VertexPartition] { 51 | def toOps[VD: ClassTag](partition: VertexPartition[VD]) 52 | : VertexPartitionBaseOps[VD, VertexPartition] = partitionToOps(partition) 53 | } 54 | } 55 | 56 | /** A map from vertex id to vertex attribute. */ 57 | class VertexPartition[VD: ClassTag]( 58 | val index: VertexIdToIndexMap, 59 | val values: Array[VD], 60 | val mask: BitSet) 61 | extends VertexPartitionBase[VD] 62 | 63 | class VertexPartitionOps[VD: ClassTag](self: VertexPartition[VD]) 64 | extends VertexPartitionBaseOps[VD, VertexPartition](self) { 65 | 66 | def withIndex(index: VertexIdToIndexMap): VertexPartition[VD] = { 67 | new VertexPartition(index, self.values, self.mask) 68 | } 69 | 70 | def withValues[VD2: ClassTag](values: Array[VD2]): VertexPartition[VD2] = { 71 | new VertexPartition(self.index, values, self.mask) 72 | } 73 | 74 | def withMask(mask: BitSet): VertexPartition[VD] = { 75 | new VertexPartition(self.index, self.values, mask) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/impl/VertexPartitionBase.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.graphx2.impl 19 | 20 | import scala.language.higherKinds 21 | import scala.reflect.ClassTag 22 | 23 | import org.apache.spark.util.collection.BitSet 24 | 25 | import org.apache.spark.graphx2._ 26 | import org.apache.spark.graphx2.util.collection.GraphXPrimitiveKeyOpenHashMap 27 | 28 | object VertexPartitionBase { 29 | /** 30 | * Construct the constituents of a VertexPartitionBase from the given vertices, merging duplicate 31 | * entries arbitrarily. 32 | */ 33 | def initFrom[VD: ClassTag](iter: Iterator[(VertexId, VD)]) 34 | : (VertexIdToIndexMap, Array[VD], BitSet) = { 35 | val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, VD] 36 | iter.foreach { pair => 37 | map(pair._1) = pair._2 38 | } 39 | (map.keySet, map._values, map.keySet.getBitSet) 40 | } 41 | 42 | /** 43 | * Construct the constituents of a VertexPartitionBase from the given vertices, merging duplicate 44 | * entries using `mergeFunc`. 45 | */ 46 | def initFrom[VD: ClassTag](iter: Iterator[(VertexId, VD)], mergeFunc: (VD, VD) => VD) 47 | : (VertexIdToIndexMap, Array[VD], BitSet) = { 48 | val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, VD] 49 | iter.foreach { pair => 50 | map.setMerge(pair._1, pair._2, mergeFunc) 51 | } 52 | (map.keySet, map._values, map.keySet.getBitSet) 53 | } 54 | } 55 | 56 | /** 57 | * An abstract map from vertex id to vertex attribute. [[VertexPartition]] is the corresponding 58 | * concrete implementation. [[VertexPartitionBaseOps]] provides a variety of operations for 59 | * VertexPartitionBase and subclasses that provide implicit evidence of membership in the 60 | * `VertexPartitionBaseOpsConstructor` typeclass (for example, 61 | * [[VertexPartition.VertexPartitionOpsConstructor]]). 62 | */ 63 | abstract class VertexPartitionBase[@specialized(Long, Int, Double) VD: ClassTag] 64 | extends Serializable { 65 | 66 | def index: VertexIdToIndexMap 67 | def values: Array[VD] 68 | def mask: BitSet 69 | 70 | val capacity: Int = index.capacity 71 | 72 | def size: Int = mask.cardinality() 73 | 74 | /** Return the vertex attribute for the given vertex ID. */ 75 | def apply(vid: VertexId): VD = values(index.getPos(vid)) 76 | 77 | def isDefined(vid: VertexId): Boolean = { 78 | val pos = index.getPos(vid) 79 | pos >= 0 && mask.get(pos) 80 | } 81 | 82 | def iterator: Iterator[(VertexId, VD)] = 83 | mask.iterator.map(ind => (index.getValue(ind), values(ind))) 84 | } 85 | 86 | /** 87 | * A typeclass for subclasses of `VertexPartitionBase` representing the ability to wrap them in a 88 | * `VertexPartitionBaseOps`. 89 | */ 90 | trait VertexPartitionBaseOpsConstructor[T[X] <: VertexPartitionBase[X]] { 91 | def toOps[VD: ClassTag](partition: T[VD]): VertexPartitionBaseOps[VD, T] 92 | } 93 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/impl/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.graphx2 19 | 20 | import org.apache.spark.util.collection.OpenHashSet 21 | 22 | package object impl { 23 | type VertexIdToIndexMap = OpenHashSet[VertexId] 24 | } 25 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/package-info.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * ALPHA COMPONENT 20 | * GraphX is a graph processing framework built on top of Spark. 21 | */ 22 | package org.apache.spark.graphx2; -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark 19 | 20 | import org.apache.spark.util.collection.OpenHashSet 21 | 22 | /** 23 | * ALPHA COMPONENT 24 | * GraphX is a graph processing framework built on top of Spark. 25 | */ 26 | package object graphx2 { 27 | /** 28 | * A 64-bit vertex identifier that uniquely identifies a vertex within a graph. It does not need 29 | * to follow any ordering or any constraints other than uniqueness. 30 | */ 31 | type VertexId = Long 32 | 33 | /** Integer identifer of a graph partition. Must be less than 2^30. */ 34 | // TODO: Consider using Char. 35 | type PartitionID = Int 36 | 37 | type VertexSet = OpenHashSet[VertexId] 38 | } 39 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/util/package-info.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /** 19 | * Collections of utilities used by graphx2. 20 | */ 21 | package org.apache.spark.graphx2.util; -------------------------------------------------------------------------------- /ml/src/main/scala/org/apache/spark/graphx2/util/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.graphx2 19 | 20 | /** 21 | * Collections of utilities used by graphx2. 22 | */ 23 | package object util 24 | -------------------------------------------------------------------------------- /ml/src/test/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | # Set everything to be logged to the file target/unit-tests.log 19 | log4j.rootCategory=INFO, file 20 | log4j.appender.file=org.apache.log4j.FileAppender 21 | log4j.appender.file.append=true 22 | log4j.appender.file.file=target/unit-tests.log 23 | log4j.appender.file.layout=org.apache.log4j.PatternLayout 24 | log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n 25 | 26 | # Ignore messages below warning level from Jetty, because it's a bit verbose 27 | log4j.logger.org.spark-project.jetty=WARN 28 | 29 | -------------------------------------------------------------------------------- /ml/src/test/scala/com/github/cloudml/zen/ml/neuralNetwork/DBNSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.neuralNetwork 19 | 20 | import com.github.cloudml.zen.ml.util.MnistDatasetSuite 21 | import org.scalatest.{FunSuite, Matchers} 22 | 23 | class DBNSuite extends FunSuite with MnistDatasetSuite with Matchers { 24 | 25 | ignore("DBN") { 26 | val (data, numVisible) = mnistTrainDataset(2500) 27 | val dbn = new DBN(Array(numVisible, 500, 10)) 28 | DBN.pretrain(data, 100, 1000, dbn, 0.1, 0.05, 0.0) 29 | DBN.finetune(data, 100, 1000, dbn, 0.02, 0.05, 0.0) 30 | val (dataTest, _) = mnistTrainDataset(5000, 2500) 31 | println("Error: " + MLP.error(dataTest, dbn.mlp, 100)) 32 | } 33 | 34 | } 35 | -------------------------------------------------------------------------------- /ml/src/test/scala/com/github/cloudml/zen/ml/neuralNetwork/MLPSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.neuralNetwork 19 | 20 | 21 | import com.github.cloudml.zen.ml.util.{Utils, SparkUtils, MnistDatasetSuite} 22 | import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM} 23 | import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics 24 | import org.apache.spark.mllib.linalg.{Vector => SV} 25 | import org.apache.spark.mllib.regression.LabeledPoint 26 | import org.apache.spark.mllib.util.MLUtils 27 | import org.scalatest.{FunSuite, Matchers} 28 | 29 | class MLPSuite extends FunSuite with MnistDatasetSuite with Matchers { 30 | ignore("MLP") { 31 | val (data, numVisible) = mnistTrainDataset(5000) 32 | val topology = Array(numVisible, 500, 10) 33 | val nn = MLP.train(data, 20, 1000, topology, fraction = 0.02, 34 | learningRate = 0.1, weightCost = 0.0) 35 | 36 | // val nn = MLP.runLBFGS(data, topology, 100, 4000, 1e-5, 0.001) 37 | // MLP.runSGD(data, nn, 37, 6000, 0.1, 0.5, 0.0) 38 | 39 | val (dataTest, _) = mnistTrainDataset(10000, 5000) 40 | println("Error: " + MLP.error(dataTest, nn, 100)) 41 | } 42 | 43 | ignore("binary classification") { 44 | val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) 45 | val dataSetFile = s"$sparkHome/data/a5a" 46 | val checkpoint = s"$sparkHome/target/tmp" 47 | sc.setCheckpointDir(checkpoint) 48 | val data = MLUtils.loadLibSVMFile(sc, dataSetFile).map { 49 | case LabeledPoint(label, features) => 50 | val y = BDV.zeros[Double](2) 51 | y := 0.04 / y.length 52 | y(if (label > 0) 0 else 1) += 0.96 53 | (features, SparkUtils.fromBreeze(y)) 54 | }.persist() 55 | val trainSet = data.filter(_._1.hashCode().abs % 5 == 3).persist() 56 | val testSet = data.filter(_._1.hashCode().abs % 5 != 3).persist() 57 | 58 | val numVisible = trainSet.first()._1.size 59 | val topology = Array(numVisible, 30, 2) 60 | var nn = MLP.train(trainSet, 100, 1000, topology, fraction = 0.02, 61 | learningRate = 0.05, weightCost = 0.0) 62 | 63 | val modelPath = s"$checkpoint/model" 64 | nn.save(sc, modelPath) 65 | nn = MLP.load(sc, modelPath) 66 | val scoreAndLabels = testSet.map { case (features, label) => 67 | val out = nn.predict(SparkUtils.toBreeze(features).toDenseVector.asDenseMatrix.t) 68 | // Utils.random.nextInt(2).toDouble 69 | (out(0, 0), if (label(0) > 0.5) 1.0 else 0.0) 70 | }.persist() 71 | scoreAndLabels.repartition(1).map(t => s"${t._1}\t${t._2}"). 72 | saveAsTextFile(s"$checkpoint/mlp/${System.currentTimeMillis()}") 73 | val testAccuracy = new BinaryClassificationMetrics(scoreAndLabels).areaUnderROC() 74 | println(f"Test AUC = $testAccuracy%1.6f") 75 | 76 | } 77 | 78 | } 79 | -------------------------------------------------------------------------------- /ml/src/test/scala/com/github/cloudml/zen/ml/neuralNetwork/RBMSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.neuralNetwork 19 | 20 | import com.github.cloudml.zen.ml.util.MnistDatasetSuite 21 | import org.scalatest.{FunSuite, Matchers} 22 | 23 | class RBMSuite extends FunSuite with MnistDatasetSuite with Matchers { 24 | 25 | ignore("RBM") { 26 | val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) 27 | val checkpoint = s"$sparkHome/target/tmp/rmb/${System.currentTimeMillis()}" 28 | sc.setCheckpointDir(checkpoint) 29 | val (data, numVisible) = mnistTrainDataset(2500) 30 | val rbm = RBM.train(data.map(_._1), 100, 1000, numVisible, 256, 0.1, 0.05, 0.0) 31 | val modelPath = s"$checkpoint/model" 32 | rbm.save(sc, modelPath) 33 | val newRBM = RBM.load(sc, modelPath) 34 | assert(rbm.equals(newRBM)) 35 | } 36 | 37 | } 38 | -------------------------------------------------------------------------------- /ml/src/test/scala/com/github/cloudml/zen/ml/neuralNetwork/StackedRBMSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.neuralNetwork 19 | 20 | import com.github.cloudml.zen.ml.util.MnistDatasetSuite 21 | import org.scalatest.{FunSuite, Matchers} 22 | 23 | class StackedRBMSuite extends FunSuite with MnistDatasetSuite with Matchers { 24 | 25 | ignore("StackedRBM") { 26 | val (data, numVisible) = mnistTrainDataset(5000) 27 | data.cache() 28 | val topology = Array(numVisible, 300, 300, 500) 29 | val stackedRBM = StackedRBM.train(data.map(_._1), 100, 1200, topology, 0.01, 0.1, 0.0) 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /ml/src/test/scala/com/github/cloudml/zen/ml/recommendation/FMSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.recommendation 19 | 20 | import com.github.cloudml.zen.ml.util._ 21 | import org.apache.spark.mllib.regression.LabeledPoint 22 | import com.google.common.io.Files 23 | import org.apache.spark.mllib.util.MLUtils 24 | 25 | import org.scalatest.{Matchers, FunSuite} 26 | 27 | class FMSuite extends FunSuite with SharedSparkContext with Matchers { 28 | test("binary classification") { 29 | val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) 30 | val dataSetFile = classOf[FMSuite].getClassLoader().getResource("binary_classification_data.txt").toString() 31 | val checkpoint = s"$sparkHome/target/tmp" 32 | sc.setCheckpointDir(checkpoint) 33 | val dataSet = MLUtils.loadLibSVMFile(sc, dataSetFile).zipWithIndex().map { 34 | case (LabeledPoint(label, features), id) => 35 | val newLabel = if (label > 0.0) 1.0 else 0.0 36 | (id, LabeledPoint(newLabel, features)) 37 | } 38 | val stepSize = 0.1 39 | val regParam = 1e-4 40 | val l2 = (regParam, regParam, regParam) 41 | val rank = 5 42 | val useAdaGrad = true 43 | val trainSet = dataSet.cache() 44 | val fm = new FMClassification(trainSet, stepSize, l2, rank, useAdaGrad) 45 | 46 | val maxIter = 10 47 | val pps = new Array[Double](maxIter) 48 | var i = 0 49 | val startedAt = System.currentTimeMillis() 50 | while (i < maxIter) { 51 | fm.run(1) 52 | pps(i) = fm.saveModel().loss(trainSet) 53 | i += 1 54 | } 55 | println((System.currentTimeMillis() - startedAt) / 1e3) 56 | pps.foreach(println) 57 | 58 | val ppsDiff = pps.init.zip(pps.tail).map { case (lhs, rhs) => lhs - rhs } 59 | assert(ppsDiff.count(_ < 0).toDouble / ppsDiff.size > 0.05) 60 | 61 | 62 | val fmModel = fm.saveModel() 63 | val tempDir = Files.createTempDir() 64 | tempDir.deleteOnExit() 65 | val path = tempDir.toURI.toString 66 | fmModel.save(sc, path) 67 | val sameModel = FMModel.load(sc, path) 68 | assert(sameModel.k === fmModel.k) 69 | assert(sameModel.intercept === fmModel.intercept) 70 | assert(sameModel.classification === fmModel.classification) 71 | assert(sameModel.factors.sortByKey().map(_._2).collect() === 72 | fmModel.factors.sortByKey().map(_._2).collect()) 73 | } 74 | 75 | ignore("regression") { 76 | val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) 77 | val dataSetFile = classOf[FMSuite].getClassLoader().getResource("regression_data.txt").toString() 78 | val checkpoint = s"$sparkHome/target/tmp" 79 | sc.setCheckpointDir(checkpoint) 80 | val dataSet = MLUtils.loadLibSVMFile(sc, dataSetFile).zipWithIndex().map { 81 | case (labeledPoint, id) => 82 | (id, labeledPoint) 83 | } 84 | val stepSize = 0.1 85 | val numIterations = 200 86 | val regParam = 1e-3 87 | val l2 = (regParam, regParam, regParam) 88 | val rank = 20 89 | val useAdaGrad = true 90 | val miniBatchFraction = 1.0 91 | val Array(trainSet, testSet) = dataSet.randomSplit(Array(0.8, 0.2)) 92 | val fm = new FMRegression(trainSet.cache(), stepSize, l2, rank, useAdaGrad, miniBatchFraction) 93 | fm.run(numIterations) 94 | val model = fm.saveModel() 95 | println(f"Test loss: ${model.loss(testSet.cache())}%1.4f") 96 | } 97 | 98 | ignore("url_combined dataSet") { 99 | // val dataSetFile = "/input/lbs/recommend/kdda/*" 100 | val dataSetFile = "/input/lbs/recommend/url_combined/*" 101 | val checkpointDir = "/input/lbs/recommend/toona/als/checkpointDir" 102 | sc.setCheckpointDir(checkpointDir) 103 | val dataSet = MLUtils.loadLibSVMFile(sc, dataSetFile).zipWithIndex().map { 104 | case (LabeledPoint(label, features), id) => 105 | val newLabel = if (label > 0.0) 1.0 else 0.0 106 | (id, LabeledPoint(newLabel, features)) 107 | }.repartition(72).cache() 108 | val stepSize = 0.1 109 | val numIterations = 500 110 | val regParam = 0.0 111 | val l2 = (regParam, regParam, regParam) 112 | val rank = 20 113 | val useAdaGrad = true 114 | val miniBatchFraction = 0.1 115 | val Array(trainSet, testSet) = dataSet.randomSplit(Array(0.8, 0.2)) 116 | val fm = new FMClassification(trainSet.cache(), stepSize, l2, rank, useAdaGrad, miniBatchFraction) 117 | fm.run(numIterations) 118 | val model = fm.saveModel() 119 | println(f"Test loss: ${model.loss(testSet.cache())}%1.4f") 120 | 121 | } 122 | 123 | } 124 | -------------------------------------------------------------------------------- /ml/src/test/scala/com/github/cloudml/zen/ml/recommendation/MVMSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.recommendation 19 | 20 | import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV, sum => brzSum} 21 | import com.github.cloudml.zen.ml.util._ 22 | import com.google.common.io.Files 23 | import org.apache.spark.mllib.linalg.{DenseVector => SDV, SparseVector => SSV, Vector => SV} 24 | import org.apache.spark.mllib.regression.LabeledPoint 25 | import org.apache.spark.mllib.util.MLUtils 26 | import org.scalatest.{FunSuite, Matchers} 27 | 28 | class MVMSuite extends FunSuite with SharedSparkContext with Matchers { 29 | test("binary classification") { 30 | val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) 31 | val dataSetFile = classOf[MVMSuite].getClassLoader().getResource("binary_classification_data.txt").toString() 32 | val checkpoint = s"$sparkHome/target/tmp" 33 | sc.setCheckpointDir(checkpoint) 34 | val dataSet = MLUtils.loadLibSVMFile(sc, dataSetFile).zipWithIndex().map { 35 | case (LabeledPoint(label, features), id) => 36 | val newLabel = if (label > 0.0) 1.0 else 0.0 37 | (id, LabeledPoint(newLabel, features)) 38 | } 39 | val stepSize = 0.1 40 | val regParam = 1e-2 41 | val l2 = (regParam, regParam, regParam) 42 | val rank = 20 43 | val useAdaGrad = true 44 | val trainSet = dataSet.cache() 45 | val fm = new FMClassification(trainSet, stepSize, l2, rank, useAdaGrad) 46 | 47 | val maxIter = 10 48 | val pps = new Array[Double](maxIter) 49 | var i = 0 50 | val startedAt = System.currentTimeMillis() 51 | while (i < maxIter) { 52 | fm.run(1) 53 | pps(i) = fm.saveModel().loss(trainSet) 54 | i += 1 55 | } 56 | println((System.currentTimeMillis() - startedAt) / 1e3) 57 | pps.foreach(println) 58 | 59 | val ppsDiff = pps.init.zip(pps.tail).map { case (lhs, rhs) => lhs - rhs } 60 | assert(ppsDiff.count(_ < 0).toDouble / ppsDiff.size > 0.05) 61 | 62 | val fmModel = fm.saveModel() 63 | val tempDir = Files.createTempDir() 64 | tempDir.deleteOnExit() 65 | val path = tempDir.toURI.toString 66 | fmModel.save(sc, path) 67 | val sameModel = FMModel.load(sc, path) 68 | assert(sameModel.k === fmModel.k) 69 | assert(sameModel.classification === fmModel.classification) 70 | assert(sameModel.factors.sortByKey().map(_._2).collect() === 71 | fmModel.factors.sortByKey().map(_._2).collect()) 72 | } 73 | 74 | ignore("url_combined classification") { 75 | val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) 76 | val dataSetFile = classOf[MVMSuite].getClassLoader().getResource("binary_classification_data.txt").toString() 77 | val checkpointDir = s"$sparkHome/target/tmp" 78 | sc.setCheckpointDir(checkpointDir) 79 | val dataSet = MLUtils.loadLibSVMFile(sc, dataSetFile).zipWithIndex().map { 80 | case (LabeledPoint(label, features), id) => 81 | val newLabel = if (label > 0.0) 1.0 else 0.0 82 | (id, LabeledPoint(newLabel, features)) 83 | }.cache() 84 | val numFeatures = dataSet.first()._2.features.size 85 | val stepSize = 0.1 86 | val numIterations = 500 87 | val regParam = 1e-3 88 | val rank = 20 89 | val views = Array(20, numFeatures / 2, numFeatures).map(_.toLong) 90 | val useAdaGrad = true 91 | val useWeightedLambda = true 92 | val miniBatchFraction = 1 93 | val Array(trainSet, testSet) = dataSet.randomSplit(Array(0.8, 0.2)) 94 | trainSet.cache() 95 | testSet.cache() 96 | 97 | val fm = new MVMClassification(trainSet, stepSize, views, regParam, 0.0, rank, 98 | useAdaGrad, useWeightedLambda, miniBatchFraction) 99 | fm.run(numIterations) 100 | val model = fm.saveModel() 101 | println(f"Test loss: ${model.loss(testSet.cache())}%1.4f") 102 | 103 | } 104 | 105 | } 106 | -------------------------------------------------------------------------------- /ml/src/test/scala/com/github/cloudml/zen/ml/regression/LogisticRegressionSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.regression 19 | 20 | import com.github.cloudml.zen.ml.util._ 21 | import org.apache.spark.mllib.regression.LabeledPoint 22 | import org.apache.spark.mllib.util.MLUtils 23 | import org.scalatest.{Matchers, FunSuite} 24 | import com.github.cloudml.zen.ml.util.SparkUtils._ 25 | 26 | 27 | class LogisticRegressionSuite extends FunSuite with SharedSparkContext with Matchers { 28 | 29 | test("LogisticRegression MIS") { 30 | val zenHome = sys.props.getOrElse("zen.test.home", fail("zen.test.home is not set!")) 31 | val dataSetFile = classOf[LogisticRegressionSuite].getClassLoader().getResource("binary_classification_data.txt").toString() 32 | val dataSet = MLUtils.loadLibSVMFile(sc, dataSetFile) 33 | val max = dataSet.map(_.features.activeValuesIterator.map(_.abs).sum + 1L).max 34 | 35 | val maxIter = 10 36 | val stepSize = 1 / (2 * max) 37 | val trainDataSet = dataSet.zipWithUniqueId().map { case (LabeledPoint(label, features), id) => 38 | val newLabel = if (label > 0.0) 1.0 else -1.0 39 | (id, LabeledPoint(newLabel, features)) 40 | } 41 | val lr = new LogisticRegressionMIS(trainDataSet, stepSize) 42 | val pps = new Array[Double](maxIter) 43 | var i = 0 44 | val startedAt = System.currentTimeMillis() 45 | while (i < maxIter) { 46 | lr.run(1) 47 | val q = lr.forward(i) 48 | pps(i) = lr.loss(q) 49 | i += 1 50 | } 51 | println((System.currentTimeMillis() - startedAt) / 1e3) 52 | pps.foreach(println) 53 | 54 | val ppsDiff = pps.init.zip(pps.tail).map { case (lhs, rhs) => lhs - rhs } 55 | assert(ppsDiff.count(_ > 0).toDouble / ppsDiff.size > 0.05) 56 | assert(pps.head - pps.last > 0) 57 | } 58 | 59 | test("LogisticRegression SGD") { 60 | val zenHome = sys.props.getOrElse("zen.test.home", fail("zen.test.home is not set!")) 61 | val dataSetFile = classOf[LogisticRegressionSuite].getClassLoader().getResource("binary_classification_data.txt").toString() 62 | val dataSet = MLUtils.loadLibSVMFile(sc, dataSetFile) 63 | val maxIter = 10 64 | val stepSize = 1 65 | val trainDataSet = dataSet.zipWithIndex().map { case (LabeledPoint(label, features), id) => 66 | val newLabel = if (label > 0.0) 1.0 else 0 67 | (id, LabeledPoint(newLabel, features)) 68 | } 69 | val lr = new LogisticRegressionSGD(trainDataSet, stepSize) 70 | val pps = new Array[Double](maxIter) 71 | var i = 0 72 | val startedAt = System.currentTimeMillis() 73 | while (i < maxIter) { 74 | lr.run(1) 75 | val margin = lr.forward(i) 76 | pps(i) = lr.loss(margin) 77 | i += 1 78 | } 79 | println((System.currentTimeMillis() - startedAt) / 1e3) 80 | pps.foreach(println) 81 | 82 | val ppsDiff = pps.init.zip(pps.tail).map { case (lhs, rhs) => lhs - rhs } 83 | assert(ppsDiff.count(_ > 0).toDouble / ppsDiff.size > 0.05) 84 | assert(pps.head - pps.last > 0) 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /ml/src/test/scala/com/github/cloudml/zen/ml/util/LocalSparkContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.util 19 | 20 | import org.apache.spark.SparkContext 21 | import org.scalatest.Suite 22 | import org.scalatest.BeforeAndAfterEach 23 | import org.scalatest.BeforeAndAfterAll 24 | 25 | import org.jboss.netty.logging.InternalLoggerFactory 26 | import org.jboss.netty.logging.Slf4JLoggerFactory 27 | 28 | trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { 29 | self: Suite => 30 | 31 | @transient var sc: SparkContext = _ 32 | 33 | override def beforeAll() { 34 | InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) 35 | super.beforeAll() 36 | } 37 | 38 | override def afterEach() { 39 | resetSparkContext() 40 | super.afterEach() 41 | } 42 | 43 | def resetSparkContext() = { 44 | LocalSparkContext.stop(sc) 45 | sc = null 46 | } 47 | 48 | } 49 | 50 | object LocalSparkContext { 51 | def stop(sc: SparkContext) { 52 | if (sc != null) { 53 | sc.stop() 54 | } 55 | // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown 56 | System.clearProperty("spark.driver.port") 57 | System.clearProperty("spark.hostPort") 58 | } 59 | 60 | /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ 61 | def withSpark[T](sc: SparkContext)(f: SparkContext => T) = { 62 | try { 63 | f(sc) 64 | } finally { 65 | stop(sc) 66 | } 67 | } 68 | 69 | } 70 | -------------------------------------------------------------------------------- /ml/src/test/scala/com/github/cloudml/zen/ml/util/MinstDatasetReader.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.util 19 | 20 | import java.io.{Closeable, DataInputStream, FileInputStream, IOException} 21 | import java.util.zip.GZIPInputStream 22 | 23 | import org.apache.spark.mllib.linalg.{DenseVector => SDV, Vector => SV} 24 | 25 | case class MinstItem(label: Int, data: Array[Int]) { 26 | def binaryVector: SV = { 27 | new SDV(data.map { i => 28 | if (i > 30) { 29 | 1D 30 | } else { 31 | 0D 32 | } 33 | }) 34 | } 35 | } 36 | 37 | class MinstDatasetReader(labelsFile: String, imagesFile: String) 38 | extends java.util.Iterator[MinstItem] with Closeable with Logging { 39 | 40 | val labelsBuf: DataInputStream = new DataInputStream(new GZIPInputStream( 41 | new FileInputStream(labelsFile))) 42 | var magic = labelsBuf.readInt() 43 | val labelCount = labelsBuf.readInt() 44 | logInfo(s"Labels magic=$magic count= $labelCount") 45 | 46 | val imagesBuf: DataInputStream = new DataInputStream(new GZIPInputStream( 47 | new FileInputStream(imagesFile))) 48 | magic = imagesBuf.readInt() 49 | val imageCount = imagesBuf.readInt() 50 | val rows = imagesBuf.readInt() 51 | val cols = imagesBuf.readInt() 52 | logInfo(s"Images magic=$magic count=$imageCount rows=$rows cols=$cols") 53 | assert(imageCount == labelCount) 54 | 55 | var current = 0 56 | 57 | override def next(): MinstItem = { 58 | try { 59 | val data = new Array[Int](rows * cols) 60 | for (i <- 0 until data.length) { 61 | data(i) = imagesBuf.readUnsignedByte() 62 | } 63 | val label = labelsBuf.readUnsignedByte() 64 | MinstItem(label, data) 65 | } catch { 66 | case e: IOException => 67 | current = imageCount 68 | throw e 69 | } 70 | finally { 71 | current += 1 72 | } 73 | } 74 | 75 | override def hasNext = current < imageCount 76 | 77 | override def close: Unit = { 78 | imagesBuf.close() 79 | labelsBuf.close() 80 | } 81 | 82 | override def remove { 83 | throw new UnsupportedOperationException("remove") 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /ml/src/test/scala/com/github/cloudml/zen/ml/util/MnistDatasetSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.util 19 | 20 | import breeze.linalg.{DenseVector => BDV} 21 | import org.apache.spark.mllib.linalg.{Vector => SV} 22 | import org.apache.spark.rdd.RDD 23 | import org.scalatest.Suite 24 | 25 | import scala.collection.JavaConversions._ 26 | 27 | trait MnistDatasetSuite extends SharedSparkContext { 28 | self: Suite => 29 | 30 | def mnistTrainDataset(size: Int = 5000, dropN: Int = 0): (RDD[(SV, SV)], Int) = { 31 | val zenHome = sys.props.getOrElse("zen.test.home", fail("spark.test.home is not set!")) 32 | // http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz 33 | val labelsFile = s"$zenHome/data/mnist/train-labels-idx1-ubyte.gz" 34 | // http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz 35 | val imagesFile = s"$zenHome/data/mnist/train-images-idx3-ubyte.gz" 36 | val minstReader = new MinstDatasetReader(labelsFile, imagesFile) 37 | val numVisible = minstReader.rows * minstReader.cols 38 | val minstData = minstReader.slice(dropN, dropN + size).map { case m@MinstItem(label, data) => 39 | assert(label < 10) 40 | val y = BDV.zeros[Double](10) 41 | y := 0.1 / y.length 42 | y(label) += 0.9 43 | val x = m.binaryVector 44 | (x, SparkUtils.fromBreeze(y)) 45 | } 46 | val data: RDD[(SV, SV)] = sc.parallelize(minstData.toSeq) 47 | (data, numVisible) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /ml/src/test/scala/com/github/cloudml/zen/ml/util/SharedSparkContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cloudml.zen.ml.util 19 | 20 | import org.apache.spark.{SparkConf, SparkContext} 21 | import org.scalatest.Suite 22 | import org.scalatest.BeforeAndAfterAll 23 | 24 | /** Shares a local `SparkContext` between all tests in a suite and closes it at the end */ 25 | trait SharedSparkContext extends BeforeAndAfterAll { 26 | self: Suite => 27 | 28 | @transient private var _sc: SparkContext = _ 29 | 30 | def sc: SparkContext = _sc 31 | 32 | override def beforeAll() { 33 | val conf = new SparkConf().setAppName(s"zen-test") 34 | conf.set("spark.cleaner.referenceTracking.blocking", "true") 35 | conf.set("spark.cleaner.referenceTracking.blocking.shuffle", "true") 36 | _sc = new SparkContext("local[3]", "test", conf) 37 | super.beforeAll() 38 | } 39 | 40 | override def afterAll() { 41 | LocalSparkContext.stop(_sc) 42 | _sc = null 43 | super.afterAll() 44 | } 45 | } 46 | --------------------------------------------------------------------------------