├── src ├── main │ └── scala │ │ ├── HashFunc.scala │ │ ├── examples.scala │ │ ├── util.scala │ │ ├── OPTICS.scala │ │ ├── DBSCAN.scala │ │ ├── LSHRank.scala │ │ ├── LSHGrouping.scala │ │ ├── StringIndex.scala │ │ ├── RadixQuicksort.scala │ │ ├── VPTree.scala │ │ ├── LSHBulkStrategies.scala │ │ ├── ImagePHash.scala │ │ ├── RadixSorting.scala │ │ ├── LSHResultBuilders.scala │ │ ├── BurstSort.scala │ │ ├── MemoryMappedLSH.scala │ │ ├── BloomFilter.scala │ │ ├── ALS.scala │ │ ├── FunkSVD.scala │ │ ├── datastructures.scala │ │ ├── heaps.scala │ │ ├── RadixSort.scala │ │ ├── Sketch.scala │ │ ├── SketchImpls.scala │ │ └── fast-dot.scala └── test │ └── scala │ └── Sorting.scala ├── license.txt └── README.md /src/main/scala/HashFunc.scala: -------------------------------------------------------------------------------- 1 | package atrox.sketch 2 | 3 | trait HashFunc[@scala.specialized(Int, Long) T] extends Serializable { 4 | def apply(x: T): Int 5 | } 6 | 7 | object HashFunc { 8 | def random(seed: Int, randomBits: Int = 32): HashFunc[Int] = { 9 | val rand = new scala.util.Random(seed) 10 | new HashFunc[Int] { 11 | private[this] val M = randomBits 12 | private[this] val a: Long = (rand.nextLong() & ((1L << 62)-1)) * 2 + 1 // random odd positive integer (a < 2^w) 13 | private[this] val b: Long = math.abs(rand.nextLong() & ((1L << (64 - M))-1)) // random non-negative integer (b < 2^(w-M) 14 | def apply(x: Int): Int = ((a*x+b) >>> (64-M)).toInt 15 | 16 | override def toString = s"HashFunc: f(x) = (${a}L * x + ${b}L) >>> ${64-M}" 17 | } 18 | } 19 | } 20 | 21 | trait HashFuncLong[T] extends Serializable { 22 | def apply(x: T): Long 23 | } 24 | -------------------------------------------------------------------------------- /src/main/scala/examples.scala: -------------------------------------------------------------------------------- 1 | package atrox.example 2 | 3 | import java.io.{ File, FileInputStream } 4 | import atrox._ 5 | import atrox.sketch._ 6 | 7 | object PHashExample extends App { 8 | 9 | if (args.length < 1) sys.exit() 10 | 11 | val directory = new File(args(0)) 12 | 13 | if (!directory.exists) sys.exit() 14 | 15 | val files = directory 16 | .listFiles 17 | .filter(f => f.getName.toLowerCase.matches(".*\\.(jpg|png)$")) 18 | 19 | println("PHashing") 20 | 21 | val phash = new ImagePHash(32, 8) // PHash produces 64 bit hash (8x8) 22 | 23 | val hashArray = files.par 24 | .map(f => phash(new FileInputStream(f))) 25 | .toArray 26 | 27 | println("LSH construction") 28 | 29 | val sketch = HammingDistance(hashArray, 64) 30 | val (_, bands) = LSH.pickHashesAndBands(0.85, 64) 31 | val lsh = LSH.estimating(sketch, LSHBuildCfg(bands = bands)) 32 | 33 | println("LSH query") 34 | 35 | for ((idx, sims) <- lsh.allSimilarItems(LSHCfg(maxResults = 1, threshold = 0.85))) { 36 | for (sim <- sims) { 37 | println(files(idx)+" "+files(sim.idx)+" "+sim.sim) 38 | } 39 | } 40 | 41 | } 42 | -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, Karel Čížek (kaja47@k47.cz) 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | * The name of the copyright holder may not be used to endorse or promote 12 | products derived from this software without specific prior written 13 | permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL KAREL ČÍŽEK BE LIABLE FOR ANY 19 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Sketches 2 | 3 | *Sketches* is a library for sketching, locality sensitive hashing, 4 | approximate similarity search and other things. 5 | 6 | ### Usage 7 | 8 | Basic use case is search for all similar items in a dataset. 9 | 10 | ```scala 11 | import atrox.sketch._ 12 | 13 | val sets: IndexedSeq[Set[Int]] = loadMyData() 14 | 15 | val (bands, hashes) = LSH.pickHashesAndBands(threshold = 0.5, maxHashes = 64) 16 | val lsh = LSH(sets, MinHash[Set[Int]](hashes), LSHBuildCfg(bands = bands)) 17 | 18 | val cfg = LSHCfg(maxResults = 50) 19 | 20 | for (Sim(idx1, idx2, estimate, similarity) <- lsh.withConfig(cfg).allSimilarItems) { 21 | println(s"similarity between item $idx1 and $idx2 is estimated to $estimate") 22 | } 23 | ``` 24 | 25 | There's more configuration options available. 26 | 27 | ```scala 28 | lsh.withConfig(LSHCfg( 29 | // Return only the 100 most relevant results. 30 | // It's strongly recommended to use this option. 31 | maxResults = 100, 32 | 33 | // Perform similarity search in parallel. 34 | parallel = true, 35 | 36 | // Use as much memory as needed. This leads to faster bulk queries but 37 | // might need to store the complete result set in memory. 38 | compact = false, 39 | 40 | // Skip anomalously large buckets. This speeds things up quite a bit. 41 | maxBucketSize = sets.size / 10 42 | )) 43 | ``` 44 | 45 | And more query methods. 46 | 47 | ```scala 48 | lsh.similarItems(q) 49 | lsh.similarIndexes(q) 50 | lsh.allSimilarItems 51 | lsh.allSimilarIndexes 52 | ``` 53 | 54 | And more sketching methods. 55 | 56 | - MinHash and SingleBitMinHash for estimating Jaccard index 57 | - WeightedMinHash for estimating weighted Jaccard index 58 | - RandomHyperplanes for estimation cosine similarity 59 | - RandomProjections for LSH based on euclidean distance 60 | - HammingDistance 61 | - SimHash 62 | -------------------------------------------------------------------------------- /src/main/scala/util.scala: -------------------------------------------------------------------------------- 1 | package atrox 2 | 3 | import breeze.linalg.{ SparseVector, DenseVector, BitVector } 4 | 5 | object crap { 6 | 7 | def sum(xs: Seq[SparseVector[Double]]): DenseVector[Double] = { 8 | val s = DenseVector.zeros[Double](xs.head.size) 9 | for (x <- xs) { 10 | s += x 11 | } 12 | s 13 | } 14 | 15 | def df(xs: Seq[SparseVector[Double]]): DenseVector[Double] = { 16 | val s = DenseVector.zeros[Double](xs.head.size) 17 | for (vec <- xs) { 18 | var offset = 0 19 | while (offset < vec.activeSize) { 20 | val i = vec.indexAt(offset) 21 | s(i) += 1 22 | offset += 1 23 | } 24 | } 25 | s 26 | } 27 | 28 | def tfBoolean(fs: Seq[SparseVector[Double]]): Seq[SparseVector[Double]] = 29 | for (vec <- fs) yield vec mapActiveValues { _ => 1.0 } 30 | 31 | def tfLog(fs: Seq[SparseVector[Double]]): Seq[SparseVector[Double]] = 32 | for (vec <- fs) yield vec mapActiveValues { f => 1.0 + math.log(f) } 33 | 34 | def tfAugmented(fs: Seq[SparseVector[Double]]): Seq[SparseVector[Double]] = 35 | for (vec <- fs) yield { 36 | val m = breeze.linalg.max(vec) 37 | vec mapActiveValues { f => 0.5 + (0.5 * f) / m } 38 | } 39 | 40 | 41 | def tfidf(tfs: Seq[SparseVector[Double]]): Seq[SparseVector[Double]] = 42 | tfidf(tfs, df(tfs)) 43 | 44 | 45 | def tfidf(tfs: Seq[SparseVector[Double]], df: DenseVector[Double]): Seq[SparseVector[Double]] = { 46 | val N = tfs.size 47 | for (vec <- tfs) yield { 48 | vec mapActivePairs { case (idx, tf) => tf * math.log(N / df(idx)) } 49 | } 50 | } 51 | 52 | } 53 | 54 | 55 | 56 | class Xorshift(var x: Int = System.currentTimeMillis.toInt, var y: Int = 4711, var z: Int = 5485612, var w: Int = 992121) { 57 | def nextInt(): Int = { 58 | val t = x ^ (x << 11) 59 | x = y 60 | y = z 61 | z = w 62 | w = w ^ (w >>> 19) ^ t ^ (t >>> 8) 63 | w 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/main/scala/OPTICS.scala: -------------------------------------------------------------------------------- 1 | package atrox 2 | 3 | import breeze.linalg._ 4 | import breeze.numerics._ 5 | import breeze.util.HashIndex 6 | //import collection.mutable.PriorityQueue 7 | import java.util.PriorityQueue 8 | import scala.collection.JavaConverters._ 9 | import collection.mutable.ArrayBuffer 10 | 11 | 12 | object OPTICS { 13 | 14 | def OPTICS[Point](dataset: IndexedSeq[Point], eps: Double, minPts: Int, dist: (Point, Point) => Double): IndexedSeq[(Point, Int, Double)] = { 15 | 16 | type PointIdxDist = (Point, Int, Double) 17 | 18 | val UNDEFINED = -1.0 19 | val reachabilityDistance = Array.fill(dataset.size)(UNDEFINED) 20 | val processed = Array.fill(dataset.size)(false) 21 | 22 | def run = { 23 | val orderedList = ArrayBuffer[Int]() 24 | 25 | for ((p, pIdx) <- dataset.zipWithIndex if !processed(pIdx)) { 26 | processed(pIdx) = true 27 | orderedList += pIdx 28 | val seeds = new PriorityQueue[PointIdxDist](16, Ordering.by(_._3)) 29 | val neighbors = getNeighbors(p) 30 | if (coreDistance(p, neighbors) != UNDEFINED) { 31 | update(neighbors, p, seeds) 32 | 33 | while (!seeds.isEmpty) { 34 | val (q, qIdx, dist) = seeds.poll() 35 | processed(qIdx) = true 36 | orderedList += qIdx 37 | val newNeighbors = getNeighbors(q) 38 | if (coreDistance(q, newNeighbors) != UNDEFINED) { 39 | update(newNeighbors, q, seeds) 40 | } 41 | } 42 | } 43 | } 44 | orderedList.toVector map { idx => (dataset(idx), idx, reachabilityDistance(idx)) } 45 | } 46 | 47 | def update(neighbors: IndexedSeq[PointIdxDist], p: Point, seeds: PriorityQueue[PointIdxDist]): Unit = { 48 | val coredist = coreDistance(p, neighbors) 49 | for (old @ (o, oIdx, oDist) <- neighbors if !processed(oIdx)) { 50 | val newReachDist = math.max(coredist, dist(p, o)) 51 | if (reachabilityDistance(oIdx) == UNDEFINED) { // o is not in seeds 52 | reachabilityDistance(oIdx) = newReachDist 53 | seeds.add((o, oIdx, newReachDist)) 54 | } else if (newReachDist < reachabilityDistance(oIdx)) { 55 | reachabilityDistance(oIdx) = newReachDist 56 | val el = seeds.iterator.asScala.find(_._2 == oIdx).get 57 | seeds.remove(el) 58 | seeds.add((o, oIdx, newReachDist)) 59 | } 60 | } 61 | } 62 | 63 | def getNeighbors(p: Point): IndexedSeq[PointIdxDist] = 64 | (for { 65 | (pp, ppIdx) <- dataset.zipWithIndex 66 | d = dist(p, pp) 67 | if d <= eps 68 | } yield (pp, ppIdx, d)).toVector.sortBy(_._3) 69 | 70 | // `neighbors` must be sorted by distance from `p` 71 | def coreDistance(p: Point, neighbors: IndexedSeq[PointIdxDist]): Double = { 72 | if (neighbors.size < minPts) UNDEFINED 73 | else neighbors(minPts-1)._3 74 | } 75 | 76 | run 77 | 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/main/scala/DBSCAN.scala: -------------------------------------------------------------------------------- 1 | package atrox 2 | 3 | import scala.specialized 4 | import scala.reflect.ClassTag 5 | 6 | 7 | object DBSCAN { 8 | def apply[Point: ClassTag](dataset: IndexedSeq[Point], eps: Double, minPts: Int, dist: (Point, Point) => Double): Result[Point] = 9 | new DBSCAN(dataset.toArray, eps, minPts, dist).run 10 | 11 | def apply[Point: ClassTag](dataset: Array[Point], eps: Double, minPts: Int, dist: (Point, Point) => Double): Result[Point] = 12 | new DBSCAN(dataset, eps, minPts, dist).run 13 | 14 | def apply[Point: ClassTag](dataset: Array[Point], minPts: Int, regionQueryFunc: Int => IndexedSeq[Int]): Result[Point] = 15 | new DBSCAN(dataset, -1, minPts, (a: Point, b: Point) => ???) { 16 | override def regionQuery(pIdx: Int): IndexedSeq[Int] = regionQueryFunc(pIdx) 17 | }.run 18 | 19 | case class Result[Point](clusters: IndexedSeq[Map[Int, Point]], noise: Map[Int, Point]) 20 | } 21 | 22 | 23 | class DBSCAN[Point: ClassTag](dataset: Array[Point], val eps: Double, val minPts: Int, val dist: (Point, Point) => Double) { 24 | 25 | val NotVisited = -1 26 | val Noise = -2 27 | 28 | // NotVisited, Noise, clutserId 29 | val pointClusters = Array.fill(dataset.size)(NotVisited) 30 | 31 | // @return seq of clusters and seq of points without cluster 32 | def run: DBSCAN.Result[Point] = { 33 | var currentClusterId = 0 34 | for (pIdx <- 0 until dataset.size) { 35 | if (pointClusters(pIdx) == NotVisited) { 36 | val neighborPts = regionQuery(pIdx) 37 | if (neighborPts.size < minPts) { 38 | pointClusters(pIdx) = Noise 39 | currentClusterId += 1 40 | } else { 41 | expandCluster(pIdx, neighborPts, currentClusterId) 42 | } 43 | } 44 | } 45 | 46 | assert(pointClusters forall (_ != NotVisited)) 47 | 48 | val grouped = ( 49 | (0 until pointClusters.length) 50 | .groupBy(pointClusters) 51 | .mapValues { idxs => idxs.map { idx => (idx, dataset(idx)) }.toMap } 52 | ) 53 | 54 | val clusters = (grouped - Noise).values.toVector 55 | val noise = grouped(Noise) 56 | DBSCAN.Result(clusters, noise) 57 | } 58 | 59 | def expandCluster(pIdx: Int, _neighborPts: IndexedSeq[Int], clusterId: Int) = { 60 | val neighborPts = scala.collection.mutable.Set[Int](_neighborPts: _*) 61 | pointClusters(pIdx) = clusterId 62 | 63 | while (!neighborPts.isEmpty) { 64 | val ppIdx = neighborPts.head 65 | neighborPts.remove(ppIdx) 66 | if (pointClusters(ppIdx) == NotVisited) { 67 | val newNeighborPts = regionQuery(ppIdx) 68 | if (newNeighborPts.size >= minPts) { 69 | //neighborPts ++= newNeighborPts 70 | newNeighborPts foreach { idx => 71 | if (pointClusters(idx) == NotVisited) { 72 | neighborPts += idx 73 | } 74 | } 75 | } 76 | } 77 | if (pointClusters(ppIdx) < 0) { // P' is not yet member of any cluster 78 | pointClusters(ppIdx) = clusterId 79 | } 80 | } 81 | } 82 | 83 | /** Return all points within P's eps-neighborhood (including P). 84 | * If this method is overwriten, it must not return duplicate points. */ 85 | def regionQuery(pIdx: Int): IndexedSeq[Int] = { 86 | val res = new collection.immutable.VectorBuilder[Int] 87 | val p = dataset(pIdx) 88 | 89 | var ppIdx = 0 90 | while (ppIdx < dataset.length) { 91 | if (dist(p, dataset(ppIdx)) <= eps) { 92 | res += ppIdx 93 | } 94 | ppIdx += 1 95 | } 96 | 97 | res.result 98 | } 99 | 100 | 101 | } 102 | -------------------------------------------------------------------------------- /src/main/scala/LSHRank.scala: -------------------------------------------------------------------------------- 1 | package atrox.sketch 2 | 3 | import atrox.Bits 4 | 5 | 6 | trait Rank[-Q, S] { 7 | def map(q: Q): S 8 | def map(idx: Int): S 9 | 10 | /** Similarity/distance function. Result must be mapped to integer such as 11 | * higher number means bigger similarity or smaller distance. */ 12 | def rank(a: S, b: S): Int 13 | 14 | def rank(a: S, b: Int): Int = rank(a, map(b)) 15 | 16 | /** Index based rank. It's intended for bulk methods that operate only on 17 | * internal tables. In that case it should be overwritten to use indexes 18 | * into Sketch object without any extra allocations. */ 19 | def rank(a: Int, b: Int): Int = rank(map(a), map(b)) 20 | 21 | def rank(r: Double): Int 22 | 23 | /** Recover similarity/distance encoded in integer rank value. */ 24 | def derank(r: Int): Double 25 | } 26 | 27 | 28 | trait SimRank[@specialized(Long) S] extends Rank[S, S] { 29 | def apply(a: S, b: S): Double 30 | 31 | def rank(a: S, b: S): Int = rank(apply(a, b)) 32 | def rank(d: Double): Int = Bits.floatToSortableInt(d.toFloat) 33 | def derank(r: Int): Double = Bits.sortableIntToFloat(r) 34 | } 35 | 36 | case class SimFun[S](f: (S, S) => Double, dataset: IndexedSeq[S]) extends SimRank[S] { 37 | def map(q: S): S = q 38 | def map(idx: Int): S = dataset(idx) 39 | 40 | def apply(a: S, b: S) = f(a, b) 41 | } 42 | 43 | 44 | trait DistRank[@specialized(Long) S] extends Rank[S, S] { 45 | def apply(a: S, b: S): Double 46 | 47 | def rank(a: S, b: S): Int = rank(apply(a, b)) 48 | def rank(d: Double): Int = ~Bits.floatToSortableInt(d.toFloat) 49 | def derank(r: Int): Double = Bits.sortableIntToFloat(~r) 50 | } 51 | 52 | case class DistFun[@specialized(Long) S](f: (S, S) => Double, dataset: IndexedSeq[S]) extends DistRank[S] { 53 | def map(q: S): S = q 54 | def map(idx: Int): S = dataset(idx) 55 | 56 | def apply(a: S, b: S) = f(a, b) 57 | } 58 | 59 | case class SketchRank[Q, SketchArray](sk: Sketch[Q, SketchArray]) extends Rank[Q, (SketchArray, Int)] { 60 | 61 | type S = (SketchArray, Int) 62 | def es = sk.estimator 63 | 64 | def map(q: Q): S = (sk.sketchers.getSketchFragment(q), 0) 65 | def map(idx: Int): S = (sk.sketchArray, idx) 66 | 67 | def rank(a: S, b: S): Int = { 68 | val (skarra, idxa) = a 69 | val (skarrb, idxb) = b 70 | es.sameBits(skarra, idxa, skarrb, idxb) 71 | } 72 | override def rank(a: S, b: Int): Int = { 73 | val (skarra, idxa) = a 74 | es.sameBits(skarra, idxa, sk.sketchArray, b) 75 | } 76 | override def rank(a: Int, b: Int): Int = es.sameBits(sk.sketchArray, a, sk.sketchArray, b) 77 | 78 | def rank(d: Double): Int = es.minSameBits(d) 79 | 80 | /** Recover similarity/distance encoded in integer rank value. */ 81 | def derank(r: Int): Double = es.estimateSimilarity(r) 82 | } 83 | 84 | 85 | case class InlineSketchRank[Q, SketchArray](sketch: Sketch[Q, SketchArray], sketchers: Sketchers[Q, SketchArray]) extends Rank[Q, SketchArray] { 86 | 87 | type S = SketchArray 88 | def es = sketch.estimator 89 | 90 | def map(q: Q): S = sketchers.getSketchFragment(q) 91 | def map(idx: Int): S = sketch.getSketchFragment(idx) 92 | 93 | def rank(a: S, b: S): Int = es.sameBits(a, 0, b, 0) 94 | override def rank(a: S, b: Int): Int = es.sameBits(a, 0, sketch.sketchArray, b) 95 | override def rank(a: Int, b: Int): Int = es.sameBits(sketch.sketchArray, a, sketch.sketchArray, b) 96 | 97 | def rank(d: Double): Int = es.minSameBits(d) 98 | 99 | /** Recover similarity/distance encoded in integer rank value. */ 100 | def derank(r: Int): Double = es.estimateSimilarity(r) 101 | } 102 | -------------------------------------------------------------------------------- /src/main/scala/LSHGrouping.scala: -------------------------------------------------------------------------------- 1 | package atrox.sketch 2 | 3 | import scala.collection.mutable 4 | import atrox.sort.RadixSort 5 | import atrox.Bits 6 | import java.util.Arrays 7 | 8 | abstract class Grouping(keyRange: Int) { 9 | def add(k: Int, v: Int): Unit 10 | def getAll: Iterator[(Int, Array[Int])] 11 | def toArray: Array[Array[Int]] 12 | 13 | protected def mkIterator = Iterator.range(0, keyRange) 14 | } 15 | 16 | object Grouping { 17 | 18 | def apply(keyRange: Int, numberOfValues: Int = Int.MaxValue, counts: Array[Int] = null) = 19 | if (numberOfValues < Int.MaxValue && numberOfValues > 0) { 20 | new Sorted(keyRange, numberOfValues) 21 | } else if (counts != null) { 22 | new Counted(keyRange, counts) 23 | } else { 24 | new Buffered(keyRange) 25 | } 26 | 27 | 28 | class Buffered(keyRange: Int) extends Grouping(keyRange) { 29 | val arr = new Array[mutable.ArrayBuilder.ofInt](keyRange) 30 | def add(k: Int, v: Int): Unit = { 31 | if (arr(k) == null) { 32 | arr(k) = new mutable.ArrayBuilder.ofInt 33 | } 34 | arr(k) += v 35 | } 36 | 37 | def getAll = mkIterator collect { case i if arr(i) != null => (i, arr(i).result) } 38 | def toArray = arr map (x => if (x != null) x.result else null) 39 | } 40 | 41 | 42 | class Counted(keyRange: Int, counts: Array[Int]) extends Grouping(keyRange) { 43 | private val positions = new Array[Int](keyRange) 44 | private val arr = new Array[Array[Int]](keyRange) 45 | for (i <- 0 until keyRange) { if (counts(i) > 0) arr(i) = new Array[Int](counts(i)) } 46 | 47 | def add(k: Int, v: Int): Unit = { 48 | arr(k)(positions(k)) = v 49 | positions(k) += 1 50 | } 51 | 52 | def getAll = mkIterator collect { case i if arr(i) != null => (i, arr(i)) } 53 | def toArray = arr 54 | } 55 | 56 | 57 | class Sorted(keyRange: Int, numberOfValues: Int) extends Grouping(keyRange) { 58 | var head = 0 59 | val arr = new Array[Long](numberOfValues) 60 | 61 | def add(k: Int, v: Int): Unit = { 62 | arr(head) = Bits.pack(k, v) 63 | head += 1 64 | } 65 | 66 | private def getKey(l: Long) = Bits.unpackIntHi(l) 67 | private def getVal(l: Long) = Bits.unpackIntLo(l) 68 | 69 | def getAll = { 70 | while (head < numberOfValues) { arr(head) = Int.MaxValue.toLong << 32 ; head += 1 } 71 | val scratch = new Array[Long](numberOfValues) 72 | val (sorted, _) = RadixSort.sort(arr, scratch, 0, numberOfValues, 4, 8, false) 73 | 74 | var i = 0 75 | mkIterator map { key => 76 | while (i < numberOfValues && getKey(sorted(i)) < key) { i += 1 } 77 | var start = i 78 | 79 | while (i < numberOfValues && getKey(sorted(i)) == key) { i += 1 } 80 | val end = i 81 | 82 | if (start == end) { 83 | null 84 | 85 | } else { 86 | val res = new Array[Int](end - start) 87 | var j = 0 88 | while (start < end) { 89 | res(j) = getVal(sorted(start)) 90 | start += 1 91 | j += 1 92 | } 93 | 94 | (key, res) 95 | } 96 | } filter (_ != null ) 97 | } 98 | 99 | 100 | def toArray = ??? 101 | } 102 | 103 | 104 | class Mapped(keyRange: Int) extends Grouping(keyRange) { 105 | private val map = mutable.Map[Int, mutable.ArrayBuilder.ofInt]() 106 | def add(k: Int, v: Int): Unit = { 107 | map.getOrElseUpdate(k, new mutable.ArrayBuilder.ofInt) += v 108 | } 109 | def getAll = for ((i, a) <- map.iterator) yield (i, a.result) 110 | def toArray = Array.tabulate(map.keys.max+1) { i => if (map.contains(i)) map(i).result else null } 111 | } 112 | 113 | } 114 | -------------------------------------------------------------------------------- /src/main/scala/StringIndex.scala: -------------------------------------------------------------------------------- 1 | package atrox 2 | 3 | /** StringIndex is an extension of StringIntDictionary that assigns ordinal number 4 | * for every string key and stores compact inversion or original string -> int 5 | * mapping. Therefore it's possible to lookup values by keys and keys by 6 | * values. But by nature how data are stored (inlined or packed in one big 7 | * char array), every lookup by value alocates new string object. */ 8 | sealed class StringIndex(initialCapacity: Int = 1024) extends Serializable { 9 | class SID(initialCapacity: Int) extends StringIntDictionary(initialCapacity = initialCapacity) { 10 | var max = 0 11 | var arrTop = 0 12 | // either (offset, length) or (inlined string, length) pairs, packed in same way as the `assoc` array 13 | var arr = new Array[Int](initialCapacity) 14 | 15 | def _index(str: CharSequence) = { 16 | val idx = getOrElseUpdate(str, max) 17 | if (idx == max) { // new string inserted 18 | if ((arrTop+2) >= arr.length) growArr() 19 | 20 | val word = tryInline(str) 21 | if (word != 0xffffffffffffffffL) { 22 | setInlinedWord(arr, arrTop, word) 23 | setInlined(arr, arrTop) 24 | setInlinedStringLength(arr, arrTop, str.length) 25 | } else { 26 | val pos = findPos(str, word) 27 | val offset = stringOffset(assoc, pos) 28 | setStringOffset(arr, arrTop, offset) 29 | setPackedStringLength(arr, arrTop, str.length) 30 | } 31 | 32 | arrTop += 2 33 | max += 1 34 | } 35 | idx 36 | } 37 | 38 | def _get(pos: Int) = makeString(arr, pos * 2) 39 | 40 | def growArr() = { 41 | val newArr = new Array[Int](arr.length * 2) 42 | System.arraycopy(arr, 0, newArr, 0, arr.length) 43 | arr = newArr 44 | } 45 | } 46 | 47 | private val sid = new SID(initialCapacity) 48 | 49 | /** Returns an integer index for the given string, adding it to the 50 | * index if it is not already present. */ 51 | def index(str: CharSequence): Int = sid._index(str) 52 | 53 | /** Returns the int id of the given element (0-based) or -1 if not 54 | * found in the index. This method never changes the index. */ 55 | def apply(str: CharSequence): Int = sid.getOrDefault(str, -1) 56 | 57 | /** Returns a string at the given position or throws 58 | * IndexOutOfBoundsException if it's not found. */ 59 | def get(pos: Int): String = sid._get(pos) 60 | 61 | /** Number of elements in this index. */ 62 | def size = sid.size 63 | 64 | /** Returns true if this index contains the string t. */ 65 | def contains(str: CharSequence) = sid.contains(str) 66 | } 67 | 68 | 69 | 70 | final class ConcurrentStringIndex(initialCapacity: Int = 1024) extends StringIndex(initialCapacity) { 71 | 72 | private val rwlock = new java.util.concurrent.locks.ReentrantReadWriteLock 73 | private val rlock = rwlock.readLock 74 | private val wlock = rwlock.writeLock 75 | 76 | /** Returns an integer index for the given string, adding it to the 77 | * index if it is not already present. */ 78 | override def index(str: CharSequence): Int = { 79 | rlock.lock() 80 | val idx = try { super.apply(str) } 81 | finally { rlock.unlock() } 82 | if (idx != -1) { 83 | idx 84 | } else { 85 | wlock.lock() 86 | try { super.index(str) } 87 | finally { wlock.unlock() } 88 | } 89 | } 90 | 91 | /** Returns the int id of the given element (0-based) or -1 if not 92 | * found in the index. This method never changes the index. */ 93 | override def apply(str: CharSequence): Int = { 94 | rlock.lock() 95 | try { super.apply(str) } 96 | finally { rlock.unlock() } 97 | } 98 | 99 | /** Returns a string at the given position or throws 100 | * IndexOutOfBoundsException if it's not found. */ 101 | override def get(pos: Int): String = { 102 | rlock.lock() 103 | try { super.get(pos) } 104 | finally { rlock.unlock() } 105 | } 106 | 107 | /** Number of elements in this index. */ 108 | override def size = { 109 | rlock.lock() 110 | try { super.size } 111 | finally { rlock.unlock() } 112 | } 113 | 114 | /** Returns true if this index contains the string t. */ 115 | override def contains(str: CharSequence) = { 116 | rlock.lock() 117 | try { super.contains(str) } 118 | finally { rlock.unlock() } 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /src/main/scala/RadixQuicksort.scala: -------------------------------------------------------------------------------- 1 | package atrox.sort 2 | 3 | import java.util.Arrays 4 | import java.util.concurrent.ThreadLocalRandom 5 | import atrox.Bits 6 | import scala.reflect.ClassTag 7 | 8 | // Three-way radix quicksort aka Multi-key quicksort 9 | // 10 | // Fast Algorithms for Sorting and Searching Strings 11 | // http://www.cs.princeton.edu/~rs/strings/paper.pdf 12 | 13 | 14 | 15 | 16 | object RadixQuicksort { 17 | 18 | 19 | def sort[S](arr: Array[S])(implicit res: RadixElement[S]) = 20 | sort0(arr, 0, arr.length, 0)(res) 21 | 22 | def sortBy[T, S](arr: Array[T], f: T => S)(implicit res: RadixElement[S], clt: ClassTag[T]) = 23 | sort0(arr, 0, arr.length, 0)(RadixElement.Mapped(f)) 24 | 25 | def sortBySchwartzianTransform[T, S](arr: Array[T], f: T => S)(implicit res: RadixElement[S], ret: RadixElement[T]) = { 26 | val tmp = Array.tabulate[(T, S)](arr.length){ i => (arr(i), f(arr(i))) } 27 | sort0(tmp, 0, arr.length, 0)(RadixElement.Mapped(_._2)) 28 | for (i <- 0 until arr.length) { 29 | arr(i) = tmp(i)._1 30 | } 31 | } 32 | 33 | 34 | def sort[S](arr: Array[S], from: Int, to: Int)(implicit res: RadixElement[S]) = 35 | sort0(arr, from, to-from, 0)(res) 36 | 37 | def sortBy[T, S](arr: Array[T], f: T => S, from: Int, to: Int)(implicit res: RadixElement[S], clt: ClassTag[T]) = 38 | sort0(arr, from, to-from, 0)((RadixElement.Mapped(f))) 39 | 40 | 41 | def sort[S](arr: Array[S], from: Int, to: Int, depth: Int = 0)(implicit res: RadixElement[S]) = 42 | sort0(arr, from, to-from, depth)(res) 43 | 44 | def sortBy[T, S](arr: Array[T], f: T => S, from: Int, to: Int, depth: Int = 0)(implicit res: RadixElement[S], clt: ClassTag[T]) = 45 | sort0(arr, from, to-from, depth)((RadixElement.Mapped(f))) 46 | 47 | 48 | 49 | private def sort0[S](arr: Array[S], base: Int, len: Int, depth: Int)(re: RadixElement[S]): Unit = { 50 | 51 | // |---equal---|---lt---|---not yet partitioned---|---gt---|---equal---| 52 | // aEq a b bEq 53 | 54 | if (len <= 1) 55 | return 56 | 57 | // insertion sort 58 | // if (len < 8) { 59 | // val start = base 60 | // val end = base + len 61 | // 62 | // var i = start + 1 63 | // while (i < end) { 64 | // val item = arr(i) 65 | // var hole = i 66 | // while (hole > start && arr(hole - 1) > item) { 67 | // arr(hole) = arr(hole - 1) 68 | // hole -= 1 69 | // } 70 | // arr(hole) = item 71 | // i += 1 72 | // } 73 | // 74 | // return 75 | // } 76 | 77 | var r = ThreadLocalRandom.current().nextInt(len) 78 | 79 | swap(arr, base, base+r) 80 | val pivot = re.byteAt(arr(base), depth) 81 | 82 | var aEq = base 83 | var a = base 84 | var b = base + len - 1 85 | var bEq = base + len - 1 86 | 87 | do { 88 | while (a <= b && re.byteAt(arr(a), depth) <= pivot) { 89 | if (re.byteAt(arr(a), depth) == pivot) { 90 | swap(arr, aEq, a) 91 | aEq += 1 92 | } 93 | a += 1 94 | } 95 | 96 | while (a <= b && re.byteAt(arr(b), depth) >= pivot) { 97 | if (re.byteAt(arr(b), depth) == pivot) { 98 | swap(arr, bEq, b) 99 | bEq -= 1 100 | } 101 | b -= 1 102 | } 103 | 104 | if (a <= b) { // if (a > b) break 105 | swap(arr, a, b) 106 | 107 | a += 1 108 | b -= 1 109 | } 110 | } while (a <= b) 111 | 112 | val aEqLen = math.min(aEq-base, a-aEq) 113 | vecswap(base, a-aEqLen, aEqLen, arr); 114 | 115 | val bEqLen = math.min(bEq-b, base+len-bEq-1) 116 | vecswap(a, base+len-bEqLen, bEqLen, arr); 117 | 118 | r = a-aEq 119 | sort0(arr, base, r, depth)(re) 120 | 121 | if (re.byteAt(arr(base+r), depth) != -1) { 122 | sort0(arr, base + r, aEq + len-bEq-1, depth+1)(re) 123 | } 124 | 125 | r = bEq-b 126 | sort0(arr, base + len-r, r, depth)(re) 127 | } 128 | 129 | private def swap[S](arr: Array[S], a: Int, b: Int) = { 130 | val x = arr(a) 131 | arr(a) = arr(b) 132 | arr(b) = x 133 | } 134 | 135 | private def vecswap[S](i: Int, j: Int, len: Int, arr: Array[S]): Unit = { 136 | var nn = len 137 | var ii = i 138 | var jj = j 139 | while (nn > 0) { 140 | swap(arr, ii, jj) 141 | ii += 1 142 | jj += 1 143 | nn -= 1 144 | } 145 | } 146 | 147 | } 148 | -------------------------------------------------------------------------------- /src/main/scala/VPTree.scala: -------------------------------------------------------------------------------- 1 | package atrox 2 | 3 | import scala.language.postfixOps 4 | import collection.mutable 5 | import VPTree._ 6 | 7 | 8 | /** VP-tree (vantage point tree) is a datastructure that can be used for 9 | * nearest-neighbour queries in arbitrary metric space. 10 | * 11 | * TODO http://boytsov.info/pubs/nips2013.pdf 12 | */ 13 | final class VPTree[T](val root: Tree[T], val distance: Distance[T]) { 14 | def approximateNearest(t: T): T = root.approxNear(t, distance) 15 | def approximateNearestN(t: T, n: Int): IndexedSeq[T] = root.approxNearN(t, n, distance) 16 | def nearest(t: T, maxDist: Double) = root.nearN(t, maxDist, distance) 17 | } 18 | 19 | 20 | object VPTree { 21 | type Distance[T] = (T, T) => Double 22 | 23 | /** Main constructor of VP-trees */ 24 | def apply[T](items: IndexedSeq[T], distance: Distance[T], leafSize: Int): VPTree[T] = 25 | new VPTree(mkNode(items, distance, leafSize), distance) 26 | 27 | sealed trait Tree[T] { 28 | def size: Int 29 | def toSeq: IndexedSeq[T] 30 | def approxNear(t: T, f: Distance[T]): T 31 | def approxNearN(t: T, n: Int, f: Distance[T]): IndexedSeq[T] 32 | def nearN(t: T, maxDist: Double, f: Distance[T]): IndexedSeq[T] 33 | } 34 | 35 | final case class Node[T](point: T, radius: Double, size: Int, in: Tree[T], out: Tree[T]) extends Tree[T] { 36 | def toSeq = in.toSeq ++ out.toSeq 37 | def approxNear(t: T, f: Distance[T]): T = { 38 | val d = f(point, t) 39 | if (d < radius) in.approxNear(t, f) 40 | else out.approxNear(t, f) 41 | } 42 | def approxNearN(t: T, n: Int, f: Distance[T]): IndexedSeq[T] = 43 | if (n <= 0) IndexedSeq() 44 | else if (n > size) toSeq 45 | else { 46 | val d = f(point, t) 47 | if (d < radius) { 48 | in.approxNearN(t, n, f) ++ out.approxNearN(t, n - in.size, f) 49 | } else { 50 | out.approxNearN(t, n, f) ++ in.approxNearN(t, n - out.size, f) 51 | } 52 | } 53 | 54 | def nearN(t: T, maxDist: Double, f: Distance[T]): IndexedSeq[T] = { 55 | val d = f(t, point) 56 | if (d + maxDist < radius) { 57 | in.nearN(t, maxDist, f) 58 | } else if (d - maxDist >= radius) { 59 | out.nearN(t, maxDist, f) 60 | } else { 61 | in.nearN(t, maxDist, f) ++ out.nearN(t, maxDist, f) 62 | } 63 | } 64 | } 65 | 66 | final case class Leaf[T](points: IndexedSeq[T]) extends Tree[T] { 67 | def size = points.length 68 | def toSeq = points 69 | def approxNear(t: T, f: Distance[T]): T = points minBy (p => f(t, p)) 70 | def approxNearN(t: T, n: Int, f: Distance[T]): IndexedSeq[T] = 71 | if (n <= 0) IndexedSeq() 72 | else if (n >= size) points 73 | else points sortBy (p => f(p, t)) take n 74 | 75 | def nearN(t: T, maxDist: Double, f: Distance[T]): IndexedSeq[T] = 76 | points filter { p => f(t, p) <= maxDist } 77 | } 78 | 79 | def mkNode[T](items: IndexedSeq[T], f: Distance[T], leafSize: Int): Tree[T] = { 80 | if (items.length <= leafSize) { 81 | Leaf(items) 82 | } else { 83 | val vp = items(util.Random.nextInt(items.length)) 84 | 85 | val radius = { 86 | val numSamples = math.sqrt(items.length).toInt * 2 87 | val distances = pickSample(items, numSamples).map(i => f(vp, i)).toArray 88 | java.util.Arrays.sort(distances) 89 | distances(distances.length / 2) 90 | } 91 | 92 | val (in, out) = items partition { item => f(item, vp) < radius } 93 | 94 | if (in.length == 0) Leaf(out) 95 | else if (out.length == 0) Leaf(in) 96 | else Node(vp, radius, items.length, mkNode(in, f, leafSize), mkNode(out, f, leafSize)) 97 | } 98 | } 99 | 100 | def pickSample[T](items: IndexedSeq[T], size: Int): IndexedSeq[T] = 101 | if (items.length <= size) items 102 | else IndexedSeq.fill(size)(items(util.Random.nextInt(items.length))) 103 | 104 | def balance[T](t: Tree[T]): List[(Int, Int)] = t match { 105 | case Leaf(_) => Nil 106 | case Node(_, _, _, in, out) => List((in.size, out.size)) ::: balance(in) ::: balance(out) 107 | } 108 | 109 | def prettyPrint[T](n: Tree[T], offset: Int = 0): String = n match { 110 | case Leaf(points) => 111 | (" "*offset)+"Leaf("+points.mkString(",")+")\n" 112 | case n: Node[_] => 113 | (" "*offset)+"Node(point = "+n.point+", radius = "+n.radius+"\n"+ 114 | prettyPrint(n.in, offset+2)+ 115 | prettyPrint(n.out, offset+2)+ 116 | (" "*offset)+")\n" 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /src/main/scala/LSHBulkStrategies.scala: -------------------------------------------------------------------------------- 1 | package atrox.sketch 2 | 3 | import scala.collection.mutable.{ ArrayBuffer, BitSet } 4 | import java.util.concurrent._ 5 | import atrox.Cursor 6 | 7 | 8 | protected class WorkStack(val size: Int) { 9 | private var arr = new Array[Int](size) 10 | private var top = 0 11 | def isEmpty = top == 0 12 | def isFull = top == arr.length 13 | def push(x: Int) = { arr(top) = x ; top += 1 } 14 | def pop(): Int = { top -= 1 ; arr(top) } 15 | } 16 | 17 | 18 | protected object Crawl { 19 | def par[R](itemsCount: Int, compute: Int => R, read: R => Cursor[Int]): Iterator[(Int, R)] = 20 | new ParallelCrawl(itemsCount, compute, read).run() 21 | 22 | /** Compact strategy processing most similar items first. That way a next 23 | * processed element shares most of it's candidates with a previously 24 | * processed one. Those shared candidates are ready in a CPU cache. 25 | * This strategy is often 25% faster than naive compact linear strategy. */ 26 | def seq[R](itemsCount: Int, compute: Int => R, read: R => Cursor[Int]): Iterator[(Int, R)] = { 27 | 28 | val mark = new BitSet(itemsCount) 29 | var waterline = 0 30 | val stack = new WorkStack(256) 31 | 32 | def progressWaterlineAndFillStackIfEmpty() = { 33 | if (stack.isEmpty) { 34 | while (waterline < itemsCount && mark(waterline)) { waterline += 1 } 35 | if (waterline < itemsCount) { 36 | stack.push(waterline) 37 | mark(waterline) = true 38 | } 39 | } 40 | } 41 | 42 | new Iterator[(Int, R)] { 43 | 44 | def hasNext = { 45 | progressWaterlineAndFillStackIfEmpty() 46 | !stack.isEmpty 47 | } 48 | 49 | def next() = { 50 | progressWaterlineAndFillStackIfEmpty() 51 | val w = stack.pop() 52 | 53 | val sims = compute(w) 54 | //mark(w) 55 | 56 | val cur = read(sims) 57 | while (cur.moveNext() && !stack.isFull) { 58 | val s = cur.value 59 | if (!mark(s)) { 60 | stack.push(s) 61 | mark(s) = true 62 | } 63 | } 64 | 65 | (w, sims) 66 | } 67 | 68 | } 69 | 70 | } 71 | 72 | } 73 | 74 | 75 | protected class ParallelCrawl[R]( 76 | val itemsCount: Int, 77 | val compute: Int => R, 78 | val read: R => Cursor[Int] 79 | ) { self => 80 | 81 | val mark = new BitSet(itemsCount) 82 | var waterline = 0 83 | 84 | val threads = Runtime.getRuntime().availableProcessors() 85 | val pool = Executors.newCachedThreadPool 86 | val queue = new ArrayBlockingQueue[Any](256) 87 | 88 | object Tombstone 89 | 90 | def run() = { 91 | val cl = new CountDownLatch(threads) 92 | 93 | for (t <- 0 until threads) { 94 | pool.execute { new Runnable { 95 | def run() = { 96 | cl.countDown() 97 | crawl() 98 | } 99 | } 100 | } 101 | } 102 | 103 | cl.await() 104 | pool.shutdown() 105 | 106 | iterator 107 | } 108 | 109 | def progressWaterlineAndFillStackIfEmpty(stack: WorkStack) = { 110 | if (stack.isEmpty) { 111 | while (waterline < itemsCount && mark(waterline)) { waterline += 1 } 112 | if (waterline < itemsCount) { 113 | stack.push(waterline) 114 | mark(waterline) = true 115 | } 116 | } 117 | } 118 | 119 | def crawl(): Unit = { 120 | val stack = new WorkStack(64) 121 | 122 | while (true) { 123 | 124 | if (stack.isEmpty) { 125 | self.synchronized { 126 | progressWaterlineAndFillStackIfEmpty(stack) 127 | } 128 | } 129 | 130 | if (stack.isEmpty) { 131 | queue.put(Tombstone) 132 | return 133 | } 134 | 135 | val res = new ArrayBuffer[(Int, R)](stack.size) 136 | while (!stack.isEmpty) { 137 | val w = stack.pop() 138 | res += ((w, compute(w))) 139 | } 140 | 141 | res.foreach { r => queue.put(r) } 142 | 143 | self.synchronized { 144 | val cur = read(res(0)._2) 145 | while (cur.moveNext() && !stack.isFull) { 146 | val s = cur.value 147 | if (!mark(s)) { 148 | stack.push(s) 149 | mark(s) = true 150 | } 151 | } 152 | } 153 | 154 | } 155 | } 156 | 157 | def iterator = new Iterator[(Int, R)] { 158 | private var running = threads 159 | var el: (Int, R) = null 160 | 161 | def hasNext = { 162 | if (el != null) true 163 | else if (running == 0) false 164 | else { 165 | val x = queue.take() 166 | if (x == Tombstone) { 167 | running -= 1 168 | hasNext 169 | } else if (x == null) { 170 | sys.error("this should not happen") 171 | } else { 172 | el = x.asInstanceOf[(Int, R)] 173 | true 174 | } 175 | } 176 | } 177 | 178 | def next() = if (!hasNext) null else { 179 | val res = el 180 | el = null 181 | res 182 | } 183 | } 184 | 185 | } 186 | 187 | 188 | -------------------------------------------------------------------------------- /src/main/scala/ImagePHash.scala: -------------------------------------------------------------------------------- 1 | package atrox.sketch 2 | 3 | import java.lang.Math._ 4 | import java.io.InputStream 5 | import java.awt.image.BufferedImage 6 | import javax.imageio.ImageIO 7 | 8 | 9 | /** 10 | * pHash-like image hash. 11 | * Based On: http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html 12 | */ 13 | class ImagePHash(size: Int = 32, smallerSize: Int = 8) { 14 | 15 | assert((smallerSize * smallerSize) % 64 == 0) 16 | 17 | private[this] val cosines = Array.tabulate[Double](size, size) { (i, j) => cos((2*i+1) / (2.0*size) * j * PI) } 18 | private[this] val coeff = Array.tabulate[Double](size) { i => if (i == 0) 1 / sqrt(2.0) else 1.0 } 19 | 20 | 21 | def apply(is: InputStream): Array[Long] = 22 | apply(ImageIO.read(is)) 23 | 24 | 25 | def apply(img: BufferedImage): Array[Long] = { 26 | /* 1. Reduce size. 27 | * Like Average Hash, pHash starts with a small image. 28 | * However, the image is larger than 8x8 32x32 is a good size. 29 | * This is really done to simplify the DCT computation and not 30 | * because it is needed to reduce the high frequencies. 31 | * 2. Reduce color. 32 | * The image is reduced to a grayscale just to further simplify 33 | * the number of computations. 34 | */ 35 | val resized = resizeAndGrayscale(img, size, size) 36 | 37 | val vals = Array.ofDim[Double](size, size) 38 | 39 | var x = 0 40 | while (x < resized.getWidth) { 41 | var y = 0 42 | while (y < resized.getHeight) { 43 | vals(x)(y) = getBlue(resized, x, y) 44 | y += 1 45 | } 46 | x += 1 47 | } 48 | 49 | /* 3. Compute the DCT. 50 | * The DCT separates the image into a collection of frequencies 51 | * and scalars. While JPEG uses an 8x8 DCT, this algorithm uses 52 | * a 32x32 DCT. 53 | */ 54 | val dctVals = applyDCT(vals) 55 | 56 | /* 4. Reduce the DCT. 57 | * This is the magic step. While the DCT is 32x32, just keep the 58 | * top-left 8x8. Those represent the lowest frequencies in the 59 | * picture. 60 | * 5. Compute the average value. 61 | * Like the Average Hash, compute the mean DCT value (using only 62 | * the 8x8 DCT low-frequency values and excluding the first term 63 | * since the DC coefficient can be significantly different from 64 | * the other values and will throw off the average). 65 | */ 66 | var total = 0.0 67 | 68 | { 69 | var x = 0 70 | while (x < smallerSize) { 71 | var y = 0 72 | while (y < smallerSize) { 73 | total += dctVals(x)(y) 74 | y += 1 75 | } 76 | x += 1 77 | } 78 | total -= dctVals(0)(0) 79 | } 80 | 81 | val avg = total / ((smallerSize * smallerSize) - 1.0) 82 | 83 | 84 | /* 6. Further reduce the DCT. 85 | * This is the magic step. Set the 64 hash bits to 0 or 1 86 | * depending on whether each of the 64 DCT values is above or 87 | * below the average value. The result doesn't tell us the 88 | * actual low frequencies it just tells us the very-rough 89 | * relative scale of the frequencies to the mean. The result 90 | * will not vary as long as the overall structure of the image 91 | * remains the same this can survive gamma and color histogram 92 | * adjustments without a problem. 93 | */ 94 | val hash = new Array[Long]((smallerSize * smallerSize) / 64) 95 | 96 | { 97 | var x = 0 98 | while (x < smallerSize) { 99 | var y = 0 100 | while (y < smallerSize) { 101 | if (!(x == 0 && y == 0)) { 102 | val idx = x * smallerSize + y 103 | if (dctVals(x)(y) > avg) { 104 | hash(idx / 64) |= (1L << (idx % 64)) 105 | } 106 | } 107 | y += 1 108 | } 109 | x += 1 110 | } 111 | } 112 | 113 | hash 114 | } 115 | 116 | private def resizeAndGrayscale(image: BufferedImage, width: Int, height: Int): BufferedImage = { 117 | val resizedImage = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY) 118 | val g = resizedImage.createGraphics() 119 | g.drawImage(image, 0, 0, width, height, null) 120 | g.dispose() 121 | resizedImage 122 | } 123 | 124 | private def getBlue(img: BufferedImage, x: Int, y: Int): Int = 125 | img.getRGB(x, y) & 0xff 126 | 127 | // DCT function stolen from http://stackoverflow.com/questions/4240490/problems-with-dct-and-idct-algorithm-in-java 128 | def applyDCT(f: Array[Array[Double]]): Array[Array[Double]] = { 129 | val result = Array.ofDim[Double](size, size) 130 | var u = 0 131 | while (u < size) { 132 | var v = 0 133 | while (v < size) { 134 | var sum = 0.0 135 | var i = 0 136 | while (i < size) { 137 | var j = 0 138 | while (j < size) { 139 | sum += cosines(i)(u) * cosines(j)(v) * f(i)(j) 140 | j += 1 141 | } 142 | i += 1 143 | } 144 | sum *= (coeff(u) * coeff(v)) / 4.0 145 | result(u)(v) = sum 146 | v += 1 147 | } 148 | u += 1 149 | } 150 | result 151 | } 152 | 153 | } 154 | -------------------------------------------------------------------------------- /src/main/scala/RadixSorting.scala: -------------------------------------------------------------------------------- 1 | package atrox.sort 2 | 3 | import scala.reflect.ClassTag 4 | import atrox.Bits 5 | 6 | 7 | 8 | abstract class RadixElement[S] { 9 | val classTag: ClassTag[S] 10 | /** 0-255 or -1 as terminal symbol */ 11 | def byteAt(str: S, pos: Int): Int 12 | def sort(arr: Array[S], len: Int, depth: Int): Unit = 13 | RadixQuicksort.sort(arr, 0, len, depth)(this) 14 | } 15 | 16 | 17 | object RadixElement { 18 | 19 | implicit val Strings: RadixElement[String] = new RadixElement[String] { 20 | def byteAt(str: String, pos: Int): Int = if (pos >= str.length*2) -1 else (str.charAt(pos/2) >> ((~pos & 1) * 8)) & 0xff 21 | val classTag: ClassTag[String] = implicitly[ClassTag[String]] 22 | } 23 | 24 | val ASCIIStrings: RadixElement[String] = new RadixElement[String] { 25 | def byteAt(str: String, pos: Int): Int = if (pos >= str.length) -1 else str.charAt(pos) & 0xff 26 | val classTag: ClassTag[String] = implicitly[ClassTag[String]] 27 | } 28 | 29 | implicit val UnsignedByteArrays: RadixElement[Array[Byte]] = new RadixElement[Array[Byte]] { 30 | def byteAt(str: Array[Byte], pos: Int): Int = if (pos < str.length) str(pos) & 0xff else -1 31 | val classTag: ClassTag[Array[Byte]] = implicitly[ClassTag[Array[Byte]]] 32 | } 33 | 34 | val SignedByteArrays: RadixElement[Array[Byte]] = new RadixElement[Array[Byte]] { 35 | def byteAt(str: Array[Byte], pos: Int): Int = if (pos < str.length) (str(pos)+0x80) & 0xff else -1 36 | val classTag: ClassTag[Array[Byte]] = implicitly[ClassTag[Array[Byte]]] 37 | } 38 | 39 | implicit val IntArrays: RadixElement[Array[Int]] = new RadixElement[Array[Int]] { 40 | def byteAt(str: Array[Int], pos: Int): Int = if (pos >= str.length*4) -1 else ((str(pos/4)+0x80000000) >> (3-(pos%4) * 8)) & 0xff 41 | val classTag: ClassTag[Array[Int]] = implicitly[ClassTag[Array[Int]]] 42 | } 43 | 44 | implicit val LongArrays: RadixElement[Array[Long]] = new RadixElement[Array[Long]] { 45 | def byteAt(str: Array[Long], pos: Int): Int = if (pos >= str.length*8) -1 else (((str(pos/8)+0x8000000000000000L) >> (7-(pos%8) * 8)) & 0xff).toInt 46 | val classTag: ClassTag[Array[Long]] = implicitly[ClassTag[Array[Long]]] 47 | } 48 | 49 | implicit val Ints: RadixElement[Int] = new RadixElement[Int] { 50 | def byteAt(str: Int, pos: Int): Int = if (pos >= 4) -1 else ((str+0x80000000) >>> ((3-pos) * 8)) & 0xff 51 | val classTag: ClassTag[Int] = implicitly[ClassTag[Int]] 52 | } 53 | 54 | implicit val Longs: RadixElement[Long] = new RadixElement[Long] { 55 | def byteAt(str: Long, pos: Int): Int = if (pos >= 8) -1 else (((str+0x8000000000000000L) >>> ((7-pos) * 8)) & 0xff).toInt 56 | val classTag: ClassTag[Long] = implicitly[ClassTag[Long]] 57 | } 58 | 59 | implicit val Floats: RadixElement[Float] = new RadixElement[Float] { 60 | def byteAt(str: Float, pos: Int): Int = if (pos >= 4) -1 else ((Bits.floatToSortableInt(str)+0x80000000) >>> ((3-pos) * 8)) & 0xff 61 | val classTag: ClassTag[Float] = implicitly[ClassTag[Float]] 62 | } 63 | 64 | implicit val Doubles: RadixElement[Double] = new RadixElement[Double] { 65 | def byteAt(str: Double, pos: Int): Int = if (pos >= 8) -1 else (((Bits.doubleToSortableLong(str)+0x8000000000000000L) >>> ((7-pos) * 8)) & 0xff).toInt 66 | val classTag: ClassTag[Double] = implicitly[ClassTag[Double]] 67 | } 68 | 69 | class Mapped[T, S](f: T => S)(implicit res: RadixElement[S], clt: ClassTag[T]) extends RadixElement[T] { 70 | val classTag: ClassTag[T] = clt 71 | def byteAt(str: T, pos: Int): Int = res.byteAt(f(str), pos) 72 | } 73 | 74 | def Mapped[T, S](f: T => S)(implicit els: RadixElement[S], clt: ClassTag[T]) = new Mapped[T, S](f) 75 | 76 | implicit def Options[T](implicit el: RadixElement[T]) = new RadixElement[Option[T]] { 77 | val classTag: ClassTag[Option[T]] = implicitly[ClassTag[Option[T]]] 78 | def byteAt(str: Option[T], pos: Int): Int = 79 | if (pos == 0) { 80 | if (str.isEmpty) -1 /*None*/ else 0 /*Some*/ 81 | } else { 82 | el.byteAt(str.get, pos-1) 83 | } 84 | } 85 | 86 | implicit def Eithers[A, B](implicit ela: RadixElement[A], elb: RadixElement[B]) = new RadixElement[Either[A, B]] { 87 | val classTag: ClassTag[Either[A, B]] = implicitly[ClassTag[Either[A, B]]] 88 | def byteAt(str: Either[A, B], pos: Int): Int = 89 | if (pos == 0) { 90 | if (str.isLeft) 1 else 2 91 | } else { 92 | str match { 93 | case Left(x) => ela.byteAt(x, pos-1) 94 | case Right(x) => elb.byteAt(x, pos-1) 95 | } 96 | } 97 | } 98 | 99 | 100 | val unsignedByteArrayComparator = new java.util.Comparator[Array[Byte]] { 101 | def compare(l: Array[Byte], r: Array[Byte]): Int = { 102 | var i = 0 103 | while (i < l.length && i < r.length) { 104 | val a = (l(i) & 0xff) 105 | val b = (r(i) & 0xff) 106 | if (a != b) { return a - b } 107 | i += 1 108 | } 109 | l.length - r.length 110 | } 111 | } 112 | 113 | val signedByteArrayComparator = new java.util.Comparator[Array[Byte]] { 114 | def compare(l: Array[Byte], r: Array[Byte]): Int = { 115 | var i = 0 116 | while (i < l.length && i < r.length) { 117 | val a = l(i) 118 | val b = r(i) 119 | if (a != b) { return a - b } 120 | i += 1 121 | } 122 | l.length - r.length 123 | } 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /src/test/scala/Sorting.scala: -------------------------------------------------------------------------------- 1 | package atrox.test 2 | 3 | import org.scalatest._ 4 | import atrox._ 5 | import atrox.sort._ 6 | import atrox.sketch._ 7 | import fastSparse._ 8 | import java.util.Arrays 9 | 10 | class BurstSort extends FlatSpec { 11 | 12 | private val chars = (('A' to 'Z') ++ ('a' to 'z') ++ ('0' to '9')) mkString "" 13 | def randomWord(len: Int, prefix: String = "", suffix: String = "")(implicit rand: util.Random) = { 14 | val sb = new StringBuilder(len + prefix.length + suffix.length) 15 | sb.append(prefix) 16 | var i = 0 ; while (i < len) { 17 | sb.append(chars.charAt(rand.nextInt(chars.length))) 18 | i += 1 19 | } 20 | sb.append(suffix) 21 | sb.toString 22 | } 23 | 24 | def randomByteArray(len: Int)(implicit rand: util.Random) = { 25 | val arr = new Array[Byte](len) 26 | rand.nextBytes(arr) 27 | arr 28 | } 29 | 30 | def shouldBeSorted(arr: Array[String]): Unit = { 31 | { 32 | val res = arr.sorted 33 | BurstSort.sort(arr) 34 | assert(arr === res) 35 | } 36 | { 37 | val res = arr.sorted 38 | val srt = BurstTrie(arr).lazySort.toArray 39 | assert(srt === res) 40 | } 41 | { 42 | val res = arr.sorted.reverse 43 | BurstSort.reverseSort(arr) 44 | assert(arr === res) 45 | } 46 | } 47 | 48 | def shouldBeSorted(arr: Array[Array[Byte]]): Unit = { 49 | val arr1 = arr.clone 50 | val res1 = arr.clone 51 | Arrays.sort(res1, RadixElement.unsignedByteArrayComparator) 52 | BurstSort.sort(arr1) 53 | //RadixQuicksort.sort(arr1) 54 | assert(arr1 === res1, s"unsigned") 55 | 56 | val arr2 = arr.clone 57 | val res2 = arr.clone 58 | Arrays.sort(res2, RadixElement.signedByteArrayComparator) 59 | BurstSort.sort(arr2)(RadixElement.SignedByteArrays) 60 | //RadixQuicksort.sort(arr2)(RadixElement.SignedByteArray) 61 | assert(arr2 === res2, s"signed\n${arr2.toSeq.map(_.mkString("[", ", ", "]"))}\n${res2.toSeq.map(_.mkString("[", ", ", "]"))}") 62 | } 63 | 64 | def shouldBeSortedT[T <: AnyRef](arr: Array[T])(implicit ord: Ordering[T], bsel: RadixElement[T]): Unit = { 65 | val res = arr.sorted 66 | BurstSort.sort(arr) 67 | assert(arr === res) 68 | } 69 | 70 | 71 | "BurstSort" should "not crash sorting empty arrays" in { 72 | 73 | BurstSort.sort(Array[String]()) 74 | BurstSort.sort(Array[Array[Byte]]()) 75 | BurstSort.sortBy(Array[String]())(_.length) 76 | BurstSort.sortBy(Array[String]())(_.length.toLong) 77 | 78 | BurstSort.sorted(Array[String]()) 79 | BurstSort.sorted(Array[Array[Byte]]()) 80 | 81 | } 82 | 83 | 84 | "BurstSort" should "sort by string" in { 85 | shouldBeSorted(Array[String]()) 86 | shouldBeSorted(Array[String]("")) 87 | shouldBeSorted(Array[String]("", "")) 88 | shouldBeSorted(Array[String]("a", "c", "b")) 89 | shouldBeSorted(Array[String]("a", "", "aa")) 90 | } 91 | 92 | 93 | "BurstSort" should "be able to sort random strings" in { 94 | implicit val rand = new util.Random(4747) 95 | 96 | for (_ <- 0 until 100) { 97 | val strings = Vector.fill(256)(randomWord(12)) 98 | for (_ <- 0 until 10) { 99 | val shuffled = rand.shuffle(strings).toArray 100 | shouldBeSorted(shuffled) 101 | } 102 | } 103 | 104 | for (_ <- 0 until 100) { 105 | val strings = Vector.fill(256)(randomWord(rand.nextInt(16))) 106 | for (_ <- 0 until 10) { 107 | shouldBeSorted(rand.shuffle(strings).toArray) 108 | } 109 | } 110 | 111 | val strings = Array.fill(1<<15)(randomWord(rand.nextInt(16))) 112 | shouldBeSorted(strings) 113 | } 114 | 115 | 116 | "BurstSort" should "be able to sort random byte arrays" in { 117 | implicit val rand = new util.Random(4747) 118 | 119 | for (_ <- 0 until 100) { 120 | val strings = Vector.fill(4)(randomByteArray(12)) 121 | for (_ <- 0 until 10) { 122 | val shuffled = rand.shuffle(strings).toArray 123 | shouldBeSorted(shuffled) 124 | } 125 | } 126 | 127 | for (_ <- 0 until 100) { 128 | val strings = Vector.fill(256)(randomByteArray(rand.nextInt(16))) 129 | for (_ <- 0 until 10) { 130 | shouldBeSorted(rand.shuffle(strings).toArray) 131 | } 132 | } 133 | } 134 | 135 | "BurstSort" should "sort by integer" in { 136 | case class C(a: Int, ord: Int) 137 | 138 | def shouldBeSorted(arr: Array[Int]): Unit = 139 | shouldBeSortedC(arr.zipWithIndex map C.tupled) 140 | 141 | def shouldBeSortedC(arr: Array[C]): Unit = { 142 | val res = arr.sortBy(_.a).map(_.a) 143 | BurstSort.sortBy(arr)(_.a) 144 | assert(arr.map(_.a) === res) 145 | } 146 | 147 | shouldBeSorted(Array(1,2,3,4,5,6,7,8,9,10)) 148 | shouldBeSorted(Array(1,1,1,1,1,1,1)) 149 | shouldBeSorted(Array(10,9,8,7,6,5,4,3,2,1)) 150 | 151 | shouldBeSorted(Array(1, 1<<8, 1<<16, 1<<24)) 152 | shouldBeSorted(Array(1, 2, 1<<8, 1<<8+1, 1<<16, 1<<16+1, 1<<24, 1<<24+1)) 153 | shouldBeSorted(Array(-1, 0, -2, 1, -3, 2, -2, 0)) 154 | 155 | } 156 | 157 | 158 | "BurstSort" should "sort Options" in { 159 | val xs = Array[Option[Int]](Some(1), None, Some(2), None, Some(-1)) 160 | shouldBeSortedT(xs) 161 | } 162 | 163 | "BurstSort" should "sort Eithers" in { 164 | val xs = Array[Either[Int, String]](Left(99), Right("z"), Left(5), Left(0), Right(""), Right("a")) 165 | implicit val ord = Ordering.by[Either[Int, String], (Int, Option[Int], Option[String])](e => e match { 166 | case Left(x) => (1, Some(x), None) 167 | case Right(x) => (2, None, Some(x)) 168 | }) 169 | shouldBeSortedT(xs) 170 | } 171 | 172 | 173 | } 174 | -------------------------------------------------------------------------------- /src/main/scala/LSHResultBuilders.scala: -------------------------------------------------------------------------------- 1 | package atrox.sketch 2 | 3 | import atrox.Cursor2 4 | import atrox.Bits 5 | import atrox.{ TopKIntInt, TopKIntIntEstimate/*, BruteForceTopKIntInt*/ } 6 | 7 | 8 | object IndexResultBuilder { 9 | def make(distinct: Boolean, maxResults: Int): IndexResultBuilder = 10 | if (maxResults == 1) { 11 | new SingularIndexResultBuilder 12 | } else if (maxResults == Int.MaxValue && !distinct) { 13 | new AllIndexResultBuilder 14 | } else if (maxResults == Int.MaxValue && distinct) { 15 | new AllDistinctIndexResultBuilder 16 | } else { 17 | new TopKIndexResultBuilder(maxResults, distinct) 18 | } 19 | } 20 | 21 | 22 | sealed trait IndexResultBuilder { 23 | def size: Int 24 | def += (idx: Int, score: Int): Unit 25 | def ++= (rb: IndexResultBuilder): Unit 26 | /** Result is sorted by score in descending order. Invocation of this method 27 | * may destroy content of this IndexResultBuilder */ 28 | def result: Array[Int] 29 | /** Result is sorted by score in ascending order. Invocation of this method 30 | * may destroy content of this IndexResultBuilder */ 31 | def idxScoreCursor: Cursor2[Int, Int] 32 | } 33 | 34 | 35 | class SingularIndexResultBuilder extends IndexResultBuilder { 36 | protected[atrox] var empty = true 37 | protected[atrox] var idx, score = 0 38 | 39 | def size: Int = if (empty) 0 else 1 40 | def += (idx: Int, score: Int): Unit = { 41 | if (empty || score > this.score) { 42 | this.idx = idx 43 | this.score = score 44 | this.empty = false 45 | } 46 | } 47 | def ++= (rb: IndexResultBuilder): Unit = rb match { 48 | case rb: SingularIndexResultBuilder => 49 | if (!rb.empty) this += (rb.idx, rb.score) 50 | case _ => sys.error("this should never happen") 51 | } 52 | def result: Array[Int] = if (empty) Array.empty[Int] else Array(idx) 53 | def idxScoreCursor: Cursor2[Int, Int] = new Cursor2[Int, Int] { 54 | var valid = false 55 | def moveNext() = { valid = !valid ; valid && !empty } 56 | def key = idx 57 | def value = score 58 | } 59 | } 60 | 61 | 62 | class AllIndexResultBuilder extends IndexResultBuilder { 63 | private val res = new collection.mutable.ArrayBuilder.ofLong 64 | private var _size = 0 65 | 66 | def size = _size 67 | 68 | def += (idx: Int, score: Int): Unit = { 69 | res += Bits.pack(hi = idx, lo = score) 70 | _size += 1 71 | } 72 | def ++= (rb: IndexResultBuilder): Unit = rb match { 73 | case rb: AllIndexResultBuilder => 74 | val xs = rb.res.result 75 | res ++= xs 76 | _size += xs.length 77 | case _ => sys.error("this should never happen") 78 | } 79 | 80 | def result = { 81 | val arr = res.result 82 | arr.sortBy(~_).map(x => Bits.unpackIntHi(x)) 83 | } 84 | def idxScoreCursor = new Cursor2[Int, Int] { 85 | private val arr = res.result 86 | java.util.Arrays.sort(arr) 87 | private var pos = -1 88 | 89 | def moveNext() = { 90 | if (pos < arr.length) pos += 1 91 | pos < arr.length 92 | } 93 | def key = Bits.unpackIntHi(arr(pos)) 94 | def value = Bits.unpackIntLo(arr(pos)) 95 | } 96 | } 97 | 98 | 99 | /** Inefficient version just for completeness sake. */ 100 | class AllDistinctIndexResultBuilder extends IndexResultBuilder { 101 | private val set = collection.mutable.Set[(Int, Int)]() 102 | def size = set.size 103 | def += (idx: Int, score: Int): Unit = set += ((idx, score)) 104 | def ++= (rb: IndexResultBuilder): Unit = rb match { 105 | case rb: AllDistinctIndexResultBuilder => this.set ++= rb.set 106 | case _ => sys.error("this should never happen") 107 | } 108 | def result = set.map(_._1).toArray.sortBy(~_) 109 | 110 | def idxScoreCursor = new Cursor2[Int, Int] { 111 | var key, value = 0 112 | 113 | def moveNext() = { 114 | if (set.isEmpty) false else { 115 | val h @ (kk, vv) = set.head 116 | key = kk 117 | value = vv 118 | set.remove(h) 119 | true 120 | } 121 | } 122 | } 123 | } 124 | 125 | 126 | class TopKIndexResultBuilder(k: Int, distinct: Boolean) extends IndexResultBuilder { 127 | private var res: TopKIntInt = null // top-k is allocated only when it's needed 128 | private def createTopK() = if (res == null) res = new TopKIntInt(k, distinct) 129 | 130 | def size = if (res == null) 0 else res.size 131 | 132 | def += (idx: Int, score: Int): Unit = { 133 | createTopK() 134 | res.add(score, idx) 135 | } 136 | def ++= (rb: IndexResultBuilder): Unit = rb match { 137 | case rb: TopKIndexResultBuilder => 138 | if (rb.res != null) { 139 | createTopK() 140 | res addAll rb.res 141 | } 142 | case rb: SingularIndexResultBuilder => 143 | if (!rb.empty) this += (rb.idx, rb.score) 144 | case _ => sys.error("this should never happen") 145 | } 146 | 147 | def result = if (res == null) Array() else res.drainToArray() 148 | def idxScoreCursor = { createTopK() ; res.drainCursorSortedAsc.swap } 149 | } 150 | 151 | 152 | class TopKEstimateIndexResultBuilder(k: Int) extends IndexResultBuilder { 153 | ??? 154 | private var res: TopKIntIntEstimate = null // top-k is allocated only when it's needed 155 | private def createTopK() = if (res == null) res = ??? // new TopKIntIntEstimate(k, 3) 156 | 157 | def size = if (res == null) 0 else res.size 158 | 159 | def += (idx: Int, score: Int): Unit = { 160 | createTopK() 161 | res.add(score, idx) 162 | } 163 | def ++= (rb: IndexResultBuilder): Unit = rb match { 164 | case rb: TopKEstimateIndexResultBuilder => 165 | if (rb.res != null) { 166 | createTopK() 167 | res addAll rb.res 168 | } 169 | case rb: SingularIndexResultBuilder => 170 | if (!rb.empty) this += (rb.idx, rb.score) 171 | case _ => sys.error("this should never happen") 172 | } 173 | 174 | def result = if (res == null) Array() else res.drainToArray() 175 | def idxScoreCursor = { createTopK() ; res.cursor.swap } 176 | } 177 | -------------------------------------------------------------------------------- /src/main/scala/BurstSort.scala: -------------------------------------------------------------------------------- 1 | package atrox.sort 2 | 3 | import java.util.Arrays 4 | import scala.annotation.tailrec 5 | import scala.reflect.ClassTag 6 | 7 | 8 | /** BurstSort 9 | * 10 | * Cache-Conscious Sorting of Large Sets of Strings with Dynamic Tries 11 | * http://goanna.cs.rmit.edu.au/~jz/fulltext/alenex03.pdf 12 | * 13 | * Ecient Trie-Based Sorting of Large Sets of Strings 14 | * http://goanna.cs.rmit.edu.au/~jz/fulltext/acsc03sz.pdf 15 | **/ 16 | object BurstSort { 17 | /** in-place sorting */ 18 | def sort[S <: AnyRef: RadixElement](arr: Array[S]) = 19 | (new BurstTrie[S] ++= arr).sortInto(arr, s => s) 20 | 21 | def reverseSort[S <: AnyRef: RadixElement](arr: Array[S]) = 22 | (new BurstTrie[S] ++= arr).sortInto(arr, s => s, true) 23 | 24 | def sortBy[T <: AnyRef: ClassTag, S: RadixElement](arr: Array[T])(f: T => S) = 25 | (new BurstTrie[T]()(RadixElement.Mapped(f)) ++= arr).sortInto(arr, s => s) 26 | 27 | /* 28 | def sortBySchwartzianTransform[T <: AnyRef, S](arr: Array[T], f: T => S)(implicit ctts: ClassTag[(T, S)], cts: ClassTag[S], els: RadixElement[S]) = { 29 | implicit val rqelts = RadixElement.Mapped[(T, S), S](_._2) 30 | val trie = new BurstTrie[(T, S), T]()(RadixElement.Mapped(_._2)) 31 | for (x <- arr) trie += (x, f(x)) 32 | trie.sortInto(arr, ts => ts._1) 33 | } 34 | */ 35 | 36 | def sorted[S <: AnyRef : ClassTag : RadixElement](arr: TraversableOnce[S]) = { 37 | val trie = (new BurstTrie[S] ++= arr) 38 | val res = new Array[S](trie.size) 39 | trie.sortInto(res, s => s) 40 | res 41 | } 42 | 43 | def sortedBy[T <: AnyRef: ClassTag, S: RadixElement](arr: TraversableOnce[T], f: T => S) = { 44 | val trie = new BurstTrie[T]()(RadixElement.Mapped(f)) 45 | for (x <- arr) trie += x 46 | val res = new Array[T](trie.size) 47 | trie.sortInto(res, x => x) 48 | res 49 | } 50 | } 51 | 52 | 53 | 54 | object BurstTrie { 55 | def apply[T <: AnyRef: RadixElement](xs: TraversableOnce[T]) = new BurstTrie[T] ++= xs 56 | } 57 | 58 | 59 | class BurstTrie[S <: AnyRef](implicit el: RadixElement[S]) { 60 | 61 | implicit def ct = el.classTag 62 | 63 | val initSize = 16 64 | val resizeFactor = 8 65 | val maxSize = 1024*4 66 | 67 | private var _size = 0 68 | private val root: Array[AnyRef] = new Array[AnyRef](257) 69 | 70 | def size = _size 71 | 72 | 73 | def ++= (strs: TraversableOnce[S]): this.type = { 74 | for (str <- strs) { this += str } ; this 75 | } 76 | 77 | 78 | def += (str: S): this.type = { 79 | 80 | @tailrec 81 | def add(str: S, depth: Int, node: Array[AnyRef]): Unit = { 82 | 83 | val char = el.byteAt(str, depth)+1 // 0-th slot is for strings with terminal symbol at depth 84 | 85 | node(char) match { 86 | case null => 87 | val leaf = new BurstLeaf[S](initSize) 88 | leaf.add(str) 89 | node(char) = leaf 90 | 91 | case leaf: BurstLeaf[S @unchecked] => 92 | // add string, resize or burst 93 | if (leaf.size == leaf.values.length) { 94 | if (leaf.size * resizeFactor <= maxSize || char == 0 /* || current byte is the last one */) { // resize 95 | leaf.resize(resizeFactor) 96 | leaf.add(str) 97 | 98 | } else { // burst 99 | val newNode = new Array[AnyRef](257) 100 | node(char) = newNode 101 | 102 | for (i <- 0 until leaf.size) { 103 | add0(leaf.values(i), depth+1, newNode) 104 | } 105 | 106 | add(str, depth+1, newNode) 107 | } 108 | 109 | } else { 110 | leaf.add(str) 111 | } 112 | 113 | case child: Array[AnyRef] => 114 | add(str, depth+1, child) 115 | } 116 | } 117 | 118 | def add0(str: S, depth: Int, node: Array[AnyRef]): Unit = add(str, depth, node) 119 | 120 | add(str, 0, root) 121 | _size += 1 122 | this 123 | } 124 | 125 | 126 | def sortInto[R <: AnyRef](res: Array[R], f: S => R, reverse: Boolean = false): Unit = { 127 | var pos = 0 128 | 129 | def run(node: Array[AnyRef], depth: Int): Unit = { 130 | if (!reverse) { 131 | var i = 0 ; while (i < 257) { 132 | doNode(node(i), i != 0, depth, f) 133 | i += 1 134 | } 135 | } else { 136 | var i = 256 ; while (i >= 0) { 137 | doNode(node(i), i != 0, depth, f) 138 | i -= 1 139 | } 140 | } 141 | } 142 | 143 | def doNode(n: AnyRef, sort: Boolean, depth: Int, f: S => R) = n match { 144 | case null => 145 | case leaf: BurstLeaf[S @unchecked] => 146 | if (sort) { 147 | el.sort(leaf.values, leaf.size, depth) 148 | } 149 | //System.arraycopy(leaf.values, 0, res, pos, leaf.size) 150 | if (reverse) { 151 | var i = 0 ; var j = leaf.size-1 ; while (i < leaf.size) { 152 | res(pos+j) = f(leaf.values(i)) 153 | i += 1 154 | j -= 1 155 | } 156 | 157 | } else { 158 | var i = 0 ; while (i < leaf.size) { 159 | res(pos+i) = f(leaf.values(i)) 160 | i += 1 161 | } 162 | } 163 | 164 | pos += leaf.size 165 | case node: Array[AnyRef] => run(node, depth + 1) 166 | } 167 | 168 | run(root, 0) 169 | } 170 | 171 | 172 | private case class LeafJob(leaf: BurstLeaf[S], depth: Int, sort: Boolean) 173 | 174 | private def inorder: Iterator[LeafJob] = { 175 | def iterate(node: AnyRef, depth: Int, sort: Boolean): Iterator[LeafJob] = node match { 176 | case null => Iterator() 177 | case leaf: BurstLeaf[S @unchecked] => Iterator(LeafJob(leaf, depth, sort)) 178 | case node: Array[AnyRef] => Iterator.range(0, node.length) flatMap { i => iterate(node(i), depth+1, i != 0) } 179 | } 180 | 181 | iterate(root, 0, true) 182 | } 183 | 184 | def lazySort = inorder.flatMap { case LeafJob(leaf, depth, sort) => 185 | if (sort) { 186 | el.sort(leaf.values, leaf.size, depth) 187 | } 188 | Iterator.tabulate(leaf.size)(i => leaf.values(i)) 189 | } 190 | 191 | } 192 | 193 | 194 | 195 | private final class BurstLeaf[S <: AnyRef](initSize: Int)(implicit ct: ClassTag[S]) { 196 | var size: Int = 0 197 | var values: Array[S] = new Array[S](initSize) 198 | 199 | def add(str: S) = { 200 | values(size) = str 201 | size += 1 202 | } 203 | 204 | def resize(factor: Int) = { 205 | values = Arrays.copyOfRange(values.asInstanceOf[Array[AnyRef]], 0, values.length * factor).asInstanceOf[Array[S]] 206 | } 207 | 208 | override def toString = values.mkString("BurstLeaf(", ",", ")") 209 | } 210 | -------------------------------------------------------------------------------- /src/main/scala/MemoryMappedLSH.scala: -------------------------------------------------------------------------------- 1 | package atrox.sketch 2 | 3 | import java.io.RandomAccessFile 4 | import java.nio.MappedByteBuffer 5 | import java.nio.IntBuffer 6 | import java.nio.channels.FileChannel 7 | import atrox.fastSparse 8 | 9 | 10 | 11 | /* 12 | object MmapIntArrArr { 13 | def persist(arr: Array[Array[Int]]) 14 | } 15 | 16 | class MmapIntArrArr(bb: ByteBuffer) { 17 | 18 | val ib = bb.asIntBuffer 19 | 20 | val tableLength = ib.get(0) 21 | 22 | val offsets = ib.position(1).limit(1+tableLength).slice() 23 | val data = ib.position(1+tableLength).slice() 24 | 25 | 26 | 27 | val start = offsets.get(idx) 28 | val end = offsets.get(idx + 1) 29 | val len = end-start 30 | 31 | val res = new Array[Int](len) 32 | 33 | for (i <- 0 until len) { 34 | res(i) = data.get(i) 35 | } 36 | 37 | require(fastSparse.isDistinctIncreasingArray(res)) 38 | 39 | res 40 | 41 | } 42 | */ 43 | 44 | 45 | 46 | abstract class MMCommon[SketchArray] extends LSHTable[SketchArray] { 47 | protected def tableLength: Int 48 | 49 | protected def lookup(skarr: SketchArray, skidx: Int, band: Int): Idxs = 50 | idxs(band * (1 << params.hashBits) + hashAndSlice.hashFun(skarr, skidx, band, params)) 51 | 52 | def rawStreamIndexes: Iterator[Idxs] = Iterator.tabulate(tableLength)(idxs) filter (arr => arr != null && arr.length != 0) 53 | 54 | def decodeParams(mm: IntBuffer, baseOffset: Int) = 55 | LSHTableParams( 56 | sketchLength = mm.get(baseOffset + 0), 57 | bands = mm.get(baseOffset + 1), 58 | bandLength = mm.get(baseOffset + 2), 59 | hashBits = mm.get(baseOffset + 3), 60 | itemsCount = mm.get(baseOffset + 4) 61 | ) 62 | 63 | protected def idxs(idx: Int): Array[Int] 64 | } 65 | 66 | abstract class MemoryMappedLSHTable[SketchArray](val mm: IntBuffer) extends MMCommon[SketchArray] { 67 | 68 | val params = decodeParams(mm, 0) 69 | 70 | protected val tableLength = mm.get(5) 71 | require(params.bands * (1 << params.hashBits) == tableLength) 72 | 73 | protected def idxs(idx: Int) = { 74 | require(idx < tableLength, s"idx < tableLength ($idx < $tableLength)") 75 | 76 | val start = mm.get(MemoryMappedLSHTable.headerSize + idx) 77 | val end = mm.get(MemoryMappedLSHTable.headerSize + idx + 1) 78 | val len = end-start 79 | 80 | val res = new Array[Int](len) 81 | 82 | for (i <- 0 until len) { 83 | res(i) = mm.get(start+i) 84 | } 85 | 86 | require(fastSparse.isDistinctIncreasingArray(res)) 87 | 88 | res 89 | } 90 | 91 | override def toString = s"MemoryMappedLSHTable($mm, params = $params, tableLength = $tableLength)" 92 | } 93 | 94 | 95 | object MemoryMappedLSHTable { 96 | 97 | def headerSize = 6 98 | 99 | def mmapTable(fileName: String): IntBuffer = { 100 | val chan = new RandomAccessFile(fileName, "r") 101 | .getChannel() 102 | 103 | val len = chan.size() 104 | 105 | chan 106 | .map(FileChannel.MapMode.READ_ONLY, 0, len) 107 | .asIntBuffer 108 | .asReadOnlyBuffer 109 | } 110 | 111 | def mmap[SketchArray](fileName: String, es: Estimator[SketchArray])(implicit has: HashAndSlice[SketchArray]): MemoryMappedLSHTable[SketchArray] = 112 | new MemoryMappedLSHTable[SketchArray](mmapTable(fileName)) { 113 | def hashAndSlice = has 114 | } 115 | 116 | // [ sketchLength | bands | bandLength | hashBits | itemsCount | idxs length | offsets ... + offset behind the last array | arrays ] 117 | def persist[SketchArray](table: IntArrayLSHTable[SketchArray], fileName: String): Unit = { 118 | val idxs = table.idxs 119 | 120 | val len = lengthOfTable(table) 121 | if (len > Int.MaxValue) throw new Exception("too long") 122 | 123 | val mmf = mmapFile(fileName, len.toInt) 124 | val b = mmf.asIntBuffer 125 | 126 | b.put(table.params.sketchLength) 127 | b.put(table.params.bands) 128 | b.put(table.params.bandLength) 129 | b.put(table.params.hashBits) 130 | b.put(table.params.itemsCount) 131 | b.put(idxs.length) 132 | 133 | var off = headerSize + idxs.length + 1 134 | 135 | for (arr <- idxs) { 136 | b.put(off) 137 | off += arrLen(arr) 138 | } 139 | 140 | b.put(off) 141 | 142 | for (arr <- idxs) { 143 | if (arr != null) { 144 | b.put(arr) 145 | } 146 | } 147 | 148 | mmf.force() 149 | 150 | } 151 | 152 | protected def lengthOfTable[SketchArray](table: IntArrayLSHTable[SketchArray]) = 153 | (headerSize + table.idxs.length.toLong + 1 + table.idxs.map(arrLen).sum) * 4 154 | 155 | protected def mmapFile(fileName: String, length: Int) = 156 | new RandomAccessFile(fileName, "rw") 157 | .getChannel() 158 | .map(FileChannel.MapMode.READ_WRITE, 0, length) 159 | 160 | private def arrLen(arr: Array[Int]) = if (arr == null) 0 else arr.length 161 | 162 | } 163 | 164 | 165 | 166 | 167 | /* 168 | abstract class CompactMemoryMappedLSHTable[SketchArray](val mm: IntBuffer) extends MMCommon[SketchArray] { 169 | 170 | val params = decodeParams(mm, 0) 171 | 172 | private val tableLength = mm.get(5) 173 | private val blockSize = mm.get(6) 174 | private val blockTableLength = (tableLength + blockSize-1) / blockSize 175 | require(params.bands * (1 << params.hashBits) == tableLength) 176 | 177 | protected def idxs(idx: Int) = { 178 | require(idx < tableLength, s"idx < tableLength ($idx < $tableLength)") 179 | 180 | val boff = mm.get(headerSize + idx / blockSize) 181 | 182 | val lstart = headerSize + blockTableLength + (idx / blockSize * blockSize) /2 183 | val lend = headerSize + blockTableLength + idx/2 184 | 185 | val arr = new Array[Short]((lend-lstart)*2) 186 | for (i < lstart until lend) { 187 | val j = i-lstart 188 | val int = mm.get(i) 189 | arr(j*2) = int & 0xffff 190 | arr(j*2+1) = int >>> 16 191 | } 192 | 193 | var sum = 0 194 | for (i <- 0 until (arr.length - (idx & 1)) { 195 | sum += arr(i) 196 | } 197 | 198 | 199 | val dstart = boff + sum 200 | val dend 201 | 202 | 203 | 204 | 205 | 206 | 207 | val res = new Array[Int](len) 208 | 209 | for (i <- 0 until len) { 210 | res(i) = mm.get(start+i) 211 | } 212 | 213 | require(fastSparse.isDistinctIncreasingArray(res)) 214 | 215 | res 216 | } 217 | 218 | override def toString = s"MemoryMappedLSHTable($mm, params = $params, tableLength = $tableLength)" 219 | } 220 | 221 | 222 | object CompactMemoryMappedLSHTable { 223 | 224 | def headerSize = 7 225 | 226 | // [ fields | block offsets | lengths (packed in 2B shorts) | data ] 227 | def persist[SketchArray](table: IntArrayLSHTable[SketchArray], fileName: String): Unit = { 228 | val len = lengthOfTable(table) // TODO 229 | if (len > Int.MaxValue) throw new Exception("too long") 230 | 231 | val idxs = table.idxs 232 | 233 | val blockSize = 16 234 | val tableLength = idxs.length 235 | val blockTableLength = (tableLength + blockSize-1) / blockSize 236 | val off = headerSize + blockTableLength + (tableLength+1/2) 237 | val boffs = idxs.grouped(blockSize).scanLeft(off) { (sum, block) => sum + block.map(arrLen).sum }.toArray 238 | 239 | 240 | val mmf = mmapFile(fileName, len.toInt) 241 | val b = mmf.asIntBuffer 242 | 243 | b.put(table.params.sketchLength) 244 | b.put(table.params.bands) 245 | b.put(table.params.bandLength) 246 | b.put(table.params.hashBits) 247 | b.put(table.params.itemsCount) 248 | b.put(idxs.length) 249 | b.put(blockSize) 250 | 251 | b.put(boffs) 252 | 253 | for (i <- 0 until idx.length by 2) { 254 | val l1 = idxs(i).length 255 | val l2 = idxs(i+1).length 256 | require(l1 < Short.MaxValue && l2 < Short.MaxValue) 257 | b.put(l2 << 16 | l1) 258 | } 259 | 260 | if (idxs.length % 2 == 1) { 261 | val l1 = idxs(idxs.length-1).length 262 | b.put(l1) 263 | } 264 | 265 | for (arr <- idxs) { 266 | if (arr != null) { 267 | b.put(arr) 268 | } 269 | } 270 | 271 | mmf.force() 272 | } 273 | } 274 | */ 275 | -------------------------------------------------------------------------------- /src/main/scala/BloomFilter.scala: -------------------------------------------------------------------------------- 1 | package atrox 2 | 3 | import atrox.sketch.HashFunc 4 | import atrox.sketch.MinHash 5 | import java.lang.Integer.highestOneBit 6 | import scala.math._ 7 | 8 | 9 | object BloomFilter { 10 | def apply[@scala.specialized(Int, Long) T](expectedItems: Int, falsePositiveRate: Double): BloomFilter[T] = { 11 | val (bitLength, hashFunctions) = optimalSize(expectedItems, falsePositiveRate) 12 | apply[T](hashFunctions, bitLength) 13 | } 14 | 15 | def apply[@scala.specialized(Int, Long) T](hashFunctions: Int, bitLength: Int): BloomFilter[T] = 16 | new BloomFilter[T](hashFunctions, higherPowerOfTwo(bitLength)) 17 | 18 | def optimalSize(expectedItems: Int, falsePositiveRate: Double): (Int, Int) = { 19 | val n = expectedItems 20 | val p = falsePositiveRate 21 | val m = ceil(-(n * log(p)) / log(pow(2.0, log(2.0)))) 22 | val k = round(log(2.0) * m / n) 23 | (m.toInt, k.toInt) 24 | } 25 | 26 | private def higherPowerOfTwo(x: Int) = 27 | highestOneBit(x) << (if (highestOneBit(x) == x) 0 else 1) 28 | } 29 | 30 | 31 | trait Bloomy[@scala.specialized(Int, Long) T] { 32 | 33 | def add(x: T): this.type 34 | 35 | def contains(x: T): Boolean 36 | 37 | /** Return true if x is in the Bloom filter. If it's not present, the method 38 | * adds x into the set. This precise behaviour makes no difference for 39 | * ordinary Bloom filters but it have effects for counting bloom filters. 40 | **/ 41 | def getAndSet(x: T): Boolean = { 42 | val res = contains(x) 43 | if (!res) { 44 | add(x) 45 | } 46 | res 47 | } 48 | 49 | def += (x: T): this.type = add(x) 50 | def apply(x: T) = contains(x) 51 | 52 | } 53 | 54 | 55 | /** A Bloom filter is a space-efficient probabilistic data structure, that is 56 | * used for set membership queries. It might give a false positive, but never 57 | * false negative. That is, if bf.contains(x) returns false, element x is 58 | * deffinitely not in the set. If it returns true, element x might have been 59 | * added to the set. If the set is properly sized, probabilisty of false 60 | * positive is very small. For example a bloom filter needs around 8 to 10 bits 61 | * per every inserted element to provide ~1% false positive rate. 62 | * 63 | * This implementation tends to overshoot and provides better guarantees by 64 | * rounding up size of a underlying bit array to the nearest power of two. 65 | */ 66 | class BloomFilter[@scala.specialized(Int, Long) T]( 67 | val hashFunctions: Int, val bitLength: Int 68 | ) extends (T => Boolean) with Bloomy[T] { 69 | 70 | require(hashFunctions > 0, "number of hash functions must be greater than zero") 71 | require(bitLength >= 64, "length of a bloom filter must be at least 64 bits") 72 | require((bitLength & (bitLength - 1)) == 0, "length of a bloom filter must be power of 2") 73 | 74 | private val arr = new Array[Long](bitLength / 64) 75 | private val mask = bitLength - 1 76 | 77 | protected def elemHashCode(x: T) = x.hashCode 78 | 79 | protected def hash(i: Int, x: T): Int = 80 | fs(i)(elemHashCode(x)) 81 | 82 | protected val fs = 83 | Array.tabulate[HashFunc[Int]](hashFunctions)(i => HashFunc.random(i * 4747)) 84 | 85 | def add(x: T): this.type = { 86 | var i = 0 87 | while (i < hashFunctions) { 88 | val pos = hash(i, x) & mask 89 | arr(pos / 64) |= (1L << (pos % 64)) 90 | i += 1 91 | } 92 | 93 | this 94 | } 95 | 96 | 97 | def contains(x: T): Boolean = { 98 | var i = 0 99 | while (i < hashFunctions) { 100 | val pos = hash(i, x) & mask 101 | if (((arr(pos / 64) >>> (pos % 64)) & 1L) == 0) { 102 | return false 103 | } 104 | i += 1 105 | } 106 | 107 | true 108 | } 109 | 110 | 111 | override def getAndSet(x: T): Boolean = { 112 | var isSet = true 113 | 114 | var i = 0 115 | while (i < hashFunctions) { 116 | val pos = hash(i, x) & mask 117 | val longPos = pos / 64 118 | val bitPos = pos % 64 119 | 120 | isSet &= (((arr(longPos) >>> bitPos) & 1L) != 0) 121 | arr(longPos) |= (1L << bitPos) 122 | 123 | i += 1 124 | } 125 | 126 | isSet 127 | } 128 | 129 | 130 | def falsePositiveRate(n: Int) = 131 | pow(1.0 - pow(1.0 - 1.0 / bitLength, hashFunctions * n), hashFunctions) 132 | 133 | def sizeInBytes = arr.length * 8 134 | 135 | override def toString = 136 | s"BloomFilter(hashFunctions = $hashFunctions, bitLength = $bitLength)" 137 | 138 | def approximateSize = { 139 | var bitsSet = 0 140 | var i = 0 ; while (i < arr.length) { 141 | bitsSet += java.lang.Long.bitCount(arr(i)) 142 | i += 1 143 | } 144 | 145 | - bitLength.toDouble / hashFunctions * math.log(1 - bitsSet.toDouble / bitLength) 146 | } 147 | 148 | 149 | def union(bf: BloomFilter[T]): BloomFilter[T] = { 150 | require(hashFunctions == bf.hashFunctions && bitLength == bf.bitLength, 151 | "Cannot unite bloom filters with different number of hash functions and lengths") 152 | 153 | val res = new BloomFilter[T](hashFunctions, bitLength) 154 | var i = 0 ; while (i < res.arr.length) { 155 | res.arr(i) = arr(i) | bf.arr(i); 156 | i += 1 157 | } 158 | 159 | res 160 | } 161 | 162 | 163 | def intersection(bf: BloomFilter[T]): BloomFilter[T] = { 164 | require(hashFunctions == bf.hashFunctions && bitLength == bf.bitLength, 165 | "Cannot intersect bloom filters with different number of hash functions and lengths") 166 | 167 | val res = new BloomFilter[T](hashFunctions, bitLength) 168 | var i = 0 ; while (i < res.arr.length) { 169 | res.arr(i) = arr(i) & bf.arr(i); 170 | i += 1 171 | } 172 | 173 | res 174 | } 175 | 176 | } 177 | 178 | 179 | class CountingBloomFilter[@scala.specialized(Int, Long) T]( 180 | val hashFunctions: Int, val counters: Int, val counterBits: Int 181 | ) extends (T => Boolean) with Bloomy[T] { 182 | 183 | require(hashFunctions > 0, "number of hash functions must be greater than zero") 184 | require(counters >= 64, "length of a bloom filter must be at least 64 bits") 185 | require((counters & (counters - 1)) == 0, "length of a bloom filter must be power of 2") 186 | require(counterBits > 0 && counterBits <= 64, "number of counter bits must be greater than 0 and smaller than 64") 187 | 188 | private[this] val arr = new Array[Long](counters * counterBits / 64 + 1) // +1 for easier extraction of bits 189 | private[this] val mask = counters - 1 190 | 191 | protected def elemHashCode(x: T) = x.hashCode 192 | 193 | protected def hash(i: Int, x: T): Int = 194 | fs(i)(elemHashCode(x)) 195 | 196 | protected val fs = 197 | Array.tabulate[HashFunc[Int]](hashFunctions)(i => HashFunc.random(i * 4747)) 198 | 199 | 200 | protected def ripBits(arr: Array[Long], bitpos: Int, bitlen: Int): Long = { 201 | val mask = (1 << bitlen) - 1 202 | val endpos = (bitpos+bitlen) / 64 203 | 204 | ((arr(bitpos / 64) >>> (bitpos % 64)) & mask) | 205 | ((arr(endpos) << (64 - bitpos % 64)) & mask) 206 | } 207 | 208 | protected def ramBits(arr: Array[Long], bitpos: Int, bitlen: Int, value: Long): Unit = { 209 | val mask = (1 << bitlen) - 1 210 | val endpos = (bitpos+bitlen) / 64 211 | 212 | /* 213 | arr(bitpos / 64) &= ~(mask << (bitpos % 64)) 214 | arr(bitpos / 64) |= (value << (bitpos % 64)) 215 | 216 | arr(endpos) &= ~(mask >>> (64 - bitpos % 64)) 217 | arr(endpos) |= (value >>> (64 - bitpos % 64)) 218 | */ 219 | 220 | // TODO code without loops 221 | var i = 0 ; while (i < bitlen) { 222 | val pos = bitpos + i 223 | arr(pos/64) ^= (-((value>>i)&1L) ^ arr(pos/64)) & (1L << (pos%64)) 224 | i += 1 225 | } 226 | } 227 | 228 | 229 | def add(x: T): this.type = { 230 | var i = 0 231 | while (i < hashFunctions) { 232 | val pos = hash(i, x) & mask 233 | val bitpos = pos * counterBits 234 | val maxCnt = (1 << counterBits) - 1 235 | 236 | val cnt = ripBits(arr, bitpos, counterBits) 237 | if (cnt < maxCnt) { 238 | ramBits(arr, bitpos, counterBits, cnt+1) 239 | } 240 | 241 | i += 1 242 | } 243 | 244 | this 245 | } 246 | 247 | 248 | def contains(x: T): Boolean = { 249 | var i = 0 250 | while (i < hashFunctions) { 251 | val pos = hash(i, x) & mask 252 | val bitpos = pos * counterBits 253 | 254 | val cnt = ripBits(arr, bitpos, counterBits) 255 | if (cnt == 0) return false 256 | i += 1 257 | } 258 | true 259 | } 260 | 261 | 262 | def count(x: T): Long = { 263 | var minCnt = Long.MaxValue 264 | 265 | var i = 0 266 | while (i < hashFunctions) { 267 | val pos = hash(i, x) & mask 268 | val bitpos = pos * counterBits 269 | 270 | val cnt = ripBits(arr, bitpos, counterBits) 271 | minCnt = math.min(minCnt, cnt) 272 | 273 | i += 1 274 | } 275 | minCnt 276 | } 277 | 278 | 279 | def falsePositiveRate(n: Int) = 280 | pow(1.0 - pow(1.0 - 1.0 / counters, hashFunctions * n), hashFunctions) 281 | 282 | def sizeInBytes = arr.length * 8 283 | 284 | override def toString = 285 | s"CountingBloomFilter(hashFunctions = $hashFunctions, counters = $counters, counterBits = $counterBits)" 286 | 287 | 288 | def bits = arr.map(_.toBinaryString.padTo(64, '_')).mkString(" ") 289 | } 290 | -------------------------------------------------------------------------------- /src/main/scala/ALS.scala: -------------------------------------------------------------------------------- 1 | package atrox 2 | 3 | import breeze.linalg._ 4 | import breeze.numerics._ 5 | import breeze.stats.distributions.Rand 6 | import java.util.Arrays 7 | 8 | 9 | object ALS { 10 | def apply(R: CSCMatrix[Float], factors: Int, α: Float, λ: Float, iterations: Int) = 11 | new ALS_Full(R, factors, α, λ, iterations).run 12 | 13 | def apply(R: Array[Array[Int]], factors: Int, α: Float, λ: Float, iterations: Int) = 14 | new ALS_Unary(new UnaryMatrix(R), factors, α, λ, iterations).run 15 | 16 | /** Sparse matrix whose values are all ones. It's column major just like CSCMatrix */ 17 | protected class UnaryMatrix(val data: Array[Array[Int]]) { self => 18 | val cols = data.length 19 | val rows = data.map(ks => if (ks.isEmpty) 0 else ks.max).max+1 20 | 21 | def t: UnaryMatrix = { 22 | val lengths = new Array[Int](rows) 23 | val positions = new Array[Int](rows) 24 | 25 | var r = 0 26 | while (r < data.length) { 27 | var c = 0 28 | while (c < data(r).length) { 29 | val col = data(r)(c) 30 | lengths(col) += 1 31 | c += 1 32 | } 33 | r += 1 34 | } 35 | 36 | val res = Array.tabulate(rows) { c => new Array[Int](lengths(c)) } 37 | 38 | r = 0 39 | while (r < data.length) { 40 | var c = 0 41 | while (c < data(r).length) { 42 | val col = data(r)(c) 43 | res(col)(positions(col)) = r 44 | positions(col) += 1 45 | c += 1 46 | } 47 | r += 1 48 | } 49 | 50 | new UnaryMatrix(res) { 51 | override val cols = self.rows 52 | override val rows = self.cols 53 | } 54 | } 55 | } 56 | 57 | /** ALS for dataset with only ones as values. 58 | * It needs little less memory and it might be marginally faster */ 59 | protected final class ALS_Unary(val R: UnaryMatrix, val factors: Int, val α: Float, val λ: Float, val iterations: Int) extends ALS[UnaryMatrix] { 60 | 61 | val rows = R.rows 62 | val cols = R.cols 63 | def transpose(R: UnaryMatrix) = R.t 64 | 65 | def sliceVec(m: UnaryMatrix, col: Int): SparseVector[Float] = { 66 | val index = m.data(col) 67 | val ones = new Array[Float](index.length) 68 | var i = 0; while (i < index.length) { ones(i) = 1.0f ; i += 1 } 69 | new SparseVector[Float](new breeze.collection.mutable.SparseArray[Float](index, ones, index.length, m.rows, 0.0f)) 70 | } 71 | 72 | private val lotOfOnes: Array[Float] = Array.fill(math.max(rows, cols))(1.0f) 73 | 74 | def copyAndPremultiply(Y: DenseMatrix[Float], R: UnaryMatrix, u: Int): (Array[Float], Int) = { 75 | val index = R.data(u) 76 | val activeSize = index.length 77 | val sqα = math.sqrt(α).toFloat 78 | 79 | val temp = new Array[Float](activeSize * factors) 80 | for (f <- 0 until factors) { 81 | var offset = 0 82 | while (offset < activeSize) { 83 | temp(offset + f * activeSize) = (Y.data(index(offset) + f * Y.majorStride) * sqα) // no need to multiply by √value it's 1 anyway 84 | offset += 1 85 | } 86 | } 87 | 88 | (temp, activeSize) 89 | } 90 | 91 | } 92 | 93 | protected final class ALS_Full(val R: CSCMatrix[Float], val factors: Int, val α: Float, val λ: Float, val iterations: Int) extends ALS[CSCMatrix[Float]] { 94 | 95 | 96 | def transpose(R: CSCMatrix[Float]): CSCMatrix[Float] = R.t 97 | val cols = R.cols 98 | val rows = R.rows 99 | 100 | def sliceVec(m: CSCMatrix[Float], col: Int): SparseVector[Float] = { 101 | val start = m.colPtrs(col) 102 | val end = m.colPtrs(col+1) 103 | val data = Arrays.copyOfRange(m.data, start, end) 104 | val idxs = Arrays.copyOfRange(m.rowIndices, start, end) 105 | new SparseVector(idxs, data, m.rows) 106 | } 107 | 108 | def copyAndPremultiply(Y: DenseMatrix[Float], R: CSCMatrix[Float], u: Int): (Array[Float], Int) = { 109 | 110 | val ru = sliceVec(R, u) 111 | val activeSize = ru.activeSize 112 | val sqα = math.sqrt(α).toFloat 113 | 114 | 115 | var offset = 0 116 | val indexArr = ru.array.index 117 | val sqDataArr = new Array[Float](activeSize) 118 | while (offset < activeSize) { 119 | sqDataArr(offset) = math.sqrt(ru.valueAt(offset)).toFloat 120 | offset += 1 121 | } 122 | 123 | // copy sparse data for better cache locality and pre-multiply them by √α and √cu 124 | val temp = new Array[Float](activeSize * factors) 125 | for (f <- 0 until factors) { 126 | var offset = 0 127 | while (offset < activeSize) { 128 | //val index = ru.indexAt(offset) 129 | //val value = ru.valueAt(offset) 130 | //val sqValue = math.sqrt(value).toFloat 131 | val index = indexArr(offset) 132 | val sqValue = sqDataArr(offset) 133 | 134 | temp(offset + f * activeSize) = (Y.data(index + f * Y.majorStride) * sqα * sqValue) 135 | offset += 1 136 | } 137 | } 138 | 139 | (temp, ru.activeSize) 140 | } 141 | 142 | } 143 | 144 | 145 | 146 | protected abstract class ALS[Dataset] { 147 | 148 | def rep[T](f: => T) = { println('rep) ; Iterator.continually(f).drop(1000000000).toSeq.head } 149 | 150 | def R: Dataset 151 | def factors: Int 152 | def α: Float 153 | def λ: Float 154 | def iterations: Int 155 | 156 | def transpose(R: Dataset): Dataset 157 | def cols: Int 158 | def rows: Int 159 | def sliceVec(m: Dataset, col: Int): SparseVector[Float] 160 | 161 | /** Returns relavant portion of matrix Y, premultiplied by √α and √cu. 162 | * Yt * (Cu - I) * Y can be compute by multiplication of resulting matrix with it's own transpose. */ 163 | def copyAndPremultiply(Y: DenseMatrix[Float], R: Dataset, u: Int): (Array[Float], Int) 164 | 165 | // R - user-item (users' ratings are in columns) 166 | // Rt - item-user 167 | def run = { 168 | print(s"users: $cols\nitems: $rows\n") 169 | val Rt = transpose(R) 170 | 171 | val X = DenseMatrix.ones[Float](rows = cols, cols = factors) // user factors 172 | val Y = DenseMatrix.ones[Float](rows = rows, cols = factors) // item factors 173 | 174 | // println((X.rows, X.cols)) 175 | // println((Y.rows, Y.cols)) 176 | // println((R.asInstanceOf[{def rows: Int}]rows, R.asInstanceOf[{def cols: Int}].cols)) 177 | // println((Rt.asInstanceOf[{def rows: Int}].rows, Rt.asInstanceOf[{def cols: Int}].cols)) 178 | 179 | for (iter <- 0 until iterations) { 180 | fix(R, X, Y, s"$iter - fit X") 181 | fix(Rt, Y, X, s"$iter - fit Y") 182 | } 183 | 184 | (X,Y) 185 | } 186 | 187 | def CuPu(Ru: SparseVector[Float]): Ru.type = { 188 | var offset = 0 189 | while (offset < Ru.activeSize) { 190 | Ru.data(offset) = Ru.data(offset) * α + 1.0f 191 | offset += 1 192 | } 193 | Ru 194 | } 195 | 196 | def fix(R: Dataset, X: DenseMatrix[Float], Y: DenseMatrix[Float], stage: String) = { 197 | println(stage) 198 | 199 | val s = System.nanoTime 200 | val YtY = Y.t * Y 201 | println("YtY "+(System.nanoTime-s)+"ns") 202 | 203 | for (u <- (0 until X.rows).par) { 204 | //val Ru = sliceVec(R, u) 205 | //val YtCuPu = Y.t * Ru.mapActiveValues(v => 1 + v * α) 206 | 207 | val YtCuPu = Y.t * CuPu(sliceVec(R, u)) 208 | 209 | val m = mult(Y, R, u) 210 | m :+= YtY 211 | X(u, ::) := (invInPlace(m) * YtCuPu).t 212 | } 213 | } 214 | 215 | def invInPlace(m: DenseMatrix[Float]): m.type = { 216 | val invM = inv(dfm2ddm(m)) 217 | 218 | val src = invM.data 219 | val dest = m.data 220 | 221 | require( 222 | m.rows == invM.rows && 223 | m.cols == invM.cols && 224 | m.offset == invM.offset && 225 | m.majorStride == invM.majorStride && 226 | m.isTranspose == invM.isTranspose 227 | ) 228 | 229 | var i = 0 230 | while (i < src.length) { 231 | dest(i) = src(i).toFloat 232 | i += 1 233 | } 234 | 235 | m 236 | } 237 | 238 | 239 | def ddm2dfm(ddm: DenseMatrix[Double]): DenseMatrix[Float] = { 240 | val src: Array[Double] = ddm.data 241 | val arr = new Array[Float](src.length) 242 | var i = 0 243 | while (i < src.length) { 244 | arr(i) = src(i).toFloat 245 | i += 1 246 | } 247 | 248 | new DenseMatrix[Float](ddm.rows, ddm.cols, arr, ddm.offset, ddm.majorStride, ddm.isTranspose) 249 | } 250 | 251 | def dfm2ddm(ddm: DenseMatrix[Float]): DenseMatrix[Double] = { 252 | val src: Array[Float] = ddm.data 253 | val arr = new Array[Double](src.length) 254 | var i = 0 255 | while (i < src.length) { 256 | arr(i) = src(i).toDouble 257 | i += 1 258 | } 259 | 260 | new DenseMatrix[Double](ddm.rows, ddm.cols, arr, ddm.offset, ddm.majorStride, ddm.isTranspose) 261 | } 262 | 263 | 264 | 265 | def mult(Y: DenseMatrix[Float], R: Dataset, u: Int): DenseMatrix[Float] = { 266 | // Yt * Y + ... 267 | val res = DenseMatrix.zeros[Float](factors, factors) 268 | 269 | // ... + Yt * (Cu - I) * Y + ... 270 | val (temp, activeSize) = copyAndPremultiply(Y, R, u) 271 | 272 | var i = 0 273 | while (i < factors) { 274 | var j = 0 275 | while (j < factors) { 276 | var prod = 0.0f 277 | var offset = 0 278 | while (offset < activeSize) { 279 | //prod += Y(index, i) * value * α * Y(index, j) 280 | prod += temp(offset + i * activeSize) * temp(offset + j * activeSize) 281 | 282 | offset += 1 283 | } 284 | //res(i, j) = prod 285 | res.unsafeUpdate(i, j, prod) 286 | j += 1 287 | } 288 | i += 1 289 | } 290 | 291 | // ... + λI 292 | i = 0 293 | while (i < factors) { 294 | res(i, i) += λ 295 | i += 1 296 | } 297 | 298 | res 299 | } 300 | 301 | 302 | } 303 | 304 | } 305 | -------------------------------------------------------------------------------- /src/main/scala/FunkSVD.scala: -------------------------------------------------------------------------------- 1 | package collab 2 | 3 | import breeze._ 4 | import breeze.linalg._ 5 | import scala.math.sqrt 6 | import scala.concurrent.duration.Duration 7 | import scala.concurrent.{ Future, future, Await } 8 | import scala.concurrent.ExecutionContext.Implicits.global 9 | 10 | 11 | object FunkSVD { 12 | 13 | def apply(input: Seq[SparseVector[Double]], features: Int, iterations: Int): (DenseMatrix[Double], DenseMatrix[Double]) = { 14 | 15 | val m = input.size // rows 16 | val n = input.head.size // cols 17 | val k = features 18 | val λ = 0.003 // learning rate 19 | val γ = 0.1 // regularization term 20 | 21 | def R = for { 22 | row <- (0 until input.size).iterator ; 23 | (col, r) <- input(row).activeIterator 24 | } yield (r, row, col) 25 | 26 | val U = DenseMatrix.fill[Double](k, m)(0.1) 27 | val V = DenseMatrix.fill[Double](k, n)(0.1) 28 | 29 | val size = input.map(_.activeSize).sum 30 | println("size "+size) 31 | 32 | for (f <- 0 until k) { 33 | println("feature "+f) 34 | for (it <- 0 until iterations) { // until convergence 35 | var err = 0.0d 36 | println("iteration "+it) 37 | for ((r, a, i) <- R) { 38 | val p = U(::, a) dot V(::, i) 39 | val ε = r - p 40 | 41 | err += ε * ε 42 | 43 | val du = λ * (ε * V(f, i) - γ * U(f, a)) 44 | val dv = λ * (ε * U(f, a) - γ * V(f, i)) 45 | 46 | U(f, a) += du 47 | V(f, i) += dv 48 | 49 | } 50 | println(sqrt(err/size)) 51 | } 52 | } 53 | 54 | (U.t, V.t) 55 | } 56 | 57 | def fast(input: Seq[SparseVector[Double]], features: Int, iterations: Int = Int.MaxValue, learningRate: Double = 0.003): (DenseMatrix[Double], DenseMatrix[Double]) = { 58 | 59 | val m = input.size // rows 60 | val n = input.head.size // cols 61 | val k = features 62 | val λ = learningRate 63 | val γ = 0.1 // regularization term 64 | 65 | val totalSize = input.map(_.activeSize).sum 66 | println("totalSize "+totalSize) 67 | 68 | val rows = new Array[Int](totalSize) 69 | val cols = new Array[Int](totalSize) 70 | val ratings = new Array[Double](totalSize) 71 | val residuals = new Array[Double](totalSize) 72 | 73 | var j = 0 74 | for { 75 | row <- (0 until input.size).iterator ; 76 | (col, r) <- input(row).activeIterator 77 | } { 78 | rows(j) = row 79 | cols(j) = col 80 | ratings(j) = r 81 | j += 1 82 | } 83 | 84 | // val U = DenseMatrix.fill[Double](k, m)(0.1) 85 | // val V = DenseMatrix.fill[Double](k, n)(0.1) 86 | 87 | val U = new Array[Double](k*m) // row major 88 | for (i <- 0 until k*m) U(i) = 0.1 89 | 90 | val V = new Array[Double](k*n) // row major 91 | for (i <- 0 until k*n) V(i) = 0.1 92 | 93 | 94 | def dot(a: Array[Double], b: Array[Double], ai: Int, bi: Int, len: Int): Double = { 95 | var i = 0 96 | var res = 0.0 97 | while (i < len) { 98 | res += a(ai*len+i) * b(bi*len+i) 99 | i += 1 100 | } 101 | res 102 | } 103 | 104 | def predict(j: Int, prod: Double): Double = 105 | residuals(j) + prod 106 | 107 | 108 | for (f <- 0 until k) { 109 | println("feature "+f) 110 | 111 | var j = 0 112 | while (j < totalSize) { 113 | val (a, i) = (rows(j), cols(j)) 114 | residuals(j) = dot(U, V, a, i, k) - U(a*k+f) * V(i*k+f) 115 | j += 1 116 | } 117 | 118 | var preverr = 99999999.0d 119 | var err = 0.0d 120 | var errdiff = Double.PositiveInfinity 121 | 122 | for (it <- 0 until iterations if errdiff > 0.000001) { // until convergence 123 | val start = System.currentTimeMillis 124 | 125 | err = 0.0d 126 | print("feature "+f+", iteration "+it) 127 | var j = 0 128 | while (j < totalSize) { 129 | val r = ratings(j) 130 | val a = rows(j) 131 | val i = cols(j) 132 | 133 | val u = U(a*k+f) 134 | val v = V(i*k+f) 135 | 136 | val p = predict(j, u * v) // dot(U, V, a, i, k) //U(::, a) dot V(::, i) 137 | val ε = r - p 138 | 139 | err += ε * ε 140 | 141 | val du = λ * (ε * v - γ * u) 142 | val dv = λ * (ε * u - γ * v) 143 | 144 | U(a*k+f) += du 145 | V(i*k+f) += dv 146 | 147 | j += 1 148 | } 149 | val delta = System.currentTimeMillis - start 150 | println(" "+delta+"ms") 151 | 152 | val errdiff = sqrt(preverr/totalSize) - sqrt(err/totalSize) 153 | println("err "+sqrt(err/totalSize)+" errdiff "+errdiff) 154 | preverr = err 155 | } 156 | 157 | } 158 | 159 | val UM = new DenseMatrix(m, k, U, offset = 0, majorStride = k, isTranspose = true) 160 | val VM = new DenseMatrix(n, k, V, offset = 0, majorStride = k, isTranspose = true) 161 | 162 | (UM, VM) 163 | } 164 | 165 | def ballancedRanges(input: Seq[SparseVector[Double]], n: Int): (Seq[Range], Seq[Range]) = { 166 | val rowCounts = input.map(_.activeSize.toDouble) 167 | 168 | val colCounts = DenseVector.zeros[Double](input.head.size) 169 | for (vec <- input) { 170 | colCounts += vec 171 | } 172 | 173 | (ballance(rowCounts.toArray, n), ballance(colCounts.toArray, n)) 174 | } 175 | 176 | def ballance(counts: Array[Double], n: Int) = { 177 | val sum = counts.sum 178 | val rangeSum = sum / n 179 | val cumsum = counts.scan(0.0)(_ + _).tail 180 | val borders = (0 until n map (i => cumsum.indexWhere(_ >= rangeSum * i))) :+ cumsum.length 181 | borders.sliding(2).toVector.map { case Seq(start, end) => start until end } 182 | } 183 | 184 | 185 | class Chunk(val size: Int) { 186 | val rows = new Array[Int](size) 187 | val cols = new Array[Int](size) 188 | val ratings = new Array[Float](size) 189 | val residuals = new Array[Float](size) 190 | } 191 | 192 | def par(threads: Int, input: Seq[SparseVector[Double]], features: Int, iterations: Int = Int.MaxValue, learningRate: Double = 0.003): (DenseMatrix[Float], DenseMatrix[Float]) = { 193 | 194 | val m = input.size // rows 195 | val n = input.head.size // cols 196 | val k = features 197 | val λ = learningRate.toFloat 198 | val γ = 0.1f // regularization term 199 | 200 | val totalSize = input.map(_.activeSize).sum 201 | println("totalSize "+totalSize) 202 | 203 | val (rowRanges, colRanges) = ballancedRanges(input, threads) 204 | 205 | println(rowRanges map (r => (r.start, r.end))) 206 | println(colRanges map (r => (r.start, r.end))) 207 | 208 | val chunks: Array[Array[Chunk]] = (0 until threads).par.map { chunkRow => 209 | Array.tabulate(threads) { chunkCol => 210 | 211 | def iterate: Iterator[(Int, Int, Float, Int)] = { 212 | var j = 0 213 | for { 214 | row <- rowRanges(chunkRow).iterator 215 | (col, r) <- input(row).activeIterator 216 | if colRanges(chunkCol).contains(col) 217 | } yield { 218 | j += 1 219 | (row, col, r.toFloat, j-1) 220 | } 221 | } 222 | 223 | val chunk = new Chunk(iterate.size) 224 | 225 | for ((row, col, r, j) <- iterate) { 226 | chunk.rows(j) = row 227 | chunk.cols(j) = col 228 | chunk.ratings(j) = r 229 | } 230 | 231 | chunk 232 | } 233 | }.toArray 234 | 235 | for (ch <- chunks; c <- ch) { println(c.size) } 236 | 237 | // tall and thin matrix, latent factors of rows 238 | val U = new Array[Float](k*m) 239 | for (i <- 0 until k*m) U(i) = 0.1f 240 | 241 | // short and wide matrix, latent factors of columns 242 | val V = new Array[Float](k*n) 243 | for (i <- 0 until k*n) V(i) = 0.1f 244 | 245 | 246 | def dot(U: Array[Float], V: Array[Float], ui: Int, vi: Int, k: Int, n: Int): Float = { 247 | var i = 0 248 | var res = 0.0f 249 | while (i < k) { 250 | res += U(ui+m*i) * V(vi+n*i) 251 | i += 1 252 | } 253 | res 254 | } 255 | 256 | for (f <- 0 until k) { 257 | println("feature "+f) 258 | 259 | 260 | for (cs <- chunks.par) { 261 | for (chunk <- cs) { 262 | var j = 0 263 | while (j < chunk.size) { 264 | val row = chunk.rows(j) 265 | val col = chunk.cols(j) 266 | chunk.residuals(j) = dot(U, V, row, col, k, n) - U(row + m*f) * V(col + n*f) 267 | j += 1 268 | } 269 | } 270 | } 271 | 272 | var preverr = Double.MaxValue 273 | var err = 0.0d 274 | var errdiff = Double.MaxValue 275 | 276 | for (it <- (0 until iterations).iterator if errdiff > 1e-7) { // until convergence 277 | val start = System.currentTimeMillis 278 | print("feature "+f+", iteration "+it) 279 | 280 | val startChunkCoords = 0 until threads map (t => (t,t)) 281 | val stages: Seq[Seq[(Int, Int)]] = 0 until threads map { t => startChunkCoords map { case (row, col) => (row, (col + t) % 4) } } 282 | 283 | err = 0.0d 284 | stages foreach { coords => 285 | val xs = coords.par map { case (chunkRow, chunkCol) => 286 | var chunkErr = 0.0d 287 | val chunk = chunks(chunkRow)(chunkCol) 288 | var j = 0 289 | while (j < chunk.size) { 290 | val r = chunk.ratings(j) 291 | val row = chunk.rows(j) 292 | val col = chunk.cols(j) 293 | 294 | val u = U(row + m*f) 295 | val v = V(col + n*f) 296 | 297 | val p = chunk.residuals(j) + u * v // predicted rating 298 | val ε = r - p 299 | 300 | chunkErr += ε * ε 301 | 302 | val du = λ * (ε * v - γ * u) 303 | val dv = λ * (ε * u - γ * v) 304 | 305 | U(row + m*f) += du 306 | V(col + n*f) += dv 307 | 308 | j += 1 309 | } 310 | chunkErr 311 | } 312 | err += xs.sum 313 | } 314 | 315 | val delta = System.currentTimeMillis - start 316 | println(" "+delta+"ms") 317 | 318 | errdiff = (sqrt(preverr/totalSize) - sqrt(err/totalSize)).toFloat 319 | println(s"prevErr: $preverr (${sqrt(preverr/totalSize)}), err: $err (${sqrt(err/totalSize)}), errdiff: $errdiff") 320 | preverr = err 321 | } 322 | 323 | } 324 | 325 | val UM = new DenseMatrix(m, k, U, offset = 0, majorStride = k, isTranspose = true) 326 | val VM = new DenseMatrix(n, k, V, offset = 0, majorStride = k, isTranspose = true) 327 | 328 | (UM, VM) 329 | } 330 | 331 | 332 | 333 | 334 | def thread(f: => Unit): Thread = { 335 | new Thread(new Runnable { 336 | def run: Unit = f 337 | }) 338 | } 339 | 340 | def runThreads[K](xs: Seq[K])(f: K => Unit) = { 341 | val ts = xs map { x => thread(f(x)) } 342 | ts.foreach(_.start()) 343 | ts.foreach(_.join()) 344 | } 345 | } 346 | -------------------------------------------------------------------------------- /src/main/scala/datastructures.scala: -------------------------------------------------------------------------------- 1 | package atrox 2 | 3 | import scala. { specialized => spec } 4 | 5 | /** Koloboke-style cursors. 6 | * 7 | * usage: while (cur.moveNext()) { doSomethingWith(cur.value) } 8 | * */ 9 | abstract class Cursor[@spec(Int, Long, Float, Double) V] { 10 | def moveNext(): Boolean 11 | def value: V 12 | } 13 | 14 | abstract class Cursor2[@spec(Int, Long, Float, Double) K, @spec(Int, Long, Float, Double) V] { self => 15 | def moveNext(): Boolean 16 | def key: K 17 | def value: V 18 | 19 | def asKeys = new Cursor[K] { 20 | def moveNext(): Boolean = self.moveNext 21 | def value: K = self.key 22 | } 23 | 24 | def asValues = new Cursor[V] { 25 | def moveNext(): Boolean = self.moveNext 26 | def value: V = self.value 27 | } 28 | 29 | def swap = new Cursor2[V, K] { 30 | def moveNext() = self.moveNext 31 | def key = self.value 32 | def value = self.key 33 | } 34 | } 35 | 36 | class IntArrayCursor(arr: Array[Int]) extends Cursor[Int] { 37 | private[this] var i = -1 38 | def moveNext(): Boolean = { i += 1 ; i < arr.length } 39 | def value: Int = arr(i) 40 | } 41 | 42 | class IntSeqMapCursor[T](seq: Seq[T])(f: T => Int) extends Cursor[Int] { 43 | private[this] var i = -1 44 | def moveNext(): Boolean = { i += 1 ; i < seq.size } 45 | def value: Int = f(seq(i)) 46 | } 47 | 48 | 49 | /** Set specialized for int values that uses direct hashing. 50 | * 51 | * slots states: free (00) → occupied (10) → deleted (11) 52 | */ 53 | class IntSet(initialSize: Int = 16, loadFactor: Double = 0.5) { 54 | require(loadFactor > 0.0 && loadFactor < 1.0) 55 | 56 | private[this] var capacity = math.max(Bits.higherPowerOfTwo(initialSize), 8) 57 | private[this] var bitmapWords = getBitmapWords(capacity) 58 | private[this] var maxSize = getMaxSize(capacity) 59 | private[this] var arr: Array[Int] = new Array[Int](capacity + bitmapWords) 60 | private[this] var filled = 0 61 | private[this] var _size = 0 62 | 63 | def size = _size 64 | 65 | def += (k: Int): this.type = { 66 | val i = findIdx(k) 67 | if (!isOccupied(arr, i)) { 68 | _size += 1 69 | filled += 1 70 | } 71 | setOccupied(arr, i) 72 | arr(bitmapWords + i) = k 73 | if (filled > maxSize) { 74 | if (size <= filled / 2) grow(1) 75 | else grow(2) 76 | } 77 | this 78 | } 79 | 80 | def ++= (cur: Cursor[Int]): this.type = { 81 | while (cur.moveNext) { this += cur.value } 82 | this 83 | } 84 | 85 | 86 | def -= (k: Int): this.type = { 87 | val i = findIdx(k) 88 | if (!isOccupied(arr, i)) return this 89 | // if the next slot is not occupied (and therefore also not deleted), delete directly 90 | if (!isOccupied(arr, (i + 1) & (capacity - 1))) { 91 | _size -= 1 92 | filled -= 1 93 | setUnoccupied(arr, i) 94 | } else { 95 | _size -= 1 96 | setDeleted(arr, i) 97 | } 98 | this 99 | } 100 | 101 | def contains(k: Int): Boolean = { 102 | val i = findIdx(k) 103 | isOccupied(arr, i) && !isDeleted(arr, i) && arr(bitmapWords + i) == k 104 | } 105 | 106 | def toArray: Array[Int] = toArray(new Array[Int](size), 0) 107 | 108 | def toArray(res: Array[Int], off: Int): Array[Int] = { 109 | var i, j = 0 110 | while (i < capacity) { 111 | if (isOccupied(arr, i) && !isDeleted(arr, i)) { 112 | res(j+off) = arr(bitmapWords + i) 113 | j += 1 114 | } 115 | i += 1 116 | } 117 | res 118 | } 119 | 120 | def clear() = { 121 | filled = 0 122 | _size = 0 123 | var i = 0 ; while (i < arr.length) { 124 | arr(i) = 0 125 | i += 1 126 | } 127 | } 128 | 129 | // Get position of the first empty slot or slot containing value k. 130 | // Must never return deleted slot. 131 | private def findIdx(k: Int) = { 132 | val mask = capacity - 1 133 | var pos = k & mask 134 | var i = pos 135 | while (isDeleted(arr, i) | (isOccupied(arr, i) && arr(bitmapWords + i) != k)) { 136 | i = (i + 1) & mask 137 | } 138 | i 139 | } 140 | 141 | protected def getBitmapWords(capacity: Int) = ((capacity * 2) + 31) / 32 142 | protected def getMaxSize(capacity: Int) = (capacity * loadFactor).toInt 143 | 144 | def getBit(arr: Array[Int], bit: Int) = (arr(bit / 32) & (1 << (bit % 32))) != 0 145 | def setBit(arr: Array[Int], bit: Int) = arr(bit / 32) |= (1 << (bit % 32)) 146 | def clrBit(arr: Array[Int], bit: Int) = arr(bit / 32) &= ~(1 << (bit % 32)) 147 | 148 | private def isOccupied(arr: Array[Int], idx: Int) = getBit(arr, idx*2) 149 | private def setOccupied(arr: Array[Int], idx: Int) = setBit(arr, idx*2) 150 | private def setUnoccupied(arr: Array[Int], idx: Int) = clrBit(arr, idx*2) 151 | private def isDeleted(arr: Array[Int], idx: Int) = getBit(arr, idx*2+1) 152 | private def setDeleted(arr: Array[Int], idx: Int) = setBit(arr, idx*2+1) 153 | 154 | private def grow(factor: Int): Unit = { 155 | val oldCap = capacity 156 | val oldBmw = bitmapWords 157 | val oldArr = arr 158 | 159 | capacity *= factor 160 | bitmapWords = getBitmapWords(capacity) 161 | maxSize = getMaxSize(capacity) 162 | arr = new Array[Int](capacity + bitmapWords) 163 | filled = 0 164 | _size = 0 165 | 166 | var i = 0 167 | while (i < oldCap) { 168 | if (isOccupied(oldArr, i) && !isDeleted(oldArr, i)) { 169 | val k = oldArr(oldBmw + i) 170 | this += k 171 | } 172 | i += 1 173 | } 174 | } 175 | } 176 | 177 | 178 | /** Frequency map intended for getting top-K elements from heavily skewed datasets. 179 | * It's precise if at most K elements have higher frequency than $freqThreshold. 180 | **/ 181 | final class IntFreqMap(initialSize: Int = 32, loadFactor: Double = 0.3, freqThreshold: Int = 16) { 182 | 183 | require(loadFactor > 0.0 && loadFactor < 1.0) 184 | require(freqThreshold > 1) 185 | 186 | private[this] var capacity = math.max(Bits.higherPowerOfTwo(initialSize), 16) 187 | private[this] var _size = 0 188 | private[this] var maxSize = (capacity * loadFactor).toInt 189 | private[this] val realFreqThreshold = Bits.higherPowerOfTwo(freqThreshold) 190 | 191 | private[this] var keys: Array[Int] = new Array[Int](capacity) 192 | private[this] var freq: Array[Int] = new Array[Int](capacity) // frequency of corresponding key 193 | private[this] val lowFreqs: Array[Int] = new Array[Int](realFreqThreshold+1) // frequency of low frequencies 194 | 195 | 196 | def size = _size 197 | 198 | def += (k: Int, count: Int = 1): this.type = { 199 | assert(count > 0) 200 | 201 | val i = findIdx(k) 202 | 203 | val oldFreq = freq(i) 204 | 205 | if (oldFreq == 0) { 206 | _size += 1 207 | } 208 | 209 | keys(i) = k 210 | freq(i) += count 211 | 212 | updateLowFrequency(oldFreq, freq(i)) 213 | 214 | if (_size > maxSize) { 215 | grow() 216 | } 217 | 218 | this 219 | } 220 | 221 | 222 | def ++= (ks: Array[Int], count: Int): this.type = { 223 | var i = 0 224 | while (i < ks.length) { 225 | this += (ks(i), count) 226 | i += 1 227 | } 228 | this 229 | } 230 | 231 | def get(k: Int): Int = { 232 | val i = findIdx(k) 233 | if (freq(i) > 0) { 234 | freq(i) 235 | } else { 236 | throw new java.util.NoSuchElementException("key not found: "+k) 237 | } 238 | } 239 | 240 | def clear(): this.type = { 241 | this._size = 0 242 | 243 | var i = 0 244 | while (i < freq.length) { 245 | keys(i) = 0 246 | freq(i) = 0 247 | i += 1 248 | } 249 | 250 | var j = 0 251 | while (j < lowFreqs.length) { 252 | lowFreqs(j) = 0 253 | j += 1 254 | } 255 | 256 | this 257 | } 258 | 259 | def iterator = keys.iterator zip freq.iterator filter { case (k, f) => f > 0 } 260 | 261 | def toArray: Array[(Int, Int)] = { 262 | val res = new Array[(Int, Int)](_size) 263 | var i, j = 0 264 | while (i < capacity) { 265 | if (freq(i) > 0) { 266 | res(j) = (keys(i), freq(i)) 267 | j += 1 268 | } 269 | i += 1 270 | } 271 | res 272 | } 273 | 274 | override def toString = iterator.mkString("IntFreqMap(", ",", ")") 275 | 276 | 277 | 278 | private def updateLowFrequency(oldFreq: Int, newFreq: Int): Unit = { 279 | val mask = realFreqThreshold - 1 280 | lowFreqs(if (oldFreq < realFreqThreshold) oldFreq & mask else realFreqThreshold) -= 1 281 | lowFreqs(if (newFreq < realFreqThreshold) newFreq & mask else realFreqThreshold) += 1 282 | } 283 | 284 | private def findIdx(k: Int) = { 285 | val mask = capacity - 1 286 | val pos = k & mask 287 | var i = pos 288 | while (freq(i) != 0 && keys(i) != k) { 289 | i = (i + 1) & mask 290 | } 291 | i 292 | } 293 | 294 | /** used to add elements into resized arrays in grow() method */ 295 | private def growset(k: Int, count: Int): Unit = { 296 | val i = findIdx(k) 297 | keys(i) = k 298 | freq(i) = count 299 | } 300 | 301 | private def grow(): Unit = { 302 | val oldKeys = keys 303 | val oldFreq = freq 304 | 305 | this.capacity *= 2 306 | this.maxSize = (this.capacity * loadFactor).toInt 307 | this.keys = new Array[Int](this.capacity) 308 | this.freq = new Array[Int](this.capacity) 309 | 310 | var i = 0 311 | while (i < oldKeys.length) { 312 | if (oldFreq(i) > 0) { 313 | this.growset(oldKeys(i), oldFreq(i)) 314 | } 315 | i += 1 316 | } 317 | } 318 | 319 | /** @return k most frequent elements, wihout corresponding frequency, not sorted */ 320 | def topK(k: Int): Array[Int] = { 321 | require(k > 0) 322 | 323 | val realK = math.min(k, size) 324 | val res = new Array[Int](realK) 325 | 326 | var i = realFreqThreshold+1 327 | var freqSum = 0 328 | while (i > 1 && freqSum < realK) { 329 | i -= 1 330 | freqSum += lowFreqs(i) 331 | } 332 | 333 | val smallestFreq = i 334 | val wholeFreq = i+1 // TODO: what if size of wholeFreq is bigger than realK, then method produces imprecise results 335 | 336 | var resIdx = 0 337 | 338 | var mapIdx = 0 339 | while (mapIdx < freq.length && resIdx < res.length) { 340 | if (freq(mapIdx) >= wholeFreq) { 341 | res(resIdx) = keys(mapIdx) 342 | resIdx += 1 343 | } 344 | mapIdx += 1 345 | } 346 | 347 | mapIdx = 0 348 | while (mapIdx < freq.length && resIdx < res.length) { 349 | if (freq(mapIdx) < wholeFreq && freq(mapIdx) >= smallestFreq) { 350 | res(resIdx) = keys(mapIdx) 351 | resIdx += 1 352 | } 353 | mapIdx += 1 354 | } 355 | 356 | res 357 | } 358 | 359 | def cursor = new Cursor2[Int, Int] { 360 | private var pos = -1 361 | def moveNext() = { 362 | do { pos += 1 } while (pos < capacity && freq(pos) <= 0) 363 | pos < capacity 364 | } 365 | def key = keys(pos) 366 | def value = freq(pos) 367 | } 368 | 369 | } 370 | 371 | 372 | 373 | 374 | class ReusableIntArrayBuilder(initialSize: Int = 16) { 375 | import java.util.Arrays 376 | 377 | require(initialSize > 0) 378 | 379 | private[this] var capacity = initialSize 380 | private[this] var pos = 0 // points behing last element 381 | private[this] var arr = new Array[Int](initialSize) 382 | 383 | def size = pos 384 | 385 | def apply(i: Int) = { 386 | if (i >= pos) throw new IndexOutOfBoundsException 387 | arr(i) 388 | } 389 | 390 | def += (x: Int) = { 391 | arr(pos) = x 392 | pos += 1 393 | if (pos == capacity) { 394 | arr = Arrays.copyOfRange(arr, 0, capacity * 2) 395 | capacity = arr.length 396 | } 397 | } 398 | 399 | def ++= (xs: Array[Int]) = { 400 | if (pos + xs.length >= capacity) { 401 | arr = Arrays.copyOfRange(arr, 0, Bits.higherPowerOfTwo(pos + xs.length)) 402 | capacity = arr.length 403 | } 404 | 405 | System.arraycopy(xs, 0, arr, pos, xs.length) 406 | pos += xs.length 407 | } 408 | 409 | /** Produces an array from added elements. The builder's contents is 410 | * empty after this operation and can be safely used again. */ 411 | def result: Array[Int] = { 412 | val res = Arrays.copyOfRange(arr, 0, pos) 413 | pos = 0 414 | res 415 | } 416 | 417 | def nonEmptyResult = 418 | if (pos == 0) null else result 419 | 420 | def foreach(f: Int => Unit) = { 421 | var i = 0; while (i < pos) { 422 | f(arr(i)) 423 | i += 1 424 | } 425 | } 426 | } 427 | -------------------------------------------------------------------------------- /src/main/scala/heaps.scala: -------------------------------------------------------------------------------- 1 | package atrox 2 | 3 | import java.lang.Math 4 | import sketch.HashFunc 5 | 6 | object TopKIntIntEstimate { 7 | protected[atrox] val hf = Array.tabulate[HashFunc[Int]](256)(i => HashFunc.random(i * 4747)) 8 | } 9 | 10 | 11 | /* Probabilistic version of TopK data structure that produces only distinct elements. 12 | * It's based on ideas of cuckoo hashing and Robin Hood hashing. 13 | * - https://cs.uwaterloo.ca/research/tr/1986/CS-86-14.pdf 14 | * - http://www.sebastiansylvan.com/post/robin-hood-hashing-should-be-your-default-hash-table-implementation/ 15 | * - https://www.pvk.ca/Blog/more_numerical_experiments_in_hashing.html 16 | * */ 17 | class TopKIntIntEstimate(k: Int, hf: Array[HashFunc[Int]], numberOfFunctions: Int, oversample: Int) { self => 18 | 19 | def this(k: Int, numberOfFunctions: Int, oversample: Int = 0) = 20 | this(k, TopKIntIntEstimate.hf, numberOfFunctions, oversample) 21 | 22 | val hashFunctions = numberOfFunctions 23 | 24 | require(k > 0) 25 | require(hashFunctions < hf.length) 26 | 27 | private[this] val kpow = Bits.higherPowerOfTwo(math.max(k, oversample)) 28 | private[this] val arr = { 29 | val arr = new Array[Long](kpow) 30 | java.util.Arrays.fill(arr, Long.MinValue) 31 | arr 32 | } 33 | private[this] var min = Long.MinValue 34 | private[this] var _size = 0 35 | 36 | private def f(h: Int, pair: Long) = 37 | hf(h)(key(pair) ^ value(pair)) 38 | //TopKIntIntEstimate.hf(h)(keyint(pair) ^ value(pair)) & (kpow - 1) 39 | 40 | private def place(_pair: Long, h: Int): Unit = { 41 | var pair = _pair 42 | var i = 0 ; while (i < h) { 43 | val pos = f(i, pair) 44 | if (pair == arr(pos)) return // new value is the same as the value in array, filtering out duplicate 45 | if (pair > arr(pos)) { // new value is bigger than the old value, try to place the old value to another position 46 | val oldPair = arr(pos) 47 | arr(pos) = pair 48 | if (oldPair <= min) { // less then or equal comparison is there for handling array initialized to MinValue 49 | if (oldPair == Long.MinValue) { // replacing an empty slot 50 | _size += 1 51 | } 52 | 53 | // just removed current minimum, find a new one, this should be triggered rarely enough (1/kpow?) 54 | if (_size >= kpow) { 55 | min = findMin() 56 | } 57 | return // encountered smaller value, it's not necessary to try to place it somewhere else 58 | } 59 | //place(oldPair, h) // place old value 60 | //return // if index of hash function that hashes value to this position was encoded in the array slot, 61 | // it wouldn't be necessary to try every hash position all over again (TODO?) 62 | pair = oldPair 63 | i = 0 // manual tail recursion, baby! 64 | } // else try next hash 65 | i += 1 66 | } 67 | } 68 | 69 | def keyThreshold = if (min == Long.MinValue) Int.MinValue else key(min) 70 | 71 | def size = _size // Math.min(_size, k) 72 | 73 | protected def findMin() = { 74 | var min = Long.MaxValue 75 | var i = 0; while (i < arr.length) { 76 | min = Math.min(min, arr(i)) 77 | i += 1 78 | } 79 | min 80 | } 81 | 82 | def add(key: Int, value: Int): Unit = { 83 | val pair = pack(key, value) 84 | if (pair == Long.MinValue) throw new IllegalArgumentException() 85 | if (pair > min) place(pair, hashFunctions) 86 | } 87 | 88 | def addAll(tk: TopKIntIntEstimate) { 89 | val cur = tk.cursor 90 | while (cur.moveNext) { add(cur.key, cur.value) } 91 | } 92 | 93 | def drainToArray(): Array[Int] = { 94 | // TODO: heapify arr and return k items 95 | val buff = new collection.mutable.ArrayBuilder.ofInt 96 | buff.sizeHint(k) 97 | 98 | var i = 0 ; while (i < arr.length) { 99 | if (arr(i) != Long.MinValue) { 100 | buff += value(arr(i)) 101 | } 102 | i += 1 103 | } 104 | buff.result 105 | } 106 | 107 | def cursor = rawCursor 108 | 109 | def rawCursor = new Cursor2[Int, Int] { 110 | private var pos = -1 111 | def moveNext() = { do { pos += 1 } while (pos < arr.length && arr(pos) == Long.MinValue) ; pos < arr.length } 112 | def key = self.key(arr(pos)) 113 | def value = self.value(arr(pos)) 114 | } 115 | 116 | private def pack(key: Int, value: Int) = Bits.pack(key, value) 117 | private def key (pair: Long): Int = Bits.unpackIntHi(pair) 118 | private def value(pair: Long): Int = Bits.unpackIntLo(pair) 119 | } 120 | 121 | 122 | /* 123 | class BruteForceTopKFloatInt(k: Int) { self => 124 | val arr = new Array[Long](k) // sorted from smallest to biggest value 125 | java.util.Arrays.fill(arr, Long.MinValue) 126 | var top = 0 127 | 128 | def size = top 129 | 130 | def insert(key: Float, value: Int): Unit = 131 | insertPair(pack(key, value)) 132 | 133 | protected def insertPair(pair: Long): Unit = 134 | if (top < k) { 135 | arr(top) = pair 136 | top += 1 137 | 138 | if (top == arr.length) { 139 | java.util.Arrays.sort(arr) 140 | } 141 | 142 | } else if (pair > arr(0)) { 143 | var i = 0 144 | while (i < k && arr(i) < pair) { i += 1 } 145 | 146 | if (i > arr.length) return 147 | if (i < arr.length && arr(i) == pair) return 148 | val end = i 149 | 150 | i = 1 151 | while (i < end) { arr(i-1) = arr(i) ; i += 1 } 152 | 153 | arr(end-1) = pair 154 | } 155 | 156 | def += (key: Float, value: Int) = insert(key, value) 157 | 158 | def ++= (tk: BruteForceTopKFloatInt): Unit = { 159 | var i = tk.top - 1; while (i >= 0) { 160 | insertPair(tk.arr(i)) 161 | i -= 1 162 | } 163 | } 164 | 165 | def minKey = key(arr(0)) 166 | 167 | def drainToArray(): Array[Int] = { 168 | val res = new Array[Int](size) 169 | var i = 0 ; while (i < size) { 170 | res(i) = value(arr(i)) 171 | i += 1 172 | } 173 | java.util.Arrays.fill(arr, Long.MinValue) 174 | res 175 | } 176 | 177 | def cursor = new Cursor2[Float, Int] { 178 | private var pos = -1 179 | def moveNext() = { pos += 1 ; pos < top } 180 | def key = self.key(arr(pos)) 181 | def value = self.value(arr(pos)) 182 | } 183 | 184 | def valuesCursor = cursor.asValues 185 | 186 | private def swap(arr: Array[Long], a: Int, b: Int) = { 187 | val tmp = arr(a) 188 | arr(a) = arr(b) 189 | arr(b) = tmp 190 | } 191 | 192 | private def pack(key: Float, value: Int) = Bits.packSortable(key, value) 193 | private def key (pair: Long): Float = Bits.unpackSortableFloatHi(pair) 194 | private def value (pair: Long): Int = Bits.unpackIntLo(pair) 195 | } 196 | */ 197 | 198 | 199 | class TopKIntInt(k: Int, distinct: Boolean = false) extends BaseMinIntIntHeap(k) { 200 | // private var valueSet: IntSet = if (distinct) new IntSet() else null 201 | protected var min = Int.MinValue 202 | 203 | /** returns value that was deleted or Int.MinValue */ 204 | def add(key: Int, value: Int): Unit = { 205 | if (size < k) { 206 | if (!distinct || !_containsValue(value)) { 207 | _insert(key, value) 208 | min = _minKey 209 | // if (distinct) { 210 | // valueSet += value 211 | // } 212 | } 213 | 214 | } else if (key > min) { 215 | if (!distinct || !_containsValue(value)) { 216 | _deleteMinAndInsert(key, value) 217 | min = _minKey 218 | // if (distinct) { 219 | // valueSet -= _minValue 220 | // valueSet += value 221 | // } 222 | } 223 | 224 | } 225 | } 226 | 227 | /** Scanning the whole heap is faster than search in a auxiliary set for 228 | * reasonable small heaps (and consume dramatically less mememory). This is 229 | * caused by the fact that search in the auxiliary set needs two dependent 230 | * dereferences which most likely lead to cache misses. When scanning costs 231 | * more than those 2 misses, it's preferable to use the auxiliary set. */ 232 | protected def _containsValue(value: Int): Boolean = { 233 | var i = 0 ; while (i < top) { 234 | if (low(arr(i)) == value) return true 235 | i += 1 236 | } 237 | false 238 | } 239 | 240 | def addAll(tk: TopKIntInt): Unit = { 241 | // backwrds iterations because that way heap is filled by big values and 242 | // rest is filered out by `key > min` condition in the insert method 243 | var i = tk.top - 1; while (i >= 0) { 244 | val key = high(tk.arr(i)) 245 | val value = low(tk.arr(i)) 246 | add(key, value) 247 | i -= 1 248 | } 249 | } 250 | 251 | /** Return the content (the value part of key-value pair) of this heap sorted 252 | * by the key part. This collection is emptied. */ 253 | def drainToArray(): Array[Int] = { 254 | val res = new Array[Int](size) 255 | var i = res.length-1 ; while (i >= 0) { 256 | res(i) = _minValue 257 | _deleteMin() 258 | i -= 1 259 | } 260 | res 261 | } 262 | 263 | def head: Int = _minValue 264 | def minKey: Int = _minKey 265 | def minValue: Int = _minValue 266 | 267 | override def toString = arr.drop(1).take(size).map(l => (high(l), low(l))).mkString("TopKIntInt(", ",", ")") 268 | 269 | def cursor = new Cursor2[Int, Int] { 270 | private var pos = -1 271 | def moveNext() = { pos += 1 ; pos < top } 272 | def key = high(arr(pos)) 273 | def value = low(arr(pos)) 274 | } 275 | 276 | def valuesCursor = cursor.asValues 277 | 278 | def drainCursorSortedAsc = new Cursor2[Int, Int] { 279 | private var k, v = 0 280 | def moveNext() = 281 | if (size == 0) false else { 282 | k = _minKey 283 | v = _minValue 284 | _deleteMin() 285 | true 286 | } 287 | 288 | def key = k 289 | def value = v 290 | } 291 | } 292 | 293 | 294 | class MinIntIntHeap(capacity: Int) extends BaseMinIntIntHeap(capacity) { 295 | def insert(key: Int, value: Int) = _insert(key, value) 296 | def minKey: Int = _minKey 297 | def minValue: Int = _minValue 298 | def deleteMin(): Unit = _deleteMin() 299 | def deleteMinAndInsert(key: Int, value: Int) = _deleteMinAndInsert(key, value) 300 | } 301 | 302 | object MinIntIntHeap { 303 | def builder(capacity: Int) = new MinIntIntHeapBuilder(capacity) 304 | } 305 | 306 | class MinIntIntHeapBuilder(capacity: Int) { 307 | private[this] var heap = new MinIntIntHeap(capacity) 308 | 309 | def insert(key: Int, value: Int) = 310 | heap._insertNoSwim(key, value) 311 | 312 | def result = { 313 | require(heap != null, "MinIntIntHeapBuilder cannot be reused") 314 | heap.makeHeap() 315 | val res = heap 316 | heap = null 317 | res 318 | } 319 | } 320 | 321 | // Both key and index are packed inside one Long value. 322 | // Key which is used for comparison forms high 4 bytes of said Long. 323 | abstract class BaseMinIntIntHeap protected (protected val arr: Array[Long], val capacity: Int) { 324 | 325 | def this(capacity: Int) = this(new Array[Long](capacity), capacity) 326 | 327 | // top points behind the last element 328 | protected var top = 0 329 | 330 | def size = top 331 | def isEmpty = top == (0) 332 | def nonEmpty = top != (0) 333 | 334 | protected[atrox] def _insertNoSwim(key: Int, value: Int): Unit = { 335 | arr(top) = pack(key, value) 336 | top += 1 337 | } 338 | 339 | protected def _insert(key: Int, value: Int): Unit = { 340 | arr(top) = pack(key, value) 341 | swim(top) 342 | top += 1 343 | } 344 | 345 | protected def _minKey: Int = { 346 | if (top == 0) throw new NoSuchElementException 347 | high(arr(0)) 348 | } 349 | protected def _minValue: Int = { 350 | if (top == 0) throw new NoSuchElementException 351 | low(arr(0)) 352 | } 353 | 354 | protected def _deleteMin(): Unit = { 355 | if (top == 0) throw new NoSuchElementException("underflow") 356 | //val minKey = high(arr(0)) 357 | top -= 1 358 | swap(0, top) 359 | arr(top) = 0 360 | sink(0) 361 | } 362 | 363 | /** This method is equivalent to deleteMin() followed by insert(), but it's 364 | * more efficient. */ 365 | protected def _deleteMinAndInsert(key: Int, value: Int): Unit = { 366 | //deleteMin() 367 | //insert(key, value) 368 | if (top == 0) throw new NoSuchElementException("underflow") 369 | arr(0) = pack(key, value) 370 | sink(0) 371 | } 372 | 373 | 374 | protected def pack(hi: Int, lo: Int): Long = hi.toLong << 32 | lo 375 | protected def high(x: Long): Int = (x >>> 32).toInt 376 | protected def low(x: Long): Int = x.toInt 377 | 378 | private def swap(a: Int, b: Int) = { 379 | val tmp = arr(a) 380 | arr(a) = arr(b) 381 | arr(b) = tmp 382 | } 383 | 384 | private def parent(pos: Int) = (pos - 1) / 2 385 | private def child(pos: Int) = pos * 2 + 1 386 | 387 | // moves value at the given position up towards the root 388 | private def swim(_pos: Int): Unit = { 389 | var pos = _pos 390 | while (pos > 0 && arr(parent(pos)) > arr(pos)) { 391 | swap(parent(pos), pos) 392 | pos = parent(pos) 393 | } 394 | } 395 | 396 | // moves value at the given position down towards leaves 397 | private def sink(_pos: Int): Unit = { 398 | val key = arr(_pos) 399 | var pos = _pos 400 | while (child(pos) < top) { 401 | var ch = child(pos) 402 | if ((ch+1) < top && arr(ch+1) < arr(ch)) ch += 1 403 | if (key <= arr(ch)) { 404 | arr(pos) = key 405 | return 406 | } 407 | arr(pos) = arr(ch) 408 | pos = ch 409 | } 410 | arr(pos) = key 411 | } 412 | 413 | protected[atrox] def makeHeap() = { 414 | var i = capacity/2-1 415 | while (i >= 0) { 416 | sink(i) 417 | i -= 1 418 | } 419 | } 420 | 421 | private def isValidHeap = 422 | 0 until capacity forall { i => 423 | val ch = child(i) 424 | (ch >= capacity || arr(ch) >= arr(i)) && 425 | (ch+1 >= capacity || arr(ch+1) >= arr(i)) 426 | } 427 | 428 | } 429 | -------------------------------------------------------------------------------- /src/main/scala/RadixSort.scala: -------------------------------------------------------------------------------- 1 | package atrox.sort 2 | 3 | import java.util.Arrays 4 | import java.lang.Float.floatToRawIntBits 5 | import java.lang.Double.doubleToRawLongBits 6 | import atrox.Bits 7 | import scala.reflect.ClassTag 8 | 9 | 10 | /** Radix sort is non-comparative sorting alorithm that have linear complexity 11 | * for fixed width integers. In practice it's much faster than 12 | * java.util.Arrays.sort for arrays larger than 1k. The only drawback is that the 13 | * implementation used here is not in-place and needs auxilary array that is as 14 | * big as the input to be sorted. 15 | * 16 | * Based on arcane knowledge of http://www.codercorner.com/RadixSortRevisited.htm 17 | */ 18 | object RadixSort { 19 | 20 | 21 | protected def computeOffsets( 22 | counts: Array[Int], offsets: Array[Int], bytes: Int, length: Int, 23 | dealWithNegatives: Boolean = true, floats: Boolean = false, detectSkips: Boolean = true 24 | ): Int = { 25 | 26 | var canSkip = 0 27 | 28 | // compute offsets/prefix sums 29 | var byte = 0 30 | while (byte < bytes) { 31 | val b256 = byte * 256 32 | 33 | offsets(b256 + 0) = 0 34 | 35 | var i = 1 36 | while (i < 256) { 37 | offsets(b256 + i) = counts(b256 + i-1) + offsets(b256 + i-1) 38 | i += 1 39 | } 40 | 41 | if (detectSkips) { 42 | // detect radices that can be skipped 43 | var i = 0 44 | var skip = false 45 | while (i < 256 && (counts(b256 + i) == length || counts(b256 + i) == 0) && !skip) { 46 | skip = (counts(b256 + i) == length) 47 | i += 1 48 | } 49 | 50 | if (skip) { 51 | canSkip |= (1 << byte) 52 | } 53 | } 54 | 55 | byte += 1 56 | } 57 | 58 | val lastByte = bytes - 1 59 | val lb256 = lastByte * 256 60 | 61 | // deal with negative values 62 | if (dealWithNegatives) { 63 | var negativeValues = 0 64 | var i = 128 65 | while (i < 256) { 66 | negativeValues += counts(lb256 + i) 67 | i += 1 68 | } 69 | 70 | if (!floats) { 71 | 72 | offsets(lb256 + 0) = negativeValues 73 | offsets(lb256 + 128) = 0 74 | 75 | var i = 1 ; while (i < 256) { 76 | val ii = i + 128 77 | val curr = ii % 256 78 | val prev = (ii - 1 + 256) % 256 79 | offsets(lb256 + curr) = counts(lb256 + prev) + offsets(lb256 + prev) 80 | i += 1 81 | } 82 | 83 | } else { 84 | 85 | offsets(lb256 + 0) = negativeValues 86 | offsets(lb256 + 255) = counts(lb256 + 255) - 1 87 | 88 | var i = 1 ; while (i < 128) { 89 | offsets(lb256 + i) = offsets(lb256 + i - 1) + counts(lb256 + i - 1) 90 | i += 1 91 | } 92 | 93 | i = 254 ; while (i > 127) { 94 | offsets(lb256 + i) = offsets(lb256 + i + 1) + counts(lb256 + i) 95 | i -= 1 96 | } 97 | 98 | } 99 | } 100 | 101 | canSkip 102 | } 103 | 104 | 105 | protected def handleResults[T](arr: Array[T], input: Array[T], output: Array[T], returnResultInSourceArray: Boolean): (Array[T], Array[T]) = { 106 | if (returnResultInSourceArray && !(input eq arr)) { 107 | // copy data into array that was passed as an argument to be sorted 108 | System.arraycopy(input, 0, output, 0, input.length) 109 | assert(input != output) 110 | (output, input) 111 | } else { 112 | // return arrays as they are 113 | assert(input != output) 114 | (input, output) 115 | } 116 | } 117 | 118 | 119 | protected def checkPreconditions[T](arr: Array[T], scratch: Array[T], from: Int, to: Int, fromByte: Int, toByte: Int, maxBytes: Int) = { 120 | require(to <= scratch.length) 121 | require(to <= arr.length) 122 | require(from >= 0) 123 | require(to >= 0) 124 | require(fromByte < toByte) 125 | require(fromByte >= 0 && fromByte < maxBytes) 126 | require(toByte > 0 && toByte <= maxBytes) 127 | } 128 | 129 | 130 | 131 | def sort(arr: Array[Int]): Unit = { 132 | if (arr.length <= 1024) { 133 | Arrays.sort(arr) 134 | } else { 135 | sort(arr, new Array[Int](arr.length), 0, arr.length, 0, 4, true) 136 | } 137 | } 138 | 139 | def sort(arr: Array[Int], scratch: Array[Int]): (Array[Int], Array[Int]) = 140 | sort(arr, scratch, 0, arr.length, 0, 4, false) 141 | 142 | def sort(arr: Array[Int], scratch: Array[Int], returnResultInSourceArray: Boolean): (Array[Int], Array[Int]) = 143 | sort(arr, scratch, 0, arr.length, 0, 4, returnResultInSourceArray) 144 | 145 | /** Sorts `arr` array using `scratch` as teporary scratchpad. Returns both 146 | * arrays, first sorted, second in undefined state. Both returned arrays are the 147 | * same arrays passed as arguments but it's not specified which is which. 148 | * 149 | * If returnResultInSourceArray is set to true, the sorted array is the one 150 | * passed as argument to be sorted. In this case arrays cannot be swapped. 151 | * 152 | * from and fromByte are inclusive positions 153 | * to and toByte are exclusive positions 154 | */ 155 | def sort(arr: Array[Int], scratch: Array[Int], from: Int, to: Int, fromByte: Int, toByte: Int, returnResultInSourceArray: Boolean): (Array[Int], Array[Int]) = { 156 | 157 | checkPreconditions(arr, scratch, from, to, fromByte, toByte, 4) 158 | 159 | if (from >= to) return (arr, scratch) 160 | 161 | var input = arr 162 | var output = scratch 163 | val counts = new Array[Int](4 * 256) 164 | val offsets = new Array[Int](4 * 256) 165 | var sorted = true 166 | var last = input(to - 1) 167 | 168 | // collect counts 169 | // This loop iterates backward because this way it brings begining of the 170 | // `arr` array into a cache and that speeds up next iteration. 171 | var i = to - 1 172 | while (i >= from) { 173 | sorted &= last >= input(i) 174 | last = input(i) 175 | 176 | var byte = 0 177 | while (byte < 4) { // iterates through all 4 bytes on purpose, JVM unrolls this loop 178 | val c = (input(i) >>> (byte * 8)) & 0xff 179 | counts(byte * 256 + c) += 1 180 | byte += 1 181 | } 182 | i -= 1 183 | } 184 | 185 | if (sorted) return (input, output) 186 | 187 | val canSkip = computeOffsets(counts, offsets, 4, arr.length) 188 | 189 | var byte = fromByte 190 | while (byte < toByte) { 191 | if ((canSkip & (1 << byte)) == 0) { 192 | 193 | val byteOffsets = Arrays.copyOfRange(offsets, byte * 256, byte * 256 + 256) 194 | 195 | var i = from 196 | while (i < to) { 197 | val c = (input(i) >>> (byte * 8)) & 0xff 198 | output(byteOffsets(c)) = input(i) 199 | byteOffsets(c) += 1 200 | i += 1 201 | } 202 | 203 | // swap input with output 204 | val tmp = input 205 | input = output 206 | output = tmp 207 | } 208 | 209 | byte += 1 210 | } 211 | 212 | handleResults(arr, input, output, returnResultInSourceArray) 213 | } 214 | 215 | 216 | 217 | def sort(arr: Array[Long]): Unit = { 218 | if (arr.length <= 1024) { 219 | Arrays.sort(arr) 220 | } else { 221 | sort(arr, new Array[Long](arr.length), 0, arr.length, 0, 8, true) 222 | } 223 | } 224 | 225 | def sort(arr: Array[Long], scratch: Array[Long]): (Array[Long], Array[Long]) = 226 | sort(arr, scratch, 0, arr.length, 0, 8, false) 227 | 228 | def sort(arr: Array[Long], scratch: Array[Long], returnResultInSourceArray: Boolean): (Array[Long], Array[Long]) = 229 | sort(arr, scratch, 0, arr.length, 0, 8, returnResultInSourceArray) 230 | 231 | def sort(arr: Array[Long], scratch: Array[Long], from: Int, to: Int, fromByte: Int, toByte: Int, returnResultInSourceArray: Boolean): (Array[Long], Array[Long]) = { 232 | 233 | checkPreconditions(arr, scratch, from, to, fromByte, toByte, 8) 234 | 235 | if (from >= to) return (arr, scratch) 236 | 237 | var input = arr 238 | var output = scratch 239 | val counts = new Array[Int](8 * 256) 240 | val offsets = new Array[Int](8 * 256) 241 | var sorted = true 242 | var last = input(to - 1) 243 | 244 | // collect counts 245 | // This loop iterates backward because this way it brings begining of the 246 | // `arr` array into a cache and that speeds up next iteration. 247 | var i = to - 1 248 | while (i >= from) { 249 | sorted &= last >= input(i) 250 | last = input(i) 251 | 252 | var byte = 0 253 | while (byte < 8) { 254 | val c = ((input(i) >>> (byte * 8)) & 0xff).toInt 255 | counts(byte * 256 + c) += 1 256 | byte += 1 257 | } 258 | i -= 1 259 | } 260 | 261 | if (sorted) return (input, output) 262 | 263 | val canSkip = computeOffsets(counts, offsets, 8, arr.length) 264 | 265 | var byte = fromByte 266 | while (byte < toByte) { 267 | if ((canSkip & (1 << byte)) == 0) { 268 | 269 | val byteOffsets = Arrays.copyOfRange(offsets, byte * 256, byte * 256 + 256) 270 | 271 | var i = from 272 | while (i < to) { 273 | val c = ((input(i) >>> (byte * 8)) & 0xff).toInt 274 | output(byteOffsets(c)) = input(i) 275 | byteOffsets(c) += 1 276 | i += 1 277 | } 278 | 279 | // swap input with output 280 | val tmp = input 281 | input = output 282 | output = tmp 283 | 284 | } 285 | byte += 1 286 | } 287 | 288 | handleResults(arr, input, output, returnResultInSourceArray) 289 | } 290 | 291 | 292 | 293 | def sort(arr: Array[Float]): Unit = { 294 | if (arr.length <= 1024) { 295 | Arrays.sort(arr) 296 | } else { 297 | sort(arr, new Array[Float](arr.length), 0, arr.length, 0, 8, true) 298 | } 299 | } 300 | 301 | def sort(arr: Array[Float], scratch: Array[Float]): (Array[Float], Array[Float]) = 302 | sort(arr, scratch, 0, arr.length, 0, 4, false) 303 | 304 | def sort(arr: Array[Float], scratch: Array[Float], returnResultInSourceArray: Boolean): (Array[Float], Array[Float]) = 305 | sort(arr, scratch, 0, arr.length, 0, 4, returnResultInSourceArray) 306 | 307 | def sort(arr: Array[Float], scratch: Array[Float], from: Int, to: Int, fromByte: Int, toByte: Int, returnResultInSourceArray: Boolean): (Array[Float], Array[Float]) = { 308 | 309 | checkPreconditions(arr, scratch, from, to, fromByte, toByte, 4) 310 | 311 | if (from >= to) return (arr, scratch) 312 | 313 | var input = arr 314 | var output = scratch 315 | val counts = new Array[Int](4 * 256) 316 | val offsets = new Array[Int](4 * 256) 317 | var sorted = true 318 | var last = input(to - 1) 319 | 320 | // collect counts 321 | // This loop iterates backward because this way it brings begining of the 322 | // `arr` array into a cache and that speeds up next iteration. 323 | var i = to - 1 324 | while (i >= from) { 325 | sorted &= last >= input(i) 326 | last = input(i) 327 | 328 | var byte = 0 329 | while (byte < 4) { 330 | val c = (floatToRawIntBits(input(i)) >>> (byte * 8)) & 0xff 331 | counts(byte * 256 + c) += 1 332 | byte += 1 333 | } 334 | i -= 1 335 | } 336 | 337 | if (sorted) return (input, output) 338 | 339 | val canSkip = computeOffsets(counts, offsets, 4, arr.length, floats = true) 340 | 341 | var byte = fromByte 342 | while (byte < toByte) { 343 | if ((canSkip & (1 << byte)) == 0) { 344 | 345 | val byteOffsets = Arrays.copyOfRange(offsets, byte * 256, byte * 256 + 256) 346 | 347 | var i = from 348 | while (i < to) { 349 | val c = (floatToRawIntBits(input(i)) >>> (byte * 8)) & 0xff 350 | output(byteOffsets(c)) = input(i) 351 | byteOffsets(c) += (if (byte < 3 || input(i) >= 0) 1 else -1) 352 | i += 1 353 | } 354 | 355 | // swap input with output 356 | val tmp = input 357 | input = output 358 | output = tmp 359 | } 360 | 361 | byte += 1 362 | } 363 | 364 | handleResults(arr, input, output, returnResultInSourceArray) 365 | } 366 | 367 | 368 | 369 | def sort(arr: Array[Double]): Unit = { 370 | if (arr.length <= 1024) { 371 | Arrays.sort(arr) 372 | } else { 373 | sort(arr, new Array[Double](arr.length), 0, arr.length, 0, 8, true) 374 | } 375 | } 376 | 377 | def sort(arr: Array[Double], scratch: Array[Double]): (Array[Double], Array[Double]) = 378 | sort(arr, scratch, 0, arr.length, 0, 8, false) 379 | 380 | def sort(arr: Array[Double], scratch: Array[Double], returnResultInSourceArray: Boolean): (Array[Double], Array[Double]) = 381 | sort(arr, scratch, 0, arr.length, 0, 8, returnResultInSourceArray) 382 | 383 | def sort(arr: Array[Double], scratch: Array[Double], from: Int, to: Int, fromByte: Int, toByte: Int, returnResultInSourceArray: Boolean): (Array[Double], Array[Double]) = { 384 | 385 | checkPreconditions(arr, scratch, from, to, fromByte, toByte, 8) 386 | 387 | if (from >= to) return (arr, scratch) 388 | 389 | var input = arr 390 | var output = scratch 391 | val counts = new Array[Int](8 * 256) 392 | val offsets = new Array[Int](8 * 256) 393 | var sorted = true 394 | var last = input(to - 1) 395 | 396 | // collect counts 397 | // This loop iterates backward because this way it brings begining of the 398 | // `arr` array into a cache and that speeds up next iteration. 399 | var i = to - 1 400 | while (i >= from) { 401 | sorted &= last >= input(i) 402 | last = input(i) 403 | 404 | var byte = 0 405 | while (byte < 8) { 406 | val c = (doubleToRawLongBits(input(i)) >>> (byte * 8) & 0xff).toInt 407 | counts(byte * 256 + c) += 1 408 | byte += 1 409 | } 410 | i -= 1 411 | } 412 | 413 | if (sorted) return (input, output) 414 | 415 | val canSkip = computeOffsets(counts, offsets, 8, arr.length, dealWithNegatives = true, floats = true, detectSkips = true) 416 | 417 | var byte = fromByte 418 | while (byte < toByte) { 419 | if ((canSkip & (1 << byte)) == 0) { 420 | 421 | val byteOffsets = Arrays.copyOfRange(offsets, byte * 256, byte * 256 + 256) 422 | 423 | var i = from 424 | while (i < to) { 425 | val c = (doubleToRawLongBits(input(i)) >>> (byte * 8) & 0xff).toInt 426 | output(byteOffsets(c)) = input(i) 427 | byteOffsets(c) += (if (byte < 7 || input(i) >= 0) 1 else -1) 428 | i += 1 429 | } 430 | 431 | // swap input with output 432 | val tmp = input 433 | input = output 434 | output = tmp 435 | } 436 | 437 | byte += 1 438 | } 439 | 440 | handleResults(arr, input, output, returnResultInSourceArray) 441 | } 442 | 443 | 444 | def sortedByInt[T <: AnyRef: ClassTag](arr: Array[T])(f: T => Int): Array[T] = { 445 | val pack = new Array[Long](arr.length) 446 | var i = 0 ; while (i < arr.length) { 447 | pack(i) = Bits.pack(f(arr(i)), i) 448 | i += 1 449 | } 450 | 451 | val (sorted, _) = RadixSort.sort(pack, new Array[Long](arr.length), 0, arr.length, 4, 8, false) 452 | 453 | val res = new Array[T](arr.length) 454 | i = 0 ; while (i < arr.length) { 455 | res(i) = arr(Bits.unpackIntLo(sorted(i))) 456 | i += 1 457 | } 458 | 459 | res 460 | } 461 | 462 | def sortByInt[T <: AnyRef: ClassTag](arr: Array[T])(f: T => Int): Unit = { 463 | val res = sortedByInt(arr)(f) 464 | var i = 0 ; while (i < arr.length) { 465 | arr(i) = res(i) 466 | i += 1 467 | } 468 | } 469 | } 470 | -------------------------------------------------------------------------------- /src/main/scala/Sketch.scala: -------------------------------------------------------------------------------- 1 | package atrox.sketch 2 | 3 | import java.lang.System.arraycopy 4 | import java.lang.Long.{ bitCount, rotateLeft } 5 | import java.util.Arrays 6 | import atrox.Bits 7 | 8 | 9 | // Sketcher: one locality sensitive hash function 10 | // Sketchers: collection of locality sensitive hash functions 11 | // Sketching: bundle of Sketchers and items to be sensitively hashed 12 | // Sketch: materialized table of skketch arrays 13 | // 14 | // Data in a BitSketch must be 8 byte aligned. SketchLength may not be multiply 15 | // of 64, but every sketch must start in new long field. 16 | 17 | 18 | case class SketchCfg( 19 | maxResults: Int = Int.MaxValue, 20 | orderByEstimate: Boolean = false, 21 | compact: Boolean = true, 22 | parallel: Boolean = false 23 | ) 24 | 25 | 26 | 27 | trait IntSketcher[-T] extends (T => Int) { 28 | /** reduces one item to one component of sketch */ 29 | def apply(item: T): Int = multi(item).hash 30 | def multi(item: T): IntMulti 31 | } 32 | 33 | case class IntMulti( 34 | hash: Int, 35 | /** cost of flipping, smaller value means this component is more borderline and flipping can yield more suitable candidates */ 36 | cost: Double, 37 | neighbour: Int 38 | ) 39 | 40 | trait BitSketcher[-T] extends (T => Boolean) { 41 | /** reduces one item to one component of sketch */ 42 | def apply(item: T): Boolean = multi(item).hash 43 | def multi(item: T): BitMulti 44 | } 45 | 46 | case class BitMulti( 47 | hash: Boolean, 48 | cost: Double 49 | ) 50 | 51 | 52 | trait Sketchers[T, SketchArray] { self => 53 | def sketchLength: Int 54 | def estimator: Estimator[SketchArray] 55 | def rank: Option[IndexedSeq[T] => Rank[T, T]] 56 | def uniformCost: Boolean 57 | 58 | def getSketchFragment(item: T): SketchArray 59 | def getSketchMultiFragment(item: T): MultiFragment[SketchArray] 60 | } 61 | 62 | case class MultiFragment[SketchArray]( 63 | hashes: SketchArray, 64 | costs: Array[Double], 65 | neighbours: SketchArray 66 | ) 67 | 68 | object Sketchers { 69 | def apply[T](sketchers: Array[IntSketcher[T]], estimator: IntEstimator, rank: Option[IndexedSeq[T] => Rank[T, T]], uniformCost: Boolean) = 70 | IntSketchersOf(sketchers, estimator, rank, uniformCost) 71 | def apply[T](n: Int, mk: Int => IntSketcher[T], estimator: IntEstimator, rank: Option[IndexedSeq[T] => Rank[T, T]], uniformCost: Boolean) = 72 | IntSketchersOf(Array.tabulate(n)(mk), estimator, rank, uniformCost) 73 | 74 | def apply[T](sketchers: Array[BitSketcher[T]], estimator: BitEstimator, rank: Option[IndexedSeq[T] => Rank[T, T]], uniformCost: Boolean) = 75 | BitSketchersOf(sketchers, estimator, rank, uniformCost) 76 | def apply[T](n: Int, mk: Int => BitSketcher[T], estimator: BitEstimator, rank: Option[IndexedSeq[T] => Rank[T, T]], uniformCost: Boolean) = 77 | BitSketchersOf(Array.tabulate(n)(mk), estimator, rank, uniformCost) 78 | } 79 | 80 | trait IntSketchers[T] extends Sketchers[T, Array[Int]] 81 | trait BitSketchers[T] extends Sketchers[T, Array[Long]] 82 | 83 | case class IntSketchersOf[T]( 84 | sketchers: Array[IntSketcher[T]], 85 | estimator: IntEstimator, 86 | rank: Option[IndexedSeq[T] => Rank[T, T]], 87 | uniformCost: Boolean 88 | ) extends IntSketchers[T] { 89 | 90 | val sketchLength = sketchers.length 91 | 92 | def getSketchFragment(item: T): Array[Int] = { 93 | val res = new Array[Int](sketchLength) 94 | var i = 0 ; while (i < sketchLength) { 95 | res(i) = sketchers(i)(item) 96 | i += 1 97 | } 98 | res 99 | } 100 | 101 | def getSketchMultiFragment(item: T): MultiFragment[Array[Int]] = { 102 | val hashes = new Array[Int](sketchLength) 103 | val costs = if (uniformCost) null else new Array[Double](sketchLength) 104 | val neighbours = new Array[Int](sketchLength) 105 | 106 | var i = 0 ; while (i < sketchLength) { 107 | val m = sketchers(i).multi(item) 108 | hashes(i) = m.hash 109 | if (!uniformCost) { 110 | costs(i) = m.cost 111 | } 112 | neighbours(i) = m.neighbour 113 | i += 1 114 | } 115 | 116 | MultiFragment(hashes, costs, neighbours) 117 | } 118 | } 119 | 120 | case class BitSketchersOf[T]( 121 | sketchers: Array[BitSketcher[T]], 122 | estimator: BitEstimator, 123 | rank: Option[IndexedSeq[T] => Rank[T, T]], 124 | uniformCost: Boolean 125 | ) extends BitSketchers[T] { 126 | 127 | val sketchLength = sketchers.length 128 | 129 | def getSketchFragment(item: T): Array[Long] = { 130 | val res = new Array[Long]((sketchLength+63)/64) 131 | var i = 0 ; while (i < sketchLength) { 132 | val s = sketchers(i)(item) 133 | val bit = if (s) 1L else 0L 134 | res(i / 64) |= (bit << (i % 64)) 135 | i += 1 136 | } 137 | res 138 | } 139 | 140 | def getSketchMultiFragment(item: T): MultiFragment[Array[Long]] = { 141 | val hashes = new Array[Long]((sketchLength+63)/64) 142 | val costs = new Array[Double](sketchLength) 143 | 144 | var i = 0 ; while (i < sketchLength) { 145 | val m = sketchers(i).multi(item) 146 | val bit = (if (m.hash) 1L else 0L) 147 | hashes(i / 64) |= (bit << (i % 64)) 148 | costs(i) = m.cost 149 | i += 1 150 | } 151 | MultiFragment(hashes, costs, null) 152 | } 153 | 154 | } 155 | 156 | 157 | object Sketching { 158 | type IntSketching[T] = Sketching[T, Array[Int]] 159 | type BitSketching[T] = Sketching[T, Array[Long]] 160 | } 161 | 162 | sealed abstract class Sketching[T, SketchArray] { self => 163 | def itemsCount: Int 164 | def sketchLength: Int 165 | def estimator: Estimator[SketchArray] 166 | def sketchers: Sketchers[T, SketchArray] 167 | 168 | def getSketchFragment(itemIdx: Int): SketchArray 169 | } 170 | 171 | case class SketchingOf[T, SketchArray]( 172 | items: IndexedSeq[T], 173 | sketchers: Sketchers[T, SketchArray] 174 | ) extends Sketching[T, SketchArray] { 175 | 176 | val sketchLength = sketchers.sketchLength 177 | val itemsCount = items.length 178 | val estimator = sketchers.estimator 179 | 180 | def getSketchFragment(itemIdx: Int): SketchArray = 181 | sketchers.getSketchFragment(items(itemIdx)) 182 | } 183 | 184 | 185 | 186 | object Sketch { 187 | def apply[T](items: Seq[T], sk: IntSketchers[T]) = IntSketch(items, sk) 188 | def apply[T](items: Seq[T], sk: BitSketchers[T]) = BitSketch(items, sk) 189 | def apply(items: Array[Long], sk: BitSketchers[Nothing]) = BitSketch[Nothing](items, sk) 190 | 191 | def apply[T, SketchArray](items: Seq[T], sk: Sketchers[T, SketchArray]): Sketch[T, SketchArray] = sk match { 192 | case sk: IntSketchers[T @unchecked] => IntSketch(items, sk) 193 | case sk: BitSketchers[T @unchecked] => BitSketch(items, sk) 194 | } 195 | } 196 | 197 | sealed abstract class Sketch[T, SketchArray] extends Sketching[T, SketchArray] with Serializable { 198 | 199 | type Idxs = Array[Int] 200 | type SimFun = (Int, Int) => Double 201 | 202 | def sketchArray: SketchArray 203 | def itemsCount: Int 204 | 205 | def withConfig(cfg: SketchCfg): Sketch[T, SketchArray] 206 | 207 | def estimator: Estimator[SketchArray] 208 | def cfg: SketchCfg 209 | def sameBits(idxA: Int, idxB: Int): Int = estimator.sameBits(sketchArray, idxA, sketchArray, idxB) 210 | def estimateSimilarity(idxA: Int, idxB: Int): Double = estimator.estimateSimilarity(sameBits(idxA, idxB)) 211 | def minSameBits(sim: Double): Int = estimator.minSameBits(sim) 212 | 213 | def similarIndexes(idx: Int, minEst: Double): Idxs = similarIndexes(idx, minEst, 0.0, null) 214 | def similarIndexes(idx: Int, minEst: Double, minSim: Double, f: SimFun): Idxs = { 215 | val minBits = estimator.minSameBits(minEst) 216 | val res = new collection.mutable.ArrayBuilder.ofInt 217 | var i = 0 ; while (i < itemsCount) { 218 | val bits = sameBits(idx, i) 219 | if (bits >= minBits && idx != i && (f == null || f(i, idx) >= minSim)) { 220 | res += i 221 | } 222 | i += 1 223 | } 224 | res.result 225 | } 226 | 227 | def similarItems(idx: Int, minEst: Double): Iterator[Sim] = similarItems(idx, minEst, 0.0, null) 228 | def similarItems(idx: Int, minEst: Double, minSim: Double, f: SimFun): Iterator[Sim] = { 229 | val minBits = estimator.minSameBits(minEst) 230 | val res = new collection.mutable.ArrayBuffer[Sim] 231 | var i = 0 ; while (i < itemsCount) { 232 | val bits = sameBits(idx, i) 233 | var sim: Double = 0.0 234 | if (bits >= minBits && idx != i && (f == null || { sim = f(i, idx) ; sim >= minSim })) { 235 | res += Sim(i, if (f != null) sim else estimator.estimateSimilarity(bits)) 236 | } 237 | i += 1 238 | } 239 | res.iterator 240 | } 241 | 242 | /* 243 | def allSimilarIndexes(minEst: Double, minSim: Double, f: SimFun): Iterator[(Int, Idxs)] = { 244 | val minBits = estimator.minSameBits(minEst) 245 | 246 | (cfg.compact, cfg.parallel) match { 247 | // this needs in the worst case O(n * s / 4) additional space, where n is the 248 | // number of items in this sketch and s is average number of similar 249 | // indexes 250 | case (false, false) => 251 | val stripeSize = 64 252 | val res = Array.fill(itemsCount)(IndexResultBuilder.make(false, cfg.maxResults) 253 | 254 | Iterator.range(0, itemsCount, step = stripeSize) flatMap { start => 255 | 256 | stripeRun(stripeSize, start, itemsCount, minBits, minSim, f, true, new Op { 257 | def apply(i: Int, j: Int, est: Double, sim: Double): Unit = { 258 | res(i) += (j, sim) 259 | res(j) += (i, sim) 260 | } 261 | }) 262 | 263 | val endi = math.min(start + stripeSize, itemsCount) 264 | Iterator.range(start, endi) map { i => 265 | val arr = res(i).result 266 | res(i) = null 267 | (i, arr) 268 | } filter { _._2.nonEmpty } 269 | } 270 | 271 | // this needs no additional memory but it have to do full n**2 iterations 272 | case (_, _) => 273 | //Iterator.tabulate(itemsCount) { idx => (idx, similarIndexes(idx, minEst, minSim, f)) } 274 | 275 | val stripeSize = 64 276 | val stripesInParallel = if (cfg.parallel) 16 else 1 277 | println(s"copmact, stripesInParallel $stripesInParallel") 278 | 279 | Iterator.range(0, itemsCount, step = stripeSize * stripesInParallel) flatMap { pti => 280 | val res = Array.fill(stripeSize * stripesInParallel)(IndexResultBuilder.make(false, cfg.maxResults) 281 | 282 | parStripes(stripesInParallel, stripeSize, pti, itemsCount) { start => 283 | stripeRun(stripeSize, start, itemsCount, minBits, minSim, f, false, new Op { 284 | def apply(i: Int, j: Int, est: Double, sim: Double): Unit = { 285 | res(i-pti) += (j, sim) 286 | } 287 | }) 288 | } 289 | 290 | Iterator.tabulate(res.length) { i => (i + pti, res(i).result) } filter { _._2.nonEmpty } 291 | } 292 | } 293 | } 294 | def allSimilarIndexes(minEst: Double): Iterator[(Int, Idxs)] = 295 | allSimilarIndexes(minEst, 0.0, null) 296 | def allSimilarIndexes: Iterator[(Int, Idxs)] = 297 | allSimilarIndexes(0.0, 0.0, null) 298 | 299 | def allSimilarItems(minEst: Double, minSim: Double, f: SimFun): Iterator[(Int, Iterator[Sim])] = { 300 | val ff = if (cfg.orderByEstimate) null else f 301 | allSimilarIndexes(minEst, minSim, ff) map { case (idx, simIdxs) => (idx, indexesToSims(idx, simIdxs, f, sketchArray)) } 302 | } 303 | def allSimilarItems(minEst: Double): Iterator[(Int, Iterator[Sim])] = 304 | allSimilarItems(minEst, 0.0, null) 305 | def allSimilarItems: Iterator[(Int, Iterator[Sim])] = 306 | allSimilarItems(0.0, 0.0, null) 307 | 308 | // === internal cruft === 309 | 310 | protected def parStripes(stripesInParallel: Int, stripeSize: Int, base: Int, length: Int)(f: (Int) => Unit): Unit = { 311 | if (stripesInParallel > 1) { 312 | val end = math.min(base + stripeSize * stripesInParallel, length) 313 | (base until end by stripeSize).par foreach { b => 314 | f(b) 315 | } 316 | } else { 317 | f(base) 318 | } 319 | } 320 | 321 | // striping/tiling leads to better cache usage patterns and that subsequently leads to better performance 322 | protected def stripeRun(stripeSize: Int, stripeBase: Int, length: Int, minBits: Int, minSim: Double, f: SimFun, half: Boolean, op: Op): Unit = { 323 | val endi = math.min(stripeBase + stripeSize, length) 324 | val startj = if (!half) 0 else stripeBase + 1 325 | 326 | var j = startj ; while (j < length) { 327 | val realendi = if (!half) endi else math.min(j, endi) 328 | var i = stripeBase ; while (i < realendi) { 329 | val bits = sameBits(i, j) 330 | var sim = 0.0 331 | if (bits >= minBits && i != j && (f == null || { sim = f(i, j) ; sim >= minSim })) { 332 | val est = estimator.estimateSimilarity(bits) 333 | op.apply(i, j, est, if (f == null) est else sim) 334 | } 335 | i += 1 336 | } 337 | j += 1 338 | } 339 | } 340 | */ 341 | 342 | protected abstract class Op { def apply(thisIdx: Int, thatIdx: Int, est: Double, sim: Double): Unit } 343 | 344 | protected def indexesToSims(idx: Int, simIdxs: Idxs, f: SimFun, sketch: SketchArray) = 345 | simIdxs.iterator.map { simIdx => 346 | val est = estimator.estimateSimilarity(sketch, idx, sketch, simIdx) 347 | Sim(simIdx, if (f != null) f(idx, simIdx) else est) 348 | } 349 | } 350 | 351 | trait Estimator[SketchArray] { 352 | def sketchLength: Int 353 | 354 | def minSameBits(sim: Double): Int 355 | def estimateSimilarity(sameBits: Int): Double 356 | 357 | def sameBits(arrA: SketchArray, idxA: Int, arrB: SketchArray, idxB: Int): Int 358 | def estimateSimilarity(arrA: SketchArray, idxA: Int, arrB: SketchArray, idxB: Int): Double = 359 | estimateSimilarity(sameBits(arrA, idxA, arrB, idxB)) 360 | } 361 | 362 | 363 | trait IntEstimator extends Estimator[Array[Int]] { 364 | def sameBits(arrA: Array[Int], idxA: Int, arrB: Array[Int], idxB: Int) = { 365 | var a = idxA * sketchLength 366 | var b = idxB * sketchLength 367 | var same = 0 368 | val end = a + sketchLength 369 | while (a < end) { 370 | same += (if (arrA(a) == arrB(b)) 1 else 0) 371 | a += 1 372 | b += 1 373 | } 374 | same 375 | } 376 | } 377 | 378 | trait BitEstimator extends Estimator[Array[Long]] { 379 | def sameBits(arrA: Array[Long], idxA: Int, arrB: Array[Long], idxB: Int) = { 380 | val longsLen = (sketchLength+63) / 64 // assumes sketch is Long aligned 381 | val a = idxA * longsLen 382 | val b = idxB * longsLen 383 | var i = 0 384 | var same = sketchLength 385 | while (i < longsLen) { 386 | same -= bitCount(arrA(a+i) ^ arrB(b+i)) 387 | i += 1 388 | } 389 | same 390 | } 391 | } 392 | 393 | 394 | trait BitEstimator64 extends BitEstimator { 395 | override def sameBits(arrA: Array[Long], idxA: Int, arrB: Array[Long], idxB: Int): Int = { 396 | 64 - bitCount(arrA(idxA) ^ arrB(idxB)) 397 | } 398 | } 399 | 400 | trait BitEstimator128 extends BitEstimator { 401 | override def sameBits(arrA: Array[Long], idxA: Int, arrB: Array[Long], idxB: Int): Int = { 402 | var same = 128 403 | same -= bitCount(arrA(idxA*2) ^ arrB(idxB*2)) 404 | same -= bitCount(arrA(idxA*2+1) ^ arrB(idxB*2+1)) 405 | same 406 | } 407 | } 408 | 409 | trait BitEstimator256 extends BitEstimator { 410 | override def sameBits(arrA: Array[Long], idxA: Int, arrB: Array[Long], idxB: Int): Int = { 411 | var same = 256 412 | same -= bitCount(arrA(idxA*4) ^ arrB(idxB*4)) 413 | same -= bitCount(arrA(idxA*4+1) ^ arrB(idxB*4+1)) 414 | same -= bitCount(arrA(idxA*4+2) ^ arrB(idxB*4+2)) 415 | same -= bitCount(arrA(idxA*4+3) ^ arrB(idxB*4+3)) 416 | same 417 | } 418 | } 419 | 420 | 421 | 422 | 423 | 424 | 425 | object IntSketch { 426 | def makeSketchArray[T](sk: IntSketchers[T], items: Seq[T]): Array[Int] = { 427 | val sketchArray = new Array[Int](items.length * sk.sketchLength) 428 | for ((item, itemIdx) <- items.iterator.zipWithIndex) { 429 | val arr = sk.getSketchFragment(item) 430 | System.arraycopy(arr, 0, sketchArray, itemIdx * sk.sketchLength, sk.sketchLength) 431 | } 432 | sketchArray 433 | } 434 | 435 | def apply[T](items: Seq[T], sk: IntSketchers[T]): IntSketch[T] = 436 | IntSketch(makeSketchArray(sk, items), sk) 437 | } 438 | 439 | 440 | case class IntSketch[T]( 441 | sketchArray: Array[Int], 442 | sketchers: Sketchers[T, Array[Int]], 443 | cfg: SketchCfg = SketchCfg() 444 | ) extends Sketch[T, Array[Int]] { 445 | 446 | val sketchLength = sketchers.sketchLength 447 | val itemsCount = sketchArray.length / sketchLength 448 | val estimator = sketchers.estimator 449 | 450 | def withConfig(_cfg: SketchCfg): IntSketch[T] = copy(cfg = _cfg) 451 | 452 | def getSketchFragment(itemIdx: Int): Array[Int] = 453 | Arrays.copyOfRange(sketchArray, itemIdx * sketchLength, (itemIdx+1) * sketchLength) 454 | 455 | def sketeches: Iterator[Array[Int]] = 456 | Iterator.tabulate(itemsCount) { i => getSketchFragment(i) } 457 | } 458 | 459 | 460 | 461 | object BitSketch { 462 | def makeSketchArray[T](sk: BitSketchers[T], items: Seq[T]): Array[Long] = { 463 | val longsLen = (sk.sketchLength+63) / 64 // assumes sketch is Long aligned 464 | val sketchArray = new Array[Long](items.length * longsLen) 465 | for ((item, itemIdx) <- items.iterator.zipWithIndex) { 466 | val arr = sk.getSketchFragment(item) 467 | System.arraycopy(arr, 0, sketchArray, itemIdx * longsLen, arr.length) 468 | } 469 | sketchArray 470 | } 471 | 472 | def apply[T](items: Seq[T], sk: BitSketchers[T]): BitSketch[T] = 473 | BitSketch(makeSketchArray(sk, items), sk) 474 | } 475 | 476 | 477 | case class BitSketch[T]( 478 | sketchArray: Array[Long], 479 | sketchers: Sketchers[T, Array[Long]], 480 | cfg: SketchCfg = SketchCfg() 481 | ) extends Sketch[T, Array[Long]] { 482 | 483 | val sketchLength = sketchers.sketchLength 484 | val itemsCount = sketchArray.length * 64 / sketchLength 485 | val estimator = sketchers.estimator 486 | val bitsPerSketch = sketchLength 487 | 488 | def withConfig(_cfg: SketchCfg): BitSketch[T] = copy(cfg = _cfg) 489 | 490 | def getSketchFragment(itemIdx: Int): Array[Long] = 491 | Bits.getBits(sketchArray, itemIdx * bitsPerSketch, (itemIdx+1) * bitsPerSketch) 492 | } 493 | -------------------------------------------------------------------------------- /src/main/scala/SketchImpls.scala: -------------------------------------------------------------------------------- 1 | package atrox.sketch 2 | 3 | import scala.language.postfixOps 4 | import breeze.linalg.{ SparseVector, DenseVector, DenseMatrix, BitVector, normalize, Vector => bVector, operators, norm } 5 | import breeze.stats.distributions.Rand 6 | import atrox.fastSparse 7 | 8 | 9 | 10 | object MinHash { 11 | 12 | def apply[T](hashFunctions: Int)(implicit mk: MinHashImpl[T, Any]): IntSketchers[T] = 13 | weighted(hashFunctions, (a: Any) => 1) 14 | 15 | /** based on https://www.sumologic.com/2015/10/22/rapid-similarity-search-with-weighted-min-hash/ */ 16 | def weighted[T, El](hashFunctions: Int, weights: El => Int)(implicit mk: MinHashImpl[T, El]): IntSketchers[T] = 17 | Sketchers(hashFunctions, (i: Int) => mk(HashFunc.random(i*1000), weights), Estimator(hashFunctions), Some(mk.mkRank), true) 18 | 19 | /** MinHash that uses only one bit. It's much faster than traditional MinHash 20 | * but it seems it's less precise. 21 | * As of right now it's not really suitable for LSH, because most elements 22 | * are hashed into few buckets. Investigation pending. 23 | * https://www.endgame.com/blog/minhash-vs-bitwise-set-hashing-jaccard-similarity-showdown */ 24 | def singleBit[T](hashFunctions: Int)(implicit mk: MinHashImpl[T, Any]): BitSketchers[T] = 25 | singleBitWeighted(hashFunctions, (a: Any) => 1) 26 | 27 | def singleBitWeighted[T, El](hashFunctions: Int, weights: El => Int)(implicit mk: MinHashImpl[T, El]): BitSketchers[T] = 28 | Sketchers(hashFunctions, (i: Int) => onebit(mk(HashFunc.random(i*1000), weights)), SingleBitEstimator(hashFunctions), Some(mk.mkRank), true) 29 | 30 | private def onebit[T](f: IntSketcher[T]) = new BitSketcher[T] { 31 | override def apply(t: T) = (f(t) & 1) != 0 32 | def multi(t: T) = { 33 | val m = f.multi(t) 34 | BitMulti((m.hash & 1) != 0, m.cost) 35 | } 36 | } 37 | 38 | 39 | trait MinHashImpl[T, +El] { 40 | def apply(hf: HashFunc[Int], weights: El => Int): IntSketcher[T] 41 | def mkRank: IndexedSeq[T] => Rank[T, T] 42 | } 43 | 44 | implicit val IntArrayMinHashImpl = new MinHashImpl[Array[Int], Int] { 45 | type Item = Array[Int] 46 | def apply(hf: HashFunc[Int], weights: Int => Int) = new IntSketcher[Item] { 47 | override def apply(set: Item) = minhashArrInt(set, hf, weights) 48 | def multi(set: Item) = minhashArr(set, hf, weights) 49 | } 50 | def mkRank = (items: IndexedSeq[Item]) => SimFun[Item](fastSparse.jaccardSimilarity, items) 51 | } 52 | 53 | implicit def GeneralMinHashImpl[El] = new MinHashImpl[Set[El], El] { 54 | type Item = Set[El] 55 | def apply(hf: HashFunc[Int], weights: El => Int) = new IntSketcher[Item] { 56 | //def apply (set: Item) = minhashTrav(set, hf, weights) 57 | def multi(set: Item) = minhashTrav(set, hf, weights) 58 | } 59 | def mkRank = (items: IndexedSeq[Item]) => SimFun[Item](jacc, items) 60 | } 61 | 62 | 63 | private def minhashArrInt(set: Array[Int], f: HashFunc[Int], weights: Int => Int): Int = { 64 | var min = Int.MaxValue 65 | var j = 0 ; while (j < set.length) { 66 | var h = set(j) 67 | var i = 0 ; while (i < weights(j)) { 68 | h = f(h) 69 | min = math.min(min, h) 70 | i += 1 71 | } 72 | j += 1 73 | } 74 | min 75 | } 76 | 77 | private def minhashArr(set: Array[Int], f: HashFunc[Int], weights: Int => Int): IntMulti = { 78 | var min, min2 = Int.MaxValue 79 | var j = 0 ; while (j < set.length) { 80 | var h = set(j) 81 | var i = 0 ; while (i < weights(j)) { 82 | h = f(h) 83 | //min = math.min(min, h) 84 | if (h < min) { 85 | min2 = math.min(min2, min) 86 | min = h 87 | } 88 | if (h != min) { 89 | min2 = math.min(min2, h) 90 | } 91 | i += 1 92 | } 93 | j += 1 94 | } 95 | IntMulti(min, 1, min2) 96 | } 97 | 98 | private def minhashTrav[El](set: Set[El], f: HashFunc[Int], weights: El => Int): IntMulti = { 99 | var min, min2 = Int.MaxValue 100 | for (el <- set) { 101 | var h = el.hashCode 102 | for (_ <- 0 until weights(el)) { 103 | h = f(h) 104 | //min = math.min(min, h) 105 | if (h < min) { 106 | min2 = math.min(min2, min) 107 | min = h 108 | } 109 | if (h != min) { 110 | min2 = math.min(min2, h) 111 | } 112 | } 113 | } 114 | IntMulti(min, 1, min2) 115 | } 116 | 117 | protected def jacc[El](a: Set[El], b: Set[El]) = { 118 | val small = if (a.size < b.size) a else b 119 | val large = if (a.size < b.size) b else a 120 | 121 | var in = 0 122 | for (el <- small) { 123 | if (large.contains(el)) in += 1 124 | } 125 | 126 | val un = small.size + large.size - in 127 | in.toDouble / un 128 | } 129 | 130 | 131 | case class Estimator(sketchLength: Int) extends IntEstimator { 132 | def estimateSimilarity(sameBits: Int): Double = 133 | sameBits.toDouble / sketchLength 134 | 135 | def minSameBits(sim: Double): Int = { 136 | require(sim >= 0.0 && sim <= 1.0, "similarity must be from (0, 1)") 137 | (sim * sketchLength).toInt 138 | } 139 | } 140 | 141 | case class SingleBitEstimator(val sketchLength: Int) extends BitEstimator { 142 | def estimateSimilarity(sameBits: Int): Double = 143 | 1.0 - 2.0 / sketchLength * (sketchLength - sameBits) 144 | 145 | def minSameBits(sim: Double): Int = { 146 | sketchLength - ((1 - sim) / (2.0 / sketchLength)).toInt 147 | } 148 | } 149 | } 150 | 151 | 152 | 153 | /** Estimates cosine of the angle between two vectors. */ 154 | object RandomHyperplanes { 155 | def apply[T](n: Int, vectorLength: Int, normalized: Boolean = false)(implicit ev: CanDot[T]): BitSketchers[T] = { 156 | Sketchers(n, (i: Int) => mkSketcher(vectorLength, i * 1000), Estimator(n), Some(mkRank(ev, normalized)), false) 157 | } 158 | 159 | // def apply(rowMatrix: DenseMatrix[Double], n: Int): BitSketch[DenseVector[Double]] = ??? 160 | // apply(0 until rowMatrix.rows map { r => rowMatrix(r, ::).t }, n)(CanDotDouble) 161 | 162 | 163 | private def mkRank[InVec](ev: CanDot[InVec], norm: Boolean) = 164 | if (norm) (items: IndexedSeq[InVec]) => SimFun[InVec]((a, b) => ev.dotInVec(a, b), items) 165 | else (items: IndexedSeq[InVec]) => SimFun[InVec]((a, b) => ev.dotInVec(a, b) / (ev.normInVec(a) * ev.normInVec(b)), items) 166 | 167 | trait CanDot[InVec] { 168 | type RndVec 169 | def makeRandomHyperplane(length: Int, seed: Int): RndVec 170 | def dotRndVec(a: InVec, b: RndVec): Double 171 | def dotInVec(a: InVec, b: InVec): Double 172 | def normRndVec(a: RndVec): Double 173 | def normInVec(a: InVec): Double 174 | } 175 | 176 | type Mul[A, B, C] = operators.OpMulInner.Impl2[A, B, C] 177 | type Norm[A, B, C] = norm.Impl2[A, B, C] 178 | 179 | implicit def CanDotFloat[T](implicit 180 | dotf1: Mul[T, DenseVector[Float], Float], 181 | dotf2: Mul[T, T, Float], 182 | normf1: Norm[DenseVector[Float], Double, Double], 183 | normf2: Norm[T, Double, Double] 184 | ) = new CanDot[T] { 185 | 186 | type RndVec = DenseVector[Float] 187 | 188 | def makeRandomHyperplane(length: Int, seed: Int): RndVec = mkRandomHyperplane(length, seed) mapValues (_.toFloat) 189 | def dotRndVec(a: T, b: RndVec): Double = dotf1(a, b).toDouble 190 | def dotInVec(a: T, b: T): Double = dotf2(a, b).toDouble 191 | def normRndVec(a: RndVec): Double = normf1(a, 2) 192 | def normInVec(a: T): Double = normf2(a, 2) 193 | } 194 | 195 | implicit def CanDotDouble[T](implicit 196 | dotf1: Mul[T, DenseVector[Double], Double], 197 | dotf2: Mul[T, T, Double], 198 | normf1: Norm[DenseVector[Double], Double, Double], 199 | normf2: Norm[T, Double, Double] 200 | ) = new CanDot[T] { 201 | 202 | type RndVec = DenseVector[Double] 203 | 204 | def makeRandomHyperplane(length: Int, seed: Int): RndVec = mkRandomHyperplane(length, seed) 205 | def dotRndVec(a: T, b: RndVec): Double = dotf1(a, b) 206 | def dotInVec(a: T, b: T): Double = dotf2(a, b) 207 | def normRndVec(a: RndVec): Double = normf1(a, 2) 208 | def normInVec(a: T): Double = normf2(a, 2) 209 | } 210 | 211 | implicit def CanDotSparseDouble = new CanDot[SparseVector[Double]] { 212 | trait RndVec { 213 | def apply(i: Int): Double 214 | def norm: Double 215 | } 216 | 217 | def makeRandomHyperplane(length: Int, seed: Int) = { 218 | val f = HashFunc.random(seed, 16) 219 | 220 | new RndVec { 221 | def apply(i: Int): Double = { 222 | val j = i+1337 223 | val bit = (f(j/16) >> (j%16)) & 1 224 | if (bit == 1) 1.0 else -1.0 225 | } 226 | val norm: Double = math.sqrt(length) 227 | } 228 | } 229 | 230 | def dotRndVec(a: SparseVector[Double], b: RndVec): Double = { 231 | var d = 0.0 232 | var offset = 0 233 | while (offset < a.activeSize) { 234 | val i = a.indexAt(offset) 235 | val v = a.valueAt(offset) 236 | d += v * b(i) 237 | offset += 1 238 | } 239 | d 240 | } 241 | 242 | def dotInVec(a: SparseVector[Double], b: SparseVector[Double]) = a dot b 243 | 244 | def normInVec(a: SparseVector[Double]): Double = norm(a, 2) 245 | def normRndVec(a: RndVec): Double = a.norm 246 | } 247 | 248 | 249 | private def mkSketcher[T](length: Int, seed: Int)(implicit ev: CanDot[T]): BitSketcher[T] = new BitSketcher[T] { 250 | private val rand = ev.makeRandomHyperplane(length, seed) 251 | //def apply(item: T) = ev.dotRndVec(item, rand) > 0.0 252 | def multi(item: T) = { 253 | val dot = ev.dotRndVec(item, rand) 254 | BitMulti(dot > 0.0, math.abs(dot)) 255 | } 256 | } 257 | 258 | private def mkRandomHyperplane(length: Int, seed: Int): DenseVector[Double] = { 259 | val rand = new scala.util.Random(seed) 260 | DenseVector.fill[Double](length)(if (rand.nextDouble < 0.5) -1.0 else 1.0) 261 | } 262 | 263 | case class Estimator(sketchLength: Int) extends BitEstimator { 264 | def estimateSimilarity(sameBits: Int): Double = 265 | math.cos(math.Pi * (1 - sameBits / sketchLength.toDouble)) 266 | 267 | def minSameBits(sim: Double): Int = { 268 | require(sim >= -1 && sim <= 1, "similarity must be from (-1, 1)") 269 | math.floor((1.0 - math.acos(sim) / math.Pi) * sketchLength).toInt 270 | } 271 | } 272 | } 273 | 274 | 275 | 276 | 277 | /** Estimates distance */ 278 | object RandomProjections { 279 | 280 | def apply[V](projections: Int, bucketSize: Double, vectorLength: Int): IntSketchers[bVector[Double]] = 281 | Sketchers(projections, (i: Int) => mkSketcher(vectorLength, i * 1000, bucketSize), Estimator(projections), None, false) 282 | 283 | 284 | private def mkSketcher(vectorLength: Int, seed: Int, bucketSize: Double) = 285 | new IntSketcher[bVector[Double]] { 286 | private val randVec = mkRandomUnitVector(vectorLength, seed) 287 | 288 | // def apply(item: bVector[Double]): Int = 289 | // ((randVec dot item) / bucketSize).toInt 290 | 291 | def multi(item: bVector[Double]) = { 292 | val d = ((randVec dot item) / bucketSize) 293 | val bucket = d.toInt 294 | val neighbour = bucket + (if (d < bucket+0.5) -1 else +1) 295 | val cost = 0.5 - math.abs(bucket+0.5-d) 296 | IntMulti(bucket, cost, neighbour) 297 | } 298 | } 299 | 300 | private def mkRandomUnitVector(length: Int, seed: Int) = { 301 | val rand = new scala.util.Random(seed) 302 | normalize(DenseVector.fill[Double](length)(rand.nextGaussian), 2) 303 | //normalize(DenseVector.rand[Double](length, Rand.gaussian), 2) 304 | } 305 | 306 | // def mkRank = (items: IndexedSeq[T]) => SimFun((a, b) => sum(pow(a - b, 2)), items) 307 | 308 | case class Estimator(sketchLength: Int) extends IntEstimator { 309 | def estimateSimilarity(sameBits: Int): Double = ??? 310 | def minSameBits(sim: Double): Int = ??? 311 | } 312 | 313 | } 314 | 315 | 316 | /* 317 | // A Brief Index for Proximity Searching https://www.researchgate.net/publication/220843654_A_Brief_Index_for_Proximity_Searching 318 | object RandomPermutations { 319 | 320 | type DistFun[T] = (T, T) => Double 321 | 322 | def sketching[T](items: IndexedSeq[T], referencePoints: IndexedSeq[T], dist: DistFun[T]): BitSketching = 323 | new BitSketching { 324 | val sketchLength: Int = referencePoints.length 325 | val length: Int = items.length 326 | val estimator: BitEstimator = Estimator(referencePoints.length) 327 | 328 | def writeSketchFragment(itemIdx: Int, from: Int, to: Int, dest: Array[Long], destOffset: Int): Unit = { 329 | val p = permutation(referencePoints, items(itemIdx), dist) 330 | val arr = encode(p, m = referencePoints.length / 2) 331 | 332 | var i = from 333 | var j = destOffset 334 | while (i < to) { 335 | val bit = (arr(i / 64) >> (i % 64)) & 1 336 | dest(j / 64) |= (bit << (j % 64)) 337 | i += 1 338 | j += 1 339 | } 340 | } 341 | } 342 | 343 | def apply[T](items: IndexedSeq[T], referencePoints: IndexedSeq[T], dist: DistFun[T]): BitSketch[T] = { 344 | val sk = sketching(items, referencePoints, dist) 345 | BitSketch.make(sk) 346 | } 347 | 348 | def sketching[T](items: IndexedSeq[T], referencePoints: Int, dist: DistFun[T]): BitSketching = 349 | sketching(items, sampleReferencePoints(items, referencePoints), dist) 350 | 351 | def apply[T](items: IndexedSeq[T], referencePoints: Int, dist: DistFun[T]): BitSketch[T] = 352 | apply(items, sampleReferencePoints(items, referencePoints), dist) 353 | 354 | 355 | private def sampleReferencePoints[T](items: IndexedSeq[T], n: Int): IndexedSeq[T] = { 356 | // TODO sampling without repetition 357 | val rnd = new util.Random(1234) 358 | IndexedSeq.fill(n) { items(rnd.nextInt(items.length)) } 359 | } 360 | 361 | def permutation[T](referencePoints: IndexedSeq[T], q: T, dist: DistFun[T]): Array[Int] = 362 | referencePoints.zipWithIndex.map { case (p, i) => (dist(q, p), i) }.sorted.map(_._2).toArray 363 | 364 | def inv(p: Array[Int]) = { 365 | val inv = new Array[Int](p.length) 366 | for (i <- 0 until p.length) inv(p(i)) = i 367 | inv 368 | } 369 | 370 | // m = p.length / 2 is apparebntly a good choice 371 | def encode(p: Array[Int], m: Int) = { 372 | require(m > 0) 373 | val pinv = inv(p) 374 | val C = new Array[Long]((p.length+63)/64) 375 | for (i <- 0 until p.length) { 376 | if (math.abs(i - pinv(i)) > m) { 377 | C(i / 64) |= (1 << (i % 64)) 378 | } 379 | } 380 | C 381 | } 382 | 383 | // Bit-encoding using permutation of the center. Interchangeable with encode. 384 | def encodePermCenter(p: Array[Int], m: Int) = { 385 | require(m > 0) 386 | val pinv = inv(p) 387 | val C = new Array[Long]((p.length+63)/64) 388 | val M = p.length / 4 389 | for (i <- 0 until p.length) { 390 | var I = i 391 | if ((I / M) % 3 == 0) { 392 | I += M 393 | } 394 | if (math.abs(I - pinv(i)) > m) { 395 | C(i / 64) |= (1 << (i % 64)) 396 | } 397 | } 398 | C 399 | } 400 | 401 | 402 | case class Estimator(sketchLength: Int) extends BitEstimator { 403 | def estimateSimilarity(sameBits: Int): Double = ??? 404 | def minSameBits(sim: Double): Int = ??? 405 | } 406 | 407 | } 408 | 409 | object RandomBisectors { 410 | 411 | type DistFun[T] = (T, T) => Double 412 | 413 | def sketching[T](items: IndexedSeq[T], bisectors: IndexedSeq[(T, T)], dist: DistFun[T]): BitSketching = 414 | new BitSketchingOf(items, bisectors.length, i => mkSketcher(bisectors(i), dist), Estimator(bisectors.length)) 415 | 416 | def apply[T](items: IndexedSeq[T], bisectors: IndexedSeq[(T, T)], dist: DistFun[T]): BitSketch[T] = 417 | BitSketch.make(sketching(items, bisectors, dist)) 418 | 419 | 420 | def sketching[T](items: IndexedSeq[T], bisectors: Int, dist: DistFun[T]): BitSketching = 421 | sketching(items, samplePairs(items, bisectors), dist) 422 | 423 | def apply[T](items: IndexedSeq[T], bisectors: Int, dist: DistFun[T]): BitSketch[T] = 424 | apply(items, samplePairs(items, bisectors), dist) 425 | 426 | 427 | private def samplePairs[T](items: IndexedSeq[T], n: Int): IndexedSeq[(T, T)] = { 428 | // TODO sampling without repetition 429 | val rnd = new util.Random(1234) 430 | def pick() = rnd.nextInt(items.length) 431 | IndexedSeq.fill(n) { (items(pick()), items(pick())) } 432 | } 433 | 434 | 435 | private def mkSketcher[T](points: (T, T), dist: (T, T) => Double) = new BitSketcher[T] { 436 | val (a, b) = points 437 | def apply(item: T): Boolean = dist(a, item) < dist(b, item) 438 | } 439 | 440 | case class Estimator(sketchLength: Int) extends BitEstimator { 441 | def estimateSimilarity(sameBits: Int): Double = sameBits.toDouble / sketchLength 442 | def minSameBits(sim: Double): Int = { 443 | require(sim >= 0.0 && sim <= 1.0, "similarity must be from (0, 1)") 444 | (sim * sketchLength).toInt 445 | } 446 | } 447 | 448 | } 449 | */ 450 | 451 | 452 | 453 | /* 454 | object PStableDistributions { 455 | 456 | def apply(vectors: IndexedSeq[SparseVector[Double]], sketchLength: Int, p: Double): PStableDistributions = { 457 | } 458 | 459 | // http://www.cs.dartmouth.edu/~ac/Teach/CS49-Fall11/Papers/indyk-stable.pdf 460 | def pstable(p: Double, a: Double, b: Double): Double = { 461 | require(a >= 0 && a <= 1.0) 462 | require(b >= 0 && b <= 1.0) 463 | import math._ 464 | 465 | val Θ = (a - 0.5) * Pi // [-π/2, π/2] 466 | val r = b // [0, 1] 467 | 468 | sin(p * Θ) / pow(cos(Θ), 1.0 / p) * pow(cos(Θ * (1 - p)) / -log(r), (1 - p) / p) 469 | } 470 | 471 | } 472 | 473 | 474 | final class PStableDistributions(val sketchArray: Array[Double], val sketchLength: Int, val p: Double) { 475 | 476 | def estimateSimilarity(idxA: Int, idxB: Int): Double 477 | 478 | def sameBits(idxA: Int, idxB: Int): Int = ??? 479 | def minSameBits(sim: Double): Int = ??? 480 | def empty: Sketch = ??? 481 | } 482 | 483 | 484 | 485 | object SpectralHashing { 486 | https://people.csail.mit.edu/torralba/publications/spectralhashing.pdf 487 | https://github.com/superhans/SpectralHashing/blob/master/compressSH.m 488 | https://github.com/wanji/sh/blob/master/sh.py 489 | } 490 | 491 | */ 492 | 493 | 494 | 495 | object HammingDistance { 496 | 497 | def apply(arr: Array[Long], bits: Int): BitSketch[Array[Long]] = { 498 | require(bits % 64 == 0) 499 | 500 | val sketchers: BitSketchers[Array[Long]] = new BitSketchers[Array[Long]] { self => 501 | val sketchLength = bits 502 | val estimator = Estimator(bits) 503 | val rank = None 504 | val uniformCost = true 505 | 506 | def getSketchFragment(item: Array[Long]) = { 507 | require(item.length*64 == bits) 508 | item 509 | } 510 | def getSketchMultiFragment(item: Array[Long]) = 511 | MultiFragment(getSketchFragment(item), null, null) 512 | } 513 | 514 | BitSketch[Array[Long]](arr, sketchers) 515 | } 516 | 517 | def apply(arr: Array[Array[Long]], bits: Int): BitSketch[Array[Long]] = { 518 | val longs = (bits+63) / 64 519 | val len = longs * arr.length 520 | val res = new Array[Long](len) 521 | var i = 0 ; while (i < len) { 522 | var j = 0 ; while (j < longs) { 523 | res(i*longs + j) = arr(i)(j) 524 | j += 1 525 | } 526 | i += 1 527 | } 528 | apply(res, bits) 529 | } 530 | 531 | 532 | case class Estimator(sketchLength: Int) extends BitEstimator { 533 | private[this] val inv = 1.0 / sketchLength 534 | 535 | def estimateSimilarity(sameBits: Int): Double = sameBits * inv 536 | def minSameBits(sim: Double): Int = (sketchLength * sim).toInt 537 | } 538 | } 539 | 540 | 541 | 542 | 543 | 544 | object SimHash { 545 | 546 | def apply[T](implicit f: HashFuncLong[T]): BitSketchers[Array[T]] = 547 | new BitSketchers[Array[T]] { 548 | val sketchLength = 64 549 | val estimator = HammingDistance.Estimator(64) 550 | val rank = None 551 | val uniformCost = true 552 | 553 | def getSketchFragment(item: Array[T]): Array[Long] = 554 | Array[Long](doSimHash64(item, f)) 555 | def getSketchMultiFragment(item: Array[T]) = ??? 556 | } 557 | 558 | 559 | implicit def md5 = new HashFuncLong[String] { 560 | def apply(x: String): Long = { 561 | val m = java.security.MessageDigest.getInstance("MD5") 562 | val bytes = m.digest(x.getBytes()) 563 | java.nio.ByteBuffer.wrap(bytes).getLong 564 | } 565 | } 566 | 567 | private def doSimHash64[T](xs: Array[T], f: HashFuncLong[T]): Long = { 568 | 569 | val counts = new Array[Int](64) 570 | 571 | for (x <- xs) { 572 | val l = f(x) 573 | var i = 0 ; while (i < 64) { 574 | counts(i) += (if ((l & (1 << i)) != 0) 1 else -1) 575 | i += 1 576 | } 577 | } 578 | 579 | var hash = 0L 580 | var i = 0 ; while (i < 64) { 581 | if (counts(i) > 0) { 582 | hash |= 1L << i 583 | } 584 | i += 1 585 | } 586 | 587 | hash 588 | } 589 | } 590 | -------------------------------------------------------------------------------- /src/main/scala/fast-dot.scala: -------------------------------------------------------------------------------- 1 | package atrox 2 | 3 | import breeze.linalg.{ SparseVector, DenseVector, BitVector } 4 | import scala.reflect.ClassTag 5 | import scala.specialized 6 | import java.util.Arrays 7 | 8 | 9 | /** Sparse-sparse vector dot product that is much simpler than standard one from breeze. 10 | * In many cases can be much faster. Someimes even 2.5 times faster. It's 11 | * caused by the fact that it contains lot less instruction and only one 12 | * semi-unpredictable conditional jump in hot path. 13 | */ 14 | object fastDotProduct extends breeze.generic.UFunc.UImpl2[breeze.linalg.operators.OpMulInner.type, SparseVector[Double], SparseVector[Double], Double] { 15 | def apply(a: SparseVector[Double], b: SparseVector[Double]): Double = { 16 | require(a.size == b.size, "Vectors must be the same length!") 17 | 18 | val ak: Array[Int] = a.array.index 19 | val av: Array[Double] = a.array.data 20 | 21 | val bk: Array[Int] = b.array.index 22 | val bv: Array[Double] = b.array.data 23 | 24 | var prod = 0.0 25 | var ai, bi = 0 26 | while (ai != ak.length && bi != bk.length) { 27 | val a = ak(ai) 28 | val b = bk(bi) 29 | if (a == b) { 30 | prod += av(ai) * bv(bi) 31 | } 32 | 33 | // progress counter with smaller key 34 | ai += (if (a <= b) 1 else 0) 35 | bi += (if (a >= b) 1 else 0) 36 | } 37 | prod 38 | } 39 | } 40 | 41 | 42 | object fastSparse { 43 | 44 | trait Rel[@specialized T] { 45 | def eq (a: T, b: T): Boolean 46 | def gt (a: T, b: T): Boolean 47 | def gte(a: T, b: T): Boolean = !gt(b, a) 48 | def lt (a: T, b: T): Boolean = gt(b, a) 49 | def lte(a: T, b: T): Boolean = !gt(a, b) 50 | def sort(a: Array[T]): a.type 51 | } 52 | 53 | object Rel { 54 | implicit val IntRel: Rel[Int] = new Rel[Int] { 55 | def gt(a: Int, b: Int) = a > b 56 | def eq(a: Int, b: Int) = a == b 57 | def sort(a: Array[Int]): a.type = { Arrays.sort(a) ; a } 58 | } 59 | implicit val LongRel: Rel[Long] = new Rel[Long] { 60 | def gt(a: Long, b: Long) = a > b 61 | def eq(a: Long, b: Long) = a == b 62 | def sort(a: Array[Long]): a.type = { Arrays.sort(a) ; a } 63 | } 64 | implicit val FloatRel: Rel[Float] = new Rel[Float] { 65 | def gt(a: Float, b: Float) = a > b 66 | def eq(a: Float, b: Float) = a == b 67 | def sort(a: Array[Float]): a.type = { Arrays.sort(a) ; a } 68 | } 69 | implicit val DoubleRel: Rel[Double] = new Rel[Double] { 70 | def gt(a: Double, b: Double) = a > b 71 | def eq(a: Double, b: Double) = a == b 72 | def sort(a: Array[Double]): a.type = { Arrays.sort(a) ; a } 73 | } 74 | } 75 | 76 | 77 | /** Prepare integer array to be used by set functions in the fastSparse 78 | * module - ie. values are distinct and increasing. This method might return 79 | * the new array or modify the old array. */ 80 | def makeSet[@specialized(Int, Long) T](arr: Array[T])(implicit rel: Rel[T]): Array[T] = { 81 | rel.sort(arr) 82 | if (isDistinctIncreasingArray(arr)) arr 83 | else arr.distinct 84 | } 85 | 86 | def isDistinctIncreasingArray[@specialized(Int, Long) T](arr: Array[T])(implicit rel: Rel[T]): Boolean = { 87 | if (arr.length <= 1) return true 88 | 89 | var last = arr(0) 90 | var i = 1 91 | while (i < arr.length) { 92 | if (rel.gte(last, arr(i))) return false 93 | last = arr(i) 94 | i += 1 95 | } 96 | 97 | true 98 | } 99 | 100 | def isIncreasingArray[@specialized(Int, Long) T](arr: Array[T])(implicit rel: Rel[T]): Boolean = { 101 | if (arr.length <= 1) return true 102 | 103 | var last = arr(0) 104 | var i = 1 105 | while (i < arr.length) { 106 | if (rel.gt(last, arr(i))) return false 107 | last = arr(i) 108 | i += 1 109 | } 110 | 111 | true 112 | } 113 | 114 | 115 | /** arguments must be sets represented as sorted arrays */ 116 | def intersectionSize(a: Array[Int], b: Array[Int]): Int = { 117 | var size, ai, bi = 0 118 | while (ai < a.length && bi < b.length) { 119 | val av = a(ai) 120 | val bv = b(bi) 121 | size += (if (av == bv) 1 else 0) 122 | ai += (if (av <= bv) 1 else 0) 123 | bi += (if (av >= bv) 1 else 0) 124 | } 125 | size 126 | } 127 | 128 | 129 | /** This method tried to look ahead and skip some unnecessary iterations. In 130 | * some cases it can be faster than straightforward code, but it's rarely 131 | * slower. */ 132 | def intersectionSizeWithSkips[@specialized(Int, Long) T](a: Array[T], b: Array[T], skip: Int)(implicit rel: Rel[T]): Int = { 133 | var size, ai, bi = 0 134 | 135 | val alen = a.length - skip 136 | val blen = b.length - skip 137 | 138 | while (ai < alen && bi < blen) { 139 | val av = a(ai) 140 | val bv = b(bi) 141 | val _ai = ai 142 | val _bi = bi 143 | size += (if (rel.eq (av, bv)) 1 else 0) 144 | ai += (if (rel.lte(av, bv)) (if (rel.lt(a(_ai+skip), bv)) skip else 1) else 0) 145 | bi += (if (rel.gte(av, bv)) (if (rel.lt(b(_bi+skip), av)) skip else 1) else 0) 146 | } 147 | 148 | while (ai < a.length && bi < b.length) { 149 | val av = a(ai) 150 | val bv = b(bi) 151 | size += (if (rel.eq (av, bv)) 1 else 0) 152 | ai += (if (rel.lte(av, bv)) 1 else 0) 153 | bi += (if (rel.gte(av, bv)) 1 else 0) 154 | } 155 | 156 | size 157 | } 158 | 159 | def unionSize(a: Array[Int], b: Array[Int]): Int = 160 | a.length + b.length - intersectionSize(a, b) 161 | 162 | /** result = |a -- b| */ 163 | def diffSize(a: Array[Int], b: Array[Int]): Int = 164 | a.length - intersectionSize(a, b) 165 | 166 | 167 | def intersectionAndUnionSize(a: Array[Int], b: Array[Int]): (Int, Int) = { 168 | val is = intersectionSize(a, b) 169 | (is, a.length + b.length - is) 170 | } 171 | 172 | def jaccardSimilarity(a: Array[Int], b: Array[Int]): Double = { 173 | val is = intersectionSize(a, b) 174 | val un = a.length + b.length - is 175 | if (un == 0) 0 else is.toDouble / un 176 | } 177 | 178 | 179 | def union(a: Array[Int], b: Array[Int]): Array[Int] = { 180 | val res = new Array[Int](unionSize(a, b)) 181 | var i, ai, bi = 0 182 | while (ai != a.length && bi != b.length) { 183 | val av = a(ai) 184 | val bv = b(bi) 185 | res(i) = (if (av > bv) bv else av) 186 | ai += (if (av <= bv) 1 else 0) 187 | bi += (if (av >= bv) 1 else 0) 188 | i += 1 189 | } 190 | 191 | while (ai != a.length) { 192 | res(i) = a(ai) 193 | i += 1 194 | ai += 1 195 | } 196 | 197 | while (bi != b.length) { 198 | res(i) = b(bi) 199 | i += 1 200 | bi += 1 201 | } 202 | 203 | res 204 | } 205 | 206 | 207 | def intersection(a: Array[Int], b: Array[Int]): Array[Int] = { 208 | val res = new Array[Int](intersectionSize(a, b)) 209 | var i, ai, bi = 0 210 | while (ai != a.length && bi != b.length) { 211 | val av = a(ai) 212 | val bv = b(bi) 213 | 214 | if (av == bv) { 215 | res(i) = av 216 | i += 1 217 | } 218 | 219 | ai += (if (av <= bv) 1 else 0) 220 | bi += (if (av >= bv) 1 else 0) 221 | } 222 | 223 | res 224 | } 225 | 226 | 227 | /** result = a -- b */ 228 | def diff(a: Array[Int], b: Array[Int]): Array[Int] = { 229 | val res = new Array[Int](diffSize(a, b)) 230 | var i, ai, bi = 0 231 | while (ai != a.length && bi != b.length) { 232 | val av = a(ai) 233 | val bv = b(bi) 234 | if (av == bv) { 235 | ai += 1 236 | bi += 1 237 | 238 | } else if (av < bv) { 239 | res(i) = av 240 | i += 1 241 | ai += 1 242 | 243 | } else { 244 | bi += 1 245 | } 246 | } 247 | 248 | while (ai != a.length) { 249 | res(i) = a(ai) 250 | i += 1 251 | ai += 1 252 | } 253 | 254 | res 255 | } 256 | 257 | 258 | def possibleSetOverlap(a: Array[Int], b: Array[Int]) = 259 | a.length != 0 && b.length != 0 && !(a(a.length - 1) < b(0) || b(b.length - 1) < a(0)) 260 | 261 | 262 | def weightedIntersectionSize(a: Array[Int], b: Array[Int], ws: Array[Double]): Double = { 263 | if (!possibleSetOverlap(a, b)) return 0.0 264 | 265 | var ai, bi = 0 266 | var size = 0.0 267 | while (ai != a.length && bi != b.length) { 268 | val av = a(ai) 269 | val bv = b(bi) 270 | size += (if (av == bv) ws(av) else 0) 271 | ai += (if (av <= bv) 1 else 0) 272 | bi += (if (av >= bv) 1 else 0) 273 | } 274 | size 275 | } 276 | 277 | private def _sum(a: Array[Int], ws: Array[Double]): Double = { 278 | var s = 0.0 279 | var i = 0 280 | while (i < a.length) { 281 | s += ws(a(i)) 282 | i += 1 283 | } 284 | s 285 | } 286 | 287 | def weightedJaccardSimilarity(a: Array[Int], b: Array[Int], ws: Array[Double]): Double = 288 | weightedJaccardSimilarity(a, b, ws, _sum(a, ws), _sum(b, ws)) 289 | 290 | def weightedJaccardSimilarity(a: Array[Int], b: Array[Int], ws: Array[Double], wasum: Double, wbsum: Double): Double = { 291 | val is = weightedIntersectionSize(a, b, ws) 292 | val un = wasum + wbsum - is 293 | if (un == 0) 0 else is / un 294 | } 295 | 296 | 297 | private def _initUnion(sets: Array[Array[Int]]): (MinIntIntHeap, Array[Int]) = { 298 | val heap = MinIntIntHeap.builder(sets.length) 299 | val positions = new Array[Int](sets.length) 300 | 301 | var i = 0 302 | while (i < sets.length) { 303 | if (sets(i) != null && sets(i).length > 0) { 304 | heap.insert(sets(i)(0), i) 305 | positions(i) += 1 306 | } 307 | i += 1 308 | } 309 | 310 | (heap.result, positions) 311 | } 312 | 313 | 314 | private def _stepUnion(sets: Array[Array[Int]], i: Int, heap: MinIntIntHeap, positions: Array[Int]) = { 315 | if (positions(i) < sets(i).length) { 316 | heap.deleteMinAndInsert(sets(i)(positions(i)), i) 317 | positions(i) += 1 318 | } else { 319 | heap.deleteMin() 320 | } 321 | } 322 | 323 | 324 | /** Computes size of union of array of sets via multiway merge */ 325 | def unionSize(sets: Array[Array[Int]]): Int = 326 | sets.length match { 327 | case 0 => 0 328 | case 1 => sets(0).length 329 | case 2 => unionSize(sets(0), sets(1)) 330 | case _ => 331 | val (heap, positions) = _initUnion(sets) 332 | var min = Long.MinValue 333 | var size = 0 334 | 335 | while (heap.nonEmpty) { 336 | val key = heap.minKey 337 | val i = heap.minValue 338 | 339 | if (key.toLong != min) { 340 | size += 1 341 | min = key 342 | } 343 | 344 | _stepUnion(sets, i, heap, positions) 345 | } 346 | 347 | size 348 | } 349 | 350 | def union(sets: Array[Array[Int]]): Array[Int] = 351 | union(sets, 0) 352 | 353 | def union(sets: Array[Array[Int]], expectedResultSize: Int): Array[Int] = 354 | sets.length match { 355 | case 0 => new Array[Int](0) 356 | //case 1 => sets(0) 357 | //case 2 => union(sets(0), sets(1)) 358 | case _ => multiwayUnion(sets, expectedResultSize) 359 | } 360 | 361 | private def multiwayUnion(sets: Array[Array[Int]], expectedResultSize: Int): Array[Int] = { 362 | val (heap, positions) = _initUnion(sets) 363 | var min = Long.MinValue 364 | val buff = new collection.mutable.ArrayBuilder.ofInt 365 | buff.sizeHint(expectedResultSize) 366 | 367 | while (heap.nonEmpty) { 368 | val key = heap.minKey 369 | val i = heap.minValue 370 | 371 | if (key.toLong != min) { 372 | buff += key 373 | min = key 374 | } 375 | 376 | _stepUnion(sets, i, heap, positions) 377 | } 378 | 379 | buff.result 380 | } 381 | 382 | def unionBruteForce(sets: Array[Array[Int]]): Array[Int] = { 383 | 384 | def isAllNonEmpty(sets: Array[Array[Int]]): Boolean = { 385 | var i = 0 ; while (i < sets.length) { 386 | if (sets(i).length == 0) return false; 387 | i += 1 388 | } 389 | true 390 | } 391 | 392 | if (sets.length == 0) return new Array[Int](0) 393 | 394 | val _sets: Array[Array[Int]] = if (isAllNonEmpty(sets)) sets else sets.filter(_.length > 0) 395 | 396 | 397 | val heads = new Array[Int](_sets.length) 398 | val positions = new Array[Int](_sets.length) 399 | 400 | var i = 0 ; while (i < _sets.length) { 401 | heads(i) = _sets(i)(0) 402 | positions(i) += 1 403 | i += 1 404 | } 405 | 406 | val buff = new collection.mutable.ArrayBuilder.ofInt 407 | buff.sizeHint(sets(0).length) 408 | 409 | var min = Long.MaxValue 410 | var activeSets = _sets.length 411 | 412 | while (activeSets > 0) { 413 | 414 | var minVal = Int.MaxValue 415 | var minIdx = 0 416 | i = 0 ; while (i < heads.length) { 417 | minIdx = if (heads(i) < minVal) i else minIdx 418 | minVal = heads(minIdx) 419 | i += 1 420 | } 421 | 422 | if (minVal.toLong != min) { 423 | buff += minVal 424 | min = minVal 425 | } 426 | 427 | if (positions(minIdx) < _sets(minIdx).length) { 428 | heads(minIdx) = _sets(minIdx)(positions(minIdx)) 429 | positions(minIdx) += 1 430 | } else { 431 | heads(minIdx) = Int.MaxValue 432 | activeSets -= 1 433 | } 434 | } 435 | 436 | //println(s"len = ${sets.map(_.length).sum}, rep = $rep, sets = ${sets.length}") 437 | 438 | buff.result 439 | } 440 | 441 | def unionOfFrequentItems(sets: Array[Array[Int]], minFreq: Int): Array[Int] = { 442 | 443 | if (minFreq <= 1) return union(sets) 444 | 445 | val (heap, positions) = _initUnion(sets) 446 | var min = Long.MinValue 447 | var cnt = 0 448 | val buff = new collection.mutable.ArrayBuilder.ofInt 449 | 450 | while (heap.nonEmpty) { 451 | val key = heap.minKey 452 | val i = heap.minValue 453 | 454 | if (key.toLong != min) { 455 | min = key 456 | cnt = 1 457 | } else { 458 | cnt += 1 459 | if (cnt == minFreq) buff += key 460 | } 461 | 462 | _stepUnion(sets, i, heap, positions) 463 | } 464 | 465 | buff.result 466 | } 467 | 468 | def intersection(sets: Array[Array[Int]]): Array[Int] = 469 | unionOfFrequentItems(sets, sets.length) 470 | 471 | 472 | def unionCursor(sets: Array[Array[Int]]): Cursor[Int] = 473 | new Cursor[Int] { 474 | private val (heap, positions) = _initUnion(sets) 475 | private var min = Long.MinValue 476 | private var v = -1 477 | 478 | def moveNext(): Boolean = { 479 | while (heap.nonEmpty && heap.minKey.toLong == min) { 480 | _stepUnion(sets, heap.minValue, heap, positions) 481 | } 482 | 483 | if (heap.nonEmpty) { 484 | val key = heap.minKey 485 | v = key 486 | min = key 487 | true 488 | } else { 489 | false 490 | } 491 | } 492 | 493 | def value = v 494 | } 495 | 496 | 497 | 498 | def mergeSortedArrays(a: Array[Int], b: Array[Int]): Array[Int] = 499 | mergeSortedArrays(a, b, new Array[Int](a.length + b.length)) 500 | 501 | def mergeSortedArrays(a: Array[Int], b: Array[Int], res: Array[Int]): res.type = { 502 | var ai, bi, ri = 0 503 | var al = a.length 504 | var bl = b.length 505 | while (ai < al && bi < bl) { 506 | if (a(ai) <= b(bi)) { res(ri) = a(ai) ; ai += 1 } 507 | else { res(ri) = b(bi) ; bi += 1 } 508 | ri += 1 509 | } 510 | while (ai < al) { res(ri) = a(ai) ; ai += 1 ; ri += 1 } 511 | while (bi < bl) { res(ri) = b(bi) ; bi += 1 ; ri += 1 } 512 | res 513 | } 514 | 515 | 516 | def renumberSetsByFrequency(sets: Array[Array[Int]]): Array[Array[Int]] = { 517 | 518 | var max = -1 519 | var i = 0; while (i < sets.length) { 520 | max = math.max(max, sets(i)(sets(i).length-1)) 521 | i += 1 522 | } 523 | 524 | val counts = new Array[Int](max+1) 525 | for (set <- sets; x <- set) counts(x) += 1 526 | 527 | var trx = Array.range(0, max+1).sortBy(idx => counts(idx)) 528 | 529 | sets.map { set => 530 | val arr = new Array[Int](set.length) 531 | var i = 0 ; while (i < arr.length) { 532 | arr(i) = trx(set(i)) 533 | i += 1 534 | } 535 | makeSet(arr) 536 | } 537 | } 538 | 539 | 540 | } 541 | 542 | 543 | object Bits { 544 | 545 | import java.lang.Float.{ floatToRawIntBits, intBitsToFloat } 546 | import java.lang.Double.{ doubleToRawLongBits, longBitsToDouble } 547 | import java.lang.Integer.highestOneBit 548 | 549 | def getBits(arr: Array[Long], from: Int, to: Int): Array[Long] = { 550 | val res = new Array[Long]((to-from+63)/64) 551 | copyBits(arr, from, to, res, 0) 552 | res 553 | } 554 | 555 | def copyBits(arr: Array[Long], from: Int, to: Int, dest: Array[Long], destpos: Int): dest.type = { 556 | var i = from 557 | var j = destpos 558 | 559 | while (i < to) { 560 | val bit = (arr(i / 64) >>> (i % 64)) & 1L 561 | dest(j / 64) |= (bit << (j % 64)) 562 | i += 1 563 | j += 1 564 | } 565 | 566 | dest 567 | } 568 | 569 | /** Extract up to 64 bits from a long array. The array is split into number of 570 | * blocks of length `blockLen`. Bits may span two neighbouring array 571 | * elements. Requested bits that overrun block length are extracted from the 572 | * begining of that block (hence *WrappingBlocks). */ 573 | def getBitsWrappingBlocks(arr: Array[Long], blockLen: Int, block: Int, bit: Int, bitLen: Int): Long = { 574 | 575 | // position of long where current sequence starts 576 | val blockstart = block * blockLen 577 | 578 | // position of first bit to be extracted 579 | val startbit = blockstart * 64 + bit 580 | val mask = (1 << bitLen) - 1 581 | 582 | val _endpos = (startbit+bitLen) / 64 583 | // if position of second long is outside of current 584 | val endpos = if (_endpos < blockstart + blockLen) _endpos else blockstart 585 | 586 | ((arr(startbit / 64) >>> (startbit % 64)) & mask) | 587 | ((arr(endpos) << (64 - startbit % 64)) & mask) 588 | } 589 | 590 | 591 | /** Extract up to 64 bits from a long array. Bits may span two neighbouring 592 | * array elements. Requested bits that overrun length of the provied array 593 | * are extracted from the begining (hence *Wrapping). */ 594 | def getBitsWrapping(arr: Array[Long], bit: Int, bitLen: Int): Long = { 595 | 596 | val startbit = bit 597 | val mask = (1 << bitLen) - 1 598 | 599 | val _endpos = (startbit+bitLen) / 64 600 | // if position of second long is outside of current 601 | val endpos = if (_endpos < arr.length) _endpos else 0 602 | 603 | ((arr(startbit / 64) >>> (startbit % 64)) & mask) | 604 | ((arr(endpos) << (64 - startbit % 64)) & mask) 605 | } 606 | 607 | 608 | /** Extract up to 64 bits from a long array. Bits may span two neighbouring 609 | * array elements. If requested bits overrun length of the array, exception 610 | * is thrown. Which means `bit` arument must be less or equal than 611 | * `arr.length * 64 - bitLen` */ 612 | def getBitsOverlapping(arr: Array[Long], bit: Int, bitLen: Int): Long = { 613 | val startbit = bit 614 | val mask = (1 << bitLen) - 1 615 | 616 | if (startbit+bitLen == arr.length * 64) { 617 | ((arr(startbit / 64) >>> (startbit % 64)) & mask) 618 | } else { 619 | ((arr(startbit / 64) >>> (startbit % 64)) & mask) | 620 | ((arr((startbit+bitLen) / 64) << (64 - startbit % 64)) & mask) 621 | } 622 | } 623 | 624 | 625 | /** Extract up to 64 bits from a long array. All requested bits must be 626 | * contained inside one long, otherwise result is incorrect (no exception is 627 | * thrown). */ 628 | def getBitsInsideLong(arr: Array[Long], bit: Int, bitLen: Int): Long = 629 | ((arr(bit / 64) >>> (bit % 64)) & (1 << bitLen) - 1) 630 | 631 | 632 | def pack(hi: Int, lo: Int): Long = hi.toLong << 32 | lo 633 | def pack(hi: Float, lo: Int): Long = pack(floatToRawIntBits(hi), lo) 634 | def pack(hi: Int, lo: Float): Long = pack(hi, floatToRawIntBits(lo)) 635 | def pack(hi: Float, lo: Float): Long = pack(floatToRawIntBits(hi), floatToRawIntBits(lo)) 636 | 637 | /** These methods encode floats in such way that they can be sorted by radix sort. */ 638 | def packSortable(hi: Float, lo: Int): Long = pack(floatFlip(floatToRawIntBits(hi)), lo) 639 | def packSortable(hi: Int, lo: Float): Long = pack(hi, floatFlip(floatToRawIntBits(lo))) 640 | def packSortable(hi: Float, lo: Float): Long = pack(floatFlip(floatToRawIntBits(hi)), floatFlip(floatToRawIntBits(lo))) 641 | 642 | def unpackIntHi(l: Long): Int = (l >>> 32).toInt 643 | def unpackIntLo(l: Long): Int = l.toInt 644 | def unpackFloatHi(l: Long): Float = intBitsToFloat(unpackIntHi(l)) 645 | def unpackFloatLo(l: Long): Float = intBitsToFloat(unpackIntLo(l)) 646 | def unpackSortableFloatHi(l: Long): Float = sortableIntToFloat(unpackIntHi(l)) 647 | def unpackSortableFloatLo(l: Long): Float = sortableIntToFloat(unpackIntLo(l)) 648 | 649 | /** Converts float to signed int that preserve ordering, 650 | * ie. if a < b, then ftsi(a) < ftsi(b) 651 | * and if a < 0, then ftsi(a) < 0 */ 652 | def floatToSortableInt(f: Float) = floatFlip(floatToRawIntBits(f)) 653 | def sortableIntToFloat(i: Int) = intBitsToFloat(floatFlip(i)) 654 | 655 | def doubleToSortableLong(f: Double) = doubleFlip(doubleToRawLongBits(f)) 656 | def sortableLongToDouble(i: Long) = longBitsToDouble(doubleFlip(i)) 657 | 658 | /** based on http://stereopsis.com/radix.html, except this converts to signed 659 | * ints. Only difference is that sign bit is never flipped. */ 660 | protected def floatFlip(f: Int) = f ^ (-(f >>> 31) & 0x7FFFFFFF) // float to signed int 661 | protected def doubleFlip(f: Long) = f ^ (-(f >>> 63) & 0x7FFFFFFFFFFFFFFFL) 662 | 663 | def higherPowerOfTwo(x: Int) = 664 | highestOneBit(x) << (if (highestOneBit(x) == x) 0 else 1) 665 | 666 | } 667 | --------------------------------------------------------------------------------