├── project
├── build.properties
└── plugins.sbt
├── .idea
├── codeStyles
│ ├── codeStyleConfig.xml
│ └── Project.xml
├── vcs.xml
├── scala_compiler.xml
├── misc.xml
├── hydra.xml
├── scala_settings.xml
├── modules.xml
└── sbt.xml
├── CONTRIBUTORS.md
├── LICENSE
├── docs
├── macos-java8.md
└── guide.md
├── src
├── main
│ ├── scala
│ │ └── com
│ │ │ └── github
│ │ │ └── cleanzr
│ │ │ └── dblink
│ │ │ ├── Parameters.scala
│ │ │ ├── util
│ │ │ ├── HardPartitioner.scala
│ │ │ ├── PathToFileConverter.scala
│ │ │ ├── BufferedFileWriter.scala
│ │ │ ├── BufferedRDDWriter.scala
│ │ │ ├── PeriodicRDDCheckpointer.scala
│ │ │ └── PeriodicCheckpointer.scala
│ │ │ ├── analysis
│ │ │ ├── BinaryClassificationMetrics.scala
│ │ │ ├── baselines.scala
│ │ │ ├── PairwiseMetrics.scala
│ │ │ ├── BinaryConfusionMatrix.scala
│ │ │ ├── ClusteringContingencyTable.scala
│ │ │ ├── ClusteringMetrics.scala
│ │ │ └── package.scala
│ │ │ ├── partitioning
│ │ │ ├── PartitionFunction.scala
│ │ │ ├── SimplePartitioner.scala
│ │ │ ├── LPTScheduler.scala
│ │ │ ├── DomainSplitter.scala
│ │ │ ├── MutableBST.scala
│ │ │ └── KDTreePartitioner.scala
│ │ │ ├── DistortionProbs.scala
│ │ │ ├── Run.scala
│ │ │ ├── Logging.scala
│ │ │ ├── accumulators
│ │ │ ├── MapLongAccumulator.scala
│ │ │ ├── MapDoubleAccumulator.scala
│ │ │ └── MapArrayAccumulator.scala
│ │ │ ├── random
│ │ │ ├── DiscreteDist.scala
│ │ │ ├── IndexNonUniformDiscreteDist.scala
│ │ │ ├── NonUniformDiscreteDist.scala
│ │ │ └── AliasSampler.scala
│ │ │ ├── SummaryAccumulators.scala
│ │ │ ├── ProjectSteps.scala
│ │ │ ├── DiagnosticsWriter.scala
│ │ │ ├── SimilarityFn.scala
│ │ │ ├── CustomKryoRegistrator.scala
│ │ │ ├── package.scala
│ │ │ ├── RecordsCache.scala
│ │ │ ├── Sampler.scala
│ │ │ ├── ProjectStep.scala
│ │ │ ├── LinkageChain.scala
│ │ │ ├── AttributeIndex.scala
│ │ │ └── Project.scala
│ └── resources
│ │ └── log4j.properties
└── test
│ └── scala
│ ├── Launch.scala
│ └── com
│ └── github
│ └── cleanzr
│ └── dblink
│ ├── EntityInvertedIndexTest.scala
│ ├── AttributeTest.scala
│ ├── random
│ ├── DiscreteDistTest.scala
│ ├── DiscreteDistBehavior.scala
│ └── AliasSamplerTest.scala
│ ├── DistortionProbsTest.scala
│ ├── AttributeIndexBehaviors.scala
│ ├── SimilarityFnTest.scala
│ └── AttributeIndexTest.scala
├── .gitignore
├── examples
├── RLdata500.conf
└── RLdata10000.conf
└── README.md
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version = 1.2.8
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.7")
2 |
3 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.1.1")
--------------------------------------------------------------------------------
/.idea/codeStyles/codeStyleConfig.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/scala_compiler.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/codeStyles/Project.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/CONTRIBUTORS.md:
--------------------------------------------------------------------------------
1 | # dblink contributors
2 |
3 | The following people (sorted alphabetically) have contributed to coding, documentation and testing efforts.
4 |
5 | * [Andee Kaplan](http://andeekaplan.com/)
6 | * Neil G. Marchant (partially supported by Australian Bureau of Statistics)
7 | * [Rebecca C. Steorts](https://resteorts.github.io/)
8 |
--------------------------------------------------------------------------------
/.idea/hydra.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/.idea/scala_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/.idea/sbt.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
16 |
17 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | dblink: Empirical Bayes Entity Resolution for Spark
2 |
3 | Copyright (C) 2018-2019 dblink contributors
4 | Copyright (C) 2018 Australian Bureau of Statistics
5 |
6 | This program is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | (at your option) any later version.
10 |
11 | This program is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with this program. If not, see .
18 |
--------------------------------------------------------------------------------
/docs/macos-java8.md:
--------------------------------------------------------------------------------
1 | # Installing Java 8+ on macOS
2 |
3 | We recommend using the [AdoptOpenJDK](https://adoptopenjdk.net) Java
4 | distribution on macOS.
5 |
6 | It can be installed using the [Homebrew](https://brew.sh/) package manager.
7 | Simply run the following commands in a terminal:
8 | ```bash
9 | $ brew tap AdoptOpenJDK/openjdk
10 | $ brew cask install adoptopenjdk8
11 | ```
12 |
13 | Note that it's possible to have multiple versions of Java installed in
14 | parallel. If you run
15 | ```bash
16 | $ java -version
17 | ```
18 | and don't see a version number like 1.8.x or 8.x, then you'll need to
19 | manually select Java 8.
20 |
21 | To do this temporarily in a terminal, run
22 | ```
23 | $ export JAVA_HOME=$(/usr/libexec/java_home -v1.8)
24 | ```
25 | All references to Java within this session will then make use of Java 8.
26 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/Parameters.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | case class Parameters(maxClusterSize: Int) {
23 | require(maxClusterSize > 0, "`maxClusterSize` must be a positive integer.")
24 |
25 | def mkString: String = {
26 | "maxClusterSize: " + maxClusterSize
27 | }
28 | }
--------------------------------------------------------------------------------
/src/test/scala/Launch.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | import com.github.cleanzr.dblink.Run
21 | import org.apache.spark.{SparkConf, SparkContext}
22 |
23 | object Launch {
24 | def main(args: Array[String]) {
25 | val conf = new SparkConf().setMaster("local[2]").setAppName("dblink")
26 | val sc = SparkContext.getOrCreate(conf)
27 | Run.main(args)
28 | }
29 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/util/HardPartitioner.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.util
21 |
22 | import org.apache.spark.Partitioner
23 |
24 | class HardPartitioner(override val numPartitions: Int) extends Partitioner {
25 | override def getPartition(key: Any): Int = {
26 | val k = key.asInstanceOf[Int]
27 | k % numPartitions
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/util/PathToFileConverter.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.util
21 |
22 | import java.io.File
23 |
24 | import org.apache.hadoop.fs.{FileSystem, Path}
25 | import org.apache.hadoop.conf.Configuration
26 |
27 | object PathToFileConverter {
28 |
29 | def fileToPath(path: Path, conf: Configuration): File = {
30 | val fs = FileSystem.get(path.toUri, conf)
31 | val tempFile = File.createTempFile(path.getName, "")
32 | tempFile.deleteOnExit()
33 | fs.copyToLocalFile(path, new Path(tempFile.getAbsolutePath))
34 | tempFile
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/src/test/scala/com/github/cleanzr/dblink/EntityInvertedIndexTest.scala:
--------------------------------------------------------------------------------
1 | package com.github.cleanzr.dblink
2 |
3 | import com.github.cleanzr.dblink.GibbsUpdates.EntityInvertedIndex
4 | import org.scalatest._
5 |
6 | class EntityInvertedIndexTest extends FlatSpec {
7 |
8 | def entityIndex = new EntityInvertedIndex
9 |
10 | val allEntityIds = Set(1, 2, 3, 4)
11 |
12 | lazy val entitiesWithOneDup = Seq(
13 | (1, Entity(Array(1, 0, 0))),
14 | (2, Entity(Array(2, 0, 1))),
15 | (3, Entity(Array(3, 0, 1))),
16 | (4, Entity(Array(4, 1, 0))),
17 | (4, Entity(Array(4, 1, 0))) // same entity twice
18 | )
19 |
20 | behavior of "An entity inverted index (containing one entity)"
21 |
22 | it should "return singleton sets for the entity's attribute values" in {
23 | val index = new EntityInvertedIndex
24 | index.add(entitiesWithOneDup.head._1, entitiesWithOneDup.head._2)
25 | assert(entitiesWithOneDup.head._2.values.zipWithIndex.forall {case (valueId, attrId) => index.getEntityIds(attrId, valueId).size === 1})
26 | }
27 |
28 | behavior of "An entity inverted index (containing multiple entities)"
29 |
30 | it should "return the correct responses to queries" in {
31 | val index = new EntityInvertedIndex
32 | entitiesWithOneDup.foreach { case (entId, entity) => index.add(entId, entity) }
33 | assert(index.getEntityIds(0, 1) === Set(1))
34 | assert(index.getEntityIds(0, 4) === Set(4))
35 | assert(index.getEntityIds(1, 0) === Set(1, 2, 3))
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/src/test/scala/com/github/cleanzr/dblink/AttributeTest.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import org.scalatest.FlatSpec
23 | import com.github.cleanzr.dblink.SimilarityFn._
24 |
25 | class AttributeTest extends FlatSpec {
26 | behavior of "An attribute with a constant similarity function"
27 |
28 | it should "be constant" in {
29 | assert(Attribute("name", ConstantSimilarityFn, BetaShapeParameters(1.0, 1.0)).isConstant === true)
30 | }
31 |
32 | behavior of "An attribute with a non-constant similarity function"
33 |
34 | it should "not be constant" in {
35 | assert(Attribute("name", LevenshteinSimilarityFn(), BetaShapeParameters(1.0, 1.0)).isConstant === false)
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
2 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
3 |
4 | # User-specific stuff:
5 | .idea/**/workspace.xml
6 | .idea/**/tasks.xml
7 | .idea/dictionaries
8 |
9 | # Sensitive or high-churn files:
10 | .idea/**/dataSources/
11 | .idea/**/dataSources.ids
12 | .idea/**/dataSources.xml
13 | .idea/**/dataSources.local.xml
14 | .idea/**/sqlDataSources.xml
15 | .idea/**/dynamic.xml
16 | .idea/**/uiDesigner.xml
17 |
18 | # Gradle:
19 | .idea/**/gradle.xml
20 | .idea/**/libraries
21 |
22 | # CMake
23 | cmake-build-debug/
24 | cmake-build-release/
25 |
26 | # Mongo Explorer plugin:
27 | .idea/**/mongoSettings.xml
28 |
29 | ## File-based project format:
30 | *.iws
31 |
32 | ## Plugin-specific files:
33 |
34 | # IntelliJ
35 | out/
36 |
37 | # mpeltonen/sbt-idea plugin
38 | .idea_modules/
39 |
40 | # JIRA plugin
41 | atlassian-ide-plugin.xml
42 |
43 | # Cursive Clojure plugin
44 | .idea/replstate.xml
45 |
46 | # Crashlytics plugin (for Android Studio and IntelliJ)
47 | com_crashlytics_export_strings.xml
48 | crashlytics.properties
49 | crashlytics-build.properties
50 | fabric.properties
51 |
52 |
53 | # Simple Build Tool
54 | # http://www.scala-sbt.org/release/docs/Getting-Started/Directories.html#configuring-version-control
55 |
56 | dist/*
57 | target/
58 | lib_managed/
59 | src_managed/
60 | project/boot/
61 | project/plugins/project/
62 | .history
63 | .cache
64 | .lib/
65 |
66 | # Scala
67 | *.class
68 | *.log
--------------------------------------------------------------------------------
/src/test/scala/com/github/cleanzr/dblink/random/DiscreteDistTest.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.random
21 |
22 | import org.apache.commons.math3.random.{MersenneTwister, RandomGenerator}
23 | import org.scalatest.FlatSpec
24 |
25 | class DiscreteDistTest extends FlatSpec with DiscreteDistBehavior {
26 |
27 | implicit val rand: RandomGenerator = new MersenneTwister()
28 |
29 | def valuesWeights = Map("A" -> 100.0, "B" -> 200.0, "C" -> 700.0)
30 | val indexDist = DiscreteDist(valuesWeights.values)
31 | val dist = DiscreteDist(valuesWeights)
32 |
33 | "A discrete distribution (without values given)" should behave like genericDiscreteDist(indexDist, 5)
34 |
35 | "A discrete distribution (with values given)" should behave like genericDiscreteDist(dist, "D")
36 | }
37 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/analysis/BinaryClassificationMetrics.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.analysis
21 |
22 | object BinaryClassificationMetrics {
23 | def precision(binaryConfusionMatrix: BinaryConfusionMatrix): Double = {
24 | binaryConfusionMatrix.TP.toDouble / binaryConfusionMatrix.PP
25 | }
26 |
27 | def recall(binaryConfusionMatrix: BinaryConfusionMatrix): Double = {
28 | binaryConfusionMatrix.TP.toDouble / binaryConfusionMatrix.P
29 | }
30 |
31 | def fMeasure(binaryConfusionMatrix: BinaryConfusionMatrix,
32 | beta: Double = 1.0): Double = {
33 | val betaSq = beta * beta
34 | val pr = precision(binaryConfusionMatrix)
35 | val re = recall(binaryConfusionMatrix)
36 | (1 + betaSq) * pr * re / (betaSq * pr + re)
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/partitioning/PartitionFunction.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.partitioning
21 |
22 | import org.apache.spark.rdd.RDD
23 |
24 | abstract class PartitionFunction[T] extends Serializable {
25 |
26 | /** Number of partitions */
27 | def numPartitions: Int
28 |
29 | /** Fit partition function based on a sample of records
30 | *
31 | * @param records RDD of records. Assumes no values are missing.
32 | */
33 | def fit(records: RDD[Array[T]]): Unit
34 |
35 | /** Get assigned partition for a set of attribute values
36 | *
37 | * @param attributeValues array of (entity) attribute values. Assumes no values are missing.
38 | * @return identifier of assigned parittion: an integer in {0, ..., numPartitions - 1}.
39 | */
40 | def getPartitionId(attributeValues: Array[T]): Int
41 |
42 | def mkString: String
43 | }
--------------------------------------------------------------------------------
/src/test/scala/com/github/cleanzr/dblink/DistortionProbsTest.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import org.scalatest.FlatSpec
23 |
24 | class DistortionProbsTest extends FlatSpec {
25 |
26 | val twoFilesOneAttribute = DistortionProbs(Seq("A", "B"), Iterator(BetaShapeParameters(3.0, 3.0)))
27 |
28 | behavior of "Distortion probabilities for two files and one attribute (with identical shape parameters)"
29 |
30 | it should "return probability 0.5 for both files" in {
31 | assert(twoFilesOneAttribute(0, "A") === 0.5)
32 | assert(twoFilesOneAttribute(0, "B") === 0.5)
33 | }
34 |
35 | it should "complain when a probability is requested for a third file" in {
36 | assertThrows[NoSuchElementException] {
37 | twoFilesOneAttribute(0, "C")
38 | }
39 | }
40 |
41 | it should "complain when a probability is requested for a second attribute" in {
42 | assertThrows[NoSuchElementException] {
43 | twoFilesOneAttribute(1, "A")
44 | }
45 | }
46 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/DistortionProbs.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | /** Container for the distortion probabilities (which vary for each file
23 | * and attribute).
24 | *
25 | * @param probs map used internally to store the probabilities
26 | */
27 | case class DistortionProbs(probs: Map[(AttributeId, FileId), Double]) {
28 | def apply(attrId: AttributeId, fileId: FileId): Double = probs.apply((attrId, fileId))
29 | }
30 |
31 | object DistortionProbs {
32 | /** Generate distortion probabilities based on the prior mean */
33 | def apply(fileIds: Traversable[FileId],
34 | distortionPrior: Iterator[BetaShapeParameters]): DistortionProbs = {
35 | require(fileIds.nonEmpty, "`fileIds` cannot be empty")
36 | require(distortionPrior.nonEmpty, "`distortionPrior` cannot be empty")
37 |
38 | val probs = distortionPrior.zipWithIndex.flatMap { case (BetaShapeParameters(alpha, beta), attrId) =>
39 | fileIds.map { fileId => ((attrId, fileId), alpha / (alpha + beta)) }
40 | }.toMap
41 |
42 | DistortionProbs(probs)
43 | }
44 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/Run.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import com.github.cleanzr.dblink.util.{BufferedFileWriter, PathToFileConverter}
23 | import com.typesafe.config.ConfigFactory
24 | import org.apache.spark.sql.SparkSession
25 | import org.apache.hadoop.fs.Path
26 |
27 | object Run extends App with Logging {
28 |
29 | val spark = SparkSession.builder().appName("dblink").getOrCreate()
30 | val sc = spark.sparkContext
31 | sc.setLogLevel("WARN")
32 |
33 | val configFile = PathToFileConverter.fileToPath(new Path(args.head), sc.hadoopConfiguration)
34 |
35 | val config = ConfigFactory.parseFile(configFile).resolve()
36 |
37 | val project = Project(config)
38 | val writer = BufferedFileWriter(project.outputPath + "run.txt", append = false, sparkContext = project.sparkContext)
39 | writer.write(project.mkString)
40 |
41 | val steps = ProjectSteps(config, project)
42 | writer.write("\n" + steps.mkString)
43 | writer.close()
44 |
45 | sc.setCheckpointDir(project.checkpointPath)
46 |
47 | steps.execute()
48 |
49 | sc.stop()
50 | }
51 |
--------------------------------------------------------------------------------
/src/test/scala/com/github/cleanzr/dblink/AttributeIndexBehaviors.scala:
--------------------------------------------------------------------------------
1 | package com.github.cleanzr.dblink
2 |
3 | import org.scalatest.{FlatSpec, Matchers}
4 |
5 | trait AttributeIndexBehaviors extends Matchers { this: FlatSpec =>
6 |
7 | def genericAttributeIndex(index: AttributeIndex, valuesWeights: Map[String, Double]) {
8 | it should "have the correct number of values" in {
9 | assert(index.numValues === valuesWeights.size)
10 | }
11 |
12 | it should "assign unique value ids in the set {0, ..., numValues - 1}" in {
13 | assert(valuesWeights.keysIterator.map(stringValue => index.valueIdxOf(stringValue)).toSet ===
14 | (0 until valuesWeights.size).toSet)
15 | }
16 |
17 | it should "have the correct probability distribution" in {
18 | val totalWeight = valuesWeights.foldLeft(0.0)((total, x) => total + x._2)
19 | assert(valuesWeights.mapValues(_/totalWeight).forall { case (stringValue, correctProb) =>
20 | index.probabilityOf(index.valueIdxOf(stringValue)) === (correctProb +- 1e-4)
21 | })
22 | }
23 |
24 | it should "complain when a probability is requested for a non-indexed value" in {
25 | assertThrows[RuntimeException] {
26 | index.probabilityOf(index.numValues + 1)
27 | }
28 | }
29 |
30 | it should "complain when a similarity normalization is requested for a non-indexed value" in {
31 | assertThrows[RuntimeException] {
32 | index.simNormalizationOf(index.numValues + 1)
33 | }
34 | }
35 |
36 | it should "complain when similar values are requested for a non-indexed value" in {
37 | assertThrows[RuntimeException] {
38 | index.simValuesOf(index.numValues + 1)
39 | }
40 | }
41 |
42 | it should "complain when an exponentiated similarity score is requested for a non-indexed value" in {
43 | assertThrows[RuntimeException] {
44 | index.expSimOf(index.numValues + 1, 0)
45 | }
46 | assertThrows[RuntimeException] {
47 | index.expSimOf(0, index.numValues + 1)
48 | }
49 | }
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/Logging.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import org.apache.log4j.Logger
23 | import org.apache.log4j.Level
24 |
25 | trait Logging {
26 | private[this] lazy val logger = Logger.getLogger(logName)
27 |
28 | private[this] lazy val logName: String = this.getClass.getName.stripSuffix("$")
29 |
30 | def info(message: => String): Unit = if (logger.isEnabledFor(Level.INFO)) logger.info(message)
31 |
32 | def warn(message: => String): Unit = if (logger.isEnabledFor(Level.WARN)) logger.warn(message)
33 |
34 | def warn(message: => String, t: Throwable): Unit = if (logger.isEnabledFor(Level.WARN)) logger.warn(message, t)
35 |
36 | def error(message: => String): Unit = if (logger.isEnabledFor(Level.ERROR)) logger.error(message)
37 |
38 | def error(message: => String, t: Throwable): Unit = if (logger.isEnabledFor(Level.ERROR)) logger.error(message, t)
39 |
40 | def fatal(message: => String): Unit = if (logger.isEnabledFor(Level.FATAL)) logger.fatal(message)
41 |
42 | def fatal(message: => String, t: Throwable): Unit = if (logger.isEnabledFor(Level.FATAL)) logger.fatal(message, t)
43 |
44 | def debug(message: => String): Unit = if (logger.isEnabledFor(Level.DEBUG)) logger.debug(message)
45 | }
46 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/accumulators/MapLongAccumulator.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.accumulators
21 |
22 | import org.apache.spark.util.AccumulatorV2
23 | import scala.collection.mutable
24 |
25 | /**
26 | * Accumulates counts corresponding to keys.
27 | * e.g. if we add K1 to the accumulator and (K1 -> 10L) is the current
28 | * key-value pair, the resulting key value pair will be (K1 -> 11L).
29 | *
30 | * @tparam K key type
31 | */
32 | class MapLongAccumulator[K] extends AccumulatorV2[(K, Long), Map[K, Long]] {
33 | private val _map = mutable.HashMap.empty[K, Long]
34 |
35 | override def reset(): Unit = _map.clear()
36 |
37 | override def add(kv: (K, Long)): Unit = {
38 | _map.update(kv._1, _map.getOrElse(kv._1, 0l) + kv._2)
39 | }
40 |
41 | override def value: Map[K, Long] = _map.toMap
42 |
43 | override def isZero: Boolean = _map.isEmpty
44 |
45 | override def copy(): MapLongAccumulator[K] = {
46 | val newAcc = new MapLongAccumulator[K]
47 | newAcc._map ++= _map
48 | newAcc
49 | }
50 |
51 | def toIterator: Iterator[(K, Long)] = _map.iterator
52 |
53 | override def merge(other: AccumulatorV2[(K, Long), Map[K, Long]]): Unit = other match {
54 | case o: MapLongAccumulator[K] => o.toIterator.foreach { x => this.add(x) }
55 | case _ =>
56 | throw new UnsupportedOperationException(
57 | s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
58 | }
59 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/analysis/baselines.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.analysis
21 |
22 | import com.github.cleanzr.dblink.{Cluster, RecordId}
23 | import org.apache.spark.rdd.RDD
24 |
25 | object baselines {
26 | def exactMatchClusters(records: RDD[(RecordId, Seq[String])]): RDD[Cluster] = {
27 | records.map(row => (row._2.mkString, row._1)) // (values, id)
28 | .aggregateByKey(Set.empty[RecordId]) (
29 | seqOp = (recIds, recId) => recIds + recId,
30 | combOp = (recIdsA, recIdsB) => recIdsA ++ recIdsB
31 | ).map(_._2)
32 | }
33 |
34 | /** Generates (overlapping) clusters that are near matches based on attribute agreements/disagreements
35 | *
36 | * @param records
37 | * @param numDisagree
38 | * @return
39 | */
40 | def nearClusters(records: RDD[(RecordId, Seq[String])], numDisagree: Int): RDD[Cluster] = {
41 | require(numDisagree >= 0, "`numAgree` must be non-negative")
42 |
43 | records.flatMap { row =>
44 | val numAttr = row._2.length
45 | val attrIds = 0 until numAttr
46 | attrIds.combinations(numDisagree).map { delIds =>
47 | val partialValues = row._2.zipWithIndex.collect { case (value, attrId) if !delIds.contains(attrId) => value }
48 | (partialValues.mkString, row._1)
49 | }
50 | }.aggregateByKey(Set.empty[RecordId])(
51 | seqOp = (recIds, recId) => recIds + recId,
52 | combOp = (recIdsA, recIdsB) => recIdsA ++ recIdsB
53 | ).map(_._2)
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/accumulators/MapDoubleAccumulator.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.accumulators
21 |
22 | import org.apache.spark.util.AccumulatorV2
23 |
24 | import scala.collection.mutable
25 |
26 | /**
27 | * Accumulates counts corresponding to keys.
28 | * e.g. if we add K1 to the accumulator and (K1 -> 10L) is the current
29 | * key-value pair, the resulting key value pair will be (K1 -> 11L).
30 | *
31 | * @tparam K key type
32 | */
33 | class MapDoubleAccumulator[K] extends AccumulatorV2[(K, Double), Map[K, Double]] {
34 | private val _map = mutable.HashMap.empty[K, Double]
35 |
36 | override def reset(): Unit = _map.clear()
37 |
38 | override def add(kv: (K, Double)): Unit = {
39 | _map.update(kv._1, _map.getOrElse(kv._1, 0.0) + kv._2)
40 | }
41 |
42 | override def value: Map[K, Double] = _map.toMap
43 |
44 | override def isZero: Boolean = _map.isEmpty
45 |
46 | override def copy(): MapDoubleAccumulator[K] = {
47 | val newAcc = new MapDoubleAccumulator[K]
48 | newAcc._map ++= _map
49 | newAcc
50 | }
51 |
52 | def toIterator: Iterator[(K, Double)] = _map.iterator
53 |
54 | override def merge(other: AccumulatorV2[(K, Double), Map[K, Double]]): Unit = other match {
55 | case o: MapDoubleAccumulator[K] => o.toIterator.foreach { x => this.add(x) }
56 | case _ =>
57 | throw new UnsupportedOperationException(
58 | s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
59 | }
60 | }
--------------------------------------------------------------------------------
/src/test/scala/com/github/cleanzr/dblink/random/DiscreteDistBehavior.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.random
21 |
22 | import org.apache.commons.math3.random.RandomGenerator
23 | import org.scalatest.{FlatSpec, Matchers}
24 |
25 | trait DiscreteDistBehavior extends Matchers { this: FlatSpec =>
26 |
27 | def genericDiscreteDist[T](dist: DiscreteDist[T], valueOutsideSupport: T)(implicit rand: RandomGenerator): Unit = {
28 | it should "be normalised" in {
29 | assert(dist.values.foldLeft(0.0) { case (sum, v) => sum + dist.probabilityOf(v) } === (1.0 +- 1e-9) )
30 | }
31 |
32 | it should "return valid probabilities for all values in the support" in {
33 | assert(dist.values.forall {value =>
34 | val prob = dist.probabilityOf(value)
35 | prob >= 0 && !prob.isInfinity && !prob.isNaN && prob <= 1})
36 | }
37 |
38 | it should "return a probability of 0.0 for a value outside the support" in {
39 | assert(dist.probabilityOf(valueOutsideSupport) === 0.0)
40 | }
41 |
42 | it should "return `numValues` equal to the size of `values`" in {
43 | assert(dist.numValues === dist.values.size)
44 | }
45 |
46 | it should "return a valid `totalWeight`" in {
47 | assert(dist.totalWeight > 0)
48 | assert(!dist.totalWeight.isNaN && !dist.totalWeight.isInfinity)
49 | }
50 |
51 | it should "not return sample values that have probability 0.0" in {
52 | assert((1 to 1000).map(_ => dist.sample()).forall(v => dist.probabilityOf(v) > 0.0))
53 | }
54 | }
55 |
56 |
57 | }
58 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/random/DiscreteDist.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.random
21 |
22 | import org.apache.commons.math3.random.RandomGenerator
23 |
24 | import scala.collection.Map
25 | import scala.reflect.ClassTag
26 |
27 | /** A distribution over a discrete set of values
28 | *
29 | * @tparam T type of the values
30 | */
31 | trait DiscreteDist[T] extends Serializable {
32 |
33 | /** Values in the support set */
34 | def values: Traversable[T]
35 |
36 | /** Number of values in the support set */
37 | def numValues: Int
38 |
39 | /** Total weight before normalization */
40 | def totalWeight: Double
41 |
42 | /** Draw a value according to the distribution
43 | *
44 | * @param rand external RandomGenerator to use for drawing sample
45 | * @return a value from the support set
46 | */
47 | def sample()(implicit rand: RandomGenerator): T
48 |
49 | /** Get the probability mass associated with a value
50 | *
51 | * @param value a value from the support set
52 | * @return probability. Returns 0.0 if the value is not in the support set.
53 | */
54 | def probabilityOf(value: T): Double
55 |
56 | /** Iterator over the values in the support, together with their probability
57 | * mass
58 | */
59 | def toIterator: Iterator[(T, Double)]
60 | }
61 |
62 | object DiscreteDist {
63 | def apply[T](valuesAndWeights: Map[T, Double])
64 | (implicit ev: ClassTag[T]): NonUniformDiscreteDist[T] = {
65 | NonUniformDiscreteDist[T](valuesAndWeights)
66 | }
67 |
68 | def apply(weights: Traversable[Double]): IndexNonUniformDiscreteDist = {
69 | IndexNonUniformDiscreteDist(weights)
70 | }
71 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/partitioning/SimplePartitioner.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.partitioning
21 |
22 | import SimplePartitioner._
23 | import org.apache.spark.rdd.RDD
24 | import scala.reflect.ClassTag
25 |
26 | /**
27 | * Block on the values of a single field, then combine to get roughly
28 | * equal-size partitions (based on empirical distribution)
29 | *
30 | * @param attributeId
31 | * @param numPartitions
32 | */
33 | class SimplePartitioner[T : ClassTag : Ordering](val attributeId: Int,
34 | override val numPartitions: Int)
35 | extends PartitionFunction[T] {
36 |
37 | private var _lpt: LPTScheduler[T, Double] = _
38 |
39 | override def fit(records: RDD[Array[T]]): Unit = {
40 | val weights = records.map(values => (values(attributeId), 1.0))
41 | .reduceByKey(_ + _)
42 | .collect()
43 | _lpt = generateLPTScheduler(weights, numPartitions)
44 | }
45 |
46 | override def getPartitionId(values: Array[T]): Int = {
47 | if (_lpt == null) -1
48 | else _lpt.getPartitionId(values(attributeId))
49 | }
50 |
51 | override def mkString: String = s"SimplePartitioner(attributeId=$attributeId, numPartitions=$numPartitions)"
52 | }
53 |
54 | object SimplePartitioner {
55 | private def generateLPTScheduler[T](weights: Array[(T, Double)],
56 | numPartitions: Int): LPTScheduler[T, Double] = {
57 | new LPTScheduler[T, Double](weights, numPartitions)
58 | }
59 |
60 | def apply[T : ClassTag : Ordering](attributeId: Int,
61 | numPartitions: Int): SimplePartitioner[T] = {
62 | new SimplePartitioner(attributeId, numPartitions)
63 | }
64 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/analysis/PairwiseMetrics.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.analysis
21 |
22 | import com.github.cleanzr.dblink.RecordPair
23 | import org.apache.spark.sql.Dataset
24 |
25 | case class PairwiseMetrics(precision: Double,
26 | recall: Double,
27 | f1score: Double) {
28 | def mkString: String = {
29 | "=====================================\n" +
30 | " Pairwise metrics \n" +
31 | "-------------------------------------\n" +
32 | s" Precision: $precision\n" +
33 | s" Recall: $recall\n" +
34 | s" F1-score: $f1score\n" +
35 | "=====================================\n"
36 | }
37 |
38 | def print(): Unit = {
39 | Console.print(mkString)
40 | }
41 | }
42 |
43 | object PairwiseMetrics {
44 | def LinksConfusionMatrix(predictedLinks: Dataset[RecordPair],
45 | trueLinks: Dataset[RecordPair]): BinaryConfusionMatrix = {
46 | // Create PairRDDs so we can use the fullOuterJoin function
47 | val predictedLinks2 = predictedLinks.rdd.map(pair => (pair, true))
48 | val trueLinks2 = trueLinks.rdd.map(pair => (pair, true))
49 | val joinedLinks = predictedLinks2.fullOuterJoin(trueLinks2)
50 | .map { case (_, link) => (link._1.isDefined, link._2.isDefined)}
51 | BinaryConfusionMatrix(joinedLinks)
52 | }
53 |
54 | def apply(predictedLinks: Dataset[RecordPair],
55 | trueLinks: Dataset[RecordPair]): PairwiseMetrics = {
56 | val confusionMatrix = LinksConfusionMatrix(predictedLinks, trueLinks)
57 |
58 | val precision = BinaryClassificationMetrics.precision(confusionMatrix)
59 | val recall = BinaryClassificationMetrics.recall(confusionMatrix)
60 | val f1score = BinaryClassificationMetrics.fMeasure(confusionMatrix, beta = 1.0)
61 |
62 | PairwiseMetrics(precision, recall, f1score)
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/src/main/resources/log4j.properties:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | # Set everything to be logged to the console
19 | log4j.rootCategory=INFO, console, file
20 | log4j.appender.console=org.apache.log4j.ConsoleAppender
21 | log4j.appender.console.target=System.err
22 | log4j.appender.console.layout=org.apache.log4j.PatternLayout
23 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
24 |
25 | # Set up logging to file
26 | log4j.appender.file=org.apache.log4j.RollingFileAppender
27 | log4j.appender.file.File=./dblink.log
28 | log4j.appender.file.ImmediateFlush=true
29 | ## Set the append to false, overwrite
30 | log4j.appender.file.Append=true
31 | log4j.appender.file.MaxFileSize=100MB
32 | log4j.appender.file.MaxBackupIndex=10
33 | ##Define the layout for file appender
34 | log4j.appender.file.layout=org.apache.log4j.PatternLayout
35 | log4j.appender.file.layout.ConversionPattern=%d{yy-MM-dd HH:mm:ss} %p %c{1}: %m%n
36 |
37 | log4j.logger.com.github.cleanzr.dblink=INFO
38 |
39 | # Set the default spark-shell log level to WARN. When running the spark-shell, the
40 | # log level for this class is used to overwrite the root logger's log level, so that
41 | # the user can have different defaults for the shell and regular Spark apps.
42 | log4j.logger.org.apache.spark.repl.Main=WARN
43 |
44 | # Settings to quiet third party logs that are too verbose
45 | log4j.logger.org.spark_project.jetty=WARN
46 | log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR
47 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
48 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
49 | log4j.logger.org.apache.parquet=ERROR
50 | log4j.logger.parquet=ERROR
51 |
52 | # SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support
53 | log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL
54 | log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/accumulators/MapArrayAccumulator.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.accumulators
21 |
22 | import org.apache.spark.util.AccumulatorV2
23 |
24 | import scala.collection.mutable
25 |
26 | /**
27 | * Accumulates arrays of values corresponding to keys.
28 | * e.g. if we add (K1 -> "C") to the accumulator and (K1 -> ArrayBuffer("A","B"))
29 | * is the current key-value pair, the resulting key value pair will be
30 | * (K1 -> ArrayBuffer("A", B", "C")).
31 | *
32 | * @tparam K key type
33 | * @tparam V value type (for elements of set)
34 | */
35 | class MapArrayAccumulator[K,V] extends AccumulatorV2[(K,V), Map[K, mutable.ArrayBuffer[V]]] {
36 | private val _map = mutable.HashMap.empty[K, mutable.ArrayBuffer[V]]
37 |
38 | override def reset(): Unit = _map.clear()
39 |
40 | override def add(x: (K,V)): Unit = {
41 | val _array = _map.getOrElseUpdate(x._1, mutable.ArrayBuffer.empty[V])
42 | _array += x._2
43 | }
44 |
45 | private def add_ks(x: (K, Traversable[V])): Unit = {
46 | val _array = _map.getOrElseUpdate(x._1, mutable.ArrayBuffer.empty[V])
47 | x._2.foreach { v => _array += v }
48 | }
49 |
50 | override def value: Map[K, mutable.ArrayBuffer[V]] = _map.toMap
51 |
52 | override def isZero: Boolean = _map.isEmpty
53 |
54 | def toIterator: Iterator[(K, mutable.ArrayBuffer[V])] = _map.iterator
55 |
56 | override def copy(): MapArrayAccumulator[K, V] = {
57 | val newAcc = new MapArrayAccumulator[K, V]
58 | _map.foreach( x => newAcc._map.update(x._1, x._2.clone()))
59 | newAcc
60 | }
61 |
62 | override def merge(other: AccumulatorV2[(K, V), Map[K, mutable.ArrayBuffer[V]]]): Unit =
63 | other match {
64 | case o: MapArrayAccumulator[K, V] => o.toIterator.foreach {x => this.add_ks(x)}
65 | case _ =>
66 | throw new UnsupportedOperationException(
67 | s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
68 | }
69 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/analysis/BinaryConfusionMatrix.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.analysis
21 |
22 | import org.apache.spark.rdd.RDD
23 | import org.apache.spark.sql.Dataset
24 |
25 | case class BinaryConfusionMatrix(TP: Long,
26 | FP: Long,
27 | FN: Long) {
28 | def P: Long = TP + FN
29 | def PP: Long = TP + FP
30 | }
31 |
32 | object BinaryConfusionMatrix {
33 | def apply(predictionsAndLabels: Seq[(Boolean, Boolean)]): BinaryConfusionMatrix = {
34 | var TP = 0L
35 | var FP = 0L
36 | var FN = 0L
37 | predictionsAndLabels.foreach {case (prediction, label) =>
38 | if (prediction & label) TP += 1L
39 | if (prediction & !label) FP += 1L
40 | if (!prediction & label) FN += 1L
41 | }
42 | BinaryConfusionMatrix(TP, FP, FN)
43 | }
44 |
45 | def apply(predictionsAndLabels: RDD[(Boolean, Boolean)]): BinaryConfusionMatrix = {
46 | val sc = predictionsAndLabels.sparkContext
47 | val accTP = sc.longAccumulator("TP")
48 | val accFP = sc.longAccumulator("FP")
49 | val accFN = sc.longAccumulator("FN")
50 | predictionsAndLabels.foreach {case (prediction, label) =>
51 | if (prediction & label) accTP.add(1L)
52 | if (prediction & !label) accFP.add(1L)
53 | if (!prediction & label) accFN.add(1L)
54 | }
55 | BinaryConfusionMatrix(accTP.value, accFP.value, accFN.value)
56 | }
57 |
58 | def apply(predictionsAndLabels: Dataset[(Boolean, Boolean)]): BinaryConfusionMatrix = {
59 | val spark = predictionsAndLabels.sparkSession
60 | val sc = spark.sparkContext
61 | val accTP = sc.longAccumulator("TP")
62 | val accFP = sc.longAccumulator("FP")
63 | val accFN = sc.longAccumulator("FN")
64 | predictionsAndLabels.rdd.foreach { case (prediction, label) =>
65 | if (prediction & label) accTP.add(1L)
66 | if (prediction & !label) accFP.add(1L)
67 | if (!prediction & label) accFN.add(1L)
68 | }
69 | BinaryConfusionMatrix(accTP.value, accFP.value, accFN.value)
70 | }
71 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/util/BufferedFileWriter.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.util
21 |
22 | import java.io.{BufferedWriter, OutputStreamWriter}
23 |
24 | import com.github.cleanzr.dblink.util.BufferedFileWriter._
25 | import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
26 | import org.apache.hadoop.util.Progressable
27 | import org.apache.spark.SparkContext
28 |
29 | case class BufferedFileWriter(path: String,
30 | append: Boolean,
31 | sparkContext: SparkContext) {
32 | private val hdfs = FileSystem.get(sparkContext.hadoopConfiguration)
33 | private val file = new Path(path)
34 | private val _progress = new WriterProgress
35 | private var partsDir: Path = _ // temp working dir for appending
36 |
37 | private val outStream = if (hdfs.exists(file) && append) {
38 | /** Hadoop doesn't support append on a ChecksumFilesystem.
39 | * Get around this limitation by writing to a temporary new file and
40 | * merging with the old file on .close() */
41 | partsDir = new Path(path + "-PARTS")
42 | hdfs.mkdirs(partsDir) // dir for new and old parts
43 | hdfs.rename(file, new Path(partsDir.toString + Path.SEPARATOR + "PART0.csv")) // move old part into dir
44 | hdfs.create(new Path(partsDir.toString + Path.SEPARATOR + "PART1.csv"), true, 1, _progress)
45 | } else {
46 | hdfs.create(file, true, 1, _progress)
47 | }
48 |
49 | private val writer = new OutputStreamWriter(outStream, "UTF-8")
50 |
51 | def close(): Unit = {
52 | writer.close()
53 | if (partsDir != null) {
54 | /** Need to merge new and old parts */
55 | FileUtil.copyMerge(hdfs, partsDir, hdfs, file,
56 | true, sparkContext.hadoopConfiguration, null)
57 | hdfs.delete(partsDir, true)
58 | }
59 | //hdfs.close()
60 | }
61 |
62 | def flush(): Unit = writer.flush()
63 |
64 | def write(str: String): Unit = writer.write(str, 0, str.length())
65 |
66 | def progress(): Unit = _progress.progress()
67 | }
68 |
69 | object BufferedFileWriter {
70 | class WriterProgress extends Progressable {
71 | override def progress(): Unit = {}
72 | }
73 | }
--------------------------------------------------------------------------------
/examples/RLdata500.conf:
--------------------------------------------------------------------------------
1 | dblink : {
2 |
3 | // Define distortion hyperparameters (to be referenced below)
4 | lowDistortion : {alpha : 0.5, beta : 50.0}
5 |
6 | // Define similarity functions (to be referenced below)
7 | constSimFn : {
8 | name : "ConstantSimilarityFn",
9 | }
10 |
11 | levSimFn : {
12 | name : "LevenshteinSimilarityFn",
13 | parameters : {
14 | threshold : 7.0
15 | maxSimilarity : 10.0
16 | }
17 | }
18 |
19 | data : {
20 | // Path to data files. Must have header row (column names).
21 | path : "./examples/RLdata500.csv"
22 |
23 | // Specify columns that contain identifiers
24 | recordIdentifier : "rec_id",
25 | // fileIdentifier : null, // not needed since this data set is only a single file
26 | entityIdentifier : "ent_id" // optional
27 |
28 | // String representation of a missing value
29 | nullValue : "NA"
30 |
31 | // Specify properties of the attributes (columns) used for matching
32 | matchingAttributes : [
33 | {name : "by", similarityFunction : ${dblink.constSimFn}, distortionPrior : ${dblink.lowDistortion}},
34 | {name : "bm", similarityFunction : ${dblink.constSimFn}, distortionPrior : ${dblink.lowDistortion}},
35 | {name : "bd", similarityFunction : ${dblink.constSimFn}, distortionPrior : ${dblink.lowDistortion}},
36 | {name : "fname_c1", similarityFunction : ${dblink.levSimFn}, distortionPrior : ${dblink.lowDistortion}},
37 | {name : "lname_c1", similarityFunction : ${dblink.levSimFn}, distortionPrior : ${dblink.lowDistortion}}
38 | ]
39 | }
40 |
41 | randomSeed : 319158
42 | expectedMaxClusterSize : 10
43 |
44 | // Specify partitioner
45 | partitioner : {
46 | name : "KDTreePartitioner",
47 | parameters : {
48 | numLevels : 0, // a value of zero means no partitioning
49 | matchingAttributes : [] // cycle through matching attributes in this order when constructing the tree
50 | }
51 | }
52 |
53 | // Path to Markov chain and full state (for resuming MCMC)
54 | outputPath : "./examples/RLdata500_results/"
55 |
56 | // Path to save Spark checkpoints
57 | checkpointPath : "/tmp/spark_checkpoint/"
58 |
59 | // Steps to be performed (in order)
60 | steps : [
61 | {name : "sample", parameters : {
62 | sampleSize : 100,
63 | burninInterval : 0,
64 | thinningInterval : 10,
65 | resume : false,
66 | sampler : "PCG-I"
67 | }},
68 | {name : "summarize", parameters : {
69 | lowerIterationCutoff : 0,
70 | quantities : ["cluster-size-distribution"]
71 | }},
72 | {name : "evaluate", parameters : {
73 | lowerIterationCutoff : 100,
74 | metrics : ["pairwise", "cluster"],
75 | useExistingSMPC : false
76 | }}
77 | ]
78 | }
79 |
--------------------------------------------------------------------------------
/examples/RLdata10000.conf:
--------------------------------------------------------------------------------
1 | dblink : {
2 |
3 | // Define distortion hyperparameters (to be referenced below)
4 | lowDistortion : {alpha : 10.0, beta : 1000.0}
5 |
6 | // Define similarity functions (to be referenced below)
7 | constSimFn : {
8 | name : "ConstantSimilarityFn",
9 | }
10 |
11 | levSimFn : {
12 | name : "LevenshteinSimilarityFn",
13 | parameters : {
14 | threshold : 7.0
15 | maxSimilarity : 10.0
16 | }
17 | }
18 |
19 | data : {
20 | // Path to data files. Must have header row (column names).
21 | path : "./examples/RLdata10000.csv"
22 |
23 | // Specify columns that contain identifiers
24 | recordIdentifier : "rec_id",
25 | // fileIdentifier : null, // not needed since this data set is only a single file
26 | entityIdentifier : "ent_id" // optional
27 |
28 | // String representation of a missing value
29 | nullValue : "NA"
30 |
31 | // Specify properties of the attributes (columns) used for matching
32 | matchingAttributes : [
33 | {name : "by", similarityFunction : ${dblink.constSimFn}, distortionPrior : ${dblink.lowDistortion}},
34 | {name : "bm", similarityFunction : ${dblink.constSimFn}, distortionPrior : ${dblink.lowDistortion}},
35 | {name : "bd", similarityFunction : ${dblink.constSimFn}, distortionPrior : ${dblink.lowDistortion}},
36 | {name : "fname_c1", similarityFunction : ${dblink.levSimFn}, distortionPrior : ${dblink.lowDistortion}},
37 | {name : "lname_c1", similarityFunction : ${dblink.levSimFn}, distortionPrior : ${dblink.lowDistortion}}
38 | ]
39 | }
40 |
41 | randomSeed : 319158
42 | expectedMaxClusterSize : 10
43 |
44 | // Specify partitioner
45 | partitioner : {
46 | name : "KDTreePartitioner",
47 | parameters : {
48 | numLevels : 1, // a value of zero means no partitioning
49 | matchingAttributes : ["fname_c1"] // cycle through matching attributes in this order when constructing the tree
50 | }
51 | }
52 |
53 | // Path to Markov chain and full state (for resuming MCMC)
54 | outputPath : "./examples/RLdata10000_results/"
55 |
56 | // Path to save Spark checkpoints
57 | checkpointPath : "/tmp/spark_checkpoint/"
58 |
59 | // Steps to be performed (in order)
60 | steps : [
61 | {name : "sample", parameters : {
62 | sampleSize : 100,
63 | burninInterval : 0,
64 | thinningInterval : 10,
65 | resume : false,
66 | sampler : "PCG-I"
67 | }},
68 | {name : "summarize", parameters : {
69 | lowerIterationCutoff : 0,
70 | quantities : ["cluster-size-distribution", "partition-sizes"]
71 | }},
72 | {name : "evaluate", parameters : {
73 | lowerIterationCutoff : 100,
74 | metrics : ["pairwise", "cluster"],
75 | useExistingSMPC : false
76 | }}
77 | ]
78 | }
79 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/SummaryAccumulators.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import com.github.cleanzr.dblink.accumulators.MapLongAccumulator
23 | import com.github.cleanzr.dblink.accumulators.MapLongAccumulator
24 | import org.apache.spark.SparkContext
25 | import org.apache.spark.util.{DoubleAccumulator, LongAccumulator}
26 |
27 | /** Collection of accumulators used for computing `SummaryVars`
28 | *
29 | * @param logLikelihood double accumulator for the (un-normalised) log-likelihood
30 | * @param numIsolates long accumulator for the number of isolated latent entities
31 | * @param aggDistortions map accumulator for the number of distortions per
32 | * file and attribute.
33 | * @param recDistortions
34 | */
35 | case class SummaryAccumulators(logLikelihood: DoubleAccumulator,
36 | numIsolates: LongAccumulator,
37 | aggDistortions: MapLongAccumulator[(AttributeId, FileId)],
38 | recDistortions: MapLongAccumulator[Int]) {
39 | /** Reset all accumulators */
40 | def reset(): Unit = {
41 | logLikelihood.reset()
42 | numIsolates.reset()
43 | aggDistortions.reset()
44 | recDistortions.reset()
45 | }
46 | }
47 |
48 | object SummaryAccumulators {
49 | /** Create a `SummaryAccumulators` object using a given SparkContext.
50 | *
51 | * @param sparkContext a SparkContext.
52 | * @return a `SummaryAccumulators` object.
53 | */
54 | def apply(sparkContext: SparkContext): SummaryAccumulators = {
55 | val logLikelihood = new DoubleAccumulator
56 | val numIsolates = new LongAccumulator
57 | val aggDistortions = new MapLongAccumulator[(AttributeId, FileId)]
58 | val recDistortions = new MapLongAccumulator[Int]
59 | sparkContext.register(logLikelihood, "log-likelihood (un-normalised)")
60 | sparkContext.register(numIsolates, "number of isolates")
61 | sparkContext.register(aggDistortions, "aggregate number of distortions per attribute/file")
62 | sparkContext.register(recDistortions, "frequency distribution of total record distortion")
63 | new SummaryAccumulators(logLikelihood, numIsolates, aggDistortions, recDistortions)
64 | }
65 | }
--------------------------------------------------------------------------------
/src/test/scala/com/github/cleanzr/dblink/random/AliasSamplerTest.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.random
21 |
22 | import org.apache.commons.math3.random.{MersenneTwister, RandomGenerator}
23 | import org.scalatest.{FlatSpec, Matchers}
24 |
25 | class AliasSamplerTest extends FlatSpec with Matchers {
26 | implicit val rand: RandomGenerator = new MersenneTwister(1)
27 |
28 | def regularProbs = IndexedSeq(0.1, 0.2, 0.7)
29 | def extremeProbs = IndexedSeq(0.000000001, 0.000000001, 0.999999999)
30 |
31 | def empiricalDistributionIsConsistent(inputProbs: IndexedSeq[Double],
32 | sampler: AliasSampler,
33 | numSamples: Int = 100000000,
34 | tolerance: Double = 1e-4): Boolean = {
35 | val empiricalProbs = Array.fill(sampler.size)(0.0)
36 | var i = 0
37 | while (i < numSamples) {
38 | empiricalProbs(sampler.sample()) += 1.0/numSamples
39 | i += 1
40 | }
41 | (empiricalProbs, inputProbs).zipped.forall( (pE, pT) => pT === (pT +- tolerance))
42 | }
43 |
44 | behavior of "An Alias sampler"
45 |
46 | it should "complain when initialized with a negative weight" in {
47 | assertThrows[IllegalArgumentException] {
48 | AliasSampler(Seq(-1.0, 1.0))
49 | }
50 | }
51 |
52 | it should "complain when initialized with a NaN weight" in {
53 | assertThrows[IllegalArgumentException] {
54 | AliasSampler(Seq(0.0/0.0, 1.0))
55 | }
56 | }
57 |
58 | it should "complain when initialized with an infinite weight" in {
59 | assertThrows[IllegalArgumentException] {
60 | AliasSampler(Seq(1.0/0.0, 1.0))
61 | }
62 | }
63 |
64 | it should "produce an asymptotic empirical distribution that is consistent with the input distribution [0.1, 0.2, 0.7]" in {
65 | val sampler = AliasSampler(regularProbs)
66 | assert(empiricalDistributionIsConsistent(regularProbs, sampler))
67 | }
68 |
69 | it should "produce an asymptotic empirical distribution that is consistent with the input distribution [0.000000001, 0.000000001, 0.999999999]" in {
70 | val sampler = AliasSampler(extremeProbs)
71 | assert(empiricalDistributionIsConsistent(extremeProbs, sampler))
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/analysis/ClusteringContingencyTable.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.analysis
21 |
22 | import ClusteringContingencyTable.ContingencyTableRow
23 | import com.github.cleanzr.dblink.Cluster
24 | import org.apache.spark.sql.Dataset
25 | import org.apache.spark.storage.StorageLevel
26 |
27 | case class ClusteringContingencyTable(table: Dataset[ContingencyTableRow],
28 | size: Long) {
29 | def persist(newLevel: StorageLevel) {
30 | table.persist(newLevel)
31 | }
32 |
33 | def unpersist() {
34 | table.unpersist()
35 | }
36 | }
37 |
38 | object ClusteringContingencyTable {
39 | case class ContingencyTableRow(PredictedUID: Long, TrueUID: Long, NumCommonElements: Int)
40 |
41 | // Note: this is a sparse representation of the contingency table. Cluster
42 | // pairs which are not present in the table have no common elements.
43 | def apply(predictedClusters: Dataset[Cluster], trueClusters: Dataset[Cluster]): ClusteringContingencyTable = {
44 | val spark = predictedClusters.sparkSession
45 | import spark.implicits._
46 | // Convert to membership representation
47 | val predictedMembership = predictedClusters.toMembership
48 | val trueMembership = trueClusters.toMembership
49 |
50 | val predictedSize = predictedMembership.count()
51 | val trueSize = trueMembership.count()
52 |
53 | // Ensure that clusterings partition the same set of elements (continue checking after join...)
54 | if (predictedSize != trueSize) throw new Exception("Clusterings do not partition the same set of elements.")
55 |
56 | val joined = predictedMembership.rdd.join(trueMembership.rdd).persist()
57 | val joinedSize = joined.count()
58 |
59 | // Continued checking...
60 | if (trueSize != joinedSize) throw new Exception("Clusterings do not partition the same set of elements.")
61 |
62 | val table = joined.map{case (_, (predictedUID, trueUID)) => ((predictedUID, trueUID), 1)}
63 | .reduceByKey(_ + _)
64 | .map{case ((predictedUID, trueUID), count) => ContingencyTableRow(predictedUID, trueUID, count)}
65 | .toDS()
66 |
67 | joined.unpersist()
68 |
69 | ClusteringContingencyTable(table, trueSize)
70 | }
71 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/util/BufferedRDDWriter.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | // Copyright (C) 2018 Neil Marchant
3 | //
4 | // Author: Neil Marchant
5 | //
6 | // This file is part of dblink.
7 | //
8 | // This program is free software: you can redistribute it and/or modify
9 | // it under the terms of the GNU General Public License as published by
10 | // the Free Software Foundation, either version 3 of the License, or
11 | // (at your option) any later version.
12 | //
13 | // This program is distributed in the hope that it will be useful,
14 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | // GNU General Public License for more details.
17 | //
18 | // You should have received a copy of the GNU General Public License
19 | // along with this program. If not, see .
20 |
21 | package com.github.cleanzr.dblink.util
22 |
23 | import org.apache.spark.rdd.RDD
24 | import org.apache.spark.sql.SparkSession
25 | import org.apache.spark.sql.Encoder
26 | import org.apache.spark.storage.StorageLevel
27 | import com.github.cleanzr.dblink.Logging
28 | import scala.reflect.ClassTag
29 |
30 | case class BufferedRDDWriter[T : ClassTag : Encoder](path: String,
31 | capacity: Int,
32 | append: Boolean,
33 | rdds: Seq[RDD[T]],
34 | firstFlush: Boolean) extends Logging {
35 |
36 | def append(rows: RDD[T]): BufferedRDDWriter[T] = {
37 | val writer = if (capacity - rdds.size == 0) this.flush() else this
38 | rows.persist(StorageLevel.MEMORY_AND_DISK)
39 | rows.count() // force evaluation
40 | val newRdds = writer.rdds :+ rows
41 | writer.copy(rdds = newRdds)
42 | }
43 |
44 | private def write(unionedRdds: RDD[T], overwrite: Boolean): Unit = {
45 | /** Write to disk in Parquet format relying on Dataset/Dataframe API */
46 | val spark = SparkSession.builder().getOrCreate()
47 | val unionedDS = spark.createDataset(unionedRdds)
48 | val saveMode = if (overwrite) "overwrite" else "append"
49 | unionedDS.write.partitionBy("partitionId").format("parquet").mode(saveMode).save(path)
50 | }
51 |
52 | def flush(): BufferedRDDWriter[T] = {
53 | if (rdds.isEmpty) return this
54 |
55 | // Combine RDDs in the buffer using a union operation
56 | val sc = rdds.head.sparkContext
57 | val unionedRdds = sc.union(rdds)
58 |
59 | // Write to disk. Note: overwrite if not appending and this is the first write to disk.
60 | write(unionedRdds, firstFlush && !append)
61 | debug(s"Flushed to disk at $path")
62 |
63 | // Unpersist RDDs in buffer as they're no longer required
64 | rdds.foreach(_.unpersist(blocking = false))
65 |
66 | this.copy(rdds = Seq.empty[RDD[T]], firstFlush = false)
67 | }
68 | }
69 |
70 | object BufferedRDDWriter extends Logging {
71 | def apply[T : ClassTag : Encoder](capacity: Int, path: String, append: Boolean): BufferedRDDWriter[T] = {
72 | val rdds = Seq.empty[RDD[T]]
73 | BufferedRDDWriter[T](path, capacity, append, rdds, firstFlush = true)
74 | }
75 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/random/IndexNonUniformDiscreteDist.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.random
21 |
22 | import org.apache.commons.math3.random.RandomGenerator
23 |
24 | class IndexNonUniformDiscreteDist(weights: Traversable[Double]) extends DiscreteDist[Int] {
25 | require(weights.nonEmpty, "`weights` must be non-empty")
26 |
27 | private val (_probsArray, _totalWeight) = IndexNonUniformDiscreteDist.processWeights(weights)
28 |
29 | override def totalWeight: Double = _totalWeight
30 |
31 | override def numValues: Int = _probsArray.length
32 |
33 | override def values: Traversable[Int] = 0 until numValues
34 |
35 | /** AliasSampler to efficiently sample from the distribution */
36 | private val sampler = AliasSampler(_probsArray, checkWeights = false, normalized = true)
37 |
38 | // /** Inverse CDF */
39 | // private def sampleNaive(): Int = {
40 | // var prob = rand.nextDouble() * totalWeight
41 | // var idx = 0
42 | // while (prob > 0.0 && idx < numValues) {
43 | // prob -= weights(idx)
44 | // idx += 1
45 | // }
46 | // if (prob > 0.0) idx else idx - 1
47 | // }
48 |
49 | override def sample()(implicit rand: RandomGenerator): Int = sampler.sample()
50 |
51 | override def probabilityOf(idx: Int): Double = {
52 | if (idx < 0 || idx >= numValues) 0.0
53 | else _probsArray(idx)
54 | }
55 |
56 | override def toIterator: Iterator[(Int, Double)] = {
57 | (0 until numValues).zip(_probsArray).toIterator
58 | }
59 | }
60 |
61 | object IndexNonUniformDiscreteDist {
62 | def apply(weights: Traversable[Double]): IndexNonUniformDiscreteDist = {
63 | new IndexNonUniformDiscreteDist(weights)
64 | }
65 |
66 | private def processWeights(weights: Traversable[Double]): (Array[Double], Double) = {
67 | val probs = Array.ofDim[Double](weights.size)
68 | var totalWeight: Double = 0.0
69 | var i = 0
70 | weights.foreach { weight =>
71 | if (weight < 0 || weight.isInfinity || weight.isNaN) {
72 | throw new IllegalArgumentException("invalid weight encountered")
73 | }
74 | probs(i) = weight
75 | totalWeight += weight
76 | i += 1
77 | }
78 | if (totalWeight.isInfinity) throw new IllegalArgumentException("total weight is not finite")
79 | if (totalWeight == 0.0) throw new IllegalArgumentException("zero probability mass")
80 | i = 0
81 | while (i < probs.length) {
82 | probs(i) = probs(i)/totalWeight
83 | i += 1
84 | }
85 | (probs, totalWeight)
86 | }
87 | }
--------------------------------------------------------------------------------
/src/test/scala/com/github/cleanzr/dblink/SimilarityFnTest.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import com.github.cleanzr.dblink.SimilarityFn._
23 | import org.scalatest.FlatSpec
24 |
25 | class SimilarityFnTest extends FlatSpec {
26 | behavior of "A constant similarity function"
27 |
28 | it should "return the same (constant) value for the max, min, threshold similarities" in {
29 | assert(ConstantSimilarityFn.maxSimilarity === ConstantSimilarityFn.minSimilarity)
30 | assert(ConstantSimilarityFn.maxSimilarity === ConstantSimilarityFn.threshold)
31 | }
32 |
33 | it should "return the max (constant) similarity for identical values" in {
34 | assert(ConstantSimilarityFn.getSimilarity("TestValue", "TestValue") === ConstantSimilarityFn.maxSimilarity)
35 | }
36 |
37 | it should "return the max (constant) similarity for distinct values" in {
38 | assert(ConstantSimilarityFn.getSimilarity("TestValue1", "TestValue2") === ConstantSimilarityFn.maxSimilarity)
39 | }
40 |
41 | def thresSimFn: LevenshteinSimilarityFn = LevenshteinSimilarityFn(5.0, 10.0)
42 | def noThresSimFn: LevenshteinSimilarityFn = LevenshteinSimilarityFn(0.0, 10.0)
43 |
44 | behavior of "A Levenshtein similarity function (maxSimilarity=10.0 and threshold=5.0)"
45 |
46 | it should "return the max similarity for an identical pair of non-empty strings" in {
47 | assert(thresSimFn.getSimilarity("John Smith", "John Smith") === thresSimFn.maxSimilarity)
48 | }
49 |
50 | it should "return the max similarity for a pair of empty strings" in {
51 | assert(thresSimFn.getSimilarity("", "") === thresSimFn.maxSimilarity)
52 | }
53 |
54 | it should "return the min similarity for an empty and non-empty string" in {
55 | assert(thresSimFn.getSimilarity("", "John Smith") === thresSimFn.minSimilarity)
56 | }
57 |
58 | it should "be symmetric in its arguments" in {
59 | assert(thresSimFn.getSimilarity("Jane Smith", "John Smith") === thresSimFn.getSimilarity("John Smith", "Jane Smith"))
60 | }
61 |
62 | it should "return a similarity of 2.0 for the strings 'AB' and 'BB'" in {
63 | assert(thresSimFn.getSimilarity("AB", "BB") === 2.0)
64 | }
65 |
66 | behavior of "A Levenshtein similarity function with no threshold (maxSimilarity=10.0)"
67 |
68 | it should "return the same value for the min, threshold similarities" in {
69 | assert(noThresSimFn.threshold === noThresSimFn.minSimilarity)
70 | }
71 |
72 | it should "return a similarity of 6.0 for the strings 'AB' and 'BB'" in {
73 | assert(noThresSimFn.getSimilarity("AB", "BB") === 6.0)
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/partitioning/LPTScheduler.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.partitioning
21 |
22 | import LPTScheduler._
23 |
24 | import scala.Numeric.Implicits._
25 | import scala.Ordering.Implicits._
26 | import scala.collection.mutable
27 | import scala.collection.mutable.ArrayBuffer
28 | import scala.reflect.ClassTag
29 |
30 | /**
31 | * Implementation of the longest processing time (LPT) algorithm for
32 | * partitioning jobs across multiple processors
33 | * @param jobs
34 | * @param numPartitions
35 | * @tparam J job index
36 | * @tparam T job runtime
37 | */
38 | class LPTScheduler[J, T : ClassTag : Numeric](jobs: IndexedSeq[(J,T)],
39 | val numPartitions: Int) extends Serializable {
40 |
41 | val numJobs: Int = jobs.size
42 | private val partitions: Array[Partition[J, T]] = partitionJobs[J, T](jobs, numPartitions)
43 | val partitionSizes: Array[T] = partitions.iterator.map[T](_.size).toArray[T]
44 | private val partitionsIndex = indexPartitions(partitions)
45 |
46 | def getJobs(partitionId: Int): Seq[J] = partitions(partitionId).jobs
47 |
48 | def getSize(partitionId: Int): T = partitionSizes(partitionId)
49 |
50 | def getPartitionId(jobId: J): Int = partitionsIndex(jobId)
51 | // TODO: deal with non-existent jobId
52 | }
53 |
54 | object LPTScheduler {
55 | case class Partition[J, T : ClassTag : Numeric](var size: T, jobs: ArrayBuffer[J])
56 |
57 | private def partitionJobs[J, T : ClassTag : Numeric](jobs: IndexedSeq[(J, T)],
58 | numPartitions: Int): Array[Partition[J, T]] = {
59 | val sortedJobs = jobs.sortBy(-_._2)
60 |
61 | // Initialise array of Partitions
62 | val partitions = Array.fill[Partition[J, T]](numPartitions) {
63 | Partition(implicitly[Numeric[T]].zero, ArrayBuffer[J]())
64 | }
65 |
66 | // Loop through jobs, putting each in the partition with the smallest size so far
67 | sortedJobs.foreach {j =>
68 | val (_, smallestId) = partitions.zipWithIndex.foldLeft ((partitions.head, 0) ) {
69 | (xO, xN) => if (xO._1.size > xN._1.size) xN else xO
70 | }
71 | partitions(smallestId).size += j._2
72 | partitions(smallestId).jobs += j._1
73 | }
74 |
75 | partitions
76 | }
77 |
78 | private def indexPartitions[J, T : Numeric](partitions: Array[Partition[J,T]]): Map[J, Int] = {
79 | val bIndex = Map.newBuilder[J, Int]
80 | partitions.zipWithIndex.foreach { case (partition, idx) =>
81 | partition.jobs.foreach(item => bIndex += (item -> idx))
82 | }
83 | bIndex.result()
84 | }
85 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/analysis/ClusteringMetrics.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.analysis
21 |
22 | import com.github.cleanzr.dblink.Cluster
23 | import org.apache.commons.math3.util.CombinatoricsUtils.binomialCoefficient
24 | import org.apache.spark.sql.Dataset
25 | import org.apache.spark.storage.StorageLevel
26 |
27 | case class ClusteringMetrics(adjRandIndex: Double) {
28 | def mkString: String = {
29 | "=====================================\n" +
30 | " Cluster metrics \n" +
31 | "-------------------------------------\n" +
32 | s" Adj. Rand index: $adjRandIndex\n" +
33 | "=====================================\n"
34 | }
35 |
36 | def print(): Unit = {
37 | Console.print(mkString)
38 | }
39 | }
40 |
41 | object ClusteringMetrics {
42 | private def comb2(x: Int): Long = if (x >= 2) binomialCoefficient(x,2) else 0
43 |
44 | def AdjustedRandIndex(contingencyTable: ClusteringContingencyTable): Double = {
45 | // Compute sum_{PredictedUID} comb2(sum_{TrueUID} NumCommonElements(PredictedUID, TrueUID))
46 | val predCombSum = contingencyTable.table.rdd
47 | .map(row => (row.PredictedUID, row.NumCommonElements))
48 | .reduceByKey(_ + _) // sum over True UID
49 | .aggregate(0L)(
50 | seqOp = (sum, x) => sum + comb2(x._2),
51 | combOp = _ + _ // apply comb2(.) and sum over PredictedUID
52 | )
53 |
54 | // Compute sum_{TrueUID} comb2(sum_{PredictedUID} NumCommonElements(PredictedUID, TrueUID))
55 | val trueCombSum = contingencyTable.table.rdd
56 | .map(row => (row.TrueUID, row.NumCommonElements))
57 | .reduceByKey(_ + _) // sum over Pred UID
58 | .aggregate(0L)(
59 | seqOp = (sum, x) => sum + comb2(x._2),
60 | combOp = _ + _ // apply comb2(.) and sum over True UID
61 | )
62 |
63 | // Compute sum_{TrueUID} sum_{PredictedUID} comb2(NumCommonElements(PredictedUID, TrueUID))
64 | val totalCombSum = contingencyTable.table.rdd
65 | .aggregate(0L)(
66 | seqOp = (sum, x) => sum + comb2(x.NumCommonElements),
67 | combOp = _ + _ // apply comb2(.) and sum over PredictedUID & TrueUID
68 | )
69 |
70 | // Return adjusted Rand index
71 | val expectedIndex = predCombSum.toDouble * trueCombSum / comb2(contingencyTable.size.toInt)
72 | val maxIndex = (predCombSum.toDouble + trueCombSum) / 2.0
73 | (totalCombSum - expectedIndex)/(maxIndex - expectedIndex)
74 | }
75 |
76 | def apply(predictedClusters: Dataset[Cluster],
77 | trueClusters: Dataset[Cluster]): ClusteringMetrics = {
78 | val contingencyTable = ClusteringContingencyTable(predictedClusters, trueClusters)
79 | contingencyTable.persist(StorageLevel.MEMORY_ONLY)
80 | val adjRandIndex = AdjustedRandIndex(contingencyTable)
81 | contingencyTable.unpersist()
82 | ClusteringMetrics(adjRandIndex)
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/ProjectSteps.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import ProjectStep.{CopyFilesStep, EvaluateStep, SampleStep, SummarizeStep}
23 | import com.typesafe.config.{Config, ConfigException}
24 | import Project.toConfigTraversable
25 |
26 | import scala.collection.JavaConverters._
27 | import scala.collection.mutable
28 | import scala.util.Try
29 |
30 | class ProjectSteps(config: Config, project: Project) {
31 |
32 | val steps: Traversable[ProjectStep] = ProjectSteps.parseSteps(config, project)
33 |
34 | def execute(): Unit = {
35 | steps.foreach(_.execute())
36 | }
37 |
38 | def mkString: String = {
39 | val lines = mutable.ArrayBuffer.empty[String]
40 | lines += "Scheduled steps"
41 | lines += "---------------"
42 | lines ++= steps.map(step => " * " + step.mkString)
43 |
44 | lines.mkString("\n")
45 | }
46 | }
47 |
48 | object ProjectSteps {
49 | def apply(config: Config, project: Project): ProjectSteps = {
50 | new ProjectSteps(config, project)
51 | }
52 |
53 | private def parseSteps(config: Config, project: Project): Traversable[ProjectStep] = {
54 | config.getObjectList("dblink.steps").toTraversable.map { step =>
55 | step.getString("name") match {
56 | case "sample" =>
57 | new SampleStep(project,
58 | sampleSize = step.getInt("parameters.sampleSize"),
59 | burninInterval = Try {step.getInt("parameters.burninInterval")} getOrElse 0,
60 | thinningInterval = Try {step.getInt("parameters.thinningInterval")} getOrElse 1,
61 | resume = Try {step.getBoolean("parameters.resume")} getOrElse true,
62 | sampler = Try {step.getString("parameters.sampler")} getOrElse "PCG-I")
63 | case "evaluate" =>
64 | new EvaluateStep(project,
65 | lowerIterationCutoff = Try {step.getInt("parameters.lowerIterationCutoff")} getOrElse 0,
66 | metrics = step.getStringList("parameters.metrics").asScala,
67 | useExistingSMPC = Try {step.getBoolean("parameters.useExistingSMPC")} getOrElse false
68 | )
69 | case "summarize" =>
70 | new SummarizeStep(project,
71 | lowerIterationCutoff = Try {step.getInt("parameters.lowerIterationCutoff")} getOrElse 0,
72 | quantities = step.getStringList("parameters.quantities").asScala
73 | )
74 | case "copy-files" =>
75 | new CopyFilesStep(project,
76 | fileNames = step.getStringList("parameters.fileNames").asScala,
77 | destinationPath = step.getString("parameters.destinationPath"),
78 | overwrite = Try {step.getBoolean("parameters.overwrite")} getOrElse false,
79 | deleteSource = Try {step.getBoolean("parameters.deleteSource")} getOrElse false
80 | )
81 | case _ => throw new ConfigException.BadValue(config.origin(), "name", "unsupported step")
82 | }
83 | }
84 | }
85 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/DiagnosticsWriter.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import com.github.cleanzr.dblink.util.BufferedFileWriter
23 | import org.apache.spark.SparkContext
24 |
25 | /** Buffered writer for diagnostics along the Markov chain
26 | *
27 | * Output is in CSV format
28 | *
29 | * @param path path to save diagnostics (can be a HDFS or local path)
30 | * @param continueChain whether to append to existing diagnostics (if present)
31 | */
32 | class DiagnosticsWriter(path: String, continueChain: Boolean)
33 | (implicit sparkContext: SparkContext) extends Logging {
34 | private val writer = BufferedFileWriter(path, continueChain, sparkContext)
35 | private var firstWrite = true
36 | info(s"Writing diagnostics along chain to $path.")
37 |
38 | /** Write header for CSV */
39 | private def writeHeader(state: State): Unit = {
40 | val recordsCache = state.bcRecordsCache.value
41 | val aggDistortionsHeaders = recordsCache.indexedAttributes.map(x => s"aggDist-${x.name}").mkString(",")
42 | val recDistortionsHeaders = (0 to recordsCache.numAttributes).map(k => s"recDistortion-$k").mkString(",")
43 | writer.write(s"iteration,systemTime-ms,numObservedEntities,logLikelihood,popSize,$aggDistortionsHeaders,$recDistortionsHeaders\n")
44 | }
45 |
46 | /** Write a row of diagnostics for the given state */
47 | def writeRow(state: State): Unit = {
48 | if (firstWrite && !continueChain) writeHeader(state); firstWrite = false
49 |
50 | // Get number of attributes
51 | val numAttributes = state.bcRecordsCache.value.numAttributes
52 |
53 | // Aggregate number of distortions for each attribute (sum over fileId)
54 | val aggAttrDistortions = state.summaryVars.aggDistortions
55 | .groupBy(_._1._1) // group by attrId
56 | .mapValues(_.values.sum) // sum over fileId
57 |
58 | // Convenience variable
59 | val recDistortions = state.summaryVars.recDistortions
60 |
61 | // Build row of string values matching header
62 | val row: Iterator[String] = Iterator(
63 | state.iteration.toString, // iteration
64 | System.currentTimeMillis().toString, // systemTime-ms
65 | (state.populationSize - state.summaryVars.numIsolates).toString, // numObservedEntities
66 | f"${state.summaryVars.logLikelihood}%.9e", // logLikelihood
67 | state.populationSize.toString) ++ // populationSize
68 | (0 until numAttributes).map(aggAttrDistortions.getOrElse(_, 0L).toString) ++ // aggDist-*
69 | (0 to numAttributes).map(recDistortions.getOrElse(_, 0L).toString) // recDistortion-*
70 |
71 | writer.write(row.mkString(",") + "\n")
72 | }
73 |
74 | def close(): Unit = writer.close()
75 |
76 | def flush(): Unit = writer.flush()
77 |
78 | /** Need to call this occasionally to keep the writer alive */
79 | def progress(): Unit = writer.progress()
80 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/util/PeriodicRDDCheckpointer.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.github.cleanzr.dblink.util
19 |
20 | import org.apache.spark.SparkContext
21 | import org.apache.spark.rdd.RDD
22 | import org.apache.spark.storage.StorageLevel
23 |
24 | /**
25 | * This class helps with persisting and checkpointing RDDs.
26 | * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as
27 | * unpersisting and removing checkpoint files.
28 | *
29 | * Users should call update() when a new RDD has been created,
30 | * before the RDD has been materialized. After updating [[PeriodicRDDCheckpointer]], users are
31 | * responsible for materializing the RDD to ensure that persisting and checkpointing actually
32 | * occur.
33 | *
34 | * When update() is called, this does the following:
35 | * - Persist new RDD (if not yet persisted), and put in queue of persisted RDDs.
36 | * - Unpersist RDDs from queue until there are at most 3 persisted RDDs.
37 | * - If using checkpointing and the checkpoint interval has been reached,
38 | * - Checkpoint the new RDD, and put in a queue of checkpointed RDDs.
39 | * - Remove older checkpoints.
40 | *
41 | * WARNINGS:
42 | * - This class should NOT be copied (since copies may conflict on which RDDs should be
43 | * checkpointed).
44 | * - This class removes checkpoint files once later RDDs have been checkpointed.
45 | * However, references to the older RDDs will still return isCheckpointed = true.
46 | *
47 | * Example usage:
48 | * {{{
49 | * val (rdd1, rdd2, rdd3, ...) = ...
50 | * val cp = new PeriodicRDDCheckpointer(2, sc)
51 | * cp.update(rdd1)
52 | * rdd1.count();
53 | * // persisted: rdd1
54 | * cp.update(rdd2)
55 | * rdd2.count();
56 | * // persisted: rdd1, rdd2
57 | * // checkpointed: rdd2
58 | * cp.update(rdd3)
59 | * rdd3.count();
60 | * // persisted: rdd1, rdd2, rdd3
61 | * // checkpointed: rdd2
62 | * cp.update(rdd4)
63 | * rdd4.count();
64 | * // persisted: rdd2, rdd3, rdd4
65 | * // checkpointed: rdd4
66 | * cp.update(rdd5)
67 | * rdd5.count();
68 | * // persisted: rdd3, rdd4, rdd5
69 | * // checkpointed: rdd4
70 | * }}}
71 | *
72 | * @param checkpointInterval RDDs will be checkpointed at this interval
73 | * @tparam T RDD element type
74 | */
75 | class PeriodicRDDCheckpointer[T](checkpointInterval: Int,
76 | sc: SparkContext)
77 | extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) {
78 |
79 | override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint()
80 |
81 | override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed
82 |
83 | override protected def persist(data: RDD[T]): Unit = {
84 | if (data.getStorageLevel == StorageLevel.NONE) {
85 | //data.persist(StorageLevel.MEMORY_ONLY)
86 | }
87 | }
88 |
89 | override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false)
90 |
91 | override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = {
92 | data.getCheckpointFile.map(x => x)
93 | }
94 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/analysis/package.scala:
--------------------------------------------------------------------------------
1 | package com.github.cleanzr.dblink
2 |
3 | import org.apache.hadoop.fs.{FileSystem, Path}
4 | import org.apache.spark.sql.{Dataset, SparkSession}
5 |
6 | import scala.reflect.ClassTag
7 |
8 | package object analysis {
9 |
10 | /**
11 | * TODO
12 | * @param links
13 | * @return
14 | */
15 | def canonicalizePairwiseLinks(links: Dataset[RecordPair]): Dataset[RecordPair] = {
16 | val spark = links.sparkSession
17 | import spark.implicits._
18 |
19 | links.map { recIds =>
20 | // Ensure pairs are sorted
21 | recIds._1.compareTo(recIds._2) match {
22 | case x if x < 0 => (recIds._1, recIds._2)
23 | case x if x > 0 => (recIds._2, recIds._1)
24 | case 0 => throw new Exception(s"Invalid link: ${recIds._1} <-> ${recIds._2}.")
25 | }
26 | }.distinct()
27 | }
28 |
29 |
30 |
31 | /**
32 | * TODO
33 | * @param path
34 | * @return
35 | */
36 | def readClustersCSV(path: String): Dataset[Cluster] = {
37 | val spark = SparkSession.builder.getOrCreate()
38 | val sc = spark.sparkContext
39 | import spark.implicits._
40 | sc.textFile(path)
41 | .map(line => line.split(",").map(_.trim).toSet).toDS()
42 | }
43 |
44 |
45 |
46 | /**
47 | * TODO
48 | * @param membership
49 | * @tparam T
50 | * @return
51 | */
52 | def membershipToClusters[T : ClassTag](membership: Dataset[(RecordId, T)]): Dataset[Cluster] = {
53 | val spark = membership.sparkSession
54 | import spark.implicits._
55 | membership.rdd
56 | .map(_.swap)
57 | .aggregateByKey(Set.empty[RecordId])(
58 | seqOp = (recIds, recId) => recIds + recId,
59 | combOp = (partA, partB) => partA union partB
60 | )
61 | .map(_._2)
62 | .toDS()
63 | }
64 |
65 |
66 |
67 | /*
68 | These private methods must appear outside the Clusters implicit value class due to a limitation of the compiler
69 | */
70 | private def _toPairwiseLinks(clusters: Dataset[Cluster]): Dataset[RecordPair] = {
71 | val spark = clusters.sparkSession
72 | import spark.implicits._
73 | val links = clusters.flatMap(_.toSeq.combinations(2).map(x => (x(0), x(1))))
74 | canonicalizePairwiseLinks(links)
75 | }
76 |
77 | private def _toMembership(clusters: Dataset[Cluster]): Dataset[(RecordId, EntityId)] = {
78 | val spark = clusters.sparkSession
79 | import spark.implicits._
80 | clusters.rdd
81 | .zipWithUniqueId()
82 | .flatMap { case (recIds, entityId) => recIds.iterator.map(recId => (recId, entityId.toInt)) }
83 | .toDS()
84 | }
85 |
86 | /**
87 | * Represents a clustering of records as sets of record ids. Provides methods for converting to pairwise and
88 | * membership representations.
89 | *
90 | * @param ds A Dataset of record clusters
91 | */
92 | implicit class Clusters(val ds: Dataset[Cluster]) extends AnyVal {
93 | /**
94 | * Save to a CSV file
95 | *
96 | * @param path Path to the file. May be a path on the local filesystem or HDFS.
97 | * @param overwrite Whether to overwrite an existing file or not.
98 | */
99 | def saveCsv(path: String, overwrite: Boolean = true): Unit = {
100 | val sc = ds.sparkSession.sparkContext
101 | val hdfs = FileSystem.get(sc.hadoopConfiguration)
102 | val file = new Path(path)
103 | if (hdfs.exists(file)) {
104 | if (!overwrite) return else hdfs.delete(file, true)
105 | }
106 | ds.rdd.map(cluster => cluster.mkString(", "))
107 | .saveAsTextFile(path)
108 | }
109 |
110 | /**
111 | * Convert to a pairwise representation
112 | */
113 | def toPairwiseLinks: Dataset[RecordPair] = _toPairwiseLinks(ds)
114 |
115 | /**
116 | * Convert to a membership vector representation
117 | */
118 | def toMembership: Dataset[(RecordId, EntityId)] = _toMembership(ds)
119 | }
120 | }
121 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/random/NonUniformDiscreteDist.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.random
21 |
22 | import org.apache.commons.math3.random.RandomGenerator
23 |
24 | import scala.collection.Map
25 | import scala.reflect.ClassTag
26 |
27 | /** A non-uniform distribution over a discrete set of values
28 | *
29 | * @param valuesWeights map from values to weights (need not be normalised)
30 | * @param rand pseudo-random number generator
31 | * @param ev
32 | * @tparam T type of the values
33 | */
34 | class NonUniformDiscreteDist[T](valuesWeights: Map[T, Double])
35 | (implicit ev: ClassTag[T]) extends DiscreteDist[T] {
36 | require(valuesWeights.nonEmpty, "`valuesWeights` must be non-empty")
37 |
38 | private val (_valuesArray, _probsArray, _totalWeight) = NonUniformDiscreteDist.processValuesWeights(valuesWeights)
39 |
40 | override def totalWeight: Double = _totalWeight
41 |
42 | override def values: Traversable[T] = _valuesArray.toTraversable
43 |
44 | override val numValues: Int = valuesWeights.size
45 |
46 | /** AliasSampler to efficiently sample from the distribution */
47 | private val sampler = AliasSampler(_probsArray, checkWeights = false, normalized = true)
48 |
49 | // /** Inverse CDF */
50 | // private def sampleNaive(): T = {
51 | // var prob = rand.nextDouble() * totalWeight
52 | // val it = valuesWeights.iterator
53 | // var vw = it.next()
54 | // while (prob > 0.0 && it.hasNext) {
55 | // vw = it.next()
56 | // prob -= vw._2
57 | // }
58 | // vw._1
59 | // }
60 |
61 | override def sample()(implicit rand: RandomGenerator): T = _valuesArray(sampler.sample())
62 |
63 | override def probabilityOf(value: T): Double =
64 | valuesWeights.getOrElse(value, 0.0)/totalWeight
65 |
66 | override def toIterator: Iterator[(T, Double)] = {
67 | (_valuesArray,_probsArray).zipped.toIterator
68 | }
69 | }
70 |
71 | object NonUniformDiscreteDist {
72 | def apply[T](valuesAndWeights: Map[T, Double])
73 | (implicit ev: ClassTag[T]): NonUniformDiscreteDist[T] = {
74 | new NonUniformDiscreteDist[T](valuesAndWeights)
75 | }
76 |
77 | private def processValuesWeights[T : ClassTag](valuesWeights: Map[T, Double]): (Array[T], Array[Double], Double) = {
78 | val values = Array.ofDim[T](valuesWeights.size)
79 | val probs = Array.ofDim[Double](valuesWeights.size)
80 | val it = valuesWeights.iterator
81 | var totalWeight: Double = 0.0
82 | var i = 0
83 | valuesWeights.foreach { pair =>
84 | val weight = pair._2
85 | values(i) = pair._1
86 | probs(i) = weight
87 | if (weight < 0 || weight.isInfinity || weight.isNaN) {
88 | throw new IllegalArgumentException("invalid weight encountered")
89 | }
90 | totalWeight += weight
91 | i += 1
92 | }
93 | if (totalWeight.isInfinity) throw new IllegalArgumentException("total weight is not finite")
94 | if (totalWeight == 0.0) throw new IllegalArgumentException("zero probability mass")
95 | i = 0
96 | while (i < probs.length) {
97 | probs(i) = probs(i)/totalWeight
98 | i += 1
99 | }
100 | (values, probs, totalWeight)
101 | }
102 | }
103 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/partitioning/DomainSplitter.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.partitioning
21 |
22 | import scala.collection.mutable
23 | import scala.math.abs
24 |
25 | /** Splits a domain of discrete weighted values into two partitions of roughly
26 | * equal weight.
27 | *
28 | * @tparam T value type
29 | */
30 | abstract class DomainSplitter[T] {
31 | /** Quality of the split (1.0 is best, 0.0 is worst) */
32 | val splitQuality: Double
33 |
34 | /** Returns whether value x is in the left (false) or right (true) set
35 | *
36 | * @param x value
37 | * @return
38 | */
39 | def apply(x: T): Boolean
40 | }
41 |
42 | object DomainSplitter {
43 | def apply[T : Ordering](domain: Array[(T, Double)]): DomainSplitter[T] = {
44 | /** Use the set splitter for small domains */
45 | if (domain.length <= 30) new LPTDomainSplitter[T](domain)
46 | else new RanDomainSplitter[T](domain)
47 | }
48 |
49 | /** Range splitter.
50 | *
51 | * Splits the domain by sorting the values, then splitting at the (weighted)
52 | * median.
53 | *
54 | * @param domain value-weight pairs for the domain
55 | * @tparam T type of the values in the domain
56 | */
57 | class RanDomainSplitter[T : Ordering](domain: Array[(T, Double)]) extends DomainSplitter[T] with Serializable {
58 | private val halfWeight = domain.iterator.map(_._2).sum / 2.0
59 |
60 | private val (splitWeight, splitValue) = {
61 | val numCandidates = domain.length
62 | val ordered = domain.sortBy(_._1)
63 | var cumWeight = 0.0
64 | var i = 0
65 | while (cumWeight <= halfWeight && i < numCandidates - 1) {
66 | cumWeight += ordered(i)._2
67 | i += 1
68 | }
69 | (cumWeight, ordered(i)._1)
70 | }
71 |
72 | override val splitQuality: Double = 1.0 - abs(splitWeight - halfWeight)/halfWeight
73 |
74 | override def apply(x: T): Boolean = Ordering[T].gt(x, splitValue)
75 | }
76 |
77 | /** LPT splitter.
78 | *
79 | * Splits the domain using the longest processing time (LPT) algorithm with
80 | * only two "processors" (buckets). This is not space efficient for large
81 | * domains, roughly half of the domain values must be stored internally.
82 | *
83 | * @param domain value-weight pairs for the domain
84 | * @tparam T type of the values in the domain
85 | */
86 | class LPTDomainSplitter[T](domain: Array[(T, Double)]) extends DomainSplitter[T] with Serializable {
87 | private val halfWeight = domain.iterator.map(_._2).sum / 2.0
88 |
89 | private val rightSet = mutable.Set.empty[T]
90 |
91 | private val splitWeight = {
92 | val ordered = domain.sortBy(-_._2) // decreasing order of weight
93 | var leftWeight = 0.0
94 | var rightWeight = 0.0
95 | ordered.foreach { case (value, weight) =>
96 | if (leftWeight >= rightWeight) {
97 | rightSet.add(value)
98 | rightWeight += weight
99 | } else {
100 | // effectively putting value in the left set
101 | leftWeight += weight
102 | }
103 | }
104 | leftWeight
105 | }
106 |
107 | override val splitQuality: Double = 1.0 - abs(splitWeight - halfWeight)/halfWeight
108 |
109 | override def apply(x: T): Boolean = rightSet.contains(x)
110 | }
111 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # dblink: Distributed End-to-End Bayesian Entity Resolution
2 | `dblink` is a Spark package for performing unsupervised entity resolution
3 | (ER) on structured data.
4 | It's based on a Bayesian model called `blink`
5 | [(Steorts, 2015)](https://projecteuclid.org/euclid.ba/1441790411),
6 | with extensions proposed in
7 | [(Marchant et al., 2021)](https://doi.org/10.1080/10618600.2020.1825451).
8 | Unlike many ER algorithms, `dblink` approximates the full posterior
9 | distribution over clusterings of records (into entities).
10 | This facilitates propagation of uncertainty to post-ER analysis,
11 | and provides a framework for answering probabilistic queries about entity
12 | membership.
13 |
14 | `dblink` approximates the posterior using Markov chain Monte Carlo.
15 | It writes samples (of clustering configurations) to disk in Parquet format.
16 | Diagnostic summary statistics are also written to disk in CSV format—these are
17 | useful for assessing convergence of the Markov chain.
18 |
19 | ## Documentation
20 | The step-by-step [guide](docs/guide.md) includes information about
21 | building dblink from source and running it locally on a test data set.
22 | Further details about configuration options for dblink is provided
23 | [here](docs/configuration.md).
24 |
25 | ## Example: RLdata
26 | Two synthetic data sets RLdata500 and RLdata10000 are included in the examples
27 | directory as CSV files.
28 | These data sets were extracted from the [RecordLinkage](https://cran.r-project.org/web/packages/RecordLinkage/index.html)
29 | R package and have been used as benchmark data sets in the entity resolution
30 | literature.
31 | Both contain 10 percent duplicates and are non-trivial to link due to added
32 | distortion.
33 | Standard entity resolution metrics can be computed as unique ids are provided
34 | in the files.
35 | Config files for these data sets are included in the examples directory:
36 | see `RLdata500.conf` and `RLdata10000.conf`.
37 | To run these examples locally (in Spark pseudocluster mode),
38 | ensure you've built or obtained the JAR according to the instructions
39 | above, then change into the source code directory and run the following
40 | command:
41 | ```bash
42 | $SPARK_HOME/bin/spark-submit \
43 | --master "local[*]" \
44 | --conf "spark.driver.extraJavaOptions=-Dlog4j.configuration=log4j.properties" \
45 | --conf "spark.driver.extraClassPath=./target/scala-2.11/dblink-assembly-0.2.0.jar" \
46 | ./target/scala-2.11/dblink-assembly-0.2.0.jar \
47 | ./examples/RLdata500.conf
48 | ```
49 | (To run with RLdata10000 instead, replace `RLdata500.conf` with
50 | `RLdata10000.conf`.)
51 | Note that the config file specifies that output will be saved in
52 | the `./examples/RLdata500_results/` (or `./examples/RLdata10000_results`)
53 | directory.
54 |
55 | ## How to: Add dblink as a project dependency
56 | _Note: This won't work yet. Waiting for project to be accepted._
57 |
58 | Maven:
59 | ```xml
60 |
61 | com.github.cleanzr
62 | dblink
63 | 0.2.0
64 |
65 | ```
66 |
67 | sbt:
68 | ```scala
69 | libraryDependencies += "com.github.cleanzr" % "dblink" % "0.2.0"
70 | ```
71 |
72 | ## How to: Build a fat JAR
73 | You can build a fat JAR using sbt by running the following command from
74 | within the project directory:
75 | ```bash
76 | $ sbt assembly
77 | ```
78 |
79 | This should output a JAR file at `./target/scala-2.11/dblink-assembly-0.2.0.jar`
80 | relative to the project directory.
81 | Note that the JAR file does not bundle Spark or Hadoop, but it does include
82 | all other dependencies.
83 |
84 | ## Contact
85 | If you encounter problems, please [open an issue](https://github.com/ngmarchant/dblink/issues)
86 | on GitHub.
87 | You can also contact the main developer by email ` gmail.com`
88 |
89 | ## License
90 | GPL-3
91 |
92 | ## Citing the package
93 |
94 | Marchant, N. G., Kaplan, A., Elazar, D. N., Rubinstein, B. I. P. and
95 | Steorts, R. C. (2021). d-blink: Distributed End-to-End Bayesian Entity
96 | Resolution. _Journal of Computational and Graphical Statistics_, _30_(2),
97 | 406–421. DOI: [10.1080/10618600.2020.1825451](https://doi.org/10.1080/10618600.2020.1825451)
98 | arXiv: [1909.06039](https://arxiv.org/abs/1909.06039).
99 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/SimilarityFn.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import org.apache.commons.lang3.StringUtils.getLevenshteinDistance
23 |
24 | /** Represents a truncated attribute similarity function */
25 | sealed trait SimilarityFn extends Serializable {
26 | /** Compute the truncated similarity
27 | *
28 | * @param a an attribute value
29 | * @param b an attribute value
30 | * @return truncated similarity for `a` and `b`
31 | */
32 | def getSimilarity(a: String, b: String): Double
33 |
34 | /** Maximum possible raw/truncated similarity score */
35 | def maxSimilarity: Double
36 |
37 | /** Minimum possible raw similarity */
38 | def minSimilarity: Double
39 |
40 | /** Truncation threshold.
41 | * All raw similarities below this threshold are truncated to zero. */
42 | def threshold: Double
43 |
44 | def mkString: String
45 | }
46 |
47 | object SimilarityFn {
48 | /** Constant attribute similarity function. */
49 | case object ConstantSimilarityFn extends SimilarityFn {
50 | override def getSimilarity(a: String, b: String): Double = 0.0
51 |
52 | override def maxSimilarity: Double = 0.0
53 |
54 | override def minSimilarity: Double = 0.0
55 |
56 | override def threshold: Double = 0.0
57 |
58 | override def mkString: String = "ConstantSimilarityFn"
59 | }
60 |
61 | abstract class NonConstantSimilarityFn(override val threshold: Double,
62 | override val maxSimilarity: Double) extends SimilarityFn {
63 | require(maxSimilarity > 0.0, "`maxSimilarity` must be positive")
64 | require(threshold >= minSimilarity && threshold < maxSimilarity,
65 | s"`threshold` must be in the interval [$minSimilarity, $maxSimilarity)")
66 |
67 | protected val transFactor: Double = maxSimilarity / (maxSimilarity - threshold)
68 |
69 | override def getSimilarity(a: String, b: String): Double = {
70 | val transSim = transFactor * (maxSimilarity * unitSimilarity(a, b) - threshold)
71 | if (transSim > 0.0) transSim else 0.0
72 | }
73 |
74 | override def minSimilarity: Double = 0.0
75 |
76 | /** Similarity function that provides scores on the unit interval */
77 | protected def unitSimilarity(a: String, b: String): Double
78 | }
79 |
80 |
81 | /* Levenshtein attribute similarity function. */
82 | class LevenshteinSimilarityFn(threshold: Double, maxSimilarity: Double)
83 | extends NonConstantSimilarityFn(threshold, maxSimilarity) {
84 |
85 | /** Similarity measure based on the normalized Levenshtein distance metric.
86 | *
87 | * See the following reference:
88 | * L. Yujian and L. Bo, “A Normalized Levenshtein Distance Metric,”
89 | * IEEE Transactions on Pattern Analysis and Machine Intelligence,
90 | * vol. 29, no. 6, pp. 1091–1095, Jun. 2007.
91 | */
92 | override protected def unitSimilarity(a: String, b: String): Double = {
93 | val totalLength = a.length + b.length
94 | if (totalLength > 0) {
95 | val dist = getLevenshteinDistance(a, b).toDouble
96 | 1.0 - 2.0 * dist / (totalLength + dist)
97 | } else 1.0
98 | }
99 |
100 | override def mkString: String = s"LevenshteinSimilarityFn(threshold=$threshold, maxSimilarity=$maxSimilarity)"
101 | }
102 |
103 | object LevenshteinSimilarityFn {
104 | def apply(threshold: Double = 7.0, maxSimilarity: Double = 10.0) =
105 | new LevenshteinSimilarityFn(threshold, maxSimilarity)
106 | }
107 | }
108 |
109 |
110 |
111 |
112 |
113 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/partitioning/MutableBST.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.partitioning
21 |
22 | import scala.collection.mutable.ArrayBuffer
23 | import scala.math.pow
24 |
25 | /** ArrayBuffer-based binary search tree.
26 | * Not very generic. Consider combining with KDTreePartitioner.
27 | * Fast to search, but not space efficient for imbalanced trees (not
28 | * a problem for our application where the trees are small and ideally
29 | * perfectly balanced.)
30 | *
31 | * @tparam A type of field values
32 | */
33 | class MutableBST[A : Ordering] extends Serializable {
34 | case class Node(var attributeId: Int, var splitter: DomainSplitter[A], var value: Int)
35 |
36 | private val _nodes = ArrayBuffer[Node](Node(-1, null, 0))
37 | private var _numLevels: Int = 0
38 | private var _numLeaves: Int = 1
39 |
40 | def numLeaves: Int = _numLeaves
41 | def numLevels: Int = _numLevels
42 | def isEmpty: Boolean = _nodes.head.splitter == null
43 | def nonEmpty: Boolean = _nodes.head.splitter != null
44 |
45 | /** Search the tree for the leaf node corresponding to the given set of attribute
46 | * values, and return the "number" of the leaf.
47 | *
48 | * @param values a set of attribute values
49 | * @return leaf number (in 0, ..., numLeaves)
50 | */
51 | def getLeafNumber(values: IndexedSeq[A]): Int = {
52 | val nodeId = getLeafNodeId(values)
53 | _nodes(nodeId).value
54 | }
55 |
56 | /** Search the tree for the leaf node corresponding to the given set of attribute
57 | * values.
58 | *
59 | * @param attributes a set of attribute values
60 | * @return id of the leaf node
61 | */
62 | def getLeafNodeId(attributes: IndexedSeq[A]): Int = {
63 | // TODO: check whether point has too many or too few dimensions
64 | var found = false
65 | var nodeId = 0
66 | while (!found && nodeId < _nodes.length) {
67 | val node = _nodes(nodeId)
68 | if (node == null) println("ERROR!!!!")
69 | if (node.splitter != null) {
70 | val pointVal = attributes(node.attributeId)
71 | // descend to next level
72 | if (node.splitter(pointVal)) nodeId = 2*nodeId + 2 // go right
73 | else nodeId = 2*nodeId + 1 // go left
74 | } else {
75 | found = true // at a leaf node
76 | }
77 | }
78 | nodeId
79 | }
80 |
81 | /** Split an existing node of the tree into two leaf nodes
82 | *
83 | * @param nodeId id of node to split
84 | * @param attributeId attribute associated with split
85 | * @param splitter splitter (specifies whether to go left or right)
86 | */
87 | def splitNode(nodeId: Int, attributeId: Int,
88 | splitter: DomainSplitter[A]): Unit = {
89 | /** Ensure that the node already exists as a leaf node. */
90 | val node = if (nodeId < _nodes.length) _nodes(nodeId) else null
91 | require(node != null, "node does not exist")
92 | require(node.splitter == null, "node is already split")
93 |
94 | /** Update the node */
95 | node.attributeId = attributeId
96 | node.splitter = splitter
97 |
98 | /** Insert children for the node */
99 | val leftChildId = 2*nodeId + 1
100 | val rightChildId = 2*nodeId + 2
101 | /** If `node` was a leaf at the lowest level, need to add space for its children */
102 | if (leftChildId >= _nodes.length) {
103 | // grow array for storing nodes
104 | _numLevels += 1
105 | val newLevelSize = pow(2.0, _numLevels).toInt
106 | _nodes ++= Iterator.fill(newLevelSize)(null)
107 | }
108 | _nodes(leftChildId) = Node(-1, null, node.value)
109 | _nodes(rightChildId) = Node(-1, null, _numLeaves)
110 | _numLeaves += 1
111 | }
112 | }
113 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/random/AliasSampler.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.random
21 |
22 | import com.github.cleanzr.dblink.random.AliasSampler._
23 | import org.apache.commons.math3.random.RandomGenerator
24 |
25 | class AliasSampler(weights: Traversable[Double], checkWeights: Boolean, normalized: Boolean) extends Serializable {
26 |
27 | def size: Int = weights.size
28 |
29 | private val (_probabilities, _aliasTable) = computeTables(weights, checkWeights, normalized)
30 |
31 | def probabilities: IndexedSeq[Double] = _probabilities
32 |
33 | def sample()(implicit rand: RandomGenerator): Int = {
34 | val U = rand.nextDouble() * size
35 | val i = U.toInt
36 | if (U < _probabilities(i)) i else _aliasTable(i)
37 | }
38 |
39 | def sample(sampleSize: Int)(implicit rand: RandomGenerator): Array[Int] = {
40 | Array.tabulate(sampleSize)(_ => this.sample())
41 | }
42 | }
43 |
44 | object AliasSampler {
45 | def apply(weights: Traversable[Double], checkWeights: Boolean = true, normalized: Boolean = false): AliasSampler = {
46 | new AliasSampler(weights, checkWeights, normalized)
47 | }
48 |
49 | private def computeTables(weights: Traversable[Double],
50 | checkWeights: Boolean,
51 | normalized: Boolean): (Array[Double], Array[Int]) = {
52 | val size = weights.size
53 |
54 | val totalWeight = if (checkWeights && !normalized) {
55 | weights.foldLeft(0.0) { (sum, weight) =>
56 | if (weight < 0 || weight.isInfinity || weight.isNaN) {
57 | throw new IllegalArgumentException("invalid weight encountered")
58 | }
59 | sum + weight
60 | }
61 | } else if (checkWeights && normalized) {
62 | weights.foreach { weight =>
63 | if (weight < 0 || weight.isInfinity || weight.isNaN) {
64 | throw new IllegalArgumentException("invalid weight encountered")
65 | }
66 | }
67 | 1.0
68 | } else if (!checkWeights && !normalized) {
69 | weights.sum
70 | } else 1.0
71 |
72 | require(totalWeight > 0.0, "zero probability mass")
73 |
74 | val probabilities = weights.map( weight => weight * size / totalWeight).toArray
75 |
76 | val aliasTable = Array.ofDim[Int](size)
77 |
78 | // Store small and large worklists in a single array. "Small" elements are stored on the left side, and
79 | // "large" elements are stored on the right side.
80 | val worklist = Array.ofDim[Int](size)
81 | var posSmall = 0
82 | var posLarge = size
83 |
84 | // Fill worklists
85 | var i = 0
86 | while(i < size) {
87 | if (probabilities(i) < 1.0) {
88 | worklist(posSmall) = i // add index i to the small worklist
89 | posSmall += 1
90 | } else {
91 | posLarge -= 1 // add index i to the large worklist
92 | worklist(posLarge) = i
93 | }
94 | i += 1
95 | }
96 |
97 | // Remove elements from worklists
98 | if (posSmall > 0) { // both small and large worklists contain elements
99 | posSmall = 0
100 | while (posSmall < size && posLarge < size) {
101 | val l = worklist(posSmall)
102 | posSmall += 1 // "remove" element from small worklist
103 | val g = worklist(posLarge)
104 | aliasTable(l) = g
105 | probabilities(g) = (probabilities(g) + probabilities(l)) - 1.0
106 | if (probabilities(g) < 1.0) posLarge += 1 // "move" element g to small worklist
107 | }
108 | }
109 |
110 | // Since the uniform on [0,1] in the `sample` method is shifted by i
111 | i = 0
112 | while (i < size) {
113 | probabilities(i) += i
114 | i += 1
115 | }
116 |
117 | (probabilities, aliasTable)
118 | }
119 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/partitioning/KDTreePartitioner.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink.partitioning
21 |
22 | import com.github.cleanzr.dblink.accumulators.MapDoubleAccumulator
23 | import com.github.cleanzr.dblink.partitioning.KDTreePartitioner.getNewSplits
24 | import org.apache.spark.broadcast.Broadcast
25 | import org.apache.spark.rdd.RDD
26 | import com.github.cleanzr.dblink.Logging
27 |
28 | class KDTreePartitioner[T : Ordering](numLevels: Int,
29 | attributeIds: Traversable[Int]) extends PartitionFunction[T] {
30 | require(numLevels >= 0, "`numLevels` must be non-negative.")
31 | if (numLevels > 0) require(attributeIds.nonEmpty, "`attributeIds` must be non-empty if `numLevels` > 0")
32 |
33 | override def numPartitions: Int = tree.numLeaves
34 |
35 | private val tree = new MutableBST[T]
36 |
37 | override def fit(records: RDD[Array[T]]): Unit = {
38 | if (numLevels > 0) require(attributeIds.nonEmpty, "non-empty list of attributes is required to build the tree.")
39 |
40 | val sc = records.sparkContext
41 |
42 | var level = 0
43 | var itAttributeIds = attributeIds.toIterator
44 | while (level < numLevels) {
45 | /** Go back to the beginning of `attributeIds` if we reach the end */
46 | val attrId = if (itAttributeIds.hasNext) itAttributeIds.next() else {
47 | itAttributeIds = attributeIds.toIterator // reset
48 | itAttributeIds.next()
49 | }
50 | KDTreePartitioner.info(s"Splitting on attribute $attrId at level $level.")
51 | val bcTree = sc.broadcast(tree)
52 | val newLevel = getNewSplits(records, bcTree, attrId)
53 | newLevel.foreach { case (nodeId, splitter) =>
54 | if (splitter.splitQuality <= 0.9)
55 | KDTreePartitioner.warn(s"Poor quality split (${splitter.splitQuality*100}%) at node $nodeId.")
56 | tree.splitNode(nodeId, attrId, splitter)
57 | }
58 | level += 1
59 | }
60 | }
61 |
62 | override def getPartitionId(attributeValues: Array[T]): Int = tree.getLeafNumber(attributeValues)
63 |
64 | override def mkString: String = {
65 | if (numLevels == 0) s"KDTreePartitioner(numLevels=0)"
66 | else s"KDTreePartitioner(numLevels=$numLevels, attributeIds=${attributeIds.mkString("[",",","]")})"
67 | }
68 | }
69 |
70 | object KDTreePartitioner extends Logging {
71 |
72 | def apply[T : Ordering](numLevels: Int, attributeIds: Traversable[Int]): KDTreePartitioner[T] = {
73 | new KDTreePartitioner(numLevels, attributeIds)
74 | }
75 |
76 | def apply[T : Ordering](): KDTreePartitioner[T] = {
77 | new KDTreePartitioner(0, Traversable())
78 | }
79 |
80 | private def getNewSplits[T : Ordering](records: RDD[Array[T]],
81 | bcTree: Broadcast[MutableBST[T]],
82 | attributeId: Int): Map[Int, DomainSplitter[T]] = {
83 | val sc = records.sparkContext
84 |
85 | val acc = new MapDoubleAccumulator[(Int, T)]
86 | sc.register(acc, s"tree level builder")
87 |
88 | /** Iterate over the records counting the number of occurrences of each
89 | * attribute value per node */
90 | records.foreachPartition { partition =>
91 | val tree = bcTree.value
92 | partition.foreach { attributeValues =>
93 | val attributeValue = attributeValues(attributeId)
94 | val nodeId = tree.getLeafNodeId(attributeValues)
95 | acc.add((nodeId, attributeValue), 1L)
96 | }
97 | }
98 |
99 | /** Group by node and compute the splitters for each node */
100 | acc.value.toArray
101 | .groupBy(_._1._1)
102 | .mapValues { x =>
103 | DomainSplitter[T](x.map { case ((_, value), weight) => (value, weight) })
104 | }
105 | }
106 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/CustomKryoRegistrator.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import com.esotericsoftware.kryo.Kryo
23 | import com.github.cleanzr.dblink.GibbsUpdates.{EntityInvertedIndex, LinksIndex}
24 | import com.github.cleanzr.dblink.partitioning.{DomainSplitter, KDTreePartitioner, LPTScheduler, MutableBST, PartitionFunction, SimplePartitioner}
25 | import com.github.cleanzr.dblink.random.{AliasSampler, DiscreteDist, IndexNonUniformDiscreteDist, NonUniformDiscreteDist}
26 | import com.github.cleanzr.dblink.GibbsUpdates.{EntityInvertedIndex, LinksIndex}
27 | import com.github.cleanzr.dblink.partitioning._
28 | import com.github.cleanzr.dblink.random.{AliasSampler, DiscreteDist, IndexNonUniformDiscreteDist, NonUniformDiscreteDist}
29 | import org.apache.commons.math3.random.MersenneTwister
30 | import org.apache.spark.serializer.KryoRegistrator
31 |
32 | class CustomKryoRegistrator extends KryoRegistrator {
33 | override def registerClasses(kryo: Kryo): Unit = {
34 | kryo.register(classOf[EntRecCluster])
35 | kryo.register(classOf[PartEntRecCluster])
36 | kryo.register(classOf[Record[_]])
37 | kryo.register(classOf[Entity])
38 | kryo.register(classOf[DistortedValue])
39 | kryo.register(classOf[SummaryVars])
40 | kryo.register(classOf[Attribute])
41 | kryo.register(classOf[SimilarityFn])
42 | kryo.register(classOf[IndexedAttribute])
43 | kryo.register(classOf[BetaShapeParameters])
44 | kryo.register(classOf[EntityInvertedIndex])
45 | kryo.register(classOf[LinksIndex])
46 | kryo.register(classOf[DomainSplitter[_]])
47 | kryo.register(classOf[KDTreePartitioner[_]])
48 | kryo.register(classOf[LPTScheduler[_,_]])
49 | kryo.register(classOf[LPTScheduler.Partition[_,_]])
50 | kryo.register(classOf[MutableBST[_]])
51 | kryo.register(classOf[PartitionFunction[_]])
52 | kryo.register(classOf[SimplePartitioner[_]])
53 | kryo.register(classOf[AliasSampler])
54 | kryo.register(classOf[DiscreteDist[_]])
55 | kryo.register(classOf[IndexNonUniformDiscreteDist])
56 | kryo.register(classOf[NonUniformDiscreteDist[_]])
57 | kryo.register(Class.forName("org.apache.spark.sql.execution.columnar.CachedBatch"))
58 | kryo.register(Class.forName("[[B"))
59 | kryo.register(Class.forName("org.apache.spark.sql.catalyst.expressions.GenericInternalRow"))
60 | kryo.register(Class.forName("org.apache.spark.unsafe.types.UTF8String"))
61 | kryo.register(classOf[Array[Object]])
62 | kryo.register(classOf[Array[EntRecCluster]])
63 | kryo.register(classOf[RecordsCache])
64 | kryo.register(classOf[Parameters])
65 | kryo.register(classOf[AliasSampler])
66 | kryo.register(classOf[MersenneTwister])
67 | kryo.register(classOf[DistortionProbs])
68 | }
69 | }
70 | //
71 | //class AugmentedRecordSerializer extends Serializer[EntRecPair] {
72 | // override def write(kryo: Kryo, output: Output, `object`: EntRecPair): Unit = {
73 | // output.writeLong(`object`.latId, true)
74 | // output.writeString(`object`.recId)
75 | // output.writeString(`object`.fileId)
76 | // val numLatFields = `object`.latFieldValues.length
77 | // output.writeInt(numLatFields)
78 | // `object`.recFieldValues.foreach(v => output.writeString(v))
79 | // `object`.distortions.foreach(v => output.writeBoolean(v))
80 | // `object`.latFieldValues.foreach(v => output.writeString(v))
81 | // }
82 | //
83 | // override def read(kryo: Kryo, input: Input, `type`: Class[EntRecPair]): EntRecPair = {
84 | // val latId = input.readLong(true)
85 | // val recId = input.readString()
86 | // val fileId = input.readString()
87 | // val numLatFields = input.readInt()
88 | // val numRecFields = if (recId == null) 0 else numLatFields
89 | // val fieldValues = Array.fill(numRecFields)(input.readString())
90 | // val distortions = Array.fill(numRecFields)(input.readBoolean())
91 | // val latFieldValues = Array.fill(numLatFields)(input.readString())
92 | // LatRecPair(latId, latFieldValues, recId, fileId, fieldValues, distortions)
93 | // }
94 | //}
95 |
--------------------------------------------------------------------------------
/src/test/scala/com/github/cleanzr/dblink/AttributeIndexTest.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import org.apache.commons.math3.random.{MersenneTwister, RandomGenerator}
23 | import org.apache.spark.{SparkConf, SparkContext}
24 | import org.scalatest._
25 | import com.github.cleanzr.dblink.SimilarityFn._
26 |
27 | class AttributeIndexTest extends FlatSpec with AttributeIndexBehaviors with Matchers {
28 |
29 | implicit val rand: RandomGenerator = new MersenneTwister(0)
30 | implicit val sc: SparkContext = {
31 | val conf = new SparkConf()
32 | .setMaster("local[1]")
33 | .setAppName("d-blink test")
34 | SparkContext.getOrCreate(conf)
35 | }
36 | sc.setLogLevel("WARN")
37 |
38 | lazy val stateWeights = Map("Australian Capital Territory" -> 0.410,
39 | "New South Wales" -> 7.86, "Northern Territory" -> 0.246, "Queensland" -> 4.92,
40 | "South Australia" -> 1.72, "Tasmania" -> 0.520, "Victoria" -> 6.32,
41 | "Western Australia" -> 2.58)
42 |
43 | lazy val constantIndex = AttributeIndex(stateWeights, ConstantSimilarityFn)
44 |
45 | lazy val nonConstantIndex = AttributeIndex(stateWeights, LevenshteinSimilarityFn(5.0, 10.0))
46 |
47 | /** The following results are for an index with LevenshteinSimilarityFn(5.0, 10.0) */
48 | def stateSimNormalizations = Map("Australian Capital Territory" -> 0.0027140755302269004,
49 | "New South Wales" -> 1.4193905286944585E-4,
50 | "Northern Territory" -> 0.00451528932619675,
51 | "Queensland" -> 2.2673706056780077E-4,
52 | "South Australia" -> 6.465919296781136E-4,
53 | "Tasmania" -> 0.00214117348291189,
54 | "Victoria" -> 1.7651936247903708E-4,
55 | "Western Australia" -> 4.317863538883541E-4)
56 |
57 | def simValuesSA = Map(7 -> 39.813678188084864, 4 -> 22026.465794806718)
58 |
59 | val expSimSAWA = 39.813678188084864
60 | val expSimVICTAS = 1.0
61 |
62 | "An attribute index (with a constant similarity function)" should behave like genericAttributeIndex(constantIndex, stateWeights)
63 |
64 | it should "return a similarity normalization constant of 1.0 for all values" in {
65 | assert((0 until constantIndex.numValues).forall(valueId => constantIndex.simNormalizationOf(valueId) == 1.0))
66 | }
67 |
68 | it should "return no similar values for all values" in {
69 | assert((0 until constantIndex.numValues).forall(valueId => constantIndex.simValuesOf(valueId).isEmpty))
70 | }
71 |
72 | it should "return an exponentiated similarity score of 1.0 for all value pairs" in {
73 | val allValueIds = 0 until constantIndex.numValues
74 | val allValueIdPairs = allValueIds.flatMap(valueId1 => allValueIds.map(valueId2 => (valueId1, valueId2)))
75 | assert(allValueIdPairs.forall { case (valueId1, valueId2) => constantIndex.expSimOf(valueId1, valueId2) == 1.0 })
76 | }
77 |
78 | "An attribute index (with a non-constant similarity function)" should behave like genericAttributeIndex(nonConstantIndex, stateWeights)
79 |
80 | it should "return the correct similarity normalization constants for all values" in {
81 | assert(stateSimNormalizations.forall { case (stringValue, trueSimNorm) =>
82 | nonConstantIndex.simNormalizationOf(nonConstantIndex.valueIdxOf(stringValue)) === (trueSimNorm +- 1e-4)
83 | })
84 | }
85 |
86 | it should "return the correct similar values for a query value" in {
87 | val testSimValues = nonConstantIndex.simValuesOf(nonConstantIndex.valueIdxOf("South Australia"))
88 | assert(simValuesSA.keySet === testSimValues.keySet)
89 | assert(simValuesSA.forall { case (valueId, trueExpSim) => testSimValues(valueId) === (trueExpSim +- 1e-4) })
90 | }
91 |
92 | it should "return the correct exponeniated similarity score for a query value pair" in {
93 | val valueIdSA = nonConstantIndex.valueIdxOf("South Australia")
94 | val valueIdWA = nonConstantIndex.valueIdxOf("Western Australia")
95 | assert(nonConstantIndex.expSimOf(valueIdSA, valueIdWA) === (expSimSAWA +- 1e-4))
96 | val valueIdVIC = nonConstantIndex.valueIdxOf("Victoria")
97 | val valueIdTAS = nonConstantIndex.valueIdxOf("Tasmania")
98 | assert(nonConstantIndex.expSimOf(valueIdVIC, valueIdTAS) === (expSimVICTAS +- 1e-4))
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/package.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr
21 |
22 | import com.github.cleanzr.dblink.SimilarityFn.ConstantSimilarityFn
23 |
24 | /** Types/case classes used throughout */
25 | package object dblink {
26 | import org.apache.spark.rdd.RDD
27 |
28 | type FileId = String
29 | type RecordId = String
30 | type PartitionId = Int
31 | type EntityId = Int
32 | type ValueId = Int
33 | type AttributeId = Int
34 | type Partitions = RDD[(PartitionId, EntRecCluster)]
35 | type AggDistortions = Map[(AttributeId, FileId), Long]
36 | type Cluster = Set[RecordId]
37 | type RecordPair = (RecordId, RecordId)
38 |
39 | case class MostProbableCluster(recordId: RecordId, cluster: Set[RecordId], frequency: Double)
40 |
41 | /** Record (row) in the input data
42 | *
43 | * @param id a unique identifier for the record (must be unique across
44 | * _all_ files)
45 | * @param fileId file identifier for the record
46 | * @param values attribute values for the record
47 | * @tparam T value type
48 | */
49 | case class Record[T](id: RecordId,
50 | fileId: FileId,
51 | values: Array[T]) {
52 | def mkString: String = s"Record(id=$id, fileId=$fileId, values=${values.mkString(",")})"
53 | }
54 |
55 |
56 | /** Latent entity
57 | *
58 | * @param values attribute values for the entity.
59 | */
60 | case class Entity(values: Array[ValueId]) {
61 | def mkString: String = s"Entity(${values.mkString(",")})"
62 | }
63 |
64 |
65 | /** Attribute value subject to distortion
66 | *
67 | * @param value the attribute value
68 | * @param distorted whether the value is distorted
69 | */
70 | case class DistortedValue(value: ValueId, distorted: Boolean) {
71 | def mkString: String = s"DistortedValue(value=$value, distorted=$distorted"
72 | }
73 |
74 |
75 | /**
76 | * Entity-record cluster
77 | * @param entity a latent entity
78 | * @param records records linked to the entity
79 | */
80 | case class EntRecCluster(entity: Entity,
81 | records: Option[Array[Record[DistortedValue]]]) {
82 | def mkString: String = {
83 | records match {
84 | case Some(r) => s"${entity.mkString}\t->\t${r.mkString(", ")}"
85 | case None => entity.mkString
86 | }
87 | }
88 | }
89 |
90 | /** Linkage state
91 | *
92 | * Represents the linkage structure within a partition for a single iteration.
93 | */
94 | case class LinkageState(iteration: Long,
95 | partitionId: PartitionId,
96 | linkageStructure: Seq[Seq[RecordId]])
97 |
98 |
99 | /** Partition-entity-record cluster triple
100 | *
101 | * This class is used in the Dataset representation of the partitions---not
102 | * the RDD representation.
103 | *
104 | * @param partitionId identifier for the partition
105 | * @param entRecCluster entity cluster of records
106 | */
107 | case class PartEntRecCluster(partitionId: PartitionId, entRecCluster: EntRecCluster)
108 |
109 | /** Container for the summary variables
110 | *
111 | * @param numIsolates
112 | * @param logLikelihood
113 | * @param aggDistortions
114 | * @param recDistortions
115 | */
116 | case class SummaryVars(numIsolates: Long,
117 | logLikelihood: Double,
118 | aggDistortions: AggDistortions,
119 | recDistortions: Map[Int, Long])
120 |
121 |
122 | /** Specifications for an attribute
123 | *
124 | * @param name column name of the attribute (should match original DataFrame)
125 | * @param similarityFn an attribute similarity function
126 | * @param distortionPrior prior for the distortion
127 | */
128 | case class Attribute(name: String,
129 | similarityFn: SimilarityFn,
130 | distortionPrior: BetaShapeParameters) {
131 | /** Whether the similarity function for this attribute is constant or not*/
132 | def isConstant: Boolean = {
133 | similarityFn match {
134 | case ConstantSimilarityFn => true
135 | case _ => false
136 | }
137 | }
138 | }
139 |
140 | /** Specifications for an indexed attribute
141 | *
142 | * @param name column name of the attribute (should match original DataFrame)
143 | * @param similarityFn an attribute similarity function
144 | * @param distortionPrior prior for the distortion
145 | * @param index index for the attribute
146 | */
147 | case class IndexedAttribute(name: String,
148 | similarityFn: SimilarityFn,
149 | distortionPrior: BetaShapeParameters,
150 | index: AttributeIndex) {
151 | /** Whether the similarity function for this attribute is constant or not*/
152 | def isConstant: Boolean = {
153 | similarityFn match {
154 | case ConstantSimilarityFn => true
155 | case _ => false
156 | }
157 | }
158 | }
159 |
160 |
161 | /** Container for shape parameters of the Beta distribution
162 | *
163 | * @param alpha "alpha" shape parameter
164 | * @param beta "beta" shape parameter
165 | */
166 | case class BetaShapeParameters(alpha: Double, beta: Double) {
167 | require(alpha > 0 && beta > 0, "shape parameters must be positive")
168 |
169 | def mkString: String = s"BetaShapeParameters(alpha=$alpha, beta=$beta)"
170 | }
171 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/RecordsCache.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import com.github.cleanzr.dblink.accumulators.MapLongAccumulator
23 | import org.apache.spark.SparkContext
24 | import org.apache.spark.rdd.RDD
25 |
26 | /** Container to store statistics/metadata for a collection of records and
27 | * facilitate sampling from the attribute domains.
28 | *
29 | * This container is broadcast to each executor.
30 | *
31 | * @param indexedAttributes indexes for the attributes
32 | * @param fileSizes number of records in each file
33 | */
34 | class RecordsCache(val indexedAttributes: IndexedSeq[IndexedAttribute],
35 | val fileSizes: Map[FileId, Long],
36 | val missingCounts: Option[Map[(FileId, AttributeId), Long]] = None) extends Serializable {
37 |
38 | def distortionPrior: Iterator[BetaShapeParameters] = indexedAttributes.iterator.map(_.distortionPrior)
39 |
40 | /** Number of records across all files */
41 | val numRecords: Long = fileSizes.values.sum
42 |
43 | /** Number of attributes used for matching */
44 | def numAttributes: Int = indexedAttributes.length
45 |
46 | /** Transform to value ids
47 | *
48 | * @param records an RDD of records in raw form (string attribute values)
49 | * @return an RDD of records where the string attribute values are replaced
50 | * by integer value ids
51 | */
52 | def transformRecords(records: RDD[Record[String]]): RDD[Record[ValueId]] =
53 | RecordsCache._transformRecords(records, indexedAttributes)
54 | }
55 |
56 | object RecordsCache extends Logging {
57 |
58 | /** Build RecordsCache
59 | *
60 | * @param records an RDD of records in raw form (string attribute values)
61 | * @param attributeSpecs specifications for each record attribute. Must match
62 | * the order of attributes in `records`.
63 | * @param expectedMaxClusterSize largest expected record cluster size. Used
64 | * as a hint when precaching distributions
65 | * over the non-constant attribute domain.
66 | * @return a RecordsCache
67 | */
68 | def apply(records: RDD[Record[String]],
69 | attributeSpecs: IndexedSeq[Attribute],
70 | expectedMaxClusterSize: Int): RecordsCache = {
71 | val firstRecord = records.take(1).head
72 | require(firstRecord.values.length == attributeSpecs.length, "attribute specifications do not match the records")
73 |
74 | /** Use accumulators to gather record stats in one pass */
75 | val accFileSizes = new MapLongAccumulator[FileId]
76 | implicit val sc: SparkContext = records.sparkContext
77 | sc.register(accFileSizes, "number of records per file")
78 |
79 | val accMissingCounts = new MapLongAccumulator[(FileId, AttributeId)]
80 | sc.register(accMissingCounts, s"missing counts per file and attribute")
81 |
82 | val accValueCounts = attributeSpecs.map{ attribute =>
83 | val acc = new MapLongAccumulator[String]
84 | sc.register(acc, s"value counts for attribute ${attribute.name}")
85 | acc
86 | }
87 |
88 | info("Gathering statistics from source data files.")
89 | /** Get file and value counts in a single foreach action */
90 | records.foreach { case Record(_, fileId, values) =>
91 | accFileSizes.add((fileId, 1L))
92 | values.zipWithIndex.foreach { case (value, attrId) =>
93 | if (value != null) accValueCounts(attrId).add(value, 1L)
94 | else accMissingCounts.add(((fileId, attrId), 1L))
95 | }
96 | }
97 |
98 | val missingCounts = accMissingCounts.value
99 | val fileSizes = accFileSizes.value
100 | val totalRecords = fileSizes.values.sum
101 | val percentageMissing = 100.0 * missingCounts.values.sum / (totalRecords * attributeSpecs.size)
102 | if (percentageMissing >= 0.0) {
103 | info(f"Finished gathering statistics from $totalRecords records across ${fileSizes.size} file(s). $percentageMissing%.3f%% of the record attribute values are missing.")
104 | } else {
105 | info(s"Finished gathering statistics from $totalRecords records across ${fileSizes.size} file(s). There are no missing record attribute values.")
106 | }
107 |
108 | /** Build an index for each attribute (generates a mapping from strings -> integers) */
109 | val indexedAttributes = attributeSpecs.zipWithIndex.map { case (attribute, attrId) =>
110 | val valuesWeights = accValueCounts(attrId).value.mapValues(_.toDouble)
111 | info(s"Indexing attribute '${attribute.name}' with ${valuesWeights.size} unique values.")
112 | val index = AttributeIndex(valuesWeights, attribute.similarityFn,
113 | Some(1 to expectedMaxClusterSize))
114 | IndexedAttribute(attribute.name, attribute.similarityFn, attribute.distortionPrior, index)
115 | }
116 |
117 | new RecordsCache(indexedAttributes, fileSizes, Some(missingCounts))
118 | }
119 |
120 | private def _transformRecords(records: RDD[Record[String]],
121 | indexedAttributes: IndexedSeq[IndexedAttribute]): RDD[Record[ValueId]] = {
122 | val firstRecord = records.take(1).head
123 | require(firstRecord.values.length == indexedAttributes.length, "attribute specifications do not match the records")
124 |
125 | records.mapPartitions( partition => {
126 | partition.map { record =>
127 | val mappedValues = record.values.zipWithIndex.map { case (stringValue, attrId) =>
128 | if (stringValue != null) indexedAttributes(attrId).index.valueIdxOf(stringValue)
129 | else -1
130 | }
131 | Record[ValueId](record.id, record.fileId, mappedValues)
132 | }
133 | }, preservesPartitioning = true)
134 | }
135 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/Sampler.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import com.github.cleanzr.dblink.util.{BufferedRDDWriter, PeriodicRDDCheckpointer}
23 | import org.apache.spark.SparkContext
24 | import org.apache.spark.sql.SparkSession
25 |
26 | object Sampler extends Logging {
27 |
28 | /** Generates posterior samples by successively applying the Markov transition operator starting from a given
29 | * initial state. The samples are written to the path provided.
30 | *
31 | * @param initialState The initial state of the Markov chain.
32 | * @param sampleSize A positive integer specifying the desired number of samples (after burn-in and thinning)
33 | * @param outputPath A string specifying the path to save output (includes samples and diagnostics). HDFS and
34 | * local filesystems are supported.
35 | * @param burninInterval A non-negative integer specifying the number of initial samples to discard as burn-in.
36 | * The default is 0, which means no burn-in is applied.
37 | * @param thinningInterval A positive integer specifying the period for saving samples to disk. The default value is
38 | * 1, which means no thinning is applied.
39 | * @param checkpointInterval A non-negative integer specifying the period for checkpointing. This prevents the
40 | * lineage of the RDD (internal to state) from becoming too long. Smaller values require
41 | * more frequent writing to disk, larger values require more CPU/memory. The default
42 | * value of 20, is a reasonable trade-off.
43 | * @param writeBufferSize A positive integer specifying the number of samples to queue in memory before writing to
44 | * disk.
45 | * @param collapsedEntityIds A Boolean specifying whether to collapse the distortions when updating the entity ids.
46 | * Defaults to false.
47 | * @param collapsedEntityValues A Boolean specifying whether to collapse the distotions when updating the entity
48 | * values. Defaults to true.
49 | * @return The final state of the Markov chain.
50 | */
51 | def sample(initialState: State,
52 | sampleSize: Int,
53 | outputPath: String,
54 | burninInterval: Int = 0,
55 | thinningInterval: Int = 1,
56 | checkpointInterval: Int = 20,
57 | writeBufferSize: Int = 10,
58 | collapsedEntityIds: Boolean = false,
59 | collapsedEntityValues: Boolean = true,
60 | sequential: Boolean = false): State = {
61 | require(sampleSize > 0, "`sampleSize` must be positive.")
62 | require(burninInterval >= 0, "`burninInterval` must be non-negative.")
63 | require(thinningInterval > 0, "`thinningInterval` must be positive.")
64 | require(checkpointInterval >= 0, "`checkpointInterval` must be non-negative.")
65 | require(writeBufferSize > 0, "`writeBufferSize` must be positive.")
66 | // TODO: ensure that savePath is a valid directory
67 |
68 | var sampleCtr = 0 // counter for number of samples produced (excludes burn-in/thinning)
69 | var state = initialState // current state
70 | val initialIteration = initialState.iteration // initial iteration (need not be zero)
71 | val continueChain = initialIteration != 0 // whether we're continuing a previous chain
72 |
73 | implicit val spark: SparkSession = SparkSession.builder().getOrCreate()
74 | implicit val sc: SparkContext = spark.sparkContext
75 | import spark.implicits._
76 |
77 | // Set-up writers
78 | val linkagePath = outputPath + "linkage-chain.parquet"
79 | var linkageWriter = BufferedRDDWriter[LinkageState](writeBufferSize, linkagePath, continueChain)
80 | val diagnosticsPath = outputPath + "diagnostics.csv"
81 | val diagnosticsWriter = new DiagnosticsWriter(diagnosticsPath, continueChain)
82 | val checkpointer = new PeriodicRDDCheckpointer[(PartitionId, EntRecCluster)](checkpointInterval, sc)
83 |
84 | if (!continueChain && burninInterval == 0) {
85 | // Need to record initial state
86 | checkpointer.update(state.partitions)
87 | linkageWriter = linkageWriter.append(state.getLinkageStructure())
88 | diagnosticsWriter.writeRow(state)
89 | }
90 |
91 | if (burninInterval > 0) info(s"Running burn-in for $burninInterval iterations.")
92 | while (sampleCtr < sampleSize) {
93 | state = state.nextState(checkpointer = checkpointer, collapsedEntityIds, collapsedEntityValues, sequential)
94 |
95 | //newState.partitions.persist(StorageLevel.MEMORY_ONLY_SER)
96 | //state = newState
97 | val completedIterations = state.iteration - initialIteration
98 |
99 | if (completedIterations - 1 == burninInterval) {
100 | if (burninInterval > 0) info("Burn-in complete.")
101 | info(s"Generating $sampleSize sample(s) with thinningInterval=$thinningInterval.")
102 | }
103 |
104 | if (completedIterations >= burninInterval) {
105 | // Finished burn-in, so start writing samples to disk (accounting for thinning)
106 | if ((completedIterations - burninInterval)%thinningInterval == 0) {
107 | linkageWriter = linkageWriter.append(state.getLinkageStructure())
108 | diagnosticsWriter.writeRow(state)
109 | sampleCtr += 1
110 | }
111 | }
112 |
113 | // Ensure writer is kept alive (may die if burninInterval/thinningInterval is large)
114 | diagnosticsWriter.progress()
115 | }
116 |
117 | info("Sampling complete. Writing final state and remaining samples to disk.")
118 | linkageWriter = linkageWriter.flush()
119 | diagnosticsWriter.close()
120 | state.save(outputPath)
121 | checkpointer.deleteAllCheckpoints()
122 | info(s"Finished writing to disk at $outputPath")
123 | state
124 | }
125 | }
126 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/util/PeriodicCheckpointer.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.github.cleanzr.dblink.util
19 |
20 | import scala.collection.mutable
21 |
22 | import org.apache.hadoop.conf.Configuration
23 | import org.apache.hadoop.fs.Path
24 |
25 | import org.apache.spark.SparkContext
26 | import org.apache.spark.internal.Logging
27 |
28 |
29 | /**
30 | * This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs
31 | * (such as Graphs and DataFrames). In documentation, we use the phrase "Dataset" to refer to
32 | * the distributed data type (RDD, Graph, etc.).
33 | *
34 | * Specifically, this abstraction automatically handles persisting and (optionally) checkpointing,
35 | * as well as unpersisting and removing checkpoint files.
36 | *
37 | * Users should call update() when a new Dataset has been created,
38 | * before the Dataset has been materialized. After updating [[PeriodicCheckpointer]], users are
39 | * responsible for materializing the Dataset to ensure that persisting and checkpointing actually
40 | * occur.
41 | *
42 | * When update() is called, this does the following:
43 | * - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets.
44 | * - Unpersist Datasets from queue until there are at most 3 persisted Datasets.
45 | * - If using checkpointing and the checkpoint interval has been reached,
46 | * - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets.
47 | * - Remove older checkpoints.
48 | *
49 | * WARNINGS:
50 | * - This class should NOT be copied (since copies may conflict on which Datasets should be
51 | * checkpointed).
52 | * - This class removes checkpoint files once later Datasets have been checkpointed.
53 | * However, references to the older Datasets will still return isCheckpointed = true.
54 | *
55 | * @param checkpointInterval Datasets will be checkpointed at this interval.
56 | * If this interval was set as -1, then checkpointing will be disabled.
57 | * @param sc SparkContext for the Datasets given to this checkpointer
58 | * @tparam T Dataset type, such as RDD[Double]
59 | */
60 | abstract class PeriodicCheckpointer[T](val checkpointInterval: Int,
61 | val sc: SparkContext) extends Logging {
62 |
63 | /** FIFO queue of past checkpointed Datasets */
64 | private val checkpointQueue = mutable.Queue[T]()
65 |
66 | /** FIFO queue of past persisted Datasets */
67 | private val persistedQueue = mutable.Queue[T]()
68 |
69 | /** Number of times [[update()]] has been called */
70 | private var updateCount = 0
71 |
72 | /**
73 | * Update with a new Dataset. Handle persistence and checkpointing as needed.
74 | * Since this handles persistence and checkpointing, this should be called before the Dataset
75 | * has been materialized.
76 | *
77 | * @param newData New Dataset created from previous Datasets in the lineage.
78 | */
79 | def update(newData: T): Unit = {
80 | persist(newData)
81 | persistedQueue.enqueue(newData)
82 | // We try to maintain 2 Datasets in persistedQueue to support the semantics of this class:
83 | // Users should call [[update()]] when a new Dataset has been created,
84 | // before the Dataset has been materialized.
85 | while (persistedQueue.size > 3) {
86 | val dataToUnpersist = persistedQueue.dequeue()
87 | unpersist(dataToUnpersist)
88 | }
89 | updateCount += 1
90 |
91 | // Handle checkpointing (after persisting)
92 | if (checkpointInterval != -1 && (updateCount % checkpointInterval) == 0
93 | && sc.getCheckpointDir.nonEmpty) {
94 | // Add new checkpoint before removing old checkpoints.
95 | checkpoint(newData)
96 | checkpointQueue.enqueue(newData)
97 | // Remove checkpoints before the latest one.
98 | var canDelete = true
99 | while (checkpointQueue.size > 1 && canDelete) {
100 | // Delete the oldest checkpoint only if the next checkpoint exists.
101 | if (isCheckpointed(checkpointQueue.head)) {
102 | removeCheckpointFile()
103 | } else {
104 | canDelete = false
105 | }
106 | }
107 | }
108 | }
109 |
110 | /** Checkpoint the Dataset */
111 | protected def checkpoint(data: T): Unit
112 |
113 | /** Return true iff the Dataset is checkpointed */
114 | protected def isCheckpointed(data: T): Boolean
115 |
116 | /**
117 | * Persist the Dataset.
118 | * Note: This should handle checking the current [[StorageLevel]] of the Dataset.
119 | */
120 | protected def persist(data: T): Unit
121 |
122 | /** Unpersist the Dataset */
123 | protected def unpersist(data: T): Unit
124 |
125 | /** Get list of checkpoint files for this given Dataset */
126 | protected def getCheckpointFiles(data: T): Iterable[String]
127 |
128 | /**
129 | * Call this to unpersist the Dataset.
130 | */
131 | def unpersistDataSet(): Unit = {
132 | while (persistedQueue.nonEmpty) {
133 | val dataToUnpersist = persistedQueue.dequeue()
134 | unpersist(dataToUnpersist)
135 | }
136 | }
137 |
138 | /**
139 | * Call this at the end to delete any remaining checkpoint files.
140 | */
141 | def deleteAllCheckpoints(): Unit = {
142 | while (checkpointQueue.nonEmpty) {
143 | removeCheckpointFile()
144 | }
145 | }
146 |
147 | /**
148 | * Call this at the end to delete any remaining checkpoint files, except for the last checkpoint.
149 | * Note that there may not be any checkpoints at all.
150 | */
151 | def deleteAllCheckpointsButLast(): Unit = {
152 | while (checkpointQueue.size > 1) {
153 | removeCheckpointFile()
154 | }
155 | }
156 |
157 | /**
158 | * Get all current checkpoint files.
159 | * This is useful in combination with [[deleteAllCheckpointsButLast()]].
160 | */
161 | def getAllCheckpointFiles: Array[String] = {
162 | checkpointQueue.flatMap(getCheckpointFiles).toArray
163 | }
164 |
165 | /**
166 | * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files.
167 | * This prints a warning but does not fail if the files cannot be removed.
168 | */
169 | private def removeCheckpointFile(): Unit = {
170 | val old = checkpointQueue.dequeue()
171 | // Since the old checkpoint is not deleted by Spark, we manually delete it.
172 | getCheckpointFiles(old).foreach(
173 | PeriodicCheckpointer.removeCheckpointFile(_, sc.hadoopConfiguration))
174 | }
175 | }
176 |
177 | object PeriodicCheckpointer extends Logging {
178 |
179 | /** Delete a checkpoint file, and log a warning if deletion fails. */
180 | def removeCheckpointFile(checkpointFile: String, conf: Configuration): Unit = {
181 | try {
182 | val path = new Path(checkpointFile)
183 | val fs = path.getFileSystem(conf)
184 | fs.delete(path, true)
185 | } catch {
186 | case e: Exception =>
187 | logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
188 | checkpointFile)
189 | }
190 | }
191 | }
--------------------------------------------------------------------------------
/docs/guide.md:
--------------------------------------------------------------------------------
1 | # Step-by-step guide
2 | This guide will take you through the steps involved in running dblink on a
3 | small test data set. To make the guide accessible, we assume that you're
4 | running dblink on your local machine. Of course, in practical applications
5 | you'll likely want to run dblink on a cluster. We'll provide some pointers
6 | for this option as we go along.
7 |
8 | ## 0. Install Java
9 | The following two steps require that Java 8+ is installed on your system.
10 | To check whether it is installed on a macOS or Linux system, run the command
11 | ```bash
12 | $ java -version
13 | ```
14 | You should see a version number of the form 8.x (or equivalently 1.8.x).
15 | Installation instructions for Oracle JDK on Windows, macOS and Linux are
16 | available [here](https://java.com/en/download/help/download_options.xml).
17 |
18 | _Note: As of April 2019, the licensing terms of the Oracle JDK have changed.
19 | We recommend using an open source alternative such as the OpenJDK. Packages
20 | are available in many Linux distributions. Instructions for macOS are
21 | available [here](macos-java8.md)._
22 |
23 | ## 1. Get access to a Spark cluster
24 | Since dblink is implemented as a Spark application, you'll need access to a
25 | Spark cluster in order to run it.
26 | Setting up a Spark cluster from scratch can be quite involved and is beyond
27 | the scope of this guide.
28 | We refer interested readers to the Spark
29 | [documentation](https://spark.apache.org/docs/latest/#launching-on-a-cluster),
30 | which discusses various deployment options.
31 | An easier route for most users is to use a preconfigured Spark cluster
32 | available through public cloud providers, such as
33 | [Amazon EMR](https://aws.amazon.com/emr/),
34 | [Azure HDInsight](https://azure.microsoft.com/en-us/services/hdinsight/),
35 | and [Google Cloud Dataproc](https://cloud.google.com/dataproc/).
36 | In this guide, we take an even simpler approach: we'll run Spark in
37 | _pseudocluster mode_ on your local machine.
38 | This is fine for testing purposes or for small data sets.
39 |
40 | We'll now take you through detailed instructions for setting up Spark in
41 | pseudocluster mode on a macOS or Linux system.
42 |
43 | First, download the prebuilt 2.3.1 release from the Spark
44 | [release archive](https://archive.apache.org/dist/spark/).
45 | ```bash
46 | $ wget https://archive.apache.org/dist/spark/spark-2.4.5/spark-2.4.5-bin-hadoop2.7.tgz
47 | ```
48 | then extract the archive.
49 | ```bash
50 | $ tar -xvf spark-2.4.5-bin-hadoop2.7.tgz
51 | ```
52 |
53 | Move the Spark folder to `/opt` and create a symbolic link so that you can
54 | easily switch to another version in the future.
55 | ```bash
56 | $ sudo mv spark-2.4.5-bin-hadoop2.7 /opt
57 | $ sudo ln -s /opt/spark-2.4.5-bin-hadoop2.7/ /opt/spark
58 | ```
59 |
60 | Define the `SPARK_HOME` variable and add the Spark binaries to your `PATH`.
61 | The way that this is done depends on your operating system and/or shell.
62 | Assuming enviornment variables are defined in `~/.profile`, you can
63 | run the following commands:
64 | ```bash
65 | $ echo 'export SPARK_HOME=/opt/spark' >> ~/.profile
66 | $ echo 'export PATH=$PATH:$SPARK_HOME/bin' >> ~/.profile
67 | ```
68 |
69 | After appending these two lines, run the following command to update your
70 | path for the current session.
71 | ```bash
72 | $ source ~/.profile
73 | ```
74 |
75 | Notes:
76 | * If using Bash on Debian, Fedora or RHEL derivatives, environment
77 | variables are typically defined in `~/.bash_profile` rather than
78 | `~/.profile`
79 | * If using ZSH, environment variables are typically defined in
80 | `~/.zprofile`
81 | * You can check which shell you're using by running `echo $SHELL`
82 |
83 | ## 2. Obtain the dblink JAR file
84 | In this step you'll obtain the dblink fat JAR, which has file name
85 | `dblink-assembly-0.2.0.jar`.
86 | It contains all of the class files and resources for dblink, packed together
87 | with any dependencies.
88 |
89 | There are two options:
90 | * (Recommended) Download a prebuilt JAR from [here](https://github.com/ngmarchant/dblink/releases).
91 | This has been built against Spark 2.4.5 and is not guaranteed to work with
92 | other versions of Spark.
93 | * Build the fat JAR file from source as explained in the section below.
94 |
95 | ### 2.1. Building the fat JAR
96 | The build tool used for dblink is called sbt. You'll need to install
97 | sbt on your system. Instructions are available for Windows, macOS and Linux
98 | in the sbt. We give alternative installtion in the second set of instructions
99 | for those using bash on MacOS.
100 | [documentation](https://www.scala-sbt.org/1.x/docs/Setup.html)
101 |
102 | On macOS or Linux, you can verify that sbt is installed correctly by running.
103 | ```bash
104 | $ sbt about
105 | ```
106 |
107 | Once you've successfully installed sbt, get the dblink source code from
108 | GitHub:
109 | ```bash
110 | $ git clone https://github.com/cleanzr/dblink.git
111 | ```
112 | then change into the dblink directory and build the package
113 | ```bash
114 | $ cd dblink
115 | $ sbt assembly
116 | ```
117 | This should produce a fat JAR at `./target/scala-2.11/dblink-assembly-0.2.0.jar`.
118 |
119 | _Note: [IntelliJ IDEA](https://www.jetbrains.com/idea/) can also be used to
120 | build the fat JAR. It is arguably more user-friendly as it has a GUI and
121 | users can avoid installing sbt._
122 |
123 | ## 3. Run dblink
124 | Having completed the above two steps, you're now ready to launch dblink.
125 | This is done using the [`spark-submit`](https://spark.apache.org/docs/latest/submitting-applications.html)
126 | interface, which supports all types of Spark deployments.
127 |
128 | As a test, let's try running the RLdata500 example provided with the source
129 | code on your local machine.
130 | From within the `dblink` directory, run the following command:
131 | ```bash
132 | $SPARK_HOME/bin/spark-submit \
133 | --master "local[1]" \
134 | --conf "spark.driver.extraJavaOptions=-Dlog4j.configuration=log4j.properties" \
135 | --conf "spark.driver.extraClassPath=./target/scala-2.11/dblink-assembly-0.2.0.jar" \
136 | ./target/scala-2.11/dblink-assembly-0.2.0.jar \
137 | ./examples/RLdata500.conf
138 | ```
139 | This will run Spark in pseudocluster (local) mode with 1 core. You can increase
140 | the number of cores available by changing `local[1]` to `local[n]` where `n`
141 | is the number of cores or `local[*]` to use all available cores.
142 | To run dblink on other data sets you will need to edit the config file (called
143 | `RLdata500.conf` above).
144 | Instructions for doing this are provided [here](configuration.md).
145 |
146 | ## 4. Output of dblink
147 | dblink saves output into a specified directory. In the RLdata500 example from
148 | above, the output is written to `./examples/RLdata500_results/`.
149 |
150 | Below we provide a brief description of the files:
151 |
152 | * `run.txt`: contains details about the job (MCMC run). This includes the
153 | data files, the attributes used, parameter settings etc.
154 | * `partitions-state.parquet` and `driver-state`: stores the final state of
155 | the Markov chain, so that MCMC can be resumed (e.g. you can run the Markov
156 | chain for longer without starting from scratch).
157 | * `diagnostics.csv` contains summary statistics along the chain which can be
158 | used to assess convergence/mixing.
159 | * `linkage-chain.parquet` contains posterior samples of the linkage structure
160 | in Parquet format.
161 |
162 | Optional files:
163 |
164 | * `evaluation-results.txt`: contains output from an "evaluate" step (e.g.
165 | precision, recall, other measures). Requires ground truth entity identifiers
166 | in the data files.
167 | * `cluster-size-distribution.csv` contains the cluster size distribution
168 | along the chain (rows are iterations, columns contain counts for each
169 | cluster/entity size. Only appears if requested in a "summarize" step.
170 | * `partition-sizes.csv` contains the partition sizes along the chain (rows
171 | are iterations, columns are counts of the number of entities residing in each
172 | partition). Only appears if requested in a "summarize" step.
173 | * `shared-most-probable-clusters.csv` is a point estimate of the linkage
174 | structure computed from the posterior samples. Each line in the file contains
175 | a comma-separated list of record identifiers which are assigned to the same
176 | cluster/entity.
177 |
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/ProjectStep.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import com.github.cleanzr.dblink.analysis.{ClusteringMetrics, PairwiseMetrics}
23 | import com.github.cleanzr.dblink.util.BufferedFileWriter
24 | import com.github.cleanzr.dblink.LinkageChain._
25 | import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
26 | import org.apache.spark.storage.StorageLevel
27 |
28 | trait ProjectStep {
29 | def execute(): Unit
30 |
31 | def mkString: String
32 | }
33 |
34 | object ProjectStep {
35 | private val supportedSamplers = Set("PCG-I", "PCG-II", "Gibbs", "Gibbs-Sequential")
36 | private val supportedEvaluationMetrics = Set("pairwise", "cluster")
37 | private val supportedSummaryQuantities = Set("cluster-size-distribution", "partition-sizes", "shared-most-probable-clusters")
38 |
39 | class SampleStep(project: Project, sampleSize: Int, burninInterval: Int,
40 | thinningInterval: Int, resume: Boolean, sampler: String) extends ProjectStep with Logging {
41 | require(sampleSize > 0, "sampleSize must be positive")
42 | require(burninInterval >= 0, "burninInterval must be non-negative")
43 | require(thinningInterval >= 0, "thinningInterval must be non-negative")
44 | require(supportedSamplers.contains(sampler), s"sampler must be one of ${supportedSamplers.mkString("", ", ", "")}.")
45 |
46 | override def execute(): Unit = {
47 | info(mkString)
48 | val initialState = if (resume) {
49 | project.savedState.getOrElse(project.generateInitialState)
50 | } else {
51 | project.generateInitialState
52 | }
53 | sampler match {
54 | case "PCG-I" => Sampler.sample(initialState, sampleSize, project.outputPath, burninInterval=burninInterval, thinningInterval=thinningInterval, collapsedEntityIds = false, collapsedEntityValues = true, sequential = false)
55 | case "PCG-II" => Sampler.sample(initialState, sampleSize, project.outputPath, burninInterval=burninInterval, thinningInterval=thinningInterval, collapsedEntityIds = true, collapsedEntityValues = true, sequential = false)
56 | case "Gibbs" => Sampler.sample(initialState, sampleSize, project.outputPath, burninInterval=burninInterval, thinningInterval=thinningInterval, collapsedEntityIds = false, collapsedEntityValues = false, sequential = false)
57 | case "Gibbs-Sequential" => Sampler.sample(initialState, sampleSize, project.outputPath, burninInterval=burninInterval, thinningInterval=thinningInterval, collapsedEntityIds = false, collapsedEntityValues = false, sequential = true)
58 | }
59 | }
60 |
61 | override def mkString: String = {
62 | if (resume) s"SampleStep: Evolving the chain from saved state with sampleSize=$sampleSize, burninInterval=$burninInterval, thinningInterval=$thinningInterval and sampler=$sampler"
63 | else s"SampleStep: Evolving the chain from new initial state with sampleSize=$sampleSize, burninInterval=$burninInterval, thinningInterval=$thinningInterval and sampler=$sampler"
64 | }
65 | }
66 |
67 | class EvaluateStep(project: Project, lowerIterationCutoff: Int, metrics: Traversable[String],
68 | useExistingSMPC: Boolean) extends ProjectStep with Logging {
69 | require(project.entIdAttribute.isDefined, "Ground truth entity ids are required for evaluation")
70 | require(lowerIterationCutoff >=0, "lowerIterationCutoff must be non-negative")
71 | require(metrics.nonEmpty, "metrics must be non-empty")
72 | require(metrics.forall(m => supportedEvaluationMetrics.contains(m)), s"metrics must be one of ${supportedEvaluationMetrics.mkString("{", ", ", "}")}.")
73 |
74 | override def execute(): Unit = {
75 | info(mkString)
76 |
77 | // Get ground truth clustering
78 | val trueClusters = project.trueClusters match {
79 | case Some(clusters) => clusters.persist()
80 | case None =>
81 | error("Ground truth clusters are unavailable")
82 | return
83 | }
84 |
85 | import analysis._
86 |
87 | // Get predicted clustering (using sMPC method)
88 | val sMPC = if (useExistingSMPC && project.sharedMostProbableClustersOnDisk) {
89 | // Read saved sMPC from disk
90 | project.savedSharedMostProbableClusters
91 | } else {
92 | // Try to compute sMPC using saved linkage chain (and save to disk)
93 | project.savedLinkageChain(lowerIterationCutoff) match {
94 | case Some(chain) =>
95 | val sMPC = sharedMostProbableClusters(chain).persist()
96 | sMPC.saveCsv(project.outputPath + "shared-most-probable-clusters.csv")
97 | chain.unpersist()
98 | Some(sMPC)
99 | case None =>
100 | error("No linkage chain")
101 | None
102 | }
103 | }
104 |
105 | sMPC match {
106 | case Some(predictedClusters) =>
107 | val results = metrics.map {
108 | case metric if metric == "pairwise" =>
109 | PairwiseMetrics(predictedClusters.toPairwiseLinks, trueClusters.toPairwiseLinks).mkString
110 | case metric if metric == "cluster" =>
111 | ClusteringMetrics(predictedClusters, trueClusters).mkString
112 | }
113 | val writer = BufferedFileWriter(project.outputPath + "evaluation-results.txt", append = false, project.sparkContext)
114 | writer.write(results.mkString("", "\n", "\n"))
115 | writer.close()
116 | case None => error("Predicted clusters are unavailable")
117 | }
118 | }
119 |
120 | override def mkString: String = {
121 | if (useExistingSMPC) s"EvaluateStep: Evaluating saved sMPC clusters using ${metrics.map("'" + _ + "'").mkString("{", ", ", "}")} metrics"
122 | else s"EvaluateStep: Evaluating sMPC clusters (computed from the chain for iterations >= $lowerIterationCutoff) using ${metrics.map("'" + _ + "'").mkString("{", ", ", "}")} metrics"
123 | }
124 | }
125 |
126 | class SummarizeStep(project: Project, lowerIterationCutoff: Int,
127 | quantities: Traversable[String]) extends ProjectStep with Logging {
128 | require(lowerIterationCutoff >= 0, "lowerIterationCutoff must be non-negative")
129 | require(quantities.nonEmpty, "quantities must be non-empty")
130 | require(quantities.forall(q => supportedSummaryQuantities.contains(q)), s"quantities must be one of ${supportedSummaryQuantities.mkString("{", ", ", "}")}.")
131 |
132 | override def execute(): Unit = {
133 | info(mkString)
134 | project.savedLinkageChain(lowerIterationCutoff) match {
135 | case Some(chain) =>
136 | quantities.foreach {
137 | case "cluster-size-distribution" =>
138 | val clustSizeDist = clusterSizeDistribution(chain)
139 | saveClusterSizeDistribution(clustSizeDist, project.outputPath)
140 | case "partition-sizes" =>
141 | val partSizes = partitionSizes(chain)
142 | savePartitionSizes(partSizes, project.outputPath)
143 | case "shared-most-probable-clusters" =>
144 | import analysis._
145 | val smpc = sharedMostProbableClusters(chain)
146 | smpc.saveCsv(project.outputPath + "shared-most-probable-clusters.csv")
147 | }
148 | case None => error("No linkage chain")
149 | }
150 | }
151 |
152 | override def mkString: String = {
153 | s"SummarizeStep: Calculating summary quantities ${quantities.map("'" + _ + "'").mkString("{", ", ", "}")} along the chain for iterations >= $lowerIterationCutoff"
154 | }
155 | }
156 |
157 | class CopyFilesStep(project: Project, fileNames: Traversable[String], destinationPath: String,
158 | overwrite: Boolean, deleteSource: Boolean) extends ProjectStep with Logging {
159 |
160 | override def execute(): Unit = {
161 | info(mkString)
162 |
163 | val conf = project.sparkContext.hadoopConfiguration
164 | val srcParent = new Path(project.outputPath)
165 | val srcFs = FileSystem.get(srcParent.toUri, conf)
166 | val dstParent = new Path(destinationPath)
167 | val dstFs = FileSystem.get(dstParent.toUri, conf)
168 | fileNames.map(fName => new Path(srcParent.toString + Path.SEPARATOR + fName))
169 | .filter(srcFs.exists)
170 | .foreach { src =>
171 | val dst = new Path(dstParent.toString + Path.SEPARATOR + src.getName)
172 | FileUtil.copy(srcFs, src, dstFs, dst, deleteSource, overwrite, conf)
173 | }
174 | }
175 |
176 | override def mkString: String = {
177 | s"CopyFilesStep: Copying ${fileNames.mkString("{",", ","}")} to destination $destinationPath"
178 | }
179 | }
180 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/LinkageChain.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Australian Bureau of Statistics
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import com.github.cleanzr.dblink.util.BufferedFileWriter
23 | import org.apache.spark.sql.{Dataset, SparkSession}
24 |
25 | import scala.collection.mutable
26 |
27 | object LinkageChain extends Logging {
28 |
29 | /** Read samples of the linkage structure
30 | *
31 | * @param path path to the output directory.
32 | * @return a linkage chain: an RDD containing samples of the linkage
33 | * structure (by partition) along the Markov chain.
34 | */
35 | def readLinkageChain(path: String): Dataset[LinkageState] = {
36 | // TODO: check path
37 | val spark = SparkSession.builder().getOrCreate()
38 | import spark.implicits._
39 |
40 | spark.read.format("parquet")
41 | .load(path + "linkage-chain.parquet")
42 | .as[LinkageState]
43 | }
44 |
45 |
46 | /**
47 | * Computes the most probable clustering for each record
48 | *
49 | * @param linkageChain A Dataset representing the linkage structure across iterations and partitions.
50 | * @return A Dataset containing the most probable cluster for each record.
51 | */
52 | def mostProbableClusters(linkageChain: Dataset[LinkageState]): Dataset[MostProbableCluster] = {
53 | val spark = linkageChain.sparkSession
54 | import spark.implicits._
55 |
56 | val numSamples = linkageChain.map(_.iteration).distinct().count()
57 | linkageChain.rdd
58 | .flatMap(_.linkageStructure.iterator.collect {case cluster if cluster.nonEmpty => (cluster.toSet, 1.0/numSamples)})
59 | .reduceByKey(_ + _)
60 | .flatMap {case (recIds, freq) => recIds.iterator.map(recId => (recId, (recIds, freq)))}
61 | .reduceByKey((x, y) => if (x._2 >= y._2) x else y)
62 | .map(x => MostProbableCluster(x._1, x._2._1, x._2._2))
63 | .toDS()
64 | }
65 |
66 |
67 | /**
68 | * Computes a point estimate of the most likely clustering that obeys transitivity constraints. The method was
69 | * introduced by Steorts et al. (2016), where it is referred to as the method of shared most probable maximal
70 | * matching sets.
71 | *
72 | * @param mostProbableClusters A Dataset containing the most probable cluster for each record.
73 | * @return A Dataset of record clusters
74 | */
75 | def sharedMostProbableClusters(mostProbableClusters: Dataset[MostProbableCluster]): Dataset[Cluster] = {
76 | val spark = mostProbableClusters.sparkSession
77 | import spark.implicits._
78 |
79 | mostProbableClusters.rdd
80 | .map(x => (x.cluster, x.recordId))
81 | .aggregateByKey(zeroValue = Set.empty[RecordId])(
82 | seqOp = (recordIds, recordId) => recordIds + recordId,
83 | combOp = (recordIdsA, recordIdsB) => recordIdsA union recordIdsB
84 | )
85 | // key = most probable cluster | value = aggregated recIds
86 | .map(_._2)
87 | .toDS()
88 | // .flatMap[Set[RecordIdType]] {
89 | // case (cluster, recIds) if cluster == recIds => Iterator(recIds)
90 | // // cluster is shared most probable for all records it contains
91 | // case (cluster, recIds) if cluster != recIds => recIds.iterator.map(Set(_))
92 | // // cluster isn't shared most probable -- output each record as a
93 | // // separate cluster
94 | // }
95 | }
96 |
97 |
98 | /**
99 | * Computes a point estimate of the most likely clustering that obeys transitivity constraints. The method was
100 | * introduced by Steorts et al. (2016), where it is referred to as the method of shared most probable maximal
101 | * matching sets.
102 | *
103 | * @param linkageChain A Dataset representing the linkage structure across iterations and partitions.
104 | * @return A Dataset of record clusters
105 | */
106 | def sharedMostProbableClusters(linkageChain: Dataset[LinkageState])(implicit i1: DummyImplicit): Dataset[Cluster] = {
107 | val mpc = mostProbableClusters(linkageChain)
108 | sharedMostProbableClusters(mpc)
109 | }
110 |
111 |
112 | /** Computes the partition sizes along the linkage chain
113 | *
114 | * @param linkageChain A Dataset representing the linkage structure across iterations and partitions.
115 | * @return A Dataset containing the partition sizes at each iteration. The key is the iteration and the value is a
116 | * map from partition ids to their corresponding sizes.
117 | */
118 | def partitionSizes(linkageChain: Dataset[LinkageState]): Dataset[(Long, Map[PartitionId, Int])] = {
119 | val spark = linkageChain.sparkSession
120 | import spark.implicits._
121 | linkageChain.rdd
122 | .map(x => (x.iteration, (x.partitionId, x.linkageStructure.size)))
123 | .aggregateByKey(Map.empty[PartitionId, Int])(
124 | seqOp = (m, v) => m + v,
125 | combOp = (m1, m2) => m1 ++ m2
126 | )
127 | .toDS()
128 | }
129 |
130 |
131 | /** Computes the cluster size frequency distribution along the linkage chain
132 | *
133 | * @param linkageChain A Dataset representing the linkage structure across iterations and partitions.
134 | * @return A Dataset containing the cluster size frequency distribution at each iteration. The key is the iteration
135 | * and the value is a map from cluster sizes to their corresponding frequencies.
136 | */
137 | def clusterSizeDistribution(linkageChain: Dataset[LinkageState]): Dataset[(Long, mutable.Map[Int, Long])] = {
138 | // Compute distribution separately for each partition, then combine the results
139 | val spark = linkageChain.sparkSession
140 | import spark.implicits._
141 |
142 | linkageChain.rdd
143 | .map(x => {
144 | val clustSizes = mutable.Map[Int, Long]().withDefaultValue(0L)
145 | x.linkageStructure.foreach { cluster => clustSizes(cluster.size) += 1L }
146 | (x.iteration, clustSizes)
147 | })
148 | .reduceByKey((a, b) => {
149 | val combined = mutable.Map[Int, Long]().withDefaultValue(0L)
150 | (a.keySet ++ b.keySet).foreach(k => combined(k) = a(k) + b(k))
151 | combined // combine maps
152 | })
153 | .toDS()
154 | }
155 |
156 |
157 | /** Computes the cluster size frequency distribution along the linkage chain
158 | * and saves the result to `cluster-size-distribution.csv` in the output directory.
159 | *
160 | * @param path path to working directory
161 | */
162 | def saveClusterSizeDistribution(clusterSizeDistribution: Dataset[(Long, mutable.Map[Int, Long])], path: String): Unit = {
163 | val sc = clusterSizeDistribution.sparkSession.sparkContext
164 |
165 | val distAlongChain = clusterSizeDistribution.collect().sortBy(_._1) // collect on driver and sort by iteration
166 |
167 | // Get the size of the largest cluster in the samples
168 | val maxClustSize = distAlongChain.aggregate(0)(
169 | seqop = (currentMax, x) => math.max(currentMax, x._2.keySet.max),
170 | combop = (a, b) => math.max(a, b)
171 | )
172 | // Output file can be created from Hadoop file system.
173 | val fullPath = path + "cluster-size-distribution.csv"
174 | info(s"Writing cluster size frequency distribution along the chain to $fullPath")
175 | val writer = BufferedFileWriter(fullPath, append = false, sc)
176 |
177 | // Write CSV header
178 | writer.write("iteration" + "," + (0 to maxClustSize).mkString(",") + "\n")
179 | // Write rows (one for each iteration)
180 | distAlongChain.foreach { case (iteration, kToCounts) =>
181 | val countsArray = (0 to maxClustSize).map(k => kToCounts.getOrElse(k, 0L))
182 | writer.write(iteration.toString + "," + countsArray.mkString(",") + "\n")
183 | }
184 | writer.close()
185 | }
186 |
187 |
188 | /** Computes the sizes of the partitions along the chain and saves the result to
189 | * `partition-sizes.csv` in the output directory.
190 | *
191 | * @param path path to output directory
192 | */
193 | def savePartitionSizes(partitionSizes: Dataset[(Long, Map[PartitionId, Int])], path: String): Unit = {
194 | val sc = partitionSizes.sparkSession.sparkContext
195 | val partSizesAlongChain = partitionSizes.collect().sortBy(_._1)
196 |
197 | val partIds = partSizesAlongChain.map(_._2.keySet).reduce(_ ++ _).toArray.sorted
198 |
199 | // Output file can be created from Hadoop file system.
200 | val fullPath = path + "partition-sizes.csv"
201 | val writer = BufferedFileWriter(fullPath, append = false, sc)
202 |
203 | // Write CSV header
204 | writer.write("iteration" + "," + partIds.mkString(",") + "\n")
205 | // Write rows (one for each iteration)
206 | partSizesAlongChain.foreach { case (iteration, m) =>
207 | val sizes = partIds.map(partId => m.getOrElse(partId, 0))
208 | writer.write(iteration.toString + "," + sizes.mkString(",") + "\n")
209 | }
210 | writer.close()
211 | }
212 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/AttributeIndex.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import com.github.cleanzr.dblink.random.DiscreteDist
23 | import com.github.cleanzr.dblink.random.DiscreteDist
24 | import org.apache.commons.math3.random.RandomGenerator
25 | import org.apache.spark.SparkContext
26 |
27 | import scala.collection.mutable
28 | import scala.math.{exp, pow}
29 |
30 | /** An index for an attribute domain.
31 | * It includes:
32 | * - an index from raw strings in the domain to integer value ids
33 | * - the empirical distribution over the domain
34 | * - an index that accepts a query value id and returns a set of
35 | * "similar" value ids (whose truncated similarity is non-zero)
36 | * - an index over pairs of value ids that returns exponentiated
37 | * truncated similarity scores
38 | */
39 | trait AttributeIndex extends Serializable {
40 |
41 | /** Number of distinct attribute values */
42 | def numValues: Int
43 |
44 | /** Object for empirical distribution */
45 | def distribution: DiscreteDist[Int]
46 |
47 | /** Get the probability mass associated with a value id
48 | *
49 | * @param valueId value id.
50 | * @return probability mass. Returns 0.0 if the value id does not exist.
51 | */
52 | def probabilityOf(valueId: ValueId): Double
53 |
54 |
55 | /** Draw a value id according to the empirical distribution
56 | *
57 | * @return a value id.
58 | */
59 | def draw()(implicit rand: RandomGenerator): ValueId
60 |
61 |
62 | /** Get the value id for a given string value
63 | *
64 | * @param value original string value
65 | * @return integer value id. Returns `-1` if value does not exist in the
66 | * index.
67 | */
68 | def valueIdxOf(value: String): ValueId
69 |
70 |
71 | /** Get the similarity normalization corresponding to the value id.
72 | *
73 | * @param valueId integer value id
74 | * @return sum_{w}(probabilities(w) * exp(sim(w,value))
75 | */
76 | def simNormalizationOf(valueId: ValueId): Double
77 |
78 |
79 | /** Get the value ids that are "similar" (above the similarity threshold)
80 | * to the given value id, along with the exponentiated similarity scores.
81 | *
82 | * @param valueId original value string
83 | * @return
84 | */
85 | def simValuesOf(valueId: ValueId): scala.collection.Map[ValueId, Double]
86 |
87 |
88 | /** Get the exponentiated similarity score for a pair of value ids
89 | *
90 | * @param valueId1 integer value id 1
91 | * @param valueId2 integer value id 2
92 | * @return exp(sim(valueId1, valueId2))
93 | */
94 | def expSimOf(valueId1: ValueId, valueId2: ValueId): Double
95 |
96 |
97 | /** Get distribution of the form:
98 | * p(v) \propto probabilityOf(v) * pow(simNormalizationOf(v), power)
99 | *
100 | * @param power the similarity normalization is raised to this power
101 | * @return a distribution
102 | */
103 | def getSimNormDist(power: Int): DiscreteDist[Int]
104 | }
105 |
106 | object AttributeIndex {
107 | def apply(valuesWeights: Map[String, Double],
108 | similarityFn: SimilarityFn,
109 | precachePowers: Option[Traversable[Int]] = None)
110 | (implicit sc: SparkContext): AttributeIndex = {
111 | require(valuesWeights.nonEmpty, "index cannot be empty")
112 |
113 | val valuesWeights_sorted = valuesWeights.toArray.sortBy(_._1)
114 | val totalWeight = valuesWeights_sorted.foldLeft(0.0){ case (sum, (_, weight)) => sum + weight }
115 | val probs = valuesWeights_sorted.map(_._2/totalWeight)
116 | val stringToId = valuesWeights_sorted.iterator.zipWithIndex.map(x => (x._1._1, x._2)).toMap
117 |
118 | similarityFn match {
119 | case SimilarityFn.ConstantSimilarityFn => new ConstantAttributeIndex(stringToId, probs)
120 | case simFn =>
121 | val simValueIndex = computeSimValueIndex(stringToId, simFn)
122 | val simNormalizations = computeSimNormalizations(simValueIndex, probs)
123 | new GenericAttributeIndex(stringToId, probs, simValueIndex, simNormalizations, precachePowers)
124 | }
125 | }
126 |
127 | /** Implementation for attributes with constant similarity functions */
128 | private class ConstantAttributeIndex(protected val stringToId: Map[String, ValueId],
129 | protected val probs: Array[Double]) extends AttributeIndex {
130 |
131 | /** Empirical distribution over the field values */
132 | val distribution: DiscreteDist[Int] = DiscreteDist(probs)
133 |
134 | val numValues: Int = stringToId.size
135 |
136 | override def probabilityOf(valueId: ValueId): Double = {
137 | require(valueId >= 0 && valueId < numValues, "valueId is not in the index")
138 | distribution.probabilityOf(valueId)
139 | }
140 |
141 | override def draw()(implicit rand: RandomGenerator): ValueId = distribution.sample()
142 |
143 | override def valueIdxOf(value: String): ValueId = stringToId(value)
144 |
145 | /** Assuming exp(sim(v1, v2)) = 1.0 for all v1, v2 */
146 | override def simNormalizationOf(valueId: ValueId): Double = {
147 | require(valueId >= 0 && valueId < numValues, "valueId is not in the index")
148 | 1.0
149 | }
150 |
151 | /** Assuming exp(sim(v1, v2)) = 1.0 for all v1, v2 */
152 | override def simValuesOf(valueId: ValueId): scala.collection.Map[ValueId, Double] = {
153 | require(valueId >= 0 && valueId < numValues, "valueId is not in the index")
154 | Map.empty[ValueId, Double]
155 | }
156 |
157 | /** Assuming exp(sim(v1, v2)) = 1.0 for all v1, v2 */
158 | override def expSimOf(valueId1: ValueId, valueId2: ValueId): Double = {
159 | require(valueId1 >= 0 && valueId1 < numValues, "valueId1 is not in the index")
160 | require(valueId2 >= 0 && valueId2 < numValues, "valueId2 is not in the index")
161 | 1.0
162 | }
163 |
164 | /** For a constant similarity function, this is the same as the empirical distribution */
165 | override def getSimNormDist(power: Int): DiscreteDist[Int] = {
166 | require(power > 0, "power must be a positive integer")
167 | distribution
168 | }
169 | }
170 |
171 | /** Implementation for attributes with non-constant attribute similarity functions */
172 | private class GenericAttributeIndex(stringToId: Map[String, ValueId],
173 | probs: Array[Double],
174 | private val simValueIndex: Array[Map[ValueId, Double]],
175 | private val simNormalizations: Array[Double],
176 | precachePowers: Option[Traversable[Int]])
177 | extends ConstantAttributeIndex(stringToId, probs) {
178 |
179 | override def simNormalizationOf(valueId: ValueId): Double = simNormalizations(valueId)
180 |
181 | override def simValuesOf(valueId: ValueId): scala.collection.Map[ValueId, Double] = simValueIndex(valueId)
182 |
183 | override def expSimOf(valueId1: ValueId, valueId2: ValueId): Double = {
184 | require(valueId2 >= 0 && valueId2 < numValues, "valueId2 is not in the index")
185 | simValuesOf(valueId1).getOrElse(valueId2, 1.0)
186 | }
187 |
188 | private val cachedSimNormDist = precachePowers match {
189 | case Some(powers) => powers.foldLeft(mutable.Map.empty[Int, DiscreteDist[Int]]) {
190 | case (m, power) if power > 0 =>
191 | m + (power -> computeSimNormDist(probs, simNormalizations, power))
192 | case (m, _) => m
193 | }
194 | case None => mutable.Map.empty[Int, DiscreteDist[Int]]
195 | }
196 |
197 | override def getSimNormDist(power: Int): DiscreteDist[Int] = {
198 | require(power > 0, "power must be a positive integer")
199 | cachedSimNormDist.get(power) match {
200 | case Some(dist) => dist
201 | case None =>
202 | val dist = computeSimNormDist(probs, simNormalizations, power)
203 | cachedSimNormDist.update(power, dist)
204 | dist
205 | }
206 | }
207 | }
208 |
209 |
210 | private def computeSimNormDist(probs: Array[Double], simNormalizations: Array[Double],
211 | power: Int): DiscreteDist[Int] = {
212 | val simNormWeights = Array.tabulate(probs.length) { valueId =>
213 | probs(valueId) * pow(simNormalizations(valueId), power)
214 | }
215 | DiscreteDist(simNormWeights)
216 | }
217 |
218 |
219 | private def computeSimValueIndex(stringToId: Map[String, ValueId],
220 | similarityFn: SimilarityFn)
221 | (implicit sc: SparkContext): Array[Map[ValueId, Double]] = {
222 | val valuesRDD = sc.parallelize(stringToId.toSeq) // TODO: how to set numSlices? Take into account number of executors.
223 | val valuePairsRDD = valuesRDD.cartesian(valuesRDD)
224 |
225 | val simValues = valuePairsRDD.map { case ((a, a_id), (b, b_id)) => (a_id, (b_id, exp(similarityFn.getSimilarity(a, b)))) }
226 | .filter(_._2._2 > 1.0) // non-zero truncated similarity
227 | .aggregateByKey(Map.empty[Int, Double])(seqOp = _ + _, combOp = _ ++ _)
228 | .collect()
229 |
230 | simValues.sortBy(_._1).map(_._2)
231 | }
232 |
233 |
234 | private def computeSimNormalizations(simValueIndex: Array[Map[ValueId, Double]],
235 | probs: Array[Double]): Array[Double] = {
236 | simValueIndex.map { simValues =>
237 | var valueId = 0
238 | var norm = 0.0
239 | while (valueId < probs.length) {
240 | norm += probs(valueId) * simValues.getOrElse(valueId, 1.0)
241 | valueId += 1
242 | }
243 | 1.0/norm
244 | }
245 | }
246 | }
--------------------------------------------------------------------------------
/src/main/scala/com/github/cleanzr/dblink/Project.scala:
--------------------------------------------------------------------------------
1 | // Copyright (C) 2018 Neil Marchant
2 | //
3 | // Author: Neil Marchant
4 | //
5 | // This file is part of dblink.
6 | //
7 | // This program is free software: you can redistribute it and/or modify
8 | // it under the terms of the GNU General Public License as published by
9 | // the Free Software Foundation, either version 3 of the License, or
10 | // (at your option) any later version.
11 | //
12 | // This program is distributed in the hope that it will be useful,
13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | // GNU General Public License for more details.
16 | //
17 | // You should have received a copy of the GNU General Public License
18 | // along with this program. If not, see .
19 |
20 | package com.github.cleanzr.dblink
21 |
22 | import SimilarityFn.{ConstantSimilarityFn, LevenshteinSimilarityFn}
23 | import com.github.cleanzr.dblink.analysis._
24 | import partitioning.{KDTreePartitioner, PartitionFunction}
25 | import com.typesafe.config.{Config, ConfigException, ConfigObject}
26 | import org.apache.hadoop.fs.{FileSystem, Path}
27 | import org.apache.spark.SparkContext
28 | import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
29 |
30 | import scala.collection.JavaConverters._
31 | import scala.collection.mutable
32 | import scala.util.Try
33 |
34 | /** An entity resolution project
35 | *
36 | * @param dataPath path to source records (in CSV format)
37 | * @param outputPath path to project directory
38 | * @param checkpointPath path for saving Spark checkpoints
39 | * @param recIdAttribute name of record identifier column in dataFrame (must be unique across all files)
40 | * @param fileIdAttribute name of file identifier column in dataFrame (optional)
41 | * @param entIdAttribute name of entity identifier column in dataFrame (optional: if ground truth is available)
42 | * @param matchingAttributes attribute specifications to use for matching
43 | * @param partitionFunction partition function (determines how entities are partitioned across executors)
44 | * @param randomSeed random seed
45 | * @param populationSize size of the latent population
46 | * @param expectedMaxClusterSize expected size of the largest record cluster (used as a hint to improve precaching)
47 | * @param dataFrame data frame containing source records
48 | */
49 | case class Project(dataPath: String, outputPath: String, checkpointPath: String,
50 | recIdAttribute: String, fileIdAttribute: Option[String],
51 | entIdAttribute: Option[String], matchingAttributes: IndexedSeq[Attribute],
52 | partitionFunction: PartitionFunction[ValueId], randomSeed: Long, populationSize: Option[Int],
53 | expectedMaxClusterSize: Int, dataFrame: DataFrame) extends Logging {
54 | require(expectedMaxClusterSize >= 0, "expectedMaxClusterSize must be non-negative")
55 |
56 | def sparkContext: SparkContext = dataFrame.sparkSession.sparkContext
57 |
58 | def mkString: String = {
59 | val lines = mutable.ArrayBuffer.empty[String]
60 | lines += "Data settings"
61 | lines += "-------------"
62 | lines += s" * Using data files located at '$dataPath'"
63 | lines += s" * The record identifier attribute is '$recIdAttribute'"
64 | fileIdAttribute match {
65 | case Some(fId) => lines += s" * The file identifier attribute is '$fId'"
66 | case None => lines += " * There is no file identifier"
67 | }
68 | entIdAttribute match {
69 | case Some(eId) => lines += s" * The entity identifier attribute is '$eId'"
70 | case None => lines += " * There is no entity identifier"
71 | }
72 | lines += s" * The matching attributes are ${matchingAttributes.map("'" + _.name + "'").mkString(", ")}"
73 | lines += ""
74 |
75 | lines += "Hyperparameter settings"
76 | lines += "-----------------------"
77 | lines ++= matchingAttributes.zipWithIndex.map { case (attribute, attributeId) =>
78 | s" * '${attribute.name}' (id=$attributeId) with ${attribute.similarityFn.mkString} and ${attribute.distortionPrior.mkString}"
79 | }
80 | lines += s" * Size of latent population is ${populationSize.toString}"
81 | lines += ""
82 |
83 | lines += "Partition function settings"
84 | lines += "---------------------------"
85 | lines += " * " + partitionFunction.mkString
86 | lines += ""
87 |
88 | lines += "Project settings"
89 | lines += "----------------"
90 | lines += s" * Using randomSeed=$randomSeed"
91 | lines += s" * Using expectedMaxClusterSize=$expectedMaxClusterSize"
92 | lines += s" * Saving Markov chain and complete final state to '$outputPath'"
93 | lines += s" * Saving Spark checkpoints to '$checkpointPath'"
94 |
95 | lines.mkString("","\n","\n")
96 | }
97 |
98 | def sharedMostProbableClustersOnDisk: Boolean = {
99 | val hdfs = FileSystem.get(sparkContext.hadoopConfiguration)
100 | val fSMPC = new Path(outputPath + "shared-most-probable-clusters.csv")
101 | hdfs.exists(fSMPC)
102 | }
103 |
104 | def savedLinkageChain(lowerIterationCutoff: Int = 0): Option[Dataset[LinkageState]] = {
105 | val savedLinkageChainExists = {
106 | val hdfs = FileSystem.get(sparkContext.hadoopConfiguration)
107 | val file = new Path(outputPath + "linkage-chain.parquet")
108 | hdfs.exists(file)
109 | }
110 | if (savedLinkageChainExists) {
111 | val chain = if (lowerIterationCutoff == 0) LinkageChain.readLinkageChain(outputPath)
112 | else LinkageChain.readLinkageChain(outputPath).filter(_.iteration >= lowerIterationCutoff)
113 | if (chain.take(1).isEmpty) None
114 | else Some(chain)
115 | } else None
116 | }
117 |
118 | def savedState: Option[State] = {
119 | val savedStateExists = {
120 | val hdfs = FileSystem.get(sparkContext.hadoopConfiguration)
121 | val fDriverState = new Path(outputPath + "driver-state")
122 | val fPartitionState = new Path(outputPath + "partitions-state.parquet")
123 | hdfs.exists(fDriverState) && hdfs.exists(fPartitionState)
124 | }
125 | if (savedStateExists) {
126 | Some(State.read(outputPath))
127 | } else None
128 | }
129 |
130 | def generateInitialState: State = {
131 | info("Generating new initial state")
132 | val parameters = Parameters(
133 | maxClusterSize = expectedMaxClusterSize
134 | )
135 | State.deterministic(
136 | records = dataFrame,
137 | recIdColname = recIdAttribute,
138 | fileIdColname = fileIdAttribute,
139 | populationSize = populationSize,
140 | attributeSpecs = matchingAttributes,
141 | parameters = parameters,
142 | partitionFunction = partitionFunction,
143 | randomSeed = randomSeed
144 | )
145 | }
146 |
147 | def savedSharedMostProbableClusters: Option[Dataset[Cluster]] = {
148 | if (sharedMostProbableClustersOnDisk) {
149 | Some(readClustersCSV(outputPath + "shared-most-probable-clusters.csv"))
150 | } else None
151 | }
152 |
153 | /**
154 | * Loads the ground truth clustering, if available.
155 | */
156 | def trueClusters: Option[Dataset[Cluster]] = {
157 | entIdAttribute match {
158 | case Some(eId) =>
159 | val spark = dataFrame.sparkSession
160 | val recIdName = recIdAttribute
161 | import spark.implicits._
162 | val membership = dataFrame.map(r => (r.getAs[RecordId](recIdName), r.getAs[EntityId](eId)))
163 | Some(membershipToClusters(membership))
164 | case _ => None
165 | }
166 | }
167 | }
168 |
169 | object Project {
170 | def apply(config: Config): Project = {
171 | val dataPath = config.getString("dblink.data.path")
172 |
173 | val dataFrame: DataFrame = {
174 | val spark = SparkSession.builder().getOrCreate()
175 | spark.read.format("csv")
176 | .option("header", "true")
177 | .option("mode", "DROPMALFORMED")
178 | .option("nullValue", config.getString("dblink.data.nullValue"))
179 | .load(dataPath)
180 | }
181 |
182 | val matchingAttributes =
183 | parseMatchingAttributes(config.getObjectList("dblink.data.matchingAttributes"))
184 |
185 | Project(
186 | dataPath = dataPath,
187 | outputPath = config.getString("dblink.outputPath"),
188 | checkpointPath = config.getString("dblink.checkpointPath"),
189 | recIdAttribute = config.getString("dblink.data.recordIdentifier"),
190 | fileIdAttribute = Try {Some(config.getString("dblink.data.fileIdentifier"))} getOrElse None,
191 | entIdAttribute = Try {Some(config.getString("dblink.data.entityIdentifier"))} getOrElse None,
192 | matchingAttributes = matchingAttributes,
193 | partitionFunction = parsePartitioner(config.getConfig("dblink.partitioner"), matchingAttributes.map(_.name)),
194 | randomSeed = config.getLong("dblink.randomSeed"),
195 | populationSize = Try {Some(config.getInt("dblink.populationSize"))} getOrElse None,
196 | expectedMaxClusterSize = Try {config.getInt("dblink.expectedMaxClusterSize")} getOrElse 10,
197 | dataFrame = dataFrame
198 | )
199 | }
200 |
201 | implicit def toConfigTraversable[T <: ConfigObject](objectList: java.util.List[T]): Traversable[Config] = objectList.asScala.map(_.toConfig)
202 |
203 | private def parseMatchingAttributes(configList: Traversable[Config]): Array[Attribute] = {
204 | configList.map { c =>
205 | val simFn = c.getString("similarityFunction.name") match {
206 | case "ConstantSimilarityFn" => ConstantSimilarityFn
207 | case "LevenshteinSimilarityFn" =>
208 | LevenshteinSimilarityFn(c.getDouble("similarityFunction.parameters.threshold"), c.getDouble("similarityFunction.parameters.maxSimilarity"))
209 | case _ => throw new ConfigException.BadValue(c.origin(), "similarityFunction.name", "unsupported value")
210 | }
211 | val distortionPrior = BetaShapeParameters(
212 | c.getDouble("distortionPrior.alpha"),
213 | c.getDouble("distortionPrior.beta")
214 | )
215 | Attribute(c.getString("name"), simFn, distortionPrior)
216 | }.toArray
217 | }
218 |
219 | private def parsePartitioner(config: Config, attributeNames: Seq[String]): KDTreePartitioner[ValueId] = {
220 | if (config.getString("name") == "KDTreePartitioner") {
221 | val numLevels = config.getInt("parameters.numLevels")
222 | val attributeIds = config.getStringList("parameters.matchingAttributes").asScala.map( n =>
223 | attributeNames.indexOf(n)
224 | )
225 | KDTreePartitioner[ValueId](numLevels, attributeIds)
226 | } else {
227 | throw new ConfigException.BadValue(config.origin(), "name", "unsupported value")
228 | }
229 | }
230 | }
--------------------------------------------------------------------------------