├── .gitignore ├── ml ├── pom.xml └── src │ └── main │ ├── resources │ └── log4j.properties │ └── scala │ └── org │ └── dma │ └── sketchml │ └── ml │ ├── SketchML.scala │ ├── algorithm │ ├── GeneralizedLinearModel.scala │ ├── LRModel.scala │ ├── LinearRegModel.scala │ └── SVMModel.scala │ ├── common │ └── Constants.scala │ ├── conf │ └── MLConf.scala │ ├── data │ ├── DataSet.scala │ ├── LabeledData.scala │ └── Parser.scala │ ├── gradient │ ├── DenseDoubleGradient.scala │ ├── DenseFloatGradient.scala │ ├── FixedPointGradient.scala │ ├── Gradient.scala │ ├── Kind.scala │ ├── SketchGradient.scala │ ├── SparseDoubleGradient.scala │ ├── SparseFloatGradient.scala │ └── ZipGradient.scala │ ├── objective │ ├── Adam.scala │ ├── GradientDescent.scala │ └── Loss.scala │ └── util │ ├── Maths.scala │ └── ValidationUtil.scala ├── pom.xml └── sketch ├── pom.xml └── src └── main ├── java └── org │ └── dma │ └── sketchml │ └── sketch │ ├── base │ ├── BinaryEncoder.java │ ├── Int2IntHash.java │ ├── QuantileSketch.java │ ├── Quantizer.java │ ├── SketchMLException.java │ └── VectorCompressor.java │ ├── binary │ ├── BinaryUtils.java │ ├── DeltaAdaptiveEncoder.java │ ├── DeltaBinaryEncoder.java │ └── HuffmanEncoder.java │ ├── common │ └── Constants.java │ ├── hash │ ├── BJHash.java │ ├── BKDRHash.java │ ├── HashFactory.java │ ├── Mix64Hash.java │ └── TWHash.java │ ├── quantization │ ├── QuantileQuantizer.java │ └── UniformQuantizer.java │ ├── sample │ ├── App.java │ ├── DenseVectorCompressor.java │ └── SparseVectorCompressor.java │ ├── sketch │ ├── frequency │ │ ├── FSketchUtils.java │ │ ├── GroupedMinMaxSketch.java │ │ └── MinMaxSketch.java │ └── quantile │ │ ├── HeapQuantileSketch.java │ │ ├── QSketchUtils.java │ │ └── QuantileSketchException.java │ └── util │ ├── Maths.java │ ├── Sort.java │ └── Utils.java └── resources └── log4j.properties /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled class file 2 | *.class 3 | 4 | # Log file 5 | *.log 6 | 7 | # BlueJ files 8 | *.ctxt 9 | 10 | # Mobile Tools for Java (J2ME) 11 | .mtj.tmp/ 12 | 13 | # Package Files # 14 | *.jar 15 | *.war 16 | *.ear 17 | *.zip 18 | *.tar.gz 19 | *.rar 20 | 21 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 22 | hs_err_pid* 23 | 24 | # IntelliJ project files 25 | .idea 26 | *.iml 27 | target/ 28 | out 29 | gen### Java template 30 | 31 | # DS_STORE 32 | *.DS_Store 33 | -------------------------------------------------------------------------------- /ml/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | sketchml 7 | org.dma.sketchml 8 | 1.0.0 9 | ../pom.xml 10 | 11 | 4.0.0 12 | 13 | ml 14 | 15 | 16 | 17 | org.dma.sketchml 18 | sketch 19 | 1.0.0 20 | 21 | 22 | 23 | junit 24 | junit 25 | 4.10 26 | test 27 | 28 | 29 | 30 | 31 | net.jcip 32 | jcip-annotations 33 | 1.0 34 | 35 | 36 | 37 | 38 | com.twitter 39 | chill_2.11 40 | 0.8.0 41 | test 42 | 43 | 44 | 45 | 46 | com.twitter 47 | chill-java 48 | 0.8.0 49 | test 50 | 51 | 52 | 53 | org.scala-lang 54 | scala-library 55 | ${scalaVersion} 56 | provided 57 | 58 | 59 | 60 | org.scala-tools 61 | maven-scala-plugin 62 | 2.11 63 | provided 64 | 65 | 66 | 67 | org.apache.spark 68 | spark-mllib_2.11 69 | 2.2.0 70 | 71 | 72 | 73 | org.apache.spark 74 | spark-mllib-local_2.11 75 | 2.2.0 76 | 77 | 78 | 79 | 80 | 81 | 82 | net.alchim31.maven 83 | scala-maven-plugin 84 | 3.2.2 85 | 86 | 87 | scala-compile-first 88 | process-resources 89 | 90 | compile 91 | 92 | 93 | 94 | scala-test-compile 95 | process-test-resources 96 | 97 | testCompile 98 | 99 | 100 | 101 | 102 | ${scalaVersion} 103 | incremental 104 | 105 | 106 | 107 | org.scala-tools 108 | maven-scala-plugin 109 | 2.15.0 110 | 111 | 112 | 113 | compile 114 | testCompile 115 | 116 | 117 | 118 | -dependencyfile 119 | ${project.build.directory}/.scala_dependencies 120 | 121 | 122 | 123 | 124 | 125 | 126 | org.apache.maven.plugins 127 | maven-surefire-plugin 128 | 129 | 130 | org.scalatest 131 | scalatest-maven-plugin 132 | 133 | 134 | org.apache.maven.plugins 135 | maven-compiler-plugin 136 | 3.3 137 | 138 | ${jdkVersion} 139 | ${jdkVersion} 140 | false 141 | 142 | 143 | 144 | org.apache.maven.plugins 145 | maven-assembly-plugin 146 | 2.5.4 147 | 148 | 149 | jar-with-dependencies 150 | 151 | 152 | 153 | 154 | make-assembly 155 | package 156 | 157 | single 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /ml/src/main/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | log4j.rootLogger=INFO, STDOUT 2 | log4j.logger.deng=INFO 3 | log4j.appender.STDOUT=org.apache.log4j.ConsoleAppender 4 | log4j.appender.STDOUT.layout=org.apache.log4j.PatternLayout 5 | log4j.appender.STDOUT.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SSS} %C %p - %m%n -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/SketchML.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml 2 | 3 | import org.apache.spark.{SparkConf, SparkContext} 4 | import org.dma.sketchml.ml.algorithm._ 5 | import org.dma.sketchml.ml.common.Constants 6 | import org.dma.sketchml.ml.conf.MLConf 7 | 8 | object SketchML { 9 | def main(args: Array[String]): Unit = { 10 | val sparkConf = new SparkConf().setAppName("SketchML") 11 | implicit val sc = SparkContext.getOrCreate(sparkConf) 12 | val mlConf = MLConf(sparkConf) 13 | val model = mlConf.algo match { 14 | case Constants.ML_LOGISTIC_REGRESSION => LRModel(mlConf) 15 | case Constants.ML_SUPPORT_VECTOR_MACHINE => SVMModel(mlConf) 16 | case Constants.ML_LINEAR_REGRESSION => LinearRegModel(mlConf) 17 | case _ => throw new UnknownError("Unsupported algorithm: " + mlConf.algo) 18 | } 19 | 20 | model.loadData() 21 | model.train() 22 | 23 | // TODO: test data 24 | } 25 | 26 | } 27 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/algorithm/GeneralizedLinearModel.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.algorithm 2 | 3 | import org.apache.spark.broadcast.Broadcast 4 | import org.apache.spark.ml.linalg.DenseVector 5 | import org.apache.spark.rdd.RDD 6 | import org.apache.spark.storage.StorageLevel 7 | import org.apache.spark.{SparkContext, SparkEnv} 8 | import org.dma.sketchml.ml.data.{DataSet, Parser} 9 | import org.dma.sketchml.ml.conf.MLConf 10 | import org.dma.sketchml.ml.gradient.Gradient 11 | import org.dma.sketchml.ml.objective.{GradientDescent, Loss} 12 | import org.dma.sketchml.ml.util.ValidationUtil 13 | import org.slf4j.{Logger, LoggerFactory} 14 | 15 | import scala.collection.mutable.ArrayBuffer 16 | import scala.util.Random 17 | 18 | object GeneralizedLinearModel { 19 | private val logger: Logger = LoggerFactory.getLogger(GeneralizedLinearModel.getClass) 20 | 21 | object Model { 22 | var weights: DenseVector = _ 23 | var optimizer: GradientDescent = _ 24 | var loss: Loss = _ 25 | var gradient: Gradient = _ 26 | } 27 | 28 | object Data { 29 | var trainData: DataSet = _ 30 | var validData: DataSet = _ 31 | } 32 | 33 | } 34 | 35 | import GeneralizedLinearModel.Model._ 36 | import GeneralizedLinearModel.Data._ 37 | 38 | abstract class GeneralizedLinearModel(@transient protected val conf: MLConf) extends Serializable { 39 | @transient protected val logger: Logger = GeneralizedLinearModel.logger 40 | 41 | @transient protected implicit val sc: SparkContext = SparkContext.getOrCreate() 42 | @transient protected var executors: RDD[Int] = _ 43 | protected val bcConf: Broadcast[MLConf] = sc.broadcast(conf) 44 | 45 | def loadData(): Unit = { 46 | val startTime = System.currentTimeMillis() 47 | val dataRdd = Parser.loadData(conf.input, conf.format, conf.featureNum, conf.workerNum) 48 | .persist(StorageLevel.MEMORY_AND_DISK) 49 | executors = dataRdd.mapPartitionsWithIndex((partId, _) => { 50 | val exeId = SparkEnv.get.executorId match { 51 | case "driver" => partId 52 | case exeStr => exeStr.toInt 53 | } 54 | Seq(exeId).iterator 55 | }, preservesPartitioning = true) 56 | val (trainDataNum, validDataNum) = dataRdd.mapPartitions(iterator => { 57 | trainData = new DataSet 58 | validData = new DataSet 59 | while (iterator.hasNext) { 60 | if (Random.nextDouble() > bcConf.value.validRatio) 61 | trainData += iterator.next() 62 | else 63 | validData += iterator.next() 64 | } 65 | Seq((trainData.size, validData.size)).iterator 66 | }, preservesPartitioning = true) 67 | .reduce((c1, c2) => (c1._1 + c2._1, c1._2 + c2._2)) 68 | //val rdds = dataRdd.randomSplit(Array(1.0 - validRatio, validRatio)) 69 | //val trainRdd = rdds(0).persist(StorageLevel.MEMORY_AND_DISK) 70 | //val validRdd = rdds(1).persist(StorageLevel.MEMORY_AND_DISK) 71 | //val trainDataNum = trainRdd.count().toInt 72 | //val validDataNum = validRdd.count().toInt 73 | dataRdd.unpersist() 74 | logger.info(s"Load data cost ${System.currentTimeMillis() - startTime} ms, " + 75 | s"$trainDataNum train data, $validDataNum valid data") 76 | } 77 | 78 | protected def initModel(): Unit 79 | 80 | def train(): Unit = { 81 | logger.info(s"Start to train a $getName model") 82 | logger.info(s"Configuration: $conf") 83 | val startTime = System.currentTimeMillis() 84 | initModel() 85 | 86 | val trainLosses = ArrayBuffer[Double](conf.epochNum) 87 | val validLosses = ArrayBuffer[Double](conf.epochNum) 88 | val timeElapsed = ArrayBuffer[Long](conf.epochNum) 89 | val batchNum = Math.ceil(1.0 / conf.batchSpRatio).toInt 90 | for (epoch <- 0 until conf.epochNum) { 91 | logger.info(s"Epoch[$epoch] start training") 92 | trainLosses += trainOneEpoch(epoch, batchNum) 93 | validLosses += validate(epoch) 94 | timeElapsed += System.currentTimeMillis() - startTime 95 | logger.info(s"Epoch[$epoch] done, ${timeElapsed.last} ms elapsed") 96 | } 97 | 98 | logger.info(s"Train done, total cost ${System.currentTimeMillis() - startTime} ms") 99 | logger.info(s"Train loss: [${trainLosses.mkString(", ")}]") 100 | logger.info(s"Valid loss: [${validLosses.mkString(", ")}]") 101 | logger.info(s"Time: [${timeElapsed.mkString(", ")}]") 102 | } 103 | 104 | protected def trainOneEpoch(epoch: Int, batchNum: Int): Double = { 105 | val epochStart = System.currentTimeMillis() 106 | var trainLoss = 0.0 107 | for (batch <- 0 until batchNum) { 108 | val batchLoss = trainOneIteration(epoch, batch) 109 | trainLoss += batchLoss 110 | } 111 | val epochCost = System.currentTimeMillis() - epochStart 112 | logger.info(s"Epoch[$epoch] train cost $epochCost ms, loss=${trainLoss / batchNum}") 113 | trainLoss / batchNum 114 | } 115 | 116 | protected def trainOneIteration(epoch: Int, batch: Int): Double = { 117 | val batchStart = System.currentTimeMillis() 118 | val batchLoss = computeGradient(epoch, batch) 119 | aggregateAndUpdate(epoch, batch) 120 | logger.info(s"Epoch[$epoch] batch $batch train cost " 121 | + s"${System.currentTimeMillis() - batchStart} ms") 122 | batchLoss 123 | } 124 | 125 | protected def computeGradient(epoch: Int, batch: Int): Double = { 126 | val miniBatchGDStart = System.currentTimeMillis() 127 | val (batchSize, objLoss, regLoss) = executors.aggregate(0, 0.0, 0.0)( 128 | seqOp = (_, _) => { 129 | val (grad, batchSize, objLoss ,regLoss) = 130 | optimizer.miniBatchGradientDescent(weights, trainData, loss) 131 | gradient = grad 132 | (batchSize, objLoss, regLoss) 133 | }, 134 | combOp = (c1, c2) => (c1._1 + c2._1, c1._2 + c2._2, c1._3 + c2._3) 135 | ) 136 | val batchLoss = objLoss / batchSize + regLoss / conf.workerNum 137 | logger.info(s"Epoch[$epoch] batch $batch compute gradient cost " 138 | + s"${System.currentTimeMillis() - miniBatchGDStart} ms, " 139 | + s"batch size=$batchSize, batch loss=$batchLoss") 140 | batchLoss 141 | } 142 | 143 | protected def aggregateAndUpdate(epoch: Int, batch: Int): Unit = { 144 | val aggrStart = System.currentTimeMillis() 145 | val sum = Gradient.sum( 146 | conf.featureNum, 147 | executors.map(_ => Gradient.compress(gradient, bcConf.value)).collect() 148 | ) 149 | val grad = Gradient.compress(sum, conf) 150 | grad.timesBy(1.0 / conf.workerNum) 151 | logger.info(s"Epoch[$epoch] batch $batch aggregate gradients cost " 152 | + s"${System.currentTimeMillis() - aggrStart} ms") 153 | 154 | val updateStart = System.currentTimeMillis() 155 | val bcGrad = sc.broadcast(grad) 156 | executors.foreach(_ => optimizer.update(bcGrad.value, weights)) 157 | logger.info(s"Epoch[$epoch] batch $batch update weights cost " 158 | + s"${System.currentTimeMillis() - updateStart} ms") 159 | } 160 | 161 | protected def validate(epoch: Int): Double = { 162 | val validStart = System.currentTimeMillis() 163 | val (sumLoss, truePos, trueNeg, falsePos, falseNeg, validNum) = 164 | executors.aggregate((0.0, 0, 0, 0, 0, 0))( 165 | seqOp = (_, _) => ValidationUtil.calLossPrecision(weights, validData, loss), 166 | combOp = (c1, c2) => (c1._1 + c2._1, c1._2 + c2._2, c1._3 + c2._3, 167 | c1._4 + c2._4, c1._5 + c2._5, c1._6 + c2._6) 168 | ) 169 | val validLoss = sumLoss / validNum 170 | val precision = 1.0 * (truePos + trueNeg) / validNum 171 | val trueRecall = 1.0 * truePos / (truePos + falseNeg) 172 | val falseRecall = 1.0 * trueNeg / (trueNeg + falsePos) 173 | logger.info(s"Epoch[$epoch] validation cost ${System.currentTimeMillis() - validStart} ms, " 174 | + s"valid size=$validNum, loss=$validLoss, precision=$precision, " 175 | + s"trueRecall=$trueRecall, falseRecall=$falseRecall") 176 | validLoss 177 | } 178 | 179 | def getName: String 180 | 181 | } 182 | 183 | 184 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/algorithm/LRModel.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.algorithm 2 | 3 | import org.apache.spark.ml.linalg.DenseVector 4 | import org.dma.sketchml.ml.algorithm.GeneralizedLinearModel.Model._ 5 | import org.dma.sketchml.ml.common.Constants 6 | import org.dma.sketchml.ml.conf.MLConf 7 | import org.dma.sketchml.ml.objective.{Adam, L2LogLoss} 8 | import org.slf4j.{Logger, LoggerFactory} 9 | 10 | object LRModel { 11 | private val logger: Logger = LoggerFactory.getLogger(LRModel.getClass) 12 | 13 | def apply(conf: MLConf): LRModel = new LRModel(conf) 14 | 15 | def getName: String = Constants.ML_LOGISTIC_REGRESSION 16 | } 17 | 18 | class LRModel(_conf: MLConf) extends GeneralizedLinearModel(_conf) { 19 | @transient override protected val logger: Logger = LRModel.logger 20 | 21 | override protected def initModel(): Unit = { 22 | executors.foreach(_ => { 23 | weights = new DenseVector(new Array[Double](bcConf.value.featureNum)) 24 | optimizer = Adam(bcConf.value) 25 | loss = new L2LogLoss(bcConf.value.l2Reg) 26 | }) 27 | } 28 | 29 | override def getName: String = LRModel.getName 30 | } 31 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/algorithm/LinearRegModel.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.algorithm 2 | 3 | import org.apache.spark.ml.linalg.DenseVector 4 | import org.dma.sketchml.ml.algorithm.GeneralizedLinearModel.Model._ 5 | import org.dma.sketchml.ml.common.Constants 6 | import org.dma.sketchml.ml.conf.MLConf 7 | import org.dma.sketchml.ml.objective.{Adam, L2SquareLoss} 8 | import org.slf4j.{Logger, LoggerFactory} 9 | 10 | object LinearRegModel { 11 | private val logger: Logger = LoggerFactory.getLogger(LinearRegModel.getClass) 12 | 13 | def apply(conf: MLConf): LinearRegModel = new LinearRegModel(conf) 14 | 15 | def getName: String = Constants.ML_LINEAR_REGRESSION 16 | } 17 | 18 | class LinearRegModel(_conf: MLConf) extends GeneralizedLinearModel(_conf) { 19 | @transient override protected val logger: Logger = LinearRegModel.logger 20 | 21 | override protected def initModel(): Unit = { 22 | executors.foreach(_ => { 23 | weights = new DenseVector(new Array[Double](bcConf.value.featureNum)) 24 | optimizer = Adam(bcConf.value) 25 | loss = new L2SquareLoss(bcConf.value.l2Reg) 26 | }) 27 | } 28 | 29 | override def getName: String = LinearRegModel.getName 30 | 31 | } 32 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/algorithm/SVMModel.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.algorithm 2 | 3 | import org.apache.spark.ml.linalg.DenseVector 4 | import org.dma.sketchml.ml.algorithm.GeneralizedLinearModel.Model._ 5 | import org.dma.sketchml.ml.common.Constants 6 | import org.dma.sketchml.ml.conf.MLConf 7 | import org.dma.sketchml.ml.objective.{Adam, L2HingeLoss} 8 | import org.slf4j.{Logger, LoggerFactory} 9 | 10 | object SVMModel { 11 | private val logger: Logger = LoggerFactory.getLogger(SVMModel.getClass) 12 | 13 | def apply(conf: MLConf): SVMModel = new SVMModel(conf) 14 | 15 | def getName: String = Constants.ML_SUPPORT_VECTOR_MACHINE 16 | } 17 | 18 | class SVMModel(_conf: MLConf) extends GeneralizedLinearModel(_conf) { 19 | @transient override protected val logger: Logger = SVMModel.logger 20 | 21 | override protected def initModel(): Unit = { 22 | executors.foreach(_ => { 23 | weights = new DenseVector(new Array[Double](bcConf.value.featureNum)) 24 | optimizer = Adam(bcConf.value) 25 | loss = new L2HingeLoss(bcConf.value.l2Reg) 26 | }) 27 | } 28 | 29 | override def getName: String = SVMModel.getName 30 | 31 | } 32 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/common/Constants.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.common 2 | 3 | object Constants { 4 | val ML_LOGISTIC_REGRESSION: String = "LogisticRegression" 5 | val ML_SUPPORT_VECTOR_MACHINE: String = "SupportVectorMachine" 6 | val ML_LINEAR_REGRESSION: String = "LinearRegression" 7 | val FORMAT_LIBSVM: String = "libsvm" 8 | val FORMAT_CSV: String = "csv" 9 | val FORMAT_DUMMY: String = "dummy" 10 | val GRADIENT_COMPRESSOR_NONE: String = "None" 11 | val GRADIENT_COMPRESSOR_FLOAT: String = "Float" 12 | val GRADIENT_COMPRESSOR_SKETCH: String = "Sketch" 13 | val GRADIENT_COMPRESSOR_FIXED_POINT: String = "FixedPoint" 14 | val GRADIENT_COMPRESSOR_ZIP: String = "Zip" 15 | 16 | } 17 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/conf/MLConf.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.conf 2 | 3 | import org.apache.spark.SparkConf 4 | import org.dma.sketchml.ml.common.Constants._ 5 | import org.dma.sketchml.sketch.base.{Quantizer, SketchMLException} 6 | import org.dma.sketchml.sketch.sketch.frequency.{GroupedMinMaxSketch, MinMaxSketch} 7 | 8 | object MLConf { 9 | // ML Conf 10 | val ML_ALGORITHM: String = "spark.sketchml.algo" 11 | val ML_INPUT_PATH: String = "spark.sketchml.input.path" 12 | val ML_INPUT_FORMAT: String = "spark.sketchml.input.format" 13 | //val ML_TEST_DATA_PATH: String = "spark.sketchml.test.path" 14 | //val ML_NUM_CLASS: String = "spark.sketchml.class.num" 15 | //val DEFAULT_ML_NUM_CLASS: Int = 2 16 | val ML_NUM_WORKER: String = "spark.sketchml.worker.num" 17 | val ML_NUM_FEATURE: String = "spark.sketchml.feature.num" 18 | val ML_VALID_RATIO: String = "spark.sketchml.valid.ratio" 19 | val DEFAULT_ML_VALID_RATIO: Double = 0.25 20 | val ML_EPOCH_NUM: String = "spark.sketchml.epoch.num" 21 | val DEFAULT_ML_EPOCH_NUM: Int = 100 22 | val ML_BATCH_SAMPLE_RATIO: String = "spark.sketchml.batch.sample.ratio" 23 | val DEFAULT_ML_BATCH_SAMPLE_RATIO: Double = 0.1 24 | val ML_LEARN_RATE: String = "spark.sketchml.learn.rate" 25 | val DEFAULT_ML_LEARN_RATE: Double = 0.1 26 | val ML_LEARN_DECAY: String = "spark.sketchml.learn.decay" 27 | val DEFAULT_ML_LEARN_DECAY: Double = 0.9 28 | val ML_REG_L1: String = "spark.sketchml.reg.l1" 29 | val DEFAULT_ML_REG_L1: Double = 0.1 30 | val ML_REG_L2: String = "spark.sketchml.reg.l2" 31 | val DEFAULT_ML_REG_L2: Double = 0.1 32 | // Sketch Conf 33 | val SKETCH_GRADIENT_COMPRESSOR: String = "spark.sketchml.gradient.compressor" 34 | val DEFAULT_SKETCH_GRADIENT_COMPRESSOR: String = GRADIENT_COMPRESSOR_SKETCH 35 | val SKETCH_QUANTIZATION_BIN_NUM: String = "spark.sketchml.quantization.bin.num" 36 | val DEFAULT_SKETCH_QUANTIZATION_BIN_NUM: Int = Quantizer.DEFAULT_BIN_NUM 37 | val SKETCH_MINMAXSKETCH_GROUP_NUM: String = "spark.sketchml.minmaxsketch.group.num" 38 | val DEFAULT_SKETCH_MINMAXSKETCH_GROUP_NUM: Int = GroupedMinMaxSketch.DEFAULT_MINMAXSKETCH_GROUP_NUM 39 | val SKETCH_MINMAXSKETCH_ROW_NUM: String = "spark.sketchml.minmaxsketch.row.num" 40 | val DEFAULT_SKETCH_MINMAXSKETCH_ROW_NUM: Int = MinMaxSketch.DEFAULT_MINMAXSKETCH_ROW_NUM 41 | val SKETCH_MINMAXSKETCH_COL_RATIO: String = "spark.sketchml.minmaxsketch.col.ratio" 42 | val DEFAULT_SKETCH_MINMAXSKETCH_COL_RATIO: Double = GroupedMinMaxSketch.DEFAULT_MINMAXSKETCH_COL_RATIO 43 | // FixedPoint Conf 44 | val FIXED_POINT_BIT_NUM: String = "spark.sketchml.fixed.point.bit.num" 45 | val DEFAULT_FIXED_POINT_BIT_NUM = 8 46 | 47 | def apply(sparkConf: SparkConf): MLConf = MLConf( 48 | sparkConf.get(ML_ALGORITHM), 49 | sparkConf.get(ML_INPUT_PATH), 50 | sparkConf.get(ML_INPUT_FORMAT), 51 | sparkConf.get(ML_NUM_WORKER).toInt, 52 | sparkConf.get(ML_NUM_FEATURE).toInt, 53 | sparkConf.getDouble(ML_VALID_RATIO, DEFAULT_ML_VALID_RATIO), 54 | sparkConf.getInt(ML_EPOCH_NUM, DEFAULT_ML_EPOCH_NUM), 55 | sparkConf.getDouble(ML_BATCH_SAMPLE_RATIO, DEFAULT_ML_BATCH_SAMPLE_RATIO), 56 | sparkConf.getDouble(ML_LEARN_RATE, DEFAULT_ML_LEARN_RATE), 57 | sparkConf.getDouble(ML_LEARN_DECAY, DEFAULT_ML_LEARN_DECAY), 58 | sparkConf.getDouble(ML_REG_L1, DEFAULT_ML_REG_L1), 59 | sparkConf.getDouble(ML_REG_L2, DEFAULT_ML_REG_L2), 60 | sparkConf.get(SKETCH_GRADIENT_COMPRESSOR, DEFAULT_SKETCH_GRADIENT_COMPRESSOR), 61 | sparkConf.getInt(SKETCH_QUANTIZATION_BIN_NUM, DEFAULT_SKETCH_QUANTIZATION_BIN_NUM), 62 | sparkConf.getInt(SKETCH_MINMAXSKETCH_GROUP_NUM, DEFAULT_SKETCH_MINMAXSKETCH_GROUP_NUM), 63 | sparkConf.getInt(SKETCH_MINMAXSKETCH_ROW_NUM, DEFAULT_SKETCH_MINMAXSKETCH_ROW_NUM), 64 | sparkConf.getDouble(SKETCH_MINMAXSKETCH_COL_RATIO, DEFAULT_SKETCH_MINMAXSKETCH_COL_RATIO), 65 | sparkConf.getInt(FIXED_POINT_BIT_NUM, DEFAULT_FIXED_POINT_BIT_NUM) 66 | ) 67 | 68 | } 69 | 70 | case class MLConf(algo: String, input: String, format: String, workerNum: Int, 71 | featureNum: Int, validRatio: Double, epochNum: Int,batchSpRatio: Double, 72 | learnRate: Double, learnDecay: Double, l1Reg: Double, l2Reg: Double, 73 | compressor: String, quantBinNum: Int, sketchGroupNum: Int, 74 | sketchRowNum: Int, sketchColRatio: Double, fixedPointBitNum: Int) { 75 | require(Seq(ML_LOGISTIC_REGRESSION, ML_SUPPORT_VECTOR_MACHINE, ML_LINEAR_REGRESSION).contains(algo), 76 | throw new SketchMLException(s"Unsupported algorithm: $algo")) 77 | require(Seq(FORMAT_LIBSVM, FORMAT_CSV, FORMAT_DUMMY).contains(format), 78 | throw new SketchMLException(s"Unrecognizable file format: $format")) 79 | require(Seq(GRADIENT_COMPRESSOR_SKETCH, GRADIENT_COMPRESSOR_FIXED_POINT, GRADIENT_COMPRESSOR_ZIP, 80 | GRADIENT_COMPRESSOR_FLOAT, GRADIENT_COMPRESSOR_NONE).contains(compressor), 81 | throw new SketchMLException(s"Unrecognizable gradient compressor: $compressor")) 82 | 83 | } 84 | 85 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/data/DataSet.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.data 2 | 3 | import scala.collection.mutable.ArrayBuffer 4 | 5 | class DataSet { 6 | private val data = ArrayBuffer[LabeledData]() 7 | private var readIndex = 0 8 | 9 | def size: Int = data.size 10 | 11 | def add(ins: LabeledData): Unit = data += ins 12 | 13 | def get(i: Int): LabeledData = data(i) 14 | 15 | def loopingRead: LabeledData = { 16 | if (readIndex >= size) 17 | readIndex = 0 18 | val ins = data(readIndex) 19 | readIndex += 1 20 | ins 21 | } 22 | 23 | def +=(ins: LabeledData): Unit = add(ins) 24 | 25 | } 26 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/data/LabeledData.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.data 2 | 3 | import org.apache.spark.ml.linalg.Vector 4 | 5 | case class LabeledData(label: Double, feature: Vector) { 6 | 7 | override def toString: String = s"($label $feature)" 8 | } 9 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/data/Parser.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.data 2 | 3 | import org.apache.spark.SparkContext 4 | import org.apache.spark.ml.linalg.Vectors 5 | import org.apache.spark.rdd.RDD 6 | import org.dma.sketchml.ml.common.Constants 7 | import org.dma.sketchml.ml.util.Maths 8 | 9 | object Parser { 10 | def loadData(input: String, format: String, maxDim: Int, numPartition: Int, 11 | negY: Boolean = true)(implicit sc: SparkContext): RDD[LabeledData] = { 12 | val parse: (String, Int, Boolean) => LabeledData = format match { 13 | case Constants.FORMAT_LIBSVM => Parser.parseLibSVM 14 | case Constants.FORMAT_CSV => Parser.parseCSV 15 | case Constants.FORMAT_DUMMY => Parser.parseDummy 16 | case _ => throw new UnknownError("Unknown file format: " + format) 17 | } 18 | sc.textFile(input) 19 | .map(line => parse(line, maxDim, negY)) 20 | .repartition(numPartition) 21 | } 22 | 23 | def parseLibSVM(line: String, maxDim: Int, negY: Boolean = true): LabeledData = { 24 | val splits = line.trim.split(" ") 25 | if (splits.length < 1) 26 | return null 27 | 28 | var y = splits(0).toDouble 29 | if (negY && Math.abs(y - 1) > Maths.EPS) 30 | y = -1 31 | 32 | val nnz = splits.length - 1 33 | val indices = new Array[Int](nnz) 34 | val values = new Array[Double](nnz) 35 | for (i <- 0 until nnz) { 36 | val kv = splits(i + 1).trim.split(":") 37 | indices(i) = kv(0).toInt 38 | values(i) = kv(1).toDouble 39 | } 40 | val x = Vectors.sparse(maxDim, indices, values) 41 | 42 | LabeledData(y, x) 43 | } 44 | 45 | def parseCSV(line: String, maxDim: Int, negY: Boolean = true): LabeledData = { 46 | val splits = line.trim.split(",") 47 | if (splits.length < 1) 48 | return null 49 | 50 | var y = splits(0).toDouble 51 | if (negY && Math.abs(y - 1) > Maths.EPS) 52 | y = -1 53 | 54 | val nnz = splits.length - 1 55 | val values = splits.slice(1, nnz + 1).map(_.trim.toDouble) 56 | val x = Vectors.dense(values) 57 | 58 | LabeledData(y, x) 59 | } 60 | 61 | def parseDummy(line: String, maxDim: Int, negY: Boolean = true): LabeledData = { 62 | val splits = line.trim.split(",") 63 | if (splits.length < 1) 64 | return null 65 | 66 | var y = splits(0).toDouble 67 | if (negY && Math.abs(y - 1) > Maths.EPS) 68 | y = -1 69 | 70 | val nnz = splits.length - 1 71 | val indices = splits.slice(1, nnz + 1).map(_.trim.toInt) 72 | val values = Array.fill(nnz)(1.0) 73 | val x = Vectors.sparse(maxDim, indices, values) 74 | 75 | LabeledData(y, x) 76 | } 77 | 78 | } 79 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/gradient/DenseDoubleGradient.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.gradient 2 | 3 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector} 4 | import org.dma.sketchml.ml.gradient.Kind.Kind 5 | import org.dma.sketchml.ml.util.Maths 6 | 7 | class DenseDoubleGradient(d: Int, val values: Array[Double]) extends Gradient(d) { 8 | def this(d: Int) = this(d, new Array[Double](d)) 9 | 10 | override def plusBy(dense: DenseDoubleGradient): Gradient = { 11 | for (i <- 0 until dim) 12 | values(i) += dense.values(i) 13 | this 14 | } 15 | 16 | override def plusBy(sparse: SparseDoubleGradient): Gradient = { 17 | val k = sparse.indices 18 | val v = sparse.values 19 | for (i <- k.indices) 20 | values(k(i)) += v(i) 21 | this 22 | } 23 | 24 | override def plusBy(dense: DenseFloatGradient): Gradient = { 25 | for (i <- 0 until dim) 26 | values(i) += dense.values(i) 27 | this 28 | } 29 | 30 | override def plusBy(sparse: SparseFloatGradient): Gradient = { 31 | val k = sparse.indices 32 | val v = sparse.values 33 | for (i <- k.indices) 34 | values(k(i)) += v(i) 35 | this 36 | } 37 | 38 | override def plusBy(sketchGrad: SketchGradient): Gradient = plusBy(sketchGrad.toAuto) 39 | 40 | override def plusBy(fpGrad: FixedPointGradient): Gradient = plusBy(fpGrad.toAuto) 41 | 42 | override def plusBy(zipGrad: ZipGradient): Gradient = plusBy(zipGrad.toAuto) 43 | 44 | override def plusBy(dense: DenseVector, x: Double): Gradient = { 45 | val v = dense.values 46 | for (i <- 0 until dim) 47 | values(i) += v(i) * x 48 | this 49 | } 50 | 51 | override def plusBy(sparse: SparseVector, x: Double): Gradient = { 52 | val k = sparse.indices 53 | val v = sparse.values 54 | for (i <- k.indices) 55 | values(k(i)) += v(i) * x 56 | this 57 | } 58 | 59 | override def timesBy(x: Double): Unit = { 60 | for (i <- 0 until dim) 61 | values(i) *= x 62 | } 63 | 64 | override def countNNZ: Int = { 65 | var nnz = 0 66 | for (i <- 0 until dim) 67 | if (Math.abs(values(i)) > Maths.EPS) 68 | nnz += 1 69 | nnz 70 | } 71 | 72 | override def toDense: DenseDoubleGradient = this 73 | 74 | override def toSparse: SparseDoubleGradient = toSparse(countNNZ) 75 | 76 | private def toSparse(nnz: Int): SparseDoubleGradient = { 77 | val k = new Array[Int](nnz) 78 | val v = new Array[Double](nnz) 79 | var i = 0 80 | var j = 0 81 | while (i < dim && j < nnz) { 82 | if (Math.abs(values(i)) > Maths.EPS) { 83 | k(j) = i 84 | v(j) = values(i) 85 | j += 1 86 | } 87 | i += 1 88 | } 89 | new SparseDoubleGradient(dim, k, v) 90 | } 91 | 92 | override def toAuto: Gradient = { 93 | val nnz = countNNZ 94 | if (nnz > dim * 2 / 3) toDense else toSparse(nnz) 95 | } 96 | 97 | override def kind: Kind = Kind.DenseDouble 98 | 99 | } 100 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/gradient/DenseFloatGradient.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.gradient 2 | 3 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector} 4 | import org.dma.sketchml.ml.gradient.Kind.Kind 5 | import org.dma.sketchml.ml.util.Maths 6 | import org.dma.sketchml.sketch.base.SketchMLException 7 | 8 | class DenseFloatGradient(d: Int, val values: Array[Float]) extends Gradient(d) { 9 | def this(d: Int) = this(d, new Array[Float](d)) 10 | 11 | def this(grad: Gradient) { 12 | this(grad.dim, new Array[Float](grad.dim)) 13 | grad.kind match { 14 | case Kind.DenseDouble => fromDense(grad.asInstanceOf[DenseDoubleGradient]) 15 | case Kind.SparseDouble => fromSparse(grad.asInstanceOf[SparseDoubleGradient]) 16 | case _ => throw new SketchMLException(s"Cannot create ${this.kind} from ${grad.kind}") 17 | } 18 | } 19 | 20 | def fromDense(dense: DenseDoubleGradient): Unit = { 21 | val dv = dense.values 22 | for (i <- 0 until dim) 23 | values(i) = dv(i).toFloat 24 | } 25 | 26 | def fromSparse(sparse: SparseDoubleGradient): Unit = { 27 | val k = sparse.indices 28 | val v = sparse.values 29 | for (i <- k.indices) 30 | values(k(i)) = v(i).toFloat 31 | } 32 | 33 | override def plusBy(dense: DenseDoubleGradient): Gradient = { 34 | for (i <- 0 until dim) 35 | values(i) += dense.values(i).toFloat 36 | this 37 | } 38 | 39 | override def plusBy(sparse: SparseDoubleGradient): Gradient = { 40 | val k = sparse.indices 41 | val v = sparse.values 42 | for (i <- k.indices) 43 | values(k(i)) += v(i).toFloat 44 | this 45 | } 46 | 47 | override def plusBy(dense: DenseFloatGradient): Gradient = { 48 | for (i <- 0 until dim) 49 | values(i) += dense.values(i) 50 | this 51 | } 52 | 53 | override def plusBy(sparse: SparseFloatGradient): Gradient = { 54 | val k = sparse.indices 55 | val v = sparse.values 56 | for (i <- k.indices) 57 | values(k(i)) += v(i) 58 | this 59 | } 60 | 61 | override def plusBy(sketchGrad: SketchGradient): Gradient = plusBy(sketchGrad.toAuto) 62 | 63 | override def plusBy(fpGrad: FixedPointGradient): Gradient = plusBy(fpGrad.toAuto) 64 | 65 | override def plusBy(zipGrad: ZipGradient): Gradient = plusBy(zipGrad.toAuto) 66 | 67 | override def plusBy(dense: DenseVector, x: Double): Gradient = { 68 | val v = dense.values 69 | val x_ = x.toFloat 70 | for (i <- 0 until dim) 71 | values(i) += v(i).toFloat * x_ 72 | this 73 | } 74 | 75 | override def plusBy(sparse: SparseVector, x: Double): Gradient = { 76 | val k = sparse.indices 77 | val v = sparse.values 78 | val x_ = x.toFloat 79 | for (i <- k.indices) 80 | values(k(i)) += v(i).toFloat * x_ 81 | this 82 | } 83 | 84 | override def timesBy(x: Double): Unit = { 85 | val x_ = x.toFloat 86 | for (i <- 0 until dim) 87 | values(i) *= x_ 88 | } 89 | 90 | override def countNNZ: Int = { 91 | var nnz = 0 92 | for (i <- 0 until dim) 93 | if (Math.abs(values(i)) > Maths.EPS) 94 | nnz += 1 95 | nnz 96 | } 97 | 98 | override def toDense: DenseDoubleGradient = 99 | new DenseDoubleGradient(dim, values.map(_.toDouble)) 100 | 101 | override def toSparse: SparseDoubleGradient = toSparse(countNNZ) 102 | 103 | private def toSparse(nnz: Int): SparseDoubleGradient = { 104 | val k = new Array[Int](nnz) 105 | val v = new Array[Double](nnz) 106 | var i = 0 107 | var j = 0 108 | while (i < dim && j < nnz) { 109 | if (Math.abs(values(i)) > Maths.EPS) { 110 | k(j) = i 111 | v(j) = values(i) 112 | j += 1 113 | } 114 | i += 1 115 | } 116 | new SparseDoubleGradient(dim, k, v) 117 | } 118 | 119 | override def toAuto: Gradient = { 120 | val nnz = countNNZ 121 | if (nnz > dim * 2 / 3) toDense else toSparse(nnz) 122 | } 123 | 124 | override def kind: Kind = Kind.DenseFloat 125 | 126 | } 127 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/gradient/FixedPointGradient.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.gradient 2 | import java.util 3 | 4 | import breeze.stats.distributions.Bernoulli 5 | import org.dma.sketchml.ml.gradient.Kind.Kind 6 | import org.dma.sketchml.sketch.base.SketchMLException 7 | import org.dma.sketchml.sketch.binary.BinaryUtils 8 | 9 | object FixedPointGradient { 10 | private val bernoulli = new Bernoulli(0.5) 11 | } 12 | 13 | class FixedPointGradient(d: Int, val numBits: Int) extends Gradient(d) { 14 | import FixedPointGradient._ 15 | 16 | require(numBits < 30, s"Bit num out of range: $numBits") 17 | 18 | def this(grad: Gradient, numBits: Int) { 19 | this(grad.dim, numBits) 20 | grad.kind match { 21 | case Kind.DenseDouble => fromDense(grad.asInstanceOf[DenseDoubleGradient]) 22 | case Kind.SparseDouble => fromSparse(grad.asInstanceOf[SparseDoubleGradient]) 23 | case _ => throw new SketchMLException(s"Cannot create ${this.kind} from ${grad.kind}") 24 | } 25 | } 26 | 27 | var size: Int = _ 28 | var norm: Double = _ 29 | var indices: Array[Int] = _ 30 | var bitset: util.BitSet = _ 31 | 32 | def fromDense(dense: DenseDoubleGradient): Unit = fromArray(dense.values) 33 | 34 | def fromSparse(sparse: SparseDoubleGradient): Unit = { 35 | indices = sparse.indices 36 | fromArray(sparse.values) 37 | } 38 | 39 | private def fromArray(values: Array[Double]): Unit = { 40 | size = values.length 41 | norm = 0.0 42 | values.foreach(v => norm += v * v) 43 | norm = Math.sqrt(norm) 44 | bitset = new util.BitSet(numBits * size) 45 | val max = (1 << (numBits - 1)) - 1 46 | val sign = 1 << (numBits - 1) 47 | for (i <- values.indices) { 48 | val sigma = if (bernoulli.draw()) 1 else 0 49 | var x = Math.floor(Math.abs(values(i)) / norm * max).toInt + sigma 50 | if (values(i) < 0) x |= sign 51 | BinaryUtils.setBits(bitset, i * numBits, x, numBits) 52 | } 53 | } 54 | 55 | override def timesBy(x: Double): Unit = norm *= x 56 | 57 | override def countNNZ: Int = size 58 | 59 | override def toDense: DenseDoubleGradient = new DenseDoubleGradient(dim, toArray) 60 | 61 | override def toSparse: SparseDoubleGradient = new SparseDoubleGradient(dim, indices, toArray) 62 | 63 | private def toArray: Array[Double] = { 64 | val values = new Array[Double](size) 65 | val max = (1 << (numBits - 1)) - 1 66 | val mask = max 67 | val sign = 1 << (numBits - 1) 68 | for (i <- 0 until size) { 69 | val x = BinaryUtils.getBits(bitset, i * numBits, numBits) 70 | var v = (x & mask).toDouble / max * norm 71 | if ((x & sign) != 0) v = -v 72 | values(i) = v 73 | } 74 | values 75 | } 76 | 77 | override def toAuto: Gradient = (if (indices == null) toDense else toSparse).toAuto 78 | 79 | override def kind: Kind = Kind.FixedPoint 80 | } 81 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/gradient/Gradient.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.gradient 2 | 3 | import javax.inject.Singleton 4 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector} 5 | import org.dma.sketchml.ml.common.Constants 6 | import org.dma.sketchml.ml.conf.MLConf 7 | import org.dma.sketchml.ml.gradient.Kind.Kind 8 | import org.dma.sketchml.ml.util.Maths 9 | import org.dma.sketchml.sketch.base.SketchMLException 10 | import org.dma.sketchml.sketch.util.Utils 11 | import org.slf4j.{Logger, LoggerFactory} 12 | 13 | object Gradient { 14 | def zero: ZeroGradient = ZeroGradient.getInstance() 15 | 16 | private def logger: Logger = LoggerFactory.getLogger(Gradient.getClass) 17 | 18 | def compress(grad: Gradient, conf: MLConf): Gradient = { 19 | val startTime = System.currentTimeMillis() 20 | val res = conf.compressor match { 21 | case Constants.GRADIENT_COMPRESSOR_SKETCH => 22 | new SketchGradient(grad, conf.quantBinNum, conf.sketchGroupNum, 23 | conf.sketchRowNum, conf.sketchColRatio) 24 | case Constants.GRADIENT_COMPRESSOR_FIXED_POINT => 25 | new FixedPointGradient(grad, conf.fixedPointBitNum) 26 | case Constants.GRADIENT_COMPRESSOR_ZIP => 27 | new ZipGradient(grad, conf.quantBinNum) 28 | case Constants.GRADIENT_COMPRESSOR_FLOAT => 29 | grad.kind match { 30 | case Kind.DenseDouble => new DenseFloatGradient(grad) 31 | case Kind.SparseDouble => SparseFloatGradient(grad) 32 | } 33 | case Constants.GRADIENT_COMPRESSOR_NONE => grad 34 | case _ => throw new SketchMLException( 35 | "Unrecognizable compressor: " + conf.compressor) 36 | } 37 | logger.info(s"Gradient compression from ${grad.kind} to ${res.kind} cost " + 38 | s"${System.currentTimeMillis() - startTime} ms") 39 | // uncomment to evaluate the performance of compression 40 | //evaluateCompression(grad, res) 41 | res 42 | } 43 | 44 | def sum(dim: Int, grads: Array[Gradient]): Gradient = { 45 | require(!grads.exists(_.dim != dim)) 46 | val sum = new DenseDoubleGradient(dim) 47 | grads.foreach(sum.plusBy) 48 | sum.toAuto 49 | } 50 | 51 | def evaluateCompression(origin: Gradient, comp: Gradient): Unit = { 52 | logger.info(s"Evaluating compression from ${origin.kind} to ${comp.kind}, " + 53 | s"sparsity[${origin.countNNZ.toDouble / origin.dim}]") 54 | // distances 55 | val (vOrig, vComp) = origin.kind match { 56 | case Kind.DenseDouble => (origin.asInstanceOf[DenseDoubleGradient].values, comp.toDense.values) 57 | case Kind.SparseDouble => (origin.asInstanceOf[SparseDoubleGradient].values, comp.toSparse.values) 58 | } 59 | logger.info(s"Distances: euclidean[${Maths.euclidean(vOrig, vComp)}], " + 60 | s"cosine[${Maths.cosine(vOrig, vComp)}]") 61 | // size 62 | val sizeOrig = Utils.sizeof(origin) 63 | val sizeComp = Utils.sizeof(comp) 64 | val rate = 1.0 * sizeOrig / sizeComp 65 | logger.info(s"Sizeof gradients: nnz[${vOrig.length}], " + 66 | s"origin[$sizeOrig bytes], comp[$sizeComp bytes], rate[$rate]") 67 | } 68 | } 69 | 70 | abstract class Gradient(val dim: Int) extends Serializable { 71 | require(dim > 0, s"Dimension is non-positive: $dim") 72 | 73 | def plusBy(o: Gradient): Gradient = { 74 | if (o.kind == Kind.ZeroGradient) 75 | this 76 | else { 77 | require(dim == o.dim, s"Adding gradients with " + 78 | s"different dimensions: $dim, ${o.dim}") 79 | o.kind match { 80 | case Kind.DenseDouble => plusBy(o.asInstanceOf[DenseDoubleGradient]) 81 | case Kind.SparseDouble => plusBy(o.asInstanceOf[SparseDoubleGradient]) 82 | case Kind.DenseFloat => plusBy(o.asInstanceOf[DenseFloatGradient]) 83 | case Kind.SparseFloat => plusBy(o.asInstanceOf[SparseFloatGradient]) 84 | case Kind.Sketch => plusBy(o.asInstanceOf[SketchGradient]) 85 | case Kind.FixedPoint => plusBy(o.asInstanceOf[FixedPointGradient]) 86 | case Kind.Zip => plusBy(o.asInstanceOf[ZipGradient]) 87 | case _ => throw new ClassNotFoundException(o.getClass.getName) 88 | } 89 | } 90 | } 91 | 92 | def plusBy(dense: DenseDoubleGradient): Gradient = throw new 93 | UnsupportedOperationException(s"Cannot to add ${dense.kind} to ${this.kind}") 94 | 95 | def plusBy(sparse: SparseDoubleGradient): Gradient = throw new 96 | UnsupportedOperationException(s"Cannot to add ${sparse.kind} to ${this.kind}") 97 | 98 | def plusBy(dense: DenseFloatGradient): Gradient = throw new 99 | UnsupportedOperationException(s"Cannot to add ${dense.kind} to ${this.kind}") 100 | 101 | def plusBy(sparse: SparseFloatGradient): Gradient = throw new 102 | UnsupportedOperationException(s"Cannot to add ${sparse.kind} to ${this.kind}") 103 | 104 | def plusBy(sketchGrad: SketchGradient): Gradient = throw new 105 | UnsupportedOperationException(s"Cannot to add ${sketchGrad.kind} to ${this.kind}") 106 | 107 | def plusBy(fpGrad: FixedPointGradient): Gradient = throw new 108 | UnsupportedOperationException(s"Cannot to add ${fpGrad.kind} to ${this.kind}") 109 | 110 | def plusBy(zipGrad: ZipGradient): Gradient = throw new 111 | UnsupportedOperationException(s"Cannot to add ${zipGrad.kind} to ${this.kind}") 112 | 113 | def plusBy(v: Vector, x: Double): Gradient = { 114 | v match { 115 | case dense: DenseVector => plusBy(dense, x) 116 | case sparse: SparseVector => plusBy(sparse, x) 117 | } 118 | } 119 | 120 | def plusBy(dense: DenseVector, x: Double): Gradient = throw new 121 | UnsupportedOperationException(s"Cannot to add DenseVector to ${this.kind}") 122 | 123 | def plusBy(sparse: SparseVector, x: Double): Gradient = throw new 124 | UnsupportedOperationException(s"Cannot to add SparseVector to ${this.kind}") 125 | 126 | def timesBy(x: Double) 127 | 128 | def countNNZ: Int 129 | 130 | def toDense: DenseDoubleGradient 131 | 132 | def toSparse: SparseDoubleGradient 133 | 134 | def toAuto: Gradient 135 | 136 | def kind: Kind 137 | 138 | def +=(o: Gradient): Gradient = plusBy(o) 139 | 140 | } 141 | 142 | /** 143 | * Singleton object for zero value of gradients 144 | */ 145 | @Singleton 146 | object ZeroGradient { 147 | private val instance = new ZeroGradient() 148 | 149 | def getInstance(): ZeroGradient = instance 150 | } 151 | 152 | class ZeroGradient private extends Gradient(1) { 153 | override def plusBy(o: Gradient): Gradient = o 154 | 155 | override def timesBy(x: Double): Unit = {} 156 | 157 | override def countNNZ: Int = 0 158 | 159 | override def toDense: DenseDoubleGradient = ??? 160 | 161 | override def toSparse: SparseDoubleGradient = ??? 162 | 163 | override def toAuto: Gradient = ??? 164 | 165 | override def kind: Kind = Kind.ZeroGradient 166 | 167 | override def plusBy(dense: DenseDoubleGradient): Gradient = dense 168 | 169 | override def plusBy(sparse: SparseDoubleGradient): Gradient = sparse 170 | 171 | override def plusBy(dense: DenseFloatGradient): Gradient = dense 172 | 173 | override def plusBy(sparse: SparseFloatGradient): Gradient = sparse 174 | 175 | override def plusBy(sketchGrad: SketchGradient): Gradient = sketchGrad 176 | 177 | override def plusBy(fpGrad: FixedPointGradient): Gradient = fpGrad 178 | 179 | override def plusBy(zipGrad: ZipGradient): Gradient = zipGrad 180 | 181 | } 182 | 183 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/gradient/Kind.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.gradient 2 | 3 | object Kind extends Enumeration { 4 | type Kind = Value 5 | val ZeroGradient, DenseDouble, SparseDouble, DenseFloat, SparseFloat, Sketch, FixedPoint, Zip = Value 6 | } 7 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/gradient/SketchGradient.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.gradient 2 | 3 | import org.dma.sketchml.ml.gradient.Kind.Kind 4 | import org.dma.sketchml.sketch.base.SketchMLException 5 | import org.dma.sketchml.sketch.quantization.QuantileQuantizer 6 | import org.dma.sketchml.sketch.sketch.frequency.GroupedMinMaxSketch 7 | 8 | class SketchGradient(d: Int, binNum: Int, groupNum: Int, rowNum: Int, colRatio: Double) extends Gradient(d) { 9 | 10 | def this(grad: Gradient, binNum: Int, groupNum: Int, rowNum: Int, colRatio: Double) { 11 | this(grad.dim, binNum, groupNum, rowNum, colRatio) 12 | grad.kind match { 13 | case Kind.DenseDouble => fromDense(grad.asInstanceOf[DenseDoubleGradient]) 14 | case Kind.SparseDouble => fromSparse(grad.asInstanceOf[SparseDoubleGradient]) 15 | case _ => throw new SketchMLException(s"Cannot create ${this.kind} from ${grad.kind}") 16 | } 17 | } 18 | 19 | private var nnz: Int = 0 20 | var bucketValues: Array[Double] = _ 21 | var bins: Array[Int] = _ 22 | var sketch: GroupedMinMaxSketch = _ 23 | 24 | def fromDense(dense: DenseDoubleGradient): Unit = { 25 | val values = dense.values 26 | val quantizer = new QuantileQuantizer(binNum) 27 | quantizer.quantize(values) 28 | //quantizer.parallelQuantize(values) 29 | bucketValues = quantizer.getValues 30 | bins = quantizer.getBins 31 | sketch = null 32 | nnz = dim 33 | } 34 | 35 | def fromSparse(sparse: SparseDoubleGradient): Unit = { 36 | // 1. quantize into bin indexes 37 | val quantizer = new QuantileQuantizer(binNum) 38 | quantizer.quantize(sparse.values) 39 | //quantizer.parallelQuantize(sparse.values) 40 | bucketValues = quantizer.getValues 41 | // 2. encode bins and keys 42 | sketch = new GroupedMinMaxSketch(groupNum, rowNum, colRatio, quantizer.getBinNum, quantizer.getZeroIdx) 43 | sketch.create(sparse.indices, quantizer.getBins) 44 | bins = null 45 | //sketch.parallelCreate(sparse.indices, quantizer.getBins) 46 | // 3. set nnz 47 | nnz = sparse.indices.length 48 | } 49 | 50 | override def timesBy(x: Double): Unit = { 51 | for (i <- bucketValues.indices) 52 | bucketValues(i) *= x 53 | } 54 | 55 | override def countNNZ: Int = nnz 56 | 57 | override def toDense: DenseDoubleGradient = { 58 | val values = bins.map(bin => bucketValues(bin)) 59 | new DenseDoubleGradient(dim, values) 60 | } 61 | 62 | override def toSparse: SparseDoubleGradient = { 63 | val kb = sketch.restore() 64 | val indices = kb.getLeft 65 | val bins = kb.getRight 66 | val values = bins.map(bin => bucketValues(bin)) 67 | new SparseDoubleGradient(dim, indices, values) 68 | } 69 | 70 | override def toAuto: Gradient = (if (bins != null) toDense else toSparse).toAuto 71 | 72 | override def kind: Kind = Kind.Sketch 73 | } 74 | 75 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/gradient/SparseDoubleGradient.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.gradient 2 | 3 | import org.dma.sketchml.ml.gradient.Kind.Kind 4 | import org.dma.sketchml.ml.util.Maths 5 | 6 | class SparseDoubleGradient(d: Int, val indices: Array[Int], 7 | val values: Array[Double]) extends Gradient(d) { 8 | { 9 | require(indices.length == values.length, 10 | s"Sizes of indices and values not match: ${indices.length} & ${values.length}") 11 | require(indices.head >= 0, s"Negative index: ${indices.head}.") 12 | for (i <- 1 until indices.length) 13 | require(indices(i - 1) < indices(i), s"Indices are not strictly increasing") 14 | require(indices.last < dim, s"Index ${indices.last} out of bounds for gradient of dimension $dim") 15 | } 16 | 17 | //override def plusBy(sparse: SparseDoubleGradient): Gradient = { 18 | // val kv = Maths.add(this.indices, this.values, sparse.indices, sparse.values) 19 | // new SparseDoubleGradient(dim, kv._1, kv._2) 20 | //} 21 | 22 | override def timesBy(x: Double): Unit = { 23 | for (i <- values.indices) 24 | values(i) *= x 25 | } 26 | 27 | override def countNNZ: Int = { 28 | var nnz = 0 29 | for (i <- values.indices) 30 | if (Math.abs(values(i)) > Maths.EPS) 31 | nnz += 1 32 | nnz 33 | } 34 | 35 | override def toDense: DenseDoubleGradient = { 36 | val dense = new Array[Double](dim) 37 | for (i <- values.indices) 38 | if (Math.abs(values(i)) > Maths.EPS) 39 | dense(indices(i)) = values(i) 40 | new DenseDoubleGradient(dim, dense) 41 | } 42 | 43 | override def toSparse: SparseDoubleGradient = this 44 | 45 | override def toAuto: Gradient = { 46 | val nnz = countNNZ 47 | if (nnz > dim * 2 / 3) toDense else toSparse 48 | } 49 | 50 | override def kind: Kind = Kind.SparseDouble 51 | } 52 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/gradient/SparseFloatGradient.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.gradient 2 | 3 | import org.dma.sketchml.ml.gradient.Kind.Kind 4 | import org.dma.sketchml.ml.util.Maths 5 | import org.dma.sketchml.sketch.base.SketchMLException 6 | 7 | object SparseFloatGradient { 8 | def apply(grad: Gradient): SparseFloatGradient = { 9 | val (indices, values) = grad.kind match { 10 | case Kind.DenseDouble => { 11 | val dense = grad.asInstanceOf[DenseDoubleGradient] 12 | ((0 until dense.dim).toArray, dense.values) 13 | } 14 | case Kind.SparseDouble => { 15 | val sparse = grad.asInstanceOf[SparseDoubleGradient] 16 | (sparse.indices, sparse.values) 17 | } 18 | case _ => throw new SketchMLException(s"Cannot create ${Kind.SparseFloat} from ${grad.kind}") 19 | } 20 | new SparseFloatGradient(grad.dim, indices, values.map(_.toFloat)) 21 | } 22 | } 23 | 24 | class SparseFloatGradient(d: Int, val indices: Array[Int], 25 | val values: Array[Float]) extends Gradient(d) { 26 | { 27 | require(indices.length == values.length, 28 | s"Sizes of indices and values not match: ${indices.length} & ${values.length}") 29 | require(indices.head >= 0, s"Negative index: ${indices.head}.") 30 | for (i <- 1 until indices.length) 31 | require(indices(i - 1) < indices(i), s"Indices are not strictly increasing") 32 | require(indices.last < dim, s"Index ${indices.last} out of bounds for gradient of dimension $dim") 33 | } 34 | 35 | override def timesBy(x: Double): Unit = { 36 | val x_ = x.toFloat 37 | for (i <- values.indices) 38 | values(i) *= x_ 39 | } 40 | 41 | override def countNNZ: Int = { 42 | var nnz = 0 43 | for (i <- values.indices) 44 | if (Math.abs(values(i)) > Maths.EPS) 45 | nnz += 1 46 | nnz 47 | } 48 | 49 | override def toDense: DenseDoubleGradient = { 50 | val dense = new Array[Double](dim) 51 | for (i <- values.indices) 52 | if (Math.abs(values(i)) > Maths.EPS) 53 | dense(indices(i)) = values(i) 54 | new DenseDoubleGradient(dim, dense) 55 | } 56 | 57 | override def toSparse: SparseDoubleGradient = 58 | new SparseDoubleGradient(dim, indices, values.map(_.toDouble)) 59 | 60 | override def toAuto: Gradient = { 61 | val nnz = countNNZ 62 | if (nnz > dim * 2 / 3) toDense else toSparse 63 | } 64 | 65 | override def kind: Kind = Kind.SparseFloat 66 | } 67 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/gradient/ZipGradient.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.gradient 2 | 3 | import java.util 4 | import java.util.concurrent.ExecutionException 5 | 6 | import org.dma.sketchml.ml.gradient.Kind.Kind 7 | import org.dma.sketchml.sketch.base.{Quantizer, SketchMLException} 8 | import org.dma.sketchml.sketch.util.Sort 9 | import org.slf4j.{Logger, LoggerFactory} 10 | 11 | class ZipGradient(d: Int, binNum: Int) extends Gradient(d) { 12 | private var size: Int = 0 13 | var indices: Array[Int] = _ 14 | var quantizer: ZipMLQuantizer = _ 15 | 16 | def this(grad: Gradient, binNum: Int) { 17 | this(grad.dim, binNum) 18 | grad.kind match { 19 | case Kind.DenseDouble => fromDense(grad.asInstanceOf[DenseDoubleGradient]) 20 | case Kind.SparseDouble => fromSparse(grad.asInstanceOf[SparseDoubleGradient]) 21 | case _ => throw new SketchMLException(s"Cannot create ${this.kind} from ${grad.kind}") 22 | } 23 | } 24 | 25 | def fromDense(dense: DenseDoubleGradient): Unit = fromArray(dense.values) 26 | 27 | def fromSparse(sparse: SparseDoubleGradient): Unit = { 28 | indices = sparse.indices 29 | fromArray(sparse.values) 30 | } 31 | 32 | private def fromArray(values: Array[Double]): Unit = { 33 | size = values.length 34 | quantizer = new ZipMLQuantizer(binNum) 35 | quantizer.quantize(values) 36 | //quantizer.parallelQuantize(values) 37 | } 38 | 39 | override def timesBy(x: Double): Unit = quantizer.timesBy(x) 40 | 41 | override def countNNZ: Int = size 42 | 43 | override def toDense: DenseDoubleGradient = { 44 | val bucketValues = quantizer.getValues 45 | val values = quantizer.getBins.map(bin => bucketValues(bin)) 46 | new DenseDoubleGradient(dim, values) 47 | } 48 | 49 | override def toSparse: SparseDoubleGradient = { 50 | val bucketValues = quantizer.getValues 51 | val values = quantizer.getBins.map(bin => bucketValues(bin)) 52 | new SparseDoubleGradient(dim, indices, values) 53 | } 54 | 55 | override def toAuto: Gradient = (if (indices == null) toDense else toSparse).toAuto 56 | 57 | override def kind: Kind = Kind.Zip 58 | } 59 | 60 | 61 | object ZipMLQuantizer { 62 | private val logger: Logger = LoggerFactory.getLogger(classOf[ZipMLQuantizer]) 63 | } 64 | 65 | class ZipMLQuantizer(b: Int) extends Quantizer(b) { 66 | import ZipMLQuantizer._ 67 | 68 | def this() = this(Quantizer.DEFAULT_BIN_NUM) 69 | 70 | override def quantize(values: Array[Double]): Unit = { 71 | val startTime = System.currentTimeMillis 72 | n = values.length 73 | // 1. pre-compute the errors 74 | val sortedValues = values.clone 75 | util.Arrays.sort(sortedValues) 76 | val r = new Array[Double](n) 77 | val t = new Array[Double](n) 78 | r(0) = sortedValues(0) 79 | t(0) = sortedValues(0) * sortedValues(0) 80 | for (i <- 1 until n) { 81 | r(i) = r(i - 1) + sortedValues(i) 82 | t(i) = t(i - 1) + sortedValues(i) * sortedValues(i) 83 | } 84 | // 2. find split points 85 | var splitNum = n 86 | val splitIndex = (0 until n).toArray 87 | while (splitNum > binNum) { 88 | val errors = new Array[Double](splitNum) 89 | for (i <- 0 until splitNum / 2) { 90 | val L = splitIndex(2 * i) 91 | val R = (if (2 * i + 2 >= splitNum) n else splitIndex(2 * i + 2)) - 1 92 | val l1Sum = r(R) - (if (L == 0) 0 else r(L - 1)) 93 | val l2Sum = t(R) - (if (L == 0) 0 else t(L - 1)) 94 | val mean = l1Sum / (R - L + 1) 95 | errors(i) = l2Sum + mean * mean * (R - L + 1) - 2 * mean * l1Sum 96 | } 97 | 98 | val thrNum = binNum / 2 - (if (splitNum % 2 == 1) 1 else 0) 99 | val threshold = Sort.selectKthLargest(errors.clone, thrNum) 100 | 101 | var newSplitNum = 0 102 | for (i <- 0 until splitNum / 2) { 103 | if (errors(i) >= threshold) { 104 | splitIndex(newSplitNum) = splitIndex(2 * i); newSplitNum += 1 105 | splitIndex(newSplitNum) = splitIndex(2 * i + 1); newSplitNum += 1 106 | } else { 107 | splitIndex(newSplitNum) = splitIndex(2 * i); newSplitNum += 1 108 | } 109 | } 110 | 111 | if (splitNum % 2 == 1) { 112 | splitIndex(newSplitNum) = splitIndex(splitNum - 1); splitNum += 1 113 | } 114 | splitNum = newSplitNum 115 | } 116 | 117 | min = sortedValues(0) 118 | max = sortedValues(n - 1) 119 | binNum = splitNum 120 | splits = new Array[Double](binNum - 1) 121 | for (i <- 0 until binNum - 1) 122 | splits(i) = sortedValues(splitIndex(i + 1)) 123 | // 2. find the zero index 124 | findZeroIdx() 125 | // 3.find index of each value 126 | quantizeToBins(values) 127 | logger.debug(s"ZipML quantization for $n items cost " + 128 | s"${System.currentTimeMillis - startTime} ms") 129 | } 130 | 131 | @throws[InterruptedException] 132 | @throws[ExecutionException] 133 | override def parallelQuantize(values: Array[Double]): Unit = { 134 | logger.warn(s"ZipML quantization should be sequential") 135 | quantize(values) 136 | } 137 | 138 | override def quantizationType: Quantizer.QuantizationType = ??? 139 | } -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/objective/Adam.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.objective 2 | 3 | import org.apache.spark.ml.linalg.DenseVector 4 | import org.dma.sketchml.ml.conf.MLConf 5 | import org.dma.sketchml.ml.gradient._ 6 | import org.dma.sketchml.ml.util.Maths 7 | import org.slf4j.{Logger, LoggerFactory} 8 | 9 | object Adam { 10 | private val logger: Logger = LoggerFactory.getLogger(Adam.getClass) 11 | 12 | def apply(conf: MLConf): GradientDescent = 13 | new Adam(conf.featureNum, conf.learnRate, conf.learnDecay, conf.batchSpRatio) 14 | } 15 | 16 | class Adam(dim: Int, lr_0: Double, decay: Double, batchSpRatio: Double) 17 | extends GradientDescent(dim, lr_0, decay, batchSpRatio) { 18 | override protected val logger = Adam.logger 19 | 20 | val beta1 = 0.9 21 | val beta2 = 0.999 22 | var beta1_t = 0.9 23 | var beta2_t = 0.999 24 | val m = new Array[Double](dim) 25 | val v = new Array[Double](dim) 26 | 27 | override def update(grad: Gradient, weight: DenseVector): Unit = { 28 | val startTime = System.currentTimeMillis() 29 | if (epoch > 0 && batch == 0) { 30 | beta1_t *= beta1 31 | beta2_t *= beta2 32 | } 33 | update0(grad, weight) 34 | logger.info(s"Update weight cost ${System.currentTimeMillis() - startTime} ms") 35 | } 36 | 37 | private def update0(grad: Gradient, weight: DenseVector): Unit = { 38 | grad match { 39 | case dense: DenseDoubleGradient => update(dense, weight, lr_0) 40 | case sparse: SparseDoubleGradient => update(sparse, weight, lr_0) 41 | case dense: DenseFloatGradient => update(dense, weight, lr_0) 42 | case sparse: SparseFloatGradient => update(sparse, weight, lr_0) 43 | case sketchGrad: SketchGradient => update0(sketchGrad.toAuto, weight) 44 | case fpGrad: FixedPointGradient => update0(fpGrad.toAuto, weight) 45 | case zipGrad: ZipGradient => update0(zipGrad.toAuto, weight) 46 | case _ => throw new ClassNotFoundException(grad.getClass.getName) 47 | } 48 | } 49 | 50 | private def update(grad: DenseDoubleGradient, weight: DenseVector, lr: Double): Unit = { 51 | val g = grad.values 52 | val w = weight.values 53 | for (i <- w.indices) { 54 | val m_t = beta1 * m(i) + (1 - beta1) * g(i) 55 | val v_t = beta2 * v(i) + (1 - beta2) * g(i) * g(i) 56 | val newGrad = (Math.sqrt(1 - beta2_t) * m_t) / ((1 - beta1_t) * (Math.sqrt(v_t) + Maths.EPS)) 57 | w(i) -= newGrad * lr 58 | m(i) = m_t 59 | v(i) = v_t 60 | } 61 | } 62 | 63 | private def update(grad: SparseDoubleGradient, weight: DenseVector, lr: Double): Unit = { 64 | val k = grad.indices 65 | val g = grad.values 66 | val w = weight.values 67 | for (i <- k.indices) { 68 | val dim = k(i) 69 | val grad = g(i) 70 | val m_t = beta1 * m(dim) + (1 - beta1) * grad 71 | val v_t = beta2 * v(dim) + (1 - beta2) * grad * grad 72 | val newGrad = (Math.sqrt(1 - beta2_t) * m_t) / ((1 - beta1_t) * (Math.sqrt(v_t) + Maths.EPS)) 73 | w(dim) -= newGrad * lr 74 | m(dim) = m_t 75 | v(dim) = v_t 76 | } 77 | } 78 | 79 | private def update(grad: DenseFloatGradient, weight: DenseVector, lr: Double): Unit = { 80 | val g = grad.values 81 | val w = weight.values 82 | for (i <- w.indices) { 83 | val m_t = beta1 * m(i) + (1 - beta1) * g(i) 84 | val v_t = beta2 * v(i) + (1 - beta2) * g(i) * g(i) 85 | val newGrad = (Math.sqrt(1 - beta2_t) * m_t) / ((1 - beta1_t) * (Math.sqrt(v_t) + Maths.EPS)) 86 | w(i) -= newGrad * lr 87 | m(i) = m_t 88 | v(i) = v_t 89 | } 90 | } 91 | 92 | private def update(grad: SparseFloatGradient, weight: DenseVector, lr: Double): Unit = { 93 | val k = grad.indices 94 | val g = grad.values 95 | val w = weight.values 96 | for (i <- k.indices) { 97 | val dim = k(i) 98 | val grad = g(i) 99 | val m_t = beta1 * m(dim) + (1 - beta1) * grad 100 | val v_t = beta2 * v(dim) + (1 - beta2) * grad * grad 101 | val newGrad = (Math.sqrt(1 - beta2_t) * m_t) / ((1 - beta1_t) * (Math.sqrt(v_t) + Maths.EPS)) 102 | w(dim) -= newGrad * lr 103 | m(dim) = m_t 104 | v(dim) = v_t 105 | } 106 | } 107 | 108 | 109 | 110 | } 111 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/objective/GradientDescent.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.objective 2 | 3 | import org.apache.spark.ml.linalg.DenseVector 4 | import org.dma.sketchml.ml.conf.MLConf 5 | import org.dma.sketchml.ml.data.DataSet 6 | import org.dma.sketchml.ml.gradient._ 7 | import org.slf4j.{Logger, LoggerFactory} 8 | 9 | object GradientDescent { 10 | private val logger: Logger = LoggerFactory.getLogger(GradientDescent.getClass) 11 | 12 | def apply(conf: MLConf): GradientDescent = 13 | new GradientDescent(conf.featureNum, conf.learnRate, conf.learnDecay, conf.batchSpRatio) 14 | } 15 | 16 | class GradientDescent(dim: Int, lr_0: Double, decay: Double, batchSpRatio: Double) { 17 | protected val logger = GradientDescent.logger 18 | 19 | var epoch: Int = 0 20 | var batch: Int = 0 21 | val batchNum: Double = Math.ceil(1.0 / batchSpRatio).toInt 22 | 23 | def miniBatchGradientDescent(weight: DenseVector, dataSet: DataSet, loss: Loss): (Gradient, Int, Double, Double) = { 24 | val startTime = System.currentTimeMillis() 25 | 26 | val denseGrad = new DenseDoubleGradient(dim) 27 | var objLoss = 0.0 28 | val batchSize = (dataSet.size * batchSpRatio).toInt 29 | for (i <- 0 until batchSize) { 30 | val ins = dataSet.loopingRead 31 | val pre = loss.predict(weight, ins.feature) 32 | val gradScala = loss.grad(pre, ins.label) 33 | denseGrad.plusBy(ins.feature, -1.0 * gradScala) 34 | objLoss += loss.loss(pre, ins.label) 35 | } 36 | val grad = denseGrad.toAuto 37 | grad.timesBy(1.0 / batchSize) 38 | 39 | if (loss.isL1Reg) 40 | l1Reg(grad, 0, loss.getRegParam) 41 | if (loss.isL2Reg) 42 | l2Reg(grad, weight, loss.getRegParam) 43 | val regLoss = loss.getReg(weight) 44 | 45 | logger.info(s"Epoch[$epoch] batch $batch gradient " + 46 | s"cost ${System.currentTimeMillis() - startTime} ms, " 47 | + s"batch size=$batchSize, obj loss=${objLoss / batchSize}, reg loss=$regLoss") 48 | batch += 1 49 | if (batch == batchNum) { epoch += 1; batch = 0 } 50 | (grad, batchSize, objLoss, regLoss) 51 | } 52 | 53 | private def l1Reg(grad: Gradient, alpha: Double, theta: Double): Unit = { 54 | val values = grad match { 55 | case dense: DenseDoubleGradient => dense.values 56 | case sparse: SparseDoubleGradient => sparse.values 57 | case _ => throw new UnsupportedOperationException( 58 | s"Cannot regularize ${grad.kind} kind of gradients") 59 | } 60 | if (values != null) { 61 | for (i <- values.indices) { 62 | if (values(i) >= 0 && values(i) <= theta) 63 | values(i) = (values(i) - alpha) max 0 64 | else if (values(i) < 0 && values(i) >= -theta) 65 | values(i) = (values(i) - alpha) min 0 66 | } 67 | } 68 | } 69 | 70 | private def l2Reg(grad: Gradient, weight: DenseVector, lambda: Double): Unit = { 71 | val w = weight.values 72 | grad match { 73 | case dense: DenseDoubleGradient => { 74 | val v = dense.values 75 | for (i <- v.indices) 76 | v(i) += w(i) * lambda 77 | } 78 | case sparse: SparseDoubleGradient => { 79 | val k = sparse.indices 80 | val v = sparse.values 81 | for (i <- k.indices) 82 | v(i) += w(k(i)) * lambda 83 | } 84 | case _ => throw new UnsupportedOperationException( 85 | s"Cannot regularize ${grad.kind} kind of gradients") 86 | } 87 | } 88 | 89 | def update(grad: Gradient, weight: DenseVector): Unit = { 90 | val startTime = System.currentTimeMillis() 91 | val lr = lr_0 / Math.sqrt(1.0 + decay * epoch) 92 | grad match { 93 | case dense: DenseDoubleGradient => update(dense, weight, lr) 94 | case sparse: SparseDoubleGradient => update(sparse, weight, lr) 95 | case dense: DenseFloatGradient => update(dense, weight, lr) 96 | case sparse: SparseFloatGradient => update(sparse, weight, lr) 97 | case sketchGrad: SketchGradient => update(sketchGrad.toAuto, weight) 98 | case fpGrad: FixedPointGradient => update(fpGrad.toAuto, weight) 99 | case zipGrad: ZipGradient => update(zipGrad.toAuto, weight) 100 | } 101 | logger.info(s"Update weight cost ${System.currentTimeMillis() - startTime} ms") 102 | } 103 | 104 | private def update(grad: DenseDoubleGradient, weight: DenseVector, lr: Double): Unit = { 105 | val g = grad.values 106 | val w = weight.values 107 | for (i <- w.indices) 108 | w(i) -= g(i) * lr 109 | } 110 | 111 | private def update(grad: SparseDoubleGradient, weight: DenseVector, lr: Double): Unit = { 112 | val k = grad.indices 113 | val v = grad.values 114 | val w = weight.values 115 | for (i <- k.indices) 116 | w(k(i)) -= v(i) * lr 117 | } 118 | 119 | private def update(grad: DenseFloatGradient, weight: DenseVector, lr: Double): Unit = { 120 | val g = grad.values 121 | val w = weight.values 122 | for (i <- w.indices) 123 | w(i) -= g(i) * lr 124 | } 125 | 126 | private def update(grad: SparseFloatGradient, weight: DenseVector, lr: Double): Unit = { 127 | val k = grad.indices 128 | val v = grad.values 129 | val w = weight.values 130 | for (i <- k.indices) 131 | w(k(i)) -= v(i) * lr 132 | } 133 | 134 | } 135 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/objective/Loss.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.objective 2 | 3 | import org.apache.spark.ml.linalg.{Vector, Vectors} 4 | import org.dma.sketchml.ml.util.Maths 5 | 6 | trait Loss extends Serializable { 7 | def loss(pre: Double, y: Double): Double 8 | 9 | def grad(pre: Double, y: Double): Double 10 | 11 | def predict(w: Vector, x: Vector): Double 12 | 13 | def isL1Reg: Boolean 14 | 15 | def isL2Reg: Boolean 16 | 17 | def getRegParam: Double 18 | 19 | def getReg(w: Vector): Double 20 | } 21 | 22 | abstract class L1Loss extends Loss { 23 | protected var lambda: Double 24 | 25 | override def isL1Reg: Boolean = this.lambda > Maths.EPS 26 | 27 | override def isL2Reg: Boolean = false 28 | 29 | override def getRegParam: Double = lambda 30 | 31 | override def getReg(w: Vector): Double = { 32 | if (isL1Reg) 33 | Vectors.norm(w, 1) * lambda 34 | else 35 | 0.0 36 | } 37 | } 38 | 39 | abstract class L2Loss extends Loss { 40 | protected var lambda: Double 41 | 42 | def isL1Reg: Boolean = false 43 | 44 | def isL2Reg: Boolean = lambda > Maths.EPS 45 | 46 | override def getRegParam: Double = lambda 47 | 48 | override def getReg(w: Vector): Double = { 49 | if (isL2Reg) 50 | Vectors.norm(w, 2) * lambda 51 | else 52 | 0.0 53 | } 54 | } 55 | 56 | class L1LogLoss(l: Double) extends L1Loss { 57 | override protected var lambda: Double = l 58 | 59 | override def loss(pre: Double, y: Double): Double = { 60 | val z = pre * y 61 | if (z > 18) 62 | Math.exp(-z) 63 | else if (z < -18) 64 | -z 65 | else 66 | Math.log(1 + Math.exp(-z)) 67 | } 68 | 69 | override def grad(pre: Double, y: Double): Double = { 70 | val z = pre * y 71 | if (z > 18) 72 | y * Math.exp(-z) 73 | else if (z < -18) 74 | y 75 | else 76 | y / (1.0 + Math.exp(z)) 77 | } 78 | 79 | override def predict(w: Vector, x: Vector): Double = Maths.dot(w, x) 80 | } 81 | 82 | class L2HingeLoss(l: Double) extends L2Loss { 83 | override protected var lambda: Double = l 84 | 85 | override def loss(pre: Double, y: Double): Double = { 86 | val z = pre * y 87 | if (z < 1.0) 88 | 1.0 - z 89 | else 90 | 0.0 91 | } 92 | 93 | override def grad(pre: Double, y: Double): Double = { 94 | if (pre * y <= 1.0) 95 | y 96 | else 97 | 0.0 98 | } 99 | 100 | override def predict(w: Vector, x: Vector): Double = Maths.dot(w, x) 101 | } 102 | 103 | class L2LogLoss(l: Double) extends L2Loss { 104 | override protected var lambda: Double = l 105 | 106 | override def loss(pre: Double, y: Double): Double = { 107 | val z = pre * y 108 | if (z > 18) 109 | Math.exp(-z) 110 | else if (z < -18) 111 | -z 112 | else 113 | Math.log(1.0 + Math.exp(-z)) 114 | } 115 | 116 | override def grad(pre: Double, y: Double): Double = { 117 | val z = pre * y 118 | if (z > 18) 119 | y * Math.exp(-z) 120 | else if (z < -18) 121 | y 122 | else 123 | y / (1.0 + Math.exp(z)) 124 | } 125 | 126 | override def predict(w: Vector, x: Vector): Double = Maths.dot(w, x) 127 | } 128 | 129 | class L2SquareLoss(l: Double) extends L2Loss { 130 | override protected var lambda: Double = l 131 | 132 | override def loss(pre: Double, y: Double): Double = 0.5 * (pre - y) * (pre - y) 133 | 134 | override def grad(pre: Double, y: Double): Double = y - pre 135 | 136 | override def predict(w: Vector, x: Vector): Double = Maths.dot(w, x) 137 | } -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/util/Maths.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.util 2 | 3 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} 4 | 5 | import scala.collection.mutable.ArrayBuffer 6 | 7 | object Maths { 8 | val EPS = 1e-8 9 | 10 | def add(k1: Array[Int], v1: Array[Double], k2: Array[Int], 11 | v2: Array[Double]): (Array[Int], Array[Double]) = { 12 | val k = ArrayBuffer[Int]() 13 | val v = ArrayBuffer[Double]() 14 | var i = 0 15 | var j = 0 16 | while (i < k1.length && j < k2.length) { 17 | if (k1(i) < k2(j)) { 18 | k += k1(i) 19 | v += v1(i) 20 | i += 1 21 | } else if (k1(i) > k2(j)) { 22 | k += k2(j) 23 | v += v2(j) 24 | j += 1 25 | } else { 26 | k += k1(i) 27 | v += v1(i) + v2(j) 28 | i += 1 29 | j += 1 30 | } 31 | } 32 | (k.toArray, v.toArray) 33 | } 34 | 35 | def dot(a: Vector, b: Vector): Double = { 36 | (a, b) match { 37 | case (a: DenseVector, b: DenseVector) => dot(a, b) 38 | case (a: DenseVector, b: SparseVector) => dot(a, b) 39 | case (a: SparseVector, b: DenseVector) => dot(a, b) 40 | case (a: SparseVector, b: SparseVector) => dot(a, b) 41 | } 42 | } 43 | 44 | def dot(a: DenseVector, b: DenseVector): Double = { 45 | require(a.size == b.size, s"Dot between vectors of size ${a.size} and ${b.size}") 46 | //(a.values, b.values).zipped.map(_*_).sum 47 | val size = a.size 48 | val aValues = a.values 49 | val bValues = b.values 50 | var dot = 0.0 51 | for (i <- 0 until size) { 52 | dot += aValues(i) * bValues(i) 53 | } 54 | dot 55 | } 56 | 57 | def dot(a: DenseVector, b: SparseVector): Double = { 58 | require(a.size == b.size, s"Dot between vectors of size ${a.size} and ${b.size}") 59 | val aValues = a.values 60 | val bIndices = b.indices 61 | val bValues = b.values 62 | val size = b.numActives 63 | var dot = 0.0 64 | for (i <- 0 until size) { 65 | val ind = bIndices(i) 66 | dot += aValues(ind) * bValues(i) 67 | } 68 | dot 69 | } 70 | 71 | def dot(a: SparseVector, b: DenseVector): Double = dot(b, a) 72 | 73 | def dot(a: SparseVector, b: SparseVector): Double = { 74 | require(a.size == b.size, s"Dot between vectors of size ${a.size} and ${b.size}") 75 | val aIndices = a.indices 76 | val aValues = a.values 77 | val aNumActives = a.numActives 78 | val bIndices = b.indices 79 | val bValues = b.values 80 | val bNumActives = b.numActives 81 | var aOff = 0 82 | var bOff = 0 83 | var dot = 0.0 84 | while (aOff < aNumActives && bOff < bNumActives) { 85 | if (aIndices(aOff) < bIndices(bOff)) { 86 | aIndices(aOff) += 1 87 | } else if (aIndices(aOff) > bIndices(bOff)) { 88 | bOff += 1 89 | } else { 90 | dot += aValues(aOff) * bValues(bOff) 91 | aOff += 1 92 | bOff += 1 93 | } 94 | } 95 | dot 96 | } 97 | 98 | def euclidean(a: Array[Double], b: Array[Double]): Double = { 99 | require(a.length == b.length) 100 | (a, b).zipped.map((x, y) => (x - y) * (x - y)).sum 101 | } 102 | 103 | def cosine(a: Array[Double], b: Array[Double]): Double = { 104 | val va = new DenseVector(a) 105 | val vb = new DenseVector(b) 106 | dot(va, vb) / (Vectors.norm(va, 2) * Vectors.norm(vb, 2)) 107 | } 108 | 109 | } 110 | -------------------------------------------------------------------------------- /ml/src/main/scala/org/dma/sketchml/ml/util/ValidationUtil.scala: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.ml.util 2 | 3 | import org.apache.spark.ml.linalg.Vector 4 | import org.dma.sketchml.ml.data.DataSet 5 | import org.dma.sketchml.ml.objective.Loss 6 | import org.dma.sketchml.sketch.util.Sort 7 | import org.slf4j.{Logger, LoggerFactory} 8 | 9 | object ValidationUtil { 10 | private val logger: Logger = LoggerFactory.getLogger(ValidationUtil.getClass) 11 | 12 | def calLossPrecision(weights: Vector, validData: DataSet, loss: Loss): (Double, Int, Int, Int, Int, Int) = { 13 | val validStart = System.currentTimeMillis() 14 | val validNum = validData.size 15 | var validLoss = 0.0 16 | var truePos = 0 // ground truth: positive, prediction: positive 17 | var falsePos = 0 // ground truth: negative, prediction: positive 18 | var trueNeg = 0 // ground truth: negative, prediction: negative 19 | var falseNeg = 0 // ground truth: positive, prediction: negative 20 | 21 | for (i <- 0 until validNum) { 22 | val ins = validData.get(i) 23 | val pre = loss.predict(weights, ins.feature) 24 | if (pre * ins.label > 0) { 25 | if (pre > 0) truePos += 1 26 | else trueNeg += 1 27 | } else if (pre * ins.label < 0) { 28 | if (pre > 0) falsePos += 1 29 | else falseNeg += 1 30 | } 31 | validLoss += loss.loss(pre, ins.label) 32 | } 33 | 34 | val precision = 1.0 * (truePos + trueNeg) / validNum 35 | val trueRecall = 1.0 * truePos / (truePos + falseNeg) 36 | val falseRecall = 1.0 * trueNeg / (trueNeg + falsePos) 37 | logger.info(s"validation cost ${System.currentTimeMillis() - validStart} ms, " 38 | + s"loss=$validLoss, precision=$precision, " 39 | + s"trueRecall=$trueRecall, falseRecall=$falseRecall") 40 | (validLoss, truePos, trueNeg, falsePos, falseNeg, validNum) 41 | } 42 | 43 | def calLossAucPrecision(weights: Vector, validData: DataSet, loss: Loss): (Double, Int, Int, Int, Int, Int) = { 44 | val validStart = System.currentTimeMillis() 45 | val validNum = validData.size 46 | var validLoss = 0.0 47 | val scoresArray = new Array[Double](validNum) 48 | val labelsArray = new Array[Double](validNum) 49 | var truePos = 0 // ground truth: positive, precision: positive 50 | var falsePos = 0 // ground truth: negative, precision: positive 51 | var trueNeg = 0 // ground truth: negative, precision: negative 52 | var falseNeg = 0 // ground truth: positive, precision: negative 53 | 54 | for (i <- 0 until validNum) { 55 | val ins = validData.get(i) 56 | val pre = loss.predict(weights, ins.feature) 57 | if (pre * ins.label > 0) { 58 | if (pre > 0) truePos += 1 59 | else trueNeg += 1 60 | } else if (pre * ins.label < 0) { 61 | if (pre > 0) falsePos += 1 62 | else falseNeg += 1 63 | } 64 | scoresArray(i) = pre 65 | labelsArray(i) = ins.label 66 | validLoss += loss.loss(pre, ins.label) 67 | } 68 | 69 | Sort.quickSort(scoresArray, labelsArray, 0, scoresArray.length) 70 | var M = 0L 71 | var N = 0L 72 | for (i <- 0 until validNum) { 73 | if (labelsArray(i) == 1) 74 | M += 1 75 | else 76 | N += 1 77 | } 78 | var sigma = 0.0 79 | for (i <- M + N - 1 to 0 by -1) { 80 | if (labelsArray(i.toInt) == 1.0) 81 | sigma += i 82 | } 83 | val aucResult = (sigma - (M + 1) * M / 2) / M / N 84 | 85 | val precision = 1.0 * (truePos + trueNeg) / validNum 86 | val trueRecall = 1.0 * truePos / (truePos + falseNeg) 87 | val falseRecall = 1.0 * trueNeg / (trueNeg + falsePos) 88 | 89 | logger.info(s"validation cost ${System.currentTimeMillis() - validStart} ms, " 90 | + s"loss=$validLoss, auc=$aucResult, precision=$precision, " 91 | + s"trueRecall=$trueRecall, falseRecall=$falseRecall") 92 | (validLoss, truePos, trueNeg, falsePos, falseNeg, validNum) 93 | } 94 | } -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | org.dma.sketchml 8 | sketchml 9 | pom 10 | 1.0.0 11 | 12 | sketch 13 | ml 14 | 15 | 16 | 17 | 1.8 18 | 2.11.7 19 | 1.7.25 20 | 21 | 22 | 23 | 24 | org.slf4j 25 | slf4j-log4j12 26 | ${slf4jVersion} 27 | 28 | 29 | 30 | 31 | src/main/java 32 | src/test/java 33 | 34 | 35 | org.apache.maven.plugins 36 | maven-compiler-plugin 37 | 3.3 38 | 39 | ${jdkVersion} 40 | ${jdkVersion} 41 | false 42 | 43 | 44 | 45 | org.apache.maven.plugins 46 | maven-assembly-plugin 47 | 2.5.4 48 | 49 | 50 | jar-with-dependencies 51 | 52 | 53 | 54 | 55 | make-assembly 56 | package 57 | 58 | single 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /sketch/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | sketchml 7 | org.dma.sketchml 8 | 1.0.0 9 | ../pom.xml 10 | 11 | 4.0.0 12 | 13 | sketch 14 | 15 | 16 | 17 | it.unimi.dsi 18 | fastutil 19 | 7.1.0 20 | 21 | 22 | 23 | org.apache.commons 24 | commons-lang3 25 | 3.0 26 | 27 | 28 | 29 | 30 | 31 | 32 | org.apache.maven.plugins 33 | maven-compiler-plugin 34 | 3.3 35 | 36 | ${jdkVersion} 37 | ${jdkVersion} 38 | false 39 | 40 | 41 | 42 | org.apache.maven.plugins 43 | maven-assembly-plugin 44 | 2.5.4 45 | 46 | 47 | jar-with-dependencies 48 | 49 | 50 | 51 | 52 | make-assembly 53 | package 54 | 55 | single 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/base/BinaryEncoder.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.base; 2 | 3 | import java.io.Serializable; 4 | import java.util.stream.IntStream; 5 | 6 | public interface BinaryEncoder extends Serializable { 7 | void encode(int[] values); 8 | 9 | int[] decode(); 10 | 11 | } 12 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/base/Int2IntHash.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.base; 2 | 3 | import java.io.Serializable; 4 | 5 | public abstract class Int2IntHash implements Serializable { 6 | protected int size; 7 | 8 | public Int2IntHash(int size) { 9 | this.size = size; 10 | } 11 | 12 | public abstract int hash(int key); 13 | 14 | public abstract Int2IntHash clone(); 15 | 16 | public int getSize() { 17 | return size; 18 | } 19 | 20 | public void setSize(int size) { 21 | this.size = size; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/base/QuantileSketch.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.base; 2 | 3 | 4 | import java.io.Serializable; 5 | 6 | public abstract class QuantileSketch implements Serializable { 7 | protected long n; // total number of data items appeared 8 | protected long estimateN; // estimated total number of data items there will be, 9 | // if not -1, sufficient space will be allocated at once 10 | 11 | protected double minValue; 12 | protected double maxValue; 13 | 14 | public QuantileSketch(long estimateN) { 15 | this.estimateN = estimateN > 0 ? estimateN : -1L; 16 | } 17 | 18 | public QuantileSketch() { 19 | this(-1L); 20 | } 21 | 22 | public abstract void reset(); 23 | 24 | public abstract void update(double value); 25 | 26 | public abstract void merge(QuantileSketch other); 27 | 28 | public abstract double getQuantile(double fraction); 29 | 30 | public abstract double[] getQuantiles(double[] fractions); 31 | 32 | public abstract double[] getQuantiles(int evenPartition); 33 | 34 | public boolean isEmpty() { 35 | return n == 0; 36 | } 37 | 38 | public long getN() { 39 | return n; 40 | } 41 | 42 | public long getEstimateN() { 43 | return estimateN; 44 | } 45 | 46 | public double getMinValue() { 47 | return minValue; 48 | } 49 | 50 | public double getMaxValue() { 51 | return maxValue; 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/base/Quantizer.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.base; 2 | 3 | import org.dma.sketchml.sketch.common.Constants; 4 | import org.dma.sketchml.sketch.quantization.QuantileQuantizer; 5 | import org.dma.sketchml.sketch.quantization.UniformQuantizer; 6 | import org.slf4j.Logger; 7 | import org.slf4j.LoggerFactory; 8 | 9 | import java.io.IOException; 10 | import java.io.ObjectInputStream; 11 | import java.io.ObjectOutputStream; 12 | import java.io.Serializable; 13 | import java.util.concurrent.Callable; 14 | import java.util.concurrent.ExecutionException; 15 | import java.util.concurrent.ExecutorService; 16 | import java.util.concurrent.Future; 17 | 18 | public abstract class Quantizer implements Serializable { 19 | public static Logger LOG = LoggerFactory.getLogger(Quantizer.class); 20 | 21 | protected int binNum; 22 | protected int n; 23 | protected double[] splits; 24 | protected int zeroIdx; 25 | protected double min; 26 | protected double max; 27 | 28 | protected int[] bins; 29 | public static final int DEFAULT_BIN_NUM = 256; 30 | 31 | public Quantizer(int binNum) { 32 | this.binNum = binNum; 33 | } 34 | 35 | public abstract void quantize(double[] values); 36 | 37 | public abstract void parallelQuantize(double[] values) throws InterruptedException, ExecutionException; 38 | 39 | public double[] getValues() { 40 | double[] res = new double[binNum]; 41 | int splitNum = binNum - 1; 42 | res[0] = 0.5 * (min + splits[0]); 43 | for (int i = 1; i < splitNum; i++) 44 | res[i] = 0.5 * (splits[i - 1] + splits[i]); 45 | res[splitNum] = 0.5 * (splits[splitNum - 1] + max); 46 | return res; 47 | } 48 | 49 | public int indexOf(double x) { 50 | if (x < splits[0]) { 51 | return 0; 52 | } else if (x >= splits[binNum - 2]) { 53 | return binNum - 1; 54 | } else { 55 | int l = zeroIdx, r = zeroIdx; 56 | if (x < 0.0) l = 0; 57 | else r = binNum - 2; 58 | while (l + 1 < r) { 59 | int mid = (l + r) >> 1; 60 | if (splits[mid] > x) { 61 | if (mid == 0 || splits[mid - 1] <= x) 62 | return mid; 63 | else 64 | r = mid; 65 | } else { 66 | l = mid; 67 | } 68 | } 69 | int mid = (l + r) >> 1; 70 | return splits[mid] <= x ? mid + 1 : mid; 71 | } 72 | } 73 | 74 | protected void findZeroIdx() { 75 | if (min > 0.0) 76 | zeroIdx = 0; 77 | else if (max < 0.0) 78 | zeroIdx = binNum - 1; 79 | else { 80 | int t = 0; 81 | while (t < binNum - 1 && splits[t] < 0.0) 82 | t++; 83 | zeroIdx = t; 84 | } 85 | } 86 | 87 | protected void quantizeToBins(double[] values) { 88 | int size = values.length; 89 | bins = new int[size]; 90 | for (int i = 0; i < size; i++) 91 | bins[i] = indexOf(values[i]); 92 | } 93 | 94 | protected void parallelQuantizeToBins(double[] values) throws InterruptedException, ExecutionException { 95 | int size = values.length; 96 | int threadNum = Constants.Parallel.getParallelism(); 97 | ExecutorService threadPool = Constants.Parallel.getThreadPool(); 98 | Future[] futures = new Future[threadNum]; 99 | bins = new int[size]; 100 | for (int i = 0; i < threadNum; i++) { 101 | int threadId = i; 102 | futures[threadId] = threadPool.submit(new Callable() { 103 | @Override 104 | public Void call() throws Exception { 105 | int elementPerThread = n / threadNum; 106 | int from = threadId * elementPerThread; 107 | int to = threadId + 1 == threadNum ? size : from + elementPerThread; 108 | for (int itemId = from; itemId < to; itemId++) 109 | bins[itemId] = indexOf(values[itemId]); 110 | return null; 111 | } 112 | }); 113 | } 114 | for (int i = 0; i < threadNum; i++) { 115 | futures[i].get(); 116 | } 117 | } 118 | 119 | public void timesBy(double x) { 120 | min *= x; 121 | max *= x; 122 | for (int i = 0; i < splits.length; i++) 123 | splits[i] *= x; 124 | } 125 | 126 | public static Quantizer newQuantizer(Quantizer.QuantizationType type, int binNum) { 127 | switch (type) { 128 | case QUANTILE: 129 | return new QuantileQuantizer(binNum); 130 | case UNIFORM: 131 | return new UniformQuantizer(binNum); 132 | default: 133 | throw new SketchMLException( 134 | "Unrecognizable quantization type: " + type); 135 | } 136 | } 137 | 138 | public enum QuantizationType { 139 | UNIFORM("UNIFORM"), 140 | QUANTILE("QUANTILE"); 141 | 142 | private final String type; 143 | 144 | QuantizationType(String type) { 145 | this.type = type; 146 | } 147 | 148 | @Override 149 | public String toString() { 150 | return type; 151 | } 152 | } 153 | 154 | public abstract QuantizationType quantizationType(); 155 | 156 | public int getBinNum() { 157 | return binNum; 158 | } 159 | 160 | public int getN() { 161 | return n; 162 | } 163 | 164 | public double[] getSplits() { 165 | return splits; 166 | } 167 | 168 | public int[] getBins() { 169 | return bins; 170 | } 171 | 172 | public int getZeroIdx() { 173 | return zeroIdx; 174 | } 175 | 176 | public double getMax() { 177 | return max; 178 | } 179 | 180 | public double getMin() { 181 | return min; 182 | } 183 | 184 | private void writeObject(ObjectOutputStream oos) throws IOException { 185 | oos.writeInt(binNum); 186 | oos.writeInt(n); 187 | for (double split : splits) 188 | oos.writeDouble(split); 189 | oos.writeInt(zeroIdx); 190 | oos.writeDouble(min); 191 | oos.writeDouble(max); 192 | oos.writeInt(bins.length); 193 | if (binNum <= 256) { 194 | for (int bin : bins) 195 | oos.writeByte(bin + Byte.MIN_VALUE); 196 | } else if (binNum <= 65536) { 197 | for (int bin : bins) 198 | oos.writeShort(bin + Short.MIN_VALUE); 199 | } else { 200 | for (int bin : bins) 201 | oos.writeInt(bin); 202 | } 203 | } 204 | 205 | private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { 206 | binNum = ois.readInt(); 207 | n = ois.readInt(); 208 | int splitNum = binNum - 1; 209 | splits = new double[splitNum]; 210 | for (int i = 0; i < splitNum; i++) 211 | splits[i] = ois.readDouble(); 212 | zeroIdx = ois.readInt(); 213 | min = ois.readDouble(); 214 | max = ois.readDouble(); 215 | bins = new int[ois.readInt()]; 216 | if (binNum <= 256) { 217 | for (int i = 0; i < bins.length; i++) 218 | bins[i] = ((int) ois.readByte()) - Byte.MIN_VALUE; 219 | } else if (binNum <= 65536) { 220 | for (int i = 0; i < bins.length; i++) 221 | bins[i] = ((int) ois.readShort()) - Short.MIN_VALUE; 222 | } else { 223 | for (int i = 0; i < bins.length; i++) 224 | bins[i] = ois.readInt(); 225 | } 226 | } 227 | 228 | } 229 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/base/SketchMLException.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.base; 2 | 3 | public class SketchMLException extends RuntimeException { 4 | public SketchMLException(String message) { 5 | super(message); 6 | } 7 | 8 | public SketchMLException(Throwable cause) { 9 | super(cause); 10 | } 11 | 12 | public SketchMLException(String message, Throwable cause) { 13 | super(message, cause); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/base/VectorCompressor.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.base; 2 | 3 | import org.apache.commons.lang3.tuple.Pair; 4 | 5 | import java.io.IOException; 6 | import java.io.Serializable; 7 | import java.util.concurrent.ExecutionException; 8 | 9 | public interface VectorCompressor extends Serializable { 10 | void compressDense(double[] values); 11 | 12 | void compressSparse(int[] keys, double[] values); 13 | 14 | void parallelCompressDense(double[] values) throws InterruptedException, ExecutionException; 15 | 16 | void parallelCompressSparse(int[] keys, double[] values) throws InterruptedException, ExecutionException; 17 | 18 | double[] decompressDense(); 19 | 20 | Pair decompressSparse(); 21 | 22 | void timesBy(double x); 23 | 24 | double size(); 25 | 26 | int memoryBytes() throws IOException; 27 | } 28 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/binary/BinaryUtils.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.binary; 2 | 3 | import java.util.BitSet; 4 | 5 | public class BinaryUtils { 6 | public static void setBits(BitSet bitSet, int offset, int value, int numBits) { 7 | for (int i = numBits - 1; i >= 0; i--) { 8 | int t = value - (1 << i); 9 | if (t >= 0) { 10 | bitSet.set(offset + numBits - 1 - i); 11 | value = t; 12 | } 13 | } 14 | } 15 | 16 | public static void setBytes(BitSet bitSet, int offset, int value, int numBytes) { 17 | setBits(bitSet, offset, value, numBytes * 8); 18 | } 19 | 20 | public static int getBits(BitSet bitSet, int offset, int numBits) { 21 | int res = 0; 22 | for (int i = 0; i < numBits; i++) { 23 | res <<= 1; 24 | res |= bitSet.get(offset + i) ? 1 : 0; 25 | if (bitSet.get(offset + i)) 26 | res |= 1; 27 | } 28 | return res; 29 | } 30 | 31 | public static int getBytes(BitSet bitSet, int offset, int numBytes) { 32 | return getBits(bitSet, offset, numBytes * 8); 33 | } 34 | 35 | public static String bits2String(int value, int numBits) { 36 | StringBuilder sb = new StringBuilder(); 37 | for (int i = numBits - 1; i >= 0; i--) { 38 | int t = value - (1 << i); 39 | if (t < 0) { 40 | sb.append("0"); 41 | } else { 42 | sb.append("1"); 43 | value = t; 44 | } 45 | } 46 | return sb.toString(); 47 | } 48 | 49 | public static String bits2String(BitSet bitset, int from, int length) { 50 | StringBuilder sb = new StringBuilder(); 51 | for (int i = from; i < from + length; i++) 52 | sb.append(bitset.get(i) ? 1 : 0); 53 | return sb.toString(); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/binary/DeltaAdaptiveEncoder.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.binary; 2 | 3 | import org.dma.sketchml.sketch.base.BinaryEncoder; 4 | import org.dma.sketchml.sketch.util.Maths; 5 | import org.slf4j.Logger; 6 | import org.slf4j.LoggerFactory; 7 | 8 | import java.io.IOException; 9 | import java.io.ObjectInputStream; 10 | import java.io.ObjectOutputStream; 11 | import java.util.BitSet; 12 | 13 | public class DeltaAdaptiveEncoder implements BinaryEncoder { 14 | private static final Logger LOG = LoggerFactory.getLogger(DeltaAdaptiveEncoder.class); 15 | 16 | private int size; 17 | private int numIntervals; // how many number of intervals it splits [0, 31] 18 | // should be exponential to 2 19 | private boolean flagKind; // whether the number of flag bits is dynamic to different interval 20 | private BitSet deltaBits; 21 | private BitSet flagBits; 22 | 23 | private void calOptimalIntervals(double[] prob) { 24 | double optBitsPerKey = 32.0; 25 | numIntervals = 1; 26 | flagKind = false; 27 | for (int m = 2; m <= 16; m *= 2) { 28 | double[] intervalProb = new double[m]; 29 | int b = 32 / m; 30 | double sum = 0.0; 31 | for (int i = 0; i < m; i++) { 32 | for (int j = 0; j < b; j++) 33 | intervalProb[i] += prob[i * b + j]; 34 | sum += (i + 1) * intervalProb[i]; 35 | } 36 | // all flags have the same number of bits 37 | double t1 = sum * b + Maths.log2nlz(m); 38 | if (t1 < optBitsPerKey) { 39 | optBitsPerKey = t1; 40 | numIntervals = m; 41 | flagKind = false; 42 | } 43 | // one bit for each interval 44 | double t2 = sum * (b + 1) + 1; 45 | if (t2 < optBitsPerKey) { 46 | optBitsPerKey = t2; 47 | numIntervals = m; 48 | flagKind = true; 49 | } 50 | } 51 | } 52 | 53 | @Override 54 | public void encode(int[] values) { 55 | size = values.length; 56 | // 1. get probabilities of each range [2^i, 2^(i+1)) 57 | int[] delta = new int[size]; 58 | int[] bitsNeeded = new int[size]; 59 | double[] prob = new double[32]; 60 | delta[0] = values[0]; 61 | if (delta[0] == 0) 62 | bitsNeeded[0] = 1; 63 | else 64 | bitsNeeded[0] = Maths.log2nlz(delta[0]) + 1; 65 | prob[bitsNeeded[0]]++; 66 | for (int i = 1; i < size; i++) { 67 | delta[i] = values[i] - values[i - 1]; 68 | bitsNeeded[i] = Maths.log2nlz(delta[i]) + 1; 69 | prob[bitsNeeded[i]]++; 70 | } 71 | for (int i = 0; i < prob.length; i++) 72 | prob[i] /= size; 73 | // 2. get the optimal number of intervals, and the kind of flag bits 74 | calOptimalIntervals(prob); 75 | // 3. encode deltas 76 | deltaBits = new BitSet(); 77 | flagBits = new BitSet(); 78 | int bitsPerInterval = 32 / numIntervals; 79 | int bitsShift = Maths.log2nlz(bitsPerInterval); 80 | int flagOffset = 0, deltaOffset = 0; 81 | if (!flagKind) { 82 | int numBitsPerFlag = Maths.log2nlz(numIntervals); 83 | for (int i = 0; i < size; i++) { 84 | // ceil(bitsNeeded / bitsPerInterval) 85 | int intervalNeeded = (bitsNeeded[i] + bitsPerInterval - 1) >> bitsShift; 86 | // set flag 87 | BinaryUtils.setBits(flagBits, flagOffset, intervalNeeded - 1, numBitsPerFlag); 88 | flagOffset += numBitsPerFlag; 89 | // set delta 90 | BinaryUtils.setBits(deltaBits, deltaOffset, delta[i], bitsPerInterval * intervalNeeded); 91 | deltaOffset += bitsPerInterval * intervalNeeded; 92 | } 93 | } else { 94 | int[] flagCandidates = new int[numIntervals + 1]; 95 | for (int i = 1; i <= numIntervals; i++) { 96 | // 0b1110 = 0b10000 - 2 97 | flagCandidates[i] = (1 << (i + 1)) - 2; 98 | } 99 | for (int i = 0; i < size; i++) { 100 | // ceil(bitsNeeded / bitsPerInterval) 101 | int intervalNeeded = (bitsNeeded[i] + bitsPerInterval - 1) >> bitsShift; 102 | // set flag 103 | BinaryUtils.setBits(flagBits, flagOffset, flagCandidates[intervalNeeded], intervalNeeded + 1); 104 | flagOffset += intervalNeeded + 1; 105 | // set delta 106 | BinaryUtils.setBits(deltaBits, deltaOffset, delta[i], bitsPerInterval * intervalNeeded); 107 | deltaOffset += bitsPerInterval * intervalNeeded; 108 | } 109 | } 110 | //LOG.info(String.format("BitsPerKey[%f], flag[%f], delta[%f]", (flagOffset + deltaOffset) * 1. / size, 111 | // flagOffset * 1. / size, deltaOffset * 1. / size)); 112 | } 113 | 114 | @Override 115 | public int[] decode() { 116 | int[] res = new int[size]; 117 | int bitsPerInterval = 32 / numIntervals; 118 | int flagOffset = 0, deltaOffset = 0, prev = 0;; 119 | if (!flagKind) { 120 | int numBitsPerFlag = Maths.log2nlz(numIntervals); 121 | for (int i = 0; i < size; i++) { 122 | // get flag 123 | int intervalNeeded = BinaryUtils.getBits(flagBits, flagOffset, numBitsPerFlag) + 1; 124 | flagOffset += numBitsPerFlag; 125 | // get delta 126 | int delta = BinaryUtils.getBits(deltaBits, deltaOffset, bitsPerInterval * intervalNeeded); 127 | deltaOffset += bitsPerInterval * intervalNeeded; 128 | // set value 129 | res[i] = prev + delta; 130 | prev = res[i]; 131 | } 132 | } else { 133 | for (int i = 0; i < size; i++) { 134 | // get flag 135 | int intervalNeeded = 0; 136 | while (flagBits.get(flagOffset++)) intervalNeeded++; 137 | // get delta 138 | int delta = BinaryUtils.getBits(deltaBits, deltaOffset, bitsPerInterval * intervalNeeded); 139 | deltaOffset += bitsPerInterval * intervalNeeded; 140 | // set value 141 | res[i] = prev + delta; 142 | prev = res[i]; 143 | } 144 | } 145 | return res; 146 | } 147 | 148 | private void writeObject(ObjectOutputStream oos) throws IOException { 149 | oos.writeInt(size); 150 | oos.writeInt(numIntervals); 151 | oos.writeBoolean(flagKind); 152 | if (flagBits == null) { 153 | oos.writeInt(0); 154 | } else { 155 | long[] flags = flagBits.toLongArray(); 156 | oos.writeInt(flags.length); 157 | for (long l : flags) { 158 | oos.writeLong(l); 159 | } 160 | } 161 | if (deltaBits == null) { 162 | oos.writeInt(0); 163 | } else { 164 | long[] delta = deltaBits.toLongArray(); 165 | oos.writeInt(delta.length); 166 | for (long l : delta) { 167 | oos.writeLong(l); 168 | } 169 | } 170 | } 171 | 172 | private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { 173 | size = ois.readInt(); 174 | numIntervals = ois.readInt(); 175 | flagKind = ois.readBoolean(); 176 | int flagsLength = ois.readInt(); 177 | long[] flags = new long[flagsLength]; 178 | for (int i = 0; i < flagsLength; i++) { 179 | flags[i] = ois.readLong(); 180 | } 181 | flagBits = BitSet.valueOf(flags); 182 | int deltaLength = ois.readInt(); 183 | long[] delta = new long[deltaLength]; 184 | for (int i = 0; i < deltaLength; i++) { 185 | delta[i] = ois.readLong(); 186 | } 187 | deltaBits = BitSet.valueOf(delta); 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/binary/DeltaBinaryEncoder.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.binary; 2 | 3 | import org.dma.sketchml.sketch.base.BinaryEncoder; 4 | import org.dma.sketchml.sketch.base.SketchMLException; 5 | import org.slf4j.Logger; 6 | import org.slf4j.LoggerFactory; 7 | 8 | import java.io.IOException; 9 | import java.io.ObjectInputStream; 10 | import java.io.ObjectOutputStream; 11 | import java.util.BitSet; 12 | 13 | /** 14 | * This is the special case for DeltaAdaptiveEncoder 15 | * where numIntervals equals to 4 and number of flag bits is constant 16 | * 17 | * */ 18 | public class DeltaBinaryEncoder implements BinaryEncoder { 19 | private static final Logger LOG = LoggerFactory.getLogger(DeltaBinaryEncoder.class); 20 | 21 | private int size; 22 | private BitSet deltaBits; 23 | private BitSet flagBits; 24 | 25 | @Override 26 | public void encode(int[] values) { 27 | size = values.length; 28 | flagBits = new BitSet(size * 2); 29 | deltaBits = new BitSet(size * 12); 30 | int offset = 0, prev = 0; 31 | for (int i = 0; i < size; i++) { 32 | int delta = values[i] - prev; 33 | int bytesNeeded = needBytes(delta); 34 | BinaryUtils.setBits(flagBits, 2 * i, bytesNeeded - 1, 2); 35 | BinaryUtils.setBytes(deltaBits, offset, delta, bytesNeeded); 36 | prev = values[i]; 37 | offset += bytesNeeded * 8; 38 | } 39 | } 40 | 41 | @Override 42 | public int[] decode() { 43 | int[] res = new int[size]; 44 | int offset = 0, prev = 0; 45 | for (int i = 0; i < size; i++) { 46 | int bytesNeeded = BinaryUtils.getBits(flagBits, i * 2, 2) + 1; 47 | int delta = BinaryUtils.getBytes(deltaBits, offset, bytesNeeded); 48 | res[i] = prev + delta; 49 | prev = res[i]; 50 | offset += bytesNeeded * 8; 51 | } 52 | return res; 53 | } 54 | 55 | public static int needBytes(int x) { 56 | if (x < 0) { 57 | throw new SketchMLException("Input of DeltaBinaryEncoder should be sorted"); 58 | } else if (x < 256) { 59 | return 1; 60 | } else if (x < 65536) { 61 | return 2; 62 | } else { 63 | return 4; 64 | } 65 | } 66 | 67 | private void writeObject(ObjectOutputStream oos) throws IOException { 68 | oos.writeInt(size); 69 | if (flagBits == null) { 70 | oos.writeInt(0); 71 | } else { 72 | long[] flags = flagBits.toLongArray(); 73 | oos.writeInt(flags.length); 74 | for (long l : flags) { 75 | oos.writeLong(l); 76 | } 77 | } 78 | if (deltaBits == null) { 79 | oos.writeInt(0); 80 | } else { 81 | long[] delta = deltaBits.toLongArray(); 82 | oos.writeInt(delta.length); 83 | for (long l : delta) { 84 | oos.writeLong(l); 85 | } 86 | } 87 | } 88 | 89 | private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { 90 | size = ois.readInt(); 91 | int flagsLength = ois.readInt(); 92 | long[] flags = new long[flagsLength]; 93 | for (int i = 0; i < flagsLength; i++) { 94 | flags[i] = ois.readLong(); 95 | } 96 | flagBits = BitSet.valueOf(flags); 97 | int deltaLength = ois.readInt(); 98 | long[] delta = new long[deltaLength]; 99 | for (int i = 0; i < deltaLength; i++) { 100 | delta[i] = ois.readLong(); 101 | } 102 | deltaBits = BitSet.valueOf(delta); 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/binary/HuffmanEncoder.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.binary; 2 | 3 | import it.unimi.dsi.fastutil.ints.Int2ObjectMap; 4 | import it.unimi.dsi.fastutil.ints.Int2ObjectRBTreeMap; 5 | import org.dma.sketchml.sketch.base.BinaryEncoder; 6 | import org.slf4j.Logger; 7 | import org.slf4j.LoggerFactory; 8 | 9 | import java.io.IOException; 10 | import java.io.ObjectInputStream; 11 | import java.io.ObjectOutputStream; 12 | import java.util.*; 13 | import java.util.stream.IntStream; 14 | 15 | public class HuffmanEncoder implements BinaryEncoder { 16 | private static final Logger LOG = LoggerFactory.getLogger(HuffmanEncoder.class); 17 | 18 | private Item[] items; 19 | private BitSet bitset; 20 | private int size; 21 | 22 | 23 | private class Node { 24 | int value; 25 | int occurrence; 26 | Node leftChild; 27 | Node rightChild; 28 | boolean isLeaf; 29 | 30 | Node(int value, int occurrence, Node leftChild, Node rightChild, boolean isLeaf) { 31 | this.value = value; 32 | this.occurrence = occurrence; 33 | this.leftChild = leftChild; 34 | this.rightChild = rightChild; 35 | this.isLeaf = isLeaf; 36 | } 37 | 38 | Node(int value, int occurrence) { 39 | this(value, occurrence, null, null, false); 40 | } 41 | 42 | Node() { 43 | this(-1, -1, null, null, false); 44 | } 45 | 46 | int getValue() { 47 | return value; 48 | } 49 | 50 | int getOccurrence() { 51 | return occurrence; 52 | } 53 | } 54 | 55 | private class Item { 56 | int value; 57 | int bits; 58 | int numBits; 59 | 60 | Item(int value, int bits, int numBits) { 61 | this.value = value; 62 | this.bits = bits; 63 | this.numBits = numBits; 64 | } 65 | 66 | @Override 67 | public String toString() { 68 | StringBuilder sb = new StringBuilder(); 69 | sb.append("("); 70 | sb.append(value); 71 | sb.append(" --> "); 72 | sb.append(BinaryUtils.bits2String(bits, numBits)); 73 | sb.append(")"); 74 | return sb.toString(); 75 | } 76 | } 77 | 78 | private void traverse(Node node, Int2ObjectMap mapping, int bits, int depth) { 79 | if (node.isLeaf) { 80 | mapping.put(node.value, new Item(node.value, bits, depth == 0 ? 1 : depth)); 81 | } else { 82 | traverse(node.leftChild, mapping, bits << 1, depth + 1); 83 | traverse(node.rightChild, mapping, (bits << 1) | 1, depth + 1); 84 | } 85 | } 86 | 87 | @Override 88 | public void encode(int[] values) { 89 | long startTime = System.currentTimeMillis(); 90 | // 1. count occurrences 91 | Int2ObjectRBTreeMap freq = new Int2ObjectRBTreeMap<>(); 92 | for (int v : values) { 93 | Node node = freq.get(v); 94 | if (node != null) 95 | node.occurrence++; 96 | else 97 | freq.put(v, new Node(v, 1, null, null, true)); 98 | } 99 | // 2. build tree 100 | PriorityQueue heap = new PriorityQueue<>(freq.size(), 101 | Comparator.comparing(Node::getOccurrence)); 102 | heap.addAll(freq.values()); 103 | while (heap.size() > 1) { 104 | Node x = heap.poll(); 105 | Node y = heap.poll(); 106 | Node p = new Node(-1, x.occurrence + y.occurrence, x, y, false); 107 | heap.add(p); 108 | } 109 | Int2ObjectMap mapping = new Int2ObjectRBTreeMap<>(); 110 | traverse(heap.peek(), mapping, 0, 0); 111 | items = new Item[mapping.size()]; 112 | mapping.values().toArray(items); 113 | // 3. encode values 114 | bitset = new BitSet(); 115 | int offset = 0; 116 | for (int v : values) { 117 | Item item = mapping.get(v); 118 | BinaryUtils.setBits(bitset, offset, item.bits, item.numBits); 119 | offset += item.numBits; 120 | } 121 | size = values.length; 122 | LOG.debug(String.format("Huffman encoding for %d values cost %d ms", 123 | values.length, System.currentTimeMillis() - startTime)); 124 | } 125 | 126 | @Override 127 | public int[] decode() { 128 | if (size == 0) 129 | return new int[0]; 130 | 131 | // 1. build Huffman tree 132 | Node root = new Node(); 133 | for (Item item : items) { 134 | int bits = item.bits; 135 | int numBits = item.numBits; 136 | Node cur = root; 137 | for (int i = numBits - 1; i >= 0; i--) { 138 | int t = bits - (1 << i); 139 | if (t >= 0) { 140 | if (cur.rightChild == null) 141 | cur.rightChild = new Node(); 142 | cur = cur.rightChild; 143 | bits = t; 144 | } else { 145 | if (cur.leftChild == null) 146 | cur.leftChild = new Node(); 147 | cur = cur.leftChild; 148 | } 149 | } 150 | cur.value = item.value; 151 | cur.isLeaf = true; 152 | } 153 | // 2. decode bits 154 | int[] res = new int[size]; 155 | int cnt = 0; 156 | Node cur = root; 157 | int idx = 0; 158 | while (cnt < size) { 159 | cur = bitset.get(idx++) ? cur.rightChild : cur.leftChild; 160 | if (cur.isLeaf) { 161 | res[cnt++] = cur.value; 162 | cur = root; 163 | } 164 | } 165 | return res; 166 | } 167 | 168 | private void writeObject(ObjectOutputStream oos) throws IOException { 169 | // items 170 | if (items == null) { 171 | oos.writeInt(0); 172 | } else { 173 | oos.writeInt(items.length); 174 | for (Item item : items) { 175 | oos.writeInt(item.value); 176 | oos.writeInt(item.bits); 177 | oos.writeInt(item.numBits); 178 | } 179 | } 180 | // bit set 181 | if (bitset == null) { 182 | oos.writeInt(0); 183 | } else { 184 | long[] bits = bitset.toLongArray(); 185 | oos.writeInt(bits.length); 186 | for (long l : bits) 187 | oos.writeLong(l); 188 | } 189 | // size 190 | oos.writeInt(size); 191 | } 192 | 193 | private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { 194 | // items 195 | int numItems = ois.readInt(); 196 | items = new Item[numItems]; 197 | for (int i = 0; i < numItems; i++) 198 | items[i] = new Item(ois.readInt(), ois.readInt(), ois.readInt()); 199 | // bit set 200 | int numLongs = ois.readInt(); 201 | long[] bits = new long[numLongs]; 202 | for (int i = 0; i < numLongs; i++) 203 | bits[i] = ois.readLong(); 204 | bitset = BitSet.valueOf(bits); 205 | // size 206 | size = ois.readInt(); 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/common/Constants.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.common; 2 | 3 | import org.dma.sketchml.sketch.base.SketchMLException; 4 | 5 | import java.util.concurrent.ExecutorService; 6 | import java.util.concurrent.Executors; 7 | 8 | public class Constants { 9 | public static class Parallel { 10 | private static int parallelism; 11 | 12 | private static ExecutorService threadPool; 13 | 14 | static { 15 | parallelism = 0; 16 | threadPool = null; 17 | } 18 | 19 | public static void setParallelism(int parallelism) { 20 | if (parallelism < 1) 21 | throw new SketchMLException("Invalid parallelism: " + parallelism); 22 | Parallel.parallelism = parallelism; 23 | Parallel.threadPool = Executors.newFixedThreadPool(parallelism); 24 | } 25 | 26 | public static int getParallelism() { 27 | if (parallelism <= 0) 28 | throw new SketchMLException("Parallelism is not set yet"); 29 | return parallelism; 30 | } 31 | 32 | public static ExecutorService getThreadPool() { 33 | if (threadPool == null) 34 | throw new SketchMLException("Parallelism is not set yet"); 35 | return threadPool; 36 | } 37 | 38 | public static void shutdown() { 39 | if (threadPool != null) 40 | threadPool.shutdown(); 41 | } 42 | } 43 | 44 | } 45 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/hash/BJHash.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.hash; 2 | 3 | import org.dma.sketchml.sketch.base.Int2IntHash; 4 | 5 | public class BJHash extends Int2IntHash { 6 | public BJHash(int size) { 7 | super(size); 8 | } 9 | 10 | public int hash(int key) { 11 | int code = key; 12 | code = (code + 0x7ed55d16) + (code << 12); 13 | code = (code ^ 0xc761c23c) ^ (code >> 19); 14 | code = (code + 0x165667b1) + (code << 5); 15 | code = (code + 0xd3a2646c) ^ (code << 9); 16 | code = (code + 0xfd7046c5) + (code << 3); 17 | code = (code ^ 0xb55a4f09) ^ (code >> 16); 18 | code %= size; 19 | return code >= 0 ? code : code + size; 20 | } 21 | 22 | @Override 23 | public Int2IntHash clone() { 24 | return new BJHash(size); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/hash/BKDRHash.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.hash; 2 | 3 | import org.dma.sketchml.sketch.base.Int2IntHash; 4 | 5 | public class BKDRHash extends Int2IntHash { 6 | private int seed; 7 | 8 | public BKDRHash(int size, int seed) { 9 | super(size); 10 | this.seed = seed; 11 | } 12 | 13 | public int hash(int key) { 14 | int code = 0; 15 | while (key != 0) { 16 | code = seed * code + (key % 10); 17 | key /= 10; 18 | } 19 | code %= size; 20 | return code >= 0 ? code : code + size; 21 | } 22 | 23 | @Override 24 | public Int2IntHash clone() { 25 | return new BKDRHash(size, seed); 26 | } 27 | 28 | public int getSeed() { 29 | return seed; 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/hash/HashFactory.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.hash; 2 | 3 | import org.dma.sketchml.sketch.base.Int2IntHash; 4 | import org.dma.sketchml.sketch.base.SketchMLException; 5 | import org.dma.sketchml.sketch.util.Maths; 6 | 7 | import java.util.*; 8 | 9 | public class HashFactory { 10 | private static final Int2IntHash[] int2intHashes = 11 | new Int2IntHash[]{new BJHash(0), new Mix64Hash(0), 12 | new TWHash(0), new BKDRHash(0, 31), new BKDRHash(0, 131), 13 | new BKDRHash(0, 267), new BKDRHash(0, 1313), new BKDRHash(0, 13131)}; 14 | private static final Random random = new Random(); 15 | 16 | public static Int2IntHash getRandomInt2IntHash(int size) { 17 | int idx = random.nextInt(int2intHashes.length); 18 | Int2IntHash res = int2intHashes[idx].clone(); 19 | res.setSize(size); 20 | return res; 21 | } 22 | 23 | public static Int2IntHash[] getRandomInt2IntHashes(int hashNum, int size) { 24 | if (hashNum > int2intHashes.length) { 25 | throw new SketchMLException(String.format("Currently only %d " + 26 | "hash functions are available", int2intHashes.length)); 27 | } else { 28 | Int2IntHash[] res = new Int2IntHash[hashNum]; 29 | int[] indexes = new int[int2intHashes.length]; 30 | Arrays.setAll(indexes, i -> i); 31 | Maths.shuffle(indexes); 32 | for (int i = 0; i < hashNum; i++) { 33 | res[i] = int2intHashes[indexes[i]].clone(); 34 | res[i].setSize(size); 35 | } 36 | return res; 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/hash/Mix64Hash.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.hash; 2 | 3 | import org.dma.sketchml.sketch.base.Int2IntHash; 4 | 5 | public class Mix64Hash extends Int2IntHash { 6 | public Mix64Hash(int size) { 7 | super(size); 8 | } 9 | 10 | public int hash(int key) { 11 | int code = key; 12 | code = (~code) + (code << 21); // code = (code << 21) - code - 1; 13 | code = code ^ (code >> 24); 14 | code = (code + (code << 3)) + (code << 8); // code * 265 15 | code = code ^ (code >> 14); 16 | code = (code + (code << 2)) + (code << 4); // code * 21 17 | code = code ^ (code >> 28); 18 | code = code + (code << 31); 19 | code %= size; 20 | return code >= 0 ? code : code + size; 21 | } 22 | 23 | @Override 24 | public Int2IntHash clone() { 25 | return new Mix64Hash(size); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/hash/TWHash.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.hash; 2 | 3 | import org.dma.sketchml.sketch.base.Int2IntHash; 4 | 5 | public class TWHash extends Int2IntHash { 6 | public TWHash(int size) { 7 | super(size); 8 | } 9 | 10 | public int hash(int key) { 11 | int code = key; 12 | code = ~code + (code << 15); 13 | code = code ^ (code >> 12); 14 | code = code + (code << 2); 15 | code = code ^ (code >> 4); 16 | code = code * 2057; 17 | code = code ^ (code >> 16); 18 | code %= size; 19 | return code >= 0 ? code : code + size; 20 | } 21 | 22 | @Override 23 | public Int2IntHash clone() { 24 | return new TWHash(size); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/quantization/QuantileQuantizer.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.quantization; 2 | 3 | import org.dma.sketchml.sketch.base.Quantizer; 4 | import org.dma.sketchml.sketch.common.Constants; 5 | import org.dma.sketchml.sketch.sketch.quantile.HeapQuantileSketch; 6 | import org.dma.sketchml.sketch.util.Maths; 7 | import org.slf4j.Logger; 8 | import org.slf4j.LoggerFactory; 9 | 10 | import java.util.concurrent.Callable; 11 | import java.util.concurrent.ExecutionException; 12 | import java.util.concurrent.ExecutorService; 13 | import java.util.concurrent.Future; 14 | 15 | public class QuantileQuantizer extends Quantizer { 16 | private static final Logger LOG = LoggerFactory.getLogger(QuantileQuantizer.class); 17 | 18 | public QuantileQuantizer(int binNum) { 19 | super(binNum); 20 | } 21 | 22 | public QuantileQuantizer() { 23 | this(Quantizer.DEFAULT_BIN_NUM); 24 | } 25 | 26 | @Override 27 | public void quantize(double[] values) { 28 | long startTime = System.currentTimeMillis(); 29 | // 1. create quantile sketch summary 30 | n = values.length; 31 | HeapQuantileSketch qSketch = new HeapQuantileSketch((long) n); 32 | for (double v : values) { 33 | qSketch.update(v); 34 | } 35 | min = qSketch.getMinValue(); 36 | max = qSketch.getMaxValue(); 37 | // 2. query quantiles, set them as bin edges 38 | splits = Maths.unique(qSketch.getQuantiles(binNum)); 39 | if (splits.length + 1 != binNum) { 40 | LOG.warn(String.format("Actual bin num %d not equal to %d", 41 | splits.length + 1, binNum)); 42 | binNum = splits.length + 1; 43 | } 44 | // 3. find the zero index 45 | findZeroIdx(); 46 | // 4. find index of each value 47 | quantizeToBins(values); 48 | LOG.debug(String.format("Quantile quantization for %d items cost %d ms", 49 | n, System.currentTimeMillis() - startTime)); 50 | } 51 | 52 | @Override 53 | public void parallelQuantize(double[] values) throws InterruptedException, ExecutionException { 54 | long startTime = System.currentTimeMillis(); 55 | // 1. create quantile sketch summary in parallel 56 | n = values.length; 57 | // 1.1. each thread create a quantile sketch based on a portion of data 58 | int threadNum = Constants.Parallel.getParallelism(); 59 | ExecutorService threadPool = Constants.Parallel.getThreadPool(); 60 | Future[] futures = new Future[threadNum]; 61 | for (int i = 0; i < threadNum; i++) { 62 | int threadId = i; 63 | futures[threadId] = threadPool.submit(new Callable() { 64 | @Override 65 | public HeapQuantileSketch call() throws Exception { 66 | int elementPerThread = n / threadNum; 67 | int from = threadId * elementPerThread; 68 | int to = threadId + 1 == threadNum ? n : from + elementPerThread; 69 | HeapQuantileSketch qSketch = new HeapQuantileSketch((long) (to - from)); 70 | for (int itemId = from; itemId < to; itemId++) { 71 | qSketch.update(values[itemId]); 72 | } 73 | return qSketch; 74 | } 75 | }); 76 | } 77 | // 1.2. merge all quantile sketches together 78 | HeapQuantileSketch qSketch = futures[0].get(); 79 | for (int i = 1; i < threadNum; i++) { 80 | qSketch.merge(futures[i].get()); 81 | } 82 | min = qSketch.getMinValue(); 83 | max = qSketch.getMaxValue(); 84 | // 2. query quantiles, set them as bin edges 85 | splits = qSketch.getQuantiles(binNum); 86 | // 3. find the zero index 87 | findZeroIdx(); 88 | // 4. find index of each value 89 | parallelQuantizeToBins(values); 90 | LOG.debug(String.format("Quantile quantization for %d items cost %d ms", 91 | n, System.currentTimeMillis() - startTime)); 92 | } 93 | 94 | @Override 95 | public QuantizationType quantizationType() { 96 | return QuantizationType.QUANTILE; 97 | } 98 | 99 | } 100 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/quantization/UniformQuantizer.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.quantization; 2 | 3 | import org.dma.sketchml.sketch.base.Quantizer; 4 | import org.slf4j.Logger; 5 | import org.slf4j.LoggerFactory; 6 | 7 | import java.util.concurrent.ExecutionException; 8 | 9 | public class UniformQuantizer extends Quantizer { 10 | public static final Logger LOG = LoggerFactory.getLogger(UniformQuantizer.class); 11 | 12 | public UniformQuantizer(int binNum) { 13 | super(binNum); 14 | } 15 | 16 | public UniformQuantizer() { 17 | super(Quantizer.DEFAULT_BIN_NUM); 18 | } 19 | 20 | @Override 21 | public void quantize(double[] values) { 22 | long startTime = System.currentTimeMillis(); 23 | n = values.length; 24 | min = Double.MAX_VALUE; 25 | max = Double.MIN_VALUE; 26 | for (double v : values) { 27 | if (v < min) min = v; 28 | if (v > max) max = v; 29 | } 30 | // 1. uniformly split the range of values 31 | double step = (max - min) / binNum; 32 | int splitNum = binNum - 1; 33 | splits = new double[splitNum]; 34 | splits[0] = min + step; 35 | for (int i = 1; i < splitNum; i++) { 36 | splits[i] = splits[i - 1] + step; 37 | } 38 | // 3. find the zero index 39 | findZeroIdx(); 40 | // 4. find index of each value 41 | quantizeToBins(values); 42 | LOG.debug(String.format("Uniform quantization for %d items cost %d ms", 43 | n, System.currentTimeMillis() - startTime)); 44 | } 45 | 46 | @Override 47 | public void parallelQuantize(double[] values) throws InterruptedException, ExecutionException { 48 | long startTime = System.currentTimeMillis(); 49 | n = values.length; 50 | min = Double.MAX_VALUE; 51 | max = Double.MIN_VALUE; 52 | for (double v : values) { 53 | if (v < min) min = v; 54 | if (v > max) max = v; 55 | } 56 | // 1. uniformly split the range of values 57 | double step = (max - min) / binNum; 58 | int splitNum = binNum - 1; 59 | splits = new double[splitNum]; 60 | splits[0] = min + step; 61 | for (int i = 1; i < splitNum; i++) { 62 | splits[i] = splits[i - 1] + step; 63 | } 64 | // 3. find the zero index 65 | findZeroIdx(); 66 | // 4. find index of each value 67 | parallelQuantizeToBins(values); 68 | LOG.debug(String.format("Uniform quantization for %d items cost %d ms", 69 | n, System.currentTimeMillis() - startTime)); 70 | } 71 | 72 | @Override 73 | public QuantizationType quantizationType() { 74 | return QuantizationType.UNIFORM; 75 | } 76 | 77 | } 78 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/sample/App.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.sample; 2 | 3 | import it.unimi.dsi.fastutil.doubles.DoubleArrayList; 4 | import it.unimi.dsi.fastutil.ints.IntArrayList; 5 | import org.apache.commons.lang3.tuple.Pair; 6 | import org.dma.sketchml.sketch.base.QuantileSketch; 7 | import org.dma.sketchml.sketch.base.Quantizer; 8 | import org.dma.sketchml.sketch.base.VectorCompressor; 9 | import org.dma.sketchml.sketch.common.Constants; 10 | import org.dma.sketchml.sketch.sketch.frequency.GroupedMinMaxSketch; 11 | import org.dma.sketchml.sketch.sketch.frequency.MinMaxSketch; 12 | import org.dma.sketchml.sketch.sketch.quantile.HeapQuantileSketch; 13 | import org.dma.sketchml.sketch.util.Utils; 14 | import org.slf4j.Logger; 15 | import org.slf4j.LoggerFactory; 16 | 17 | import java.util.Arrays; 18 | import java.util.Random; 19 | 20 | public class App { 21 | private static final Logger LOG = LoggerFactory.getLogger(App.class); 22 | 23 | private static Random random = new Random(); 24 | 25 | public static void main(String[] args) throws Exception { 26 | Constants.Parallel.setParallelism(4); 27 | dense(); 28 | sparse(); 29 | Constants.Parallel.shutdown(); 30 | } 31 | 32 | private static void dense() throws Exception { 33 | int n = 1000000; 34 | double density = 0.9; 35 | double[] values = new double[n]; 36 | for (int i = 0; i < n; i++) { 37 | if (random.nextDouble() < density) { 38 | values[i] = random.nextGaussian(); 39 | } 40 | } 41 | Quantizer.QuantizationType quantType = Quantizer.QuantizationType.QUANTILE; 42 | int binNum = Quantizer.DEFAULT_BIN_NUM; 43 | VectorCompressor compressor = new DenseVectorCompressor(quantType, binNum); 44 | //compressor.compressDense(values); 45 | compressor.parallelCompressDense(values); 46 | compressor = (VectorCompressor) Utils.testSerialization(compressor); 47 | double[] dValues = compressor.decompressDense(); 48 | LOG.info("First 10 values before: " + Arrays.toString(Arrays.copyOf(values, 10))); 49 | LOG.info("First 10 values after: " + Arrays.toString(Arrays.copyOf(dValues, 10))); 50 | QuantileSketch qSketch = new HeapQuantileSketch((long) n); 51 | double rmse = 0.0; 52 | for (int i = 0; i < n; i++) { 53 | qSketch.update(dValues[i] - values[i]); 54 | rmse += (dValues[i] - values[i]) *(dValues[i] - values[i]); 55 | } 56 | double[] err = qSketch.getQuantiles(100); 57 | rmse = Math.sqrt(rmse / n); 58 | LOG.info("Quantiles of errors: " + Arrays.toString(err)); 59 | LOG.info("RMSE: " + rmse); 60 | int originBytes = values.length * 8; 61 | int compressBytes = compressor.memoryBytes(); 62 | LOG.info(String.format("Compress %d bytes into %d bytes, compression rate: %f", 63 | originBytes, compressBytes, 1.0 * originBytes / compressBytes)); 64 | } 65 | 66 | private static void sparse() throws Exception { 67 | int n = 100000; 68 | double sparsity = 0.9; 69 | IntArrayList keyList = new IntArrayList(); 70 | DoubleArrayList valueList = new DoubleArrayList(); 71 | for (int i = 0; i < n; i++) { 72 | if (random.nextDouble() > sparsity) { 73 | keyList.add(i); 74 | valueList.add(random.nextGaussian()); 75 | } 76 | } 77 | int nnz = keyList.size(); 78 | int[] keys = keyList.toIntArray(); 79 | double[] values = valueList.toDoubleArray(); 80 | Quantizer.QuantizationType quantType = Quantizer.QuantizationType.QUANTILE; 81 | int binNum = Quantizer.DEFAULT_BIN_NUM; 82 | int groupNum = GroupedMinMaxSketch.DEFAULT_MINMAXSKETCH_GROUP_NUM; 83 | int rowNum = MinMaxSketch.DEFAULT_MINMAXSKETCH_ROW_NUM; 84 | double colRatio = GroupedMinMaxSketch.DEFAULT_MINMAXSKETCH_COL_RATIO; 85 | VectorCompressor compressor = new SparseVectorCompressor( 86 | quantType, binNum, groupNum, rowNum, colRatio); 87 | compressor = (VectorCompressor) Utils.testSerialization(compressor); 88 | //compressor.compressSparse(keys, values); 89 | compressor.parallelCompressSparse(keys, values); 90 | Pair dResult = compressor.decompressSparse(); 91 | int[] dKeys = dResult.getLeft(); 92 | double[] dValues = dResult.getRight(); 93 | LOG.info(String.format("Array length: [%d, %d] vs. [%d, %d]", 94 | nnz, nnz, dKeys.length, dValues.length)); 95 | LOG.info("First 10 keys before: " + Arrays.toString(Arrays.copyOf(keys, 10))); 96 | LOG.info("First 10 keys after: " + Arrays.toString(Arrays.copyOf(dKeys, 10))); 97 | LOG.info("First 10 values before: " + Arrays.toString(Arrays.copyOf(values, 10))); 98 | LOG.info("First 10 values after: " + Arrays.toString(Arrays.copyOf(dValues, 10))); 99 | QuantileSketch qSketch = new HeapQuantileSketch((long) nnz); 100 | double rmse = 0.0; 101 | for (int i = 0; i < nnz; i++) { 102 | if (keys[i] != dKeys[i]) { 103 | LOG.error(String.format("Keys not match: [%d, %d]", keys[i], dKeys[i])); 104 | } else { 105 | qSketch.update(dValues[i] - values[i]); 106 | rmse += (dValues[i] - values[i]) * (dValues[i] - values[i]); 107 | } 108 | } 109 | double[] err = qSketch.getQuantiles(100); 110 | rmse = Math.sqrt(rmse / nnz); 111 | LOG.info("Quantiles of errors: " + Arrays.toString(err)); 112 | LOG.info("RMSE: " + rmse); 113 | int originBytes = 12 * nnz; 114 | int compressBytes = compressor.memoryBytes(); 115 | LOG.info(String.format("Compress %d bytes into %d bytes, compression rate: %f", 116 | originBytes, compressBytes, 1.0 * originBytes / compressBytes)); 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/sample/DenseVectorCompressor.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.sample; 2 | 3 | import org.apache.commons.lang3.tuple.ImmutablePair; 4 | import org.apache.commons.lang3.tuple.Pair; 5 | import org.dma.sketchml.sketch.base.Quantizer; 6 | import org.dma.sketchml.sketch.base.SketchMLException; 7 | import org.dma.sketchml.sketch.base.VectorCompressor; 8 | import org.dma.sketchml.sketch.util.Maths; 9 | import org.dma.sketchml.sketch.util.Utils; 10 | import org.slf4j.Logger; 11 | import org.slf4j.LoggerFactory; 12 | 13 | import java.io.IOException; 14 | import java.io.Serializable; 15 | import java.util.Arrays; 16 | import java.util.concurrent.ExecutionException; 17 | 18 | public class DenseVectorCompressor implements VectorCompressor { 19 | private static final Logger LOG = LoggerFactory.getLogger(DenseVectorCompressor.class); 20 | 21 | private int size; 22 | 23 | private Quantizer.QuantizationType quantType; 24 | private int quantBinNum; 25 | private Quantizer quantizer; 26 | 27 | public DenseVectorCompressor( 28 | Quantizer.QuantizationType quantType, int quantBinNum) { 29 | this.quantType = quantType; 30 | this.quantBinNum = quantBinNum; 31 | } 32 | 33 | @Override 34 | public void compressDense(double[] values) { 35 | long startTime = System.currentTimeMillis(); 36 | size = values.length; 37 | quantizer = Quantizer.newQuantizer(quantType, quantBinNum); 38 | quantizer.quantize(values); 39 | LOG.debug(String.format("Dense vector compression cost %d ms, %d items " + 40 | "in total", System.currentTimeMillis() - startTime, size)); 41 | } 42 | 43 | @Override 44 | public void compressSparse(int[] keys, double[] values) { 45 | LOG.warn("Compressing a sparse vector with DenseVectorCompressor"); 46 | if (keys.length != values.length) { 47 | throw new SketchMLException(String.format( 48 | "Lengths of key array and value array do not match: %d, %d", 49 | keys.length, values.length)); 50 | } 51 | int maxKey = Maths.max(keys); 52 | double[] dense = new double[maxKey]; 53 | for (int i = 0; i < keys.length; i++) 54 | dense[keys[i]] = values[i]; 55 | compressDense(dense); 56 | } 57 | 58 | @Override 59 | public void parallelCompressDense(double[] values) throws InterruptedException, ExecutionException { 60 | long startTime = System.currentTimeMillis(); 61 | size = values.length; 62 | quantizer = Quantizer.newQuantizer(quantType, quantBinNum); 63 | quantizer.parallelQuantize(values); 64 | LOG.debug(String.format("Dense vector parallel compression cost %d ms, %d items " + 65 | "in total", System.currentTimeMillis() - startTime, size)); 66 | } 67 | 68 | @Override 69 | public void parallelCompressSparse(int[] keys, double[] values) throws InterruptedException, ExecutionException { 70 | LOG.warn("Compressing a sparse vector with DenseVectorCompressor"); 71 | if (keys.length != values.length) { 72 | throw new SketchMLException(String.format( 73 | "Lengths of key array and value array do not match: %d, %d", 74 | keys.length, values.length)); 75 | } 76 | int maxKey = Maths.max(keys); 77 | double[] dense = new double[maxKey]; 78 | for (int i = 0; i < keys.length; i++) 79 | dense[keys[i]] = values[i]; 80 | parallelCompressDense(dense); 81 | } 82 | 83 | @Override 84 | public double[] decompressDense() { 85 | double[] values = new double[size]; 86 | double[] quantValues = quantizer.getValues(); 87 | int[] bins = quantizer.getBins(); 88 | for (int i = 0; i < size; i++) 89 | values[i] = quantValues[bins[i]]; 90 | return values; 91 | } 92 | 93 | @Override 94 | public Pair decompressSparse() { 95 | double[] values = decompressDense(); 96 | int[] keys = new int[values.length]; 97 | Arrays.setAll(keys, i -> i); 98 | return new ImmutablePair<>(keys, values); 99 | } 100 | 101 | @Override 102 | public void timesBy(double x) { 103 | quantizer.timesBy(x); 104 | } 105 | 106 | @Override 107 | public double size() { 108 | return size; 109 | } 110 | 111 | @Override 112 | public int memoryBytes() throws IOException { 113 | int res = 12; 114 | if (quantizer != null) res += Utils.sizeof(quantizer); 115 | return res; 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/sample/SparseVectorCompressor.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.sample; 2 | 3 | 4 | import org.apache.commons.lang3.tuple.ImmutablePair; 5 | import org.apache.commons.lang3.tuple.Pair; 6 | import org.dma.sketchml.sketch.sketch.frequency.GroupedMinMaxSketch; 7 | import org.dma.sketchml.sketch.base.Quantizer; 8 | import org.dma.sketchml.sketch.base.SketchMLException; 9 | import org.dma.sketchml.sketch.base.VectorCompressor; 10 | import org.dma.sketchml.sketch.util.Utils; 11 | import org.slf4j.Logger; 12 | import org.slf4j.LoggerFactory; 13 | 14 | import java.io.*; 15 | import java.util.Arrays; 16 | import java.util.concurrent.ExecutionException; 17 | 18 | public class SparseVectorCompressor implements VectorCompressor { 19 | private static final Logger LOG = LoggerFactory.getLogger(SparseVectorCompressor.class); 20 | 21 | private int size; 22 | 23 | private Quantizer.QuantizationType quantType; 24 | private int quantBinNum; 25 | private double[] quantValues; 26 | 27 | private GroupedMinMaxSketch mmSketches; 28 | 29 | private int mmSketchGroupNum; 30 | private int mmSketchRowNum; 31 | private double mmSketchColRatio; 32 | 33 | public SparseVectorCompressor( 34 | Quantizer.QuantizationType quantType, int quantBinNum, 35 | int mmSketchGroupNum, int mmSketchRowNum, double mmSketchColRatio) { 36 | this.quantType = quantType; 37 | this.quantBinNum = quantBinNum; 38 | this.mmSketchGroupNum = mmSketchGroupNum; 39 | this.mmSketchRowNum = mmSketchRowNum; 40 | this.mmSketchColRatio = mmSketchColRatio; 41 | } 42 | 43 | @Override 44 | public void compressDense(double[] values) { 45 | LOG.warn("Compressing a dense vector with SparseVectorCompressor"); 46 | int[] keys = new int[values.length]; 47 | Arrays.setAll(keys, i -> i); 48 | compressSparse(keys, values); 49 | } 50 | 51 | @Override 52 | public void compressSparse(int[] keys, double[] values) { 53 | long startTime = System.currentTimeMillis(); 54 | if (keys.length != values.length) { 55 | throw new SketchMLException(String.format( 56 | "Lengths of key array and value array do not match: %d, %d", 57 | keys.length, values.length)); 58 | } 59 | size = keys.length; 60 | // 1. quantize into bin indexes 61 | Quantizer quantizer = Quantizer.newQuantizer(quantType, quantBinNum); 62 | quantizer.quantize(values); 63 | quantValues = quantizer.getValues(); 64 | // 2. encode bins and keys 65 | mmSketches = new GroupedMinMaxSketch(mmSketchGroupNum, mmSketchRowNum, 66 | mmSketchColRatio, quantizer.getBinNum(), quantizer.getZeroIdx()); 67 | mmSketches.create(keys, quantizer.getBins()); 68 | LOG.debug(String.format("Sparse vector compression cost %d ms, %d key-value " + 69 | "pairs in total", System.currentTimeMillis() - startTime, size)); 70 | } 71 | 72 | @Override 73 | public void parallelCompressDense(double[] values) throws InterruptedException, ExecutionException { 74 | LOG.warn("Compressing a dense vector with SparseVectorCompressor"); 75 | int[] keys = new int[values.length]; 76 | Arrays.setAll(keys, i -> i); 77 | parallelCompressSparse(keys, values); 78 | } 79 | 80 | @Override 81 | public void parallelCompressSparse(int[] keys, double[] values) throws InterruptedException, ExecutionException { 82 | long startTime = System.currentTimeMillis(); 83 | if (keys.length != values.length) { 84 | throw new SketchMLException(String.format( 85 | "Lengths of key array and value array do not match: %d, %d", 86 | keys.length, values.length)); 87 | } 88 | size = keys.length; 89 | // 1. quantize into bin indexes 90 | Quantizer quantizer = Quantizer.newQuantizer(quantType, quantBinNum); 91 | quantizer.parallelQuantize(values); 92 | quantValues = quantizer.getValues(); 93 | // 2. encode bins and keys 94 | mmSketches = new GroupedMinMaxSketch(mmSketchGroupNum, mmSketchRowNum, 95 | mmSketchColRatio, quantizer.getBinNum(), quantizer.getZeroIdx()); 96 | mmSketches.parallelCreate(keys, quantizer.getBins()); 97 | LOG.debug(String.format("Sparse vector parallel compression cost %d ms, %d key-value " + 98 | "pairs in total", System.currentTimeMillis() - startTime, size)); 99 | } 100 | 101 | @Override 102 | public double[] decompressDense() { 103 | Pair kv = decompressSparse(); 104 | int[] keys = kv.getLeft(); 105 | double[] values = kv.getRight(); 106 | int maxKey = 0; 107 | for (int key : keys) { 108 | maxKey = Math.max(key, maxKey); 109 | } 110 | double[] res = new double[maxKey + 1]; 111 | for (int i = 0; i < size; i++) { 112 | res[keys[i]] = values[i]; 113 | } 114 | return res; 115 | } 116 | 117 | @Override 118 | public Pair decompressSparse() { 119 | Pair kb = mmSketches.restore(); 120 | int[] keys = kb.getLeft(); 121 | int[] bins = kb.getRight(); 122 | double[] values = new double[size]; 123 | for (int i = 0; i < size; i++) 124 | values[i] = quantValues[bins[i]]; 125 | return new ImmutablePair<>(keys, values); 126 | } 127 | 128 | @Override 129 | public void timesBy(double x) { 130 | if (quantValues != null) { 131 | for (int i = 0; i < quantValues.length; i++) 132 | quantValues[i] *= x; 133 | } 134 | } 135 | 136 | @Override 137 | public double size() { 138 | return size; 139 | } 140 | 141 | @Override 142 | public int memoryBytes() throws IOException { 143 | int res = 28 + quantValues.length * 8; 144 | if (mmSketches != null) 145 | res += Utils.sizeof(mmSketches); 146 | return res; 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/sketch/frequency/FSketchUtils.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.sketch.frequency; 2 | 3 | import it.unimi.dsi.fastutil.ints.IntArrayList; 4 | import org.apache.commons.lang3.tuple.ImmutablePair; 5 | import org.apache.commons.lang3.tuple.Pair; 6 | 7 | public class FSketchUtils { 8 | 9 | public static int[] calGroupEdges(int zeroIdx, int binNum, int groupNum) { 10 | if (groupNum == 2) { 11 | return new int[]{zeroIdx, binNum}; 12 | } else { 13 | int[] groupEdges = new int[groupNum]; 14 | int binsPerGroup = binNum / groupNum; 15 | if (zeroIdx < binsPerGroup) { 16 | groupEdges[0] = zeroIdx; 17 | } else if ((zeroIdx % binsPerGroup) < (binsPerGroup / 2)) { 18 | groupEdges[0] = binsPerGroup + zeroIdx % binsPerGroup; 19 | } else { 20 | groupEdges[0] = zeroIdx % binsPerGroup; 21 | } 22 | for (int i = 1; i < groupNum - 1; i++) { 23 | groupEdges[i] = groupEdges[i - 1] + binsPerGroup; 24 | } 25 | groupEdges[groupNum - 1] = binNum; 26 | return groupEdges; 27 | } 28 | } 29 | 30 | public static Pair partition(int[] keys, int[] bins, int[] groupEdges) { 31 | int groupNum = groupEdges.length; 32 | IntArrayList[] keyLists = new IntArrayList[groupNum]; 33 | IntArrayList[] binLists = new IntArrayList[groupNum]; 34 | for (int i = 0; i < groupNum; i++) { 35 | int groupSpan = i > 0 ? (groupEdges[i] - groupEdges[i - 1]) : groupEdges[0]; 36 | int estimatedGroupSize = (int) Math.ceil(1.0 * keys.length / groupNum * groupSpan); 37 | keyLists[i] = new IntArrayList(estimatedGroupSize); 38 | binLists[i] = new IntArrayList(estimatedGroupSize); 39 | } 40 | for (int i = 0; i < keys.length; i++) { 41 | int groupIdx = 0; 42 | while (groupEdges[groupIdx] <= bins[i]) groupIdx++; 43 | keyLists[groupIdx].add(keys[i]); 44 | binLists[groupIdx].add(bins[i]); 45 | } 46 | return new ImmutablePair<>(keyLists, binLists); 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/sketch/frequency/GroupedMinMaxSketch.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.sketch.frequency; 2 | 3 | import it.unimi.dsi.fastutil.ints.IntArrayList; 4 | import org.apache.commons.lang3.tuple.ImmutablePair; 5 | import org.apache.commons.lang3.tuple.Pair; 6 | import org.dma.sketchml.sketch.base.BinaryEncoder; 7 | import org.dma.sketchml.sketch.binary.DeltaAdaptiveEncoder; 8 | import org.dma.sketchml.sketch.common.Constants; 9 | import org.dma.sketchml.sketch.util.Sort; 10 | import org.slf4j.Logger; 11 | import org.slf4j.LoggerFactory; 12 | 13 | import java.io.IOException; 14 | import java.io.ObjectInputStream; 15 | import java.io.ObjectOutputStream; 16 | import java.io.Serializable; 17 | import java.util.ArrayList; 18 | import java.util.List; 19 | import java.util.concurrent.Callable; 20 | import java.util.concurrent.ExecutionException; 21 | import java.util.concurrent.ExecutorService; 22 | import java.util.concurrent.Future; 23 | 24 | public class GroupedMinMaxSketch implements Serializable { 25 | private static final Logger LOG = LoggerFactory.getLogger(GroupedMinMaxSketch.class); 26 | 27 | private int groupNum; 28 | private int rowNum; 29 | private double colRatio; 30 | private int binNum; 31 | private int zeroValue; 32 | private MinMaxSketch[] sketches; 33 | private BinaryEncoder[] encoders; 34 | 35 | public static final int DEFAULT_MINMAXSKETCH_GROUP_NUM = 8; 36 | public static final double DEFAULT_MINMAXSKETCH_COL_RATIO = 0.3; 37 | 38 | public GroupedMinMaxSketch(int groupNum, int rowNum, double colRatio, int binNum, int zeroValue) { 39 | this.groupNum = groupNum; 40 | this.rowNum = rowNum; 41 | this.colRatio = colRatio; 42 | this.binNum = binNum; 43 | this.zeroValue = zeroValue; 44 | } 45 | 46 | public GroupedMinMaxSketch(int binNum, int zeroValue) { 47 | this(DEFAULT_MINMAXSKETCH_GROUP_NUM, MinMaxSketch.DEFAULT_MINMAXSKETCH_ROW_NUM, 48 | DEFAULT_MINMAXSKETCH_COL_RATIO, binNum, zeroValue); 49 | } 50 | 51 | public void create(int[] keys, int[] bins) { 52 | long startTime = System.currentTimeMillis(); 53 | // 1. divide bins into several groups 54 | int[] groupEdges = FSketchUtils.calGroupEdges(zeroValue, binNum, groupNum); 55 | sketches = new MinMaxSketch[groupNum]; 56 | encoders = new BinaryEncoder[groupNum]; 57 | Pair partKBLists = 58 | FSketchUtils.partition(keys, bins, groupEdges); 59 | // 2. encode bins and keys 60 | for (int i = 0; i < groupNum; i++) { 61 | IntArrayList keyList = partKBLists.getLeft()[i]; 62 | IntArrayList binList = partKBLists.getRight()[i]; 63 | Pair group = compOneGroup( 64 | keyList, binList, groupEdges, i); 65 | sketches[i] = group.getLeft(); 66 | encoders[i] = group.getRight(); 67 | } 68 | LOG.debug(String.format("Create grouped MinMaxSketch cost %d ms", 69 | System.currentTimeMillis() - startTime)); 70 | } 71 | 72 | public void parallelCreate(int[] keys, int[] bins) throws InterruptedException, ExecutionException { 73 | long startTime = System.currentTimeMillis(); 74 | // 1. divide bins into several groups 75 | int[] groupEdges = FSketchUtils.calGroupEdges(zeroValue, binNum, groupNum); 76 | sketches = new MinMaxSketch[groupNum]; 77 | encoders = new BinaryEncoder[groupNum]; 78 | Pair partKBLists = 79 | FSketchUtils.partition(keys, bins, groupEdges); 80 | // 2. each thread encode one group of bins and keys 81 | ExecutorService threadPool = Constants.Parallel.getThreadPool(); 82 | Future>[] futures = new Future[groupNum]; 83 | for (int i = 0; i < groupNum; i++) { 84 | int threadId = i; 85 | futures[threadId] = threadPool.submit(new Callable>() { 86 | @Override 87 | public Pair call() throws Exception { 88 | IntArrayList keyList = partKBLists.getLeft()[threadId]; 89 | IntArrayList binList = partKBLists.getRight()[threadId]; 90 | return compOneGroup(keyList, binList, groupEdges, threadId); 91 | } 92 | }); 93 | } 94 | for (int i = 0; i < groupNum; i++) { 95 | Pair res = futures[i].get(); 96 | sketches[i] = res.getLeft(); 97 | encoders[i] = res.getRight(); 98 | } 99 | LOG.debug(String.format("Create grouped MinMaxSketch cost %d ms", 100 | System.currentTimeMillis() - startTime)); 101 | } 102 | 103 | private Pair compOneGroup(IntArrayList keyList, IntArrayList binList, 104 | int[] groupEdges, int groupId) { 105 | int groupSize = keyList.size(); 106 | if (groupSize == 0) { 107 | LOG.warn(String.format("Group[%d] is empty, group edges: [%d, %d)", groupId, 108 | groupId == 0 ? 0 : groupEdges[groupId - 1], groupEdges[groupId])); 109 | return new ImmutablePair<>(null, null); 110 | } 111 | // encode bins 112 | int colNum = (int) Math.ceil(groupSize * colRatio); 113 | MinMaxSketch sketch = new MinMaxSketch(rowNum, colNum, zeroValue); 114 | for (int j = 0; j < groupSize; j++) { 115 | sketch.insert(keyList.getInt(j), binList.getInt(j)); 116 | } 117 | // encode keys 118 | BinaryEncoder encoder = new DeltaAdaptiveEncoder(); 119 | encoder.encode(keyList.toIntArray(null)); 120 | return new ImmutablePair<>(sketch, encoder); 121 | } 122 | 123 | public Pair restore() { 124 | int size = 0; 125 | // decode each group 126 | // in case there are empty groups 127 | List keysToMerge = new ArrayList<>(groupNum); 128 | List binsToMerge = new ArrayList<>(groupNum); 129 | for (int i = 0; i < groupNum; i++) { 130 | if (encoders[i] != null && sketches[i] != null) { 131 | int[] groupKeys = encoders[i].decode(); 132 | int[] groupBins = new int[groupKeys.length]; 133 | for (int j = 0; j < groupKeys.length; j++) 134 | groupBins[j] = sketches[i].query(groupKeys[j]); 135 | keysToMerge.add(groupKeys); 136 | binsToMerge.add(groupBins); 137 | size += groupKeys.length; 138 | } 139 | } 140 | // merge 141 | int[] keys = new int[size]; 142 | int[] bins = new int[size]; 143 | Sort.merge(keysToMerge.toArray(new int[keysToMerge.size()][]), 144 | binsToMerge.toArray(new int[binsToMerge.size()][]), keys, bins); 145 | return new ImmutablePair<>(keys, bins); 146 | } 147 | 148 | private void writeObject(ObjectOutputStream oos) throws IOException { 149 | oos.writeInt(groupNum); 150 | oos.writeInt(rowNum); 151 | oos.writeDouble(colRatio); 152 | oos.writeInt(binNum); 153 | oos.writeInt(zeroValue); 154 | for (MinMaxSketch sketch : sketches) 155 | oos.writeObject(sketch); 156 | for (BinaryEncoder encoder : encoders) 157 | oos.writeObject(encoder); 158 | } 159 | 160 | private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { 161 | groupNum = ois.readInt(); 162 | rowNum = ois.readInt(); 163 | colRatio = ois.readDouble(); 164 | binNum = ois.readInt(); 165 | zeroValue = ois.readInt(); 166 | sketches = new MinMaxSketch[groupNum]; 167 | for (int i = 0; i < groupNum; i++) 168 | sketches[i] = (MinMaxSketch) ois.readObject(); 169 | encoders = new BinaryEncoder[groupNum]; 170 | for (int i = 0; i < groupNum; i++) 171 | encoders[i] = (BinaryEncoder) ois.readObject(); 172 | } 173 | 174 | } 175 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/sketch/frequency/MinMaxSketch.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.sketch.frequency; 2 | 3 | import org.dma.sketchml.sketch.base.BinaryEncoder; 4 | import org.dma.sketchml.sketch.base.Int2IntHash; 5 | import org.dma.sketchml.sketch.binary.HuffmanEncoder; 6 | import org.dma.sketchml.sketch.hash.HashFactory; 7 | import org.slf4j.Logger; 8 | import org.slf4j.LoggerFactory; 9 | 10 | import java.io.IOException; 11 | import java.io.ObjectInputStream; 12 | import java.io.ObjectOutputStream; 13 | import java.io.Serializable; 14 | import java.util.Arrays; 15 | 16 | public class MinMaxSketch implements Serializable { 17 | private static final Logger LOG = LoggerFactory.getLogger(MinMaxSketch.class); 18 | 19 | protected int rowNum; 20 | protected int colNum; 21 | protected int[] table; 22 | protected int zeroValue; 23 | protected Int2IntHash[] hashes; 24 | 25 | public static final int DEFAULT_MINMAXSKETCH_ROW_NUM = 2; 26 | 27 | public MinMaxSketch(int rowNum, int colNum, int zeroValue) { 28 | this.rowNum = rowNum; 29 | this.colNum = colNum; 30 | this.table = new int[rowNum * colNum]; 31 | this.zeroValue = zeroValue; 32 | int maxValue = compare(Integer.MIN_VALUE, Integer.MAX_VALUE) <= 0 33 | ? Integer.MIN_VALUE : Integer.MAX_VALUE; 34 | Arrays.fill(table, maxValue); 35 | this.hashes = HashFactory.getRandomInt2IntHashes(rowNum, colNum); 36 | } 37 | 38 | public MinMaxSketch(int colNum, int zeroValue) { 39 | this(DEFAULT_MINMAXSKETCH_ROW_NUM, colNum, zeroValue); 40 | } 41 | 42 | /** 43 | * Min: insert the minimal (closest to `zeroValue`) value 44 | * 45 | * @param key 46 | * @param value 47 | */ 48 | public void insert(int key, int value) { 49 | for (int i = 0; i < rowNum; i++) { 50 | int code = hashes[i].hash(key); 51 | int index = i * colNum + code; 52 | if (compare(value, table[index]) < 0) 53 | table[index] = value; 54 | } 55 | } 56 | 57 | 58 | /** 59 | * Max: return the maximal (furthest to `zeroValue`) value 60 | * 61 | * @param key 62 | * @return 63 | */ 64 | public int query(int key) { 65 | int res = zeroValue; 66 | for (int i = 0; i < rowNum; i++) { 67 | int code = hashes[i].hash(key); 68 | int index = i * colNum + code; 69 | if (compare(table[index], res) > 0) 70 | res = table[index]; 71 | } 72 | return res; 73 | } 74 | 75 | /** 76 | * Compare two numbers' distances w.r.t. `zeroValue` 77 | * 78 | * @param v1 79 | * @param v2 80 | * @return 81 | */ 82 | private int compare(int v1, int v2) { 83 | int d1 = Math.abs(v1 - zeroValue); 84 | int d2 = Math.abs(v2 - zeroValue); 85 | return d1 - d2; 86 | } 87 | 88 | private void writeObject(ObjectOutputStream oos) throws IOException { 89 | oos.writeInt(rowNum); 90 | oos.writeInt(colNum); 91 | oos.writeInt(zeroValue); 92 | for (Int2IntHash hash : hashes) 93 | oos.writeObject(hash); 94 | BinaryEncoder huffman = new HuffmanEncoder(); 95 | huffman.encode(table); 96 | oos.writeObject(huffman); 97 | } 98 | 99 | private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { 100 | rowNum = ois.readInt(); 101 | colNum = ois.readInt(); 102 | zeroValue = ois.readInt(); 103 | hashes = new Int2IntHash[rowNum]; 104 | for (int i = 0; i < rowNum; i++) 105 | hashes[i] = (Int2IntHash) ois.readObject(); 106 | BinaryEncoder encoder = (BinaryEncoder) ois.readObject(); 107 | table = encoder.decode(); 108 | } 109 | 110 | public int getRowNum() { 111 | return rowNum; 112 | } 113 | 114 | public int getColNum() { 115 | return colNum; 116 | } 117 | 118 | public int getZeroValue() { 119 | return zeroValue; 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/sketch/quantile/HeapQuantileSketch.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.sketch.quantile; 2 | 3 | import org.dma.sketchml.sketch.base.QuantileSketch; 4 | 5 | import java.util.Arrays; 6 | 7 | /** 8 | * Implementation of quantile sketch on the Java heap 9 | * bashed on `DataSketches` of Yahoo! 10 | */ 11 | public class HeapQuantileSketch extends QuantileSketch { 12 | private int k; // parameter that controls space usage 13 | public static final int DEFAULT_K = 128; 14 | 15 | /** 16 | * This single array contains the base buffer plus all levels some of which may not be used. 17 | * A level is of size K and is either full and sorted, or not used. A "not used" buffer may have 18 | * garbage. Whether a level buffer used or not is indicated by the bitPattern_. 19 | * The base buffer has length 2*K but might not be full and isn't necessarily sorted. 20 | * The base buffer precedes the level buffers. 21 | * 22 | * The levels arrays require quite a bit of explanation, which we defer until later. 23 | */ 24 | private double[] combinedBuffer; 25 | private int combinedBufferCapacity; // equals combinedBuffer.length 26 | private int baseBufferCount; // #samples currently in base buffer (= n % (2*k)) 27 | private long bitPattern; // active levels expressed as a bit pattern (= n / (2*k)) 28 | private static final int MIN_BASE_BUF_SIZE = 4; 29 | 30 | /** 31 | * data structure for answering quantile queries 32 | */ 33 | private double[] samplesArr; // array of size samples 34 | private long[] weightsArr; // array of cut points 35 | 36 | public HeapQuantileSketch(int k, long estimateN) { 37 | super(estimateN); 38 | QSketchUtils.checkK(k); 39 | this.k = k; 40 | reset(); 41 | } 42 | 43 | public HeapQuantileSketch() { 44 | this(DEFAULT_K, -1L); 45 | } 46 | 47 | public HeapQuantileSketch(int k) { 48 | this(k, -1L); 49 | } 50 | 51 | public HeapQuantileSketch(long estimateN) { 52 | this(DEFAULT_K, estimateN); 53 | } 54 | 55 | @Override 56 | public void reset() { 57 | n = 0; 58 | if (estimateN < 0) 59 | combinedBufferCapacity = Math.min(MIN_BASE_BUF_SIZE, k * 2); 60 | else if (estimateN < k * 2) 61 | combinedBufferCapacity = k * 4; 62 | else 63 | combinedBufferCapacity = QSketchUtils.needBufferCapacity(k, estimateN); 64 | combinedBuffer = new double[combinedBufferCapacity]; 65 | baseBufferCount = 0; 66 | bitPattern = 0L; 67 | minValue = Double.MAX_VALUE; 68 | maxValue = Double.MIN_VALUE; 69 | samplesArr = null; 70 | weightsArr = null; 71 | } 72 | 73 | @Override 74 | public void update(double value) { 75 | if (Double.isNaN(value)) 76 | throw new QuantileSketchException("Encounter NaN value"); 77 | maxValue = Math.max(maxValue, value); 78 | minValue = Math.min(minValue, value); 79 | 80 | if (baseBufferCount + 1 > combinedBufferCapacity) 81 | ensureBaseBuffer(); 82 | combinedBuffer[baseBufferCount++] = value; 83 | n++; 84 | if (baseBufferCount == (k * 2)) 85 | fullBaseBufferPropagation(); 86 | } 87 | 88 | private void ensureBaseBuffer() { 89 | final double[] baseBuffer = combinedBuffer; 90 | int oldSize = combinedBufferCapacity; 91 | if (oldSize >= k * 2) 92 | throw new QuantileSketchException("Buffer over size"); 93 | int newSize = Math.max(Math.min(k * 2, oldSize * 2), 1); 94 | combinedBufferCapacity = newSize; 95 | combinedBuffer = Arrays.copyOf(baseBuffer, newSize); 96 | } 97 | 98 | private void ensureLevels(long newN) { 99 | int numLevels = 1 + (63 - Long.numberOfLeadingZeros(newN / (k * 2))); 100 | int spaceNeeded = k * (numLevels + 2); 101 | if (spaceNeeded <= combinedBufferCapacity) return; 102 | final double[] baseBuffer = combinedBuffer; 103 | combinedBuffer = Arrays.copyOf(baseBuffer, spaceNeeded); 104 | combinedBufferCapacity = spaceNeeded; 105 | } 106 | 107 | private void fullBaseBufferPropagation() { 108 | ensureLevels(n); 109 | final double[] baseBuffer = combinedBuffer; 110 | Arrays.sort(baseBuffer, 0, baseBufferCount); 111 | inPlacePropagationUpdate(0, baseBuffer, 0); 112 | baseBufferCount = 0; 113 | QSketchUtils.checkBitPattern(bitPattern, n, k); 114 | } 115 | 116 | private void inPlacePropagationUpdate(int beginLevel, final double[] buf, int bufBeginPos) { 117 | final double[] levelsArr = combinedBuffer; 118 | int endLevel = beginLevel; 119 | long tmp = bitPattern >>> beginLevel; 120 | while ((tmp & 1) != 0) { tmp >>>= 1; endLevel++; } 121 | QSketchUtils.compactBuffer(buf, bufBeginPos, levelsArr, (endLevel + 2) * k, k); 122 | QSketchUtils.levelwisePropagation(bitPattern, k, beginLevel, endLevel, buf, bufBeginPos, levelsArr); 123 | bitPattern += 1L << beginLevel; 124 | } 125 | 126 | public void makeSummary() { 127 | int baseBufferItems = (int)(n % (k * 2)); 128 | QSketchUtils.checkBitPattern(bitPattern, n, k); 129 | int validLevels = Long.bitCount(bitPattern); 130 | int numSamples = baseBufferItems + validLevels * k; 131 | samplesArr = new double[numSamples]; 132 | weightsArr = new long[numSamples + 1]; 133 | 134 | copyBuf2Arr(numSamples); 135 | QSketchUtils.blockyMergeSort(samplesArr, weightsArr, numSamples, k); 136 | 137 | long cnt = 0L; 138 | for (int i = 0; i <= numSamples; i++) { 139 | long newCnt = cnt + weightsArr[i]; 140 | weightsArr[i] = cnt; 141 | cnt = newCnt; 142 | } 143 | } 144 | 145 | private void copyBuf2Arr(int numSamples) { 146 | long weight = 1L; 147 | int cur = 0; 148 | long bp = bitPattern; 149 | 150 | // copy the highest levels 151 | for (int level = 0; bp != 0; level++, bp >>>= 1) { 152 | weight *= 2; 153 | if ((bp & 1) != 0) { 154 | int offset = k * (level + 2); 155 | for (int i = 0; i < k; i++) { 156 | samplesArr[cur] = combinedBuffer[i + offset]; 157 | weightsArr[cur] = weight; 158 | cur++; 159 | } 160 | } 161 | } 162 | 163 | // copy baseBuffer 164 | int startBlk = cur; 165 | for (int i = 0; i < baseBufferCount; i++) { 166 | samplesArr[cur] = combinedBuffer[i]; 167 | weightsArr[cur] = 1L; 168 | cur++; 169 | } 170 | weightsArr[cur] = 0L; 171 | if (cur != numSamples) 172 | throw new QuantileSketchException("Missing items when copying buffer to array"); 173 | Arrays.sort(samplesArr, startBlk, cur); 174 | } 175 | 176 | @Override 177 | public void merge(QuantileSketch other) { 178 | if (other instanceof HeapQuantileSketch) { 179 | merge((HeapQuantileSketch) other); 180 | } else { 181 | throw new QuantileSketchException("Cannot merge different " + 182 | "kinds of quantile sketches"); 183 | } 184 | } 185 | 186 | public void merge(HeapQuantileSketch other) { 187 | if (other == null || other.isEmpty()) return; 188 | if (other.k != this.k) 189 | throw new QuantileSketchException("Merge sketches with different k"); 190 | QSketchUtils.checkBitPattern(other.bitPattern, other.n, other.k); 191 | if (this.isEmpty()) { 192 | this.copy(other); 193 | return; 194 | } 195 | 196 | // merge two non-empty quantile sketches 197 | long totalN = this.n + other.n; 198 | for (int i = 0; i < other.baseBufferCount; i++) { 199 | update(other.combinedBuffer[i]); 200 | } 201 | ensureLevels(totalN); 202 | 203 | final double[] auxBuf = new double[k * 2]; 204 | long bp = other.bitPattern; 205 | for (int level = 0; bp != 0L; level++, bp >>>= 1) { 206 | if ((bp & 1L) != 0L) { 207 | inPlacePropagationMerge(level, other.combinedBuffer, 208 | k * (level + 2), auxBuf, 0); 209 | } 210 | } 211 | 212 | this.n = totalN; 213 | this.maxValue = Math.max(this.maxValue, other.maxValue); 214 | this.minValue = Math.min(this.minValue, other.minValue); 215 | this.samplesArr = null; 216 | this.weightsArr = null; 217 | } 218 | 219 | private void inPlacePropagationMerge(int beginLevel, final double[] buf, int bufStart, 220 | final double[] auxBuf, int auxBufStart) { 221 | final double[] levelsArr = combinedBuffer; 222 | int endLevel = beginLevel; 223 | long tmp = bitPattern >>> beginLevel; 224 | while ((tmp & 1) != 0) { tmp >>>= 1; endLevel++; } 225 | System.arraycopy(buf, bufStart, levelsArr, k * (endLevel + 2), k); 226 | QSketchUtils.levelwisePropagation(bitPattern, k, beginLevel, endLevel, auxBuf, auxBufStart, levelsArr); 227 | bitPattern += 1L << beginLevel; 228 | } 229 | 230 | public void copy(HeapQuantileSketch other) { 231 | this.n = other.n; 232 | this.minValue = other.minValue; 233 | this.maxValue = other.maxValue; 234 | if (this.estimateN == -1) { 235 | this.combinedBufferCapacity = other.combinedBufferCapacity; 236 | this.combinedBuffer = other.combinedBuffer.clone(); 237 | } else if (other.combinedBufferCapacity > this.combinedBufferCapacity) { 238 | this.combinedBufferCapacity = other.combinedBufferCapacity; 239 | this.combinedBuffer = other.combinedBuffer.clone(); 240 | } else { 241 | System.arraycopy(other.combinedBuffer, 0, 242 | this.combinedBuffer, 0, other.combinedBufferCapacity); 243 | } 244 | this.baseBufferCount = other.baseBufferCount; 245 | this.bitPattern = other.bitPattern; 246 | if (other.samplesArr != null && other.weightsArr != null) { 247 | this.samplesArr = other.samplesArr.clone(); 248 | this.weightsArr = other.weightsArr.clone(); 249 | } 250 | } 251 | 252 | @Override 253 | public double getQuantile(double fraction) { 254 | QSketchUtils.checkFraction(fraction); 255 | if (samplesArr == null || weightsArr == null) 256 | makeSummary(); 257 | 258 | if (samplesArr.length == 0) 259 | return Double.NaN; 260 | 261 | if (fraction == 0.0) 262 | return minValue; 263 | else if (fraction == 1.0) 264 | return maxValue; 265 | else 266 | return getQuantileFromArr(fraction); 267 | } 268 | 269 | @Override 270 | public double[] getQuantiles(double[] fractions) { 271 | QSketchUtils.checkFractions(fractions); 272 | if (samplesArr == null || weightsArr == null) 273 | makeSummary(); 274 | 275 | double[] res = new double[fractions.length]; 276 | if (samplesArr.length == 0) { 277 | Arrays.fill(res, Double.NaN); 278 | return res; 279 | } 280 | 281 | for (int i = 0; i < fractions.length; i++) { 282 | if (fractions[i] == 0.0) 283 | res[i] = minValue; 284 | else if (fractions[i] == 1.0) 285 | res[i] = maxValue; 286 | else 287 | res[i] = getQuantileFromArr(fractions[i]); 288 | } 289 | return res; 290 | } 291 | 292 | @Override 293 | public double[] getQuantiles(int evenPartition) { 294 | QSketchUtils.checkEvenPartiotion(evenPartition); 295 | if (samplesArr == null || weightsArr == null) 296 | makeSummary(); 297 | 298 | double[] splits = new double[evenPartition - 1]; 299 | if (samplesArr.length == 0) { 300 | Arrays.fill(splits, Double.NaN); 301 | return splits; 302 | } 303 | 304 | int index = 0; 305 | double curFrac = 1.0 / evenPartition; 306 | double step = 1.0 / evenPartition; 307 | for (int i = 0; i + 1 < evenPartition; i++) { 308 | long rank = (long)(n * curFrac); 309 | rank = Math.min(rank, n - 1); 310 | int left = index, right = weightsArr.length - 1; 311 | while (left + 1 < right) { 312 | int mid = left + ((right - left) >> 1); 313 | if (weightsArr[mid] <= rank) 314 | left = mid; 315 | else 316 | right = mid; 317 | } 318 | splits[i] = samplesArr[left]; 319 | index = left; 320 | curFrac += step; 321 | } 322 | return splits; 323 | } 324 | 325 | private double getQuantileFromArr(double fraction) { 326 | long rank = (long)(n * fraction); 327 | if (rank == n) n--; 328 | int left = 0, right = weightsArr.length - 1; 329 | while (left + 1 < right) { 330 | int mid = left + ((right - left) >> 1); 331 | if (weightsArr[mid] <= rank) 332 | left = mid; 333 | else 334 | right = mid; 335 | } 336 | return samplesArr[left]; 337 | } 338 | 339 | public int getK() { 340 | return k; 341 | } 342 | 343 | 344 | } 345 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/sketch/quantile/QSketchUtils.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.sketch.quantile; 2 | 3 | import org.dma.sketchml.sketch.util.Maths; 4 | 5 | import java.util.Arrays; 6 | import java.util.Random; 7 | 8 | public class QSketchUtils { 9 | private static final Random rand = new Random(); 10 | 11 | protected static void checkK(int k) { 12 | if (k < 1) 13 | throw new QuantileSketchException("Invalid value of k: k should be positive"); 14 | else if (k >= 65535) 15 | throw new QuantileSketchException("Invalid value of k: k should not be larger than 65536"); 16 | else if (!Maths.isPowerOf2(k)) 17 | throw new QuantileSketchException("Invalid value of k: k should be power of 2"); 18 | } 19 | 20 | protected static int needBufferCapacity(int k, long estimateN) { 21 | int numLevels = 1 + (63 - Long.numberOfLeadingZeros(estimateN / (k * 2))); 22 | return k * (numLevels + 2); 23 | } 24 | 25 | protected static void checkBitPattern(long bitPattern, long n, int k) { 26 | if (bitPattern != n / (k * 2)) 27 | throw new QuantileSketchException("Bit Pattern not match"); 28 | } 29 | 30 | protected static void checkFraction(double fraction) { 31 | if (fraction < 0.0 || fraction > 1.0) 32 | throw new QuantileSketchException("Fraction should be in range [0.0, 1.0]"); 33 | } 34 | 35 | protected static void checkFractions(double[] fractions) { 36 | for (double f: fractions) 37 | checkFraction(f); 38 | } 39 | 40 | protected static void checkEvenPartiotion(int evenPartition) { 41 | if (evenPartition <= 1) 42 | throw new QuantileSketchException("Invalid partition number: " + evenPartition); 43 | } 44 | 45 | protected static void compactBuffer(final double[] srcBuf, int srcOffset, 46 | final double[] dstBuf, int dstOffset, int dstSize) { 47 | int offset = rand.nextBoolean() ? 1 : 0; 48 | int bound = dstOffset + dstSize; 49 | for (int i = srcOffset + offset, j = dstOffset; j < bound; i += 2, j++) 50 | dstBuf[j] = srcBuf[i]; 51 | } 52 | 53 | protected static void mergeArrays(final double[] src1, int srcOffset1, 54 | final double[] src2, int srcOffset2, 55 | final double[] dst, int dstOffset, int size) { 56 | int bound1 = srcOffset1 + size; 57 | int bound2 = srcOffset2 + size; 58 | int i1 = srcOffset1, i2 = srcOffset2, i3 = dstOffset; 59 | while (i1 < bound1 && i2 < bound2) { 60 | if (src1[i1] < src2[i2]) 61 | dst[i3++] = src1[i1++]; 62 | else 63 | dst[i3++] = src2[i2++]; 64 | } 65 | if (i1 < bound1) 66 | System.arraycopy(src1, i1, dst, i3, bound1 - i1); 67 | else 68 | System.arraycopy(src2, i2, dst, i3, bound2 - i2); 69 | } 70 | 71 | protected static void levelwisePropagation(long bitPattern, int k, 72 | int beginLevel, int endLevel, 73 | final double[] buf, int bufBeginPos, 74 | final double[] levelsArr) { 75 | for (int level = beginLevel; level < endLevel; level++) { 76 | if ((bitPattern & (1L << level)) == 0) 77 | throw new QuantileSketchException("Encounter empty level: " + level); 78 | QSketchUtils.mergeArrays(levelsArr, k * (level + 2), 79 | levelsArr, k * (endLevel + 2), buf, bufBeginPos, k); 80 | QSketchUtils.compactBuffer(buf, bufBeginPos, levelsArr, k * (endLevel + 2), k); 81 | } 82 | } 83 | 84 | protected static void blockyMergeSort(final double[] keys, final long[] values, 85 | int length, int blkSize) { 86 | if (blkSize <= 0 || length <= blkSize) return; 87 | int numBlks = (length + (blkSize - 1)) / blkSize; 88 | final double[] tmpKeys = Arrays.copyOf(keys, length); 89 | final long[] tmpValues = Arrays.copyOf(values, length); 90 | recursiveBlockyMergeSort(tmpKeys, tmpValues, keys, values, 0, numBlks, blkSize, length); 91 | } 92 | 93 | protected static void recursiveBlockyMergeSort(final double[] kSrc, final long[] vSrc, 94 | final double[] kDst, final long[] vDst, 95 | int blkStart, int blkLen, int blkSize, int arrLimit) { 96 | if (blkLen == 1) return; 97 | int blkLen1 = blkLen >> 1; 98 | int blkLen2 = blkLen - blkLen1; 99 | int blkStart1 = blkStart; 100 | int blkStart2 = blkStart + blkLen1; 101 | 102 | recursiveBlockyMergeSort(kDst, vDst, kSrc, vSrc, blkStart1, blkLen1, blkSize, arrLimit); 103 | recursiveBlockyMergeSort(kDst, vDst, kSrc, vSrc, blkStart2, blkLen2, blkSize, arrLimit); 104 | 105 | int arrStart1 = blkStart1 * blkSize; 106 | int arrStart2 = blkStart2 * blkSize; 107 | int arrLen1 = blkLen1 * blkSize; 108 | int arrLen2 = blkLen2 * blkSize; 109 | if (arrStart2 + arrLen2 > arrLimit) 110 | arrLen2 = arrLimit - arrStart2; 111 | 112 | blockyMerge(kSrc, vSrc, arrStart1, arrLen1, arrStart2, arrLen2, kDst, vDst, arrStart1); 113 | } 114 | 115 | protected static void blockyMerge(final double []kSrc, final long []vSrc, 116 | int arrStart1, int arrLen1, 117 | int arrStart2, int arrLen2, 118 | final double []kDst, final long []vDst, int arrStart3){ 119 | int arrEnd1 = arrStart1 + arrLen1; 120 | int arrEnd2 = arrStart2 + arrLen2; 121 | int i1 = arrStart1, i2 = arrStart2, i3 = arrStart3; 122 | while (i1 < arrEnd1 && i2 < arrEnd2) { 123 | if (kSrc[i1] <= kSrc[i2]){ 124 | kDst[i3] = kSrc[i1]; 125 | vDst[i3] = vSrc[i1]; 126 | ++i1; ++i3; 127 | } else { 128 | kDst[i3] = kSrc[i2]; 129 | vDst[i3] = vSrc[i2]; 130 | ++i2; ++i3; 131 | } 132 | } 133 | 134 | if (i1 < arrEnd1) { 135 | System.arraycopy(kSrc, i1, kDst, i3, arrEnd1 - i1); 136 | System.arraycopy(vSrc, i1, vDst, i3, arrEnd1 - i1); 137 | } else { 138 | System.arraycopy(kSrc, i2, kDst, i3, arrEnd2 - i2); 139 | System.arraycopy(vSrc, i2, vDst, i3, arrEnd2 - i2); 140 | } 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/sketch/quantile/QuantileSketchException.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.sketch.quantile; 2 | 3 | import org.dma.sketchml.sketch.base.SketchMLException; 4 | 5 | public class QuantileSketchException extends SketchMLException { 6 | public QuantileSketchException(String message) { 7 | super(message); 8 | } 9 | 10 | public QuantileSketchException(Throwable cause) { 11 | super(cause); 12 | } 13 | 14 | public QuantileSketchException(String message, Throwable cause) { 15 | super(message, cause); 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/util/Maths.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.util; 2 | 3 | import org.dma.sketchml.sketch.base.SketchMLException; 4 | 5 | import java.util.Random; 6 | 7 | public class Maths { 8 | public static boolean isPowerOf2(int k) { 9 | for (int i = 1; i < 65536; i <<= 1) { 10 | if (k == i) 11 | return true; 12 | } 13 | return false; 14 | } 15 | 16 | public static int log2nlz(int k) { 17 | if (k <= 0) 18 | throw new SketchMLException("Log for " + k); 19 | else 20 | return 31 - Integer.numberOfLeadingZeros(k); 21 | } 22 | 23 | public static int max(int[] array) { 24 | int res = array[0]; 25 | for (int i = 1; i < array.length; i++) 26 | res = Math.max(res, array[i]); 27 | return res; 28 | } 29 | 30 | public static int argmax(int[] array) { 31 | int max = array[0], res = 0; 32 | for (int i = 1; i < array.length; i++) { 33 | if (array[i] > max) { 34 | max = array[i]; 35 | res = i; 36 | } 37 | } 38 | return res; 39 | } 40 | 41 | public static void shuffle(int[] array) { 42 | Random random = new Random(); 43 | for (int i = array.length - 1; i > 0; i--) { 44 | int index = random.nextInt(i + 1); 45 | int t = array[index]; 46 | array[index] = array[i]; 47 | array[i] = t; 48 | } 49 | } 50 | 51 | public static double[] unique(double[] sorted) { 52 | int size = sorted.length, cnt = 1; 53 | for (int i = 1; i < size; i++) 54 | if (sorted[i] != sorted[i - 1]) 55 | cnt++; 56 | if (cnt != size) { 57 | double[] res = new double[cnt]; 58 | res[0] = sorted[0]; 59 | int index = 1; 60 | for (int i = 1; i < size; i++) 61 | if (sorted[i] != sorted[i - 1]) 62 | res[index++] = sorted[i]; 63 | return res; 64 | } else { 65 | return sorted; 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/util/Sort.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.util; 2 | 3 | import it.unimi.dsi.fastutil.doubles.DoubleArrayPriorityQueue; 4 | import it.unimi.dsi.fastutil.doubles.DoubleComparator; 5 | import it.unimi.dsi.fastutil.doubles.DoublePriorityQueue; 6 | import it.unimi.dsi.fastutil.ints.IntComparator; 7 | import org.dma.sketchml.sketch.base.SketchMLException; 8 | 9 | /** 10 | * Quick sort utils 11 | */ 12 | public class Sort { 13 | public static int quickSelect(int[] array, int k, int low, int high) { 14 | if (k > 0 && k <= high - low + 1) { 15 | int pivot = array[high]; 16 | int ii = low; 17 | for (int jj = low; jj < high; jj++) { 18 | if (array[jj] <= pivot) { 19 | swap(array, ii++, jj); 20 | } 21 | } 22 | swap(array, ii, high); 23 | 24 | if (ii - low == k - 1) { 25 | return array[ii]; 26 | } else if (ii - low > k - 1) { 27 | return quickSelect(array, k, low, ii - 1); 28 | } else { 29 | return quickSelect(array, k - ii + low - 1, ii + 1, high); 30 | } 31 | } 32 | throw new SketchMLException("k is more than number of elements in array"); 33 | } 34 | 35 | public static double selectKthLargest(double[] array, int k) { 36 | return selectKthLargest(array, k, new DoubleArrayPriorityQueue(k)); 37 | } 38 | 39 | public static double selectKthLargest(double[] array, int k, DoubleComparator comp) { 40 | return selectKthLargest(array, k, new DoubleArrayPriorityQueue(k, comp)); 41 | } 42 | 43 | private static double selectKthLargest(double[] array, int k, DoublePriorityQueue queue) { 44 | if (k > array.length) 45 | throw new SketchMLException("k is more than number of elements in array"); 46 | 47 | int i = 0; 48 | while (i < k) 49 | queue.enqueue(array[i++]); 50 | for (; i < array.length; i++) { 51 | double top = queue.firstDouble(); 52 | if (array[i] > top) { 53 | queue.dequeueDouble(); 54 | queue.enqueue(array[i]); 55 | } 56 | } 57 | return queue.firstDouble(); 58 | } 59 | 60 | public static void quickSort(int[] array, double[] values, int low, int high) { 61 | if (low < high) { 62 | int tmp = array[low]; 63 | double tmpValue = values[low]; 64 | int ii = low, jj = high; 65 | while (ii < jj) { 66 | while (ii < jj && array[jj] >= tmp) { 67 | jj--; 68 | } 69 | 70 | array[ii] = array[jj]; 71 | values[ii] = values[jj]; 72 | 73 | while (ii < jj && array[ii] <= tmp) { 74 | ii++; 75 | } 76 | 77 | array[jj] = array[ii]; 78 | values[jj] = values[ii]; 79 | } 80 | array[ii] = tmp; 81 | values[ii] = tmpValue; 82 | 83 | quickSort(array, values, low, ii - 1); 84 | quickSort(array, values, ii + 1, high); 85 | } 86 | } 87 | 88 | public static void quickSort(long[] array, double[] values, int low, int high) { 89 | if (low < high) { 90 | long tmp = array[low]; 91 | double tmpValue = values[low]; 92 | int ii = low, jj = high; 93 | while (ii < jj) { 94 | while (ii < jj && array[jj] >= tmp) { 95 | jj--; 96 | } 97 | 98 | array[ii] = array[jj]; 99 | values[ii] = values[jj]; 100 | 101 | while (ii < jj && array[ii] <= tmp) { 102 | ii++; 103 | } 104 | 105 | array[jj] = array[ii]; 106 | values[jj] = values[ii]; 107 | } 108 | array[ii] = tmp; 109 | values[ii] = tmpValue; 110 | 111 | quickSort(array, values, low, ii - 1); 112 | quickSort(array, values, ii + 1, high); 113 | } 114 | } 115 | 116 | public static void quickSort(long[] array, int low, int high) { 117 | if (low < high) { 118 | long tmp = array[low]; 119 | int ii = low, jj = high; 120 | while (ii < jj) { 121 | while (ii < jj && array[jj] >= tmp) { 122 | jj--; 123 | } 124 | 125 | array[ii] = array[jj]; 126 | 127 | while (ii < jj && array[ii] <= tmp) { 128 | ii++; 129 | } 130 | 131 | array[jj] = array[ii]; 132 | } 133 | array[ii] = tmp; 134 | 135 | quickSort(array, low, ii - 1); 136 | quickSort(array, ii + 1, high); 137 | } 138 | } 139 | 140 | public static void quickSort(int[] array, int[] values, int low, int high) { 141 | if (low < high) { 142 | int tmp = array[low]; 143 | int tmpValue = values[low]; 144 | int ii = low, jj = high; 145 | while (ii < jj) { 146 | while (ii < jj && array[jj] >= tmp) { 147 | jj--; 148 | } 149 | 150 | array[ii] = array[jj]; 151 | values[ii] = values[jj]; 152 | 153 | while (ii < jj && array[ii] <= tmp) { 154 | ii++; 155 | } 156 | 157 | array[jj] = array[ii]; 158 | values[jj] = values[ii]; 159 | } 160 | array[ii] = tmp; 161 | values[ii] = tmpValue; 162 | 163 | quickSort(array, values, low, ii - 1); 164 | quickSort(array, values, ii + 1, high); 165 | } 166 | } 167 | 168 | public static void quickSort(double[] x, double[] y, int from, int to, DoubleComparator comp) { 169 | int len = to - from; 170 | if (len < 7) { 171 | selectionSort(x, y, from, to, comp); 172 | } else { 173 | int m = from + len / 2; 174 | int v; 175 | int a; 176 | int b; 177 | if (len > 7) { 178 | v = from; 179 | a = to - 1; 180 | if (len > 50) { 181 | b = len / 8; 182 | v = med3(x, from, from + b, from + 2 * b, comp); 183 | m = med3(x, m - b, m, m + b, comp); 184 | a = med3(x, a - 2 * b, a - b, a, comp); 185 | } 186 | 187 | m = med3(x, v, m, a, comp); 188 | } 189 | 190 | double seed = x[m]; 191 | a = from; 192 | b = from; 193 | int c = to - 1; 194 | int d = c; 195 | 196 | while (true) { 197 | int s; 198 | while (b > c || (s = comp.compare(x[b], seed)) > 0) { 199 | for (; c >= b && (s = comp.compare(x[c], seed)) >= 0; --c) { 200 | if (s == 0) { 201 | swap(x, c, d); 202 | swap(y, c, d); 203 | d--; 204 | } 205 | } 206 | 207 | if (b > c) { 208 | s = Math.min(a - from, b - a); 209 | vecSwap(x, from, b - s, s); 210 | vecSwap(y, from, b - s, s); 211 | s = Math.min(d - c, to - d - 1); 212 | vecSwap(x, b, to - s, s); 213 | vecSwap(y, b, to - s, s); 214 | if ((s = b - a) > 1) { 215 | quickSort(x, y, from, from + s, comp); 216 | } 217 | 218 | if ((s = d - c) > 1) { 219 | quickSort(x, y, to - s, to, comp); 220 | } 221 | 222 | return; 223 | } 224 | 225 | swap(x, b, c); 226 | swap(y, b, c); 227 | b++; 228 | c--; 229 | } 230 | 231 | if (s == 0) { 232 | swap(x, a, b); 233 | swap(y, a, b); 234 | a++; 235 | } 236 | 237 | ++b; 238 | } 239 | } 240 | } 241 | 242 | public static void quickSort(double[] x, double[] y, int from, int to) { 243 | DoubleComparator cmp = new DoubleComparator() { 244 | public int compare(double v, double v1) { 245 | if (Math.abs(v - v1) < 10e-12) 246 | return 0; 247 | else 248 | return v - v1 > 10e-12 ? 1 : -1; 249 | } 250 | 251 | public int compare(Double o1, Double o2) { 252 | if (Math.abs(o1 - o2) < 10e-12) 253 | return 0; 254 | else 255 | return o1 - o2 > 10e-12 ? 1 : -1; 256 | } 257 | }; 258 | quickSort(x, y, from, to, cmp); 259 | } 260 | 261 | public static void quickSort(int[] array, float[] values, int low, int high) { 262 | if (low < high) { 263 | int tmp = array[low]; 264 | float tmpValue = values[low]; 265 | int ii = low, jj = high; 266 | while (ii < jj) { 267 | while (ii < jj && array[jj] >= tmp) { 268 | jj--; 269 | } 270 | 271 | array[ii] = array[jj]; 272 | values[ii] = values[jj]; 273 | 274 | while (ii < jj && array[ii] <= tmp) { 275 | ii++; 276 | } 277 | 278 | array[jj] = array[ii]; 279 | values[jj] = values[ii]; 280 | } 281 | array[ii] = tmp; 282 | values[ii] = tmpValue; 283 | 284 | quickSort(array, values, low, ii - 1); 285 | quickSort(array, values, ii + 1, high); 286 | } 287 | } 288 | 289 | 290 | private static int med3(double[] x, int a, int b, int c, DoubleComparator comp) { 291 | int ab = comp.compare(x[a], x[b]); 292 | int ac = comp.compare(x[a], x[c]); 293 | int bc = comp.compare(x[b], x[c]); 294 | return ab < 0 ? (bc < 0 ? b : (ac < 0 ? c : a)) : (bc > 0 ? b : (ac > 0 ? c : a)); 295 | } 296 | 297 | private static void vecSwap(double[] x, int a, int b, int n) { 298 | for (int i = 0; i < n; ++b) { 299 | swap(x, a, b); 300 | ++i; 301 | ++a; 302 | } 303 | 304 | } 305 | 306 | private static void swap(int[] x, int a, int b) { 307 | int t = x[a]; 308 | x[a] = x[b]; 309 | x[b] = t; 310 | } 311 | 312 | private static void swap(double[] x, int a, int b) { 313 | double t = x[a]; 314 | x[a] = x[b]; 315 | x[b] = t; 316 | } 317 | 318 | public static void selectionSort(int[] a, int[] y, int from, int to, IntComparator comp) { 319 | for (int i = from; i < to - 1; ++i) { 320 | int m = i; 321 | 322 | int u; 323 | for (u = i + 1; u < to; ++u) { 324 | if (comp.compare(a[u], a[m]) < 0) { 325 | m = u; 326 | } 327 | } 328 | 329 | if (m != i) { 330 | u = a[i]; 331 | a[i] = a[m]; 332 | a[m] = u; 333 | u = y[i]; 334 | y[i] = y[m]; 335 | y[m] = u; 336 | } 337 | } 338 | 339 | } 340 | 341 | public static void selectionSort(double[] a, double[] y, int from, int to, 342 | DoubleComparator comp) { 343 | for (int i = from; i < to - 1; ++i) { 344 | int m = i; 345 | for (int u = i + 1; u < to; ++u) { 346 | if (comp.compare(a[u], a[m]) < 0) { 347 | m = u; 348 | } 349 | } 350 | 351 | if (m != i) { 352 | double temp = a[i]; 353 | a[i] = a[m]; 354 | a[m] = temp; 355 | temp = y[i]; 356 | y[i] = y[m]; 357 | y[m] = temp; 358 | } 359 | } 360 | } 361 | 362 | public static void merge(int[][] as, int[][] ys, int[] a, int[] y) { 363 | int[] ks = new int[as.length]; 364 | int cur = 0; 365 | while (cur < a.length) { 366 | int argmin = -1; 367 | int min = Integer.MAX_VALUE; 368 | for (int i = 0; i < ks.length; i++) { 369 | if (ks[i] < as[i].length && as[i][ks[i]] < min) { 370 | argmin = i; 371 | min = as[i][ks[i]]; 372 | } 373 | } 374 | a[cur] = as[argmin][ks[argmin]]; 375 | y[cur] = ys[argmin][ks[argmin]]; 376 | ks[argmin]++; 377 | cur++; 378 | } 379 | } 380 | 381 | public static void merge(int[][] as, double[][] ys, int[] a, double[] y) { 382 | int[] ks = new int[as.length]; 383 | int cur = 0; 384 | while (cur < a.length) { 385 | int argmin = -1; 386 | int min = Integer.MAX_VALUE; 387 | for (int i = 0; i < ks.length; i++) { 388 | if (ks[i] < as[i].length && as[i][ks[i]] < min) 389 | argmin = i; 390 | } 391 | a[cur] = as[argmin][ks[argmin]]; 392 | y[cur] = ys[argmin][ks[argmin]]; 393 | ks[argmin]++; 394 | cur++; 395 | } 396 | } 397 | } 398 | -------------------------------------------------------------------------------- /sketch/src/main/java/org/dma/sketchml/sketch/util/Utils.java: -------------------------------------------------------------------------------- 1 | package org.dma.sketchml.sketch.util; 2 | 3 | import java.io.*; 4 | 5 | public class Utils { 6 | public static int sizeof(Object obj) throws IOException { 7 | ByteArrayOutputStream byteOutputStream = new ByteArrayOutputStream(); 8 | ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteOutputStream); 9 | 10 | objectOutputStream.writeObject(obj); 11 | objectOutputStream.flush(); 12 | objectOutputStream.close(); 13 | 14 | return byteOutputStream.toByteArray().length; 15 | } 16 | 17 | public static Serializable testSerialization(Serializable obj) throws IOException, ClassNotFoundException { 18 | ByteArrayOutputStream byteOutputStream = new ByteArrayOutputStream(); 19 | ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteOutputStream); 20 | objectOutputStream.writeObject(obj); 21 | objectOutputStream.flush(); 22 | objectOutputStream.close(); 23 | 24 | ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(byteOutputStream.toByteArray()); 25 | ObjectInputStream inputStream = new ObjectInputStream(byteArrayInputStream); 26 | return (Serializable) inputStream.readObject(); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /sketch/src/main/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | log4j.rootLogger=INFO, STDOUT 2 | log4j.logger.deng=INFO 3 | log4j.appender.STDOUT=org.apache.log4j.ConsoleAppender 4 | log4j.appender.STDOUT.layout=org.apache.log4j.PatternLayout 5 | log4j.appender.STDOUT.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SSS} %C %p - %m%n --------------------------------------------------------------------------------