├── NOTICE ├── project ├── build.properties └── plugins.sbt ├── catalog-info.yaml ├── .github ├── dependabot.yml └── workflows │ ├── release.yml │ └── ci.yml ├── .gitignore ├── .scalafix.conf ├── .scalafmt.conf ├── core └── src │ ├── main │ └── scala │ │ └── com │ │ └── spotify │ │ └── noether │ │ ├── package.scala │ │ ├── Prediction.scala │ │ ├── LogLoss.scala │ │ ├── ErrorRateSummary.scala │ │ ├── BinaryConfusionMatrix.scala │ │ ├── ConfusionMatrix.scala │ │ ├── MeanAveragePrecision.scala │ │ ├── MultiAggregatorMap.scala │ │ ├── PrecisionAtK.scala │ │ ├── NdcgAtK.scala │ │ ├── CalibrationHistogram.scala │ │ ├── AUC.scala │ │ └── ClassificationReport.scala │ └── test │ └── scala │ └── com │ └── spotify │ └── noether │ ├── MeanAveragePrecisionTest.scala │ ├── RankingData.scala │ ├── BinaryConfusionMatrixTest.scala │ ├── LogLossTest.scala │ ├── ErrorRateSummaryTest.scala │ ├── NdcgAtKTest.scala │ ├── PrecisionAtKTest.scala │ ├── AggregatorTest.scala │ ├── ConfusionMatrixTest.scala │ ├── AUCTest.scala │ ├── ClassificationReportTest.scala │ └── CalibrationHistogramTest.scala ├── tfx └── src │ ├── main │ ├── scala │ │ └── com │ │ │ └── spotify │ │ │ └── noether │ │ │ └── tfx │ │ │ ├── package.scala │ │ │ ├── Tfma.scala │ │ │ ├── TfmaConverter.scala │ │ │ └── TfmaImplicits.scala │ └── protobuf │ │ └── metrics_for_slice.proto │ └── test │ └── scala │ └── com │ └── spotify │ └── noether │ └── tfx │ └── TfmaConverterTest.scala ├── examples └── src │ ├── test │ └── scala │ │ └── com │ │ └── spotify │ │ └── noether │ │ ├── AggregatorExampleTest.scala │ │ └── MultiAggregatorMapTests.scala │ └── main │ └── scala │ └── com │ └── spotify │ └── noether │ └── AggregatorExample.scala ├── benchmark ├── src │ └── main │ │ └── scala │ │ └── com │ │ └── spotify │ │ └── noether │ │ └── benchmark │ │ └── CalibrationHistogramCreateBenchmark.scala └── README.md ├── README.md └── LICENSE /NOTICE: -------------------------------------------------------------------------------- 1 | Noether 2 | Copyright 2018 Spotify AB -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.8.2 2 | -------------------------------------------------------------------------------- /catalog-info.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: backstage.io/v1alpha1 2 | kind: Resource 3 | metadata: 4 | name: noether 5 | spec: 6 | type: resource 7 | owner: flatmap 8 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: github-actions 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | time: "04:00" 8 | open-pull-requests-limit: 10 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .bigquery 3 | .repl 4 | target 5 | editorial 6 | top 7 | data 8 | normed 9 | tracks 10 | .DS_Store 11 | *.pyc 12 | *.spl 13 | *.spi 14 | local_data 15 | *.iml 16 | 17 | .metals 18 | .bloop 19 | .bsp 20 | -------------------------------------------------------------------------------- /.scalafix.conf: -------------------------------------------------------------------------------- 1 | rules = [ 2 | RemoveUnused, 3 | LeakingImplicitClassVal 4 | ProcedureSyntax 5 | ExplicitResultTypes 6 | ] 7 | 8 | ExplicitResultTypes.memberKind = [Def, Val, Var] 9 | ExplicitResultTypes.memberVisibility = [Public] 10 | ExplicitResultTypes.skipSimpleDefinition = false 11 | -------------------------------------------------------------------------------- /.scalafmt.conf: -------------------------------------------------------------------------------- 1 | version = "3.7.4" 2 | maxColumn = 100 3 | runner.dialect = scala213 4 | 5 | binPack.literalArgumentLists = true 6 | 7 | continuationIndent { 8 | callSite = 2 9 | defnSite = 2 10 | } 11 | 12 | newlines { 13 | alwaysBeforeMultilineDef = false 14 | sometimesBeforeColonInMethodReturnType = true 15 | } 16 | 17 | docstrings.style = Asterisk 18 | 19 | project.git = false 20 | 21 | rewrite { 22 | rules = [PreferCurlyFors, RedundantBraces, RedundantParens, SortImports, SortModifiers] 23 | redundantBraces.generalExpressions = false 24 | redundantBraces.maxLines = 1 25 | } 26 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | push: 4 | branches: [main] 5 | tags: ["*"] 6 | jobs: 7 | publish: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v3 11 | - uses: actions/setup-java@v3 12 | with: 13 | distribution: temurin 14 | java-version: 17 15 | cache: sbt 16 | - uses: olafurpg/setup-gpg@v3 17 | - name: Publish ${{ github.ref }} 18 | run: sbt ci-release 19 | env: 20 | PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} 21 | PGP_SECRET: ${{ secrets.PGP_SECRET }} 22 | SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }} 23 | SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }} 24 | -------------------------------------------------------------------------------- /core/src/main/scala/com/spotify/noether/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify 19 | 20 | package object noether { 21 | type RankingPrediction[T] = Prediction[Array[T], Array[T]] 22 | } 23 | -------------------------------------------------------------------------------- /tfx/src/main/scala/com/spotify/noether/tfx/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | package object tfx extends TfmaImplicits { 21 | type BinaryPred = Prediction[Boolean, Double] 22 | } 23 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.github.sbt" % "sbt-unidoc" % "0.5.0") 2 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.0.8") 3 | addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.0") 4 | addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.5") 5 | addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "1.1.2") 6 | addSbtPlugin("com.github.sbt" % "sbt-ghpages" % "0.7.0") 7 | addSbtPlugin("com.typesafe.sbt" % "sbt-site" % "1.4.1") 8 | addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.6") 9 | addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.11.0") 10 | addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.5.12") 11 | addSbtPlugin("ch.epfl.lamp" % "sbt-dotty" % "0.5.5") 12 | 13 | libraryDependencies ++= Seq( 14 | "com.thesamet.scalapb" %% "compilerplugin" % "0.11.13" 15 | ) 16 | 17 | dependencyOverrides += "org.scala-lang.modules" %% "scala-xml" % "2.1.0" 18 | -------------------------------------------------------------------------------- /examples/src/test/scala/com/spotify/noether/AggregatorExampleTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import org.scalatest.flatspec.AnyFlatSpec 21 | import org.scalatest.matchers.should.Matchers 22 | 23 | class AggregatorExampleTest extends AnyFlatSpec with Matchers { 24 | it should "not fail when executing example" in { 25 | AggregatorExample.main(Array.empty) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /core/src/test/scala/com/spotify/noether/MeanAveragePrecisionTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import org.scalactic.{Equality, TolerantNumerics} 21 | 22 | class MeanAveragePrecisionTest extends AggregatorTest { 23 | import RankingData._ 24 | 25 | implicit private val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.1) 26 | 27 | it should "compute map for rankings" in { 28 | assert(run(MeanAveragePrecision[Int]())(rankingData) === 0.355026) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /core/src/main/scala/com/spotify/noether/Prediction.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | /** 21 | * Generic Prediction Object used by most aggregators 22 | * 23 | * @param actual 24 | * Real value for this entry. Also normally seen as label. 25 | * @param predicted 26 | * Predicted value. Can be a class or a score depending on the aggregator. 27 | * @tparam L 28 | * Type of the Real Value 29 | * @tparam S 30 | * Type of the Predicted Value 31 | */ 32 | final case class Prediction[L, S](actual: L, predicted: S) { 33 | override def toString: String = s"$actual,$predicted" 34 | } 35 | -------------------------------------------------------------------------------- /core/src/test/scala/com/spotify/noether/RankingData.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | object RankingData { 21 | def rankingData: Seq[RankingPrediction[Int]] = 22 | Seq( 23 | Prediction(Array(1, 2, 3, 4, 5), Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5)), 24 | Prediction(Array(1, 2, 3), Array(4, 1, 5, 6, 2, 7, 3, 8, 9, 10)), 25 | Prediction(Array.empty[Int], Array(1, 2, 3, 4, 5)) 26 | ) 27 | 28 | def smallRankingData: Seq[RankingPrediction[Int]] = 29 | Seq( 30 | Prediction(Array(1, 2, 3, 4, 5), Array(1, 6, 2)), 31 | Prediction(Array(1, 2, 3), Array.empty[Int]) 32 | ) 33 | } 34 | -------------------------------------------------------------------------------- /core/src/main/scala/com/spotify/noether/LogLoss.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import com.twitter.algebird.{Aggregator, Semigroup} 21 | 22 | /** 23 | * LogLoss of the predictions. 24 | */ 25 | case object LogLoss extends Aggregator[Prediction[Int, List[Double]], (Double, Long), Double] { 26 | def prepare(input: Prediction[Int, List[Double]]): (Double, Long) = 27 | (math.log(input.predicted(input.actual)), 1L) 28 | 29 | def semigroup: Semigroup[(Double, Long)] = 30 | implicitly[Semigroup[(Double, Long)]] 31 | 32 | def present(score: (Double, Long)): Double = -1 * (score._1 / score._2) 33 | } 34 | -------------------------------------------------------------------------------- /core/src/test/scala/com/spotify/noether/BinaryConfusionMatrixTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | class BinaryConfusionMatrixTest extends AggregatorTest { 21 | it should "return correct scores" in { 22 | val data = List( 23 | (false, 0.1), 24 | (false, 0.6), 25 | (false, 0.2), 26 | (true, 0.2), 27 | (true, 0.8), 28 | (true, 0.7), 29 | (true, 0.6) 30 | ).map { case (pred, s) => Prediction(pred, s) } 31 | 32 | val matrix = run(BinaryConfusionMatrix())(data) 33 | 34 | assert(matrix(1, 1) === 3L) 35 | assert(matrix(0, 1) === 1L) 36 | assert(matrix(1, 0) === 1L) 37 | assert(matrix(0, 0) === 2L) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /examples/src/main/scala/com/spotify/noether/AggregatorExample.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import com.twitter.algebird.MultiAggregator 21 | 22 | object AggregatorExample { 23 | def main(args: Array[String]): Unit = { 24 | val multiAggregator = 25 | MultiAggregator((AUC(ROC), AUC(PR), ClassificationReport(), BinaryConfusionMatrix())) 26 | .andThenPresent { case (roc, pr, report, cm) => 27 | (roc, pr, report.accuracy, report.recall, report.precision, cm(1, 1), cm(0, 0)) 28 | } 29 | 30 | val predictions = List(Prediction(false, 0.1), Prediction(false, 0.6), Prediction(true, 0.9)) 31 | 32 | multiAggregator.apply(predictions) 33 | () 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /core/src/test/scala/com/spotify/noether/LogLossTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import org.scalactic.{Equality, TolerantNumerics} 21 | 22 | class LogLossTest extends AggregatorTest { 23 | implicit private val doubleEq: Equality[Double] = 24 | TolerantNumerics.tolerantDoubleEquality(0.1) 25 | private val classes = 10 26 | private def s(idx: Int, score: Double): List[Double] = 27 | 0.until(classes).map(i => if (i == idx) score else 0.0).toList 28 | 29 | it should "return correct scores" in { 30 | val data = List((s(0, 0.8), 0), (s(1, 0.6), 1), (s(2, 0.7), 2)).map { case (scores, label) => 31 | Prediction(label, scores) 32 | } 33 | 34 | assert(run(LogLoss)(data) === 0.363548039673) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /core/src/main/scala/com/spotify/noether/ErrorRateSummary.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import com.twitter.algebird.{Aggregator, Semigroup} 21 | 22 | /** 23 | * Measurement of what percentage of values were predicted incorrectly. 24 | */ 25 | case object ErrorRateSummary 26 | extends Aggregator[Prediction[Int, List[Double]], (Double, Long), Double] { 27 | def prepare(input: Prediction[Int, List[Double]]): (Double, Long) = { 28 | val best = input.predicted.zipWithIndex.maxBy(_._1)._2 29 | if (best == input.actual) (0.0, 1L) else (1.0, 1L) 30 | } 31 | 32 | def semigroup: Semigroup[(Double, Long)] = 33 | implicitly[Semigroup[(Double, Long)]] 34 | 35 | def present(score: (Double, Long)): Double = score._1 / score._2 36 | } 37 | -------------------------------------------------------------------------------- /tfx/src/main/scala/com/spotify/noether/tfx/Tfma.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether.tfx 19 | 20 | import com.twitter.algebird.Aggregator 21 | 22 | object Tfma { 23 | trait ConversionOps[A, B, T <: Aggregator[A, B, _]] { 24 | val self: T 25 | val converter: TfmaConverter[A, B, T] 26 | def asTfmaProto: Aggregator[A, B, EvalResult] = converter.convertToTfmaProto(self) 27 | } 28 | 29 | object ConversionOps { 30 | def apply[A, B, T <: Aggregator[A, B, _]]( 31 | instance: T, 32 | tfmaConverter: TfmaConverter[A, B, T] 33 | ): ConversionOps[A, B, T] = 34 | new ConversionOps[A, B, T] { 35 | override val self: T = instance 36 | override val converter: TfmaConverter[A, B, T] = tfmaConverter 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /core/src/test/scala/com/spotify/noether/ErrorRateSummaryTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import org.scalactic.TolerantNumerics 21 | import org.scalactic.Equality 22 | 23 | class ErrorRateSummaryTest extends AggregatorTest { 24 | implicit private val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.1) 25 | private val classes = 10 26 | private def s(idx: Int): List[Double] = 27 | 0.until(classes).map(i => if (i == idx) 1.0 else 0.0).toList 28 | 29 | it should "return correct scores" in { 30 | val data = 31 | List((s(1), 1), (s(3), 1), (s(5), 5), (s(2), 3), (s(0), 0), (s(8), 1)).map { 32 | case (scores, label) => Prediction(label, scores) 33 | } 34 | 35 | assert(run(ErrorRateSummary)(data) === 0.5) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /examples/src/test/scala/com/spotify/noether/MultiAggregatorMapTests.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import com.twitter.algebird._ 21 | import org.scalatest.flatspec.AnyFlatSpec 22 | import org.scalatest.matchers.must.Matchers._ 23 | 24 | class MultiAggregatorMapTests extends AnyFlatSpec { 25 | 26 | it must "aggregate into a list of individual aggregator results" in { 27 | 28 | val multiListAgg = MultiAggregatorMap[Long, Long, Long]( 29 | List( 30 | "min" -> Aggregator.min, 31 | "max" -> Aggregator.max, 32 | "size" -> Aggregator.size 33 | ) 34 | ) 35 | 36 | val result = multiListAgg(List(0, 1, 2, 3, 4, 5)) 37 | 38 | result("min") mustBe 0L 39 | result("max") mustBe 5L 40 | result("size") mustBe 6L 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /core/src/test/scala/com/spotify/noether/NdcgAtKTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import org.scalactic.{Equality, TolerantNumerics} 21 | 22 | class NdcgAtKTest extends AggregatorTest { 23 | import RankingData._ 24 | implicit private val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.1) 25 | 26 | it should "compute ndcg for rankings" in { 27 | assert(run(NdcgAtK[Int](3))(rankingData) === 1.0 / 3) 28 | assert(run(NdcgAtK[Int](5))(rankingData) === 0.328788) 29 | assert(run(NdcgAtK[Int](10))(rankingData) === 0.487913) 30 | assert(run(NdcgAtK[Int](15))(rankingData) === run(NdcgAtK[Int](10))(rankingData)) 31 | } 32 | 33 | it should "compute ndcg for rankings with few predictions" in { 34 | assert(run(NdcgAtK[Int](1))(smallRankingData) === 0.5) 35 | assert(run(NdcgAtK[Int](2))(smallRankingData) === 0.30657) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /tfx/src/main/scala/com/spotify/noether/tfx/TfmaConverter.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether.tfx 19 | 20 | import com.twitter.algebird.Aggregator 21 | import tensorflow_model_analysis.MetricsForSliceOuterClass._ 22 | 23 | trait TfmaConverter[A, B, T <: Aggregator[A, B, _]] { 24 | def convertToTfmaProto(underlying: T): Aggregator[A, B, EvalResult] 25 | } 26 | 27 | sealed trait Plot { 28 | def plotData: PlotsForSlice 29 | } 30 | object Plot { 31 | case class CalibrationHistogram(plotData: PlotsForSlice) extends Plot 32 | case class ConfusionMatrix(plotData: PlotsForSlice) extends Plot 33 | } 34 | 35 | case class EvalResult(metrics: Option[MetricsForSlice], plots: Option[Plot]) 36 | object EvalResult { 37 | def apply(metrics: MetricsForSlice): EvalResult = EvalResult(Some(metrics), None) 38 | def apply(metrics: MetricsForSlice, plot: Plot): EvalResult = 39 | EvalResult(Some(metrics), Some(plot)) 40 | def apply(plot: Plot): EvalResult = EvalResult(None, Some(plot)) 41 | } 42 | -------------------------------------------------------------------------------- /core/src/main/scala/com/spotify/noether/BinaryConfusionMatrix.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import breeze.linalg.DenseMatrix 21 | import com.twitter.algebird.{Aggregator, Semigroup} 22 | 23 | /** 24 | * Special Case for a Binary Confusion Matrix to make it easier to compose with other binary 25 | * aggregators 26 | * 27 | * @param threshold 28 | * Threshold to apply on predictions 29 | */ 30 | case class BinaryConfusionMatrix(threshold: Double = 0.5) 31 | extends Aggregator[Prediction[Boolean, Double], Map[(Int, Int), Long], DenseMatrix[Long]] { 32 | private val confusionMatrix = ConfusionMatrix(Seq(0, 1)) 33 | 34 | def prepare(input: Prediction[Boolean, Double]): Map[(Int, Int), Long] = { 35 | val pred = Prediction(if (input.actual) 1 else 0, if (input.predicted > threshold) 1 else 0) 36 | confusionMatrix.prepare(pred) 37 | } 38 | def semigroup: Semigroup[Map[(Int, Int), Long]] = confusionMatrix.semigroup 39 | def present(m: Map[(Int, Int), Long]): DenseMatrix[Long] = 40 | confusionMatrix.present(m) 41 | } 42 | -------------------------------------------------------------------------------- /core/src/main/scala/com/spotify/noether/ConfusionMatrix.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import breeze.linalg.DenseMatrix 21 | import com.twitter.algebird.{Aggregator, Semigroup} 22 | 23 | /** 24 | * Generic Consfusion Matrix Aggregator for any dimension. Thresholds must be applied to make a 25 | * prediction prior to using this aggregator. 26 | * 27 | * @param labels 28 | * List of possible label values 29 | */ 30 | final case class ConfusionMatrix(labels: Seq[Int]) 31 | extends Aggregator[Prediction[Int, Int], Map[(Int, Int), Long], DenseMatrix[Long]] { 32 | def prepare(input: Prediction[Int, Int]): Map[(Int, Int), Long] = 33 | Map((input.predicted, input.actual) -> 1L) 34 | 35 | def semigroup: Semigroup[Map[(Int, Int), Long]] = 36 | Semigroup.mapSemigroup[(Int, Int), Long] 37 | 38 | def present(m: Map[(Int, Int), Long]): DenseMatrix[Long] = { 39 | val mat = DenseMatrix.zeros[Long](labels.size, labels.size) 40 | for { 41 | i <- labels 42 | j <- labels 43 | } { 44 | mat(i, j) = m.getOrElse((i, j), 0L) 45 | } 46 | mat 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /core/src/test/scala/com/spotify/noether/PrecisionAtKTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import org.scalactic.{Equality, TolerantNumerics} 21 | 22 | class PrecisionAtKTest extends AggregatorTest { 23 | import RankingData._ 24 | implicit private val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.1) 25 | 26 | it should "compute precisionAtK for rankings" in { 27 | assert(run(PrecisionAtK[Int](1))(rankingData) === 1.0 / 3) 28 | assert(run(PrecisionAtK[Int](2))(rankingData) === 1.0 / 3) 29 | assert(run(PrecisionAtK[Int](3))(rankingData) === 1.0 / 3) 30 | assert(run(PrecisionAtK[Int](4))(rankingData) === 0.75 / 3) 31 | assert(run(PrecisionAtK[Int](5))(rankingData) === 0.8 / 3) 32 | assert(run(PrecisionAtK[Int](10))(rankingData) === 0.8 / 3) 33 | assert(run(PrecisionAtK[Int](15))(rankingData) === 8.0 / 45) 34 | } 35 | 36 | it should "compute precisionAtK for rankings with few predictions" in { 37 | assert(run(PrecisionAtK[Int](1))(smallRankingData) === 0.5) 38 | assert(run(PrecisionAtK[Int](2))(smallRankingData) === 0.25) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: [push, pull_request] 3 | 4 | jobs: 5 | checks: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v3 9 | - name: Java 11 setup 10 | uses: actions/setup-java@v3 11 | with: 12 | distribution: temurin 13 | java-version: 11 14 | cache: sbt 15 | - run: sbt "; +scalafmtCheckAll; scalafmtSbtCheck" "; scalafixEnable; scalafixAll --check" 16 | test: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v3 20 | - name: Java ${{matrix.java}} setup 21 | uses: actions/setup-java@v3 22 | with: 23 | distribution: temurin 24 | java-version: ${{matrix.java}} 25 | cache: sbt 26 | - if: startsWith(matrix.scala,'2.13') 27 | run: | 28 | sbt coverage "++${{matrix.scala}} test" coverageReport 29 | bash <(curl -s https://codecov.io/bash) 30 | - if: "!startsWith(matrix.scala,'2.13')" 31 | run: sbt "++${{matrix.scala}} test" 32 | strategy: 33 | matrix: 34 | java: 35 | - 11 36 | scala: 37 | - 2.11.12 38 | - 2.12.17 39 | - 2.13.10 40 | - 3.2.2 41 | mimaReport: 42 | runs-on: ubuntu-latest 43 | steps: 44 | - uses: actions/checkout@v3 45 | - name: Java ${{matrix.java}} setup 46 | uses: actions/setup-java@v3 47 | with: 48 | distribution: temurin 49 | java-version: ${{matrix.java}} 50 | cache: sbt 51 | - run: sbt "++${{matrix.scala}} mimaReportBinaryIssues" 52 | strategy: 53 | matrix: 54 | java: 55 | - 11 56 | scala: 57 | - 2.11.12 58 | - 2.12.17 59 | - 2.13.10 60 | -------------------------------------------------------------------------------- /core/src/main/scala/com/spotify/noether/MeanAveragePrecision.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import com.twitter.algebird.{Aggregator, Semigroup} 21 | 22 | /** 23 | * Returns the mean average precision (MAP) of all the predictions. If a query has an empty ground 24 | * truth set, the average precision will be zero 25 | */ 26 | case class MeanAveragePrecision[T]() 27 | extends Aggregator[RankingPrediction[T], (Double, Long), Double] { 28 | def prepare(input: RankingPrediction[T]): (Double, Long) = { 29 | val labSet = input.actual.toSet 30 | if (labSet.nonEmpty) { 31 | var i = 0 32 | var cnt = 0 33 | var precSum = 0.0 34 | val n = input.predicted.length 35 | while (i < n) { 36 | if (labSet.contains(input.predicted(i))) { 37 | cnt += 1 38 | precSum += cnt.toDouble / (i + 1) 39 | } 40 | i += 1 41 | } 42 | (precSum / labSet.size, 1L) 43 | } else { 44 | (0.0, 1L) 45 | } 46 | } 47 | 48 | def semigroup: Semigroup[(Double, Long)] = implicitly[Semigroup[(Double, Long)]] 49 | 50 | def present(score: (Double, Long)): Double = score._1 / score._2 51 | } 52 | -------------------------------------------------------------------------------- /core/src/test/scala/com/spotify/noether/AggregatorTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} 21 | 22 | import com.twitter.algebird.Aggregator 23 | import org.scalatest.flatspec.AnyFlatSpec 24 | import org.scalatest.matchers.should.Matchers 25 | 26 | trait AggregatorTest extends AnyFlatSpec with Matchers { 27 | def run[A, B, C](aggregator: Aggregator[A, B, C])(as: Seq[A]): C = { 28 | val bs = as.map(aggregator.prepare _ compose ensureSerializable) 29 | val b = ensureSerializable(aggregator.reduce(bs)) 30 | ensureSerializable(aggregator.present(b)) 31 | } 32 | 33 | private def serializeToByteArray(value: Any): Array[Byte] = { 34 | val buffer = new ByteArrayOutputStream() 35 | val oos = new ObjectOutputStream(buffer) 36 | oos.writeObject(value) 37 | buffer.toByteArray 38 | } 39 | 40 | private def deserializeFromByteArray(encodedValue: Array[Byte]): AnyRef = { 41 | val ois = new ObjectInputStream(new ByteArrayInputStream(encodedValue)) 42 | ois.readObject() 43 | } 44 | 45 | private def ensureSerializable[T](value: T): T = 46 | deserializeFromByteArray(serializeToByteArray(value)).asInstanceOf[T] 47 | } 48 | -------------------------------------------------------------------------------- /benchmark/src/main/scala/com/spotify/noether/benchmark/CalibrationHistogramCreateBenchmark.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | package benchmark 20 | 21 | import com.spotify.noether.benchmark.CalibrationHistogramCreateBenchmark.CalibrationHistogramState 22 | import org.openjdk.jmh.annotations._ 23 | 24 | import scala.util.Random 25 | 26 | object PredictionUtils { 27 | def generatePredictions(nbPrediction: Int): Seq[Prediction[Boolean, Double]] = 28 | Seq.fill(nbPrediction)(Prediction(Random.nextBoolean(), Random.nextDouble())) 29 | } 30 | 31 | object CalibrationHistogramCreateBenchmark { 32 | @State(Scope.Benchmark) 33 | class CalibrationHistogramState() { 34 | @Param(Array("100", "1000", "3000")) 35 | var nbElement = 0 36 | 37 | @Param(Array("100", "200", "300")) 38 | var nbBucket = 0 39 | 40 | @Param(Array("0.1", "0.2", "0.3")) 41 | var lowerBound = 0.0 42 | 43 | @Param(Array("0.2", "0.4", "0.5")) 44 | var upperBound = 0.0 45 | 46 | var histogram: CalibrationHistogram = _ 47 | 48 | @Setup 49 | def setup(): Unit = 50 | histogram = CalibrationHistogram(lowerBound, upperBound, nbBucket) 51 | } 52 | } 53 | 54 | class CalibrationHistogramCreateBenchmark { 55 | @Benchmark 56 | def createCalibrationHistogram(calibrationHistogramState: CalibrationHistogramState): Double = 57 | calibrationHistogramState.histogram.bucketSize 58 | } 59 | -------------------------------------------------------------------------------- /core/src/test/scala/com/spotify/noether/ConfusionMatrixTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import breeze.linalg.DenseMatrix 21 | 22 | class ConfusionMatrixTest extends AggregatorTest { 23 | it should "return correct confusion matrix" in { 24 | val data = 25 | List( 26 | (0, 0), 27 | (0, 0), 28 | (0, 0), 29 | (0, 1), 30 | (0, 1), 31 | (1, 0), 32 | (1, 0), 33 | (1, 0), 34 | (1, 0), 35 | (1, 1), 36 | (1, 1), 37 | (2, 1), 38 | (2, 2), 39 | (2, 2), 40 | (2, 2) 41 | ).map { case (p, a) => Prediction(a, p) } 42 | 43 | val labels = Seq(0, 1, 2) 44 | val actual = run(ConfusionMatrix(labels))(data) 45 | 46 | val mat = DenseMatrix.zeros[Long](labels.size, labels.size) 47 | mat(0, 0) = 3L 48 | mat(0, 1) = 2L 49 | mat(0, 2) = 0L 50 | mat(1, 0) = 4L 51 | mat(1, 1) = 2L 52 | mat(1, 2) = 0L 53 | mat(2, 0) = 0L 54 | mat(2, 1) = 1L 55 | mat(2, 2) = 3L 56 | assert(actual == mat) 57 | } 58 | 59 | it should "return correct scores" in { 60 | val data = List( 61 | (0, 0), 62 | (0, 1), 63 | (0, 0), 64 | (1, 0), 65 | (1, 1), 66 | (1, 1), 67 | (1, 1) 68 | ).map { case (s, pred) => Prediction(pred, s) } 69 | 70 | val matrix = run(ConfusionMatrix(Seq(0, 1)))(data) 71 | 72 | assert(matrix(1, 1) === 3L) 73 | assert(matrix(0, 1) === 1L) 74 | assert(matrix(1, 0) === 1L) 75 | assert(matrix(0, 0) === 2L) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /core/src/main/scala/com/spotify/noether/MultiAggregatorMap.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import com.twitter.algebird.{Aggregator, Semigroup} 21 | import scala.collection.mutable.ArrayBuffer 22 | 23 | /** 24 | * Aggregator which combines an unbounded list of other aggregators. Each aggregator in the list is 25 | * tagged by a string. The string(aka name) could be used to retrieve the aggregated value from the 26 | * Map emitted by the "present" function. 27 | */ 28 | case class MultiAggregatorMap[-A, B, +C](aggregatorsMap: List[(String, Aggregator[A, B, C])]) 29 | extends Aggregator[A, List[B], Map[String, C]] { 30 | 31 | private[this] val aggregators = aggregatorsMap.map(_._2) 32 | 33 | def prepare(input: A): List[B] = 34 | aggregators.map(_.prepare(input)) 35 | 36 | def semigroup: Semigroup[List[B]] = new Semigroup[List[B]] { 37 | def plus(x: List[B], y: List[B]): List[B] = { 38 | var i = 0 39 | val resultList = new ArrayBuffer[B] 40 | while (i < aggregators.length) { 41 | resultList.append(aggregators(i).semigroup.plus(x(i), y(i))) 42 | i += 1 43 | } 44 | resultList.toList 45 | } 46 | } 47 | 48 | def present(reduction: List[B]): Map[String, C] = { 49 | var i = 0 50 | val resultList = new ArrayBuffer[(String, C)] 51 | val aggregatorsMapKeys = aggregatorsMap.map(_._1) 52 | while (i < aggregators.length) { 53 | resultList.append(aggregatorsMapKeys(i) -> aggregators(i).present(reduction(i))) 54 | i += 1 55 | } 56 | resultList.toMap 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /core/src/test/scala/com/spotify/noether/AUCTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import org.scalactic.TolerantNumerics 21 | import org.scalactic.Equality 22 | 23 | class AUCTest extends AggregatorTest { 24 | implicit private val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.1) 25 | 26 | private val data = 27 | List( 28 | (0.1, false), 29 | (0.1, true), 30 | (0.4, false), 31 | (0.6, false), 32 | (0.6, true), 33 | (0.6, true), 34 | (0.8, true) 35 | ).map { case (s, pred) => Prediction(pred, s) } 36 | 37 | it should "return ROC AUC" in { 38 | assert(run(AUC(ROC, samples = 50))(data) === 0.7) 39 | } 40 | 41 | it should "return PR AUC" in { 42 | assert(run(AUC(PR, samples = 50))(data) === 0.83) 43 | } 44 | 45 | it should "return points of a PR Curve" in { 46 | val expected = Array( 47 | (0.0, 1.0), 48 | (0.0, 1.0), 49 | (0.25, 1.0), 50 | (0.75, 0.75), 51 | (0.75, 0.6), 52 | (1.0, 0.5714285714285714) 53 | ).map { case (a, b) => MetricCurvePoint(a, b) } 54 | assert(run(Curve(PR, samples = 5))(data).points === MetricCurvePoints(expected).points) 55 | } 56 | 57 | it should "return points of a ROC Curve" in { 58 | val expected = Array( 59 | (0.0, 0.0), 60 | (0.0, 0.25), 61 | (0.3333333333333333, 0.75), 62 | (0.6666666666666666, 0.75), 63 | (1.0, 1.0), 64 | (1.0, 1.0) 65 | ).map { case (a, b) => MetricCurvePoint(a, b) } 66 | assert(run(Curve(ROC, samples = 5))(data).points === MetricCurvePoints(expected).points) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /core/src/test/scala/com/spotify/noether/ClassificationReportTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import org.scalactic.TolerantNumerics 21 | import org.scalactic.Equality 22 | 23 | class ClassificationReportTest extends AggregatorTest { 24 | implicit private val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.1) 25 | 26 | it should "return correct scores" in { 27 | val data = List( 28 | (0.1, false), 29 | (0.1, true), 30 | (0.4, false), 31 | (0.6, false), 32 | (0.6, true), 33 | (0.6, true), 34 | (0.8, true) 35 | ).map { case (s, pred) => Prediction(pred, s) } 36 | 37 | val score = run(ClassificationReport())(data) 38 | 39 | assert(score.recall === 0.75) 40 | assert(score.precision === 0.75) 41 | assert(score.fscore === 0.75) 42 | assert(score.fpr === 0.333) 43 | } 44 | 45 | it should "support multiclass reports" in { 46 | val predictions = Seq( 47 | (0, 0), 48 | (0, 0), 49 | (0, 1), 50 | (1, 1), 51 | (1, 1), 52 | (1, 0), 53 | (1, 2), 54 | (2, 2), 55 | (2, 2), 56 | (2, 2) 57 | ).map { case (p, a) => Prediction(a, p) } 58 | 59 | val reports = run(MultiClassificationReport(Seq(0, 1, 2)))(predictions) 60 | 61 | val report0 = reports(0) 62 | assert(report0.recall == 2.0 / 3.0) 63 | assert(report0.precision == 2.0 / 3.0) 64 | val report1 = reports(1) 65 | assert(report1.recall == 2.0 / 3.0) 66 | assert(report1.precision == 0.5) 67 | val report2 = reports(2) 68 | assert(report2.recall == 0.75) 69 | assert(report2.precision == 1.0) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /core/src/main/scala/com/spotify/noether/PrecisionAtK.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import com.twitter.algebird.{Aggregator, Semigroup} 21 | 22 | /** 23 | * Compute the average precision of all the predictions, truncated at ranking position k. 24 | * 25 | * If for a prediction, the ranking algorithm returns n (n is less than k) results, the precision 26 | * value will be computed as #(relevant items retrieved) / k. This formula also applies when the 27 | * size of the ground truth set is less than k. 28 | * 29 | * If a prediction has an empty ground truth set, zero will be used as precision together 30 | * 31 | * See the following paper for detail: 32 | * 33 | * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen 34 | * 35 | * @param k 36 | * the position to compute the truncated precision, must be positive 37 | */ 38 | case class PrecisionAtK[T](k: Int) 39 | extends Aggregator[RankingPrediction[T], (Double, Long), Double] { 40 | require(k > 0, "ranking position k should be positive") 41 | def prepare(input: RankingPrediction[T]): (Double, Long) = { 42 | val labSet = input.actual.toSet 43 | if (labSet.nonEmpty) { 44 | val n = math.min(input.predicted.length, k) 45 | var i = 0 46 | var cnt = 0 47 | while (i < n) { 48 | if (labSet.contains(input.predicted(i))) { 49 | cnt += 1 50 | } 51 | i += 1 52 | } 53 | (cnt.toDouble / k, 1L) 54 | } else { 55 | (0.0, 1L) 56 | } 57 | } 58 | 59 | def semigroup: Semigroup[(Double, Long)] = implicitly[Semigroup[(Double, Long)]] 60 | 61 | def present(score: (Double, Long)): Double = score._1 / score._2 62 | } 63 | -------------------------------------------------------------------------------- /core/src/main/scala/com/spotify/noether/NdcgAtK.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import com.twitter.algebird.{Aggregator, Semigroup} 21 | 22 | /** 23 | * Compute the average NDCG value of all the predictions, truncated at ranking position k. The 24 | * discounted cumulative gain at position k is computed as: sum,,i=1,,^k^ (2^{relevance of ''i''th 25 | * item}^ - 1) / log(i + 1), and the NDCG is obtained by dividing the DCG value on the ground truth 26 | * set. In the current implementation, the relevance value is binary. If a query has an empty ground 27 | * truth set, zero will be used as ndcg 28 | * 29 | * See the following paper for detail: 30 | * 31 | * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen 32 | * 33 | * @param k 34 | * the position to compute the truncated ndcg, must be positive 35 | */ 36 | case class NdcgAtK[T](k: Int) extends Aggregator[RankingPrediction[T], (Double, Long), Double] { 37 | require(k > 0, "ranking position k should be positive") 38 | def prepare(input: RankingPrediction[T]): (Double, Long) = { 39 | val labSet = input.actual.toSet 40 | 41 | if (labSet.nonEmpty) { 42 | val labSetSize = labSet.size 43 | val n = math.min(math.max(input.predicted.length, labSetSize), k) 44 | var maxDcg = 0.0 45 | var dcg = 0.0 46 | var i = 0 47 | while (i < n) { 48 | val gain = 1.0 / math.log(i + 2.0) 49 | if (i < input.predicted.length && labSet.contains(input.predicted(i))) { 50 | dcg += gain 51 | } 52 | if (i < labSetSize) { 53 | maxDcg += gain 54 | } 55 | i += 1 56 | } 57 | (dcg / maxDcg, 1L) 58 | } else { 59 | (0.0, 1L) 60 | } 61 | } 62 | 63 | def semigroup: Semigroup[(Double, Long)] = implicitly[Semigroup[(Double, Long)]] 64 | 65 | def present(score: (Double, Long)): Double = score._1 / score._2 66 | } 67 | -------------------------------------------------------------------------------- /core/src/test/scala/com/spotify/noether/CalibrationHistogramTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | package com.spotify.noether 18 | 19 | import org.scalactic.{Equality, TolerantNumerics} 20 | 21 | class CalibrationHistogramTest extends AggregatorTest { 22 | it should "return correct histogram" in { 23 | implicit val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.001) 24 | val data = Seq( 25 | (0.15, 1.15), // lb 26 | (0.288, 1.288), // rounding error puts this in (0.249, 0.288) 27 | (0.30, 1.30), // (0.288, 0.3269) 28 | (0.36, 1.36), // (0.3269, 0.365) 29 | (0.555, 1.555), // (0.5219, 0.5609) 30 | (1.2, 2.2), // ub 31 | (0.7, 1.7) // ub 32 | ).map { case (p, a) => Prediction(a, p) } 33 | 34 | val actual = run(CalibrationHistogram(0.21, 0.60, 10))(data) 35 | 36 | val expected = List( 37 | CalibrationHistogramBucket(Double.NegativeInfinity, 0.21, 1.0, 1.15, 0.15), 38 | CalibrationHistogramBucket(0.21, 0.249, 0.0, 0.0, 0.0), 39 | CalibrationHistogramBucket(0.249, 0.288, 1.0, 1.288, 0.288), 40 | CalibrationHistogramBucket(0.288, 0.327, 1.0, 1.30, 0.30), 41 | CalibrationHistogramBucket(0.327, 0.366, 1.0, 1.36, 0.36), 42 | CalibrationHistogramBucket(0.366, 0.405, 0.0, 0.0, 0.0), 43 | CalibrationHistogramBucket(0.405, 0.4449, 0.0, 0.0, 0.0), 44 | CalibrationHistogramBucket(0.444, 0.483, 0.0, 0.0, 0.0), 45 | CalibrationHistogramBucket(0.483, 0.522, 0.0, 0.0, 0.0), 46 | CalibrationHistogramBucket(0.522, 0.561, 1.0, 1.555, 0.555), 47 | CalibrationHistogramBucket(0.561, 0.6, 0.0, 0.0, 0.0), 48 | CalibrationHistogramBucket(0.6, Double.PositiveInfinity, 2.0, 3.9, 1.9) 49 | ) 50 | 51 | assert(actual.length == expected.length) 52 | (0 until expected.length).foreach { i => 53 | assert(actual(i).numPredictions === expected(i).numPredictions) 54 | assert(actual(i).sumPredictions === expected(i).sumPredictions) 55 | assert(actual(i).sumLabels === expected(i).sumLabels) 56 | assert(actual(i).lowerThresholdInclusive === expected(i).lowerThresholdInclusive) 57 | assert(actual(i).upperThresholdExclusive === expected(i).upperThresholdExclusive) 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /benchmark/README.md: -------------------------------------------------------------------------------- 1 | 2 | # noether-benchmark 3 | 4 | This module describes benchmarking for the noether project. 5 | 6 | ## Run benchmark 7 | 8 | * Open an sbt shell in the project root: 9 | 10 | sbt 11 | 12 | * Then switch to the subproject `noetherBenchmark` 13 | 14 | 15 | sbt:noether> project noetherBenchmark 16 | 17 | # Run all the benchmarks.. 18 | sbt:noetherBenchmark> jmh:run .* 19 | 20 | # Run specific benchmark (CalibrationHistogram benchmark) 21 | sbt:noetherBenchmark> jmh:run .*CalibrationHistogram.* 22 | 23 | ### Example 24 | 25 | jmh:run -t1 -f1 -wi 2 -i 3 .*Calibration.* 26 | 27 | 28 | The output should look something like this: 29 | 30 | 31 | [info] Benchmark (lowerBound) (nbBucket) (nbElement) (upperBound) Mode Cnt Score Error Units 32 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 100 0.2 thrpt 3 354495958,865 ± 19222766,245 ops/s 33 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 100 0.4 thrpt 3 354159610,708 ± 19887361,006 ops/s 34 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 100 0.5 thrpt 3 353401735,022 ± 9100858,662 ops/s 35 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 1000 0.2 thrpt 3 352173891,822 ± 26480103,060 ops/s 36 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 1000 0.4 thrpt 3 352870170,909 ± 21414322,507 ops/s 37 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 1000 0.5 thrpt 3 354991297,412 ± 37282062,985 ops/s 38 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 3000 0.2 thrpt 3 352385150,389 ± 25101115,471 ops/s 39 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 3000 0.4 thrpt 3 355695107,774 ± 27960629,389 ops/s 40 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 3000 0.5 thrpt 3 353599563,110 ± 35675109,158 ops/s 41 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 200 100 0.2 thrpt 3 354172265,477 ± 29726736,262 ops/s 42 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 200 100 0.4 thrpt 3 353407712,527 ± 15231956,560 ops/s 43 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 200 100 0.5 thrpt 3 355430167,821 ± 53204747,979 ops/s 44 | ... 45 | -------------------------------------------------------------------------------- /core/src/main/scala/com/spotify/noether/CalibrationHistogram.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | package com.spotify.noether 18 | 19 | import com.twitter.algebird.{Aggregator, Semigroup} 20 | 21 | import scala.math.floor 22 | 23 | /** 24 | * Histogram bucket. 25 | * 26 | * @param lowerThresholdInclusive 27 | * Lower bound on bucket, inclusive 28 | * @param upperThresholdExclusive 29 | * Upper bound on bucket, exclusive 30 | * @param numPredictions 31 | * Number of predictions in this bucket 32 | * @param sumLabels 33 | * Sum of label values for this bucket 34 | * @param sumPredictions 35 | * Sum of prediction values for this bucket 36 | */ 37 | final case class CalibrationHistogramBucket( 38 | lowerThresholdInclusive: Double, 39 | upperThresholdExclusive: Double, 40 | numPredictions: Double, 41 | sumLabels: Double, 42 | sumPredictions: Double 43 | ) 44 | 45 | /** 46 | * Split predictions into Tensorflow Model Analysis compatible CalibrationHistogramBucket buckets. 47 | * 48 | * If a prediction is less than the lower bound, it belongs to the bucket [-inf, lower bound) If it 49 | * is greater than or equal to the upper bound, it belongs to the bucket (upper bound, inf] 50 | * 51 | * @param lowerBound 52 | * Left boundary, inclusive 53 | * @param upperBound 54 | * Right boundary, exclusive 55 | * @param numBuckets 56 | * Number of buckets in the histogram 57 | */ 58 | 59 | final case class CalibrationHistogram( 60 | lowerBound: Double = 0.0, 61 | upperBound: Double = 1.0, 62 | numBuckets: Int = 10 63 | ) extends Aggregator[ 64 | Prediction[Double, Double], 65 | Map[Double, (Double, Double, Long)], 66 | List[CalibrationHistogramBucket] 67 | ] { 68 | val bucketSize: Double = (upperBound - lowerBound) / numBuckets.toDouble 69 | 70 | private def thresholdsFromBucket(b: Double): (Double, Double) = b match { 71 | case Double.PositiveInfinity => (upperBound, b) 72 | case Double.NegativeInfinity => (b, lowerBound) 73 | case _ => { 74 | (lowerBound + (b * bucketSize), lowerBound + (b * bucketSize) + bucketSize) 75 | } 76 | } 77 | 78 | def prepare(input: Prediction[Double, Double]): Map[Double, (Double, Double, Long)] = { 79 | val bucketNumber = input.predicted match { 80 | case p if p < lowerBound => Double.NegativeInfinity 81 | case p if p >= upperBound => Double.PositiveInfinity 82 | case _ => 83 | floor((input.predicted - lowerBound) / bucketSize) 84 | } 85 | 86 | Map((bucketNumber, (input.predicted, input.actual, 1L))) 87 | } 88 | 89 | def semigroup: Semigroup[Map[Double, (Double, Double, Long)]] = 90 | Semigroup.mapSemigroup[Double, (Double, Double, Long)] 91 | 92 | def present(m: Map[Double, (Double, Double, Long)]): List[CalibrationHistogramBucket] = { 93 | val buckets = Vector(Double.NegativeInfinity) ++ 94 | (0 until numBuckets).map(_.toDouble).toVector ++ 95 | Vector(Double.PositiveInfinity) 96 | buckets.map { l => 97 | val (lb, ub) = thresholdsFromBucket(l) 98 | m.get(l) match { 99 | case None => CalibrationHistogramBucket(lb, ub, 0.0, 0.0, 0.0) 100 | case Some((predictionSum, labelSum, numExamples)) => 101 | CalibrationHistogramBucket(lb, ub, numExamples.toDouble, labelSum, predictionSum) 102 | } 103 | }.toList 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /tfx/src/main/protobuf/metrics_for_slice.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | syntax = "proto3"; 16 | 17 | import "google/protobuf/wrappers.proto"; 18 | 19 | package tensorflow_model_analysis; 20 | 21 | 22 | // The value will be converted into an error message if we do not know its type. 23 | message UnknownType { 24 | string value = 2; 25 | string error = 1; 26 | } 27 | 28 | message BoundedValue { 29 | // The lower bound of the range. 30 | google.protobuf.DoubleValue lower_bound = 1; 31 | // The upper bound of the range. 32 | google.protobuf.DoubleValue upper_bound = 2; 33 | // Represents an exact value if the lower_bound and upper_bound are unset, 34 | // else it's an approximate value. For the approximate value, it should be 35 | // within the range [lower_bound, uppper_bound]. 36 | google.protobuf.DoubleValue value = 3; 37 | } 38 | 39 | // Value at cutoffs, e.g. for precision@K, recall@K 40 | message ValueAtCutoffs { 41 | message ValueCutoffPair { 42 | int32 cutoff = 1; 43 | double value = 2; 44 | } 45 | repeated ValueCutoffPair values = 1; 46 | } 47 | 48 | // Confusion matrix at thresholds. 49 | message ConfusionMatrixAtThresholds { 50 | message ConfusionMatrixAtThreshold { 51 | double threshold = 1; 52 | double false_negatives = 2; 53 | double true_negatives = 3; 54 | double false_positives = 4; 55 | double true_positives = 5; 56 | double precision = 6; 57 | double recall = 7; 58 | } 59 | repeated ConfusionMatrixAtThreshold matrices = 1; 60 | } 61 | 62 | // It stores metrics values in different types, so that the frontend will know 63 | // how to visualize the values based on the types. 64 | message MetricValue { 65 | oneof type { 66 | google.protobuf.DoubleValue double_value = 1; 67 | BoundedValue bounded_value = 2; 68 | ValueAtCutoffs value_at_cutoffs = 4; 69 | ConfusionMatrixAtThresholds confusion_matrix_at_thresholds = 5; 70 | UnknownType unknown_type = 3; 71 | } 72 | } 73 | 74 | // A single slice key. 75 | message SingleSliceKey { 76 | string column = 1; 77 | oneof kind { 78 | bytes bytes_value = 2; 79 | float float_value = 3; 80 | int64 int64_value = 4; 81 | } 82 | } 83 | 84 | // A slice key, which may consist of multiple single slice keys. 85 | message SliceKey { 86 | repeated SingleSliceKey single_slice_keys = 1; 87 | } 88 | 89 | message MetricsForSlice { 90 | // The slice key for the metrics. 91 | SliceKey slice_key = 1; 92 | // A map to store metrics. Currently we convert the post_export_metric 93 | // provided by TFMA to its appropriate type for better visualization, and map 94 | // all other metrics to DoubleValue type. 95 | map metrics = 2; 96 | } 97 | 98 | message CalibrationHistogramBuckets { 99 | message Bucket { 100 | double lower_threshold_inclusive = 1; 101 | double upper_threshold_exclusive = 2; 102 | google.protobuf.DoubleValue num_weighted_examples = 3; 103 | google.protobuf.DoubleValue total_weighted_label = 4; 104 | google.protobuf.DoubleValue total_weighted_refined_prediction = 5; 105 | } 106 | repeated Bucket buckets = 1; 107 | } 108 | 109 | message PlotData { 110 | // For calibration plot and prediction distribution. 111 | CalibrationHistogramBuckets calibration_histogram_buckets = 1; 112 | // For auc curve and auprc curve. 113 | ConfusionMatrixAtThresholds confusion_matrix_at_thresholds = 2; 114 | } 115 | 116 | message PlotsForSlice { 117 | // The slice key for the metrics. 118 | SliceKey slice_key = 1; 119 | // The plot data 120 | PlotData plot_data = 2; 121 | } 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Noether 2 | ======= 3 | 4 | [![Build Status](https://travis-ci.org/spotify/noether.svg?branch=master)](https://travis-ci.org/spotify/noether) 5 | [![codecov.io](https://codecov.io/github/spotify/noether/coverage.svg?branch=master)](https://codecov.io/github/spotify/noether?branch=master) 6 | [![GitHub license](https://img.shields.io/github/license/spotify/noether.svg)](./LICENSE) 7 | [![Maven Central](https://img.shields.io/maven-central/v/com.spotify/noether-core_2.12.svg)](https://maven-badges.herokuapp.com/maven-central/com.spotify/noether-core_2.12) 8 | [![Scaladoc](https://img.shields.io/badge/scaladoc-latest-blue.svg)](https://spotify.github.io/noether/latest/api/com/spotify/noether/index.html) 9 | [![Scala Steward badge](https://img.shields.io/badge/Scala_Steward-helping-brightgreen.svg?style=flat&logo=)](https://scala-steward.org) 10 | 11 | > [Emmy Noether](https://en.wikipedia.org/wiki/Emmy_Noether) was a German mathematician known for her landmark contributions to abstract algebra and theoretical physics. 12 | 13 | Noether is a collection of Machine Learning tools targeted at the JVM and Scala. 14 | It relies heavily on the [Algebird](https://github.com/twitter/algebird) library especially for Aggregators. 15 | 16 | # Aggregators 17 | 18 | Aggregators enable creation of reusable and composable aggregation functions. Most Machine Learning loss functions and metrics can be 19 | decomposed into a single aggregator. This becomes useful when a model produces a set of predictions and one or more metrics are needed 20 | to be computed on this collection. 21 | 22 | Below is an example for a binary classification task. Algebird's MultiAggregator can be used to combine multiple metrics into a 23 | single callable aggregator. 24 | 25 | ```scala 26 | val multiAggregator = 27 | MultiAggregator(AUC(ROC), AUC(PR), ClassificationReport(), BinaryConfusionMatrix()) 28 | .andThenPresent{case (roc, pr, report, cm) => 29 | (roc, pr, report.accuracy, report.recall, report.precision, cm(1, 1), cm(0, 0)) 30 | } 31 | 32 | val predictions = List(Prediction(false, 0.1), Prediction(false, 0.6), Prediction(true, 0.9)) 33 | 34 | println(multiAggregator(predictions)) 35 | ``` 36 | 37 | ## Prediction Object 38 | 39 | Most aggregators take a single parameterized class called Prediction as input to the aggregator. However the type of 40 | the prediction object differ based on the aggregator. In the above example each binary classifier takes a prediction 41 | of type `Prediction[Boolean, Double]` where the first type is the label and the second in the predicted score. 42 | 43 | Other aggregators will takes slightly different types such as the Error Rate Aggregator which expects `Prediction[Int, List[Double]]` 44 | where the types are label and a list of scores. 45 | 46 | ## Available Aggregators 47 | 48 | See the docs on each aggregator for a more detailed walk-through on the functionality and the return objects. 49 | 50 | 1. ConfusionMatrix 51 | 1. Includes a special BinaryConfusionMatrix case to make composition easier with the other binary classification metrics. 52 | 2. AUC 53 | 1. Supports both ROC and PR 54 | 3. ClassificationReport 55 | 1. Returns a list of summary metrics for a binary classification problem. 56 | 4. LogLoss 57 | 1. Available for multiclass. Returns the total log loss for the predictions. 58 | 5. ErrorRateSummary 59 | 1. Available for multiclass. Returns the proportion of misclassified predictions.w 60 | 61 | # Tensorflow Model Analysis Support 62 | 63 | Noether supports outputting metrics as TFX `metrics_for_slice` protobufs, which can be used in 64 | TFMA methods. This is available in the `noether-tfx` package: 65 | 66 | ```scala 67 | libraryDependencies += "com.spotify" %% "noether-tfx" % noetherVersion 68 | ``` 69 | 70 | ```scala 71 | import com.spotify.noether.tfx._ 72 | 73 | val data = List( 74 | (0, 0), 75 | (0, 1), 76 | (0, 0), 77 | (1, 0), 78 | (1, 1), 79 | (1, 1), 80 | (1, 1) 81 | ).map { case (s, pred) => Prediction(pred, s) } 82 | 83 | val tfmaProto = ConfusionMatrix(Seq(0, 1)).asTfmaProto(data) 84 | ``` 85 | 86 | # License 87 | 88 | Copyright 2016-2018 Spotify AB. 89 | 90 | Licensed under the Apache License, Version 2.0: http://www.apache.org/licenses/LICENSE-2.0 91 | -------------------------------------------------------------------------------- /core/src/main/scala/com/spotify/noether/AUC.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import breeze.linalg._ 21 | import com.twitter.algebird.{Aggregator, Semigroup} 22 | 23 | case class MetricCurve(cm: Array[Map[(Int, Int), Long]]) extends Serializable 24 | 25 | case class MetricCurvePoints(points: Array[MetricCurvePoint]) extends Serializable 26 | 27 | case class MetricCurvePoint(x: Double, y: Double) extends Serializable 28 | 29 | private[noether] object AreaUnderCurve { 30 | def trapezoid(points: Seq[MetricCurvePoint]): Double = { 31 | val x = points.head 32 | val y = points.last 33 | (y.x - x.x) * (y.y + x.y) / 2.0 34 | } 35 | 36 | def of(curve: MetricCurvePoints): Double = { 37 | curve.points.iterator 38 | .sliding(2) 39 | .withPartial(false) 40 | .foldLeft(0.0)((auc: Double, points: Seq[MetricCurvePoint]) => auc + trapezoid(points)) 41 | } 42 | } 43 | 44 | /** 45 | * Which function to apply on the list of confusion matrices prior to the AUC calculation. 46 | */ 47 | sealed trait AUCMetric 48 | 49 | /** 50 | * Receiver operating 51 | * characteristic Curve 52 | */ 53 | case object ROC extends AUCMetric 54 | 55 | /** 56 | * Precision Recall Curve 57 | */ 58 | case object PR extends AUCMetric 59 | 60 | /** 61 | * Compute a series of points for a collection of predictions. 62 | * 63 | * Internally a linspace is defined using the given number of [[samples]]. Each point in the 64 | * linspace represents a threshold which is used to build a confusion matrix. The (x,y) location of 65 | * the line is then returned. 66 | * 67 | * [[AUCMetric]] which is given to the aggregate selects the function to apply on the confusion 68 | * matrix prior to the AUC calculation. 69 | * 70 | * @param metric 71 | * Which function to apply on the confusion matrix. 72 | * @param samples 73 | * Number of samples to use for the curve definition. 74 | */ 75 | case class Curve(metric: AUCMetric, samples: Int = 100) 76 | extends Aggregator[Prediction[Boolean, Double], MetricCurve, MetricCurvePoints] { 77 | private lazy val thresholds = linspace(0.0, 1.0, samples) 78 | private lazy val aggregators = 79 | thresholds.data.map(ClassificationReport(_)).toArray 80 | 81 | def prepare(input: Prediction[Boolean, Double]): MetricCurve = 82 | MetricCurve(aggregators.map(_.prepare(input))) 83 | 84 | def semigroup: Semigroup[MetricCurve] = { 85 | val sg = ClassificationReport().semigroup 86 | Semigroup.from { case (l, r) => 87 | MetricCurve(l.cm.zip(r.cm).map { case (cl, cr) => sg.plus(cl, cr) }) 88 | } 89 | } 90 | 91 | def present(c: MetricCurve): MetricCurvePoints = { 92 | val total = c.cm.map { matrix => 93 | val scores = ClassificationReport().present(matrix) 94 | metric match { 95 | case ROC => MetricCurvePoint(scores.fpr, scores.recall) 96 | case PR => MetricCurvePoint(scores.recall, scores.precision) 97 | } 98 | }.reverse 99 | 100 | val points = metric match { 101 | case ROC => total ++ Array(MetricCurvePoint(1.0, 1.0)) 102 | case PR => Array(MetricCurvePoint(0.0, 1.0)) ++ total 103 | } 104 | 105 | MetricCurvePoints(points) 106 | } 107 | } 108 | 109 | /** 110 | * Compute the "Area Under the Curve" for a collection of predictions. Uses the Trapezoid method to 111 | * compute the area. 112 | * 113 | * Internally a linspace is defined using the given number of [[samples]]. Each point in the 114 | * linspace represents a threshold which is used to build a confusion matrix. The area is then 115 | * defined on this list of confusion matrices. 116 | * 117 | * [[AUCMetric]] which is given to the aggregate selects the function to apply on the confusion 118 | * matrix prior to the AUC calculation. 119 | * 120 | * @param metric 121 | * Which function to apply on the confusion matrix. 122 | * @param samples 123 | * Number of samples to use for the curve definition. 124 | */ 125 | case class AUC(metric: AUCMetric, samples: Int = 100) 126 | extends Aggregator[Prediction[Boolean, Double], MetricCurve, Double] { 127 | private val curve = Curve(metric, samples) 128 | def prepare(input: Prediction[Boolean, Double]): MetricCurve = curve.prepare(input) 129 | def semigroup: Semigroup[MetricCurve] = curve.semigroup 130 | def present(c: MetricCurve): Double = AreaUnderCurve.of(curve.present(c)) 131 | } 132 | -------------------------------------------------------------------------------- /core/src/main/scala/com/spotify/noether/ClassificationReport.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether 19 | 20 | import com.twitter.algebird.{Aggregator, Semigroup} 21 | 22 | /** 23 | * Classification Report 24 | * 25 | * @param mcc 26 | * Matthews Correlation Coefficient 27 | * @param fscore 28 | * f-score 29 | * @param precision 30 | * Precision 31 | * @param recall 32 | * Recall 33 | * @param accuracy 34 | * Accuracy 35 | * @param fpr 36 | * False Positive Rate 37 | */ 38 | final case class Report( 39 | mcc: Double, 40 | fscore: Double, 41 | precision: Double, 42 | recall: Double, 43 | accuracy: Double, 44 | fpr: Double 45 | ) 46 | 47 | /** 48 | * Generate a Classification Report for a collection of binary predictions. The output of this 49 | * aggregator will be a [[Report]] object. 50 | * 51 | * @param threshold 52 | * Threshold to apply to get the predictions. 53 | * @param beta 54 | * Beta parameter used in the f-score calculation. 55 | */ 56 | final case class ClassificationReport(threshold: Double = 0.5, beta: Double = 1.0) 57 | extends Aggregator[Prediction[Boolean, Double], Map[(Int, Int), Long], Report] { 58 | private val aggregator = MultiClassificationReport(Seq(0, 1)) 59 | 60 | def prepare(input: Prediction[Boolean, Double]): Map[(Int, Int), Long] = { 61 | val predicted = Prediction( 62 | if (input.actual) 1 else 0, 63 | if (input.predicted > threshold) 1 else 0 64 | ) 65 | aggregator.prepare(predicted) 66 | } 67 | 68 | def semigroup: Semigroup[Map[(Int, Int), Long]] = aggregator.semigroup 69 | 70 | def present(m: Map[(Int, Int), Long]): Report = aggregator.present(m)(1) 71 | } 72 | 73 | /** 74 | * Generate a Classification Report for a collection of multiclass predictions. A report is 75 | * generated for each class by treating the predictions as binary of either "class" or "not class". 76 | * The output of this aggregator will be a map of classes and their [[Report]] objects. 77 | * 78 | * @param labels 79 | * List of possible label values. 80 | * @param beta 81 | * Beta parameter used in the f-score calculation. 82 | */ 83 | final case class MultiClassificationReport(labels: Seq[Int], beta: Double = 1.0) 84 | extends Aggregator[Prediction[Int, Int], Map[(Int, Int), Long], Map[Int, Report]] { 85 | private val aggregator = ConfusionMatrix(labels) 86 | 87 | override def prepare(input: Prediction[Int, Int]): Map[(Int, Int), Long] = 88 | aggregator.prepare(input) 89 | 90 | override def semigroup: Semigroup[Map[(Int, Int), Long]] = 91 | aggregator.semigroup 92 | 93 | override def present(m: Map[(Int, Int), Long]): Map[Int, Report] = { 94 | val mat = m.withDefaultValue(0L) 95 | labels.foldLeft(Map.empty[Int, Report]) { (result, clazz) => 96 | val fp = mat.iterator 97 | .filter { case ((p, a), _) => 98 | p == clazz && a != clazz 99 | } 100 | .map(_._2) 101 | .sum 102 | .toDouble 103 | val tp = mat(clazz -> clazz).toDouble 104 | val tn = mat.iterator 105 | .filter { case ((p, a), _) => 106 | p != clazz && a != clazz 107 | } 108 | .map(_._2) 109 | .sum 110 | .toDouble 111 | val fn = mat.iterator 112 | .filter { case ((p, a), _) => p != clazz && a == clazz } 113 | .map(_._2) 114 | .sum 115 | .toDouble 116 | 117 | val mccDenom = math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) 118 | val mcc = if (mccDenom > 0.0) ((tp * tn) - (fp * fn)) / mccDenom else 0.0 119 | 120 | val precDenom = tp + fp 121 | val precision = if (precDenom > 0.0) tp / precDenom else 1.0 122 | 123 | val recallDenom = tp + fn 124 | val recall = if (recallDenom > 0.0) tp / recallDenom else 1.0 125 | 126 | val accuracyDenom = tp + fn + tn + fp 127 | val accuracy = if (accuracyDenom > 0.0) (tp + tn) / accuracyDenom else 0.0 128 | 129 | val fpDenom = fp + tn 130 | val fpr = if (fpDenom > 0.0) fp / fpDenom else 0.0 131 | 132 | val betaSqr = Math.pow(beta, 2.0) 133 | 134 | val fScoreDenom = (betaSqr * precision) + recall 135 | val fscore = if (fScoreDenom > 0.0) { 136 | (1 + betaSqr) * ((precision * recall) / fScoreDenom) 137 | } else { 138 | 1.0 139 | } 140 | 141 | result + (clazz -> Report(mcc, fscore, precision, recall, accuracy, fpr)) 142 | } 143 | } 144 | 145 | } 146 | -------------------------------------------------------------------------------- /tfx/src/test/scala/com/spotify/noether/tfx/TfmaConverterTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether.tfx 19 | 20 | import com.spotify.noether._ 21 | import org.scalactic.{Equality, TolerantNumerics} 22 | import org.scalatest.Assertion 23 | import tensorflow_model_analysis.MetricsForSliceOuterClass.MetricsForSlice 24 | import tensorflow_model_analysis.MetricsForSliceOuterClass.CalibrationHistogramBuckets 25 | import scala.jdk.CollectionConverters._ 26 | import org.scalatest.flatspec.AnyFlatSpec 27 | import org.scalatest.matchers.should.Matchers 28 | 29 | class TfmaConverterTest extends AnyFlatSpec with Matchers { 30 | implicit val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.001) 31 | 32 | "TfmaConverter" should "work with ConfusionMatrix" in { 33 | val data = List( 34 | (0, 0), 35 | (0, 1), 36 | (0, 0), 37 | (1, 0), 38 | (1, 1), 39 | (1, 1), 40 | (1, 1) 41 | ).map { case (s, pred) => Prediction(pred, s) } 42 | 43 | val evalResult = ConfusionMatrix(Seq(0, 1)).asTfmaProto(data) 44 | val cmProto = evalResult.metrics.get 45 | 46 | val cm = cmProto.getMetricsMap 47 | .get("Noether_ConfusionMatrix") 48 | .getConfusionMatrixAtThresholds 49 | .getMatrices(0) 50 | 51 | assert(cm.getFalseNegatives.toLong === 1L) 52 | assert(cm.getFalsePositives.toLong === 1L) 53 | assert(cm.getTrueNegatives.toLong === 2L) 54 | assert(cm.getTruePositives.toLong === 3L) 55 | 56 | val plotCm = evalResult.plots.get.plotData.getPlotData.getConfusionMatrixAtThresholds 57 | .getMatrices(0) 58 | 59 | assert(plotCm.getFalseNegatives.toLong === 1L) 60 | assert(plotCm.getFalsePositives.toLong === 1L) 61 | assert(plotCm.getTrueNegatives.toLong === 2L) 62 | assert(plotCm.getTruePositives.toLong === 3L) 63 | } 64 | 65 | it should "work with BinaryConfusionMatrix" in { 66 | val data = List( 67 | (false, 0.1), 68 | (false, 0.6), 69 | (false, 0.2), 70 | (true, 0.2), 71 | (true, 0.8), 72 | (true, 0.7), 73 | (true, 0.6) 74 | ).map { case (pred, s) => Prediction(pred, s) } 75 | 76 | val evalResult = BinaryConfusionMatrix().asTfmaProto(data) 77 | val cmProto = evalResult.metrics.get 78 | 79 | val cm = cmProto.getMetricsMap 80 | .get("Noether_ConfusionMatrix") 81 | .getConfusionMatrixAtThresholds 82 | .getMatrices(0) 83 | 84 | assert(cm.getThreshold === 0.5) 85 | assert(cm.getTruePositives.toLong === 3L) 86 | assert(cm.getFalseNegatives.toLong === 1L) 87 | assert(cm.getFalsePositives.toLong === 1L) 88 | assert(cm.getTrueNegatives.toLong === 2L) 89 | 90 | val plotCm = evalResult.plots.get.plotData.getPlotData.getConfusionMatrixAtThresholds 91 | .getMatrices(0) 92 | 93 | assert(plotCm.getFalseNegatives.toLong === 1L) 94 | assert(plotCm.getFalsePositives.toLong === 1L) 95 | assert(plotCm.getTrueNegatives.toLong === 2L) 96 | assert(plotCm.getTruePositives.toLong === 3L) 97 | } 98 | 99 | it should "work with ClassificationReport" in { 100 | val data = List( 101 | (0.1, false), 102 | (0.1, true), 103 | (0.4, false), 104 | (0.6, false), 105 | (0.6, true), 106 | (0.6, true), 107 | (0.8, true) 108 | ).map { case (s, pred) => Prediction(pred, s) } 109 | 110 | val metrics = ClassificationReport().asTfmaProto(data).metrics 111 | 112 | def assertMetric(name: String, expected: Double): Assertion = 113 | assert(metrics.get.getMetricsMap.get(name).getDoubleValue.getValue === expected) 114 | 115 | assertMetric("Noether_Accuracy", 0.7142857142857143) 116 | assertMetric("Noether_FPR", 0.333) 117 | assertMetric("Noether_FScore", 0.75) 118 | assertMetric("Noether_MCC", 0.4166666666666667) 119 | assertMetric("Noether_Precision", 0.75) 120 | assertMetric("Noether_Recall", 0.75) 121 | } 122 | 123 | it should "work with ErrorRateSummary" in { 124 | val classes = 10 125 | def s(idx: Int): List[Double] = 0.until(classes).map(i => if (i == idx) 1.0 else 0.0).toList 126 | 127 | val data = 128 | List((s(1), 1), (s(3), 1), (s(5), 5), (s(2), 3), (s(0), 0), (s(8), 1)).map { 129 | case (scores, label) => Prediction(label, scores) 130 | } 131 | 132 | val ersProto: MetricsForSlice = ErrorRateSummary.asTfmaProto(data).metrics.get 133 | 134 | val ersV = 135 | ersProto.getMetricsMap.get("Noether_ErrorRateSummary").getDoubleValue.getValue 136 | assert(ersV === 0.5) 137 | } 138 | 139 | it should "work with AUC" in { 140 | val data = List( 141 | (false, 0.1), 142 | (false, 0.6), 143 | (false, 0.2), 144 | (true, 0.2), 145 | (true, 0.8), 146 | (true, 0.7), 147 | (true, 0.6) 148 | ).map { case (pred, s) => Prediction(pred, s) } 149 | 150 | val aucROCProto = AUC(ROC).asTfmaProto(data).metrics.get 151 | val aucPRProto = AUC(PR).asTfmaProto(data).metrics.get 152 | 153 | val actualROC = aucROCProto.getMetricsMap.get("Noether_AUC:ROC").getDoubleValue.getValue 154 | val actualPR = aucPRProto.getMetricsMap.get("Noether_AUC:PR").getDoubleValue.getValue 155 | 156 | assert(actualROC === 0.833) 157 | assert(actualPR === 0.896) 158 | } 159 | 160 | it should "work with LogLoss" in { 161 | val classes = 10 162 | 163 | def s(idx: Int, score: Double): List[Double] = 164 | 0.until(classes).map(i => if (i == idx) score else 0.0).toList 165 | 166 | val data = List((s(0, 0.8), 0), (s(1, 0.6), 1), (s(2, 0.7), 2)).map { case (scores, label) => 167 | Prediction(label, scores) 168 | } 169 | 170 | val logLossProto: MetricsForSlice = LogLoss.asTfmaProto(data).metrics.get 171 | 172 | val logLoss = logLossProto.getMetricsMap.get("Noether_LogLoss").getDoubleValue.getValue 173 | assert(logLoss === 0.363548039673) 174 | } 175 | 176 | it should "work with MeanAveragePrecision" in { 177 | import RankingData._ 178 | val proto = MeanAveragePrecision[Int]().asTfmaProto(rankingData).metrics.get 179 | val meanAvgPrecision = proto.getMetricsMap 180 | .get("Noether_MeanAvgPrecision") 181 | .getDoubleValue 182 | .getValue 183 | assert(meanAvgPrecision === 0.355026) 184 | } 185 | 186 | it should "work with NdcgAtK" in { 187 | import RankingData._ 188 | implicit val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.1) 189 | 190 | def getNdcgAtK(v: Int): Double = 191 | NdcgAtK[Int](v) 192 | .asTfmaProto(rankingData) 193 | .metrics 194 | .get 195 | .getMetricsMap 196 | .get("Noether_NdcgAtK") 197 | .getDoubleValue 198 | .getValue 199 | 200 | assert(getNdcgAtK(3) === 1.0 / 3) 201 | assert(getNdcgAtK(5) === 0.328788) 202 | assert(getNdcgAtK(10) === 0.487913) 203 | assert(getNdcgAtK(15) === getNdcgAtK(10)) 204 | } 205 | 206 | it should "work with PrecisionAtK" in { 207 | import RankingData._ 208 | implicit val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.1) 209 | 210 | def getPrecisionAtK(v: Int): Double = 211 | PrecisionAtK[Int](v) 212 | .asTfmaProto(rankingData) 213 | .metrics 214 | .get 215 | .getMetricsMap 216 | .get("Noether_PrecisionAtK") 217 | .getDoubleValue 218 | .getValue 219 | 220 | assert(getPrecisionAtK(1) === 1.0 / 3) 221 | assert(getPrecisionAtK(2) === 1.0 / 3) 222 | assert(getPrecisionAtK(3) === 1.0 / 3) 223 | assert(getPrecisionAtK(4) === 0.75 / 3) 224 | assert(getPrecisionAtK(5) === 0.8 / 3) 225 | assert(getPrecisionAtK(10) === 0.8 / 3) 226 | assert(getPrecisionAtK(15) === 8.0 / 45) 227 | } 228 | 229 | it should "work with CalibrationHistogram" in { 230 | val data = Seq( 231 | (0.15, 1.15), // lb 232 | (0.288, 1.288), // rounding error puts this in (0.249, 0.288) 233 | (0.30, 1.30), // (0.288, 0.3269) 234 | (0.36, 1.36), // (0.3269, 0.365) 235 | (0.555, 1.555), // (0.5219, 0.5609) 236 | (1.2, 2.2), // ub 237 | (0.7, 1.7) // ub 238 | ).map { case (p, a) => Prediction(a, p) } 239 | 240 | val result = CalibrationHistogram(0.21, 0.60, 10).asTfmaProto(data) 241 | 242 | def protoToCaseClass(p: CalibrationHistogramBuckets.Bucket): CalibrationHistogramBucket = { 243 | CalibrationHistogramBucket( 244 | p.getLowerThresholdInclusive, 245 | p.getUpperThresholdExclusive, 246 | p.getNumWeightedExamples.getValue, 247 | p.getTotalWeightedLabel.getValue, 248 | p.getTotalWeightedRefinedPrediction.getValue 249 | ) 250 | } 251 | 252 | val actual = 253 | result.plots.get.plotData.getPlotData.getCalibrationHistogramBuckets.getBucketsList.asScala 254 | .map(protoToCaseClass) 255 | 256 | val expected = List( 257 | CalibrationHistogramBucket(Double.NegativeInfinity, 0.21, 1.0, 1.15, 0.15), 258 | CalibrationHistogramBucket(0.21, 0.249, 0.0, 0.0, 0.0), 259 | CalibrationHistogramBucket(0.249, 0.288, 1.0, 1.288, 0.288), 260 | CalibrationHistogramBucket(0.288, 0.327, 1.0, 1.30, 0.30), 261 | CalibrationHistogramBucket(0.327, 0.366, 1.0, 1.36, 0.36), 262 | CalibrationHistogramBucket(0.366, 0.405, 0.0, 0.0, 0.0), 263 | CalibrationHistogramBucket(0.405, 0.4449, 0.0, 0.0, 0.0), 264 | CalibrationHistogramBucket(0.444, 0.483, 0.0, 0.0, 0.0), 265 | CalibrationHistogramBucket(0.483, 0.522, 0.0, 0.0, 0.0), 266 | CalibrationHistogramBucket(0.522, 0.561, 1.0, 1.555, 0.555), 267 | CalibrationHistogramBucket(0.561, 0.6, 0.0, 0.0, 0.0), 268 | CalibrationHistogramBucket(0.6, Double.PositiveInfinity, 2.0, 3.9, 1.9) 269 | ) 270 | 271 | assert(actual.length == expected.length) 272 | (0 until expected.length).foreach { i => 273 | assert(actual(i).numPredictions === expected(i).numPredictions) 274 | assert(actual(i).sumPredictions === expected(i).sumPredictions) 275 | assert(actual(i).sumLabels === expected(i).sumLabels) 276 | assert(actual(i).lowerThresholdInclusive === expected(i).lowerThresholdInclusive) 277 | assert(actual(i).upperThresholdExclusive === expected(i).upperThresholdExclusive) 278 | } 279 | } 280 | } 281 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /tfx/src/main/scala/com/spotify/noether/tfx/TfmaImplicits.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2018 Spotify AB. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, 11 | * software distributed under the License is distributed on an 12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | * KIND, either express or implied. See the License for the 14 | * specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | 18 | package com.spotify.noether.tfx 19 | 20 | import breeze.linalg.DenseMatrix 21 | import com.google.protobuf.DoubleValue 22 | import com.spotify.noether._ 23 | import com.spotify.noether.tfx.Tfma.ConversionOps 24 | import com.twitter.algebird.Aggregator 25 | import tensorflow_model_analysis.MetricsForSliceOuterClass.ConfusionMatrixAtThresholds._ 26 | import tensorflow_model_analysis.MetricsForSliceOuterClass._ 27 | 28 | import scala.collection.JavaConverters._ 29 | 30 | trait TfmaImplicits { 31 | private def confusionMatrixToMetric(cm: ConfusionMatrixAtThresholds): MetricsForSlice = { 32 | MetricsForSlice 33 | .newBuilder() 34 | .setSliceKey(SliceKey.getDefaultInstance) 35 | .putMetrics( 36 | "Noether_ConfusionMatrix", 37 | MetricValue 38 | .newBuilder() 39 | .setConfusionMatrixAtThresholds(cm) 40 | .build() 41 | ) 42 | .build() 43 | } 44 | 45 | private def denseMatrixToConfusionMatrix( 46 | threshold: Option[Double] = None 47 | )(matrix: DenseMatrix[Long]): ConfusionMatrixAtThresholds = { 48 | val tp = matrix.valueAt(1, 1).toDouble 49 | val tn = matrix.valueAt(0, 0).toDouble 50 | val fp = matrix.valueAt(1, 0).toDouble 51 | val fn = matrix.valueAt(0, 1).toDouble 52 | 53 | val cmBuilder = ConfusionMatrixAtThreshold 54 | .newBuilder() 55 | .setFalseNegatives(fn) 56 | .setFalsePositives(fp) 57 | .setTrueNegatives(tn) 58 | .setTruePositives(tp) 59 | .setPrecision(tp / (tp + fp)) 60 | .setRecall(tp / (tp + fn)) 61 | 62 | threshold.foreach(cmBuilder.setThreshold) 63 | ConfusionMatrixAtThresholds 64 | .newBuilder() 65 | .addMatrices(cmBuilder.build()) 66 | .build() 67 | } 68 | 69 | private def buildDoubleMetric(name: String, value: Double): MetricsForSlice = 70 | MetricsForSlice 71 | .newBuilder() 72 | .setSliceKey(SliceKey.getDefaultInstance) 73 | .putMetrics( 74 | name, 75 | MetricValue 76 | .newBuilder() 77 | .setDoubleValue(DoubleValue.newBuilder().setValue(value)) 78 | .build() 79 | ) 80 | .build() 81 | 82 | private def buildDoubleMetrics(metrics: Map[String, Double]): MetricsForSlice = { 83 | val metricValues = metrics.iterator 84 | .map { case (k, m) => 85 | val value = MetricValue 86 | .newBuilder() 87 | .setDoubleValue(DoubleValue.newBuilder().setValue(m)) 88 | .build() 89 | (k, value) 90 | } 91 | .toMap 92 | .asJava 93 | 94 | MetricsForSlice 95 | .newBuilder() 96 | .setSliceKey(SliceKey.getDefaultInstance) 97 | .putAllMetrics(metricValues) 98 | .build() 99 | } 100 | 101 | private def buildConfusionMatrixPlot(cm: ConfusionMatrixAtThresholds): PlotsForSlice = 102 | PlotsForSlice 103 | .newBuilder() 104 | .setSliceKey(SliceKey.getDefaultInstance) 105 | .setPlotData( 106 | PlotData 107 | .newBuilder() 108 | .setConfusionMatrixAtThresholds(cm) 109 | ) 110 | .build() 111 | 112 | private def mkDoubleValue(d: Double): DoubleValue = 113 | DoubleValue 114 | .newBuilder() 115 | .setValue(d) 116 | .build() 117 | 118 | private def buildCalibrationHistogramPlot(ch: List[CalibrationHistogramBucket]): PlotsForSlice = { 119 | val plotData = PlotData 120 | .newBuilder() 121 | .setCalibrationHistogramBuckets( 122 | CalibrationHistogramBuckets 123 | .newBuilder() 124 | .addAllBuckets(ch.map { b => 125 | CalibrationHistogramBuckets.Bucket 126 | .newBuilder() 127 | .setLowerThresholdInclusive(b.lowerThresholdInclusive) 128 | .setUpperThresholdExclusive(b.upperThresholdExclusive) 129 | .setTotalWeightedRefinedPrediction(mkDoubleValue(b.sumPredictions)) 130 | .setTotalWeightedLabel(mkDoubleValue(b.sumLabels)) 131 | .setNumWeightedExamples(mkDoubleValue(b.numPredictions)) 132 | .build() 133 | }.asJava) 134 | ) 135 | .build() 136 | 137 | PlotsForSlice 138 | .newBuilder() 139 | .setSliceKey(SliceKey.getDefaultInstance) 140 | .setPlotData(plotData) 141 | .build() 142 | } 143 | 144 | implicit def confusionMatrixConversion(agg: ConfusionMatrix)(implicit 145 | c: TfmaConverter[Prediction[Int, Int], Map[(Int, Int), Long], ConfusionMatrix] 146 | ): ConversionOps[Prediction[Int, Int], Map[(Int, Int), Long], ConfusionMatrix] = 147 | ConversionOps[Prediction[Int, Int], Map[(Int, Int), Long], ConfusionMatrix](agg, c) 148 | 149 | implicit def calibrationHistogramConversion(agg: CalibrationHistogram)(implicit 150 | c: TfmaConverter[ 151 | Prediction[Double, Double], 152 | Map[Double, (Double, Double, Long)], 153 | CalibrationHistogram 154 | ] 155 | ): ConversionOps[ 156 | Prediction[Double, Double], 157 | Map[Double, (Double, Double, Long)], 158 | CalibrationHistogram 159 | ] = 160 | ConversionOps[ 161 | Prediction[Double, Double], 162 | Map[Double, (Double, Double, Long)], 163 | CalibrationHistogram 164 | ]( 165 | agg, 166 | c 167 | ) 168 | 169 | implicit def binaryConfusionMatrixConversion(agg: BinaryConfusionMatrix)(implicit 170 | c: TfmaConverter[BinaryPred, Map[(Int, Int), Long], BinaryConfusionMatrix] 171 | ): ConversionOps[BinaryPred, Map[(Int, Int), Long], BinaryConfusionMatrix] = 172 | ConversionOps[BinaryPred, Map[(Int, Int), Long], BinaryConfusionMatrix](agg, c) 173 | 174 | implicit def classificationReportConversion(agg: ClassificationReport)(implicit 175 | c: TfmaConverter[BinaryPred, Map[(Int, Int), Long], ClassificationReport] 176 | ): ConversionOps[BinaryPred, Map[(Int, Int), Long], ClassificationReport] = 177 | ConversionOps[BinaryPred, Map[(Int, Int), Long], ClassificationReport](agg, c) 178 | 179 | implicit def aucConversion(agg: AUC)(implicit 180 | c: TfmaConverter[BinaryPred, MetricCurve, AUC] 181 | ): ConversionOps[BinaryPred, MetricCurve, AUC] = 182 | ConversionOps[BinaryPred, MetricCurve, AUC](agg, c) 183 | 184 | implicit def errorRateSummaryConversion(agg: ErrorRateSummary.type)(implicit 185 | c: TfmaConverter[Prediction[Int, List[Double]], (Double, Long), ErrorRateSummary.type] 186 | ): ConversionOps[Prediction[Int, List[Double]], (Double, Long), ErrorRateSummary.type] = 187 | ConversionOps[Prediction[Int, List[Double]], (Double, Long), ErrorRateSummary.type](agg, c) 188 | 189 | implicit def logLossConversion(agg: LogLoss.type)(implicit 190 | c: TfmaConverter[Prediction[Int, List[Double]], (Double, Long), LogLoss.type] 191 | ): ConversionOps[Prediction[Int, List[Double]], (Double, Long), LogLoss.type] = 192 | ConversionOps[Prediction[Int, List[Double]], (Double, Long), LogLoss.type](agg, c) 193 | 194 | implicit def meanAvgPrecisionConversion[T](agg: MeanAveragePrecision[T])(implicit 195 | c: TfmaConverter[RankingPrediction[T], (Double, Long), MeanAveragePrecision[T]] 196 | ): ConversionOps[RankingPrediction[T], (Double, Long), MeanAveragePrecision[T]] = 197 | ConversionOps[RankingPrediction[T], (Double, Long), MeanAveragePrecision[T]](agg, c) 198 | 199 | implicit def ndcgAtKConversion[T](agg: NdcgAtK[T])(implicit 200 | c: TfmaConverter[RankingPrediction[T], (Double, Long), NdcgAtK[T]] 201 | ): ConversionOps[RankingPrediction[T], (Double, Long), NdcgAtK[T]] = 202 | ConversionOps[RankingPrediction[T], (Double, Long), NdcgAtK[T]](agg, c) 203 | 204 | implicit def precisionAtKConversion[T](agg: PrecisionAtK[T])(implicit 205 | c: TfmaConverter[RankingPrediction[T], (Double, Long), PrecisionAtK[T]] 206 | ): ConversionOps[RankingPrediction[T], (Double, Long), PrecisionAtK[T]] = 207 | ConversionOps[RankingPrediction[T], (Double, Long), PrecisionAtK[T]](agg, c) 208 | 209 | implicit val errorRateSummaryConverter 210 | : TfmaConverter[Prediction[Int, List[Double]], (Double, Long), ErrorRateSummary.type] = 211 | new TfmaConverter[Prediction[Int, List[Double]], (Double, Long), ErrorRateSummary.type] { 212 | override def convertToTfmaProto( 213 | underlying: ErrorRateSummary.type 214 | ): Aggregator[Prediction[Int, List[Double]], (Double, Long), EvalResult] = 215 | ErrorRateSummary.andThenPresent { ers => 216 | val metrics = buildDoubleMetric("Noether_ErrorRateSummary", ers) 217 | EvalResult(metrics) 218 | } 219 | } 220 | 221 | implicit val binaryConfusionMatrixConverter 222 | : TfmaConverter[BinaryPred, Map[(Int, Int), Long], BinaryConfusionMatrix] = 223 | new TfmaConverter[BinaryPred, Map[(Int, Int), Long], BinaryConfusionMatrix] { 224 | override def convertToTfmaProto( 225 | underlying: BinaryConfusionMatrix 226 | ): Aggregator[BinaryPred, Map[(Int, Int), Long], EvalResult] = { 227 | underlying 228 | .andThenPresent( 229 | (denseMatrixToConfusionMatrix(Some(underlying.threshold)) _) 230 | .andThen { cm => 231 | val metrics = confusionMatrixToMetric(cm) 232 | val plots = buildConfusionMatrixPlot(cm) 233 | EvalResult(metrics, Plot.ConfusionMatrix(plots)) 234 | } 235 | ) 236 | } 237 | } 238 | 239 | implicit val confusionMatrixConverter 240 | : TfmaConverter[Prediction[Int, Int], Map[(Int, Int), Long], ConfusionMatrix] = 241 | new TfmaConverter[Prediction[Int, Int], Map[(Int, Int), Long], ConfusionMatrix] { 242 | override def convertToTfmaProto( 243 | underlying: ConfusionMatrix 244 | ): Aggregator[Prediction[Int, Int], Map[(Int, Int), Long], EvalResult] = 245 | underlying.andThenPresent((denseMatrixToConfusionMatrix() _).andThen { cm => 246 | val metrics = confusionMatrixToMetric(cm) 247 | val plots = buildConfusionMatrixPlot(cm) 248 | EvalResult(metrics, Plot.ConfusionMatrix(plots)) 249 | }) 250 | } 251 | 252 | implicit val classificationReportConverter 253 | : TfmaConverter[BinaryPred, Map[(Int, Int), Long], ClassificationReport] = 254 | new TfmaConverter[BinaryPred, Map[(Int, Int), Long], ClassificationReport] { 255 | override def convertToTfmaProto( 256 | underlying: ClassificationReport 257 | ): Aggregator[BinaryPred, Map[(Int, Int), Long], EvalResult] = 258 | underlying.andThenPresent { report => 259 | val allMetrics = Map( 260 | "Noether_Accuracy" -> report.accuracy, 261 | "Noether_FPR" -> report.fpr, 262 | "Noether_FScore" -> report.fscore, 263 | "Noether_MCC" -> report.mcc, 264 | "Noether_Precision" -> report.precision, 265 | "Noether_Recall" -> report.recall 266 | ) 267 | val metrics = buildDoubleMetrics(allMetrics) 268 | EvalResult(metrics) 269 | } 270 | } 271 | 272 | implicit val aucConverter: TfmaConverter[BinaryPred, MetricCurve, AUC] = 273 | new TfmaConverter[BinaryPred, MetricCurve, AUC] { 274 | override def convertToTfmaProto( 275 | underlying: AUC 276 | ): Aggregator[BinaryPred, MetricCurve, EvalResult] = 277 | underlying 278 | .andThenPresent { areaValue => 279 | val metricName = underlying.metric match { 280 | case ROC => "Noether_AUC:ROC" 281 | case PR => "Noether_AUC:PR" 282 | } 283 | val metrics = buildDoubleMetric(metricName, areaValue) 284 | EvalResult(metrics) 285 | } 286 | } 287 | 288 | implicit val logLossConverter 289 | : TfmaConverter[Prediction[Int, List[Double]], (Double, Long), LogLoss.type] = 290 | new TfmaConverter[Prediction[Int, List[Double]], (Double, Long), LogLoss.type] { 291 | override def convertToTfmaProto( 292 | underlying: LogLoss.type 293 | ): Aggregator[Prediction[Int, List[Double]], (Double, Long), EvalResult] = 294 | underlying.andThenPresent { logLoss => 295 | val metrics = buildDoubleMetric("Noether_LogLoss", logLoss) 296 | EvalResult(metrics) 297 | } 298 | } 299 | 300 | implicit def meanAvgPrecisionConverter[T] 301 | : TfmaConverter[RankingPrediction[T], (Double, Long), MeanAveragePrecision[T]] = 302 | new TfmaConverter[RankingPrediction[T], (Double, Long), MeanAveragePrecision[T]] { 303 | override def convertToTfmaProto( 304 | underlying: MeanAveragePrecision[T] 305 | ): Aggregator[RankingPrediction[T], (Double, Long), EvalResult] = 306 | underlying.andThenPresent { meanAvgPrecision => 307 | val metrics = buildDoubleMetric("Noether_MeanAvgPrecision", meanAvgPrecision) 308 | EvalResult(metrics) 309 | } 310 | } 311 | 312 | implicit def ndcgAtKConverter[T] 313 | : TfmaConverter[RankingPrediction[T], (Double, Long), NdcgAtK[T]] = 314 | new TfmaConverter[RankingPrediction[T], (Double, Long), NdcgAtK[T]] { 315 | override def convertToTfmaProto( 316 | underlying: NdcgAtK[T] 317 | ): Aggregator[RankingPrediction[T], (Double, Long), EvalResult] = 318 | underlying.andThenPresent { ndcgAtK => 319 | val metrics = buildDoubleMetric("Noether_NdcgAtK", ndcgAtK) 320 | EvalResult(metrics) 321 | } 322 | } 323 | 324 | implicit def precisionAtK[T] 325 | : TfmaConverter[RankingPrediction[T], (Double, Long), PrecisionAtK[T]] = 326 | new TfmaConverter[RankingPrediction[T], (Double, Long), PrecisionAtK[T]] { 327 | override def convertToTfmaProto( 328 | underlying: PrecisionAtK[T] 329 | ): Aggregator[RankingPrediction[T], (Double, Long), EvalResult] = 330 | underlying.andThenPresent { precisionAtK => 331 | val metrics = buildDoubleMetric("Noether_PrecisionAtK", precisionAtK) 332 | EvalResult(metrics) 333 | } 334 | } 335 | 336 | implicit def calibrationHistogram: TfmaConverter[Prediction[Double, Double], Map[ 337 | Double, 338 | (Double, Double, Long) 339 | ], CalibrationHistogram] = 340 | new TfmaConverter[ 341 | Prediction[Double, Double], 342 | Map[Double, (Double, Double, Long)], 343 | CalibrationHistogram 344 | ] { 345 | override def convertToTfmaProto( 346 | underlying: CalibrationHistogram 347 | ): Aggregator[Prediction[Double, Double], Map[Double, (Double, Double, Long)], EvalResult] = 348 | underlying.andThenPresent { calibrationHistogram => 349 | val plot = buildCalibrationHistogramPlot(calibrationHistogram) 350 | EvalResult(Plot.CalibrationHistogram(plot)) 351 | } 352 | } 353 | } 354 | --------------------------------------------------------------------------------