├── LICENSE.txt ├── README.md ├── berkeley-entity-1.0.jar ├── build.sbt ├── config └── base.conf ├── lib ├── BerkeleyParser-1.7.jar ├── bliki-resources │ ├── Messages_en.properties │ └── interwiki.properties ├── bliki.jar ├── futile.jar ├── jwi.jar └── scala-xml-2.11.0-M4.jar ├── project └── plugins.sbt ├── pull-datasets.sh ├── run-test.sh ├── src └── main │ └── java │ └── edu │ └── berkeley │ └── nlp │ └── entity │ ├── Chunk.scala │ ├── ConllDoc.scala │ ├── ConllDocReader.scala │ ├── ConllDocWriter.scala │ ├── DepConstTree.scala │ ├── Driver.java │ ├── EntitySystem.scala │ ├── GUtil.scala │ ├── GeneralTrainer2.scala │ ├── WordNetInterfacer.scala │ ├── bp │ ├── CombinatorialIterator.scala │ ├── CompletePropertyFactor.scala │ ├── ConstantFactor.scala │ ├── Domain.scala │ ├── Factor.scala │ ├── Node.scala │ └── SimpleFactorGraph.scala │ ├── coref │ ├── AuxiliaryFeaturizer.scala │ ├── ConjFeatures.java │ ├── ConjScheme.java │ ├── ConjType.java │ ├── CorefConllScorer.scala │ ├── CorefDoc.scala │ ├── CorefDocAssembler.scala │ ├── CorefDocAssemblerACE.scala │ ├── CorefEvaluator.scala │ ├── CorefFeaturizerTrainer.scala │ ├── CorefMaskMaker.scala │ ├── CorefPruner.scala │ ├── CorefSystem.scala │ ├── DocumentGraph.scala │ ├── DocumentInferencer.scala │ ├── DocumentInferencerBasic.scala │ ├── DocumentInferencerOracle.scala │ ├── DocumentInferencerRahman.scala │ ├── EntityFeaturizer.scala │ ├── ErrorAnalyzer.scala │ ├── FeatureSetSpecification.scala │ ├── Gender.java │ ├── LexicalCountsBundle.scala │ ├── LexicalInferenceExtractor.scala │ ├── LexicalInferenceFeaturizer.scala │ ├── LexicalInferenceFeaturizerMultiThresh.scala │ ├── Mention.scala │ ├── MentionPropertyComputer.scala │ ├── MentionRankingComputer.scala │ ├── MentionRankingComputerSparse.scala │ ├── MentionRankingDocumentComputer.scala │ ├── MentionType.java │ ├── Number.java │ ├── NumberGenderComputer.scala │ ├── OrderedClustering.scala │ ├── OrderedClusteringBound.scala │ ├── PairwiseIndexingFeaturizer.scala │ ├── PairwiseIndexingFeaturizerJoint.scala │ ├── PairwiseLossFunctions.scala │ ├── PairwiseScorer.scala │ ├── PronounDictionary.scala │ ├── PruningStrategy.scala │ └── package.scala │ ├── joint │ ├── FactorGraphFactory.scala │ ├── GeneralTrainer.scala │ ├── JointComputerShared.scala │ ├── JointDoc.scala │ ├── JointDocACE.scala │ ├── JointDocFactorGraph.scala │ ├── JointDocFactorGraphACE.scala │ ├── JointDocFactorGraphOnto.scala │ ├── JointFeaturizerShared.scala │ ├── JointLossFcns.scala │ ├── JointPredictor.scala │ └── JointPredictorACE.scala │ ├── lang │ ├── ArabicTreebankLanguagePack.java │ ├── CorefLanguagePack.scala │ ├── Language.java │ ├── ModArabicHeadFinder.java │ └── ModCollinsHeadFinder.java │ ├── ner │ ├── CorpusCounts.scala │ ├── MCNerExample.scala │ ├── MCNerFeaturizer.scala │ ├── NEEvaluator.scala │ ├── NESentenceMunger.scala │ ├── NerDriver.java │ ├── NerExample.scala │ ├── NerFeaturizer.scala │ ├── NerPruner.scala │ └── NerSystemLabeled.scala │ ├── preprocess │ ├── ConllDocSharder.scala │ ├── PTBToConllMunger.scala │ ├── PreprocessingDriver.java │ ├── Reprocessor.scala │ ├── SentenceSplitter.scala │ ├── SentenceSplitterTokenizerDriver.java │ ├── Tokenizer.scala │ └── TokenizerTest.scala │ ├── sem │ ├── AbbreviationHandler.scala │ ├── BasicWordNetSemClasser.scala │ ├── BrownClusterInterface.scala │ ├── FancyHeadMatcher.scala │ ├── GoogleNgramUtils.scala │ ├── MentionFilter.scala │ ├── QueryCountCollector.scala │ ├── QueryCountsBundle.scala │ ├── SemClass.scala │ └── SemClasser.scala │ ├── sig │ ├── BootstrapDriver.scala │ ├── BootstrapDriverNER.scala │ ├── BootstrapDriverWiki.scala │ └── MetricComputer.scala │ ├── wiki │ ├── ACEMunger.scala │ ├── ACETester.scala │ ├── BasicWikifier.scala │ ├── BlikiInterface.java │ ├── FahrniOutputAnalyzer.scala │ ├── FahrniWikifier.scala │ ├── JointQueryDenotationChooser.scala │ ├── Query.scala │ ├── QueryChooser.scala │ ├── WikiAnnotReaderWriter.scala │ ├── WikificationEvaluator.scala │ ├── WikificationFeaturizer.scala │ ├── Wikifier.scala │ ├── WikipediaAuxDB.scala │ ├── WikipediaCategoryDB.scala │ ├── WikipediaInterface.scala │ ├── WikipediaLinkDB.scala │ ├── WikipediaRedirectsDB.scala │ ├── WikipediaTitleGivenSurfaceDB.scala │ └── package.scala │ └── xdistrib │ ├── ComponentFeaturizer.scala │ ├── CorefComputerDistrib.scala │ └── DocumentGraphComponents.scala └── test └── text ├── government.txt └── music.txt /berkeley-entity-1.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregdurrett/berkeley-entity/0022442f51d85ceff092e089aa776d6576706a3a/berkeley-entity-1.0.jar -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | name := "berkeley-entity" 2 | 3 | version := "1" 4 | 5 | scalaVersion := "2.11.2" 6 | 7 | mainClass in assembly := Some("edu.berkeley.nlp.entity.Driver") 8 | 9 | -------------------------------------------------------------------------------- /config/base.conf: -------------------------------------------------------------------------------- 1 | create true 2 | useStandardExecPoolDirStrategy false 3 | overwriteExecDir true 4 | execDir specify_execDir 5 | -------------------------------------------------------------------------------- /lib/BerkeleyParser-1.7.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregdurrett/berkeley-entity/0022442f51d85ceff092e089aa776d6576706a3a/lib/BerkeleyParser-1.7.jar -------------------------------------------------------------------------------- /lib/bliki-resources/Messages_en.properties: -------------------------------------------------------------------------------- 1 | wiki.tags.toc.content=Contents 2 | wiki.api.url=http://en.wikipedia.org/w/api.php 3 | wiki.api.media1=Media 4 | wiki.api.media2=Media 5 | wiki.api.special1=Special 6 | wiki.api.special2=Special 7 | wiki.api.talk1=Talk 8 | wiki.api.talk2=Talk 9 | wiki.api.user1=User 10 | wiki.api.user2=User 11 | wiki.api.usertalk1=User_talk 12 | wiki.api.usertalk2=User_talk 13 | wiki.api.meta1=Meta 14 | wiki.api.meta2=Meta 15 | wiki.api.metatalk1=Meta_talk 16 | wiki.api.metatalk2=Meta_talk 17 | wiki.api.image1=Image 18 | wiki.api.image2=File 19 | wiki.api.imagetalk1=Image_talk 20 | wiki.api.imagetalk2=File_talk 21 | wiki.api.mediawiki1=MediaWiki 22 | wiki.api.mediawiki2=MediaWiki 23 | wiki.api.mediawikitalk1=MediaWiki_talk 24 | wiki.api.mediawikitalk2=MediaWiki_talk 25 | wiki.api.template1=Template 26 | wiki.api.template2=Template 27 | wiki.api.templatetalk1=Template_talk 28 | wiki.api.templatetalk2=Template_talk 29 | wiki.api.help1=Help 30 | wiki.api.help2=Help 31 | wiki.api.helptalk1=Help_talk 32 | wiki.api.helptalk2=Help_talk 33 | wiki.api.category1=Category 34 | wiki.api.category2=Category 35 | wiki.api.categorytalk1=Category_talk 36 | wiki.api.categorytalk2=Category_talk -------------------------------------------------------------------------------- /lib/bliki.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregdurrett/berkeley-entity/0022442f51d85ceff092e089aa776d6576706a3a/lib/bliki.jar -------------------------------------------------------------------------------- /lib/futile.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregdurrett/berkeley-entity/0022442f51d85ceff092e089aa776d6576706a3a/lib/futile.jar -------------------------------------------------------------------------------- /lib/jwi.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregdurrett/berkeley-entity/0022442f51d85ceff092e089aa776d6576706a3a/lib/jwi.jar -------------------------------------------------------------------------------- /lib/scala-xml-2.11.0-M4.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregdurrett/berkeley-entity/0022442f51d85ceff092e089aa776d6576706a3a/lib/scala-xml-2.11.0-M4.jar -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns) 2 | 3 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.10") 4 | -------------------------------------------------------------------------------- /pull-datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir data 4 | 5 | # Number and gender data 6 | rm data/gender.data 7 | wget http://www.cs.utexas.edu/~gdurrett/data/gender.data.tgz 8 | tar -xvf gender.data.tgz 9 | mv gender.data data/ 10 | rm gender.data.tgz 11 | 12 | # Brown clusters 13 | wget http://people.csail.mit.edu/maestro/papers/bllip-clusters.gz 14 | gunzip bllip-clusters.gz 15 | mv bllip-clusters data/ 16 | 17 | # CoNLL scorer 18 | wget http://conll.cemantix.org/download/reference-coreference-scorers.v7.tar.gz 19 | tar -xvf reference-coreference-scorers.v7.tar.gz 20 | mkdir scorer 21 | mv reference-coreference-scorers/v7/ scorer 22 | cp scorer/v7/lib/CorScorer.pm lib/ 23 | cp -r scorer/v7/lib/Algorithm lib/ 24 | rm -rf reference-coreference-scorers* 25 | 26 | -------------------------------------------------------------------------------- /run-test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | jarpath=berkeley-entity-1.0.jar 4 | 5 | mkdir test/scratch 6 | mkdir test/preprocessed 7 | mkdir test/coref 8 | mkdir test/corefner 9 | mkdir test/joint 10 | 11 | # Preprocess the data, no NER 12 | if [ ! -f test/preprocessed/government.txt ]; then 13 | echo "RUNNING PREPROCESSING" 14 | java -Xmx2g -cp $jarpath edu.berkeley.nlp.entity.preprocess.PreprocessingDriver ++config/base.conf -execDir test/scratch/preprocess -inputDir test/text -outputDir test/preprocessed 15 | else 16 | echo "Skipping preprocessing..." 17 | fi 18 | 19 | # The following commands demonstrate running: 20 | # 1) the coref system in isolation 21 | # 2) the coref + NER system 22 | # 3) the full joint system 23 | # Note that the joint system does not depend on either of the earlier two; 24 | # this is merely meant to demonstrate possible modes of operation. 25 | 26 | # Run the coreference system 27 | if [ ! -f test/coref/government.txt-0.pred_conll ]; then 28 | echo "RUNNING COREF" 29 | java -Xmx2g -cp $jarpath edu.berkeley.nlp.entity.Driver ++config/base.conf -execDir test/scratch/coref -mode COREF_PREDICT -modelPath models/coref-onto.ser.gz -testPath test/preprocessed -outputPath test/coref -corefDocSuffix "" 30 | else 31 | echo "Skipping coref..." 32 | fi 33 | 34 | # Run the coref+NER system 35 | if [ ! -f test/corefner/output.conll ]; then 36 | echo "RUNNING COREF+NER" 37 | java -Xmx6g -cp $jarpath edu.berkeley.nlp.entity.Driver ++config/base.conf -execDir test/scratch/corefner -mode PREDICT -modelPath models/corefner-onto.ser.gz -testPath test/preprocessed 38 | cp test/scratch/corefner/output*.conll test/corefner/ 39 | else 40 | echo "Skipping coref+ner..." 41 | fi 42 | 43 | # Run the full joint system 44 | # Now run the joint prediction 45 | if [ ! -f test/joint/output.conll ]; then 46 | echo "RUNNING COREF+NER+WIKI" 47 | # First, need to extract the subset of Wikipedia relevant to these documents. We have already 48 | # done this to avoid having. Here is the command used: 49 | #java -Xmx4g -cp $jarpath:lib/bliki-resources edu.berkeley.nlp.entity.wiki.WikipediaInterface -datasetPaths test/preprocessed -docSuffix "" -wikipediaDumpPath data/wikipedia/enwiki-latest-pages-articles.xml -outputPath models/wiki-db-test.ser.gz 50 | java -Xmx8g -cp $jarpath edu.berkeley.nlp.entity.Driver ++config/base.conf -execDir test/scratch/joint -mode PREDICT -modelPath models/joint-onto.ser.gz -testPath test/preprocessed -wikipediaPath models/wiki-db-test.ser.gz 51 | cp test/scratch/joint/output*.conll test/joint/ 52 | else 53 | echo "Skipping coref+ner+wiki..." 54 | fi 55 | 56 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/Chunk.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity 2 | 3 | // Chunks are semi-inclusive intervals. 4 | @SerialVersionUID(1L) 5 | case class Chunk[T](val start: Int, 6 | val end: Int, 7 | val label: T); 8 | 9 | object Chunk { 10 | def seqify[T](chunk: Chunk[T]): Chunk[Seq[T]] = new Chunk(chunk.start, chunk.end, Seq(chunk.label)); 11 | } -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/ConllDoc.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity 2 | import edu.berkeley.nlp.futile.syntax.Tree 3 | 4 | 5 | case class ConllDocJustWords(val docID: String, 6 | val docPartNo: Int, 7 | val words: Seq[Seq[String]]) { 8 | def wordsArrs = words.map(_.toArray).toArray; 9 | } 10 | 11 | // rawText should only be used to save trouble when outputting the document 12 | // for scoring; never at any other time! 13 | case class ConllDoc(val docID: String, 14 | val docPartNo: Int, 15 | val words: Seq[Seq[String]], 16 | val pos: Seq[Seq[String]], 17 | val trees: Seq[DepConstTree], 18 | val nerChunks: Seq[Seq[Chunk[String]]], 19 | val corefChunks: Seq[Seq[Chunk[Int]]], 20 | val speakers: Seq[Seq[String]]) { 21 | 22 | val numSents = words.size; 23 | 24 | def uid = docID -> docPartNo; 25 | 26 | def fileName = { 27 | if (docID.contains("/")) { 28 | docID.substring(docID.lastIndexOf("/") + 1); 29 | } else { 30 | docID; 31 | } 32 | } 33 | 34 | def printableDocName = docID + " (part " + docPartNo + ")"; 35 | 36 | def isConversation = docID.startsWith("bc") || docID.startsWith("wb"); 37 | 38 | def getCorrespondingNERChunk(sentIdx: Int, headIdx: Int): Option[Chunk[String]] = ConllDoc.getCorrespondingNERChunk(nerChunks(sentIdx), headIdx); 39 | } 40 | 41 | object ConllDoc { 42 | 43 | def getCorrespondingNERChunk(nerChunks: Seq[Chunk[String]], headIdx: Int): Option[Chunk[String]] = { 44 | val maybeChunk = nerChunks.filter(chunk => chunk.start <= headIdx && headIdx < chunk.end); 45 | if (maybeChunk.size >= 1) Some(maybeChunk.head) else None; 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/bp/CombinatorialIterator.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.bp 2 | 3 | class CombinatorialIterator(domainSizes: Array[Int]) { 4 | val currComb = domainSizes.map(i => 0); 5 | var isStarting = true; 6 | 7 | def next = { 8 | if (!isStarting) { 9 | // Need to advance 10 | var idx = 0; 11 | var isAdvanced = false; 12 | while (!isAdvanced && idx < domainSizes.size) { 13 | if (currComb(idx) < domainSizes(idx) - 1) { 14 | currComb(idx) += 1; 15 | isAdvanced = true; 16 | } else { 17 | currComb(idx) = 0; 18 | idx += 1; 19 | } 20 | } 21 | } 22 | isStarting = false; 23 | currComb; 24 | } 25 | 26 | def isDone = !isStarting && (0 until currComb.size).map(i => currComb(i) == domainSizes(i) - 1).reduce(_ && _); 27 | 28 | def hasNext = !isDone; 29 | } 30 | 31 | object CombinatorialIterator { 32 | def main(args: Array[String]) { 33 | val combTest = new CombinatorialIterator(Array(3, 2, 1, 4)); 34 | while (combTest.hasNext) { 35 | println(combTest.next.toSeq); 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/bp/Domain.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.bp 2 | 3 | case class Domain[T](val entries: Array[T]) { 4 | def size = entries.size 5 | 6 | def indexOf(entry: T) = entries.indexOf(entry); 7 | 8 | def value(idx: Int): T = entries(idx); 9 | 10 | override def toString() = entries.foldLeft("")((str, entry) => str + entry + " ").dropRight(1); 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/bp/Node.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.bp 2 | import scala.collection.mutable.ArrayBuffer 3 | import edu.berkeley.nlp.futile.util.Logger 4 | import edu.berkeley.nlp.entity.GUtil 5 | 6 | class Node[T](val domain: Domain[T]) { 7 | var factors = new ArrayBuffer[Factor](); 8 | var receivedMessages: Array[Array[Double]] = null; 9 | var sentMessages: Array[Array[Double]] = null; 10 | var cachedBeliefsOrMarginals: Array[Double] = Array.fill(domain.size)(0.0); 11 | 12 | def registerFactor(factor: Factor) { 13 | factors += factor; 14 | } 15 | 16 | // TODO: Do I need this null thing? 17 | def initializeReceivedMessagesUniform() { 18 | if (receivedMessages == null) { 19 | receivedMessages = new Array[Array[Double]](factors.size); 20 | } else { 21 | for (i <- 0 until receivedMessages.size) { 22 | receivedMessages(i) = null; 23 | } 24 | } 25 | } 26 | 27 | def clearSentMessages() { 28 | this.sentMessages = null; 29 | } 30 | 31 | // This is just here so we can let things be null...At some point, it was a problem because 32 | // the received messages remember which factors sent them, so clearing them for some reason 33 | // caused problems (maybe writing the value 1.0 was problematic when we weren't clearing the 34 | // received messages on the other end?). Can probably get rid of this somehow and just do the 35 | // obvious thing of initializing messages to 1.0. 36 | def receivedMessageValue(i: Int, j: Int): Double = { 37 | if (receivedMessages(i) == null) { 38 | 1.0; 39 | } else { 40 | receivedMessages(i)(j); 41 | } 42 | } 43 | 44 | def receiveMessage(factor: Factor, message: Array[Double]) { 45 | // Lots of checks on well-formedness 46 | require(receivedMessages != null); 47 | require(message.size == domain.size); 48 | var messageIdx = 0; 49 | var total = 0.0 50 | // Message can contain some zeroes but can't be all zeroes 51 | while (messageIdx < message.size) { 52 | if (message(messageIdx).isNaN() || message(messageIdx).isInfinite) { 53 | Logger.logss("For domain: " + domain + ", bad received message: " + message.toSeq + " from " + factor.getClass()); 54 | Logger.logss("Previous message: " + receivedMessages(factors.indexOf(factor)).toSeq); 55 | require(false); 56 | } 57 | total += message(messageIdx); 58 | messageIdx += 1; 59 | } 60 | if (total == 0) { 61 | Logger.logss("For domain: " + domain + ", bad received message: " + message.toSeq + " from " + factor.getClass()); 62 | Logger.logss("Previous message: " + receivedMessages(factors.indexOf(factor)).toSeq); 63 | require(false) 64 | } 65 | // This is what the method actually does 66 | val idx = factors.indexOf(factor); 67 | require(idx != -1 && idx < receivedMessages.size); 68 | receivedMessages(idx) = message; 69 | } 70 | 71 | def sendMessages() { 72 | sendMessages(1.0); 73 | } 74 | 75 | def sendMessages(messageMultiplier: Double) { 76 | sendMessagesUseLogSpace(messageMultiplier); 77 | } 78 | 79 | // Received messages get exponentiated 80 | def sendMessagesUseLogSpace(messageMultiplier: Double) { 81 | for (i <- 0 until cachedBeliefsOrMarginals.size) { 82 | cachedBeliefsOrMarginals(i) = 0.0; 83 | } 84 | require(receivedMessages.size == factors.size); 85 | for (i <- 0 until receivedMessages.size) { 86 | var j = 0; 87 | while (j < cachedBeliefsOrMarginals.size) { 88 | cachedBeliefsOrMarginals(j) += Math.log(receivedMessageValue(i, j)) * messageMultiplier; 89 | j += 1; 90 | } 91 | } 92 | GUtil.logNormalizei(cachedBeliefsOrMarginals); 93 | require(!GUtil.containsNaN(cachedBeliefsOrMarginals), cachedBeliefsOrMarginals.toSeq) 94 | for (i <- 0 until cachedBeliefsOrMarginals.size) { 95 | cachedBeliefsOrMarginals(i) = Math.exp(cachedBeliefsOrMarginals(i)); 96 | } 97 | require(!GUtil.containsNaN(cachedBeliefsOrMarginals), cachedBeliefsOrMarginals.toSeq) 98 | if (sentMessages == null) { 99 | sentMessages = new Array[Array[Double]](factors.size); 100 | } 101 | for (i <- 0 until factors.length) { 102 | if (sentMessages(i) == null) { 103 | sentMessages(i) = new Array[Double](domain.size); 104 | } 105 | var j = 0; 106 | var normalizer = 0.0; 107 | while (j < domain.size) { 108 | val rmVal = receivedMessageValue(i, j); 109 | if (rmVal == 0) { 110 | sentMessages(i)(j) = 0; 111 | } else { 112 | val msgVal = cachedBeliefsOrMarginals(j)/rmVal; 113 | normalizer += msgVal; 114 | sentMessages(i)(j) = msgVal; 115 | } 116 | j += 1; 117 | } 118 | require(normalizer > 0, domain.entries.toSeq); 119 | j = 0; 120 | while (j < domain.size) { 121 | sentMessages(i)(j) /= normalizer; 122 | j += 1; 123 | } 124 | factors(i).receiveMessage(this, sentMessages(i)); 125 | } 126 | } 127 | 128 | def getMarginals(): Array[Double] = { 129 | getMarginalsUseLogSpace(1.0); 130 | } 131 | 132 | def getMarginals(messageMultiplier: Double): Array[Double] = { 133 | getMarginalsUseLogSpace(messageMultiplier); 134 | } 135 | 136 | def getMarginalsUseLogSpace(messageMultiplier: Double): Array[Double] = { 137 | for (i <- 0 until cachedBeliefsOrMarginals.size) { 138 | cachedBeliefsOrMarginals(i) = 0.0; 139 | } 140 | for (i <- 0 until cachedBeliefsOrMarginals.size) { 141 | for (j <- 0 until receivedMessages.size) { 142 | cachedBeliefsOrMarginals(i) += Math.log(receivedMessageValue(j, i)) * messageMultiplier; 143 | } 144 | } 145 | // if (domain.size < 20) { 146 | // Logger.logss("Node with domain " + domain.entries.toSeq + " receiving messages from " + factors.size + " factors"); 147 | // receivedMessages.foreach(msg => Logger.logss(if (msg != null) msg.toSeq else "null")); 148 | // } 149 | GUtil.logNormalizei(cachedBeliefsOrMarginals); 150 | for (i <- 0 until cachedBeliefsOrMarginals.size) { 151 | cachedBeliefsOrMarginals(i) = Math.exp(cachedBeliefsOrMarginals(i)); 152 | } 153 | cachedBeliefsOrMarginals 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/AuxiliaryFeaturizer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | 3 | trait AuxiliaryFeaturizer extends Serializable { 4 | def featurize(docGraph: DocumentGraph, currMentIdx: Int, antecedentIdx: Int): Seq[String]; 5 | } -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/ConjFeatures.java: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref; 2 | 3 | 4 | public enum ConjFeatures { 5 | NONE, TYPE, TYPE_OR_RAW_PRON, TYPE_OR_CANONICAL_PRON, SEMCLASS_OR_CANONICAL_PRON, SEMCLASS_OR_CANONICAL_PRON_COORD, 6 | NER_OR_CANONICAL_PRON, NERFINE_OR_CANONICAL_PRON, SEMCLASS_NER_OR_CANONICAL_PRON, 7 | CUSTOM_OR_CANONICAL_PRON, CUSTOM_NER_OR_CANONICAL_PRON, CUSTOM_NERMED_OR_CANONICAL_PRON, CUSTOM_NERFINE_OR_CANONICAL_PRON; 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/ConjScheme.java: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref; 2 | 3 | 4 | public enum ConjScheme { 5 | COARSE_CURRENT_BOTH, BOTH, COARSE_BOTH, COARSE_BOTH_BLACKLIST, COARSE_BOTH_WHITELIST; 6 | } 7 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/ConjType.java: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref; 2 | 3 | // REALLY_ONLY_FINE = don't even add the completely unconjoined features 4 | public enum ConjType { 5 | NONE, TYPE, TYPE_OR_RAW_PRON, CANONICAL_NOPRONPRON, 6 | CANONICAL, CANONICAL_AND_SEM, 7 | CANONICAL_ONLY_FINE, CANONICAL_AND_SEM_ONLY_FINE, 8 | CANONICAL_AND_SEM_REALLY_ONLY_FINE; 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/CorefConllScorer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | import java.io.File 3 | import java.io.PrintWriter 4 | import java.util.regex.Pattern 5 | 6 | import scala.collection.mutable.ArrayBuffer 7 | import scala.collection.mutable.HashMap 8 | import scala.sys.process.stringSeqToProcess 9 | import scala.sys.process.Process 10 | import edu.berkeley.nlp.futile.util.Logger 11 | import edu.berkeley.nlp.entity.Driver 12 | import edu.berkeley.nlp.entity.ConllDoc 13 | import edu.berkeley.nlp.entity.ConllDocWriter 14 | 15 | class CorefConllScorer(val conllEvalScriptPath: String) { 16 | 17 | def renderFinalScore(conllDocs: Seq[ConllDoc], rawPredClusterings: Seq[OrderedClusteringBound], goldClusterings: Seq[OrderedClusteringBound]) = { 18 | val summary = score(conllDocs, rawPredClusterings, goldClusterings, true); 19 | CorefConllScorer.processConllString(summary, false); 20 | } 21 | 22 | def renderSuffStats(conllDoc: ConllDoc, rawPredClustering: OrderedClusteringBound, goldClustering: OrderedClusteringBound) = { 23 | val summary = score(Seq(conllDoc), Seq(rawPredClustering), Seq(goldClustering), false); 24 | CorefConllScorer.processConllString(summary, true); 25 | } 26 | 27 | def score(conllDocs: Seq[ConllDoc], rawPredClusterings: Seq[OrderedClusteringBound], goldClusterings: Seq[OrderedClusteringBound], saveTempFiles: Boolean) = { 28 | val predClusterings = if (Driver.doConllPostprocessing) rawPredClusterings.map(_.postprocessForConll()) else rawPredClusterings; 29 | // var predFile = File.createTempFile("temp", ".conll"); 30 | val (predFile, goldFile) = if (Driver.conllOutputDir != "" && saveTempFiles) { 31 | val pFile = File.createTempFile("temp", ".conll", new File(Driver.conllOutputDir)); 32 | val gFile = new File(pFile.getAbsolutePath() + "-gold"); 33 | Logger.logss("PRED FILE: " + pFile.getAbsolutePath()); 34 | Logger.logss("GOLD FILE: " + gFile.getAbsolutePath()); 35 | Logger.logss("To score, run:"); 36 | Logger.logss("perl scorer.pl all " + gFile.getAbsolutePath() + " " + pFile.getAbsolutePath() + " none"); 37 | (pFile, gFile); 38 | } else { 39 | val pFile = File.createTempFile("temp", ".conll"); 40 | val gFile = new File(pFile.getAbsolutePath() + "-gold"); 41 | pFile.deleteOnExit(); 42 | gFile.deleteOnExit(); 43 | (pFile, gFile); 44 | } 45 | val predWriter = new PrintWriter(predFile); 46 | val goldWriter = new PrintWriter(goldFile); 47 | for (i <- 0 until conllDocs.size) { 48 | ConllDocWriter.writeDoc(predWriter, conllDocs(i), predClusterings(i)); 49 | ConllDocWriter.writeDoc(goldWriter, conllDocs(i), goldClusterings(i)); 50 | } 51 | // Flush and close the buffers 52 | predWriter.close(); 53 | goldWriter.close(); 54 | Logger.logss("Running scoring program..."); 55 | import edu.berkeley.nlp.entity.Driver; 56 | // Build and run the process for the CoNLL eval script script 57 | import scala.sys.process._ 58 | val output = Process(Seq(conllEvalScriptPath, "all", goldFile.getAbsolutePath(), predFile.getAbsolutePath(), "none")).lines; 59 | Logger.logss("Scoring program complete!"); 60 | output.reduce(_ + "\n" + _); 61 | } 62 | } 63 | 64 | object CorefConllScorer { 65 | 66 | def processConllString(summary: String, renderSuffStats: Boolean) = { 67 | val pr = Pattern.compile("Coreference:.*\\(([0-9.]+) / ([0-9.]+)\\).*\\(([0-9.]+) / ([0-9.]+)\\)"); 68 | val prMatcher = pr.matcher(summary); 69 | var prCount = 0; 70 | var (mucPNum, mucPDenom, mucRNum, mucRDenom) = (0.0, 0.0, 0.0, 0.0); 71 | var (bcubPNum, bcubPDenom, bcubRNum, bcubRDenom) = (0.0, 0.0, 0.0, 0.0); 72 | var (ceafePNum, ceafePDenom, ceafeRNum, ceafeRDenom) = (0.0, 0.0, 0.0, 0.0); 73 | // Four matches: MUC, B-cubed, CEAFM, CEAFE (BLANC doesn't match because of different formatting) 74 | while (prMatcher.find()) { 75 | if (prCount == 0) { 76 | mucRNum = prMatcher.group(1).toDouble; 77 | mucRDenom = prMatcher.group(2).toDouble; 78 | mucPNum = prMatcher.group(3).toDouble; 79 | mucPDenom = prMatcher.group(4).toDouble; 80 | } 81 | if (prCount == 1) { 82 | bcubRNum = prMatcher.group(1).toDouble; 83 | bcubRDenom = prMatcher.group(2).toDouble; 84 | bcubPNum = prMatcher.group(3).toDouble; 85 | bcubPDenom = prMatcher.group(4).toDouble; 86 | } 87 | if (prCount == 3) { 88 | ceafeRNum = prMatcher.group(1).toDouble; 89 | ceafeRDenom = prMatcher.group(2).toDouble; 90 | ceafePNum = prMatcher.group(3).toDouble; 91 | ceafePDenom = prMatcher.group(4).toDouble; 92 | } 93 | prCount += 1; 94 | } 95 | val mucP = mucPNum/mucPDenom * 100.0; 96 | val mucR = mucRNum/mucRDenom * 100.0; 97 | val mucF = 2 * mucP * mucR/(mucP + mucR); 98 | val bcubP = bcubPNum/bcubPDenom * 100.0; 99 | val bcubR = bcubRNum/bcubRDenom * 100.0; 100 | val bcubF = 2 * bcubP * bcubR/(bcubP + bcubR); 101 | val ceafeP = ceafePNum/ceafePDenom * 100.0; 102 | val ceafeR = ceafeRNum/ceafeRDenom * 100.0; 103 | val ceafeF = 2 * ceafeP * ceafeR/(ceafeP + ceafeR); 104 | val avg = (mucF + bcubF + ceafeF)/3.0; 105 | if (renderSuffStats) { 106 | "MUC/BCUB/CEAFE P/R N/D:\t" + mucPNum + "\t" + mucPDenom + "\t" + mucRNum + "\t" + mucRDenom + "\t" + bcubPNum + "\t" + bcubPDenom + "\t" + bcubRNum + "\t" + bcubRDenom + "\t" + ceafePNum + "\t" + ceafePDenom + "\t" + ceafeRNum + "\t" +ceafeRDenom; 107 | } else { 108 | "MUC P-R-F1, BCUB P-R-F1, CEAFE P-R-F1, Average:\t" + fmt(mucP) + "\t" + fmt(mucR) + "\t" + fmt(mucF) + "\t" + fmt(bcubP) + "\t" + fmt(bcubR) + "\t" + fmt(bcubF) + "\t" + fmt(ceafeP) + "\t" + fmt(ceafeR) + "\t" + fmt(ceafeF) + "\t" + fmt(avg) + "\n" + 109 | "MUC = " + fmt(mucF) + ", BCUB = " + fmt(bcubF) + ", CEAFE = " + fmt(ceafeF) + ", AVG = " + fmt(avg); 110 | } 111 | } 112 | 113 | private def fmt(d: Double): String = { 114 | val str = "" + (d + 0.005); 115 | str.substring(0, Math.min(str.length(), str.indexOf(".") + 3)); 116 | } 117 | 118 | def main(args: Array[String]) { 119 | import scala.sys.process._ 120 | val cmd = Seq("ls", "clean-data/"); 121 | println(cmd.lines.toIndexedSeq); 122 | } 123 | 124 | } 125 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/CorefDoc.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | import java.io.File 3 | import scala.collection.JavaConverters.asScalaBufferConverter 4 | import scala.collection.JavaConverters.mapAsScalaMapConverter 5 | import scala.collection.mutable.HashSet 6 | import scala.collection.mutable.ArrayBuffer 7 | import scala.collection.mutable.HashMap 8 | import edu.berkeley.nlp.entity.lang.Language 9 | import edu.berkeley.nlp.futile.syntax.Trees.PennTreeRenderer 10 | import edu.berkeley.nlp.futile.util.Counter 11 | import edu.berkeley.nlp.futile.util.Logger 12 | import edu.berkeley.nlp.entity.GUtil 13 | import edu.berkeley.nlp.entity.ConllDoc 14 | 15 | case class CorefDoc(val rawDoc: ConllDoc, 16 | val goldMentions: Seq[Mention], 17 | val goldClustering: OrderedClustering, 18 | val predMentions: Seq[Mention]) { 19 | 20 | var oraclePredOrderedClustering: OrderedClustering = null; 21 | 22 | def numPredMents = predMentions.size; 23 | 24 | def isInGold(predMent: Mention): Boolean = { 25 | goldMentions.filter(goldMent => goldMent.sentIdx == predMent.sentIdx && goldMent.startIdx == predMent.startIdx && goldMent.endIdx == predMent.endIdx).size > 0; 26 | } 27 | 28 | def getGoldMentionF1SuffStats: (Int, Int, Int) = { 29 | (predMentions.filter(isInGold(_)).size, predMentions.size, goldMentions.size) 30 | } 31 | 32 | /** 33 | * Determines and caches an "oracle predicted clustering." For each predicted mention: 34 | * --If that mention does not have a corresponding gold mention (start and end indices match): 35 | * --Put the current mention in its own cluster. 36 | * --If that mention does have a corresponding gold mention: 37 | * --Fetch that mention's antecedents (if any) 38 | * --Choose the first with a corresponding predicted mention (if any) 39 | * --Assign this mention as the current mention's parent. 40 | */ 41 | def getOraclePredClustering = { 42 | if (oraclePredOrderedClustering == null) { 43 | val predToGoldIdxMap = new HashMap[Int,Int](); 44 | val goldToPredIdxMap = new HashMap[Int,Int](); 45 | for (pIdx <- 0 until predMentions.size) { 46 | for (gIdx <- 0 until goldMentions.size) { 47 | val pMent = predMentions(pIdx); 48 | val gMent = goldMentions(gIdx); 49 | if (pMent.sentIdx == gMent.sentIdx && pMent.startIdx == gMent.startIdx && pMent.endIdx == gMent.endIdx) { 50 | predToGoldIdxMap.put(pIdx, gIdx); 51 | goldToPredIdxMap.put(gIdx, pIdx); 52 | } 53 | } 54 | } 55 | val oracleClusterIds = new ArrayBuffer[Int]; 56 | var nextClusterId = 0; 57 | for (predIdx <- 0 until predMentions.size) { 58 | // Fetch the parent 59 | var parent = -1; 60 | if (predToGoldIdxMap.contains(predIdx)) { 61 | val correspondingGoldIdx = predToGoldIdxMap(predIdx); 62 | // Find the antecedents of the corresponding gold mention 63 | val goldAntecedentIdxs = goldClustering.getAllAntecedents(correspondingGoldIdx); 64 | // For each one, do a weird data sanitizing check, then try to find a corresponding 65 | // predicted mention to use as the predicted parent 66 | for (goldAntecedentIdx <- goldAntecedentIdxs.reverse) { 67 | val correspondingGold = goldMentions(correspondingGoldIdx); 68 | val goldAntecedent = goldMentions(goldAntecedentIdx); 69 | // wsj_0990 has some duplicate gold mentions, need to handle these... 70 | val sameMention = goldAntecedent.sentIdx == correspondingGold.sentIdx && goldAntecedent.startIdx == correspondingGold.startIdx && goldAntecedent.endIdx == correspondingGold.endIdx 71 | if (!sameMention && goldToPredIdxMap.contains(goldAntecedentIdx)) { 72 | val predAntecedentIdx = goldToPredIdxMap(goldAntecedentIdx) 73 | if (predAntecedentIdx >= predIdx) { 74 | val ment = predMentions(predIdx); 75 | val predAntecedent = predMentions(predAntecedentIdx); 76 | Logger.logss("Monotonicity violated:\n" + 77 | "Antecedent(" + predAntecedentIdx + "): " + predAntecedent.startIdx + " " + predAntecedent.endIdx + " " + predAntecedent.headIdx + "\n" + 78 | "Current(" + predMentions.indexOf(ment) + "): " + ment.startIdx + " " + ment.endIdx + " " + ment.headIdx + "\n" + 79 | "Gold antecedent(" + goldMentions.indexOf(goldAntecedent) + "): " + goldAntecedent.startIdx + " " + goldAntecedent.endIdx + " " + goldAntecedent.headIdx + "\n" + 80 | "Gold current(" + goldMentions.indexOf(correspondingGold) + "): " + correspondingGold.startIdx + " " + correspondingGold.endIdx + " " + correspondingGold.headIdx); 81 | Logger.logss("Setting parent to -1..."); 82 | parent = -1; 83 | } else { 84 | parent = predAntecedentIdx 85 | } 86 | } 87 | } 88 | } 89 | // Now compute the oracle cluster ID 90 | val clusterId = if (parent == -1) { 91 | nextClusterId += 1; 92 | nextClusterId - 1; 93 | } else { 94 | oracleClusterIds(parent); 95 | } 96 | oracleClusterIds += clusterId; 97 | } 98 | oraclePredOrderedClustering = OrderedClustering.createFromClusterIds(oracleClusterIds); 99 | } 100 | oraclePredOrderedClustering 101 | } 102 | } 103 | 104 | object CorefDoc { 105 | def displayMentionPRF1(docs: Seq[CorefDoc]) { 106 | val suffStats = docs.map(_.getGoldMentionF1SuffStats).reduce((a, b) => (a._1 + b._1, a._2 + b._2, a._3 + b._3)); 107 | Logger.logss(GUtil.renderPRF1(suffStats._1, suffStats._2, suffStats._3)) 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/CorefDocAssemblerACE.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | 3 | import edu.berkeley.nlp.entity.lang.EnglishCorefLanguagePack 4 | import edu.berkeley.nlp.futile.util.Logger 5 | import scala.collection.mutable.ArrayBuffer 6 | import edu.berkeley.nlp.entity.wiki.ACEMunger 7 | import java.io.File 8 | import edu.berkeley.nlp.entity.ConllDoc 9 | 10 | class CorefDocAssemblerACE(dirPath: String) { 11 | 12 | val langPack = new EnglishCorefLanguagePack() 13 | 14 | def createCorefDoc(rawDoc: ConllDoc, propertyComputer: MentionPropertyComputer): CorefDoc = { 15 | val (goldMentions, goldClustering) = CorefDocAssembler.extractGoldMentions(rawDoc, propertyComputer, langPack); 16 | if (goldMentions.size == 0) { 17 | Logger.logss("WARNING: no gold mentions on document " + rawDoc.printableDocName); 18 | } 19 | // TODO: Load pred mentions here 20 | val mentSpansEachSent: Seq[Seq[(Int,Int)]] = ACEMunger.getPredMentionsBySentence(new File(dirPath + "/" + rawDoc.docID)); 21 | val predMentions = new ArrayBuffer[Mention](); 22 | for (sentIdx <- 0 until mentSpansEachSent.size; mentSpan <- mentSpansEachSent(sentIdx)) { 23 | val headIdx = rawDoc.trees(sentIdx).getSpanHead(mentSpan._1, mentSpan._2); 24 | predMentions += Mention.createMentionComputeProperties(rawDoc, predMentions.size, sentIdx, mentSpan._1, mentSpan._2, headIdx, Seq(headIdx), false, propertyComputer, langPack) 25 | } 26 | 27 | Logger.logss(rawDoc.docID); 28 | for (i <- 0 until rawDoc.numSents) { 29 | Logger.logss(goldMentions.filter(_.sentIdx == i).map(ment => ment.startIdx -> ment.endIdx)); 30 | Logger.logss(goldMentions.filter(_.sentIdx == i).map(ment => ment.words)); 31 | Logger.logss(predMentions.filter(_.sentIdx == i).map(ment => ment.startIdx -> ment.endIdx)); 32 | Logger.logss(predMentions.filter(_.sentIdx == i).map(ment => ment.words)); 33 | } 34 | 35 | new CorefDoc(rawDoc, goldMentions, goldClustering, predMentions) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/CorefMaskMaker.scala: -------------------------------------------------------------------------------- 1 | //package edu.berkeley.nlp.entity.coref 2 | // 3 | //import scala.collection.mutable.ArrayBuffer 4 | //import scala.collection.mutable.HashMap 5 | //import scala.util.Random 6 | //import edu.berkeley.nlp.entity.Driver; 7 | //import edu.berkeley.nlp.entity.GUtil 8 | //import edu.berkeley.nlp.entity.sem.BasicWordNetSemClasser 9 | //import edu.berkeley.nlp.entity.sem.QueryCountsBundle 10 | //import edu.berkeley.nlp.entity.wiki.WikipediaInterface 11 | //import edu.berkeley.nlp.futile.fig.basic.Indexer 12 | //import edu.berkeley.nlp.futile.util.Logger 13 | // 14 | //object CorefMaskMaker { 15 | // 16 | //} 17 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/DocumentInferencer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | import edu.berkeley.nlp.futile.util.Logger 3 | import edu.berkeley.nlp.futile.fig.basic.Indexer 4 | 5 | trait DocumentInferencer { 6 | 7 | def getInitialWeightVector(featureIndexer: Indexer[String]): Array[Float]; 8 | 9 | def computeLikelihood(docGraph: DocumentGraph, 10 | pairwiseScorer: PairwiseScorer, 11 | lossFcn: (CorefDoc, Int, Int) => Float): Float; 12 | 13 | def addUnregularizedStochasticGradient(docGraph: DocumentGraph, 14 | pairwiseScorer: PairwiseScorer, 15 | lossFcn: (CorefDoc, Int, Int) => Float, 16 | gradient: Array[Float]); 17 | 18 | def viterbiDecode(docGraph: DocumentGraph, pairwiseScorer: PairwiseScorer): Array[Int]; 19 | 20 | def finishPrintStats(); 21 | 22 | def viterbiDecodeFormClustering(docGraph: DocumentGraph, pairwiseScorer: PairwiseScorer): (Array[Int], OrderedClustering) = { 23 | val predBackptrs = viterbiDecode(docGraph, pairwiseScorer); 24 | (predBackptrs, OrderedClustering.createFromBackpointers(predBackptrs)); 25 | } 26 | 27 | def viterbiDecodeAll(docGraphs: Seq[DocumentGraph], pairwiseScorer: PairwiseScorer): Array[Array[Int]] = { 28 | val allPredBackptrs = new Array[Array[Int]](docGraphs.size); 29 | for (i <- 0 until docGraphs.size) { 30 | val docGraph = docGraphs(i); 31 | Logger.logs("Decoding " + i); 32 | val predBackptrs = viterbiDecode(docGraph, pairwiseScorer); 33 | allPredBackptrs(i) = predBackptrs; 34 | } 35 | allPredBackptrs; 36 | } 37 | 38 | def viterbiDecodeAllFormClusterings(docGraphs: Seq[DocumentGraph], pairwiseScorer: PairwiseScorer): (Array[Array[Int]], Array[OrderedClustering]) = { 39 | val allPredBackptrs = viterbiDecodeAll(docGraphs, pairwiseScorer); 40 | val allPredClusteringsSeq = (0 until docGraphs.size).map(i => OrderedClustering.createFromBackpointers(allPredBackptrs(i))); 41 | (allPredBackptrs, allPredClusteringsSeq.toArray) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/DocumentInferencerBasic.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | import edu.berkeley.nlp.futile.fig.basic.Indexer 3 | import edu.berkeley.nlp.entity.Driver; 4 | import edu.berkeley.nlp.entity.GUtil 5 | import edu.berkeley.nlp.futile.util.Logger 6 | import edu.berkeley.nlp.futile.util.Counter 7 | import scala.collection.JavaConverters._ 8 | 9 | class DocumentInferencerBasic extends DocumentInferencer { 10 | 11 | def getInitialWeightVector(featureIndexer: Indexer[String]): Array[Float] = Array.fill(featureIndexer.size())(0.0F); 12 | 13 | /** 14 | * N.B. always returns a reference to the same matrix, so don't call twice in a row and 15 | * attempt to use the results of both computations 16 | */ 17 | def computeMarginals(docGraph: DocumentGraph, 18 | gold: Boolean, 19 | lossFcn: (CorefDoc, Int, Int) => Float, 20 | pairwiseScorer: PairwiseScorer): Array[Array[Float]] = { 21 | computeMarginals(docGraph, gold, lossFcn, docGraph.featurizeIndexAndScoreNonPrunedUseCache(pairwiseScorer)._2) 22 | } 23 | 24 | def computeMarginals(docGraph: DocumentGraph, 25 | gold: Boolean, 26 | lossFcn: (CorefDoc, Int, Int) => Float, 27 | scoresChart: Array[Array[Float]]): Array[Array[Float]] = { 28 | val marginals = docGraph.cachedMarginalMatrix; 29 | for (i <- 0 until docGraph.size) { 30 | var normalizer = 0.0F; 31 | // Restrict to gold antecedents if we're doing gold, but don't load the gold antecedents 32 | // if we're not. 33 | val goldAntecedents: Seq[Int] = if (gold) docGraph.getGoldAntecedentsUnderCurrentPruning(i) else null; 34 | for (j <- 0 to i) { 35 | // If this is a legal antecedent 36 | if (!docGraph.isPruned(i, j) && (!gold || goldAntecedents.contains(j))) { 37 | // N.B. Including lossFcn is okay even for gold because it should be zero 38 | val unnormalizedProb = Math.exp(scoresChart(i)(j) + lossFcn(docGraph.corefDoc, i, j)).toFloat; 39 | marginals(i)(j) = unnormalizedProb; 40 | normalizer += unnormalizedProb; 41 | } else { 42 | marginals(i)(j) = 0.0F; 43 | } 44 | } 45 | for (j <- 0 to i) { 46 | marginals(i)(j) /= normalizer; 47 | } 48 | } 49 | marginals; 50 | } 51 | 52 | def computeLikelihood(docGraph: DocumentGraph, 53 | pairwiseScorer: PairwiseScorer, 54 | lossFcn: (CorefDoc, Int, Int) => Float): Float = { 55 | var likelihood = 0.0F; 56 | val marginals = computeMarginals(docGraph, false, lossFcn, pairwiseScorer); 57 | for (i <- 0 until docGraph.size) { 58 | val goldAntecedents = docGraph.getGoldAntecedentsUnderCurrentPruning(i); 59 | var currProb = 0.0; 60 | for (j <- goldAntecedents) { 61 | currProb += marginals(i)(j); 62 | } 63 | var currLogProb = Math.log(currProb).toFloat; 64 | if (currLogProb.isInfinite()) { 65 | currLogProb = -30; 66 | } 67 | likelihood += currLogProb; 68 | } 69 | likelihood; 70 | } 71 | 72 | def addUnregularizedStochasticGradient(docGraph: DocumentGraph, 73 | pairwiseScorer: PairwiseScorer, 74 | lossFcn: (CorefDoc, Int, Int) => Float, 75 | gradient: Array[Float]) = { 76 | val (featsChart, scoresChart) = docGraph.featurizeIndexAndScoreNonPrunedUseCache(pairwiseScorer); 77 | // N.B. Can't have pred marginals and gold marginals around at the same time because 78 | // they both live in the same cached matrix 79 | val predMarginals = this.computeMarginals(docGraph, false, lossFcn, scoresChart); 80 | for (i <- 0 until docGraph.size) { 81 | for (j <- 0 to i) { 82 | if (predMarginals(i)(j) > 1e-20) { 83 | GUtil.addToGradient(featsChart(i)(j), -predMarginals(i)(j), gradient); 84 | } 85 | } 86 | } 87 | val goldMarginals = this.computeMarginals(docGraph, true, lossFcn, scoresChart); 88 | for (i <- 0 until docGraph.size) { 89 | for (j <- 0 to i) { 90 | if (goldMarginals(i)(j) > 1e-20) { 91 | GUtil.addToGradient(featsChart(i)(j), goldMarginals(i)(j), gradient); 92 | } 93 | } 94 | } 95 | } 96 | 97 | def viterbiDecode(docGraph: DocumentGraph, scorer: PairwiseScorer): Array[Int] = { 98 | val (featsChart, scoresChart) = docGraph.featurizeIndexAndScoreNonPrunedUseCache(scorer); 99 | viterbiDecode(scoresChart); 100 | } 101 | 102 | def viterbiDecode(scoresChart: Array[Array[Float]]) = { 103 | val scoreFcn = (idx: Int) => scoresChart(idx); 104 | DocumentInferencerBasic.decodeMax(scoresChart.size, scoreFcn); 105 | } 106 | 107 | def finishPrintStats() = {} 108 | } 109 | 110 | object DocumentInferencerBasic { 111 | 112 | def decodeMax(size: Int, scoreFcn: Int => Array[Float]): Array[Int] = { 113 | val backpointers = new Array[Int](size); 114 | for (i <- 0 until size) { 115 | val allScores = scoreFcn(i); 116 | var bestIdx = -1; 117 | var bestScore = Float.NegativeInfinity; 118 | for (j <- 0 to i) { 119 | val currScore = allScores(j); 120 | if (bestIdx == -1 || currScore > bestScore) { 121 | bestIdx = j; 122 | bestScore = currScore; 123 | } 124 | } 125 | backpointers(i) = bestIdx; 126 | } 127 | backpointers; 128 | } 129 | 130 | } 131 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/DocumentInferencerOracle.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | import edu.berkeley.nlp.futile.fig.basic.Indexer 3 | 4 | class DocumentInferencerOracle extends DocumentInferencer { 5 | 6 | def getInitialWeightVector(featureIndexer: Indexer[String]): Array[Float] = Array.fill(featureIndexer.size())(0.0F); 7 | 8 | def computeLikelihood(docGraph: DocumentGraph, 9 | pairwiseScorer: PairwiseScorer, 10 | lossFcn: (CorefDoc, Int, Int) => Float) = { 11 | 0.0F; 12 | } 13 | 14 | def addUnregularizedStochasticGradient(docGraph: DocumentGraph, 15 | pairwiseScorer: PairwiseScorer, 16 | lossFcn: (CorefDoc, Int, Int) => Float, 17 | gradient: Array[Float]) = { 18 | } 19 | 20 | def viterbiDecode(docGraph: DocumentGraph, 21 | pairwiseScorer: PairwiseScorer): Array[Int] = { 22 | val clustering = docGraph.getOraclePredClustering(); 23 | val resultSeq = for (i <- 0 until docGraph.size) yield { 24 | val immediateAntecedentOrMinus1 = clustering.getImmediateAntecedent(i); 25 | if (immediateAntecedentOrMinus1 == -1) { 26 | i; 27 | } else { 28 | docGraph.getMentions.indexOf(immediateAntecedentOrMinus1); 29 | } 30 | } 31 | resultSeq.toArray; 32 | } 33 | 34 | def finishPrintStats() = {} 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/ErrorAnalyzer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | 3 | import edu.berkeley.nlp.futile.util.Logger 4 | 5 | object ErrorAnalyzer { 6 | 7 | /** 8 | * Prints error analysis on a dev set. Predictions are passed as both predicted backpointers 9 | * as well as the final clusterings those backpointers yield when you take the transitive 10 | * closure. Finally, the scorer (featurizer + weights) are also provided in case you want 11 | * to recompute anything, look at how badly the gold scored, etc. 12 | * 13 | */ 14 | def analyzeErrors(docGraphs: Seq[DocumentGraph], 15 | allPredBackptrs: Seq[Array[Int]], 16 | allPredClusterings: Seq[OrderedClustering], 17 | scorer: PairwiseScorer) { 18 | for (docIdx <- 0 until docGraphs.size) { 19 | val doc = docGraphs(docIdx) 20 | val clustering = allPredClusterings(docIdx) 21 | for (i <- 0 until doc.size) { 22 | val prediction = allPredBackptrs(docIdx)(i) 23 | val ment = doc.getMention(i) 24 | val isCorrect = doc.isGoldCurrentPruning(i, prediction) 25 | val goldLinks = doc.getGoldAntecedentsUnderCurrentPruning(i) 26 | if (!isCorrect) { 27 | Logger.logss(doc.corefDoc.rawDoc.uid) 28 | Logger.logss(" Error on mention " + i + " " + doc.getMention(i).spanToStringWithHeadAndContext(2)) 29 | Logger.logss(" Prediction: " + prediction + " " + (if (prediction == i) "SELF" else doc.getMention(prediction).spanToStringWithHeadAndContext(2))) 30 | if (goldLinks.size == 1 && goldLinks(0) == i) { 31 | Logger.logss(" Gold: singleton") 32 | } else { 33 | Logger.logss(" Gold: " + goldLinks.toSeq) 34 | Logger.logss(" First antecedent: " + doc.getMention(goldLinks(0)).spanToStringWithHeadAndContext(2)) 35 | if (goldLinks.size > 1) { 36 | Logger.logss(" Most recent antecedent: " + doc.getMention(goldLinks.last).spanToStringWithHeadAndContext(2)) 37 | } 38 | } 39 | } 40 | } 41 | } 42 | } 43 | } -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/FeatureSetSpecification.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | 3 | import edu.berkeley.nlp.futile.util.Logger 4 | 5 | @SerialVersionUID(1L) 6 | case class FeatureSetSpecification(val featsToUse: Set[String], 7 | val conjScheme: ConjScheme, 8 | val conjFeatures: ConjFeatures, 9 | val conjListedTypePairs: Set[(MentionType,MentionType)], 10 | val conjListedTemplates: Set[String]) 11 | 12 | object FeatureSetSpecification { 13 | 14 | def apply(featsToUse: String, 15 | conjScheme: ConjScheme, 16 | conjFeatures: ConjFeatures): FeatureSetSpecification = { 17 | apply(featsToUse, conjScheme, conjFeatures, "", ""); 18 | } 19 | 20 | def apply(featsToUse: String, 21 | conjScheme: ConjScheme, 22 | conjFeatures: ConjFeatures, 23 | conjListedTypePairs: String, 24 | conjListedTemplates: String): FeatureSetSpecification = { 25 | val typePairs = conjListedTypePairs.split("\\+").toIndexedSeq.filter(!_.isEmpty).map(entry => MentionType.valueOf(entry.split("-")(0)) -> MentionType.valueOf(entry.split("-")(1))).toSet; 26 | val templates = conjListedTemplates.split("\\+").filter(!_.isEmpty).toSet; 27 | Logger.logss("Feature set: " + featsToUse + "\n" + conjScheme + "\n" + conjFeatures + "\nType pairs: " + typePairs + "\nTemplates: " + templates) 28 | new FeatureSetSpecification(featsToUse.split("\\+").toSet, conjScheme, conjFeatures, typePairs, templates); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/Gender.java: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref; 2 | 3 | public enum Gender { 4 | MALE, FEMALE, NEUTRAL, UNKNOWN; 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/LexicalInferenceFeaturizer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | 3 | import scala.collection.mutable.ArrayBuffer 4 | import java.util.IdentityHashMap 5 | import scala.collection.mutable.HashSet 6 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 7 | import edu.berkeley.nlp.futile.util.Logger 8 | import edu.berkeley.nlp.entity.GUtil 9 | import scala.collection.mutable.HashMap 10 | 11 | class LexicalInferenceFeaturizer(val lexInfDB: HashMap[(String,String),Seq[String]], 12 | val usePathFeatures: Boolean) extends AuxiliaryFeaturizer { 13 | 14 | override def featurize(docGraph: DocumentGraph, currIdx: Int, antecedentIdx: Int): Seq[String] = { 15 | val feats = new ArrayBuffer[String] 16 | val curr = docGraph.getMention(currIdx) 17 | val ant = docGraph.getMention(antecedentIdx) 18 | if (!curr.mentionType.isClosedClass() && !ant.mentionType.isClosedClass()) { 19 | val currText = curr.spanToString 20 | val antText = ant.spanToString 21 | val forwardContained = lexInfDB.contains(antText -> currText) 22 | feats += "LI=" + forwardContained 23 | if (usePathFeatures && forwardContained) { 24 | val rels = lexInfDB(antText -> currText) 25 | if (!rels.isEmpty) { 26 | for (rel <- rels) { 27 | feats += "LIPathContains=" + rel 28 | } 29 | } 30 | } 31 | val reverseContained = lexInfDB.contains(currText -> antText) 32 | feats += "LIRev=" + reverseContained 33 | if (usePathFeatures && reverseContained) { 34 | val rels = lexInfDB(currText -> antText) 35 | if (!rels.isEmpty) { 36 | for (rel <- rels) { 37 | feats += "LIRevPathContains=" + rel 38 | } 39 | } 40 | } 41 | } 42 | feats 43 | } 44 | } 45 | 46 | //class LexicalInferenceOracleFeaturizer(val lexInfDB: HashMap[(String,String),Seq[String]]) extends AuxiliaryFeaturizer { 47 | // 48 | // override def featurize(docGraph: DocumentGraph, currIdx: Int, antecedentIdx: Int): Seq[String] = { 49 | // val feats = new ArrayBuffer[String] 50 | // val curr = docGraph.getMention(currIdx) 51 | // val ant = docGraph.getMention(antecedentIdx) 52 | // if (!curr.mentionType.isClosedClass() && !ant.mentionType.isClosedClass()) { 53 | // val currText = curr.spanToString 54 | // val antText = ant.spanToString 55 | // val forwardContained = lexInfDB.contains(antText -> currText) 56 | // if (forwardContained) { 57 | // val areGold = docGraph.corefDoc.getOraclePredClustering.areInSameCluster(currIdx, antecedentIdx) 58 | // if (areGold) { 59 | // feats += "OracleIncluded" 60 | // } 61 | // } 62 | // } 63 | // feats 64 | // } 65 | //} 66 | 67 | object LexicalInferenceFeaturizer { 68 | def loadLexInfFeaturizer(lexInfResultsDir: String, usePathFeatures: Boolean) = { 69 | val lexInfDB = new HashMap[(String,String),Seq[String]]; 70 | addFileToSet(lexInfDB, lexInfResultsDir + "/train.txt") 71 | addFileToSet(lexInfDB, lexInfResultsDir + "/dev.txt") 72 | addFileToSet(lexInfDB, lexInfResultsDir + "/test.txt") 73 | Logger.logss("Loaded " + lexInfDB.size + " true positive lexical inference pairs from " + lexInfResultsDir) 74 | new LexicalInferenceFeaturizer(lexInfDB, usePathFeatures) 75 | } 76 | 77 | def addFileToSet(map: HashMap[(String,String),Seq[String]], file: String) { 78 | val lineItr = IOUtils.lineIterator(IOUtils.openInHard(file)) 79 | var corr = 0 80 | var pred = 0 81 | var gold = 0 82 | while (lineItr.hasNext) { 83 | val line = lineItr.next 84 | if (!line.trim.isEmpty) { 85 | val lineSplit = line.split("\\t") 86 | // Gold and prediction 87 | val goldTrue = lineSplit(6).toLowerCase().startsWith("t") 88 | val predTrue = lineSplit(7).toLowerCase().startsWith("t") 89 | if (goldTrue && predTrue) corr += 1 90 | if (goldTrue) gold += 1 91 | if (predTrue) pred += 1 92 | if (lineSplit(7).toLowerCase().startsWith("t")) { 93 | if (lineSplit.size == 9) { 94 | // Drop ^ and $ 95 | val relStr = lineSplit(8).drop(1).dropRight(1) 96 | map += (lineSplit(1) -> lineSplit(4)) -> relStr.split("\\s+").toSeq 97 | } else { 98 | map += (lineSplit(1) -> lineSplit(4)) -> Seq[String]() 99 | } 100 | } 101 | } 102 | } 103 | Logger.logss("Lexical inf accuracy in " + file + ": " + GUtil.renderPRF1(corr, pred, gold)) 104 | } 105 | } -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/MentionPropertyComputer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | 3 | class MentionPropertyComputer(val maybeNumGendComputer: Option[NumberGenderComputer]); 4 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/MentionRankingComputer.scala: -------------------------------------------------------------------------------- 1 | //package edu.berkeley.nlp.entity.coref 2 | // 3 | //import edu.berkeley.nlp.entity.LikelihoodAndGradientComputer 4 | //import edu.berkeley.nlp.futile.fig.basic.Indexer 5 | //import edu.berkeley.nlp.entity.GUtil 6 | // 7 | //class MentionRankingComputer(val featIdx: Indexer[String], 8 | // val featurizer: PairwiseIndexingFeaturizer, 9 | // val lossFcn: PairwiseLossFunction, 10 | // val doSps: Boolean = false, 11 | // val useDownstreamLoss: Boolean = false) extends LikelihoodAndGradientComputer[(DocumentGraph,Int)] { 12 | // 13 | // def getInitialWeights(initialWeightsScale: Double): Array[Double] = Array.tabulate(featIdx.size)(i => 0.0) 14 | // 15 | // private def computeFeatsScores(ex: (DocumentGraph,Int), weights: Array[Double]): (Array[Array[Int]], Array[Float]) = { 16 | // val docGraph = ex._1 17 | // val i = ex._2 18 | // val featsChart = docGraph.featurizeIndexNonPrunedUseCache(featurizer)(i) 19 | // val scoreVec = docGraph.cachedScoreMatrix(i); 20 | // for (j <- 0 to i) { 21 | // if (!docGraph.prunedEdges(i)(j)) { 22 | // require(featsChart(j).size > 0); 23 | // scoreVec(j) = GUtil.scoreIndexedFeatsDouble(featsChart(j), weights).toFloat; 24 | // } else { 25 | // scoreVec(j) = Float.NegativeInfinity; 26 | // } 27 | // } 28 | // featsChart -> scoreVec 29 | // } 30 | // 31 | // private def computeMarginals(ex: (DocumentGraph,Int), scores: Array[Float], gold: Boolean) = { 32 | // val docGraph = ex._1 33 | // val i = ex._2 34 | // val marginals = docGraph.cachedMarginalMatrix(i) 35 | // var normalizer = 0.0F; 36 | // // Restrict to gold antecedents if we're doing gold, but don't load the gold antecedents 37 | // // if we're not. 38 | // val goldAntecedents: Seq[Int] = if (gold) docGraph.getGoldAntecedentsUnderCurrentPruning(i) else null; 39 | // val losses = lossFcn.loss(docGraph.corefDoc, i, docGraph.prunedEdges) 40 | // for (j <- 0 to i) { 41 | // // If this is a legal antecedent 42 | // if (!docGraph.isPruned(i, j) && (!gold || goldAntecedents.contains(j))) { 43 | // // N.B. Including lossFcn is okay even for gold because it should be zero 44 | // val score = scores(j) + losses(j) 45 | // val unnormalizedProb = Math.exp(score).toFloat 46 | //// val unnormalizedProb = Math.exp(scores(j) + lossFcn.loss(docGraph.corefDoc, i, j)).toFloat; 47 | // marginals(j) = unnormalizedProb; 48 | // normalizer += unnormalizedProb; 49 | // } else { 50 | // marginals(j) = 0.0F; 51 | // } 52 | // } 53 | // for (j <- 0 to i) { 54 | // marginals(j) /= normalizer; 55 | // } 56 | // marginals 57 | // } 58 | // 59 | // private def computeMax(ex: (DocumentGraph,Int), scores: Array[Float], gold: Boolean): (Int, Double) = { 60 | // val docGraph = ex._1 61 | // val i = ex._2 62 | // var bestIdx = -1 63 | // var bestScore = Float.NegativeInfinity; 64 | // // Restrict to gold antecedents if we're doing gold, but don't load the gold antecedents 65 | // // if we're not. 66 | // val goldAntecedents: Seq[Int] = if (gold) docGraph.getGoldAntecedentsUnderCurrentPruning(i) else null; 67 | // val losses = lossFcn.loss(docGraph.corefDoc, i, docGraph.prunedEdges) 68 | // for (j <- 0 to i) { 69 | // // If this is a legal antecedent 70 | // if (!docGraph.isPruned(i, j) && (!gold || goldAntecedents.contains(j))) { 71 | // // N.B. Including lossFcn is okay even for gold because it should be zero 72 | // val score = (scores(j) + losses(j)).toFloat; 73 | //// val score = scores(j) + lossFcn.loss(docGraph.corefDoc, i, j).toFloat; 74 | // if (bestIdx == -1 || score > bestScore) { 75 | // bestIdx = j 76 | // bestScore = score 77 | // } 78 | // } 79 | // } 80 | // bestIdx -> bestScore 81 | // } 82 | // 83 | // def accumulateGradientAndComputeObjective(ex: (DocumentGraph,Int), weights: Array[Double], gradient: Array[Double]): Double = { 84 | // val docGraph = ex._1 85 | // val i = ex._2 86 | // val (featsChart, scores) = computeFeatsScores(ex, weights) 87 | // if (doSps) { 88 | // val (predMax, predScore) = computeMax(ex, scores, false) 89 | // val (goldMax, goldScore) = computeMax(ex, scores, true) 90 | // if (predMax != goldMax) { 91 | // GUtil.addToGradientDouble(featsChart(predMax), -1.0, gradient) 92 | // GUtil.addToGradientDouble(featsChart(goldMax), 1.0, gradient) 93 | // predScore - goldScore 94 | // } else { 95 | // 0.0 // no gap 96 | // } 97 | // } else { 98 | // // N.B. pred and gold marginals live in the same marginals matrix so don't have them 99 | // // both around at the same time 100 | // val predMarginals = computeMarginals(ex, scores, false); 101 | // val goldAntecedents = docGraph.getGoldAntecedentsUnderCurrentPruning(i); 102 | // // Pred terms in gradient and likelihood computation 103 | // var currProb = 0.0 104 | // for (j <- 0 to i) { 105 | // if (predMarginals(j) > 1e-20) { 106 | // GUtil.addToGradientDouble(featsChart(j), -predMarginals(j).toDouble, gradient); 107 | // if (goldAntecedents.contains(j)) { 108 | // currProb += predMarginals(j) 109 | // } 110 | // } 111 | // } 112 | // var currLogProb = Math.log(currProb).toFloat; 113 | // if (currLogProb.isInfinite()) { 114 | // currLogProb = -30; 115 | // } 116 | // // Gold terms in gradient 117 | // val goldMarginals = computeMarginals(ex, scores, true); 118 | // for (j <- 0 to i) { 119 | // if (goldMarginals(j) > 1e-20) { 120 | // GUtil.addToGradientDouble(featsChart(j), goldMarginals(j).toDouble, gradient); 121 | // } 122 | // } 123 | // currLogProb 124 | // } 125 | // } 126 | // 127 | // def computeObjective(ex: (DocumentGraph,Int), weights: Array[Double]): Double = { 128 | // accumulateGradientAndComputeObjective(ex, weights, Array.tabulate(weights.size)(i => 0.0)) 129 | // } 130 | //} -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/MentionRankingComputerSparse.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | 3 | import edu.berkeley.nlp.futile.fig.basic.Indexer 4 | import edu.berkeley.nlp.entity.GUtil 5 | import edu.berkeley.nlp.entity.LikelihoodAndGradientComputerSparse 6 | import edu.berkeley.nlp.futile.util.IntCounter 7 | import edu.berkeley.nlp.entity.AdagradWeightVector 8 | import edu.berkeley.nlp.futile.util.Logger 9 | import edu.berkeley.nlp.entity.GeneralTrainer2 10 | 11 | class MentionRankingComputerSparse(val featIdx: Indexer[String], 12 | val featurizer: PairwiseIndexingFeaturizer, 13 | val lossFcn: PairwiseLossFunction, 14 | val doSps: Boolean = false) extends LikelihoodAndGradientComputerSparse[(DocumentGraph,Int)] { 15 | 16 | def getInitialWeights(initialWeightsScale: Double): Array[Double] = Array.tabulate(featIdx.size)(i => 0.0) 17 | 18 | private def computeFeatsScores(ex: (DocumentGraph,Int), weights: AdagradWeightVector): (Array[Array[Int]], Array[Float]) = { 19 | val docGraph = ex._1 20 | val i = ex._2 21 | val featsChart = docGraph.featurizeIndexNonPrunedUseCache(featurizer)(i) 22 | val scoreVec = docGraph.cachedScoreMatrix(i); 23 | for (j <- 0 to i) { 24 | if (!docGraph.prunedEdges(i)(j)) { 25 | require(featsChart(j).size > 0); 26 | scoreVec(j) = weights.score(featsChart(j)).toFloat; 27 | } else { 28 | scoreVec(j) = Float.NegativeInfinity; 29 | } 30 | } 31 | featsChart -> scoreVec 32 | } 33 | 34 | private def computeMarginals(ex: (DocumentGraph,Int), scores: Array[Float], gold: Boolean) = { 35 | val docGraph = ex._1 36 | val i = ex._2 37 | val marginals = docGraph.cachedMarginalMatrix(i) 38 | var normalizer = 0.0F; 39 | // Restrict to gold antecedents if we're doing gold, but don't load the gold antecedents 40 | // if we're not. 41 | val goldAntecedents: Seq[Int] = if (gold) docGraph.getGoldAntecedentsUnderCurrentPruning(i) else null; 42 | val losses = lossFcn.loss(docGraph.corefDoc, i, docGraph.prunedEdges) 43 | for (j <- 0 to i) { 44 | // If this is a legal antecedent 45 | if (!docGraph.isPruned(i, j) && (!gold || goldAntecedents.contains(j))) { 46 | // N.B. Including lossFcn is okay even for gold because it should be zero 47 | val score = scores(j) + losses(j) 48 | val unnormalizedProb = Math.exp(score).toFloat 49 | // val unnormalizedProb = Math.exp(scores(j) + lossFcn.loss(docGraph.corefDoc, i, j)).toFloat; 50 | marginals(j) = unnormalizedProb; 51 | normalizer += unnormalizedProb; 52 | } else { 53 | marginals(j) = 0.0F; 54 | } 55 | } 56 | for (j <- 0 to i) { 57 | marginals(j) /= normalizer; 58 | } 59 | marginals 60 | } 61 | 62 | private def computeMax(ex: (DocumentGraph,Int), scores: Array[Float], gold: Boolean): (Int, Double) = { 63 | val docGraph = ex._1 64 | val i = ex._2 65 | var bestIdx = -1 66 | var bestScore = Float.NegativeInfinity; 67 | // Restrict to gold antecedents if we're doing gold, but don't load the gold antecedents 68 | // if we're not. 69 | val goldAntecedents: Seq[Int] = if (gold) docGraph.getGoldAntecedentsUnderCurrentPruning(i) else null; 70 | val losses = lossFcn.loss(docGraph.corefDoc, i, docGraph.prunedEdges) 71 | for (j <- 0 to i) { 72 | // If this is a legal antecedent 73 | if (!docGraph.isPruned(i, j) && (!gold || goldAntecedents.contains(j))) { 74 | // N.B. Including lossFcn is okay even for gold because it should be zero 75 | val score = (scores(j) + losses(j)).toFloat; 76 | // val score = scores(j) + lossFcn.loss(docGraph.corefDoc, i, j).toFloat; 77 | if (bestIdx == -1 || score > bestScore) { 78 | bestIdx = j 79 | bestScore = score 80 | } 81 | } 82 | } 83 | bestIdx -> bestScore 84 | } 85 | 86 | def accumulateGradientAndComputeObjective(ex: (DocumentGraph,Int), weights: AdagradWeightVector, gradient: IntCounter): Double = { 87 | val docGraph = ex._1 88 | val i = ex._2 89 | val (featsChart, scores) = computeFeatsScores(ex, weights) 90 | if (doSps) { 91 | val (predMax, predScore) = computeMax(ex, scores, false) 92 | val (goldMax, goldScore) = computeMax(ex, scores, true) 93 | if (predMax != goldMax) { 94 | GeneralTrainer2.addToGradient(featsChart(predMax), -1.0, gradient) 95 | GeneralTrainer2.addToGradient(featsChart(goldMax), 1.0, gradient) 96 | predScore - goldScore 97 | } else { 98 | 0.0 // no gap 99 | } 100 | } else { 101 | // N.B. pred and gold marginals live in the same marginals matrix so don't have them 102 | // both around at the same time 103 | val predMarginals = computeMarginals(ex, scores, false); 104 | val goldAntecedents = docGraph.getGoldAntecedentsUnderCurrentPruning(i); 105 | // Pred terms in gradient and likelihood computation 106 | var currProb = 0.0 107 | for (j <- 0 to i) { 108 | if (predMarginals(j) > 1e-20) { 109 | GeneralTrainer2.addToGradient(featsChart(j), -predMarginals(j).toDouble, gradient); 110 | if (goldAntecedents.contains(j)) { 111 | currProb += predMarginals(j) 112 | } 113 | } 114 | } 115 | var currLogProb = Math.log(currProb).toFloat; 116 | if (currLogProb.isInfinite()) { 117 | currLogProb = -30; 118 | } 119 | // Gold terms in gradient 120 | val goldMarginals = computeMarginals(ex, scores, true); 121 | for (j <- 0 to i) { 122 | if (goldMarginals(j) > 1e-20) { 123 | GeneralTrainer2.addToGradient(featsChart(j), goldMarginals(j).toDouble, gradient); 124 | } 125 | } 126 | currLogProb 127 | } 128 | } 129 | 130 | def computeObjective(ex: (DocumentGraph,Int), weights: AdagradWeightVector): Double = { 131 | accumulateGradientAndComputeObjective(ex, weights, new IntCounter) 132 | } 133 | } -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/MentionType.java: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref; 2 | 3 | 4 | public enum MentionType { 5 | 6 | PROPER(false), NOMINAL(false), PRONOMINAL(true), DEMONSTRATIVE(true); 7 | 8 | private boolean isClosedClass; 9 | 10 | private MentionType(boolean isClosedClass) { 11 | this.isClosedClass = isClosedClass; 12 | } 13 | 14 | public boolean isClosedClass() { 15 | return isClosedClass; 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/Number.java: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref; 2 | 3 | 4 | public enum Number { 5 | SINGULAR, PLURAL, UNKNOWN; 6 | } 7 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/OrderedClustering.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | import scala.collection.mutable.HashMap 3 | import scala.collection.JavaConverters._ 4 | import scala.collection.mutable.ArrayBuffer 5 | import scala.collection.mutable.HashSet 6 | 7 | @SerialVersionUID(1L) 8 | class OrderedClustering(val clusters: Seq[Seq[Int]]) extends Serializable { 9 | // Elements must be consecutive integers from 0 up to n 10 | private val allIndicesSorted = clusters.foldLeft(new ArrayBuffer[Int])(_ ++ _).sorted; 11 | require(allIndicesSorted.sameElements((0 until allIndicesSorted.size).toSeq), allIndicesSorted); 12 | private val mentionToClusterIdMap = new HashMap[Int,Int]; 13 | private val mentionToClusterMap = new HashMap[Int,Seq[Int]]; 14 | for (clusterIdx <- 0 until clusters.size) { 15 | val cluster = clusters(clusterIdx) 16 | for (i <- cluster) { 17 | mentionToClusterIdMap.put(i, clusterIdx) 18 | mentionToClusterMap.put(i, cluster); 19 | } 20 | } 21 | 22 | def getCluster(idx: Int) = mentionToClusterMap(idx); 23 | 24 | def isSingleton(idx: Int) = mentionToClusterMap(idx).size == 1; 25 | 26 | def startsCluster(idx: Int) = mentionToClusterMap(idx)(0) == idx; 27 | 28 | def areInSameCluster(idx1: Int, idx2: Int) = mentionToClusterMap(idx1).contains(idx2); 29 | 30 | def getImmediateAntecedent(idx: Int) = { 31 | val cluster = mentionToClusterMap(idx); 32 | val mentIdxInCluster = cluster.indexOf(idx); 33 | if (mentIdxInCluster == 0) { 34 | -1 35 | } else { 36 | cluster(mentIdxInCluster - 1); 37 | } 38 | } 39 | 40 | def getAllAntecedents(idx: Int) = { 41 | val cluster = mentionToClusterMap(idx); 42 | cluster.slice(0, cluster.indexOf(idx)); 43 | } 44 | 45 | def getAllConsequents(idx: Int) = { 46 | val cluster = mentionToClusterMap(idx); 47 | cluster.slice(cluster.indexOf(idx) + 1, cluster.size); 48 | } 49 | 50 | def getClusterIdxMap = mentionToClusterIdMap 51 | 52 | def getClusterIdx(idx: Int) = mentionToClusterIdMap(idx) 53 | 54 | def getConsistentBackpointers = { 55 | Array.tabulate(allIndicesSorted.size)(i => { 56 | val immediateAnt = getImmediateAntecedent(i) 57 | if (immediateAnt == -1) i else immediateAnt 58 | }) 59 | } 60 | 61 | def getSubclustering(mentIdxsToKeep: Seq[Int]): OrderedClustering = { 62 | val oldIndicesToNewIndicesMap = new HashMap[Int,Int](); 63 | (0 until mentIdxsToKeep.size).map(i => oldIndicesToNewIndicesMap.put(mentIdxsToKeep(i), i)); 64 | val filteredConvertedClusters = clusters.map(cluster => cluster.filter(mentIdxsToKeep.contains(_)).map(mentIdx => oldIndicesToNewIndicesMap(mentIdx))); 65 | val filteredConvertedClustersNoEmpties = filteredConvertedClusters.filter(cluster => !cluster.isEmpty); 66 | new OrderedClustering(filteredConvertedClustersNoEmpties); 67 | } 68 | 69 | def bind(ments: Seq[Mention], doConllPostprocessing: Boolean): OrderedClusteringBound = { 70 | if (doConllPostprocessing) new OrderedClusteringBound(ments, this).postprocessForConll() else new OrderedClusteringBound(ments, this); 71 | } 72 | } 73 | 74 | object OrderedClustering { 75 | 76 | def createFromClusterIds(clusterIds: Seq[Int]) = { 77 | val mentIdAndClusterId = (0 until clusterIds.size).map(i => (i, clusterIds(i))); 78 | val clustersUnsorted = mentIdAndClusterId.groupBy(_._2).values; 79 | val finalClusters = clustersUnsorted.toSeq.sortBy(_.head).map(clusterWithClusterId => clusterWithClusterId.map(_._1)); 80 | new OrderedClustering(finalClusters.toSeq); 81 | } 82 | 83 | def createFromBackpointers(backpointers: Seq[Int]) = { 84 | var nextClusterID = 0; 85 | val clusters = new ArrayBuffer[ArrayBuffer[Int]](); 86 | val mentionToCluster = new HashMap[Int,ArrayBuffer[Int]](); 87 | for (i <- 0 until backpointers.size) { 88 | if (backpointers(i) == i) { 89 | val cluster = ArrayBuffer(i); 90 | clusters += cluster; 91 | mentionToCluster.put(i, cluster); 92 | } else { 93 | val cluster = mentionToCluster(backpointers(i)); 94 | cluster += i; 95 | mentionToCluster.put(i, cluster); 96 | } 97 | } 98 | new OrderedClustering(clusters); 99 | } 100 | } 101 | 102 | class OrderedClusteringFromBackpointers(val backpointers: Seq[Int], 103 | val oc: OrderedClustering) { 104 | val adjacencyMap = new HashMap[Int,ArrayBuffer[Int]] 105 | for (i <- 0 until backpointers.size) { 106 | adjacencyMap(i) = new ArrayBuffer[Int] 107 | } 108 | for (i <- 0 until backpointers.size) { 109 | if (backpointers(i) != i) { 110 | adjacencyMap(i) += backpointers(i) 111 | adjacencyMap(backpointers(i)) += i 112 | } 113 | } 114 | 115 | def computeFromFrontier(seeds: Set[Int], blocked: Set[Int]) = { 116 | var frontier = new HashSet[Int] 117 | frontier ++= seeds 118 | var newFrontier = new HashSet[Int] 119 | val cluster = new HashSet[Int] 120 | while ((frontier -- cluster).size > 0) { 121 | for (node <- (frontier -- cluster)) { 122 | cluster += node 123 | newFrontier ++= (adjacencyMap(node) -- blocked) 124 | } 125 | frontier = newFrontier -- cluster 126 | } 127 | cluster.toSet 128 | } 129 | 130 | def changeBackpointerGetClusters(i: Int, newBackpointer: Int): (Seq[Int], Seq[Int]) = { 131 | // Split the cluster out 132 | val (oldAntCluster, partialCurrCluster) = if (backpointers(i) != i) { 133 | val oldAnt = backpointers(i) 134 | val oldAntAdjacent = adjacencyMap(oldAnt) 135 | computeFromFrontier(Set(backpointers(i)), Set(i)) -> computeFromFrontier(Set(i), Set(backpointers(i))) 136 | } else { 137 | (Seq[Int](), oc.getCluster(i)) 138 | } 139 | 140 | val newCurrCluster = if (newBackpointer != i) partialCurrCluster ++ oc.getCluster(newBackpointer) else partialCurrCluster 141 | oldAntCluster.toSeq -> newCurrCluster.toSeq 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/OrderedClusteringBound.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | import scala.collection.JavaConverters._ 3 | 4 | @SerialVersionUID(1L) 5 | class OrderedClusteringBound(val ments: Seq[Mention], 6 | val clustering: OrderedClustering) extends Serializable { 7 | 8 | def postprocessForConll(): OrderedClusteringBound = { 9 | val mentIdxsToKeep = (0 until ments.size).filter(i => !clustering.isSingleton(i)); 10 | new OrderedClusteringBound(mentIdxsToKeep.map(i => ments(i)), clustering.getSubclustering(mentIdxsToKeep)); 11 | } 12 | 13 | def getClusterIdx(ment: Mention) = { 14 | clustering.getClusterIdx(ments.indexOf(ment)); 15 | } 16 | 17 | def toSimple = new OrderedClusteringBoundSimple(ments.map(ment => (ment.sentIdx, ment.startIdx, ment.endIdx)), clustering) 18 | } 19 | 20 | class OrderedClusteringBoundSimple(val ments: Seq[(Int,Int,Int)], 21 | val clustering: OrderedClustering) { 22 | 23 | } -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/PairwiseIndexingFeaturizer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | import edu.berkeley.nlp.futile.fig.basic.Indexer 3 | import edu.berkeley.nlp.futile.util.Counter 4 | import edu.berkeley.nlp.futile.util.Logger 5 | import scala.collection.JavaConverters._ 6 | import edu.berkeley.nlp.entity.sem.QueryCountsBundle 7 | import edu.berkeley.nlp.entity.wiki.WikipediaInterface 8 | 9 | trait PairwiseIndexingFeaturizer { 10 | 11 | def getIndexer(): Indexer[String]; 12 | 13 | def getQueryCountsBundle: Option[QueryCountsBundle]; 14 | 15 | def featurizeIndex(docGraph: DocumentGraph, currMentIdx: Int, antecedentIdx: Int, addToFeaturizer: Boolean): Array[Int]; 16 | } 17 | 18 | object PairwiseIndexingFeaturizer { 19 | 20 | def printFeatureTemplateCounts(indexer: Indexer[String]) { 21 | val templateCounts = new Counter[String](); 22 | for (i <- 0 until indexer.size) { 23 | val template = PairwiseIndexingFeaturizer.getTemplate(indexer.get(i)); 24 | templateCounts.incrementCount(template, 1.0); 25 | } 26 | templateCounts.keepTopNKeys(200); 27 | if (templateCounts.size > 200) { 28 | Logger.logss("Not going to print more than 200 templates"); 29 | } 30 | templateCounts.keySet().asScala.toSeq.sorted.foreach(template => Logger.logss(template + ": " + templateCounts.getCount(template).toInt)); 31 | 32 | val conjCounts = new Counter[String](); 33 | for (i <- 0 until indexer.size) { 34 | val currFeatureName = indexer.get(i); 35 | val conjStart = currFeatureName.indexOf("&C"); 36 | if (conjStart == -1) { 37 | conjCounts.incrementCount("No &C", 1.0); 38 | } else { 39 | conjCounts.incrementCount(currFeatureName.substring(conjStart), 1.0); 40 | } 41 | } 42 | conjCounts.keepTopNKeys(1000); 43 | if (conjCounts.size > 1000) { 44 | Logger.logss("Not going to print more than 1000 templates"); 45 | } 46 | conjCounts.keySet().asScala.toSeq.sorted.foreach(conj => Logger.logss(conj + ": " + conjCounts.getCount(conj).toInt)); 47 | } 48 | 49 | def getTemplate(feat: String) = { 50 | val currFeatureTemplateStop = feat.indexOf("="); 51 | if (currFeatureTemplateStop == -1) { 52 | ""; 53 | } else { 54 | feat.substring(0, currFeatureTemplateStop); 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/PairwiseScorer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | 3 | import scala.collection.mutable.HashMap 4 | import edu.berkeley.nlp.entity.GUtil 5 | import edu.berkeley.nlp.futile.util.Counter 6 | import edu.berkeley.nlp.futile.util.Beam 7 | import edu.berkeley.nlp.futile.util.Logger 8 | import edu.berkeley.nlp.futile.fig.basic.Indexer 9 | 10 | @SerialVersionUID(1L) 11 | class PairwiseScorer(val featurizer: PairwiseIndexingFeaturizer, val weights: Array[Float]) extends Serializable { 12 | 13 | def score(docGraph: DocumentGraph, currMentIdx: Int, antMentIdx: Int, addToFeaturizer: Boolean = false) = { 14 | GUtil.scoreIndexedFeats(featurizer.featurizeIndex(docGraph, currMentIdx, antMentIdx, false), weights); 15 | } 16 | 17 | def numWeights = weights.size 18 | 19 | def computeTopFeatsPerTemplate(cutoff: Int): Map[String,Counter[String]] = { 20 | val topFeatsPerTemplate = new HashMap[String,Counter[String]]; 21 | for (featIdx <- 0 until weights.size) { 22 | val featName = featurizer.getIndexer.getObject(featIdx); 23 | val featTemplate = PairwiseIndexingFeaturizer.getTemplate(featName); 24 | val weight = weights(featIdx); 25 | if (!topFeatsPerTemplate.contains(featTemplate)) { 26 | topFeatsPerTemplate.put(featTemplate, new Counter[String]); 27 | } 28 | topFeatsPerTemplate(featTemplate).incrementCount(featName, weight); 29 | } 30 | topFeatsPerTemplate.map(entry => { 31 | val counter = entry._2; 32 | counter.keepTopNKeysByAbsValue(cutoff); 33 | entry._1 -> counter 34 | }).toMap; 35 | } 36 | 37 | def pack: PairwiseScorer = { 38 | if (!featurizer.isInstanceOf[PairwiseIndexingFeaturizerJoint]) { 39 | Logger.logss("Can't pack"); 40 | this; 41 | } else { 42 | val oldFeaturizer = featurizer.asInstanceOf[PairwiseIndexingFeaturizerJoint] 43 | val (newFeatureIndexer, newWeights) = GUtil.packFeaturesAndWeights(featurizer.getIndexer(), weights); 44 | val newFeaturizer = oldFeaturizer.replaceIndexer(newFeatureIndexer); 45 | new PairwiseScorer(newFeaturizer, newWeights); 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/PronounDictionary.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.coref 2 | import scala.collection.mutable.HashMap 3 | 4 | object PronounDictionary { 5 | val firstPersonPronouns = Set("i", "me", "myself", "mine", "my", "we", "us", "ourself", "ourselves", "ours", "our"); 6 | val secondPersonPronouns = Set("you", "yourself", "yours", "your", "yourselves"); 7 | val thirdPersonPronouns = Set("he", "him", "himself", "his", "she", "her", "herself", "hers", "her", "it", "itself", "its", "one", "oneself", "one's", "they", "them", "themself", "themselves", "theirs", "their", "they", "them", "'em", "themselves"); 8 | val otherPronouns = Set("who", "whom", "whose", "where", "when","which"); 9 | 10 | val demonstratives = Set("this", "that", "these", "those"); 11 | 12 | // Borrowed from Stanford 13 | val singularPronouns = Set("i", "me", "myself", "mine", "my", "yourself", "he", "him", "himself", "his", "she", "her", "herself", "hers", "her", "it", "itself", "its", "one", "oneself", "one's"); 14 | val pluralPronouns = Set("we", "us", "ourself", "ourselves", "ours", "our", "yourself", "yourselves", "they", "them", "themself", "themselves", "theirs", "their"); 15 | val malePronouns = Set("he", "him", "himself", "his"); 16 | val femalePronouns = Set("her", "hers", "herself", "she"); 17 | val neutralPronouns = Set("it", "its", "itself", "where", "here", "there", "which"); 18 | 19 | 20 | val allPronouns = firstPersonPronouns ++ secondPersonPronouns ++ thirdPersonPronouns ++ otherPronouns; 21 | 22 | // Constructed based on Stanford's Dictionaries class 23 | val canonicalizations = new HashMap[String,String](); 24 | canonicalizations.put("i", "i"); 25 | canonicalizations.put("me", "i"); 26 | canonicalizations.put("my", "i"); 27 | canonicalizations.put("myself", "i"); 28 | canonicalizations.put("mine", "i"); 29 | canonicalizations.put("you", "you"); 30 | canonicalizations.put("your", "you"); 31 | canonicalizations.put("yourself", "you"); 32 | canonicalizations.put("yourselves", "you"); 33 | canonicalizations.put("yours", "you"); 34 | canonicalizations.put("he", "he"); 35 | canonicalizations.put("him", "he"); 36 | canonicalizations.put("his", "he"); 37 | canonicalizations.put("himself", "he"); 38 | canonicalizations.put("she", "she"); 39 | canonicalizations.put("her", "she"); 40 | canonicalizations.put("herself", "she"); 41 | canonicalizations.put("hers", "she"); 42 | 43 | canonicalizations.put("we", "we"); 44 | canonicalizations.put("us", "we"); 45 | canonicalizations.put("our", "we"); 46 | canonicalizations.put("ourself", "we"); 47 | canonicalizations.put("ourselves", "we"); 48 | canonicalizations.put("ours", "we"); 49 | canonicalizations.put("they", "they"); 50 | canonicalizations.put("them", "they"); 51 | canonicalizations.put("their", "they"); 52 | canonicalizations.put("themself", "they"); 53 | canonicalizations.put("themselves", "they"); 54 | canonicalizations.put("theirs", "they"); 55 | canonicalizations.put("'em", "they"); 56 | canonicalizations.put("it", "it"); 57 | canonicalizations.put("itself", "it"); 58 | canonicalizations.put("its", "it"); 59 | canonicalizations.put("one", "one"); 60 | canonicalizations.put("oneself", "one"); 61 | canonicalizations.put("one's", "one"); 62 | 63 | canonicalizations.put("this", "this"); 64 | canonicalizations.put("that", "that"); 65 | canonicalizations.put("these", "these"); 66 | canonicalizations.put("those", "those"); 67 | canonicalizations.put("which", "which"); 68 | canonicalizations.put("who", "who"); 69 | canonicalizations.put("whom", "who"); 70 | // canonicalizations.put("where", "where"); 71 | // canonicalizations.put("whose", "whose"); 72 | // This entry is here just to make results consistent with earlier ones 73 | // on our very small dev set 74 | canonicalizations.put("thy", "thy"); 75 | canonicalizations.put("y'all", "you"); 76 | canonicalizations.put("you're", "you"); 77 | canonicalizations.put("you'll", "you"); 78 | canonicalizations.put("'s", "'s"); 79 | 80 | def isPronLc(str: String): Boolean = { 81 | !mightBeAcronym(str) && allPronouns.contains(str.toLowerCase()); 82 | } 83 | 84 | def isDemonstrative(str: String): Boolean = { 85 | !mightBeAcronym(str) && demonstratives.contains(str.toLowerCase()); 86 | } 87 | 88 | def mightBeAcronym(str: String) = { 89 | if (str.size <= 4) { 90 | var acronym = true; 91 | var i = 0; 92 | while (acronym && i < str.size) { 93 | if (!Character.isUpperCase(str.charAt(i))) { 94 | acronym = false; 95 | } 96 | i += 1; 97 | } 98 | acronym; 99 | } else { 100 | false; 101 | } 102 | } 103 | 104 | def canonicalize(str: String): String = { 105 | if (!canonicalizations.contains(str.toLowerCase())) { 106 | ""; 107 | } else { 108 | canonicalizations(str.toLowerCase()); 109 | } 110 | } 111 | 112 | def main(args: Array[String]) { 113 | println(PronounDictionary.canonicalizations("'em")); 114 | println(PronounDictionary.isPronLc("them")); 115 | println(PronounDictionary.isPronLc("Them")); 116 | println(PronounDictionary.isPronLc("NotThem")); 117 | println(PronounDictionary.mightBeAcronym("them")); 118 | println(PronounDictionary.mightBeAcronym("Them")); 119 | println(PronounDictionary.mightBeAcronym("THEM")); 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/PruningStrategy.scala: -------------------------------------------------------------------------------- 1 | //package edu.berkeley.nlp.entity.coref 2 | // 3 | //case class PruningStrategy(val strategy: String) { 4 | // 5 | // def getDistanceArgs(): (Int, Int) = { 6 | // require(strategy.startsWith("distance")); 7 | // (splitStrategy(1).toInt, splitStrategy(2).toInt); 8 | // } 9 | // 10 | // def getModelPath: String = PruningStrategy.getModelPath(strategy); 11 | // def getModelLogRatio: Float = PruningStrategy.getModelLogRatio(strategy); 12 | //} 13 | // 14 | //object PruningStrategy { 15 | // 16 | // def buildPruner(strategy: String) { 17 | // if (strategy.startsWith("distance")) { 18 | // val splitStrategy = strategy.split(":"); 19 | // val (maxSent(splitStrategy(1).toInt, splitStrategy(2).toInt); 20 | // } else if (strategy.startsWith("models")) { 21 | // 22 | // } 23 | // } 24 | // 25 | // def getModelPath(strategy: String): String = { 26 | // require(strategy.startsWith("models")); 27 | // strategy.split(":")(1); 28 | // } 29 | // 30 | // def getModelLogRatio(strategy: String): Float = { 31 | // require(strategy.startsWith("models")); 32 | // strategy.split(":")(2).toFloat; 33 | // } 34 | //} 35 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/coref/package.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity 2 | 3 | package object coref { 4 | type UID = (String, Int); 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/joint/FactorGraphFactory.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.joint 2 | 3 | import scala.collection.mutable.HashMap 4 | import edu.berkeley.nlp.entity.coref.UID 5 | import edu.berkeley.nlp.futile.fig.basic.Indexer 6 | import edu.berkeley.nlp.entity.ner.NerFeaturizer 7 | import edu.berkeley.nlp.entity.ner.MCNerFeaturizer 8 | import edu.berkeley.nlp.entity.coref.CorefDoc 9 | import edu.berkeley.nlp.entity.wiki.WikipediaInterface 10 | 11 | trait FactorGraphFactory[D,G<:JointDocFactorGraph] { 12 | 13 | def getIndexer: Indexer[String]; 14 | 15 | def getDocFactorGraph(obj: D, 16 | isGold: Boolean, 17 | addToIndexer: Boolean, 18 | useCache: Boolean, 19 | corefLossFcn: (CorefDoc, Int, Int) => Float, 20 | nerLossFcn: (String, String) => Float, 21 | wikiLossFcn: (Seq[String], String) => Float): G; 22 | 23 | def getDocFactorGraphHard(obj: D, isGold: Boolean) = { 24 | getDocFactorGraph(obj, isGold, false, true, null, null, null); 25 | } 26 | } 27 | 28 | class FactorGraphFactoryOnto(val featurizer: JointFeaturizerShared[NerFeaturizer], 29 | val wikiDB: Option[WikipediaInterface]) extends FactorGraphFactory[JointDoc,JointDocFactorGraphOnto] { 30 | val goldFactorGraphCache = new HashMap[UID, JointDocFactorGraphOnto](); 31 | val guessFactorGraphCache = new HashMap[UID, JointDocFactorGraphOnto](); 32 | 33 | private def fetchGraphCache(gold: Boolean) = { 34 | if (gold) { 35 | goldFactorGraphCache 36 | } else { 37 | guessFactorGraphCache; 38 | } 39 | } 40 | 41 | def getIndexer = featurizer.indexer; 42 | 43 | def getDocFactorGraph(doc: JointDoc, 44 | gold: Boolean, 45 | addToIndexer: Boolean, 46 | useCache: Boolean, 47 | corefLossFcn: (CorefDoc, Int, Int) => Float, 48 | nerLossFcn: (String, String) => Float, 49 | wikiLossFcn: (Seq[String], String) => Float): JointDocFactorGraphOnto = { 50 | if (useCache) { 51 | val cache = fetchGraphCache(gold); 52 | if (!cache.contains(doc.rawDoc.uid)) { 53 | cache.put(doc.rawDoc.uid, new JointDocFactorGraphOnto(doc, featurizer, wikiDB, gold, addToIndexer, corefLossFcn, nerLossFcn, wikiLossFcn)); 54 | } 55 | cache(doc.rawDoc.uid); 56 | } else { 57 | if (corefLossFcn == null) { 58 | throw new RuntimeException("You called getDocFactorGraphHard but it wasn't in the cache...") 59 | } 60 | new JointDocFactorGraphOnto(doc, featurizer, wikiDB, gold, addToIndexer, corefLossFcn, nerLossFcn, wikiLossFcn) 61 | } 62 | } 63 | } 64 | 65 | class FactorGraphFactoryACE(val featurizer: JointFeaturizerShared[MCNerFeaturizer], 66 | val wikiDB: Option[WikipediaInterface]) extends FactorGraphFactory[JointDocACE,JointDocFactorGraphACE] { 67 | val goldFactorGraphCache = new HashMap[UID, JointDocFactorGraphACE](); 68 | val guessFactorGraphCache = new HashMap[UID, JointDocFactorGraphACE](); 69 | 70 | private def fetchGraphCache(gold: Boolean) = { 71 | if (gold) { 72 | goldFactorGraphCache 73 | } else { 74 | guessFactorGraphCache; 75 | } 76 | } 77 | 78 | def getIndexer = featurizer.indexer; 79 | 80 | def getDocFactorGraph(doc: JointDocACE, 81 | gold: Boolean, 82 | addToIndexer: Boolean, 83 | useCache: Boolean, 84 | corefLossFcn: (CorefDoc, Int, Int) => Float, 85 | nerLossFcn: (String, String) => Float, 86 | wikiLossFcn: (Seq[String], String) => Float): JointDocFactorGraphACE = { 87 | if (useCache) { 88 | val cache = fetchGraphCache(gold); 89 | if (!cache.contains(doc.rawDoc.uid)) { 90 | cache.put(doc.rawDoc.uid, new JointDocFactorGraphACE(doc, featurizer, wikiDB, gold, addToIndexer, corefLossFcn, nerLossFcn, wikiLossFcn)); 91 | } 92 | cache(doc.rawDoc.uid); 93 | } else { 94 | if (corefLossFcn == null) { 95 | throw new RuntimeException("You called getDocFactorGraphHard but it wasn't in the cache...") 96 | } 97 | new JointDocFactorGraphACE(doc, featurizer, wikiDB, gold, addToIndexer, corefLossFcn, nerLossFcn, wikiLossFcn) 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/joint/GeneralTrainer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.joint 2 | import scala.Array.canBuildFrom 3 | import scala.collection.JavaConverters.asScalaBufferConverter 4 | import scala.collection.mutable.HashSet 5 | import edu.berkeley.nlp.futile.math.CachingDifferentiableFunction 6 | import edu.berkeley.nlp.futile.math.LBFGSMinimizer 7 | import edu.berkeley.nlp.futile.util.Logger 8 | import edu.berkeley.nlp.futile.fig.basic.SysInfoUtils 9 | import java.util.Arrays 10 | 11 | trait LikelihoodAndGradientComputer[T] { 12 | def addUnregularizedStochasticGradient(ex: T, weights: Array[Float], gradient: Array[Float]); 13 | def computeLogLikelihood(ex: T, weights: Array[Float]): Float; 14 | } 15 | 16 | class GeneralTrainer[T] { 17 | 18 | var inferenceNanos = 0L; 19 | var adagradNanos = 0L; 20 | 21 | def train(trainExs: Seq[T], 22 | computer: LikelihoodAndGradientComputer[T], 23 | numFeats: Int, 24 | eta: Float, 25 | reg: Float, 26 | batchSize: Int, 27 | numItrs: Int): Array[Float] = { 28 | trainAdagrad(trainExs, computer, numFeats, eta, reg, batchSize, numItrs); 29 | } 30 | 31 | def trainAdagrad(trainExs: Seq[T], 32 | computer: LikelihoodAndGradientComputer[T], 33 | numFeats: Int, 34 | eta: Float, 35 | lambda: Float, 36 | batchSize: Int, 37 | numItrs: Int, 38 | learningCallback: Array[Float] => Unit = (weights: Array[Float]) => {}): Array[Float] = { 39 | // val weights = Array.fill(pairwiseIndexingFeaturizer.featureIndexer.size)(0.0); 40 | val weights = Array.fill(numFeats)(0.0F); 41 | val reusableGradientArray = Array.fill(numFeats)(0.0F); 42 | val diagGt = Array.fill(numFeats)(0.0F); 43 | for (i <- 0 until numItrs) { 44 | Logger.logss("ITERATION " + i); 45 | val startTime = System.nanoTime(); 46 | inferenceNanos = 0; 47 | adagradNanos = 0; 48 | Logger.startTrack("Computing gradient"); 49 | var currIdx = 0; 50 | var currBatchIdx = 0; 51 | val printFreq = (trainExs.size / batchSize) / 10 // Print progress 10 times per pass through the data 52 | while (currIdx < trainExs.size) { 53 | if (printFreq == 0 || currBatchIdx % printFreq == 0) { 54 | Logger.logs("Computing gradient on " + currIdx); 55 | } 56 | takeAdagradStepL1R(trainExs.slice(currIdx, Math.min(trainExs.size, currIdx + batchSize)), 57 | computer, 58 | weights, 59 | reusableGradientArray, 60 | diagGt, 61 | eta, 62 | lambda); 63 | learningCallback(weights) 64 | currIdx += batchSize; 65 | currBatchIdx += 1; 66 | } 67 | Logger.endTrack(); 68 | Logger.logss("NONZERO WEIGHTS: " + weights.foldRight(0)((weight, count) => if (Math.abs(weight) > 1e-15) count + 1 else count)); 69 | Logger.logss("WEIGHT VECTOR NORM: " + weights.foldRight(0.0)((weight, norm) => norm + weight * weight)); 70 | if (i == 0 || i == 1 || i % 5 == 4 || i == numItrs - 1) { 71 | Logger.startTrack("Evaluating objective on train"); 72 | Logger.logss("TRAIN OBJECTIVE: " + computeObjectiveL1R(trainExs, computer, weights, lambda)); 73 | Logger.endTrack(); 74 | } 75 | Logger.logss("MILLIS FOR ITER " + i + ": " + (System.nanoTime() - startTime) / 1000000.0); 76 | Logger.logss("MILLIS INFERENCE FOR ITER " + i + ": " + inferenceNanos / 1000000.0); 77 | Logger.logss("MILLIS ADAGRAD FOR ITER " + i + ": " + adagradNanos / 1000000.0); 78 | Logger.logss("MEMORY AFTER ITER " + i + ": " + SysInfoUtils.getUsedMemoryStr()); 79 | } 80 | weights 81 | } 82 | 83 | def computeObjectiveL1R(trainExs: Seq[T], 84 | computer: LikelihoodAndGradientComputer[T], 85 | weights: Array[Float], 86 | lambda: Float): Float = { 87 | var objective = computeLikelihood(trainExs, computer, weights); 88 | for (weight <- weights) { 89 | objective -= lambda * Math.abs(weight); 90 | } 91 | objective; 92 | } 93 | 94 | def computeLikelihood(trainExs: Seq[T], 95 | computer: LikelihoodAndGradientComputer[T], 96 | weights: Array[Float]): Float = { 97 | (trainExs.foldRight(0.0)((ex, likelihood) => likelihood + computer.computeLogLikelihood(ex, weights))).toFloat; 98 | } 99 | 100 | def takeAdagradStepL1R(exs: Seq[T], 101 | computer: LikelihoodAndGradientComputer[T], 102 | weights: Array[Float], 103 | reusableGradientArray: Array[Float], 104 | diagGt: Array[Float], 105 | eta: Float, 106 | lambda: Float) { 107 | Arrays.fill(reusableGradientArray, 0.0F); 108 | var nanoTime = System.nanoTime(); 109 | for (ex <- exs) { 110 | computer.addUnregularizedStochasticGradient(ex, weights, reusableGradientArray); 111 | } 112 | inferenceNanos += (System.nanoTime() - nanoTime); 113 | nanoTime = System.nanoTime(); 114 | // Precompute this so dividing by batch size is a multiply and not a divide 115 | val batchSizeMultiplier = 1.0F/exs.size; 116 | var i = 0; 117 | while (i < reusableGradientArray.size) { 118 | val xti = weights(i); 119 | // N.B. We negate the gradient here because the Adagrad formulas are all for minimizing 120 | // and we're trying to maximize, so think of it as minimizing the negative of the objective 121 | // which has the opposite gradient 122 | // Equation (25) in http://www.cs.berkeley.edu/~jduchi/projects/DuchiHaSi10.pdf 123 | // eta is the step size, lambda is the regularization 124 | val gti = -reusableGradientArray(i) * batchSizeMultiplier; 125 | // Update diagGt 126 | diagGt(i) += gti * gti; 127 | val Htii = 1F + Math.sqrt(diagGt(i)).toFloat; 128 | // Avoid divisions at all costs... 129 | val etaOverHtii = eta / Htii; 130 | val newXti = xti - etaOverHtii * gti; 131 | weights(i) = Math.signum(newXti) * Math.max(0, Math.abs(newXti) - lambda * etaOverHtii); 132 | i += 1; 133 | } 134 | adagradNanos += (System.nanoTime() - nanoTime); 135 | } 136 | 137 | 138 | } 139 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/joint/JointDocACE.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.joint 2 | 3 | import scala.collection.JavaConverters._ 4 | import scala.collection.mutable.ArrayBuffer 5 | import scala.collection.mutable.HashMap 6 | 7 | import edu.berkeley.nlp.entity.Chunk 8 | import edu.berkeley.nlp.entity.ConllDoc 9 | import edu.berkeley.nlp.entity.coref.DocumentGraph 10 | import edu.berkeley.nlp.entity.coref.Mention 11 | import edu.berkeley.nlp.entity.wiki._ 12 | import edu.berkeley.nlp.futile.util.Logger 13 | 14 | class JointDocACE(val rawDoc: ConllDoc, 15 | val docGraph: DocumentGraph, 16 | val goldWikiChunks: Seq[Seq[Chunk[Seq[String]]]]) { 17 | 18 | val goldChunks = (0 until docGraph.corefDoc.rawDoc.numSents).map(sentIdx => { 19 | // Only take the part that's actually the type from each one, "GPE", "LOC", etc. 20 | docGraph.corefDoc.goldMentions.filter(_.sentIdx == sentIdx).map(ment => new Chunk[String](ment.startIdx, ment.endIdx, getGoldLabel(ment))); 21 | }); 22 | 23 | def getGoldLabel(ment: Mention) = { 24 | if (ment.nerString.size >= 3) { 25 | ment.nerString.substring(0, 3) 26 | } else { 27 | "O" // SHouldn't happen during training 28 | }; 29 | } 30 | 31 | def getGoldWikLabels(ment: Mention): Seq[String] = { 32 | val matchingChunk = goldWikiChunks(ment.sentIdx).filter(chunk => chunk.start == ment.startIdx && chunk.end == ment.endIdx); 33 | if (matchingChunk.size > 0) matchingChunk.head.label else Seq(ExcludeToken); 34 | } 35 | } 36 | 37 | object JointDocACE { 38 | 39 | def apply(rawDoc: ConllDoc, 40 | docGraph: DocumentGraph, 41 | maybeGoldWikiChunks: Option[Seq[Seq[Chunk[Seq[String]]]]]) = { 42 | val goldWikiChunks = if (maybeGoldWikiChunks.isDefined) { 43 | maybeGoldWikiChunks.get 44 | } else { 45 | (0 until rawDoc.numSents).map(i => Seq[Chunk[Seq[String]]]()); 46 | } 47 | new JointDocACE(rawDoc, docGraph, goldWikiChunks); 48 | } 49 | 50 | def assembleJointDocs(docGraphs: Seq[DocumentGraph], 51 | goldWikification: CorpusWikiAnnots) = { 52 | docGraphs.map(docGraph => { 53 | val rawDoc = docGraph.corefDoc.rawDoc; 54 | val goldWiki = if (goldWikification.contains(rawDoc.docID)) { 55 | Some((0 until rawDoc.numSents).map(goldWikification(rawDoc.docID)(_))); 56 | } else { 57 | if (!goldWikification.isEmpty) { 58 | Logger.logss("WARNING: Some Wikification doc entries found, but none for " + rawDoc.docID); 59 | } 60 | None; 61 | } 62 | JointDocACE(rawDoc, docGraph, goldWiki); 63 | }); 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/joint/JointDocFactorGraph.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.joint 2 | 3 | import scala.collection.JavaConverters._ 4 | import scala.collection.mutable.ArrayBuffer 5 | import edu.berkeley.nlp.entity.coref.CorefDoc 6 | import edu.berkeley.nlp.entity.coref.DocumentGraph 7 | import edu.berkeley.nlp.entity.GUtil 8 | import edu.berkeley.nlp.entity.coref.PairwiseIndexingFeaturizer 9 | import edu.berkeley.nlp.entity.coref.PairwiseIndexingFeaturizerJoint 10 | import edu.berkeley.nlp.entity.coref.PairwiseScorer 11 | import edu.berkeley.nlp.entity.ner.NerSystemLabeled 12 | import edu.berkeley.nlp.entity.sem.SemClass 13 | import edu.berkeley.nlp.futile.fig.basic.Indexer 14 | import edu.berkeley.nlp.futile.util.Logger 15 | import edu.berkeley.nlp.entity.ner.NerFeaturizer 16 | import edu.berkeley.nlp.entity.Chunk 17 | import edu.berkeley.nlp.entity.bp.BetterPropertyFactor 18 | import edu.berkeley.nlp.entity.bp.Factor 19 | import edu.berkeley.nlp.entity.bp.Node 20 | import edu.berkeley.nlp.entity.bp.UnaryFactorOld 21 | import scala.Array.canBuildFrom 22 | import edu.berkeley.nlp.entity.bp.Domain 23 | import edu.berkeley.nlp.entity.bp.UnaryFactorGeneral 24 | import edu.berkeley.nlp.entity.ner.NerExample 25 | import edu.berkeley.nlp.entity.bp.BinaryFactorGeneral 26 | import edu.berkeley.nlp.entity.Driver 27 | import edu.berkeley.nlp.entity.bp.ConstantBinaryFactor 28 | import edu.berkeley.nlp.entity.bp.SimpleFactorGraph 29 | import scala.collection.mutable.HashMap 30 | 31 | trait JointDocFactorGraph { 32 | 33 | def setWeights(weights: Array[Float]); 34 | 35 | def computeAndStoreMarginals(weights: Array[Float], 36 | exponentiateMessages: Boolean, 37 | numBpIters: Int); 38 | 39 | def computeLogNormalizerApprox: Double; 40 | 41 | def scrubMessages(); 42 | 43 | def passMessagesFancy(numItrs: Int, exponentiateMessages: Boolean); 44 | 45 | def addExpectedFeatureCountsToGradient(scale: Float, gradient: Array[Float]); 46 | 47 | def decodeCorefProduceBackpointers: Array[Int]; 48 | 49 | def decodeNERProduceChunks: Seq[Seq[Chunk[String]]]; 50 | 51 | def decodeWikificationProduceChunks: Seq[Seq[Chunk[String]]]; 52 | 53 | def getRepresentativeFeatures: HashMap[String,String]; 54 | } 55 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/joint/JointLossFcns.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.joint 2 | 3 | import edu.berkeley.nlp.entity.wiki._ 4 | import edu.berkeley.nlp.entity.Driver 5 | 6 | object JointLossFcns { 7 | 8 | val nerLossFcn = (gold: String, pred: String) => if (gold == pred) 0.0F else Driver.nerLossScale.toFloat; 9 | 10 | val noNerLossFcn = (gold: String, pred: String) => 0.0F 11 | 12 | val wikiLossFcn = (gold: Seq[String], pred: String) => { 13 | if (gold.contains(NilToken) && pred != NilToken) { 14 | Driver.wikiNilKbLoss.toFloat 15 | } else if (!gold.contains(NilToken) && pred == NilToken){ 16 | Driver.wikiKbNilLoss.toFloat 17 | } else if (!isCorrect(gold, pred)) { 18 | Driver.wikiKbKbLoss.toFloat 19 | } else { 20 | 0.0F; 21 | } 22 | } 23 | 24 | val noWikiLossFcn = (gold: Seq[String], pred: String) => 0.0F; 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/joint/JointPredictor.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.joint 2 | 3 | import scala.collection.mutable.ArrayBuffer 4 | import scala.collection.mutable.HashMap 5 | import edu.berkeley.nlp.entity.Chunk 6 | import edu.berkeley.nlp.entity.ConllDoc 7 | import edu.berkeley.nlp.entity.ConllDocReader 8 | import edu.berkeley.nlp.entity.ConllDocWriter 9 | import edu.berkeley.nlp.entity.GUtil 10 | import edu.berkeley.nlp.entity.coref.CorefDocAssembler 11 | import edu.berkeley.nlp.entity.coref.CorefDocAssemblerACE 12 | import edu.berkeley.nlp.entity.coref.CorefPruner 13 | import edu.berkeley.nlp.entity.coref.CorefSystem 14 | import edu.berkeley.nlp.entity.coref.DocumentGraph 15 | import edu.berkeley.nlp.entity.coref.MentionPropertyComputer 16 | import edu.berkeley.nlp.entity.coref.OrderedClustering 17 | import edu.berkeley.nlp.entity.lang.Language 18 | import edu.berkeley.nlp.entity.ner.NerFeaturizer 19 | import edu.berkeley.nlp.entity.ner.NerPruner 20 | import edu.berkeley.nlp.entity.wiki.WikipediaInterface 21 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 22 | import edu.berkeley.nlp.futile.fig.exec.Execution 23 | import edu.berkeley.nlp.futile.util.Logger 24 | import edu.berkeley.nlp.entity.ner.NEEvaluator 25 | import edu.berkeley.nlp.entity.coref.CorefEvaluator 26 | import edu.berkeley.nlp.entity.Driver 27 | import scala.collection.GenTraversableOnce 28 | import java.io.PrintWriter 29 | 30 | @SerialVersionUID(1L) 31 | class JointPredictor(val jointFeaturizer: JointFeaturizerShared[NerFeaturizer], 32 | val weights: Array[Float], 33 | val corefPruner: CorefPruner, 34 | val nerPruner: NerPruner) extends Serializable { 35 | 36 | def makeIndividualDocPredictionWriter(maybeWikipediaInterface: Option[WikipediaInterface], outWriter: PrintWriter, outWikiWriter: PrintWriter): (JointDoc => Unit) = { 37 | val fgfOnto = new FactorGraphFactoryOnto(jointFeaturizer, maybeWikipediaInterface); 38 | val computer = new JointComputerShared(fgfOnto); 39 | (jointDoc: JointDoc) => { 40 | Logger.logss("Decoding " + jointDoc.rawDoc.printableDocName); 41 | // Don't decode if there are no mentions because things will break 42 | if (jointDoc.docGraph.getMentions.size == 0) { 43 | if (jointDoc.rawDoc.numSents > 0) { 44 | Logger.logss("WARNING: Document with zero mentions but nonzero number of sentences, not running NER but there could be NE mentions") 45 | } 46 | ConllDocWriter.writeDoc(outWriter, jointDoc.rawDoc) 47 | } else { 48 | val (backptrs, clustering, nerChunks, wikiChunks) = computer.viterbiDecodeProduceAnnotations(jointDoc, weights); 49 | ConllDocWriter.writeDocWithPredAnnotationsWikiStandoff(outWriter, outWikiWriter, jointDoc.rawDoc, nerChunks, clustering.bind(jointDoc.docGraph.getMentions, Driver.doConllPostprocessing), wikiChunks); 50 | } 51 | } 52 | } 53 | 54 | def decodeWriteOutput(jointTestDocs: Seq[JointDoc], maybeWikipediaInterface: Option[WikipediaInterface], doConllPostprocessing: Boolean) { 55 | decodeWriteOutputMaybeEvaluate(jointTestDocs, maybeWikipediaInterface, doConllPostprocessing, false); 56 | } 57 | 58 | def decodeWriteOutputEvaluate(jointTestDocs: Seq[JointDoc], maybeWikipediaInterface: Option[WikipediaInterface], doConllPostprocessing: Boolean) { 59 | decodeWriteOutputMaybeEvaluate(jointTestDocs, maybeWikipediaInterface, doConllPostprocessing, true); 60 | } 61 | 62 | private def decodeWriteOutputMaybeEvaluate(jointTestDocs: Seq[JointDoc], maybeWikipediaInterface: Option[WikipediaInterface], doConllPostprocessing: Boolean, evaluate: Boolean) { 63 | val fgfOnto = new FactorGraphFactoryOnto(jointFeaturizer, maybeWikipediaInterface); 64 | val computer = new JointComputerShared(fgfOnto); 65 | val outWriter = IOUtils.openOutHard(Execution.getFile("output.conll")) 66 | val outWikiWriter = IOUtils.openOutHard(Execution.getFile("output-wiki.conll")) 67 | val allPredBackptrsAndClusterings = new ArrayBuffer[(Array[Int],OrderedClustering)]; 68 | val predNEChunks = new ArrayBuffer[Seq[Seq[Chunk[String]]]]; 69 | Logger.startTrack("Decoding"); 70 | for (i <- (0 until jointTestDocs.size)) { 71 | Logger.logss("Decoding " + i); 72 | val jointDevDoc = jointTestDocs(i); 73 | val (backptrs, clustering, nerChunks, wikiChunks) = computer.viterbiDecodeProduceAnnotations(jointDevDoc, weights); 74 | ConllDocWriter.writeDocWithPredAnnotationsWikiStandoff(outWriter, outWikiWriter, jointDevDoc.rawDoc, nerChunks, clustering.bind(jointDevDoc.docGraph.getMentions, Driver.doConllPostprocessing), wikiChunks); 75 | if (evaluate) { 76 | allPredBackptrsAndClusterings += (backptrs -> clustering); 77 | predNEChunks += nerChunks; 78 | } 79 | } 80 | outWriter.close(); 81 | outWikiWriter.close(); 82 | Logger.endTrack(); 83 | if (evaluate) { 84 | Logger.logss(CorefEvaluator.evaluateAndRender(jointTestDocs.map(_.docGraph), allPredBackptrsAndClusterings.map(_._1), allPredBackptrsAndClusterings.map(_._2), 85 | Driver.conllEvalScriptPath, "DEV: ", Driver.analysesToPrint)); 86 | NEEvaluator.evaluateChunksBySent(jointTestDocs.flatMap(_.goldNERChunks), predNEChunks.flatten); 87 | NEEvaluator.evaluateOnConll2011(jointTestDocs, predNEChunks, Driver.conll2011Path.split(",").flatMap(path => ConllDocReader.readDocNames(path)).toSet, if (Driver.writeNerOutput) Execution.getFile("ner.txt") else ""); 88 | } 89 | } 90 | 91 | def pack: JointPredictor = { 92 | if (jointFeaturizer.canReplaceIndexer) { 93 | val (newIndexer, newWeights) = GUtil.packFeaturesAndWeights(jointFeaturizer.indexer, weights); 94 | new JointPredictor(jointFeaturizer.replaceIndexer(newIndexer), newWeights, corefPruner, nerPruner); 95 | } else { 96 | this; 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/lang/ArabicTreebankLanguagePack.java: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.lang; 2 | 3 | import edu.berkeley.nlp.futile.treebank.AbstractTreebankLanguagePack; 4 | 5 | 6 | public class ArabicTreebankLanguagePack extends AbstractTreebankLanguagePack { 7 | private static final String[] collinsPunctTags = {"PUNC"}; 8 | 9 | private static final String[] pennPunctTags = {"PUNC"}; 10 | 11 | private static final String[] pennPunctWords = {".","\"",",","-LRB-","-RRB-","-",":","/","?","_","*","%","!",">","-PLUS-","...",";","..","&","=","ر","'","\\","`","......"}; 12 | 13 | private static final String[] pennSFPunctTags = {"PUNC"}; 14 | 15 | private static final String[] pennSFPunctWords = {".", "!", "?"}; 16 | 17 | /** 18 | * The first 3 are used by the Penn Treebank; # is used by the 19 | * BLLIP corpus, and ^ and ~ are used by Klein's lexparser. 20 | * Chris deleted '_' for Arabic as it appears in tags (NO_FUNC). 21 | * June 2006: CDM tested _ again with true (new) Treebank tags to see if it 22 | * was useful for densening up the tag space, but the results were negative. 23 | * Roger added + for Arabic but Chris deleted it again, since unless you've 24 | * recoded determiners, it screws up DET+NOUN, etc. (That is, it would only be useful if 25 | * you always wanted to cut at the first '+', but in practice that is not viable, certainly 26 | * not with the IBM ATB processing either.) 27 | */ 28 | private static final char[] annotationIntroducingChars = {'-', '=', '|', '#', '^', '~'}; 29 | 30 | /** 31 | * This is valid for "BobChrisTreeNormalizer" conventions only. 32 | * wsg: "ROOT" should always be the first value. See {@link #startSymbol} in 33 | * the parent class. 34 | */ 35 | private static final String[] pennStartSymbols = {"ROOT"}; 36 | 37 | 38 | /** 39 | * Returns a String array of punctuation tags for this treebank/language. 40 | * 41 | * @return The punctuation tags 42 | */ 43 | @Override 44 | public String[] punctuationTags() { 45 | return pennPunctTags; 46 | } 47 | 48 | 49 | /** 50 | * Returns a String array of punctuation words for this treebank/language. 51 | * 52 | * @return The punctuation words 53 | */ 54 | @Override 55 | public String[] punctuationWords() { 56 | return pennPunctWords; 57 | } 58 | 59 | 60 | /** 61 | * Returns a String array of sentence final punctuation tags for this 62 | * treebank/language. 63 | * 64 | * @return The sentence final punctuation tags 65 | */ 66 | @Override 67 | public String[] sentenceFinalPunctuationTags() { 68 | return pennSFPunctTags; 69 | } 70 | 71 | /** 72 | * Returns a String array of sentence final punctuation words for this 73 | * treebank/language. 74 | * 75 | * @return The sentence final punctuation tags 76 | */ 77 | public String[] sentenceFinalPunctuationWords() { 78 | return pennSFPunctWords; 79 | } 80 | 81 | /** 82 | * Returns a String array of treebank start symbols. 83 | * 84 | * @return The start symbols 85 | */ 86 | @Override 87 | public String[] startSymbols() { 88 | return pennStartSymbols; 89 | } 90 | 91 | /** 92 | * Returns the extension of treebank files for this treebank. 93 | * This is "tree". 94 | */ 95 | public String treebankFileExtension() { 96 | return "tree"; 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/lang/CorefLanguagePack.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.lang 2 | 3 | trait CorefLanguagePack { 4 | def getMentionConstituentTypes: Seq[String]; 5 | def getPronominalTags: Seq[String]; 6 | def getProperTags: Seq[String]; 7 | } 8 | 9 | class EnglishCorefLanguagePack extends CorefLanguagePack { 10 | def getMentionConstituentTypes: Seq[String] = Seq("NP", "NML"); 11 | def getPronominalTags: Seq[String] = Seq("PRP", "PRP$"); 12 | def getProperTags: Seq[String] = Seq("NNP"); 13 | } 14 | 15 | class ChineseCorefLanguagePack extends CorefLanguagePack { 16 | def getMentionConstituentTypes: Seq[String] = Seq("NP"); 17 | def getPronominalTags: Seq[String] = Seq("PN"); 18 | def getProperTags: Seq[String] = Seq("NR"); 19 | } 20 | 21 | class ArabicCorefLanguagePack extends CorefLanguagePack { 22 | def getMentionConstituentTypes: Seq[String] = Seq("NP"); 23 | def getPronominalTags: Seq[String] = Seq("PRP", "PRP$"); 24 | def getProperTags: Seq[String] = Seq("NNP"); 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/lang/Language.java: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.lang; 2 | 3 | 4 | public enum Language { 5 | ENGLISH, ARABIC, CHINESE; 6 | } 7 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/ner/CorpusCounts.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.ner 2 | 3 | import scala.collection.mutable.HashSet 4 | import edu.berkeley.nlp.futile.util.Counter 5 | import edu.berkeley.nlp.futile.util.Logger 6 | 7 | @SerialVersionUID(1L) 8 | class CorpusCounts(val unigramCounts: Counter[String], 9 | val bigramCounts: Counter[(String,String)], 10 | val prefixCounts: Counter[String], 11 | val suffixCounts: Counter[String], 12 | val mostCommonUnigrams: HashSet[String]) extends Serializable { 13 | } 14 | 15 | object CorpusCounts { 16 | 17 | def countUnigramsBigrams(rawSents: Seq[Seq[String]], 18 | unigramThreshold: Int, 19 | bigramThreshold: Int, 20 | prefSuffThreshold: Int, 21 | commonUnigramCount: Int = 100): CorpusCounts = { 22 | val unigramCounts = new Counter[String]; 23 | val prefixCounts = new Counter[String]; 24 | val suffixCounts = new Counter[String]; 25 | def countWord(w: String) = { 26 | unigramCounts.incrementCount(w, 1.0); 27 | prefixCounts.incrementCount(NerFeaturizer.prefixFor(w), 1.0); 28 | suffixCounts.incrementCount(NerFeaturizer.prefixFor(w), 1.0); 29 | } 30 | val bigramCounts = new Counter[(String,String)]; 31 | for (rawSent <- rawSents) { 32 | // Be sure to increment the start- and end-of-word labels 33 | countWord(NerExample.wordAt(rawSent, -1)); 34 | for (i <- 0 to rawSent.size) { 35 | countWord(NerExample.wordAt(rawSent, i)); 36 | bigramCounts.incrementCount(NerExample.wordAt(rawSent, i-1) -> NerExample.wordAt(rawSent, i), 1.0); 37 | } 38 | } 39 | unigramCounts.pruneKeysBelowThreshold(unigramThreshold); 40 | bigramCounts.pruneKeysBelowThreshold(bigramThreshold); 41 | prefixCounts.pruneKeysBelowThreshold(prefSuffThreshold); 42 | suffixCounts.pruneKeysBelowThreshold(prefSuffThreshold); 43 | val mostCommonUnigrams = new HashSet[String]; 44 | val unigramPq = unigramCounts.asPriorityQueue(); 45 | while (unigramPq.hasNext && mostCommonUnigrams.size < commonUnigramCount) { 46 | val unigram = unigramPq.next; 47 | mostCommonUnigrams += unigram; 48 | } 49 | Logger.logss(unigramCounts.size + " unigrams kept above " + unigramThreshold); 50 | Logger.logss(bigramCounts.size + " bigrams kept above " + bigramThreshold); 51 | Logger.logss(prefixCounts.size + " prefixes kept above " + prefSuffThreshold); 52 | Logger.logss(suffixCounts.size + " suffixes kept above " + prefSuffThreshold); 53 | new CorpusCounts(unigramCounts, bigramCounts, prefixCounts, suffixCounts, mostCommonUnigrams); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/ner/MCNerExample.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.ner 2 | 3 | import edu.berkeley.nlp.entity.DepConstTree 4 | import edu.berkeley.nlp.entity.coref.DocumentGraph 5 | 6 | case class MCNerExample(val words: Seq[String], 7 | val poss: Seq[String], 8 | val tree: DepConstTree, 9 | val startIdx: Int, 10 | val headIdx: Int, 11 | val endIdx: Int, 12 | val goldLabel: String) { 13 | def wordAt(i: Int) = NerExample.wordAt(words, i); 14 | def posAt(i: Int) = NerExample.posAt(poss, i); 15 | } 16 | 17 | object MCNerExample { 18 | def apply(docGraph: DocumentGraph, idx: Int) = { 19 | val pm = docGraph.getMention(idx); 20 | val rawDoc = docGraph.corefDoc.rawDoc; 21 | new MCNerExample(rawDoc.words(pm.sentIdx), rawDoc.pos(pm.sentIdx), rawDoc.trees(pm.sentIdx), pm.startIdx, pm.headIdx, pm.endIdx, pm.nerString); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/ner/MCNerFeaturizer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.ner 2 | 3 | import edu.berkeley.nlp.futile.fig.basic.Indexer 4 | import edu.berkeley.nlp.entity.wiki.WikipediaInterface 5 | import edu.berkeley.nlp.futile.util.Counter 6 | import edu.berkeley.nlp.futile.util.Logger 7 | import scala.collection.JavaConverters._ 8 | import scala.collection.mutable.HashSet 9 | 10 | @SerialVersionUID(1L) 11 | class MCNerFeaturizer(val featureSet: Set[String], 12 | val featureIndexer: Indexer[String], 13 | val labelIndexer: Indexer[String], 14 | val corpusCounts: CorpusCounts, 15 | val wikipediaDB: Option[WikipediaInterface] = None, 16 | val brownClusters: Option[Map[String,String]] = None) extends Serializable { 17 | Logger.logss(labelIndexer.getObjects.asScala.toSeq + " label indices"); 18 | val nerFeaturizer = new NerFeaturizer(featureSet, featureIndexer, labelIndexer, corpusCounts, wikipediaDB, brownClusters); 19 | 20 | def replaceIndexer(newIndexer: Indexer[String]) = { 21 | new MCNerFeaturizer(featureSet, newIndexer, labelIndexer, corpusCounts, wikipediaDB, brownClusters); 22 | } 23 | 24 | def featurize(ex: MCNerExample, addToIndexer: Boolean): Array[Array[Int]] = { 25 | // Go from the beginning to the head index only (drop trailing PPs, etc.) 26 | val featsEachTokenEachLabel = nerFeaturizer.featurize(new NerExample(ex.words, ex.poss, null), addToIndexer, Some(ex.startIdx -> (ex.headIdx + 1))); 27 | // Append all token features for each label 28 | (0 until labelIndexer.size).map(labelIdx => featsEachTokenEachLabel.map(_(labelIdx)).flatten).toArray; 29 | } 30 | } 31 | 32 | object MCNerFeaturizer { 33 | 34 | val TagSet = IndexedSeq("FAC", "GPE", "LOC", "PER", "ORG", "VEH", "WEA"); 35 | val StdLabelIndexer = new Indexer[String](); 36 | for (tag <- TagSet) { 37 | StdLabelIndexer.add(tag); 38 | } 39 | 40 | def apply(featureSet: Set[String], 41 | featureIndexer: Indexer[String], 42 | labelIndexer: Indexer[String], 43 | rawSents: Seq[Seq[String]], 44 | wikipediaDB: Option[WikipediaInterface], 45 | brownClusters: Option[Map[String,String]], 46 | unigramThreshold: Int = 1, 47 | bigramThreshold: Int = 10, 48 | prefSuffThreshold: Int = 1) = { 49 | val corpusCounts = CorpusCounts.countUnigramsBigrams(rawSents, unigramThreshold, bigramThreshold, prefSuffThreshold); 50 | new MCNerFeaturizer(featureSet, featureIndexer, labelIndexer, corpusCounts, wikipediaDB, brownClusters); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/ner/NESentenceMunger.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.ner 2 | 3 | import edu.berkeley.nlp.entity.ConllDocReader 4 | import edu.berkeley.nlp.futile.util.Logger 5 | import edu.berkeley.nlp.entity.ConllDoc 6 | import edu.berkeley.nlp.futile.syntax.Trees.PennTreeRenderer 7 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 8 | 9 | object NESentenceMunger { 10 | 11 | def writeSentences(file: String, docs: Seq[ConllDoc]) { 12 | val out = IOUtils.openOutHard(file); 13 | for (doc <- docs; words <- doc.words) { 14 | out.println(words.foldLeft("")(_ + " " + _).trim); 15 | } 16 | out.close(); 17 | } 18 | 19 | def main(args: Array[String]) { 20 | val devDocs = ConllDocReader.loadRawConllDocsWithSuffix("data/conll-2012-en/dev/", -1, "gold_conll"); 21 | writeSentences("data/conll-2012-dev.sentences", devDocs); 22 | val testDocs = ConllDocReader.loadRawConllDocsWithSuffix("data/conll-2012-en/test/", -1, "gold_conll"); 23 | writeSentences("data/conll-2012-test.sentences", testDocs); 24 | val dev2011Docs = ConllDocReader.loadRawConllDocsWithSuffix("data/ontonotes-conll/dev/", -1, "gold_conll"); 25 | writeSentences("data/conll-2011-dev.sentences", dev2011Docs); 26 | val test2011Docs = ConllDocReader.loadRawConllDocsWithSuffix("data/ontonotes-conll/test/", -1, "gold_conll"); 27 | writeSentences("data/conll-2011-test.sentences", test2011Docs); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/ner/NerDriver.java: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.ner; 2 | 3 | import edu.berkeley.nlp.futile.fig.basic.Option; 4 | import edu.berkeley.nlp.futile.fig.exec.Execution; 5 | import edu.berkeley.nlp.futile.util.Logger; 6 | 7 | /** 8 | * Driver for training and a CRF-based NER system. Currently does not run as a standalone 9 | * NER predictor and only integrates with the preprocessor described in 10 | * edu.berkeley.nlp.entity.preprocess.PreprocessingDriver 11 | * 12 | * TRAIN_EVALUATE: Trains and evaluates the NER system. 13 | * Required arguments: -trainPath, -testPath, -brownClustersPath (path to bllip-clusters) 14 | * 15 | * TRAIN_RUN_KFOLD: Generates pruning masks for training and test data. k-fold cross 16 | * validation is used to generate training masks (so we don't train on the same data 17 | * that we evaluate with). 18 | * Required arguments: -kFoldInputPath, -testPath, -brownClustersPath (path to bllip-clusters), 19 | * -marginalsPath (output path), -modelOutPath 20 | * 21 | * Other arguments allow you to tweak the trained model 22 | * 23 | * @author gdurrett 24 | * 25 | */ 26 | public class NerDriver implements Runnable { 27 | @Option(gloss = "") 28 | public static Mode mode = Mode.TRAIN_EVALUATE; 29 | 30 | @Option(gloss = "Path to read/write the model") 31 | public static String modelPath = ""; 32 | 33 | // TRAINING_OPTIONS 34 | @Option(gloss = "Path to CoNLL training set") 35 | public static String trainPath = ""; 36 | @Option(gloss = "Training set size, -1 for all") 37 | public static int trainSize = -1; 38 | @Option(gloss = "Path to CoNLL test set") 39 | public static String testPath = ""; 40 | @Option(gloss = "Test set size, -1 for all") 41 | public static int testSize = -1; 42 | @Option(gloss = "Path to Brown clusters") 43 | public static String brownClustersPath = ""; 44 | @Option(gloss = "Path to write evaluation output to") 45 | public static String outputPath = ""; 46 | 47 | @Option(gloss = "Use predicted POS tags") 48 | public static boolean usePredPos = true; 49 | 50 | @Option(gloss = "Use variational decoding") 51 | public static boolean variational = false; 52 | 53 | @Option(gloss = "Path to serialized NE marginals for the training set") 54 | public static String marginalsPath = ""; 55 | @Option(gloss = "Path to serialized model") 56 | public static String modelOutPath = ""; 57 | 58 | @Option(gloss = "Path to document sets to do the k-fold trick on") 59 | public static String kFoldInputPath = ""; 60 | @Option(gloss = "K-fold set size, -1 for all") 61 | public static int kFoldSize = -1; 62 | @Option(gloss = "Path to write documents after k-fold prediction") 63 | public static String kFoldOutputPath = ""; 64 | @Option(gloss = "Number of folds to use") 65 | public static int numFolds = 5; 66 | 67 | @Option(gloss = "Regularization") 68 | public static double reg = 1e-8; 69 | @Option(gloss = "Iterations to optimize for") 70 | public static int numItrs = 30; 71 | @Option(gloss = "Adagrad batch size") 72 | public static int batchSize = 100; 73 | 74 | @Option(gloss = "Feature set") 75 | public static String featureSet = "bigrams+brown"; 76 | @Option(gloss = "Unigram cutoff threshold") 77 | public static int unigramThreshold = 1; 78 | @Option(gloss = "Bigram cutoff threshold") 79 | public static int bigramThreshold = 10; 80 | @Option(gloss = "Prefix/suffix cutoff threshold") 81 | public static int prefSuffThreshold = 2; 82 | 83 | 84 | public static enum Mode { 85 | TRAIN_EVALUATE, TRAIN_RUN_KFOLD; 86 | } 87 | 88 | public static void main(String[] args) { 89 | NerDriver main = new NerDriver(); 90 | Execution.run(args, main); // add .class here if that class should receive command-line args 91 | } 92 | 93 | public void run() { 94 | Logger.setFig(); 95 | switch (mode) { 96 | case TRAIN_EVALUATE: NerSystemLabeled.trainEvaluateNerSystem(trainPath, trainSize, testPath, testSize); 97 | break; 98 | case TRAIN_RUN_KFOLD: 99 | NerSystemLabeled.trainPredictTokenMarginalsKFold(kFoldInputPath, kFoldSize, brownClustersPath, testPath.split(","), testSize, numFolds, marginalsPath, modelOutPath); 100 | break; 101 | } 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/ner/NerExample.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.ner 2 | import edu.berkeley.nlp.futile.fig.basic.Indexer 3 | import scala.collection.mutable.ArrayBuffer 4 | import edu.berkeley.nlp.futile.util.Logger 5 | 6 | case class NerExample(val words: Seq[String], 7 | val poss: Seq[String], 8 | val goldLabels: Seq[String]) { 9 | def wordAt(i: Int) = NerExample.wordAt(words, i); 10 | def posAt(i: Int) = NerExample.posAt(poss, i); 11 | } 12 | 13 | object NerExample { 14 | def wordAt(words: Seq[String], i: Int) = if (i < 0) "<>" else if (i >= words.size) "<>" else words(i); 15 | def posAt(poss: Seq[String], i: Int) = if (i < 0) "<>" else if (i >= poss.size) "<>" else poss(i); 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/ner/NerPruner.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.ner 2 | 3 | import scala.collection.mutable.HashMap 4 | import edu.berkeley.nlp.entity.coref.UID 5 | import edu.berkeley.nlp.entity.ConllDoc 6 | import edu.berkeley.nlp.entity.GUtil 7 | import edu.berkeley.nlp.futile.fig.basic.Indexer 8 | import edu.berkeley.nlp.entity.Driver 9 | import edu.berkeley.nlp.futile.util.Logger 10 | 11 | trait NerPruner { 12 | 13 | def pruneSentence(doc: ConllDoc, sentIdx: Int): Array[Array[String]]; 14 | } 15 | 16 | @SerialVersionUID(1L) 17 | class NerPrunerFromModel(val nerModel: NerSystemLabeled, 18 | val pruningThreshold: Double) extends NerPruner with Serializable { 19 | 20 | def pruneSentence(doc: ConllDoc, sentIdx: Int): Array[Array[String]] = { 21 | val sentMarginals = nerModel.computeLogMarginals(doc.words(sentIdx).toArray, doc.pos(sentIdx).toArray); 22 | NerPruner.pruneFromMarginals(sentMarginals, nerModel.labelIndexer, pruningThreshold); 23 | } 24 | } 25 | 26 | @SerialVersionUID(1L) 27 | class NerPrunerFromMarginals(val nerMarginals: HashMap[UID,Seq[Array[Array[Float]]]], 28 | val neLabelIndexer: Indexer[String], 29 | val pruningThreshold: Double) extends NerPruner with Serializable { 30 | 31 | def pruneSentence(doc: ConllDoc, sentIdx: Int): Array[Array[String]] = { 32 | require(nerMarginals.contains(doc.uid), "Doc ID " + doc.uid + " doesn't have precomputed NER marginals" + 33 | " and the NER pruner in this model is configured to rely on these. You need to either change" + 34 | " how you specify the pruner (if training) or use a different model entirely (if testing)"); 35 | NerPruner.pruneFromMarginals(nerMarginals(doc.uid)(sentIdx), neLabelIndexer, pruningThreshold); 36 | } 37 | } 38 | 39 | @SerialVersionUID(1L) 40 | class NerPrunerFromMarginalsAndModel(val nerMarginals: HashMap[UID,Seq[Array[Array[Float]]]], 41 | val neLabelIndexer: Indexer[String], 42 | val nerModel: NerSystemLabeled, 43 | val pruningThreshold: Double) extends NerPruner with Serializable { 44 | 45 | def pruneSentence(doc: ConllDoc, sentIdx: Int): Array[Array[String]] = { 46 | val sentMarginals = if (nerMarginals.contains(doc.uid)) { 47 | nerMarginals(doc.uid)(sentIdx) 48 | } else { 49 | nerModel.computeLogMarginals(doc.words(sentIdx).toArray, doc.pos(sentIdx).toArray); 50 | } 51 | NerPruner.pruneFromMarginals(sentMarginals, neLabelIndexer, pruningThreshold); 52 | } 53 | } 54 | 55 | object NerPruner { 56 | 57 | def buildPruner(strategy: String): NerPruner = { 58 | val splitStrategy = strategy.split(":"); 59 | if (splitStrategy(0) == "model") { 60 | val nerModel = GUtil.load(splitStrategy(1)).asInstanceOf[NerSystemLabeled] 61 | val threshold = splitStrategy(2).toDouble; 62 | new NerPrunerFromModel(nerModel, threshold); 63 | } else if (splitStrategy(0) == "marginals") { 64 | val nerMarginals = GUtil.load(splitStrategy(1)).asInstanceOf[HashMap[UID,Seq[Array[Array[Float]]]]]; 65 | val threshold = splitStrategy(2).toDouble; 66 | new NerPrunerFromMarginals(nerMarginals, NerSystemLabeled.StdLabelIndexer, threshold); 67 | } else if (splitStrategy(0) == "marginalsmodel") { 68 | val nerMarginals = GUtil.load(splitStrategy(1)).asInstanceOf[HashMap[UID,Seq[Array[Array[Float]]]]]; 69 | val nerModel = GUtil.load(splitStrategy(2)).asInstanceOf[NerSystemLabeled] 70 | val threshold = splitStrategy(3).toDouble 71 | new NerPrunerFromMarginalsAndModel(nerMarginals, NerSystemLabeled.StdLabelIndexer, nerModel, threshold); 72 | } else if (splitStrategy(0) == "build") { 73 | Logger.logss("----------------------------") 74 | Logger.logss("BUILDING COARSE NER MODELS"); 75 | val marginalsOutPath = splitStrategy(1) 76 | val modelOutPath = if (marginalsOutPath.endsWith(".ser.gz")) marginalsOutPath.dropRight(7) + "-model.ser.gz" else marginalsOutPath + "-model.ser.gz"; 77 | val threshold = splitStrategy(2).toFloat; 78 | val numFolds = splitStrategy(3).toInt; 79 | val (nerMarginals, nerModel) = NerSystemLabeled.trainPredictTokenMarginalsKFold(Driver.trainPath, Driver.trainSize, Driver.brownPath, Array(Driver.testPath), Driver.testSize, numFolds, marginalsOutPath, modelOutPath); 80 | new NerPrunerFromMarginalsAndModel(nerMarginals, NerSystemLabeled.StdLabelIndexer, nerModel, threshold); 81 | } else { 82 | throw new RuntimeException("Unknown NER pruning method") 83 | } 84 | } 85 | 86 | def pruneFromMarginals(sentMarginals: Array[Array[Float]], neLabelIndexer: Indexer[String], pruningThreshold: Double) = { 87 | val sentLen = sentMarginals.size; 88 | val allOptions = Array.tabulate(sentLen)(wordIdx => { 89 | val bestOption = GUtil.argMaxIdxFloat(sentMarginals(wordIdx)); 90 | val bestScore = sentMarginals(wordIdx)(bestOption); 91 | val remainingOptions = (0 until neLabelIndexer.size).toArray.filter(labelIdx => { 92 | !sentMarginals(wordIdx)(labelIdx).isInfinite && sentMarginals(wordIdx)(labelIdx) >= bestScore + pruningThreshold 93 | }).map(i => neLabelIndexer.getObject(i)) 94 | remainingOptions; 95 | }); 96 | allOptions; 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/preprocess/ConllDocSharder.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.preprocess 2 | 3 | import edu.berkeley.nlp.futile.LightRunner 4 | import edu.berkeley.nlp.entity.ConllDoc 5 | import edu.berkeley.nlp.entity.ConllDocReader 6 | import java.io.File 7 | import edu.berkeley.nlp.entity.lang.Language 8 | import edu.berkeley.nlp.entity.ConllDocWriter 9 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 10 | import edu.berkeley.nlp.futile.util.Logger 11 | 12 | /** 13 | * When given a file as input and a file as output, reads in CoNLL documents from the 14 | * file and writes each document to its own file in the directory. 15 | * N.B. Currently only works on CoNLL docs without standoff entity link annotations 16 | */ 17 | object ConllDocSharder { 18 | 19 | val inputFile = "" 20 | val outputDirectory = "" 21 | 22 | def main(args: Array[String]) { 23 | LightRunner.populateScala(ConllDocSharder.getClass(), args) 24 | require(!inputFile.isEmpty && !outputDirectory.isEmpty) 25 | var tokens = 0 26 | new ConllDocReader(Language.ENGLISH).readConllDocsProcessStreaming(inputFile, (doc: ConllDoc) => { 27 | val outputName = outputDirectory + "/" + doc.docID 28 | val writer = IOUtils.openOutHard(outputName) 29 | tokens += doc.words.map(_.size).foldLeft(0)(_ + _) 30 | ConllDocWriter.writeDoc(writer, doc) 31 | writer.close 32 | }) 33 | Logger.logss("Wrote " + tokens + " tokens") 34 | } 35 | } -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/preprocess/PTBToConllMunger.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.preprocess 2 | 3 | import edu.berkeley.nlp.futile.LightRunner 4 | import java.io.File 5 | import edu.berkeley.nlp.futile.syntax.Trees.PennTreeReader 6 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 7 | import edu.berkeley.nlp.entity.ConllDoc 8 | import edu.berkeley.nlp.entity.DepConstTree 9 | import scala.collection.mutable.ArrayBuffer 10 | import edu.berkeley.nlp.entity.Chunk 11 | import edu.berkeley.nlp.entity.ConllDocWriter 12 | import edu.berkeley.nlp.entity.coref.OrderedClusteringBound 13 | import edu.berkeley.nlp.entity.coref.Mention 14 | import edu.berkeley.nlp.entity.coref.OrderedClustering 15 | import edu.berkeley.nlp.futile.util.Logger 16 | import edu.berkeley.nlp.futile.syntax.Tree 17 | import scala.collection.JavaConverters._ 18 | import scala.collection.mutable.HashSet 19 | import edu.berkeley.nlp.futile.syntax.Trees 20 | 21 | /** 22 | * Takes either a file or directory as input, reads in PTB files one per line, 23 | * and writes them to a file or directory 24 | */ 25 | object PTBToConllMunger { 26 | 27 | val input = "" 28 | val output = "" 29 | // Changes trees from having "Mr ." as two tokens to "Mr."; this appears to be 30 | // a problem when using some tokenizers. Coref systems expect the latter. 31 | val fixAbbrevs = false 32 | val abbrevsToFix = new HashSet[String]() ++ Array[String]("Mr", "Mrs", "Ms", "Dr") 33 | 34 | def main(args: Array[String]) { 35 | // LightRunner.initializeOutput(PTBToConllMunger.getClass()); 36 | LightRunner.populateScala(PTBToConllMunger.getClass(), args) 37 | val inputFile = new File(input) 38 | val outputFile = new File(output) 39 | var outputWriter = if (outputFile.isDirectory) null else IOUtils.openOutHard(outputFile) 40 | for (file <- (if (inputFile.isDirectory) inputFile.listFiles.toSeq else Seq(inputFile))) { 41 | val doc = readParsesMakeDoc(file) 42 | if (outputWriter == null) { 43 | outputWriter = IOUtils.openOutHard(outputFile.getAbsolutePath + "/" + doc.docID) 44 | ConllDocWriter.writeDoc(outputWriter, doc) 45 | outputWriter.close 46 | outputWriter = null 47 | } else { 48 | ConllDocWriter.writeDoc(outputWriter, doc) 49 | } 50 | } 51 | if (outputWriter != null) { 52 | outputWriter.close 53 | } 54 | } 55 | 56 | def readParsesMakeDoc(file: File) = { 57 | val inBuffer = IOUtils.openInHard(file) 58 | val lineItr = IOUtils.lineIterator(inBuffer) 59 | val words = new ArrayBuffer[Seq[String]] 60 | val pos = new ArrayBuffer[Seq[String]] 61 | val trees = new ArrayBuffer[DepConstTree] 62 | while (lineItr.hasNext) { 63 | val currLine = lineItr.next.trim 64 | if (currLine != "" && currLine != "(())") { 65 | val currParseRaw = PennTreeReader.parseHard(currLine, false) 66 | val currParse = if (fixAbbrevs) { 67 | fixAbbrevs(currParseRaw) 68 | } else { 69 | currParseRaw 70 | } 71 | val currDepConstTree = DepConstTree.apply(currParse) 72 | words += currDepConstTree.words 73 | pos += currDepConstTree.pos 74 | trees += currDepConstTree 75 | } 76 | } 77 | inBuffer.close() 78 | val nerChunks = (0 until words.size).map(i => Seq[Chunk[String]]()) 79 | val corefChunks = (0 until words.size).map(i => Seq[Chunk[Int]]()) 80 | val speakers = (0 until words.size).map(i => (0 until words(i).size).map(j => "-")) 81 | new ConllDoc(file.getName(), 0, words, pos, trees, nerChunks, corefChunks, speakers) 82 | } 83 | 84 | def fixAbbrevs(tree: Tree[String]): Tree[String] = { 85 | val treeYield = tree.getYield().asScala 86 | val abbrevIndices = new ArrayBuffer[Int] 87 | for (abbrev <- abbrevsToFix) { 88 | var startIdx = 0 89 | var currIdx = treeYield.indexOf(abbrev, startIdx) 90 | while (currIdx != -1) { 91 | abbrevIndices += currIdx 92 | startIdx = currIdx + 1 93 | currIdx = treeYield.indexOf(abbrev, startIdx) 94 | } 95 | } 96 | if (abbrevIndices.size == 0) { 97 | tree 98 | } else { 99 | // The transformation could theoretically product X over X so redo this transformation 100 | new Trees.XOverXRemover().transformTree(transformFixAbbrevs(tree, 0, treeYield.size, abbrevIndices.sorted)) 101 | } 102 | } 103 | 104 | /** 105 | * Need to do two things to fix abbreviations: add . to the abbreviation and remove the . token 106 | */ 107 | def transformFixAbbrevs(tree: Tree[String], startIdx: Int, endIdx: Int, abbrevIndices: Seq[Int]): Tree[String] = { 108 | // Leaf: fix the abbreviation label if necessary 109 | if (tree.isLeaf()) { 110 | if (abbrevIndices.contains(startIdx)) { 111 | new Tree[String](tree.getLabel() + ".") 112 | } else { 113 | tree 114 | } 115 | } else { 116 | // } else if (tree.isPreTerminal()) { 117 | // new Tree[String](tree.getLabel(), transformFixAbbrevs(tree.getChildren().get(0), startIdx, endIdx, abbrevIndices) 118 | // } else { 119 | // Select things that either contain this index or the next (the .) 120 | val matchingAbbrevIndices = abbrevIndices.filter(idx => startIdx <= idx + 1 && idx < endIdx) 121 | if (matchingAbbrevIndices.size == 0) { 122 | tree 123 | } else { 124 | val children = tree.getChildren().asScala 125 | var currIdx = startIdx 126 | val newChildren = new ArrayBuffer[Tree[String]] 127 | for (child <- children) { 128 | val childYield = child.getYield().asScala 129 | val childYieldSize = childYield.size 130 | val currEndIdx = currIdx + childYieldSize 131 | // If this child only dominates the offending period 132 | if (matchingAbbrevIndices.contains(currIdx - 1) && childYieldSize == 1 && childYield.head == ".") { 133 | // Delete this child by doing nothing 134 | } else { 135 | // Otherwise proceed as normal 136 | newChildren += transformFixAbbrevs(child, currIdx, currEndIdx, abbrevIndices) 137 | } 138 | currIdx += childYieldSize 139 | } 140 | new Tree[String](tree.getLabel(), newChildren.asJava) 141 | } 142 | } 143 | } 144 | } -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/preprocess/Reprocessor.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.preprocess 2 | 3 | import edu.berkeley.nlp.PCFGLA.CoarseToFineMaxRuleParser 4 | import edu.berkeley.nlp.entity.ConllDoc 5 | import scala.collection.JavaConverters._ 6 | import scala.collection.mutable.ArrayBuffer 7 | import java.io.PrintWriter 8 | import edu.berkeley.nlp.entity.ConllDocReader 9 | import edu.berkeley.nlp.syntax.Tree 10 | import edu.berkeley.nlp.futile.util.Logger 11 | import java.util.Arrays 12 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 13 | import edu.berkeley.nlp.entity.Chunk 14 | import edu.berkeley.nlp.entity.ConllDocWriter 15 | import edu.berkeley.nlp.entity.ner.NerSystemLabeled 16 | 17 | object Reprocessor { 18 | 19 | def redoConllDocument(parser: CoarseToFineMaxRuleParser, backoffParser: CoarseToFineMaxRuleParser, nerSystem: NerSystemLabeled, docReader: ConllDocReader, inputPath: String, outputPath: String) { 20 | val writer = IOUtils.openOutHard(outputPath); 21 | val docs = docReader.readConllDocs(inputPath); 22 | for (doc <- docs) { 23 | Logger.logss("Reprocessing: " + doc.docID + " part " + doc.docPartNo); 24 | val newPos = new ArrayBuffer[Seq[String]](); 25 | val newParses = new ArrayBuffer[edu.berkeley.nlp.futile.syntax.Tree[String]](); 26 | val newNerChunks = new ArrayBuffer[Seq[Chunk[String]]](); 27 | for (sentIdx <- 0 until doc.words.size) { 28 | if (sentIdx % 10 == 0) { 29 | Logger.logss("Sentence " + sentIdx); 30 | } 31 | val sent = doc.words(sentIdx); 32 | var parse = PreprocessingDriver.parse(parser, backoffParser, sent.asJava); 33 | parse = if (parse.getYield().size() != sent.length) { 34 | Logger.logss("Couldn't parse sentence: " + sent.toSeq); 35 | Logger.logss("Using default parse"); 36 | convertFromFutileTree(doc.trees(sentIdx).constTree); 37 | } else { 38 | parse; 39 | } 40 | val posTags = parse.getPreTerminalYield().asScala.toArray; 41 | newPos += posTags; 42 | newParses += convertToFutileTree(parse); 43 | val nerBioLabels = if (nerSystem != null) nerSystem.tagBIO(sent.toArray, posTags) else Array.tabulate(sent.size)(i => "O"); 44 | newNerChunks += convertBioToChunks(nerBioLabels); 45 | } 46 | ConllDocWriter.writeIncompleteConllDoc(writer, doc.docID, doc.docPartNo, doc.words, newPos, newParses, doc.speakers, newNerChunks, doc.corefChunks); 47 | } 48 | writer.close(); 49 | } 50 | 51 | def convertBioToChunks(nerBioLabels: Seq[String]): Seq[Chunk[String]] = { 52 | var lastNerStart = -1; 53 | val chunks = new ArrayBuffer[Chunk[String]](); 54 | for (i <- 0 until nerBioLabels.size) { 55 | if (nerBioLabels(i).startsWith("B")) { 56 | if (lastNerStart != -1) { 57 | chunks += new Chunk[String](lastNerStart, i, "MISC"); 58 | } 59 | lastNerStart = i; 60 | } else if (nerBioLabels(i).startsWith("O")) { 61 | if (lastNerStart != -1) { 62 | chunks += new Chunk[String](lastNerStart, i, "MISC"); 63 | lastNerStart = -1; 64 | } 65 | } 66 | } 67 | chunks; 68 | } 69 | 70 | def convertToFutileTree(slavTree: edu.berkeley.nlp.syntax.Tree[String]): edu.berkeley.nlp.futile.syntax.Tree[String] = { 71 | new edu.berkeley.nlp.futile.syntax.Tree[String](slavTree.getLabel(), slavTree.getChildren().asScala.map(convertToFutileTree(_)).asJava); 72 | } 73 | 74 | def convertFromFutileTree(myTree: edu.berkeley.nlp.futile.syntax.Tree[String]): edu.berkeley.nlp.syntax.Tree[String] = { 75 | new edu.berkeley.nlp.syntax.Tree[String](myTree.getLabel(), myTree.getChildren().asScala.map(convertFromFutileTree(_)).asJava); 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/preprocess/SentenceSplitterTokenizerDriver.java: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.preprocess; 2 | 3 | import java.io.PrintWriter; 4 | import java.util.List; 5 | 6 | import edu.berkeley.nlp.futile.tokenizer.PTBLineLexer; 7 | import edu.berkeley.nlp.futile.util.Logger; 8 | import edu.berkeley.nlp.futile.fig.basic.IOUtils; 9 | import edu.berkeley.nlp.futile.fig.basic.Option; 10 | import edu.berkeley.nlp.futile.fig.exec.Execution; 11 | 12 | /** 13 | * Driver for training new sentence splitting models and running just the 14 | * sentence splitting and tokenization phases of the preprocessing pipeline. 15 | * 16 | * TRAIN: Trains a sentence splitter 17 | * Required arguments: -inputPath 18 | * or -trainFromConll, -conllTrainPath, -conllTestPath (can be the same as train) 19 | * 20 | * RUN: Sentence-splits and tokenizes some data 21 | * Required arguments: -inputPath, -modelPath 22 | * 23 | * @author gdurrett 24 | * 25 | */ 26 | public class SentenceSplitterTokenizerDriver implements Runnable { 27 | @Option(gloss = "") 28 | public static Mode mode = Mode.TRAIN; 29 | 30 | @Option(gloss = "Raw text input") 31 | public static String inputPath = ""; 32 | @Option(gloss = "") 33 | public static String outputPath = ""; 34 | @Option(gloss = "") 35 | public static boolean respectInputLineBreaks = false; 36 | @Option(gloss = "") 37 | public static boolean respectInputTwoLineBreaks = true; 38 | 39 | @Option(gloss = "Path to read/write the model") 40 | public static String modelPath = ""; 41 | 42 | // TRAINING OPTIONS 43 | @Option(gloss = "Train the sentence splitter from the CoNLL data. If false, you " + 44 | "must provide your own data in the format\n" + 45 | ". <0 or 1>\n" + 46 | "where 0 indicates not a boundary and 1 indicates a boundary.") 47 | public static boolean trainFromConll = true; 48 | 49 | @Option(gloss = "Path to training set") 50 | public static String trainPath = ""; 51 | @Option(gloss = "Path to test set") 52 | public static String testPath = ""; 53 | @Option(gloss = "Path to CoNLL training set") 54 | public static String conllTrainPath = ""; 55 | @Option(gloss = "Training set size, -1 for all") 56 | public static int conllTrainSize = -1; 57 | @Option(gloss = "Path to CoNLL test set") 58 | public static String conllTestPath = ""; 59 | @Option(gloss = "Test set size, -1 for all") 60 | public static int conllTestSize = -1; 61 | 62 | public static enum Mode { 63 | TRAIN, RUN; 64 | } 65 | 66 | public static void main(String[] args) { 67 | SentenceSplitterTokenizerDriver main = new SentenceSplitterTokenizerDriver(); 68 | Execution.run(args, main); // add .class here if that class should receive command-line args 69 | } 70 | 71 | public void run() { 72 | Logger.setFig(); 73 | switch (mode) { 74 | case TRAIN: SentenceSplitter.trainSentenceSplitter(); 75 | break; 76 | case RUN: 77 | SentenceSplitter splitter = SentenceSplitter.loadSentenceSplitter(modelPath); 78 | String[] lines = IOUtils.readLinesHard(inputPath).toArray(new String[0]); 79 | String[] canonicalizedParagraphs = splitter.formCanonicalizedParagraphs(lines, respectInputLineBreaks, respectInputTwoLineBreaks); 80 | String[] sentences = splitter.splitSentences(canonicalizedParagraphs); 81 | String[][] tokenizedSentences = splitter.tokenize(sentences); 82 | PrintWriter writer = IOUtils.openOutHard(outputPath); 83 | for (String[] sentence : tokenizedSentences) { 84 | for (int i = 0; i < sentence.length; i++) { 85 | writer.print(sentence[i]); 86 | if (i < sentence.length - 1) { 87 | writer.print(" "); 88 | } 89 | } 90 | writer.println(); 91 | } 92 | writer.close(); 93 | break; 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/preprocess/Tokenizer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.preprocess 2 | 3 | import edu.berkeley.nlp.futile.tokenizer.PTBLineLexer 4 | import scala.collection.mutable.HashMap 5 | import edu.berkeley.nlp.futile.util.Logger 6 | 7 | trait Tokenizer { 8 | def tokenize(sentence: String): Array[String]; 9 | } 10 | 11 | case class StandardPTBTokenizer() extends Tokenizer { 12 | val tokenizer = new PTBLineLexer(); 13 | 14 | def tokenize(sentence: String) = { 15 | tokenizer.tokenize(sentence).toArray(new Array[String](0)); 16 | } 17 | } 18 | 19 | /** 20 | * Taken from 21 | * http://www.cis.upenn.edu/~treebank/tokenizer.sed 22 | * 23 | * Main things this doesn't do 24 | * --Hyphens aren't split out 25 | * --Respects the sentence boundary detector which fails sometimes 26 | */ 27 | case class CustomPTBTokenizer() extends Tokenizer { 28 | 29 | def tokenize(sentence: String) = { 30 | // Beginning and end are padded with a space to make matching simpler 31 | var currSentence = " " + sentence.trim() + " "; 32 | // Fix quotes 33 | currSentence = currSentence.replace(" \"", " `` "); 34 | currSentence = currSentence.replace("\"", "''"); 35 | // Do this before periods 36 | currSentence = currSentence.replace("...", " ... "); 37 | // Split out final periods, possibly followed by ', ", ), ], } 38 | // currSentence = currSentence.replaceAll(".(['\"\\)}\\]]) ", " . $1 "); 39 | currSentence = currSentence.replaceAll("\\.(['\"\\)}\\]]|(''))? $", " . $1 "); 40 | // Break out quotes 41 | currSentence = currSentence.replaceAll("''", " '' "); 42 | // Dashes 43 | currSentence = currSentence.replace("--", " -- "); 44 | // Split out punctuation, brackets, etc. 45 | // Exception: keep commas in numbers 46 | currSentence = currSentence.replaceAll("(\\d),(\\d)", "$1COMMAMARKER$2") 47 | for (symbol <- Tokenizer.AllSymbols) { 48 | currSentence = currSentence.replace(symbol, " " + symbol + " "); 49 | } 50 | currSentence = currSentence.replace("COMMAMARKER", ","); 51 | // Fix brackets, etc. 52 | for (entry <- Tokenizer.ReplacementMap) { 53 | currSentence = currSentence.replace(entry._1, entry._2); 54 | } 55 | // Split out suffixes 56 | for (entry <- Tokenizer.SuffixesMap) { 57 | currSentence = currSentence.replace(entry._1, entry._2); 58 | } 59 | // Seems like the tokenizer doesn't do this 60 | // for (entry <- Tokenizer.CompoundsMap) { 61 | // currSentence = currSentence.replace(entry._1, entry._2); 62 | // } 63 | currSentence = currSentence.replaceAll("([^'])' ", "$1 ' "); 64 | currSentence = currSentence.replaceAll(" '([^'\\s])", " ' $1"); 65 | currSentence = currSentence.replaceAll("([^\\s])'([sSmMdD])", "$1 '$2 "); 66 | currSentence.trim.split("\\s+"); 67 | } 68 | } 69 | 70 | object Tokenizer { 71 | val PuncSymbols = Set("?", "!", ",", ";", ":", "@", "#", "$", "%", "&"); 72 | val BracketSymbols = Set("(", ")", "[", "]", "{", "}"); 73 | val AllSymbols = PuncSymbols ++ BracketSymbols; 74 | 75 | val ReplacementMap = HashMap("(" -> "-LRB-", 76 | ")" -> "-RRB-", 77 | "[" -> "-LSB-", 78 | "]" -> "-RSB-", 79 | "{" -> "-LCB-", 80 | "}" -> "-RCB-"); 81 | 82 | val SuffixesMap = HashMap("'ll " -> " 'll ", 83 | "'re " -> " 're ", 84 | "'ve " -> " 've ", 85 | "n't " -> " n't "); 86 | SuffixesMap ++= SuffixesMap.map(entry => entry._1.toUpperCase -> entry._2.toUpperCase); 87 | 88 | val CompoundsMap = HashMap(" Cannot " -> " Can not ", 89 | " D'ye " -> " D' ye ", 90 | " Gimme " -> " Gim me ", 91 | " Gonna " -> " Gon na ", 92 | " Gotta " -> " Got ta ", 93 | " Lemme " -> " Lem me ", 94 | " More'n " -> " More 'n ", 95 | " 'Tis " -> " 'T is ", 96 | " 'Twas " -> " 'T was ", 97 | " Wanna " -> " Wan na "); 98 | CompoundsMap ++= CompoundsMap.map(entry => entry._1.toLowerCase -> entry._2.toLowerCase); 99 | } 100 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/preprocess/TokenizerTest.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.preprocess 2 | 3 | import scala.collection.JavaConverters._ 4 | import edu.berkeley.nlp.futile.tokenizer.PTBLineLexer 5 | import scala.io.Source 6 | import scala.collection.mutable.ArrayBuffer 7 | 8 | object TokenizerTest { 9 | 10 | 11 | // def main(args: Array[String]) { 12 | // val tokenizer = new PTBLineLexer(); 13 | // val s1 = "In pp collisions at a centre-of-mass energy s=7 TeV the emergence of similar long-range (2<|Δη|<4) near-side (Δφ≈0) correlations was reported in events with significantly higher-than-average particle multiplicity [21]."; 14 | // val s2 = "The coefficient v2p is significantly lower than v2π for 0.52.5 GeV/c." 15 | // println(tokenizer.tokenize(s1).asScala.reduce(_ + "::" + _)); 16 | // println(tokenizer.tokenize(s2).asScala.reduce(_ + "::" + _)); 17 | // } 18 | // 19 | // 20 | 21 | def main(args: Array[String]) = { 22 | // val str = "a.) "; 23 | // println(str.replaceAll("\\.(['\"\\)}\\]])? $", " . $1 ")); 24 | // System.exit(0); 25 | 26 | // val str = "a.'' "; 27 | // println(str.replaceAll("\\.(['\"\\)}\\]]|(''))? $", " . $1 ")); 28 | // System.exit(0); 29 | 30 | 31 | val tok1 = new CustomPTBTokenizer; 32 | // To test: possessives (plural and singular), quotes, quote + punc, all punc 33 | // println(tok.tokenize()); 34 | // println(" D'ye ".replaceAll("D'ye", " D' ye")); 35 | // var currSentence = "Jesus' thing and I'm I'd"; 36 | // currSentence = currSentence.replaceAll("([^'])' ", "$1 ' "); 37 | // currSentence = currSentence.replaceAll("'([sSmMdD])", " '$1 "); 38 | // var currSentence = "((("; 39 | // currSentence = currSentence.replace("(", "-LRB-"); 40 | // 41 | // println("things.\" ".replaceAll(".(['\"\\)}\\]]) ", " . $1 ")); 42 | // println("things.' ".replaceAll(".(['\"\\)}\\]]) ", " . $1 ")); 43 | // println("things.+ ".replaceAll(".(['\"\\)}\\]]) ", " . $1 ")); 44 | 45 | val tok2 = new PTBLineLexer; 46 | var count = 0; 47 | 48 | val splitter = SentenceSplitter.loadSentenceSplitter("models/sentsplit.txt.gz"); 49 | val lines = Source.fromFile("data/ace-tokenization-test-2k.txt").getLines.toSeq; 50 | val paras = splitter.formCanonicalizedParagraphs(lines.toArray, false, true); 51 | val sentences = splitter.splitSentences(paras) 52 | println(paras.size + " paras, " + sentences.size + " sents"); 53 | // val sent = "This is a test. No it's really a test. Like actually a test. Like what."; 54 | // println(splitter.splitSentences(Array(sent)).toSeq); 55 | for (sentence <- sentences) { 56 | // println(sentence); 57 | val tokenized1 = new ArrayBuffer[String] ++ tok1.tokenize(sentence); 58 | val tokenized2 = new ArrayBuffer[String] ++ tok2.tokenize(sentence).asScala; 59 | if (tokenized1.sameElements(tokenized2)) { 60 | count += 1; 61 | } else { 62 | println(tokenized1) 63 | println(tokenized2); 64 | } 65 | } 66 | println(count); 67 | 68 | 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/sem/AbbreviationHandler.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.sem 2 | 3 | import edu.berkeley.nlp.entity.coref.DocumentGraph 4 | 5 | object AbbreviationHandler { 6 | 7 | def isAbbreviation(docGraph: DocumentGraph, antIdx: Int, mentIdx: Int): Boolean = { 8 | val antWords = docGraph.getMention(antIdx).words; 9 | val mentWords = docGraph.getMention(antIdx).words; 10 | isAbbreviation(docGraph.getMention(antIdx).headString, docGraph.getMention(mentIdx).words, docGraph.getMention(mentIdx).headIdx - docGraph.getMention(mentIdx).startIdx) || 11 | isAbbreviation(docGraph.getMention(mentIdx).headString, docGraph.getMention(antIdx).words, docGraph.getMention(antIdx).headIdx - docGraph.getMention(antIdx).startIdx) 12 | } 13 | 14 | def testAndCanonicalizeAbbreviation(abbrev: String) = { 15 | if (abbrev.size >= 2 && abbrev.size <= 4) { 16 | val abbrevFiltered = abbrev.filter(c => c != '.'); 17 | if (abbrevFiltered.size > 0 && abbrevFiltered.map(c => Character.isUpperCase(c)).reduce(_ && _)) { 18 | abbrevFiltered; 19 | } else { 20 | ""; 21 | } 22 | } else { 23 | ""; 24 | } 25 | } 26 | 27 | def isAbbreviation(abbrev: String, other: Seq[String], headOffset: Int): Boolean = { 28 | val abbrevCanonical = testAndCanonicalizeAbbreviation(abbrev); 29 | if (abbrevCanonical == "") { 30 | false 31 | } else { 32 | val isAbbrev1 = isAbbreviationType1(abbrevCanonical, other, headOffset); 33 | val isAbbrev2 = isAbbreviationType2(abbrevCanonical, other, headOffset); 34 | val isAbbrev = isAbbrev1 || isAbbrev2; 35 | isAbbrev; 36 | } 37 | } 38 | 39 | def isAbbreviationType1(abbrev: String, other: Seq[String], headOffset: Int): Boolean = { 40 | // Abbreviation type 1: consecutive characters of words (African National Congress => ANC) containing the head. 41 | // TODO: Maybe don't uppercase (helps sometimes, hurts sometimes) 42 | val firstPossible = Math.max(0, headOffset - abbrev.size + 1); 43 | val lastPossible = Math.min(other.size, headOffset + abbrev.size); 44 | var isAbbrev = false; 45 | for (i <- firstPossible to lastPossible) { 46 | isAbbrev = isAbbrev || abbrev == other.slice(i, i + abbrev.size).map(word => Character.toUpperCase(word.charAt(0))).foldLeft("")(_ + _); 47 | } 48 | isAbbrev; 49 | } 50 | 51 | def isAbbreviationType2(abbrev: String, other: Seq[String], headOffset: Int): Boolean = { 52 | var otherFirstLettersNoLc = ""; 53 | var headIdxInFinal = -1; 54 | for (i <- 0 until other.size) { 55 | if (Character.isUpperCase(other(i).charAt(0))) { 56 | otherFirstLettersNoLc += other(i).charAt(0); 57 | } 58 | if (i == headOffset) { 59 | headIdxInFinal = otherFirstLettersNoLc.size - 1; 60 | } 61 | } 62 | val firstPossible = Math.max(0, headIdxInFinal - abbrev.size + 1); 63 | val lastPossible = Math.min(other.size, headIdxInFinal + abbrev.size); 64 | otherFirstLettersNoLc.slice(firstPossible, lastPossible).contains(abbrev); 65 | } 66 | 67 | def main(args: Array[String]) { 68 | println(isAbbreviation("ANC", Seq("African", "National", "Congress"), 2)); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/sem/BasicWordNetSemClasser.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.sem 2 | 3 | import edu.berkeley.nlp.entity.WordNetInterfacer 4 | import edu.berkeley.nlp.entity.coref.Mention 5 | 6 | @SerialVersionUID(1L) 7 | class BasicWordNetSemClasser extends SemClasser { 8 | def getSemClass(ment: Mention, wni: WordNetInterfacer): String = { 9 | SemClass.getSemClassNoNer(ment.headStringLc, wni).toString 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/sem/BrownClusterInterface.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.sem 2 | 3 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 4 | import scala.collection.mutable.HashMap 5 | import edu.berkeley.nlp.futile.util.Logger 6 | 7 | object BrownClusterInterface { 8 | 9 | def loadBrownClusters(path: String, cutoff: Int): Map[String,String] = { 10 | val wordsToClusters = new HashMap[String,String]; 11 | val iterator = IOUtils.lineIterator(path); 12 | while (iterator.hasNext) { 13 | val nextLine = iterator.next; 14 | val fields = nextLine.split("\\s+"); 15 | if (fields.size == 3 && fields(fields.size - 1).toInt >= cutoff) { 16 | wordsToClusters.put(fields(1), fields(0)); 17 | } 18 | } 19 | Logger.logss(wordsToClusters.size + " Brown cluster definitions read in"); 20 | wordsToClusters.toMap; 21 | } 22 | } -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/sem/GoogleNgramUtils.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.sem 2 | 3 | object GoogleNgramUtils { 4 | 5 | def fastAccessLine(line: String, fieldIdx: Int, gramSize: Int) = { 6 | if (fieldIdx == 0) { 7 | fastAccessFirst(line); 8 | } else if (fieldIdx == gramSize - 1) { 9 | fastAccessLast(line); 10 | } else { 11 | fastAccessLineHelper(line, fieldIdx); 12 | } 13 | } 14 | 15 | def fastAccessCount(line: String) = fastAccessCountDouble(line); 16 | 17 | def fastAccessCountDouble(line: String) = fastAccessCountString(line).toDouble; 18 | 19 | def fastAccessCountLong(line: String) = fastAccessCountString(line).toLong; 20 | 21 | def fastAccessCountString(line: String) = { 22 | var firstSpaceIdxFromEnd = line.size - 1; 23 | while (firstSpaceIdxFromEnd >= 0 && !Character.isWhitespace(line.charAt(firstSpaceIdxFromEnd))) { 24 | firstSpaceIdxFromEnd -= 1; 25 | } 26 | line.slice(firstSpaceIdxFromEnd + 1, line.size) 27 | } 28 | 29 | private def fastAccessLineHelper(line: String, fieldIdx: Int) = { 30 | var wordIdx = 0; 31 | var inSpace = false; 32 | var startIdx = 0; 33 | var endIdx = 0; 34 | var i = 0; 35 | while (i < line.size) { 36 | if (Character.isWhitespace(line.charAt(i))) { 37 | if (!inSpace) { 38 | inSpace = true; 39 | wordIdx += 1; 40 | if (wordIdx == fieldIdx + 1) { 41 | endIdx = i; 42 | } 43 | } 44 | } else { 45 | if (inSpace) { 46 | inSpace = false; 47 | if (wordIdx == fieldIdx) { 48 | startIdx = i; 49 | } 50 | } 51 | } 52 | i += 1; 53 | } 54 | line.slice(startIdx, endIdx); 55 | } 56 | 57 | private def fastAccessFirst(line: String) = { 58 | var firstSpaceIdx = 0; 59 | while (firstSpaceIdx < line.size && !Character.isWhitespace(line.charAt(firstSpaceIdx))) { 60 | firstSpaceIdx += 1; 61 | } 62 | line.slice(0, firstSpaceIdx); 63 | } 64 | 65 | private def fastAccessLast(line: String) = { 66 | // Go past space 67 | var firstSpaceIdxFromEnd = line.size - 1; 68 | var count = 0; 69 | while (firstSpaceIdxFromEnd >= 0 && !Character.isWhitespace(line.charAt(firstSpaceIdxFromEnd))) { 70 | firstSpaceIdxFromEnd -= 1; 71 | } 72 | var endOfWordIdx = firstSpaceIdxFromEnd - 1; 73 | while (endOfWordIdx >= 0 && Character.isWhitespace(line.charAt(endOfWordIdx))) { 74 | endOfWordIdx -= 1; 75 | } 76 | var beginningOfWordIdx = endOfWordIdx - 1; 77 | while (beginningOfWordIdx >= 0 && !Character.isWhitespace(line.charAt(beginningOfWordIdx))) { 78 | beginningOfWordIdx -= 1; 79 | } 80 | line.slice(beginningOfWordIdx + 1, endOfWordIdx + 1); 81 | } 82 | 83 | def main(args: Array[String]) { 84 | val line = "! " 1952"; 85 | println(fastAccessLine(line, 0, 3)); 86 | println(fastAccessLine(line, 1, 3)); 87 | println(fastAccessLine(line, 2, 3)); 88 | println(fastAccessCount(line)); 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/sem/QueryCountCollector.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.sem 2 | import scala.collection.mutable.ArrayBuffer 3 | import scala.collection.JavaConverters._ 4 | import edu.berkeley.nlp.entity.coref.CorefSystem 5 | import edu.berkeley.nlp.entity.coref.MentionType 6 | import edu.berkeley.nlp.futile.util.Counter 7 | import edu.berkeley.nlp.futile.util.Logger 8 | import scala.collection.mutable.HashSet 9 | import java.io.File 10 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 11 | import java.io.IOException 12 | import edu.berkeley.nlp.entity.sem.GoogleNgramUtils._ 13 | 14 | object QueryCountCollector { 15 | 16 | // TODO: Make this more efficient 17 | // --Some kind of primitive hash map with indexed ints instead of a Counter[(String,String)]? 18 | // TODO: Think about the casing here... 19 | def collectCounts(trainPath: String, trainSize: Int, testPath: String, testSize: Int, countsRootDir: String, ngramsPathFile: String) = { 20 | // Get the head pairs 21 | val docs = CorefSystem.loadCorefDocs(trainPath, trainSize, "auto_conll", null) ++ CorefSystem.loadCorefDocs(testPath, testSize, "auto_conll", null); 22 | val heads = new HashSet[String](); 23 | val headPairs = new HashSet[(String, String)]; 24 | for (doc <- docs; i <- 0 until doc.predMentions.size) { 25 | if (doc.predMentions(i).mentionType != MentionType.PRONOMINAL) { 26 | heads.add(doc.predMentions(i).headString); 27 | for (j <- 0 until i) { 28 | if (doc.predMentions(j).mentionType != MentionType.PRONOMINAL) { 29 | val first = doc.predMentions(j).headString; 30 | val second = doc.predMentions(i).headString; 31 | if (first != second) { 32 | // Logger.logss("Registering pair: " + first + ", " + second); 33 | headPairs += first -> second; 34 | } 35 | } 36 | } 37 | } 38 | } 39 | Logger.logss(heads.size + " distinct heads, " + headPairs.size + " distinct head pairs; some of them are " + headPairs.slice(0, Math.min(10, headPairs.size))); 40 | val headCounts = new Counter[String]; 41 | val headPairCounts = new Counter[(String,String)]; 42 | // Open the outfile early to fail fast 43 | val outWriter = IOUtils.openOutHard(ngramsPathFile) 44 | // Load the n-grams and count them 45 | // Iterate through all 3-grams and 4-grams 46 | try { 47 | var numLinesProcessed = 0; 48 | for (file <- new File(countsRootDir + "/1gms").listFiles) { 49 | Logger.logss("Processing " + file.getAbsolutePath); 50 | val lineIterator = IOUtils.lineIterator(file.getAbsolutePath()); 51 | while (lineIterator.hasNext) { 52 | countUnigram(lineIterator.next, heads, headCounts); 53 | numLinesProcessed += 1; 54 | } 55 | } 56 | Logger.logss(numLinesProcessed + " 1-grams processed"); 57 | numLinesProcessed = 0; 58 | for (file <- new File(countsRootDir + "/3gms").listFiles) { 59 | Logger.logss("Processing " + file.getAbsolutePath); 60 | val lineIterator = IOUtils.lineIterator(file.getAbsolutePath()); 61 | while (lineIterator.hasNext) { 62 | count(lineIterator.next, heads, headPairs, headPairCounts, 3); 63 | numLinesProcessed += 1; 64 | } 65 | } 66 | Logger.logss(numLinesProcessed + " 3-grams processed"); 67 | numLinesProcessed = 0; 68 | for (file <- new File(countsRootDir + "/4gms").listFiles) { 69 | Logger.logss("Processing " + file.getAbsolutePath); 70 | val lineIterator = IOUtils.lineIterator(file.getAbsolutePath()); 71 | while (lineIterator.hasNext) { 72 | count(lineIterator.next, heads, headPairs, headPairCounts, 4); 73 | numLinesProcessed += 1; 74 | } 75 | } 76 | Logger.logss(numLinesProcessed + " 4-grams processed"); 77 | } catch { 78 | case e: IOException => throw new RuntimeException(e); 79 | } 80 | // Write to file 81 | Logger.logss("Extracted counts for " + headCounts.size + " heads and " + headPairCounts.size + " head pairs"); 82 | for (word <- headCounts.keySet.asScala.toSeq.sorted) { 83 | val str = word + " " + headCounts.getCount(word).toInt 84 | outWriter.println(str); 85 | } 86 | for (pair <- headPairCounts.keySet.asScala.toSeq.sorted) { 87 | val str = pair._1 + " " + pair._2 + " " + headPairCounts.getCount(pair).toInt 88 | outWriter.println(str); 89 | // Logger.logss(str); 90 | } 91 | outWriter.close; 92 | } 93 | 94 | def countUnigram(line: String, heads: HashSet[String], headCounts: Counter[String]) { 95 | val word = fastAccessLine(line, 0, 1); 96 | if (heads.contains(word)) { 97 | headCounts.incrementCount(word, fastAccessCount(line)); 98 | } 99 | } 100 | 101 | def count(line: String, heads: HashSet[String], headPairs: HashSet[(String,String)], headPairCounts: Counter[(String, String)], gramSize: Int) { 102 | val firstWord = fastAccessLine(line, 0, gramSize); 103 | if (heads.contains(firstWord)) { 104 | val lastWord = fastAccessLine(line, gramSize - 1, gramSize); 105 | if (heads.contains(lastWord)) { 106 | val pair = firstWord -> lastWord; 107 | val pairFlipped = lastWord -> firstWord; 108 | if (headPairs.contains(pair) || headPairs.contains(pairFlipped)) { 109 | if (gramSize == 3) { 110 | val middleWordLc = fastAccessLine(line, 1, gramSize).toLowerCase; 111 | if (middleWordLc == "is" || middleWordLc == "are" || middleWordLc == "was" || middleWordLc == "were") { 112 | // Logger.logss("Matched a pattern: " + line); 113 | headPairCounts.incrementCount(pair, fastAccessCount(line)) 114 | headPairCounts.incrementCount(pairFlipped, fastAccessCount(line)) 115 | } 116 | } 117 | else if (gramSize == 4) { 118 | val secondWordLc = fastAccessLine(line, 1, gramSize).toLowerCase; 119 | val thirdWordLc = fastAccessLine(line, 2, gramSize).toLowerCase; 120 | if ((secondWordLc == "is" || secondWordLc == "are" || secondWordLc == "was" || secondWordLc == "were") 121 | && (thirdWordLc == "a" || thirdWordLc == "an" || thirdWordLc == "the")) { 122 | // Logger.logss("Matched a pattern: " + line); 123 | headPairCounts.incrementCount(pair, fastAccessCount(line)) 124 | headPairCounts.incrementCount(pairFlipped, fastAccessCount(line)) 125 | } 126 | } 127 | } 128 | } 129 | } 130 | 131 | } 132 | 133 | } 134 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/sem/QueryCountsBundle.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.sem 2 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 3 | import edu.berkeley.nlp.futile.util.Counter 4 | import edu.berkeley.nlp.futile.util.Logger 5 | import java.io.File 6 | 7 | @SerialVersionUID(1L) 8 | class QueryCountsBundle(val wordCounts: Counter[String], 9 | val pairCounts: Counter[(String,String)]) extends Serializable { 10 | } 11 | 12 | object QueryCountsBundle { 13 | 14 | def createFromFile(path: String) = { 15 | val wordCounts = new Counter[String]; 16 | val pairCounts = new Counter[(String,String)]; 17 | val cleanedPath = if (path != path.trim) { 18 | Logger.logss("WARNING: queryCountsFile has spurious spaces for some inexplicable reason; trimming"); 19 | path.trim; 20 | } else { 21 | path; 22 | } 23 | val lineItr = IOUtils.lineIterator(cleanedPath); 24 | while (lineItr.hasNext) { 25 | val line = lineItr.next; 26 | val fields = line.split("\\s+"); 27 | if (fields.size == 2) { 28 | wordCounts.incrementCount(fields(0), fields(1).toDouble); 29 | } else if (fields.size == 3) { 30 | pairCounts.incrementCount(fields(0) -> fields(1), fields(2).toDouble); 31 | } 32 | } 33 | Logger.logss("Loaded " + pairCounts.size + " query counts from " + path); 34 | new QueryCountsBundle(wordCounts, pairCounts); 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/sem/SemClass.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.sem 2 | 3 | import edu.berkeley.nlp.entity.WordNetInterfacer 4 | import edu.berkeley.nlp.entity.coref.Mention 5 | import edu.mit.jwi.item.ISynsetID 6 | import edu.berkeley.nlp.entity.coref.NumberGenderComputer 7 | import edu.berkeley.nlp.entity.coref.CorefSystem 8 | import edu.berkeley.nlp.entity.coref.DocumentGraph 9 | import java.util.regex.Pattern 10 | import edu.berkeley.nlp.futile.util.Counter 11 | import java.io.File 12 | import scala.collection.mutable.HashMap 13 | import edu.mit.jwi.item.ISynset 14 | 15 | object SemClass extends Enumeration { 16 | type SemClass = Value 17 | val Person, Location, Organization, Date, Event, Other = Value 18 | 19 | // These do slightly less well than well-trained NER tags 20 | // val DateWords = Set("today", "yesterday", "tomorrow", "day", "days", "week", "weeks", "month", "months", "year", "years", 21 | // "january", "february", "march", "april", "may", "june", "july", "august", "september", "october", "november", "december"); 22 | // val DatePattern = Pattern.compile("[0-9]{4}"); 23 | 24 | def getSemClass(headStringLc: String, wni: WordNetInterfacer): SemClass = getSemClass(headStringLc, "", wni); 25 | 26 | def getSemClassNoNer(headStringLc: String, wni: WordNetInterfacer): SemClass = getSemClass(headStringLc, "", wni); 27 | 28 | def getSemClassNoNer(synsets: Seq[ISynset], wni: WordNetInterfacer): SemClass = { 29 | if (wni.isAnySynsetHypernym(synsets, wni.personSynset, 10)) { 30 | SemClass.Person; 31 | } else if (wni.isAnySynsetHypernym(synsets, wni.locationSynset, 10)) { 32 | SemClass.Location; 33 | } else if (wni.isAnySynsetHypernym(synsets, wni.organizationSynset, 10)) { 34 | SemClass.Organization; 35 | } else { 36 | SemClass.Other; 37 | } 38 | } 39 | 40 | def getSemClassOnlyNer(nerString: String): SemClass = { 41 | if (nerString == "PERSON") { 42 | SemClass.Person 43 | } else if (nerString == "LOC") { 44 | SemClass.Location 45 | } else if (nerString == "ORG" || nerString == "GPE") { 46 | SemClass.Organization 47 | } else if (nerString == "DATE" || nerString == "TIME") { 48 | SemClass.Date; 49 | } else { 50 | SemClass.Other; 51 | } 52 | } 53 | 54 | def getStrSemClassOnlyNerFiner(nerString: String): String = { 55 | if (nerString == "PERSON" || nerString == "LOC" || nerString == "ORG" || nerString == "GPE" || nerString == "NORP" || nerString == "DATE" || nerString == "CARDINAL") { 56 | nerString 57 | } else { 58 | "OTHER" 59 | } 60 | } 61 | 62 | def getSemClass(headStringLc: String, nerString: String, wni: WordNetInterfacer): SemClass = { 63 | if (wni.isAnySynsetHypernym(headStringLc, wni.personSynset) || nerString == "PERSON") { 64 | SemClass.Person; 65 | } else if (wni.isAnySynsetHypernym(headStringLc, wni.locationSynset) || nerString == "LOC") { 66 | SemClass.Location; 67 | } else if (wni.isAnySynsetHypernym(headStringLc, wni.organizationSynset) || nerString == "ORG" || nerString == "GPE") { 68 | SemClass.Organization; 69 | // Event seems unhelpful based on an experiment, and it makes for even more features 70 | // } else if (wni.isAnySynsetHypernym(headStringLc, wni.eventSynset) || nerString == "EVENT") { 71 | // SemClass.Event; 72 | } else if (nerString == "DATE" || nerString == "TIME") { 73 | SemClass.Date; 74 | } else { 75 | SemClass.Other; 76 | } 77 | } 78 | 79 | def isDate(headStringLc: String) = false //DateWords.contains(headStringLc) || DatePattern.matcher(headStringLc).matches(); 80 | 81 | } 82 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/sem/SemClasser.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.sem 2 | 3 | import edu.berkeley.nlp.entity.coref.Mention 4 | import edu.berkeley.nlp.entity.WordNetInterfacer 5 | import edu.berkeley.nlp.entity.coref.CorefDoc 6 | 7 | trait SemClasser extends Serializable { 8 | // We only bother to define these for NOMINAL and PROPER mentions; it shouldn't be 9 | // called for anything else 10 | def getSemClass(ment: Mention, wni: WordNetInterfacer): String; 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/sig/BootstrapDriverNER.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.sig 2 | 3 | import edu.berkeley.nlp.entity.ner.NEEvaluator 4 | import edu.berkeley.nlp.entity.coref.CorefSystem 5 | import edu.berkeley.nlp.entity.ConllDocReader 6 | import edu.berkeley.nlp.entity.lang.Language 7 | import edu.berkeley.nlp.entity.Chunk 8 | 9 | object BootstrapDriverNER { 10 | 11 | def main(args: Array[String]) { 12 | val goldPath = args(0); 13 | val worseFilePath = args(1); 14 | val betterFilePath = args(2); 15 | val goldDocs = ConllDocReader.loadRawConllDocsWithSuffix(goldPath, -1, "gold_conll", Language.ENGLISH); 16 | // val sentences = goldDocs.flatMap(_.words) 17 | // val worseChunks = NEEvaluator.readIllinoisNEROutput(worseFilePath, sentences) 18 | // val betterChunks = NEEvaluator.readIllinoisNEROutput(betterFilePath, sentences) 19 | val worseChunks = NEEvaluator.readIllinoisNEROutputSoft(worseFilePath) 20 | val betterChunks = NEEvaluator.readIllinoisNEROutputSoft(betterFilePath) 21 | val goldChunks = goldDocs.flatMap(_.nerChunks); 22 | 23 | val worseSuffStats = convertToSuffStats(goldChunks, worseChunks); 24 | val betterSuffStats = convertToSuffStats(goldChunks, betterChunks); 25 | 26 | BootstrapDriver.printSimpleBootstrapPValue(worseSuffStats, betterSuffStats, new F1Computer(0, 1, 0, 2)) 27 | } 28 | 29 | def convertToSuffStats(goldChunks: Seq[Seq[Chunk[String]]], predChunks: Seq[Seq[Chunk[String]]]): Seq[Seq[Double]] = { 30 | for (i <- 0 until goldChunks.size) yield { 31 | Seq(predChunks(i).filter(chunk => goldChunks(i).contains(chunk)).size.toDouble, predChunks(i).size.toDouble, goldChunks(i).size.toDouble); 32 | } 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/sig/BootstrapDriverWiki.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.sig 2 | 3 | object BootstrapDriverWiki { 4 | 5 | def main(args: Array[String]) = { 6 | val worsePath = args(0); 7 | val betterPath = args(1); 8 | 9 | // TODO: Compute whatever I care about... 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/sig/MetricComputer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.sig 2 | 3 | import edu.berkeley.nlp.entity.GUtil 4 | import edu.berkeley.nlp.futile.util.Logger 5 | 6 | trait MetricComputer { 7 | 8 | def isSigDifference(origWorse: Seq[Seq[Double]], origBetter: Seq[Seq[Double]], resampledIndices: Seq[Int]): Boolean = { 9 | val origBetterScore = computeMetric(origBetter); 10 | val origWorseScore = computeMetric(origWorse); 11 | val origDiff = origBetterScore - origWorseScore; 12 | if (origDiff < 0) { 13 | false; 14 | } else { 15 | val newBetterScore = computeMetric(origBetter, resampledIndices); 16 | val newWorseScore = computeMetric(origWorse, resampledIndices); 17 | val newDiff = newBetterScore - newWorseScore; 18 | var sig = newDiff < 2 * origDiff; 19 | // println(GUtil.fmtTwoDigitNumber(origWorseScore, 2) + " " + GUtil.fmtTwoDigitNumber(origBetterScore, 2) + " " + 20 | // GUtil.fmtTwoDigitNumber(newWorseScore, 2) + " " + GUtil.fmtTwoDigitNumber(newBetterScore, 2)) 21 | // newDiff > 0 22 | sig; 23 | } 24 | } 25 | 26 | def computeMetric(results: Seq[Seq[Double]]): Double = computeMetric(results, 0 until results.size); 27 | 28 | def computeMetric(results: Seq[Seq[Double]], idxList: Seq[Int]): Double; 29 | 30 | def computeMetricFull(results: Seq[Seq[Double]]): Seq[Double] = computeMetricFull(results, 0 until results.size); 31 | 32 | // Supports returning multiple metric values so you can retrieve precision, recall, 33 | // and F1 for metrics that have that structure 34 | def computeMetricFull(results: Seq[Seq[Double]], idxList: Seq[Int]): Seq[Double]; 35 | } 36 | 37 | object MetricComputer { 38 | 39 | def fmtMetricValues(suffStats: Seq[Double]): String = { 40 | val strList = suffStats.map(entry => GUtil.fmtTwoDigitNumber(entry, 2)); 41 | strList.foldLeft("")((str, entry) => str + " & " + entry) 42 | } 43 | } 44 | 45 | class F1Computer(val precNumIdx: Int, val precDenomIdx: Int, val recNumIdx: Int, val recDenomIdx: Int) extends MetricComputer { 46 | 47 | def computeMetric(results: Seq[Seq[Double]], idxList: Seq[Int]): Double = { 48 | computeMetricFull(results, idxList)(2); 49 | } 50 | 51 | def computeMetricFull(results: Seq[Seq[Double]], idxList: Seq[Int]): Seq[Double] = { 52 | require(results.size >= 1); 53 | require(results(0).size > precNumIdx && results(0).size > precDenomIdx && results(0).size > recNumIdx && results(0).size > recDenomIdx) 54 | var totalPrecNum = idxList.foldLeft(0.0)((total, idx) => total + results(idx)(precNumIdx)); 55 | var totalPrecDenom = idxList.foldLeft(0.0)((total, idx) => total + results(idx)(precDenomIdx)); 56 | var totalRecNum = idxList.foldLeft(0.0)((total, idx) => total + results(idx)(recNumIdx)); 57 | var totalRecDenom = idxList.foldLeft(0.0)((total, idx) => total + results(idx)(recDenomIdx)); 58 | val prec = totalPrecNum/totalPrecDenom; 59 | val rec = totalRecNum/totalRecDenom; 60 | Seq(prec * 100.0, rec * 100.0, 2 * prec * rec/(prec + rec) * 100.0); 61 | } 62 | 63 | } 64 | 65 | object F1Computer { 66 | def apply() = new F1Computer(0, 1, 0, 2); // correct, total pred, total gold 67 | } 68 | 69 | class AccuracyComputer extends MetricComputer { 70 | 71 | def computeMetric(results: Seq[Seq[Double]], idxList: Seq[Int]): Double = { 72 | computeMetricFull(results, idxList)(0); 73 | } 74 | 75 | def computeMetricFull(results: Seq[Seq[Double]], idxList: Seq[Int]): Seq[Double] = { 76 | Seq(idxList.map(results(_)(0)).foldLeft(0.0)(_ + _) / idxList.map(results(_)(1)).foldLeft(0.0)(_ + _)); 77 | } 78 | 79 | } 80 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/wiki/ACETester.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.wiki 2 | 3 | import scala.collection.mutable.HashMap 4 | import edu.berkeley.nlp.entity.ConllDocReader 5 | import edu.berkeley.nlp.entity.coref.CorefDocAssembler 6 | import edu.berkeley.nlp.entity.GUtil 7 | import edu.berkeley.nlp.entity.coref.MentionPropertyComputer 8 | import edu.berkeley.nlp.entity.lang.Language 9 | import edu.berkeley.nlp.futile.LightRunner 10 | import edu.berkeley.nlp.futile.util.Logger 11 | import edu.berkeley.nlp.futile.fig.basic.Indexer 12 | import scala.collection.mutable.ArrayBuffer 13 | import edu.berkeley.nlp.entity.Chunk 14 | 15 | object ACETester { 16 | 17 | // Command line options 18 | val dataPath = "data/ace05/ace05-all-conll" 19 | val wikiDBPath = "models/wiki-db-ace.ser.gz" 20 | val wikiPath = "data/ace05/ace05-all-conll-wiki" 21 | val useFancyQueryChooser = false; 22 | 23 | def main(args: Array[String]) { 24 | LightRunner.initializeOutput(ACETester.getClass()); 25 | LightRunner.populateScala(ACETester.getClass(), args) 26 | val docs = ConllDocReader.loadRawConllDocsWithSuffix(dataPath, -1, "", Language.ENGLISH); 27 | // val goldWikification = GUtil.load(wikiAnnotsPath).asInstanceOf[CorpusWikiAnnots]; 28 | 29 | val goldWikification = WikiAnnotReaderWriter.readStandoffAnnotsAsCorpusAnnots(wikiPath) 30 | 31 | // Detect mentions, which depend on the NER coarse pass 32 | val assembler = CorefDocAssembler(Language.ENGLISH, true); 33 | val corefDocs = docs.map(doc => assembler.createCorefDoc(doc, new MentionPropertyComputer(None))); 34 | 35 | // This does super, super well but is probably cheating 36 | // val wikiDB = GUtil.load(wikiDBPath).asInstanceOf[WikipediaInterface]; 37 | // val trainDataPath = "data/ace05/train"; 38 | // val trainDocs = ConllDocReader.loadRawConllDocsWithSuffix(trainDataPath, -1, "", Language.ENGLISH); 39 | // val trainCorefDocs = trainDocs.map(doc => assembler.createCorefDoc(doc, new MentionPropertyComputer(None))); 40 | // val wikifier = new BasicWikifier(wikiDB, Some(trainCorefDocs), Some(goldWikification)); 41 | 42 | val queryChooser = if (useFancyQueryChooser) { 43 | GUtil.load("models/querychooser.ser.gz").asInstanceOf[QueryChooser] 44 | } else { 45 | val fi = new Indexer[String]; 46 | fi.getIndex("FirstNonempty=true"); 47 | fi.getIndex("FirstNonempty=false"); 48 | new QueryChooser(fi, Array(1F, -1F)) 49 | } 50 | 51 | 52 | val wikiDB = GUtil.load(wikiDBPath).asInstanceOf[WikipediaInterface]; 53 | val wikifier = new BasicWikifier(wikiDB, Some(queryChooser)); 54 | 55 | // val wikiDB = GUtil.load(wikiDBPath).asInstanceOf[WikipediaInterface]; 56 | // val aceHeads = ACEMunger.mungeACEToGetHeads("data/ace05/ace05-all-copy"); 57 | // val wikifier = new BasicWikifier(wikiDB, None, None, Some(aceHeads)); 58 | 59 | // val wikifier: Wikifier = FahrniWikifier.readFahrniWikifier("data/wikipedia/lex.anchor.lowAmbiguity-resolved", 60 | // "data/wikipedia/simTerms"); 61 | 62 | var recalled = 0; 63 | for (corefDoc <- corefDocs) { 64 | val docName = corefDoc.rawDoc.docID 65 | for (i <- 0 until corefDoc.predMentions.size) { 66 | val ment = corefDoc.predMentions(i); 67 | val goldLabel = getGoldWikification(goldWikification(docName), ment) 68 | if (goldLabel.size >= 1 && goldLabel(0) != NilToken) { 69 | wikifier.oracleWikify(docName, ment, goldLabel); 70 | val myTitles = wikifier.wikifyGetTitleSet(docName, ment); 71 | if (containsCorrect(goldLabel, myTitles)) { 72 | recalled += 1; 73 | } 74 | wikifier 75 | } else if (goldLabel.size == 1 && goldLabel(0) == NilToken) { 76 | wikifier.oracleWikifyNil(docName, ment); 77 | } 78 | } 79 | } 80 | Logger.logss("Recalled: " + recalled); 81 | wikifier.printDiagnostics(); 82 | LightRunner.finalizeOutput(); 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/wiki/BlikiInterface.java: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.wiki; 2 | 3 | import info.bliki.wiki.filter.PlainTextConverter; 4 | import info.bliki.wiki.model.WikiModel; 5 | import edu.berkeley.nlp.futile.util.Logger; 6 | 7 | 8 | public class BlikiInterface { 9 | 10 | public static String renderPlaintext(String text) { 11 | // Remove references or they will endure after processing 12 | String normalizedText = removeReferences(normalizeGtLtQuot(text)); 13 | WikiModel wikiModel = new WikiModel("http://www.mywiki.com/wiki/${image}", "http://www.mywiki.com/wiki/${title}"); 14 | String plainText = wikiModel.render(new PlainTextConverter(), normalizedText); 15 | return removeCurlyBrackets(plainText); 16 | } 17 | 18 | public static String normalizeGtLtQuot(String line) { 19 | return line.replaceAll("<", "<").replaceAll(">", ">").replaceAll(""", "\""); 20 | } 21 | 22 | public static String removeComments(String line) { 23 | return removeDelimitedContent(line, ""); 24 | } 25 | 26 | public static String removeReferences(String line) { 27 | return removeDelimitedContent(line, ""); 28 | } 29 | 30 | public static String removeSquareBrackets(String line) { 31 | return removeDelimitedContent(line, "[[", "]]"); 32 | } 33 | 34 | public static String removeCurlyBrackets(String line) { 35 | return removeDelimitedContent(line, "{{", "}}"); 36 | } 37 | 38 | public static String removeParentheticals(String line) { 39 | return removeDelimitedContent(line, "(", ")"); 40 | } 41 | 42 | private static String removeDelimitedContent(String line, String startDelim, String endDelim) { 43 | String newLine = line; 44 | int endIdx = line.indexOf(endDelim) + endDelim.length(); 45 | int startIdx = line.lastIndexOf(startDelim, endIdx - endDelim.length()); 46 | // System.out.println(startIdx + " " + endIdx); 47 | while (startIdx >= 0 && endIdx >= 0 && endIdx > startIdx) { 48 | newLine = newLine.substring(0, startIdx) + newLine.substring(endIdx); 49 | endIdx = newLine.indexOf(endDelim) + endDelim.length(); 50 | startIdx = newLine.lastIndexOf(startDelim, endIdx - endDelim.length()); 51 | } 52 | return newLine; 53 | } 54 | 55 | 56 | // public static final String TEST = "This is a [[Hello World]] '''example'''"; 57 | public static final String TEST = "'''Autism''' is a [[Neurodevelopmental disorder|disorder of neural development]] characterized by impaired [[Interpersonal relationship|social interaction]] and [[verbal communication|verbal]] and [[non-verbal communication]], and by restricted, repetitive or [[stereotypy|stereotyped]] behavior."; 58 | 59 | public static void main(String[] args) { 60 | WikiModel wikiModel = new WikiModel("http://www.mywiki.com/wiki/${image}", "http://www.mywiki.com/wiki/${title}"); 61 | String plainStr = wikiModel.render(new PlainTextConverter(), TEST); 62 | System.out.print(plainStr); 63 | } 64 | 65 | } 66 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/wiki/Query.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.wiki 2 | 3 | import scala.collection.JavaConverters._ 4 | import scala.collection.mutable.ArrayBuffer 5 | import edu.berkeley.nlp.entity.coref.Mention 6 | import edu.berkeley.nlp.futile.util.Logger 7 | import edu.berkeley.nlp.futile.util.Counter 8 | 9 | /** 10 | * Simple data structure to store information about a query to the Wikipedia 11 | * title given surface database formed from a particular mention. 12 | * 13 | * @author gdurrett 14 | */ 15 | case class Query(val words: Seq[String], 16 | val originalMent: Mention, 17 | val finalSpan: (Int, Int), 18 | val queryType: String, 19 | val removePuncFromQuery: Boolean = true) { 20 | 21 | def getFinalQueryStr = { 22 | val wordsNoPunc = if (removePuncFromQuery) { 23 | words.map(str => str.filter(c => !Query.PuncList.contains(c))).filter(!_.isEmpty); 24 | } else { 25 | words; 26 | } 27 | if (wordsNoPunc.isEmpty) "" else wordsNoPunc.reduce(_ + " " + _); 28 | } 29 | } 30 | 31 | object Query { 32 | 33 | def makeNilQuery(ment: Mention) = { 34 | new Query(Seq[String]("XXNILXX"), ment, (ment.headIdx + 1, ment.headIdx + 1), "NIL"); 35 | } 36 | 37 | // These parameter settings have been tuned to give best performance on query extraction 38 | // for ACE, so are probably good there but might need to be revisited in other settings. 39 | val CapitalizationQueryExpand = false; 40 | val PluralQueryExpand = true; 41 | val RemovePuncFromQuery = true; 42 | val UseFirstHead = true; 43 | val MaxQueryLen = 4; 44 | val BlackList = Set("the", "a", "my", "your", "his", "her", "our", "their", "its", "this", "that", "these", "those") 45 | val PuncList = Set(',', '.', '!', '?', ':', ';', '\'', '"', '(', ')', '[', ']', '{', '}', ' '); 46 | 47 | /** 48 | * Check if a token is "blacklisted", meaning that we shouldn't form a query that starts with 49 | * it (such queries tend to do weird and bad things 50 | */ 51 | def isBlacklisted(word: String, mentStartIdx: Int) = { 52 | BlackList.contains(word) || (mentStartIdx == 0 && BlackList.contains(word.toLowerCase)); 53 | } 54 | 55 | /** 56 | * Very crappy stemmer 57 | */ 58 | def removePlural(word: String) = { 59 | if (word.endsWith("sses")) { 60 | word.dropRight(2); 61 | } else if (word.endsWith("ies")) { 62 | // Not quite right... 63 | word.substring(0, word.size - 3) + "y"; 64 | } else if (word.endsWith("s")) { 65 | word.dropRight(1); 66 | } else { 67 | word; 68 | } 69 | } 70 | 71 | /** 72 | * Given a mention, extracts the set of possible queries that we'll consider. This is done by 73 | * considering different subsets of the words in the mention and munging capitalization and 74 | * stemming, since lowercasing and dropping a plural-marking "s" are useful for nominals. 75 | */ 76 | def extractQueriesBest(ment: Mention, addNilQuery: Boolean = false): Seq[Query] = { 77 | val queries = new ArrayBuffer[Query]; 78 | val mentWords = ment.words; 79 | // Try the whole query, then prefixes ending in the head 80 | val relHeadIdx = (if (UseFirstHead) ment.contextTree.getSpanHeadACECustom(ment.startIdx, ment.endIdx) else ment.headIdx) - ment.startIdx; 81 | val indicesToTry = (Seq((0, mentWords.size)) ++ (0 to relHeadIdx).map(i => (i, relHeadIdx + 1))).filter(indices => { 82 | indices._2 - indices._1 == 1 || !isBlacklisted(mentWords(indices._1), ment.startIdx); 83 | }).filter(indices => indices._2 - indices._1 > 0 && indices._2 - indices._1 <= MaxQueryLen).distinct; 84 | for (indices <- indicesToTry) { 85 | // Query the full thing as is 86 | val queriesThisSlice = new ArrayBuffer[Query]; 87 | val query = new Query(mentWords.slice(indices._1, indices._2), ment, indices, "STD", RemovePuncFromQuery); 88 | val firstWord = mentWords(indices._1); 89 | val lastWord = mentWords(indices._2 - 1); 90 | queriesThisSlice += query; 91 | // Handle capitalization: if the first word does not have any uppercase characters 92 | if (!firstWord.map(Character.isUpperCase(_)).reduce(_ || _) && Character.isLowerCase(firstWord(0))) { 93 | queriesThisSlice += new Query(Seq(wikiCase(firstWord)) ++ mentWords.slice(indices._1 + 1, indices._2), ment, indices, "WIKICASED", RemovePuncFromQuery); 94 | } 95 | // Stemming (but only on head alone) 96 | if (PluralQueryExpand && (indices._2 - indices._1) == 1 && firstWord.last == 's') { 97 | queriesThisSlice ++= queriesThisSlice.map(query => new Query(Seq(removePlural(query.words(0))), ment, indices, query.queryType + "-STEM", RemovePuncFromQuery)); 98 | } 99 | queries ++= queriesThisSlice; 100 | } 101 | // Finally, strip punctuation from queries; we don't do this earlier because it makes it hard 102 | // to find the head 103 | // val finalQueries = if (RemovePuncFromQuery) { 104 | // queries.map(_.map(str => str.filter(c => !PuncList.contains(c))).filter(!_.isEmpty)).filter(!_.isEmpty) 105 | // } else { 106 | // queries; 107 | // } 108 | queries.filter(!_.getFinalQueryStr.isEmpty) ++ (if (addNilQuery) Seq(Query.makeNilQuery(ment)) else Seq[Query]()); 109 | } 110 | 111 | def extractDenotationSetWithNil(queries: Seq[Query], queryDisambigs: Seq[Counter[String]], maxDenotations: Int): Seq[String] = { 112 | val choicesEachQuery = queryDisambigs.map(_.getSortedKeys().asScala); 113 | val optionsAndPriorities = (0 until queryDisambigs.size).flatMap(i => { 114 | val sortedKeys = queryDisambigs(i).getSortedKeys().asScala 115 | (0 until sortedKeys.size).map(j => (sortedKeys(j), j * 1000 + i)); 116 | }); 117 | // choicesEachQuery.foreach(Logger.logss(_)); 118 | // Logger.logss(optionsAndPriorities); 119 | val allFinalOptions = Seq(NilToken) ++ optionsAndPriorities.sortBy(_._2).map(_._1).distinct; 120 | val finalOptionsTruncated = allFinalOptions.slice(0, Math.min(allFinalOptions.size, maxDenotations)); 121 | // Logger.logss(finalOptions); 122 | finalOptionsTruncated; 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/wiki/WikiAnnotReaderWriter.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.wiki 2 | 3 | import scala.collection.mutable.HashMap 4 | import edu.berkeley.nlp.entity.Chunk 5 | import java.io.PrintWriter 6 | import edu.berkeley.nlp.entity.ConllDocWriter 7 | import scala.collection.mutable.ArrayBuffer 8 | import edu.berkeley.nlp.entity.ConllDocReader 9 | import edu.berkeley.nlp.entity.coref.UID 10 | import java.io.File 11 | import edu.berkeley.nlp.futile.util.Logger 12 | 13 | object WikiAnnotReaderWriter { 14 | 15 | // This can never be in a Wikipedia article title 16 | val WikiTitleSeparator = "|"; 17 | 18 | def readAllStandoffAnnots(dirName: String): HashMap[UID,Seq[Seq[Chunk[Seq[String]]]]] = { 19 | val standoffAnnots = new HashMap[UID,Seq[Seq[Chunk[Seq[String]]]]] 20 | for (file <- new File(dirName).listFiles) { 21 | standoffAnnots ++= readStandoffAnnots(file.getAbsolutePath) 22 | } 23 | standoffAnnots 24 | } 25 | 26 | def readStandoffAnnots(fileName: String): HashMap[UID,Seq[Seq[Chunk[Seq[String]]]]] = { 27 | val fcn = (docID: String, docPartNo: Int, docBySentencesByLines: ArrayBuffer[ArrayBuffer[String]]) => { 28 | // type = (UID, Seq[Seq[Chunk[Seq[String]]]]) 29 | new UID(docID, docPartNo) -> docBySentencesByLines.map(assembleWikiChunks(_)); 30 | }; 31 | val uidsWithStandoffAnnots: Seq[(UID, Seq[Seq[Chunk[Seq[String]]]])] = ConllDocReader.readConllDocsGeneral(fileName, fcn); 32 | new HashMap[UID,Seq[Seq[Chunk[Seq[String]]]]] ++ uidsWithStandoffAnnots; 33 | } 34 | 35 | def readStandoffAnnotsAsCorpusAnnots(wikiPath: String) = { 36 | val goldWikification = new CorpusWikiAnnots; 37 | for (entry <- WikiAnnotReaderWriter.readAllStandoffAnnots(wikiPath)) { 38 | val fileName = entry._1._1; 39 | val docAnnots = new DocWikiAnnots; 40 | for (i <- 0 until entry._2.size) { 41 | docAnnots += i -> (new ArrayBuffer[Chunk[Seq[String]]]() ++ entry._2(i)) 42 | } 43 | goldWikification += fileName -> docAnnots 44 | } 45 | goldWikification 46 | } 47 | 48 | def writeStandoffAnnots(writer: PrintWriter, docName: String, docPartNo: Int, annots: HashMap[Int,ArrayBuffer[Chunk[Seq[String]]]], sentLens: Seq[Int]) { 49 | writeStandoffAnnots(writer, docName, docPartNo, (0 until sentLens.size).map(annots(_)), sentLens); 50 | } 51 | 52 | def writeStandoffAnnots(writer: PrintWriter, docName: String, docPartNo: Int, annots: Seq[Seq[Chunk[Seq[String]]]], sentLens: Seq[Int]) { 53 | val numZeroesToAddToPartNo = 3 - docPartNo.toString.size; 54 | writer.println("#begin document (" + docName + "); part " + ("0" * numZeroesToAddToPartNo) + docPartNo); 55 | for (sentBits <- getWikiBits(sentLens, annots)) { 56 | for (bit <- sentBits) { 57 | writer.println(bit); 58 | } 59 | writer.println(); 60 | } 61 | writer.println("#end document"); 62 | } 63 | 64 | def wikiTitleSeqToString(titles: Seq[String]): String = { 65 | if (titles.isEmpty) { 66 | ExcludeToken 67 | } else { 68 | titles.map(_.replace("(", "-LRB-").replace(")", "-RRB-").replace("*", "-STAR-")).reduce(_ + "|" + _); 69 | } 70 | } 71 | 72 | def stringToWikiTitleSeq(str: String): Seq[String] = { 73 | if (str == ExcludeToken) { 74 | Seq[String](); 75 | } else { 76 | str.split("\\|").map(_.replace("-LRB-", "(").replace("-RRB-", ")").replace("-STAR-", "*")).toSeq; 77 | } 78 | } 79 | 80 | def getWikiBits(sentLens: Seq[Int], wikiChunks: Seq[Seq[Chunk[Seq[String]]]]): Seq[Seq[String]] = { 81 | for (sentIdx <- 0 until sentLens.size) yield { 82 | for (tokenIdx <- 0 until sentLens(sentIdx)) yield { 83 | val chunksStartingHere = wikiChunks(sentIdx).filter(chunk => chunk.start == tokenIdx).sortBy(- _.end); 84 | val numChunksEndingHere = wikiChunks(sentIdx).filter(chunk => chunk.end - 1 == tokenIdx).size; 85 | var str = ""; 86 | for (chunk <- chunksStartingHere) { 87 | str += "(" + wikiTitleSeqToString(chunk.label); 88 | } 89 | str += "*"; 90 | for (i <- 0 until numChunksEndingHere) { 91 | str += ")"; 92 | } 93 | str; 94 | } 95 | } 96 | } 97 | 98 | def assembleWikiChunks(bits: Seq[String]) = { 99 | // Bits have to look like some number of (LABEL followed by * followed by some number of ) 100 | val chunks = new ArrayBuffer[Chunk[Seq[String]]](); 101 | val currStartIdxStack = new ArrayBuffer[Int]; 102 | val currTypeStack = new ArrayBuffer[Seq[String]]; 103 | for (i <- 0 until bits.size) { 104 | val containsStar = bits(i).contains("*"); 105 | var bitRemainder = bits(i); 106 | while (bitRemainder.startsWith("(")) { 107 | var endIdx = bitRemainder.indexOf("(", 1); 108 | if (endIdx < 0) { 109 | endIdx = (if (containsStar) bitRemainder.indexOf("*") else bitRemainder.indexOf(")")); 110 | } 111 | require(endIdx >= 0, bitRemainder + " " + bits); 112 | currStartIdxStack += i 113 | currTypeStack += stringToWikiTitleSeq(bitRemainder.substring(1, endIdx)); 114 | bitRemainder = bitRemainder.substring(endIdx); 115 | } 116 | if (containsStar) { 117 | require(bitRemainder.startsWith("*"), bitRemainder + " " + bits); 118 | bitRemainder = bitRemainder.substring(1); 119 | } 120 | while (bitRemainder.startsWith(")")) { 121 | require(!currStartIdxStack.isEmpty, "Bad bits: " + bits); 122 | chunks += new Chunk[Seq[String]](currStartIdxStack.last, i+1, currTypeStack.last); 123 | currStartIdxStack.remove(currStartIdxStack.size - 1); 124 | currTypeStack.remove(currTypeStack.size - 1); 125 | bitRemainder = bitRemainder.substring(1); 126 | } 127 | require(bitRemainder.size == 0); 128 | } 129 | chunks; 130 | } 131 | } -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/wiki/WikificationFeaturizer.scala: -------------------------------------------------------------------------------- 1 | //package edu.berkeley.nlp.entity.wiki 2 | // 3 | //@SerialVersionUID(1L) 4 | //class WikificationFeaturizer(val wikiDB: WikipediaInterface, 5 | // val wikiCategoryDB: Option[WikipediaCategoryDB], 6 | // val wikiLinkDB: Option[WikipediaLinkDB]) extends Serializable { 7 | // 8 | //} 9 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/wiki/Wikifier.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.wiki 2 | 3 | import edu.berkeley.nlp.entity.coref.Mention 4 | import edu.berkeley.nlp.futile.util.Counter 5 | 6 | trait Wikifier { 7 | 8 | def wikify(docName: String, ment: Mention): String; 9 | 10 | def wikifyGetTitleSet(docName: String, ment: Mention): Seq[String]; 11 | 12 | def wikifyGetPriorForJointModel(docName: String, ment: Mention): Counter[String]; 13 | 14 | def oracleWikifyNil(docName: String, ment: Mention); 15 | 16 | def oracleWikify(docName: String, ment: Mention, goldTitles: Seq[String]); 17 | 18 | def printDiagnostics(); 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/wiki/WikipediaAuxDB.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.wiki 2 | 3 | import scala.collection.mutable.HashMap 4 | import scala.collection.mutable.ArrayBuffer 5 | import scala.collection.JavaConverters._ 6 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 7 | import edu.berkeley.nlp.futile.util.Logger 8 | import edu.berkeley.nlp.futile.util.Counter 9 | import edu.berkeley.nlp.PCFGLA.CoarseToFineMaxRuleParser 10 | import edu.berkeley.nlp.entity.preprocess.Reprocessor 11 | import edu.berkeley.nlp.entity.preprocess.PreprocessingDriver 12 | import edu.berkeley.nlp.entity.preprocess.SentenceSplitter 13 | import edu.berkeley.nlp.entity.DepConstTree 14 | import scala.collection.mutable.HashSet 15 | 16 | @SerialVersionUID(1L) 17 | class WikipediaAuxDB(val disambiguationSet: HashSet[String]) extends Serializable { 18 | def isDisambiguation(pageTitle: String) = disambiguationSet.contains(pageTitle); 19 | 20 | def purgeDisambiguationAll(counter: Counter[String]) = { 21 | for (key <- counter.keySet.asScala.toSeq) { 22 | if (isDisambiguation(key)) { 23 | // Logger.logss("Purging " + key); 24 | counter.removeKey(key); 25 | } 26 | } 27 | counter; 28 | } 29 | 30 | // Horrifyingly hard-coded but I didn't want to introduce a new dependency.... 31 | def writeToJsonFile(path: String) { 32 | val writer = IOUtils.openOutHard(path) 33 | writer.println("{") 34 | val disSet = disambiguationSet.map(entry => "\"" + entry + "\"") 35 | writer.println(" \"disambiguationSet\": [" + (if (disambiguationSet.isEmpty) "" else disSet.reduce(_ + ", " + _)) + "]") 36 | writer.println("}") 37 | writer.close() 38 | } 39 | } 40 | 41 | object WikipediaAuxDB { 42 | 43 | def processWikipedia(wikipediaPath: String, 44 | pageTitleSetLc: Set[String]): WikipediaAuxDB = { 45 | val lines = IOUtils.lineIterator(IOUtils.openInHard(wikipediaPath)); 46 | var currentPageTitle = ""; 47 | var doneWithThisPage = false; 48 | var isInText = false; 49 | val disambiguationSet = new HashSet[String] 50 | // Extract first line that's not in brackets 51 | while (lines.hasNext) { 52 | val line = lines.next; 53 | if (line.size > 8 && doneWithThisPage) { 54 | // Do nothing 55 | } else { 56 | if (line.contains("")) { 57 | doneWithThisPage = false; 58 | } else if (line.contains("")) { 59 | currentPageTitle = line.substring(line.indexOf("<title>") + 7, line.indexOf("")); 60 | if (!pageTitleSetLc.contains(currentPageTitle.toLowerCase)) { 61 | doneWithThisPage = true; 62 | } 63 | } 64 | if (!doneWithThisPage && (line.startsWith("{{disambiguation}}") || line.startsWith("{{disambiguation|") || 65 | line.startsWith("{{disambig}}") || line.startsWith("{{hndis"))) { 66 | disambiguationSet += currentPageTitle; 67 | doneWithThisPage = true; 68 | } 69 | } 70 | } 71 | new WikipediaAuxDB(disambiguationSet); 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/wiki/WikipediaRedirectsDB.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.wiki 2 | 3 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 4 | import scala.collection.mutable.HashMap 5 | import edu.berkeley.nlp.futile.util.Counter 6 | import scala.collection.JavaConverters._ 7 | import edu.berkeley.nlp.futile.util.Logger 8 | import edu.berkeley.nlp.entity.wiki._ 9 | 10 | @SerialVersionUID(1L) 11 | class WikipediaRedirectsDB(val redirects: HashMap[String,String]) extends Serializable { 12 | val redirectsWikicase = new HashMap[String,String]; 13 | redirects.foreach(redirect => redirectsWikicase += wikiCase(redirect._1) -> redirect._2); 14 | val possibleRedirectTargets = redirects.map(_._2).toSet; 15 | val possibleRedirectTargetsLc = redirects.map(_._2.toLowerCase).toSet; 16 | 17 | def followRedirect(title: String) = { 18 | // val print = title == "student_association"; 19 | val print = false 20 | // Try to redirect 21 | val result = if (redirects.contains(title)) { 22 | if (print) Logger.logss("1 " + redirects(title)); 23 | redirects(title); 24 | } else if (redirectsWikicase.contains(wikiCase(title))){ 25 | if (print) Logger.logss("3 " + redirectsWikicase(wikiCase(title))); 26 | redirectsWikicase(wikiCase(title)); 27 | } else if (WikipediaRedirectsDB.CapitalizeInitial) { 28 | if (print) Logger.logss("4 " + wikiCase(title)); 29 | wikiCase(title) 30 | } else { 31 | if (print) Logger.logss("5 " + title); 32 | title; 33 | } 34 | WikipediaRedirectsDB.removeWeirdMarkup(result); 35 | } 36 | 37 | def followRedirectsCounter(titleCounts: Counter[String]) = { 38 | val newTitleCounts = new Counter[String]; 39 | for (title <- titleCounts.keySet.asScala) { 40 | newTitleCounts.incrementCount(followRedirect(title), titleCounts.getCount(title)); 41 | } 42 | newTitleCounts; 43 | } 44 | 45 | // Horrifyingly hard-coded but I didn't want to introduce a new dependency.... 46 | def writeToJsonFile(path: String) { 47 | val writer = IOUtils.openOutHard(path) 48 | writer.println("{") 49 | WikipediaInterface.writeMapToJson(writer, redirects, "redirects", "title", "redirectsTo") 50 | writer.println("}") 51 | writer.close() 52 | } 53 | } 54 | 55 | object WikipediaRedirectsDB { 56 | 57 | val CapitalizeInitial = true; 58 | 59 | def removeWeirdMarkup(str: String) = { 60 | str.replace("'", "'"); 61 | } 62 | 63 | def processWikipediaGetRedirects(wikipediaPath: String, redirectCandidates: Set[String]) = { 64 | val redirects = new HashMap[String,String]; 65 | val lines = IOUtils.lineIterator(IOUtils.openInHard(wikipediaPath)); 66 | var lineIdx = 0; 67 | 68 | var currentPageTitle = ""; 69 | var doneWithThisPage = true; 70 | while (lines.hasNext) { 71 | val line = lines.next; 72 | if (lineIdx % 100000 == 0) { 73 | println("Line: " + lineIdx + ", processed"); 74 | } 75 | lineIdx += 1; 76 | if (line.contains("")) { 77 | doneWithThisPage = false; 78 | } else if (!doneWithThisPage && line.contains("")) { 79 | // 7 = "<title>".length() 80 | currentPageTitle = line.substring(line.indexOf("<title>") + 7, line.indexOf("")).replace(" ", "_"); 81 | if (!redirectCandidates.contains(currentPageTitle)) { 82 | doneWithThisPage = true; 83 | } 84 | } else if (!doneWithThisPage && line.contains("" so we just need to catch the next one and skip 114 | // longer lines 115 | if (line.size > 8 && doneWithThisPage) { 116 | // Do nothing 117 | } else { 118 | if (line.contains("")) { 119 | doneWithThisPage = false; 120 | numPagesSeen += 1; 121 | } else if (line.contains("")) { 122 | // 7 = "<title>".length() 123 | currentPageTitle = maybeLc(line.substring(line.indexOf("<title>") + 7, line.indexOf("")), lowercase); 124 | } else if (line.contains(" chunk.start == ment.startIdx && chunk.end == ment.endIdx); 32 | if (matchingChunks.isEmpty) Seq[String]() else matchingChunks(0).label; 33 | } 34 | } 35 | 36 | def isCorrect(gold: Seq[String], guess: String): Boolean = { 37 | // (gold.isEmpty && guess == NilToken) || 38 | (gold.map(_.toLowerCase).contains(guess.toLowerCase.replace(" ", "_"))); // handles the -NIL- case too 39 | } 40 | 41 | def containsCorrect(gold: Seq[String], guesses: Iterable[String]): Boolean = { 42 | guesses.foldLeft(false)(_ || isCorrect(gold, _)); 43 | } 44 | 45 | def accessACEHead(aceHeads: HashMap[String,HashMap[Int,Seq[Chunk[Int]]]], docName: String, sentIdx: Int, startIdx: Int, endIdx: Int) = { 46 | if (aceHeads.contains(docName) && aceHeads(docName).contains(sentIdx)) { 47 | val possibleMatchingChunk = aceHeads(docName)(sentIdx).filter(chunk => chunk.start == startIdx && chunk.end == endIdx); 48 | if (possibleMatchingChunk.size >= 1) { 49 | possibleMatchingChunk.head.label; 50 | } else { 51 | -1; 52 | } 53 | } else { 54 | -1; 55 | } 56 | } 57 | 58 | def extractAnnotation[T](annots: CorpusAnnots[T], docName: String, sentIdx: Int, startIdx: Int, endIdx: Int): Option[T] = { 59 | if (annots.contains(docName) && annots(docName).contains(sentIdx)) { 60 | extractChunkLabel(annots(docName)(sentIdx), startIdx, endIdx); 61 | } else { 62 | None; 63 | } 64 | } 65 | 66 | def extractChunkLabel[T](chunks: Seq[Chunk[T]], startIdx: Int, endIdx: Int): Option[T] = { 67 | val maybeResult = chunks.filter(chunk => chunk.start == startIdx && chunk.end == endIdx); 68 | if (maybeResult.size != 1) None else Some(maybeResult.head.label); 69 | } 70 | 71 | def wikiCase(str: String) = { 72 | require(str != null); 73 | if (str.size == 0) "" else Character.toUpperCase(str.charAt(0)) + str.substring(1) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/xdistrib/ComponentFeaturizer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.xdistrib 2 | 3 | import scala.collection.mutable.ArrayBuffer 4 | 5 | import edu.berkeley.nlp.entity.coref.DocumentGraph 6 | import edu.berkeley.nlp.entity.coref.FeatureSetSpecification 7 | import edu.berkeley.nlp.entity.coref.LexicalCountsBundle 8 | import edu.berkeley.nlp.entity.wiki.WikipediaInterface 9 | import edu.berkeley.nlp.entity.sem.QueryCountsBundle 10 | import edu.berkeley.nlp.entity.sem.SemClasser 11 | import edu.berkeley.nlp.futile.fig.basic.Indexer 12 | 13 | class ComponentFeaturizer(val componentIndexer: Indexer[String], 14 | val featureSet: FeatureSetSpecification, 15 | val lexicalCounts: LexicalCountsBundle, 16 | val queryCounts: Option[QueryCountsBundle], 17 | val wikipediaInterface: Option[WikipediaInterface], 18 | val semClasser: Option[SemClasser]) { 19 | 20 | def featurizeComponents(docGraph: DocumentGraph, idx: Int, addToFeaturizer: Boolean): Array[Int] = { 21 | val feats = new ArrayBuffer[Int]; 22 | def addFeatureShortcut = (featName: String) => { 23 | // Only used in CANONICAL_ONLY_PAIR, so only compute the truth value in this case 24 | if (addToFeaturizer || componentIndexer.contains(featName)) { 25 | feats += componentIndexer.getIndex(featName); 26 | } 27 | } 28 | val ment = docGraph.getMention(idx); 29 | if (!ment.mentionType.isClosedClass) { 30 | // WORDS 31 | if (featureSet.featsToUse.contains("comp-word")) { 32 | if (lexicalCounts.commonHeadWordCounts.containsKey(ment.headStringLc)) { 33 | addFeatureShortcut("CHead=" + ment.headStringLc); 34 | } else { 35 | addFeatureShortcut("CHead=" + ment.headPos); 36 | } 37 | } 38 | // SEMCLASS 39 | if (featureSet.featsToUse.contains("comp-sc") && semClasser.isDefined) { 40 | addFeatureShortcut("CSC=" + semClasser.get.getSemClass(ment, docGraph.cachedWni)); 41 | } 42 | // WIKIPEDIA 43 | if (featureSet.featsToUse.contains("comp-wiki") && wikipediaInterface.isDefined) { 44 | val title = wikipediaInterface.get.disambiguate(ment); 45 | val topCategory = wikipediaInterface.get.getTopKCategoriesByFrequency(title, 1); 46 | if (topCategory.size >= 1) { 47 | addFeatureShortcut("CCategory=" + wikipediaInterface.get.getInfoboxHead(title)); 48 | } 49 | addFeatureShortcut("CInfo=" + wikipediaInterface.get.getInfoboxHead(title)); 50 | addFeatureShortcut("CApp=" + wikipediaInterface.get.getAppositive(title)); 51 | } 52 | } 53 | feats.toArray; 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/main/java/edu/berkeley/nlp/entity/xdistrib/DocumentGraphComponents.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.entity.xdistrib 2 | 3 | import edu.berkeley.nlp.entity.coref.DocumentGraph 4 | 5 | class DocumentGraphComponents(val docGraph: DocumentGraph, 6 | val components: Array[Array[Int]]) { 7 | val cachedSummedVects: Array[Array[Float]] = Array.tabulate(docGraph.size)(i => null); 8 | } 9 | 10 | object DocumentGraphComponents { 11 | 12 | def cacheComponents(docGraph: DocumentGraph, componentFeaturizer: ComponentFeaturizer, addToFeaturizer: Boolean) = { 13 | new DocumentGraphComponents(docGraph, Array.tabulate(docGraph.size)(i => componentFeaturizer.featurizeComponents(docGraph, i, addToFeaturizer))); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /test/text/government.txt: -------------------------------------------------------------------------------- 1 | The political standoff in the nation’s capital ended just minutes before a midnight deadline when the government’s ability to borrow money would have expired. Republicans conceded defeat on Wednesday by agreeing to finance the operations of government until Jan. 15 and raise the nation’s debt limit through the middle of February. The Senate passed the legislation first, and the House followed around 10:15 p.m. 2 | 3 | The agreement paves the way for another series of budget negotiations in the weeks ahead, even as conservative Republicans in the House and Senate vowed to renew their fight for cuts in spending and changes to the Affordable Care Act. 4 | 5 | Just hours after Mr. Obama signed the temporary spending measure into law around 12:30 a.m., agencies in Washington and across the country prepared to reopen offices, public parks, research projects and community programs that have been mothballed for more than two weeks. The government’s top personnel officer announced that officials should restart normal functions “in a prompt and orderly manner.” 6 | 7 | In Washington, the city’s subway trains were once again packed with federal workers streaming in from the suburbs, government IDs dangling from lanyards around their necks. At the Lincoln Memorial, tourists waited nearby as a park ranger cut down the signs announcing that the memorial was closed. 8 | 9 | Robert Lagana said Thursday morning he was eager to get back to his job at the International Trade Commission. 10 | 11 | “It beats climbing the walls, wondering where your next paycheck is going to be and how you’re going to make your bills,” Mr. Lagana said as he made his way to his office near L’Enfant Plaza. 12 | 13 | But he also expressed frustration with lawmakers who held up the budget over the new health care law. “They really need to come up with a law where this never happens again,” he said, adding later, “You just feel like you don’t have a voice.” 14 | 15 | At the Environmental Protection Agency headquarters in Washington, Vice President Joseph R. Biden Jr. showed up to see workers who had been furloughed. 16 | 17 | “I brought some muffins!” Mr. Biden said as he arrived at the security desk. When he was asked about the shutdown, he said: “I’m happy it’s ended. It was unnecessary to begin with. I’m happy it’s ended.” 18 | 19 | He greeted returning workers with handshakes and hugs. 20 | 21 | The Smithsonian Institution announced via Twitter that its museums would reopen to the public on Thursday. The National Zoo’s popular “Panda Cam” was once again broadcasting live streams of the zoo’s newest panda cub by late Thursday morning, and officials said that the zoo would be open to visitors on Friday. 22 | 23 | But how quickly other parts of the government will resume normal operations was not immediately clear. 24 | 25 | Some federal agencies began offering employees guidance for their return to work. A memorandum from officials at the Department of the Interior encouraged returning workers to check their e-mail and voicemail, fill out their timecards and to “check on any refrigerators and throw out any perished food.” 26 | 27 | The Interior memo hinted at how long it will take for the government to be fully functioning. It said snack bars at the main Interior building would be open on Thursday, but the cafeteria would be closed. Shuttles between Interior buildings in the capital will not be operating, the memorandum said. 28 | 29 | Across the country, federal workers returned to work, and visitors returned to national historic sites. 30 | 31 | In New York City, office workers poured in and out of the mammoth building at 26 Federal Plaza in Lower Manhattan on Thursday morning; some had been essential staff who worked through the shutdown, while others were eager to get back to the job – and start being paid again. 32 | 33 | “Put yourself in that situation,” said Regina Napoli, 60, a legal administrator for the Social Security Administration who lives on Long Island. “The bills pile up.” 34 | 35 | Her colleague, Selma Chan, 64, agreed. “We were feeling the strain financially and physically,” said Ms. Chan, whose younger daughter is a student at New York University. “We didn’t know what to do.” 36 | 37 | Ms. Chan said she had mostly stayed at home in Flushing, Queens, and had contemplated borrowing money from her elderly mother to make ends meet. But on Thursday, she was beaming as she held up a brown paper bag with a latte and a grilled cheese sandwich – an indulgence she said was not possible the day before. 38 | 39 | -------------------------------------------------------------------------------- /test/text/music.txt: -------------------------------------------------------------------------------- 1 | Multiple studies link music study to academic achievement. But what is it about serious music training that seems to correlate with outsize success in other fields? 2 | 3 | The connection isn’t a coincidence. I know because I asked. I put the question to top-flight professionals in industries from tech to finance to media, all of whom had serious (if often little-known) past lives as musicians. Almost all made a connection between their music training and their professional achievements. 4 | 5 | The phenomenon extends beyond the math-music association. Strikingly, many high achievers told me music opened up the pathways to creative thinking. And their experiences suggest that music training sharpens other qualities: Collaboration. The ability to listen. A way of thinking that weaves together disparate ideas. The power to focus on the present and the future simultaneously. 6 | 7 | Will your school music program turn your kid into a Paul Allen, the billionaire co-founder of Microsoft (guitar)? Or a Woody Allen (clarinet)? Probably not. These are singular achievers. But the way these and other visionaries I spoke to process music is intriguing. As is the way many of them apply music’s lessons of focus and discipline into new ways of thinking and communicating — even problem solving. 8 | 9 | Look carefully and you’ll find musicians at the top of almost any industry. Woody Allen performs weekly with a jazz band. The television broadcaster Paula Zahn (cello) and the NBC chief White House correspondent Chuck Todd (French horn) attended college on music scholarships; NBC’s Andrea Mitchell trained to become a professional violinist. Both Microsoft’s Mr. Allen and the venture capitalist Roger McNamee have rock bands. Larry Page, a co-founder of Google, played saxophone in high school. Steven Spielberg is a clarinetist and son of a pianist. The former World Bank president James D. Wolfensohn has played cello at Carnegie Hall. 10 | 11 | “It’s not a coincidence,” says Mr. Greenspan, who gave up jazz clarinet but still dabbles at the baby grand in his living room. “I can tell you as a statistician, the probability that that is mere chance is extremely small.” The cautious former Fed chief adds, “That’s all that you can judge about the facts. The crucial question is: why does that connection exist?” 12 | 13 | Paul Allen offers an answer. He says music “reinforces your confidence in the ability to create.” Mr. Allen began playing the violin at age 7 and switched to the guitar as a teenager. Even in the early days of Microsoft, he would pick up his guitar at the end of marathon days of programming. The music was the emotional analog to his day job, with each channeling a different type of creative impulse. In both, he says, “something is pushing you to look beyond what currently exists and express yourself in a new way.” 14 | 15 | Mr. Todd says there is a connection between years of practice and competition and what he calls the “drive for perfection.” The veteran advertising executive Steve Hayden credits his background as a cellist for his most famous work, the Apple “1984” commercial depicting rebellion against a dictator. “I was thinking of Stravinsky when I came up with that idea,” he says. He adds that his cello performance background helps him work collaboratively: “Ensemble playing trains you, quite literally, to play well with others, to know when to solo and when to follow.” 16 | 17 | For many of the high achievers I spoke with, music functions as a “hidden language,” as Mr. Wolfensohn calls it, one that enhances the ability to connect disparate or even contradictory ideas. When he ran the World Bank, Mr. Wolfensohn traveled to more than 100 countries, often taking in local performances (and occasionally joining in on a borrowed cello), which helped him understand “the culture of people, as distinct from their balance sheet.” 18 | 19 | --------------------------------------------------------------------------------