├── project ├── build.properties ├── SparkSubmit.scala ├── plugins.sbt ├── Dependencies.scala └── Common.scala ├── .gitignore ├── .travis.yml ├── .gitattributes ├── data └── mnist │ ├── tsne.gif │ ├── mnist.csv.gz │ └── tsne.R ├── spark-tsne-core └── src │ ├── main │ └── scala │ │ ├── com │ │ └── github │ │ │ └── saurfang │ │ │ └── spark │ │ │ └── tsne │ │ │ ├── TSNEParam.scala │ │ │ ├── TSNEHelper.scala │ │ │ ├── tree │ │ │ └── SPTree.scala │ │ │ ├── impl │ │ │ ├── SimpleTSNE.scala │ │ │ ├── BHTSNE.scala │ │ │ └── LBFGSTSNE.scala │ │ │ ├── X2P.scala │ │ │ └── TSNEGradient.scala │ │ └── org │ │ └── apache │ │ └── spark │ │ └── mllib │ │ └── X2PHelper.scala │ └── test │ └── scala │ ├── com │ └── github │ │ └── saurfang │ │ └── spark │ │ └── tsne │ │ ├── TSNEGradientTest.scala │ │ ├── X2PSuite.scala │ │ ├── BugDemonstrationTest.scala │ │ └── tree │ │ └── SPTreeSpec.scala │ └── org │ └── apache │ └── spark │ ├── SharedSparkContext.scala │ └── LocalSparkContext.scala ├── spark-tsne-examples └── src │ └── main │ ├── resources │ └── log4j.properties │ └── scala │ └── com │ └── github │ └── saurfang │ └── spark │ └── tsne │ └── examples │ └── MNIST.scala ├── README.md ├── spark-tsne-player └── src │ └── main │ └── html │ └── tsne.html └── LICENSE /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=0.13.13 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /RUNNING_PID 2 | /logs/ 3 | project/project/ 4 | project/target/ 5 | target/ 6 | .idea 7 | .tmp 8 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | scala: 3 | - 2.10.6 4 | - 2.11.7 5 | jdk: 6 | - oraclejdk8 7 | - oraclejdk7 8 | - openjdk7 9 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | 2 | *.gz filter=lfs diff=lfs merge=lfs -text 3 | *.gif filter=lfs diff=lfs merge=lfs -text 4 | *.json filter=lfs diff=lfs merge=lfs -text 5 | -------------------------------------------------------------------------------- /data/mnist/tsne.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e4f8da9e3ad4586d4776babe8cf8c41c8686974a29e9af2051be4b8edabf3dfc 3 | size 4371431 4 | -------------------------------------------------------------------------------- /data/mnist/mnist.csv.gz: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3c06183a4876a6364923d297797d0ed6cb84337e2061bf1b250211a0af323e37 3 | size 13266494 4 | -------------------------------------------------------------------------------- /project/SparkSubmit.scala: -------------------------------------------------------------------------------- 1 | import sbtsparksubmit.SparkSubmitPlugin.autoImport._ 2 | 3 | object SparkSubmit { 4 | lazy val settings = 5 | SparkSubmitSetting("sparkMNIST", 6 | Seq( 7 | "--master", "local[3]", 8 | "--class", "com.github.saurfang.spark.tsne.examples.MNIST" 9 | ) 10 | ) 11 | } 12 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.0") 2 | 3 | addSbtPlugin("me.lessis" % "bintray-sbt" % "0.2.1") 4 | 5 | addSbtPlugin("com.typesafe.sbt" % "sbt-git" % "0.8.4") 6 | 7 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.13.0") 8 | 9 | addSbtPlugin("com.github.saurfang" % "sbt-spark-submit" % "0.0.4") 10 | -------------------------------------------------------------------------------- /spark-tsne-core/src/main/scala/com/github/saurfang/spark/tsne/TSNEParam.scala: -------------------------------------------------------------------------------- 1 | package com.github.saurfang.spark.tsne 2 | 3 | case class TSNEParam( 4 | early_exaggeration: Int = 100, 5 | exaggeration_factor: Double = 4.0, 6 | t_momentum: Int = 25, 7 | initial_momentum: Double = 0.5, 8 | final_momentum: Double = 0.8, 9 | eta: Double = 500.0, 10 | min_gain: Double = 0.01 11 | ) 12 | -------------------------------------------------------------------------------- /spark-tsne-core/src/test/scala/com/github/saurfang/spark/tsne/TSNEGradientTest.scala: -------------------------------------------------------------------------------- 1 | package com.github.saurfang.spark.tsne 2 | 3 | import breeze.linalg._ 4 | import org.scalatest.{FunSuite, Matchers} 5 | 6 | /** 7 | * Created by forest on 7/17/15. 8 | */ 9 | class TSNEGradientTest extends FunSuite with Matchers { 10 | test("computeNumerator should compute numerator for sub indices") { 11 | val Y = DenseMatrix.create(3, 2, (1 to 6).map(_.toDouble).toArray) 12 | println(Y) 13 | val num = TSNEGradient.computeNumerator(Y, 0, 2) 14 | println(num) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /spark-tsne-core/src/test/scala/org/apache/spark/SharedSparkContext.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark 2 | 3 | import org.scalatest.{BeforeAndAfterAll, Suite} 4 | 5 | /** Shares a local `SparkContext` between all tests in a suite and closes it at the end */ 6 | trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => 7 | 8 | @transient private var _sc: SparkContext = _ 9 | 10 | def sc: SparkContext = _sc 11 | 12 | var conf = new SparkConf(false) 13 | 14 | override def beforeAll() { 15 | _sc = new SparkContext("local[4]", "test", conf) 16 | super.beforeAll() 17 | } 18 | 19 | override def afterAll() { 20 | LocalSparkContext.stop(_sc) 21 | _sc = null 22 | super.afterAll() 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /spark-tsne-examples/src/main/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # Set everything to be logged to the console 2 | log4j.rootCategory=INFO, console 3 | log4j.appender.console=org.apache.log4j.ConsoleAppender 4 | log4j.appender.console.target=System.err 5 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 6 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 7 | 8 | # Settings to quiet third party logs that are too verbose 9 | log4j.logger.org.spark-project.jetty=WARN 10 | log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR 11 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO 12 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO 13 | log4j.logger.org.apache.spark=WARN 14 | log4j.logger.org.apache.spark.mllib=INFO 15 | -------------------------------------------------------------------------------- /project/Dependencies.scala: -------------------------------------------------------------------------------- 1 | import sbt._ 2 | import Keys._ 3 | 4 | object Dependencies { 5 | val Versions = Seq( 6 | crossScalaVersions := Seq("2.11.8", "2.10.5"), 7 | scalaVersion := crossScalaVersions.value.head 8 | ) 9 | 10 | object Compile { 11 | val spark = "org.apache.spark" %% "spark-mllib" % "2.1.0" % "provided" 12 | val breeze_natives = "org.scalanlp" %% "breeze-natives" % "0.11.2" % "provided" 13 | val logging = Seq( 14 | "org.slf4j" % "slf4j-api" % "1.7.16", 15 | "org.slf4j" % "slf4j-log4j12" % "1.7.16") 16 | 17 | object Test { 18 | val scalatest = "org.scalatest" %% "scalatest" % "3.0.0" % "test" 19 | } 20 | } 21 | 22 | import Compile._ 23 | val l = libraryDependencies 24 | 25 | val core = l ++= Seq(spark, breeze_natives, Test.scalatest) ++ logging 26 | } 27 | -------------------------------------------------------------------------------- /spark-tsne-core/src/main/scala/org/apache/spark/mllib/X2PHelper.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.mllib 2 | 3 | import breeze.linalg._ 4 | import breeze.numerics._ 5 | import org.apache.spark.mllib.linalg.{Vector, Vectors} 6 | import org.apache.spark.mllib.util.MLUtils 7 | 8 | 9 | object X2PHelper { 10 | 11 | case class VectorWithNorm(vector: Vector, norm: Double) 12 | 13 | def fastSquaredDistance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { 14 | MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) 15 | } 16 | 17 | def Hbeta(D: DenseVector[Double], beta: Double = 1.0) : (Double, DenseVector[Double]) = { 18 | val P: DenseVector[Double] = exp(- D * beta) 19 | val sumP = sum(P) 20 | if(sumP == 0) { 21 | (0.0, DenseVector.zeros(D.size)) 22 | }else { 23 | val H = log(sumP) + (beta * sum(D :* P) / sumP) 24 | (H, P / sumP) 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /spark-tsne-core/src/test/scala/com/github/saurfang/spark/tsne/X2PSuite.scala: -------------------------------------------------------------------------------- 1 | package com.github.saurfang.spark.tsne 2 | 3 | import org.apache.spark.SharedSparkContext 4 | import org.apache.spark.mllib.linalg.Vectors 5 | import org.apache.spark.mllib.linalg.distributed.RowMatrix 6 | import org.scalatest.{FunSuite, Matchers} 7 | 8 | /** 9 | * Created by forest on 8/16/15. 10 | */ 11 | class X2PSuite extends FunSuite with SharedSparkContext with Matchers { 12 | 13 | test("Test X2P against tsne.jl implementation") { 14 | val input = new RowMatrix( 15 | sc.parallelize(Seq(1 to 3, 4 to 6, 7 to 9, 10 to 12)) 16 | .map(x => Vectors.dense(x.map(_.toDouble).toArray)) 17 | ) 18 | val output = X2P(input, 1e-5, 2).toRowMatrix().rows.collect().map(_.toArray.toList) 19 | println(output.toList) 20 | //output shouldBe List(List(0, .5, .5), List(.5, 0, .5), List(.5, .5, .0)) 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /project/Common.scala: -------------------------------------------------------------------------------- 1 | import sbt._ 2 | import Keys._ 3 | import com.typesafe.sbt.GitPlugin.autoImport._ 4 | 5 | import scala.language.experimental.macros 6 | import scala.reflect.macros.Context 7 | 8 | object Common { 9 | val commonSettings = Seq( 10 | organization in ThisBuild := "com.github.saurfang", 11 | javacOptions ++= Seq("-source", "1.7", "-target", "1.7"), 12 | scalacOptions ++= Seq("-target:jvm-1.7", "-deprecation", "-feature"), 13 | //git.useGitDescribe := true, 14 | git.baseVersion := "0.0.1", 15 | parallelExecution in test := false, 16 | updateOptions := updateOptions.value.withCachedResolution(true) 17 | ) 18 | 19 | def tsneProject(path: String): Project = macro tsneProjectMacroImpl 20 | 21 | def tsneProjectMacroImpl(c: Context)(path: c.Expr[String]) = { 22 | import c.universe._ 23 | reify { 24 | (Project.projectMacroImpl(c).splice in file(path.splice)). 25 | settings(name := path.splice). 26 | settings(Dependencies.Versions). 27 | settings(commonSettings: _*) 28 | } 29 | } 30 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # spark-tsne 2 | 3 | [![Join the chat at https://gitter.im/saurfang/spark-tsne](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/saurfang/spark-tsne?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Build Status](https://travis-ci.org/erwinvaneijk/spark-tsne.svg?branch=master)](https://travis-ci.org/erwinvaneijk/spark-tsne) 4 | Distributed [t-SNE](http://lvdmaaten.github.io/tsne/) with Apache Spark. WIP... 5 | 6 | t-SNE is a dimension reduction technique that is particularly good for visualizing high 7 | dimensional data. This is an attempt to implement this algorithm using Spark to leverage 8 | distributed computing power. 9 | 10 | The project is still in progress of replicating reference implementations from the original 11 | papers. Spark specific optimizations will be the next goal once the correctness is verified. 12 | 13 | Currently I'm showcasing this using the standard [MNIST](http://yann.lecun.com/exdb/mnist/) 14 | handwriting recognition dataset. I have created a [WebGL player](https://saurfang.github.io/spark-tsne-demo/tsne-pixi.html) 15 | (built using [pixi.js](https://github.com/pixijs/pixi.js)) to visualize the inner workings 16 | as well as the final results of t-SNE. If a WebGL is unavailable for you, you may checkout 17 | the [d3.js player](https://saurfang.github.io/spark-tsne-demo/tsne.html) instead. 18 | 19 | ![](data/mnist/tsne.gif) 20 | 21 | ## Credits 22 | 23 | - [t-SNE Julia implementation](https://github.com/lejon/TSne.jl) 24 | - [Barnes-Hut t-SNE](https://github.com/lvdmaaten/bhtsne/) 25 | -------------------------------------------------------------------------------- /data/mnist/tsne.R: -------------------------------------------------------------------------------- 1 | library(dplyr) 2 | library(ggplot2) 3 | library(animation) 4 | library(jsonlite) 5 | 6 | resultFiles <- list.files("~/GitHub/spark-tsne/.tmp/MNIST/", "result", full.names = TRUE) 7 | results <- lapply(resultFiles, function(file) { read.csv(file, FALSE) }) 8 | resultsCombined <- lapply(1:length(results), function(i) { 9 | result <- results[[i]] 10 | names(result) <- c("label", "x", "y") 11 | mutate(result, i = i, key = row_number()) 12 | }) %>% 13 | rbind_all() 14 | 15 | #### save results as json for viewer #### 16 | iterations <- c(1:99, seq(100, length(results), 5)) # assume 100 early exaggeration here 17 | resultsByObs <- filter(resultsCombined, i %in% iterations) %>% 18 | group_by(key) %>% 19 | # do({ 20 | # list(key = unbox(.$key[1]), label = unbox(.$label[1]), 21 | # # assume order will preserve 22 | # pos = select(., x, y)) %>% 23 | # data_frame 24 | # }) 25 | do(key = unbox(.$key[1]), 26 | label = unbox(.$label[1]), 27 | pos = select(., x, y)) 28 | write(toJSON(list(iterations = iterations, data = resultsByObs)), "mnist.json") 29 | 30 | #### save plot as animated gif #### 31 | computeLimit <- function(f, cumf) { 32 | cumf(lapply(results, f)) 33 | } 34 | 35 | xmax <- computeLimit(. %>% {max(abs(.$V2))}, cummax) 36 | ymax <- computeLimit(. %>% {max(abs(.$V3))}, cummax) 37 | 38 | plotResult <- function(i) { 39 | ggplot(results[[i]]) + 40 | aes(V2, V3, color = as.factor(V1), label = V1) + 41 | #geom_point() + 42 | geom_text() + 43 | xlim(-xmax[i], xmax[i]) + 44 | ylim(-ymax[i], ymax[i]) 45 | } 46 | 47 | traceAnimate <- function(n = length(results), step = 1) { 48 | lapply(seq(1, n, step), function(i) { 49 | print(plotResult(i)) 50 | }) 51 | } 52 | 53 | file.remove("tsne.gif") 54 | saveGIF(traceAnimate(step = 5), interval = 0.05, movie.name = "tsne.gif", loop = 1) 55 | -------------------------------------------------------------------------------- /spark-tsne-core/src/main/scala/com/github/saurfang/spark/tsne/TSNEHelper.scala: -------------------------------------------------------------------------------- 1 | package com.github.saurfang.spark.tsne 2 | 3 | import breeze.linalg._ 4 | import breeze.stats._ 5 | import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix 6 | import org.apache.spark.rdd.RDD 7 | 8 | object TSNEHelper { 9 | // p_ij = (p_{i|j} + p_{j|i}) / 2n 10 | def computeP(p_ji: CoordinateMatrix, n: Int): RDD[(Int, Iterable[(Int, Double)])] = { 11 | p_ji.entries 12 | .flatMap(e => Seq( 13 | ((e.i.toInt, e.j.toInt), e.value), 14 | ((e.j.toInt, e.i.toInt), e.value) 15 | )) 16 | .reduceByKey(_ + _) // p + p' 17 | .map{case ((i, j), v) => (i, (j, math.max(v / 2 / n, 1e-12))) } // p / 2n 18 | .groupByKey() 19 | } 20 | 21 | /** 22 | * Update Y via gradient dY 23 | * @param Y current Y 24 | * @param dY gradient dY 25 | * @param iY stored y_i - y_{i-1} 26 | * @param gains adaptive learning rates 27 | * @param iteration n 28 | * @param param [[TSNEParam]] 29 | * @return 30 | */ 31 | def update(Y: DenseMatrix[Double], 32 | dY: DenseMatrix[Double], 33 | iY: DenseMatrix[Double], 34 | gains: DenseMatrix[Double], 35 | iteration: Int, 36 | param: TSNEParam): DenseMatrix[Double] = { 37 | import param._ 38 | val momentum = if (iteration <= t_momentum) initial_momentum else final_momentum 39 | gains.foreachPair { 40 | case ((i, j), old_gain) => 41 | val new_gain = math.max(min_gain, 42 | if ((dY(i, j) > 0.0) != (iY(i, j) > 0.0)) 43 | old_gain + 0.2 44 | else 45 | old_gain * 0.8 46 | ) 47 | gains.update(i, j, new_gain) 48 | 49 | val new_iY = momentum * iY(i, j) - eta * new_gain * dY(i, j) 50 | iY.update(i, j, new_iY) 51 | 52 | Y.update(i, j, Y(i, j) + new_iY) // Y += iY 53 | } 54 | val t_Y: DenseVector[Double] = mean(Y(::, *)).t 55 | val y_sub = Y(*, ::) 56 | Y := y_sub - t_Y 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /spark-tsne-core/src/test/scala/com/github/saurfang/spark/tsne/BugDemonstrationTest.scala: -------------------------------------------------------------------------------- 1 | package com.github.saurfang.spark.tsne 2 | 3 | import org.apache.spark.mllib.linalg.{Vectors, Vector} 4 | import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} 5 | import org.apache.spark.sql.SparkSession 6 | import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} 7 | 8 | /** 9 | * This test demonstrates the bug introduced when upgrading the codebase to spark 2.1. 10 | * 11 | * For completeness and to check regressions, it's now added to the codebase. 12 | * 13 | * @author erwin.vaneijk@gmail.com 14 | */ 15 | class BugDemonstrationTest extends FunSuite with Matchers with BeforeAndAfterAll { 16 | private var sparkSession : SparkSession = _ 17 | override def beforeAll(): Unit = { 18 | super.beforeAll() 19 | sparkSession = SparkSession.builder().appName("BugTests").master("local[2]").getOrCreate() 20 | } 21 | 22 | override def afterAll(): Unit = { 23 | super.afterAll() 24 | sparkSession.stop() 25 | } 26 | 27 | test("This demonstrates a bug was fixed in tsne-spark 2.1") { 28 | val sc = sparkSession.sparkContext 29 | 30 | val observations = sc.parallelize( 31 | Seq( 32 | Vectors.dense(1.0, 10.0, 100.0), 33 | Vectors.dense(2.0, 20.0, 200.0), 34 | Vectors.dense(3.0, 30.0, 300.0) 35 | ) 36 | ) 37 | 38 | // Compute column summary statistics. 39 | val summary: MultivariateStatisticalSummary = Statistics.colStats(observations) 40 | val expectedMean = Vectors.dense(2.0,20.0,200.0) 41 | val resultMean = summary.mean 42 | assertEqualEnough(resultMean, expectedMean) 43 | val expectedVariance = Vectors.dense(1.0,100.0,10000.0) 44 | assertEqualEnough(summary.variance, expectedVariance) 45 | val expectedNumNonZeros = Vectors.dense(3.0, 3.0, 3.0) 46 | assertEqualEnough(summary.numNonzeros, expectedNumNonZeros) 47 | } 48 | 49 | private def assertEqualEnough(sample: Vector, expected: Vector): Unit = { 50 | expected.toArray.zipWithIndex.foreach{ case(d: Double, i: Int) => 51 | sample(i) should be (d +- 1E-12) 52 | } 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /spark-tsne-core/src/test/scala/org/apache/spark/LocalSparkContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark 19 | 20 | import _root_.io.netty.util.internal.logging.{InternalLoggerFactory, Slf4JLoggerFactory} 21 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Suite} 22 | 23 | /** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */ 24 | trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite => 25 | 26 | @transient var sc: SparkContext = _ 27 | 28 | override def beforeAll() { 29 | InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) 30 | super.beforeAll() 31 | } 32 | 33 | override def afterEach() { 34 | resetSparkContext() 35 | super.afterEach() 36 | } 37 | 38 | def resetSparkContext(): Unit = { 39 | LocalSparkContext.stop(sc) 40 | sc = null 41 | } 42 | 43 | } 44 | 45 | object LocalSparkContext { 46 | def stop(sc: SparkContext) { 47 | if (sc != null) { 48 | sc.stop() 49 | } 50 | // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown 51 | System.clearProperty("spark.driver.port") 52 | } 53 | 54 | /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ 55 | def withSpark[T](sc: SparkContext)(f: SparkContext => T): T = { 56 | try { 57 | f(sc) 58 | } finally { 59 | stop(sc) 60 | } 61 | } 62 | 63 | } -------------------------------------------------------------------------------- /spark-tsne-core/src/test/scala/com/github/saurfang/spark/tsne/tree/SPTreeSpec.scala: -------------------------------------------------------------------------------- 1 | package com.github.saurfang.spark.tsne.tree 2 | 3 | import breeze.linalg._ 4 | import org.scalatest.{FunSpec, Matchers} 5 | 6 | class SPTreeSpec extends FunSpec with Matchers { 7 | 8 | describe("SPTree") { 9 | describe("with 2 dimensions (quadtree)") { 10 | val tree = new SPTree(2, DenseVector(0.0, 0.0), DenseVector(2.0, 4.0)) 11 | import tree._ 12 | it("should have 4 children") { 13 | children.length shouldBe 4 14 | } 15 | it("each child should have correct width") { 16 | val width = DenseVector(1.0, 2.0) 17 | children.foreach(x => x.width shouldBe width) 18 | } 19 | it("children should have correct corner") { 20 | children.map(_.corner) shouldBe Array( 21 | DenseVector(0.0, 0.0), 22 | DenseVector(0.0, 2.0), 23 | DenseVector(1.0, 0.0), 24 | DenseVector(1.0, 2.0) 25 | ) 26 | } 27 | it("getCell should return correct cell") { 28 | getCell(DenseVector(1.0, 1.0)).corner shouldBe DenseVector(0.0, 0.0) 29 | getCell(DenseVector(1.5, 1.5)).corner shouldBe DenseVector(1.0, 0.0) 30 | getCell(DenseVector(2.0, 2.0)).corner shouldBe DenseVector(1.0, 0.0) 31 | getCell(DenseVector(2.0, 2.5)).corner shouldBe DenseVector(1.0, 2.0) 32 | } 33 | it("should be able to be constructed from DenseMatrix") { 34 | val data = Array( 35 | 1.0, 1.0, 1.0, 2.0, 1.1, 1.11, 1.11, 1, 36 | 3.0, 1.0, 2.0, 2.0, 1.1, 1.11, 1.11, 1 37 | ) 38 | val matrix = DenseMatrix.create[Double](data.length / 2, 2, data) 39 | val tree = SPTree(matrix) 40 | 41 | tree.getCount shouldBe matrix.rows 42 | tree.children.map(_.getCount).sum shouldBe matrix.rows 43 | tree.center shouldBe DenseVector(data.grouped(matrix.rows).map(x => x.sum / x.length).toArray) 44 | verifyCorrectness(tree) 45 | } 46 | } 47 | } 48 | 49 | def verifyCorrectness(tree: SPTree): Unit = { 50 | if(tree.getCount <= 1) tree.isLeaf shouldBe true 51 | if(tree.getCount > 0) tree.center shouldBe (tree.totalMass / tree.getCount.toDouble) 52 | if(tree.isLeaf) { 53 | tree.children.foreach(_.isLeaf shouldBe true) 54 | tree.children.foreach(_.getCount shouldBe 0) 55 | } else { 56 | tree.children.map(_.getCount).sum shouldBe tree.getCount 57 | val totalMassTally = tree.children.foldLeft(DenseVector.zeros[Double](tree.dimension))((acc, t) => acc + t.totalMass) 58 | (0 until tree.dimension).foreach(i => totalMassTally(i) shouldBe (tree.totalMass(i) +- 1e-5)) 59 | tree.children.foreach(verifyCorrectness) 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /spark-tsne-core/src/main/scala/com/github/saurfang/spark/tsne/tree/SPTree.scala: -------------------------------------------------------------------------------- 1 | package com.github.saurfang.spark.tsne.tree 2 | 3 | import breeze.linalg._ 4 | import breeze.numerics._ 5 | 6 | import scala.annotation.tailrec 7 | 8 | 9 | class SPTree private[tree](val dimension: Int, 10 | val corner: DenseVector[Double], 11 | val width: DenseVector[Double]) extends Serializable { 12 | private[this] val childWidth: DenseVector[Double] = width :/ 2.0 13 | lazy val radiusSq: Double = sum(pow(width, 2)) 14 | private[tree] val totalMass: DenseVector[Double] = DenseVector.zeros(dimension) 15 | private var count: Int = 0 16 | private var leaf: Boolean = true 17 | val center: DenseVector[Double] = DenseVector.zeros(dimension) 18 | 19 | lazy val children: Array[SPTree] = { 20 | (0 until pow(2, dimension)).toArray.map { 21 | i => 22 | val bits = DenseVector(s"%0${dimension}d".format(i.toBinaryString.toInt).toArray.map(_.toDouble - '0'.toDouble)) 23 | val childCorner: DenseVector[Double] = corner + (bits :* childWidth) 24 | new SPTree(dimension, childCorner, childWidth) 25 | } 26 | } 27 | 28 | final def insert(vector: DenseVector[Double], finalize: Boolean = false): SPTree = { 29 | totalMass += vector 30 | count += 1 31 | 32 | if(leaf) { 33 | if(count == 1) { // first to leaf 34 | center := vector 35 | } else if(!vector.equals(center)) { 36 | (1 until count).foreach(_ => getCell(center).insert(center, finalize)) //subdivide 37 | leaf = false 38 | } 39 | } 40 | 41 | if(finalize) computeCenter(false) 42 | 43 | if(leaf) this else getCell(vector).insert(vector, finalize) 44 | } 45 | 46 | def computeCenter(recursive: Boolean = true): Unit = { 47 | if(count > 0) { 48 | center := totalMass / count.toDouble 49 | if(recursive) children.foreach(_.computeCenter()) 50 | } 51 | } 52 | 53 | def getCell(vector: DenseVector[Double]): SPTree = { 54 | val idx = ((vector - corner) :/ childWidth).data 55 | children(idx.foldLeft(0)((acc, i) => acc * 2 + min(max(i.ceil.toInt - 1, 0), 1))) 56 | } 57 | 58 | def getCount: Int = count 59 | 60 | def isLeaf: Boolean = leaf 61 | } 62 | 63 | object SPTree { 64 | def apply(Y: DenseMatrix[Double]): SPTree = { 65 | val d = Y.cols 66 | val minMaxs = minMax(Y(::, *)).t 67 | val mins = minMaxs.mapValues(_._1) 68 | val maxs = minMaxs.mapValues(_._2) 69 | 70 | val tree = new SPTree(Y.cols, mins, maxs - mins) 71 | 72 | // insert points but wait till end to compute all centers 73 | //Y(*, ::).foreach(tree.insert(_, finalize = false)) 74 | (0 until Y.rows).foreach(i => tree.insert(Y(i, ::).t, finalize = false)) 75 | // compute all center of mass 76 | tree.computeCenter() 77 | 78 | tree 79 | } 80 | } -------------------------------------------------------------------------------- /spark-tsne-core/src/main/scala/com/github/saurfang/spark/tsne/impl/SimpleTSNE.scala: -------------------------------------------------------------------------------- 1 | package com.github.saurfang.spark.tsne.impl 2 | 3 | import breeze.linalg._ 4 | import breeze.stats.distributions.Rand 5 | import com.github.saurfang.spark.tsne.{TSNEGradient, TSNEHelper, TSNEParam, X2P} 6 | import org.apache.spark.mllib.linalg.distributed.RowMatrix 7 | import org.apache.spark.storage.StorageLevel 8 | import org.slf4j.LoggerFactory 9 | 10 | import scala.util.Random 11 | 12 | object SimpleTSNE { 13 | private def logger = LoggerFactory.getLogger(SimpleTSNE.getClass) 14 | 15 | def tsne( 16 | input: RowMatrix, 17 | noDims: Int = 2, 18 | maxIterations: Int = 1000, 19 | perplexity: Double = 30, 20 | callback: (Int, DenseMatrix[Double], Option[Double]) => Unit = {case _ => }, 21 | seed: Long = Random.nextLong()): DenseMatrix[Double] = { 22 | if(input.rows.getStorageLevel == StorageLevel.NONE) { 23 | logger.warn("Input is not persisted and performance could be bad") 24 | } 25 | 26 | Rand.generator.setSeed(seed) 27 | 28 | val tsneParam = TSNEParam() 29 | import tsneParam._ 30 | 31 | val n = input.numRows().toInt 32 | val Y: DenseMatrix[Double] = DenseMatrix.rand(n, noDims, Rand.gaussian(0, 1)) 33 | val iY = DenseMatrix.zeros[Double](n, noDims) 34 | val gains = DenseMatrix.ones[Double](n, noDims) 35 | 36 | // approximate p_{j|i} 37 | val p_ji = X2P(input, 1e-5, perplexity) 38 | val P = TSNEHelper.computeP(p_ji, n).glom().cache() 39 | 40 | var iteration = 1 41 | while(iteration <= maxIterations) { 42 | val bcY = P.context.broadcast(Y) 43 | 44 | val numerator = P.map{ arr => TSNEGradient.computeNumerator(bcY.value, arr.map(_._1): _*) }.cache() 45 | val bcNumerator = P.context.broadcast({ 46 | numerator.treeAggregate(0.0)(seqOp = (x, v) => x + sum(v), combOp = _ + _) 47 | }) 48 | 49 | val (dY, loss) = P.zip(numerator).treeAggregate((DenseMatrix.zeros[Double](n, noDims), 0.0))( 50 | seqOp = (c, v) => { 51 | // c: (grad, loss), v: (Array[(i, Iterable(j, Distance))], numerator) 52 | val l = TSNEGradient.compute(v._1, bcY.value, v._2, bcNumerator.value, c._1, iteration <= early_exaggeration) 53 | (c._1, c._2 + l) 54 | }, 55 | combOp = (c1, c2) => { 56 | // c: (grad, loss) 57 | (c1._1 + c2._1, c1._2 + c2._2) 58 | }) 59 | 60 | bcY.destroy() 61 | bcNumerator.destroy() 62 | numerator.unpersist() 63 | 64 | TSNEHelper.update(Y, dY, iY, gains, iteration, tsneParam) 65 | 66 | logger.debug(s"Iteration $iteration finished with $loss") 67 | callback(iteration, Y.copy, Some(loss)) 68 | iteration += 1 69 | } 70 | Y 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /spark-tsne-core/src/main/scala/com/github/saurfang/spark/tsne/X2P.scala: -------------------------------------------------------------------------------- 1 | package com.github.saurfang.spark.tsne 2 | 3 | import breeze.linalg.DenseVector 4 | import org.apache.spark.mllib.X2PHelper._ 5 | import org.apache.spark.mllib.linalg.Vectors 6 | import org.apache.spark.mllib.linalg.distributed.{CoordinateMatrix, MatrixEntry, RowMatrix} 7 | import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ 8 | import org.slf4j.LoggerFactory 9 | 10 | object X2P { 11 | 12 | private def logger = LoggerFactory.getLogger(X2P.getClass) 13 | 14 | def apply(x: RowMatrix, tol: Double = 1e-5, perplexity: Double = 30.0): CoordinateMatrix = { 15 | require(tol >= 0, "Tolerance must be non-negative") 16 | require(perplexity > 0, "Perplexity must be positive") 17 | 18 | val mu = (3 * perplexity).toInt //TODO: Expose this as parameter 19 | val logU = Math.log(perplexity) 20 | val norms = x.rows.map(Vectors.norm(_, 2.0)) 21 | norms.persist() 22 | val rowsWithNorm = x.rows.zip(norms).map{ case (v, norm) => VectorWithNorm(v, norm) } 23 | val neighbors = rowsWithNorm.zipWithIndex() 24 | .cartesian(rowsWithNorm.zipWithIndex()) 25 | .flatMap { 26 | case ((u, i), (v, j)) => 27 | if(i < j) { 28 | val dist = fastSquaredDistance(u, v) 29 | Seq((i, (j, dist)), (j, (i, dist))) 30 | } else Seq.empty 31 | } 32 | .topByKey(mu)(Ordering.by(e => -e._2)) 33 | 34 | val p_betas = 35 | neighbors.map { 36 | case (i, arr) => 37 | var betamin = Double.NegativeInfinity 38 | var betamax = Double.PositiveInfinity 39 | var beta = 1.0 40 | 41 | val d = DenseVector(arr.map(_._2)) 42 | var (h, p) = Hbeta(d, beta) 43 | 44 | //logInfo("data was " + d.toArray.toList) 45 | //logInfo("array P was " + p.toList) 46 | 47 | // Evaluate whether the perplexity is within tolerance 48 | def Hdiff = h - logU 49 | var tries = 0 50 | while (Math.abs(Hdiff) > tol && tries < 50) { 51 | //If not, increase or decrease precision 52 | if (Hdiff > 0) { 53 | betamin = beta 54 | beta = if (betamax.isInfinite) beta * 2 else (beta + betamax) / 2 55 | } else { 56 | betamax = beta 57 | beta = if (betamin.isInfinite) beta / 2 else (beta + betamin) / 2 58 | } 59 | 60 | // Recompute the values 61 | val HP = Hbeta(d, beta) 62 | h = HP._1 63 | p = HP._2 64 | tries = tries + 1 65 | } 66 | 67 | //logInfo("array P is " + p.toList) 68 | 69 | (arr.map(_._1).zip(p.toArray).map { case (j, v) => MatrixEntry(i, j, v) }, beta) 70 | } 71 | 72 | logger.info("Mean value of sigma: " + p_betas.map(x => math.sqrt(1 / x._2)).mean) 73 | new CoordinateMatrix(p_betas.flatMap(_._1)) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /spark-tsne-examples/src/main/scala/com/github/saurfang/spark/tsne/examples/MNIST.scala: -------------------------------------------------------------------------------- 1 | package com.github.saurfang.spark.tsne.examples 2 | 3 | 4 | import java.io.{BufferedWriter, OutputStreamWriter} 5 | 6 | import com.github.saurfang.spark.tsne.impl._ 7 | import com.github.saurfang.spark.tsne.tree.SPTree 8 | import org.apache.hadoop.fs.{FileSystem, Path} 9 | import org.apache.spark.mllib.linalg.Vectors 10 | import org.apache.spark.mllib.linalg.distributed.RowMatrix 11 | import org.apache.spark.{SparkConf, SparkContext} 12 | import org.slf4j.LoggerFactory 13 | 14 | object MNIST { 15 | private def logger = LoggerFactory.getLogger(MNIST.getClass) 16 | 17 | def main (args: Array[String]) { 18 | val conf = new SparkConf() 19 | .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 20 | .registerKryoClasses(Array(classOf[SPTree])) 21 | val sc = new SparkContext(conf) 22 | val hadoopConf = sc.hadoopConfiguration 23 | val fs = FileSystem.get(hadoopConf) 24 | 25 | val dataset = sc.textFile("data/MNIST/mnist.csv.gz") 26 | .zipWithIndex() 27 | .filter(_._2 < 6000) 28 | .sortBy(_._2, true, 60) 29 | .map(_._1) 30 | .map(_.split(",")) 31 | .map(x => (x.head.toInt, x.tail.map(_.toDouble))) 32 | .cache() 33 | //logInfo(dataset.collect.map(_._2.toList).toList.toString) 34 | 35 | //val features = dataset.map(x => Vectors.dense(x._2)) 36 | //val scaler = new StandardScaler(true, true).fit(features) 37 | //val scaledData = scaler.transform(features) 38 | // .map(v => Vectors.dense(v.toArray.map(x => if(x.isNaN || x.isInfinite) 0.0 else x))) 39 | // .cache() 40 | val data = dataset.flatMap(_._2) 41 | val mean = data.mean() 42 | val std = data.stdev() 43 | val scaledData = dataset.map(x => Vectors.dense(x._2.map(v => (v - mean) / std))).cache() 44 | 45 | val labels = dataset.map(_._1).collect() 46 | val matrix = new RowMatrix(scaledData) 47 | val pcaMatrix = matrix.multiply(matrix.computePrincipalComponents(50)) 48 | pcaMatrix.rows.cache() 49 | 50 | val costWriter = new BufferedWriter(new OutputStreamWriter(fs.create(new Path(s".tmp/MNIST/cost.txt"), true))) 51 | 52 | //SimpleTSNE.tsne(pcaMatrix, perplexity = 20, maxIterations = 200) 53 | BHTSNE.tsne(pcaMatrix, maxIterations = 500, callback = { 54 | //LBFGSTSNE.tsne(pcaMatrix, perplexity = 10, maxNumIterations = 500, numCorrections = 10, convergenceTol = 1e-8) 55 | case (i, y, loss) => 56 | if(loss.isDefined) logger.info(s"$i iteration finished with loss $loss") 57 | 58 | val os = fs.create(new Path(s".tmp/MNIST/result${"%05d".format(i)}.csv"), true) 59 | val writer = new BufferedWriter(new OutputStreamWriter(os)) 60 | try { 61 | (0 until y.rows).foreach { 62 | row => 63 | writer.write(labels(row).toString) 64 | writer.write(y(row, ::).inner.toArray.mkString(",", ",", "\n")) 65 | } 66 | if(loss.isDefined) costWriter.write(loss.get + "\n") 67 | } finally { 68 | writer.close() 69 | } 70 | }) 71 | costWriter.close() 72 | 73 | sc.stop() 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /spark-tsne-core/src/main/scala/com/github/saurfang/spark/tsne/impl/BHTSNE.scala: -------------------------------------------------------------------------------- 1 | package com.github.saurfang.spark.tsne.impl 2 | 3 | import breeze.linalg._ 4 | import breeze.stats.distributions.Rand 5 | import com.github.saurfang.spark.tsne.tree.SPTree 6 | import com.github.saurfang.spark.tsne.{TSNEGradient, TSNEHelper, TSNEParam, X2P} 7 | import org.apache.spark.mllib.linalg.distributed.RowMatrix 8 | import org.apache.spark.storage.StorageLevel 9 | import org.slf4j.LoggerFactory 10 | 11 | import scala.util.Random 12 | 13 | object BHTSNE { 14 | private def logger = LoggerFactory.getLogger(BHTSNE.getClass) 15 | 16 | def tsne( 17 | input: RowMatrix, 18 | noDims: Int = 2, 19 | maxIterations: Int = 1000, 20 | perplexity: Double = 30, 21 | theta: Double = 0.5, 22 | reportLoss: Int => Boolean = {i => i % 10 == 0}, 23 | callback: (Int, DenseMatrix[Double], Option[Double]) => Unit = {case _ => }, 24 | seed: Long = Random.nextLong() 25 | ): DenseMatrix[Double] = { 26 | if(input.rows.getStorageLevel == StorageLevel.NONE) { 27 | logger.warn("Input is not persisted and performance could be bad") 28 | } 29 | 30 | Rand.generator.setSeed(seed) 31 | 32 | val tsneParam = TSNEParam() 33 | import tsneParam._ 34 | 35 | val n = input.numRows().toInt 36 | val Y: DenseMatrix[Double] = DenseMatrix.rand(n, noDims, Rand.gaussian(0, 1)) :/ 1e4 37 | val iY = DenseMatrix.zeros[Double](n, noDims) 38 | val gains = DenseMatrix.ones[Double](n, noDims) 39 | 40 | // approximate p_{j|i} 41 | val p_ji = X2P(input, 1e-5, perplexity) 42 | val P = TSNEHelper.computeP(p_ji, n).glom() 43 | .map(rows => rows.map { 44 | case (i, data) => 45 | (i, data.map(_._1).toSeq, DenseVector(data.map(_._2 * exaggeration_factor).toArray)) 46 | }) 47 | .cache() 48 | 49 | var iteration = 1 50 | while(iteration <= maxIterations) { 51 | val bcY = P.context.broadcast(Y) 52 | val bcTree = P.context.broadcast(SPTree(Y)) 53 | 54 | val initialValue = (DenseMatrix.zeros[Double](n, noDims), DenseMatrix.zeros[Double](n, noDims), 0.0) 55 | val (posF, negF, sumQ) = P.treeAggregate(initialValue)( 56 | seqOp = (c, v) => { 57 | // c: (pos, neg, sumQ), v: Array[(i, Seq(j), vec(Distance))] 58 | TSNEGradient.computeEdgeForces(v, bcY.value, c._1) 59 | val q = TSNEGradient.computeNonEdgeForces(bcTree.value, bcY.value, theta, c._2, v.map(_._1): _*) 60 | (c._1, c._2, c._3 + q) 61 | }, 62 | combOp = (c1, c2) => { 63 | // c: (grad, loss) 64 | (c1._1 + c2._1, c1._2 + c2._2, c1._3 + c2._3) 65 | }) 66 | val dY: DenseMatrix[Double] = posF :- (negF :/ sumQ) 67 | 68 | TSNEHelper.update(Y, dY, iY, gains, iteration, tsneParam) 69 | 70 | if(reportLoss(iteration)) { 71 | val loss = P.treeAggregate(0.0)( 72 | seqOp = (c, v) => { 73 | TSNEGradient.computeLoss(v, bcY.value, sumQ) 74 | }, 75 | combOp = _ + _ 76 | ) 77 | logger.debug(s"Iteration $iteration finished with $loss") 78 | callback(iteration, Y.copy, Some(loss)) 79 | } else { 80 | logger.debug(s"Iteration $iteration finished") 81 | callback(iteration, Y.copy, None) 82 | } 83 | 84 | bcY.destroy() 85 | bcTree.destroy() 86 | 87 | //undo early exaggeration 88 | if(iteration == early_exaggeration) { 89 | P.foreach { 90 | rows => rows.foreach { 91 | case (_, _, vec) => vec.foreachPair { case (i, v) => vec.update(i, v / exaggeration_factor) } 92 | } 93 | } 94 | } 95 | 96 | iteration += 1 97 | } 98 | 99 | Y 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /spark-tsne-core/src/main/scala/com/github/saurfang/spark/tsne/impl/LBFGSTSNE.scala: -------------------------------------------------------------------------------- 1 | package com.github.saurfang.spark.tsne.impl 2 | 3 | import breeze.linalg._ 4 | import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS} 5 | import breeze.stats.distributions.Rand 6 | import com.github.saurfang.spark.tsne.{TSNEGradient, X2P} 7 | import org.apache.spark.mllib.linalg.distributed.RowMatrix 8 | import org.apache.spark.rdd.RDD 9 | import org.apache.spark.storage.StorageLevel 10 | import org.slf4j.LoggerFactory 11 | 12 | import scala.util.Random 13 | 14 | /** 15 | * TODO: This doesn't work at all (yet or ever). 16 | */ 17 | object LBFGSTSNE { 18 | private def logger = LoggerFactory.getLogger(LBFGSTSNE.getClass) 19 | 20 | def tsne( 21 | input: RowMatrix, 22 | noDims: Int = 2, 23 | maxNumIterations: Int = 1000, 24 | numCorrections: Int = 10, 25 | convergenceTol: Double = 1e-4, 26 | perplexity: Double = 30, 27 | seed: Long = Random.nextLong()): DenseMatrix[Double] = { 28 | if(input.rows.getStorageLevel == StorageLevel.NONE) { 29 | logger.warn("Input is not persisted and performance could be bad") 30 | } 31 | 32 | Rand.generator.setSeed(seed) 33 | 34 | val n = input.numRows().toInt 35 | val early_exaggeration = 100 36 | val t_momentum = 250 37 | val initial_momentum = 0.5 38 | val final_momentum = 0.8 39 | val eta = 500.0 40 | val min_gain = 0.01 41 | 42 | val Y: DenseMatrix[Double] = DenseMatrix.rand(n, noDims, Rand.gaussian) //:* .0001 43 | val iY = DenseMatrix.zeros[Double](n, noDims) 44 | val gains = DenseMatrix.ones[Double](n, noDims) 45 | 46 | // approximate p_{j|i} 47 | val p_ji = X2P(input, 1e-5, perplexity) 48 | //logInfo(p_ji.toRowMatrix().rows.collect().toList.toString) 49 | // p_ij = (p_{i|j} + p_{j|i}) / 2n 50 | val P = p_ji.transpose().entries.union(p_ji.entries) 51 | .map(e => ((e.i.toInt, e.j.toInt), e.value)) 52 | .reduceByKey(_ + _) 53 | .map{case ((i, j), v) => (i, (j, v / 2 / n)) } 54 | .groupByKey() 55 | .glom() 56 | .cache() 57 | 58 | var iteration = 1 59 | 60 | { 61 | val costFun = new CostFun(P, n, noDims, true) 62 | val lbfgs = new LBFGS[DenseVector[Double]](maxNumIterations, numCorrections, convergenceTol) 63 | val states = lbfgs.iterations(new CachedDiffFunction(costFun), new DenseVector(Y.data)) 64 | 65 | while (states.hasNext) { 66 | val state = states.next() 67 | val loss = state.value 68 | //logInfo(state.convergedReason.get.toString) 69 | logger.debug(s"Iteration $iteration finished with $loss") 70 | 71 | Y := asDenseMatrix(state.x, n, noDims) 72 | //subscriber.onNext((iteration, Y.copy, Some(loss))) 73 | iteration += 1 74 | } 75 | } 76 | 77 | { 78 | val costFun = new CostFun(P, n, noDims, false) 79 | val lbfgs = new LBFGS[DenseVector[Double]](maxNumIterations, numCorrections, convergenceTol) 80 | val states = lbfgs.iterations(new CachedDiffFunction(costFun), new DenseVector(Y.data)) 81 | 82 | while (states.hasNext) { 83 | val state = states.next() 84 | val loss = state.value 85 | //logInfo(state.convergedReason.get.toString) 86 | logger.debug(s"Iteration $iteration finished with $loss") 87 | 88 | Y := asDenseMatrix(state.x, n, noDims) 89 | //subscriber.onNext((iteration, Y.copy, Some(loss))) 90 | iteration += 1 91 | } 92 | } 93 | 94 | Y 95 | } 96 | 97 | private[this] def asDenseMatrix(v: DenseVector[Double], n: Int, noDims: Int) = { 98 | v.asDenseMatrix.reshape(n, noDims) 99 | } 100 | 101 | private class CostFun( 102 | P: RDD[Array[(Int, Iterable[(Int, Double)])]], 103 | n: Int, 104 | noDims: Int, 105 | exaggeration: Boolean) extends DiffFunction[DenseVector[Double]] { 106 | 107 | override def calculate(weights: DenseVector[Double]): (Double, DenseVector[Double]) = { 108 | val bcY = P.context.broadcast(asDenseMatrix(weights, n, noDims)) 109 | val bcExaggeration = P.context.broadcast(exaggeration) 110 | 111 | val numerator = P.map{ arr => TSNEGradient.computeNumerator(bcY.value, arr.map(_._1): _*) }.cache() 112 | val bcNumerator = P.context.broadcast({ 113 | numerator.treeAggregate(0.0)(seqOp = (x, v) => x + sum(v), combOp = _ + _) 114 | }) 115 | 116 | val (dY, loss) = P.zip(numerator).treeAggregate((DenseMatrix.zeros[Double](n, noDims), 0.0))( 117 | seqOp = (c, v) => { 118 | // c: (grad, loss), v: (Array[(i, Iterable(j, Distance))], numerator) 119 | // TODO: See if we can include early_exaggeration 120 | val l = TSNEGradient.compute(v._1, bcY.value, v._2, bcNumerator.value, c._1, bcExaggeration.value) 121 | (c._1, c._2 + l) 122 | }, 123 | combOp = (c1, c2) => { 124 | // c: (grad, loss) 125 | (c1._1 += c2._1, c1._2 + c2._2) 126 | }) 127 | 128 | numerator.unpersist() 129 | 130 | (loss, new DenseVector(dY.data)) 131 | } 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /spark-tsne-core/src/main/scala/com/github/saurfang/spark/tsne/TSNEGradient.scala: -------------------------------------------------------------------------------- 1 | package com.github.saurfang.spark.tsne 2 | 3 | import breeze.linalg._ 4 | import breeze.numerics._ 5 | import com.github.saurfang.spark.tsne.tree.SPTree 6 | import org.slf4j.LoggerFactory 7 | 8 | object TSNEGradient { 9 | def logger = LoggerFactory.getLogger(TSNEGradient.getClass) 10 | 11 | /** 12 | * Compute the numerator from the matrix Y 13 | * 14 | * @param idx the index in the matrix to use. 15 | * @param Y the matrix to analyze 16 | * @return the numerator 17 | */ 18 | def computeNumerator(Y: DenseMatrix[Double], idx: Int *): DenseMatrix[Double] = { 19 | // Y_sum = ||Y_i||^2 20 | val sumY = sum(pow(Y, 2).apply(*, ::)) // n * 1 21 | val subY = Y(idx, ::).toDenseMatrix // k * 1 22 | val y1: DenseMatrix[Double] = Y * (-2.0 :* subY.t) // n * k 23 | val num: DenseMatrix[Double] = (y1(::, *) + sumY).t // k * n 24 | num := 1.0 :/ (1.0 :+ (num(::, *) + sumY(idx).toDenseVector)) // k * n 25 | 26 | idx.indices.foreach(i => num.update(i, idx(i), 0.0)) // num(i, i) = 0 27 | 28 | num 29 | } 30 | 31 | /** 32 | * Compute the TSNE Gradient at i. Update the gradient through dY then return costs attributed at i. 33 | * 34 | * @param data data point for row i by list of pair of (j, p_ij) and 0 <= j < n 35 | * @param Y current Y [n * 2] 36 | * @param totalNum the common numerator that captures the t-distribution of Y 37 | * @param dY gradient of Y 38 | * @return loss attributed to row i 39 | */ 40 | def compute( 41 | data: Array[(Int, Iterable[(Int, Double)])], 42 | Y: DenseMatrix[Double], 43 | num: DenseMatrix[Double], 44 | totalNum: Double, 45 | dY: DenseMatrix[Double], 46 | exaggeration: Boolean): Double = { 47 | // q = (1 + ||Y_i - Y_j||^2)^-1 / sum(1 + ||Y_k - Y_l||^2)^-1 48 | val q: DenseMatrix[Double] = num / totalNum 49 | q.foreachPair{case ((i, j), v) => q.update(i, j, math.max(v, 1e-12))} 50 | 51 | // q = q - p 52 | val loss = data.zipWithIndex.flatMap { 53 | case ((_, itr), i) => 54 | itr.map{ 55 | case (j, p) => 56 | val exaggeratedP = if(exaggeration) p * 4 else p 57 | val qij = q(i, j) 58 | val l = exaggeratedP * math.log(exaggeratedP / qij) 59 | q.update(i, j, qij - exaggeratedP) 60 | if(l.isNaN) 0.0 else l 61 | } 62 | }.sum 63 | 64 | // l = [ (p_ij - q_ij) * (1 + ||Y_i - Y_j||^2)^-1 ] 65 | q :*= -num 66 | // l_sum = [0 0 ... sum(l) ... 0] 67 | sum(q(*, ::)).foreachPair{ case (i, v) => q.update(i, data(i)._1, q(i, data(i)._1) - v) } 68 | 69 | // dY_i = -4 * (l - l_sum) * Y 70 | val dYi: DenseMatrix[Double] = -4.0 :* (q * Y) 71 | data.map(_._1).zipWithIndex.foreach{ 72 | case (i, idx) => dY(i, ::) := dYi(idx, ::) 73 | } 74 | 75 | loss 76 | } 77 | 78 | /** BH Tree related functions **/ 79 | 80 | /** 81 | * 82 | * @param data array of (row_id, Seq(col_id), Vector(P_ij)) 83 | * @param Y matrix 84 | * @param posF positive forces 85 | */ 86 | def computeEdgeForces(data: Array[(Int, Seq[Int], DenseVector[Double])], 87 | Y: DenseMatrix[Double], 88 | posF: DenseMatrix[Double]): Unit = { 89 | data.foreach { 90 | case (i, cols, vec) => 91 | // k x D - 1 x D => k x D 92 | val diff = Y(cols, ::).toDenseMatrix.apply(*, ::) - Y(i, ::).t 93 | // k x D => k x 1 94 | val qZ = 1.0 :+ sum(pow(diff, 2).apply(*, ::)) 95 | posF(i, ::) := (vec :/ qZ).t * (-diff) 96 | } 97 | } 98 | 99 | def computeNonEdgeForces(tree: SPTree, 100 | Y: DenseMatrix[Double], 101 | theta: Double, 102 | negF: DenseMatrix[Double], 103 | idx: Int *): Double = { 104 | idx.foldLeft(0.0)((acc, i) => acc + computeNonEdgeForce(tree, Y(i, ::).t, theta, negF, i)) 105 | } 106 | 107 | /** 108 | * Calcualte negative forces using BH approximation 109 | * 110 | * @param tree SPTree used for approximation 111 | * @param y y_i 112 | * @param theta threshold for correctness / speed 113 | * @param negF negative forces 114 | * @param i row 115 | * @return sum of Q 116 | */ 117 | private def computeNonEdgeForce(tree: SPTree, 118 | y: DenseVector[Double], 119 | theta: Double, 120 | negF: DenseMatrix[Double], 121 | i: Int): Double = { 122 | import tree._ 123 | if(getCount == 0 || (isLeaf && center.equals(y))) { 124 | 0.0 125 | } else { 126 | val diff = y - center 127 | val diffSq = sum(pow(diff, 2)) 128 | if(isLeaf || radiusSq / diffSq < theta) { 129 | val qZ = 1 / (1 + diffSq) 130 | val nqZ = getCount * qZ 131 | negF(i, ::) :+= (nqZ * qZ * diff).t 132 | nqZ 133 | } else { 134 | children.foldLeft(0.0)((acc, child) => acc + computeNonEdgeForce(child, y, theta, negF, i)) 135 | } 136 | } 137 | } 138 | 139 | def computeLoss(data: Array[(Int, Seq[Int], DenseVector[Double])], 140 | Y: DenseMatrix[Double], 141 | sumQ: Double): Double = { 142 | data.foldLeft(0.0){ 143 | case (acc, (i, cols, vec)) => 144 | val diff = Y(cols, ::).toDenseMatrix.apply(*, ::) - Y(i, ::).t 145 | val diffSq = sum(pow(diff, 2).apply(*, ::)) 146 | val Q = (1.0 :/ (1.0 :+ diffSq)) :/ sumQ 147 | sum(vec :* breeze.numerics.log(max(vec, 1e-12) :/ max(Q, 1e-12))) 148 | } 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /spark-tsne-player/src/main/html/tsne.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | t-SNE Viewer 5 | 42 | 43 |

T-SNE Viewer

44 | 45 |

46 | 47 | 48 | 49 |

Source: The Wealth & Health of Nations, Mike Bostock.

50 | 51 | 52 | 53 | 210 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2015 Forest Fang 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | =================================== 204 | 205 | This project also contains code from TSne.jl and d3.js. 206 | License can be found at: 207 | https://github.com/lejon/TSne.jl/blob/master/LICENSE.md 208 | and 209 | https://github.com/mbostock/d3/blob/master/LICENSE 210 | respectively. 211 | --------------------------------------------------------------------------------