├── NOTICE
├── project
├── build.properties
└── plugins.sbt
├── catalog-info.yaml
├── .github
├── dependabot.yml
└── workflows
│ ├── release.yml
│ └── ci.yml
├── .gitignore
├── .scalafix.conf
├── .scalafmt.conf
├── core
└── src
│ ├── main
│ └── scala
│ │ └── com
│ │ └── spotify
│ │ └── noether
│ │ ├── package.scala
│ │ ├── Prediction.scala
│ │ ├── LogLoss.scala
│ │ ├── ErrorRateSummary.scala
│ │ ├── BinaryConfusionMatrix.scala
│ │ ├── ConfusionMatrix.scala
│ │ ├── MeanAveragePrecision.scala
│ │ ├── MultiAggregatorMap.scala
│ │ ├── PrecisionAtK.scala
│ │ ├── NdcgAtK.scala
│ │ ├── CalibrationHistogram.scala
│ │ ├── AUC.scala
│ │ └── ClassificationReport.scala
│ └── test
│ └── scala
│ └── com
│ └── spotify
│ └── noether
│ ├── MeanAveragePrecisionTest.scala
│ ├── RankingData.scala
│ ├── BinaryConfusionMatrixTest.scala
│ ├── LogLossTest.scala
│ ├── ErrorRateSummaryTest.scala
│ ├── NdcgAtKTest.scala
│ ├── PrecisionAtKTest.scala
│ ├── AggregatorTest.scala
│ ├── ConfusionMatrixTest.scala
│ ├── AUCTest.scala
│ ├── ClassificationReportTest.scala
│ └── CalibrationHistogramTest.scala
├── tfx
└── src
│ ├── main
│ ├── scala
│ │ └── com
│ │ │ └── spotify
│ │ │ └── noether
│ │ │ └── tfx
│ │ │ ├── package.scala
│ │ │ ├── Tfma.scala
│ │ │ ├── TfmaConverter.scala
│ │ │ └── TfmaImplicits.scala
│ └── protobuf
│ │ └── metrics_for_slice.proto
│ └── test
│ └── scala
│ └── com
│ └── spotify
│ └── noether
│ └── tfx
│ └── TfmaConverterTest.scala
├── examples
└── src
│ ├── test
│ └── scala
│ │ └── com
│ │ └── spotify
│ │ └── noether
│ │ ├── AggregatorExampleTest.scala
│ │ └── MultiAggregatorMapTests.scala
│ └── main
│ └── scala
│ └── com
│ └── spotify
│ └── noether
│ └── AggregatorExample.scala
├── benchmark
├── src
│ └── main
│ │ └── scala
│ │ └── com
│ │ └── spotify
│ │ └── noether
│ │ └── benchmark
│ │ └── CalibrationHistogramCreateBenchmark.scala
└── README.md
├── README.md
└── LICENSE
/NOTICE:
--------------------------------------------------------------------------------
1 | Noether
2 | Copyright 2018 Spotify AB
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version=1.8.2
2 |
--------------------------------------------------------------------------------
/catalog-info.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: backstage.io/v1alpha1
2 | kind: Resource
3 | metadata:
4 | name: noether
5 | spec:
6 | type: resource
7 | owner: flatmap
8 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: github-actions
4 | directory: "/"
5 | schedule:
6 | interval: daily
7 | time: "04:00"
8 | open-pull-requests-limit: 10
9 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | .bigquery
3 | .repl
4 | target
5 | editorial
6 | top
7 | data
8 | normed
9 | tracks
10 | .DS_Store
11 | *.pyc
12 | *.spl
13 | *.spi
14 | local_data
15 | *.iml
16 |
17 | .metals
18 | .bloop
19 | .bsp
20 |
--------------------------------------------------------------------------------
/.scalafix.conf:
--------------------------------------------------------------------------------
1 | rules = [
2 | RemoveUnused,
3 | LeakingImplicitClassVal
4 | ProcedureSyntax
5 | ExplicitResultTypes
6 | ]
7 |
8 | ExplicitResultTypes.memberKind = [Def, Val, Var]
9 | ExplicitResultTypes.memberVisibility = [Public]
10 | ExplicitResultTypes.skipSimpleDefinition = false
11 |
--------------------------------------------------------------------------------
/.scalafmt.conf:
--------------------------------------------------------------------------------
1 | version = "3.7.4"
2 | maxColumn = 100
3 | runner.dialect = scala213
4 |
5 | binPack.literalArgumentLists = true
6 |
7 | continuationIndent {
8 | callSite = 2
9 | defnSite = 2
10 | }
11 |
12 | newlines {
13 | alwaysBeforeMultilineDef = false
14 | sometimesBeforeColonInMethodReturnType = true
15 | }
16 |
17 | docstrings.style = Asterisk
18 |
19 | project.git = false
20 |
21 | rewrite {
22 | rules = [PreferCurlyFors, RedundantBraces, RedundantParens, SortImports, SortModifiers]
23 | redundantBraces.generalExpressions = false
24 | redundantBraces.maxLines = 1
25 | }
26 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 | on:
3 | push:
4 | branches: [main]
5 | tags: ["*"]
6 | jobs:
7 | publish:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/checkout@v3
11 | - uses: actions/setup-java@v3
12 | with:
13 | distribution: temurin
14 | java-version: 17
15 | cache: sbt
16 | - uses: olafurpg/setup-gpg@v3
17 | - name: Publish ${{ github.ref }}
18 | run: sbt ci-release
19 | env:
20 | PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }}
21 | PGP_SECRET: ${{ secrets.PGP_SECRET }}
22 | SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }}
23 | SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }}
24 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/spotify/noether/package.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify
19 |
20 | package object noether {
21 | type RankingPrediction[T] = Prediction[Array[T], Array[T]]
22 | }
23 |
--------------------------------------------------------------------------------
/tfx/src/main/scala/com/spotify/noether/tfx/package.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | package object tfx extends TfmaImplicits {
21 | type BinaryPred = Prediction[Boolean, Double]
22 | }
23 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("com.github.sbt" % "sbt-unidoc" % "0.5.0")
2 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.0.8")
3 | addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.0")
4 | addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.5")
5 | addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "1.1.2")
6 | addSbtPlugin("com.github.sbt" % "sbt-ghpages" % "0.7.0")
7 | addSbtPlugin("com.typesafe.sbt" % "sbt-site" % "1.4.1")
8 | addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.6")
9 | addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.11.0")
10 | addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.5.12")
11 | addSbtPlugin("ch.epfl.lamp" % "sbt-dotty" % "0.5.5")
12 |
13 | libraryDependencies ++= Seq(
14 | "com.thesamet.scalapb" %% "compilerplugin" % "0.11.13"
15 | )
16 |
17 | dependencyOverrides += "org.scala-lang.modules" %% "scala-xml" % "2.1.0"
18 |
--------------------------------------------------------------------------------
/examples/src/test/scala/com/spotify/noether/AggregatorExampleTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import org.scalatest.flatspec.AnyFlatSpec
21 | import org.scalatest.matchers.should.Matchers
22 |
23 | class AggregatorExampleTest extends AnyFlatSpec with Matchers {
24 | it should "not fail when executing example" in {
25 | AggregatorExample.main(Array.empty)
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/spotify/noether/MeanAveragePrecisionTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import org.scalactic.{Equality, TolerantNumerics}
21 |
22 | class MeanAveragePrecisionTest extends AggregatorTest {
23 | import RankingData._
24 |
25 | implicit private val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.1)
26 |
27 | it should "compute map for rankings" in {
28 | assert(run(MeanAveragePrecision[Int]())(rankingData) === 0.355026)
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/spotify/noether/Prediction.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | /**
21 | * Generic Prediction Object used by most aggregators
22 | *
23 | * @param actual
24 | * Real value for this entry. Also normally seen as label.
25 | * @param predicted
26 | * Predicted value. Can be a class or a score depending on the aggregator.
27 | * @tparam L
28 | * Type of the Real Value
29 | * @tparam S
30 | * Type of the Predicted Value
31 | */
32 | final case class Prediction[L, S](actual: L, predicted: S) {
33 | override def toString: String = s"$actual,$predicted"
34 | }
35 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/spotify/noether/RankingData.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | object RankingData {
21 | def rankingData: Seq[RankingPrediction[Int]] =
22 | Seq(
23 | Prediction(Array(1, 2, 3, 4, 5), Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5)),
24 | Prediction(Array(1, 2, 3), Array(4, 1, 5, 6, 2, 7, 3, 8, 9, 10)),
25 | Prediction(Array.empty[Int], Array(1, 2, 3, 4, 5))
26 | )
27 |
28 | def smallRankingData: Seq[RankingPrediction[Int]] =
29 | Seq(
30 | Prediction(Array(1, 2, 3, 4, 5), Array(1, 6, 2)),
31 | Prediction(Array(1, 2, 3), Array.empty[Int])
32 | )
33 | }
34 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/spotify/noether/LogLoss.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import com.twitter.algebird.{Aggregator, Semigroup}
21 |
22 | /**
23 | * LogLoss of the predictions.
24 | */
25 | case object LogLoss extends Aggregator[Prediction[Int, List[Double]], (Double, Long), Double] {
26 | def prepare(input: Prediction[Int, List[Double]]): (Double, Long) =
27 | (math.log(input.predicted(input.actual)), 1L)
28 |
29 | def semigroup: Semigroup[(Double, Long)] =
30 | implicitly[Semigroup[(Double, Long)]]
31 |
32 | def present(score: (Double, Long)): Double = -1 * (score._1 / score._2)
33 | }
34 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/spotify/noether/BinaryConfusionMatrixTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | class BinaryConfusionMatrixTest extends AggregatorTest {
21 | it should "return correct scores" in {
22 | val data = List(
23 | (false, 0.1),
24 | (false, 0.6),
25 | (false, 0.2),
26 | (true, 0.2),
27 | (true, 0.8),
28 | (true, 0.7),
29 | (true, 0.6)
30 | ).map { case (pred, s) => Prediction(pred, s) }
31 |
32 | val matrix = run(BinaryConfusionMatrix())(data)
33 |
34 | assert(matrix(1, 1) === 3L)
35 | assert(matrix(0, 1) === 1L)
36 | assert(matrix(1, 0) === 1L)
37 | assert(matrix(0, 0) === 2L)
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/examples/src/main/scala/com/spotify/noether/AggregatorExample.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import com.twitter.algebird.MultiAggregator
21 |
22 | object AggregatorExample {
23 | def main(args: Array[String]): Unit = {
24 | val multiAggregator =
25 | MultiAggregator((AUC(ROC), AUC(PR), ClassificationReport(), BinaryConfusionMatrix()))
26 | .andThenPresent { case (roc, pr, report, cm) =>
27 | (roc, pr, report.accuracy, report.recall, report.precision, cm(1, 1), cm(0, 0))
28 | }
29 |
30 | val predictions = List(Prediction(false, 0.1), Prediction(false, 0.6), Prediction(true, 0.9))
31 |
32 | multiAggregator.apply(predictions)
33 | ()
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/spotify/noether/LogLossTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import org.scalactic.{Equality, TolerantNumerics}
21 |
22 | class LogLossTest extends AggregatorTest {
23 | implicit private val doubleEq: Equality[Double] =
24 | TolerantNumerics.tolerantDoubleEquality(0.1)
25 | private val classes = 10
26 | private def s(idx: Int, score: Double): List[Double] =
27 | 0.until(classes).map(i => if (i == idx) score else 0.0).toList
28 |
29 | it should "return correct scores" in {
30 | val data = List((s(0, 0.8), 0), (s(1, 0.6), 1), (s(2, 0.7), 2)).map { case (scores, label) =>
31 | Prediction(label, scores)
32 | }
33 |
34 | assert(run(LogLoss)(data) === 0.363548039673)
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/spotify/noether/ErrorRateSummary.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import com.twitter.algebird.{Aggregator, Semigroup}
21 |
22 | /**
23 | * Measurement of what percentage of values were predicted incorrectly.
24 | */
25 | case object ErrorRateSummary
26 | extends Aggregator[Prediction[Int, List[Double]], (Double, Long), Double] {
27 | def prepare(input: Prediction[Int, List[Double]]): (Double, Long) = {
28 | val best = input.predicted.zipWithIndex.maxBy(_._1)._2
29 | if (best == input.actual) (0.0, 1L) else (1.0, 1L)
30 | }
31 |
32 | def semigroup: Semigroup[(Double, Long)] =
33 | implicitly[Semigroup[(Double, Long)]]
34 |
35 | def present(score: (Double, Long)): Double = score._1 / score._2
36 | }
37 |
--------------------------------------------------------------------------------
/tfx/src/main/scala/com/spotify/noether/tfx/Tfma.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether.tfx
19 |
20 | import com.twitter.algebird.Aggregator
21 |
22 | object Tfma {
23 | trait ConversionOps[A, B, T <: Aggregator[A, B, _]] {
24 | val self: T
25 | val converter: TfmaConverter[A, B, T]
26 | def asTfmaProto: Aggregator[A, B, EvalResult] = converter.convertToTfmaProto(self)
27 | }
28 |
29 | object ConversionOps {
30 | def apply[A, B, T <: Aggregator[A, B, _]](
31 | instance: T,
32 | tfmaConverter: TfmaConverter[A, B, T]
33 | ): ConversionOps[A, B, T] =
34 | new ConversionOps[A, B, T] {
35 | override val self: T = instance
36 | override val converter: TfmaConverter[A, B, T] = tfmaConverter
37 | }
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/spotify/noether/ErrorRateSummaryTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import org.scalactic.TolerantNumerics
21 | import org.scalactic.Equality
22 |
23 | class ErrorRateSummaryTest extends AggregatorTest {
24 | implicit private val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.1)
25 | private val classes = 10
26 | private def s(idx: Int): List[Double] =
27 | 0.until(classes).map(i => if (i == idx) 1.0 else 0.0).toList
28 |
29 | it should "return correct scores" in {
30 | val data =
31 | List((s(1), 1), (s(3), 1), (s(5), 5), (s(2), 3), (s(0), 0), (s(8), 1)).map {
32 | case (scores, label) => Prediction(label, scores)
33 | }
34 |
35 | assert(run(ErrorRateSummary)(data) === 0.5)
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/examples/src/test/scala/com/spotify/noether/MultiAggregatorMapTests.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2020 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import com.twitter.algebird._
21 | import org.scalatest.flatspec.AnyFlatSpec
22 | import org.scalatest.matchers.must.Matchers._
23 |
24 | class MultiAggregatorMapTests extends AnyFlatSpec {
25 |
26 | it must "aggregate into a list of individual aggregator results" in {
27 |
28 | val multiListAgg = MultiAggregatorMap[Long, Long, Long](
29 | List(
30 | "min" -> Aggregator.min,
31 | "max" -> Aggregator.max,
32 | "size" -> Aggregator.size
33 | )
34 | )
35 |
36 | val result = multiListAgg(List(0, 1, 2, 3, 4, 5))
37 |
38 | result("min") mustBe 0L
39 | result("max") mustBe 5L
40 | result("size") mustBe 6L
41 | }
42 |
43 | }
44 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/spotify/noether/NdcgAtKTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import org.scalactic.{Equality, TolerantNumerics}
21 |
22 | class NdcgAtKTest extends AggregatorTest {
23 | import RankingData._
24 | implicit private val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.1)
25 |
26 | it should "compute ndcg for rankings" in {
27 | assert(run(NdcgAtK[Int](3))(rankingData) === 1.0 / 3)
28 | assert(run(NdcgAtK[Int](5))(rankingData) === 0.328788)
29 | assert(run(NdcgAtK[Int](10))(rankingData) === 0.487913)
30 | assert(run(NdcgAtK[Int](15))(rankingData) === run(NdcgAtK[Int](10))(rankingData))
31 | }
32 |
33 | it should "compute ndcg for rankings with few predictions" in {
34 | assert(run(NdcgAtK[Int](1))(smallRankingData) === 0.5)
35 | assert(run(NdcgAtK[Int](2))(smallRankingData) === 0.30657)
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/tfx/src/main/scala/com/spotify/noether/tfx/TfmaConverter.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether.tfx
19 |
20 | import com.twitter.algebird.Aggregator
21 | import tensorflow_model_analysis.MetricsForSliceOuterClass._
22 |
23 | trait TfmaConverter[A, B, T <: Aggregator[A, B, _]] {
24 | def convertToTfmaProto(underlying: T): Aggregator[A, B, EvalResult]
25 | }
26 |
27 | sealed trait Plot {
28 | def plotData: PlotsForSlice
29 | }
30 | object Plot {
31 | case class CalibrationHistogram(plotData: PlotsForSlice) extends Plot
32 | case class ConfusionMatrix(plotData: PlotsForSlice) extends Plot
33 | }
34 |
35 | case class EvalResult(metrics: Option[MetricsForSlice], plots: Option[Plot])
36 | object EvalResult {
37 | def apply(metrics: MetricsForSlice): EvalResult = EvalResult(Some(metrics), None)
38 | def apply(metrics: MetricsForSlice, plot: Plot): EvalResult =
39 | EvalResult(Some(metrics), Some(plot))
40 | def apply(plot: Plot): EvalResult = EvalResult(None, Some(plot))
41 | }
42 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/spotify/noether/BinaryConfusionMatrix.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import breeze.linalg.DenseMatrix
21 | import com.twitter.algebird.{Aggregator, Semigroup}
22 |
23 | /**
24 | * Special Case for a Binary Confusion Matrix to make it easier to compose with other binary
25 | * aggregators
26 | *
27 | * @param threshold
28 | * Threshold to apply on predictions
29 | */
30 | case class BinaryConfusionMatrix(threshold: Double = 0.5)
31 | extends Aggregator[Prediction[Boolean, Double], Map[(Int, Int), Long], DenseMatrix[Long]] {
32 | private val confusionMatrix = ConfusionMatrix(Seq(0, 1))
33 |
34 | def prepare(input: Prediction[Boolean, Double]): Map[(Int, Int), Long] = {
35 | val pred = Prediction(if (input.actual) 1 else 0, if (input.predicted > threshold) 1 else 0)
36 | confusionMatrix.prepare(pred)
37 | }
38 | def semigroup: Semigroup[Map[(Int, Int), Long]] = confusionMatrix.semigroup
39 | def present(m: Map[(Int, Int), Long]): DenseMatrix[Long] =
40 | confusionMatrix.present(m)
41 | }
42 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/spotify/noether/ConfusionMatrix.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import breeze.linalg.DenseMatrix
21 | import com.twitter.algebird.{Aggregator, Semigroup}
22 |
23 | /**
24 | * Generic Consfusion Matrix Aggregator for any dimension. Thresholds must be applied to make a
25 | * prediction prior to using this aggregator.
26 | *
27 | * @param labels
28 | * List of possible label values
29 | */
30 | final case class ConfusionMatrix(labels: Seq[Int])
31 | extends Aggregator[Prediction[Int, Int], Map[(Int, Int), Long], DenseMatrix[Long]] {
32 | def prepare(input: Prediction[Int, Int]): Map[(Int, Int), Long] =
33 | Map((input.predicted, input.actual) -> 1L)
34 |
35 | def semigroup: Semigroup[Map[(Int, Int), Long]] =
36 | Semigroup.mapSemigroup[(Int, Int), Long]
37 |
38 | def present(m: Map[(Int, Int), Long]): DenseMatrix[Long] = {
39 | val mat = DenseMatrix.zeros[Long](labels.size, labels.size)
40 | for {
41 | i <- labels
42 | j <- labels
43 | } {
44 | mat(i, j) = m.getOrElse((i, j), 0L)
45 | }
46 | mat
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/spotify/noether/PrecisionAtKTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import org.scalactic.{Equality, TolerantNumerics}
21 |
22 | class PrecisionAtKTest extends AggregatorTest {
23 | import RankingData._
24 | implicit private val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.1)
25 |
26 | it should "compute precisionAtK for rankings" in {
27 | assert(run(PrecisionAtK[Int](1))(rankingData) === 1.0 / 3)
28 | assert(run(PrecisionAtK[Int](2))(rankingData) === 1.0 / 3)
29 | assert(run(PrecisionAtK[Int](3))(rankingData) === 1.0 / 3)
30 | assert(run(PrecisionAtK[Int](4))(rankingData) === 0.75 / 3)
31 | assert(run(PrecisionAtK[Int](5))(rankingData) === 0.8 / 3)
32 | assert(run(PrecisionAtK[Int](10))(rankingData) === 0.8 / 3)
33 | assert(run(PrecisionAtK[Int](15))(rankingData) === 8.0 / 45)
34 | }
35 |
36 | it should "compute precisionAtK for rankings with few predictions" in {
37 | assert(run(PrecisionAtK[Int](1))(smallRankingData) === 0.5)
38 | assert(run(PrecisionAtK[Int](2))(smallRankingData) === 0.25)
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: ci
2 | on: [push, pull_request]
3 |
4 | jobs:
5 | checks:
6 | runs-on: ubuntu-latest
7 | steps:
8 | - uses: actions/checkout@v3
9 | - name: Java 11 setup
10 | uses: actions/setup-java@v3
11 | with:
12 | distribution: temurin
13 | java-version: 11
14 | cache: sbt
15 | - run: sbt "; +scalafmtCheckAll; scalafmtSbtCheck" "; scalafixEnable; scalafixAll --check"
16 | test:
17 | runs-on: ubuntu-latest
18 | steps:
19 | - uses: actions/checkout@v3
20 | - name: Java ${{matrix.java}} setup
21 | uses: actions/setup-java@v3
22 | with:
23 | distribution: temurin
24 | java-version: ${{matrix.java}}
25 | cache: sbt
26 | - if: startsWith(matrix.scala,'2.13')
27 | run: |
28 | sbt coverage "++${{matrix.scala}} test" coverageReport
29 | bash <(curl -s https://codecov.io/bash)
30 | - if: "!startsWith(matrix.scala,'2.13')"
31 | run: sbt "++${{matrix.scala}} test"
32 | strategy:
33 | matrix:
34 | java:
35 | - 11
36 | scala:
37 | - 2.11.12
38 | - 2.12.17
39 | - 2.13.10
40 | - 3.2.2
41 | mimaReport:
42 | runs-on: ubuntu-latest
43 | steps:
44 | - uses: actions/checkout@v3
45 | - name: Java ${{matrix.java}} setup
46 | uses: actions/setup-java@v3
47 | with:
48 | distribution: temurin
49 | java-version: ${{matrix.java}}
50 | cache: sbt
51 | - run: sbt "++${{matrix.scala}} mimaReportBinaryIssues"
52 | strategy:
53 | matrix:
54 | java:
55 | - 11
56 | scala:
57 | - 2.11.12
58 | - 2.12.17
59 | - 2.13.10
60 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/spotify/noether/MeanAveragePrecision.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import com.twitter.algebird.{Aggregator, Semigroup}
21 |
22 | /**
23 | * Returns the mean average precision (MAP) of all the predictions. If a query has an empty ground
24 | * truth set, the average precision will be zero
25 | */
26 | case class MeanAveragePrecision[T]()
27 | extends Aggregator[RankingPrediction[T], (Double, Long), Double] {
28 | def prepare(input: RankingPrediction[T]): (Double, Long) = {
29 | val labSet = input.actual.toSet
30 | if (labSet.nonEmpty) {
31 | var i = 0
32 | var cnt = 0
33 | var precSum = 0.0
34 | val n = input.predicted.length
35 | while (i < n) {
36 | if (labSet.contains(input.predicted(i))) {
37 | cnt += 1
38 | precSum += cnt.toDouble / (i + 1)
39 | }
40 | i += 1
41 | }
42 | (precSum / labSet.size, 1L)
43 | } else {
44 | (0.0, 1L)
45 | }
46 | }
47 |
48 | def semigroup: Semigroup[(Double, Long)] = implicitly[Semigroup[(Double, Long)]]
49 |
50 | def present(score: (Double, Long)): Double = score._1 / score._2
51 | }
52 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/spotify/noether/AggregatorTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
21 |
22 | import com.twitter.algebird.Aggregator
23 | import org.scalatest.flatspec.AnyFlatSpec
24 | import org.scalatest.matchers.should.Matchers
25 |
26 | trait AggregatorTest extends AnyFlatSpec with Matchers {
27 | def run[A, B, C](aggregator: Aggregator[A, B, C])(as: Seq[A]): C = {
28 | val bs = as.map(aggregator.prepare _ compose ensureSerializable)
29 | val b = ensureSerializable(aggregator.reduce(bs))
30 | ensureSerializable(aggregator.present(b))
31 | }
32 |
33 | private def serializeToByteArray(value: Any): Array[Byte] = {
34 | val buffer = new ByteArrayOutputStream()
35 | val oos = new ObjectOutputStream(buffer)
36 | oos.writeObject(value)
37 | buffer.toByteArray
38 | }
39 |
40 | private def deserializeFromByteArray(encodedValue: Array[Byte]): AnyRef = {
41 | val ois = new ObjectInputStream(new ByteArrayInputStream(encodedValue))
42 | ois.readObject()
43 | }
44 |
45 | private def ensureSerializable[T](value: T): T =
46 | deserializeFromByteArray(serializeToByteArray(value)).asInstanceOf[T]
47 | }
48 |
--------------------------------------------------------------------------------
/benchmark/src/main/scala/com/spotify/noether/benchmark/CalibrationHistogramCreateBenchmark.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 | package benchmark
20 |
21 | import com.spotify.noether.benchmark.CalibrationHistogramCreateBenchmark.CalibrationHistogramState
22 | import org.openjdk.jmh.annotations._
23 |
24 | import scala.util.Random
25 |
26 | object PredictionUtils {
27 | def generatePredictions(nbPrediction: Int): Seq[Prediction[Boolean, Double]] =
28 | Seq.fill(nbPrediction)(Prediction(Random.nextBoolean(), Random.nextDouble()))
29 | }
30 |
31 | object CalibrationHistogramCreateBenchmark {
32 | @State(Scope.Benchmark)
33 | class CalibrationHistogramState() {
34 | @Param(Array("100", "1000", "3000"))
35 | var nbElement = 0
36 |
37 | @Param(Array("100", "200", "300"))
38 | var nbBucket = 0
39 |
40 | @Param(Array("0.1", "0.2", "0.3"))
41 | var lowerBound = 0.0
42 |
43 | @Param(Array("0.2", "0.4", "0.5"))
44 | var upperBound = 0.0
45 |
46 | var histogram: CalibrationHistogram = _
47 |
48 | @Setup
49 | def setup(): Unit =
50 | histogram = CalibrationHistogram(lowerBound, upperBound, nbBucket)
51 | }
52 | }
53 |
54 | class CalibrationHistogramCreateBenchmark {
55 | @Benchmark
56 | def createCalibrationHistogram(calibrationHistogramState: CalibrationHistogramState): Double =
57 | calibrationHistogramState.histogram.bucketSize
58 | }
59 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/spotify/noether/ConfusionMatrixTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import breeze.linalg.DenseMatrix
21 |
22 | class ConfusionMatrixTest extends AggregatorTest {
23 | it should "return correct confusion matrix" in {
24 | val data =
25 | List(
26 | (0, 0),
27 | (0, 0),
28 | (0, 0),
29 | (0, 1),
30 | (0, 1),
31 | (1, 0),
32 | (1, 0),
33 | (1, 0),
34 | (1, 0),
35 | (1, 1),
36 | (1, 1),
37 | (2, 1),
38 | (2, 2),
39 | (2, 2),
40 | (2, 2)
41 | ).map { case (p, a) => Prediction(a, p) }
42 |
43 | val labels = Seq(0, 1, 2)
44 | val actual = run(ConfusionMatrix(labels))(data)
45 |
46 | val mat = DenseMatrix.zeros[Long](labels.size, labels.size)
47 | mat(0, 0) = 3L
48 | mat(0, 1) = 2L
49 | mat(0, 2) = 0L
50 | mat(1, 0) = 4L
51 | mat(1, 1) = 2L
52 | mat(1, 2) = 0L
53 | mat(2, 0) = 0L
54 | mat(2, 1) = 1L
55 | mat(2, 2) = 3L
56 | assert(actual == mat)
57 | }
58 |
59 | it should "return correct scores" in {
60 | val data = List(
61 | (0, 0),
62 | (0, 1),
63 | (0, 0),
64 | (1, 0),
65 | (1, 1),
66 | (1, 1),
67 | (1, 1)
68 | ).map { case (s, pred) => Prediction(pred, s) }
69 |
70 | val matrix = run(ConfusionMatrix(Seq(0, 1)))(data)
71 |
72 | assert(matrix(1, 1) === 3L)
73 | assert(matrix(0, 1) === 1L)
74 | assert(matrix(1, 0) === 1L)
75 | assert(matrix(0, 0) === 2L)
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/spotify/noether/MultiAggregatorMap.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2020 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import com.twitter.algebird.{Aggregator, Semigroup}
21 | import scala.collection.mutable.ArrayBuffer
22 |
23 | /**
24 | * Aggregator which combines an unbounded list of other aggregators. Each aggregator in the list is
25 | * tagged by a string. The string(aka name) could be used to retrieve the aggregated value from the
26 | * Map emitted by the "present" function.
27 | */
28 | case class MultiAggregatorMap[-A, B, +C](aggregatorsMap: List[(String, Aggregator[A, B, C])])
29 | extends Aggregator[A, List[B], Map[String, C]] {
30 |
31 | private[this] val aggregators = aggregatorsMap.map(_._2)
32 |
33 | def prepare(input: A): List[B] =
34 | aggregators.map(_.prepare(input))
35 |
36 | def semigroup: Semigroup[List[B]] = new Semigroup[List[B]] {
37 | def plus(x: List[B], y: List[B]): List[B] = {
38 | var i = 0
39 | val resultList = new ArrayBuffer[B]
40 | while (i < aggregators.length) {
41 | resultList.append(aggregators(i).semigroup.plus(x(i), y(i)))
42 | i += 1
43 | }
44 | resultList.toList
45 | }
46 | }
47 |
48 | def present(reduction: List[B]): Map[String, C] = {
49 | var i = 0
50 | val resultList = new ArrayBuffer[(String, C)]
51 | val aggregatorsMapKeys = aggregatorsMap.map(_._1)
52 | while (i < aggregators.length) {
53 | resultList.append(aggregatorsMapKeys(i) -> aggregators(i).present(reduction(i)))
54 | i += 1
55 | }
56 | resultList.toMap
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/spotify/noether/AUCTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import org.scalactic.TolerantNumerics
21 | import org.scalactic.Equality
22 |
23 | class AUCTest extends AggregatorTest {
24 | implicit private val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.1)
25 |
26 | private val data =
27 | List(
28 | (0.1, false),
29 | (0.1, true),
30 | (0.4, false),
31 | (0.6, false),
32 | (0.6, true),
33 | (0.6, true),
34 | (0.8, true)
35 | ).map { case (s, pred) => Prediction(pred, s) }
36 |
37 | it should "return ROC AUC" in {
38 | assert(run(AUC(ROC, samples = 50))(data) === 0.7)
39 | }
40 |
41 | it should "return PR AUC" in {
42 | assert(run(AUC(PR, samples = 50))(data) === 0.83)
43 | }
44 |
45 | it should "return points of a PR Curve" in {
46 | val expected = Array(
47 | (0.0, 1.0),
48 | (0.0, 1.0),
49 | (0.25, 1.0),
50 | (0.75, 0.75),
51 | (0.75, 0.6),
52 | (1.0, 0.5714285714285714)
53 | ).map { case (a, b) => MetricCurvePoint(a, b) }
54 | assert(run(Curve(PR, samples = 5))(data).points === MetricCurvePoints(expected).points)
55 | }
56 |
57 | it should "return points of a ROC Curve" in {
58 | val expected = Array(
59 | (0.0, 0.0),
60 | (0.0, 0.25),
61 | (0.3333333333333333, 0.75),
62 | (0.6666666666666666, 0.75),
63 | (1.0, 1.0),
64 | (1.0, 1.0)
65 | ).map { case (a, b) => MetricCurvePoint(a, b) }
66 | assert(run(Curve(ROC, samples = 5))(data).points === MetricCurvePoints(expected).points)
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/spotify/noether/ClassificationReportTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import org.scalactic.TolerantNumerics
21 | import org.scalactic.Equality
22 |
23 | class ClassificationReportTest extends AggregatorTest {
24 | implicit private val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.1)
25 |
26 | it should "return correct scores" in {
27 | val data = List(
28 | (0.1, false),
29 | (0.1, true),
30 | (0.4, false),
31 | (0.6, false),
32 | (0.6, true),
33 | (0.6, true),
34 | (0.8, true)
35 | ).map { case (s, pred) => Prediction(pred, s) }
36 |
37 | val score = run(ClassificationReport())(data)
38 |
39 | assert(score.recall === 0.75)
40 | assert(score.precision === 0.75)
41 | assert(score.fscore === 0.75)
42 | assert(score.fpr === 0.333)
43 | }
44 |
45 | it should "support multiclass reports" in {
46 | val predictions = Seq(
47 | (0, 0),
48 | (0, 0),
49 | (0, 1),
50 | (1, 1),
51 | (1, 1),
52 | (1, 0),
53 | (1, 2),
54 | (2, 2),
55 | (2, 2),
56 | (2, 2)
57 | ).map { case (p, a) => Prediction(a, p) }
58 |
59 | val reports = run(MultiClassificationReport(Seq(0, 1, 2)))(predictions)
60 |
61 | val report0 = reports(0)
62 | assert(report0.recall == 2.0 / 3.0)
63 | assert(report0.precision == 2.0 / 3.0)
64 | val report1 = reports(1)
65 | assert(report1.recall == 2.0 / 3.0)
66 | assert(report1.precision == 0.5)
67 | val report2 = reports(2)
68 | assert(report2.recall == 0.75)
69 | assert(report2.precision == 1.0)
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/spotify/noether/PrecisionAtK.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import com.twitter.algebird.{Aggregator, Semigroup}
21 |
22 | /**
23 | * Compute the average precision of all the predictions, truncated at ranking position k.
24 | *
25 | * If for a prediction, the ranking algorithm returns n (n is less than k) results, the precision
26 | * value will be computed as #(relevant items retrieved) / k. This formula also applies when the
27 | * size of the ground truth set is less than k.
28 | *
29 | * If a prediction has an empty ground truth set, zero will be used as precision together
30 | *
31 | * See the following paper for detail:
32 | *
33 | * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
34 | *
35 | * @param k
36 | * the position to compute the truncated precision, must be positive
37 | */
38 | case class PrecisionAtK[T](k: Int)
39 | extends Aggregator[RankingPrediction[T], (Double, Long), Double] {
40 | require(k > 0, "ranking position k should be positive")
41 | def prepare(input: RankingPrediction[T]): (Double, Long) = {
42 | val labSet = input.actual.toSet
43 | if (labSet.nonEmpty) {
44 | val n = math.min(input.predicted.length, k)
45 | var i = 0
46 | var cnt = 0
47 | while (i < n) {
48 | if (labSet.contains(input.predicted(i))) {
49 | cnt += 1
50 | }
51 | i += 1
52 | }
53 | (cnt.toDouble / k, 1L)
54 | } else {
55 | (0.0, 1L)
56 | }
57 | }
58 |
59 | def semigroup: Semigroup[(Double, Long)] = implicitly[Semigroup[(Double, Long)]]
60 |
61 | def present(score: (Double, Long)): Double = score._1 / score._2
62 | }
63 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/spotify/noether/NdcgAtK.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import com.twitter.algebird.{Aggregator, Semigroup}
21 |
22 | /**
23 | * Compute the average NDCG value of all the predictions, truncated at ranking position k. The
24 | * discounted cumulative gain at position k is computed as: sum,,i=1,,^k^ (2^{relevance of ''i''th
25 | * item}^ - 1) / log(i + 1), and the NDCG is obtained by dividing the DCG value on the ground truth
26 | * set. In the current implementation, the relevance value is binary. If a query has an empty ground
27 | * truth set, zero will be used as ndcg
28 | *
29 | * See the following paper for detail:
30 | *
31 | * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
32 | *
33 | * @param k
34 | * the position to compute the truncated ndcg, must be positive
35 | */
36 | case class NdcgAtK[T](k: Int) extends Aggregator[RankingPrediction[T], (Double, Long), Double] {
37 | require(k > 0, "ranking position k should be positive")
38 | def prepare(input: RankingPrediction[T]): (Double, Long) = {
39 | val labSet = input.actual.toSet
40 |
41 | if (labSet.nonEmpty) {
42 | val labSetSize = labSet.size
43 | val n = math.min(math.max(input.predicted.length, labSetSize), k)
44 | var maxDcg = 0.0
45 | var dcg = 0.0
46 | var i = 0
47 | while (i < n) {
48 | val gain = 1.0 / math.log(i + 2.0)
49 | if (i < input.predicted.length && labSet.contains(input.predicted(i))) {
50 | dcg += gain
51 | }
52 | if (i < labSetSize) {
53 | maxDcg += gain
54 | }
55 | i += 1
56 | }
57 | (dcg / maxDcg, 1L)
58 | } else {
59 | (0.0, 1L)
60 | }
61 | }
62 |
63 | def semigroup: Semigroup[(Double, Long)] = implicitly[Semigroup[(Double, Long)]]
64 |
65 | def present(score: (Double, Long)): Double = score._1 / score._2
66 | }
67 |
--------------------------------------------------------------------------------
/core/src/test/scala/com/spotify/noether/CalibrationHistogramTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 | package com.spotify.noether
18 |
19 | import org.scalactic.{Equality, TolerantNumerics}
20 |
21 | class CalibrationHistogramTest extends AggregatorTest {
22 | it should "return correct histogram" in {
23 | implicit val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.001)
24 | val data = Seq(
25 | (0.15, 1.15), // lb
26 | (0.288, 1.288), // rounding error puts this in (0.249, 0.288)
27 | (0.30, 1.30), // (0.288, 0.3269)
28 | (0.36, 1.36), // (0.3269, 0.365)
29 | (0.555, 1.555), // (0.5219, 0.5609)
30 | (1.2, 2.2), // ub
31 | (0.7, 1.7) // ub
32 | ).map { case (p, a) => Prediction(a, p) }
33 |
34 | val actual = run(CalibrationHistogram(0.21, 0.60, 10))(data)
35 |
36 | val expected = List(
37 | CalibrationHistogramBucket(Double.NegativeInfinity, 0.21, 1.0, 1.15, 0.15),
38 | CalibrationHistogramBucket(0.21, 0.249, 0.0, 0.0, 0.0),
39 | CalibrationHistogramBucket(0.249, 0.288, 1.0, 1.288, 0.288),
40 | CalibrationHistogramBucket(0.288, 0.327, 1.0, 1.30, 0.30),
41 | CalibrationHistogramBucket(0.327, 0.366, 1.0, 1.36, 0.36),
42 | CalibrationHistogramBucket(0.366, 0.405, 0.0, 0.0, 0.0),
43 | CalibrationHistogramBucket(0.405, 0.4449, 0.0, 0.0, 0.0),
44 | CalibrationHistogramBucket(0.444, 0.483, 0.0, 0.0, 0.0),
45 | CalibrationHistogramBucket(0.483, 0.522, 0.0, 0.0, 0.0),
46 | CalibrationHistogramBucket(0.522, 0.561, 1.0, 1.555, 0.555),
47 | CalibrationHistogramBucket(0.561, 0.6, 0.0, 0.0, 0.0),
48 | CalibrationHistogramBucket(0.6, Double.PositiveInfinity, 2.0, 3.9, 1.9)
49 | )
50 |
51 | assert(actual.length == expected.length)
52 | (0 until expected.length).foreach { i =>
53 | assert(actual(i).numPredictions === expected(i).numPredictions)
54 | assert(actual(i).sumPredictions === expected(i).sumPredictions)
55 | assert(actual(i).sumLabels === expected(i).sumLabels)
56 | assert(actual(i).lowerThresholdInclusive === expected(i).lowerThresholdInclusive)
57 | assert(actual(i).upperThresholdExclusive === expected(i).upperThresholdExclusive)
58 | }
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/benchmark/README.md:
--------------------------------------------------------------------------------
1 |
2 | # noether-benchmark
3 |
4 | This module describes benchmarking for the noether project.
5 |
6 | ## Run benchmark
7 |
8 | * Open an sbt shell in the project root:
9 |
10 | sbt
11 |
12 | * Then switch to the subproject `noetherBenchmark`
13 |
14 |
15 | sbt:noether> project noetherBenchmark
16 |
17 | # Run all the benchmarks..
18 | sbt:noetherBenchmark> jmh:run .*
19 |
20 | # Run specific benchmark (CalibrationHistogram benchmark)
21 | sbt:noetherBenchmark> jmh:run .*CalibrationHistogram.*
22 |
23 | ### Example
24 |
25 | jmh:run -t1 -f1 -wi 2 -i 3 .*Calibration.*
26 |
27 |
28 | The output should look something like this:
29 |
30 |
31 | [info] Benchmark (lowerBound) (nbBucket) (nbElement) (upperBound) Mode Cnt Score Error Units
32 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 100 0.2 thrpt 3 354495958,865 ± 19222766,245 ops/s
33 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 100 0.4 thrpt 3 354159610,708 ± 19887361,006 ops/s
34 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 100 0.5 thrpt 3 353401735,022 ± 9100858,662 ops/s
35 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 1000 0.2 thrpt 3 352173891,822 ± 26480103,060 ops/s
36 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 1000 0.4 thrpt 3 352870170,909 ± 21414322,507 ops/s
37 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 1000 0.5 thrpt 3 354991297,412 ± 37282062,985 ops/s
38 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 3000 0.2 thrpt 3 352385150,389 ± 25101115,471 ops/s
39 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 3000 0.4 thrpt 3 355695107,774 ± 27960629,389 ops/s
40 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 100 3000 0.5 thrpt 3 353599563,110 ± 35675109,158 ops/s
41 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 200 100 0.2 thrpt 3 354172265,477 ± 29726736,262 ops/s
42 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 200 100 0.4 thrpt 3 353407712,527 ± 15231956,560 ops/s
43 | [info] CalibrationHistogramCreateBenchmark.createCalibrationHistogram 0.1 200 100 0.5 thrpt 3 355430167,821 ± 53204747,979 ops/s
44 | ...
45 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/spotify/noether/CalibrationHistogram.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 | package com.spotify.noether
18 |
19 | import com.twitter.algebird.{Aggregator, Semigroup}
20 |
21 | import scala.math.floor
22 |
23 | /**
24 | * Histogram bucket.
25 | *
26 | * @param lowerThresholdInclusive
27 | * Lower bound on bucket, inclusive
28 | * @param upperThresholdExclusive
29 | * Upper bound on bucket, exclusive
30 | * @param numPredictions
31 | * Number of predictions in this bucket
32 | * @param sumLabels
33 | * Sum of label values for this bucket
34 | * @param sumPredictions
35 | * Sum of prediction values for this bucket
36 | */
37 | final case class CalibrationHistogramBucket(
38 | lowerThresholdInclusive: Double,
39 | upperThresholdExclusive: Double,
40 | numPredictions: Double,
41 | sumLabels: Double,
42 | sumPredictions: Double
43 | )
44 |
45 | /**
46 | * Split predictions into Tensorflow Model Analysis compatible CalibrationHistogramBucket buckets.
47 | *
48 | * If a prediction is less than the lower bound, it belongs to the bucket [-inf, lower bound) If it
49 | * is greater than or equal to the upper bound, it belongs to the bucket (upper bound, inf]
50 | *
51 | * @param lowerBound
52 | * Left boundary, inclusive
53 | * @param upperBound
54 | * Right boundary, exclusive
55 | * @param numBuckets
56 | * Number of buckets in the histogram
57 | */
58 |
59 | final case class CalibrationHistogram(
60 | lowerBound: Double = 0.0,
61 | upperBound: Double = 1.0,
62 | numBuckets: Int = 10
63 | ) extends Aggregator[
64 | Prediction[Double, Double],
65 | Map[Double, (Double, Double, Long)],
66 | List[CalibrationHistogramBucket]
67 | ] {
68 | val bucketSize: Double = (upperBound - lowerBound) / numBuckets.toDouble
69 |
70 | private def thresholdsFromBucket(b: Double): (Double, Double) = b match {
71 | case Double.PositiveInfinity => (upperBound, b)
72 | case Double.NegativeInfinity => (b, lowerBound)
73 | case _ => {
74 | (lowerBound + (b * bucketSize), lowerBound + (b * bucketSize) + bucketSize)
75 | }
76 | }
77 |
78 | def prepare(input: Prediction[Double, Double]): Map[Double, (Double, Double, Long)] = {
79 | val bucketNumber = input.predicted match {
80 | case p if p < lowerBound => Double.NegativeInfinity
81 | case p if p >= upperBound => Double.PositiveInfinity
82 | case _ =>
83 | floor((input.predicted - lowerBound) / bucketSize)
84 | }
85 |
86 | Map((bucketNumber, (input.predicted, input.actual, 1L)))
87 | }
88 |
89 | def semigroup: Semigroup[Map[Double, (Double, Double, Long)]] =
90 | Semigroup.mapSemigroup[Double, (Double, Double, Long)]
91 |
92 | def present(m: Map[Double, (Double, Double, Long)]): List[CalibrationHistogramBucket] = {
93 | val buckets = Vector(Double.NegativeInfinity) ++
94 | (0 until numBuckets).map(_.toDouble).toVector ++
95 | Vector(Double.PositiveInfinity)
96 | buckets.map { l =>
97 | val (lb, ub) = thresholdsFromBucket(l)
98 | m.get(l) match {
99 | case None => CalibrationHistogramBucket(lb, ub, 0.0, 0.0, 0.0)
100 | case Some((predictionSum, labelSum, numExamples)) =>
101 | CalibrationHistogramBucket(lb, ub, numExamples.toDouble, labelSum, predictionSum)
102 | }
103 | }.toList
104 | }
105 | }
106 |
--------------------------------------------------------------------------------
/tfx/src/main/protobuf/metrics_for_slice.proto:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | syntax = "proto3";
16 |
17 | import "google/protobuf/wrappers.proto";
18 |
19 | package tensorflow_model_analysis;
20 |
21 |
22 | // The value will be converted into an error message if we do not know its type.
23 | message UnknownType {
24 | string value = 2;
25 | string error = 1;
26 | }
27 |
28 | message BoundedValue {
29 | // The lower bound of the range.
30 | google.protobuf.DoubleValue lower_bound = 1;
31 | // The upper bound of the range.
32 | google.protobuf.DoubleValue upper_bound = 2;
33 | // Represents an exact value if the lower_bound and upper_bound are unset,
34 | // else it's an approximate value. For the approximate value, it should be
35 | // within the range [lower_bound, uppper_bound].
36 | google.protobuf.DoubleValue value = 3;
37 | }
38 |
39 | // Value at cutoffs, e.g. for precision@K, recall@K
40 | message ValueAtCutoffs {
41 | message ValueCutoffPair {
42 | int32 cutoff = 1;
43 | double value = 2;
44 | }
45 | repeated ValueCutoffPair values = 1;
46 | }
47 |
48 | // Confusion matrix at thresholds.
49 | message ConfusionMatrixAtThresholds {
50 | message ConfusionMatrixAtThreshold {
51 | double threshold = 1;
52 | double false_negatives = 2;
53 | double true_negatives = 3;
54 | double false_positives = 4;
55 | double true_positives = 5;
56 | double precision = 6;
57 | double recall = 7;
58 | }
59 | repeated ConfusionMatrixAtThreshold matrices = 1;
60 | }
61 |
62 | // It stores metrics values in different types, so that the frontend will know
63 | // how to visualize the values based on the types.
64 | message MetricValue {
65 | oneof type {
66 | google.protobuf.DoubleValue double_value = 1;
67 | BoundedValue bounded_value = 2;
68 | ValueAtCutoffs value_at_cutoffs = 4;
69 | ConfusionMatrixAtThresholds confusion_matrix_at_thresholds = 5;
70 | UnknownType unknown_type = 3;
71 | }
72 | }
73 |
74 | // A single slice key.
75 | message SingleSliceKey {
76 | string column = 1;
77 | oneof kind {
78 | bytes bytes_value = 2;
79 | float float_value = 3;
80 | int64 int64_value = 4;
81 | }
82 | }
83 |
84 | // A slice key, which may consist of multiple single slice keys.
85 | message SliceKey {
86 | repeated SingleSliceKey single_slice_keys = 1;
87 | }
88 |
89 | message MetricsForSlice {
90 | // The slice key for the metrics.
91 | SliceKey slice_key = 1;
92 | // A map to store metrics. Currently we convert the post_export_metric
93 | // provided by TFMA to its appropriate type for better visualization, and map
94 | // all other metrics to DoubleValue type.
95 | map metrics = 2;
96 | }
97 |
98 | message CalibrationHistogramBuckets {
99 | message Bucket {
100 | double lower_threshold_inclusive = 1;
101 | double upper_threshold_exclusive = 2;
102 | google.protobuf.DoubleValue num_weighted_examples = 3;
103 | google.protobuf.DoubleValue total_weighted_label = 4;
104 | google.protobuf.DoubleValue total_weighted_refined_prediction = 5;
105 | }
106 | repeated Bucket buckets = 1;
107 | }
108 |
109 | message PlotData {
110 | // For calibration plot and prediction distribution.
111 | CalibrationHistogramBuckets calibration_histogram_buckets = 1;
112 | // For auc curve and auprc curve.
113 | ConfusionMatrixAtThresholds confusion_matrix_at_thresholds = 2;
114 | }
115 |
116 | message PlotsForSlice {
117 | // The slice key for the metrics.
118 | SliceKey slice_key = 1;
119 | // The plot data
120 | PlotData plot_data = 2;
121 | }
122 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Noether
2 | =======
3 |
4 | [](https://travis-ci.org/spotify/noether)
5 | [](https://codecov.io/github/spotify/noether?branch=master)
6 | [](./LICENSE)
7 | [](https://maven-badges.herokuapp.com/maven-central/com.spotify/noether-core_2.12)
8 | [](https://spotify.github.io/noether/latest/api/com/spotify/noether/index.html)
9 | [](https://scala-steward.org)
10 |
11 | > [Emmy Noether](https://en.wikipedia.org/wiki/Emmy_Noether) was a German mathematician known for her landmark contributions to abstract algebra and theoretical physics.
12 |
13 | Noether is a collection of Machine Learning tools targeted at the JVM and Scala.
14 | It relies heavily on the [Algebird](https://github.com/twitter/algebird) library especially for Aggregators.
15 |
16 | # Aggregators
17 |
18 | Aggregators enable creation of reusable and composable aggregation functions. Most Machine Learning loss functions and metrics can be
19 | decomposed into a single aggregator. This becomes useful when a model produces a set of predictions and one or more metrics are needed
20 | to be computed on this collection.
21 |
22 | Below is an example for a binary classification task. Algebird's MultiAggregator can be used to combine multiple metrics into a
23 | single callable aggregator.
24 |
25 | ```scala
26 | val multiAggregator =
27 | MultiAggregator(AUC(ROC), AUC(PR), ClassificationReport(), BinaryConfusionMatrix())
28 | .andThenPresent{case (roc, pr, report, cm) =>
29 | (roc, pr, report.accuracy, report.recall, report.precision, cm(1, 1), cm(0, 0))
30 | }
31 |
32 | val predictions = List(Prediction(false, 0.1), Prediction(false, 0.6), Prediction(true, 0.9))
33 |
34 | println(multiAggregator(predictions))
35 | ```
36 |
37 | ## Prediction Object
38 |
39 | Most aggregators take a single parameterized class called Prediction as input to the aggregator. However the type of
40 | the prediction object differ based on the aggregator. In the above example each binary classifier takes a prediction
41 | of type `Prediction[Boolean, Double]` where the first type is the label and the second in the predicted score.
42 |
43 | Other aggregators will takes slightly different types such as the Error Rate Aggregator which expects `Prediction[Int, List[Double]]`
44 | where the types are label and a list of scores.
45 |
46 | ## Available Aggregators
47 |
48 | See the docs on each aggregator for a more detailed walk-through on the functionality and the return objects.
49 |
50 | 1. ConfusionMatrix
51 | 1. Includes a special BinaryConfusionMatrix case to make composition easier with the other binary classification metrics.
52 | 2. AUC
53 | 1. Supports both ROC and PR
54 | 3. ClassificationReport
55 | 1. Returns a list of summary metrics for a binary classification problem.
56 | 4. LogLoss
57 | 1. Available for multiclass. Returns the total log loss for the predictions.
58 | 5. ErrorRateSummary
59 | 1. Available for multiclass. Returns the proportion of misclassified predictions.w
60 |
61 | # Tensorflow Model Analysis Support
62 |
63 | Noether supports outputting metrics as TFX `metrics_for_slice` protobufs, which can be used in
64 | TFMA methods. This is available in the `noether-tfx` package:
65 |
66 | ```scala
67 | libraryDependencies += "com.spotify" %% "noether-tfx" % noetherVersion
68 | ```
69 |
70 | ```scala
71 | import com.spotify.noether.tfx._
72 |
73 | val data = List(
74 | (0, 0),
75 | (0, 1),
76 | (0, 0),
77 | (1, 0),
78 | (1, 1),
79 | (1, 1),
80 | (1, 1)
81 | ).map { case (s, pred) => Prediction(pred, s) }
82 |
83 | val tfmaProto = ConfusionMatrix(Seq(0, 1)).asTfmaProto(data)
84 | ```
85 |
86 | # License
87 |
88 | Copyright 2016-2018 Spotify AB.
89 |
90 | Licensed under the Apache License, Version 2.0: http://www.apache.org/licenses/LICENSE-2.0
91 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/spotify/noether/AUC.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import breeze.linalg._
21 | import com.twitter.algebird.{Aggregator, Semigroup}
22 |
23 | case class MetricCurve(cm: Array[Map[(Int, Int), Long]]) extends Serializable
24 |
25 | case class MetricCurvePoints(points: Array[MetricCurvePoint]) extends Serializable
26 |
27 | case class MetricCurvePoint(x: Double, y: Double) extends Serializable
28 |
29 | private[noether] object AreaUnderCurve {
30 | def trapezoid(points: Seq[MetricCurvePoint]): Double = {
31 | val x = points.head
32 | val y = points.last
33 | (y.x - x.x) * (y.y + x.y) / 2.0
34 | }
35 |
36 | def of(curve: MetricCurvePoints): Double = {
37 | curve.points.iterator
38 | .sliding(2)
39 | .withPartial(false)
40 | .foldLeft(0.0)((auc: Double, points: Seq[MetricCurvePoint]) => auc + trapezoid(points))
41 | }
42 | }
43 |
44 | /**
45 | * Which function to apply on the list of confusion matrices prior to the AUC calculation.
46 | */
47 | sealed trait AUCMetric
48 |
49 | /**
50 | * Receiver operating
51 | * characteristic Curve
52 | */
53 | case object ROC extends AUCMetric
54 |
55 | /**
56 | * Precision Recall Curve
57 | */
58 | case object PR extends AUCMetric
59 |
60 | /**
61 | * Compute a series of points for a collection of predictions.
62 | *
63 | * Internally a linspace is defined using the given number of [[samples]]. Each point in the
64 | * linspace represents a threshold which is used to build a confusion matrix. The (x,y) location of
65 | * the line is then returned.
66 | *
67 | * [[AUCMetric]] which is given to the aggregate selects the function to apply on the confusion
68 | * matrix prior to the AUC calculation.
69 | *
70 | * @param metric
71 | * Which function to apply on the confusion matrix.
72 | * @param samples
73 | * Number of samples to use for the curve definition.
74 | */
75 | case class Curve(metric: AUCMetric, samples: Int = 100)
76 | extends Aggregator[Prediction[Boolean, Double], MetricCurve, MetricCurvePoints] {
77 | private lazy val thresholds = linspace(0.0, 1.0, samples)
78 | private lazy val aggregators =
79 | thresholds.data.map(ClassificationReport(_)).toArray
80 |
81 | def prepare(input: Prediction[Boolean, Double]): MetricCurve =
82 | MetricCurve(aggregators.map(_.prepare(input)))
83 |
84 | def semigroup: Semigroup[MetricCurve] = {
85 | val sg = ClassificationReport().semigroup
86 | Semigroup.from { case (l, r) =>
87 | MetricCurve(l.cm.zip(r.cm).map { case (cl, cr) => sg.plus(cl, cr) })
88 | }
89 | }
90 |
91 | def present(c: MetricCurve): MetricCurvePoints = {
92 | val total = c.cm.map { matrix =>
93 | val scores = ClassificationReport().present(matrix)
94 | metric match {
95 | case ROC => MetricCurvePoint(scores.fpr, scores.recall)
96 | case PR => MetricCurvePoint(scores.recall, scores.precision)
97 | }
98 | }.reverse
99 |
100 | val points = metric match {
101 | case ROC => total ++ Array(MetricCurvePoint(1.0, 1.0))
102 | case PR => Array(MetricCurvePoint(0.0, 1.0)) ++ total
103 | }
104 |
105 | MetricCurvePoints(points)
106 | }
107 | }
108 |
109 | /**
110 | * Compute the "Area Under the Curve" for a collection of predictions. Uses the Trapezoid method to
111 | * compute the area.
112 | *
113 | * Internally a linspace is defined using the given number of [[samples]]. Each point in the
114 | * linspace represents a threshold which is used to build a confusion matrix. The area is then
115 | * defined on this list of confusion matrices.
116 | *
117 | * [[AUCMetric]] which is given to the aggregate selects the function to apply on the confusion
118 | * matrix prior to the AUC calculation.
119 | *
120 | * @param metric
121 | * Which function to apply on the confusion matrix.
122 | * @param samples
123 | * Number of samples to use for the curve definition.
124 | */
125 | case class AUC(metric: AUCMetric, samples: Int = 100)
126 | extends Aggregator[Prediction[Boolean, Double], MetricCurve, Double] {
127 | private val curve = Curve(metric, samples)
128 | def prepare(input: Prediction[Boolean, Double]): MetricCurve = curve.prepare(input)
129 | def semigroup: Semigroup[MetricCurve] = curve.semigroup
130 | def present(c: MetricCurve): Double = AreaUnderCurve.of(curve.present(c))
131 | }
132 |
--------------------------------------------------------------------------------
/core/src/main/scala/com/spotify/noether/ClassificationReport.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether
19 |
20 | import com.twitter.algebird.{Aggregator, Semigroup}
21 |
22 | /**
23 | * Classification Report
24 | *
25 | * @param mcc
26 | * Matthews Correlation Coefficient
27 | * @param fscore
28 | * f-score
29 | * @param precision
30 | * Precision
31 | * @param recall
32 | * Recall
33 | * @param accuracy
34 | * Accuracy
35 | * @param fpr
36 | * False Positive Rate
37 | */
38 | final case class Report(
39 | mcc: Double,
40 | fscore: Double,
41 | precision: Double,
42 | recall: Double,
43 | accuracy: Double,
44 | fpr: Double
45 | )
46 |
47 | /**
48 | * Generate a Classification Report for a collection of binary predictions. The output of this
49 | * aggregator will be a [[Report]] object.
50 | *
51 | * @param threshold
52 | * Threshold to apply to get the predictions.
53 | * @param beta
54 | * Beta parameter used in the f-score calculation.
55 | */
56 | final case class ClassificationReport(threshold: Double = 0.5, beta: Double = 1.0)
57 | extends Aggregator[Prediction[Boolean, Double], Map[(Int, Int), Long], Report] {
58 | private val aggregator = MultiClassificationReport(Seq(0, 1))
59 |
60 | def prepare(input: Prediction[Boolean, Double]): Map[(Int, Int), Long] = {
61 | val predicted = Prediction(
62 | if (input.actual) 1 else 0,
63 | if (input.predicted > threshold) 1 else 0
64 | )
65 | aggregator.prepare(predicted)
66 | }
67 |
68 | def semigroup: Semigroup[Map[(Int, Int), Long]] = aggregator.semigroup
69 |
70 | def present(m: Map[(Int, Int), Long]): Report = aggregator.present(m)(1)
71 | }
72 |
73 | /**
74 | * Generate a Classification Report for a collection of multiclass predictions. A report is
75 | * generated for each class by treating the predictions as binary of either "class" or "not class".
76 | * The output of this aggregator will be a map of classes and their [[Report]] objects.
77 | *
78 | * @param labels
79 | * List of possible label values.
80 | * @param beta
81 | * Beta parameter used in the f-score calculation.
82 | */
83 | final case class MultiClassificationReport(labels: Seq[Int], beta: Double = 1.0)
84 | extends Aggregator[Prediction[Int, Int], Map[(Int, Int), Long], Map[Int, Report]] {
85 | private val aggregator = ConfusionMatrix(labels)
86 |
87 | override def prepare(input: Prediction[Int, Int]): Map[(Int, Int), Long] =
88 | aggregator.prepare(input)
89 |
90 | override def semigroup: Semigroup[Map[(Int, Int), Long]] =
91 | aggregator.semigroup
92 |
93 | override def present(m: Map[(Int, Int), Long]): Map[Int, Report] = {
94 | val mat = m.withDefaultValue(0L)
95 | labels.foldLeft(Map.empty[Int, Report]) { (result, clazz) =>
96 | val fp = mat.iterator
97 | .filter { case ((p, a), _) =>
98 | p == clazz && a != clazz
99 | }
100 | .map(_._2)
101 | .sum
102 | .toDouble
103 | val tp = mat(clazz -> clazz).toDouble
104 | val tn = mat.iterator
105 | .filter { case ((p, a), _) =>
106 | p != clazz && a != clazz
107 | }
108 | .map(_._2)
109 | .sum
110 | .toDouble
111 | val fn = mat.iterator
112 | .filter { case ((p, a), _) => p != clazz && a == clazz }
113 | .map(_._2)
114 | .sum
115 | .toDouble
116 |
117 | val mccDenom = math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
118 | val mcc = if (mccDenom > 0.0) ((tp * tn) - (fp * fn)) / mccDenom else 0.0
119 |
120 | val precDenom = tp + fp
121 | val precision = if (precDenom > 0.0) tp / precDenom else 1.0
122 |
123 | val recallDenom = tp + fn
124 | val recall = if (recallDenom > 0.0) tp / recallDenom else 1.0
125 |
126 | val accuracyDenom = tp + fn + tn + fp
127 | val accuracy = if (accuracyDenom > 0.0) (tp + tn) / accuracyDenom else 0.0
128 |
129 | val fpDenom = fp + tn
130 | val fpr = if (fpDenom > 0.0) fp / fpDenom else 0.0
131 |
132 | val betaSqr = Math.pow(beta, 2.0)
133 |
134 | val fScoreDenom = (betaSqr * precision) + recall
135 | val fscore = if (fScoreDenom > 0.0) {
136 | (1 + betaSqr) * ((precision * recall) / fScoreDenom)
137 | } else {
138 | 1.0
139 | }
140 |
141 | result + (clazz -> Report(mcc, fscore, precision, recall, accuracy, fpr))
142 | }
143 | }
144 |
145 | }
146 |
--------------------------------------------------------------------------------
/tfx/src/test/scala/com/spotify/noether/tfx/TfmaConverterTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether.tfx
19 |
20 | import com.spotify.noether._
21 | import org.scalactic.{Equality, TolerantNumerics}
22 | import org.scalatest.Assertion
23 | import tensorflow_model_analysis.MetricsForSliceOuterClass.MetricsForSlice
24 | import tensorflow_model_analysis.MetricsForSliceOuterClass.CalibrationHistogramBuckets
25 | import scala.jdk.CollectionConverters._
26 | import org.scalatest.flatspec.AnyFlatSpec
27 | import org.scalatest.matchers.should.Matchers
28 |
29 | class TfmaConverterTest extends AnyFlatSpec with Matchers {
30 | implicit val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.001)
31 |
32 | "TfmaConverter" should "work with ConfusionMatrix" in {
33 | val data = List(
34 | (0, 0),
35 | (0, 1),
36 | (0, 0),
37 | (1, 0),
38 | (1, 1),
39 | (1, 1),
40 | (1, 1)
41 | ).map { case (s, pred) => Prediction(pred, s) }
42 |
43 | val evalResult = ConfusionMatrix(Seq(0, 1)).asTfmaProto(data)
44 | val cmProto = evalResult.metrics.get
45 |
46 | val cm = cmProto.getMetricsMap
47 | .get("Noether_ConfusionMatrix")
48 | .getConfusionMatrixAtThresholds
49 | .getMatrices(0)
50 |
51 | assert(cm.getFalseNegatives.toLong === 1L)
52 | assert(cm.getFalsePositives.toLong === 1L)
53 | assert(cm.getTrueNegatives.toLong === 2L)
54 | assert(cm.getTruePositives.toLong === 3L)
55 |
56 | val plotCm = evalResult.plots.get.plotData.getPlotData.getConfusionMatrixAtThresholds
57 | .getMatrices(0)
58 |
59 | assert(plotCm.getFalseNegatives.toLong === 1L)
60 | assert(plotCm.getFalsePositives.toLong === 1L)
61 | assert(plotCm.getTrueNegatives.toLong === 2L)
62 | assert(plotCm.getTruePositives.toLong === 3L)
63 | }
64 |
65 | it should "work with BinaryConfusionMatrix" in {
66 | val data = List(
67 | (false, 0.1),
68 | (false, 0.6),
69 | (false, 0.2),
70 | (true, 0.2),
71 | (true, 0.8),
72 | (true, 0.7),
73 | (true, 0.6)
74 | ).map { case (pred, s) => Prediction(pred, s) }
75 |
76 | val evalResult = BinaryConfusionMatrix().asTfmaProto(data)
77 | val cmProto = evalResult.metrics.get
78 |
79 | val cm = cmProto.getMetricsMap
80 | .get("Noether_ConfusionMatrix")
81 | .getConfusionMatrixAtThresholds
82 | .getMatrices(0)
83 |
84 | assert(cm.getThreshold === 0.5)
85 | assert(cm.getTruePositives.toLong === 3L)
86 | assert(cm.getFalseNegatives.toLong === 1L)
87 | assert(cm.getFalsePositives.toLong === 1L)
88 | assert(cm.getTrueNegatives.toLong === 2L)
89 |
90 | val plotCm = evalResult.plots.get.plotData.getPlotData.getConfusionMatrixAtThresholds
91 | .getMatrices(0)
92 |
93 | assert(plotCm.getFalseNegatives.toLong === 1L)
94 | assert(plotCm.getFalsePositives.toLong === 1L)
95 | assert(plotCm.getTrueNegatives.toLong === 2L)
96 | assert(plotCm.getTruePositives.toLong === 3L)
97 | }
98 |
99 | it should "work with ClassificationReport" in {
100 | val data = List(
101 | (0.1, false),
102 | (0.1, true),
103 | (0.4, false),
104 | (0.6, false),
105 | (0.6, true),
106 | (0.6, true),
107 | (0.8, true)
108 | ).map { case (s, pred) => Prediction(pred, s) }
109 |
110 | val metrics = ClassificationReport().asTfmaProto(data).metrics
111 |
112 | def assertMetric(name: String, expected: Double): Assertion =
113 | assert(metrics.get.getMetricsMap.get(name).getDoubleValue.getValue === expected)
114 |
115 | assertMetric("Noether_Accuracy", 0.7142857142857143)
116 | assertMetric("Noether_FPR", 0.333)
117 | assertMetric("Noether_FScore", 0.75)
118 | assertMetric("Noether_MCC", 0.4166666666666667)
119 | assertMetric("Noether_Precision", 0.75)
120 | assertMetric("Noether_Recall", 0.75)
121 | }
122 |
123 | it should "work with ErrorRateSummary" in {
124 | val classes = 10
125 | def s(idx: Int): List[Double] = 0.until(classes).map(i => if (i == idx) 1.0 else 0.0).toList
126 |
127 | val data =
128 | List((s(1), 1), (s(3), 1), (s(5), 5), (s(2), 3), (s(0), 0), (s(8), 1)).map {
129 | case (scores, label) => Prediction(label, scores)
130 | }
131 |
132 | val ersProto: MetricsForSlice = ErrorRateSummary.asTfmaProto(data).metrics.get
133 |
134 | val ersV =
135 | ersProto.getMetricsMap.get("Noether_ErrorRateSummary").getDoubleValue.getValue
136 | assert(ersV === 0.5)
137 | }
138 |
139 | it should "work with AUC" in {
140 | val data = List(
141 | (false, 0.1),
142 | (false, 0.6),
143 | (false, 0.2),
144 | (true, 0.2),
145 | (true, 0.8),
146 | (true, 0.7),
147 | (true, 0.6)
148 | ).map { case (pred, s) => Prediction(pred, s) }
149 |
150 | val aucROCProto = AUC(ROC).asTfmaProto(data).metrics.get
151 | val aucPRProto = AUC(PR).asTfmaProto(data).metrics.get
152 |
153 | val actualROC = aucROCProto.getMetricsMap.get("Noether_AUC:ROC").getDoubleValue.getValue
154 | val actualPR = aucPRProto.getMetricsMap.get("Noether_AUC:PR").getDoubleValue.getValue
155 |
156 | assert(actualROC === 0.833)
157 | assert(actualPR === 0.896)
158 | }
159 |
160 | it should "work with LogLoss" in {
161 | val classes = 10
162 |
163 | def s(idx: Int, score: Double): List[Double] =
164 | 0.until(classes).map(i => if (i == idx) score else 0.0).toList
165 |
166 | val data = List((s(0, 0.8), 0), (s(1, 0.6), 1), (s(2, 0.7), 2)).map { case (scores, label) =>
167 | Prediction(label, scores)
168 | }
169 |
170 | val logLossProto: MetricsForSlice = LogLoss.asTfmaProto(data).metrics.get
171 |
172 | val logLoss = logLossProto.getMetricsMap.get("Noether_LogLoss").getDoubleValue.getValue
173 | assert(logLoss === 0.363548039673)
174 | }
175 |
176 | it should "work with MeanAveragePrecision" in {
177 | import RankingData._
178 | val proto = MeanAveragePrecision[Int]().asTfmaProto(rankingData).metrics.get
179 | val meanAvgPrecision = proto.getMetricsMap
180 | .get("Noether_MeanAvgPrecision")
181 | .getDoubleValue
182 | .getValue
183 | assert(meanAvgPrecision === 0.355026)
184 | }
185 |
186 | it should "work with NdcgAtK" in {
187 | import RankingData._
188 | implicit val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.1)
189 |
190 | def getNdcgAtK(v: Int): Double =
191 | NdcgAtK[Int](v)
192 | .asTfmaProto(rankingData)
193 | .metrics
194 | .get
195 | .getMetricsMap
196 | .get("Noether_NdcgAtK")
197 | .getDoubleValue
198 | .getValue
199 |
200 | assert(getNdcgAtK(3) === 1.0 / 3)
201 | assert(getNdcgAtK(5) === 0.328788)
202 | assert(getNdcgAtK(10) === 0.487913)
203 | assert(getNdcgAtK(15) === getNdcgAtK(10))
204 | }
205 |
206 | it should "work with PrecisionAtK" in {
207 | import RankingData._
208 | implicit val doubleEq: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.1)
209 |
210 | def getPrecisionAtK(v: Int): Double =
211 | PrecisionAtK[Int](v)
212 | .asTfmaProto(rankingData)
213 | .metrics
214 | .get
215 | .getMetricsMap
216 | .get("Noether_PrecisionAtK")
217 | .getDoubleValue
218 | .getValue
219 |
220 | assert(getPrecisionAtK(1) === 1.0 / 3)
221 | assert(getPrecisionAtK(2) === 1.0 / 3)
222 | assert(getPrecisionAtK(3) === 1.0 / 3)
223 | assert(getPrecisionAtK(4) === 0.75 / 3)
224 | assert(getPrecisionAtK(5) === 0.8 / 3)
225 | assert(getPrecisionAtK(10) === 0.8 / 3)
226 | assert(getPrecisionAtK(15) === 8.0 / 45)
227 | }
228 |
229 | it should "work with CalibrationHistogram" in {
230 | val data = Seq(
231 | (0.15, 1.15), // lb
232 | (0.288, 1.288), // rounding error puts this in (0.249, 0.288)
233 | (0.30, 1.30), // (0.288, 0.3269)
234 | (0.36, 1.36), // (0.3269, 0.365)
235 | (0.555, 1.555), // (0.5219, 0.5609)
236 | (1.2, 2.2), // ub
237 | (0.7, 1.7) // ub
238 | ).map { case (p, a) => Prediction(a, p) }
239 |
240 | val result = CalibrationHistogram(0.21, 0.60, 10).asTfmaProto(data)
241 |
242 | def protoToCaseClass(p: CalibrationHistogramBuckets.Bucket): CalibrationHistogramBucket = {
243 | CalibrationHistogramBucket(
244 | p.getLowerThresholdInclusive,
245 | p.getUpperThresholdExclusive,
246 | p.getNumWeightedExamples.getValue,
247 | p.getTotalWeightedLabel.getValue,
248 | p.getTotalWeightedRefinedPrediction.getValue
249 | )
250 | }
251 |
252 | val actual =
253 | result.plots.get.plotData.getPlotData.getCalibrationHistogramBuckets.getBucketsList.asScala
254 | .map(protoToCaseClass)
255 |
256 | val expected = List(
257 | CalibrationHistogramBucket(Double.NegativeInfinity, 0.21, 1.0, 1.15, 0.15),
258 | CalibrationHistogramBucket(0.21, 0.249, 0.0, 0.0, 0.0),
259 | CalibrationHistogramBucket(0.249, 0.288, 1.0, 1.288, 0.288),
260 | CalibrationHistogramBucket(0.288, 0.327, 1.0, 1.30, 0.30),
261 | CalibrationHistogramBucket(0.327, 0.366, 1.0, 1.36, 0.36),
262 | CalibrationHistogramBucket(0.366, 0.405, 0.0, 0.0, 0.0),
263 | CalibrationHistogramBucket(0.405, 0.4449, 0.0, 0.0, 0.0),
264 | CalibrationHistogramBucket(0.444, 0.483, 0.0, 0.0, 0.0),
265 | CalibrationHistogramBucket(0.483, 0.522, 0.0, 0.0, 0.0),
266 | CalibrationHistogramBucket(0.522, 0.561, 1.0, 1.555, 0.555),
267 | CalibrationHistogramBucket(0.561, 0.6, 0.0, 0.0, 0.0),
268 | CalibrationHistogramBucket(0.6, Double.PositiveInfinity, 2.0, 3.9, 1.9)
269 | )
270 |
271 | assert(actual.length == expected.length)
272 | (0 until expected.length).foreach { i =>
273 | assert(actual(i).numPredictions === expected(i).numPredictions)
274 | assert(actual(i).sumPredictions === expected(i).sumPredictions)
275 | assert(actual(i).sumLabels === expected(i).sumLabels)
276 | assert(actual(i).lowerThresholdInclusive === expected(i).lowerThresholdInclusive)
277 | assert(actual(i).upperThresholdExclusive === expected(i).upperThresholdExclusive)
278 | }
279 | }
280 | }
281 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright {yyyy} {name of copyright owner}
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/tfx/src/main/scala/com/spotify/noether/tfx/TfmaImplicits.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2018 Spotify AB.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing,
11 | * software distributed under the License is distributed on an
12 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | * KIND, either express or implied. See the License for the
14 | * specific language governing permissions and limitations
15 | * under the License.
16 | */
17 |
18 | package com.spotify.noether.tfx
19 |
20 | import breeze.linalg.DenseMatrix
21 | import com.google.protobuf.DoubleValue
22 | import com.spotify.noether._
23 | import com.spotify.noether.tfx.Tfma.ConversionOps
24 | import com.twitter.algebird.Aggregator
25 | import tensorflow_model_analysis.MetricsForSliceOuterClass.ConfusionMatrixAtThresholds._
26 | import tensorflow_model_analysis.MetricsForSliceOuterClass._
27 |
28 | import scala.collection.JavaConverters._
29 |
30 | trait TfmaImplicits {
31 | private def confusionMatrixToMetric(cm: ConfusionMatrixAtThresholds): MetricsForSlice = {
32 | MetricsForSlice
33 | .newBuilder()
34 | .setSliceKey(SliceKey.getDefaultInstance)
35 | .putMetrics(
36 | "Noether_ConfusionMatrix",
37 | MetricValue
38 | .newBuilder()
39 | .setConfusionMatrixAtThresholds(cm)
40 | .build()
41 | )
42 | .build()
43 | }
44 |
45 | private def denseMatrixToConfusionMatrix(
46 | threshold: Option[Double] = None
47 | )(matrix: DenseMatrix[Long]): ConfusionMatrixAtThresholds = {
48 | val tp = matrix.valueAt(1, 1).toDouble
49 | val tn = matrix.valueAt(0, 0).toDouble
50 | val fp = matrix.valueAt(1, 0).toDouble
51 | val fn = matrix.valueAt(0, 1).toDouble
52 |
53 | val cmBuilder = ConfusionMatrixAtThreshold
54 | .newBuilder()
55 | .setFalseNegatives(fn)
56 | .setFalsePositives(fp)
57 | .setTrueNegatives(tn)
58 | .setTruePositives(tp)
59 | .setPrecision(tp / (tp + fp))
60 | .setRecall(tp / (tp + fn))
61 |
62 | threshold.foreach(cmBuilder.setThreshold)
63 | ConfusionMatrixAtThresholds
64 | .newBuilder()
65 | .addMatrices(cmBuilder.build())
66 | .build()
67 | }
68 |
69 | private def buildDoubleMetric(name: String, value: Double): MetricsForSlice =
70 | MetricsForSlice
71 | .newBuilder()
72 | .setSliceKey(SliceKey.getDefaultInstance)
73 | .putMetrics(
74 | name,
75 | MetricValue
76 | .newBuilder()
77 | .setDoubleValue(DoubleValue.newBuilder().setValue(value))
78 | .build()
79 | )
80 | .build()
81 |
82 | private def buildDoubleMetrics(metrics: Map[String, Double]): MetricsForSlice = {
83 | val metricValues = metrics.iterator
84 | .map { case (k, m) =>
85 | val value = MetricValue
86 | .newBuilder()
87 | .setDoubleValue(DoubleValue.newBuilder().setValue(m))
88 | .build()
89 | (k, value)
90 | }
91 | .toMap
92 | .asJava
93 |
94 | MetricsForSlice
95 | .newBuilder()
96 | .setSliceKey(SliceKey.getDefaultInstance)
97 | .putAllMetrics(metricValues)
98 | .build()
99 | }
100 |
101 | private def buildConfusionMatrixPlot(cm: ConfusionMatrixAtThresholds): PlotsForSlice =
102 | PlotsForSlice
103 | .newBuilder()
104 | .setSliceKey(SliceKey.getDefaultInstance)
105 | .setPlotData(
106 | PlotData
107 | .newBuilder()
108 | .setConfusionMatrixAtThresholds(cm)
109 | )
110 | .build()
111 |
112 | private def mkDoubleValue(d: Double): DoubleValue =
113 | DoubleValue
114 | .newBuilder()
115 | .setValue(d)
116 | .build()
117 |
118 | private def buildCalibrationHistogramPlot(ch: List[CalibrationHistogramBucket]): PlotsForSlice = {
119 | val plotData = PlotData
120 | .newBuilder()
121 | .setCalibrationHistogramBuckets(
122 | CalibrationHistogramBuckets
123 | .newBuilder()
124 | .addAllBuckets(ch.map { b =>
125 | CalibrationHistogramBuckets.Bucket
126 | .newBuilder()
127 | .setLowerThresholdInclusive(b.lowerThresholdInclusive)
128 | .setUpperThresholdExclusive(b.upperThresholdExclusive)
129 | .setTotalWeightedRefinedPrediction(mkDoubleValue(b.sumPredictions))
130 | .setTotalWeightedLabel(mkDoubleValue(b.sumLabels))
131 | .setNumWeightedExamples(mkDoubleValue(b.numPredictions))
132 | .build()
133 | }.asJava)
134 | )
135 | .build()
136 |
137 | PlotsForSlice
138 | .newBuilder()
139 | .setSliceKey(SliceKey.getDefaultInstance)
140 | .setPlotData(plotData)
141 | .build()
142 | }
143 |
144 | implicit def confusionMatrixConversion(agg: ConfusionMatrix)(implicit
145 | c: TfmaConverter[Prediction[Int, Int], Map[(Int, Int), Long], ConfusionMatrix]
146 | ): ConversionOps[Prediction[Int, Int], Map[(Int, Int), Long], ConfusionMatrix] =
147 | ConversionOps[Prediction[Int, Int], Map[(Int, Int), Long], ConfusionMatrix](agg, c)
148 |
149 | implicit def calibrationHistogramConversion(agg: CalibrationHistogram)(implicit
150 | c: TfmaConverter[
151 | Prediction[Double, Double],
152 | Map[Double, (Double, Double, Long)],
153 | CalibrationHistogram
154 | ]
155 | ): ConversionOps[
156 | Prediction[Double, Double],
157 | Map[Double, (Double, Double, Long)],
158 | CalibrationHistogram
159 | ] =
160 | ConversionOps[
161 | Prediction[Double, Double],
162 | Map[Double, (Double, Double, Long)],
163 | CalibrationHistogram
164 | ](
165 | agg,
166 | c
167 | )
168 |
169 | implicit def binaryConfusionMatrixConversion(agg: BinaryConfusionMatrix)(implicit
170 | c: TfmaConverter[BinaryPred, Map[(Int, Int), Long], BinaryConfusionMatrix]
171 | ): ConversionOps[BinaryPred, Map[(Int, Int), Long], BinaryConfusionMatrix] =
172 | ConversionOps[BinaryPred, Map[(Int, Int), Long], BinaryConfusionMatrix](agg, c)
173 |
174 | implicit def classificationReportConversion(agg: ClassificationReport)(implicit
175 | c: TfmaConverter[BinaryPred, Map[(Int, Int), Long], ClassificationReport]
176 | ): ConversionOps[BinaryPred, Map[(Int, Int), Long], ClassificationReport] =
177 | ConversionOps[BinaryPred, Map[(Int, Int), Long], ClassificationReport](agg, c)
178 |
179 | implicit def aucConversion(agg: AUC)(implicit
180 | c: TfmaConverter[BinaryPred, MetricCurve, AUC]
181 | ): ConversionOps[BinaryPred, MetricCurve, AUC] =
182 | ConversionOps[BinaryPred, MetricCurve, AUC](agg, c)
183 |
184 | implicit def errorRateSummaryConversion(agg: ErrorRateSummary.type)(implicit
185 | c: TfmaConverter[Prediction[Int, List[Double]], (Double, Long), ErrorRateSummary.type]
186 | ): ConversionOps[Prediction[Int, List[Double]], (Double, Long), ErrorRateSummary.type] =
187 | ConversionOps[Prediction[Int, List[Double]], (Double, Long), ErrorRateSummary.type](agg, c)
188 |
189 | implicit def logLossConversion(agg: LogLoss.type)(implicit
190 | c: TfmaConverter[Prediction[Int, List[Double]], (Double, Long), LogLoss.type]
191 | ): ConversionOps[Prediction[Int, List[Double]], (Double, Long), LogLoss.type] =
192 | ConversionOps[Prediction[Int, List[Double]], (Double, Long), LogLoss.type](agg, c)
193 |
194 | implicit def meanAvgPrecisionConversion[T](agg: MeanAveragePrecision[T])(implicit
195 | c: TfmaConverter[RankingPrediction[T], (Double, Long), MeanAveragePrecision[T]]
196 | ): ConversionOps[RankingPrediction[T], (Double, Long), MeanAveragePrecision[T]] =
197 | ConversionOps[RankingPrediction[T], (Double, Long), MeanAveragePrecision[T]](agg, c)
198 |
199 | implicit def ndcgAtKConversion[T](agg: NdcgAtK[T])(implicit
200 | c: TfmaConverter[RankingPrediction[T], (Double, Long), NdcgAtK[T]]
201 | ): ConversionOps[RankingPrediction[T], (Double, Long), NdcgAtK[T]] =
202 | ConversionOps[RankingPrediction[T], (Double, Long), NdcgAtK[T]](agg, c)
203 |
204 | implicit def precisionAtKConversion[T](agg: PrecisionAtK[T])(implicit
205 | c: TfmaConverter[RankingPrediction[T], (Double, Long), PrecisionAtK[T]]
206 | ): ConversionOps[RankingPrediction[T], (Double, Long), PrecisionAtK[T]] =
207 | ConversionOps[RankingPrediction[T], (Double, Long), PrecisionAtK[T]](agg, c)
208 |
209 | implicit val errorRateSummaryConverter
210 | : TfmaConverter[Prediction[Int, List[Double]], (Double, Long), ErrorRateSummary.type] =
211 | new TfmaConverter[Prediction[Int, List[Double]], (Double, Long), ErrorRateSummary.type] {
212 | override def convertToTfmaProto(
213 | underlying: ErrorRateSummary.type
214 | ): Aggregator[Prediction[Int, List[Double]], (Double, Long), EvalResult] =
215 | ErrorRateSummary.andThenPresent { ers =>
216 | val metrics = buildDoubleMetric("Noether_ErrorRateSummary", ers)
217 | EvalResult(metrics)
218 | }
219 | }
220 |
221 | implicit val binaryConfusionMatrixConverter
222 | : TfmaConverter[BinaryPred, Map[(Int, Int), Long], BinaryConfusionMatrix] =
223 | new TfmaConverter[BinaryPred, Map[(Int, Int), Long], BinaryConfusionMatrix] {
224 | override def convertToTfmaProto(
225 | underlying: BinaryConfusionMatrix
226 | ): Aggregator[BinaryPred, Map[(Int, Int), Long], EvalResult] = {
227 | underlying
228 | .andThenPresent(
229 | (denseMatrixToConfusionMatrix(Some(underlying.threshold)) _)
230 | .andThen { cm =>
231 | val metrics = confusionMatrixToMetric(cm)
232 | val plots = buildConfusionMatrixPlot(cm)
233 | EvalResult(metrics, Plot.ConfusionMatrix(plots))
234 | }
235 | )
236 | }
237 | }
238 |
239 | implicit val confusionMatrixConverter
240 | : TfmaConverter[Prediction[Int, Int], Map[(Int, Int), Long], ConfusionMatrix] =
241 | new TfmaConverter[Prediction[Int, Int], Map[(Int, Int), Long], ConfusionMatrix] {
242 | override def convertToTfmaProto(
243 | underlying: ConfusionMatrix
244 | ): Aggregator[Prediction[Int, Int], Map[(Int, Int), Long], EvalResult] =
245 | underlying.andThenPresent((denseMatrixToConfusionMatrix() _).andThen { cm =>
246 | val metrics = confusionMatrixToMetric(cm)
247 | val plots = buildConfusionMatrixPlot(cm)
248 | EvalResult(metrics, Plot.ConfusionMatrix(plots))
249 | })
250 | }
251 |
252 | implicit val classificationReportConverter
253 | : TfmaConverter[BinaryPred, Map[(Int, Int), Long], ClassificationReport] =
254 | new TfmaConverter[BinaryPred, Map[(Int, Int), Long], ClassificationReport] {
255 | override def convertToTfmaProto(
256 | underlying: ClassificationReport
257 | ): Aggregator[BinaryPred, Map[(Int, Int), Long], EvalResult] =
258 | underlying.andThenPresent { report =>
259 | val allMetrics = Map(
260 | "Noether_Accuracy" -> report.accuracy,
261 | "Noether_FPR" -> report.fpr,
262 | "Noether_FScore" -> report.fscore,
263 | "Noether_MCC" -> report.mcc,
264 | "Noether_Precision" -> report.precision,
265 | "Noether_Recall" -> report.recall
266 | )
267 | val metrics = buildDoubleMetrics(allMetrics)
268 | EvalResult(metrics)
269 | }
270 | }
271 |
272 | implicit val aucConverter: TfmaConverter[BinaryPred, MetricCurve, AUC] =
273 | new TfmaConverter[BinaryPred, MetricCurve, AUC] {
274 | override def convertToTfmaProto(
275 | underlying: AUC
276 | ): Aggregator[BinaryPred, MetricCurve, EvalResult] =
277 | underlying
278 | .andThenPresent { areaValue =>
279 | val metricName = underlying.metric match {
280 | case ROC => "Noether_AUC:ROC"
281 | case PR => "Noether_AUC:PR"
282 | }
283 | val metrics = buildDoubleMetric(metricName, areaValue)
284 | EvalResult(metrics)
285 | }
286 | }
287 |
288 | implicit val logLossConverter
289 | : TfmaConverter[Prediction[Int, List[Double]], (Double, Long), LogLoss.type] =
290 | new TfmaConverter[Prediction[Int, List[Double]], (Double, Long), LogLoss.type] {
291 | override def convertToTfmaProto(
292 | underlying: LogLoss.type
293 | ): Aggregator[Prediction[Int, List[Double]], (Double, Long), EvalResult] =
294 | underlying.andThenPresent { logLoss =>
295 | val metrics = buildDoubleMetric("Noether_LogLoss", logLoss)
296 | EvalResult(metrics)
297 | }
298 | }
299 |
300 | implicit def meanAvgPrecisionConverter[T]
301 | : TfmaConverter[RankingPrediction[T], (Double, Long), MeanAveragePrecision[T]] =
302 | new TfmaConverter[RankingPrediction[T], (Double, Long), MeanAveragePrecision[T]] {
303 | override def convertToTfmaProto(
304 | underlying: MeanAveragePrecision[T]
305 | ): Aggregator[RankingPrediction[T], (Double, Long), EvalResult] =
306 | underlying.andThenPresent { meanAvgPrecision =>
307 | val metrics = buildDoubleMetric("Noether_MeanAvgPrecision", meanAvgPrecision)
308 | EvalResult(metrics)
309 | }
310 | }
311 |
312 | implicit def ndcgAtKConverter[T]
313 | : TfmaConverter[RankingPrediction[T], (Double, Long), NdcgAtK[T]] =
314 | new TfmaConverter[RankingPrediction[T], (Double, Long), NdcgAtK[T]] {
315 | override def convertToTfmaProto(
316 | underlying: NdcgAtK[T]
317 | ): Aggregator[RankingPrediction[T], (Double, Long), EvalResult] =
318 | underlying.andThenPresent { ndcgAtK =>
319 | val metrics = buildDoubleMetric("Noether_NdcgAtK", ndcgAtK)
320 | EvalResult(metrics)
321 | }
322 | }
323 |
324 | implicit def precisionAtK[T]
325 | : TfmaConverter[RankingPrediction[T], (Double, Long), PrecisionAtK[T]] =
326 | new TfmaConverter[RankingPrediction[T], (Double, Long), PrecisionAtK[T]] {
327 | override def convertToTfmaProto(
328 | underlying: PrecisionAtK[T]
329 | ): Aggregator[RankingPrediction[T], (Double, Long), EvalResult] =
330 | underlying.andThenPresent { precisionAtK =>
331 | val metrics = buildDoubleMetric("Noether_PrecisionAtK", precisionAtK)
332 | EvalResult(metrics)
333 | }
334 | }
335 |
336 | implicit def calibrationHistogram: TfmaConverter[Prediction[Double, Double], Map[
337 | Double,
338 | (Double, Double, Long)
339 | ], CalibrationHistogram] =
340 | new TfmaConverter[
341 | Prediction[Double, Double],
342 | Map[Double, (Double, Double, Long)],
343 | CalibrationHistogram
344 | ] {
345 | override def convertToTfmaProto(
346 | underlying: CalibrationHistogram
347 | ): Aggregator[Prediction[Double, Double], Map[Double, (Double, Double, Long)], EvalResult] =
348 | underlying.andThenPresent { calibrationHistogram =>
349 | val plot = buildCalibrationHistogramPlot(calibrationHistogram)
350 | EvalResult(Plot.CalibrationHistogram(plot))
351 | }
352 | }
353 | }
354 |
--------------------------------------------------------------------------------