├── .gitignore
├── ml
├── pom.xml
└── src
│ └── main
│ ├── resources
│ └── log4j.properties
│ └── scala
│ └── org
│ └── dma
│ └── sketchml
│ └── ml
│ ├── SketchML.scala
│ ├── algorithm
│ ├── GeneralizedLinearModel.scala
│ ├── LRModel.scala
│ ├── LinearRegModel.scala
│ └── SVMModel.scala
│ ├── common
│ └── Constants.scala
│ ├── conf
│ └── MLConf.scala
│ ├── data
│ ├── DataSet.scala
│ ├── LabeledData.scala
│ └── Parser.scala
│ ├── gradient
│ ├── DenseDoubleGradient.scala
│ ├── DenseFloatGradient.scala
│ ├── FixedPointGradient.scala
│ ├── Gradient.scala
│ ├── Kind.scala
│ ├── SketchGradient.scala
│ ├── SparseDoubleGradient.scala
│ ├── SparseFloatGradient.scala
│ └── ZipGradient.scala
│ ├── objective
│ ├── Adam.scala
│ ├── GradientDescent.scala
│ └── Loss.scala
│ └── util
│ ├── Maths.scala
│ └── ValidationUtil.scala
├── pom.xml
└── sketch
├── pom.xml
└── src
└── main
├── java
└── org
│ └── dma
│ └── sketchml
│ └── sketch
│ ├── base
│ ├── BinaryEncoder.java
│ ├── Int2IntHash.java
│ ├── QuantileSketch.java
│ ├── Quantizer.java
│ ├── SketchMLException.java
│ └── VectorCompressor.java
│ ├── binary
│ ├── BinaryUtils.java
│ ├── DeltaAdaptiveEncoder.java
│ ├── DeltaBinaryEncoder.java
│ └── HuffmanEncoder.java
│ ├── common
│ └── Constants.java
│ ├── hash
│ ├── BJHash.java
│ ├── BKDRHash.java
│ ├── HashFactory.java
│ ├── Mix64Hash.java
│ └── TWHash.java
│ ├── quantization
│ ├── QuantileQuantizer.java
│ └── UniformQuantizer.java
│ ├── sample
│ ├── App.java
│ ├── DenseVectorCompressor.java
│ └── SparseVectorCompressor.java
│ ├── sketch
│ ├── frequency
│ │ ├── FSketchUtils.java
│ │ ├── GroupedMinMaxSketch.java
│ │ └── MinMaxSketch.java
│ └── quantile
│ │ ├── HeapQuantileSketch.java
│ │ ├── QSketchUtils.java
│ │ └── QuantileSketchException.java
│ └── util
│ ├── Maths.java
│ ├── Sort.java
│ └── Utils.java
└── resources
└── log4j.properties
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled class file
2 | *.class
3 |
4 | # Log file
5 | *.log
6 |
7 | # BlueJ files
8 | *.ctxt
9 |
10 | # Mobile Tools for Java (J2ME)
11 | .mtj.tmp/
12 |
13 | # Package Files #
14 | *.jar
15 | *.war
16 | *.ear
17 | *.zip
18 | *.tar.gz
19 | *.rar
20 |
21 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
22 | hs_err_pid*
23 |
24 | # IntelliJ project files
25 | .idea
26 | *.iml
27 | target/
28 | out
29 | gen### Java template
30 |
31 | # DS_STORE
32 | *.DS_Store
33 |
--------------------------------------------------------------------------------
/ml/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | sketchml
7 | org.dma.sketchml
8 | 1.0.0
9 | ../pom.xml
10 |
11 | 4.0.0
12 |
13 | ml
14 |
15 |
16 |
17 | org.dma.sketchml
18 | sketch
19 | 1.0.0
20 |
21 |
22 |
23 | junit
24 | junit
25 | 4.10
26 | test
27 |
28 |
29 |
30 |
31 | net.jcip
32 | jcip-annotations
33 | 1.0
34 |
35 |
36 |
37 |
38 | com.twitter
39 | chill_2.11
40 | 0.8.0
41 | test
42 |
43 |
44 |
45 |
46 | com.twitter
47 | chill-java
48 | 0.8.0
49 | test
50 |
51 |
52 |
53 | org.scala-lang
54 | scala-library
55 | ${scalaVersion}
56 | provided
57 |
58 |
59 |
60 | org.scala-tools
61 | maven-scala-plugin
62 | 2.11
63 | provided
64 |
65 |
66 |
67 | org.apache.spark
68 | spark-mllib_2.11
69 | 2.2.0
70 |
71 |
72 |
73 | org.apache.spark
74 | spark-mllib-local_2.11
75 | 2.2.0
76 |
77 |
78 |
79 |
80 |
81 |
82 | net.alchim31.maven
83 | scala-maven-plugin
84 | 3.2.2
85 |
86 |
87 | scala-compile-first
88 | process-resources
89 |
90 | compile
91 |
92 |
93 |
94 | scala-test-compile
95 | process-test-resources
96 |
97 | testCompile
98 |
99 |
100 |
101 |
102 | ${scalaVersion}
103 | incremental
104 |
105 |
106 |
107 | org.scala-tools
108 | maven-scala-plugin
109 | 2.15.0
110 |
111 |
112 |
113 | compile
114 | testCompile
115 |
116 |
117 |
118 | -dependencyfile
119 | ${project.build.directory}/.scala_dependencies
120 |
121 |
122 |
123 |
124 |
125 |
126 | org.apache.maven.plugins
127 | maven-surefire-plugin
128 |
129 |
130 | org.scalatest
131 | scalatest-maven-plugin
132 |
133 |
134 | org.apache.maven.plugins
135 | maven-compiler-plugin
136 | 3.3
137 |
138 | ${jdkVersion}
139 | ${jdkVersion}
140 | false
141 |
142 |
143 |
144 | org.apache.maven.plugins
145 | maven-assembly-plugin
146 | 2.5.4
147 |
148 |
149 | jar-with-dependencies
150 |
151 |
152 |
153 |
154 | make-assembly
155 | package
156 |
157 | single
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
--------------------------------------------------------------------------------
/ml/src/main/resources/log4j.properties:
--------------------------------------------------------------------------------
1 | log4j.rootLogger=INFO, STDOUT
2 | log4j.logger.deng=INFO
3 | log4j.appender.STDOUT=org.apache.log4j.ConsoleAppender
4 | log4j.appender.STDOUT.layout=org.apache.log4j.PatternLayout
5 | log4j.appender.STDOUT.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SSS} %C %p - %m%n
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/SketchML.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml
2 |
3 | import org.apache.spark.{SparkConf, SparkContext}
4 | import org.dma.sketchml.ml.algorithm._
5 | import org.dma.sketchml.ml.common.Constants
6 | import org.dma.sketchml.ml.conf.MLConf
7 |
8 | object SketchML {
9 | def main(args: Array[String]): Unit = {
10 | val sparkConf = new SparkConf().setAppName("SketchML")
11 | implicit val sc = SparkContext.getOrCreate(sparkConf)
12 | val mlConf = MLConf(sparkConf)
13 | val model = mlConf.algo match {
14 | case Constants.ML_LOGISTIC_REGRESSION => LRModel(mlConf)
15 | case Constants.ML_SUPPORT_VECTOR_MACHINE => SVMModel(mlConf)
16 | case Constants.ML_LINEAR_REGRESSION => LinearRegModel(mlConf)
17 | case _ => throw new UnknownError("Unsupported algorithm: " + mlConf.algo)
18 | }
19 |
20 | model.loadData()
21 | model.train()
22 |
23 | // TODO: test data
24 | }
25 |
26 | }
27 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/algorithm/GeneralizedLinearModel.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.algorithm
2 |
3 | import org.apache.spark.broadcast.Broadcast
4 | import org.apache.spark.ml.linalg.DenseVector
5 | import org.apache.spark.rdd.RDD
6 | import org.apache.spark.storage.StorageLevel
7 | import org.apache.spark.{SparkContext, SparkEnv}
8 | import org.dma.sketchml.ml.data.{DataSet, Parser}
9 | import org.dma.sketchml.ml.conf.MLConf
10 | import org.dma.sketchml.ml.gradient.Gradient
11 | import org.dma.sketchml.ml.objective.{GradientDescent, Loss}
12 | import org.dma.sketchml.ml.util.ValidationUtil
13 | import org.slf4j.{Logger, LoggerFactory}
14 |
15 | import scala.collection.mutable.ArrayBuffer
16 | import scala.util.Random
17 |
18 | object GeneralizedLinearModel {
19 | private val logger: Logger = LoggerFactory.getLogger(GeneralizedLinearModel.getClass)
20 |
21 | object Model {
22 | var weights: DenseVector = _
23 | var optimizer: GradientDescent = _
24 | var loss: Loss = _
25 | var gradient: Gradient = _
26 | }
27 |
28 | object Data {
29 | var trainData: DataSet = _
30 | var validData: DataSet = _
31 | }
32 |
33 | }
34 |
35 | import GeneralizedLinearModel.Model._
36 | import GeneralizedLinearModel.Data._
37 |
38 | abstract class GeneralizedLinearModel(@transient protected val conf: MLConf) extends Serializable {
39 | @transient protected val logger: Logger = GeneralizedLinearModel.logger
40 |
41 | @transient protected implicit val sc: SparkContext = SparkContext.getOrCreate()
42 | @transient protected var executors: RDD[Int] = _
43 | protected val bcConf: Broadcast[MLConf] = sc.broadcast(conf)
44 |
45 | def loadData(): Unit = {
46 | val startTime = System.currentTimeMillis()
47 | val dataRdd = Parser.loadData(conf.input, conf.format, conf.featureNum, conf.workerNum)
48 | .persist(StorageLevel.MEMORY_AND_DISK)
49 | executors = dataRdd.mapPartitionsWithIndex((partId, _) => {
50 | val exeId = SparkEnv.get.executorId match {
51 | case "driver" => partId
52 | case exeStr => exeStr.toInt
53 | }
54 | Seq(exeId).iterator
55 | }, preservesPartitioning = true)
56 | val (trainDataNum, validDataNum) = dataRdd.mapPartitions(iterator => {
57 | trainData = new DataSet
58 | validData = new DataSet
59 | while (iterator.hasNext) {
60 | if (Random.nextDouble() > bcConf.value.validRatio)
61 | trainData += iterator.next()
62 | else
63 | validData += iterator.next()
64 | }
65 | Seq((trainData.size, validData.size)).iterator
66 | }, preservesPartitioning = true)
67 | .reduce((c1, c2) => (c1._1 + c2._1, c1._2 + c2._2))
68 | //val rdds = dataRdd.randomSplit(Array(1.0 - validRatio, validRatio))
69 | //val trainRdd = rdds(0).persist(StorageLevel.MEMORY_AND_DISK)
70 | //val validRdd = rdds(1).persist(StorageLevel.MEMORY_AND_DISK)
71 | //val trainDataNum = trainRdd.count().toInt
72 | //val validDataNum = validRdd.count().toInt
73 | dataRdd.unpersist()
74 | logger.info(s"Load data cost ${System.currentTimeMillis() - startTime} ms, " +
75 | s"$trainDataNum train data, $validDataNum valid data")
76 | }
77 |
78 | protected def initModel(): Unit
79 |
80 | def train(): Unit = {
81 | logger.info(s"Start to train a $getName model")
82 | logger.info(s"Configuration: $conf")
83 | val startTime = System.currentTimeMillis()
84 | initModel()
85 |
86 | val trainLosses = ArrayBuffer[Double](conf.epochNum)
87 | val validLosses = ArrayBuffer[Double](conf.epochNum)
88 | val timeElapsed = ArrayBuffer[Long](conf.epochNum)
89 | val batchNum = Math.ceil(1.0 / conf.batchSpRatio).toInt
90 | for (epoch <- 0 until conf.epochNum) {
91 | logger.info(s"Epoch[$epoch] start training")
92 | trainLosses += trainOneEpoch(epoch, batchNum)
93 | validLosses += validate(epoch)
94 | timeElapsed += System.currentTimeMillis() - startTime
95 | logger.info(s"Epoch[$epoch] done, ${timeElapsed.last} ms elapsed")
96 | }
97 |
98 | logger.info(s"Train done, total cost ${System.currentTimeMillis() - startTime} ms")
99 | logger.info(s"Train loss: [${trainLosses.mkString(", ")}]")
100 | logger.info(s"Valid loss: [${validLosses.mkString(", ")}]")
101 | logger.info(s"Time: [${timeElapsed.mkString(", ")}]")
102 | }
103 |
104 | protected def trainOneEpoch(epoch: Int, batchNum: Int): Double = {
105 | val epochStart = System.currentTimeMillis()
106 | var trainLoss = 0.0
107 | for (batch <- 0 until batchNum) {
108 | val batchLoss = trainOneIteration(epoch, batch)
109 | trainLoss += batchLoss
110 | }
111 | val epochCost = System.currentTimeMillis() - epochStart
112 | logger.info(s"Epoch[$epoch] train cost $epochCost ms, loss=${trainLoss / batchNum}")
113 | trainLoss / batchNum
114 | }
115 |
116 | protected def trainOneIteration(epoch: Int, batch: Int): Double = {
117 | val batchStart = System.currentTimeMillis()
118 | val batchLoss = computeGradient(epoch, batch)
119 | aggregateAndUpdate(epoch, batch)
120 | logger.info(s"Epoch[$epoch] batch $batch train cost "
121 | + s"${System.currentTimeMillis() - batchStart} ms")
122 | batchLoss
123 | }
124 |
125 | protected def computeGradient(epoch: Int, batch: Int): Double = {
126 | val miniBatchGDStart = System.currentTimeMillis()
127 | val (batchSize, objLoss, regLoss) = executors.aggregate(0, 0.0, 0.0)(
128 | seqOp = (_, _) => {
129 | val (grad, batchSize, objLoss ,regLoss) =
130 | optimizer.miniBatchGradientDescent(weights, trainData, loss)
131 | gradient = grad
132 | (batchSize, objLoss, regLoss)
133 | },
134 | combOp = (c1, c2) => (c1._1 + c2._1, c1._2 + c2._2, c1._3 + c2._3)
135 | )
136 | val batchLoss = objLoss / batchSize + regLoss / conf.workerNum
137 | logger.info(s"Epoch[$epoch] batch $batch compute gradient cost "
138 | + s"${System.currentTimeMillis() - miniBatchGDStart} ms, "
139 | + s"batch size=$batchSize, batch loss=$batchLoss")
140 | batchLoss
141 | }
142 |
143 | protected def aggregateAndUpdate(epoch: Int, batch: Int): Unit = {
144 | val aggrStart = System.currentTimeMillis()
145 | val sum = Gradient.sum(
146 | conf.featureNum,
147 | executors.map(_ => Gradient.compress(gradient, bcConf.value)).collect()
148 | )
149 | val grad = Gradient.compress(sum, conf)
150 | grad.timesBy(1.0 / conf.workerNum)
151 | logger.info(s"Epoch[$epoch] batch $batch aggregate gradients cost "
152 | + s"${System.currentTimeMillis() - aggrStart} ms")
153 |
154 | val updateStart = System.currentTimeMillis()
155 | val bcGrad = sc.broadcast(grad)
156 | executors.foreach(_ => optimizer.update(bcGrad.value, weights))
157 | logger.info(s"Epoch[$epoch] batch $batch update weights cost "
158 | + s"${System.currentTimeMillis() - updateStart} ms")
159 | }
160 |
161 | protected def validate(epoch: Int): Double = {
162 | val validStart = System.currentTimeMillis()
163 | val (sumLoss, truePos, trueNeg, falsePos, falseNeg, validNum) =
164 | executors.aggregate((0.0, 0, 0, 0, 0, 0))(
165 | seqOp = (_, _) => ValidationUtil.calLossPrecision(weights, validData, loss),
166 | combOp = (c1, c2) => (c1._1 + c2._1, c1._2 + c2._2, c1._3 + c2._3,
167 | c1._4 + c2._4, c1._5 + c2._5, c1._6 + c2._6)
168 | )
169 | val validLoss = sumLoss / validNum
170 | val precision = 1.0 * (truePos + trueNeg) / validNum
171 | val trueRecall = 1.0 * truePos / (truePos + falseNeg)
172 | val falseRecall = 1.0 * trueNeg / (trueNeg + falsePos)
173 | logger.info(s"Epoch[$epoch] validation cost ${System.currentTimeMillis() - validStart} ms, "
174 | + s"valid size=$validNum, loss=$validLoss, precision=$precision, "
175 | + s"trueRecall=$trueRecall, falseRecall=$falseRecall")
176 | validLoss
177 | }
178 |
179 | def getName: String
180 |
181 | }
182 |
183 |
184 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/algorithm/LRModel.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.algorithm
2 |
3 | import org.apache.spark.ml.linalg.DenseVector
4 | import org.dma.sketchml.ml.algorithm.GeneralizedLinearModel.Model._
5 | import org.dma.sketchml.ml.common.Constants
6 | import org.dma.sketchml.ml.conf.MLConf
7 | import org.dma.sketchml.ml.objective.{Adam, L2LogLoss}
8 | import org.slf4j.{Logger, LoggerFactory}
9 |
10 | object LRModel {
11 | private val logger: Logger = LoggerFactory.getLogger(LRModel.getClass)
12 |
13 | def apply(conf: MLConf): LRModel = new LRModel(conf)
14 |
15 | def getName: String = Constants.ML_LOGISTIC_REGRESSION
16 | }
17 |
18 | class LRModel(_conf: MLConf) extends GeneralizedLinearModel(_conf) {
19 | @transient override protected val logger: Logger = LRModel.logger
20 |
21 | override protected def initModel(): Unit = {
22 | executors.foreach(_ => {
23 | weights = new DenseVector(new Array[Double](bcConf.value.featureNum))
24 | optimizer = Adam(bcConf.value)
25 | loss = new L2LogLoss(bcConf.value.l2Reg)
26 | })
27 | }
28 |
29 | override def getName: String = LRModel.getName
30 | }
31 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/algorithm/LinearRegModel.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.algorithm
2 |
3 | import org.apache.spark.ml.linalg.DenseVector
4 | import org.dma.sketchml.ml.algorithm.GeneralizedLinearModel.Model._
5 | import org.dma.sketchml.ml.common.Constants
6 | import org.dma.sketchml.ml.conf.MLConf
7 | import org.dma.sketchml.ml.objective.{Adam, L2SquareLoss}
8 | import org.slf4j.{Logger, LoggerFactory}
9 |
10 | object LinearRegModel {
11 | private val logger: Logger = LoggerFactory.getLogger(LinearRegModel.getClass)
12 |
13 | def apply(conf: MLConf): LinearRegModel = new LinearRegModel(conf)
14 |
15 | def getName: String = Constants.ML_LINEAR_REGRESSION
16 | }
17 |
18 | class LinearRegModel(_conf: MLConf) extends GeneralizedLinearModel(_conf) {
19 | @transient override protected val logger: Logger = LinearRegModel.logger
20 |
21 | override protected def initModel(): Unit = {
22 | executors.foreach(_ => {
23 | weights = new DenseVector(new Array[Double](bcConf.value.featureNum))
24 | optimizer = Adam(bcConf.value)
25 | loss = new L2SquareLoss(bcConf.value.l2Reg)
26 | })
27 | }
28 |
29 | override def getName: String = LinearRegModel.getName
30 |
31 | }
32 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/algorithm/SVMModel.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.algorithm
2 |
3 | import org.apache.spark.ml.linalg.DenseVector
4 | import org.dma.sketchml.ml.algorithm.GeneralizedLinearModel.Model._
5 | import org.dma.sketchml.ml.common.Constants
6 | import org.dma.sketchml.ml.conf.MLConf
7 | import org.dma.sketchml.ml.objective.{Adam, L2HingeLoss}
8 | import org.slf4j.{Logger, LoggerFactory}
9 |
10 | object SVMModel {
11 | private val logger: Logger = LoggerFactory.getLogger(SVMModel.getClass)
12 |
13 | def apply(conf: MLConf): SVMModel = new SVMModel(conf)
14 |
15 | def getName: String = Constants.ML_SUPPORT_VECTOR_MACHINE
16 | }
17 |
18 | class SVMModel(_conf: MLConf) extends GeneralizedLinearModel(_conf) {
19 | @transient override protected val logger: Logger = SVMModel.logger
20 |
21 | override protected def initModel(): Unit = {
22 | executors.foreach(_ => {
23 | weights = new DenseVector(new Array[Double](bcConf.value.featureNum))
24 | optimizer = Adam(bcConf.value)
25 | loss = new L2HingeLoss(bcConf.value.l2Reg)
26 | })
27 | }
28 |
29 | override def getName: String = SVMModel.getName
30 |
31 | }
32 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/common/Constants.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.common
2 |
3 | object Constants {
4 | val ML_LOGISTIC_REGRESSION: String = "LogisticRegression"
5 | val ML_SUPPORT_VECTOR_MACHINE: String = "SupportVectorMachine"
6 | val ML_LINEAR_REGRESSION: String = "LinearRegression"
7 | val FORMAT_LIBSVM: String = "libsvm"
8 | val FORMAT_CSV: String = "csv"
9 | val FORMAT_DUMMY: String = "dummy"
10 | val GRADIENT_COMPRESSOR_NONE: String = "None"
11 | val GRADIENT_COMPRESSOR_FLOAT: String = "Float"
12 | val GRADIENT_COMPRESSOR_SKETCH: String = "Sketch"
13 | val GRADIENT_COMPRESSOR_FIXED_POINT: String = "FixedPoint"
14 | val GRADIENT_COMPRESSOR_ZIP: String = "Zip"
15 |
16 | }
17 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/conf/MLConf.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.conf
2 |
3 | import org.apache.spark.SparkConf
4 | import org.dma.sketchml.ml.common.Constants._
5 | import org.dma.sketchml.sketch.base.{Quantizer, SketchMLException}
6 | import org.dma.sketchml.sketch.sketch.frequency.{GroupedMinMaxSketch, MinMaxSketch}
7 |
8 | object MLConf {
9 | // ML Conf
10 | val ML_ALGORITHM: String = "spark.sketchml.algo"
11 | val ML_INPUT_PATH: String = "spark.sketchml.input.path"
12 | val ML_INPUT_FORMAT: String = "spark.sketchml.input.format"
13 | //val ML_TEST_DATA_PATH: String = "spark.sketchml.test.path"
14 | //val ML_NUM_CLASS: String = "spark.sketchml.class.num"
15 | //val DEFAULT_ML_NUM_CLASS: Int = 2
16 | val ML_NUM_WORKER: String = "spark.sketchml.worker.num"
17 | val ML_NUM_FEATURE: String = "spark.sketchml.feature.num"
18 | val ML_VALID_RATIO: String = "spark.sketchml.valid.ratio"
19 | val DEFAULT_ML_VALID_RATIO: Double = 0.25
20 | val ML_EPOCH_NUM: String = "spark.sketchml.epoch.num"
21 | val DEFAULT_ML_EPOCH_NUM: Int = 100
22 | val ML_BATCH_SAMPLE_RATIO: String = "spark.sketchml.batch.sample.ratio"
23 | val DEFAULT_ML_BATCH_SAMPLE_RATIO: Double = 0.1
24 | val ML_LEARN_RATE: String = "spark.sketchml.learn.rate"
25 | val DEFAULT_ML_LEARN_RATE: Double = 0.1
26 | val ML_LEARN_DECAY: String = "spark.sketchml.learn.decay"
27 | val DEFAULT_ML_LEARN_DECAY: Double = 0.9
28 | val ML_REG_L1: String = "spark.sketchml.reg.l1"
29 | val DEFAULT_ML_REG_L1: Double = 0.1
30 | val ML_REG_L2: String = "spark.sketchml.reg.l2"
31 | val DEFAULT_ML_REG_L2: Double = 0.1
32 | // Sketch Conf
33 | val SKETCH_GRADIENT_COMPRESSOR: String = "spark.sketchml.gradient.compressor"
34 | val DEFAULT_SKETCH_GRADIENT_COMPRESSOR: String = GRADIENT_COMPRESSOR_SKETCH
35 | val SKETCH_QUANTIZATION_BIN_NUM: String = "spark.sketchml.quantization.bin.num"
36 | val DEFAULT_SKETCH_QUANTIZATION_BIN_NUM: Int = Quantizer.DEFAULT_BIN_NUM
37 | val SKETCH_MINMAXSKETCH_GROUP_NUM: String = "spark.sketchml.minmaxsketch.group.num"
38 | val DEFAULT_SKETCH_MINMAXSKETCH_GROUP_NUM: Int = GroupedMinMaxSketch.DEFAULT_MINMAXSKETCH_GROUP_NUM
39 | val SKETCH_MINMAXSKETCH_ROW_NUM: String = "spark.sketchml.minmaxsketch.row.num"
40 | val DEFAULT_SKETCH_MINMAXSKETCH_ROW_NUM: Int = MinMaxSketch.DEFAULT_MINMAXSKETCH_ROW_NUM
41 | val SKETCH_MINMAXSKETCH_COL_RATIO: String = "spark.sketchml.minmaxsketch.col.ratio"
42 | val DEFAULT_SKETCH_MINMAXSKETCH_COL_RATIO: Double = GroupedMinMaxSketch.DEFAULT_MINMAXSKETCH_COL_RATIO
43 | // FixedPoint Conf
44 | val FIXED_POINT_BIT_NUM: String = "spark.sketchml.fixed.point.bit.num"
45 | val DEFAULT_FIXED_POINT_BIT_NUM = 8
46 |
47 | def apply(sparkConf: SparkConf): MLConf = MLConf(
48 | sparkConf.get(ML_ALGORITHM),
49 | sparkConf.get(ML_INPUT_PATH),
50 | sparkConf.get(ML_INPUT_FORMAT),
51 | sparkConf.get(ML_NUM_WORKER).toInt,
52 | sparkConf.get(ML_NUM_FEATURE).toInt,
53 | sparkConf.getDouble(ML_VALID_RATIO, DEFAULT_ML_VALID_RATIO),
54 | sparkConf.getInt(ML_EPOCH_NUM, DEFAULT_ML_EPOCH_NUM),
55 | sparkConf.getDouble(ML_BATCH_SAMPLE_RATIO, DEFAULT_ML_BATCH_SAMPLE_RATIO),
56 | sparkConf.getDouble(ML_LEARN_RATE, DEFAULT_ML_LEARN_RATE),
57 | sparkConf.getDouble(ML_LEARN_DECAY, DEFAULT_ML_LEARN_DECAY),
58 | sparkConf.getDouble(ML_REG_L1, DEFAULT_ML_REG_L1),
59 | sparkConf.getDouble(ML_REG_L2, DEFAULT_ML_REG_L2),
60 | sparkConf.get(SKETCH_GRADIENT_COMPRESSOR, DEFAULT_SKETCH_GRADIENT_COMPRESSOR),
61 | sparkConf.getInt(SKETCH_QUANTIZATION_BIN_NUM, DEFAULT_SKETCH_QUANTIZATION_BIN_NUM),
62 | sparkConf.getInt(SKETCH_MINMAXSKETCH_GROUP_NUM, DEFAULT_SKETCH_MINMAXSKETCH_GROUP_NUM),
63 | sparkConf.getInt(SKETCH_MINMAXSKETCH_ROW_NUM, DEFAULT_SKETCH_MINMAXSKETCH_ROW_NUM),
64 | sparkConf.getDouble(SKETCH_MINMAXSKETCH_COL_RATIO, DEFAULT_SKETCH_MINMAXSKETCH_COL_RATIO),
65 | sparkConf.getInt(FIXED_POINT_BIT_NUM, DEFAULT_FIXED_POINT_BIT_NUM)
66 | )
67 |
68 | }
69 |
70 | case class MLConf(algo: String, input: String, format: String, workerNum: Int,
71 | featureNum: Int, validRatio: Double, epochNum: Int,batchSpRatio: Double,
72 | learnRate: Double, learnDecay: Double, l1Reg: Double, l2Reg: Double,
73 | compressor: String, quantBinNum: Int, sketchGroupNum: Int,
74 | sketchRowNum: Int, sketchColRatio: Double, fixedPointBitNum: Int) {
75 | require(Seq(ML_LOGISTIC_REGRESSION, ML_SUPPORT_VECTOR_MACHINE, ML_LINEAR_REGRESSION).contains(algo),
76 | throw new SketchMLException(s"Unsupported algorithm: $algo"))
77 | require(Seq(FORMAT_LIBSVM, FORMAT_CSV, FORMAT_DUMMY).contains(format),
78 | throw new SketchMLException(s"Unrecognizable file format: $format"))
79 | require(Seq(GRADIENT_COMPRESSOR_SKETCH, GRADIENT_COMPRESSOR_FIXED_POINT, GRADIENT_COMPRESSOR_ZIP,
80 | GRADIENT_COMPRESSOR_FLOAT, GRADIENT_COMPRESSOR_NONE).contains(compressor),
81 | throw new SketchMLException(s"Unrecognizable gradient compressor: $compressor"))
82 |
83 | }
84 |
85 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/data/DataSet.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.data
2 |
3 | import scala.collection.mutable.ArrayBuffer
4 |
5 | class DataSet {
6 | private val data = ArrayBuffer[LabeledData]()
7 | private var readIndex = 0
8 |
9 | def size: Int = data.size
10 |
11 | def add(ins: LabeledData): Unit = data += ins
12 |
13 | def get(i: Int): LabeledData = data(i)
14 |
15 | def loopingRead: LabeledData = {
16 | if (readIndex >= size)
17 | readIndex = 0
18 | val ins = data(readIndex)
19 | readIndex += 1
20 | ins
21 | }
22 |
23 | def +=(ins: LabeledData): Unit = add(ins)
24 |
25 | }
26 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/data/LabeledData.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.data
2 |
3 | import org.apache.spark.ml.linalg.Vector
4 |
5 | case class LabeledData(label: Double, feature: Vector) {
6 |
7 | override def toString: String = s"($label $feature)"
8 | }
9 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/data/Parser.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.data
2 |
3 | import org.apache.spark.SparkContext
4 | import org.apache.spark.ml.linalg.Vectors
5 | import org.apache.spark.rdd.RDD
6 | import org.dma.sketchml.ml.common.Constants
7 | import org.dma.sketchml.ml.util.Maths
8 |
9 | object Parser {
10 | def loadData(input: String, format: String, maxDim: Int, numPartition: Int,
11 | negY: Boolean = true)(implicit sc: SparkContext): RDD[LabeledData] = {
12 | val parse: (String, Int, Boolean) => LabeledData = format match {
13 | case Constants.FORMAT_LIBSVM => Parser.parseLibSVM
14 | case Constants.FORMAT_CSV => Parser.parseCSV
15 | case Constants.FORMAT_DUMMY => Parser.parseDummy
16 | case _ => throw new UnknownError("Unknown file format: " + format)
17 | }
18 | sc.textFile(input)
19 | .map(line => parse(line, maxDim, negY))
20 | .repartition(numPartition)
21 | }
22 |
23 | def parseLibSVM(line: String, maxDim: Int, negY: Boolean = true): LabeledData = {
24 | val splits = line.trim.split(" ")
25 | if (splits.length < 1)
26 | return null
27 |
28 | var y = splits(0).toDouble
29 | if (negY && Math.abs(y - 1) > Maths.EPS)
30 | y = -1
31 |
32 | val nnz = splits.length - 1
33 | val indices = new Array[Int](nnz)
34 | val values = new Array[Double](nnz)
35 | for (i <- 0 until nnz) {
36 | val kv = splits(i + 1).trim.split(":")
37 | indices(i) = kv(0).toInt
38 | values(i) = kv(1).toDouble
39 | }
40 | val x = Vectors.sparse(maxDim, indices, values)
41 |
42 | LabeledData(y, x)
43 | }
44 |
45 | def parseCSV(line: String, maxDim: Int, negY: Boolean = true): LabeledData = {
46 | val splits = line.trim.split(",")
47 | if (splits.length < 1)
48 | return null
49 |
50 | var y = splits(0).toDouble
51 | if (negY && Math.abs(y - 1) > Maths.EPS)
52 | y = -1
53 |
54 | val nnz = splits.length - 1
55 | val values = splits.slice(1, nnz + 1).map(_.trim.toDouble)
56 | val x = Vectors.dense(values)
57 |
58 | LabeledData(y, x)
59 | }
60 |
61 | def parseDummy(line: String, maxDim: Int, negY: Boolean = true): LabeledData = {
62 | val splits = line.trim.split(",")
63 | if (splits.length < 1)
64 | return null
65 |
66 | var y = splits(0).toDouble
67 | if (negY && Math.abs(y - 1) > Maths.EPS)
68 | y = -1
69 |
70 | val nnz = splits.length - 1
71 | val indices = splits.slice(1, nnz + 1).map(_.trim.toInt)
72 | val values = Array.fill(nnz)(1.0)
73 | val x = Vectors.sparse(maxDim, indices, values)
74 |
75 | LabeledData(y, x)
76 | }
77 |
78 | }
79 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/gradient/DenseDoubleGradient.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.gradient
2 |
3 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector}
4 | import org.dma.sketchml.ml.gradient.Kind.Kind
5 | import org.dma.sketchml.ml.util.Maths
6 |
7 | class DenseDoubleGradient(d: Int, val values: Array[Double]) extends Gradient(d) {
8 | def this(d: Int) = this(d, new Array[Double](d))
9 |
10 | override def plusBy(dense: DenseDoubleGradient): Gradient = {
11 | for (i <- 0 until dim)
12 | values(i) += dense.values(i)
13 | this
14 | }
15 |
16 | override def plusBy(sparse: SparseDoubleGradient): Gradient = {
17 | val k = sparse.indices
18 | val v = sparse.values
19 | for (i <- k.indices)
20 | values(k(i)) += v(i)
21 | this
22 | }
23 |
24 | override def plusBy(dense: DenseFloatGradient): Gradient = {
25 | for (i <- 0 until dim)
26 | values(i) += dense.values(i)
27 | this
28 | }
29 |
30 | override def plusBy(sparse: SparseFloatGradient): Gradient = {
31 | val k = sparse.indices
32 | val v = sparse.values
33 | for (i <- k.indices)
34 | values(k(i)) += v(i)
35 | this
36 | }
37 |
38 | override def plusBy(sketchGrad: SketchGradient): Gradient = plusBy(sketchGrad.toAuto)
39 |
40 | override def plusBy(fpGrad: FixedPointGradient): Gradient = plusBy(fpGrad.toAuto)
41 |
42 | override def plusBy(zipGrad: ZipGradient): Gradient = plusBy(zipGrad.toAuto)
43 |
44 | override def plusBy(dense: DenseVector, x: Double): Gradient = {
45 | val v = dense.values
46 | for (i <- 0 until dim)
47 | values(i) += v(i) * x
48 | this
49 | }
50 |
51 | override def plusBy(sparse: SparseVector, x: Double): Gradient = {
52 | val k = sparse.indices
53 | val v = sparse.values
54 | for (i <- k.indices)
55 | values(k(i)) += v(i) * x
56 | this
57 | }
58 |
59 | override def timesBy(x: Double): Unit = {
60 | for (i <- 0 until dim)
61 | values(i) *= x
62 | }
63 |
64 | override def countNNZ: Int = {
65 | var nnz = 0
66 | for (i <- 0 until dim)
67 | if (Math.abs(values(i)) > Maths.EPS)
68 | nnz += 1
69 | nnz
70 | }
71 |
72 | override def toDense: DenseDoubleGradient = this
73 |
74 | override def toSparse: SparseDoubleGradient = toSparse(countNNZ)
75 |
76 | private def toSparse(nnz: Int): SparseDoubleGradient = {
77 | val k = new Array[Int](nnz)
78 | val v = new Array[Double](nnz)
79 | var i = 0
80 | var j = 0
81 | while (i < dim && j < nnz) {
82 | if (Math.abs(values(i)) > Maths.EPS) {
83 | k(j) = i
84 | v(j) = values(i)
85 | j += 1
86 | }
87 | i += 1
88 | }
89 | new SparseDoubleGradient(dim, k, v)
90 | }
91 |
92 | override def toAuto: Gradient = {
93 | val nnz = countNNZ
94 | if (nnz > dim * 2 / 3) toDense else toSparse(nnz)
95 | }
96 |
97 | override def kind: Kind = Kind.DenseDouble
98 |
99 | }
100 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/gradient/DenseFloatGradient.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.gradient
2 |
3 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector}
4 | import org.dma.sketchml.ml.gradient.Kind.Kind
5 | import org.dma.sketchml.ml.util.Maths
6 | import org.dma.sketchml.sketch.base.SketchMLException
7 |
8 | class DenseFloatGradient(d: Int, val values: Array[Float]) extends Gradient(d) {
9 | def this(d: Int) = this(d, new Array[Float](d))
10 |
11 | def this(grad: Gradient) {
12 | this(grad.dim, new Array[Float](grad.dim))
13 | grad.kind match {
14 | case Kind.DenseDouble => fromDense(grad.asInstanceOf[DenseDoubleGradient])
15 | case Kind.SparseDouble => fromSparse(grad.asInstanceOf[SparseDoubleGradient])
16 | case _ => throw new SketchMLException(s"Cannot create ${this.kind} from ${grad.kind}")
17 | }
18 | }
19 |
20 | def fromDense(dense: DenseDoubleGradient): Unit = {
21 | val dv = dense.values
22 | for (i <- 0 until dim)
23 | values(i) = dv(i).toFloat
24 | }
25 |
26 | def fromSparse(sparse: SparseDoubleGradient): Unit = {
27 | val k = sparse.indices
28 | val v = sparse.values
29 | for (i <- k.indices)
30 | values(k(i)) = v(i).toFloat
31 | }
32 |
33 | override def plusBy(dense: DenseDoubleGradient): Gradient = {
34 | for (i <- 0 until dim)
35 | values(i) += dense.values(i).toFloat
36 | this
37 | }
38 |
39 | override def plusBy(sparse: SparseDoubleGradient): Gradient = {
40 | val k = sparse.indices
41 | val v = sparse.values
42 | for (i <- k.indices)
43 | values(k(i)) += v(i).toFloat
44 | this
45 | }
46 |
47 | override def plusBy(dense: DenseFloatGradient): Gradient = {
48 | for (i <- 0 until dim)
49 | values(i) += dense.values(i)
50 | this
51 | }
52 |
53 | override def plusBy(sparse: SparseFloatGradient): Gradient = {
54 | val k = sparse.indices
55 | val v = sparse.values
56 | for (i <- k.indices)
57 | values(k(i)) += v(i)
58 | this
59 | }
60 |
61 | override def plusBy(sketchGrad: SketchGradient): Gradient = plusBy(sketchGrad.toAuto)
62 |
63 | override def plusBy(fpGrad: FixedPointGradient): Gradient = plusBy(fpGrad.toAuto)
64 |
65 | override def plusBy(zipGrad: ZipGradient): Gradient = plusBy(zipGrad.toAuto)
66 |
67 | override def plusBy(dense: DenseVector, x: Double): Gradient = {
68 | val v = dense.values
69 | val x_ = x.toFloat
70 | for (i <- 0 until dim)
71 | values(i) += v(i).toFloat * x_
72 | this
73 | }
74 |
75 | override def plusBy(sparse: SparseVector, x: Double): Gradient = {
76 | val k = sparse.indices
77 | val v = sparse.values
78 | val x_ = x.toFloat
79 | for (i <- k.indices)
80 | values(k(i)) += v(i).toFloat * x_
81 | this
82 | }
83 |
84 | override def timesBy(x: Double): Unit = {
85 | val x_ = x.toFloat
86 | for (i <- 0 until dim)
87 | values(i) *= x_
88 | }
89 |
90 | override def countNNZ: Int = {
91 | var nnz = 0
92 | for (i <- 0 until dim)
93 | if (Math.abs(values(i)) > Maths.EPS)
94 | nnz += 1
95 | nnz
96 | }
97 |
98 | override def toDense: DenseDoubleGradient =
99 | new DenseDoubleGradient(dim, values.map(_.toDouble))
100 |
101 | override def toSparse: SparseDoubleGradient = toSparse(countNNZ)
102 |
103 | private def toSparse(nnz: Int): SparseDoubleGradient = {
104 | val k = new Array[Int](nnz)
105 | val v = new Array[Double](nnz)
106 | var i = 0
107 | var j = 0
108 | while (i < dim && j < nnz) {
109 | if (Math.abs(values(i)) > Maths.EPS) {
110 | k(j) = i
111 | v(j) = values(i)
112 | j += 1
113 | }
114 | i += 1
115 | }
116 | new SparseDoubleGradient(dim, k, v)
117 | }
118 |
119 | override def toAuto: Gradient = {
120 | val nnz = countNNZ
121 | if (nnz > dim * 2 / 3) toDense else toSparse(nnz)
122 | }
123 |
124 | override def kind: Kind = Kind.DenseFloat
125 |
126 | }
127 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/gradient/FixedPointGradient.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.gradient
2 | import java.util
3 |
4 | import breeze.stats.distributions.Bernoulli
5 | import org.dma.sketchml.ml.gradient.Kind.Kind
6 | import org.dma.sketchml.sketch.base.SketchMLException
7 | import org.dma.sketchml.sketch.binary.BinaryUtils
8 |
9 | object FixedPointGradient {
10 | private val bernoulli = new Bernoulli(0.5)
11 | }
12 |
13 | class FixedPointGradient(d: Int, val numBits: Int) extends Gradient(d) {
14 | import FixedPointGradient._
15 |
16 | require(numBits < 30, s"Bit num out of range: $numBits")
17 |
18 | def this(grad: Gradient, numBits: Int) {
19 | this(grad.dim, numBits)
20 | grad.kind match {
21 | case Kind.DenseDouble => fromDense(grad.asInstanceOf[DenseDoubleGradient])
22 | case Kind.SparseDouble => fromSparse(grad.asInstanceOf[SparseDoubleGradient])
23 | case _ => throw new SketchMLException(s"Cannot create ${this.kind} from ${grad.kind}")
24 | }
25 | }
26 |
27 | var size: Int = _
28 | var norm: Double = _
29 | var indices: Array[Int] = _
30 | var bitset: util.BitSet = _
31 |
32 | def fromDense(dense: DenseDoubleGradient): Unit = fromArray(dense.values)
33 |
34 | def fromSparse(sparse: SparseDoubleGradient): Unit = {
35 | indices = sparse.indices
36 | fromArray(sparse.values)
37 | }
38 |
39 | private def fromArray(values: Array[Double]): Unit = {
40 | size = values.length
41 | norm = 0.0
42 | values.foreach(v => norm += v * v)
43 | norm = Math.sqrt(norm)
44 | bitset = new util.BitSet(numBits * size)
45 | val max = (1 << (numBits - 1)) - 1
46 | val sign = 1 << (numBits - 1)
47 | for (i <- values.indices) {
48 | val sigma = if (bernoulli.draw()) 1 else 0
49 | var x = Math.floor(Math.abs(values(i)) / norm * max).toInt + sigma
50 | if (values(i) < 0) x |= sign
51 | BinaryUtils.setBits(bitset, i * numBits, x, numBits)
52 | }
53 | }
54 |
55 | override def timesBy(x: Double): Unit = norm *= x
56 |
57 | override def countNNZ: Int = size
58 |
59 | override def toDense: DenseDoubleGradient = new DenseDoubleGradient(dim, toArray)
60 |
61 | override def toSparse: SparseDoubleGradient = new SparseDoubleGradient(dim, indices, toArray)
62 |
63 | private def toArray: Array[Double] = {
64 | val values = new Array[Double](size)
65 | val max = (1 << (numBits - 1)) - 1
66 | val mask = max
67 | val sign = 1 << (numBits - 1)
68 | for (i <- 0 until size) {
69 | val x = BinaryUtils.getBits(bitset, i * numBits, numBits)
70 | var v = (x & mask).toDouble / max * norm
71 | if ((x & sign) != 0) v = -v
72 | values(i) = v
73 | }
74 | values
75 | }
76 |
77 | override def toAuto: Gradient = (if (indices == null) toDense else toSparse).toAuto
78 |
79 | override def kind: Kind = Kind.FixedPoint
80 | }
81 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/gradient/Gradient.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.gradient
2 |
3 | import javax.inject.Singleton
4 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
5 | import org.dma.sketchml.ml.common.Constants
6 | import org.dma.sketchml.ml.conf.MLConf
7 | import org.dma.sketchml.ml.gradient.Kind.Kind
8 | import org.dma.sketchml.ml.util.Maths
9 | import org.dma.sketchml.sketch.base.SketchMLException
10 | import org.dma.sketchml.sketch.util.Utils
11 | import org.slf4j.{Logger, LoggerFactory}
12 |
13 | object Gradient {
14 | def zero: ZeroGradient = ZeroGradient.getInstance()
15 |
16 | private def logger: Logger = LoggerFactory.getLogger(Gradient.getClass)
17 |
18 | def compress(grad: Gradient, conf: MLConf): Gradient = {
19 | val startTime = System.currentTimeMillis()
20 | val res = conf.compressor match {
21 | case Constants.GRADIENT_COMPRESSOR_SKETCH =>
22 | new SketchGradient(grad, conf.quantBinNum, conf.sketchGroupNum,
23 | conf.sketchRowNum, conf.sketchColRatio)
24 | case Constants.GRADIENT_COMPRESSOR_FIXED_POINT =>
25 | new FixedPointGradient(grad, conf.fixedPointBitNum)
26 | case Constants.GRADIENT_COMPRESSOR_ZIP =>
27 | new ZipGradient(grad, conf.quantBinNum)
28 | case Constants.GRADIENT_COMPRESSOR_FLOAT =>
29 | grad.kind match {
30 | case Kind.DenseDouble => new DenseFloatGradient(grad)
31 | case Kind.SparseDouble => SparseFloatGradient(grad)
32 | }
33 | case Constants.GRADIENT_COMPRESSOR_NONE => grad
34 | case _ => throw new SketchMLException(
35 | "Unrecognizable compressor: " + conf.compressor)
36 | }
37 | logger.info(s"Gradient compression from ${grad.kind} to ${res.kind} cost " +
38 | s"${System.currentTimeMillis() - startTime} ms")
39 | // uncomment to evaluate the performance of compression
40 | //evaluateCompression(grad, res)
41 | res
42 | }
43 |
44 | def sum(dim: Int, grads: Array[Gradient]): Gradient = {
45 | require(!grads.exists(_.dim != dim))
46 | val sum = new DenseDoubleGradient(dim)
47 | grads.foreach(sum.plusBy)
48 | sum.toAuto
49 | }
50 |
51 | def evaluateCompression(origin: Gradient, comp: Gradient): Unit = {
52 | logger.info(s"Evaluating compression from ${origin.kind} to ${comp.kind}, " +
53 | s"sparsity[${origin.countNNZ.toDouble / origin.dim}]")
54 | // distances
55 | val (vOrig, vComp) = origin.kind match {
56 | case Kind.DenseDouble => (origin.asInstanceOf[DenseDoubleGradient].values, comp.toDense.values)
57 | case Kind.SparseDouble => (origin.asInstanceOf[SparseDoubleGradient].values, comp.toSparse.values)
58 | }
59 | logger.info(s"Distances: euclidean[${Maths.euclidean(vOrig, vComp)}], " +
60 | s"cosine[${Maths.cosine(vOrig, vComp)}]")
61 | // size
62 | val sizeOrig = Utils.sizeof(origin)
63 | val sizeComp = Utils.sizeof(comp)
64 | val rate = 1.0 * sizeOrig / sizeComp
65 | logger.info(s"Sizeof gradients: nnz[${vOrig.length}], " +
66 | s"origin[$sizeOrig bytes], comp[$sizeComp bytes], rate[$rate]")
67 | }
68 | }
69 |
70 | abstract class Gradient(val dim: Int) extends Serializable {
71 | require(dim > 0, s"Dimension is non-positive: $dim")
72 |
73 | def plusBy(o: Gradient): Gradient = {
74 | if (o.kind == Kind.ZeroGradient)
75 | this
76 | else {
77 | require(dim == o.dim, s"Adding gradients with " +
78 | s"different dimensions: $dim, ${o.dim}")
79 | o.kind match {
80 | case Kind.DenseDouble => plusBy(o.asInstanceOf[DenseDoubleGradient])
81 | case Kind.SparseDouble => plusBy(o.asInstanceOf[SparseDoubleGradient])
82 | case Kind.DenseFloat => plusBy(o.asInstanceOf[DenseFloatGradient])
83 | case Kind.SparseFloat => plusBy(o.asInstanceOf[SparseFloatGradient])
84 | case Kind.Sketch => plusBy(o.asInstanceOf[SketchGradient])
85 | case Kind.FixedPoint => plusBy(o.asInstanceOf[FixedPointGradient])
86 | case Kind.Zip => plusBy(o.asInstanceOf[ZipGradient])
87 | case _ => throw new ClassNotFoundException(o.getClass.getName)
88 | }
89 | }
90 | }
91 |
92 | def plusBy(dense: DenseDoubleGradient): Gradient = throw new
93 | UnsupportedOperationException(s"Cannot to add ${dense.kind} to ${this.kind}")
94 |
95 | def plusBy(sparse: SparseDoubleGradient): Gradient = throw new
96 | UnsupportedOperationException(s"Cannot to add ${sparse.kind} to ${this.kind}")
97 |
98 | def plusBy(dense: DenseFloatGradient): Gradient = throw new
99 | UnsupportedOperationException(s"Cannot to add ${dense.kind} to ${this.kind}")
100 |
101 | def plusBy(sparse: SparseFloatGradient): Gradient = throw new
102 | UnsupportedOperationException(s"Cannot to add ${sparse.kind} to ${this.kind}")
103 |
104 | def plusBy(sketchGrad: SketchGradient): Gradient = throw new
105 | UnsupportedOperationException(s"Cannot to add ${sketchGrad.kind} to ${this.kind}")
106 |
107 | def plusBy(fpGrad: FixedPointGradient): Gradient = throw new
108 | UnsupportedOperationException(s"Cannot to add ${fpGrad.kind} to ${this.kind}")
109 |
110 | def plusBy(zipGrad: ZipGradient): Gradient = throw new
111 | UnsupportedOperationException(s"Cannot to add ${zipGrad.kind} to ${this.kind}")
112 |
113 | def plusBy(v: Vector, x: Double): Gradient = {
114 | v match {
115 | case dense: DenseVector => plusBy(dense, x)
116 | case sparse: SparseVector => plusBy(sparse, x)
117 | }
118 | }
119 |
120 | def plusBy(dense: DenseVector, x: Double): Gradient = throw new
121 | UnsupportedOperationException(s"Cannot to add DenseVector to ${this.kind}")
122 |
123 | def plusBy(sparse: SparseVector, x: Double): Gradient = throw new
124 | UnsupportedOperationException(s"Cannot to add SparseVector to ${this.kind}")
125 |
126 | def timesBy(x: Double)
127 |
128 | def countNNZ: Int
129 |
130 | def toDense: DenseDoubleGradient
131 |
132 | def toSparse: SparseDoubleGradient
133 |
134 | def toAuto: Gradient
135 |
136 | def kind: Kind
137 |
138 | def +=(o: Gradient): Gradient = plusBy(o)
139 |
140 | }
141 |
142 | /**
143 | * Singleton object for zero value of gradients
144 | */
145 | @Singleton
146 | object ZeroGradient {
147 | private val instance = new ZeroGradient()
148 |
149 | def getInstance(): ZeroGradient = instance
150 | }
151 |
152 | class ZeroGradient private extends Gradient(1) {
153 | override def plusBy(o: Gradient): Gradient = o
154 |
155 | override def timesBy(x: Double): Unit = {}
156 |
157 | override def countNNZ: Int = 0
158 |
159 | override def toDense: DenseDoubleGradient = ???
160 |
161 | override def toSparse: SparseDoubleGradient = ???
162 |
163 | override def toAuto: Gradient = ???
164 |
165 | override def kind: Kind = Kind.ZeroGradient
166 |
167 | override def plusBy(dense: DenseDoubleGradient): Gradient = dense
168 |
169 | override def plusBy(sparse: SparseDoubleGradient): Gradient = sparse
170 |
171 | override def plusBy(dense: DenseFloatGradient): Gradient = dense
172 |
173 | override def plusBy(sparse: SparseFloatGradient): Gradient = sparse
174 |
175 | override def plusBy(sketchGrad: SketchGradient): Gradient = sketchGrad
176 |
177 | override def plusBy(fpGrad: FixedPointGradient): Gradient = fpGrad
178 |
179 | override def plusBy(zipGrad: ZipGradient): Gradient = zipGrad
180 |
181 | }
182 |
183 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/gradient/Kind.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.gradient
2 |
3 | object Kind extends Enumeration {
4 | type Kind = Value
5 | val ZeroGradient, DenseDouble, SparseDouble, DenseFloat, SparseFloat, Sketch, FixedPoint, Zip = Value
6 | }
7 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/gradient/SketchGradient.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.gradient
2 |
3 | import org.dma.sketchml.ml.gradient.Kind.Kind
4 | import org.dma.sketchml.sketch.base.SketchMLException
5 | import org.dma.sketchml.sketch.quantization.QuantileQuantizer
6 | import org.dma.sketchml.sketch.sketch.frequency.GroupedMinMaxSketch
7 |
8 | class SketchGradient(d: Int, binNum: Int, groupNum: Int, rowNum: Int, colRatio: Double) extends Gradient(d) {
9 |
10 | def this(grad: Gradient, binNum: Int, groupNum: Int, rowNum: Int, colRatio: Double) {
11 | this(grad.dim, binNum, groupNum, rowNum, colRatio)
12 | grad.kind match {
13 | case Kind.DenseDouble => fromDense(grad.asInstanceOf[DenseDoubleGradient])
14 | case Kind.SparseDouble => fromSparse(grad.asInstanceOf[SparseDoubleGradient])
15 | case _ => throw new SketchMLException(s"Cannot create ${this.kind} from ${grad.kind}")
16 | }
17 | }
18 |
19 | private var nnz: Int = 0
20 | var bucketValues: Array[Double] = _
21 | var bins: Array[Int] = _
22 | var sketch: GroupedMinMaxSketch = _
23 |
24 | def fromDense(dense: DenseDoubleGradient): Unit = {
25 | val values = dense.values
26 | val quantizer = new QuantileQuantizer(binNum)
27 | quantizer.quantize(values)
28 | //quantizer.parallelQuantize(values)
29 | bucketValues = quantizer.getValues
30 | bins = quantizer.getBins
31 | sketch = null
32 | nnz = dim
33 | }
34 |
35 | def fromSparse(sparse: SparseDoubleGradient): Unit = {
36 | // 1. quantize into bin indexes
37 | val quantizer = new QuantileQuantizer(binNum)
38 | quantizer.quantize(sparse.values)
39 | //quantizer.parallelQuantize(sparse.values)
40 | bucketValues = quantizer.getValues
41 | // 2. encode bins and keys
42 | sketch = new GroupedMinMaxSketch(groupNum, rowNum, colRatio, quantizer.getBinNum, quantizer.getZeroIdx)
43 | sketch.create(sparse.indices, quantizer.getBins)
44 | bins = null
45 | //sketch.parallelCreate(sparse.indices, quantizer.getBins)
46 | // 3. set nnz
47 | nnz = sparse.indices.length
48 | }
49 |
50 | override def timesBy(x: Double): Unit = {
51 | for (i <- bucketValues.indices)
52 | bucketValues(i) *= x
53 | }
54 |
55 | override def countNNZ: Int = nnz
56 |
57 | override def toDense: DenseDoubleGradient = {
58 | val values = bins.map(bin => bucketValues(bin))
59 | new DenseDoubleGradient(dim, values)
60 | }
61 |
62 | override def toSparse: SparseDoubleGradient = {
63 | val kb = sketch.restore()
64 | val indices = kb.getLeft
65 | val bins = kb.getRight
66 | val values = bins.map(bin => bucketValues(bin))
67 | new SparseDoubleGradient(dim, indices, values)
68 | }
69 |
70 | override def toAuto: Gradient = (if (bins != null) toDense else toSparse).toAuto
71 |
72 | override def kind: Kind = Kind.Sketch
73 | }
74 |
75 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/gradient/SparseDoubleGradient.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.gradient
2 |
3 | import org.dma.sketchml.ml.gradient.Kind.Kind
4 | import org.dma.sketchml.ml.util.Maths
5 |
6 | class SparseDoubleGradient(d: Int, val indices: Array[Int],
7 | val values: Array[Double]) extends Gradient(d) {
8 | {
9 | require(indices.length == values.length,
10 | s"Sizes of indices and values not match: ${indices.length} & ${values.length}")
11 | require(indices.head >= 0, s"Negative index: ${indices.head}.")
12 | for (i <- 1 until indices.length)
13 | require(indices(i - 1) < indices(i), s"Indices are not strictly increasing")
14 | require(indices.last < dim, s"Index ${indices.last} out of bounds for gradient of dimension $dim")
15 | }
16 |
17 | //override def plusBy(sparse: SparseDoubleGradient): Gradient = {
18 | // val kv = Maths.add(this.indices, this.values, sparse.indices, sparse.values)
19 | // new SparseDoubleGradient(dim, kv._1, kv._2)
20 | //}
21 |
22 | override def timesBy(x: Double): Unit = {
23 | for (i <- values.indices)
24 | values(i) *= x
25 | }
26 |
27 | override def countNNZ: Int = {
28 | var nnz = 0
29 | for (i <- values.indices)
30 | if (Math.abs(values(i)) > Maths.EPS)
31 | nnz += 1
32 | nnz
33 | }
34 |
35 | override def toDense: DenseDoubleGradient = {
36 | val dense = new Array[Double](dim)
37 | for (i <- values.indices)
38 | if (Math.abs(values(i)) > Maths.EPS)
39 | dense(indices(i)) = values(i)
40 | new DenseDoubleGradient(dim, dense)
41 | }
42 |
43 | override def toSparse: SparseDoubleGradient = this
44 |
45 | override def toAuto: Gradient = {
46 | val nnz = countNNZ
47 | if (nnz > dim * 2 / 3) toDense else toSparse
48 | }
49 |
50 | override def kind: Kind = Kind.SparseDouble
51 | }
52 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/gradient/SparseFloatGradient.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.gradient
2 |
3 | import org.dma.sketchml.ml.gradient.Kind.Kind
4 | import org.dma.sketchml.ml.util.Maths
5 | import org.dma.sketchml.sketch.base.SketchMLException
6 |
7 | object SparseFloatGradient {
8 | def apply(grad: Gradient): SparseFloatGradient = {
9 | val (indices, values) = grad.kind match {
10 | case Kind.DenseDouble => {
11 | val dense = grad.asInstanceOf[DenseDoubleGradient]
12 | ((0 until dense.dim).toArray, dense.values)
13 | }
14 | case Kind.SparseDouble => {
15 | val sparse = grad.asInstanceOf[SparseDoubleGradient]
16 | (sparse.indices, sparse.values)
17 | }
18 | case _ => throw new SketchMLException(s"Cannot create ${Kind.SparseFloat} from ${grad.kind}")
19 | }
20 | new SparseFloatGradient(grad.dim, indices, values.map(_.toFloat))
21 | }
22 | }
23 |
24 | class SparseFloatGradient(d: Int, val indices: Array[Int],
25 | val values: Array[Float]) extends Gradient(d) {
26 | {
27 | require(indices.length == values.length,
28 | s"Sizes of indices and values not match: ${indices.length} & ${values.length}")
29 | require(indices.head >= 0, s"Negative index: ${indices.head}.")
30 | for (i <- 1 until indices.length)
31 | require(indices(i - 1) < indices(i), s"Indices are not strictly increasing")
32 | require(indices.last < dim, s"Index ${indices.last} out of bounds for gradient of dimension $dim")
33 | }
34 |
35 | override def timesBy(x: Double): Unit = {
36 | val x_ = x.toFloat
37 | for (i <- values.indices)
38 | values(i) *= x_
39 | }
40 |
41 | override def countNNZ: Int = {
42 | var nnz = 0
43 | for (i <- values.indices)
44 | if (Math.abs(values(i)) > Maths.EPS)
45 | nnz += 1
46 | nnz
47 | }
48 |
49 | override def toDense: DenseDoubleGradient = {
50 | val dense = new Array[Double](dim)
51 | for (i <- values.indices)
52 | if (Math.abs(values(i)) > Maths.EPS)
53 | dense(indices(i)) = values(i)
54 | new DenseDoubleGradient(dim, dense)
55 | }
56 |
57 | override def toSparse: SparseDoubleGradient =
58 | new SparseDoubleGradient(dim, indices, values.map(_.toDouble))
59 |
60 | override def toAuto: Gradient = {
61 | val nnz = countNNZ
62 | if (nnz > dim * 2 / 3) toDense else toSparse
63 | }
64 |
65 | override def kind: Kind = Kind.SparseFloat
66 | }
67 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/gradient/ZipGradient.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.gradient
2 |
3 | import java.util
4 | import java.util.concurrent.ExecutionException
5 |
6 | import org.dma.sketchml.ml.gradient.Kind.Kind
7 | import org.dma.sketchml.sketch.base.{Quantizer, SketchMLException}
8 | import org.dma.sketchml.sketch.util.Sort
9 | import org.slf4j.{Logger, LoggerFactory}
10 |
11 | class ZipGradient(d: Int, binNum: Int) extends Gradient(d) {
12 | private var size: Int = 0
13 | var indices: Array[Int] = _
14 | var quantizer: ZipMLQuantizer = _
15 |
16 | def this(grad: Gradient, binNum: Int) {
17 | this(grad.dim, binNum)
18 | grad.kind match {
19 | case Kind.DenseDouble => fromDense(grad.asInstanceOf[DenseDoubleGradient])
20 | case Kind.SparseDouble => fromSparse(grad.asInstanceOf[SparseDoubleGradient])
21 | case _ => throw new SketchMLException(s"Cannot create ${this.kind} from ${grad.kind}")
22 | }
23 | }
24 |
25 | def fromDense(dense: DenseDoubleGradient): Unit = fromArray(dense.values)
26 |
27 | def fromSparse(sparse: SparseDoubleGradient): Unit = {
28 | indices = sparse.indices
29 | fromArray(sparse.values)
30 | }
31 |
32 | private def fromArray(values: Array[Double]): Unit = {
33 | size = values.length
34 | quantizer = new ZipMLQuantizer(binNum)
35 | quantizer.quantize(values)
36 | //quantizer.parallelQuantize(values)
37 | }
38 |
39 | override def timesBy(x: Double): Unit = quantizer.timesBy(x)
40 |
41 | override def countNNZ: Int = size
42 |
43 | override def toDense: DenseDoubleGradient = {
44 | val bucketValues = quantizer.getValues
45 | val values = quantizer.getBins.map(bin => bucketValues(bin))
46 | new DenseDoubleGradient(dim, values)
47 | }
48 |
49 | override def toSparse: SparseDoubleGradient = {
50 | val bucketValues = quantizer.getValues
51 | val values = quantizer.getBins.map(bin => bucketValues(bin))
52 | new SparseDoubleGradient(dim, indices, values)
53 | }
54 |
55 | override def toAuto: Gradient = (if (indices == null) toDense else toSparse).toAuto
56 |
57 | override def kind: Kind = Kind.Zip
58 | }
59 |
60 |
61 | object ZipMLQuantizer {
62 | private val logger: Logger = LoggerFactory.getLogger(classOf[ZipMLQuantizer])
63 | }
64 |
65 | class ZipMLQuantizer(b: Int) extends Quantizer(b) {
66 | import ZipMLQuantizer._
67 |
68 | def this() = this(Quantizer.DEFAULT_BIN_NUM)
69 |
70 | override def quantize(values: Array[Double]): Unit = {
71 | val startTime = System.currentTimeMillis
72 | n = values.length
73 | // 1. pre-compute the errors
74 | val sortedValues = values.clone
75 | util.Arrays.sort(sortedValues)
76 | val r = new Array[Double](n)
77 | val t = new Array[Double](n)
78 | r(0) = sortedValues(0)
79 | t(0) = sortedValues(0) * sortedValues(0)
80 | for (i <- 1 until n) {
81 | r(i) = r(i - 1) + sortedValues(i)
82 | t(i) = t(i - 1) + sortedValues(i) * sortedValues(i)
83 | }
84 | // 2. find split points
85 | var splitNum = n
86 | val splitIndex = (0 until n).toArray
87 | while (splitNum > binNum) {
88 | val errors = new Array[Double](splitNum)
89 | for (i <- 0 until splitNum / 2) {
90 | val L = splitIndex(2 * i)
91 | val R = (if (2 * i + 2 >= splitNum) n else splitIndex(2 * i + 2)) - 1
92 | val l1Sum = r(R) - (if (L == 0) 0 else r(L - 1))
93 | val l2Sum = t(R) - (if (L == 0) 0 else t(L - 1))
94 | val mean = l1Sum / (R - L + 1)
95 | errors(i) = l2Sum + mean * mean * (R - L + 1) - 2 * mean * l1Sum
96 | }
97 |
98 | val thrNum = binNum / 2 - (if (splitNum % 2 == 1) 1 else 0)
99 | val threshold = Sort.selectKthLargest(errors.clone, thrNum)
100 |
101 | var newSplitNum = 0
102 | for (i <- 0 until splitNum / 2) {
103 | if (errors(i) >= threshold) {
104 | splitIndex(newSplitNum) = splitIndex(2 * i); newSplitNum += 1
105 | splitIndex(newSplitNum) = splitIndex(2 * i + 1); newSplitNum += 1
106 | } else {
107 | splitIndex(newSplitNum) = splitIndex(2 * i); newSplitNum += 1
108 | }
109 | }
110 |
111 | if (splitNum % 2 == 1) {
112 | splitIndex(newSplitNum) = splitIndex(splitNum - 1); splitNum += 1
113 | }
114 | splitNum = newSplitNum
115 | }
116 |
117 | min = sortedValues(0)
118 | max = sortedValues(n - 1)
119 | binNum = splitNum
120 | splits = new Array[Double](binNum - 1)
121 | for (i <- 0 until binNum - 1)
122 | splits(i) = sortedValues(splitIndex(i + 1))
123 | // 2. find the zero index
124 | findZeroIdx()
125 | // 3.find index of each value
126 | quantizeToBins(values)
127 | logger.debug(s"ZipML quantization for $n items cost " +
128 | s"${System.currentTimeMillis - startTime} ms")
129 | }
130 |
131 | @throws[InterruptedException]
132 | @throws[ExecutionException]
133 | override def parallelQuantize(values: Array[Double]): Unit = {
134 | logger.warn(s"ZipML quantization should be sequential")
135 | quantize(values)
136 | }
137 |
138 | override def quantizationType: Quantizer.QuantizationType = ???
139 | }
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/objective/Adam.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.objective
2 |
3 | import org.apache.spark.ml.linalg.DenseVector
4 | import org.dma.sketchml.ml.conf.MLConf
5 | import org.dma.sketchml.ml.gradient._
6 | import org.dma.sketchml.ml.util.Maths
7 | import org.slf4j.{Logger, LoggerFactory}
8 |
9 | object Adam {
10 | private val logger: Logger = LoggerFactory.getLogger(Adam.getClass)
11 |
12 | def apply(conf: MLConf): GradientDescent =
13 | new Adam(conf.featureNum, conf.learnRate, conf.learnDecay, conf.batchSpRatio)
14 | }
15 |
16 | class Adam(dim: Int, lr_0: Double, decay: Double, batchSpRatio: Double)
17 | extends GradientDescent(dim, lr_0, decay, batchSpRatio) {
18 | override protected val logger = Adam.logger
19 |
20 | val beta1 = 0.9
21 | val beta2 = 0.999
22 | var beta1_t = 0.9
23 | var beta2_t = 0.999
24 | val m = new Array[Double](dim)
25 | val v = new Array[Double](dim)
26 |
27 | override def update(grad: Gradient, weight: DenseVector): Unit = {
28 | val startTime = System.currentTimeMillis()
29 | if (epoch > 0 && batch == 0) {
30 | beta1_t *= beta1
31 | beta2_t *= beta2
32 | }
33 | update0(grad, weight)
34 | logger.info(s"Update weight cost ${System.currentTimeMillis() - startTime} ms")
35 | }
36 |
37 | private def update0(grad: Gradient, weight: DenseVector): Unit = {
38 | grad match {
39 | case dense: DenseDoubleGradient => update(dense, weight, lr_0)
40 | case sparse: SparseDoubleGradient => update(sparse, weight, lr_0)
41 | case dense: DenseFloatGradient => update(dense, weight, lr_0)
42 | case sparse: SparseFloatGradient => update(sparse, weight, lr_0)
43 | case sketchGrad: SketchGradient => update0(sketchGrad.toAuto, weight)
44 | case fpGrad: FixedPointGradient => update0(fpGrad.toAuto, weight)
45 | case zipGrad: ZipGradient => update0(zipGrad.toAuto, weight)
46 | case _ => throw new ClassNotFoundException(grad.getClass.getName)
47 | }
48 | }
49 |
50 | private def update(grad: DenseDoubleGradient, weight: DenseVector, lr: Double): Unit = {
51 | val g = grad.values
52 | val w = weight.values
53 | for (i <- w.indices) {
54 | val m_t = beta1 * m(i) + (1 - beta1) * g(i)
55 | val v_t = beta2 * v(i) + (1 - beta2) * g(i) * g(i)
56 | val newGrad = (Math.sqrt(1 - beta2_t) * m_t) / ((1 - beta1_t) * (Math.sqrt(v_t) + Maths.EPS))
57 | w(i) -= newGrad * lr
58 | m(i) = m_t
59 | v(i) = v_t
60 | }
61 | }
62 |
63 | private def update(grad: SparseDoubleGradient, weight: DenseVector, lr: Double): Unit = {
64 | val k = grad.indices
65 | val g = grad.values
66 | val w = weight.values
67 | for (i <- k.indices) {
68 | val dim = k(i)
69 | val grad = g(i)
70 | val m_t = beta1 * m(dim) + (1 - beta1) * grad
71 | val v_t = beta2 * v(dim) + (1 - beta2) * grad * grad
72 | val newGrad = (Math.sqrt(1 - beta2_t) * m_t) / ((1 - beta1_t) * (Math.sqrt(v_t) + Maths.EPS))
73 | w(dim) -= newGrad * lr
74 | m(dim) = m_t
75 | v(dim) = v_t
76 | }
77 | }
78 |
79 | private def update(grad: DenseFloatGradient, weight: DenseVector, lr: Double): Unit = {
80 | val g = grad.values
81 | val w = weight.values
82 | for (i <- w.indices) {
83 | val m_t = beta1 * m(i) + (1 - beta1) * g(i)
84 | val v_t = beta2 * v(i) + (1 - beta2) * g(i) * g(i)
85 | val newGrad = (Math.sqrt(1 - beta2_t) * m_t) / ((1 - beta1_t) * (Math.sqrt(v_t) + Maths.EPS))
86 | w(i) -= newGrad * lr
87 | m(i) = m_t
88 | v(i) = v_t
89 | }
90 | }
91 |
92 | private def update(grad: SparseFloatGradient, weight: DenseVector, lr: Double): Unit = {
93 | val k = grad.indices
94 | val g = grad.values
95 | val w = weight.values
96 | for (i <- k.indices) {
97 | val dim = k(i)
98 | val grad = g(i)
99 | val m_t = beta1 * m(dim) + (1 - beta1) * grad
100 | val v_t = beta2 * v(dim) + (1 - beta2) * grad * grad
101 | val newGrad = (Math.sqrt(1 - beta2_t) * m_t) / ((1 - beta1_t) * (Math.sqrt(v_t) + Maths.EPS))
102 | w(dim) -= newGrad * lr
103 | m(dim) = m_t
104 | v(dim) = v_t
105 | }
106 | }
107 |
108 |
109 |
110 | }
111 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/objective/GradientDescent.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.objective
2 |
3 | import org.apache.spark.ml.linalg.DenseVector
4 | import org.dma.sketchml.ml.conf.MLConf
5 | import org.dma.sketchml.ml.data.DataSet
6 | import org.dma.sketchml.ml.gradient._
7 | import org.slf4j.{Logger, LoggerFactory}
8 |
9 | object GradientDescent {
10 | private val logger: Logger = LoggerFactory.getLogger(GradientDescent.getClass)
11 |
12 | def apply(conf: MLConf): GradientDescent =
13 | new GradientDescent(conf.featureNum, conf.learnRate, conf.learnDecay, conf.batchSpRatio)
14 | }
15 |
16 | class GradientDescent(dim: Int, lr_0: Double, decay: Double, batchSpRatio: Double) {
17 | protected val logger = GradientDescent.logger
18 |
19 | var epoch: Int = 0
20 | var batch: Int = 0
21 | val batchNum: Double = Math.ceil(1.0 / batchSpRatio).toInt
22 |
23 | def miniBatchGradientDescent(weight: DenseVector, dataSet: DataSet, loss: Loss): (Gradient, Int, Double, Double) = {
24 | val startTime = System.currentTimeMillis()
25 |
26 | val denseGrad = new DenseDoubleGradient(dim)
27 | var objLoss = 0.0
28 | val batchSize = (dataSet.size * batchSpRatio).toInt
29 | for (i <- 0 until batchSize) {
30 | val ins = dataSet.loopingRead
31 | val pre = loss.predict(weight, ins.feature)
32 | val gradScala = loss.grad(pre, ins.label)
33 | denseGrad.plusBy(ins.feature, -1.0 * gradScala)
34 | objLoss += loss.loss(pre, ins.label)
35 | }
36 | val grad = denseGrad.toAuto
37 | grad.timesBy(1.0 / batchSize)
38 |
39 | if (loss.isL1Reg)
40 | l1Reg(grad, 0, loss.getRegParam)
41 | if (loss.isL2Reg)
42 | l2Reg(grad, weight, loss.getRegParam)
43 | val regLoss = loss.getReg(weight)
44 |
45 | logger.info(s"Epoch[$epoch] batch $batch gradient " +
46 | s"cost ${System.currentTimeMillis() - startTime} ms, "
47 | + s"batch size=$batchSize, obj loss=${objLoss / batchSize}, reg loss=$regLoss")
48 | batch += 1
49 | if (batch == batchNum) { epoch += 1; batch = 0 }
50 | (grad, batchSize, objLoss, regLoss)
51 | }
52 |
53 | private def l1Reg(grad: Gradient, alpha: Double, theta: Double): Unit = {
54 | val values = grad match {
55 | case dense: DenseDoubleGradient => dense.values
56 | case sparse: SparseDoubleGradient => sparse.values
57 | case _ => throw new UnsupportedOperationException(
58 | s"Cannot regularize ${grad.kind} kind of gradients")
59 | }
60 | if (values != null) {
61 | for (i <- values.indices) {
62 | if (values(i) >= 0 && values(i) <= theta)
63 | values(i) = (values(i) - alpha) max 0
64 | else if (values(i) < 0 && values(i) >= -theta)
65 | values(i) = (values(i) - alpha) min 0
66 | }
67 | }
68 | }
69 |
70 | private def l2Reg(grad: Gradient, weight: DenseVector, lambda: Double): Unit = {
71 | val w = weight.values
72 | grad match {
73 | case dense: DenseDoubleGradient => {
74 | val v = dense.values
75 | for (i <- v.indices)
76 | v(i) += w(i) * lambda
77 | }
78 | case sparse: SparseDoubleGradient => {
79 | val k = sparse.indices
80 | val v = sparse.values
81 | for (i <- k.indices)
82 | v(i) += w(k(i)) * lambda
83 | }
84 | case _ => throw new UnsupportedOperationException(
85 | s"Cannot regularize ${grad.kind} kind of gradients")
86 | }
87 | }
88 |
89 | def update(grad: Gradient, weight: DenseVector): Unit = {
90 | val startTime = System.currentTimeMillis()
91 | val lr = lr_0 / Math.sqrt(1.0 + decay * epoch)
92 | grad match {
93 | case dense: DenseDoubleGradient => update(dense, weight, lr)
94 | case sparse: SparseDoubleGradient => update(sparse, weight, lr)
95 | case dense: DenseFloatGradient => update(dense, weight, lr)
96 | case sparse: SparseFloatGradient => update(sparse, weight, lr)
97 | case sketchGrad: SketchGradient => update(sketchGrad.toAuto, weight)
98 | case fpGrad: FixedPointGradient => update(fpGrad.toAuto, weight)
99 | case zipGrad: ZipGradient => update(zipGrad.toAuto, weight)
100 | }
101 | logger.info(s"Update weight cost ${System.currentTimeMillis() - startTime} ms")
102 | }
103 |
104 | private def update(grad: DenseDoubleGradient, weight: DenseVector, lr: Double): Unit = {
105 | val g = grad.values
106 | val w = weight.values
107 | for (i <- w.indices)
108 | w(i) -= g(i) * lr
109 | }
110 |
111 | private def update(grad: SparseDoubleGradient, weight: DenseVector, lr: Double): Unit = {
112 | val k = grad.indices
113 | val v = grad.values
114 | val w = weight.values
115 | for (i <- k.indices)
116 | w(k(i)) -= v(i) * lr
117 | }
118 |
119 | private def update(grad: DenseFloatGradient, weight: DenseVector, lr: Double): Unit = {
120 | val g = grad.values
121 | val w = weight.values
122 | for (i <- w.indices)
123 | w(i) -= g(i) * lr
124 | }
125 |
126 | private def update(grad: SparseFloatGradient, weight: DenseVector, lr: Double): Unit = {
127 | val k = grad.indices
128 | val v = grad.values
129 | val w = weight.values
130 | for (i <- k.indices)
131 | w(k(i)) -= v(i) * lr
132 | }
133 |
134 | }
135 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/objective/Loss.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.objective
2 |
3 | import org.apache.spark.ml.linalg.{Vector, Vectors}
4 | import org.dma.sketchml.ml.util.Maths
5 |
6 | trait Loss extends Serializable {
7 | def loss(pre: Double, y: Double): Double
8 |
9 | def grad(pre: Double, y: Double): Double
10 |
11 | def predict(w: Vector, x: Vector): Double
12 |
13 | def isL1Reg: Boolean
14 |
15 | def isL2Reg: Boolean
16 |
17 | def getRegParam: Double
18 |
19 | def getReg(w: Vector): Double
20 | }
21 |
22 | abstract class L1Loss extends Loss {
23 | protected var lambda: Double
24 |
25 | override def isL1Reg: Boolean = this.lambda > Maths.EPS
26 |
27 | override def isL2Reg: Boolean = false
28 |
29 | override def getRegParam: Double = lambda
30 |
31 | override def getReg(w: Vector): Double = {
32 | if (isL1Reg)
33 | Vectors.norm(w, 1) * lambda
34 | else
35 | 0.0
36 | }
37 | }
38 |
39 | abstract class L2Loss extends Loss {
40 | protected var lambda: Double
41 |
42 | def isL1Reg: Boolean = false
43 |
44 | def isL2Reg: Boolean = lambda > Maths.EPS
45 |
46 | override def getRegParam: Double = lambda
47 |
48 | override def getReg(w: Vector): Double = {
49 | if (isL2Reg)
50 | Vectors.norm(w, 2) * lambda
51 | else
52 | 0.0
53 | }
54 | }
55 |
56 | class L1LogLoss(l: Double) extends L1Loss {
57 | override protected var lambda: Double = l
58 |
59 | override def loss(pre: Double, y: Double): Double = {
60 | val z = pre * y
61 | if (z > 18)
62 | Math.exp(-z)
63 | else if (z < -18)
64 | -z
65 | else
66 | Math.log(1 + Math.exp(-z))
67 | }
68 |
69 | override def grad(pre: Double, y: Double): Double = {
70 | val z = pre * y
71 | if (z > 18)
72 | y * Math.exp(-z)
73 | else if (z < -18)
74 | y
75 | else
76 | y / (1.0 + Math.exp(z))
77 | }
78 |
79 | override def predict(w: Vector, x: Vector): Double = Maths.dot(w, x)
80 | }
81 |
82 | class L2HingeLoss(l: Double) extends L2Loss {
83 | override protected var lambda: Double = l
84 |
85 | override def loss(pre: Double, y: Double): Double = {
86 | val z = pre * y
87 | if (z < 1.0)
88 | 1.0 - z
89 | else
90 | 0.0
91 | }
92 |
93 | override def grad(pre: Double, y: Double): Double = {
94 | if (pre * y <= 1.0)
95 | y
96 | else
97 | 0.0
98 | }
99 |
100 | override def predict(w: Vector, x: Vector): Double = Maths.dot(w, x)
101 | }
102 |
103 | class L2LogLoss(l: Double) extends L2Loss {
104 | override protected var lambda: Double = l
105 |
106 | override def loss(pre: Double, y: Double): Double = {
107 | val z = pre * y
108 | if (z > 18)
109 | Math.exp(-z)
110 | else if (z < -18)
111 | -z
112 | else
113 | Math.log(1.0 + Math.exp(-z))
114 | }
115 |
116 | override def grad(pre: Double, y: Double): Double = {
117 | val z = pre * y
118 | if (z > 18)
119 | y * Math.exp(-z)
120 | else if (z < -18)
121 | y
122 | else
123 | y / (1.0 + Math.exp(z))
124 | }
125 |
126 | override def predict(w: Vector, x: Vector): Double = Maths.dot(w, x)
127 | }
128 |
129 | class L2SquareLoss(l: Double) extends L2Loss {
130 | override protected var lambda: Double = l
131 |
132 | override def loss(pre: Double, y: Double): Double = 0.5 * (pre - y) * (pre - y)
133 |
134 | override def grad(pre: Double, y: Double): Double = y - pre
135 |
136 | override def predict(w: Vector, x: Vector): Double = Maths.dot(w, x)
137 | }
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/util/Maths.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.util
2 |
3 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
4 |
5 | import scala.collection.mutable.ArrayBuffer
6 |
7 | object Maths {
8 | val EPS = 1e-8
9 |
10 | def add(k1: Array[Int], v1: Array[Double], k2: Array[Int],
11 | v2: Array[Double]): (Array[Int], Array[Double]) = {
12 | val k = ArrayBuffer[Int]()
13 | val v = ArrayBuffer[Double]()
14 | var i = 0
15 | var j = 0
16 | while (i < k1.length && j < k2.length) {
17 | if (k1(i) < k2(j)) {
18 | k += k1(i)
19 | v += v1(i)
20 | i += 1
21 | } else if (k1(i) > k2(j)) {
22 | k += k2(j)
23 | v += v2(j)
24 | j += 1
25 | } else {
26 | k += k1(i)
27 | v += v1(i) + v2(j)
28 | i += 1
29 | j += 1
30 | }
31 | }
32 | (k.toArray, v.toArray)
33 | }
34 |
35 | def dot(a: Vector, b: Vector): Double = {
36 | (a, b) match {
37 | case (a: DenseVector, b: DenseVector) => dot(a, b)
38 | case (a: DenseVector, b: SparseVector) => dot(a, b)
39 | case (a: SparseVector, b: DenseVector) => dot(a, b)
40 | case (a: SparseVector, b: SparseVector) => dot(a, b)
41 | }
42 | }
43 |
44 | def dot(a: DenseVector, b: DenseVector): Double = {
45 | require(a.size == b.size, s"Dot between vectors of size ${a.size} and ${b.size}")
46 | //(a.values, b.values).zipped.map(_*_).sum
47 | val size = a.size
48 | val aValues = a.values
49 | val bValues = b.values
50 | var dot = 0.0
51 | for (i <- 0 until size) {
52 | dot += aValues(i) * bValues(i)
53 | }
54 | dot
55 | }
56 |
57 | def dot(a: DenseVector, b: SparseVector): Double = {
58 | require(a.size == b.size, s"Dot between vectors of size ${a.size} and ${b.size}")
59 | val aValues = a.values
60 | val bIndices = b.indices
61 | val bValues = b.values
62 | val size = b.numActives
63 | var dot = 0.0
64 | for (i <- 0 until size) {
65 | val ind = bIndices(i)
66 | dot += aValues(ind) * bValues(i)
67 | }
68 | dot
69 | }
70 |
71 | def dot(a: SparseVector, b: DenseVector): Double = dot(b, a)
72 |
73 | def dot(a: SparseVector, b: SparseVector): Double = {
74 | require(a.size == b.size, s"Dot between vectors of size ${a.size} and ${b.size}")
75 | val aIndices = a.indices
76 | val aValues = a.values
77 | val aNumActives = a.numActives
78 | val bIndices = b.indices
79 | val bValues = b.values
80 | val bNumActives = b.numActives
81 | var aOff = 0
82 | var bOff = 0
83 | var dot = 0.0
84 | while (aOff < aNumActives && bOff < bNumActives) {
85 | if (aIndices(aOff) < bIndices(bOff)) {
86 | aIndices(aOff) += 1
87 | } else if (aIndices(aOff) > bIndices(bOff)) {
88 | bOff += 1
89 | } else {
90 | dot += aValues(aOff) * bValues(bOff)
91 | aOff += 1
92 | bOff += 1
93 | }
94 | }
95 | dot
96 | }
97 |
98 | def euclidean(a: Array[Double], b: Array[Double]): Double = {
99 | require(a.length == b.length)
100 | (a, b).zipped.map((x, y) => (x - y) * (x - y)).sum
101 | }
102 |
103 | def cosine(a: Array[Double], b: Array[Double]): Double = {
104 | val va = new DenseVector(a)
105 | val vb = new DenseVector(b)
106 | dot(va, vb) / (Vectors.norm(va, 2) * Vectors.norm(vb, 2))
107 | }
108 |
109 | }
110 |
--------------------------------------------------------------------------------
/ml/src/main/scala/org/dma/sketchml/ml/util/ValidationUtil.scala:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.ml.util
2 |
3 | import org.apache.spark.ml.linalg.Vector
4 | import org.dma.sketchml.ml.data.DataSet
5 | import org.dma.sketchml.ml.objective.Loss
6 | import org.dma.sketchml.sketch.util.Sort
7 | import org.slf4j.{Logger, LoggerFactory}
8 |
9 | object ValidationUtil {
10 | private val logger: Logger = LoggerFactory.getLogger(ValidationUtil.getClass)
11 |
12 | def calLossPrecision(weights: Vector, validData: DataSet, loss: Loss): (Double, Int, Int, Int, Int, Int) = {
13 | val validStart = System.currentTimeMillis()
14 | val validNum = validData.size
15 | var validLoss = 0.0
16 | var truePos = 0 // ground truth: positive, prediction: positive
17 | var falsePos = 0 // ground truth: negative, prediction: positive
18 | var trueNeg = 0 // ground truth: negative, prediction: negative
19 | var falseNeg = 0 // ground truth: positive, prediction: negative
20 |
21 | for (i <- 0 until validNum) {
22 | val ins = validData.get(i)
23 | val pre = loss.predict(weights, ins.feature)
24 | if (pre * ins.label > 0) {
25 | if (pre > 0) truePos += 1
26 | else trueNeg += 1
27 | } else if (pre * ins.label < 0) {
28 | if (pre > 0) falsePos += 1
29 | else falseNeg += 1
30 | }
31 | validLoss += loss.loss(pre, ins.label)
32 | }
33 |
34 | val precision = 1.0 * (truePos + trueNeg) / validNum
35 | val trueRecall = 1.0 * truePos / (truePos + falseNeg)
36 | val falseRecall = 1.0 * trueNeg / (trueNeg + falsePos)
37 | logger.info(s"validation cost ${System.currentTimeMillis() - validStart} ms, "
38 | + s"loss=$validLoss, precision=$precision, "
39 | + s"trueRecall=$trueRecall, falseRecall=$falseRecall")
40 | (validLoss, truePos, trueNeg, falsePos, falseNeg, validNum)
41 | }
42 |
43 | def calLossAucPrecision(weights: Vector, validData: DataSet, loss: Loss): (Double, Int, Int, Int, Int, Int) = {
44 | val validStart = System.currentTimeMillis()
45 | val validNum = validData.size
46 | var validLoss = 0.0
47 | val scoresArray = new Array[Double](validNum)
48 | val labelsArray = new Array[Double](validNum)
49 | var truePos = 0 // ground truth: positive, precision: positive
50 | var falsePos = 0 // ground truth: negative, precision: positive
51 | var trueNeg = 0 // ground truth: negative, precision: negative
52 | var falseNeg = 0 // ground truth: positive, precision: negative
53 |
54 | for (i <- 0 until validNum) {
55 | val ins = validData.get(i)
56 | val pre = loss.predict(weights, ins.feature)
57 | if (pre * ins.label > 0) {
58 | if (pre > 0) truePos += 1
59 | else trueNeg += 1
60 | } else if (pre * ins.label < 0) {
61 | if (pre > 0) falsePos += 1
62 | else falseNeg += 1
63 | }
64 | scoresArray(i) = pre
65 | labelsArray(i) = ins.label
66 | validLoss += loss.loss(pre, ins.label)
67 | }
68 |
69 | Sort.quickSort(scoresArray, labelsArray, 0, scoresArray.length)
70 | var M = 0L
71 | var N = 0L
72 | for (i <- 0 until validNum) {
73 | if (labelsArray(i) == 1)
74 | M += 1
75 | else
76 | N += 1
77 | }
78 | var sigma = 0.0
79 | for (i <- M + N - 1 to 0 by -1) {
80 | if (labelsArray(i.toInt) == 1.0)
81 | sigma += i
82 | }
83 | val aucResult = (sigma - (M + 1) * M / 2) / M / N
84 |
85 | val precision = 1.0 * (truePos + trueNeg) / validNum
86 | val trueRecall = 1.0 * truePos / (truePos + falseNeg)
87 | val falseRecall = 1.0 * trueNeg / (trueNeg + falsePos)
88 |
89 | logger.info(s"validation cost ${System.currentTimeMillis() - validStart} ms, "
90 | + s"loss=$validLoss, auc=$aucResult, precision=$precision, "
91 | + s"trueRecall=$trueRecall, falseRecall=$falseRecall")
92 | (validLoss, truePos, trueNeg, falsePos, falseNeg, validNum)
93 | }
94 | }
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | org.dma.sketchml
8 | sketchml
9 | pom
10 | 1.0.0
11 |
12 | sketch
13 | ml
14 |
15 |
16 |
17 | 1.8
18 | 2.11.7
19 | 1.7.25
20 |
21 |
22 |
23 |
24 | org.slf4j
25 | slf4j-log4j12
26 | ${slf4jVersion}
27 |
28 |
29 |
30 |
31 | src/main/java
32 | src/test/java
33 |
34 |
35 | org.apache.maven.plugins
36 | maven-compiler-plugin
37 | 3.3
38 |
39 | ${jdkVersion}
40 | ${jdkVersion}
41 | false
42 |
43 |
44 |
45 | org.apache.maven.plugins
46 | maven-assembly-plugin
47 | 2.5.4
48 |
49 |
50 | jar-with-dependencies
51 |
52 |
53 |
54 |
55 | make-assembly
56 | package
57 |
58 | single
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/sketch/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | sketchml
7 | org.dma.sketchml
8 | 1.0.0
9 | ../pom.xml
10 |
11 | 4.0.0
12 |
13 | sketch
14 |
15 |
16 |
17 | it.unimi.dsi
18 | fastutil
19 | 7.1.0
20 |
21 |
22 |
23 | org.apache.commons
24 | commons-lang3
25 | 3.0
26 |
27 |
28 |
29 |
30 |
31 |
32 | org.apache.maven.plugins
33 | maven-compiler-plugin
34 | 3.3
35 |
36 | ${jdkVersion}
37 | ${jdkVersion}
38 | false
39 |
40 |
41 |
42 | org.apache.maven.plugins
43 | maven-assembly-plugin
44 | 2.5.4
45 |
46 |
47 | jar-with-dependencies
48 |
49 |
50 |
51 |
52 | make-assembly
53 | package
54 |
55 | single
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/base/BinaryEncoder.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.base;
2 |
3 | import java.io.Serializable;
4 | import java.util.stream.IntStream;
5 |
6 | public interface BinaryEncoder extends Serializable {
7 | void encode(int[] values);
8 |
9 | int[] decode();
10 |
11 | }
12 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/base/Int2IntHash.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.base;
2 |
3 | import java.io.Serializable;
4 |
5 | public abstract class Int2IntHash implements Serializable {
6 | protected int size;
7 |
8 | public Int2IntHash(int size) {
9 | this.size = size;
10 | }
11 |
12 | public abstract int hash(int key);
13 |
14 | public abstract Int2IntHash clone();
15 |
16 | public int getSize() {
17 | return size;
18 | }
19 |
20 | public void setSize(int size) {
21 | this.size = size;
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/base/QuantileSketch.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.base;
2 |
3 |
4 | import java.io.Serializable;
5 |
6 | public abstract class QuantileSketch implements Serializable {
7 | protected long n; // total number of data items appeared
8 | protected long estimateN; // estimated total number of data items there will be,
9 | // if not -1, sufficient space will be allocated at once
10 |
11 | protected double minValue;
12 | protected double maxValue;
13 |
14 | public QuantileSketch(long estimateN) {
15 | this.estimateN = estimateN > 0 ? estimateN : -1L;
16 | }
17 |
18 | public QuantileSketch() {
19 | this(-1L);
20 | }
21 |
22 | public abstract void reset();
23 |
24 | public abstract void update(double value);
25 |
26 | public abstract void merge(QuantileSketch other);
27 |
28 | public abstract double getQuantile(double fraction);
29 |
30 | public abstract double[] getQuantiles(double[] fractions);
31 |
32 | public abstract double[] getQuantiles(int evenPartition);
33 |
34 | public boolean isEmpty() {
35 | return n == 0;
36 | }
37 |
38 | public long getN() {
39 | return n;
40 | }
41 |
42 | public long getEstimateN() {
43 | return estimateN;
44 | }
45 |
46 | public double getMinValue() {
47 | return minValue;
48 | }
49 |
50 | public double getMaxValue() {
51 | return maxValue;
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/base/Quantizer.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.base;
2 |
3 | import org.dma.sketchml.sketch.common.Constants;
4 | import org.dma.sketchml.sketch.quantization.QuantileQuantizer;
5 | import org.dma.sketchml.sketch.quantization.UniformQuantizer;
6 | import org.slf4j.Logger;
7 | import org.slf4j.LoggerFactory;
8 |
9 | import java.io.IOException;
10 | import java.io.ObjectInputStream;
11 | import java.io.ObjectOutputStream;
12 | import java.io.Serializable;
13 | import java.util.concurrent.Callable;
14 | import java.util.concurrent.ExecutionException;
15 | import java.util.concurrent.ExecutorService;
16 | import java.util.concurrent.Future;
17 |
18 | public abstract class Quantizer implements Serializable {
19 | public static Logger LOG = LoggerFactory.getLogger(Quantizer.class);
20 |
21 | protected int binNum;
22 | protected int n;
23 | protected double[] splits;
24 | protected int zeroIdx;
25 | protected double min;
26 | protected double max;
27 |
28 | protected int[] bins;
29 | public static final int DEFAULT_BIN_NUM = 256;
30 |
31 | public Quantizer(int binNum) {
32 | this.binNum = binNum;
33 | }
34 |
35 | public abstract void quantize(double[] values);
36 |
37 | public abstract void parallelQuantize(double[] values) throws InterruptedException, ExecutionException;
38 |
39 | public double[] getValues() {
40 | double[] res = new double[binNum];
41 | int splitNum = binNum - 1;
42 | res[0] = 0.5 * (min + splits[0]);
43 | for (int i = 1; i < splitNum; i++)
44 | res[i] = 0.5 * (splits[i - 1] + splits[i]);
45 | res[splitNum] = 0.5 * (splits[splitNum - 1] + max);
46 | return res;
47 | }
48 |
49 | public int indexOf(double x) {
50 | if (x < splits[0]) {
51 | return 0;
52 | } else if (x >= splits[binNum - 2]) {
53 | return binNum - 1;
54 | } else {
55 | int l = zeroIdx, r = zeroIdx;
56 | if (x < 0.0) l = 0;
57 | else r = binNum - 2;
58 | while (l + 1 < r) {
59 | int mid = (l + r) >> 1;
60 | if (splits[mid] > x) {
61 | if (mid == 0 || splits[mid - 1] <= x)
62 | return mid;
63 | else
64 | r = mid;
65 | } else {
66 | l = mid;
67 | }
68 | }
69 | int mid = (l + r) >> 1;
70 | return splits[mid] <= x ? mid + 1 : mid;
71 | }
72 | }
73 |
74 | protected void findZeroIdx() {
75 | if (min > 0.0)
76 | zeroIdx = 0;
77 | else if (max < 0.0)
78 | zeroIdx = binNum - 1;
79 | else {
80 | int t = 0;
81 | while (t < binNum - 1 && splits[t] < 0.0)
82 | t++;
83 | zeroIdx = t;
84 | }
85 | }
86 |
87 | protected void quantizeToBins(double[] values) {
88 | int size = values.length;
89 | bins = new int[size];
90 | for (int i = 0; i < size; i++)
91 | bins[i] = indexOf(values[i]);
92 | }
93 |
94 | protected void parallelQuantizeToBins(double[] values) throws InterruptedException, ExecutionException {
95 | int size = values.length;
96 | int threadNum = Constants.Parallel.getParallelism();
97 | ExecutorService threadPool = Constants.Parallel.getThreadPool();
98 | Future[] futures = new Future[threadNum];
99 | bins = new int[size];
100 | for (int i = 0; i < threadNum; i++) {
101 | int threadId = i;
102 | futures[threadId] = threadPool.submit(new Callable() {
103 | @Override
104 | public Void call() throws Exception {
105 | int elementPerThread = n / threadNum;
106 | int from = threadId * elementPerThread;
107 | int to = threadId + 1 == threadNum ? size : from + elementPerThread;
108 | for (int itemId = from; itemId < to; itemId++)
109 | bins[itemId] = indexOf(values[itemId]);
110 | return null;
111 | }
112 | });
113 | }
114 | for (int i = 0; i < threadNum; i++) {
115 | futures[i].get();
116 | }
117 | }
118 |
119 | public void timesBy(double x) {
120 | min *= x;
121 | max *= x;
122 | for (int i = 0; i < splits.length; i++)
123 | splits[i] *= x;
124 | }
125 |
126 | public static Quantizer newQuantizer(Quantizer.QuantizationType type, int binNum) {
127 | switch (type) {
128 | case QUANTILE:
129 | return new QuantileQuantizer(binNum);
130 | case UNIFORM:
131 | return new UniformQuantizer(binNum);
132 | default:
133 | throw new SketchMLException(
134 | "Unrecognizable quantization type: " + type);
135 | }
136 | }
137 |
138 | public enum QuantizationType {
139 | UNIFORM("UNIFORM"),
140 | QUANTILE("QUANTILE");
141 |
142 | private final String type;
143 |
144 | QuantizationType(String type) {
145 | this.type = type;
146 | }
147 |
148 | @Override
149 | public String toString() {
150 | return type;
151 | }
152 | }
153 |
154 | public abstract QuantizationType quantizationType();
155 |
156 | public int getBinNum() {
157 | return binNum;
158 | }
159 |
160 | public int getN() {
161 | return n;
162 | }
163 |
164 | public double[] getSplits() {
165 | return splits;
166 | }
167 |
168 | public int[] getBins() {
169 | return bins;
170 | }
171 |
172 | public int getZeroIdx() {
173 | return zeroIdx;
174 | }
175 |
176 | public double getMax() {
177 | return max;
178 | }
179 |
180 | public double getMin() {
181 | return min;
182 | }
183 |
184 | private void writeObject(ObjectOutputStream oos) throws IOException {
185 | oos.writeInt(binNum);
186 | oos.writeInt(n);
187 | for (double split : splits)
188 | oos.writeDouble(split);
189 | oos.writeInt(zeroIdx);
190 | oos.writeDouble(min);
191 | oos.writeDouble(max);
192 | oos.writeInt(bins.length);
193 | if (binNum <= 256) {
194 | for (int bin : bins)
195 | oos.writeByte(bin + Byte.MIN_VALUE);
196 | } else if (binNum <= 65536) {
197 | for (int bin : bins)
198 | oos.writeShort(bin + Short.MIN_VALUE);
199 | } else {
200 | for (int bin : bins)
201 | oos.writeInt(bin);
202 | }
203 | }
204 |
205 | private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
206 | binNum = ois.readInt();
207 | n = ois.readInt();
208 | int splitNum = binNum - 1;
209 | splits = new double[splitNum];
210 | for (int i = 0; i < splitNum; i++)
211 | splits[i] = ois.readDouble();
212 | zeroIdx = ois.readInt();
213 | min = ois.readDouble();
214 | max = ois.readDouble();
215 | bins = new int[ois.readInt()];
216 | if (binNum <= 256) {
217 | for (int i = 0; i < bins.length; i++)
218 | bins[i] = ((int) ois.readByte()) - Byte.MIN_VALUE;
219 | } else if (binNum <= 65536) {
220 | for (int i = 0; i < bins.length; i++)
221 | bins[i] = ((int) ois.readShort()) - Short.MIN_VALUE;
222 | } else {
223 | for (int i = 0; i < bins.length; i++)
224 | bins[i] = ois.readInt();
225 | }
226 | }
227 |
228 | }
229 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/base/SketchMLException.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.base;
2 |
3 | public class SketchMLException extends RuntimeException {
4 | public SketchMLException(String message) {
5 | super(message);
6 | }
7 |
8 | public SketchMLException(Throwable cause) {
9 | super(cause);
10 | }
11 |
12 | public SketchMLException(String message, Throwable cause) {
13 | super(message, cause);
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/base/VectorCompressor.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.base;
2 |
3 | import org.apache.commons.lang3.tuple.Pair;
4 |
5 | import java.io.IOException;
6 | import java.io.Serializable;
7 | import java.util.concurrent.ExecutionException;
8 |
9 | public interface VectorCompressor extends Serializable {
10 | void compressDense(double[] values);
11 |
12 | void compressSparse(int[] keys, double[] values);
13 |
14 | void parallelCompressDense(double[] values) throws InterruptedException, ExecutionException;
15 |
16 | void parallelCompressSparse(int[] keys, double[] values) throws InterruptedException, ExecutionException;
17 |
18 | double[] decompressDense();
19 |
20 | Pair decompressSparse();
21 |
22 | void timesBy(double x);
23 |
24 | double size();
25 |
26 | int memoryBytes() throws IOException;
27 | }
28 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/binary/BinaryUtils.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.binary;
2 |
3 | import java.util.BitSet;
4 |
5 | public class BinaryUtils {
6 | public static void setBits(BitSet bitSet, int offset, int value, int numBits) {
7 | for (int i = numBits - 1; i >= 0; i--) {
8 | int t = value - (1 << i);
9 | if (t >= 0) {
10 | bitSet.set(offset + numBits - 1 - i);
11 | value = t;
12 | }
13 | }
14 | }
15 |
16 | public static void setBytes(BitSet bitSet, int offset, int value, int numBytes) {
17 | setBits(bitSet, offset, value, numBytes * 8);
18 | }
19 |
20 | public static int getBits(BitSet bitSet, int offset, int numBits) {
21 | int res = 0;
22 | for (int i = 0; i < numBits; i++) {
23 | res <<= 1;
24 | res |= bitSet.get(offset + i) ? 1 : 0;
25 | if (bitSet.get(offset + i))
26 | res |= 1;
27 | }
28 | return res;
29 | }
30 |
31 | public static int getBytes(BitSet bitSet, int offset, int numBytes) {
32 | return getBits(bitSet, offset, numBytes * 8);
33 | }
34 |
35 | public static String bits2String(int value, int numBits) {
36 | StringBuilder sb = new StringBuilder();
37 | for (int i = numBits - 1; i >= 0; i--) {
38 | int t = value - (1 << i);
39 | if (t < 0) {
40 | sb.append("0");
41 | } else {
42 | sb.append("1");
43 | value = t;
44 | }
45 | }
46 | return sb.toString();
47 | }
48 |
49 | public static String bits2String(BitSet bitset, int from, int length) {
50 | StringBuilder sb = new StringBuilder();
51 | for (int i = from; i < from + length; i++)
52 | sb.append(bitset.get(i) ? 1 : 0);
53 | return sb.toString();
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/binary/DeltaAdaptiveEncoder.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.binary;
2 |
3 | import org.dma.sketchml.sketch.base.BinaryEncoder;
4 | import org.dma.sketchml.sketch.util.Maths;
5 | import org.slf4j.Logger;
6 | import org.slf4j.LoggerFactory;
7 |
8 | import java.io.IOException;
9 | import java.io.ObjectInputStream;
10 | import java.io.ObjectOutputStream;
11 | import java.util.BitSet;
12 |
13 | public class DeltaAdaptiveEncoder implements BinaryEncoder {
14 | private static final Logger LOG = LoggerFactory.getLogger(DeltaAdaptiveEncoder.class);
15 |
16 | private int size;
17 | private int numIntervals; // how many number of intervals it splits [0, 31]
18 | // should be exponential to 2
19 | private boolean flagKind; // whether the number of flag bits is dynamic to different interval
20 | private BitSet deltaBits;
21 | private BitSet flagBits;
22 |
23 | private void calOptimalIntervals(double[] prob) {
24 | double optBitsPerKey = 32.0;
25 | numIntervals = 1;
26 | flagKind = false;
27 | for (int m = 2; m <= 16; m *= 2) {
28 | double[] intervalProb = new double[m];
29 | int b = 32 / m;
30 | double sum = 0.0;
31 | for (int i = 0; i < m; i++) {
32 | for (int j = 0; j < b; j++)
33 | intervalProb[i] += prob[i * b + j];
34 | sum += (i + 1) * intervalProb[i];
35 | }
36 | // all flags have the same number of bits
37 | double t1 = sum * b + Maths.log2nlz(m);
38 | if (t1 < optBitsPerKey) {
39 | optBitsPerKey = t1;
40 | numIntervals = m;
41 | flagKind = false;
42 | }
43 | // one bit for each interval
44 | double t2 = sum * (b + 1) + 1;
45 | if (t2 < optBitsPerKey) {
46 | optBitsPerKey = t2;
47 | numIntervals = m;
48 | flagKind = true;
49 | }
50 | }
51 | }
52 |
53 | @Override
54 | public void encode(int[] values) {
55 | size = values.length;
56 | // 1. get probabilities of each range [2^i, 2^(i+1))
57 | int[] delta = new int[size];
58 | int[] bitsNeeded = new int[size];
59 | double[] prob = new double[32];
60 | delta[0] = values[0];
61 | if (delta[0] == 0)
62 | bitsNeeded[0] = 1;
63 | else
64 | bitsNeeded[0] = Maths.log2nlz(delta[0]) + 1;
65 | prob[bitsNeeded[0]]++;
66 | for (int i = 1; i < size; i++) {
67 | delta[i] = values[i] - values[i - 1];
68 | bitsNeeded[i] = Maths.log2nlz(delta[i]) + 1;
69 | prob[bitsNeeded[i]]++;
70 | }
71 | for (int i = 0; i < prob.length; i++)
72 | prob[i] /= size;
73 | // 2. get the optimal number of intervals, and the kind of flag bits
74 | calOptimalIntervals(prob);
75 | // 3. encode deltas
76 | deltaBits = new BitSet();
77 | flagBits = new BitSet();
78 | int bitsPerInterval = 32 / numIntervals;
79 | int bitsShift = Maths.log2nlz(bitsPerInterval);
80 | int flagOffset = 0, deltaOffset = 0;
81 | if (!flagKind) {
82 | int numBitsPerFlag = Maths.log2nlz(numIntervals);
83 | for (int i = 0; i < size; i++) {
84 | // ceil(bitsNeeded / bitsPerInterval)
85 | int intervalNeeded = (bitsNeeded[i] + bitsPerInterval - 1) >> bitsShift;
86 | // set flag
87 | BinaryUtils.setBits(flagBits, flagOffset, intervalNeeded - 1, numBitsPerFlag);
88 | flagOffset += numBitsPerFlag;
89 | // set delta
90 | BinaryUtils.setBits(deltaBits, deltaOffset, delta[i], bitsPerInterval * intervalNeeded);
91 | deltaOffset += bitsPerInterval * intervalNeeded;
92 | }
93 | } else {
94 | int[] flagCandidates = new int[numIntervals + 1];
95 | for (int i = 1; i <= numIntervals; i++) {
96 | // 0b1110 = 0b10000 - 2
97 | flagCandidates[i] = (1 << (i + 1)) - 2;
98 | }
99 | for (int i = 0; i < size; i++) {
100 | // ceil(bitsNeeded / bitsPerInterval)
101 | int intervalNeeded = (bitsNeeded[i] + bitsPerInterval - 1) >> bitsShift;
102 | // set flag
103 | BinaryUtils.setBits(flagBits, flagOffset, flagCandidates[intervalNeeded], intervalNeeded + 1);
104 | flagOffset += intervalNeeded + 1;
105 | // set delta
106 | BinaryUtils.setBits(deltaBits, deltaOffset, delta[i], bitsPerInterval * intervalNeeded);
107 | deltaOffset += bitsPerInterval * intervalNeeded;
108 | }
109 | }
110 | //LOG.info(String.format("BitsPerKey[%f], flag[%f], delta[%f]", (flagOffset + deltaOffset) * 1. / size,
111 | // flagOffset * 1. / size, deltaOffset * 1. / size));
112 | }
113 |
114 | @Override
115 | public int[] decode() {
116 | int[] res = new int[size];
117 | int bitsPerInterval = 32 / numIntervals;
118 | int flagOffset = 0, deltaOffset = 0, prev = 0;;
119 | if (!flagKind) {
120 | int numBitsPerFlag = Maths.log2nlz(numIntervals);
121 | for (int i = 0; i < size; i++) {
122 | // get flag
123 | int intervalNeeded = BinaryUtils.getBits(flagBits, flagOffset, numBitsPerFlag) + 1;
124 | flagOffset += numBitsPerFlag;
125 | // get delta
126 | int delta = BinaryUtils.getBits(deltaBits, deltaOffset, bitsPerInterval * intervalNeeded);
127 | deltaOffset += bitsPerInterval * intervalNeeded;
128 | // set value
129 | res[i] = prev + delta;
130 | prev = res[i];
131 | }
132 | } else {
133 | for (int i = 0; i < size; i++) {
134 | // get flag
135 | int intervalNeeded = 0;
136 | while (flagBits.get(flagOffset++)) intervalNeeded++;
137 | // get delta
138 | int delta = BinaryUtils.getBits(deltaBits, deltaOffset, bitsPerInterval * intervalNeeded);
139 | deltaOffset += bitsPerInterval * intervalNeeded;
140 | // set value
141 | res[i] = prev + delta;
142 | prev = res[i];
143 | }
144 | }
145 | return res;
146 | }
147 |
148 | private void writeObject(ObjectOutputStream oos) throws IOException {
149 | oos.writeInt(size);
150 | oos.writeInt(numIntervals);
151 | oos.writeBoolean(flagKind);
152 | if (flagBits == null) {
153 | oos.writeInt(0);
154 | } else {
155 | long[] flags = flagBits.toLongArray();
156 | oos.writeInt(flags.length);
157 | for (long l : flags) {
158 | oos.writeLong(l);
159 | }
160 | }
161 | if (deltaBits == null) {
162 | oos.writeInt(0);
163 | } else {
164 | long[] delta = deltaBits.toLongArray();
165 | oos.writeInt(delta.length);
166 | for (long l : delta) {
167 | oos.writeLong(l);
168 | }
169 | }
170 | }
171 |
172 | private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
173 | size = ois.readInt();
174 | numIntervals = ois.readInt();
175 | flagKind = ois.readBoolean();
176 | int flagsLength = ois.readInt();
177 | long[] flags = new long[flagsLength];
178 | for (int i = 0; i < flagsLength; i++) {
179 | flags[i] = ois.readLong();
180 | }
181 | flagBits = BitSet.valueOf(flags);
182 | int deltaLength = ois.readInt();
183 | long[] delta = new long[deltaLength];
184 | for (int i = 0; i < deltaLength; i++) {
185 | delta[i] = ois.readLong();
186 | }
187 | deltaBits = BitSet.valueOf(delta);
188 | }
189 | }
190 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/binary/DeltaBinaryEncoder.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.binary;
2 |
3 | import org.dma.sketchml.sketch.base.BinaryEncoder;
4 | import org.dma.sketchml.sketch.base.SketchMLException;
5 | import org.slf4j.Logger;
6 | import org.slf4j.LoggerFactory;
7 |
8 | import java.io.IOException;
9 | import java.io.ObjectInputStream;
10 | import java.io.ObjectOutputStream;
11 | import java.util.BitSet;
12 |
13 | /**
14 | * This is the special case for DeltaAdaptiveEncoder
15 | * where numIntervals equals to 4 and number of flag bits is constant
16 | *
17 | * */
18 | public class DeltaBinaryEncoder implements BinaryEncoder {
19 | private static final Logger LOG = LoggerFactory.getLogger(DeltaBinaryEncoder.class);
20 |
21 | private int size;
22 | private BitSet deltaBits;
23 | private BitSet flagBits;
24 |
25 | @Override
26 | public void encode(int[] values) {
27 | size = values.length;
28 | flagBits = new BitSet(size * 2);
29 | deltaBits = new BitSet(size * 12);
30 | int offset = 0, prev = 0;
31 | for (int i = 0; i < size; i++) {
32 | int delta = values[i] - prev;
33 | int bytesNeeded = needBytes(delta);
34 | BinaryUtils.setBits(flagBits, 2 * i, bytesNeeded - 1, 2);
35 | BinaryUtils.setBytes(deltaBits, offset, delta, bytesNeeded);
36 | prev = values[i];
37 | offset += bytesNeeded * 8;
38 | }
39 | }
40 |
41 | @Override
42 | public int[] decode() {
43 | int[] res = new int[size];
44 | int offset = 0, prev = 0;
45 | for (int i = 0; i < size; i++) {
46 | int bytesNeeded = BinaryUtils.getBits(flagBits, i * 2, 2) + 1;
47 | int delta = BinaryUtils.getBytes(deltaBits, offset, bytesNeeded);
48 | res[i] = prev + delta;
49 | prev = res[i];
50 | offset += bytesNeeded * 8;
51 | }
52 | return res;
53 | }
54 |
55 | public static int needBytes(int x) {
56 | if (x < 0) {
57 | throw new SketchMLException("Input of DeltaBinaryEncoder should be sorted");
58 | } else if (x < 256) {
59 | return 1;
60 | } else if (x < 65536) {
61 | return 2;
62 | } else {
63 | return 4;
64 | }
65 | }
66 |
67 | private void writeObject(ObjectOutputStream oos) throws IOException {
68 | oos.writeInt(size);
69 | if (flagBits == null) {
70 | oos.writeInt(0);
71 | } else {
72 | long[] flags = flagBits.toLongArray();
73 | oos.writeInt(flags.length);
74 | for (long l : flags) {
75 | oos.writeLong(l);
76 | }
77 | }
78 | if (deltaBits == null) {
79 | oos.writeInt(0);
80 | } else {
81 | long[] delta = deltaBits.toLongArray();
82 | oos.writeInt(delta.length);
83 | for (long l : delta) {
84 | oos.writeLong(l);
85 | }
86 | }
87 | }
88 |
89 | private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
90 | size = ois.readInt();
91 | int flagsLength = ois.readInt();
92 | long[] flags = new long[flagsLength];
93 | for (int i = 0; i < flagsLength; i++) {
94 | flags[i] = ois.readLong();
95 | }
96 | flagBits = BitSet.valueOf(flags);
97 | int deltaLength = ois.readInt();
98 | long[] delta = new long[deltaLength];
99 | for (int i = 0; i < deltaLength; i++) {
100 | delta[i] = ois.readLong();
101 | }
102 | deltaBits = BitSet.valueOf(delta);
103 | }
104 | }
105 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/binary/HuffmanEncoder.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.binary;
2 |
3 | import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
4 | import it.unimi.dsi.fastutil.ints.Int2ObjectRBTreeMap;
5 | import org.dma.sketchml.sketch.base.BinaryEncoder;
6 | import org.slf4j.Logger;
7 | import org.slf4j.LoggerFactory;
8 |
9 | import java.io.IOException;
10 | import java.io.ObjectInputStream;
11 | import java.io.ObjectOutputStream;
12 | import java.util.*;
13 | import java.util.stream.IntStream;
14 |
15 | public class HuffmanEncoder implements BinaryEncoder {
16 | private static final Logger LOG = LoggerFactory.getLogger(HuffmanEncoder.class);
17 |
18 | private Item[] items;
19 | private BitSet bitset;
20 | private int size;
21 |
22 |
23 | private class Node {
24 | int value;
25 | int occurrence;
26 | Node leftChild;
27 | Node rightChild;
28 | boolean isLeaf;
29 |
30 | Node(int value, int occurrence, Node leftChild, Node rightChild, boolean isLeaf) {
31 | this.value = value;
32 | this.occurrence = occurrence;
33 | this.leftChild = leftChild;
34 | this.rightChild = rightChild;
35 | this.isLeaf = isLeaf;
36 | }
37 |
38 | Node(int value, int occurrence) {
39 | this(value, occurrence, null, null, false);
40 | }
41 |
42 | Node() {
43 | this(-1, -1, null, null, false);
44 | }
45 |
46 | int getValue() {
47 | return value;
48 | }
49 |
50 | int getOccurrence() {
51 | return occurrence;
52 | }
53 | }
54 |
55 | private class Item {
56 | int value;
57 | int bits;
58 | int numBits;
59 |
60 | Item(int value, int bits, int numBits) {
61 | this.value = value;
62 | this.bits = bits;
63 | this.numBits = numBits;
64 | }
65 |
66 | @Override
67 | public String toString() {
68 | StringBuilder sb = new StringBuilder();
69 | sb.append("(");
70 | sb.append(value);
71 | sb.append(" --> ");
72 | sb.append(BinaryUtils.bits2String(bits, numBits));
73 | sb.append(")");
74 | return sb.toString();
75 | }
76 | }
77 |
78 | private void traverse(Node node, Int2ObjectMap- mapping, int bits, int depth) {
79 | if (node.isLeaf) {
80 | mapping.put(node.value, new Item(node.value, bits, depth == 0 ? 1 : depth));
81 | } else {
82 | traverse(node.leftChild, mapping, bits << 1, depth + 1);
83 | traverse(node.rightChild, mapping, (bits << 1) | 1, depth + 1);
84 | }
85 | }
86 |
87 | @Override
88 | public void encode(int[] values) {
89 | long startTime = System.currentTimeMillis();
90 | // 1. count occurrences
91 | Int2ObjectRBTreeMap freq = new Int2ObjectRBTreeMap<>();
92 | for (int v : values) {
93 | Node node = freq.get(v);
94 | if (node != null)
95 | node.occurrence++;
96 | else
97 | freq.put(v, new Node(v, 1, null, null, true));
98 | }
99 | // 2. build tree
100 | PriorityQueue heap = new PriorityQueue<>(freq.size(),
101 | Comparator.comparing(Node::getOccurrence));
102 | heap.addAll(freq.values());
103 | while (heap.size() > 1) {
104 | Node x = heap.poll();
105 | Node y = heap.poll();
106 | Node p = new Node(-1, x.occurrence + y.occurrence, x, y, false);
107 | heap.add(p);
108 | }
109 | Int2ObjectMap
- mapping = new Int2ObjectRBTreeMap<>();
110 | traverse(heap.peek(), mapping, 0, 0);
111 | items = new Item[mapping.size()];
112 | mapping.values().toArray(items);
113 | // 3. encode values
114 | bitset = new BitSet();
115 | int offset = 0;
116 | for (int v : values) {
117 | Item item = mapping.get(v);
118 | BinaryUtils.setBits(bitset, offset, item.bits, item.numBits);
119 | offset += item.numBits;
120 | }
121 | size = values.length;
122 | LOG.debug(String.format("Huffman encoding for %d values cost %d ms",
123 | values.length, System.currentTimeMillis() - startTime));
124 | }
125 |
126 | @Override
127 | public int[] decode() {
128 | if (size == 0)
129 | return new int[0];
130 |
131 | // 1. build Huffman tree
132 | Node root = new Node();
133 | for (Item item : items) {
134 | int bits = item.bits;
135 | int numBits = item.numBits;
136 | Node cur = root;
137 | for (int i = numBits - 1; i >= 0; i--) {
138 | int t = bits - (1 << i);
139 | if (t >= 0) {
140 | if (cur.rightChild == null)
141 | cur.rightChild = new Node();
142 | cur = cur.rightChild;
143 | bits = t;
144 | } else {
145 | if (cur.leftChild == null)
146 | cur.leftChild = new Node();
147 | cur = cur.leftChild;
148 | }
149 | }
150 | cur.value = item.value;
151 | cur.isLeaf = true;
152 | }
153 | // 2. decode bits
154 | int[] res = new int[size];
155 | int cnt = 0;
156 | Node cur = root;
157 | int idx = 0;
158 | while (cnt < size) {
159 | cur = bitset.get(idx++) ? cur.rightChild : cur.leftChild;
160 | if (cur.isLeaf) {
161 | res[cnt++] = cur.value;
162 | cur = root;
163 | }
164 | }
165 | return res;
166 | }
167 |
168 | private void writeObject(ObjectOutputStream oos) throws IOException {
169 | // items
170 | if (items == null) {
171 | oos.writeInt(0);
172 | } else {
173 | oos.writeInt(items.length);
174 | for (Item item : items) {
175 | oos.writeInt(item.value);
176 | oos.writeInt(item.bits);
177 | oos.writeInt(item.numBits);
178 | }
179 | }
180 | // bit set
181 | if (bitset == null) {
182 | oos.writeInt(0);
183 | } else {
184 | long[] bits = bitset.toLongArray();
185 | oos.writeInt(bits.length);
186 | for (long l : bits)
187 | oos.writeLong(l);
188 | }
189 | // size
190 | oos.writeInt(size);
191 | }
192 |
193 | private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
194 | // items
195 | int numItems = ois.readInt();
196 | items = new Item[numItems];
197 | for (int i = 0; i < numItems; i++)
198 | items[i] = new Item(ois.readInt(), ois.readInt(), ois.readInt());
199 | // bit set
200 | int numLongs = ois.readInt();
201 | long[] bits = new long[numLongs];
202 | for (int i = 0; i < numLongs; i++)
203 | bits[i] = ois.readLong();
204 | bitset = BitSet.valueOf(bits);
205 | // size
206 | size = ois.readInt();
207 | }
208 | }
209 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/common/Constants.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.common;
2 |
3 | import org.dma.sketchml.sketch.base.SketchMLException;
4 |
5 | import java.util.concurrent.ExecutorService;
6 | import java.util.concurrent.Executors;
7 |
8 | public class Constants {
9 | public static class Parallel {
10 | private static int parallelism;
11 |
12 | private static ExecutorService threadPool;
13 |
14 | static {
15 | parallelism = 0;
16 | threadPool = null;
17 | }
18 |
19 | public static void setParallelism(int parallelism) {
20 | if (parallelism < 1)
21 | throw new SketchMLException("Invalid parallelism: " + parallelism);
22 | Parallel.parallelism = parallelism;
23 | Parallel.threadPool = Executors.newFixedThreadPool(parallelism);
24 | }
25 |
26 | public static int getParallelism() {
27 | if (parallelism <= 0)
28 | throw new SketchMLException("Parallelism is not set yet");
29 | return parallelism;
30 | }
31 |
32 | public static ExecutorService getThreadPool() {
33 | if (threadPool == null)
34 | throw new SketchMLException("Parallelism is not set yet");
35 | return threadPool;
36 | }
37 |
38 | public static void shutdown() {
39 | if (threadPool != null)
40 | threadPool.shutdown();
41 | }
42 | }
43 |
44 | }
45 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/hash/BJHash.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.hash;
2 |
3 | import org.dma.sketchml.sketch.base.Int2IntHash;
4 |
5 | public class BJHash extends Int2IntHash {
6 | public BJHash(int size) {
7 | super(size);
8 | }
9 |
10 | public int hash(int key) {
11 | int code = key;
12 | code = (code + 0x7ed55d16) + (code << 12);
13 | code = (code ^ 0xc761c23c) ^ (code >> 19);
14 | code = (code + 0x165667b1) + (code << 5);
15 | code = (code + 0xd3a2646c) ^ (code << 9);
16 | code = (code + 0xfd7046c5) + (code << 3);
17 | code = (code ^ 0xb55a4f09) ^ (code >> 16);
18 | code %= size;
19 | return code >= 0 ? code : code + size;
20 | }
21 |
22 | @Override
23 | public Int2IntHash clone() {
24 | return new BJHash(size);
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/hash/BKDRHash.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.hash;
2 |
3 | import org.dma.sketchml.sketch.base.Int2IntHash;
4 |
5 | public class BKDRHash extends Int2IntHash {
6 | private int seed;
7 |
8 | public BKDRHash(int size, int seed) {
9 | super(size);
10 | this.seed = seed;
11 | }
12 |
13 | public int hash(int key) {
14 | int code = 0;
15 | while (key != 0) {
16 | code = seed * code + (key % 10);
17 | key /= 10;
18 | }
19 | code %= size;
20 | return code >= 0 ? code : code + size;
21 | }
22 |
23 | @Override
24 | public Int2IntHash clone() {
25 | return new BKDRHash(size, seed);
26 | }
27 |
28 | public int getSeed() {
29 | return seed;
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/hash/HashFactory.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.hash;
2 |
3 | import org.dma.sketchml.sketch.base.Int2IntHash;
4 | import org.dma.sketchml.sketch.base.SketchMLException;
5 | import org.dma.sketchml.sketch.util.Maths;
6 |
7 | import java.util.*;
8 |
9 | public class HashFactory {
10 | private static final Int2IntHash[] int2intHashes =
11 | new Int2IntHash[]{new BJHash(0), new Mix64Hash(0),
12 | new TWHash(0), new BKDRHash(0, 31), new BKDRHash(0, 131),
13 | new BKDRHash(0, 267), new BKDRHash(0, 1313), new BKDRHash(0, 13131)};
14 | private static final Random random = new Random();
15 |
16 | public static Int2IntHash getRandomInt2IntHash(int size) {
17 | int idx = random.nextInt(int2intHashes.length);
18 | Int2IntHash res = int2intHashes[idx].clone();
19 | res.setSize(size);
20 | return res;
21 | }
22 |
23 | public static Int2IntHash[] getRandomInt2IntHashes(int hashNum, int size) {
24 | if (hashNum > int2intHashes.length) {
25 | throw new SketchMLException(String.format("Currently only %d " +
26 | "hash functions are available", int2intHashes.length));
27 | } else {
28 | Int2IntHash[] res = new Int2IntHash[hashNum];
29 | int[] indexes = new int[int2intHashes.length];
30 | Arrays.setAll(indexes, i -> i);
31 | Maths.shuffle(indexes);
32 | for (int i = 0; i < hashNum; i++) {
33 | res[i] = int2intHashes[indexes[i]].clone();
34 | res[i].setSize(size);
35 | }
36 | return res;
37 | }
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/hash/Mix64Hash.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.hash;
2 |
3 | import org.dma.sketchml.sketch.base.Int2IntHash;
4 |
5 | public class Mix64Hash extends Int2IntHash {
6 | public Mix64Hash(int size) {
7 | super(size);
8 | }
9 |
10 | public int hash(int key) {
11 | int code = key;
12 | code = (~code) + (code << 21); // code = (code << 21) - code - 1;
13 | code = code ^ (code >> 24);
14 | code = (code + (code << 3)) + (code << 8); // code * 265
15 | code = code ^ (code >> 14);
16 | code = (code + (code << 2)) + (code << 4); // code * 21
17 | code = code ^ (code >> 28);
18 | code = code + (code << 31);
19 | code %= size;
20 | return code >= 0 ? code : code + size;
21 | }
22 |
23 | @Override
24 | public Int2IntHash clone() {
25 | return new Mix64Hash(size);
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/hash/TWHash.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.hash;
2 |
3 | import org.dma.sketchml.sketch.base.Int2IntHash;
4 |
5 | public class TWHash extends Int2IntHash {
6 | public TWHash(int size) {
7 | super(size);
8 | }
9 |
10 | public int hash(int key) {
11 | int code = key;
12 | code = ~code + (code << 15);
13 | code = code ^ (code >> 12);
14 | code = code + (code << 2);
15 | code = code ^ (code >> 4);
16 | code = code * 2057;
17 | code = code ^ (code >> 16);
18 | code %= size;
19 | return code >= 0 ? code : code + size;
20 | }
21 |
22 | @Override
23 | public Int2IntHash clone() {
24 | return new TWHash(size);
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/quantization/QuantileQuantizer.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.quantization;
2 |
3 | import org.dma.sketchml.sketch.base.Quantizer;
4 | import org.dma.sketchml.sketch.common.Constants;
5 | import org.dma.sketchml.sketch.sketch.quantile.HeapQuantileSketch;
6 | import org.dma.sketchml.sketch.util.Maths;
7 | import org.slf4j.Logger;
8 | import org.slf4j.LoggerFactory;
9 |
10 | import java.util.concurrent.Callable;
11 | import java.util.concurrent.ExecutionException;
12 | import java.util.concurrent.ExecutorService;
13 | import java.util.concurrent.Future;
14 |
15 | public class QuantileQuantizer extends Quantizer {
16 | private static final Logger LOG = LoggerFactory.getLogger(QuantileQuantizer.class);
17 |
18 | public QuantileQuantizer(int binNum) {
19 | super(binNum);
20 | }
21 |
22 | public QuantileQuantizer() {
23 | this(Quantizer.DEFAULT_BIN_NUM);
24 | }
25 |
26 | @Override
27 | public void quantize(double[] values) {
28 | long startTime = System.currentTimeMillis();
29 | // 1. create quantile sketch summary
30 | n = values.length;
31 | HeapQuantileSketch qSketch = new HeapQuantileSketch((long) n);
32 | for (double v : values) {
33 | qSketch.update(v);
34 | }
35 | min = qSketch.getMinValue();
36 | max = qSketch.getMaxValue();
37 | // 2. query quantiles, set them as bin edges
38 | splits = Maths.unique(qSketch.getQuantiles(binNum));
39 | if (splits.length + 1 != binNum) {
40 | LOG.warn(String.format("Actual bin num %d not equal to %d",
41 | splits.length + 1, binNum));
42 | binNum = splits.length + 1;
43 | }
44 | // 3. find the zero index
45 | findZeroIdx();
46 | // 4. find index of each value
47 | quantizeToBins(values);
48 | LOG.debug(String.format("Quantile quantization for %d items cost %d ms",
49 | n, System.currentTimeMillis() - startTime));
50 | }
51 |
52 | @Override
53 | public void parallelQuantize(double[] values) throws InterruptedException, ExecutionException {
54 | long startTime = System.currentTimeMillis();
55 | // 1. create quantile sketch summary in parallel
56 | n = values.length;
57 | // 1.1. each thread create a quantile sketch based on a portion of data
58 | int threadNum = Constants.Parallel.getParallelism();
59 | ExecutorService threadPool = Constants.Parallel.getThreadPool();
60 | Future[] futures = new Future[threadNum];
61 | for (int i = 0; i < threadNum; i++) {
62 | int threadId = i;
63 | futures[threadId] = threadPool.submit(new Callable() {
64 | @Override
65 | public HeapQuantileSketch call() throws Exception {
66 | int elementPerThread = n / threadNum;
67 | int from = threadId * elementPerThread;
68 | int to = threadId + 1 == threadNum ? n : from + elementPerThread;
69 | HeapQuantileSketch qSketch = new HeapQuantileSketch((long) (to - from));
70 | for (int itemId = from; itemId < to; itemId++) {
71 | qSketch.update(values[itemId]);
72 | }
73 | return qSketch;
74 | }
75 | });
76 | }
77 | // 1.2. merge all quantile sketches together
78 | HeapQuantileSketch qSketch = futures[0].get();
79 | for (int i = 1; i < threadNum; i++) {
80 | qSketch.merge(futures[i].get());
81 | }
82 | min = qSketch.getMinValue();
83 | max = qSketch.getMaxValue();
84 | // 2. query quantiles, set them as bin edges
85 | splits = qSketch.getQuantiles(binNum);
86 | // 3. find the zero index
87 | findZeroIdx();
88 | // 4. find index of each value
89 | parallelQuantizeToBins(values);
90 | LOG.debug(String.format("Quantile quantization for %d items cost %d ms",
91 | n, System.currentTimeMillis() - startTime));
92 | }
93 |
94 | @Override
95 | public QuantizationType quantizationType() {
96 | return QuantizationType.QUANTILE;
97 | }
98 |
99 | }
100 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/quantization/UniformQuantizer.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.quantization;
2 |
3 | import org.dma.sketchml.sketch.base.Quantizer;
4 | import org.slf4j.Logger;
5 | import org.slf4j.LoggerFactory;
6 |
7 | import java.util.concurrent.ExecutionException;
8 |
9 | public class UniformQuantizer extends Quantizer {
10 | public static final Logger LOG = LoggerFactory.getLogger(UniformQuantizer.class);
11 |
12 | public UniformQuantizer(int binNum) {
13 | super(binNum);
14 | }
15 |
16 | public UniformQuantizer() {
17 | super(Quantizer.DEFAULT_BIN_NUM);
18 | }
19 |
20 | @Override
21 | public void quantize(double[] values) {
22 | long startTime = System.currentTimeMillis();
23 | n = values.length;
24 | min = Double.MAX_VALUE;
25 | max = Double.MIN_VALUE;
26 | for (double v : values) {
27 | if (v < min) min = v;
28 | if (v > max) max = v;
29 | }
30 | // 1. uniformly split the range of values
31 | double step = (max - min) / binNum;
32 | int splitNum = binNum - 1;
33 | splits = new double[splitNum];
34 | splits[0] = min + step;
35 | for (int i = 1; i < splitNum; i++) {
36 | splits[i] = splits[i - 1] + step;
37 | }
38 | // 3. find the zero index
39 | findZeroIdx();
40 | // 4. find index of each value
41 | quantizeToBins(values);
42 | LOG.debug(String.format("Uniform quantization for %d items cost %d ms",
43 | n, System.currentTimeMillis() - startTime));
44 | }
45 |
46 | @Override
47 | public void parallelQuantize(double[] values) throws InterruptedException, ExecutionException {
48 | long startTime = System.currentTimeMillis();
49 | n = values.length;
50 | min = Double.MAX_VALUE;
51 | max = Double.MIN_VALUE;
52 | for (double v : values) {
53 | if (v < min) min = v;
54 | if (v > max) max = v;
55 | }
56 | // 1. uniformly split the range of values
57 | double step = (max - min) / binNum;
58 | int splitNum = binNum - 1;
59 | splits = new double[splitNum];
60 | splits[0] = min + step;
61 | for (int i = 1; i < splitNum; i++) {
62 | splits[i] = splits[i - 1] + step;
63 | }
64 | // 3. find the zero index
65 | findZeroIdx();
66 | // 4. find index of each value
67 | parallelQuantizeToBins(values);
68 | LOG.debug(String.format("Uniform quantization for %d items cost %d ms",
69 | n, System.currentTimeMillis() - startTime));
70 | }
71 |
72 | @Override
73 | public QuantizationType quantizationType() {
74 | return QuantizationType.UNIFORM;
75 | }
76 |
77 | }
78 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/sample/App.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.sample;
2 |
3 | import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
4 | import it.unimi.dsi.fastutil.ints.IntArrayList;
5 | import org.apache.commons.lang3.tuple.Pair;
6 | import org.dma.sketchml.sketch.base.QuantileSketch;
7 | import org.dma.sketchml.sketch.base.Quantizer;
8 | import org.dma.sketchml.sketch.base.VectorCompressor;
9 | import org.dma.sketchml.sketch.common.Constants;
10 | import org.dma.sketchml.sketch.sketch.frequency.GroupedMinMaxSketch;
11 | import org.dma.sketchml.sketch.sketch.frequency.MinMaxSketch;
12 | import org.dma.sketchml.sketch.sketch.quantile.HeapQuantileSketch;
13 | import org.dma.sketchml.sketch.util.Utils;
14 | import org.slf4j.Logger;
15 | import org.slf4j.LoggerFactory;
16 |
17 | import java.util.Arrays;
18 | import java.util.Random;
19 |
20 | public class App {
21 | private static final Logger LOG = LoggerFactory.getLogger(App.class);
22 |
23 | private static Random random = new Random();
24 |
25 | public static void main(String[] args) throws Exception {
26 | Constants.Parallel.setParallelism(4);
27 | dense();
28 | sparse();
29 | Constants.Parallel.shutdown();
30 | }
31 |
32 | private static void dense() throws Exception {
33 | int n = 1000000;
34 | double density = 0.9;
35 | double[] values = new double[n];
36 | for (int i = 0; i < n; i++) {
37 | if (random.nextDouble() < density) {
38 | values[i] = random.nextGaussian();
39 | }
40 | }
41 | Quantizer.QuantizationType quantType = Quantizer.QuantizationType.QUANTILE;
42 | int binNum = Quantizer.DEFAULT_BIN_NUM;
43 | VectorCompressor compressor = new DenseVectorCompressor(quantType, binNum);
44 | //compressor.compressDense(values);
45 | compressor.parallelCompressDense(values);
46 | compressor = (VectorCompressor) Utils.testSerialization(compressor);
47 | double[] dValues = compressor.decompressDense();
48 | LOG.info("First 10 values before: " + Arrays.toString(Arrays.copyOf(values, 10)));
49 | LOG.info("First 10 values after: " + Arrays.toString(Arrays.copyOf(dValues, 10)));
50 | QuantileSketch qSketch = new HeapQuantileSketch((long) n);
51 | double rmse = 0.0;
52 | for (int i = 0; i < n; i++) {
53 | qSketch.update(dValues[i] - values[i]);
54 | rmse += (dValues[i] - values[i]) *(dValues[i] - values[i]);
55 | }
56 | double[] err = qSketch.getQuantiles(100);
57 | rmse = Math.sqrt(rmse / n);
58 | LOG.info("Quantiles of errors: " + Arrays.toString(err));
59 | LOG.info("RMSE: " + rmse);
60 | int originBytes = values.length * 8;
61 | int compressBytes = compressor.memoryBytes();
62 | LOG.info(String.format("Compress %d bytes into %d bytes, compression rate: %f",
63 | originBytes, compressBytes, 1.0 * originBytes / compressBytes));
64 | }
65 |
66 | private static void sparse() throws Exception {
67 | int n = 100000;
68 | double sparsity = 0.9;
69 | IntArrayList keyList = new IntArrayList();
70 | DoubleArrayList valueList = new DoubleArrayList();
71 | for (int i = 0; i < n; i++) {
72 | if (random.nextDouble() > sparsity) {
73 | keyList.add(i);
74 | valueList.add(random.nextGaussian());
75 | }
76 | }
77 | int nnz = keyList.size();
78 | int[] keys = keyList.toIntArray();
79 | double[] values = valueList.toDoubleArray();
80 | Quantizer.QuantizationType quantType = Quantizer.QuantizationType.QUANTILE;
81 | int binNum = Quantizer.DEFAULT_BIN_NUM;
82 | int groupNum = GroupedMinMaxSketch.DEFAULT_MINMAXSKETCH_GROUP_NUM;
83 | int rowNum = MinMaxSketch.DEFAULT_MINMAXSKETCH_ROW_NUM;
84 | double colRatio = GroupedMinMaxSketch.DEFAULT_MINMAXSKETCH_COL_RATIO;
85 | VectorCompressor compressor = new SparseVectorCompressor(
86 | quantType, binNum, groupNum, rowNum, colRatio);
87 | compressor = (VectorCompressor) Utils.testSerialization(compressor);
88 | //compressor.compressSparse(keys, values);
89 | compressor.parallelCompressSparse(keys, values);
90 | Pair dResult = compressor.decompressSparse();
91 | int[] dKeys = dResult.getLeft();
92 | double[] dValues = dResult.getRight();
93 | LOG.info(String.format("Array length: [%d, %d] vs. [%d, %d]",
94 | nnz, nnz, dKeys.length, dValues.length));
95 | LOG.info("First 10 keys before: " + Arrays.toString(Arrays.copyOf(keys, 10)));
96 | LOG.info("First 10 keys after: " + Arrays.toString(Arrays.copyOf(dKeys, 10)));
97 | LOG.info("First 10 values before: " + Arrays.toString(Arrays.copyOf(values, 10)));
98 | LOG.info("First 10 values after: " + Arrays.toString(Arrays.copyOf(dValues, 10)));
99 | QuantileSketch qSketch = new HeapQuantileSketch((long) nnz);
100 | double rmse = 0.0;
101 | for (int i = 0; i < nnz; i++) {
102 | if (keys[i] != dKeys[i]) {
103 | LOG.error(String.format("Keys not match: [%d, %d]", keys[i], dKeys[i]));
104 | } else {
105 | qSketch.update(dValues[i] - values[i]);
106 | rmse += (dValues[i] - values[i]) * (dValues[i] - values[i]);
107 | }
108 | }
109 | double[] err = qSketch.getQuantiles(100);
110 | rmse = Math.sqrt(rmse / nnz);
111 | LOG.info("Quantiles of errors: " + Arrays.toString(err));
112 | LOG.info("RMSE: " + rmse);
113 | int originBytes = 12 * nnz;
114 | int compressBytes = compressor.memoryBytes();
115 | LOG.info(String.format("Compress %d bytes into %d bytes, compression rate: %f",
116 | originBytes, compressBytes, 1.0 * originBytes / compressBytes));
117 | }
118 | }
119 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/sample/DenseVectorCompressor.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.sample;
2 |
3 | import org.apache.commons.lang3.tuple.ImmutablePair;
4 | import org.apache.commons.lang3.tuple.Pair;
5 | import org.dma.sketchml.sketch.base.Quantizer;
6 | import org.dma.sketchml.sketch.base.SketchMLException;
7 | import org.dma.sketchml.sketch.base.VectorCompressor;
8 | import org.dma.sketchml.sketch.util.Maths;
9 | import org.dma.sketchml.sketch.util.Utils;
10 | import org.slf4j.Logger;
11 | import org.slf4j.LoggerFactory;
12 |
13 | import java.io.IOException;
14 | import java.io.Serializable;
15 | import java.util.Arrays;
16 | import java.util.concurrent.ExecutionException;
17 |
18 | public class DenseVectorCompressor implements VectorCompressor {
19 | private static final Logger LOG = LoggerFactory.getLogger(DenseVectorCompressor.class);
20 |
21 | private int size;
22 |
23 | private Quantizer.QuantizationType quantType;
24 | private int quantBinNum;
25 | private Quantizer quantizer;
26 |
27 | public DenseVectorCompressor(
28 | Quantizer.QuantizationType quantType, int quantBinNum) {
29 | this.quantType = quantType;
30 | this.quantBinNum = quantBinNum;
31 | }
32 |
33 | @Override
34 | public void compressDense(double[] values) {
35 | long startTime = System.currentTimeMillis();
36 | size = values.length;
37 | quantizer = Quantizer.newQuantizer(quantType, quantBinNum);
38 | quantizer.quantize(values);
39 | LOG.debug(String.format("Dense vector compression cost %d ms, %d items " +
40 | "in total", System.currentTimeMillis() - startTime, size));
41 | }
42 |
43 | @Override
44 | public void compressSparse(int[] keys, double[] values) {
45 | LOG.warn("Compressing a sparse vector with DenseVectorCompressor");
46 | if (keys.length != values.length) {
47 | throw new SketchMLException(String.format(
48 | "Lengths of key array and value array do not match: %d, %d",
49 | keys.length, values.length));
50 | }
51 | int maxKey = Maths.max(keys);
52 | double[] dense = new double[maxKey];
53 | for (int i = 0; i < keys.length; i++)
54 | dense[keys[i]] = values[i];
55 | compressDense(dense);
56 | }
57 |
58 | @Override
59 | public void parallelCompressDense(double[] values) throws InterruptedException, ExecutionException {
60 | long startTime = System.currentTimeMillis();
61 | size = values.length;
62 | quantizer = Quantizer.newQuantizer(quantType, quantBinNum);
63 | quantizer.parallelQuantize(values);
64 | LOG.debug(String.format("Dense vector parallel compression cost %d ms, %d items " +
65 | "in total", System.currentTimeMillis() - startTime, size));
66 | }
67 |
68 | @Override
69 | public void parallelCompressSparse(int[] keys, double[] values) throws InterruptedException, ExecutionException {
70 | LOG.warn("Compressing a sparse vector with DenseVectorCompressor");
71 | if (keys.length != values.length) {
72 | throw new SketchMLException(String.format(
73 | "Lengths of key array and value array do not match: %d, %d",
74 | keys.length, values.length));
75 | }
76 | int maxKey = Maths.max(keys);
77 | double[] dense = new double[maxKey];
78 | for (int i = 0; i < keys.length; i++)
79 | dense[keys[i]] = values[i];
80 | parallelCompressDense(dense);
81 | }
82 |
83 | @Override
84 | public double[] decompressDense() {
85 | double[] values = new double[size];
86 | double[] quantValues = quantizer.getValues();
87 | int[] bins = quantizer.getBins();
88 | for (int i = 0; i < size; i++)
89 | values[i] = quantValues[bins[i]];
90 | return values;
91 | }
92 |
93 | @Override
94 | public Pair decompressSparse() {
95 | double[] values = decompressDense();
96 | int[] keys = new int[values.length];
97 | Arrays.setAll(keys, i -> i);
98 | return new ImmutablePair<>(keys, values);
99 | }
100 |
101 | @Override
102 | public void timesBy(double x) {
103 | quantizer.timesBy(x);
104 | }
105 |
106 | @Override
107 | public double size() {
108 | return size;
109 | }
110 |
111 | @Override
112 | public int memoryBytes() throws IOException {
113 | int res = 12;
114 | if (quantizer != null) res += Utils.sizeof(quantizer);
115 | return res;
116 | }
117 | }
118 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/sample/SparseVectorCompressor.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.sample;
2 |
3 |
4 | import org.apache.commons.lang3.tuple.ImmutablePair;
5 | import org.apache.commons.lang3.tuple.Pair;
6 | import org.dma.sketchml.sketch.sketch.frequency.GroupedMinMaxSketch;
7 | import org.dma.sketchml.sketch.base.Quantizer;
8 | import org.dma.sketchml.sketch.base.SketchMLException;
9 | import org.dma.sketchml.sketch.base.VectorCompressor;
10 | import org.dma.sketchml.sketch.util.Utils;
11 | import org.slf4j.Logger;
12 | import org.slf4j.LoggerFactory;
13 |
14 | import java.io.*;
15 | import java.util.Arrays;
16 | import java.util.concurrent.ExecutionException;
17 |
18 | public class SparseVectorCompressor implements VectorCompressor {
19 | private static final Logger LOG = LoggerFactory.getLogger(SparseVectorCompressor.class);
20 |
21 | private int size;
22 |
23 | private Quantizer.QuantizationType quantType;
24 | private int quantBinNum;
25 | private double[] quantValues;
26 |
27 | private GroupedMinMaxSketch mmSketches;
28 |
29 | private int mmSketchGroupNum;
30 | private int mmSketchRowNum;
31 | private double mmSketchColRatio;
32 |
33 | public SparseVectorCompressor(
34 | Quantizer.QuantizationType quantType, int quantBinNum,
35 | int mmSketchGroupNum, int mmSketchRowNum, double mmSketchColRatio) {
36 | this.quantType = quantType;
37 | this.quantBinNum = quantBinNum;
38 | this.mmSketchGroupNum = mmSketchGroupNum;
39 | this.mmSketchRowNum = mmSketchRowNum;
40 | this.mmSketchColRatio = mmSketchColRatio;
41 | }
42 |
43 | @Override
44 | public void compressDense(double[] values) {
45 | LOG.warn("Compressing a dense vector with SparseVectorCompressor");
46 | int[] keys = new int[values.length];
47 | Arrays.setAll(keys, i -> i);
48 | compressSparse(keys, values);
49 | }
50 |
51 | @Override
52 | public void compressSparse(int[] keys, double[] values) {
53 | long startTime = System.currentTimeMillis();
54 | if (keys.length != values.length) {
55 | throw new SketchMLException(String.format(
56 | "Lengths of key array and value array do not match: %d, %d",
57 | keys.length, values.length));
58 | }
59 | size = keys.length;
60 | // 1. quantize into bin indexes
61 | Quantizer quantizer = Quantizer.newQuantizer(quantType, quantBinNum);
62 | quantizer.quantize(values);
63 | quantValues = quantizer.getValues();
64 | // 2. encode bins and keys
65 | mmSketches = new GroupedMinMaxSketch(mmSketchGroupNum, mmSketchRowNum,
66 | mmSketchColRatio, quantizer.getBinNum(), quantizer.getZeroIdx());
67 | mmSketches.create(keys, quantizer.getBins());
68 | LOG.debug(String.format("Sparse vector compression cost %d ms, %d key-value " +
69 | "pairs in total", System.currentTimeMillis() - startTime, size));
70 | }
71 |
72 | @Override
73 | public void parallelCompressDense(double[] values) throws InterruptedException, ExecutionException {
74 | LOG.warn("Compressing a dense vector with SparseVectorCompressor");
75 | int[] keys = new int[values.length];
76 | Arrays.setAll(keys, i -> i);
77 | parallelCompressSparse(keys, values);
78 | }
79 |
80 | @Override
81 | public void parallelCompressSparse(int[] keys, double[] values) throws InterruptedException, ExecutionException {
82 | long startTime = System.currentTimeMillis();
83 | if (keys.length != values.length) {
84 | throw new SketchMLException(String.format(
85 | "Lengths of key array and value array do not match: %d, %d",
86 | keys.length, values.length));
87 | }
88 | size = keys.length;
89 | // 1. quantize into bin indexes
90 | Quantizer quantizer = Quantizer.newQuantizer(quantType, quantBinNum);
91 | quantizer.parallelQuantize(values);
92 | quantValues = quantizer.getValues();
93 | // 2. encode bins and keys
94 | mmSketches = new GroupedMinMaxSketch(mmSketchGroupNum, mmSketchRowNum,
95 | mmSketchColRatio, quantizer.getBinNum(), quantizer.getZeroIdx());
96 | mmSketches.parallelCreate(keys, quantizer.getBins());
97 | LOG.debug(String.format("Sparse vector parallel compression cost %d ms, %d key-value " +
98 | "pairs in total", System.currentTimeMillis() - startTime, size));
99 | }
100 |
101 | @Override
102 | public double[] decompressDense() {
103 | Pair kv = decompressSparse();
104 | int[] keys = kv.getLeft();
105 | double[] values = kv.getRight();
106 | int maxKey = 0;
107 | for (int key : keys) {
108 | maxKey = Math.max(key, maxKey);
109 | }
110 | double[] res = new double[maxKey + 1];
111 | for (int i = 0; i < size; i++) {
112 | res[keys[i]] = values[i];
113 | }
114 | return res;
115 | }
116 |
117 | @Override
118 | public Pair decompressSparse() {
119 | Pair kb = mmSketches.restore();
120 | int[] keys = kb.getLeft();
121 | int[] bins = kb.getRight();
122 | double[] values = new double[size];
123 | for (int i = 0; i < size; i++)
124 | values[i] = quantValues[bins[i]];
125 | return new ImmutablePair<>(keys, values);
126 | }
127 |
128 | @Override
129 | public void timesBy(double x) {
130 | if (quantValues != null) {
131 | for (int i = 0; i < quantValues.length; i++)
132 | quantValues[i] *= x;
133 | }
134 | }
135 |
136 | @Override
137 | public double size() {
138 | return size;
139 | }
140 |
141 | @Override
142 | public int memoryBytes() throws IOException {
143 | int res = 28 + quantValues.length * 8;
144 | if (mmSketches != null)
145 | res += Utils.sizeof(mmSketches);
146 | return res;
147 | }
148 | }
149 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/sketch/frequency/FSketchUtils.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.sketch.frequency;
2 |
3 | import it.unimi.dsi.fastutil.ints.IntArrayList;
4 | import org.apache.commons.lang3.tuple.ImmutablePair;
5 | import org.apache.commons.lang3.tuple.Pair;
6 |
7 | public class FSketchUtils {
8 |
9 | public static int[] calGroupEdges(int zeroIdx, int binNum, int groupNum) {
10 | if (groupNum == 2) {
11 | return new int[]{zeroIdx, binNum};
12 | } else {
13 | int[] groupEdges = new int[groupNum];
14 | int binsPerGroup = binNum / groupNum;
15 | if (zeroIdx < binsPerGroup) {
16 | groupEdges[0] = zeroIdx;
17 | } else if ((zeroIdx % binsPerGroup) < (binsPerGroup / 2)) {
18 | groupEdges[0] = binsPerGroup + zeroIdx % binsPerGroup;
19 | } else {
20 | groupEdges[0] = zeroIdx % binsPerGroup;
21 | }
22 | for (int i = 1; i < groupNum - 1; i++) {
23 | groupEdges[i] = groupEdges[i - 1] + binsPerGroup;
24 | }
25 | groupEdges[groupNum - 1] = binNum;
26 | return groupEdges;
27 | }
28 | }
29 |
30 | public static Pair partition(int[] keys, int[] bins, int[] groupEdges) {
31 | int groupNum = groupEdges.length;
32 | IntArrayList[] keyLists = new IntArrayList[groupNum];
33 | IntArrayList[] binLists = new IntArrayList[groupNum];
34 | for (int i = 0; i < groupNum; i++) {
35 | int groupSpan = i > 0 ? (groupEdges[i] - groupEdges[i - 1]) : groupEdges[0];
36 | int estimatedGroupSize = (int) Math.ceil(1.0 * keys.length / groupNum * groupSpan);
37 | keyLists[i] = new IntArrayList(estimatedGroupSize);
38 | binLists[i] = new IntArrayList(estimatedGroupSize);
39 | }
40 | for (int i = 0; i < keys.length; i++) {
41 | int groupIdx = 0;
42 | while (groupEdges[groupIdx] <= bins[i]) groupIdx++;
43 | keyLists[groupIdx].add(keys[i]);
44 | binLists[groupIdx].add(bins[i]);
45 | }
46 | return new ImmutablePair<>(keyLists, binLists);
47 | }
48 |
49 | }
50 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/sketch/frequency/GroupedMinMaxSketch.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.sketch.frequency;
2 |
3 | import it.unimi.dsi.fastutil.ints.IntArrayList;
4 | import org.apache.commons.lang3.tuple.ImmutablePair;
5 | import org.apache.commons.lang3.tuple.Pair;
6 | import org.dma.sketchml.sketch.base.BinaryEncoder;
7 | import org.dma.sketchml.sketch.binary.DeltaAdaptiveEncoder;
8 | import org.dma.sketchml.sketch.common.Constants;
9 | import org.dma.sketchml.sketch.util.Sort;
10 | import org.slf4j.Logger;
11 | import org.slf4j.LoggerFactory;
12 |
13 | import java.io.IOException;
14 | import java.io.ObjectInputStream;
15 | import java.io.ObjectOutputStream;
16 | import java.io.Serializable;
17 | import java.util.ArrayList;
18 | import java.util.List;
19 | import java.util.concurrent.Callable;
20 | import java.util.concurrent.ExecutionException;
21 | import java.util.concurrent.ExecutorService;
22 | import java.util.concurrent.Future;
23 |
24 | public class GroupedMinMaxSketch implements Serializable {
25 | private static final Logger LOG = LoggerFactory.getLogger(GroupedMinMaxSketch.class);
26 |
27 | private int groupNum;
28 | private int rowNum;
29 | private double colRatio;
30 | private int binNum;
31 | private int zeroValue;
32 | private MinMaxSketch[] sketches;
33 | private BinaryEncoder[] encoders;
34 |
35 | public static final int DEFAULT_MINMAXSKETCH_GROUP_NUM = 8;
36 | public static final double DEFAULT_MINMAXSKETCH_COL_RATIO = 0.3;
37 |
38 | public GroupedMinMaxSketch(int groupNum, int rowNum, double colRatio, int binNum, int zeroValue) {
39 | this.groupNum = groupNum;
40 | this.rowNum = rowNum;
41 | this.colRatio = colRatio;
42 | this.binNum = binNum;
43 | this.zeroValue = zeroValue;
44 | }
45 |
46 | public GroupedMinMaxSketch(int binNum, int zeroValue) {
47 | this(DEFAULT_MINMAXSKETCH_GROUP_NUM, MinMaxSketch.DEFAULT_MINMAXSKETCH_ROW_NUM,
48 | DEFAULT_MINMAXSKETCH_COL_RATIO, binNum, zeroValue);
49 | }
50 |
51 | public void create(int[] keys, int[] bins) {
52 | long startTime = System.currentTimeMillis();
53 | // 1. divide bins into several groups
54 | int[] groupEdges = FSketchUtils.calGroupEdges(zeroValue, binNum, groupNum);
55 | sketches = new MinMaxSketch[groupNum];
56 | encoders = new BinaryEncoder[groupNum];
57 | Pair partKBLists =
58 | FSketchUtils.partition(keys, bins, groupEdges);
59 | // 2. encode bins and keys
60 | for (int i = 0; i < groupNum; i++) {
61 | IntArrayList keyList = partKBLists.getLeft()[i];
62 | IntArrayList binList = partKBLists.getRight()[i];
63 | Pair group = compOneGroup(
64 | keyList, binList, groupEdges, i);
65 | sketches[i] = group.getLeft();
66 | encoders[i] = group.getRight();
67 | }
68 | LOG.debug(String.format("Create grouped MinMaxSketch cost %d ms",
69 | System.currentTimeMillis() - startTime));
70 | }
71 |
72 | public void parallelCreate(int[] keys, int[] bins) throws InterruptedException, ExecutionException {
73 | long startTime = System.currentTimeMillis();
74 | // 1. divide bins into several groups
75 | int[] groupEdges = FSketchUtils.calGroupEdges(zeroValue, binNum, groupNum);
76 | sketches = new MinMaxSketch[groupNum];
77 | encoders = new BinaryEncoder[groupNum];
78 | Pair partKBLists =
79 | FSketchUtils.partition(keys, bins, groupEdges);
80 | // 2. each thread encode one group of bins and keys
81 | ExecutorService threadPool = Constants.Parallel.getThreadPool();
82 | Future>[] futures = new Future[groupNum];
83 | for (int i = 0; i < groupNum; i++) {
84 | int threadId = i;
85 | futures[threadId] = threadPool.submit(new Callable>() {
86 | @Override
87 | public Pair call() throws Exception {
88 | IntArrayList keyList = partKBLists.getLeft()[threadId];
89 | IntArrayList binList = partKBLists.getRight()[threadId];
90 | return compOneGroup(keyList, binList, groupEdges, threadId);
91 | }
92 | });
93 | }
94 | for (int i = 0; i < groupNum; i++) {
95 | Pair res = futures[i].get();
96 | sketches[i] = res.getLeft();
97 | encoders[i] = res.getRight();
98 | }
99 | LOG.debug(String.format("Create grouped MinMaxSketch cost %d ms",
100 | System.currentTimeMillis() - startTime));
101 | }
102 |
103 | private Pair compOneGroup(IntArrayList keyList, IntArrayList binList,
104 | int[] groupEdges, int groupId) {
105 | int groupSize = keyList.size();
106 | if (groupSize == 0) {
107 | LOG.warn(String.format("Group[%d] is empty, group edges: [%d, %d)", groupId,
108 | groupId == 0 ? 0 : groupEdges[groupId - 1], groupEdges[groupId]));
109 | return new ImmutablePair<>(null, null);
110 | }
111 | // encode bins
112 | int colNum = (int) Math.ceil(groupSize * colRatio);
113 | MinMaxSketch sketch = new MinMaxSketch(rowNum, colNum, zeroValue);
114 | for (int j = 0; j < groupSize; j++) {
115 | sketch.insert(keyList.getInt(j), binList.getInt(j));
116 | }
117 | // encode keys
118 | BinaryEncoder encoder = new DeltaAdaptiveEncoder();
119 | encoder.encode(keyList.toIntArray(null));
120 | return new ImmutablePair<>(sketch, encoder);
121 | }
122 |
123 | public Pair restore() {
124 | int size = 0;
125 | // decode each group
126 | // in case there are empty groups
127 | List keysToMerge = new ArrayList<>(groupNum);
128 | List binsToMerge = new ArrayList<>(groupNum);
129 | for (int i = 0; i < groupNum; i++) {
130 | if (encoders[i] != null && sketches[i] != null) {
131 | int[] groupKeys = encoders[i].decode();
132 | int[] groupBins = new int[groupKeys.length];
133 | for (int j = 0; j < groupKeys.length; j++)
134 | groupBins[j] = sketches[i].query(groupKeys[j]);
135 | keysToMerge.add(groupKeys);
136 | binsToMerge.add(groupBins);
137 | size += groupKeys.length;
138 | }
139 | }
140 | // merge
141 | int[] keys = new int[size];
142 | int[] bins = new int[size];
143 | Sort.merge(keysToMerge.toArray(new int[keysToMerge.size()][]),
144 | binsToMerge.toArray(new int[binsToMerge.size()][]), keys, bins);
145 | return new ImmutablePair<>(keys, bins);
146 | }
147 |
148 | private void writeObject(ObjectOutputStream oos) throws IOException {
149 | oos.writeInt(groupNum);
150 | oos.writeInt(rowNum);
151 | oos.writeDouble(colRatio);
152 | oos.writeInt(binNum);
153 | oos.writeInt(zeroValue);
154 | for (MinMaxSketch sketch : sketches)
155 | oos.writeObject(sketch);
156 | for (BinaryEncoder encoder : encoders)
157 | oos.writeObject(encoder);
158 | }
159 |
160 | private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
161 | groupNum = ois.readInt();
162 | rowNum = ois.readInt();
163 | colRatio = ois.readDouble();
164 | binNum = ois.readInt();
165 | zeroValue = ois.readInt();
166 | sketches = new MinMaxSketch[groupNum];
167 | for (int i = 0; i < groupNum; i++)
168 | sketches[i] = (MinMaxSketch) ois.readObject();
169 | encoders = new BinaryEncoder[groupNum];
170 | for (int i = 0; i < groupNum; i++)
171 | encoders[i] = (BinaryEncoder) ois.readObject();
172 | }
173 |
174 | }
175 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/sketch/frequency/MinMaxSketch.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.sketch.frequency;
2 |
3 | import org.dma.sketchml.sketch.base.BinaryEncoder;
4 | import org.dma.sketchml.sketch.base.Int2IntHash;
5 | import org.dma.sketchml.sketch.binary.HuffmanEncoder;
6 | import org.dma.sketchml.sketch.hash.HashFactory;
7 | import org.slf4j.Logger;
8 | import org.slf4j.LoggerFactory;
9 |
10 | import java.io.IOException;
11 | import java.io.ObjectInputStream;
12 | import java.io.ObjectOutputStream;
13 | import java.io.Serializable;
14 | import java.util.Arrays;
15 |
16 | public class MinMaxSketch implements Serializable {
17 | private static final Logger LOG = LoggerFactory.getLogger(MinMaxSketch.class);
18 |
19 | protected int rowNum;
20 | protected int colNum;
21 | protected int[] table;
22 | protected int zeroValue;
23 | protected Int2IntHash[] hashes;
24 |
25 | public static final int DEFAULT_MINMAXSKETCH_ROW_NUM = 2;
26 |
27 | public MinMaxSketch(int rowNum, int colNum, int zeroValue) {
28 | this.rowNum = rowNum;
29 | this.colNum = colNum;
30 | this.table = new int[rowNum * colNum];
31 | this.zeroValue = zeroValue;
32 | int maxValue = compare(Integer.MIN_VALUE, Integer.MAX_VALUE) <= 0
33 | ? Integer.MIN_VALUE : Integer.MAX_VALUE;
34 | Arrays.fill(table, maxValue);
35 | this.hashes = HashFactory.getRandomInt2IntHashes(rowNum, colNum);
36 | }
37 |
38 | public MinMaxSketch(int colNum, int zeroValue) {
39 | this(DEFAULT_MINMAXSKETCH_ROW_NUM, colNum, zeroValue);
40 | }
41 |
42 | /**
43 | * Min: insert the minimal (closest to `zeroValue`) value
44 | *
45 | * @param key
46 | * @param value
47 | */
48 | public void insert(int key, int value) {
49 | for (int i = 0; i < rowNum; i++) {
50 | int code = hashes[i].hash(key);
51 | int index = i * colNum + code;
52 | if (compare(value, table[index]) < 0)
53 | table[index] = value;
54 | }
55 | }
56 |
57 |
58 | /**
59 | * Max: return the maximal (furthest to `zeroValue`) value
60 | *
61 | * @param key
62 | * @return
63 | */
64 | public int query(int key) {
65 | int res = zeroValue;
66 | for (int i = 0; i < rowNum; i++) {
67 | int code = hashes[i].hash(key);
68 | int index = i * colNum + code;
69 | if (compare(table[index], res) > 0)
70 | res = table[index];
71 | }
72 | return res;
73 | }
74 |
75 | /**
76 | * Compare two numbers' distances w.r.t. `zeroValue`
77 | *
78 | * @param v1
79 | * @param v2
80 | * @return
81 | */
82 | private int compare(int v1, int v2) {
83 | int d1 = Math.abs(v1 - zeroValue);
84 | int d2 = Math.abs(v2 - zeroValue);
85 | return d1 - d2;
86 | }
87 |
88 | private void writeObject(ObjectOutputStream oos) throws IOException {
89 | oos.writeInt(rowNum);
90 | oos.writeInt(colNum);
91 | oos.writeInt(zeroValue);
92 | for (Int2IntHash hash : hashes)
93 | oos.writeObject(hash);
94 | BinaryEncoder huffman = new HuffmanEncoder();
95 | huffman.encode(table);
96 | oos.writeObject(huffman);
97 | }
98 |
99 | private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
100 | rowNum = ois.readInt();
101 | colNum = ois.readInt();
102 | zeroValue = ois.readInt();
103 | hashes = new Int2IntHash[rowNum];
104 | for (int i = 0; i < rowNum; i++)
105 | hashes[i] = (Int2IntHash) ois.readObject();
106 | BinaryEncoder encoder = (BinaryEncoder) ois.readObject();
107 | table = encoder.decode();
108 | }
109 |
110 | public int getRowNum() {
111 | return rowNum;
112 | }
113 |
114 | public int getColNum() {
115 | return colNum;
116 | }
117 |
118 | public int getZeroValue() {
119 | return zeroValue;
120 | }
121 | }
122 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/sketch/quantile/HeapQuantileSketch.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.sketch.quantile;
2 |
3 | import org.dma.sketchml.sketch.base.QuantileSketch;
4 |
5 | import java.util.Arrays;
6 |
7 | /**
8 | * Implementation of quantile sketch on the Java heap
9 | * bashed on `DataSketches` of Yahoo!
10 | */
11 | public class HeapQuantileSketch extends QuantileSketch {
12 | private int k; // parameter that controls space usage
13 | public static final int DEFAULT_K = 128;
14 |
15 | /**
16 | * This single array contains the base buffer plus all levels some of which may not be used.
17 | * A level is of size K and is either full and sorted, or not used. A "not used" buffer may have
18 | * garbage. Whether a level buffer used or not is indicated by the bitPattern_.
19 | * The base buffer has length 2*K but might not be full and isn't necessarily sorted.
20 | * The base buffer precedes the level buffers.
21 | *
22 | * The levels arrays require quite a bit of explanation, which we defer until later.
23 | */
24 | private double[] combinedBuffer;
25 | private int combinedBufferCapacity; // equals combinedBuffer.length
26 | private int baseBufferCount; // #samples currently in base buffer (= n % (2*k))
27 | private long bitPattern; // active levels expressed as a bit pattern (= n / (2*k))
28 | private static final int MIN_BASE_BUF_SIZE = 4;
29 |
30 | /**
31 | * data structure for answering quantile queries
32 | */
33 | private double[] samplesArr; // array of size samples
34 | private long[] weightsArr; // array of cut points
35 |
36 | public HeapQuantileSketch(int k, long estimateN) {
37 | super(estimateN);
38 | QSketchUtils.checkK(k);
39 | this.k = k;
40 | reset();
41 | }
42 |
43 | public HeapQuantileSketch() {
44 | this(DEFAULT_K, -1L);
45 | }
46 |
47 | public HeapQuantileSketch(int k) {
48 | this(k, -1L);
49 | }
50 |
51 | public HeapQuantileSketch(long estimateN) {
52 | this(DEFAULT_K, estimateN);
53 | }
54 |
55 | @Override
56 | public void reset() {
57 | n = 0;
58 | if (estimateN < 0)
59 | combinedBufferCapacity = Math.min(MIN_BASE_BUF_SIZE, k * 2);
60 | else if (estimateN < k * 2)
61 | combinedBufferCapacity = k * 4;
62 | else
63 | combinedBufferCapacity = QSketchUtils.needBufferCapacity(k, estimateN);
64 | combinedBuffer = new double[combinedBufferCapacity];
65 | baseBufferCount = 0;
66 | bitPattern = 0L;
67 | minValue = Double.MAX_VALUE;
68 | maxValue = Double.MIN_VALUE;
69 | samplesArr = null;
70 | weightsArr = null;
71 | }
72 |
73 | @Override
74 | public void update(double value) {
75 | if (Double.isNaN(value))
76 | throw new QuantileSketchException("Encounter NaN value");
77 | maxValue = Math.max(maxValue, value);
78 | minValue = Math.min(minValue, value);
79 |
80 | if (baseBufferCount + 1 > combinedBufferCapacity)
81 | ensureBaseBuffer();
82 | combinedBuffer[baseBufferCount++] = value;
83 | n++;
84 | if (baseBufferCount == (k * 2))
85 | fullBaseBufferPropagation();
86 | }
87 |
88 | private void ensureBaseBuffer() {
89 | final double[] baseBuffer = combinedBuffer;
90 | int oldSize = combinedBufferCapacity;
91 | if (oldSize >= k * 2)
92 | throw new QuantileSketchException("Buffer over size");
93 | int newSize = Math.max(Math.min(k * 2, oldSize * 2), 1);
94 | combinedBufferCapacity = newSize;
95 | combinedBuffer = Arrays.copyOf(baseBuffer, newSize);
96 | }
97 |
98 | private void ensureLevels(long newN) {
99 | int numLevels = 1 + (63 - Long.numberOfLeadingZeros(newN / (k * 2)));
100 | int spaceNeeded = k * (numLevels + 2);
101 | if (spaceNeeded <= combinedBufferCapacity) return;
102 | final double[] baseBuffer = combinedBuffer;
103 | combinedBuffer = Arrays.copyOf(baseBuffer, spaceNeeded);
104 | combinedBufferCapacity = spaceNeeded;
105 | }
106 |
107 | private void fullBaseBufferPropagation() {
108 | ensureLevels(n);
109 | final double[] baseBuffer = combinedBuffer;
110 | Arrays.sort(baseBuffer, 0, baseBufferCount);
111 | inPlacePropagationUpdate(0, baseBuffer, 0);
112 | baseBufferCount = 0;
113 | QSketchUtils.checkBitPattern(bitPattern, n, k);
114 | }
115 |
116 | private void inPlacePropagationUpdate(int beginLevel, final double[] buf, int bufBeginPos) {
117 | final double[] levelsArr = combinedBuffer;
118 | int endLevel = beginLevel;
119 | long tmp = bitPattern >>> beginLevel;
120 | while ((tmp & 1) != 0) { tmp >>>= 1; endLevel++; }
121 | QSketchUtils.compactBuffer(buf, bufBeginPos, levelsArr, (endLevel + 2) * k, k);
122 | QSketchUtils.levelwisePropagation(bitPattern, k, beginLevel, endLevel, buf, bufBeginPos, levelsArr);
123 | bitPattern += 1L << beginLevel;
124 | }
125 |
126 | public void makeSummary() {
127 | int baseBufferItems = (int)(n % (k * 2));
128 | QSketchUtils.checkBitPattern(bitPattern, n, k);
129 | int validLevels = Long.bitCount(bitPattern);
130 | int numSamples = baseBufferItems + validLevels * k;
131 | samplesArr = new double[numSamples];
132 | weightsArr = new long[numSamples + 1];
133 |
134 | copyBuf2Arr(numSamples);
135 | QSketchUtils.blockyMergeSort(samplesArr, weightsArr, numSamples, k);
136 |
137 | long cnt = 0L;
138 | for (int i = 0; i <= numSamples; i++) {
139 | long newCnt = cnt + weightsArr[i];
140 | weightsArr[i] = cnt;
141 | cnt = newCnt;
142 | }
143 | }
144 |
145 | private void copyBuf2Arr(int numSamples) {
146 | long weight = 1L;
147 | int cur = 0;
148 | long bp = bitPattern;
149 |
150 | // copy the highest levels
151 | for (int level = 0; bp != 0; level++, bp >>>= 1) {
152 | weight *= 2;
153 | if ((bp & 1) != 0) {
154 | int offset = k * (level + 2);
155 | for (int i = 0; i < k; i++) {
156 | samplesArr[cur] = combinedBuffer[i + offset];
157 | weightsArr[cur] = weight;
158 | cur++;
159 | }
160 | }
161 | }
162 |
163 | // copy baseBuffer
164 | int startBlk = cur;
165 | for (int i = 0; i < baseBufferCount; i++) {
166 | samplesArr[cur] = combinedBuffer[i];
167 | weightsArr[cur] = 1L;
168 | cur++;
169 | }
170 | weightsArr[cur] = 0L;
171 | if (cur != numSamples)
172 | throw new QuantileSketchException("Missing items when copying buffer to array");
173 | Arrays.sort(samplesArr, startBlk, cur);
174 | }
175 |
176 | @Override
177 | public void merge(QuantileSketch other) {
178 | if (other instanceof HeapQuantileSketch) {
179 | merge((HeapQuantileSketch) other);
180 | } else {
181 | throw new QuantileSketchException("Cannot merge different " +
182 | "kinds of quantile sketches");
183 | }
184 | }
185 |
186 | public void merge(HeapQuantileSketch other) {
187 | if (other == null || other.isEmpty()) return;
188 | if (other.k != this.k)
189 | throw new QuantileSketchException("Merge sketches with different k");
190 | QSketchUtils.checkBitPattern(other.bitPattern, other.n, other.k);
191 | if (this.isEmpty()) {
192 | this.copy(other);
193 | return;
194 | }
195 |
196 | // merge two non-empty quantile sketches
197 | long totalN = this.n + other.n;
198 | for (int i = 0; i < other.baseBufferCount; i++) {
199 | update(other.combinedBuffer[i]);
200 | }
201 | ensureLevels(totalN);
202 |
203 | final double[] auxBuf = new double[k * 2];
204 | long bp = other.bitPattern;
205 | for (int level = 0; bp != 0L; level++, bp >>>= 1) {
206 | if ((bp & 1L) != 0L) {
207 | inPlacePropagationMerge(level, other.combinedBuffer,
208 | k * (level + 2), auxBuf, 0);
209 | }
210 | }
211 |
212 | this.n = totalN;
213 | this.maxValue = Math.max(this.maxValue, other.maxValue);
214 | this.minValue = Math.min(this.minValue, other.minValue);
215 | this.samplesArr = null;
216 | this.weightsArr = null;
217 | }
218 |
219 | private void inPlacePropagationMerge(int beginLevel, final double[] buf, int bufStart,
220 | final double[] auxBuf, int auxBufStart) {
221 | final double[] levelsArr = combinedBuffer;
222 | int endLevel = beginLevel;
223 | long tmp = bitPattern >>> beginLevel;
224 | while ((tmp & 1) != 0) { tmp >>>= 1; endLevel++; }
225 | System.arraycopy(buf, bufStart, levelsArr, k * (endLevel + 2), k);
226 | QSketchUtils.levelwisePropagation(bitPattern, k, beginLevel, endLevel, auxBuf, auxBufStart, levelsArr);
227 | bitPattern += 1L << beginLevel;
228 | }
229 |
230 | public void copy(HeapQuantileSketch other) {
231 | this.n = other.n;
232 | this.minValue = other.minValue;
233 | this.maxValue = other.maxValue;
234 | if (this.estimateN == -1) {
235 | this.combinedBufferCapacity = other.combinedBufferCapacity;
236 | this.combinedBuffer = other.combinedBuffer.clone();
237 | } else if (other.combinedBufferCapacity > this.combinedBufferCapacity) {
238 | this.combinedBufferCapacity = other.combinedBufferCapacity;
239 | this.combinedBuffer = other.combinedBuffer.clone();
240 | } else {
241 | System.arraycopy(other.combinedBuffer, 0,
242 | this.combinedBuffer, 0, other.combinedBufferCapacity);
243 | }
244 | this.baseBufferCount = other.baseBufferCount;
245 | this.bitPattern = other.bitPattern;
246 | if (other.samplesArr != null && other.weightsArr != null) {
247 | this.samplesArr = other.samplesArr.clone();
248 | this.weightsArr = other.weightsArr.clone();
249 | }
250 | }
251 |
252 | @Override
253 | public double getQuantile(double fraction) {
254 | QSketchUtils.checkFraction(fraction);
255 | if (samplesArr == null || weightsArr == null)
256 | makeSummary();
257 |
258 | if (samplesArr.length == 0)
259 | return Double.NaN;
260 |
261 | if (fraction == 0.0)
262 | return minValue;
263 | else if (fraction == 1.0)
264 | return maxValue;
265 | else
266 | return getQuantileFromArr(fraction);
267 | }
268 |
269 | @Override
270 | public double[] getQuantiles(double[] fractions) {
271 | QSketchUtils.checkFractions(fractions);
272 | if (samplesArr == null || weightsArr == null)
273 | makeSummary();
274 |
275 | double[] res = new double[fractions.length];
276 | if (samplesArr.length == 0) {
277 | Arrays.fill(res, Double.NaN);
278 | return res;
279 | }
280 |
281 | for (int i = 0; i < fractions.length; i++) {
282 | if (fractions[i] == 0.0)
283 | res[i] = minValue;
284 | else if (fractions[i] == 1.0)
285 | res[i] = maxValue;
286 | else
287 | res[i] = getQuantileFromArr(fractions[i]);
288 | }
289 | return res;
290 | }
291 |
292 | @Override
293 | public double[] getQuantiles(int evenPartition) {
294 | QSketchUtils.checkEvenPartiotion(evenPartition);
295 | if (samplesArr == null || weightsArr == null)
296 | makeSummary();
297 |
298 | double[] splits = new double[evenPartition - 1];
299 | if (samplesArr.length == 0) {
300 | Arrays.fill(splits, Double.NaN);
301 | return splits;
302 | }
303 |
304 | int index = 0;
305 | double curFrac = 1.0 / evenPartition;
306 | double step = 1.0 / evenPartition;
307 | for (int i = 0; i + 1 < evenPartition; i++) {
308 | long rank = (long)(n * curFrac);
309 | rank = Math.min(rank, n - 1);
310 | int left = index, right = weightsArr.length - 1;
311 | while (left + 1 < right) {
312 | int mid = left + ((right - left) >> 1);
313 | if (weightsArr[mid] <= rank)
314 | left = mid;
315 | else
316 | right = mid;
317 | }
318 | splits[i] = samplesArr[left];
319 | index = left;
320 | curFrac += step;
321 | }
322 | return splits;
323 | }
324 |
325 | private double getQuantileFromArr(double fraction) {
326 | long rank = (long)(n * fraction);
327 | if (rank == n) n--;
328 | int left = 0, right = weightsArr.length - 1;
329 | while (left + 1 < right) {
330 | int mid = left + ((right - left) >> 1);
331 | if (weightsArr[mid] <= rank)
332 | left = mid;
333 | else
334 | right = mid;
335 | }
336 | return samplesArr[left];
337 | }
338 |
339 | public int getK() {
340 | return k;
341 | }
342 |
343 |
344 | }
345 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/sketch/quantile/QSketchUtils.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.sketch.quantile;
2 |
3 | import org.dma.sketchml.sketch.util.Maths;
4 |
5 | import java.util.Arrays;
6 | import java.util.Random;
7 |
8 | public class QSketchUtils {
9 | private static final Random rand = new Random();
10 |
11 | protected static void checkK(int k) {
12 | if (k < 1)
13 | throw new QuantileSketchException("Invalid value of k: k should be positive");
14 | else if (k >= 65535)
15 | throw new QuantileSketchException("Invalid value of k: k should not be larger than 65536");
16 | else if (!Maths.isPowerOf2(k))
17 | throw new QuantileSketchException("Invalid value of k: k should be power of 2");
18 | }
19 |
20 | protected static int needBufferCapacity(int k, long estimateN) {
21 | int numLevels = 1 + (63 - Long.numberOfLeadingZeros(estimateN / (k * 2)));
22 | return k * (numLevels + 2);
23 | }
24 |
25 | protected static void checkBitPattern(long bitPattern, long n, int k) {
26 | if (bitPattern != n / (k * 2))
27 | throw new QuantileSketchException("Bit Pattern not match");
28 | }
29 |
30 | protected static void checkFraction(double fraction) {
31 | if (fraction < 0.0 || fraction > 1.0)
32 | throw new QuantileSketchException("Fraction should be in range [0.0, 1.0]");
33 | }
34 |
35 | protected static void checkFractions(double[] fractions) {
36 | for (double f: fractions)
37 | checkFraction(f);
38 | }
39 |
40 | protected static void checkEvenPartiotion(int evenPartition) {
41 | if (evenPartition <= 1)
42 | throw new QuantileSketchException("Invalid partition number: " + evenPartition);
43 | }
44 |
45 | protected static void compactBuffer(final double[] srcBuf, int srcOffset,
46 | final double[] dstBuf, int dstOffset, int dstSize) {
47 | int offset = rand.nextBoolean() ? 1 : 0;
48 | int bound = dstOffset + dstSize;
49 | for (int i = srcOffset + offset, j = dstOffset; j < bound; i += 2, j++)
50 | dstBuf[j] = srcBuf[i];
51 | }
52 |
53 | protected static void mergeArrays(final double[] src1, int srcOffset1,
54 | final double[] src2, int srcOffset2,
55 | final double[] dst, int dstOffset, int size) {
56 | int bound1 = srcOffset1 + size;
57 | int bound2 = srcOffset2 + size;
58 | int i1 = srcOffset1, i2 = srcOffset2, i3 = dstOffset;
59 | while (i1 < bound1 && i2 < bound2) {
60 | if (src1[i1] < src2[i2])
61 | dst[i3++] = src1[i1++];
62 | else
63 | dst[i3++] = src2[i2++];
64 | }
65 | if (i1 < bound1)
66 | System.arraycopy(src1, i1, dst, i3, bound1 - i1);
67 | else
68 | System.arraycopy(src2, i2, dst, i3, bound2 - i2);
69 | }
70 |
71 | protected static void levelwisePropagation(long bitPattern, int k,
72 | int beginLevel, int endLevel,
73 | final double[] buf, int bufBeginPos,
74 | final double[] levelsArr) {
75 | for (int level = beginLevel; level < endLevel; level++) {
76 | if ((bitPattern & (1L << level)) == 0)
77 | throw new QuantileSketchException("Encounter empty level: " + level);
78 | QSketchUtils.mergeArrays(levelsArr, k * (level + 2),
79 | levelsArr, k * (endLevel + 2), buf, bufBeginPos, k);
80 | QSketchUtils.compactBuffer(buf, bufBeginPos, levelsArr, k * (endLevel + 2), k);
81 | }
82 | }
83 |
84 | protected static void blockyMergeSort(final double[] keys, final long[] values,
85 | int length, int blkSize) {
86 | if (blkSize <= 0 || length <= blkSize) return;
87 | int numBlks = (length + (blkSize - 1)) / blkSize;
88 | final double[] tmpKeys = Arrays.copyOf(keys, length);
89 | final long[] tmpValues = Arrays.copyOf(values, length);
90 | recursiveBlockyMergeSort(tmpKeys, tmpValues, keys, values, 0, numBlks, blkSize, length);
91 | }
92 |
93 | protected static void recursiveBlockyMergeSort(final double[] kSrc, final long[] vSrc,
94 | final double[] kDst, final long[] vDst,
95 | int blkStart, int blkLen, int blkSize, int arrLimit) {
96 | if (blkLen == 1) return;
97 | int blkLen1 = blkLen >> 1;
98 | int blkLen2 = blkLen - blkLen1;
99 | int blkStart1 = blkStart;
100 | int blkStart2 = blkStart + blkLen1;
101 |
102 | recursiveBlockyMergeSort(kDst, vDst, kSrc, vSrc, blkStart1, blkLen1, blkSize, arrLimit);
103 | recursiveBlockyMergeSort(kDst, vDst, kSrc, vSrc, blkStart2, blkLen2, blkSize, arrLimit);
104 |
105 | int arrStart1 = blkStart1 * blkSize;
106 | int arrStart2 = blkStart2 * blkSize;
107 | int arrLen1 = blkLen1 * blkSize;
108 | int arrLen2 = blkLen2 * blkSize;
109 | if (arrStart2 + arrLen2 > arrLimit)
110 | arrLen2 = arrLimit - arrStart2;
111 |
112 | blockyMerge(kSrc, vSrc, arrStart1, arrLen1, arrStart2, arrLen2, kDst, vDst, arrStart1);
113 | }
114 |
115 | protected static void blockyMerge(final double []kSrc, final long []vSrc,
116 | int arrStart1, int arrLen1,
117 | int arrStart2, int arrLen2,
118 | final double []kDst, final long []vDst, int arrStart3){
119 | int arrEnd1 = arrStart1 + arrLen1;
120 | int arrEnd2 = arrStart2 + arrLen2;
121 | int i1 = arrStart1, i2 = arrStart2, i3 = arrStart3;
122 | while (i1 < arrEnd1 && i2 < arrEnd2) {
123 | if (kSrc[i1] <= kSrc[i2]){
124 | kDst[i3] = kSrc[i1];
125 | vDst[i3] = vSrc[i1];
126 | ++i1; ++i3;
127 | } else {
128 | kDst[i3] = kSrc[i2];
129 | vDst[i3] = vSrc[i2];
130 | ++i2; ++i3;
131 | }
132 | }
133 |
134 | if (i1 < arrEnd1) {
135 | System.arraycopy(kSrc, i1, kDst, i3, arrEnd1 - i1);
136 | System.arraycopy(vSrc, i1, vDst, i3, arrEnd1 - i1);
137 | } else {
138 | System.arraycopy(kSrc, i2, kDst, i3, arrEnd2 - i2);
139 | System.arraycopy(vSrc, i2, vDst, i3, arrEnd2 - i2);
140 | }
141 | }
142 | }
143 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/sketch/quantile/QuantileSketchException.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.sketch.quantile;
2 |
3 | import org.dma.sketchml.sketch.base.SketchMLException;
4 |
5 | public class QuantileSketchException extends SketchMLException {
6 | public QuantileSketchException(String message) {
7 | super(message);
8 | }
9 |
10 | public QuantileSketchException(Throwable cause) {
11 | super(cause);
12 | }
13 |
14 | public QuantileSketchException(String message, Throwable cause) {
15 | super(message, cause);
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/util/Maths.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.util;
2 |
3 | import org.dma.sketchml.sketch.base.SketchMLException;
4 |
5 | import java.util.Random;
6 |
7 | public class Maths {
8 | public static boolean isPowerOf2(int k) {
9 | for (int i = 1; i < 65536; i <<= 1) {
10 | if (k == i)
11 | return true;
12 | }
13 | return false;
14 | }
15 |
16 | public static int log2nlz(int k) {
17 | if (k <= 0)
18 | throw new SketchMLException("Log for " + k);
19 | else
20 | return 31 - Integer.numberOfLeadingZeros(k);
21 | }
22 |
23 | public static int max(int[] array) {
24 | int res = array[0];
25 | for (int i = 1; i < array.length; i++)
26 | res = Math.max(res, array[i]);
27 | return res;
28 | }
29 |
30 | public static int argmax(int[] array) {
31 | int max = array[0], res = 0;
32 | for (int i = 1; i < array.length; i++) {
33 | if (array[i] > max) {
34 | max = array[i];
35 | res = i;
36 | }
37 | }
38 | return res;
39 | }
40 |
41 | public static void shuffle(int[] array) {
42 | Random random = new Random();
43 | for (int i = array.length - 1; i > 0; i--) {
44 | int index = random.nextInt(i + 1);
45 | int t = array[index];
46 | array[index] = array[i];
47 | array[i] = t;
48 | }
49 | }
50 |
51 | public static double[] unique(double[] sorted) {
52 | int size = sorted.length, cnt = 1;
53 | for (int i = 1; i < size; i++)
54 | if (sorted[i] != sorted[i - 1])
55 | cnt++;
56 | if (cnt != size) {
57 | double[] res = new double[cnt];
58 | res[0] = sorted[0];
59 | int index = 1;
60 | for (int i = 1; i < size; i++)
61 | if (sorted[i] != sorted[i - 1])
62 | res[index++] = sorted[i];
63 | return res;
64 | } else {
65 | return sorted;
66 | }
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/util/Sort.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.util;
2 |
3 | import it.unimi.dsi.fastutil.doubles.DoubleArrayPriorityQueue;
4 | import it.unimi.dsi.fastutil.doubles.DoubleComparator;
5 | import it.unimi.dsi.fastutil.doubles.DoublePriorityQueue;
6 | import it.unimi.dsi.fastutil.ints.IntComparator;
7 | import org.dma.sketchml.sketch.base.SketchMLException;
8 |
9 | /**
10 | * Quick sort utils
11 | */
12 | public class Sort {
13 | public static int quickSelect(int[] array, int k, int low, int high) {
14 | if (k > 0 && k <= high - low + 1) {
15 | int pivot = array[high];
16 | int ii = low;
17 | for (int jj = low; jj < high; jj++) {
18 | if (array[jj] <= pivot) {
19 | swap(array, ii++, jj);
20 | }
21 | }
22 | swap(array, ii, high);
23 |
24 | if (ii - low == k - 1) {
25 | return array[ii];
26 | } else if (ii - low > k - 1) {
27 | return quickSelect(array, k, low, ii - 1);
28 | } else {
29 | return quickSelect(array, k - ii + low - 1, ii + 1, high);
30 | }
31 | }
32 | throw new SketchMLException("k is more than number of elements in array");
33 | }
34 |
35 | public static double selectKthLargest(double[] array, int k) {
36 | return selectKthLargest(array, k, new DoubleArrayPriorityQueue(k));
37 | }
38 |
39 | public static double selectKthLargest(double[] array, int k, DoubleComparator comp) {
40 | return selectKthLargest(array, k, new DoubleArrayPriorityQueue(k, comp));
41 | }
42 |
43 | private static double selectKthLargest(double[] array, int k, DoublePriorityQueue queue) {
44 | if (k > array.length)
45 | throw new SketchMLException("k is more than number of elements in array");
46 |
47 | int i = 0;
48 | while (i < k)
49 | queue.enqueue(array[i++]);
50 | for (; i < array.length; i++) {
51 | double top = queue.firstDouble();
52 | if (array[i] > top) {
53 | queue.dequeueDouble();
54 | queue.enqueue(array[i]);
55 | }
56 | }
57 | return queue.firstDouble();
58 | }
59 |
60 | public static void quickSort(int[] array, double[] values, int low, int high) {
61 | if (low < high) {
62 | int tmp = array[low];
63 | double tmpValue = values[low];
64 | int ii = low, jj = high;
65 | while (ii < jj) {
66 | while (ii < jj && array[jj] >= tmp) {
67 | jj--;
68 | }
69 |
70 | array[ii] = array[jj];
71 | values[ii] = values[jj];
72 |
73 | while (ii < jj && array[ii] <= tmp) {
74 | ii++;
75 | }
76 |
77 | array[jj] = array[ii];
78 | values[jj] = values[ii];
79 | }
80 | array[ii] = tmp;
81 | values[ii] = tmpValue;
82 |
83 | quickSort(array, values, low, ii - 1);
84 | quickSort(array, values, ii + 1, high);
85 | }
86 | }
87 |
88 | public static void quickSort(long[] array, double[] values, int low, int high) {
89 | if (low < high) {
90 | long tmp = array[low];
91 | double tmpValue = values[low];
92 | int ii = low, jj = high;
93 | while (ii < jj) {
94 | while (ii < jj && array[jj] >= tmp) {
95 | jj--;
96 | }
97 |
98 | array[ii] = array[jj];
99 | values[ii] = values[jj];
100 |
101 | while (ii < jj && array[ii] <= tmp) {
102 | ii++;
103 | }
104 |
105 | array[jj] = array[ii];
106 | values[jj] = values[ii];
107 | }
108 | array[ii] = tmp;
109 | values[ii] = tmpValue;
110 |
111 | quickSort(array, values, low, ii - 1);
112 | quickSort(array, values, ii + 1, high);
113 | }
114 | }
115 |
116 | public static void quickSort(long[] array, int low, int high) {
117 | if (low < high) {
118 | long tmp = array[low];
119 | int ii = low, jj = high;
120 | while (ii < jj) {
121 | while (ii < jj && array[jj] >= tmp) {
122 | jj--;
123 | }
124 |
125 | array[ii] = array[jj];
126 |
127 | while (ii < jj && array[ii] <= tmp) {
128 | ii++;
129 | }
130 |
131 | array[jj] = array[ii];
132 | }
133 | array[ii] = tmp;
134 |
135 | quickSort(array, low, ii - 1);
136 | quickSort(array, ii + 1, high);
137 | }
138 | }
139 |
140 | public static void quickSort(int[] array, int[] values, int low, int high) {
141 | if (low < high) {
142 | int tmp = array[low];
143 | int tmpValue = values[low];
144 | int ii = low, jj = high;
145 | while (ii < jj) {
146 | while (ii < jj && array[jj] >= tmp) {
147 | jj--;
148 | }
149 |
150 | array[ii] = array[jj];
151 | values[ii] = values[jj];
152 |
153 | while (ii < jj && array[ii] <= tmp) {
154 | ii++;
155 | }
156 |
157 | array[jj] = array[ii];
158 | values[jj] = values[ii];
159 | }
160 | array[ii] = tmp;
161 | values[ii] = tmpValue;
162 |
163 | quickSort(array, values, low, ii - 1);
164 | quickSort(array, values, ii + 1, high);
165 | }
166 | }
167 |
168 | public static void quickSort(double[] x, double[] y, int from, int to, DoubleComparator comp) {
169 | int len = to - from;
170 | if (len < 7) {
171 | selectionSort(x, y, from, to, comp);
172 | } else {
173 | int m = from + len / 2;
174 | int v;
175 | int a;
176 | int b;
177 | if (len > 7) {
178 | v = from;
179 | a = to - 1;
180 | if (len > 50) {
181 | b = len / 8;
182 | v = med3(x, from, from + b, from + 2 * b, comp);
183 | m = med3(x, m - b, m, m + b, comp);
184 | a = med3(x, a - 2 * b, a - b, a, comp);
185 | }
186 |
187 | m = med3(x, v, m, a, comp);
188 | }
189 |
190 | double seed = x[m];
191 | a = from;
192 | b = from;
193 | int c = to - 1;
194 | int d = c;
195 |
196 | while (true) {
197 | int s;
198 | while (b > c || (s = comp.compare(x[b], seed)) > 0) {
199 | for (; c >= b && (s = comp.compare(x[c], seed)) >= 0; --c) {
200 | if (s == 0) {
201 | swap(x, c, d);
202 | swap(y, c, d);
203 | d--;
204 | }
205 | }
206 |
207 | if (b > c) {
208 | s = Math.min(a - from, b - a);
209 | vecSwap(x, from, b - s, s);
210 | vecSwap(y, from, b - s, s);
211 | s = Math.min(d - c, to - d - 1);
212 | vecSwap(x, b, to - s, s);
213 | vecSwap(y, b, to - s, s);
214 | if ((s = b - a) > 1) {
215 | quickSort(x, y, from, from + s, comp);
216 | }
217 |
218 | if ((s = d - c) > 1) {
219 | quickSort(x, y, to - s, to, comp);
220 | }
221 |
222 | return;
223 | }
224 |
225 | swap(x, b, c);
226 | swap(y, b, c);
227 | b++;
228 | c--;
229 | }
230 |
231 | if (s == 0) {
232 | swap(x, a, b);
233 | swap(y, a, b);
234 | a++;
235 | }
236 |
237 | ++b;
238 | }
239 | }
240 | }
241 |
242 | public static void quickSort(double[] x, double[] y, int from, int to) {
243 | DoubleComparator cmp = new DoubleComparator() {
244 | public int compare(double v, double v1) {
245 | if (Math.abs(v - v1) < 10e-12)
246 | return 0;
247 | else
248 | return v - v1 > 10e-12 ? 1 : -1;
249 | }
250 |
251 | public int compare(Double o1, Double o2) {
252 | if (Math.abs(o1 - o2) < 10e-12)
253 | return 0;
254 | else
255 | return o1 - o2 > 10e-12 ? 1 : -1;
256 | }
257 | };
258 | quickSort(x, y, from, to, cmp);
259 | }
260 |
261 | public static void quickSort(int[] array, float[] values, int low, int high) {
262 | if (low < high) {
263 | int tmp = array[low];
264 | float tmpValue = values[low];
265 | int ii = low, jj = high;
266 | while (ii < jj) {
267 | while (ii < jj && array[jj] >= tmp) {
268 | jj--;
269 | }
270 |
271 | array[ii] = array[jj];
272 | values[ii] = values[jj];
273 |
274 | while (ii < jj && array[ii] <= tmp) {
275 | ii++;
276 | }
277 |
278 | array[jj] = array[ii];
279 | values[jj] = values[ii];
280 | }
281 | array[ii] = tmp;
282 | values[ii] = tmpValue;
283 |
284 | quickSort(array, values, low, ii - 1);
285 | quickSort(array, values, ii + 1, high);
286 | }
287 | }
288 |
289 |
290 | private static int med3(double[] x, int a, int b, int c, DoubleComparator comp) {
291 | int ab = comp.compare(x[a], x[b]);
292 | int ac = comp.compare(x[a], x[c]);
293 | int bc = comp.compare(x[b], x[c]);
294 | return ab < 0 ? (bc < 0 ? b : (ac < 0 ? c : a)) : (bc > 0 ? b : (ac > 0 ? c : a));
295 | }
296 |
297 | private static void vecSwap(double[] x, int a, int b, int n) {
298 | for (int i = 0; i < n; ++b) {
299 | swap(x, a, b);
300 | ++i;
301 | ++a;
302 | }
303 |
304 | }
305 |
306 | private static void swap(int[] x, int a, int b) {
307 | int t = x[a];
308 | x[a] = x[b];
309 | x[b] = t;
310 | }
311 |
312 | private static void swap(double[] x, int a, int b) {
313 | double t = x[a];
314 | x[a] = x[b];
315 | x[b] = t;
316 | }
317 |
318 | public static void selectionSort(int[] a, int[] y, int from, int to, IntComparator comp) {
319 | for (int i = from; i < to - 1; ++i) {
320 | int m = i;
321 |
322 | int u;
323 | for (u = i + 1; u < to; ++u) {
324 | if (comp.compare(a[u], a[m]) < 0) {
325 | m = u;
326 | }
327 | }
328 |
329 | if (m != i) {
330 | u = a[i];
331 | a[i] = a[m];
332 | a[m] = u;
333 | u = y[i];
334 | y[i] = y[m];
335 | y[m] = u;
336 | }
337 | }
338 |
339 | }
340 |
341 | public static void selectionSort(double[] a, double[] y, int from, int to,
342 | DoubleComparator comp) {
343 | for (int i = from; i < to - 1; ++i) {
344 | int m = i;
345 | for (int u = i + 1; u < to; ++u) {
346 | if (comp.compare(a[u], a[m]) < 0) {
347 | m = u;
348 | }
349 | }
350 |
351 | if (m != i) {
352 | double temp = a[i];
353 | a[i] = a[m];
354 | a[m] = temp;
355 | temp = y[i];
356 | y[i] = y[m];
357 | y[m] = temp;
358 | }
359 | }
360 | }
361 |
362 | public static void merge(int[][] as, int[][] ys, int[] a, int[] y) {
363 | int[] ks = new int[as.length];
364 | int cur = 0;
365 | while (cur < a.length) {
366 | int argmin = -1;
367 | int min = Integer.MAX_VALUE;
368 | for (int i = 0; i < ks.length; i++) {
369 | if (ks[i] < as[i].length && as[i][ks[i]] < min) {
370 | argmin = i;
371 | min = as[i][ks[i]];
372 | }
373 | }
374 | a[cur] = as[argmin][ks[argmin]];
375 | y[cur] = ys[argmin][ks[argmin]];
376 | ks[argmin]++;
377 | cur++;
378 | }
379 | }
380 |
381 | public static void merge(int[][] as, double[][] ys, int[] a, double[] y) {
382 | int[] ks = new int[as.length];
383 | int cur = 0;
384 | while (cur < a.length) {
385 | int argmin = -1;
386 | int min = Integer.MAX_VALUE;
387 | for (int i = 0; i < ks.length; i++) {
388 | if (ks[i] < as[i].length && as[i][ks[i]] < min)
389 | argmin = i;
390 | }
391 | a[cur] = as[argmin][ks[argmin]];
392 | y[cur] = ys[argmin][ks[argmin]];
393 | ks[argmin]++;
394 | cur++;
395 | }
396 | }
397 | }
398 |
--------------------------------------------------------------------------------
/sketch/src/main/java/org/dma/sketchml/sketch/util/Utils.java:
--------------------------------------------------------------------------------
1 | package org.dma.sketchml.sketch.util;
2 |
3 | import java.io.*;
4 |
5 | public class Utils {
6 | public static int sizeof(Object obj) throws IOException {
7 | ByteArrayOutputStream byteOutputStream = new ByteArrayOutputStream();
8 | ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteOutputStream);
9 |
10 | objectOutputStream.writeObject(obj);
11 | objectOutputStream.flush();
12 | objectOutputStream.close();
13 |
14 | return byteOutputStream.toByteArray().length;
15 | }
16 |
17 | public static Serializable testSerialization(Serializable obj) throws IOException, ClassNotFoundException {
18 | ByteArrayOutputStream byteOutputStream = new ByteArrayOutputStream();
19 | ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteOutputStream);
20 | objectOutputStream.writeObject(obj);
21 | objectOutputStream.flush();
22 | objectOutputStream.close();
23 |
24 | ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(byteOutputStream.toByteArray());
25 | ObjectInputStream inputStream = new ObjectInputStream(byteArrayInputStream);
26 | return (Serializable) inputStream.readObject();
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/sketch/src/main/resources/log4j.properties:
--------------------------------------------------------------------------------
1 | log4j.rootLogger=INFO, STDOUT
2 | log4j.logger.deng=INFO
3 | log4j.appender.STDOUT=org.apache.log4j.ConsoleAppender
4 | log4j.appender.STDOUT.layout=org.apache.log4j.PatternLayout
5 | log4j.appender.STDOUT.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SSS} %C %p - %m%n
--------------------------------------------------------------------------------