├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── build.sbt ├── project ├── build.properties ├── build.sbt └── plugins.sbt └── src ├── main └── scala │ └── scalarank │ ├── datapoint │ ├── Datapoint.scala │ ├── Query.scala │ └── SVMRankDatapoint.scala │ ├── metrics │ └── package.scala │ └── ranker │ ├── LinearRegressionRanker.scala │ ├── OracleRanker.scala │ ├── RandomRanker.scala │ ├── RankNetRanker.scala │ └── Ranker.scala └── test ├── resources ├── test.txt └── train.txt └── scala └── scalarank ├── TestData.scala ├── metrics └── MetricsSpec.scala └── ranker ├── GradientCheck.scala ├── LinearRegressionRankerSpec.scala ├── OracleRankerSpec.scala ├── RandomRankerSpec.scala ├── RankNetRankerSpec.scala └── RankerSpec.scala /.gitignore: -------------------------------------------------------------------------------- 1 | # general 2 | *.class 3 | *.log 4 | 5 | # sbt specific 6 | .cache 7 | .history 8 | .lib/ 9 | dist/* 10 | target/ 11 | lib_managed/ 12 | src_managed/ 13 | project/boot/ 14 | project/plugins/project/ 15 | 16 | # Scala-IDE specific 17 | .scala_dependencies 18 | .worksheet 19 | .idea 20 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | scala: 3 | - 2.11.8 4 | jdk: 5 | - oraclejdk7 6 | - oraclejdk8 7 | branches: 8 | only: 9 | - master 10 | 11 | script: "sbt clean coverage test" 12 | after_success: "sbt coverageReport coveralls" 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Rolf Jagerman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ScalaRank 2 | 3 | [![Build Status](https://travis-ci.org/rjagerman/scalarank.svg?branch=master)](https://travis-ci.org/rjagerman/scalarank) [![Coverage Status](https://coveralls.io/repos/github/rjagerman/scalarank/badge.svg?branch=master)](https://coveralls.io/github/rjagerman/scalarank?branch=master) 4 | 5 | :warning: This library is no longer maintained and the repository has been archived 6 | 7 | A modern scala library providing efficient implementations of offline learning to rank algorithms. Under the hood we use 8 | [nd4j](http://nd4j.org/) and [deeplearning4j](https://deeplearning4j.org/) for our scientific computing and neural 9 | network needs. 10 | 11 | ## Algorithms 12 | 13 | Included algorithms are: 14 | * **Oracle**: An oracle ranker that predicts perfectly but requires relevance labels during ranking. 15 | * **Linear Regression**: A linear regression ranker that ranks by predicting scalar values. 16 | * **[RankNet](https://www.microsoft.com/en-us/research/publication/learning-to-rank-using-gradient-descent/)**: A 17 | neural network with a cost function that minimizes number of wrong inversions. 18 | 19 | The following algorithms are currently in development: 20 | * **[LambdaRank](http://research.microsoft.com/en-us/um/people/cburges/papers/LambdaRank.pdf)**: An extension to 21 | RankNet that optimizes non-smooth list-wise metrics directly. 22 | * **[LambdaMART](http://research.microsoft.com/en-us/um/people/cburges/tech_reports/MSR-TR-2010-82.pdf)**: Variant of 23 | LambdaRank that uses boosted regression trees instead of neural networks. 24 | 25 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | organization := "nl.uva.science.ilps" 2 | 3 | name := "ScalaRank" 4 | 5 | version := "1.0" 6 | 7 | scalaVersion := "2.11.8" 8 | 9 | 10 | libraryDependencies += "org.nd4j" %% "nd4s" % "0.6.0" 11 | 12 | libraryDependencies += "org.deeplearning4j" % "deeplearning4j-core" % "0.6.0" 13 | 14 | libraryDependencies += "org.nd4j" % "nd4j-native-platform" % "0.6.0" 15 | 16 | libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.1" % "test" 17 | 18 | classpathTypes += "maven-plugin" 19 | 20 | parallelExecution in Test := false 21 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 0.13.8 -------------------------------------------------------------------------------- /project/build.sbt: -------------------------------------------------------------------------------- 1 | 2 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.3.5") 3 | 4 | addSbtPlugin("org.scoverage" % "sbt-coveralls" % "1.1.0") 5 | 6 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | logLevel := Level.Warn -------------------------------------------------------------------------------- /src/main/scala/scalarank/datapoint/Datapoint.scala: -------------------------------------------------------------------------------- 1 | package scalarank.datapoint 2 | 3 | import org.nd4j.linalg.api.ndarray.INDArray 4 | 5 | /** 6 | * A data point, this is typically a feature vector containing query-document features 7 | */ 8 | abstract class Datapoint { 9 | 10 | /** 11 | * The features as a dense vector 12 | */ 13 | val features: INDArray 14 | 15 | } 16 | 17 | /** 18 | * For labeling data points with relevance 19 | */ 20 | trait Relevance { 21 | 22 | /** 23 | * The relevance of the data point. Typically higher means more relevant. 24 | */ 25 | val relevance: Double 26 | 27 | } 28 | 29 | -------------------------------------------------------------------------------- /src/main/scala/scalarank/datapoint/Query.scala: -------------------------------------------------------------------------------- 1 | package scalarank.datapoint 2 | 3 | /** 4 | * A query 5 | * 6 | * @param id The query identifier 7 | * @param datapoints An array of data points representing query-document pairs 8 | */ 9 | class Query[A <: Datapoint](val id: Int, val datapoints: IndexedSeq[A]) 10 | -------------------------------------------------------------------------------- /src/main/scala/scalarank/datapoint/SVMRankDatapoint.scala: -------------------------------------------------------------------------------- 1 | package scalarank.datapoint 2 | 3 | import org.nd4j.linalg.api.ndarray.INDArray 4 | import org.nd4j.linalg.factory.Nd4j 5 | 6 | /** 7 | * A datapoint based on SVM rank syntax 8 | * 9 | * @param line The line containing this data point 10 | */ 11 | class SVMRankDatapoint(line: String) extends Datapoint with Relevance { 12 | 13 | override val features: INDArray = { 14 | val (_, values) = SVMRankDatapoint.FEATURE_REGEX.findAllIn(line). 15 | map(m => m.split(":")). 16 | map(m => (m(0).toInt, m(1).toDouble)). 17 | toArray.sorted.unzip 18 | Nd4j.create(values) 19 | } 20 | 21 | override val relevance: Double = SVMRankDatapoint.RELEVANCE_REGEX.findFirstIn(line).get.toDouble 22 | 23 | val qid: Int = line match { case SVMRankDatapoint.QID_REGEX(id) => id.toInt } 24 | 25 | } 26 | 27 | object SVMRankDatapoint { 28 | 29 | private val QID_REGEX = """.*qid:([0-9]+).*""".r 30 | private val RELEVANCE_REGEX= """^[0-9]+""".r 31 | private val FEATURE_REGEX = """[0-9]+:[^ ]+""".r 32 | 33 | def apply(line: String): SVMRankDatapoint = new SVMRankDatapoint(line) 34 | 35 | } 36 | 37 | -------------------------------------------------------------------------------- /src/main/scala/scalarank/metrics/package.scala: -------------------------------------------------------------------------------- 1 | package scalarank 2 | 3 | import scala.reflect.ClassTag 4 | import scalarank.datapoint.{Datapoint, Relevance} 5 | import scalarank.ranker.OracleRanker 6 | 7 | /** 8 | * Provides methods for computing common IR metrics 9 | * 10 | * ==Overview== 11 | * Each method takes as input a sorted array of Datapoints with Relevance labels. The sorting of this array is what 12 | * will be evaluated. 13 | * 14 | * {{{ 15 | * scala> metrics.precision(ranking) 16 | * res0: Double = 0.66666666 17 | * }}} 18 | * 19 | * In order to compute any metric at a cutoff point K, just use the take method on the input array: 20 | * 21 | * {{{ 22 | * scala> val K = 10 23 | * K: Int = 10 24 | * scala> metrics.precision(ranking.take(K)) 25 | * res1: Double = 0.7 26 | * }}} 27 | * 28 | */ 29 | package object metrics { 30 | 31 | /** 32 | * Computes precision 33 | * 34 | * @param ranking The ranked list 35 | * @return The precision 36 | */ 37 | def precision[D <: Datapoint with Relevance](ranking: Seq[D]): Double = { 38 | ranking.count(d => d.relevance > 0.0).toDouble / ranking.length.toDouble 39 | } 40 | 41 | /** 42 | * Computes average precision 43 | * 44 | * @param ranking The ranked list 45 | * @return The average precision 46 | */ 47 | def averagePrecision[D <: Datapoint with Relevance](ranking: Seq[D]): Double = { 48 | val relevantDocuments = ranking.zipWithIndex.filter { case (d, i) => d.relevance != 0.0 } 49 | average(relevantDocuments.zipWithIndex.map { case ((d, i), c) => 50 | (c + 1.0) / (i + 1.0) 51 | }) 52 | } 53 | 54 | /** 55 | * Computes the reciprocal rank 56 | * 57 | * @param ranking The ranked list 58 | * @return The reciprocal rank 59 | */ 60 | def reciprocalRank[D <: Datapoint with Relevance](ranking: Seq[D]): Double = { 61 | ranking.indexWhere(d => d.relevance > 0.0) match { 62 | case -1 => 0.0 63 | case x => 1.0 / (1 + x).toDouble 64 | } 65 | } 66 | 67 | /** 68 | * Computes the discounted cumulative gain 69 | * 70 | * @param ranking The ranked list 71 | * @return The discounted cumulative gain 72 | */ 73 | def dcg[D <: Datapoint with Relevance](ranking: Seq[D]): Double = { 74 | ranking.zipWithIndex.map { 75 | case (d, 0) => d.relevance 76 | case (d, i) => d.relevance * (1.0 / (Math.log(2 + i) / Math.log(2.0))) 77 | }.sum 78 | } 79 | 80 | /** 81 | * Computes the normalized discounted cumulative gain 82 | * 83 | * @param ranking The ranked list 84 | * @return The normalized discounted cumulative gain 85 | */ 86 | def ndcg[D <: Datapoint with Relevance : ClassTag](ranking: Seq[D]): Double = { 87 | val oracle = new OracleRanker[D] 88 | dcg(oracle.rank(ranking.toIndexedSeq)) match { 89 | case 0 => 0.0 90 | case perfectDcg => dcg(ranking) / perfectDcg 91 | } 92 | } 93 | 94 | /** 95 | * Computes the mean of a metric over a series of rankings 96 | * 97 | * @param rankings The rankings (e.g. per query) 98 | * @param metric The metric to use 99 | * @return The mean 100 | */ 101 | def mean[D <: Datapoint with Relevance](rankings: Iterable[Seq[D]], metric: Seq[D] => Double): Double = { 102 | meanAtK[D](rankings, metric, rankings.map(r => r.size).max) 103 | } 104 | 105 | /** 106 | * Computes the mean of a metric over a series of rankings and cutting them off at K 107 | * 108 | * @param rankings The rankings (e.g. per query) 109 | * @param metric The metric to use 110 | * @param K The cutoff point 111 | * @return The mean 112 | */ 113 | def meanAtK[D <: Datapoint with Relevance](rankings: Iterable[Seq[D]], metric: Seq[D] => Double, K: Int): Double = { 114 | average(rankings.map(ranking => metric(ranking.take(K)))) 115 | } 116 | 117 | /** 118 | * Computes the average of an iterable of numerics 119 | * 120 | * @param ts The iterable 121 | * @param num The numerical type 122 | * @tparam T The type 123 | * @return The average 124 | */ 125 | private def average[T](ts: Iterable[T])(implicit num: Numeric[T]): Double = ts.size match { 126 | case 0 => 0.0 127 | case size => num.toDouble(ts.sum) / size.toDouble 128 | } 129 | 130 | } 131 | -------------------------------------------------------------------------------- /src/main/scala/scalarank/ranker/LinearRegressionRanker.scala: -------------------------------------------------------------------------------- 1 | package scalarank.ranker 2 | 3 | import org.deeplearning4j.nn.api.OptimizationAlgorithm 4 | import org.deeplearning4j.nn.conf.layers.{DenseLayer, OutputLayer} 5 | import org.deeplearning4j.nn.conf.{NeuralNetConfiguration, Updater} 6 | import org.deeplearning4j.nn.multilayer.MultiLayerNetwork 7 | import org.deeplearning4j.nn.weights.WeightInit 8 | import org.nd4j.linalg.api.ndarray.INDArray 9 | import org.nd4j.linalg.dataset.api.iterator.DataSetIterator 10 | import org.nd4j.linalg.factory.Nd4j 11 | import org.nd4s.Implicits._ 12 | import org.nd4j.linalg.lossfunctions.LossFunctions 13 | 14 | import scala.collection.JavaConverters._ 15 | import scala.reflect.ClassTag 16 | import scalarank.datapoint.{Datapoint, Query, Relevance} 17 | 18 | /** 19 | * A linear regression ranker that ranks by scoring data points as scalar values. 20 | * 21 | * Linear regression is implemented as a single-layer neural network with an MSE loss function and an identity 22 | * activation function. 23 | * 24 | * @param features The dimensionality of the input features 25 | * @param seed The random seed 26 | * @param iterations The number of iterations 27 | * @param learningRate The learning rate 28 | * @tparam TrainType Type to train on which needs to be at least Datapoint with Relevance 29 | * @tparam RankType Type to rank with which needs to be at least Datapoint 30 | */ 31 | class LinearRegressionRanker[TrainType <: Datapoint with Relevance,RankType <: Datapoint : ClassTag](val features: Int, 32 | val seed: Int = 42, 33 | val iterations: Int = 100, 34 | val learningRate: Double = 1e-3) 35 | extends Ranker[TrainType, RankType] { 36 | 37 | val network = new MultiLayerNetwork(new NeuralNetConfiguration.Builder() 38 | .seed(seed) 39 | .iterations(iterations) 40 | .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) 41 | .learningRate(learningRate) 42 | .updater(Updater.ADAM) 43 | .list() 44 | .layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MSE) 45 | .activation("identity") 46 | .nIn(features) 47 | .nOut(1) 48 | .build()) 49 | .pretrain(false).backprop(true).build() 50 | ) 51 | 52 | /** 53 | * Trains the ranker on a set of labeled data points 54 | * 55 | * @param data The set of labeled data points 56 | */ 57 | override def train(data: Iterable[Query[TrainType]]): Unit = { 58 | 59 | val datapoints = data.toArray.flatMap(x => x.datapoints) 60 | val labels = datapoints.map(x => x.relevance) 61 | 62 | val X = toMatrix[TrainType](datapoints) 63 | val y = labels.toNDArray 64 | 65 | network.fit(X, y) 66 | } 67 | 68 | /** 69 | * Ranks given set of data points 70 | * 71 | * @param data The set of data points 72 | * @return An ordered list of data points 73 | */ 74 | override def score(data: IndexedSeq[RankType]): IndexedSeq[Double] = { 75 | val X = toMatrix(data) 76 | val y = network.output(X) 77 | (0 until y.length()).map(i => y(i)) 78 | } 79 | 80 | /** 81 | * Converts given iterable of data points to an ND4J matrix 82 | * 83 | * @param data The data points 84 | * @tparam D The datapoint type 85 | * @return A matrix of the features 86 | */ 87 | private def toMatrix[D <: Datapoint](data: Iterable[D]): INDArray = { 88 | Nd4j.vstack(data.map(x => x.features).asJavaCollection) 89 | } 90 | 91 | } 92 | -------------------------------------------------------------------------------- /src/main/scala/scalarank/ranker/OracleRanker.scala: -------------------------------------------------------------------------------- 1 | package scalarank.ranker 2 | 3 | import scala.reflect.ClassTag 4 | import scalarank.datapoint.{Datapoint, Query, Relevance} 5 | 6 | /** 7 | * Ranks documents perfectly but requires relevance labels to be known during ranking 8 | * 9 | * @tparam T Type to train on and rank with which needs to be at least Datapoint with Relevance 10 | */ 11 | class OracleRanker[T <: Datapoint with Relevance : ClassTag] extends Ranker[T, T] { 12 | 13 | /** 14 | * Trains the ranker on a set of labeled data points 15 | * 16 | * @param data The set of labeled data points 17 | */ 18 | override def train(data: Iterable[Query[T]]): Unit = { } 19 | 20 | /** 21 | * Ranks given set of data points 22 | * 23 | * @param data The set of data points 24 | * @return An ordered list of data points 25 | */ 26 | override def score(data: IndexedSeq[T]): IndexedSeq[Double] = { 27 | val maximum = data.map(d => d.relevance).max 28 | data.map(d => maximum - d.relevance) 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /src/main/scala/scalarank/ranker/RandomRanker.scala: -------------------------------------------------------------------------------- 1 | package scalarank.ranker 2 | 3 | import scala.reflect.ClassTag 4 | import scala.util.Random 5 | import scalarank.datapoint.{Datapoint, Query, Relevance} 6 | 7 | /** 8 | * Ranks documents randomly 9 | * 10 | * @tparam T Type to train on and rank with which needs to be at least Datapoint with Relevance 11 | */ 12 | class RandomRanker[T <: Datapoint with Relevance, R <: Datapoint : ClassTag](seed: Int) extends Ranker[T, R] { 13 | 14 | val rng = new Random(seed) 15 | 16 | /** 17 | * Trains the ranker on a set of labeled data points 18 | * 19 | * @param data The set of labeled data points 20 | */ 21 | override def train(data: Iterable[Query[T]]): Unit = { } 22 | 23 | /** 24 | * Ranks given set of data points 25 | * 26 | * @param data The set of data points 27 | * @return An ordered list of data points 28 | */ 29 | override def score(data: IndexedSeq[R]): IndexedSeq[Double] = { 30 | data.indices.map(_ => rng.nextDouble()).toArray 31 | } 32 | 33 | } 34 | -------------------------------------------------------------------------------- /src/main/scala/scalarank/ranker/RankNetRanker.scala: -------------------------------------------------------------------------------- 1 | package scalarank.ranker 2 | 3 | import org.apache.commons.math3.util.Pair 4 | import org.deeplearning4j.nn.api.OptimizationAlgorithm 5 | import org.deeplearning4j.nn.conf.layers.{DenseLayer, OutputLayer} 6 | import org.deeplearning4j.nn.conf.{NeuralNetConfiguration, Updater} 7 | import org.deeplearning4j.nn.multilayer.MultiLayerNetwork 8 | import org.deeplearning4j.nn.weights.WeightInit 9 | import org.nd4j.linalg.api.ndarray.INDArray 10 | import org.nd4j.linalg.ops.transforms.Transforms._ 11 | import org.nd4j.linalg.factory.Nd4j 12 | import org.nd4s.Implicits._ 13 | import org.nd4j.linalg.lossfunctions.ILossFunction 14 | 15 | import scala.collection.JavaConverters._ 16 | import scala.reflect.ClassTag 17 | import scalarank.datapoint.{Datapoint, Query, Relevance} 18 | 19 | /** 20 | * A RankNet ranker that minimizes number of pair-wise inversions 21 | * 22 | * Burges, Chris, et al. "Learning to rank using gradient descent." 23 | * Proceedings of the 22nd international conference on Machine learning. ACM, 2005. 24 | * 25 | * @param features The dimensionality of the input features 26 | * @param σ The shape of the sigmoid 27 | * @param hidden An array where each value n corresponds to a dense layer of size n in the network 28 | * @param seed The random seed 29 | * @param iterations The number of iterations 30 | * @param learningRate The learning rate 31 | * @tparam TrainType Type to train on which needs to be at least Datapoint with Relevance 32 | * @tparam RankType Type to rank with which needs to be at least Datapoint 33 | */ 34 | class RankNetRanker[TrainType <: Datapoint with Relevance,RankType <: Datapoint : ClassTag](val features: Int, 35 | val σ: Double = 1.0, 36 | val hidden: Array[Int] = Array(10), 37 | val seed: Int = 42, 38 | val iterations: Int = 20, 39 | val learningRate: Double = 5e-5) 40 | extends Ranker[TrainType, RankType] { 41 | 42 | /** 43 | * Custom RankNet loss function 44 | */ 45 | private val loss = new RankNetLoss(σ) 46 | 47 | /** 48 | * Neural network 49 | */ 50 | val network = new MultiLayerNetwork({ 51 | 52 | // Basic neural network settings 53 | var build = new NeuralNetConfiguration.Builder() 54 | .seed(seed) 55 | .iterations(1) 56 | .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) 57 | .learningRate(learningRate) 58 | .updater(Updater.ADAM) 59 | .list() 60 | 61 | // Construct hidden layers based on array "hidden" 62 | var in = features 63 | for (h <- hidden.indices) { 64 | build = build.layer(h, new DenseLayer.Builder() 65 | .nIn(in) 66 | .nOut(hidden(h)) 67 | .activation("relu") 68 | .weightInit(WeightInit.RELU) 69 | .build()) 70 | in = hidden(h) 71 | } 72 | 73 | // Construct output layer with our custom loss function 74 | build.layer(hidden.length, new OutputLayer.Builder(loss) 75 | .activation("identity") 76 | .nIn(in) 77 | .nOut(1) 78 | .build()) 79 | .pretrain(false) 80 | .backprop(true) 81 | .build() 82 | }) 83 | 84 | /** 85 | * Trains the ranker on a set of labeled data points 86 | * 87 | * @param data The set of labeled data points 88 | */ 89 | override def train(data: Iterable[Query[TrainType]]): Unit = { 90 | 91 | for (t <- 0 until iterations) { 92 | data.foreach { query => 93 | val X = toMatrix[TrainType](query.datapoints) 94 | val y = query.datapoints.map(_.relevance).toNDArray 95 | network.fit(X, y) 96 | } 97 | } 98 | 99 | } 100 | 101 | /** 102 | * Ranks given set of data points 103 | * 104 | * @param data The set of data points 105 | * @return An ordered list of data points 106 | */ 107 | override def score(data: IndexedSeq[RankType]): IndexedSeq[Double] = { 108 | val X = toMatrix(data) 109 | val y = network.output(X) 110 | (0 until y.length()).map(i => y(i)) 111 | } 112 | 113 | /** 114 | * Converts given iterable of data points to an ND4J matrix 115 | * 116 | * @param data The data points 117 | * @tparam D The datapoint type 118 | * @return A matrix of the features 119 | */ 120 | private def toMatrix[D <: Datapoint](data: Iterable[D]): INDArray = { 121 | Nd4j.vstack(data.map(x => x.features).asJavaCollection) 122 | } 123 | 124 | } 125 | 126 | /** 127 | * Loss function for RankNet 128 | * 129 | * @param σ The shape of the sigmoid 130 | */ 131 | private class RankNetLoss(σ: Double = 1.0) extends ILossFunction { 132 | 133 | override def computeGradientAndScore(labels: INDArray, 134 | preOutput: INDArray, 135 | activationFn: String, 136 | mask: INDArray, 137 | average: Boolean): Pair[java.lang.Double, INDArray] = { 138 | val S_var = S(labels) 139 | val sigma_var = sigma(output(preOutput, activationFn)) 140 | Pair.create(score(S_var, sigma_var, average), gradient(S_var, sigma_var)) 141 | } 142 | 143 | override def computeGradient(labels: INDArray, 144 | preOutput: INDArray, 145 | activationFn: String, 146 | mask: INDArray): INDArray = { 147 | gradient(S(labels), sigma(output(preOutput, activationFn))) 148 | } 149 | 150 | override def computeScoreArray(labels: INDArray, 151 | preOutput: INDArray, 152 | activationFn: String, 153 | mask: INDArray): INDArray = { 154 | scoreArray(S(labels), sigma(output(preOutput, activationFn))) 155 | } 156 | 157 | override def computeScore(labels: INDArray, 158 | preOutput: INDArray, 159 | activationFn: java.lang.String, 160 | mask: INDArray, 161 | average: Boolean): Double = { 162 | score(S(labels), sigma(output(preOutput, activationFn)), average) 163 | } 164 | 165 | /** 166 | * Computes the gradient for the full ranking 167 | * 168 | * @param S The S_ij matrix, indicating whether certain elements should be ranked higher or lower 169 | * @param sigma The sigma matrix, indicating how scores relate to each other 170 | * @return The gradient 171 | */ 172 | private def gradient(S: INDArray, sigma: INDArray): INDArray = { 173 | Nd4j.mean(((-S + 1)*0.5 - sigmoid(-sigma)) * σ, 0).transpose 174 | } 175 | 176 | /** 177 | * Computes the score for the full ranking 178 | * 179 | * @param S The S_ij matrix, indicating whether certain elements should be ranked higher or lower 180 | * @param sigma The sigma matrix, indicating how scores relate to each other 181 | * @return The score array 182 | */ 183 | private def scoreArray(S: INDArray, sigma: INDArray): INDArray = { 184 | Nd4j.mean((-S + 1) * 0.5 * sigma + log(exp(-sigma) + 1), 0) 185 | } 186 | 187 | /** 188 | * Computes an aggregate over the score, with either summing or averaging 189 | * 190 | * @param S The S_ij matrix, indicating whether certain elements should be ranked higher or lower 191 | * @param sigma The sigma matrix, indicating how scores relate to each other 192 | * @param average Whether to average or sum 193 | * @return The score as a single value 194 | */ 195 | private def score(S: INDArray, sigma: INDArray, average: Boolean): Double = average match { 196 | case true => Nd4j.mean(scoreArray(S, sigma))(0) 197 | case false => Nd4j.sum(scoreArray(S, sigma))(0) 198 | } 199 | 200 | /** 201 | * Computes the matrix S_ij, which indicates wheter certain elements should be ranked higher or lower 202 | * 203 | * S_ij = { 204 | * 1.0 if y_i > y_j 205 | * 0.0 if y_i = y_j 206 | * -1.0 if y_i < y_j 207 | * } 208 | * 209 | * @param labels The labels 210 | * @return The S_ij matrix 211 | */ 212 | private def S(labels: INDArray): INDArray = { 213 | val labelMatrix = labels.transpose.mmul(Nd4j.ones(labels.rows, labels.columns)) - Nd4j.ones(labels.columns, labels.rows).mmul(labels) 214 | labelMatrix.gt(0) - labelMatrix.lt(0) 215 | } 216 | 217 | /** 218 | * Computes the sigma matrix, which indicates how scores relate to each other 219 | * 220 | * sigma_ij = σ * (s_i - s_j) 221 | * 222 | * @param outputs The signal outputs from the network 223 | * @return The sigma matrix 224 | */ 225 | private def sigma(outputs: INDArray): INDArray = { 226 | (outputs.transpose.mmul(Nd4j.ones(outputs.rows, outputs.columns)) - Nd4j.ones(outputs.columns, outputs.rows).mmul(outputs)) * σ 227 | } 228 | 229 | /** 230 | * Compute output with an activation function 231 | * 232 | * @param preOutput The output of the network before applying the activation function 233 | * @param activationFn The activation function 234 | * @return The output with given activation function 235 | */ 236 | private def output(preOutput: INDArray, activationFn: String): INDArray = { 237 | Nd4j.getExecutioner.execAndReturn(Nd4j.getOpFactory.createTransform(activationFn, preOutput.dup)) 238 | } 239 | 240 | } 241 | 242 | -------------------------------------------------------------------------------- /src/main/scala/scalarank/ranker/Ranker.scala: -------------------------------------------------------------------------------- 1 | package scalarank.ranker 2 | 3 | import scala.reflect.ClassTag 4 | import scalarank.datapoint.{Datapoint, Query, Relevance} 5 | 6 | /** 7 | * An abstract ranker interface 8 | * 9 | * @tparam TrainType Type to train on which needs to be at least Datapoint with Relevance 10 | * @tparam RankType Type to rank with which needs to be at least Datapoint 11 | */ 12 | trait Ranker[TrainType <: Datapoint with Relevance, RankType <: Datapoint] { 13 | 14 | /** 15 | * Trains the ranker on a set of labeled data points 16 | * 17 | * @param data The set of labeled data points 18 | */ 19 | def train(data: Iterable[Query[TrainType]]): Unit 20 | 21 | /** 22 | * Scores the given set of query-document pairs 23 | * 24 | * @param data The data set 25 | * @return The scores 26 | */ 27 | def score(data: IndexedSeq[RankType]): IndexedSeq[Double] 28 | 29 | /** 30 | * Ranks given set of data points 31 | * 32 | * @param data The set of data points 33 | * @return An ordered list of data points 34 | */ 35 | def rank[R <: RankType : ClassTag](data: IndexedSeq[R]): IndexedSeq[R] = { 36 | sort(data, score(data)) 37 | } 38 | 39 | /** 40 | * Sorts given data using given set of scores 41 | * 42 | * @param data The data 43 | * @param scores The computed scores 44 | * @return A sorted array of ranks 45 | */ 46 | protected def sort[R <: RankType : ClassTag](data: IndexedSeq[R], scores: IndexedSeq[Double]): IndexedSeq[R] = { 47 | data.zip(scores).sortBy(_._2).map(_._1) 48 | } 49 | 50 | } 51 | 52 | -------------------------------------------------------------------------------- /src/test/scala/scalarank/TestData.scala: -------------------------------------------------------------------------------- 1 | package scalarank 2 | 3 | import org.nd4j.linalg.api.ndarray.INDArray 4 | import org.nd4s.Implicits._ 5 | 6 | import scalarank.datapoint.{Datapoint, Query, Relevance, SVMRankDatapoint} 7 | 8 | /** 9 | * An object that contains test data 10 | */ 11 | object TestData { 12 | 13 | val featureless: Array[Datapoint with Relevance] = Array( 14 | 4.0, 3.0, 4.0, 3.0, 1.0, 2.0, 1.0, 4.0, 0.0, 4.0, 0.0, 2.0, 2.0, 2.0, 1.0, 3.0, 2.0, 1.0, 0.0, 0.0, 0.0, 0.0 15 | ).map(r => new FeaturelessDatapointRelevance(r)) 16 | 17 | val featurelessPrecision = 0.7272727272727273 18 | val featurelessAveragePrecision = 0.9343461583351291 19 | val featurelessReciprocalRank = 1.000000 20 | val featurelessDCG = 16.31221516353917 21 | val featurelessnDCG = 0.937572811083981 22 | 23 | val sampleTrainData = readSVMRank("/train.txt") 24 | val sampleTestData = readSVMRank("/test.txt") 25 | 26 | def readSVMRank(file: String): IndexedSeq[Query[SVMRankDatapoint]] = { 27 | val samples = scala.io.Source.fromInputStream(getClass.getResourceAsStream(file)). 28 | getLines.map(l => SVMRankDatapoint(l)) 29 | samples.toArray.groupBy(d => d.qid).map { case (qid, ds) => 30 | new Query[SVMRankDatapoint](qid, ds) 31 | }.toIndexedSeq 32 | } 33 | 34 | /** 35 | * A test data point with a dense feature vector 36 | * 37 | * @param f The features as an array of doubles 38 | * @param r The relevance 39 | */ 40 | class TestDatapoint(f: Array[Double], r: Double) extends Datapoint with Relevance { 41 | override val features: INDArray = f.toNDArray 42 | override val relevance: Double = r 43 | } 44 | 45 | /** 46 | * A datapoint with relevance that does not contain features 47 | * 48 | * @param r The relevance label 49 | */ 50 | class FeaturelessDatapointRelevance(r: Double) extends Datapoint with Relevance { 51 | override val features: INDArray = null 52 | override val relevance: Double = r 53 | } 54 | 55 | } 56 | -------------------------------------------------------------------------------- /src/test/scala/scalarank/metrics/MetricsSpec.scala: -------------------------------------------------------------------------------- 1 | package scalarank.metrics 2 | 3 | import org.scalatest.FlatSpec 4 | import scalarank.TestData 5 | 6 | /** 7 | * Test specification for metrics 8 | */ 9 | class MetricsSpec extends FlatSpec { 10 | 11 | "Precision" should "be 1.0 for only relevant documents" in { 12 | val data = Array(1.0, 1.0, 1.0, 1.0, 1.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 13 | assert(precision(data) == 1.0) 14 | } 15 | 16 | it should "be 0.0 for only non-relevant documents" in { 17 | val data = Array(0.0, 0.0, 0.0, 0.0, 0.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 18 | assert(precision(data) == 0.0) 19 | } 20 | 21 | it should "be 0.5 when half the documents are relevant" in { 22 | val data = Array(0.0, 1.0, 0.0, 1.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 23 | assert(precision(data) == 0.5) 24 | } 25 | 26 | it should "be invariant to changes in ordering" in { 27 | val data = Array( 28 | Array(0.0, 1.0, 0.0, 1.0).map(x => new TestData.FeaturelessDatapointRelevance(x)), 29 | Array(1.0, 1.0, 0.0, 0.0).map(x => new TestData.FeaturelessDatapointRelevance(x)), 30 | Array(1.0, 0.0, 0.0, 1.0).map(x => new TestData.FeaturelessDatapointRelevance(x)), 31 | Array(0.0, 0.0, 1.0, 1.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 32 | ) 33 | data.foreach { d => 34 | assert(precision(d) == 0.5) 35 | } 36 | } 37 | 38 | it should "be %.4f for our test data set".format(TestData.featurelessPrecision) in { 39 | assert(precision(TestData.featureless) == TestData.featurelessPrecision) 40 | } 41 | 42 | "AveragePrecision" should "be 1.0 for only relevant documents" in { 43 | val data = Array(1.0, 1.0, 1.0, 1.0, 1.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 44 | assert(averagePrecision(data) == 1.0) 45 | } 46 | 47 | it should "be 0.0 for only non-relevant documents" in { 48 | val data = Array(0.0, 0.0, 0.0, 0.0, 0.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 49 | assert(averagePrecision(data) == 0.0) 50 | } 51 | 52 | it should "be 1.0 when exactly the first half of the documents are relevant" in { 53 | val data = Array(1.0, 1.0, 1.0, 0.0, 0.0, 0.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 54 | assert(averagePrecision(data) == 1.0) 55 | } 56 | 57 | it should "be %.4f for our test data set".format(TestData.featurelessAveragePrecision) in { 58 | assert(averagePrecision(TestData.featureless) == TestData.featurelessAveragePrecision) 59 | } 60 | 61 | "ReciprocalRank" should "be 1.0 for only relevant documents" in { 62 | val data = Array(1.0, 1.0, 1.0, 1.0, 1.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 63 | assert(reciprocalRank(data) == 1.0) 64 | } 65 | 66 | it should "be 1.0 when only the first document is relevant" in { 67 | val data = Array(1.0, 0.0, 0.0, 0.0, 0.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 68 | assert(reciprocalRank(data) == 1.0) 69 | } 70 | 71 | it should "be 0.0 for only non-relevant documents" in { 72 | val data = Array(0.0, 0.0, 0.0, 0.0, 0.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 73 | assert(reciprocalRank(data) == 0.0) 74 | } 75 | 76 | it should "be 0.5 when the second document is the first relevant one" in { 77 | val data = Array(0.0, 1.0, 0.0, 0.0, 0.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 78 | assert(reciprocalRank(data) == 0.5) 79 | } 80 | 81 | it should "be 0.3333 when the third document is the first relevant one" in { 82 | val data = Array(0.0, 0.0, 1.0, 0.0, 0.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 83 | assert(reciprocalRank(data) == 0.3333333333333333) 84 | } 85 | 86 | it should "be %.4f for our test data set".format(TestData.featurelessReciprocalRank) in { 87 | assert(reciprocalRank(TestData.featureless) == TestData.featurelessReciprocalRank) 88 | } 89 | 90 | "DCG" should "be 2.1309 for three relevant documents" in { 91 | val data = Array(1.0, 1.0, 1.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 92 | assert(dcg(data) == 2.1309297535714573) 93 | } 94 | 95 | it should "be 0.0 for only non-relevant documents" in { 96 | val data = Array(0.0, 0.0, 0.0, 0.0, 0.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 97 | assert(dcg(data) == 0.0) 98 | } 99 | 100 | it should "be %.4f for our test data set".format(TestData.featurelessDCG) in { 101 | assert(dcg(TestData.featureless) == TestData.featurelessDCG) 102 | } 103 | 104 | "nDCG" should "be 1.0 for only relevant documents" in { 105 | val data = Array(1.0, 1.0, 1.0, 1.0, 1.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 106 | assert(ndcg(data) == 1.0) 107 | } 108 | 109 | it should "be 0.0 for only non-relevant documents" in { 110 | val data = Array(0.0, 0.0, 0.0, 0.0, 0.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 111 | assert(ndcg(data) == 0.0) 112 | } 113 | 114 | it should "be 1.0 for a perfectly ranked list" in { 115 | val data = Array(5.0, 5.0, 4.0, 2.0, 0.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 116 | assert(ndcg(data) == 1.0) 117 | } 118 | 119 | it should "be less than 1.0 for a non-perfectly ranked list" in { 120 | val data = Array(5.0, 4.0, 5.0, 2.0, 0.0).map(x => new TestData.FeaturelessDatapointRelevance(x)) 121 | assert(ndcg(data) < 1.0) 122 | } 123 | 124 | it should "be %.4f for our test data set".format(TestData.featurelessnDCG) in { 125 | assert(ndcg(TestData.featureless) == TestData.featurelessnDCG) 126 | } 127 | 128 | } 129 | -------------------------------------------------------------------------------- /src/test/scala/scalarank/ranker/GradientCheck.scala: -------------------------------------------------------------------------------- 1 | package scalarank.ranker 2 | 3 | import org.nd4j.linalg.api.ndarray.INDArray 4 | import org.nd4j.linalg.factory.Nd4j 5 | import org.nd4s.Implicits._ 6 | 7 | /** 8 | * Test trait for checking gradient functions 9 | */ 10 | trait GradientCheck { 11 | 12 | /** 13 | * Computes the gradient limit: lim h→0 (‖f(x+h) - f(x) - ∇f(x) · h‖ / ‖h‖) 14 | * 15 | * @param gradient The gradient (as a vector) 16 | * @param x The input to compute said gradient (as a vector) 17 | * @param function The function over which the gradient is computed 18 | * @return The limit 19 | */ 20 | def gradientLimits(gradient: INDArray, x: INDArray, function: INDArray => INDArray): Array[Double] = { 21 | val rand = Nd4j.randn(x.rows, x.columns) 22 | Array(1e1, 1, 1e-1, 1e-2).map { ε => 23 | (0 until x.columns).map { i => 24 | val e = Nd4j.zeros(x.columns) 25 | e(i) = 1.0 26 | val approximateGradient = (function(x + e * ε) - function(x - e * ε)) / (2*ε) 27 | Math.abs(approximateGradient(i) - gradient(i)) 28 | }.sum 29 | } 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /src/test/scala/scalarank/ranker/LinearRegressionRankerSpec.scala: -------------------------------------------------------------------------------- 1 | package scalarank.ranker 2 | 3 | import org.scalatest.FlatSpec 4 | 5 | import scalarank.{TestData, metrics} 6 | import scalarank.datapoint.SVMRankDatapoint 7 | import scalarank.metrics._ 8 | 9 | /** 10 | * Test specification for the Linear Regression ranker 11 | */ 12 | class LinearRegressionRankerSpec extends RankerSpec { 13 | 14 | "A LinearRegression Ranker" should "report appropriate nDCG results on MQ2008 Fold 1" in { 15 | testRanker(new LinearRegressionRanker(featureSize, seed=42), ndcg, "nDCG") 16 | } 17 | 18 | } 19 | -------------------------------------------------------------------------------- /src/test/scala/scalarank/ranker/OracleRankerSpec.scala: -------------------------------------------------------------------------------- 1 | package scalarank.ranker 2 | 3 | import org.scalatest.FlatSpec 4 | 5 | import scalarank.{TestData, metrics} 6 | import scalarank.datapoint.{Datapoint, Relevance} 7 | import scalarank.metrics._ 8 | 9 | /** 10 | * Test specification for the Oracle ranker 11 | */ 12 | class OracleRankerSpec extends RankerSpec { 13 | 14 | "An Oracle ranker" should "rank perfectly on our test data" in { 15 | val oracle = new OracleRanker[Datapoint with Relevance] 16 | val data = TestData.featureless 17 | oracle.train(Iterable.empty) 18 | val ranking = oracle.rank(data) 19 | assert((ranking, ranking.drop(1)).zipped.forall { case (x,y) => x.relevance >= y.relevance }) 20 | } 21 | 22 | it should "have perfect nDCG on our test data" in { 23 | val oracle = new OracleRanker[Datapoint with Relevance] 24 | val data = TestData.featureless 25 | oracle.train(Iterable.empty) 26 | val ranking = oracle.rank(data) 27 | assert(metrics.ndcg(ranking) == 1.0) 28 | } 29 | 30 | it should "report appropriate nDCG results on MQ2008 Fold 1" in { 31 | testRanker(new OracleRanker(), ndcg, "nDCG") 32 | } 33 | 34 | } 35 | -------------------------------------------------------------------------------- /src/test/scala/scalarank/ranker/RandomRankerSpec.scala: -------------------------------------------------------------------------------- 1 | package scalarank.ranker 2 | 3 | import scalarank.metrics._ 4 | 5 | /** 6 | * Test specification for the Random ranker 7 | */ 8 | class RandomRankerSpec extends RankerSpec { 9 | 10 | "A random ranker" should "report appropriate nDCG results on MQ2008 Fold 1" in { 11 | testRanker(new RandomRanker(42), ndcg, "nDCG") 12 | } 13 | 14 | } 15 | -------------------------------------------------------------------------------- /src/test/scala/scalarank/ranker/RankNetRankerSpec.scala: -------------------------------------------------------------------------------- 1 | package scalarank.ranker 2 | 3 | import org.nd4j.linalg.api.ndarray.INDArray 4 | import org.nd4j.linalg.factory.Nd4j 5 | import org.nd4s.Implicits._ 6 | import org.scalatest.{FlatSpec, Matchers} 7 | 8 | import scalarank.datapoint.SVMRankDatapoint 9 | import scalarank.metrics._ 10 | import scalarank.{TestData, metrics} 11 | 12 | /** 13 | * Test specification for the Linear Regression ranker 14 | */ 15 | class RankNetRankerSpec extends RankerSpec with GradientCheck with Matchers { 16 | 17 | "A RankNet ranker" should "report appropriate nDCG results on MQ2008 Fold 1" in { 18 | testRanker(new RankNetRanker(featureSize, seed=42), ndcg, "nDCG") 19 | } 20 | 21 | "A RankNet loss function" should "be approximately log(2) when correctly predicted" in { 22 | 23 | // Create loss 24 | val loss = new RankNetLoss() 25 | 26 | // Single correctly predicted value 27 | val labels = Nd4j.create(Array(0.0, 0.0, 0.0)) 28 | val outputs = Nd4j.create(Array(0.0, 0.0, 0.0)) 29 | 30 | // Compute cost 31 | val cost = loss.computeScore(labels, outputs, "identity", null, true) 32 | assert(Math.abs(cost - Math.log(2.0)) < 0.0000001) 33 | 34 | } 35 | 36 | it should "succesfully perform the gradient limit check" in { 37 | 38 | // Create loss 39 | val loss = new RankNetLoss() 40 | 41 | // Set up labels and x sample data 42 | val labels = Nd4j.create(Array(0.0, 1.0, 0.0, 4.0)) 43 | val x = Nd4j.create(Array(0.1, -2.0, 7.0, 3.4)) 44 | 45 | // Check gradient 46 | val grad = -loss.computeGradient(labels, x, "identity", null) 47 | def f(x: INDArray): INDArray = loss.computeScoreArray(labels, x, "identity", null) 48 | val limits = gradientLimits(grad, x, f) 49 | info(limits.mkString(" > ")) 50 | limits.sliding(2).foreach { case Array(l1, l2) => assert(l1 > l2) } 51 | 52 | } 53 | 54 | it should "succesfully compute both the gradient and cost" in { 55 | 56 | // Create loss 57 | val loss = new RankNetLoss() 58 | 59 | // Set up labels and x sample data 60 | val labels = Nd4j.create(Array(0.0, 1.0, 0.0, 4.0)) 61 | val x = Nd4j.create(Array(0.1, -2.0, 7.0, 3.4)) 62 | 63 | // Compute the gradient and score 64 | val gradient = loss.computeGradient(labels, x, "identity", null) 65 | val score = loss.computeScore(labels, x, "identity", null, average=true) 66 | val gradientAndScore = loss.computeGradientAndScore(labels, x, "identity", null, average=true) 67 | 68 | // Check computation 69 | gradientAndScore.getFirst shouldBe score 70 | gradientAndScore.getSecond shouldBe gradient 71 | 72 | } 73 | 74 | } 75 | -------------------------------------------------------------------------------- /src/test/scala/scalarank/ranker/RankerSpec.scala: -------------------------------------------------------------------------------- 1 | package scalarank.ranker 2 | 3 | import org.scalatest.FlatSpec 4 | 5 | import scala.collection.mutable 6 | import scala.reflect.ClassTag 7 | import scalarank.TestData 8 | import scalarank.metrics.{averagePrecision, mean, meanAtK, ndcg} 9 | import scalarank.datapoint.{Datapoint, Query, Relevance, SVMRankDatapoint} 10 | 11 | /** 12 | * Testing ranker performance 13 | */ 14 | class RankerSpec extends FlatSpec { 15 | 16 | val trainData = TestData.sampleTrainData 17 | val testData = TestData.sampleTestData 18 | val featureSize = trainData(0).datapoints(0).features.length() 19 | 20 | /** 21 | * Tests a ranker by training it on our training set and testing it on our test set 22 | * 23 | * @param ranker The ranker to train and evaluate 24 | * @param metric The metric to score by 25 | */ 26 | protected def testRanker(ranker: Ranker[SVMRankDatapoint, SVMRankDatapoint], 27 | metric: Seq[SVMRankDatapoint] => Double, 28 | metricName: String = ""): Unit = { 29 | ranker.train(trainData) 30 | val rankings = testData.map(d => ranker.rank(d.datapoints)) 31 | Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).foreach { k => 32 | val result = meanAtK(rankings, metric, k) 33 | info(s"$metricName@${k.toString.padTo(4, ' ')} = $result") 34 | } 35 | info(s"$metricName mean = ${mean(rankings, metric)}") 36 | } 37 | 38 | } 39 | --------------------------------------------------------------------------------