├── project ├── build.properties └── plugins.sbt ├── .idea ├── codeStyles │ ├── codeStyleConfig.xml │ └── Project.xml ├── vcs.xml ├── scala_compiler.xml ├── misc.xml ├── hydra.xml ├── scala_settings.xml ├── modules.xml └── sbt.xml ├── CONTRIBUTORS.md ├── LICENSE ├── docs ├── macos-java8.md └── guide.md ├── src ├── main │ ├── scala │ │ └── com │ │ │ └── github │ │ │ └── cleanzr │ │ │ └── dblink │ │ │ ├── Parameters.scala │ │ │ ├── util │ │ │ ├── HardPartitioner.scala │ │ │ ├── PathToFileConverter.scala │ │ │ ├── BufferedFileWriter.scala │ │ │ ├── BufferedRDDWriter.scala │ │ │ ├── PeriodicRDDCheckpointer.scala │ │ │ └── PeriodicCheckpointer.scala │ │ │ ├── analysis │ │ │ ├── BinaryClassificationMetrics.scala │ │ │ ├── baselines.scala │ │ │ ├── PairwiseMetrics.scala │ │ │ ├── BinaryConfusionMatrix.scala │ │ │ ├── ClusteringContingencyTable.scala │ │ │ ├── ClusteringMetrics.scala │ │ │ └── package.scala │ │ │ ├── partitioning │ │ │ ├── PartitionFunction.scala │ │ │ ├── SimplePartitioner.scala │ │ │ ├── LPTScheduler.scala │ │ │ ├── DomainSplitter.scala │ │ │ ├── MutableBST.scala │ │ │ └── KDTreePartitioner.scala │ │ │ ├── DistortionProbs.scala │ │ │ ├── Run.scala │ │ │ ├── Logging.scala │ │ │ ├── accumulators │ │ │ ├── MapLongAccumulator.scala │ │ │ ├── MapDoubleAccumulator.scala │ │ │ └── MapArrayAccumulator.scala │ │ │ ├── random │ │ │ ├── DiscreteDist.scala │ │ │ ├── IndexNonUniformDiscreteDist.scala │ │ │ ├── NonUniformDiscreteDist.scala │ │ │ └── AliasSampler.scala │ │ │ ├── SummaryAccumulators.scala │ │ │ ├── ProjectSteps.scala │ │ │ ├── DiagnosticsWriter.scala │ │ │ ├── SimilarityFn.scala │ │ │ ├── CustomKryoRegistrator.scala │ │ │ ├── package.scala │ │ │ ├── RecordsCache.scala │ │ │ ├── Sampler.scala │ │ │ ├── ProjectStep.scala │ │ │ ├── LinkageChain.scala │ │ │ ├── AttributeIndex.scala │ │ │ └── Project.scala │ └── resources │ │ └── log4j.properties └── test │ └── scala │ ├── Launch.scala │ └── com │ └── github │ └── cleanzr │ └── dblink │ ├── EntityInvertedIndexTest.scala │ ├── AttributeTest.scala │ ├── random │ ├── DiscreteDistTest.scala │ ├── DiscreteDistBehavior.scala │ └── AliasSamplerTest.scala │ ├── DistortionProbsTest.scala │ ├── AttributeIndexBehaviors.scala │ ├── SimilarityFnTest.scala │ └── AttributeIndexTest.scala ├── .gitignore ├── examples ├── RLdata500.conf └── RLdata10000.conf └── README.md /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 1.2.8 -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.7") 2 | 3 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.1.1") -------------------------------------------------------------------------------- /.idea/codeStyles/codeStyleConfig.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/scala_compiler.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/codeStyles/Project.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | # dblink contributors 2 | 3 | The following people (sorted alphabetically) have contributed to coding, documentation and testing efforts. 4 | 5 | * [Andee Kaplan](http://andeekaplan.com/) 6 | * Neil G. Marchant (partially supported by Australian Bureau of Statistics) 7 | * [Rebecca C. Steorts](https://resteorts.github.io/) 8 | -------------------------------------------------------------------------------- /.idea/hydra.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 9 | -------------------------------------------------------------------------------- /.idea/scala_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 12 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /.idea/sbt.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | dblink: Empirical Bayes Entity Resolution for Spark 2 | 3 | Copyright (C) 2018-2019 dblink contributors 4 | Copyright (C) 2018 Australian Bureau of Statistics 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with this program. If not, see . 18 | -------------------------------------------------------------------------------- /docs/macos-java8.md: -------------------------------------------------------------------------------- 1 | # Installing Java 8+ on macOS 2 | 3 | We recommend using the [AdoptOpenJDK](https://adoptopenjdk.net) Java 4 | distribution on macOS. 5 | 6 | It can be installed using the [Homebrew](https://brew.sh/) package manager. 7 | Simply run the following commands in a terminal: 8 | ```bash 9 | $ brew tap AdoptOpenJDK/openjdk 10 | $ brew cask install adoptopenjdk8 11 | ``` 12 | 13 | Note that it's possible to have multiple versions of Java installed in 14 | parallel. If you run 15 | ```bash 16 | $ java -version 17 | ``` 18 | and don't see a version number like 1.8.x or 8.x, then you'll need to 19 | manually select Java 8. 20 | 21 | To do this temporarily in a terminal, run 22 | ``` 23 | $ export JAVA_HOME=$(/usr/libexec/java_home -v1.8) 24 | ``` 25 | All references to Java within this session will then make use of Java 8. 26 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/Parameters.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | case class Parameters(maxClusterSize: Int) { 23 | require(maxClusterSize > 0, "`maxClusterSize` must be a positive integer.") 24 | 25 | def mkString: String = { 26 | "maxClusterSize: " + maxClusterSize 27 | } 28 | } -------------------------------------------------------------------------------- /src/test/scala/Launch.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | import com.github.cleanzr.dblink.Run 21 | import org.apache.spark.{SparkConf, SparkContext} 22 | 23 | object Launch { 24 | def main(args: Array[String]) { 25 | val conf = new SparkConf().setMaster("local[2]").setAppName("dblink") 26 | val sc = SparkContext.getOrCreate(conf) 27 | Run.main(args) 28 | } 29 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/util/HardPartitioner.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.util 21 | 22 | import org.apache.spark.Partitioner 23 | 24 | class HardPartitioner(override val numPartitions: Int) extends Partitioner { 25 | override def getPartition(key: Any): Int = { 26 | val k = key.asInstanceOf[Int] 27 | k % numPartitions 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/util/PathToFileConverter.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.util 21 | 22 | import java.io.File 23 | 24 | import org.apache.hadoop.fs.{FileSystem, Path} 25 | import org.apache.hadoop.conf.Configuration 26 | 27 | object PathToFileConverter { 28 | 29 | def fileToPath(path: Path, conf: Configuration): File = { 30 | val fs = FileSystem.get(path.toUri, conf) 31 | val tempFile = File.createTempFile(path.getName, "") 32 | tempFile.deleteOnExit() 33 | fs.copyToLocalFile(path, new Path(tempFile.getAbsolutePath)) 34 | tempFile 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/test/scala/com/github/cleanzr/dblink/EntityInvertedIndexTest.scala: -------------------------------------------------------------------------------- 1 | package com.github.cleanzr.dblink 2 | 3 | import com.github.cleanzr.dblink.GibbsUpdates.EntityInvertedIndex 4 | import org.scalatest._ 5 | 6 | class EntityInvertedIndexTest extends FlatSpec { 7 | 8 | def entityIndex = new EntityInvertedIndex 9 | 10 | val allEntityIds = Set(1, 2, 3, 4) 11 | 12 | lazy val entitiesWithOneDup = Seq( 13 | (1, Entity(Array(1, 0, 0))), 14 | (2, Entity(Array(2, 0, 1))), 15 | (3, Entity(Array(3, 0, 1))), 16 | (4, Entity(Array(4, 1, 0))), 17 | (4, Entity(Array(4, 1, 0))) // same entity twice 18 | ) 19 | 20 | behavior of "An entity inverted index (containing one entity)" 21 | 22 | it should "return singleton sets for the entity's attribute values" in { 23 | val index = new EntityInvertedIndex 24 | index.add(entitiesWithOneDup.head._1, entitiesWithOneDup.head._2) 25 | assert(entitiesWithOneDup.head._2.values.zipWithIndex.forall {case (valueId, attrId) => index.getEntityIds(attrId, valueId).size === 1}) 26 | } 27 | 28 | behavior of "An entity inverted index (containing multiple entities)" 29 | 30 | it should "return the correct responses to queries" in { 31 | val index = new EntityInvertedIndex 32 | entitiesWithOneDup.foreach { case (entId, entity) => index.add(entId, entity) } 33 | assert(index.getEntityIds(0, 1) === Set(1)) 34 | assert(index.getEntityIds(0, 4) === Set(4)) 35 | assert(index.getEntityIds(1, 0) === Set(1, 2, 3)) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/test/scala/com/github/cleanzr/dblink/AttributeTest.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import org.scalatest.FlatSpec 23 | import com.github.cleanzr.dblink.SimilarityFn._ 24 | 25 | class AttributeTest extends FlatSpec { 26 | behavior of "An attribute with a constant similarity function" 27 | 28 | it should "be constant" in { 29 | assert(Attribute("name", ConstantSimilarityFn, BetaShapeParameters(1.0, 1.0)).isConstant === true) 30 | } 31 | 32 | behavior of "An attribute with a non-constant similarity function" 33 | 34 | it should "not be constant" in { 35 | assert(Attribute("name", LevenshteinSimilarityFn(), BetaShapeParameters(1.0, 1.0)).isConstant === false) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 2 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 3 | 4 | # User-specific stuff: 5 | .idea/**/workspace.xml 6 | .idea/**/tasks.xml 7 | .idea/dictionaries 8 | 9 | # Sensitive or high-churn files: 10 | .idea/**/dataSources/ 11 | .idea/**/dataSources.ids 12 | .idea/**/dataSources.xml 13 | .idea/**/dataSources.local.xml 14 | .idea/**/sqlDataSources.xml 15 | .idea/**/dynamic.xml 16 | .idea/**/uiDesigner.xml 17 | 18 | # Gradle: 19 | .idea/**/gradle.xml 20 | .idea/**/libraries 21 | 22 | # CMake 23 | cmake-build-debug/ 24 | cmake-build-release/ 25 | 26 | # Mongo Explorer plugin: 27 | .idea/**/mongoSettings.xml 28 | 29 | ## File-based project format: 30 | *.iws 31 | 32 | ## Plugin-specific files: 33 | 34 | # IntelliJ 35 | out/ 36 | 37 | # mpeltonen/sbt-idea plugin 38 | .idea_modules/ 39 | 40 | # JIRA plugin 41 | atlassian-ide-plugin.xml 42 | 43 | # Cursive Clojure plugin 44 | .idea/replstate.xml 45 | 46 | # Crashlytics plugin (for Android Studio and IntelliJ) 47 | com_crashlytics_export_strings.xml 48 | crashlytics.properties 49 | crashlytics-build.properties 50 | fabric.properties 51 | 52 | 53 | # Simple Build Tool 54 | # http://www.scala-sbt.org/release/docs/Getting-Started/Directories.html#configuring-version-control 55 | 56 | dist/* 57 | target/ 58 | lib_managed/ 59 | src_managed/ 60 | project/boot/ 61 | project/plugins/project/ 62 | .history 63 | .cache 64 | .lib/ 65 | 66 | # Scala 67 | *.class 68 | *.log -------------------------------------------------------------------------------- /src/test/scala/com/github/cleanzr/dblink/random/DiscreteDistTest.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.random 21 | 22 | import org.apache.commons.math3.random.{MersenneTwister, RandomGenerator} 23 | import org.scalatest.FlatSpec 24 | 25 | class DiscreteDistTest extends FlatSpec with DiscreteDistBehavior { 26 | 27 | implicit val rand: RandomGenerator = new MersenneTwister() 28 | 29 | def valuesWeights = Map("A" -> 100.0, "B" -> 200.0, "C" -> 700.0) 30 | val indexDist = DiscreteDist(valuesWeights.values) 31 | val dist = DiscreteDist(valuesWeights) 32 | 33 | "A discrete distribution (without values given)" should behave like genericDiscreteDist(indexDist, 5) 34 | 35 | "A discrete distribution (with values given)" should behave like genericDiscreteDist(dist, "D") 36 | } 37 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/analysis/BinaryClassificationMetrics.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.analysis 21 | 22 | object BinaryClassificationMetrics { 23 | def precision(binaryConfusionMatrix: BinaryConfusionMatrix): Double = { 24 | binaryConfusionMatrix.TP.toDouble / binaryConfusionMatrix.PP 25 | } 26 | 27 | def recall(binaryConfusionMatrix: BinaryConfusionMatrix): Double = { 28 | binaryConfusionMatrix.TP.toDouble / binaryConfusionMatrix.P 29 | } 30 | 31 | def fMeasure(binaryConfusionMatrix: BinaryConfusionMatrix, 32 | beta: Double = 1.0): Double = { 33 | val betaSq = beta * beta 34 | val pr = precision(binaryConfusionMatrix) 35 | val re = recall(binaryConfusionMatrix) 36 | (1 + betaSq) * pr * re / (betaSq * pr + re) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/partitioning/PartitionFunction.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.partitioning 21 | 22 | import org.apache.spark.rdd.RDD 23 | 24 | abstract class PartitionFunction[T] extends Serializable { 25 | 26 | /** Number of partitions */ 27 | def numPartitions: Int 28 | 29 | /** Fit partition function based on a sample of records 30 | * 31 | * @param records RDD of records. Assumes no values are missing. 32 | */ 33 | def fit(records: RDD[Array[T]]): Unit 34 | 35 | /** Get assigned partition for a set of attribute values 36 | * 37 | * @param attributeValues array of (entity) attribute values. Assumes no values are missing. 38 | * @return identifier of assigned parittion: an integer in {0, ..., numPartitions - 1}. 39 | */ 40 | def getPartitionId(attributeValues: Array[T]): Int 41 | 42 | def mkString: String 43 | } -------------------------------------------------------------------------------- /src/test/scala/com/github/cleanzr/dblink/DistortionProbsTest.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import org.scalatest.FlatSpec 23 | 24 | class DistortionProbsTest extends FlatSpec { 25 | 26 | val twoFilesOneAttribute = DistortionProbs(Seq("A", "B"), Iterator(BetaShapeParameters(3.0, 3.0))) 27 | 28 | behavior of "Distortion probabilities for two files and one attribute (with identical shape parameters)" 29 | 30 | it should "return probability 0.5 for both files" in { 31 | assert(twoFilesOneAttribute(0, "A") === 0.5) 32 | assert(twoFilesOneAttribute(0, "B") === 0.5) 33 | } 34 | 35 | it should "complain when a probability is requested for a third file" in { 36 | assertThrows[NoSuchElementException] { 37 | twoFilesOneAttribute(0, "C") 38 | } 39 | } 40 | 41 | it should "complain when a probability is requested for a second attribute" in { 42 | assertThrows[NoSuchElementException] { 43 | twoFilesOneAttribute(1, "A") 44 | } 45 | } 46 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/DistortionProbs.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | /** Container for the distortion probabilities (which vary for each file 23 | * and attribute). 24 | * 25 | * @param probs map used internally to store the probabilities 26 | */ 27 | case class DistortionProbs(probs: Map[(AttributeId, FileId), Double]) { 28 | def apply(attrId: AttributeId, fileId: FileId): Double = probs.apply((attrId, fileId)) 29 | } 30 | 31 | object DistortionProbs { 32 | /** Generate distortion probabilities based on the prior mean */ 33 | def apply(fileIds: Traversable[FileId], 34 | distortionPrior: Iterator[BetaShapeParameters]): DistortionProbs = { 35 | require(fileIds.nonEmpty, "`fileIds` cannot be empty") 36 | require(distortionPrior.nonEmpty, "`distortionPrior` cannot be empty") 37 | 38 | val probs = distortionPrior.zipWithIndex.flatMap { case (BetaShapeParameters(alpha, beta), attrId) => 39 | fileIds.map { fileId => ((attrId, fileId), alpha / (alpha + beta)) } 40 | }.toMap 41 | 42 | DistortionProbs(probs) 43 | } 44 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/Run.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import com.github.cleanzr.dblink.util.{BufferedFileWriter, PathToFileConverter} 23 | import com.typesafe.config.ConfigFactory 24 | import org.apache.spark.sql.SparkSession 25 | import org.apache.hadoop.fs.Path 26 | 27 | object Run extends App with Logging { 28 | 29 | val spark = SparkSession.builder().appName("dblink").getOrCreate() 30 | val sc = spark.sparkContext 31 | sc.setLogLevel("WARN") 32 | 33 | val configFile = PathToFileConverter.fileToPath(new Path(args.head), sc.hadoopConfiguration) 34 | 35 | val config = ConfigFactory.parseFile(configFile).resolve() 36 | 37 | val project = Project(config) 38 | val writer = BufferedFileWriter(project.outputPath + "run.txt", append = false, sparkContext = project.sparkContext) 39 | writer.write(project.mkString) 40 | 41 | val steps = ProjectSteps(config, project) 42 | writer.write("\n" + steps.mkString) 43 | writer.close() 44 | 45 | sc.setCheckpointDir(project.checkpointPath) 46 | 47 | steps.execute() 48 | 49 | sc.stop() 50 | } 51 | -------------------------------------------------------------------------------- /src/test/scala/com/github/cleanzr/dblink/AttributeIndexBehaviors.scala: -------------------------------------------------------------------------------- 1 | package com.github.cleanzr.dblink 2 | 3 | import org.scalatest.{FlatSpec, Matchers} 4 | 5 | trait AttributeIndexBehaviors extends Matchers { this: FlatSpec => 6 | 7 | def genericAttributeIndex(index: AttributeIndex, valuesWeights: Map[String, Double]) { 8 | it should "have the correct number of values" in { 9 | assert(index.numValues === valuesWeights.size) 10 | } 11 | 12 | it should "assign unique value ids in the set {0, ..., numValues - 1}" in { 13 | assert(valuesWeights.keysIterator.map(stringValue => index.valueIdxOf(stringValue)).toSet === 14 | (0 until valuesWeights.size).toSet) 15 | } 16 | 17 | it should "have the correct probability distribution" in { 18 | val totalWeight = valuesWeights.foldLeft(0.0)((total, x) => total + x._2) 19 | assert(valuesWeights.mapValues(_/totalWeight).forall { case (stringValue, correctProb) => 20 | index.probabilityOf(index.valueIdxOf(stringValue)) === (correctProb +- 1e-4) 21 | }) 22 | } 23 | 24 | it should "complain when a probability is requested for a non-indexed value" in { 25 | assertThrows[RuntimeException] { 26 | index.probabilityOf(index.numValues + 1) 27 | } 28 | } 29 | 30 | it should "complain when a similarity normalization is requested for a non-indexed value" in { 31 | assertThrows[RuntimeException] { 32 | index.simNormalizationOf(index.numValues + 1) 33 | } 34 | } 35 | 36 | it should "complain when similar values are requested for a non-indexed value" in { 37 | assertThrows[RuntimeException] { 38 | index.simValuesOf(index.numValues + 1) 39 | } 40 | } 41 | 42 | it should "complain when an exponentiated similarity score is requested for a non-indexed value" in { 43 | assertThrows[RuntimeException] { 44 | index.expSimOf(index.numValues + 1, 0) 45 | } 46 | assertThrows[RuntimeException] { 47 | index.expSimOf(0, index.numValues + 1) 48 | } 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/Logging.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import org.apache.log4j.Logger 23 | import org.apache.log4j.Level 24 | 25 | trait Logging { 26 | private[this] lazy val logger = Logger.getLogger(logName) 27 | 28 | private[this] lazy val logName: String = this.getClass.getName.stripSuffix("$") 29 | 30 | def info(message: => String): Unit = if (logger.isEnabledFor(Level.INFO)) logger.info(message) 31 | 32 | def warn(message: => String): Unit = if (logger.isEnabledFor(Level.WARN)) logger.warn(message) 33 | 34 | def warn(message: => String, t: Throwable): Unit = if (logger.isEnabledFor(Level.WARN)) logger.warn(message, t) 35 | 36 | def error(message: => String): Unit = if (logger.isEnabledFor(Level.ERROR)) logger.error(message) 37 | 38 | def error(message: => String, t: Throwable): Unit = if (logger.isEnabledFor(Level.ERROR)) logger.error(message, t) 39 | 40 | def fatal(message: => String): Unit = if (logger.isEnabledFor(Level.FATAL)) logger.fatal(message) 41 | 42 | def fatal(message: => String, t: Throwable): Unit = if (logger.isEnabledFor(Level.FATAL)) logger.fatal(message, t) 43 | 44 | def debug(message: => String): Unit = if (logger.isEnabledFor(Level.DEBUG)) logger.debug(message) 45 | } 46 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/accumulators/MapLongAccumulator.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.accumulators 21 | 22 | import org.apache.spark.util.AccumulatorV2 23 | import scala.collection.mutable 24 | 25 | /** 26 | * Accumulates counts corresponding to keys. 27 | * e.g. if we add K1 to the accumulator and (K1 -> 10L) is the current 28 | * key-value pair, the resulting key value pair will be (K1 -> 11L). 29 | * 30 | * @tparam K key type 31 | */ 32 | class MapLongAccumulator[K] extends AccumulatorV2[(K, Long), Map[K, Long]] { 33 | private val _map = mutable.HashMap.empty[K, Long] 34 | 35 | override def reset(): Unit = _map.clear() 36 | 37 | override def add(kv: (K, Long)): Unit = { 38 | _map.update(kv._1, _map.getOrElse(kv._1, 0l) + kv._2) 39 | } 40 | 41 | override def value: Map[K, Long] = _map.toMap 42 | 43 | override def isZero: Boolean = _map.isEmpty 44 | 45 | override def copy(): MapLongAccumulator[K] = { 46 | val newAcc = new MapLongAccumulator[K] 47 | newAcc._map ++= _map 48 | newAcc 49 | } 50 | 51 | def toIterator: Iterator[(K, Long)] = _map.iterator 52 | 53 | override def merge(other: AccumulatorV2[(K, Long), Map[K, Long]]): Unit = other match { 54 | case o: MapLongAccumulator[K] => o.toIterator.foreach { x => this.add(x) } 55 | case _ => 56 | throw new UnsupportedOperationException( 57 | s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") 58 | } 59 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/analysis/baselines.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.analysis 21 | 22 | import com.github.cleanzr.dblink.{Cluster, RecordId} 23 | import org.apache.spark.rdd.RDD 24 | 25 | object baselines { 26 | def exactMatchClusters(records: RDD[(RecordId, Seq[String])]): RDD[Cluster] = { 27 | records.map(row => (row._2.mkString, row._1)) // (values, id) 28 | .aggregateByKey(Set.empty[RecordId]) ( 29 | seqOp = (recIds, recId) => recIds + recId, 30 | combOp = (recIdsA, recIdsB) => recIdsA ++ recIdsB 31 | ).map(_._2) 32 | } 33 | 34 | /** Generates (overlapping) clusters that are near matches based on attribute agreements/disagreements 35 | * 36 | * @param records 37 | * @param numDisagree 38 | * @return 39 | */ 40 | def nearClusters(records: RDD[(RecordId, Seq[String])], numDisagree: Int): RDD[Cluster] = { 41 | require(numDisagree >= 0, "`numAgree` must be non-negative") 42 | 43 | records.flatMap { row => 44 | val numAttr = row._2.length 45 | val attrIds = 0 until numAttr 46 | attrIds.combinations(numDisagree).map { delIds => 47 | val partialValues = row._2.zipWithIndex.collect { case (value, attrId) if !delIds.contains(attrId) => value } 48 | (partialValues.mkString, row._1) 49 | } 50 | }.aggregateByKey(Set.empty[RecordId])( 51 | seqOp = (recIds, recId) => recIds + recId, 52 | combOp = (recIdsA, recIdsB) => recIdsA ++ recIdsB 53 | ).map(_._2) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/accumulators/MapDoubleAccumulator.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.accumulators 21 | 22 | import org.apache.spark.util.AccumulatorV2 23 | 24 | import scala.collection.mutable 25 | 26 | /** 27 | * Accumulates counts corresponding to keys. 28 | * e.g. if we add K1 to the accumulator and (K1 -> 10L) is the current 29 | * key-value pair, the resulting key value pair will be (K1 -> 11L). 30 | * 31 | * @tparam K key type 32 | */ 33 | class MapDoubleAccumulator[K] extends AccumulatorV2[(K, Double), Map[K, Double]] { 34 | private val _map = mutable.HashMap.empty[K, Double] 35 | 36 | override def reset(): Unit = _map.clear() 37 | 38 | override def add(kv: (K, Double)): Unit = { 39 | _map.update(kv._1, _map.getOrElse(kv._1, 0.0) + kv._2) 40 | } 41 | 42 | override def value: Map[K, Double] = _map.toMap 43 | 44 | override def isZero: Boolean = _map.isEmpty 45 | 46 | override def copy(): MapDoubleAccumulator[K] = { 47 | val newAcc = new MapDoubleAccumulator[K] 48 | newAcc._map ++= _map 49 | newAcc 50 | } 51 | 52 | def toIterator: Iterator[(K, Double)] = _map.iterator 53 | 54 | override def merge(other: AccumulatorV2[(K, Double), Map[K, Double]]): Unit = other match { 55 | case o: MapDoubleAccumulator[K] => o.toIterator.foreach { x => this.add(x) } 56 | case _ => 57 | throw new UnsupportedOperationException( 58 | s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") 59 | } 60 | } -------------------------------------------------------------------------------- /src/test/scala/com/github/cleanzr/dblink/random/DiscreteDistBehavior.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.random 21 | 22 | import org.apache.commons.math3.random.RandomGenerator 23 | import org.scalatest.{FlatSpec, Matchers} 24 | 25 | trait DiscreteDistBehavior extends Matchers { this: FlatSpec => 26 | 27 | def genericDiscreteDist[T](dist: DiscreteDist[T], valueOutsideSupport: T)(implicit rand: RandomGenerator): Unit = { 28 | it should "be normalised" in { 29 | assert(dist.values.foldLeft(0.0) { case (sum, v) => sum + dist.probabilityOf(v) } === (1.0 +- 1e-9) ) 30 | } 31 | 32 | it should "return valid probabilities for all values in the support" in { 33 | assert(dist.values.forall {value => 34 | val prob = dist.probabilityOf(value) 35 | prob >= 0 && !prob.isInfinity && !prob.isNaN && prob <= 1}) 36 | } 37 | 38 | it should "return a probability of 0.0 for a value outside the support" in { 39 | assert(dist.probabilityOf(valueOutsideSupport) === 0.0) 40 | } 41 | 42 | it should "return `numValues` equal to the size of `values`" in { 43 | assert(dist.numValues === dist.values.size) 44 | } 45 | 46 | it should "return a valid `totalWeight`" in { 47 | assert(dist.totalWeight > 0) 48 | assert(!dist.totalWeight.isNaN && !dist.totalWeight.isInfinity) 49 | } 50 | 51 | it should "not return sample values that have probability 0.0" in { 52 | assert((1 to 1000).map(_ => dist.sample()).forall(v => dist.probabilityOf(v) > 0.0)) 53 | } 54 | } 55 | 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/random/DiscreteDist.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.random 21 | 22 | import org.apache.commons.math3.random.RandomGenerator 23 | 24 | import scala.collection.Map 25 | import scala.reflect.ClassTag 26 | 27 | /** A distribution over a discrete set of values 28 | * 29 | * @tparam T type of the values 30 | */ 31 | trait DiscreteDist[T] extends Serializable { 32 | 33 | /** Values in the support set */ 34 | def values: Traversable[T] 35 | 36 | /** Number of values in the support set */ 37 | def numValues: Int 38 | 39 | /** Total weight before normalization */ 40 | def totalWeight: Double 41 | 42 | /** Draw a value according to the distribution 43 | * 44 | * @param rand external RandomGenerator to use for drawing sample 45 | * @return a value from the support set 46 | */ 47 | def sample()(implicit rand: RandomGenerator): T 48 | 49 | /** Get the probability mass associated with a value 50 | * 51 | * @param value a value from the support set 52 | * @return probability. Returns 0.0 if the value is not in the support set. 53 | */ 54 | def probabilityOf(value: T): Double 55 | 56 | /** Iterator over the values in the support, together with their probability 57 | * mass 58 | */ 59 | def toIterator: Iterator[(T, Double)] 60 | } 61 | 62 | object DiscreteDist { 63 | def apply[T](valuesAndWeights: Map[T, Double]) 64 | (implicit ev: ClassTag[T]): NonUniformDiscreteDist[T] = { 65 | NonUniformDiscreteDist[T](valuesAndWeights) 66 | } 67 | 68 | def apply(weights: Traversable[Double]): IndexNonUniformDiscreteDist = { 69 | IndexNonUniformDiscreteDist(weights) 70 | } 71 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/partitioning/SimplePartitioner.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.partitioning 21 | 22 | import SimplePartitioner._ 23 | import org.apache.spark.rdd.RDD 24 | import scala.reflect.ClassTag 25 | 26 | /** 27 | * Block on the values of a single field, then combine to get roughly 28 | * equal-size partitions (based on empirical distribution) 29 | * 30 | * @param attributeId 31 | * @param numPartitions 32 | */ 33 | class SimplePartitioner[T : ClassTag : Ordering](val attributeId: Int, 34 | override val numPartitions: Int) 35 | extends PartitionFunction[T] { 36 | 37 | private var _lpt: LPTScheduler[T, Double] = _ 38 | 39 | override def fit(records: RDD[Array[T]]): Unit = { 40 | val weights = records.map(values => (values(attributeId), 1.0)) 41 | .reduceByKey(_ + _) 42 | .collect() 43 | _lpt = generateLPTScheduler(weights, numPartitions) 44 | } 45 | 46 | override def getPartitionId(values: Array[T]): Int = { 47 | if (_lpt == null) -1 48 | else _lpt.getPartitionId(values(attributeId)) 49 | } 50 | 51 | override def mkString: String = s"SimplePartitioner(attributeId=$attributeId, numPartitions=$numPartitions)" 52 | } 53 | 54 | object SimplePartitioner { 55 | private def generateLPTScheduler[T](weights: Array[(T, Double)], 56 | numPartitions: Int): LPTScheduler[T, Double] = { 57 | new LPTScheduler[T, Double](weights, numPartitions) 58 | } 59 | 60 | def apply[T : ClassTag : Ordering](attributeId: Int, 61 | numPartitions: Int): SimplePartitioner[T] = { 62 | new SimplePartitioner(attributeId, numPartitions) 63 | } 64 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/analysis/PairwiseMetrics.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.analysis 21 | 22 | import com.github.cleanzr.dblink.RecordPair 23 | import org.apache.spark.sql.Dataset 24 | 25 | case class PairwiseMetrics(precision: Double, 26 | recall: Double, 27 | f1score: Double) { 28 | def mkString: String = { 29 | "=====================================\n" + 30 | " Pairwise metrics \n" + 31 | "-------------------------------------\n" + 32 | s" Precision: $precision\n" + 33 | s" Recall: $recall\n" + 34 | s" F1-score: $f1score\n" + 35 | "=====================================\n" 36 | } 37 | 38 | def print(): Unit = { 39 | Console.print(mkString) 40 | } 41 | } 42 | 43 | object PairwiseMetrics { 44 | def LinksConfusionMatrix(predictedLinks: Dataset[RecordPair], 45 | trueLinks: Dataset[RecordPair]): BinaryConfusionMatrix = { 46 | // Create PairRDDs so we can use the fullOuterJoin function 47 | val predictedLinks2 = predictedLinks.rdd.map(pair => (pair, true)) 48 | val trueLinks2 = trueLinks.rdd.map(pair => (pair, true)) 49 | val joinedLinks = predictedLinks2.fullOuterJoin(trueLinks2) 50 | .map { case (_, link) => (link._1.isDefined, link._2.isDefined)} 51 | BinaryConfusionMatrix(joinedLinks) 52 | } 53 | 54 | def apply(predictedLinks: Dataset[RecordPair], 55 | trueLinks: Dataset[RecordPair]): PairwiseMetrics = { 56 | val confusionMatrix = LinksConfusionMatrix(predictedLinks, trueLinks) 57 | 58 | val precision = BinaryClassificationMetrics.precision(confusionMatrix) 59 | val recall = BinaryClassificationMetrics.recall(confusionMatrix) 60 | val f1score = BinaryClassificationMetrics.fMeasure(confusionMatrix, beta = 1.0) 61 | 62 | PairwiseMetrics(precision, recall, f1score) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /src/main/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | # Set everything to be logged to the console 19 | log4j.rootCategory=INFO, console, file 20 | log4j.appender.console=org.apache.log4j.ConsoleAppender 21 | log4j.appender.console.target=System.err 22 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 23 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 24 | 25 | # Set up logging to file 26 | log4j.appender.file=org.apache.log4j.RollingFileAppender 27 | log4j.appender.file.File=./dblink.log 28 | log4j.appender.file.ImmediateFlush=true 29 | ## Set the append to false, overwrite 30 | log4j.appender.file.Append=true 31 | log4j.appender.file.MaxFileSize=100MB 32 | log4j.appender.file.MaxBackupIndex=10 33 | ##Define the layout for file appender 34 | log4j.appender.file.layout=org.apache.log4j.PatternLayout 35 | log4j.appender.file.layout.ConversionPattern=%d{yy-MM-dd HH:mm:ss} %p %c{1}: %m%n 36 | 37 | log4j.logger.com.github.cleanzr.dblink=INFO 38 | 39 | # Set the default spark-shell log level to WARN. When running the spark-shell, the 40 | # log level for this class is used to overwrite the root logger's log level, so that 41 | # the user can have different defaults for the shell and regular Spark apps. 42 | log4j.logger.org.apache.spark.repl.Main=WARN 43 | 44 | # Settings to quiet third party logs that are too verbose 45 | log4j.logger.org.spark_project.jetty=WARN 46 | log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR 47 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO 48 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO 49 | log4j.logger.org.apache.parquet=ERROR 50 | log4j.logger.parquet=ERROR 51 | 52 | # SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support 53 | log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL 54 | log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/accumulators/MapArrayAccumulator.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.accumulators 21 | 22 | import org.apache.spark.util.AccumulatorV2 23 | 24 | import scala.collection.mutable 25 | 26 | /** 27 | * Accumulates arrays of values corresponding to keys. 28 | * e.g. if we add (K1 -> "C") to the accumulator and (K1 -> ArrayBuffer("A","B")) 29 | * is the current key-value pair, the resulting key value pair will be 30 | * (K1 -> ArrayBuffer("A", B", "C")). 31 | * 32 | * @tparam K key type 33 | * @tparam V value type (for elements of set) 34 | */ 35 | class MapArrayAccumulator[K,V] extends AccumulatorV2[(K,V), Map[K, mutable.ArrayBuffer[V]]] { 36 | private val _map = mutable.HashMap.empty[K, mutable.ArrayBuffer[V]] 37 | 38 | override def reset(): Unit = _map.clear() 39 | 40 | override def add(x: (K,V)): Unit = { 41 | val _array = _map.getOrElseUpdate(x._1, mutable.ArrayBuffer.empty[V]) 42 | _array += x._2 43 | } 44 | 45 | private def add_ks(x: (K, Traversable[V])): Unit = { 46 | val _array = _map.getOrElseUpdate(x._1, mutable.ArrayBuffer.empty[V]) 47 | x._2.foreach { v => _array += v } 48 | } 49 | 50 | override def value: Map[K, mutable.ArrayBuffer[V]] = _map.toMap 51 | 52 | override def isZero: Boolean = _map.isEmpty 53 | 54 | def toIterator: Iterator[(K, mutable.ArrayBuffer[V])] = _map.iterator 55 | 56 | override def copy(): MapArrayAccumulator[K, V] = { 57 | val newAcc = new MapArrayAccumulator[K, V] 58 | _map.foreach( x => newAcc._map.update(x._1, x._2.clone())) 59 | newAcc 60 | } 61 | 62 | override def merge(other: AccumulatorV2[(K, V), Map[K, mutable.ArrayBuffer[V]]]): Unit = 63 | other match { 64 | case o: MapArrayAccumulator[K, V] => o.toIterator.foreach {x => this.add_ks(x)} 65 | case _ => 66 | throw new UnsupportedOperationException( 67 | s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") 68 | } 69 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/analysis/BinaryConfusionMatrix.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.analysis 21 | 22 | import org.apache.spark.rdd.RDD 23 | import org.apache.spark.sql.Dataset 24 | 25 | case class BinaryConfusionMatrix(TP: Long, 26 | FP: Long, 27 | FN: Long) { 28 | def P: Long = TP + FN 29 | def PP: Long = TP + FP 30 | } 31 | 32 | object BinaryConfusionMatrix { 33 | def apply(predictionsAndLabels: Seq[(Boolean, Boolean)]): BinaryConfusionMatrix = { 34 | var TP = 0L 35 | var FP = 0L 36 | var FN = 0L 37 | predictionsAndLabels.foreach {case (prediction, label) => 38 | if (prediction & label) TP += 1L 39 | if (prediction & !label) FP += 1L 40 | if (!prediction & label) FN += 1L 41 | } 42 | BinaryConfusionMatrix(TP, FP, FN) 43 | } 44 | 45 | def apply(predictionsAndLabels: RDD[(Boolean, Boolean)]): BinaryConfusionMatrix = { 46 | val sc = predictionsAndLabels.sparkContext 47 | val accTP = sc.longAccumulator("TP") 48 | val accFP = sc.longAccumulator("FP") 49 | val accFN = sc.longAccumulator("FN") 50 | predictionsAndLabels.foreach {case (prediction, label) => 51 | if (prediction & label) accTP.add(1L) 52 | if (prediction & !label) accFP.add(1L) 53 | if (!prediction & label) accFN.add(1L) 54 | } 55 | BinaryConfusionMatrix(accTP.value, accFP.value, accFN.value) 56 | } 57 | 58 | def apply(predictionsAndLabels: Dataset[(Boolean, Boolean)]): BinaryConfusionMatrix = { 59 | val spark = predictionsAndLabels.sparkSession 60 | val sc = spark.sparkContext 61 | val accTP = sc.longAccumulator("TP") 62 | val accFP = sc.longAccumulator("FP") 63 | val accFN = sc.longAccumulator("FN") 64 | predictionsAndLabels.rdd.foreach { case (prediction, label) => 65 | if (prediction & label) accTP.add(1L) 66 | if (prediction & !label) accFP.add(1L) 67 | if (!prediction & label) accFN.add(1L) 68 | } 69 | BinaryConfusionMatrix(accTP.value, accFP.value, accFN.value) 70 | } 71 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/util/BufferedFileWriter.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.util 21 | 22 | import java.io.{BufferedWriter, OutputStreamWriter} 23 | 24 | import com.github.cleanzr.dblink.util.BufferedFileWriter._ 25 | import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} 26 | import org.apache.hadoop.util.Progressable 27 | import org.apache.spark.SparkContext 28 | 29 | case class BufferedFileWriter(path: String, 30 | append: Boolean, 31 | sparkContext: SparkContext) { 32 | private val hdfs = FileSystem.get(sparkContext.hadoopConfiguration) 33 | private val file = new Path(path) 34 | private val _progress = new WriterProgress 35 | private var partsDir: Path = _ // temp working dir for appending 36 | 37 | private val outStream = if (hdfs.exists(file) && append) { 38 | /** Hadoop doesn't support append on a ChecksumFilesystem. 39 | * Get around this limitation by writing to a temporary new file and 40 | * merging with the old file on .close() */ 41 | partsDir = new Path(path + "-PARTS") 42 | hdfs.mkdirs(partsDir) // dir for new and old parts 43 | hdfs.rename(file, new Path(partsDir.toString + Path.SEPARATOR + "PART0.csv")) // move old part into dir 44 | hdfs.create(new Path(partsDir.toString + Path.SEPARATOR + "PART1.csv"), true, 1, _progress) 45 | } else { 46 | hdfs.create(file, true, 1, _progress) 47 | } 48 | 49 | private val writer = new OutputStreamWriter(outStream, "UTF-8") 50 | 51 | def close(): Unit = { 52 | writer.close() 53 | if (partsDir != null) { 54 | /** Need to merge new and old parts */ 55 | FileUtil.copyMerge(hdfs, partsDir, hdfs, file, 56 | true, sparkContext.hadoopConfiguration, null) 57 | hdfs.delete(partsDir, true) 58 | } 59 | //hdfs.close() 60 | } 61 | 62 | def flush(): Unit = writer.flush() 63 | 64 | def write(str: String): Unit = writer.write(str, 0, str.length()) 65 | 66 | def progress(): Unit = _progress.progress() 67 | } 68 | 69 | object BufferedFileWriter { 70 | class WriterProgress extends Progressable { 71 | override def progress(): Unit = {} 72 | } 73 | } -------------------------------------------------------------------------------- /examples/RLdata500.conf: -------------------------------------------------------------------------------- 1 | dblink : { 2 | 3 | // Define distortion hyperparameters (to be referenced below) 4 | lowDistortion : {alpha : 0.5, beta : 50.0} 5 | 6 | // Define similarity functions (to be referenced below) 7 | constSimFn : { 8 | name : "ConstantSimilarityFn", 9 | } 10 | 11 | levSimFn : { 12 | name : "LevenshteinSimilarityFn", 13 | parameters : { 14 | threshold : 7.0 15 | maxSimilarity : 10.0 16 | } 17 | } 18 | 19 | data : { 20 | // Path to data files. Must have header row (column names). 21 | path : "./examples/RLdata500.csv" 22 | 23 | // Specify columns that contain identifiers 24 | recordIdentifier : "rec_id", 25 | // fileIdentifier : null, // not needed since this data set is only a single file 26 | entityIdentifier : "ent_id" // optional 27 | 28 | // String representation of a missing value 29 | nullValue : "NA" 30 | 31 | // Specify properties of the attributes (columns) used for matching 32 | matchingAttributes : [ 33 | {name : "by", similarityFunction : ${dblink.constSimFn}, distortionPrior : ${dblink.lowDistortion}}, 34 | {name : "bm", similarityFunction : ${dblink.constSimFn}, distortionPrior : ${dblink.lowDistortion}}, 35 | {name : "bd", similarityFunction : ${dblink.constSimFn}, distortionPrior : ${dblink.lowDistortion}}, 36 | {name : "fname_c1", similarityFunction : ${dblink.levSimFn}, distortionPrior : ${dblink.lowDistortion}}, 37 | {name : "lname_c1", similarityFunction : ${dblink.levSimFn}, distortionPrior : ${dblink.lowDistortion}} 38 | ] 39 | } 40 | 41 | randomSeed : 319158 42 | expectedMaxClusterSize : 10 43 | 44 | // Specify partitioner 45 | partitioner : { 46 | name : "KDTreePartitioner", 47 | parameters : { 48 | numLevels : 0, // a value of zero means no partitioning 49 | matchingAttributes : [] // cycle through matching attributes in this order when constructing the tree 50 | } 51 | } 52 | 53 | // Path to Markov chain and full state (for resuming MCMC) 54 | outputPath : "./examples/RLdata500_results/" 55 | 56 | // Path to save Spark checkpoints 57 | checkpointPath : "/tmp/spark_checkpoint/" 58 | 59 | // Steps to be performed (in order) 60 | steps : [ 61 | {name : "sample", parameters : { 62 | sampleSize : 100, 63 | burninInterval : 0, 64 | thinningInterval : 10, 65 | resume : false, 66 | sampler : "PCG-I" 67 | }}, 68 | {name : "summarize", parameters : { 69 | lowerIterationCutoff : 0, 70 | quantities : ["cluster-size-distribution"] 71 | }}, 72 | {name : "evaluate", parameters : { 73 | lowerIterationCutoff : 100, 74 | metrics : ["pairwise", "cluster"], 75 | useExistingSMPC : false 76 | }} 77 | ] 78 | } 79 | -------------------------------------------------------------------------------- /examples/RLdata10000.conf: -------------------------------------------------------------------------------- 1 | dblink : { 2 | 3 | // Define distortion hyperparameters (to be referenced below) 4 | lowDistortion : {alpha : 10.0, beta : 1000.0} 5 | 6 | // Define similarity functions (to be referenced below) 7 | constSimFn : { 8 | name : "ConstantSimilarityFn", 9 | } 10 | 11 | levSimFn : { 12 | name : "LevenshteinSimilarityFn", 13 | parameters : { 14 | threshold : 7.0 15 | maxSimilarity : 10.0 16 | } 17 | } 18 | 19 | data : { 20 | // Path to data files. Must have header row (column names). 21 | path : "./examples/RLdata10000.csv" 22 | 23 | // Specify columns that contain identifiers 24 | recordIdentifier : "rec_id", 25 | // fileIdentifier : null, // not needed since this data set is only a single file 26 | entityIdentifier : "ent_id" // optional 27 | 28 | // String representation of a missing value 29 | nullValue : "NA" 30 | 31 | // Specify properties of the attributes (columns) used for matching 32 | matchingAttributes : [ 33 | {name : "by", similarityFunction : ${dblink.constSimFn}, distortionPrior : ${dblink.lowDistortion}}, 34 | {name : "bm", similarityFunction : ${dblink.constSimFn}, distortionPrior : ${dblink.lowDistortion}}, 35 | {name : "bd", similarityFunction : ${dblink.constSimFn}, distortionPrior : ${dblink.lowDistortion}}, 36 | {name : "fname_c1", similarityFunction : ${dblink.levSimFn}, distortionPrior : ${dblink.lowDistortion}}, 37 | {name : "lname_c1", similarityFunction : ${dblink.levSimFn}, distortionPrior : ${dblink.lowDistortion}} 38 | ] 39 | } 40 | 41 | randomSeed : 319158 42 | expectedMaxClusterSize : 10 43 | 44 | // Specify partitioner 45 | partitioner : { 46 | name : "KDTreePartitioner", 47 | parameters : { 48 | numLevels : 1, // a value of zero means no partitioning 49 | matchingAttributes : ["fname_c1"] // cycle through matching attributes in this order when constructing the tree 50 | } 51 | } 52 | 53 | // Path to Markov chain and full state (for resuming MCMC) 54 | outputPath : "./examples/RLdata10000_results/" 55 | 56 | // Path to save Spark checkpoints 57 | checkpointPath : "/tmp/spark_checkpoint/" 58 | 59 | // Steps to be performed (in order) 60 | steps : [ 61 | {name : "sample", parameters : { 62 | sampleSize : 100, 63 | burninInterval : 0, 64 | thinningInterval : 10, 65 | resume : false, 66 | sampler : "PCG-I" 67 | }}, 68 | {name : "summarize", parameters : { 69 | lowerIterationCutoff : 0, 70 | quantities : ["cluster-size-distribution", "partition-sizes"] 71 | }}, 72 | {name : "evaluate", parameters : { 73 | lowerIterationCutoff : 100, 74 | metrics : ["pairwise", "cluster"], 75 | useExistingSMPC : false 76 | }} 77 | ] 78 | } 79 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/SummaryAccumulators.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import com.github.cleanzr.dblink.accumulators.MapLongAccumulator 23 | import com.github.cleanzr.dblink.accumulators.MapLongAccumulator 24 | import org.apache.spark.SparkContext 25 | import org.apache.spark.util.{DoubleAccumulator, LongAccumulator} 26 | 27 | /** Collection of accumulators used for computing `SummaryVars` 28 | * 29 | * @param logLikelihood double accumulator for the (un-normalised) log-likelihood 30 | * @param numIsolates long accumulator for the number of isolated latent entities 31 | * @param aggDistortions map accumulator for the number of distortions per 32 | * file and attribute. 33 | * @param recDistortions 34 | */ 35 | case class SummaryAccumulators(logLikelihood: DoubleAccumulator, 36 | numIsolates: LongAccumulator, 37 | aggDistortions: MapLongAccumulator[(AttributeId, FileId)], 38 | recDistortions: MapLongAccumulator[Int]) { 39 | /** Reset all accumulators */ 40 | def reset(): Unit = { 41 | logLikelihood.reset() 42 | numIsolates.reset() 43 | aggDistortions.reset() 44 | recDistortions.reset() 45 | } 46 | } 47 | 48 | object SummaryAccumulators { 49 | /** Create a `SummaryAccumulators` object using a given SparkContext. 50 | * 51 | * @param sparkContext a SparkContext. 52 | * @return a `SummaryAccumulators` object. 53 | */ 54 | def apply(sparkContext: SparkContext): SummaryAccumulators = { 55 | val logLikelihood = new DoubleAccumulator 56 | val numIsolates = new LongAccumulator 57 | val aggDistortions = new MapLongAccumulator[(AttributeId, FileId)] 58 | val recDistortions = new MapLongAccumulator[Int] 59 | sparkContext.register(logLikelihood, "log-likelihood (un-normalised)") 60 | sparkContext.register(numIsolates, "number of isolates") 61 | sparkContext.register(aggDistortions, "aggregate number of distortions per attribute/file") 62 | sparkContext.register(recDistortions, "frequency distribution of total record distortion") 63 | new SummaryAccumulators(logLikelihood, numIsolates, aggDistortions, recDistortions) 64 | } 65 | } -------------------------------------------------------------------------------- /src/test/scala/com/github/cleanzr/dblink/random/AliasSamplerTest.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.random 21 | 22 | import org.apache.commons.math3.random.{MersenneTwister, RandomGenerator} 23 | import org.scalatest.{FlatSpec, Matchers} 24 | 25 | class AliasSamplerTest extends FlatSpec with Matchers { 26 | implicit val rand: RandomGenerator = new MersenneTwister(1) 27 | 28 | def regularProbs = IndexedSeq(0.1, 0.2, 0.7) 29 | def extremeProbs = IndexedSeq(0.000000001, 0.000000001, 0.999999999) 30 | 31 | def empiricalDistributionIsConsistent(inputProbs: IndexedSeq[Double], 32 | sampler: AliasSampler, 33 | numSamples: Int = 100000000, 34 | tolerance: Double = 1e-4): Boolean = { 35 | val empiricalProbs = Array.fill(sampler.size)(0.0) 36 | var i = 0 37 | while (i < numSamples) { 38 | empiricalProbs(sampler.sample()) += 1.0/numSamples 39 | i += 1 40 | } 41 | (empiricalProbs, inputProbs).zipped.forall( (pE, pT) => pT === (pT +- tolerance)) 42 | } 43 | 44 | behavior of "An Alias sampler" 45 | 46 | it should "complain when initialized with a negative weight" in { 47 | assertThrows[IllegalArgumentException] { 48 | AliasSampler(Seq(-1.0, 1.0)) 49 | } 50 | } 51 | 52 | it should "complain when initialized with a NaN weight" in { 53 | assertThrows[IllegalArgumentException] { 54 | AliasSampler(Seq(0.0/0.0, 1.0)) 55 | } 56 | } 57 | 58 | it should "complain when initialized with an infinite weight" in { 59 | assertThrows[IllegalArgumentException] { 60 | AliasSampler(Seq(1.0/0.0, 1.0)) 61 | } 62 | } 63 | 64 | it should "produce an asymptotic empirical distribution that is consistent with the input distribution [0.1, 0.2, 0.7]" in { 65 | val sampler = AliasSampler(regularProbs) 66 | assert(empiricalDistributionIsConsistent(regularProbs, sampler)) 67 | } 68 | 69 | it should "produce an asymptotic empirical distribution that is consistent with the input distribution [0.000000001, 0.000000001, 0.999999999]" in { 70 | val sampler = AliasSampler(extremeProbs) 71 | assert(empiricalDistributionIsConsistent(extremeProbs, sampler)) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/analysis/ClusteringContingencyTable.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.analysis 21 | 22 | import ClusteringContingencyTable.ContingencyTableRow 23 | import com.github.cleanzr.dblink.Cluster 24 | import org.apache.spark.sql.Dataset 25 | import org.apache.spark.storage.StorageLevel 26 | 27 | case class ClusteringContingencyTable(table: Dataset[ContingencyTableRow], 28 | size: Long) { 29 | def persist(newLevel: StorageLevel) { 30 | table.persist(newLevel) 31 | } 32 | 33 | def unpersist() { 34 | table.unpersist() 35 | } 36 | } 37 | 38 | object ClusteringContingencyTable { 39 | case class ContingencyTableRow(PredictedUID: Long, TrueUID: Long, NumCommonElements: Int) 40 | 41 | // Note: this is a sparse representation of the contingency table. Cluster 42 | // pairs which are not present in the table have no common elements. 43 | def apply(predictedClusters: Dataset[Cluster], trueClusters: Dataset[Cluster]): ClusteringContingencyTable = { 44 | val spark = predictedClusters.sparkSession 45 | import spark.implicits._ 46 | // Convert to membership representation 47 | val predictedMembership = predictedClusters.toMembership 48 | val trueMembership = trueClusters.toMembership 49 | 50 | val predictedSize = predictedMembership.count() 51 | val trueSize = trueMembership.count() 52 | 53 | // Ensure that clusterings partition the same set of elements (continue checking after join...) 54 | if (predictedSize != trueSize) throw new Exception("Clusterings do not partition the same set of elements.") 55 | 56 | val joined = predictedMembership.rdd.join(trueMembership.rdd).persist() 57 | val joinedSize = joined.count() 58 | 59 | // Continued checking... 60 | if (trueSize != joinedSize) throw new Exception("Clusterings do not partition the same set of elements.") 61 | 62 | val table = joined.map{case (_, (predictedUID, trueUID)) => ((predictedUID, trueUID), 1)} 63 | .reduceByKey(_ + _) 64 | .map{case ((predictedUID, trueUID), count) => ContingencyTableRow(predictedUID, trueUID, count)} 65 | .toDS() 66 | 67 | joined.unpersist() 68 | 69 | ClusteringContingencyTable(table, trueSize) 70 | } 71 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/util/BufferedRDDWriter.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // Copyright (C) 2018 Neil Marchant 3 | // 4 | // Author: Neil Marchant 5 | // 6 | // This file is part of dblink. 7 | // 8 | // This program is free software: you can redistribute it and/or modify 9 | // it under the terms of the GNU General Public License as published by 10 | // the Free Software Foundation, either version 3 of the License, or 11 | // (at your option) any later version. 12 | // 13 | // This program is distributed in the hope that it will be useful, 14 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | // GNU General Public License for more details. 17 | // 18 | // You should have received a copy of the GNU General Public License 19 | // along with this program. If not, see . 20 | 21 | package com.github.cleanzr.dblink.util 22 | 23 | import org.apache.spark.rdd.RDD 24 | import org.apache.spark.sql.SparkSession 25 | import org.apache.spark.sql.Encoder 26 | import org.apache.spark.storage.StorageLevel 27 | import com.github.cleanzr.dblink.Logging 28 | import scala.reflect.ClassTag 29 | 30 | case class BufferedRDDWriter[T : ClassTag : Encoder](path: String, 31 | capacity: Int, 32 | append: Boolean, 33 | rdds: Seq[RDD[T]], 34 | firstFlush: Boolean) extends Logging { 35 | 36 | def append(rows: RDD[T]): BufferedRDDWriter[T] = { 37 | val writer = if (capacity - rdds.size == 0) this.flush() else this 38 | rows.persist(StorageLevel.MEMORY_AND_DISK) 39 | rows.count() // force evaluation 40 | val newRdds = writer.rdds :+ rows 41 | writer.copy(rdds = newRdds) 42 | } 43 | 44 | private def write(unionedRdds: RDD[T], overwrite: Boolean): Unit = { 45 | /** Write to disk in Parquet format relying on Dataset/Dataframe API */ 46 | val spark = SparkSession.builder().getOrCreate() 47 | val unionedDS = spark.createDataset(unionedRdds) 48 | val saveMode = if (overwrite) "overwrite" else "append" 49 | unionedDS.write.partitionBy("partitionId").format("parquet").mode(saveMode).save(path) 50 | } 51 | 52 | def flush(): BufferedRDDWriter[T] = { 53 | if (rdds.isEmpty) return this 54 | 55 | // Combine RDDs in the buffer using a union operation 56 | val sc = rdds.head.sparkContext 57 | val unionedRdds = sc.union(rdds) 58 | 59 | // Write to disk. Note: overwrite if not appending and this is the first write to disk. 60 | write(unionedRdds, firstFlush && !append) 61 | debug(s"Flushed to disk at $path") 62 | 63 | // Unpersist RDDs in buffer as they're no longer required 64 | rdds.foreach(_.unpersist(blocking = false)) 65 | 66 | this.copy(rdds = Seq.empty[RDD[T]], firstFlush = false) 67 | } 68 | } 69 | 70 | object BufferedRDDWriter extends Logging { 71 | def apply[T : ClassTag : Encoder](capacity: Int, path: String, append: Boolean): BufferedRDDWriter[T] = { 72 | val rdds = Seq.empty[RDD[T]] 73 | BufferedRDDWriter[T](path, capacity, append, rdds, firstFlush = true) 74 | } 75 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/random/IndexNonUniformDiscreteDist.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.random 21 | 22 | import org.apache.commons.math3.random.RandomGenerator 23 | 24 | class IndexNonUniformDiscreteDist(weights: Traversable[Double]) extends DiscreteDist[Int] { 25 | require(weights.nonEmpty, "`weights` must be non-empty") 26 | 27 | private val (_probsArray, _totalWeight) = IndexNonUniformDiscreteDist.processWeights(weights) 28 | 29 | override def totalWeight: Double = _totalWeight 30 | 31 | override def numValues: Int = _probsArray.length 32 | 33 | override def values: Traversable[Int] = 0 until numValues 34 | 35 | /** AliasSampler to efficiently sample from the distribution */ 36 | private val sampler = AliasSampler(_probsArray, checkWeights = false, normalized = true) 37 | 38 | // /** Inverse CDF */ 39 | // private def sampleNaive(): Int = { 40 | // var prob = rand.nextDouble() * totalWeight 41 | // var idx = 0 42 | // while (prob > 0.0 && idx < numValues) { 43 | // prob -= weights(idx) 44 | // idx += 1 45 | // } 46 | // if (prob > 0.0) idx else idx - 1 47 | // } 48 | 49 | override def sample()(implicit rand: RandomGenerator): Int = sampler.sample() 50 | 51 | override def probabilityOf(idx: Int): Double = { 52 | if (idx < 0 || idx >= numValues) 0.0 53 | else _probsArray(idx) 54 | } 55 | 56 | override def toIterator: Iterator[(Int, Double)] = { 57 | (0 until numValues).zip(_probsArray).toIterator 58 | } 59 | } 60 | 61 | object IndexNonUniformDiscreteDist { 62 | def apply(weights: Traversable[Double]): IndexNonUniformDiscreteDist = { 63 | new IndexNonUniformDiscreteDist(weights) 64 | } 65 | 66 | private def processWeights(weights: Traversable[Double]): (Array[Double], Double) = { 67 | val probs = Array.ofDim[Double](weights.size) 68 | var totalWeight: Double = 0.0 69 | var i = 0 70 | weights.foreach { weight => 71 | if (weight < 0 || weight.isInfinity || weight.isNaN) { 72 | throw new IllegalArgumentException("invalid weight encountered") 73 | } 74 | probs(i) = weight 75 | totalWeight += weight 76 | i += 1 77 | } 78 | if (totalWeight.isInfinity) throw new IllegalArgumentException("total weight is not finite") 79 | if (totalWeight == 0.0) throw new IllegalArgumentException("zero probability mass") 80 | i = 0 81 | while (i < probs.length) { 82 | probs(i) = probs(i)/totalWeight 83 | i += 1 84 | } 85 | (probs, totalWeight) 86 | } 87 | } -------------------------------------------------------------------------------- /src/test/scala/com/github/cleanzr/dblink/SimilarityFnTest.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import com.github.cleanzr.dblink.SimilarityFn._ 23 | import org.scalatest.FlatSpec 24 | 25 | class SimilarityFnTest extends FlatSpec { 26 | behavior of "A constant similarity function" 27 | 28 | it should "return the same (constant) value for the max, min, threshold similarities" in { 29 | assert(ConstantSimilarityFn.maxSimilarity === ConstantSimilarityFn.minSimilarity) 30 | assert(ConstantSimilarityFn.maxSimilarity === ConstantSimilarityFn.threshold) 31 | } 32 | 33 | it should "return the max (constant) similarity for identical values" in { 34 | assert(ConstantSimilarityFn.getSimilarity("TestValue", "TestValue") === ConstantSimilarityFn.maxSimilarity) 35 | } 36 | 37 | it should "return the max (constant) similarity for distinct values" in { 38 | assert(ConstantSimilarityFn.getSimilarity("TestValue1", "TestValue2") === ConstantSimilarityFn.maxSimilarity) 39 | } 40 | 41 | def thresSimFn: LevenshteinSimilarityFn = LevenshteinSimilarityFn(5.0, 10.0) 42 | def noThresSimFn: LevenshteinSimilarityFn = LevenshteinSimilarityFn(0.0, 10.0) 43 | 44 | behavior of "A Levenshtein similarity function (maxSimilarity=10.0 and threshold=5.0)" 45 | 46 | it should "return the max similarity for an identical pair of non-empty strings" in { 47 | assert(thresSimFn.getSimilarity("John Smith", "John Smith") === thresSimFn.maxSimilarity) 48 | } 49 | 50 | it should "return the max similarity for a pair of empty strings" in { 51 | assert(thresSimFn.getSimilarity("", "") === thresSimFn.maxSimilarity) 52 | } 53 | 54 | it should "return the min similarity for an empty and non-empty string" in { 55 | assert(thresSimFn.getSimilarity("", "John Smith") === thresSimFn.minSimilarity) 56 | } 57 | 58 | it should "be symmetric in its arguments" in { 59 | assert(thresSimFn.getSimilarity("Jane Smith", "John Smith") === thresSimFn.getSimilarity("John Smith", "Jane Smith")) 60 | } 61 | 62 | it should "return a similarity of 2.0 for the strings 'AB' and 'BB'" in { 63 | assert(thresSimFn.getSimilarity("AB", "BB") === 2.0) 64 | } 65 | 66 | behavior of "A Levenshtein similarity function with no threshold (maxSimilarity=10.0)" 67 | 68 | it should "return the same value for the min, threshold similarities" in { 69 | assert(noThresSimFn.threshold === noThresSimFn.minSimilarity) 70 | } 71 | 72 | it should "return a similarity of 6.0 for the strings 'AB' and 'BB'" in { 73 | assert(noThresSimFn.getSimilarity("AB", "BB") === 6.0) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/partitioning/LPTScheduler.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.partitioning 21 | 22 | import LPTScheduler._ 23 | 24 | import scala.Numeric.Implicits._ 25 | import scala.Ordering.Implicits._ 26 | import scala.collection.mutable 27 | import scala.collection.mutable.ArrayBuffer 28 | import scala.reflect.ClassTag 29 | 30 | /** 31 | * Implementation of the longest processing time (LPT) algorithm for 32 | * partitioning jobs across multiple processors 33 | * @param jobs 34 | * @param numPartitions 35 | * @tparam J job index 36 | * @tparam T job runtime 37 | */ 38 | class LPTScheduler[J, T : ClassTag : Numeric](jobs: IndexedSeq[(J,T)], 39 | val numPartitions: Int) extends Serializable { 40 | 41 | val numJobs: Int = jobs.size 42 | private val partitions: Array[Partition[J, T]] = partitionJobs[J, T](jobs, numPartitions) 43 | val partitionSizes: Array[T] = partitions.iterator.map[T](_.size).toArray[T] 44 | private val partitionsIndex = indexPartitions(partitions) 45 | 46 | def getJobs(partitionId: Int): Seq[J] = partitions(partitionId).jobs 47 | 48 | def getSize(partitionId: Int): T = partitionSizes(partitionId) 49 | 50 | def getPartitionId(jobId: J): Int = partitionsIndex(jobId) 51 | // TODO: deal with non-existent jobId 52 | } 53 | 54 | object LPTScheduler { 55 | case class Partition[J, T : ClassTag : Numeric](var size: T, jobs: ArrayBuffer[J]) 56 | 57 | private def partitionJobs[J, T : ClassTag : Numeric](jobs: IndexedSeq[(J, T)], 58 | numPartitions: Int): Array[Partition[J, T]] = { 59 | val sortedJobs = jobs.sortBy(-_._2) 60 | 61 | // Initialise array of Partitions 62 | val partitions = Array.fill[Partition[J, T]](numPartitions) { 63 | Partition(implicitly[Numeric[T]].zero, ArrayBuffer[J]()) 64 | } 65 | 66 | // Loop through jobs, putting each in the partition with the smallest size so far 67 | sortedJobs.foreach {j => 68 | val (_, smallestId) = partitions.zipWithIndex.foldLeft ((partitions.head, 0) ) { 69 | (xO, xN) => if (xO._1.size > xN._1.size) xN else xO 70 | } 71 | partitions(smallestId).size += j._2 72 | partitions(smallestId).jobs += j._1 73 | } 74 | 75 | partitions 76 | } 77 | 78 | private def indexPartitions[J, T : Numeric](partitions: Array[Partition[J,T]]): Map[J, Int] = { 79 | val bIndex = Map.newBuilder[J, Int] 80 | partitions.zipWithIndex.foreach { case (partition, idx) => 81 | partition.jobs.foreach(item => bIndex += (item -> idx)) 82 | } 83 | bIndex.result() 84 | } 85 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/analysis/ClusteringMetrics.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.analysis 21 | 22 | import com.github.cleanzr.dblink.Cluster 23 | import org.apache.commons.math3.util.CombinatoricsUtils.binomialCoefficient 24 | import org.apache.spark.sql.Dataset 25 | import org.apache.spark.storage.StorageLevel 26 | 27 | case class ClusteringMetrics(adjRandIndex: Double) { 28 | def mkString: String = { 29 | "=====================================\n" + 30 | " Cluster metrics \n" + 31 | "-------------------------------------\n" + 32 | s" Adj. Rand index: $adjRandIndex\n" + 33 | "=====================================\n" 34 | } 35 | 36 | def print(): Unit = { 37 | Console.print(mkString) 38 | } 39 | } 40 | 41 | object ClusteringMetrics { 42 | private def comb2(x: Int): Long = if (x >= 2) binomialCoefficient(x,2) else 0 43 | 44 | def AdjustedRandIndex(contingencyTable: ClusteringContingencyTable): Double = { 45 | // Compute sum_{PredictedUID} comb2(sum_{TrueUID} NumCommonElements(PredictedUID, TrueUID)) 46 | val predCombSum = contingencyTable.table.rdd 47 | .map(row => (row.PredictedUID, row.NumCommonElements)) 48 | .reduceByKey(_ + _) // sum over True UID 49 | .aggregate(0L)( 50 | seqOp = (sum, x) => sum + comb2(x._2), 51 | combOp = _ + _ // apply comb2(.) and sum over PredictedUID 52 | ) 53 | 54 | // Compute sum_{TrueUID} comb2(sum_{PredictedUID} NumCommonElements(PredictedUID, TrueUID)) 55 | val trueCombSum = contingencyTable.table.rdd 56 | .map(row => (row.TrueUID, row.NumCommonElements)) 57 | .reduceByKey(_ + _) // sum over Pred UID 58 | .aggregate(0L)( 59 | seqOp = (sum, x) => sum + comb2(x._2), 60 | combOp = _ + _ // apply comb2(.) and sum over True UID 61 | ) 62 | 63 | // Compute sum_{TrueUID} sum_{PredictedUID} comb2(NumCommonElements(PredictedUID, TrueUID)) 64 | val totalCombSum = contingencyTable.table.rdd 65 | .aggregate(0L)( 66 | seqOp = (sum, x) => sum + comb2(x.NumCommonElements), 67 | combOp = _ + _ // apply comb2(.) and sum over PredictedUID & TrueUID 68 | ) 69 | 70 | // Return adjusted Rand index 71 | val expectedIndex = predCombSum.toDouble * trueCombSum / comb2(contingencyTable.size.toInt) 72 | val maxIndex = (predCombSum.toDouble + trueCombSum) / 2.0 73 | (totalCombSum - expectedIndex)/(maxIndex - expectedIndex) 74 | } 75 | 76 | def apply(predictedClusters: Dataset[Cluster], 77 | trueClusters: Dataset[Cluster]): ClusteringMetrics = { 78 | val contingencyTable = ClusteringContingencyTable(predictedClusters, trueClusters) 79 | contingencyTable.persist(StorageLevel.MEMORY_ONLY) 80 | val adjRandIndex = AdjustedRandIndex(contingencyTable) 81 | contingencyTable.unpersist() 82 | ClusteringMetrics(adjRandIndex) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/ProjectSteps.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import ProjectStep.{CopyFilesStep, EvaluateStep, SampleStep, SummarizeStep} 23 | import com.typesafe.config.{Config, ConfigException} 24 | import Project.toConfigTraversable 25 | 26 | import scala.collection.JavaConverters._ 27 | import scala.collection.mutable 28 | import scala.util.Try 29 | 30 | class ProjectSteps(config: Config, project: Project) { 31 | 32 | val steps: Traversable[ProjectStep] = ProjectSteps.parseSteps(config, project) 33 | 34 | def execute(): Unit = { 35 | steps.foreach(_.execute()) 36 | } 37 | 38 | def mkString: String = { 39 | val lines = mutable.ArrayBuffer.empty[String] 40 | lines += "Scheduled steps" 41 | lines += "---------------" 42 | lines ++= steps.map(step => " * " + step.mkString) 43 | 44 | lines.mkString("\n") 45 | } 46 | } 47 | 48 | object ProjectSteps { 49 | def apply(config: Config, project: Project): ProjectSteps = { 50 | new ProjectSteps(config, project) 51 | } 52 | 53 | private def parseSteps(config: Config, project: Project): Traversable[ProjectStep] = { 54 | config.getObjectList("dblink.steps").toTraversable.map { step => 55 | step.getString("name") match { 56 | case "sample" => 57 | new SampleStep(project, 58 | sampleSize = step.getInt("parameters.sampleSize"), 59 | burninInterval = Try {step.getInt("parameters.burninInterval")} getOrElse 0, 60 | thinningInterval = Try {step.getInt("parameters.thinningInterval")} getOrElse 1, 61 | resume = Try {step.getBoolean("parameters.resume")} getOrElse true, 62 | sampler = Try {step.getString("parameters.sampler")} getOrElse "PCG-I") 63 | case "evaluate" => 64 | new EvaluateStep(project, 65 | lowerIterationCutoff = Try {step.getInt("parameters.lowerIterationCutoff")} getOrElse 0, 66 | metrics = step.getStringList("parameters.metrics").asScala, 67 | useExistingSMPC = Try {step.getBoolean("parameters.useExistingSMPC")} getOrElse false 68 | ) 69 | case "summarize" => 70 | new SummarizeStep(project, 71 | lowerIterationCutoff = Try {step.getInt("parameters.lowerIterationCutoff")} getOrElse 0, 72 | quantities = step.getStringList("parameters.quantities").asScala 73 | ) 74 | case "copy-files" => 75 | new CopyFilesStep(project, 76 | fileNames = step.getStringList("parameters.fileNames").asScala, 77 | destinationPath = step.getString("parameters.destinationPath"), 78 | overwrite = Try {step.getBoolean("parameters.overwrite")} getOrElse false, 79 | deleteSource = Try {step.getBoolean("parameters.deleteSource")} getOrElse false 80 | ) 81 | case _ => throw new ConfigException.BadValue(config.origin(), "name", "unsupported step") 82 | } 83 | } 84 | } 85 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/DiagnosticsWriter.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import com.github.cleanzr.dblink.util.BufferedFileWriter 23 | import org.apache.spark.SparkContext 24 | 25 | /** Buffered writer for diagnostics along the Markov chain 26 | * 27 | * Output is in CSV format 28 | * 29 | * @param path path to save diagnostics (can be a HDFS or local path) 30 | * @param continueChain whether to append to existing diagnostics (if present) 31 | */ 32 | class DiagnosticsWriter(path: String, continueChain: Boolean) 33 | (implicit sparkContext: SparkContext) extends Logging { 34 | private val writer = BufferedFileWriter(path, continueChain, sparkContext) 35 | private var firstWrite = true 36 | info(s"Writing diagnostics along chain to $path.") 37 | 38 | /** Write header for CSV */ 39 | private def writeHeader(state: State): Unit = { 40 | val recordsCache = state.bcRecordsCache.value 41 | val aggDistortionsHeaders = recordsCache.indexedAttributes.map(x => s"aggDist-${x.name}").mkString(",") 42 | val recDistortionsHeaders = (0 to recordsCache.numAttributes).map(k => s"recDistortion-$k").mkString(",") 43 | writer.write(s"iteration,systemTime-ms,numObservedEntities,logLikelihood,popSize,$aggDistortionsHeaders,$recDistortionsHeaders\n") 44 | } 45 | 46 | /** Write a row of diagnostics for the given state */ 47 | def writeRow(state: State): Unit = { 48 | if (firstWrite && !continueChain) writeHeader(state); firstWrite = false 49 | 50 | // Get number of attributes 51 | val numAttributes = state.bcRecordsCache.value.numAttributes 52 | 53 | // Aggregate number of distortions for each attribute (sum over fileId) 54 | val aggAttrDistortions = state.summaryVars.aggDistortions 55 | .groupBy(_._1._1) // group by attrId 56 | .mapValues(_.values.sum) // sum over fileId 57 | 58 | // Convenience variable 59 | val recDistortions = state.summaryVars.recDistortions 60 | 61 | // Build row of string values matching header 62 | val row: Iterator[String] = Iterator( 63 | state.iteration.toString, // iteration 64 | System.currentTimeMillis().toString, // systemTime-ms 65 | (state.populationSize - state.summaryVars.numIsolates).toString, // numObservedEntities 66 | f"${state.summaryVars.logLikelihood}%.9e", // logLikelihood 67 | state.populationSize.toString) ++ // populationSize 68 | (0 until numAttributes).map(aggAttrDistortions.getOrElse(_, 0L).toString) ++ // aggDist-* 69 | (0 to numAttributes).map(recDistortions.getOrElse(_, 0L).toString) // recDistortion-* 70 | 71 | writer.write(row.mkString(",") + "\n") 72 | } 73 | 74 | def close(): Unit = writer.close() 75 | 76 | def flush(): Unit = writer.flush() 77 | 78 | /** Need to call this occasionally to keep the writer alive */ 79 | def progress(): Unit = writer.progress() 80 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/util/PeriodicRDDCheckpointer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cleanzr.dblink.util 19 | 20 | import org.apache.spark.SparkContext 21 | import org.apache.spark.rdd.RDD 22 | import org.apache.spark.storage.StorageLevel 23 | 24 | /** 25 | * This class helps with persisting and checkpointing RDDs. 26 | * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as 27 | * unpersisting and removing checkpoint files. 28 | * 29 | * Users should call update() when a new RDD has been created, 30 | * before the RDD has been materialized. After updating [[PeriodicRDDCheckpointer]], users are 31 | * responsible for materializing the RDD to ensure that persisting and checkpointing actually 32 | * occur. 33 | * 34 | * When update() is called, this does the following: 35 | * - Persist new RDD (if not yet persisted), and put in queue of persisted RDDs. 36 | * - Unpersist RDDs from queue until there are at most 3 persisted RDDs. 37 | * - If using checkpointing and the checkpoint interval has been reached, 38 | * - Checkpoint the new RDD, and put in a queue of checkpointed RDDs. 39 | * - Remove older checkpoints. 40 | * 41 | * WARNINGS: 42 | * - This class should NOT be copied (since copies may conflict on which RDDs should be 43 | * checkpointed). 44 | * - This class removes checkpoint files once later RDDs have been checkpointed. 45 | * However, references to the older RDDs will still return isCheckpointed = true. 46 | * 47 | * Example usage: 48 | * {{{ 49 | * val (rdd1, rdd2, rdd3, ...) = ... 50 | * val cp = new PeriodicRDDCheckpointer(2, sc) 51 | * cp.update(rdd1) 52 | * rdd1.count(); 53 | * // persisted: rdd1 54 | * cp.update(rdd2) 55 | * rdd2.count(); 56 | * // persisted: rdd1, rdd2 57 | * // checkpointed: rdd2 58 | * cp.update(rdd3) 59 | * rdd3.count(); 60 | * // persisted: rdd1, rdd2, rdd3 61 | * // checkpointed: rdd2 62 | * cp.update(rdd4) 63 | * rdd4.count(); 64 | * // persisted: rdd2, rdd3, rdd4 65 | * // checkpointed: rdd4 66 | * cp.update(rdd5) 67 | * rdd5.count(); 68 | * // persisted: rdd3, rdd4, rdd5 69 | * // checkpointed: rdd4 70 | * }}} 71 | * 72 | * @param checkpointInterval RDDs will be checkpointed at this interval 73 | * @tparam T RDD element type 74 | */ 75 | class PeriodicRDDCheckpointer[T](checkpointInterval: Int, 76 | sc: SparkContext) 77 | extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) { 78 | 79 | override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint() 80 | 81 | override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed 82 | 83 | override protected def persist(data: RDD[T]): Unit = { 84 | if (data.getStorageLevel == StorageLevel.NONE) { 85 | //data.persist(StorageLevel.MEMORY_ONLY) 86 | } 87 | } 88 | 89 | override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false) 90 | 91 | override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = { 92 | data.getCheckpointFile.map(x => x) 93 | } 94 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/analysis/package.scala: -------------------------------------------------------------------------------- 1 | package com.github.cleanzr.dblink 2 | 3 | import org.apache.hadoop.fs.{FileSystem, Path} 4 | import org.apache.spark.sql.{Dataset, SparkSession} 5 | 6 | import scala.reflect.ClassTag 7 | 8 | package object analysis { 9 | 10 | /** 11 | * TODO 12 | * @param links 13 | * @return 14 | */ 15 | def canonicalizePairwiseLinks(links: Dataset[RecordPair]): Dataset[RecordPair] = { 16 | val spark = links.sparkSession 17 | import spark.implicits._ 18 | 19 | links.map { recIds => 20 | // Ensure pairs are sorted 21 | recIds._1.compareTo(recIds._2) match { 22 | case x if x < 0 => (recIds._1, recIds._2) 23 | case x if x > 0 => (recIds._2, recIds._1) 24 | case 0 => throw new Exception(s"Invalid link: ${recIds._1} <-> ${recIds._2}.") 25 | } 26 | }.distinct() 27 | } 28 | 29 | 30 | 31 | /** 32 | * TODO 33 | * @param path 34 | * @return 35 | */ 36 | def readClustersCSV(path: String): Dataset[Cluster] = { 37 | val spark = SparkSession.builder.getOrCreate() 38 | val sc = spark.sparkContext 39 | import spark.implicits._ 40 | sc.textFile(path) 41 | .map(line => line.split(",").map(_.trim).toSet).toDS() 42 | } 43 | 44 | 45 | 46 | /** 47 | * TODO 48 | * @param membership 49 | * @tparam T 50 | * @return 51 | */ 52 | def membershipToClusters[T : ClassTag](membership: Dataset[(RecordId, T)]): Dataset[Cluster] = { 53 | val spark = membership.sparkSession 54 | import spark.implicits._ 55 | membership.rdd 56 | .map(_.swap) 57 | .aggregateByKey(Set.empty[RecordId])( 58 | seqOp = (recIds, recId) => recIds + recId, 59 | combOp = (partA, partB) => partA union partB 60 | ) 61 | .map(_._2) 62 | .toDS() 63 | } 64 | 65 | 66 | 67 | /* 68 | These private methods must appear outside the Clusters implicit value class due to a limitation of the compiler 69 | */ 70 | private def _toPairwiseLinks(clusters: Dataset[Cluster]): Dataset[RecordPair] = { 71 | val spark = clusters.sparkSession 72 | import spark.implicits._ 73 | val links = clusters.flatMap(_.toSeq.combinations(2).map(x => (x(0), x(1)))) 74 | canonicalizePairwiseLinks(links) 75 | } 76 | 77 | private def _toMembership(clusters: Dataset[Cluster]): Dataset[(RecordId, EntityId)] = { 78 | val spark = clusters.sparkSession 79 | import spark.implicits._ 80 | clusters.rdd 81 | .zipWithUniqueId() 82 | .flatMap { case (recIds, entityId) => recIds.iterator.map(recId => (recId, entityId.toInt)) } 83 | .toDS() 84 | } 85 | 86 | /** 87 | * Represents a clustering of records as sets of record ids. Provides methods for converting to pairwise and 88 | * membership representations. 89 | * 90 | * @param ds A Dataset of record clusters 91 | */ 92 | implicit class Clusters(val ds: Dataset[Cluster]) extends AnyVal { 93 | /** 94 | * Save to a CSV file 95 | * 96 | * @param path Path to the file. May be a path on the local filesystem or HDFS. 97 | * @param overwrite Whether to overwrite an existing file or not. 98 | */ 99 | def saveCsv(path: String, overwrite: Boolean = true): Unit = { 100 | val sc = ds.sparkSession.sparkContext 101 | val hdfs = FileSystem.get(sc.hadoopConfiguration) 102 | val file = new Path(path) 103 | if (hdfs.exists(file)) { 104 | if (!overwrite) return else hdfs.delete(file, true) 105 | } 106 | ds.rdd.map(cluster => cluster.mkString(", ")) 107 | .saveAsTextFile(path) 108 | } 109 | 110 | /** 111 | * Convert to a pairwise representation 112 | */ 113 | def toPairwiseLinks: Dataset[RecordPair] = _toPairwiseLinks(ds) 114 | 115 | /** 116 | * Convert to a membership vector representation 117 | */ 118 | def toMembership: Dataset[(RecordId, EntityId)] = _toMembership(ds) 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/random/NonUniformDiscreteDist.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.random 21 | 22 | import org.apache.commons.math3.random.RandomGenerator 23 | 24 | import scala.collection.Map 25 | import scala.reflect.ClassTag 26 | 27 | /** A non-uniform distribution over a discrete set of values 28 | * 29 | * @param valuesWeights map from values to weights (need not be normalised) 30 | * @param rand pseudo-random number generator 31 | * @param ev 32 | * @tparam T type of the values 33 | */ 34 | class NonUniformDiscreteDist[T](valuesWeights: Map[T, Double]) 35 | (implicit ev: ClassTag[T]) extends DiscreteDist[T] { 36 | require(valuesWeights.nonEmpty, "`valuesWeights` must be non-empty") 37 | 38 | private val (_valuesArray, _probsArray, _totalWeight) = NonUniformDiscreteDist.processValuesWeights(valuesWeights) 39 | 40 | override def totalWeight: Double = _totalWeight 41 | 42 | override def values: Traversable[T] = _valuesArray.toTraversable 43 | 44 | override val numValues: Int = valuesWeights.size 45 | 46 | /** AliasSampler to efficiently sample from the distribution */ 47 | private val sampler = AliasSampler(_probsArray, checkWeights = false, normalized = true) 48 | 49 | // /** Inverse CDF */ 50 | // private def sampleNaive(): T = { 51 | // var prob = rand.nextDouble() * totalWeight 52 | // val it = valuesWeights.iterator 53 | // var vw = it.next() 54 | // while (prob > 0.0 && it.hasNext) { 55 | // vw = it.next() 56 | // prob -= vw._2 57 | // } 58 | // vw._1 59 | // } 60 | 61 | override def sample()(implicit rand: RandomGenerator): T = _valuesArray(sampler.sample()) 62 | 63 | override def probabilityOf(value: T): Double = 64 | valuesWeights.getOrElse(value, 0.0)/totalWeight 65 | 66 | override def toIterator: Iterator[(T, Double)] = { 67 | (_valuesArray,_probsArray).zipped.toIterator 68 | } 69 | } 70 | 71 | object NonUniformDiscreteDist { 72 | def apply[T](valuesAndWeights: Map[T, Double]) 73 | (implicit ev: ClassTag[T]): NonUniformDiscreteDist[T] = { 74 | new NonUniformDiscreteDist[T](valuesAndWeights) 75 | } 76 | 77 | private def processValuesWeights[T : ClassTag](valuesWeights: Map[T, Double]): (Array[T], Array[Double], Double) = { 78 | val values = Array.ofDim[T](valuesWeights.size) 79 | val probs = Array.ofDim[Double](valuesWeights.size) 80 | val it = valuesWeights.iterator 81 | var totalWeight: Double = 0.0 82 | var i = 0 83 | valuesWeights.foreach { pair => 84 | val weight = pair._2 85 | values(i) = pair._1 86 | probs(i) = weight 87 | if (weight < 0 || weight.isInfinity || weight.isNaN) { 88 | throw new IllegalArgumentException("invalid weight encountered") 89 | } 90 | totalWeight += weight 91 | i += 1 92 | } 93 | if (totalWeight.isInfinity) throw new IllegalArgumentException("total weight is not finite") 94 | if (totalWeight == 0.0) throw new IllegalArgumentException("zero probability mass") 95 | i = 0 96 | while (i < probs.length) { 97 | probs(i) = probs(i)/totalWeight 98 | i += 1 99 | } 100 | (values, probs, totalWeight) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/partitioning/DomainSplitter.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.partitioning 21 | 22 | import scala.collection.mutable 23 | import scala.math.abs 24 | 25 | /** Splits a domain of discrete weighted values into two partitions of roughly 26 | * equal weight. 27 | * 28 | * @tparam T value type 29 | */ 30 | abstract class DomainSplitter[T] { 31 | /** Quality of the split (1.0 is best, 0.0 is worst) */ 32 | val splitQuality: Double 33 | 34 | /** Returns whether value x is in the left (false) or right (true) set 35 | * 36 | * @param x value 37 | * @return 38 | */ 39 | def apply(x: T): Boolean 40 | } 41 | 42 | object DomainSplitter { 43 | def apply[T : Ordering](domain: Array[(T, Double)]): DomainSplitter[T] = { 44 | /** Use the set splitter for small domains */ 45 | if (domain.length <= 30) new LPTDomainSplitter[T](domain) 46 | else new RanDomainSplitter[T](domain) 47 | } 48 | 49 | /** Range splitter. 50 | * 51 | * Splits the domain by sorting the values, then splitting at the (weighted) 52 | * median. 53 | * 54 | * @param domain value-weight pairs for the domain 55 | * @tparam T type of the values in the domain 56 | */ 57 | class RanDomainSplitter[T : Ordering](domain: Array[(T, Double)]) extends DomainSplitter[T] with Serializable { 58 | private val halfWeight = domain.iterator.map(_._2).sum / 2.0 59 | 60 | private val (splitWeight, splitValue) = { 61 | val numCandidates = domain.length 62 | val ordered = domain.sortBy(_._1) 63 | var cumWeight = 0.0 64 | var i = 0 65 | while (cumWeight <= halfWeight && i < numCandidates - 1) { 66 | cumWeight += ordered(i)._2 67 | i += 1 68 | } 69 | (cumWeight, ordered(i)._1) 70 | } 71 | 72 | override val splitQuality: Double = 1.0 - abs(splitWeight - halfWeight)/halfWeight 73 | 74 | override def apply(x: T): Boolean = Ordering[T].gt(x, splitValue) 75 | } 76 | 77 | /** LPT splitter. 78 | * 79 | * Splits the domain using the longest processing time (LPT) algorithm with 80 | * only two "processors" (buckets). This is not space efficient for large 81 | * domains, roughly half of the domain values must be stored internally. 82 | * 83 | * @param domain value-weight pairs for the domain 84 | * @tparam T type of the values in the domain 85 | */ 86 | class LPTDomainSplitter[T](domain: Array[(T, Double)]) extends DomainSplitter[T] with Serializable { 87 | private val halfWeight = domain.iterator.map(_._2).sum / 2.0 88 | 89 | private val rightSet = mutable.Set.empty[T] 90 | 91 | private val splitWeight = { 92 | val ordered = domain.sortBy(-_._2) // decreasing order of weight 93 | var leftWeight = 0.0 94 | var rightWeight = 0.0 95 | ordered.foreach { case (value, weight) => 96 | if (leftWeight >= rightWeight) { 97 | rightSet.add(value) 98 | rightWeight += weight 99 | } else { 100 | // effectively putting value in the left set 101 | leftWeight += weight 102 | } 103 | } 104 | leftWeight 105 | } 106 | 107 | override val splitQuality: Double = 1.0 - abs(splitWeight - halfWeight)/halfWeight 108 | 109 | override def apply(x: T): Boolean = rightSet.contains(x) 110 | } 111 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dblink: Distributed End-to-End Bayesian Entity Resolution 2 | `dblink` is a Spark package for performing unsupervised entity resolution 3 | (ER) on structured data. 4 | It's based on a Bayesian model called `blink` 5 | [(Steorts, 2015)](https://projecteuclid.org/euclid.ba/1441790411), 6 | with extensions proposed in 7 | [(Marchant et al., 2021)](https://doi.org/10.1080/10618600.2020.1825451). 8 | Unlike many ER algorithms, `dblink` approximates the full posterior 9 | distribution over clusterings of records (into entities). 10 | This facilitates propagation of uncertainty to post-ER analysis, 11 | and provides a framework for answering probabilistic queries about entity 12 | membership. 13 | 14 | `dblink` approximates the posterior using Markov chain Monte Carlo. 15 | It writes samples (of clustering configurations) to disk in Parquet format. 16 | Diagnostic summary statistics are also written to disk in CSV format—these are 17 | useful for assessing convergence of the Markov chain. 18 | 19 | ## Documentation 20 | The step-by-step [guide](docs/guide.md) includes information about 21 | building dblink from source and running it locally on a test data set. 22 | Further details about configuration options for dblink is provided 23 | [here](docs/configuration.md). 24 | 25 | ## Example: RLdata 26 | Two synthetic data sets RLdata500 and RLdata10000 are included in the examples 27 | directory as CSV files. 28 | These data sets were extracted from the [RecordLinkage](https://cran.r-project.org/web/packages/RecordLinkage/index.html) 29 | R package and have been used as benchmark data sets in the entity resolution 30 | literature. 31 | Both contain 10 percent duplicates and are non-trivial to link due to added 32 | distortion. 33 | Standard entity resolution metrics can be computed as unique ids are provided 34 | in the files. 35 | Config files for these data sets are included in the examples directory: 36 | see `RLdata500.conf` and `RLdata10000.conf`. 37 | To run these examples locally (in Spark pseudocluster mode), 38 | ensure you've built or obtained the JAR according to the instructions 39 | above, then change into the source code directory and run the following 40 | command: 41 | ```bash 42 | $SPARK_HOME/bin/spark-submit \ 43 | --master "local[*]" \ 44 | --conf "spark.driver.extraJavaOptions=-Dlog4j.configuration=log4j.properties" \ 45 | --conf "spark.driver.extraClassPath=./target/scala-2.11/dblink-assembly-0.2.0.jar" \ 46 | ./target/scala-2.11/dblink-assembly-0.2.0.jar \ 47 | ./examples/RLdata500.conf 48 | ``` 49 | (To run with RLdata10000 instead, replace `RLdata500.conf` with 50 | `RLdata10000.conf`.) 51 | Note that the config file specifies that output will be saved in 52 | the `./examples/RLdata500_results/` (or `./examples/RLdata10000_results`) 53 | directory. 54 | 55 | ## How to: Add dblink as a project dependency 56 | _Note: This won't work yet. Waiting for project to be accepted._ 57 | 58 | Maven: 59 | ```xml 60 | 61 | com.github.cleanzr 62 | dblink 63 | 0.2.0 64 | 65 | ``` 66 | 67 | sbt: 68 | ```scala 69 | libraryDependencies += "com.github.cleanzr" % "dblink" % "0.2.0" 70 | ``` 71 | 72 | ## How to: Build a fat JAR 73 | You can build a fat JAR using sbt by running the following command from 74 | within the project directory: 75 | ```bash 76 | $ sbt assembly 77 | ``` 78 | 79 | This should output a JAR file at `./target/scala-2.11/dblink-assembly-0.2.0.jar` 80 | relative to the project directory. 81 | Note that the JAR file does not bundle Spark or Hadoop, but it does include 82 | all other dependencies. 83 | 84 | ## Contact 85 | If you encounter problems, please [open an issue](https://github.com/ngmarchant/dblink/issues) 86 | on GitHub. 87 | You can also contact the main developer by email ` gmail.com` 88 | 89 | ## License 90 | GPL-3 91 | 92 | ## Citing the package 93 | 94 | Marchant, N. G., Kaplan, A., Elazar, D. N., Rubinstein, B. I. P. and 95 | Steorts, R. C. (2021). d-blink: Distributed End-to-End Bayesian Entity 96 | Resolution. _Journal of Computational and Graphical Statistics_, _30_(2), 97 | 406–421. DOI: [10.1080/10618600.2020.1825451](https://doi.org/10.1080/10618600.2020.1825451) 98 | arXiv: [1909.06039](https://arxiv.org/abs/1909.06039). 99 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/SimilarityFn.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import org.apache.commons.lang3.StringUtils.getLevenshteinDistance 23 | 24 | /** Represents a truncated attribute similarity function */ 25 | sealed trait SimilarityFn extends Serializable { 26 | /** Compute the truncated similarity 27 | * 28 | * @param a an attribute value 29 | * @param b an attribute value 30 | * @return truncated similarity for `a` and `b` 31 | */ 32 | def getSimilarity(a: String, b: String): Double 33 | 34 | /** Maximum possible raw/truncated similarity score */ 35 | def maxSimilarity: Double 36 | 37 | /** Minimum possible raw similarity */ 38 | def minSimilarity: Double 39 | 40 | /** Truncation threshold. 41 | * All raw similarities below this threshold are truncated to zero. */ 42 | def threshold: Double 43 | 44 | def mkString: String 45 | } 46 | 47 | object SimilarityFn { 48 | /** Constant attribute similarity function. */ 49 | case object ConstantSimilarityFn extends SimilarityFn { 50 | override def getSimilarity(a: String, b: String): Double = 0.0 51 | 52 | override def maxSimilarity: Double = 0.0 53 | 54 | override def minSimilarity: Double = 0.0 55 | 56 | override def threshold: Double = 0.0 57 | 58 | override def mkString: String = "ConstantSimilarityFn" 59 | } 60 | 61 | abstract class NonConstantSimilarityFn(override val threshold: Double, 62 | override val maxSimilarity: Double) extends SimilarityFn { 63 | require(maxSimilarity > 0.0, "`maxSimilarity` must be positive") 64 | require(threshold >= minSimilarity && threshold < maxSimilarity, 65 | s"`threshold` must be in the interval [$minSimilarity, $maxSimilarity)") 66 | 67 | protected val transFactor: Double = maxSimilarity / (maxSimilarity - threshold) 68 | 69 | override def getSimilarity(a: String, b: String): Double = { 70 | val transSim = transFactor * (maxSimilarity * unitSimilarity(a, b) - threshold) 71 | if (transSim > 0.0) transSim else 0.0 72 | } 73 | 74 | override def minSimilarity: Double = 0.0 75 | 76 | /** Similarity function that provides scores on the unit interval */ 77 | protected def unitSimilarity(a: String, b: String): Double 78 | } 79 | 80 | 81 | /* Levenshtein attribute similarity function. */ 82 | class LevenshteinSimilarityFn(threshold: Double, maxSimilarity: Double) 83 | extends NonConstantSimilarityFn(threshold, maxSimilarity) { 84 | 85 | /** Similarity measure based on the normalized Levenshtein distance metric. 86 | * 87 | * See the following reference: 88 | * L. Yujian and L. Bo, “A Normalized Levenshtein Distance Metric,” 89 | * IEEE Transactions on Pattern Analysis and Machine Intelligence, 90 | * vol. 29, no. 6, pp. 1091–1095, Jun. 2007. 91 | */ 92 | override protected def unitSimilarity(a: String, b: String): Double = { 93 | val totalLength = a.length + b.length 94 | if (totalLength > 0) { 95 | val dist = getLevenshteinDistance(a, b).toDouble 96 | 1.0 - 2.0 * dist / (totalLength + dist) 97 | } else 1.0 98 | } 99 | 100 | override def mkString: String = s"LevenshteinSimilarityFn(threshold=$threshold, maxSimilarity=$maxSimilarity)" 101 | } 102 | 103 | object LevenshteinSimilarityFn { 104 | def apply(threshold: Double = 7.0, maxSimilarity: Double = 10.0) = 105 | new LevenshteinSimilarityFn(threshold, maxSimilarity) 106 | } 107 | } 108 | 109 | 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/partitioning/MutableBST.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.partitioning 21 | 22 | import scala.collection.mutable.ArrayBuffer 23 | import scala.math.pow 24 | 25 | /** ArrayBuffer-based binary search tree. 26 | * Not very generic. Consider combining with KDTreePartitioner. 27 | * Fast to search, but not space efficient for imbalanced trees (not 28 | * a problem for our application where the trees are small and ideally 29 | * perfectly balanced.) 30 | * 31 | * @tparam A type of field values 32 | */ 33 | class MutableBST[A : Ordering] extends Serializable { 34 | case class Node(var attributeId: Int, var splitter: DomainSplitter[A], var value: Int) 35 | 36 | private val _nodes = ArrayBuffer[Node](Node(-1, null, 0)) 37 | private var _numLevels: Int = 0 38 | private var _numLeaves: Int = 1 39 | 40 | def numLeaves: Int = _numLeaves 41 | def numLevels: Int = _numLevels 42 | def isEmpty: Boolean = _nodes.head.splitter == null 43 | def nonEmpty: Boolean = _nodes.head.splitter != null 44 | 45 | /** Search the tree for the leaf node corresponding to the given set of attribute 46 | * values, and return the "number" of the leaf. 47 | * 48 | * @param values a set of attribute values 49 | * @return leaf number (in 0, ..., numLeaves) 50 | */ 51 | def getLeafNumber(values: IndexedSeq[A]): Int = { 52 | val nodeId = getLeafNodeId(values) 53 | _nodes(nodeId).value 54 | } 55 | 56 | /** Search the tree for the leaf node corresponding to the given set of attribute 57 | * values. 58 | * 59 | * @param attributes a set of attribute values 60 | * @return id of the leaf node 61 | */ 62 | def getLeafNodeId(attributes: IndexedSeq[A]): Int = { 63 | // TODO: check whether point has too many or too few dimensions 64 | var found = false 65 | var nodeId = 0 66 | while (!found && nodeId < _nodes.length) { 67 | val node = _nodes(nodeId) 68 | if (node == null) println("ERROR!!!!") 69 | if (node.splitter != null) { 70 | val pointVal = attributes(node.attributeId) 71 | // descend to next level 72 | if (node.splitter(pointVal)) nodeId = 2*nodeId + 2 // go right 73 | else nodeId = 2*nodeId + 1 // go left 74 | } else { 75 | found = true // at a leaf node 76 | } 77 | } 78 | nodeId 79 | } 80 | 81 | /** Split an existing node of the tree into two leaf nodes 82 | * 83 | * @param nodeId id of node to split 84 | * @param attributeId attribute associated with split 85 | * @param splitter splitter (specifies whether to go left or right) 86 | */ 87 | def splitNode(nodeId: Int, attributeId: Int, 88 | splitter: DomainSplitter[A]): Unit = { 89 | /** Ensure that the node already exists as a leaf node. */ 90 | val node = if (nodeId < _nodes.length) _nodes(nodeId) else null 91 | require(node != null, "node does not exist") 92 | require(node.splitter == null, "node is already split") 93 | 94 | /** Update the node */ 95 | node.attributeId = attributeId 96 | node.splitter = splitter 97 | 98 | /** Insert children for the node */ 99 | val leftChildId = 2*nodeId + 1 100 | val rightChildId = 2*nodeId + 2 101 | /** If `node` was a leaf at the lowest level, need to add space for its children */ 102 | if (leftChildId >= _nodes.length) { 103 | // grow array for storing nodes 104 | _numLevels += 1 105 | val newLevelSize = pow(2.0, _numLevels).toInt 106 | _nodes ++= Iterator.fill(newLevelSize)(null) 107 | } 108 | _nodes(leftChildId) = Node(-1, null, node.value) 109 | _nodes(rightChildId) = Node(-1, null, _numLeaves) 110 | _numLeaves += 1 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/random/AliasSampler.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.random 21 | 22 | import com.github.cleanzr.dblink.random.AliasSampler._ 23 | import org.apache.commons.math3.random.RandomGenerator 24 | 25 | class AliasSampler(weights: Traversable[Double], checkWeights: Boolean, normalized: Boolean) extends Serializable { 26 | 27 | def size: Int = weights.size 28 | 29 | private val (_probabilities, _aliasTable) = computeTables(weights, checkWeights, normalized) 30 | 31 | def probabilities: IndexedSeq[Double] = _probabilities 32 | 33 | def sample()(implicit rand: RandomGenerator): Int = { 34 | val U = rand.nextDouble() * size 35 | val i = U.toInt 36 | if (U < _probabilities(i)) i else _aliasTable(i) 37 | } 38 | 39 | def sample(sampleSize: Int)(implicit rand: RandomGenerator): Array[Int] = { 40 | Array.tabulate(sampleSize)(_ => this.sample()) 41 | } 42 | } 43 | 44 | object AliasSampler { 45 | def apply(weights: Traversable[Double], checkWeights: Boolean = true, normalized: Boolean = false): AliasSampler = { 46 | new AliasSampler(weights, checkWeights, normalized) 47 | } 48 | 49 | private def computeTables(weights: Traversable[Double], 50 | checkWeights: Boolean, 51 | normalized: Boolean): (Array[Double], Array[Int]) = { 52 | val size = weights.size 53 | 54 | val totalWeight = if (checkWeights && !normalized) { 55 | weights.foldLeft(0.0) { (sum, weight) => 56 | if (weight < 0 || weight.isInfinity || weight.isNaN) { 57 | throw new IllegalArgumentException("invalid weight encountered") 58 | } 59 | sum + weight 60 | } 61 | } else if (checkWeights && normalized) { 62 | weights.foreach { weight => 63 | if (weight < 0 || weight.isInfinity || weight.isNaN) { 64 | throw new IllegalArgumentException("invalid weight encountered") 65 | } 66 | } 67 | 1.0 68 | } else if (!checkWeights && !normalized) { 69 | weights.sum 70 | } else 1.0 71 | 72 | require(totalWeight > 0.0, "zero probability mass") 73 | 74 | val probabilities = weights.map( weight => weight * size / totalWeight).toArray 75 | 76 | val aliasTable = Array.ofDim[Int](size) 77 | 78 | // Store small and large worklists in a single array. "Small" elements are stored on the left side, and 79 | // "large" elements are stored on the right side. 80 | val worklist = Array.ofDim[Int](size) 81 | var posSmall = 0 82 | var posLarge = size 83 | 84 | // Fill worklists 85 | var i = 0 86 | while(i < size) { 87 | if (probabilities(i) < 1.0) { 88 | worklist(posSmall) = i // add index i to the small worklist 89 | posSmall += 1 90 | } else { 91 | posLarge -= 1 // add index i to the large worklist 92 | worklist(posLarge) = i 93 | } 94 | i += 1 95 | } 96 | 97 | // Remove elements from worklists 98 | if (posSmall > 0) { // both small and large worklists contain elements 99 | posSmall = 0 100 | while (posSmall < size && posLarge < size) { 101 | val l = worklist(posSmall) 102 | posSmall += 1 // "remove" element from small worklist 103 | val g = worklist(posLarge) 104 | aliasTable(l) = g 105 | probabilities(g) = (probabilities(g) + probabilities(l)) - 1.0 106 | if (probabilities(g) < 1.0) posLarge += 1 // "move" element g to small worklist 107 | } 108 | } 109 | 110 | // Since the uniform on [0,1] in the `sample` method is shifted by i 111 | i = 0 112 | while (i < size) { 113 | probabilities(i) += i 114 | i += 1 115 | } 116 | 117 | (probabilities, aliasTable) 118 | } 119 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/partitioning/KDTreePartitioner.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink.partitioning 21 | 22 | import com.github.cleanzr.dblink.accumulators.MapDoubleAccumulator 23 | import com.github.cleanzr.dblink.partitioning.KDTreePartitioner.getNewSplits 24 | import org.apache.spark.broadcast.Broadcast 25 | import org.apache.spark.rdd.RDD 26 | import com.github.cleanzr.dblink.Logging 27 | 28 | class KDTreePartitioner[T : Ordering](numLevels: Int, 29 | attributeIds: Traversable[Int]) extends PartitionFunction[T] { 30 | require(numLevels >= 0, "`numLevels` must be non-negative.") 31 | if (numLevels > 0) require(attributeIds.nonEmpty, "`attributeIds` must be non-empty if `numLevels` > 0") 32 | 33 | override def numPartitions: Int = tree.numLeaves 34 | 35 | private val tree = new MutableBST[T] 36 | 37 | override def fit(records: RDD[Array[T]]): Unit = { 38 | if (numLevels > 0) require(attributeIds.nonEmpty, "non-empty list of attributes is required to build the tree.") 39 | 40 | val sc = records.sparkContext 41 | 42 | var level = 0 43 | var itAttributeIds = attributeIds.toIterator 44 | while (level < numLevels) { 45 | /** Go back to the beginning of `attributeIds` if we reach the end */ 46 | val attrId = if (itAttributeIds.hasNext) itAttributeIds.next() else { 47 | itAttributeIds = attributeIds.toIterator // reset 48 | itAttributeIds.next() 49 | } 50 | KDTreePartitioner.info(s"Splitting on attribute $attrId at level $level.") 51 | val bcTree = sc.broadcast(tree) 52 | val newLevel = getNewSplits(records, bcTree, attrId) 53 | newLevel.foreach { case (nodeId, splitter) => 54 | if (splitter.splitQuality <= 0.9) 55 | KDTreePartitioner.warn(s"Poor quality split (${splitter.splitQuality*100}%) at node $nodeId.") 56 | tree.splitNode(nodeId, attrId, splitter) 57 | } 58 | level += 1 59 | } 60 | } 61 | 62 | override def getPartitionId(attributeValues: Array[T]): Int = tree.getLeafNumber(attributeValues) 63 | 64 | override def mkString: String = { 65 | if (numLevels == 0) s"KDTreePartitioner(numLevels=0)" 66 | else s"KDTreePartitioner(numLevels=$numLevels, attributeIds=${attributeIds.mkString("[",",","]")})" 67 | } 68 | } 69 | 70 | object KDTreePartitioner extends Logging { 71 | 72 | def apply[T : Ordering](numLevels: Int, attributeIds: Traversable[Int]): KDTreePartitioner[T] = { 73 | new KDTreePartitioner(numLevels, attributeIds) 74 | } 75 | 76 | def apply[T : Ordering](): KDTreePartitioner[T] = { 77 | new KDTreePartitioner(0, Traversable()) 78 | } 79 | 80 | private def getNewSplits[T : Ordering](records: RDD[Array[T]], 81 | bcTree: Broadcast[MutableBST[T]], 82 | attributeId: Int): Map[Int, DomainSplitter[T]] = { 83 | val sc = records.sparkContext 84 | 85 | val acc = new MapDoubleAccumulator[(Int, T)] 86 | sc.register(acc, s"tree level builder") 87 | 88 | /** Iterate over the records counting the number of occurrences of each 89 | * attribute value per node */ 90 | records.foreachPartition { partition => 91 | val tree = bcTree.value 92 | partition.foreach { attributeValues => 93 | val attributeValue = attributeValues(attributeId) 94 | val nodeId = tree.getLeafNodeId(attributeValues) 95 | acc.add((nodeId, attributeValue), 1L) 96 | } 97 | } 98 | 99 | /** Group by node and compute the splitters for each node */ 100 | acc.value.toArray 101 | .groupBy(_._1._1) 102 | .mapValues { x => 103 | DomainSplitter[T](x.map { case ((_, value), weight) => (value, weight) }) 104 | } 105 | } 106 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/CustomKryoRegistrator.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import com.esotericsoftware.kryo.Kryo 23 | import com.github.cleanzr.dblink.GibbsUpdates.{EntityInvertedIndex, LinksIndex} 24 | import com.github.cleanzr.dblink.partitioning.{DomainSplitter, KDTreePartitioner, LPTScheduler, MutableBST, PartitionFunction, SimplePartitioner} 25 | import com.github.cleanzr.dblink.random.{AliasSampler, DiscreteDist, IndexNonUniformDiscreteDist, NonUniformDiscreteDist} 26 | import com.github.cleanzr.dblink.GibbsUpdates.{EntityInvertedIndex, LinksIndex} 27 | import com.github.cleanzr.dblink.partitioning._ 28 | import com.github.cleanzr.dblink.random.{AliasSampler, DiscreteDist, IndexNonUniformDiscreteDist, NonUniformDiscreteDist} 29 | import org.apache.commons.math3.random.MersenneTwister 30 | import org.apache.spark.serializer.KryoRegistrator 31 | 32 | class CustomKryoRegistrator extends KryoRegistrator { 33 | override def registerClasses(kryo: Kryo): Unit = { 34 | kryo.register(classOf[EntRecCluster]) 35 | kryo.register(classOf[PartEntRecCluster]) 36 | kryo.register(classOf[Record[_]]) 37 | kryo.register(classOf[Entity]) 38 | kryo.register(classOf[DistortedValue]) 39 | kryo.register(classOf[SummaryVars]) 40 | kryo.register(classOf[Attribute]) 41 | kryo.register(classOf[SimilarityFn]) 42 | kryo.register(classOf[IndexedAttribute]) 43 | kryo.register(classOf[BetaShapeParameters]) 44 | kryo.register(classOf[EntityInvertedIndex]) 45 | kryo.register(classOf[LinksIndex]) 46 | kryo.register(classOf[DomainSplitter[_]]) 47 | kryo.register(classOf[KDTreePartitioner[_]]) 48 | kryo.register(classOf[LPTScheduler[_,_]]) 49 | kryo.register(classOf[LPTScheduler.Partition[_,_]]) 50 | kryo.register(classOf[MutableBST[_]]) 51 | kryo.register(classOf[PartitionFunction[_]]) 52 | kryo.register(classOf[SimplePartitioner[_]]) 53 | kryo.register(classOf[AliasSampler]) 54 | kryo.register(classOf[DiscreteDist[_]]) 55 | kryo.register(classOf[IndexNonUniformDiscreteDist]) 56 | kryo.register(classOf[NonUniformDiscreteDist[_]]) 57 | kryo.register(Class.forName("org.apache.spark.sql.execution.columnar.CachedBatch")) 58 | kryo.register(Class.forName("[[B")) 59 | kryo.register(Class.forName("org.apache.spark.sql.catalyst.expressions.GenericInternalRow")) 60 | kryo.register(Class.forName("org.apache.spark.unsafe.types.UTF8String")) 61 | kryo.register(classOf[Array[Object]]) 62 | kryo.register(classOf[Array[EntRecCluster]]) 63 | kryo.register(classOf[RecordsCache]) 64 | kryo.register(classOf[Parameters]) 65 | kryo.register(classOf[AliasSampler]) 66 | kryo.register(classOf[MersenneTwister]) 67 | kryo.register(classOf[DistortionProbs]) 68 | } 69 | } 70 | // 71 | //class AugmentedRecordSerializer extends Serializer[EntRecPair] { 72 | // override def write(kryo: Kryo, output: Output, `object`: EntRecPair): Unit = { 73 | // output.writeLong(`object`.latId, true) 74 | // output.writeString(`object`.recId) 75 | // output.writeString(`object`.fileId) 76 | // val numLatFields = `object`.latFieldValues.length 77 | // output.writeInt(numLatFields) 78 | // `object`.recFieldValues.foreach(v => output.writeString(v)) 79 | // `object`.distortions.foreach(v => output.writeBoolean(v)) 80 | // `object`.latFieldValues.foreach(v => output.writeString(v)) 81 | // } 82 | // 83 | // override def read(kryo: Kryo, input: Input, `type`: Class[EntRecPair]): EntRecPair = { 84 | // val latId = input.readLong(true) 85 | // val recId = input.readString() 86 | // val fileId = input.readString() 87 | // val numLatFields = input.readInt() 88 | // val numRecFields = if (recId == null) 0 else numLatFields 89 | // val fieldValues = Array.fill(numRecFields)(input.readString()) 90 | // val distortions = Array.fill(numRecFields)(input.readBoolean()) 91 | // val latFieldValues = Array.fill(numLatFields)(input.readString()) 92 | // LatRecPair(latId, latFieldValues, recId, fileId, fieldValues, distortions) 93 | // } 94 | //} 95 | -------------------------------------------------------------------------------- /src/test/scala/com/github/cleanzr/dblink/AttributeIndexTest.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import org.apache.commons.math3.random.{MersenneTwister, RandomGenerator} 23 | import org.apache.spark.{SparkConf, SparkContext} 24 | import org.scalatest._ 25 | import com.github.cleanzr.dblink.SimilarityFn._ 26 | 27 | class AttributeIndexTest extends FlatSpec with AttributeIndexBehaviors with Matchers { 28 | 29 | implicit val rand: RandomGenerator = new MersenneTwister(0) 30 | implicit val sc: SparkContext = { 31 | val conf = new SparkConf() 32 | .setMaster("local[1]") 33 | .setAppName("d-blink test") 34 | SparkContext.getOrCreate(conf) 35 | } 36 | sc.setLogLevel("WARN") 37 | 38 | lazy val stateWeights = Map("Australian Capital Territory" -> 0.410, 39 | "New South Wales" -> 7.86, "Northern Territory" -> 0.246, "Queensland" -> 4.92, 40 | "South Australia" -> 1.72, "Tasmania" -> 0.520, "Victoria" -> 6.32, 41 | "Western Australia" -> 2.58) 42 | 43 | lazy val constantIndex = AttributeIndex(stateWeights, ConstantSimilarityFn) 44 | 45 | lazy val nonConstantIndex = AttributeIndex(stateWeights, LevenshteinSimilarityFn(5.0, 10.0)) 46 | 47 | /** The following results are for an index with LevenshteinSimilarityFn(5.0, 10.0) */ 48 | def stateSimNormalizations = Map("Australian Capital Territory" -> 0.0027140755302269004, 49 | "New South Wales" -> 1.4193905286944585E-4, 50 | "Northern Territory" -> 0.00451528932619675, 51 | "Queensland" -> 2.2673706056780077E-4, 52 | "South Australia" -> 6.465919296781136E-4, 53 | "Tasmania" -> 0.00214117348291189, 54 | "Victoria" -> 1.7651936247903708E-4, 55 | "Western Australia" -> 4.317863538883541E-4) 56 | 57 | def simValuesSA = Map(7 -> 39.813678188084864, 4 -> 22026.465794806718) 58 | 59 | val expSimSAWA = 39.813678188084864 60 | val expSimVICTAS = 1.0 61 | 62 | "An attribute index (with a constant similarity function)" should behave like genericAttributeIndex(constantIndex, stateWeights) 63 | 64 | it should "return a similarity normalization constant of 1.0 for all values" in { 65 | assert((0 until constantIndex.numValues).forall(valueId => constantIndex.simNormalizationOf(valueId) == 1.0)) 66 | } 67 | 68 | it should "return no similar values for all values" in { 69 | assert((0 until constantIndex.numValues).forall(valueId => constantIndex.simValuesOf(valueId).isEmpty)) 70 | } 71 | 72 | it should "return an exponentiated similarity score of 1.0 for all value pairs" in { 73 | val allValueIds = 0 until constantIndex.numValues 74 | val allValueIdPairs = allValueIds.flatMap(valueId1 => allValueIds.map(valueId2 => (valueId1, valueId2))) 75 | assert(allValueIdPairs.forall { case (valueId1, valueId2) => constantIndex.expSimOf(valueId1, valueId2) == 1.0 }) 76 | } 77 | 78 | "An attribute index (with a non-constant similarity function)" should behave like genericAttributeIndex(nonConstantIndex, stateWeights) 79 | 80 | it should "return the correct similarity normalization constants for all values" in { 81 | assert(stateSimNormalizations.forall { case (stringValue, trueSimNorm) => 82 | nonConstantIndex.simNormalizationOf(nonConstantIndex.valueIdxOf(stringValue)) === (trueSimNorm +- 1e-4) 83 | }) 84 | } 85 | 86 | it should "return the correct similar values for a query value" in { 87 | val testSimValues = nonConstantIndex.simValuesOf(nonConstantIndex.valueIdxOf("South Australia")) 88 | assert(simValuesSA.keySet === testSimValues.keySet) 89 | assert(simValuesSA.forall { case (valueId, trueExpSim) => testSimValues(valueId) === (trueExpSim +- 1e-4) }) 90 | } 91 | 92 | it should "return the correct exponeniated similarity score for a query value pair" in { 93 | val valueIdSA = nonConstantIndex.valueIdxOf("South Australia") 94 | val valueIdWA = nonConstantIndex.valueIdxOf("Western Australia") 95 | assert(nonConstantIndex.expSimOf(valueIdSA, valueIdWA) === (expSimSAWA +- 1e-4)) 96 | val valueIdVIC = nonConstantIndex.valueIdxOf("Victoria") 97 | val valueIdTAS = nonConstantIndex.valueIdxOf("Tasmania") 98 | assert(nonConstantIndex.expSimOf(valueIdVIC, valueIdTAS) === (expSimVICTAS +- 1e-4)) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/package.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr 21 | 22 | import com.github.cleanzr.dblink.SimilarityFn.ConstantSimilarityFn 23 | 24 | /** Types/case classes used throughout */ 25 | package object dblink { 26 | import org.apache.spark.rdd.RDD 27 | 28 | type FileId = String 29 | type RecordId = String 30 | type PartitionId = Int 31 | type EntityId = Int 32 | type ValueId = Int 33 | type AttributeId = Int 34 | type Partitions = RDD[(PartitionId, EntRecCluster)] 35 | type AggDistortions = Map[(AttributeId, FileId), Long] 36 | type Cluster = Set[RecordId] 37 | type RecordPair = (RecordId, RecordId) 38 | 39 | case class MostProbableCluster(recordId: RecordId, cluster: Set[RecordId], frequency: Double) 40 | 41 | /** Record (row) in the input data 42 | * 43 | * @param id a unique identifier for the record (must be unique across 44 | * _all_ files) 45 | * @param fileId file identifier for the record 46 | * @param values attribute values for the record 47 | * @tparam T value type 48 | */ 49 | case class Record[T](id: RecordId, 50 | fileId: FileId, 51 | values: Array[T]) { 52 | def mkString: String = s"Record(id=$id, fileId=$fileId, values=${values.mkString(",")})" 53 | } 54 | 55 | 56 | /** Latent entity 57 | * 58 | * @param values attribute values for the entity. 59 | */ 60 | case class Entity(values: Array[ValueId]) { 61 | def mkString: String = s"Entity(${values.mkString(",")})" 62 | } 63 | 64 | 65 | /** Attribute value subject to distortion 66 | * 67 | * @param value the attribute value 68 | * @param distorted whether the value is distorted 69 | */ 70 | case class DistortedValue(value: ValueId, distorted: Boolean) { 71 | def mkString: String = s"DistortedValue(value=$value, distorted=$distorted" 72 | } 73 | 74 | 75 | /** 76 | * Entity-record cluster 77 | * @param entity a latent entity 78 | * @param records records linked to the entity 79 | */ 80 | case class EntRecCluster(entity: Entity, 81 | records: Option[Array[Record[DistortedValue]]]) { 82 | def mkString: String = { 83 | records match { 84 | case Some(r) => s"${entity.mkString}\t->\t${r.mkString(", ")}" 85 | case None => entity.mkString 86 | } 87 | } 88 | } 89 | 90 | /** Linkage state 91 | * 92 | * Represents the linkage structure within a partition for a single iteration. 93 | */ 94 | case class LinkageState(iteration: Long, 95 | partitionId: PartitionId, 96 | linkageStructure: Seq[Seq[RecordId]]) 97 | 98 | 99 | /** Partition-entity-record cluster triple 100 | * 101 | * This class is used in the Dataset representation of the partitions---not 102 | * the RDD representation. 103 | * 104 | * @param partitionId identifier for the partition 105 | * @param entRecCluster entity cluster of records 106 | */ 107 | case class PartEntRecCluster(partitionId: PartitionId, entRecCluster: EntRecCluster) 108 | 109 | /** Container for the summary variables 110 | * 111 | * @param numIsolates 112 | * @param logLikelihood 113 | * @param aggDistortions 114 | * @param recDistortions 115 | */ 116 | case class SummaryVars(numIsolates: Long, 117 | logLikelihood: Double, 118 | aggDistortions: AggDistortions, 119 | recDistortions: Map[Int, Long]) 120 | 121 | 122 | /** Specifications for an attribute 123 | * 124 | * @param name column name of the attribute (should match original DataFrame) 125 | * @param similarityFn an attribute similarity function 126 | * @param distortionPrior prior for the distortion 127 | */ 128 | case class Attribute(name: String, 129 | similarityFn: SimilarityFn, 130 | distortionPrior: BetaShapeParameters) { 131 | /** Whether the similarity function for this attribute is constant or not*/ 132 | def isConstant: Boolean = { 133 | similarityFn match { 134 | case ConstantSimilarityFn => true 135 | case _ => false 136 | } 137 | } 138 | } 139 | 140 | /** Specifications for an indexed attribute 141 | * 142 | * @param name column name of the attribute (should match original DataFrame) 143 | * @param similarityFn an attribute similarity function 144 | * @param distortionPrior prior for the distortion 145 | * @param index index for the attribute 146 | */ 147 | case class IndexedAttribute(name: String, 148 | similarityFn: SimilarityFn, 149 | distortionPrior: BetaShapeParameters, 150 | index: AttributeIndex) { 151 | /** Whether the similarity function for this attribute is constant or not*/ 152 | def isConstant: Boolean = { 153 | similarityFn match { 154 | case ConstantSimilarityFn => true 155 | case _ => false 156 | } 157 | } 158 | } 159 | 160 | 161 | /** Container for shape parameters of the Beta distribution 162 | * 163 | * @param alpha "alpha" shape parameter 164 | * @param beta "beta" shape parameter 165 | */ 166 | case class BetaShapeParameters(alpha: Double, beta: Double) { 167 | require(alpha > 0 && beta > 0, "shape parameters must be positive") 168 | 169 | def mkString: String = s"BetaShapeParameters(alpha=$alpha, beta=$beta)" 170 | } 171 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/RecordsCache.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import com.github.cleanzr.dblink.accumulators.MapLongAccumulator 23 | import org.apache.spark.SparkContext 24 | import org.apache.spark.rdd.RDD 25 | 26 | /** Container to store statistics/metadata for a collection of records and 27 | * facilitate sampling from the attribute domains. 28 | * 29 | * This container is broadcast to each executor. 30 | * 31 | * @param indexedAttributes indexes for the attributes 32 | * @param fileSizes number of records in each file 33 | */ 34 | class RecordsCache(val indexedAttributes: IndexedSeq[IndexedAttribute], 35 | val fileSizes: Map[FileId, Long], 36 | val missingCounts: Option[Map[(FileId, AttributeId), Long]] = None) extends Serializable { 37 | 38 | def distortionPrior: Iterator[BetaShapeParameters] = indexedAttributes.iterator.map(_.distortionPrior) 39 | 40 | /** Number of records across all files */ 41 | val numRecords: Long = fileSizes.values.sum 42 | 43 | /** Number of attributes used for matching */ 44 | def numAttributes: Int = indexedAttributes.length 45 | 46 | /** Transform to value ids 47 | * 48 | * @param records an RDD of records in raw form (string attribute values) 49 | * @return an RDD of records where the string attribute values are replaced 50 | * by integer value ids 51 | */ 52 | def transformRecords(records: RDD[Record[String]]): RDD[Record[ValueId]] = 53 | RecordsCache._transformRecords(records, indexedAttributes) 54 | } 55 | 56 | object RecordsCache extends Logging { 57 | 58 | /** Build RecordsCache 59 | * 60 | * @param records an RDD of records in raw form (string attribute values) 61 | * @param attributeSpecs specifications for each record attribute. Must match 62 | * the order of attributes in `records`. 63 | * @param expectedMaxClusterSize largest expected record cluster size. Used 64 | * as a hint when precaching distributions 65 | * over the non-constant attribute domain. 66 | * @return a RecordsCache 67 | */ 68 | def apply(records: RDD[Record[String]], 69 | attributeSpecs: IndexedSeq[Attribute], 70 | expectedMaxClusterSize: Int): RecordsCache = { 71 | val firstRecord = records.take(1).head 72 | require(firstRecord.values.length == attributeSpecs.length, "attribute specifications do not match the records") 73 | 74 | /** Use accumulators to gather record stats in one pass */ 75 | val accFileSizes = new MapLongAccumulator[FileId] 76 | implicit val sc: SparkContext = records.sparkContext 77 | sc.register(accFileSizes, "number of records per file") 78 | 79 | val accMissingCounts = new MapLongAccumulator[(FileId, AttributeId)] 80 | sc.register(accMissingCounts, s"missing counts per file and attribute") 81 | 82 | val accValueCounts = attributeSpecs.map{ attribute => 83 | val acc = new MapLongAccumulator[String] 84 | sc.register(acc, s"value counts for attribute ${attribute.name}") 85 | acc 86 | } 87 | 88 | info("Gathering statistics from source data files.") 89 | /** Get file and value counts in a single foreach action */ 90 | records.foreach { case Record(_, fileId, values) => 91 | accFileSizes.add((fileId, 1L)) 92 | values.zipWithIndex.foreach { case (value, attrId) => 93 | if (value != null) accValueCounts(attrId).add(value, 1L) 94 | else accMissingCounts.add(((fileId, attrId), 1L)) 95 | } 96 | } 97 | 98 | val missingCounts = accMissingCounts.value 99 | val fileSizes = accFileSizes.value 100 | val totalRecords = fileSizes.values.sum 101 | val percentageMissing = 100.0 * missingCounts.values.sum / (totalRecords * attributeSpecs.size) 102 | if (percentageMissing >= 0.0) { 103 | info(f"Finished gathering statistics from $totalRecords records across ${fileSizes.size} file(s). $percentageMissing%.3f%% of the record attribute values are missing.") 104 | } else { 105 | info(s"Finished gathering statistics from $totalRecords records across ${fileSizes.size} file(s). There are no missing record attribute values.") 106 | } 107 | 108 | /** Build an index for each attribute (generates a mapping from strings -> integers) */ 109 | val indexedAttributes = attributeSpecs.zipWithIndex.map { case (attribute, attrId) => 110 | val valuesWeights = accValueCounts(attrId).value.mapValues(_.toDouble) 111 | info(s"Indexing attribute '${attribute.name}' with ${valuesWeights.size} unique values.") 112 | val index = AttributeIndex(valuesWeights, attribute.similarityFn, 113 | Some(1 to expectedMaxClusterSize)) 114 | IndexedAttribute(attribute.name, attribute.similarityFn, attribute.distortionPrior, index) 115 | } 116 | 117 | new RecordsCache(indexedAttributes, fileSizes, Some(missingCounts)) 118 | } 119 | 120 | private def _transformRecords(records: RDD[Record[String]], 121 | indexedAttributes: IndexedSeq[IndexedAttribute]): RDD[Record[ValueId]] = { 122 | val firstRecord = records.take(1).head 123 | require(firstRecord.values.length == indexedAttributes.length, "attribute specifications do not match the records") 124 | 125 | records.mapPartitions( partition => { 126 | partition.map { record => 127 | val mappedValues = record.values.zipWithIndex.map { case (stringValue, attrId) => 128 | if (stringValue != null) indexedAttributes(attrId).index.valueIdxOf(stringValue) 129 | else -1 130 | } 131 | Record[ValueId](record.id, record.fileId, mappedValues) 132 | } 133 | }, preservesPartitioning = true) 134 | } 135 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/Sampler.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import com.github.cleanzr.dblink.util.{BufferedRDDWriter, PeriodicRDDCheckpointer} 23 | import org.apache.spark.SparkContext 24 | import org.apache.spark.sql.SparkSession 25 | 26 | object Sampler extends Logging { 27 | 28 | /** Generates posterior samples by successively applying the Markov transition operator starting from a given 29 | * initial state. The samples are written to the path provided. 30 | * 31 | * @param initialState The initial state of the Markov chain. 32 | * @param sampleSize A positive integer specifying the desired number of samples (after burn-in and thinning) 33 | * @param outputPath A string specifying the path to save output (includes samples and diagnostics). HDFS and 34 | * local filesystems are supported. 35 | * @param burninInterval A non-negative integer specifying the number of initial samples to discard as burn-in. 36 | * The default is 0, which means no burn-in is applied. 37 | * @param thinningInterval A positive integer specifying the period for saving samples to disk. The default value is 38 | * 1, which means no thinning is applied. 39 | * @param checkpointInterval A non-negative integer specifying the period for checkpointing. This prevents the 40 | * lineage of the RDD (internal to state) from becoming too long. Smaller values require 41 | * more frequent writing to disk, larger values require more CPU/memory. The default 42 | * value of 20, is a reasonable trade-off. 43 | * @param writeBufferSize A positive integer specifying the number of samples to queue in memory before writing to 44 | * disk. 45 | * @param collapsedEntityIds A Boolean specifying whether to collapse the distortions when updating the entity ids. 46 | * Defaults to false. 47 | * @param collapsedEntityValues A Boolean specifying whether to collapse the distotions when updating the entity 48 | * values. Defaults to true. 49 | * @return The final state of the Markov chain. 50 | */ 51 | def sample(initialState: State, 52 | sampleSize: Int, 53 | outputPath: String, 54 | burninInterval: Int = 0, 55 | thinningInterval: Int = 1, 56 | checkpointInterval: Int = 20, 57 | writeBufferSize: Int = 10, 58 | collapsedEntityIds: Boolean = false, 59 | collapsedEntityValues: Boolean = true, 60 | sequential: Boolean = false): State = { 61 | require(sampleSize > 0, "`sampleSize` must be positive.") 62 | require(burninInterval >= 0, "`burninInterval` must be non-negative.") 63 | require(thinningInterval > 0, "`thinningInterval` must be positive.") 64 | require(checkpointInterval >= 0, "`checkpointInterval` must be non-negative.") 65 | require(writeBufferSize > 0, "`writeBufferSize` must be positive.") 66 | // TODO: ensure that savePath is a valid directory 67 | 68 | var sampleCtr = 0 // counter for number of samples produced (excludes burn-in/thinning) 69 | var state = initialState // current state 70 | val initialIteration = initialState.iteration // initial iteration (need not be zero) 71 | val continueChain = initialIteration != 0 // whether we're continuing a previous chain 72 | 73 | implicit val spark: SparkSession = SparkSession.builder().getOrCreate() 74 | implicit val sc: SparkContext = spark.sparkContext 75 | import spark.implicits._ 76 | 77 | // Set-up writers 78 | val linkagePath = outputPath + "linkage-chain.parquet" 79 | var linkageWriter = BufferedRDDWriter[LinkageState](writeBufferSize, linkagePath, continueChain) 80 | val diagnosticsPath = outputPath + "diagnostics.csv" 81 | val diagnosticsWriter = new DiagnosticsWriter(diagnosticsPath, continueChain) 82 | val checkpointer = new PeriodicRDDCheckpointer[(PartitionId, EntRecCluster)](checkpointInterval, sc) 83 | 84 | if (!continueChain && burninInterval == 0) { 85 | // Need to record initial state 86 | checkpointer.update(state.partitions) 87 | linkageWriter = linkageWriter.append(state.getLinkageStructure()) 88 | diagnosticsWriter.writeRow(state) 89 | } 90 | 91 | if (burninInterval > 0) info(s"Running burn-in for $burninInterval iterations.") 92 | while (sampleCtr < sampleSize) { 93 | state = state.nextState(checkpointer = checkpointer, collapsedEntityIds, collapsedEntityValues, sequential) 94 | 95 | //newState.partitions.persist(StorageLevel.MEMORY_ONLY_SER) 96 | //state = newState 97 | val completedIterations = state.iteration - initialIteration 98 | 99 | if (completedIterations - 1 == burninInterval) { 100 | if (burninInterval > 0) info("Burn-in complete.") 101 | info(s"Generating $sampleSize sample(s) with thinningInterval=$thinningInterval.") 102 | } 103 | 104 | if (completedIterations >= burninInterval) { 105 | // Finished burn-in, so start writing samples to disk (accounting for thinning) 106 | if ((completedIterations - burninInterval)%thinningInterval == 0) { 107 | linkageWriter = linkageWriter.append(state.getLinkageStructure()) 108 | diagnosticsWriter.writeRow(state) 109 | sampleCtr += 1 110 | } 111 | } 112 | 113 | // Ensure writer is kept alive (may die if burninInterval/thinningInterval is large) 114 | diagnosticsWriter.progress() 115 | } 116 | 117 | info("Sampling complete. Writing final state and remaining samples to disk.") 118 | linkageWriter = linkageWriter.flush() 119 | diagnosticsWriter.close() 120 | state.save(outputPath) 121 | checkpointer.deleteAllCheckpoints() 122 | info(s"Finished writing to disk at $outputPath") 123 | state 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/util/PeriodicCheckpointer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.github.cleanzr.dblink.util 19 | 20 | import scala.collection.mutable 21 | 22 | import org.apache.hadoop.conf.Configuration 23 | import org.apache.hadoop.fs.Path 24 | 25 | import org.apache.spark.SparkContext 26 | import org.apache.spark.internal.Logging 27 | 28 | 29 | /** 30 | * This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs 31 | * (such as Graphs and DataFrames). In documentation, we use the phrase "Dataset" to refer to 32 | * the distributed data type (RDD, Graph, etc.). 33 | * 34 | * Specifically, this abstraction automatically handles persisting and (optionally) checkpointing, 35 | * as well as unpersisting and removing checkpoint files. 36 | * 37 | * Users should call update() when a new Dataset has been created, 38 | * before the Dataset has been materialized. After updating [[PeriodicCheckpointer]], users are 39 | * responsible for materializing the Dataset to ensure that persisting and checkpointing actually 40 | * occur. 41 | * 42 | * When update() is called, this does the following: 43 | * - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets. 44 | * - Unpersist Datasets from queue until there are at most 3 persisted Datasets. 45 | * - If using checkpointing and the checkpoint interval has been reached, 46 | * - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets. 47 | * - Remove older checkpoints. 48 | * 49 | * WARNINGS: 50 | * - This class should NOT be copied (since copies may conflict on which Datasets should be 51 | * checkpointed). 52 | * - This class removes checkpoint files once later Datasets have been checkpointed. 53 | * However, references to the older Datasets will still return isCheckpointed = true. 54 | * 55 | * @param checkpointInterval Datasets will be checkpointed at this interval. 56 | * If this interval was set as -1, then checkpointing will be disabled. 57 | * @param sc SparkContext for the Datasets given to this checkpointer 58 | * @tparam T Dataset type, such as RDD[Double] 59 | */ 60 | abstract class PeriodicCheckpointer[T](val checkpointInterval: Int, 61 | val sc: SparkContext) extends Logging { 62 | 63 | /** FIFO queue of past checkpointed Datasets */ 64 | private val checkpointQueue = mutable.Queue[T]() 65 | 66 | /** FIFO queue of past persisted Datasets */ 67 | private val persistedQueue = mutable.Queue[T]() 68 | 69 | /** Number of times [[update()]] has been called */ 70 | private var updateCount = 0 71 | 72 | /** 73 | * Update with a new Dataset. Handle persistence and checkpointing as needed. 74 | * Since this handles persistence and checkpointing, this should be called before the Dataset 75 | * has been materialized. 76 | * 77 | * @param newData New Dataset created from previous Datasets in the lineage. 78 | */ 79 | def update(newData: T): Unit = { 80 | persist(newData) 81 | persistedQueue.enqueue(newData) 82 | // We try to maintain 2 Datasets in persistedQueue to support the semantics of this class: 83 | // Users should call [[update()]] when a new Dataset has been created, 84 | // before the Dataset has been materialized. 85 | while (persistedQueue.size > 3) { 86 | val dataToUnpersist = persistedQueue.dequeue() 87 | unpersist(dataToUnpersist) 88 | } 89 | updateCount += 1 90 | 91 | // Handle checkpointing (after persisting) 92 | if (checkpointInterval != -1 && (updateCount % checkpointInterval) == 0 93 | && sc.getCheckpointDir.nonEmpty) { 94 | // Add new checkpoint before removing old checkpoints. 95 | checkpoint(newData) 96 | checkpointQueue.enqueue(newData) 97 | // Remove checkpoints before the latest one. 98 | var canDelete = true 99 | while (checkpointQueue.size > 1 && canDelete) { 100 | // Delete the oldest checkpoint only if the next checkpoint exists. 101 | if (isCheckpointed(checkpointQueue.head)) { 102 | removeCheckpointFile() 103 | } else { 104 | canDelete = false 105 | } 106 | } 107 | } 108 | } 109 | 110 | /** Checkpoint the Dataset */ 111 | protected def checkpoint(data: T): Unit 112 | 113 | /** Return true iff the Dataset is checkpointed */ 114 | protected def isCheckpointed(data: T): Boolean 115 | 116 | /** 117 | * Persist the Dataset. 118 | * Note: This should handle checking the current [[StorageLevel]] of the Dataset. 119 | */ 120 | protected def persist(data: T): Unit 121 | 122 | /** Unpersist the Dataset */ 123 | protected def unpersist(data: T): Unit 124 | 125 | /** Get list of checkpoint files for this given Dataset */ 126 | protected def getCheckpointFiles(data: T): Iterable[String] 127 | 128 | /** 129 | * Call this to unpersist the Dataset. 130 | */ 131 | def unpersistDataSet(): Unit = { 132 | while (persistedQueue.nonEmpty) { 133 | val dataToUnpersist = persistedQueue.dequeue() 134 | unpersist(dataToUnpersist) 135 | } 136 | } 137 | 138 | /** 139 | * Call this at the end to delete any remaining checkpoint files. 140 | */ 141 | def deleteAllCheckpoints(): Unit = { 142 | while (checkpointQueue.nonEmpty) { 143 | removeCheckpointFile() 144 | } 145 | } 146 | 147 | /** 148 | * Call this at the end to delete any remaining checkpoint files, except for the last checkpoint. 149 | * Note that there may not be any checkpoints at all. 150 | */ 151 | def deleteAllCheckpointsButLast(): Unit = { 152 | while (checkpointQueue.size > 1) { 153 | removeCheckpointFile() 154 | } 155 | } 156 | 157 | /** 158 | * Get all current checkpoint files. 159 | * This is useful in combination with [[deleteAllCheckpointsButLast()]]. 160 | */ 161 | def getAllCheckpointFiles: Array[String] = { 162 | checkpointQueue.flatMap(getCheckpointFiles).toArray 163 | } 164 | 165 | /** 166 | * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files. 167 | * This prints a warning but does not fail if the files cannot be removed. 168 | */ 169 | private def removeCheckpointFile(): Unit = { 170 | val old = checkpointQueue.dequeue() 171 | // Since the old checkpoint is not deleted by Spark, we manually delete it. 172 | getCheckpointFiles(old).foreach( 173 | PeriodicCheckpointer.removeCheckpointFile(_, sc.hadoopConfiguration)) 174 | } 175 | } 176 | 177 | object PeriodicCheckpointer extends Logging { 178 | 179 | /** Delete a checkpoint file, and log a warning if deletion fails. */ 180 | def removeCheckpointFile(checkpointFile: String, conf: Configuration): Unit = { 181 | try { 182 | val path = new Path(checkpointFile) 183 | val fs = path.getFileSystem(conf) 184 | fs.delete(path, true) 185 | } catch { 186 | case e: Exception => 187 | logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + 188 | checkpointFile) 189 | } 190 | } 191 | } -------------------------------------------------------------------------------- /docs/guide.md: -------------------------------------------------------------------------------- 1 | # Step-by-step guide 2 | This guide will take you through the steps involved in running dblink on a 3 | small test data set. To make the guide accessible, we assume that you're 4 | running dblink on your local machine. Of course, in practical applications 5 | you'll likely want to run dblink on a cluster. We'll provide some pointers 6 | for this option as we go along. 7 | 8 | ## 0. Install Java 9 | The following two steps require that Java 8+ is installed on your system. 10 | To check whether it is installed on a macOS or Linux system, run the command 11 | ```bash 12 | $ java -version 13 | ``` 14 | You should see a version number of the form 8.x (or equivalently 1.8.x). 15 | Installation instructions for Oracle JDK on Windows, macOS and Linux are 16 | available [here](https://java.com/en/download/help/download_options.xml). 17 | 18 | _Note: As of April 2019, the licensing terms of the Oracle JDK have changed. 19 | We recommend using an open source alternative such as the OpenJDK. Packages 20 | are available in many Linux distributions. Instructions for macOS are 21 | available [here](macos-java8.md)._ 22 | 23 | ## 1. Get access to a Spark cluster 24 | Since dblink is implemented as a Spark application, you'll need access to a 25 | Spark cluster in order to run it. 26 | Setting up a Spark cluster from scratch can be quite involved and is beyond 27 | the scope of this guide. 28 | We refer interested readers to the Spark 29 | [documentation](https://spark.apache.org/docs/latest/#launching-on-a-cluster), 30 | which discusses various deployment options. 31 | An easier route for most users is to use a preconfigured Spark cluster 32 | available through public cloud providers, such as 33 | [Amazon EMR](https://aws.amazon.com/emr/), 34 | [Azure HDInsight](https://azure.microsoft.com/en-us/services/hdinsight/), 35 | and [Google Cloud Dataproc](https://cloud.google.com/dataproc/). 36 | In this guide, we take an even simpler approach: we'll run Spark in 37 | _pseudocluster mode_ on your local machine. 38 | This is fine for testing purposes or for small data sets. 39 | 40 | We'll now take you through detailed instructions for setting up Spark in 41 | pseudocluster mode on a macOS or Linux system. 42 | 43 | First, download the prebuilt 2.3.1 release from the Spark 44 | [release archive](https://archive.apache.org/dist/spark/). 45 | ```bash 46 | $ wget https://archive.apache.org/dist/spark/spark-2.4.5/spark-2.4.5-bin-hadoop2.7.tgz 47 | ``` 48 | then extract the archive. 49 | ```bash 50 | $ tar -xvf spark-2.4.5-bin-hadoop2.7.tgz 51 | ``` 52 | 53 | Move the Spark folder to `/opt` and create a symbolic link so that you can 54 | easily switch to another version in the future. 55 | ```bash 56 | $ sudo mv spark-2.4.5-bin-hadoop2.7 /opt 57 | $ sudo ln -s /opt/spark-2.4.5-bin-hadoop2.7/ /opt/spark 58 | ``` 59 | 60 | Define the `SPARK_HOME` variable and add the Spark binaries to your `PATH`. 61 | The way that this is done depends on your operating system and/or shell. 62 | Assuming enviornment variables are defined in `~/.profile`, you can 63 | run the following commands: 64 | ```bash 65 | $ echo 'export SPARK_HOME=/opt/spark' >> ~/.profile 66 | $ echo 'export PATH=$PATH:$SPARK_HOME/bin' >> ~/.profile 67 | ``` 68 | 69 | After appending these two lines, run the following command to update your 70 | path for the current session. 71 | ```bash 72 | $ source ~/.profile 73 | ``` 74 | 75 | Notes: 76 | * If using Bash on Debian, Fedora or RHEL derivatives, environment 77 | variables are typically defined in `~/.bash_profile` rather than 78 | `~/.profile` 79 | * If using ZSH, environment variables are typically defined in 80 | `~/.zprofile` 81 | * You can check which shell you're using by running `echo $SHELL` 82 | 83 | ## 2. Obtain the dblink JAR file 84 | In this step you'll obtain the dblink fat JAR, which has file name 85 | `dblink-assembly-0.2.0.jar`. 86 | It contains all of the class files and resources for dblink, packed together 87 | with any dependencies. 88 | 89 | There are two options: 90 | * (Recommended) Download a prebuilt JAR from [here](https://github.com/ngmarchant/dblink/releases). 91 | This has been built against Spark 2.4.5 and is not guaranteed to work with 92 | other versions of Spark. 93 | * Build the fat JAR file from source as explained in the section below. 94 | 95 | ### 2.1. Building the fat JAR 96 | The build tool used for dblink is called sbt. You'll need to install 97 | sbt on your system. Instructions are available for Windows, macOS and Linux 98 | in the sbt. We give alternative installtion in the second set of instructions 99 | for those using bash on MacOS. 100 | [documentation](https://www.scala-sbt.org/1.x/docs/Setup.html) 101 | 102 | On macOS or Linux, you can verify that sbt is installed correctly by running. 103 | ```bash 104 | $ sbt about 105 | ``` 106 | 107 | Once you've successfully installed sbt, get the dblink source code from 108 | GitHub: 109 | ```bash 110 | $ git clone https://github.com/cleanzr/dblink.git 111 | ``` 112 | then change into the dblink directory and build the package 113 | ```bash 114 | $ cd dblink 115 | $ sbt assembly 116 | ``` 117 | This should produce a fat JAR at `./target/scala-2.11/dblink-assembly-0.2.0.jar`. 118 | 119 | _Note: [IntelliJ IDEA](https://www.jetbrains.com/idea/) can also be used to 120 | build the fat JAR. It is arguably more user-friendly as it has a GUI and 121 | users can avoid installing sbt._ 122 | 123 | ## 3. Run dblink 124 | Having completed the above two steps, you're now ready to launch dblink. 125 | This is done using the [`spark-submit`](https://spark.apache.org/docs/latest/submitting-applications.html) 126 | interface, which supports all types of Spark deployments. 127 | 128 | As a test, let's try running the RLdata500 example provided with the source 129 | code on your local machine. 130 | From within the `dblink` directory, run the following command: 131 | ```bash 132 | $SPARK_HOME/bin/spark-submit \ 133 | --master "local[1]" \ 134 | --conf "spark.driver.extraJavaOptions=-Dlog4j.configuration=log4j.properties" \ 135 | --conf "spark.driver.extraClassPath=./target/scala-2.11/dblink-assembly-0.2.0.jar" \ 136 | ./target/scala-2.11/dblink-assembly-0.2.0.jar \ 137 | ./examples/RLdata500.conf 138 | ``` 139 | This will run Spark in pseudocluster (local) mode with 1 core. You can increase 140 | the number of cores available by changing `local[1]` to `local[n]` where `n` 141 | is the number of cores or `local[*]` to use all available cores. 142 | To run dblink on other data sets you will need to edit the config file (called 143 | `RLdata500.conf` above). 144 | Instructions for doing this are provided [here](configuration.md). 145 | 146 | ## 4. Output of dblink 147 | dblink saves output into a specified directory. In the RLdata500 example from 148 | above, the output is written to `./examples/RLdata500_results/`. 149 | 150 | Below we provide a brief description of the files: 151 | 152 | * `run.txt`: contains details about the job (MCMC run). This includes the 153 | data files, the attributes used, parameter settings etc. 154 | * `partitions-state.parquet` and `driver-state`: stores the final state of 155 | the Markov chain, so that MCMC can be resumed (e.g. you can run the Markov 156 | chain for longer without starting from scratch). 157 | * `diagnostics.csv` contains summary statistics along the chain which can be 158 | used to assess convergence/mixing. 159 | * `linkage-chain.parquet` contains posterior samples of the linkage structure 160 | in Parquet format. 161 | 162 | Optional files: 163 | 164 | * `evaluation-results.txt`: contains output from an "evaluate" step (e.g. 165 | precision, recall, other measures). Requires ground truth entity identifiers 166 | in the data files. 167 | * `cluster-size-distribution.csv` contains the cluster size distribution 168 | along the chain (rows are iterations, columns contain counts for each 169 | cluster/entity size. Only appears if requested in a "summarize" step. 170 | * `partition-sizes.csv` contains the partition sizes along the chain (rows 171 | are iterations, columns are counts of the number of entities residing in each 172 | partition). Only appears if requested in a "summarize" step. 173 | * `shared-most-probable-clusters.csv` is a point estimate of the linkage 174 | structure computed from the posterior samples. Each line in the file contains 175 | a comma-separated list of record identifiers which are assigned to the same 176 | cluster/entity. 177 | -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/ProjectStep.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import com.github.cleanzr.dblink.analysis.{ClusteringMetrics, PairwiseMetrics} 23 | import com.github.cleanzr.dblink.util.BufferedFileWriter 24 | import com.github.cleanzr.dblink.LinkageChain._ 25 | import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} 26 | import org.apache.spark.storage.StorageLevel 27 | 28 | trait ProjectStep { 29 | def execute(): Unit 30 | 31 | def mkString: String 32 | } 33 | 34 | object ProjectStep { 35 | private val supportedSamplers = Set("PCG-I", "PCG-II", "Gibbs", "Gibbs-Sequential") 36 | private val supportedEvaluationMetrics = Set("pairwise", "cluster") 37 | private val supportedSummaryQuantities = Set("cluster-size-distribution", "partition-sizes", "shared-most-probable-clusters") 38 | 39 | class SampleStep(project: Project, sampleSize: Int, burninInterval: Int, 40 | thinningInterval: Int, resume: Boolean, sampler: String) extends ProjectStep with Logging { 41 | require(sampleSize > 0, "sampleSize must be positive") 42 | require(burninInterval >= 0, "burninInterval must be non-negative") 43 | require(thinningInterval >= 0, "thinningInterval must be non-negative") 44 | require(supportedSamplers.contains(sampler), s"sampler must be one of ${supportedSamplers.mkString("", ", ", "")}.") 45 | 46 | override def execute(): Unit = { 47 | info(mkString) 48 | val initialState = if (resume) { 49 | project.savedState.getOrElse(project.generateInitialState) 50 | } else { 51 | project.generateInitialState 52 | } 53 | sampler match { 54 | case "PCG-I" => Sampler.sample(initialState, sampleSize, project.outputPath, burninInterval=burninInterval, thinningInterval=thinningInterval, collapsedEntityIds = false, collapsedEntityValues = true, sequential = false) 55 | case "PCG-II" => Sampler.sample(initialState, sampleSize, project.outputPath, burninInterval=burninInterval, thinningInterval=thinningInterval, collapsedEntityIds = true, collapsedEntityValues = true, sequential = false) 56 | case "Gibbs" => Sampler.sample(initialState, sampleSize, project.outputPath, burninInterval=burninInterval, thinningInterval=thinningInterval, collapsedEntityIds = false, collapsedEntityValues = false, sequential = false) 57 | case "Gibbs-Sequential" => Sampler.sample(initialState, sampleSize, project.outputPath, burninInterval=burninInterval, thinningInterval=thinningInterval, collapsedEntityIds = false, collapsedEntityValues = false, sequential = true) 58 | } 59 | } 60 | 61 | override def mkString: String = { 62 | if (resume) s"SampleStep: Evolving the chain from saved state with sampleSize=$sampleSize, burninInterval=$burninInterval, thinningInterval=$thinningInterval and sampler=$sampler" 63 | else s"SampleStep: Evolving the chain from new initial state with sampleSize=$sampleSize, burninInterval=$burninInterval, thinningInterval=$thinningInterval and sampler=$sampler" 64 | } 65 | } 66 | 67 | class EvaluateStep(project: Project, lowerIterationCutoff: Int, metrics: Traversable[String], 68 | useExistingSMPC: Boolean) extends ProjectStep with Logging { 69 | require(project.entIdAttribute.isDefined, "Ground truth entity ids are required for evaluation") 70 | require(lowerIterationCutoff >=0, "lowerIterationCutoff must be non-negative") 71 | require(metrics.nonEmpty, "metrics must be non-empty") 72 | require(metrics.forall(m => supportedEvaluationMetrics.contains(m)), s"metrics must be one of ${supportedEvaluationMetrics.mkString("{", ", ", "}")}.") 73 | 74 | override def execute(): Unit = { 75 | info(mkString) 76 | 77 | // Get ground truth clustering 78 | val trueClusters = project.trueClusters match { 79 | case Some(clusters) => clusters.persist() 80 | case None => 81 | error("Ground truth clusters are unavailable") 82 | return 83 | } 84 | 85 | import analysis._ 86 | 87 | // Get predicted clustering (using sMPC method) 88 | val sMPC = if (useExistingSMPC && project.sharedMostProbableClustersOnDisk) { 89 | // Read saved sMPC from disk 90 | project.savedSharedMostProbableClusters 91 | } else { 92 | // Try to compute sMPC using saved linkage chain (and save to disk) 93 | project.savedLinkageChain(lowerIterationCutoff) match { 94 | case Some(chain) => 95 | val sMPC = sharedMostProbableClusters(chain).persist() 96 | sMPC.saveCsv(project.outputPath + "shared-most-probable-clusters.csv") 97 | chain.unpersist() 98 | Some(sMPC) 99 | case None => 100 | error("No linkage chain") 101 | None 102 | } 103 | } 104 | 105 | sMPC match { 106 | case Some(predictedClusters) => 107 | val results = metrics.map { 108 | case metric if metric == "pairwise" => 109 | PairwiseMetrics(predictedClusters.toPairwiseLinks, trueClusters.toPairwiseLinks).mkString 110 | case metric if metric == "cluster" => 111 | ClusteringMetrics(predictedClusters, trueClusters).mkString 112 | } 113 | val writer = BufferedFileWriter(project.outputPath + "evaluation-results.txt", append = false, project.sparkContext) 114 | writer.write(results.mkString("", "\n", "\n")) 115 | writer.close() 116 | case None => error("Predicted clusters are unavailable") 117 | } 118 | } 119 | 120 | override def mkString: String = { 121 | if (useExistingSMPC) s"EvaluateStep: Evaluating saved sMPC clusters using ${metrics.map("'" + _ + "'").mkString("{", ", ", "}")} metrics" 122 | else s"EvaluateStep: Evaluating sMPC clusters (computed from the chain for iterations >= $lowerIterationCutoff) using ${metrics.map("'" + _ + "'").mkString("{", ", ", "}")} metrics" 123 | } 124 | } 125 | 126 | class SummarizeStep(project: Project, lowerIterationCutoff: Int, 127 | quantities: Traversable[String]) extends ProjectStep with Logging { 128 | require(lowerIterationCutoff >= 0, "lowerIterationCutoff must be non-negative") 129 | require(quantities.nonEmpty, "quantities must be non-empty") 130 | require(quantities.forall(q => supportedSummaryQuantities.contains(q)), s"quantities must be one of ${supportedSummaryQuantities.mkString("{", ", ", "}")}.") 131 | 132 | override def execute(): Unit = { 133 | info(mkString) 134 | project.savedLinkageChain(lowerIterationCutoff) match { 135 | case Some(chain) => 136 | quantities.foreach { 137 | case "cluster-size-distribution" => 138 | val clustSizeDist = clusterSizeDistribution(chain) 139 | saveClusterSizeDistribution(clustSizeDist, project.outputPath) 140 | case "partition-sizes" => 141 | val partSizes = partitionSizes(chain) 142 | savePartitionSizes(partSizes, project.outputPath) 143 | case "shared-most-probable-clusters" => 144 | import analysis._ 145 | val smpc = sharedMostProbableClusters(chain) 146 | smpc.saveCsv(project.outputPath + "shared-most-probable-clusters.csv") 147 | } 148 | case None => error("No linkage chain") 149 | } 150 | } 151 | 152 | override def mkString: String = { 153 | s"SummarizeStep: Calculating summary quantities ${quantities.map("'" + _ + "'").mkString("{", ", ", "}")} along the chain for iterations >= $lowerIterationCutoff" 154 | } 155 | } 156 | 157 | class CopyFilesStep(project: Project, fileNames: Traversable[String], destinationPath: String, 158 | overwrite: Boolean, deleteSource: Boolean) extends ProjectStep with Logging { 159 | 160 | override def execute(): Unit = { 161 | info(mkString) 162 | 163 | val conf = project.sparkContext.hadoopConfiguration 164 | val srcParent = new Path(project.outputPath) 165 | val srcFs = FileSystem.get(srcParent.toUri, conf) 166 | val dstParent = new Path(destinationPath) 167 | val dstFs = FileSystem.get(dstParent.toUri, conf) 168 | fileNames.map(fName => new Path(srcParent.toString + Path.SEPARATOR + fName)) 169 | .filter(srcFs.exists) 170 | .foreach { src => 171 | val dst = new Path(dstParent.toString + Path.SEPARATOR + src.getName) 172 | FileUtil.copy(srcFs, src, dstFs, dst, deleteSource, overwrite, conf) 173 | } 174 | } 175 | 176 | override def mkString: String = { 177 | s"CopyFilesStep: Copying ${fileNames.mkString("{",", ","}")} to destination $destinationPath" 178 | } 179 | } 180 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/LinkageChain.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Australian Bureau of Statistics 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import com.github.cleanzr.dblink.util.BufferedFileWriter 23 | import org.apache.spark.sql.{Dataset, SparkSession} 24 | 25 | import scala.collection.mutable 26 | 27 | object LinkageChain extends Logging { 28 | 29 | /** Read samples of the linkage structure 30 | * 31 | * @param path path to the output directory. 32 | * @return a linkage chain: an RDD containing samples of the linkage 33 | * structure (by partition) along the Markov chain. 34 | */ 35 | def readLinkageChain(path: String): Dataset[LinkageState] = { 36 | // TODO: check path 37 | val spark = SparkSession.builder().getOrCreate() 38 | import spark.implicits._ 39 | 40 | spark.read.format("parquet") 41 | .load(path + "linkage-chain.parquet") 42 | .as[LinkageState] 43 | } 44 | 45 | 46 | /** 47 | * Computes the most probable clustering for each record 48 | * 49 | * @param linkageChain A Dataset representing the linkage structure across iterations and partitions. 50 | * @return A Dataset containing the most probable cluster for each record. 51 | */ 52 | def mostProbableClusters(linkageChain: Dataset[LinkageState]): Dataset[MostProbableCluster] = { 53 | val spark = linkageChain.sparkSession 54 | import spark.implicits._ 55 | 56 | val numSamples = linkageChain.map(_.iteration).distinct().count() 57 | linkageChain.rdd 58 | .flatMap(_.linkageStructure.iterator.collect {case cluster if cluster.nonEmpty => (cluster.toSet, 1.0/numSamples)}) 59 | .reduceByKey(_ + _) 60 | .flatMap {case (recIds, freq) => recIds.iterator.map(recId => (recId, (recIds, freq)))} 61 | .reduceByKey((x, y) => if (x._2 >= y._2) x else y) 62 | .map(x => MostProbableCluster(x._1, x._2._1, x._2._2)) 63 | .toDS() 64 | } 65 | 66 | 67 | /** 68 | * Computes a point estimate of the most likely clustering that obeys transitivity constraints. The method was 69 | * introduced by Steorts et al. (2016), where it is referred to as the method of shared most probable maximal 70 | * matching sets. 71 | * 72 | * @param mostProbableClusters A Dataset containing the most probable cluster for each record. 73 | * @return A Dataset of record clusters 74 | */ 75 | def sharedMostProbableClusters(mostProbableClusters: Dataset[MostProbableCluster]): Dataset[Cluster] = { 76 | val spark = mostProbableClusters.sparkSession 77 | import spark.implicits._ 78 | 79 | mostProbableClusters.rdd 80 | .map(x => (x.cluster, x.recordId)) 81 | .aggregateByKey(zeroValue = Set.empty[RecordId])( 82 | seqOp = (recordIds, recordId) => recordIds + recordId, 83 | combOp = (recordIdsA, recordIdsB) => recordIdsA union recordIdsB 84 | ) 85 | // key = most probable cluster | value = aggregated recIds 86 | .map(_._2) 87 | .toDS() 88 | // .flatMap[Set[RecordIdType]] { 89 | // case (cluster, recIds) if cluster == recIds => Iterator(recIds) 90 | // // cluster is shared most probable for all records it contains 91 | // case (cluster, recIds) if cluster != recIds => recIds.iterator.map(Set(_)) 92 | // // cluster isn't shared most probable -- output each record as a 93 | // // separate cluster 94 | // } 95 | } 96 | 97 | 98 | /** 99 | * Computes a point estimate of the most likely clustering that obeys transitivity constraints. The method was 100 | * introduced by Steorts et al. (2016), where it is referred to as the method of shared most probable maximal 101 | * matching sets. 102 | * 103 | * @param linkageChain A Dataset representing the linkage structure across iterations and partitions. 104 | * @return A Dataset of record clusters 105 | */ 106 | def sharedMostProbableClusters(linkageChain: Dataset[LinkageState])(implicit i1: DummyImplicit): Dataset[Cluster] = { 107 | val mpc = mostProbableClusters(linkageChain) 108 | sharedMostProbableClusters(mpc) 109 | } 110 | 111 | 112 | /** Computes the partition sizes along the linkage chain 113 | * 114 | * @param linkageChain A Dataset representing the linkage structure across iterations and partitions. 115 | * @return A Dataset containing the partition sizes at each iteration. The key is the iteration and the value is a 116 | * map from partition ids to their corresponding sizes. 117 | */ 118 | def partitionSizes(linkageChain: Dataset[LinkageState]): Dataset[(Long, Map[PartitionId, Int])] = { 119 | val spark = linkageChain.sparkSession 120 | import spark.implicits._ 121 | linkageChain.rdd 122 | .map(x => (x.iteration, (x.partitionId, x.linkageStructure.size))) 123 | .aggregateByKey(Map.empty[PartitionId, Int])( 124 | seqOp = (m, v) => m + v, 125 | combOp = (m1, m2) => m1 ++ m2 126 | ) 127 | .toDS() 128 | } 129 | 130 | 131 | /** Computes the cluster size frequency distribution along the linkage chain 132 | * 133 | * @param linkageChain A Dataset representing the linkage structure across iterations and partitions. 134 | * @return A Dataset containing the cluster size frequency distribution at each iteration. The key is the iteration 135 | * and the value is a map from cluster sizes to their corresponding frequencies. 136 | */ 137 | def clusterSizeDistribution(linkageChain: Dataset[LinkageState]): Dataset[(Long, mutable.Map[Int, Long])] = { 138 | // Compute distribution separately for each partition, then combine the results 139 | val spark = linkageChain.sparkSession 140 | import spark.implicits._ 141 | 142 | linkageChain.rdd 143 | .map(x => { 144 | val clustSizes = mutable.Map[Int, Long]().withDefaultValue(0L) 145 | x.linkageStructure.foreach { cluster => clustSizes(cluster.size) += 1L } 146 | (x.iteration, clustSizes) 147 | }) 148 | .reduceByKey((a, b) => { 149 | val combined = mutable.Map[Int, Long]().withDefaultValue(0L) 150 | (a.keySet ++ b.keySet).foreach(k => combined(k) = a(k) + b(k)) 151 | combined // combine maps 152 | }) 153 | .toDS() 154 | } 155 | 156 | 157 | /** Computes the cluster size frequency distribution along the linkage chain 158 | * and saves the result to `cluster-size-distribution.csv` in the output directory. 159 | * 160 | * @param path path to working directory 161 | */ 162 | def saveClusterSizeDistribution(clusterSizeDistribution: Dataset[(Long, mutable.Map[Int, Long])], path: String): Unit = { 163 | val sc = clusterSizeDistribution.sparkSession.sparkContext 164 | 165 | val distAlongChain = clusterSizeDistribution.collect().sortBy(_._1) // collect on driver and sort by iteration 166 | 167 | // Get the size of the largest cluster in the samples 168 | val maxClustSize = distAlongChain.aggregate(0)( 169 | seqop = (currentMax, x) => math.max(currentMax, x._2.keySet.max), 170 | combop = (a, b) => math.max(a, b) 171 | ) 172 | // Output file can be created from Hadoop file system. 173 | val fullPath = path + "cluster-size-distribution.csv" 174 | info(s"Writing cluster size frequency distribution along the chain to $fullPath") 175 | val writer = BufferedFileWriter(fullPath, append = false, sc) 176 | 177 | // Write CSV header 178 | writer.write("iteration" + "," + (0 to maxClustSize).mkString(",") + "\n") 179 | // Write rows (one for each iteration) 180 | distAlongChain.foreach { case (iteration, kToCounts) => 181 | val countsArray = (0 to maxClustSize).map(k => kToCounts.getOrElse(k, 0L)) 182 | writer.write(iteration.toString + "," + countsArray.mkString(",") + "\n") 183 | } 184 | writer.close() 185 | } 186 | 187 | 188 | /** Computes the sizes of the partitions along the chain and saves the result to 189 | * `partition-sizes.csv` in the output directory. 190 | * 191 | * @param path path to output directory 192 | */ 193 | def savePartitionSizes(partitionSizes: Dataset[(Long, Map[PartitionId, Int])], path: String): Unit = { 194 | val sc = partitionSizes.sparkSession.sparkContext 195 | val partSizesAlongChain = partitionSizes.collect().sortBy(_._1) 196 | 197 | val partIds = partSizesAlongChain.map(_._2.keySet).reduce(_ ++ _).toArray.sorted 198 | 199 | // Output file can be created from Hadoop file system. 200 | val fullPath = path + "partition-sizes.csv" 201 | val writer = BufferedFileWriter(fullPath, append = false, sc) 202 | 203 | // Write CSV header 204 | writer.write("iteration" + "," + partIds.mkString(",") + "\n") 205 | // Write rows (one for each iteration) 206 | partSizesAlongChain.foreach { case (iteration, m) => 207 | val sizes = partIds.map(partId => m.getOrElse(partId, 0)) 208 | writer.write(iteration.toString + "," + sizes.mkString(",") + "\n") 209 | } 210 | writer.close() 211 | } 212 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/AttributeIndex.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import com.github.cleanzr.dblink.random.DiscreteDist 23 | import com.github.cleanzr.dblink.random.DiscreteDist 24 | import org.apache.commons.math3.random.RandomGenerator 25 | import org.apache.spark.SparkContext 26 | 27 | import scala.collection.mutable 28 | import scala.math.{exp, pow} 29 | 30 | /** An index for an attribute domain. 31 | * It includes: 32 | * - an index from raw strings in the domain to integer value ids 33 | * - the empirical distribution over the domain 34 | * - an index that accepts a query value id and returns a set of 35 | * "similar" value ids (whose truncated similarity is non-zero) 36 | * - an index over pairs of value ids that returns exponentiated 37 | * truncated similarity scores 38 | */ 39 | trait AttributeIndex extends Serializable { 40 | 41 | /** Number of distinct attribute values */ 42 | def numValues: Int 43 | 44 | /** Object for empirical distribution */ 45 | def distribution: DiscreteDist[Int] 46 | 47 | /** Get the probability mass associated with a value id 48 | * 49 | * @param valueId value id. 50 | * @return probability mass. Returns 0.0 if the value id does not exist. 51 | */ 52 | def probabilityOf(valueId: ValueId): Double 53 | 54 | 55 | /** Draw a value id according to the empirical distribution 56 | * 57 | * @return a value id. 58 | */ 59 | def draw()(implicit rand: RandomGenerator): ValueId 60 | 61 | 62 | /** Get the value id for a given string value 63 | * 64 | * @param value original string value 65 | * @return integer value id. Returns `-1` if value does not exist in the 66 | * index. 67 | */ 68 | def valueIdxOf(value: String): ValueId 69 | 70 | 71 | /** Get the similarity normalization corresponding to the value id. 72 | * 73 | * @param valueId integer value id 74 | * @return sum_{w}(probabilities(w) * exp(sim(w,value)) 75 | */ 76 | def simNormalizationOf(valueId: ValueId): Double 77 | 78 | 79 | /** Get the value ids that are "similar" (above the similarity threshold) 80 | * to the given value id, along with the exponentiated similarity scores. 81 | * 82 | * @param valueId original value string 83 | * @return 84 | */ 85 | def simValuesOf(valueId: ValueId): scala.collection.Map[ValueId, Double] 86 | 87 | 88 | /** Get the exponentiated similarity score for a pair of value ids 89 | * 90 | * @param valueId1 integer value id 1 91 | * @param valueId2 integer value id 2 92 | * @return exp(sim(valueId1, valueId2)) 93 | */ 94 | def expSimOf(valueId1: ValueId, valueId2: ValueId): Double 95 | 96 | 97 | /** Get distribution of the form: 98 | * p(v) \propto probabilityOf(v) * pow(simNormalizationOf(v), power) 99 | * 100 | * @param power the similarity normalization is raised to this power 101 | * @return a distribution 102 | */ 103 | def getSimNormDist(power: Int): DiscreteDist[Int] 104 | } 105 | 106 | object AttributeIndex { 107 | def apply(valuesWeights: Map[String, Double], 108 | similarityFn: SimilarityFn, 109 | precachePowers: Option[Traversable[Int]] = None) 110 | (implicit sc: SparkContext): AttributeIndex = { 111 | require(valuesWeights.nonEmpty, "index cannot be empty") 112 | 113 | val valuesWeights_sorted = valuesWeights.toArray.sortBy(_._1) 114 | val totalWeight = valuesWeights_sorted.foldLeft(0.0){ case (sum, (_, weight)) => sum + weight } 115 | val probs = valuesWeights_sorted.map(_._2/totalWeight) 116 | val stringToId = valuesWeights_sorted.iterator.zipWithIndex.map(x => (x._1._1, x._2)).toMap 117 | 118 | similarityFn match { 119 | case SimilarityFn.ConstantSimilarityFn => new ConstantAttributeIndex(stringToId, probs) 120 | case simFn => 121 | val simValueIndex = computeSimValueIndex(stringToId, simFn) 122 | val simNormalizations = computeSimNormalizations(simValueIndex, probs) 123 | new GenericAttributeIndex(stringToId, probs, simValueIndex, simNormalizations, precachePowers) 124 | } 125 | } 126 | 127 | /** Implementation for attributes with constant similarity functions */ 128 | private class ConstantAttributeIndex(protected val stringToId: Map[String, ValueId], 129 | protected val probs: Array[Double]) extends AttributeIndex { 130 | 131 | /** Empirical distribution over the field values */ 132 | val distribution: DiscreteDist[Int] = DiscreteDist(probs) 133 | 134 | val numValues: Int = stringToId.size 135 | 136 | override def probabilityOf(valueId: ValueId): Double = { 137 | require(valueId >= 0 && valueId < numValues, "valueId is not in the index") 138 | distribution.probabilityOf(valueId) 139 | } 140 | 141 | override def draw()(implicit rand: RandomGenerator): ValueId = distribution.sample() 142 | 143 | override def valueIdxOf(value: String): ValueId = stringToId(value) 144 | 145 | /** Assuming exp(sim(v1, v2)) = 1.0 for all v1, v2 */ 146 | override def simNormalizationOf(valueId: ValueId): Double = { 147 | require(valueId >= 0 && valueId < numValues, "valueId is not in the index") 148 | 1.0 149 | } 150 | 151 | /** Assuming exp(sim(v1, v2)) = 1.0 for all v1, v2 */ 152 | override def simValuesOf(valueId: ValueId): scala.collection.Map[ValueId, Double] = { 153 | require(valueId >= 0 && valueId < numValues, "valueId is not in the index") 154 | Map.empty[ValueId, Double] 155 | } 156 | 157 | /** Assuming exp(sim(v1, v2)) = 1.0 for all v1, v2 */ 158 | override def expSimOf(valueId1: ValueId, valueId2: ValueId): Double = { 159 | require(valueId1 >= 0 && valueId1 < numValues, "valueId1 is not in the index") 160 | require(valueId2 >= 0 && valueId2 < numValues, "valueId2 is not in the index") 161 | 1.0 162 | } 163 | 164 | /** For a constant similarity function, this is the same as the empirical distribution */ 165 | override def getSimNormDist(power: Int): DiscreteDist[Int] = { 166 | require(power > 0, "power must be a positive integer") 167 | distribution 168 | } 169 | } 170 | 171 | /** Implementation for attributes with non-constant attribute similarity functions */ 172 | private class GenericAttributeIndex(stringToId: Map[String, ValueId], 173 | probs: Array[Double], 174 | private val simValueIndex: Array[Map[ValueId, Double]], 175 | private val simNormalizations: Array[Double], 176 | precachePowers: Option[Traversable[Int]]) 177 | extends ConstantAttributeIndex(stringToId, probs) { 178 | 179 | override def simNormalizationOf(valueId: ValueId): Double = simNormalizations(valueId) 180 | 181 | override def simValuesOf(valueId: ValueId): scala.collection.Map[ValueId, Double] = simValueIndex(valueId) 182 | 183 | override def expSimOf(valueId1: ValueId, valueId2: ValueId): Double = { 184 | require(valueId2 >= 0 && valueId2 < numValues, "valueId2 is not in the index") 185 | simValuesOf(valueId1).getOrElse(valueId2, 1.0) 186 | } 187 | 188 | private val cachedSimNormDist = precachePowers match { 189 | case Some(powers) => powers.foldLeft(mutable.Map.empty[Int, DiscreteDist[Int]]) { 190 | case (m, power) if power > 0 => 191 | m + (power -> computeSimNormDist(probs, simNormalizations, power)) 192 | case (m, _) => m 193 | } 194 | case None => mutable.Map.empty[Int, DiscreteDist[Int]] 195 | } 196 | 197 | override def getSimNormDist(power: Int): DiscreteDist[Int] = { 198 | require(power > 0, "power must be a positive integer") 199 | cachedSimNormDist.get(power) match { 200 | case Some(dist) => dist 201 | case None => 202 | val dist = computeSimNormDist(probs, simNormalizations, power) 203 | cachedSimNormDist.update(power, dist) 204 | dist 205 | } 206 | } 207 | } 208 | 209 | 210 | private def computeSimNormDist(probs: Array[Double], simNormalizations: Array[Double], 211 | power: Int): DiscreteDist[Int] = { 212 | val simNormWeights = Array.tabulate(probs.length) { valueId => 213 | probs(valueId) * pow(simNormalizations(valueId), power) 214 | } 215 | DiscreteDist(simNormWeights) 216 | } 217 | 218 | 219 | private def computeSimValueIndex(stringToId: Map[String, ValueId], 220 | similarityFn: SimilarityFn) 221 | (implicit sc: SparkContext): Array[Map[ValueId, Double]] = { 222 | val valuesRDD = sc.parallelize(stringToId.toSeq) // TODO: how to set numSlices? Take into account number of executors. 223 | val valuePairsRDD = valuesRDD.cartesian(valuesRDD) 224 | 225 | val simValues = valuePairsRDD.map { case ((a, a_id), (b, b_id)) => (a_id, (b_id, exp(similarityFn.getSimilarity(a, b)))) } 226 | .filter(_._2._2 > 1.0) // non-zero truncated similarity 227 | .aggregateByKey(Map.empty[Int, Double])(seqOp = _ + _, combOp = _ ++ _) 228 | .collect() 229 | 230 | simValues.sortBy(_._1).map(_._2) 231 | } 232 | 233 | 234 | private def computeSimNormalizations(simValueIndex: Array[Map[ValueId, Double]], 235 | probs: Array[Double]): Array[Double] = { 236 | simValueIndex.map { simValues => 237 | var valueId = 0 238 | var norm = 0.0 239 | while (valueId < probs.length) { 240 | norm += probs(valueId) * simValues.getOrElse(valueId, 1.0) 241 | valueId += 1 242 | } 243 | 1.0/norm 244 | } 245 | } 246 | } -------------------------------------------------------------------------------- /src/main/scala/com/github/cleanzr/dblink/Project.scala: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2018 Neil Marchant 2 | // 3 | // Author: Neil Marchant 4 | // 5 | // This file is part of dblink. 6 | // 7 | // This program is free software: you can redistribute it and/or modify 8 | // it under the terms of the GNU General Public License as published by 9 | // the Free Software Foundation, either version 3 of the License, or 10 | // (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, 13 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | // GNU General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program. If not, see . 19 | 20 | package com.github.cleanzr.dblink 21 | 22 | import SimilarityFn.{ConstantSimilarityFn, LevenshteinSimilarityFn} 23 | import com.github.cleanzr.dblink.analysis._ 24 | import partitioning.{KDTreePartitioner, PartitionFunction} 25 | import com.typesafe.config.{Config, ConfigException, ConfigObject} 26 | import org.apache.hadoop.fs.{FileSystem, Path} 27 | import org.apache.spark.SparkContext 28 | import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} 29 | 30 | import scala.collection.JavaConverters._ 31 | import scala.collection.mutable 32 | import scala.util.Try 33 | 34 | /** An entity resolution project 35 | * 36 | * @param dataPath path to source records (in CSV format) 37 | * @param outputPath path to project directory 38 | * @param checkpointPath path for saving Spark checkpoints 39 | * @param recIdAttribute name of record identifier column in dataFrame (must be unique across all files) 40 | * @param fileIdAttribute name of file identifier column in dataFrame (optional) 41 | * @param entIdAttribute name of entity identifier column in dataFrame (optional: if ground truth is available) 42 | * @param matchingAttributes attribute specifications to use for matching 43 | * @param partitionFunction partition function (determines how entities are partitioned across executors) 44 | * @param randomSeed random seed 45 | * @param populationSize size of the latent population 46 | * @param expectedMaxClusterSize expected size of the largest record cluster (used as a hint to improve precaching) 47 | * @param dataFrame data frame containing source records 48 | */ 49 | case class Project(dataPath: String, outputPath: String, checkpointPath: String, 50 | recIdAttribute: String, fileIdAttribute: Option[String], 51 | entIdAttribute: Option[String], matchingAttributes: IndexedSeq[Attribute], 52 | partitionFunction: PartitionFunction[ValueId], randomSeed: Long, populationSize: Option[Int], 53 | expectedMaxClusterSize: Int, dataFrame: DataFrame) extends Logging { 54 | require(expectedMaxClusterSize >= 0, "expectedMaxClusterSize must be non-negative") 55 | 56 | def sparkContext: SparkContext = dataFrame.sparkSession.sparkContext 57 | 58 | def mkString: String = { 59 | val lines = mutable.ArrayBuffer.empty[String] 60 | lines += "Data settings" 61 | lines += "-------------" 62 | lines += s" * Using data files located at '$dataPath'" 63 | lines += s" * The record identifier attribute is '$recIdAttribute'" 64 | fileIdAttribute match { 65 | case Some(fId) => lines += s" * The file identifier attribute is '$fId'" 66 | case None => lines += " * There is no file identifier" 67 | } 68 | entIdAttribute match { 69 | case Some(eId) => lines += s" * The entity identifier attribute is '$eId'" 70 | case None => lines += " * There is no entity identifier" 71 | } 72 | lines += s" * The matching attributes are ${matchingAttributes.map("'" + _.name + "'").mkString(", ")}" 73 | lines += "" 74 | 75 | lines += "Hyperparameter settings" 76 | lines += "-----------------------" 77 | lines ++= matchingAttributes.zipWithIndex.map { case (attribute, attributeId) => 78 | s" * '${attribute.name}' (id=$attributeId) with ${attribute.similarityFn.mkString} and ${attribute.distortionPrior.mkString}" 79 | } 80 | lines += s" * Size of latent population is ${populationSize.toString}" 81 | lines += "" 82 | 83 | lines += "Partition function settings" 84 | lines += "---------------------------" 85 | lines += " * " + partitionFunction.mkString 86 | lines += "" 87 | 88 | lines += "Project settings" 89 | lines += "----------------" 90 | lines += s" * Using randomSeed=$randomSeed" 91 | lines += s" * Using expectedMaxClusterSize=$expectedMaxClusterSize" 92 | lines += s" * Saving Markov chain and complete final state to '$outputPath'" 93 | lines += s" * Saving Spark checkpoints to '$checkpointPath'" 94 | 95 | lines.mkString("","\n","\n") 96 | } 97 | 98 | def sharedMostProbableClustersOnDisk: Boolean = { 99 | val hdfs = FileSystem.get(sparkContext.hadoopConfiguration) 100 | val fSMPC = new Path(outputPath + "shared-most-probable-clusters.csv") 101 | hdfs.exists(fSMPC) 102 | } 103 | 104 | def savedLinkageChain(lowerIterationCutoff: Int = 0): Option[Dataset[LinkageState]] = { 105 | val savedLinkageChainExists = { 106 | val hdfs = FileSystem.get(sparkContext.hadoopConfiguration) 107 | val file = new Path(outputPath + "linkage-chain.parquet") 108 | hdfs.exists(file) 109 | } 110 | if (savedLinkageChainExists) { 111 | val chain = if (lowerIterationCutoff == 0) LinkageChain.readLinkageChain(outputPath) 112 | else LinkageChain.readLinkageChain(outputPath).filter(_.iteration >= lowerIterationCutoff) 113 | if (chain.take(1).isEmpty) None 114 | else Some(chain) 115 | } else None 116 | } 117 | 118 | def savedState: Option[State] = { 119 | val savedStateExists = { 120 | val hdfs = FileSystem.get(sparkContext.hadoopConfiguration) 121 | val fDriverState = new Path(outputPath + "driver-state") 122 | val fPartitionState = new Path(outputPath + "partitions-state.parquet") 123 | hdfs.exists(fDriverState) && hdfs.exists(fPartitionState) 124 | } 125 | if (savedStateExists) { 126 | Some(State.read(outputPath)) 127 | } else None 128 | } 129 | 130 | def generateInitialState: State = { 131 | info("Generating new initial state") 132 | val parameters = Parameters( 133 | maxClusterSize = expectedMaxClusterSize 134 | ) 135 | State.deterministic( 136 | records = dataFrame, 137 | recIdColname = recIdAttribute, 138 | fileIdColname = fileIdAttribute, 139 | populationSize = populationSize, 140 | attributeSpecs = matchingAttributes, 141 | parameters = parameters, 142 | partitionFunction = partitionFunction, 143 | randomSeed = randomSeed 144 | ) 145 | } 146 | 147 | def savedSharedMostProbableClusters: Option[Dataset[Cluster]] = { 148 | if (sharedMostProbableClustersOnDisk) { 149 | Some(readClustersCSV(outputPath + "shared-most-probable-clusters.csv")) 150 | } else None 151 | } 152 | 153 | /** 154 | * Loads the ground truth clustering, if available. 155 | */ 156 | def trueClusters: Option[Dataset[Cluster]] = { 157 | entIdAttribute match { 158 | case Some(eId) => 159 | val spark = dataFrame.sparkSession 160 | val recIdName = recIdAttribute 161 | import spark.implicits._ 162 | val membership = dataFrame.map(r => (r.getAs[RecordId](recIdName), r.getAs[EntityId](eId))) 163 | Some(membershipToClusters(membership)) 164 | case _ => None 165 | } 166 | } 167 | } 168 | 169 | object Project { 170 | def apply(config: Config): Project = { 171 | val dataPath = config.getString("dblink.data.path") 172 | 173 | val dataFrame: DataFrame = { 174 | val spark = SparkSession.builder().getOrCreate() 175 | spark.read.format("csv") 176 | .option("header", "true") 177 | .option("mode", "DROPMALFORMED") 178 | .option("nullValue", config.getString("dblink.data.nullValue")) 179 | .load(dataPath) 180 | } 181 | 182 | val matchingAttributes = 183 | parseMatchingAttributes(config.getObjectList("dblink.data.matchingAttributes")) 184 | 185 | Project( 186 | dataPath = dataPath, 187 | outputPath = config.getString("dblink.outputPath"), 188 | checkpointPath = config.getString("dblink.checkpointPath"), 189 | recIdAttribute = config.getString("dblink.data.recordIdentifier"), 190 | fileIdAttribute = Try {Some(config.getString("dblink.data.fileIdentifier"))} getOrElse None, 191 | entIdAttribute = Try {Some(config.getString("dblink.data.entityIdentifier"))} getOrElse None, 192 | matchingAttributes = matchingAttributes, 193 | partitionFunction = parsePartitioner(config.getConfig("dblink.partitioner"), matchingAttributes.map(_.name)), 194 | randomSeed = config.getLong("dblink.randomSeed"), 195 | populationSize = Try {Some(config.getInt("dblink.populationSize"))} getOrElse None, 196 | expectedMaxClusterSize = Try {config.getInt("dblink.expectedMaxClusterSize")} getOrElse 10, 197 | dataFrame = dataFrame 198 | ) 199 | } 200 | 201 | implicit def toConfigTraversable[T <: ConfigObject](objectList: java.util.List[T]): Traversable[Config] = objectList.asScala.map(_.toConfig) 202 | 203 | private def parseMatchingAttributes(configList: Traversable[Config]): Array[Attribute] = { 204 | configList.map { c => 205 | val simFn = c.getString("similarityFunction.name") match { 206 | case "ConstantSimilarityFn" => ConstantSimilarityFn 207 | case "LevenshteinSimilarityFn" => 208 | LevenshteinSimilarityFn(c.getDouble("similarityFunction.parameters.threshold"), c.getDouble("similarityFunction.parameters.maxSimilarity")) 209 | case _ => throw new ConfigException.BadValue(c.origin(), "similarityFunction.name", "unsupported value") 210 | } 211 | val distortionPrior = BetaShapeParameters( 212 | c.getDouble("distortionPrior.alpha"), 213 | c.getDouble("distortionPrior.beta") 214 | ) 215 | Attribute(c.getString("name"), simFn, distortionPrior) 216 | }.toArray 217 | } 218 | 219 | private def parsePartitioner(config: Config, attributeNames: Seq[String]): KDTreePartitioner[ValueId] = { 220 | if (config.getString("name") == "KDTreePartitioner") { 221 | val numLevels = config.getInt("parameters.numLevels") 222 | val attributeIds = config.getStringList("parameters.matchingAttributes").asScala.map( n => 223 | attributeNames.indexOf(n) 224 | ) 225 | KDTreePartitioner[ValueId](numLevels, attributeIds) 226 | } else { 227 | throw new ConfigException.BadValue(config.origin(), "name", "unsupported value") 228 | } 229 | } 230 | } --------------------------------------------------------------------------------