├── LICENSE ├── README.md ├── acl2016_nyt50_eval_doc_list.txt ├── berkeley-doc-summarizer-assembly-1.jar ├── build.sbt ├── lib ├── berkeley-entity-assembly-1.jar ├── glpk-java-1.1.0.jar └── trove.jar ├── project └── plugins.sbt ├── rouge └── ROUGE │ ├── README.txt │ ├── RELEASE-NOTE.txt │ ├── ROUGE-1.5.5.pl │ ├── data │ ├── WordNet-1.6-Exceptions │ │ ├── WordNet-1.6.exc.db │ │ ├── adj.exc │ │ ├── adv.exc │ │ ├── buildExeptionDB.pl │ │ ├── noun.exc │ │ └── verb.exc │ ├── WordNet-1.6.exc.db │ ├── WordNet-2.0-Exceptions │ │ ├── WordNet-2.0.exc.db │ │ ├── adj.exc │ │ ├── adv.exc │ │ ├── buildExeptionDB.pl │ │ ├── noun.exc │ │ └── verb.exc │ ├── WordNet-2.0.exc.db │ └── smart_common_words.txt │ └── rouge-gillick.sh ├── run-glpk-test.sh ├── run-summarizer.sh ├── src └── main │ └── scala │ ├── edu │ └── berkeley │ │ └── nlp │ │ └── summ │ │ ├── BigramCountSummarizer.scala │ │ ├── CompressiveAnaphoraSummarizer.scala │ │ ├── CompressiveAnaphoraSummarizerILP.scala │ │ ├── CorefUtils.scala │ │ ├── DiscourseSummarizer.scala │ │ ├── GLPKTest.java │ │ ├── GeneralTrainer.scala │ │ ├── Main.scala │ │ ├── Pair.java │ │ ├── RougeComputer.scala │ │ ├── RougeFileMunger.scala │ │ ├── Summarizer.scala │ │ ├── compression │ │ ├── SentenceCompressor.java │ │ ├── SyntacticCompressor.scala │ │ └── TreeProcessor.java │ │ ├── data │ │ ├── DepParse.scala │ │ ├── DepParseDoc.scala │ │ ├── DiscourseDepEx.scala │ │ ├── DiscourseTree.scala │ │ ├── DiscourseTreeReader.scala │ │ ├── EDUAligner.scala │ │ ├── StopwordDict.scala │ │ ├── SummDoc.scala │ │ └── SummaryAligner.scala │ │ ├── demo │ │ ├── SummarizerDemo.scala │ │ └── TreeJPanel.java │ │ └── preprocess │ │ ├── DiscourseDependencyParser.scala │ │ ├── EDUSegmenter.scala │ │ ├── EDUSegmenterSemiMarkov.scala │ │ └── StandoffAnnotationHandler.scala │ └── mstparser │ ├── Feature.java │ ├── FeatureVector.java │ ├── KsummDependencyDecoder.java │ ├── KsummKBestParseForest.java │ └── ParseForestItem.java └── test ├── 1818952 ├── 1832050 ├── 1848490 ├── 1855662 └── government.txt /README.md: -------------------------------------------------------------------------------- 1 | berkeley-doc-summarizer 2 | ======================= 3 | 4 | The Berkeley Document Summarizer is a learning-based single-document 5 | summarization system. It compresses source document text based on constraints 6 | from constituency parses and RST discourse parses. Moreover, it can improve 7 | summary clarity by reexpressing pronouns whose antecedents would otherwise be 8 | deleted or unclear. 9 | 10 | NOTE: If all you're interested in is the New York Times dataset, you do *not* 11 | need to do most of the setup and preprocessing below. Instead, use the pre-built 12 | .jar and run the commands in the "New York Times Dataset" section under "Training" 13 | below. 14 | 15 | 16 | 17 | ## Preamble 18 | 19 | The Berkeley Document Summarizer is described in: 20 | 21 | "Learning-Based Single-Document Summarization with Compression and Anaphoricity Constraints" 22 | Greg Durrett, Taylor Berg-Kirkpatrick, and Dan Klein. ACL 2016. 23 | 24 | See http://www.eecs.berkeley.edu/~gdurrett/ for papers and BibTeX. 25 | 26 | Questions? Bugs? Email me at gdurrett@eecs.berkeley.edu 27 | 28 | 29 | 30 | ## License 31 | 32 | Copyright (c) 2013-2016 Greg Durrett. All Rights Reserved. 33 | 34 | This program is free software: you can redistribute it and/or modify 35 | it under the terms of the GNU General Public License as published by 36 | the Free Software Foundation, either version 3 of the License, or 37 | (at your option) any later version. 38 | 39 | This program is distributed in the hope that it will be useful, 40 | but WITHOUT ANY WARRANTY; without even the implied warranty of 41 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 42 | GNU General Public License for more details. 43 | 44 | You should have received a copy of the GNU General Public License 45 | along with this program. If not, see http://www.gnu.org/licenses/ 46 | 47 | 48 | ## Setup 49 | 50 | #### Models and Data 51 | 52 | Models are not included in GitHub due to their large size. Download the latest 53 | models from http://nlp.cs.berkeley.edu/projects/summarizer.shtml. These 54 | are necessary for both training the system (you need the EDU segmenter, discourse 55 | parser, and coreference model) as well as running it (you need the EDU segmenter, 56 | discourse parser, and summarization model, which contains the coreference model). 57 | All of these are expected in the models/ subdirectory. 58 | 59 | We also require [number and gender data](http://www.cs.utexas.edu/~gdurrett/data/gender.data.tgz) 60 | produced by Shane Bergsma and Dekang Lin in in "Bootstrapping Path-Based Pronoun Resolution". 61 | Download this, untar/gzip it, and put it at `data/gender.data` (default path the system 62 | expects it to be at). 63 | 64 | #### GLPK 65 | 66 | For solving ILPs, our system relies on GLPK, specifically [GLPK for Java](http://glpk-java.sourceforge.net/). 67 | For OS X, the easiest way to install GLPK is with [homebrew](http://brew.sh/). On Linux, 68 | you should run ```sudo apt-get install glpk-utils libglpk-dev libglpk-java```. 69 | 70 | Both the libglpk-java and Java Native Interface (JNI) libraries need to be in 71 | your Java library path (see below for how to test this); these libraries allow 72 | Java to interact with the native GLPK code. Additionally, when running the 73 | system, you must have ```glpk-java-1.1.0.jar``` on the build path; this is 74 | included in the lib directory and bundled with the distributed jar, and will 75 | continue to be included automatically if you build with sbt. 76 | 77 | You can test whether the system can call GLPK successfully with with 78 | ```run-glpk-test.sh```, which tries to solve a small ILP defined in 79 | ```edu.berkeley.nlp.summ.GLPKTest```. The script attempts to augment the 80 | library path with ```/usr/local/lib/jni```, which is sometimes where the JNI 81 | library is located on OS X. If this script reports an error, you may need to 82 | augment the Java library path with the location of either the JNI or the 83 | libglpk_java libraries as follows: 84 | 85 | -Djava.library.path=":" 86 | 87 | #### Building from source 88 | 89 | The easiest way to build is with SBT: 90 | https://github.com/harrah/xsbt/wiki/Getting-Started-Setup 91 | 92 | then run 93 | 94 | sbt assembly 95 | 96 | which will compile everything and build a runnable jar. 97 | 98 | You can also import it into Eclipse and use the Scala IDE plug-in for Eclipse 99 | http://scala-ide.org 100 | 101 | 102 | 103 | ## Running the system 104 | 105 | The two most useful main classes are ```edu.berkeley.nlp.summ.Main``` and 106 | ```edu.berkeley.nlp.summ.Summarizer```. The former is a more involved harness 107 | for training and evaluating the system on the New York Times corpus (see below 108 | for how to acquire this corpus), and the latter simply takes a trained model 109 | and runs it. Both files contain descriptions of their functionality and command-line 110 | arguments. 111 | 112 | An example run on new data is included in ```run-summarizer.sh```. The main 113 | prerequisite for running the summarizer on new data is having that data preprocessed 114 | in the CoNLL format with constituency parses, NER, and coreference. For a system that 115 | does this, see the [Berkeley Entity Resolution System](https://github.com/gregdurrett/berkeley-entity). 116 | The ```test/``` directory already contains a few such files. 117 | 118 | The summarizer then does additional processing with EDU segmentation and discourse parsing. 119 | These use the models that are by default located in ```models/edusegmenter.ser.gz``` and 120 | ```models/discoursedep.ser.gz```. You can control these with command-line switches. 121 | 122 | The system is distributed with several pre-trained variants: 123 | 124 | * ```summarizer-extractive.ser.gz```: a sentence-extractive summarizer 125 | * ```summarizer-extractive-compressive.ser.gz```: an extractive-compressive summarizer 126 | * ```summarizer-full.ser.gz```: an extractive-compressive summarizer with the ability to rewrite pronouns 127 | and additional coreference features and constraints 128 | 129 | 130 | 131 | ## Training 132 | 133 | #### New York Times Dataset 134 | 135 | The primary corpus we use for training and evaluation is the New York Times Annotated Corpus 136 | (Sandhaus, 2007), LDC2008T19. We distribute our preprocessing as standoff annotations which 137 | replace words with (line, char start, char end) triples, except for some cases where words are 138 | included manually (e.g. when tokenization makes our data non-recoverable from the original 139 | file). A few scattered tokens are included explicitly, plus roughly 1% of files that our 140 | system couldn't find a suitable alignment for. 141 | 142 | To prepare the dataset, first you need to extract all the XML files from 2003-2007 and flatten 143 | them into a single directory. Not all files have summaries, so not all of these will 144 | be used. Next, run 145 | 146 | mkdir train_corefner 147 | java -Xmx3g -cp edu.berkeley.nlp.summ.preprocess.StandoffAnnotationHandler \ 148 | -inputDir train_corefner_standoff/ -rawXMLDir -outputDir train_corefner/ 149 | 150 | This will take the train standoff annotation files and reconstitute 151 | the real files using the XML data, writing to the output directory. Use ```eval``` instead of ```train``` 152 | to reconstitute the test set. 153 | 154 | To reconstitute abstracts, run: 155 | 156 | java -Xmx3g -cp edu.berkeley.nlp.summ.preprocess.StandoffAnnotationHandler \ 157 | -inputDir train_abstracts_standoff/ -rawXMLDir -outputDir train_abstracts/ \ 158 | -tagName "abstract" 159 | 160 | and similarly swap out for ```eval``` appropriately. 161 | 162 | #### ROUGE Scorer 163 | 164 | We bundle the system with a version of the ROUGE scorer that will be called during 165 | execution. ```rouge-gillick.sh``` hardcodes command-line arguments used in this work and 166 | in Hirao et al. (2013)'s work. The system expects this in the ```rouge/ROUGE/``` directory 167 | under the execution directory, along with the appropriate data files (which we've also 168 | bundled with this release). 169 | 170 | See ```edu.berkeley.nlp.summ.RougeComputer.evaluateRougeNonTok``` for a method you can 171 | use to evaluate ROUGE in a manner consistent with our evaluation. 172 | 173 | #### Training the system 174 | 175 | To train the full system, run: 176 | 177 | java -Xmx80g -cp -Djava.library.path=:/usr/local/lib/jni edu.berkeley.nlp.summ.Main \ 178 | -trainDocsPath -trainAbstractsPath \ 179 | -evalDocsPath -evalAbstractsPath -abstractsAreConll \ 180 | -modelPath "models/trained-model.ser.gz" -corefModelPath "models/coref-onto.ser.gz" \ 181 | -printSummaries -printSummariesForTurk \ 182 | 183 | where ``````, ``````, and the data paths are instantiated accordingly. The system requires a lot 184 | of memory due to caching 25,000 training documents with annotations. 185 | 186 | To train the sentence extractive version of the system, add: 187 | 188 | -doPronounReplacement false -useFragilePronouns false -noRst 189 | 190 | To train the extractive-compressive version, add: 191 | 192 | -doPronounReplacement false -useFragilePronouns false 193 | 194 | 195 | The results you get using this command should be: 196 | 197 | * extractive: ROUGE-1 recall: 38.6 / ROUGE-2 recall: 23.3 198 | * extractive-compressive: ROUGE-1 recall: 42.2 / ROUGE-2 recall: 26.1 199 | * full: ROUGE-1 recall: 41.9 / ROUGE-2 recall: 25.7 200 | 201 | (Results are slightly different from those in the paper due to minor changes for this 202 | release.) 203 | 204 | -------------------------------------------------------------------------------- /berkeley-doc-summarizer-assembly-1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregdurrett/berkeley-doc-summarizer/3c32487b7419e8b76a5ec2e5233de7c75ca93fa8/berkeley-doc-summarizer-assembly-1.jar -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | import AssemblyKeys._ // put this at the top of the file 2 | 3 | name := "berkeley-doc-summarizer" 4 | 5 | version := "1" 6 | 7 | scalaVersion := "2.11.2" 8 | 9 | assemblySettings 10 | 11 | -------------------------------------------------------------------------------- /lib/berkeley-entity-assembly-1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregdurrett/berkeley-doc-summarizer/3c32487b7419e8b76a5ec2e5233de7c75ca93fa8/lib/berkeley-entity-assembly-1.jar -------------------------------------------------------------------------------- /lib/glpk-java-1.1.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregdurrett/berkeley-doc-summarizer/3c32487b7419e8b76a5ec2e5233de7c75ca93fa8/lib/glpk-java-1.1.0.jar -------------------------------------------------------------------------------- /lib/trove.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregdurrett/berkeley-doc-summarizer/3c32487b7419e8b76a5ec2e5233de7c75ca93fa8/lib/trove.jar -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns) 2 | 3 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.10.0") 4 | -------------------------------------------------------------------------------- /rouge/ROUGE/RELEASE-NOTE.txt: -------------------------------------------------------------------------------- 1 | # Revision Note: 05/26/2005, Chin-Yew LIN 2 | # 1.5.5 3 | # (1) Correct stemming on multi-token BE heads and modifiers. 4 | # Previously, only single token heads and modifiers were assumed. 5 | # (2) Correct the resampling routine which ignores the last evaluation 6 | # item in the evaluation list. Therefore, the average scores reported 7 | # by ROUGE is only based on the first N-1 evaluation items. 8 | # Thanks Barry Schiffman at Columbia University to report this bug. 9 | # This bug only affects ROUGE-1.5.X. For pre-1.5 ROUGE, it only affects 10 | # the computation of confidence interval (CI) estimation, i.e. CI is only 11 | # estimated by the first N-1 evaluation items, but it *does not* affect 12 | # average scores. 13 | # (3) Change read_text and read_text_LCS functions to read exact words or 14 | # bytes required by users. Previous versions carry out whitespace 15 | # compression and other string clear up actions before enforce the length 16 | # limit. 17 | # 1.5.4.1 18 | # (1) Minor description change about "-t 0" option. 19 | # 1.5.4 20 | # (1) Add easy evalution mode for single reference evaluations with -z 21 | # option. 22 | # 1.5.3 23 | # (1) Add option to compute ROUGE score based on SIMPLE BE format. Given 24 | # a set of peer and model summary file in BE format with appropriate 25 | # options, ROUGE will compute matching scores based on BE lexical 26 | # matches. 27 | # There are 6 options: 28 | # 1. H : Head only match. This is similar to unigram match but 29 | # only BE Head is used in matching. BEs generated by 30 | # Minipar-based breaker do not include head-only BEs, 31 | # therefore, the score will always be zero. Use HM or HMR 32 | # optiions instead. 33 | # 2. HM : Head and modifier match. This is similar to bigram or 34 | # skip bigram but it's head-modifier bigram match based on 35 | # parse result. Only BE triples with non-NIL modifier are 36 | # included in the matching. 37 | # 3. HMR : Head, modifier, and relation match. This is similar to 38 | # trigram match but it's head-modifier-relation trigram 39 | # match based on parse result. Only BE triples with non-NIL 40 | # relation are included in the matching. 41 | # 4. HM1 : This is combination of H and HM. It is similar to unigram + 42 | # bigram or skip bigram with unigram match but it's 43 | # head-modifier bigram match based on parse result. 44 | # In this case, the modifier field in a BE can be "NIL" 45 | # 5. HMR1 : This is combination of HM and HMR. It is similar to 46 | # trigram match but it's head-modifier-relation trigram 47 | # match based on parse result. In this case, the relation 48 | # field of the BE can be "NIL". 49 | # 6. HMR2 : This is combination of H, HM and HMR. It is similar to 50 | # trigram match but it's head-modifier-relation trigram 51 | # match based on parse result. In this case, the modifier and 52 | # relation fields of the BE can both be "NIL". 53 | # 1.5.2 54 | # (1) Add option to compute ROUGE score by token using the whole corpus 55 | # as average unit instead of individual sentences. Previous versions of 56 | # ROUGE uses sentence (or unit) boundary to break counting unit and takes 57 | # the average score from the counting unit as the final score. 58 | # Using the whole corpus as one single counting unit can potentially 59 | # improve the reliablity of the final score that treats each token as 60 | # equally important; while the previous approach considers each sentence as 61 | # equally important that ignores the length effect of each individual 62 | # sentences (i.e. long sentences contribute equal weight to the final 63 | # score as short sentences.) 64 | # +v1.2 provide a choice of these two counting modes that users can 65 | # choose the one that fits their scenarios. 66 | # 1.5.1 67 | # (1) Add precision oriented measure and f-measure to deal with different lengths 68 | # in candidates and references. Importance between recall and precision can 69 | # be controled by 'alpha' parameter: 70 | # alpha -> 0: recall is more important 71 | # alpha -> 1: precision is more important 72 | # Following Chapter 7 in C.J. van Rijsbergen's "Information Retrieval". 73 | # http://www.dcs.gla.ac.uk/Keith/Chapter.7/Ch.7.html 74 | # F = 1/(alpha * (1/P) + (1 - alpha) * (1/R)) ;;; weighted harmonic mean 75 | # 1.4.2 76 | # (1) Enforce length limit at the time when summary text is read. Previously (before 77 | # and including v1.4.1), length limit was enforced at tokenization time. 78 | # 1.4.1 79 | # (1) Fix potential over counting in ROUGE-L and ROUGE-W 80 | # In previous version (i.e. 1.4 and order), LCS hit is computed 81 | # by summing union hit over all model sentences. Each model sentence 82 | # is compared with all peer sentences and mark the union LCS. The 83 | # length of the union LCS is the hit of that model sentence. The 84 | # final hit is then sum over all model union LCS hits. This potentially 85 | # would over count a peer sentence which already been marked as contributed 86 | # to some other model sentence. Therefore, double counting is resulted. 87 | # This is seen in evalution where ROUGE-L score is higher than ROUGE-1 and 88 | # this is not correct. 89 | # ROUGEeval-1.4.1.pl fixes this by add a clip function to prevent 90 | # double counting. 91 | # 1.4 92 | # (1) Remove internal Jackknifing procedure: 93 | # Now the ROUGE script will use all the references listed in the 94 | # section in each section and no 95 | # automatic Jackknifing is performed. 96 | # If Jackknifing procedure is required when comparing human and system 97 | # performance, then users have to setup the procedure in the ROUGE 98 | # evaluation configuration script as follows: 99 | # For example, to evaluate system X with 4 references R1, R2, R3, and R4. 100 | # We do the following computation: 101 | # 102 | # for system: and for comparable human: 103 | # s1 = X vs. R1, R2, R3 h1 = R4 vs. R1, R2, R3 104 | # s2 = X vs. R1, R3, R4 h2 = R2 vs. R1, R3, R4 105 | # s3 = X vs. R1, R2, R4 h3 = R3 vs. R1, R2, R4 106 | # s4 = X vs. R2, R3, R4 h4 = R1 vs. R2, R3, R4 107 | # 108 | # Average system score for X = (s1+s2+s3+s4)/4 and for human = (h1+h2+h3+h4)/4 109 | # Implementation of this in a ROUGE evaluation configuration script is as follows: 110 | # Instead of writing all references in a evaluation section as below: 111 | # 112 | # ... 113 | # 114 | #

systemX 115 | # 116 | # 117 | # R1 118 | # R2 119 | # R3 120 | # R4 121 | # 122 | # 123 | # we write the following: 124 | # 125 | # 126 | #

systemX 127 | # 128 | # 129 | # R2 130 | # R3 131 | # R4 132 | # 133 | # 134 | # 135 | # 136 | #

systemX 137 | # 138 | # 139 | # R1 140 | # R3 141 | # R4 142 | # 143 | # 144 | # 145 | # 146 | #

systemX 147 | # 148 | # 149 | # R1 150 | # R2 151 | # R4 152 | # 153 | # 154 | # 155 | # 156 | #

systemX 157 | # 158 | # 159 | # R1 160 | # R2 161 | # R3 162 | # 163 | # 164 | # 165 | # In this case, the system and human numbers are comparable. 166 | # ROUGE as it is implemented for summarization evaluation is a recall-based metric. 167 | # As we increase the number of references, we are increasing the number of 168 | # count units (n-gram or skip-bigram or LCSes) in the target pool (i.e. 169 | # the number ends up in the denominator of any ROUGE formula is larger). 170 | # Therefore, a candidate summary has more chance to hit but it also has to 171 | # hit more. In the end, this means lower absolute ROUGE scores when more 172 | # references are used and using different sets of rerferences should not 173 | # be compared to each other. There is no nomalization mechanism in ROUGE 174 | # to properly adjust difference due to different number of references used. 175 | # 176 | # In the ROUGE implementations before v1.4 when there are N models provided for 177 | # evaluating system X in the ROUGE evaluation script, ROUGE does the 178 | # following: 179 | # (1) s1 = X vs. R2, R3, R4, ..., RN 180 | # (2) s2 = X vs. R1, R3, R4, ..., RN 181 | # (3) s3 = X vs. R1, R2, R4, ..., RN 182 | # (4) s4 = X vs. R1, R2, R3, ..., RN 183 | # (5) ... 184 | # (6) sN= X vs. R1, R2, R3, ..., RN-1 185 | # And the final ROUGE score is computed by taking average of (s1, s2, s3, 186 | # s4, ..., sN). When we provide only three references for evaluation of a 187 | # human summarizer, ROUGE does the same thing but using 2 out 3 188 | # references, get three numbers, and then take the average as the final 189 | # score. Now ROUGE (after v1.4) will use all references without this 190 | # internal Jackknifing procedure. The speed of the evaluation should improve 191 | # a lot, since only one set instead of four sets of computation will be 192 | # conducted. 193 | # 1.3 194 | # (1) Add skip bigram 195 | # (2) Add an option to specify the number of sampling point (default is 1000) 196 | # 1.2.3 197 | # (1) Correct the enviroment variable option: -e. Now users can specify evironment 198 | # variable ROUGE_EVAL_HOME using the "-e" option; previously this option is 199 | # not active. Thanks Zhouyan Li of Concordia University, Canada pointing this 200 | # out. 201 | # 1.2.2 202 | # (1) Correct confidence interval calculation for median, maximum, and minimum. 203 | # Line 390. 204 | # 1.2.1 205 | # (1) Add sentence per line format input format. See files in Verify-SPL for examples. 206 | # (2) Streamline command line arguments. 207 | # (3) Use bootstrap resampling to estimate confidence intervals instead of using t-test 208 | # or z-test which assume a normal distribution. 209 | # (4) Add LCS (longest common subsequence) evaluation method. 210 | # (5) Add WLCS (weighted longest common subsequence) evaluation method. 211 | # (6) Add length cutoff in bytes. 212 | # (7) Add an option to specify the longest ngram to compute. The default is 4. 213 | # 1.2 214 | # (1) Change zero condition check in subroutine &computeNGramScores when 215 | # computing $gram1Score from 216 | # if($totalGram2Count!=0) to 217 | # if($totalGram1Count!=0) 218 | # Thanks Ken Litkowski for this bug report. 219 | # This original script will set gram1Score to zero if there is no 220 | # bigram matches. This should rarely has significant affect the final score 221 | # since (a) there are bigram matches most of time; (b) the computation 222 | # of gram1Score is using Jackknifing procedure. However, this definitely 223 | # did not compute the correct $gram1Score when there is no bigram matches. 224 | # Therefore, users of version 1.1 should definitely upgrade to newer 225 | # version of the script that does not contain this bug. 226 | # Note: To use this script, two additional data files are needed: 227 | # (1) smart_common_words.txt - contains stopword list from SMART IR engine 228 | # (2) WordNet-1.6.exc.db - WordNet 1.6 exception inflexion database 229 | # These two files have to be put in a directory pointed by the environment 230 | # variable: "ROUGE_EVAL_HOME". 231 | # If environment variable ROUGE_EVAL_HOME does not exist, this script will 232 | # will assume it can find these two database files in the current directory. 233 | -------------------------------------------------------------------------------- /rouge/ROUGE/data/WordNet-1.6-Exceptions/WordNet-1.6.exc.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregdurrett/berkeley-doc-summarizer/3c32487b7419e8b76a5ec2e5233de7c75ca93fa8/rouge/ROUGE/data/WordNet-1.6-Exceptions/WordNet-1.6.exc.db -------------------------------------------------------------------------------- /rouge/ROUGE/data/WordNet-1.6-Exceptions/adv.exc: -------------------------------------------------------------------------------- 1 | best well 2 | better well 3 | deeper deeply 4 | farther far 5 | further far 6 | harder hard 7 | hardest hard 8 | -------------------------------------------------------------------------------- /rouge/ROUGE/data/WordNet-1.6-Exceptions/buildExeptionDB.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | use DB_File; 3 | @ARGV!=3&&die "Usage: buildExceptionDB.pl WordNet-exception-file-directory exception-file-extension output-file\n"; 4 | opendir(DIR,$ARGV[0])||die "Cannot open directory $ARGV[0]\n"; 5 | tie %exceptiondb,'DB_File',"$ARGV[2]",O_CREAT|O_RDWR,0640,$DB_HASH or 6 | die "Cannot open exception db file for output: $ARGV[2]\n"; 7 | while(defined($file=readdir(DIR))) { 8 | if($file=~/\.$ARGV[1]$/o) { 9 | print $file,"\n"; 10 | open(IN,"$file")||die "Cannot open exception file: $file\n"; 11 | while(defined($line=)) { 12 | chomp($line); 13 | @tmp=split(/\s+/,$line); 14 | $exceptiondb{$tmp[0]}=$tmp[1]; 15 | print $tmp[0],"\n"; 16 | } 17 | close(IN); 18 | } 19 | } 20 | untie %exceptiondb; 21 | 22 | -------------------------------------------------------------------------------- /rouge/ROUGE/data/WordNet-1.6.exc.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregdurrett/berkeley-doc-summarizer/3c32487b7419e8b76a5ec2e5233de7c75ca93fa8/rouge/ROUGE/data/WordNet-1.6.exc.db -------------------------------------------------------------------------------- /rouge/ROUGE/data/WordNet-2.0-Exceptions/WordNet-2.0.exc.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregdurrett/berkeley-doc-summarizer/3c32487b7419e8b76a5ec2e5233de7c75ca93fa8/rouge/ROUGE/data/WordNet-2.0-Exceptions/WordNet-2.0.exc.db -------------------------------------------------------------------------------- /rouge/ROUGE/data/WordNet-2.0-Exceptions/adv.exc: -------------------------------------------------------------------------------- 1 | best well 2 | better well 3 | deeper deeply 4 | farther far 5 | further far 6 | harder hard 7 | hardest hard 8 | -------------------------------------------------------------------------------- /rouge/ROUGE/data/WordNet-2.0-Exceptions/buildExeptionDB.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | use DB_File; 3 | @ARGV!=3&&die "Usage: buildExceptionDB.pl WordNet-exception-file-directory exception-file-extension output-file\n"; 4 | opendir(DIR,$ARGV[0])||die "Cannot open directory $ARGV[0]\n"; 5 | tie %exceptiondb,'DB_File',"$ARGV[2]",O_CREAT|O_RDWR,0640,$DB_HASH or 6 | die "Cannot open exception db file for output: $ARGV[2]\n"; 7 | while(defined($file=readdir(DIR))) { 8 | if($file=~/\.$ARGV[1]$/o) { 9 | print $file,"\n"; 10 | open(IN,"$file")||die "Cannot open exception file: $file\n"; 11 | while(defined($line=)) { 12 | chomp($line); 13 | @tmp=split(/\s+/,$line); 14 | $exceptiondb{$tmp[0]}=$tmp[1]; 15 | print $tmp[0],"\n"; 16 | } 17 | close(IN); 18 | } 19 | } 20 | untie %exceptiondb; 21 | 22 | -------------------------------------------------------------------------------- /rouge/ROUGE/data/WordNet-2.0.exc.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregdurrett/berkeley-doc-summarizer/3c32487b7419e8b76a5ec2e5233de7c75ca93fa8/rouge/ROUGE/data/WordNet-2.0.exc.db -------------------------------------------------------------------------------- /rouge/ROUGE/data/smart_common_words.txt: -------------------------------------------------------------------------------- 1 | reuters 2 | ap 3 | jan 4 | feb 5 | mar 6 | apr 7 | may 8 | jun 9 | jul 10 | aug 11 | sep 12 | oct 13 | nov 14 | dec 15 | tech 16 | news 17 | index 18 | mon 19 | tue 20 | wed 21 | thu 22 | fri 23 | sat 24 | 's 25 | a 26 | a's 27 | able 28 | about 29 | above 30 | according 31 | accordingly 32 | across 33 | actually 34 | after 35 | afterwards 36 | again 37 | against 38 | ain't 39 | all 40 | allow 41 | allows 42 | almost 43 | alone 44 | along 45 | already 46 | also 47 | although 48 | always 49 | am 50 | amid 51 | among 52 | amongst 53 | an 54 | and 55 | another 56 | any 57 | anybody 58 | anyhow 59 | anyone 60 | anything 61 | anyway 62 | anyways 63 | anywhere 64 | apart 65 | appear 66 | appreciate 67 | appropriate 68 | are 69 | aren't 70 | around 71 | as 72 | aside 73 | ask 74 | asking 75 | associated 76 | at 77 | available 78 | away 79 | awfully 80 | b 81 | be 82 | became 83 | because 84 | become 85 | becomes 86 | becoming 87 | been 88 | before 89 | beforehand 90 | behind 91 | being 92 | believe 93 | below 94 | beside 95 | besides 96 | best 97 | better 98 | between 99 | beyond 100 | both 101 | brief 102 | but 103 | by 104 | c 105 | c'mon 106 | c's 107 | came 108 | can 109 | can't 110 | cannot 111 | cant 112 | cause 113 | causes 114 | certain 115 | certainly 116 | changes 117 | clearly 118 | co 119 | com 120 | come 121 | comes 122 | concerning 123 | consequently 124 | consider 125 | considering 126 | contain 127 | containing 128 | contains 129 | corresponding 130 | could 131 | couldn't 132 | course 133 | currently 134 | d 135 | definitely 136 | described 137 | despite 138 | did 139 | didn't 140 | different 141 | do 142 | does 143 | doesn't 144 | doing 145 | don't 146 | done 147 | down 148 | downwards 149 | during 150 | e 151 | each 152 | edu 153 | eg 154 | e.g. 155 | eight 156 | either 157 | else 158 | elsewhere 159 | enough 160 | entirely 161 | especially 162 | et 163 | etc 164 | etc. 165 | even 166 | ever 167 | every 168 | everybody 169 | everyone 170 | everything 171 | everywhere 172 | ex 173 | exactly 174 | example 175 | except 176 | f 177 | far 178 | few 179 | fifth 180 | five 181 | followed 182 | following 183 | follows 184 | for 185 | former 186 | formerly 187 | forth 188 | four 189 | from 190 | further 191 | furthermore 192 | g 193 | get 194 | gets 195 | getting 196 | given 197 | gives 198 | go 199 | goes 200 | going 201 | gone 202 | got 203 | gotten 204 | greetings 205 | h 206 | had 207 | hadn't 208 | happens 209 | hardly 210 | has 211 | hasn't 212 | have 213 | haven't 214 | having 215 | he 216 | he's 217 | hello 218 | help 219 | hence 220 | her 221 | here 222 | here's 223 | hereafter 224 | hereby 225 | herein 226 | hereupon 227 | hers 228 | herself 229 | hi 230 | him 231 | himself 232 | his 233 | hither 234 | hopefully 235 | how 236 | howbeit 237 | however 238 | i 239 | i'd 240 | i'll 241 | i'm 242 | i've 243 | ie 244 | i.e. 245 | if 246 | ignored 247 | immediate 248 | in 249 | inasmuch 250 | inc 251 | indeed 252 | indicate 253 | indicated 254 | indicates 255 | inner 256 | insofar 257 | instead 258 | into 259 | inward 260 | is 261 | isn't 262 | it 263 | it'd 264 | it'll 265 | it's 266 | its 267 | itself 268 | j 269 | just 270 | k 271 | keep 272 | keeps 273 | kept 274 | know 275 | knows 276 | known 277 | l 278 | lately 279 | later 280 | latter 281 | latterly 282 | least 283 | less 284 | lest 285 | let 286 | let's 287 | like 288 | liked 289 | likely 290 | little 291 | look 292 | looking 293 | looks 294 | ltd 295 | m 296 | mainly 297 | many 298 | may 299 | maybe 300 | me 301 | mean 302 | meanwhile 303 | merely 304 | might 305 | more 306 | moreover 307 | most 308 | mostly 309 | mr. 310 | ms. 311 | much 312 | must 313 | my 314 | myself 315 | n 316 | namely 317 | nd 318 | near 319 | nearly 320 | necessary 321 | need 322 | needs 323 | neither 324 | never 325 | nevertheless 326 | new 327 | next 328 | nine 329 | no 330 | nobody 331 | non 332 | none 333 | noone 334 | nor 335 | normally 336 | not 337 | nothing 338 | novel 339 | now 340 | nowhere 341 | o 342 | obviously 343 | of 344 | off 345 | often 346 | oh 347 | ok 348 | okay 349 | old 350 | on 351 | once 352 | one 353 | ones 354 | only 355 | onto 356 | or 357 | other 358 | others 359 | otherwise 360 | ought 361 | our 362 | ours 363 | ourselves 364 | out 365 | outside 366 | over 367 | overall 368 | own 369 | p 370 | particular 371 | particularly 372 | per 373 | perhaps 374 | placed 375 | please 376 | plus 377 | possible 378 | presumably 379 | probably 380 | provides 381 | q 382 | que 383 | quite 384 | qv 385 | r 386 | rather 387 | rd 388 | re 389 | really 390 | reasonably 391 | regarding 392 | regardless 393 | regards 394 | relatively 395 | respectively 396 | right 397 | s 398 | said 399 | same 400 | saw 401 | say 402 | saying 403 | says 404 | second 405 | secondly 406 | see 407 | seeing 408 | seem 409 | seemed 410 | seeming 411 | seems 412 | seen 413 | self 414 | selves 415 | sensible 416 | sent 417 | serious 418 | seriously 419 | seven 420 | several 421 | shall 422 | she 423 | should 424 | shouldn't 425 | since 426 | six 427 | so 428 | some 429 | somebody 430 | somehow 431 | someone 432 | something 433 | sometime 434 | sometimes 435 | somewhat 436 | somewhere 437 | soon 438 | sorry 439 | specified 440 | specify 441 | specifying 442 | still 443 | sub 444 | such 445 | sup 446 | sure 447 | t 448 | t's 449 | take 450 | taken 451 | tell 452 | tends 453 | th 454 | than 455 | thank 456 | thanks 457 | thanx 458 | that 459 | that's 460 | thats 461 | the 462 | their 463 | theirs 464 | them 465 | themselves 466 | then 467 | thence 468 | there 469 | there's 470 | thereafter 471 | thereby 472 | therefore 473 | therein 474 | theres 475 | thereupon 476 | these 477 | they 478 | they'd 479 | they'll 480 | they're 481 | they've 482 | think 483 | third 484 | this 485 | thorough 486 | thoroughly 487 | those 488 | though 489 | three 490 | through 491 | throughout 492 | thru 493 | thus 494 | to 495 | together 496 | too 497 | took 498 | toward 499 | towards 500 | tried 501 | tries 502 | truly 503 | try 504 | trying 505 | twice 506 | two 507 | u 508 | un 509 | under 510 | unfortunately 511 | unless 512 | unlikely 513 | until 514 | unto 515 | up 516 | upon 517 | us 518 | use 519 | used 520 | useful 521 | uses 522 | using 523 | usually 524 | uucp 525 | v 526 | value 527 | various 528 | very 529 | via 530 | viz 531 | vs 532 | w 533 | want 534 | wants 535 | was 536 | wasn't 537 | way 538 | we 539 | we'd 540 | we'll 541 | we're 542 | we've 543 | welcome 544 | well 545 | went 546 | were 547 | weren't 548 | what 549 | what's 550 | whatever 551 | when 552 | whence 553 | whenever 554 | where 555 | where's 556 | whereafter 557 | whereas 558 | whereby 559 | wherein 560 | whereupon 561 | wherever 562 | whether 563 | which 564 | while 565 | whither 566 | who 567 | who's 568 | whoever 569 | whole 570 | whom 571 | whose 572 | why 573 | will 574 | willing 575 | wish 576 | with 577 | within 578 | without 579 | won't 580 | wonder 581 | would 582 | would 583 | wouldn't 584 | x 585 | y 586 | yes 587 | yet 588 | you 589 | you'd 590 | you'll 591 | you're 592 | you've 593 | your 594 | yours 595 | yourself 596 | yourselves 597 | z 598 | zero 599 | -------------------------------------------------------------------------------- /rouge/ROUGE/rouge-gillick.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd /Users/gdurrett/n/berkeley-doc-summarizer/rouge/ROUGE 4 | config_file=$1 5 | perl ROUGE-1.5.5.pl -e data/ -n 2 -x -m -s $config_file 1 | grep Average 6 | 7 | -------------------------------------------------------------------------------- /run-glpk-test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | jarpath="berkeley-doc-summarizer-assembly-1.jar" 4 | echo "Running with jar located at $jarpath" 5 | # Extracts the existing java library path and adds /usr/local/lib/jni to it 6 | java_lib_path=$(java -cp $jarpath edu.berkeley.nlp.summ.GLPKTest noglpk | head -1) 7 | java_lib_path="$java_lib_path:/usr/local/lib/jni" 8 | echo "Using the following library path: $java_lib_path" 9 | java -ea -server -Xmx3g -Djava.library.path=$java_lib_path -cp $jarpath edu.berkeley.nlp.summ.GLPKTest 10 | 11 | -------------------------------------------------------------------------------- /run-summarizer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | jarpath="berkeley-doc-summarizer-assembly-1.jar" 4 | echo "Running with jar located at $jarpath" 5 | echo "If the assembled project jar isn't located here, edit run-summarizer.sh to point to the correct location" 6 | # Extracts the existing java library path 7 | java_lib_path=$(java -cp $jarpath edu.berkeley.nlp.summ.GLPKTest noglpk | head -1) 8 | java_lib_path="$java_lib_path:/usr/local/lib/jni" 9 | echo "Using the following library path: $java_lib_path" 10 | 11 | if [ -d "test-summaries-extractive" ]; then 12 | rm -rf test-summaries-extractive 13 | fi 14 | mkdir test-summaries-extractive/ 15 | # See edu.berkeley.nlp.summ.Summarizer for additional command line arguments 16 | java -ea -server -Xmx3g -Djava.library.path=$java_lib_path -cp $jarpath edu.berkeley.nlp.summ.Summarizer -inputDir "test/" -outputDir "test-summaries-extractive/" -modelPath "models/summarizer-extractive.ser.gz" -noRst 17 | 18 | if [ -d "test-summaries-extractive-compressive" ]; then 19 | rm -rf test-summaries-extractive-compressive 20 | fi 21 | mkdir test-summaries-extractive-compressive 22 | java -ea -server -Xmx3g -Djava.library.path=$java_lib_path -cp $jarpath edu.berkeley.nlp.summ.Summarizer -inputDir "test/" -outputDir "test-summaries-extractive-compressive" -modelPath "models/summarizer-extractive-compressive.ser.gz" 23 | 24 | if [ -d "test-summaries-full" ]; then 25 | rm -rf test-summaries-full 26 | fi 27 | mkdir test-summaries-full/ 28 | java -ea -server -Xmx3g -Djava.library.path=$java_lib_path -cp $jarpath edu.berkeley.nlp.summ.Summarizer -inputDir "test/" -outputDir "test-summaries-full" 29 | 30 | -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/BigramCountSummarizer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ 2 | 3 | import edu.berkeley.nlp.summ.data.DiscourseDepExProcessed 4 | import edu.berkeley.nlp.summ.data.FragilePronoun 5 | import edu.berkeley.nlp.summ.data.PronounReplacement 6 | 7 | /** 8 | * Baseline summarizer based on Gillick and Favre (2009) that extracts a set of sentences 9 | * aiming to maximize coverage of high-scoring bigram types in the source document. Specifically, 10 | * we optimize for 11 | * 12 | * sum_{bigram in summary bigrams} count(bigram) 13 | * 14 | * @author gdurrett 15 | */ 16 | @SerialVersionUID(1L) 17 | class BigramCountSummarizer(useUnigrams: Boolean = false) extends DiscourseDepExSummarizer { 18 | 19 | def decodeBigramCounts(ex: DiscourseDepExProcessed, budget: Int): (Seq[Int], Seq[Int], Seq[Int], Seq[Int], Seq[Int], Double) = { 20 | val allPronReplacements = Seq[PronounReplacement]() 21 | val leafScores = DiscourseDepExSummarizer.biasTowardsEarlier(Array.fill(ex.eduAlignments.size)(0.0)) 22 | // val pronReplacementScores = Array.tabulate(allPronReplacements.size)(i => ex.getBigramRecallDelta(allPronReplacements(i), useUnigramRouge)) 23 | val pronReplacementScores = Array.fill(allPronReplacements.size)(0.0) 24 | val bigramSeq = ex.getDocBigramsSeq(useUnigrams) 25 | val bigramCounts = ex.getDocBigramCounts(useUnigrams) 26 | val bigramScores = bigramSeq.map(bigram => bigramCounts.getCount(bigram)) 27 | val results = CompressiveAnaphoraSummarizerILP.summarizeILPWithGLPK(ex, ex.getParents("flat"), leafScores, bigramScores, allPronReplacements, pronReplacementScores, budget, 28 | 0, 0, false, Seq[FragilePronoun](), useUnigrams) 29 | results 30 | } 31 | 32 | def summarize(ex: DiscourseDepExProcessed, budget: Int, cleanUpForHumans: Boolean = true): Seq[String] = { 33 | val edus = decodeBigramCounts(ex, budget)._1 34 | ex.getSummaryTextWithPronounsReplaced(edus, Seq(), cleanUpForHumans) 35 | } 36 | 37 | def summarizeOracle(ex: DiscourseDepExProcessed, budget: Int): Seq[String] = { 38 | throw new RuntimeException("Unimplemented") 39 | } 40 | 41 | def display(ex: DiscourseDepExProcessed, budget: Int) { 42 | throw new RuntimeException("Unimplemented") 43 | } 44 | 45 | def printStatistics() { 46 | 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/CompressiveAnaphoraSummarizerILP.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ 2 | 3 | import scala.collection.mutable.ArrayBuffer 4 | import org.gnu.glpk.GLPK 5 | import org.gnu.glpk.GLPKConstants 6 | import edu.berkeley.nlp.summ.data.FragilePronoun 7 | import edu.berkeley.nlp.summ.data.PronounReplacement 8 | import edu.berkeley.nlp.summ.data.DiscourseDepExProcessed 9 | 10 | /** 11 | * ILP implementation for the main summarizer. 12 | */ 13 | object CompressiveAnaphoraSummarizerILP { 14 | 15 | def summarizeILPWithGLPK(ex: DiscourseDepExProcessed, 16 | parents: Seq[Int], 17 | eduScores: Seq[Double], 18 | bigramScores: Seq[Double], 19 | pronReplacements: Seq[PronounReplacement], 20 | pronReplacementScores: Seq[Double], 21 | budget: Int, 22 | numEqualsConstraints: Int, 23 | numParentConstraints: Int, 24 | doPronounConstraints: Boolean, 25 | fragilePronouns: Seq[FragilePronoun], 26 | useUnigrams: Boolean): (Seq[Int], Seq[Int], Seq[Int], Seq[Int], Seq[Int], Double) = { 27 | require(pronReplacements.isEmpty || useUnigrams, "Can't do pronoun replacement correctly with bigrams right now") 28 | val constraints = TreeKnapsackSummarizer.makeConstraints(parents, ex.parentLabels, numEqualsConstraints, numParentConstraints) 29 | val leafSizes = ex.leafSizes 30 | val parentLabels = ex.parentLabels 31 | 32 | // val docBigramsSeq = ex.getEduBigramsMap.reduce(_ ++ _).toSet.toSeq.sorted 33 | val docBigramsSeq = ex.getDocBigramsSeq(useUnigrams) 34 | 35 | val edusContainingBigram = docBigramsSeq.map(bigram => { 36 | (0 until ex.eduAlignments.size).filter(eduIdx => { 37 | val words = ex.getEduWords(eduIdx) 38 | val poss = ex.getEduPoss(eduIdx) 39 | val bigrams = if (useUnigrams) RougeComputer.getUnigramsNoStopwords(words, poss) else RougeComputer.getBigramsNoStopwords(words, poss) 40 | bigrams.contains(bigram) 41 | }) 42 | }) 43 | val pronRepsContainingBigram = docBigramsSeq.map(bigram => new ArrayBuffer[Int]) 44 | for (pronReplacementIdx <- 0 until pronReplacements.size) { 45 | val pronReplacement = pronReplacements(pronReplacementIdx) 46 | // N.B. We currently don't handle deleted bigrams; it's assumed there really won't 47 | // be any because the pronoun replacement should only replace pronouns 48 | val addedDeletedBigrams = ex.getBigramDelta(pronReplacement, useUnigrams) 49 | for (bigram <- addedDeletedBigrams._1.toSeq.sorted) { 50 | val bigramIdx = if (pronReplacement.addedGenitive) { 51 | // This should always be fine because 's is its own token and is a stopword, so 52 | // the only time it appears in a token should be if we made a genitive alteration... 53 | docBigramsSeq.indexOf(bigram._1.replace("'s", "") -> bigram._2.replace("'s", "")) 54 | } else { 55 | docBigramsSeq.indexOf(bigram) 56 | } 57 | // Very occasional non-unique things are possible due to the genitive modification... 58 | // it did cause a crash at one point 59 | if (bigramIdx >= 0 && !pronRepsContainingBigram(bigramIdx).contains(pronReplacementIdx)) { 60 | pronRepsContainingBigram(bigramIdx) += pronReplacementIdx 61 | } 62 | } 63 | } 64 | 65 | val debug = false 66 | 67 | // Turn off output 68 | GLPK.glp_term_out(GLPKConstants.GLP_OFF) 69 | val lp = GLPK.glp_create_prob(); 70 | GLPK.glp_set_prob_name(lp, "myProblem"); 71 | 72 | // Variables 73 | val numEdus = leafSizes.size 74 | val numBigrams = docBigramsSeq.size 75 | val numProns = pronReplacements.size 76 | val numVariables = numEdus + numBigrams + numProns 77 | val bigramOffset = numEdus 78 | val pronOffset = numEdus + numBigrams 79 | val cutOffset = numEdus + numBigrams + numProns 80 | val restartOffset = numEdus + numBigrams + numProns 81 | GLPK.glp_add_cols(lp, numVariables); 82 | for (i <- 0 until numVariables) { 83 | GLPK.glp_set_col_name(lp, i+1, "x" + i); 84 | GLPK.glp_set_col_kind(lp, i+1, GLPKConstants.GLP_BV); 85 | } 86 | // Objective weights 87 | GLPK.glp_set_obj_name(lp, "obj"); 88 | GLPK.glp_set_obj_dir(lp, GLPKConstants.GLP_MAX); 89 | for (i <- 0 until numEdus) { 90 | GLPK.glp_set_obj_coef(lp, i+1, eduScores(i)); 91 | } 92 | for (i <- 0 until numBigrams) { 93 | GLPK.glp_set_obj_coef(lp, bigramOffset+i+1, bigramScores(i)); 94 | } 95 | for (i <- 0 until numProns) { 96 | GLPK.glp_set_obj_coef(lp, pronOffset+i+1, pronReplacementScores(i)); 97 | } 98 | // Constraints 99 | // val numBigramConstraints = numBigrams + edusContainingBigram.map(_.size).foldLeft(0)(_ + _) 100 | // val numPronConstraints = pronReplacements.size * 2 + pronReplacements.map(_.prevEDUsContainingEntity.size).foldLeft(0)(_ + _) 101 | // val numConstraints = constraints.size + 1 + numBigramConstraints + numPronConstraints 102 | // GLPK.glp_add_rows(lp, numConstraints+1); 103 | var constraintIdx = 0 104 | val ind = GLPK.new_intArray(numVariables+1); 105 | val values = GLPK.new_doubleArray(numVariables+1); 106 | // Binary constraints, usually from parent-child relationships 107 | for (i <- 0 until constraints.size) { 108 | val parent = constraints(i)._1 109 | val child = constraints(i)._2 110 | GLPK.intArray_setitem(ind, 1, parent+1); 111 | GLPK.intArray_setitem(ind, 2, child+1); 112 | GLPK.doubleArray_setitem(values, 1, 1); 113 | GLPK.doubleArray_setitem(values, 2, -1); 114 | GLPK.glp_add_rows(lp, 1); 115 | GLPK.glp_set_row_bnds(lp, constraintIdx+1, GLPKConstants.GLP_LO, 0, 0); 116 | GLPK.glp_set_mat_row(lp, constraintIdx+1, 2, ind, values); 117 | constraintIdx += 1 118 | } 119 | // BIGRAMS 120 | // Constraints representing bigrams being in or out 121 | // bigram >= edu for each edu containing that bigram AND for each pron rep implying it 122 | // bigram - edu/pronrep >= 0 123 | for (i <- 0 until numBigrams) { 124 | val edusContaining = edusContainingBigram(i) 125 | val pronRepsContaining = pronRepsContainingBigram(i) 126 | // val pronRepsContaining = Seq[Int]() 127 | for (eduContaining <- edusContaining) { 128 | GLPK.intArray_setitem(ind, 1, bigramOffset+i+1); 129 | GLPK.intArray_setitem(ind, 2, eduContaining+1); 130 | GLPK.doubleArray_setitem(values, 1, 1); 131 | GLPK.doubleArray_setitem(values, 2, -1); 132 | GLPK.glp_add_rows(lp, 1); 133 | GLPK.glp_set_row_bnds(lp, constraintIdx+1, GLPKConstants.GLP_LO, 0, 0); 134 | GLPK.glp_set_mat_row(lp, constraintIdx+1, 2, ind, values); 135 | constraintIdx += 1 136 | } 137 | for (pronRepContaining <- pronRepsContainingBigram(i)) { 138 | GLPK.intArray_setitem(ind, 1, bigramOffset+i+1); 139 | GLPK.intArray_setitem(ind, 2, pronOffset+pronRepContaining+1); 140 | GLPK.doubleArray_setitem(values, 1, 1); 141 | GLPK.doubleArray_setitem(values, 2, -1); 142 | GLPK.glp_add_rows(lp, 1); 143 | GLPK.glp_set_row_bnds(lp, constraintIdx+1, GLPKConstants.GLP_LO, 0, 0); 144 | GLPK.glp_set_mat_row(lp, constraintIdx+1, 2, ind, values); 145 | constraintIdx += 1 146 | } 147 | // bigram <= sum of all edus containing the bigram AND sum of all pron replacements containing bigram 148 | // sum - bigram >= 0 149 | for (eduContainingIdx <- 0 until edusContaining.size) { 150 | GLPK.intArray_setitem(ind, eduContainingIdx+1, edusContaining(eduContainingIdx)+1); 151 | GLPK.doubleArray_setitem(values, eduContainingIdx+1, 1); 152 | } 153 | for (pronRepContainingIdx <- 0 until pronRepsContaining.size) { 154 | GLPK.intArray_setitem(ind, edusContaining.size + pronRepContainingIdx + 1, pronOffset + pronRepsContaining(pronRepContainingIdx)+1); 155 | GLPK.doubleArray_setitem(values, edusContaining.size + pronRepContainingIdx + 1, 1); 156 | } 157 | GLPK.intArray_setitem(ind, edusContaining.size + pronRepsContaining.size + 1, bigramOffset+i+1); 158 | GLPK.doubleArray_setitem(values, edusContaining.size + pronRepsContaining.size + 1, -1); 159 | GLPK.glp_add_rows(lp, 1); 160 | GLPK.glp_set_row_bnds(lp, constraintIdx+1, GLPKConstants.GLP_LO, 0, 0); 161 | GLPK.glp_set_mat_row(lp, constraintIdx+1, edusContaining.size + pronRepsContaining.size + 1, ind, values); 162 | constraintIdx += 1 163 | } 164 | // PRONOUNS 165 | // Pronoun constraints 166 | for (pronReplacementIdx <- 0 until pronReplacements.size) { 167 | val pronReplacement = pronReplacements(pronReplacementIdx) 168 | // Pronoun is only on if the EDU is on: edu - pron >= 0 (pron <= edu) 169 | GLPK.intArray_setitem(ind, 1, pronReplacement.eduIdx+1); 170 | GLPK.intArray_setitem(ind, 2, pronOffset + pronReplacementIdx + 1); 171 | GLPK.doubleArray_setitem(values, 1, 1); 172 | GLPK.doubleArray_setitem(values, 2, -1); 173 | GLPK.glp_add_rows(lp, 1); 174 | GLPK.glp_set_row_bnds(lp, constraintIdx+1, GLPKConstants.GLP_LO, 0, 0); 175 | GLPK.glp_set_mat_row(lp, constraintIdx+1, 2, ind, values); 176 | constraintIdx += 1 177 | // If the current edu is on, then the pronoun is on if none of the previous 178 | // instantiations of it are included. 179 | // pron >= 1 - sum(prev) 180 | // sum(prev_edu) + var >= 1 IF edu is on 181 | // sum(prev_edu) + var >= 0 otherwise 182 | // => var - edu + sum(prev_edu) >= 0 183 | val prevEdus = pronReplacement.prevEDUsContainingEntity 184 | GLPK.intArray_setitem(ind, 1, pronOffset + pronReplacementIdx + 1); 185 | GLPK.doubleArray_setitem(values, 1, 1); 186 | GLPK.intArray_setitem(ind, 2, pronReplacement.eduIdx+1); 187 | GLPK.doubleArray_setitem(values, 2, -1); 188 | for (i <- 0 until prevEdus.size) { 189 | GLPK.intArray_setitem(ind, i+3, prevEdus(i)+1); 190 | GLPK.doubleArray_setitem(values, i+3, 1); 191 | } 192 | GLPK.glp_add_rows(lp, 1); 193 | GLPK.glp_set_row_bnds(lp, constraintIdx+1, GLPKConstants.GLP_LO, 0, 0); 194 | GLPK.glp_set_mat_row(lp, constraintIdx+1, prevEdus.size+2, ind, values); 195 | constraintIdx += 1 196 | // Pronoun is off if any previous instantiation is included 197 | // pron <= (1 - edu) 198 | // pron + edu <= 1 199 | for (i <- 0 until prevEdus.size) { 200 | GLPK.intArray_setitem(ind, 1, pronOffset + pronReplacementIdx + 1); 201 | GLPK.intArray_setitem(ind, 2, prevEdus(i) + 1); 202 | GLPK.doubleArray_setitem(values, 1, 1); 203 | GLPK.doubleArray_setitem(values, 2, 1); 204 | GLPK.glp_add_rows(lp, 1); 205 | GLPK.glp_set_row_bnds(lp, constraintIdx+1, GLPKConstants.GLP_UP, 0, 1); 206 | GLPK.glp_set_mat_row(lp, constraintIdx+1, 2, ind, values); 207 | constraintIdx += 1 208 | } 209 | // For the version where pronouns are constraint, set pronoun <= 0 so we can't 210 | // use pronoun replacements at all (just avoid including pronouns with bad anaphora) 211 | if (doPronounConstraints) { 212 | GLPK.intArray_setitem(ind, 1, pronOffset + pronReplacementIdx + 1); 213 | GLPK.doubleArray_setitem(values, 1, 1); 214 | GLPK.glp_add_rows(lp, 1); 215 | GLPK.glp_set_row_bnds(lp, constraintIdx+1, GLPKConstants.GLP_UP, 0, 0); 216 | GLPK.glp_set_mat_row(lp, constraintIdx+1, 1, ind, values); 217 | constraintIdx += 1 218 | } 219 | } 220 | // FRAGILE PRONOUN CONSTRAINTS 221 | for (fragilePronoun <- fragilePronouns) { 222 | // eduIdx <= 1/(n-0.001)(prev_edu_sum) + pronRep 223 | // eduIdx - stuff - pronRep <= 0 224 | GLPK.intArray_setitem(ind, 1, fragilePronoun.eduIdx + 1); 225 | GLPK.doubleArray_setitem(values, 1, 1); 226 | var currConstIdx = 2 227 | // Subtract from the denominator so the constraint doesn't have floating point issues 228 | for (pastEduIdx <- fragilePronoun.antecedentEdus) { 229 | GLPK.intArray_setitem(ind, currConstIdx, pastEduIdx + 1); 230 | GLPK.doubleArray_setitem(values, currConstIdx, -1.0/(fragilePronoun.antecedentEdus.size - 0.01)); 231 | currConstIdx += 1 232 | } 233 | val correspondingPronRepIndices = pronReplacements.zipWithIndex.filter(_._1.mentIdx == fragilePronoun.mentIdx).map(_._2) 234 | if (!correspondingPronRepIndices.isEmpty) { 235 | GLPK.intArray_setitem(ind, currConstIdx, pronOffset + correspondingPronRepIndices.head + 1); 236 | GLPK.doubleArray_setitem(values, currConstIdx, -1.0); 237 | correspondingPronRepIndices.head 238 | currConstIdx += 1 239 | } 240 | GLPK.glp_add_rows(lp, 1); 241 | GLPK.glp_set_row_bnds(lp, constraintIdx+1, GLPKConstants.GLP_UP, 0, 0); 242 | GLPK.glp_set_mat_row(lp, constraintIdx+1, currConstIdx-1, ind, values); 243 | constraintIdx += 1 244 | } 245 | // CUTS AND RESTARTS 246 | val leq = true 247 | val geq = false 248 | def addConstraint(vars: Seq[Int], coeffs: Seq[Double], isLeq: Boolean, result: Int) { 249 | for (i <- 0 until vars.size) { 250 | GLPK.intArray_setitem(ind, i+1, vars(i)); 251 | GLPK.doubleArray_setitem(values, i+1, coeffs(i)); 252 | } 253 | GLPK.glp_add_rows(lp, 1); 254 | GLPK.glp_set_row_bnds(lp, constraintIdx+1, if (isLeq) GLPKConstants.GLP_UP else GLPKConstants.GLP_LO, if (isLeq) 0 else result, if (isLeq) result else 0); 255 | GLPK.glp_set_mat_row(lp, constraintIdx+1, vars.size, ind, values); 256 | constraintIdx += 1 257 | } 258 | // LENGTH 259 | for (i <- 0 until numEdus) { 260 | GLPK.intArray_setitem(ind, i+1, i+1); 261 | GLPK.doubleArray_setitem(values, i+1, leafSizes(i)); 262 | } 263 | for (i <- 0 until numProns) { 264 | val colIdx = numEdus+i+1 265 | GLPK.intArray_setitem(ind, colIdx, pronOffset+i+1); 266 | // Logger.logss("Additional words: " + (pronReplacements(i-numEdus).replacementWords.size - 1)) 267 | GLPK.doubleArray_setitem(values, colIdx, pronReplacements(i).replacementWords.size - 1); 268 | } 269 | GLPK.glp_add_rows(lp, 1); 270 | GLPK.glp_set_row_bnds(lp, constraintIdx+1, GLPKConstants.GLP_UP, 0, budget); 271 | GLPK.glp_set_mat_row(lp, constraintIdx+1, numEdus + numProns, ind, values); 272 | GLPK.delete_doubleArray(values); 273 | GLPK.delete_intArray(ind); 274 | // require(constraintIdx+1 == numConstraints, constraintIdx+1 + " " + numConstraints) 275 | 276 | val (soln, score) = TreeKnapsackSummarizer.solveILPAndReport(lp, numVariables) 277 | GLPK.glp_delete_prob(lp) 278 | // (soln.filter(_ < numEdus), 279 | // soln.filter(idx => idx >= pronOffset && idx < cutOffset).map(idx => idx - pronOffset), 280 | // soln.filter(idx => idx >= bigramOffset && idx < pronOffset).map(idx => idx - bigramOffset), 281 | // score) 282 | (soln.filter(_ < numEdus), 283 | soln.filter(idx => idx >= pronOffset && idx < cutOffset).map(idx => idx - pronOffset), 284 | soln.filter(idx => idx >= bigramOffset && idx < pronOffset).map(idx => idx - bigramOffset), 285 | soln.filter(idx => idx >= cutOffset && idx < restartOffset).map(idx => idx - cutOffset), 286 | soln.filter(idx => idx >= restartOffset).map(idx => idx - restartOffset), 287 | score) 288 | } 289 | } -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/CorefUtils.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ 2 | 3 | import scala.collection.mutable.ArrayBuffer 4 | import scala.collection.mutable.HashMap 5 | import edu.berkeley.nlp.entity.DepConstTree 6 | import edu.berkeley.nlp.entity.coref.Mention 7 | import edu.berkeley.nlp.entity.coref.PronounDictionary 8 | import edu.berkeley.nlp.entity.coref.MentionType 9 | import edu.berkeley.nlp.entity.coref.CorefDoc 10 | import edu.berkeley.nlp.entity.GUtil 11 | import edu.berkeley.nlp.futile.math.SloppyMath 12 | 13 | object CorefUtils { 14 | 15 | def getAntecedent(corefDoc: CorefDoc, predictor: edu.berkeley.nlp.entity.coref.PairwiseScorer, index: Int) = { 16 | val posteriors = computePosteriors(corefDoc, predictor, Seq(index)) 17 | GUtil.argMaxIdx(posteriors(0)) 18 | } 19 | 20 | def computePosteriors(corefDoc: CorefDoc, predictor: edu.berkeley.nlp.entity.coref.PairwiseScorer, indicesOfInterest: Seq[Int]): Array[Array[Double]] = { 21 | val docGraph = new edu.berkeley.nlp.entity.coref.DocumentGraph(corefDoc, false) 22 | Array.tabulate(indicesOfInterest.size)(idxIdxOfInterest => { 23 | val idx = indicesOfInterest(idxIdxOfInterest) 24 | val scores = Array.tabulate(idx+1)(antIdx => predictor.score(docGraph, idx, antIdx, false).toDouble) 25 | val logNormalizer = scores.foldLeft(Double.NegativeInfinity)(SloppyMath.logAdd(_, _)) 26 | for (antIdx <- 0 until scores.size) { 27 | scores(antIdx) = scores(antIdx) - logNormalizer 28 | } 29 | scores 30 | }) 31 | } 32 | 33 | /** 34 | * This exists to make results consistent with what was there before 35 | */ 36 | def remapMentionType(ment: Mention) = { 37 | val newMentionType = if (ment.endIdx - ment.startIdx == 1 && PronounDictionary.isDemonstrative(ment.rawDoc.words(ment.sentIdx)(ment.headIdx))) { 38 | MentionType.DEMONSTRATIVE; 39 | } else if (ment.endIdx - ment.startIdx == 1 && PronounDictionary.isPronLc(ment.rawDoc.words(ment.sentIdx)(ment.headIdx))) { 40 | MentionType.PRONOMINAL; 41 | } else if (ment.rawDoc.pos(ment.sentIdx)(ment.headIdx) == "NNS" || ment.rawDoc.pos(ment.sentIdx)(ment.headIdx) == "NNPS") { 42 | MentionType.PROPER; 43 | } else { 44 | MentionType.NOMINAL; 45 | } 46 | new Mention(ment.rawDoc, 47 | ment.mentIdx, 48 | ment.sentIdx, 49 | ment.startIdx, 50 | ment.endIdx, 51 | ment.headIdx, 52 | ment.allHeadIndices, 53 | ment.isCoordinated, 54 | newMentionType, 55 | ment.nerString, 56 | ment.number, 57 | ment.gender) 58 | 59 | } 60 | 61 | def getMentionText(ment: Mention) = ment.rawDoc.words(ment.sentIdx).slice(ment.startIdx, ment.endIdx) 62 | 63 | def getMentionNerSpan(ment: Mention): Option[(Int,Int)] = { 64 | // Smallest NER chunk that contains the head 65 | val conllDoc = ment.rawDoc 66 | val matchingChunks = conllDoc.nerChunks(ment.sentIdx).filter(chunk => chunk.start <= ment.headIdx && ment.headIdx < chunk.end); 67 | if (!matchingChunks.isEmpty) { 68 | val smallestChunk = matchingChunks.sortBy(chunk => chunk.end - chunk.start).head; 69 | Some(smallestChunk.start -> smallestChunk.end) 70 | } else { 71 | None 72 | } 73 | } 74 | 75 | def getSpanHeads(tree: DepConstTree, startIdx: Int, endIdx: Int): Seq[Int] = getSpanHeads(tree.childParentDepMap, startIdx, endIdx); 76 | 77 | def getSpanHeads(childParentDepMap: HashMap[Int,Int], startIdx: Int, endIdx: Int): Seq[Int] = { 78 | // If it's a constituent, only one should have a head outside 79 | val outsidePointing = new ArrayBuffer[Int]; 80 | for (i <- startIdx until endIdx) { 81 | val ptr = childParentDepMap(i); 82 | if (ptr < startIdx || ptr >= endIdx) { 83 | outsidePointing += i; 84 | } 85 | } 86 | outsidePointing 87 | } 88 | 89 | def isDefinitelyPerson(str: String): Boolean = { 90 | val canonicalization = PronounDictionary.canonicalize(str) 91 | // N.B. Don't check "we" or "they" because those might be used in inanimate cases 92 | canonicalization == "i" || canonicalization == "you" || canonicalization == "he" || canonicalization == "she" 93 | } 94 | } -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/DiscourseSummarizer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ 2 | 3 | import scala.collection.JavaConverters.asScalaSetConverter 4 | import scala.collection.mutable.ArrayBuffer 5 | import scala.collection.mutable.HashMap 6 | import scala.collection.mutable.HashSet 7 | import org.gnu.glpk.GLPK 8 | import org.gnu.glpk.GLPKConstants 9 | import org.gnu.glpk.glp_iocp 10 | import org.gnu.glpk.glp_prob 11 | import edu.berkeley.nlp.futile.util.Counter 12 | import edu.berkeley.nlp.futile.util.Logger 13 | import edu.berkeley.nlp.summ.data.DiscourseTree 14 | import edu.berkeley.nlp.summ.data.DiscourseNode 15 | import edu.berkeley.nlp.summ.data.StopwordDict 16 | import edu.berkeley.nlp.summ.data.DiscourseDepExProcessed 17 | 18 | /** 19 | * Generic discourse-aware summarizer 20 | */ 21 | trait DiscourseDepExSummarizer extends Serializable { 22 | def summarize(ex: DiscourseDepExProcessed, budget: Int, cleanUpForHumans: Boolean = true): Seq[String] 23 | def summarizeOracle(ex: DiscourseDepExProcessed, budget: Int): Seq[String] 24 | def display(ex: DiscourseDepExProcessed, budget: Int); 25 | def printStatistics(); 26 | } 27 | 28 | object DiscourseDepExSummarizer { 29 | 30 | // Used for tie-breaking in the ILP scores. This seems to be the best way to do it; works better than a similar 31 | // version which biases towards earlier content but generally including more content 32 | def biasTowardsEarlier(leafScores: Seq[Double]) = { 33 | (0 until leafScores.size).map(i => leafScores(i) - (i+1) * 0.000001) 34 | } 35 | } 36 | 37 | /** 38 | * Tree knapsack system of Hirao et al. (2013) and Yoshida et al. (2014), depending on whether 39 | * it is instantiated over gold or predicted discourse dependency trees. EDUs are scored heuristically 40 | * and an EDU can only be included if its parent is included as well. 41 | */ 42 | object TreeKnapsackSummarizer { 43 | 44 | def computeEduValuesUseStopwordSet(leafWordss: Seq[Seq[String]], parents: Seq[Int]): Array[Double] = { 45 | computeEduValues(leafWordss, parents, (word: String) => StopwordDict.stopwords.contains(word)) 46 | } 47 | 48 | def computeEduValuesUsePoss(leafWordss: Seq[Seq[String]], leafPoss: Seq[Seq[String]], parents: Seq[Int]): Array[Double] = { 49 | val leafPossDict = new HashMap[String,String] 50 | for (i <- 0 until leafWordss.size) { 51 | for (j <- 0 until leafWordss(i).size) { 52 | leafPossDict += leafWordss(i)(j) -> leafPoss(i)(j) 53 | } 54 | } 55 | computeEduValues(leafWordss, parents, (word: String) => StopwordDict.stopwordTags.contains(leafPossDict(word))) 56 | } 57 | 58 | def computeEduValues(leafWordss: Seq[Seq[String]], parents: Seq[Int], stopwordTest: (String => Boolean)): Array[Double] = { 59 | val depths = DiscourseTree.computeDepths(parents, Array.fill(parents.size)(""), true) 60 | val wordFreqs = new Counter[String] 61 | for (leafWords <- leafWordss) { 62 | for (word <- leafWords) { 63 | wordFreqs.incrementCount(word, 1.0) 64 | } 65 | } 66 | // Use log counts 67 | for (word <- wordFreqs.keySet.asScala) { 68 | val containsLetter = word.map(c => Character.isLetter(c)).foldLeft(false)(_ || _) 69 | val isStopword = stopwordTest(word) 70 | val isWordValid = containsLetter && !isStopword 71 | if (isWordValid) { 72 | wordFreqs.setCount(word, Math.log(1 + wordFreqs.getCount(word))/Math.log(2)) 73 | } else { 74 | wordFreqs.setCount(word, 0.0) 75 | } 76 | } 77 | Array.tabulate(leafWordss.size)(i => { 78 | val wordSet = leafWordss(i).toSet 79 | var totalCount = 0.0 80 | for (word <- wordSet) { 81 | totalCount += wordFreqs.getCount(word) 82 | } 83 | totalCount.toDouble / depths(i) 84 | }) 85 | } 86 | 87 | def scoreHypothesis(tree: DiscourseTree, budget: Int, eduScores: Array[Double], edusOn: Seq[Int]) = { 88 | var totalScore = 0.0 89 | var budgetUsed = 0 90 | var constraintsViolated = 0 91 | for (edu <- edusOn) { 92 | totalScore += eduScores(edu) 93 | budgetUsed += tree.leaves(edu).leafWords.size 94 | if (tree.parents(edu) != -1 && !edusOn.contains(tree.parents(edu))) { 95 | Logger.logss("Violated constraint! Included " + edu + " without the parent " + tree.parents(edu)) 96 | constraintsViolated += 1 97 | } 98 | } 99 | Logger.logss("Score: " + totalScore + " with budget " + budgetUsed + "/" + budget + "; " + constraintsViolated + " constraints violated") 100 | } 101 | 102 | def summarizeFirstK(leaves: Seq[DiscourseNode], budget: Int): Seq[String] = { 103 | val sents = new ArrayBuffer[String] 104 | var budgetUsed = 0 105 | var leafIdx = 0 106 | var done = false 107 | while (!done) { 108 | val currLeafWords = leaves(leafIdx).leafWords 109 | if (budgetUsed + currLeafWords.size < budget) { 110 | sents += currLeafWords.reduce(_ + " " + _) 111 | budgetUsed += currLeafWords.size 112 | } else { 113 | done = true 114 | // Comment or uncomment this to take partial EDUs 115 | // sents += currLeafWords.slice(0, budget - budgetUsed).reduce(_ + " " + _) 116 | budgetUsed = budget 117 | } 118 | leafIdx += 1 119 | } 120 | sents 121 | } 122 | 123 | def summarizeILP(leafSizes: Seq[Int], parents: Seq[Int], budget: Int, eduScores: Seq[Double], useGurobi: Boolean): (Seq[Int], Double) = { 124 | summarizeILP(leafSizes, parents, Array.fill(parents.size)(""), budget, eduScores, 1, 1) 125 | } 126 | 127 | def summarizeILP(leafSizes: Seq[Int], parents: Seq[Int], labels: Seq[String], budget: Int, eduScores: Seq[Double], numEqualsConstraints: Int, numParentConstraints: Int): (Seq[Int], Double) = { 128 | summarizeILPWithGLPK(leafSizes, budget, eduScores, makeConstraints(parents, labels, numEqualsConstraints, numParentConstraints)) 129 | } 130 | 131 | def makeConstraints(parents: Seq[Int], labels: Seq[String], numEqualsConstraints: Int, numParentConstraints: Int): ArrayBuffer[(Int,Int)] = { 132 | val constraints = new ArrayBuffer[(Int,Int)] 133 | for (i <- 0 until parents.size) { 134 | val isEq = labels(i).startsWith("=") 135 | val isList = labels(i) == "=List" 136 | if (parents(i) != -1 && labels(i).startsWith("=")) { 137 | if (numEqualsConstraints == 1) { 138 | constraints += (parents(i) -> i) 139 | } else if (numEqualsConstraints == 2) { 140 | constraints += (parents(i) -> i) 141 | constraints += (i -> parents(i)) 142 | } 143 | } else { 144 | if (parents(i) != -1) { 145 | if (numParentConstraints == 1) { 146 | constraints += (parents(i) -> i) 147 | } 148 | } 149 | } 150 | } 151 | constraints 152 | } 153 | 154 | // constraints = sequence of (a, b) pairs where b's inclusion implies a's as well (a = parent, b = child in the standard case) 155 | def summarizeILPWithGLPK(leafSizes: Seq[Int], budget: Int, eduScores: Seq[Double], constraints: Seq[(Int,Int)]): (Seq[Int], Double) = { 156 | val debug = false 157 | // Turn off output 158 | GLPK.glp_term_out(GLPKConstants.GLP_OFF) 159 | val lp = GLPK.glp_create_prob(); 160 | GLPK.glp_set_prob_name(lp, "myProblem"); 161 | 162 | // Variables 163 | val numVariables = leafSizes.size 164 | GLPK.glp_add_cols(lp, numVariables); 165 | for (i <- 0 until numVariables) { 166 | GLPK.glp_set_col_name(lp, i+1, "x" + i); 167 | GLPK.glp_set_col_kind(lp, i+1, GLPKConstants.GLP_BV); 168 | } 169 | // Objective weights 170 | GLPK.glp_set_obj_name(lp, "obj"); 171 | GLPK.glp_set_obj_dir(lp, GLPKConstants.GLP_MAX); 172 | for (i <- 0 until leafSizes.size) { 173 | GLPK.glp_set_obj_coef(lp, i+1, eduScores(i)); 174 | } 175 | // Constraints 176 | val numConstraints = constraints.size 177 | GLPK.glp_add_rows(lp, numConstraints+1); 178 | val ind = GLPK.new_intArray(numVariables+1); 179 | val values = GLPK.new_doubleArray(numVariables+1); 180 | // Binary constraints, usually from parent-child relationships 181 | for (i <- 0 until constraints.size) { 182 | val parent = constraints(i)._1 183 | val child = constraints(i)._2 184 | GLPK.intArray_setitem(ind, 1, parent+1); 185 | GLPK.intArray_setitem(ind, 2, child+1); 186 | GLPK.doubleArray_setitem(values, 1, 1); 187 | GLPK.doubleArray_setitem(values, 2, -1); 188 | GLPK.glp_set_row_name(lp, i+1, "c" + (i+1)); 189 | GLPK.glp_set_row_bnds(lp, i+1, GLPKConstants.GLP_LO, 0, 0); 190 | GLPK.glp_set_mat_row(lp, i+1, 2, ind, values); 191 | } 192 | // Length constraint 193 | for (j <- 0 until leafSizes.size) { 194 | GLPK.intArray_setitem(ind, j+1, j+1); 195 | GLPK.doubleArray_setitem(values, j+1, leafSizes(j)); 196 | } 197 | GLPK.glp_set_row_name(lp, numConstraints + 1, "clen"); 198 | GLPK.glp_set_row_bnds(lp, numConstraints + 1, GLPKConstants.GLP_UP, 0, budget); 199 | GLPK.glp_set_mat_row(lp, numConstraints + 1, leafSizes.size, ind, values); 200 | GLPK.delete_doubleArray(values); 201 | GLPK.delete_intArray(ind); 202 | 203 | val (soln, score) = solveILPAndReport(lp, leafSizes.size) 204 | GLPK.glp_delete_prob(lp) 205 | (soln, score) 206 | } 207 | 208 | def solveILPAndReport(lp: glp_prob, numEdus: Int): (Seq[Int],Double) = { 209 | val iocp = new glp_iocp(); 210 | GLPK.glp_init_iocp(iocp); 211 | iocp.setPresolve(GLPKConstants.GLP_ON); 212 | val edusChosen = new ArrayBuffer[Int] 213 | val ret = GLPK.glp_intopt(lp, iocp); 214 | if (ret == 0) { 215 | for (i <- 0 until numEdus) { 216 | val colValue = GLPK.glp_mip_col_val(lp, i+1) 217 | if (colValue == 1) { 218 | edusChosen += i 219 | } 220 | } 221 | } else { 222 | throw new RuntimeException("Couldn't solve!") 223 | } 224 | edusChosen.toSeq -> GLPK.glp_mip_obj_val(lp) 225 | } 226 | 227 | def extractSummary(leafWords: Seq[Seq[String]], selected: Seq[Int]) = { 228 | selected.map(leafIdx => leafWords(leafIdx).reduce(_ + " " + _)) 229 | } 230 | 231 | } -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/GLPKTest.java: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ; 2 | 3 | import org.gnu.glpk.GLPK; 4 | import org.gnu.glpk.glp_prob; 5 | 6 | /** 7 | * Small class to let you easily test whether your Java library path settings 8 | * are correct. Run with 9 | * 10 | * -Djava.library.path=": implements Serializable { 11 | static final long serialVersionUID = 42; 12 | 13 | F first; 14 | S second; 15 | 16 | public F getFirst() { 17 | return first; 18 | } 19 | 20 | public S getSecond() { 21 | return second; 22 | } 23 | 24 | public void setFirst(F pFirst) { 25 | first = pFirst; 26 | } 27 | 28 | public void setSecond(S pSecond) { 29 | second = pSecond; 30 | } 31 | 32 | public Pair reverse() { 33 | return new Pair(second, first); 34 | } 35 | 36 | public boolean equals(Object o) { 37 | if (this == o) 38 | return true; 39 | if (!(o instanceof Pair)) 40 | return false; 41 | 42 | final Pair pair = (Pair) o; 43 | 44 | if (first != null ? !first.equals(pair.first) : pair.first != null) 45 | return false; 46 | if (second != null ? !second.equals(pair.second) : pair.second != null) 47 | return false; 48 | 49 | return true; 50 | } 51 | 52 | public int hashCode() { 53 | int result; 54 | result = (first != null ? first.hashCode() : 0); 55 | result = 29 * result + (second != null ? second.hashCode() : 0); 56 | return result; 57 | } 58 | 59 | public String toString() { 60 | return "(" + getFirst() + ", " + getSecond() + ")"; 61 | } 62 | 63 | public Pair(F first, S second) { 64 | this.first = first; 65 | this.second = second; 66 | } 67 | 68 | // Compares only first values 69 | public static class FirstComparator, T> 70 | implements Comparator> { 71 | public int compare(Pair p1, Pair p2) { 72 | return p1.getFirst().compareTo(p2.getFirst()); 73 | } 74 | } 75 | 76 | public static class ReverseFirstComparator, T> 77 | implements Comparator> { 78 | public int compare(Pair p1, Pair p2) { 79 | return p2.getFirst().compareTo(p1.getFirst()); 80 | } 81 | } 82 | 83 | // Compares only second values 84 | public static class SecondComparator> 85 | implements Comparator> { 86 | public int compare(Pair p1, Pair p2) { 87 | return p1.getSecond().compareTo(p2.getSecond()); 88 | } 89 | } 90 | 91 | public static class ReverseSecondComparator> 92 | implements Comparator> { 93 | public int compare(Pair p1, Pair p2) { 94 | return p2.getSecond().compareTo(p1.getSecond()); 95 | } 96 | } 97 | 98 | public static Pair makePair(S first, T second) { 99 | return new Pair(first, second); 100 | } 101 | 102 | public static class LexicographicPairComparator implements Comparator> { 103 | Comparator firstComparator; 104 | Comparator secondComparator; 105 | 106 | public int compare(Pair pair1, Pair pair2) { 107 | int firstCompare = firstComparator.compare(pair1.getFirst(), pair2.getFirst()); 108 | if (firstCompare != 0) 109 | return firstCompare; 110 | return secondComparator.compare(pair1.getSecond(), pair2.getSecond()); 111 | } 112 | 113 | public LexicographicPairComparator(Comparator firstComparator, Comparator secondComparator) { 114 | this.firstComparator = firstComparator; 115 | this.secondComparator = secondComparator; 116 | } 117 | } 118 | 119 | public static class DefaultLexicographicPairComparator,S extends Comparable> 120 | implements Comparator> { 121 | 122 | public int compare(Pair o1, Pair o2) { 123 | int firstCompare = o1.getFirst().compareTo(o2.getFirst()); 124 | if (firstCompare != 0) { 125 | return firstCompare; 126 | } 127 | return o2.getSecond().compareTo(o2.getSecond()); 128 | } 129 | 130 | } 131 | 132 | } 133 | -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/RougeComputer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ 2 | 3 | import java.io.File 4 | import java.nio.file.Files 5 | import java.nio.file.Paths 6 | import scala.collection.JavaConverters.asScalaBufferConverter 7 | import scala.collection.mutable.ArrayBuffer 8 | import scala.collection.mutable.HashSet 9 | import scala.sys.process.Process 10 | import edu.berkeley.nlp.futile.classify.ClassifyUtils 11 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 12 | import edu.berkeley.nlp.futile.util.Logger 13 | import edu.berkeley.nlp.summ.data.StopwordDict 14 | 15 | /** 16 | * Contains methods for both computing ROUGE losses programmatically as well as 17 | * dispatching to the actual ROUGE scorer for evaluation. 18 | */ 19 | object RougeComputer { 20 | 21 | def getBigrams(words: Seq[String]): Seq[(String,String)] = { 22 | (0 until words.size - 1).map(i => words(i) -> words(i+1)) 23 | } 24 | 25 | def getUnigramsNoStopwords(words: Seq[String], poss: Seq[String]): Seq[(String,String)] = { 26 | val nonStopIndices = (0 until words.size).filter(i => !StopwordDict.stopwordTags.contains(poss(i))) 27 | // Unigrams are encoded as bigrams because... 28 | nonStopIndices.map(idx => words(idx) -> "") 29 | } 30 | 31 | def getBigramsNoStopwords(words: Seq[String], poss: Seq[String]): Seq[(String,String)] = { 32 | // val nonStopIndices = (0 until words.size).filter(i => !stopwordTags.contains(poss(i))) 33 | val nonStopIndices = (0 until words.size).filter(i => !StopwordDict.stopwordTags.contains(poss(i))) 34 | (0 until nonStopIndices.size - 1).map(i => words(nonStopIndices(i)) -> words(nonStopIndices(i+1))) 35 | } 36 | 37 | def computeRouge1SuffStats(sents: Seq[Seq[String]], summSents: Seq[Seq[String]]) = { 38 | val sourceUnigrams = sents.map(_.toSet).foldLeft(new HashSet[String])(_ ++ _) 39 | val targetUnigrams = summSents.map(_.toSet).foldLeft(new HashSet[String])(_ ++ _) 40 | val numHit = (targetUnigrams & sourceUnigrams).size 41 | (numHit, targetUnigrams.size) 42 | } 43 | 44 | def computeRouge2SuffStats(sents: Seq[Seq[String]], summSents: Seq[Seq[String]]) = { 45 | val sourceBigrams = getBigramSet(sents) 46 | val targetBigrams = getBigramSet(summSents) 47 | val numHit = (targetBigrams & sourceBigrams).size 48 | (numHit, targetBigrams.size) 49 | } 50 | 51 | def getBigramSet(sents: Seq[Seq[String]]) = { 52 | val bigrams = new HashSet[(String,String)] 53 | for (sent <- sents) { 54 | for (i <- 0 until sent.size - 1) { 55 | bigrams += sent(i) -> sent(i+1) 56 | } 57 | } 58 | bigrams 59 | } 60 | 61 | def readSummDocJustTextBySents(fileName: String) = { 62 | val lines = IOUtils.readLines(fileName).asScala 63 | val results = lines.map(_.trim).filter(_.size != 0).map(_.split("\\s+").toSeq) 64 | results 65 | } 66 | 67 | def write(fileName: String, data: Seq[Seq[String]]) { 68 | val printWriter = IOUtils.openOutHard(fileName) 69 | data.foreach(sent => printWriter.println(sent.foldLeft("")(_ + " " + _).trim)) 70 | printWriter.close() 71 | } 72 | 73 | /** 74 | * Takes system and reference summaries, writes them to an output directory (which might be deleted), 75 | * and call the ROUGE scorer on them. Note that although the summaries aren't "tokenized" in the sense 76 | * that each line is a single String, they should be the result of taking a tokenized string and joining 77 | * it with spaces, or the ROUGE scorer won't work correctly 78 | */ 79 | def evaluateRougeNonTok(sysSumms: Seq[Seq[String]], 80 | refSumms: Seq[Seq[String]], 81 | rougeDirPath: String, 82 | skipRougeEvaluation: Boolean = false, 83 | keepRougeDirs: Boolean = false, 84 | suppressOutput: Boolean = false): Array[Double] = { 85 | var unigramRecallNum = 0 86 | var unigramRecallDenom = 0 87 | var bigramRecallNum = 0 88 | var bigramRecallDenom = 0 89 | var totalWordsUsed = 0 90 | 91 | val tmpDirPath = Files.createTempDirectory(Paths.get(rougeDirPath), "outputs") 92 | val tmpDirAbsPath = tmpDirPath.toAbsolutePath.toString 93 | val tmpDir = new File(tmpDirAbsPath) 94 | if (!keepRougeDirs) tmpDir.deleteOnExit() 95 | val sysDir = new File(tmpDirAbsPath + "/system") 96 | sysDir.mkdir() 97 | if (!keepRougeDirs) sysDir.deleteOnExit() 98 | val refDir = new File(tmpDirAbsPath + "/reference") 99 | refDir.mkdir() 100 | if (!keepRougeDirs) refDir.deleteOnExit() 101 | val settingsFile = File.createTempFile("settings", ".xml", new File(rougeDirPath)) 102 | if (!keepRougeDirs) settingsFile.deleteOnExit() 103 | 104 | for (i <- 0 until sysSumms.size) { 105 | val fileName = "" + i 106 | val sysSumm = sysSumms(i) 107 | val refSumm = refSumms(i) 108 | 109 | val systemPath = tmpDirAbsPath + "/system/" + fileName + "_system1.txt" 110 | 111 | totalWordsUsed += sysSumm.map(_.split("\\s+").size).foldLeft(0)(_ + _) 112 | val unigramRecallSuffStats = RougeComputer.computeRouge1SuffStats(sysSumm.map(_.split("\\s+").toSeq), refSumm.map(_.split("\\s+").toSeq)) 113 | unigramRecallNum += unigramRecallSuffStats._1 114 | unigramRecallDenom += unigramRecallSuffStats._2 115 | val bigramRecallSuffStats = RougeComputer.computeRouge2SuffStats(sysSumm.map(_.split("\\s+").toSeq), refSumm.map(_.split("\\s+").toSeq)) 116 | bigramRecallNum += bigramRecallSuffStats._1 117 | bigramRecallDenom += bigramRecallSuffStats._2 118 | RougeFileMunger.writeSummary(fileName, sysSumm, systemPath, keepRougeDirs) 119 | // write(systemPath, if (runFirstK) firstK else modelSents) 120 | // val refPath = outDir + "/reference/" + cleanedFileName + "_reference1.txt" 121 | val refPath = tmpDirAbsPath + "/reference/" + fileName + "_reference1.txt" 122 | RougeFileMunger.writeSummary(fileName, refSumm, refPath, keepRougeDirs) 123 | // write(refPath, summ) 124 | } 125 | if (!suppressOutput) Logger.logss("Unigram recall: " + ClassifyUtils.renderNumerDenom(unigramRecallNum, unigramRecallDenom)) 126 | if (!suppressOutput) Logger.logss("Bigram recall: " + ClassifyUtils.renderNumerDenom(bigramRecallNum, bigramRecallDenom)) 127 | if (!suppressOutput) Logger.logss(totalWordsUsed + " words used") 128 | evaluateRouge(settingsFile.getAbsolutePath, tmpDirAbsPath, rougeDirPath, skipRougeEvaluation, suppressOutput) 129 | } 130 | 131 | def evaluateRouge(settingsFileAbsPath: String, outDirAbsPath: String, rougeDirPath: String, skipScoring: Boolean, suppressOutput: Boolean): Array[Double] = { 132 | RougeFileMunger.writeSettings(settingsFileAbsPath, outDirAbsPath) 133 | if (!suppressOutput) Logger.logss("ROUGE OUTPUT: files written to " + outDirAbsPath + " with settings file in " + settingsFileAbsPath) 134 | if (!skipScoring) { 135 | import scala.sys.process._ 136 | val output = Process(Seq(rougeDirPath + "/rouge-gillick.sh", settingsFileAbsPath)).lines; 137 | val lines = new ArrayBuffer[String] 138 | output.foreach(lines += _) 139 | // lines.foreach(Logger.logss(_)) 140 | val nums = lines.map(line => { 141 | val startIdx = line.indexOf(":") + 2 142 | val endIdx = line.indexOf("(") - 1 143 | line.substring(startIdx, endIdx) 144 | }) 145 | if (!suppressOutput) Logger.logss("ROUGE 1 P/R/F1: " + nums(1) + ", " + nums(0) + ", " + nums(2) + "; ROUGE 2 P/R/F1: " + nums(4) + ", " + nums(3) + ", " + nums(5)) 146 | val numsDoubles = nums.map(_.toDouble) 147 | Array(numsDoubles(1), numsDoubles(0), numsDoubles(2), numsDoubles(4), numsDoubles(3), numsDoubles(5)) 148 | } else { 149 | Logger.logss("...skipping evaluation, you have to run it yourself") 150 | Array[Double]() 151 | } 152 | } 153 | 154 | def bootstrap(worseSumms: Seq[Seq[String]], 155 | betterSumms: Seq[Seq[String]], 156 | refSumms: Seq[Seq[String]], 157 | rougeDirPath: String) { 158 | val size = worseSumms.size 159 | require(worseSumms.size == betterSumms.size) 160 | val worseSuffStats = (0 until size).map(i => evaluateRougeNonTok(Seq(worseSumms(i)), Seq(refSumms(i)), rougeDirPath, suppressOutput = true)) 161 | val betterSuffStats = (0 until size).map(i => evaluateRougeNonTok(Seq(betterSumms(i)), Seq(refSumms(i)), rougeDirPath, suppressOutput = true)) 162 | val numStats = worseSuffStats(0).size 163 | // Use macro-averages 164 | val origDiff = (0 until numStats).map(i => betterSuffStats.map(_(i)).foldLeft(0.0)(_ + _)/size - worseSuffStats.map(_(i)).foldLeft(0.0)(_ + _)/size) 165 | val bootstrapSamples = 10000 166 | val rng = new scala.util.Random(0) 167 | val numSig = Array.fill(numStats)(0) 168 | for (sampleIdx <- 0 until bootstrapSamples) { 169 | val resampled = (0 until size).map(i => rng.nextInt(size)); 170 | val newDiff = (0 until numStats).map(i => resampled.map(idx => betterSuffStats(idx)(i)).foldLeft(0.0)(_ + _)/size - resampled.map(idx => worseSuffStats(idx)(i)).foldLeft(0.0)(_ + _)/size) 171 | for (i <- 0 until numStats) { 172 | if (origDiff(i) >= 0 && newDiff(i) < 2 * origDiff(i)) { 173 | numSig(i) += 1 174 | } 175 | } 176 | } 177 | Logger.logss("ROUGE 1 P/R/F1; 2 P/R/F1: " + (0 until numSig.size).map(i => ClassifyUtils.renderNumerDenom(numSig(i), bootstrapSamples))) 178 | } 179 | } -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/RougeFileMunger.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ 2 | 3 | import java.io.File 4 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 5 | import scala.collection.JavaConverters._ 6 | 7 | /** 8 | * Handles writing ROUGE-style XML files 9 | */ 10 | object RougeFileMunger { 11 | 12 | val input = "data/RSTDiscourse/sample-outputs/" 13 | val output = "data/RSTDiscourse/sample-outputs-rouge/" 14 | val settingsPath = "data/RSTDiscourse/rouge-settings.xml" 15 | val detokenize = true 16 | 17 | def writeSummary(fileName: String, sents: Seq[String], outPath: String, keepFile: Boolean) { 18 | val outFile = new File(outPath) 19 | if (!keepFile) outFile.deleteOnExit() 20 | val outWriter = IOUtils.openOutHard(outFile) 21 | outWriter.println("") 22 | outWriter.println("" + fileName + "") 23 | outWriter.println("<") 24 | var counter = 1 25 | for (sent <- sents) { 26 | outWriter.println("[" + counter + "] " + sent + "") 27 | counter += 1 28 | } 29 | outWriter.println("") 30 | outWriter.println("") 31 | outWriter.close 32 | } 33 | 34 | def detokenizeSentence(line: String) = { 35 | line.replace(" ,", ",").replace(" .", ".").replace(" !", "!").replace(" ?", "?").replace(" :", ":").replace(" ;", ";"). 36 | replace("`` ", "``").replace(" ''", "''").replace(" '", "'").replace(" \"", "\"").replace("$ ", "$") 37 | } 38 | 39 | def processFiles(rootPath: String, subDir: String) = { 40 | val refFiles = new File(rootPath + "/" + subDir).listFiles 41 | for (refFile <- refFiles) { 42 | val rawName = refFile.getName() 43 | val name = rawName.substring(0, if (rawName.indexOf("_") == -1) rawName.size else rawName.indexOf("_")) 44 | val lines = IOUtils.readLinesHard(refFile.getAbsolutePath()).asScala.map(sent => if (detokenize) detokenizeSentence(sent) else sent) 45 | writeSummary(name, lines, output + "/" + subDir + "/" + refFile.getName, true) 46 | } 47 | } 48 | 49 | def writeSettings(settingsPath: String, dirPaths: String) { 50 | val outWriter = IOUtils.openOutHard(settingsPath) 51 | outWriter.println("""""") 52 | val rawDirName = new File(dirPaths).getName() 53 | val docs = new File(dirPaths + "/reference").listFiles 54 | var idx = 0 55 | for (doc <- docs) { 56 | val rawName = doc.getName().substring(0, doc.getName.indexOf("_")) 57 | outWriter.println("") 58 | outWriter.println("" + rawDirName + "/reference") 59 | outWriter.println("" + rawDirName + "/system") 60 | outWriter.println(" ") 61 | outWriter.println("") 62 | outWriter.println("

" + rawName + "_system1.txt

") 63 | outWriter.println("
") 64 | outWriter.println("") 65 | outWriter.println("" + rawName + "_reference1.txt") 66 | outWriter.println("") 67 | outWriter.println("
") 68 | idx += 1 69 | } 70 | outWriter.println("") 71 | outWriter.close 72 | } 73 | 74 | def main(args: Array[String]) { 75 | processFiles(input, "reference") 76 | processFiles(input, "system") 77 | writeSettings(settingsPath, output) 78 | } 79 | } -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/Summarizer.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ 2 | 3 | import java.io.File 4 | import edu.berkeley.nlp.entity.ConllDocReader 5 | import edu.berkeley.nlp.entity.coref.CorefDocAssembler 6 | import edu.berkeley.nlp.entity.coref.MentionPropertyComputer 7 | import edu.berkeley.nlp.entity.coref.NumberGenderComputer 8 | import edu.berkeley.nlp.entity.lang.EnglishCorefLanguagePack 9 | import edu.berkeley.nlp.entity.lang.Language 10 | import edu.berkeley.nlp.futile.LightRunner 11 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 12 | import edu.berkeley.nlp.futile.util.Logger 13 | import edu.berkeley.nlp.summ.data.SummDoc 14 | import edu.berkeley.nlp.summ.preprocess.DiscourseDependencyParser 15 | import edu.berkeley.nlp.summ.preprocess.EDUSegmenter 16 | import edu.berkeley.nlp.summ.data.DiscourseDepExProcessed 17 | 18 | /** 19 | * Main class for running the summarizer on unlabeled data. See run-summarizer.sh for 20 | * example usage. The most useful arguments are: 21 | * -inputDir: directory of files (in CoNLL format, with parses/coref/NER) to summarize 22 | * -outputDir: directory to write summaries 23 | * -modelPath if you want to use a different version of the summarizer. 24 | * 25 | * Any member of this class can be passed as a command-line argument to the 26 | * system if it is preceded with a dash, e.g. 27 | * -budget 100 28 | */ 29 | object Summarizer { 30 | 31 | val numberGenderPath = "data/gender.data"; 32 | val segmenterPath = "models/edusegmenter.ser.gz" 33 | val discourseParserPath = "models/discoursedep.ser.gz" 34 | val modelPath = "models/summarizer-full.ser.gz" 35 | 36 | val inputDir = "" 37 | val outputDir = "" 38 | 39 | // Indicates that we shouldn't do any discourse preprocessing; this is only appropriate 40 | // for the sentence-extractive version of the system 41 | val noRst = false 42 | 43 | // Summary budget, in words. Set this to whatever you want it to. 44 | val budget = 50 45 | 46 | def main(args: Array[String]) { 47 | LightRunner.initializeOutput(Summarizer.getClass()) 48 | LightRunner.populateScala(Summarizer.getClass(), args) 49 | 50 | Logger.logss("Loading model...") 51 | val model = IOUtils.readObjFile(modelPath).asInstanceOf[CompressiveAnaphoraSummarizer] 52 | Logger.logss("Model loaded!") 53 | val (segmenter, discourseParser) = if (noRst) { 54 | (None, None) 55 | } else { 56 | Logger.logss("Loading segmenter...") 57 | val tmpSegmenter = IOUtils.readObjFile(segmenterPath).asInstanceOf[EDUSegmenter] 58 | Logger.logss("Segmenter loaded!") 59 | Logger.logss("Loading discourse parser...") 60 | val tmpDiscourseParser = IOUtils.readObjFile(discourseParserPath).asInstanceOf[DiscourseDependencyParser] 61 | Logger.logss("Discourse parser loaded!") 62 | (Some(tmpSegmenter), Some(tmpDiscourseParser)) 63 | } 64 | 65 | val numberGenderComputer = NumberGenderComputer.readBergsmaLinData(numberGenderPath); 66 | val mpc = new MentionPropertyComputer(Some(numberGenderComputer)) 67 | 68 | val reader = new ConllDocReader(Language.ENGLISH) 69 | val assembler = new CorefDocAssembler(new EnglishCorefLanguagePack, true) 70 | val filesToSummarize = new File(inputDir).listFiles() 71 | for (file <- filesToSummarize) { 72 | val conllDoc = reader.readConllDocs(file.getAbsolutePath).head 73 | val corefDoc = assembler.createCorefDoc(conllDoc, mpc) 74 | val summDoc = SummDoc.makeSummDoc(conllDoc.docID, corefDoc, Seq()) 75 | val ex = if (noRst) { 76 | DiscourseDepExProcessed.makeTrivial(summDoc) 77 | } else { 78 | DiscourseDepExProcessed.makeWithEduAndSyntactic(summDoc, segmenter.get, discourseParser.get) 79 | } 80 | val summaryLines = model.summarize(ex, budget, true) 81 | val outWriter = IOUtils.openOutHard(outputDir + "/" + file.getName) 82 | for (summLine <- summaryLines) { 83 | outWriter.println(summLine) 84 | } 85 | outWriter.close 86 | } 87 | LightRunner.finalizeOutput() 88 | } 89 | } -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/compression/SyntacticCompressor.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ.compression 2 | 3 | import scala.collection.mutable.ArrayBuffer 4 | import edu.berkeley.nlp.futile.util.Logger 5 | import scala.collection.mutable.HashMap 6 | import scala.collection.JavaConverters 7 | import edu.berkeley.nlp.futile.LightRunner 8 | import java.util.IdentityHashMap 9 | import edu.berkeley.nlp.futile.syntax.Tree 10 | import scala.collection.JavaConverters._ 11 | import edu.berkeley.nlp.summ.data.SummDoc 12 | 13 | object SyntacticCompressor { 14 | 15 | def findComprParents(comprs: Seq[(Int,Int)]): Seq[Int] = { 16 | var someUnset = true 17 | val parents = Array.tabulate(comprs.size)(i => -2) 18 | while (someUnset) { 19 | var someUnsetThisItr = false 20 | for (idx <- 0 until comprs.size) { 21 | if (parents(idx) == -2) { 22 | var containedInNone = true 23 | var containedIdx = -1 24 | var containerSize = Int.MaxValue 25 | val compr = comprs(idx) 26 | for (comparisonIdx <- 0 until comprs.size) { 27 | if (comparisonIdx != idx) { 28 | val possibleParent = comprs(comparisonIdx) 29 | val possibleParentSize = possibleParent._2 - possibleParent._1 30 | if (isContained(compr, possibleParent) && possibleParentSize < containerSize) { 31 | containedInNone = false 32 | containedIdx = comparisonIdx 33 | containerSize = possibleParentSize 34 | } 35 | } 36 | } 37 | if (containedInNone) { 38 | parents(idx) = -1 39 | } else if (containedIdx != -1) { 40 | parents(idx) = containedIdx 41 | } else { 42 | someUnsetThisItr = true 43 | } 44 | } 45 | } 46 | someUnset = someUnsetThisItr 47 | } 48 | parents 49 | } 50 | 51 | def isContained(containee: (Int,Int), container: (Int,Int)) = { 52 | container._1 <= containee._1 && containee._2 <= container._2 53 | } 54 | 55 | def identifyCuts(tree: Tree[String]) = { 56 | val processedTree = TreeProcessor.processTree(tree) 57 | require(processedTree.getYield.size == tree.getYield.size) 58 | val indicesToCut = new ArrayBuffer[(Int,Int)] 59 | val parentsMap = SentenceCompressor.getParentTrees(processedTree) 60 | identifyCutsHelper(processedTree, 0, processedTree.getYield().size, parentsMap) 61 | } 62 | 63 | val emptyCuts = new ArrayBuffer[(Int,Int)] 64 | 65 | def identifyCutsHelper(tree: Tree[String], startIdx: Int, endIdx: Int, parentsMap: IdentityHashMap[Tree[String],Tree[String]]): ArrayBuffer[(Int,Int)] = { 66 | if (tree.isLeaf) { 67 | emptyCuts 68 | } else { 69 | val legalCuts = new ArrayBuffer[(Int,Int)] 70 | val thisTreeFeats = SentenceCompressor.getCutFeatures(tree, startIdx, endIdx, parentsMap) 71 | if (thisTreeFeats.size > 0) { 72 | legalCuts += (startIdx -> endIdx) 73 | } 74 | val children = tree.getChildren() 75 | var currStartIdx = startIdx 76 | for (child <- children.asScala) { 77 | legalCuts ++= identifyCutsHelper(child, currStartIdx, currStartIdx + child.getYield.size, parentsMap) 78 | currStartIdx += child.getYield.size 79 | } 80 | legalCuts 81 | } 82 | } 83 | 84 | def compress(rawDoc: SummDoc): (Seq[((Int,Int),(Int,Int))], Seq[Int], Seq[String]) = { 85 | val conllDoc = rawDoc.corefDoc.rawDoc 86 | val possibleCompressionsEachSent: Seq[Set[(Int,Int)]] = (0 until conllDoc.numSents).map(i => identifyCuts(conllDoc.trees(i).constTree).toSet) 87 | compress(possibleCompressionsEachSent, conllDoc.words.map(_.size)) 88 | } 89 | 90 | def compress(possibleCompressionsEachSent: Seq[Set[(Int,Int)]], sentLens: Seq[Int]): (Seq[((Int,Int),(Int,Int))], Seq[Int], Seq[String]) = { 91 | val numSents = sentLens.size 92 | val chunks = new ArrayBuffer[((Int,Int),(Int,Int))] 93 | val parents = new ArrayBuffer[Int] 94 | val labels = new ArrayBuffer[String] 95 | for (sentIdx <- 0 until numSents) { 96 | val (sentChunks, sentParents, sentLabels) = compressSentence(possibleCompressionsEachSent(sentIdx), sentLens(sentIdx)) 97 | val sentOffset = chunks.size 98 | chunks ++= sentChunks.map(chunk => (sentIdx -> chunk._1) -> (sentIdx -> chunk._2)) 99 | // Offset all chunks by the number of chunks occurring earlier in the document, same with 100 | parents ++= sentParents.map(parent => if (parent == -1) -1 else parent + sentOffset) 101 | labels ++= sentLabels 102 | } 103 | (chunks.toSeq, parents.toSeq, labels.toSeq) 104 | } 105 | 106 | def compressSentence(possibleCompressions: Set[(Int,Int)], sentLen: Int) = { 107 | // Logger.logss(possibleCompressions + " " + sentLen) 108 | val orderedComprs = Seq((0, sentLen)) ++ possibleCompressions.toSeq.sortBy(_._1) 109 | val sentComprParents = findComprParents(orderedComprs) 110 | // Segments start at 0 and wherever something happens with the compressions. The last always ends 111 | // at the end of the sentence but this doesn't appear in the list 112 | val boundaries = (Seq(0) ++ orderedComprs.map(_._1) ++ orderedComprs.map(_._2).filter(_ != sentLen)).toSet.toSeq.sorted 113 | val segmentToCompressionMapping = Array.tabulate(boundaries.size)(idx => { 114 | val start = boundaries(idx) 115 | val end = if (idx == boundaries.size - 1) sentLen else boundaries(idx+1) 116 | var segIdx = -1 117 | var comprLen = Int.MaxValue 118 | // Find the smallest compression containing this segment 119 | for (comprIdx <- 0 until orderedComprs.size) { 120 | // if (isContained((start, end), orderedComprs(comprIdx)) && (sentComprParents(comprIdx) == -1 || !isContained((start, end), orderedComprs(sentComprParents(comprIdx))))) { 121 | val newComprLen = orderedComprs(comprIdx)._2 - orderedComprs(comprIdx)._1 122 | if (isContained((start, end), orderedComprs(comprIdx)) && newComprLen < comprLen) { 123 | segIdx = comprIdx 124 | comprLen = newComprLen 125 | } 126 | } 127 | segIdx 128 | }) 129 | // Logger.logss("B: " + boundaries) 130 | // Logger.logss("STCM: " + segmentToCompressionMapping.toSeq) 131 | val compressionToSegmentMapping = Array.tabulate(orderedComprs.size)(comprIdx => { 132 | (0 until segmentToCompressionMapping.size).filter(segIdx => segmentToCompressionMapping(segIdx) == comprIdx) 133 | }) 134 | val sentParents = new ArrayBuffer[Int] 135 | val sentLabels = new ArrayBuffer[String] 136 | // If it's not in a compression, hook it to the first one of the sentence with =SameUnit 137 | // If it is a compression, hook it to its immediately larger compression or to the first one 138 | // of the sentence with Compression 139 | for (i <- 0 until boundaries.size) { 140 | val myCompr = segmentToCompressionMapping(i) 141 | val myOtherSegs = compressionToSegmentMapping(myCompr) 142 | if (myOtherSegs.indexOf(i) == 0) { 143 | if (sentComprParents(myCompr) == -1) { 144 | sentParents += -1 145 | sentLabels += "" 146 | } else { 147 | var parentWithSegs = sentComprParents(myCompr) 148 | while (parentWithSegs != -1 && compressionToSegmentMapping(parentWithSegs).isEmpty) { 149 | parentWithSegs = sentComprParents(parentWithSegs) 150 | } 151 | if (parentWithSegs == -1) { 152 | sentParents += -1 153 | sentLabels += "Compression" 154 | } else { 155 | sentParents += compressionToSegmentMapping(parentWithSegs).head 156 | sentLabels += "Compression" 157 | } 158 | } 159 | // Return the parent 160 | } else { 161 | sentParents += myOtherSegs(0) 162 | sentLabels += "=SameUnit" 163 | } 164 | } 165 | val sentChunks = (0 until boundaries.size).map(i => boundaries(i) -> (if (i == boundaries.size - 1) sentLen else boundaries(i+1))) 166 | (sentChunks, sentParents, sentLabels) 167 | } 168 | 169 | def refineEdus(rawDoc: SummDoc, eduAlignments: Seq[((Int,Int),(Int,Int))], parents: Seq[Int], labels: Seq[String]): (Seq[((Int,Int),(Int,Int))], Seq[Int], Seq[String]) = { 170 | val conllDoc = rawDoc.corefDoc.rawDoc 171 | val possibleCompressionsEachSent: Seq[Set[(Int,Int)]] = (0 until conllDoc.numSents).map(i => identifyCuts(conllDoc.trees(i).constTree).toSet) 172 | refineEdus(possibleCompressionsEachSent, conllDoc.words.map(_.size), eduAlignments, parents, labels) 173 | } 174 | 175 | def refineEdus(possibleCompressionsEachSent: Seq[Set[(Int,Int)]], sentLens: Seq[Int], edus: Seq[((Int,Int),(Int,Int))], parents: Seq[Int], labels: Seq[String]) = { 176 | require(edus.size == parents.size && edus.size == labels.size, edus.size + " " + parents.size + " " + labels.size) 177 | val numSents = sentLens.size 178 | val newChunks = new ArrayBuffer[((Int,Int),(Int,Int))] 179 | val newParents = new ArrayBuffer[Int] 180 | val newLabels = new ArrayBuffer[String] 181 | for (sentIdx <- 0 until numSents) { 182 | val sentEduStartIdx = edus.filter(_._1._1 < sentIdx).size 183 | val sentEduEndIdx = edus.filter(_._1._1 <= sentIdx).size 184 | require(sentEduEndIdx > sentEduStartIdx) 185 | val origEdus = edus.slice(sentEduStartIdx, sentEduEndIdx).map(edu => edu._1._2 -> edu._2._2) 186 | val adjustedParents = parents.slice(sentEduStartIdx, sentEduEndIdx).map(parent => if (parent < sentEduStartIdx || parent >= sentEduEndIdx) -1 else (parent - sentEduStartIdx)) 187 | val (sentChunks, sentParents, sentLabels) = refineEdusInSentence(possibleCompressionsEachSent(sentIdx), origEdus, adjustedParents, labels.slice(sentEduStartIdx, sentEduEndIdx)) 188 | val sentOffset = newChunks.size 189 | newChunks ++= sentChunks.map(chunk => (sentIdx -> chunk._1) -> (sentIdx -> chunk._2)) 190 | // Offset all chunks by the number of chunks occurring earlier in the document, same with 191 | newParents ++= sentParents.map(parent => if (parent == -1) -1 else parent + sentOffset) 192 | newLabels ++= sentLabels 193 | } 194 | (newChunks.toSeq, newParents.toSeq, newLabels.toSeq) 195 | } 196 | 197 | def refineEdusInSentence(possibleCompressions: Set[(Int,Int)], edus: Seq[(Int,Int)], parents: Seq[Int], labels: Seq[String]): (Seq[(Int,Int)], Seq[Int], Seq[String]) = { 198 | require(edus.size == parents.size && edus.size == labels.size, edus.size + " " + parents.size + " " + labels.size) 199 | // For each EDU, try to refine it and make a little parent structure 200 | // val newParents = new ArrayBuffer[Int] 201 | // val newLabels = new ArrayBuffer[String] 202 | val newGroupEdus = new ArrayBuffer[Seq[(Int,Int)]] 203 | val newGroupParents = new ArrayBuffer[Seq[Int]] 204 | val newGroupLabels = new ArrayBuffer[Seq[String]] 205 | 206 | var parentOffset = 0 207 | for (eduIdx <- 0 until edus.size) { 208 | val edu = edus(eduIdx) 209 | val possibleCompressionsThisEdu = possibleCompressions.filter(compr => compr._1 >= edu._1 && compr._2 <= edu._2 && compr != edu).map(compr => (compr._1 - edu._1) -> (compr._2 - edu._1)) 210 | if (possibleCompressionsThisEdu.isEmpty) { 211 | newGroupEdus += Seq(edu) 212 | newGroupParents += Seq(-1) 213 | require(eduIdx < labels.size, edus.size + " " + parents.size + " " + labels.size + " " + eduIdx) 214 | newGroupLabels += Seq(labels(eduIdx)) 215 | } else { 216 | val (eduChunks, eduParents, eduLabels) = compressSentence(possibleCompressionsThisEdu, edu._2 - edu._1) 217 | newGroupEdus += eduChunks.map(eduChunk => (eduChunk._1 + edu._1) -> (eduChunk._2 + edu._1)) 218 | newGroupParents += eduParents 219 | newGroupLabels += eduLabels 220 | } 221 | } 222 | require(newGroupEdus.size == parents.size, newGroupEdus.size + " " + parents.size) 223 | val cumEdus = new ArrayBuffer[Int] 224 | var currNum = 0 225 | for (i <- 0 until newGroupEdus.size) { 226 | require(newGroupParents(i).contains(-1), "No root of the subtree! NGE: " + newGroupEdus + "; NGP: " + newGroupParents + "; NGL: " + newGroupLabels) 227 | cumEdus += currNum 228 | currNum += newGroupEdus(i).size 229 | } 230 | // Note that we can have multiple heads if you have (0, 6) and (6, 9) fully spanning an EDU. 231 | val eduHeads = Array.tabulate(newGroupEdus.size)(i => { 232 | (0 until newGroupEdus(i).size).filter(j => newGroupParents(i)(j) == -1).map(cumEdus(i) + _) 233 | }) 234 | // We need to update the parents. Within each EDUs sub-parts, we have parents 235 | val newParents = new ArrayBuffer[Int] 236 | val newLabels = new ArrayBuffer[String] 237 | for (i <- 0 until newGroupEdus.size) { 238 | for (j <- 0 until newGroupEdus(i).size) { 239 | if (newGroupParents(i)(j) == -1) { 240 | val parent = parents(i) 241 | if (parent == -1) { 242 | newParents += -1 243 | } else { 244 | newParents += eduHeads(parent).head 245 | } 246 | // When you have multiple heads of an EDU, that means it's fully spanned by possible 247 | // compressions. Handle this appropriately 248 | if (eduHeads(i).size > 1) { 249 | newLabels += "Compression" 250 | } else { 251 | newLabels += labels(i) 252 | } 253 | } else { 254 | newParents += cumEdus(i) + newGroupParents(i)(j) 255 | newLabels += newGroupLabels(i)(j) 256 | } 257 | } 258 | } 259 | val newEdus = newGroupEdus.flatten.toSeq 260 | (newEdus, newParents.toSeq, newLabels) 261 | } 262 | 263 | def main(args: Array[String]) { 264 | LightRunner.initializeOutput(SyntacticCompressor.getClass()) 265 | Logger.logss(", stuff".drop(1).trim) 266 | // Should be 2, -1, 1, -1 267 | val cuts = Seq((0, 1), (0, 3), (0, 2), (5, 6)) 268 | println(findComprParents(cuts).toSeq) 269 | // Should be -1, 3, 0, 2, 0 270 | println(findComprParents(Seq((0, 6), (0, 1), (0, 3), (0, 2), (5, 6))).toSeq) 271 | println(compress(Seq(cuts.toSet), Seq(6))) 272 | println(compress(Seq(cuts.toSet), Seq(7))) 273 | val cuts2 = Seq((0, 3), (1, 2), (5, 6)) 274 | println(compress(Seq(cuts2.toSet), Seq(7))) 275 | println(compress(Seq(cuts.toSet, cuts2.toSet), Seq(7, 7))) 276 | val cuts3 = Seq((0, 6), (6, 9)) 277 | println(compress(Seq(cuts3.toSet), Seq(9))) 278 | val cuts4 = Seq((0, 6), (6, 9), (0, 9)) 279 | println(compress(Seq(cuts4.toSet), Seq(10))) 280 | 281 | // Refining EDUs 282 | println(refineEdusInSentence(Set((0, 1), (6, 7)), Seq((0, 5), (5, 10)), Seq(-1, 0), Seq("None", "Elaboration"))) 283 | LightRunner.finalizeOutput() 284 | 285 | } 286 | } -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/compression/TreeProcessor.java: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ.compression; 2 | 3 | 4 | import java.util.ArrayList; 5 | import java.util.List; 6 | 7 | import edu.berkeley.nlp.futile.syntax.Tree; 8 | 9 | public class TreeProcessor { 10 | 11 | public static Tree processTree(Tree tree) { 12 | tree = tree.clone(); 13 | bundleFinalAttributionClause(tree); 14 | bundleCCAndCoordinatedPhrase(tree); 15 | // displayTree(tree); 16 | return tree; 17 | } 18 | 19 | public static void bundleFinalAttributionClause(Tree tree) { 20 | if (tree.isLeaf()) return; 21 | for (Tree child : tree.getChildren()) bundleFinalAttributionClause(child); 22 | 23 | int numChildren = tree.getChildren().size(); 24 | if ((tree.getLabel().equals("S") || tree.getLabel().equals("SINV")) && numChildren >= 3 && tree.getChild(numChildren-3).getLabel().equals("S") && tree.getChild(numChildren-2).getLabel().equals("NP") && tree.getChild(numChildren-1).getLabel().equals("VP")) { 25 | Tree NPTree = tree.getChild(numChildren-2); 26 | Tree VPTree = tree.getChild(numChildren-1); 27 | List VPChildLabels = new ArrayList(); 28 | for (Tree child : VPTree.getChildren()) { 29 | VPChildLabels.add(child.getLabel()); 30 | } 31 | String VPHeadWord = SentenceCompressor.getHeadWord(VPTree); 32 | 33 | Tree NPTreeInObjectPosition = null; 34 | Tree SorSBARTreeInObjectPosition = null; 35 | if (VPTree.getChildren().size() >= 2) { 36 | for (int c=1; c newChild = new Tree("SATTR", new ArrayList>()); 51 | newChild.getChildren().add(NPTree); 52 | newChild.getChildren().add(VPTree); 53 | tree.getChildren().remove(numChildren-1); 54 | tree.getChildren().remove(numChildren-2); 55 | tree.getChildren().add(newChild); 56 | } 57 | } 58 | } 59 | if ((tree.getLabel().equals("S") || tree.getLabel().equals("SINV")) && numChildren >= 3 && tree.getChild(numChildren-3).getLabel().equals("S") && tree.getChild(numChildren-2).getLabel().equals("VP") && tree.getChild(numChildren-1).getLabel().equals("NP")) { 60 | Tree NPTree = tree.getChild(numChildren-1); 61 | Tree VPTree = tree.getChild(numChildren-2); 62 | List VPChildLabels = new ArrayList(); 63 | for (Tree child : VPTree.getChildren()) { 64 | VPChildLabels.add(child.getLabel()); 65 | } 66 | String VPHeadWord = SentenceCompressor.getHeadWord(VPTree); 67 | 68 | if (VPTree.getYield().size() == 1) { 69 | if (SentenceCompressor.attributionVPHeads.contains(VPHeadWord)) { 70 | Tree newChild = new Tree("SATTR", new ArrayList>()); 71 | newChild.getChildren().add(VPTree); 72 | newChild.getChildren().add(NPTree); 73 | tree.getChildren().remove(numChildren-1); 74 | tree.getChildren().remove(numChildren-2); 75 | tree.getChildren().add(newChild); 76 | } 77 | } 78 | } 79 | } 80 | 81 | public static void bundleCCAndCoordinatedPhrase(Tree tree) { 82 | if (tree.isLeaf()) return; 83 | for (Tree child : tree.getChildren()) bundleCCAndCoordinatedPhrase(child); 84 | 85 | List> oldChildren = tree.getChildren(); 86 | List> newChildren = new ArrayList>(); 87 | for (int c=0; c child = oldChildren.get(c); 89 | if (child.getLabel().equals("CC") && c < oldChildren.size()-1 && oldChildren.get(c+1).getLabel().equals(tree.getLabel())) { 90 | oldChildren.get(c+1).getChildren().add(0, child); 91 | } else { 92 | newChildren.add(child); 93 | } 94 | } 95 | tree.setChildren(newChildren); 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/data/DepParse.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ.data 2 | 3 | import java.io.BufferedReader 4 | import scala.collection.JavaConverters.asScalaBufferConverter 5 | import scala.collection.mutable.ArrayBuffer 6 | import edu.berkeley.nlp.entity.ConllDoc 7 | import edu.berkeley.nlp.entity.DepConstTree 8 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 9 | import edu.berkeley.nlp.futile.syntax.Trees.PennTreeReader 10 | import edu.berkeley.nlp.entity.lang.Language 11 | import edu.berkeley.nlp.entity.ConllDocReader 12 | 13 | /** 14 | * Named DepParse for legacy reasons -- actually just a tagged sentence. 15 | */ 16 | trait DepParse extends Serializable { 17 | def size: Int 18 | def getWord(idx: Int): String; 19 | def getWords: Array[String]; 20 | def getPos(idx: Int): String; 21 | def getPoss: Array[String]; 22 | } 23 | 24 | object DepParse { 25 | 26 | def readFromFile(file: String): Seq[DepParseRaw] = { 27 | val reader = IOUtils.openInHard(file) 28 | val parses = readFromFile(reader) 29 | reader.close() 30 | parses 31 | } 32 | 33 | def readFromFile(reader: BufferedReader): Seq[DepParseRaw] = { 34 | val lineItr = IOUtils.lineIterator(reader) 35 | val sents = new ArrayBuffer[DepParseRaw]; 36 | val currWords = new ArrayBuffer[String]; 37 | val currPoss = new ArrayBuffer[String]; 38 | val currParents = new ArrayBuffer[Int]; 39 | val currLabels = new ArrayBuffer[String]; 40 | while (lineItr.hasNext) { 41 | val line = lineItr.next 42 | if (line.trim.isEmpty) { 43 | sents += new DepParseRaw(currWords.toArray, currPoss.toArray) 44 | currWords.clear() 45 | currPoss.clear() 46 | currParents.clear() 47 | currLabels.clear() 48 | } else { 49 | val splitLine = line.split("\\s+") 50 | // println(splitLine.toSeq) 51 | require(splitLine.size == 10, "Wrong number of fields in split line " + splitLine.size + "; splits = " + splitLine.toSeq) 52 | currWords += splitLine(1) 53 | currPoss += splitLine(4) 54 | currParents += splitLine(6).toInt - 1 55 | currLabels += splitLine(7) 56 | } 57 | } 58 | if (!currWords.isEmpty) { 59 | sents += new DepParseRaw(currWords.toArray, currPoss.toArray) 60 | } 61 | sents 62 | } 63 | 64 | def readFromConstFile(file: String): Seq[DepParseRaw] = { 65 | val sents = new ArrayBuffer[DepParseRaw]; 66 | val reader = IOUtils.openInHard(file) 67 | val lineItr = IOUtils.lineIterator(reader) 68 | while (lineItr.hasNext) { 69 | val tree = PennTreeReader.parseHard(lineItr.next, false) 70 | val words = tree.getYield().asScala 71 | val pos = tree.getYield().asScala 72 | sents += new DepParseRaw(words.toArray, pos.toArray) 73 | } 74 | reader.close() 75 | sents 76 | } 77 | 78 | def readFromConllFile(file: String): Seq[DepParseConllWrapped] = { 79 | val conllDoc = new ConllDocReader(Language.ENGLISH).readConllDocs(file, -1).head 80 | (0 until conllDoc.numSents).map(i => new DepParseConllWrapped(conllDoc, i)) 81 | } 82 | 83 | def fromDepConstTree(depConstTree: DepConstTree) = { 84 | // Build a dep parse with no labels 85 | val parents = Array.tabulate(depConstTree.size)(i => depConstTree.childParentDepMap(i)) 86 | new DepParseRaw(depConstTree.words.toArray, depConstTree.pos.toArray) 87 | } 88 | } 89 | 90 | @SerialVersionUID(1565751158494431431L) 91 | class DepParseRaw(val words: Array[String], 92 | val poss: Array[String]) extends DepParse { 93 | 94 | def size = words.size 95 | def getWord(idx: Int): String = words(idx) 96 | def getWords = words 97 | def getPos(idx: Int): String = poss(idx) 98 | def getPoss = poss 99 | } 100 | 101 | @SerialVersionUID(4946839606729754922L) 102 | class DepParseConllWrapped(val conllDoc: ConllDoc, 103 | val sentIdx: Int) extends DepParse { 104 | 105 | def size = conllDoc.words(sentIdx).size 106 | def getWord(idx: Int): String = conllDoc.words(sentIdx)(idx) 107 | def getWords = conllDoc.words(sentIdx).toArray 108 | def getPos(idx: Int): String = conllDoc.pos(sentIdx)(idx) 109 | def getPoss = conllDoc.pos(sentIdx).toArray 110 | } 111 | -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/data/DepParseDoc.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ.data 2 | 3 | import scala.collection.mutable.HashSet 4 | import scala.collection.mutable.ArrayBuffer 5 | 6 | trait DepParseDoc extends Serializable { 7 | 8 | def name: String 9 | def doc: Seq[DepParse] 10 | def summary: Seq[DepParse] 11 | 12 | override def toString() = { 13 | toString(Int.MaxValue) 14 | } 15 | 16 | def toString(maxNumSentences: Int) = { 17 | "DOCUMENT:\n" + doc.map(_.getWords.reduce(_ + " " + _)).slice(0, Math.min(maxNumSentences, doc.size)).reduce(_ + "\n" + _) + 18 | "\nSUMMARY:\n" + summary.map(_.getWords.reduce(_ + " " + _)).slice(0, Math.min(maxNumSentences, doc.size)).reduce(_ + "\n" + _) 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/data/DiscourseTree.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ.data 2 | 3 | import scala.collection.mutable.ArrayBuffer 4 | import scala.collection.mutable.HashMap 5 | import scala.collection.mutable.HashSet 6 | 7 | import edu.berkeley.nlp.futile.tokenizer.PTBLineLexer 8 | 9 | 10 | case class DiscourseTree(val name: String, 11 | val rootNode: DiscourseNode) extends Serializable { 12 | 13 | // Set parent of each node 14 | DiscourseTree.setParentsHelper(rootNode) 15 | 16 | // Set head of each node 17 | // This is according to the method described in Hirao et al. (2013), 18 | // but it doesn't actually lead to sensible heads. 19 | val leaves = DiscourseTree.getLeaves(rootNode) 20 | def numLeaves = leaves.size 21 | def leafWords = leaves.map(_.leafWords.toSeq) 22 | val leafStatuses = leaves.map(_.label) 23 | private def setHiraoHeadsHelper(node: DiscourseNode) { 24 | // TODO: Actually percolate heads, only break ties if two Ns 25 | var leftmostN = -1 26 | for (idx <- node.span._1 until node.span._2) { 27 | if (leftmostN == -1 && leafStatuses(idx) == "Nucleus") { 28 | leftmostN = idx; 29 | } 30 | } 31 | node.hiraoHead = if (leftmostN == -1) { 32 | node.span._1 33 | } else { 34 | leftmostN 35 | } 36 | for (child <- node.children) { 37 | setHiraoHeadsHelper(child) 38 | } 39 | } 40 | setHiraoHeadsHelper(rootNode) 41 | DiscourseTree.setRecursiveHeadsHelper(rootNode) 42 | 43 | // Determine dependency structure 44 | // This is the method specified in Hirao et al. but it 45 | // doesn't seem to do well; you tend to end up with pretty 46 | // shallow structures and they look a bit weird overall. 47 | val hiraoParents = Array.tabulate(leafStatuses.size)(i => { 48 | val leaf = leaves(i) 49 | if (leafStatuses(i) == DiscourseNode.Nucleus) { 50 | // Find the nearest dominating S, then assign to 51 | // the head that S's parent. This is because only 52 | // S's are in subordinating relations so we need to find 53 | // one in order to establish the hierarchy. 54 | var node = leaf.parent 55 | while (node != null && node.label != DiscourseNode.Satellite) { 56 | node = node.parent 57 | } 58 | if (node == null) { 59 | -1 60 | } else { 61 | node.parent.head 62 | } 63 | } else { 64 | require(leafStatuses(i) == DiscourseNode.Satellite) 65 | // Find the first parent with a head that's not this 66 | var node = leaf.parent 67 | while (node.head == leaf.span._1) { 68 | node = node.parent 69 | } 70 | node.head 71 | } 72 | }) 73 | // "Best" parent method, where your depth ends up being the number of 74 | // Ss between you and the root + 1 75 | private val advancedDepTree = DiscourseTree.setAdvancedParents(rootNode, leafStatuses.size, false) 76 | private val advancedParents = advancedDepTree.parents 77 | private val advancedLabels = advancedDepTree.labels.map(DiscourseTree.getCoarseLabel(_)) 78 | private val apRootIndices = advancedParents.zipWithIndex.filter(_._1 == -1).map(_._2) 79 | private val advancedParentsOneRoot = Array.tabulate(advancedParents.size)(i => if (apRootIndices.contains(i) && apRootIndices(0) != i) apRootIndices(0) else advancedParents(i)) 80 | 81 | private val advancedDepTreeMNLinks = DiscourseTree.setAdvancedParents(rootNode, leafStatuses.size, true) 82 | private val advancedParentsMNLinks = advancedDepTreeMNLinks.parents 83 | private val advancedLabelsMNLinks = advancedDepTreeMNLinks.labels.map(DiscourseTree.getCoarseLabel(_)) 84 | 85 | val parents = advancedParentsOneRoot // Current best 86 | val parentsMultiRoot = advancedParents 87 | val labels = advancedLabels 88 | val childrenMap = DiscourseTree.makeChildrenMap(parents) 89 | 90 | 91 | def getParents(useMultinuclearLinks: Boolean) = { 92 | if (useMultinuclearLinks) advancedDepTreeMNLinks.parents else advancedParentsOneRoot 93 | } 94 | 95 | def getLabels(useMultinuclearLinks: Boolean) = { 96 | if (useMultinuclearLinks) advancedLabelsMNLinks else advancedLabels 97 | } 98 | } 99 | 100 | object DiscourseTree { 101 | 102 | def getCoarseLabel(label: String) = { 103 | if (label == null) { 104 | "root" 105 | } else if (label.contains("-")) { 106 | label.substring(0, label.indexOf("-")) 107 | } else { 108 | label 109 | } 110 | } 111 | 112 | def getLeaves(rootNode: DiscourseNode): ArrayBuffer[DiscourseNode] = { 113 | val leaves = new ArrayBuffer[DiscourseNode] 114 | getLeavesHelper(rootNode, leaves) 115 | } 116 | 117 | def getLeavesHelper(rootNode: DiscourseNode, leaves: ArrayBuffer[DiscourseNode]): ArrayBuffer[DiscourseNode] = { 118 | if (rootNode.isLeaf) { 119 | leaves += rootNode 120 | leaves 121 | } else { 122 | for (child <- rootNode.children) { 123 | getLeavesHelper(child, leaves) 124 | } 125 | leaves 126 | } 127 | } 128 | 129 | def setParentsHelper(node: DiscourseNode) { 130 | for (child <- node.children) { 131 | child.parent = node 132 | setParentsHelper(child) 133 | } 134 | } 135 | 136 | def setRecursiveHeadsHelper(node: DiscourseNode): Int = { 137 | if (node.isLeaf) { 138 | node.head = node.span._1 139 | node.head 140 | } else { 141 | var parentHeadIdx = -1 142 | for (child <- node.children) { 143 | val childHead = setRecursiveHeadsHelper(child) 144 | if (parentHeadIdx == -1 && child.label == DiscourseNode.Nucleus) { 145 | parentHeadIdx = childHead 146 | } 147 | } 148 | require(parentHeadIdx != -1) 149 | node.head = parentHeadIdx 150 | parentHeadIdx 151 | } 152 | } 153 | 154 | def setAdvancedParents(node: DiscourseNode, numLeaves: Int, addMultinuclearLinks: Boolean): DiscourseDependencyTree = { 155 | val depTree = new DiscourseDependencyTree(Array.fill(numLeaves)(-1), new Array[String](numLeaves), new ArrayBuffer[(Int,Int)]) 156 | setAdvancedParentsHelper(node, depTree, addMultinuclearLinks) 157 | depTree 158 | } 159 | 160 | /** 161 | * Set parents according to the "advanced" strategy, which by definition 162 | * produces a tree such that the depth of each node is 1 + the number of Ss 163 | * between it and the root. This helper method returns the set of unbound 164 | * nodes at this point in the recursion; ordinarily in parsing this would just 165 | * be one head, but it can be multiple in the case of N => N N rules. 166 | */ 167 | def setAdvancedParentsHelper(node: DiscourseNode, depTree: DiscourseDependencyTree, addMultinuclearLinks: Boolean): Seq[Int] = { 168 | // Leaf node 169 | if (node.children.size == 0) { 170 | Seq(node.span._1) 171 | } else if (node.children.size == 2) { 172 | //////////// 173 | // BINARY // 174 | //////////// 175 | // Identify the satellite (if it exists) and link up all exposed heads from 176 | // the satellite to the nucleus. The rel2par of the satellite encodes the relation. 177 | val leftExposed = setAdvancedParentsHelper(node.children(0), depTree, addMultinuclearLinks) 178 | val rightExposed = setAdvancedParentsHelper(node.children(1), depTree, addMultinuclearLinks) 179 | val ruleType = node.children(0).label + " " + node.children(1).label 180 | // BINUCLEAR 181 | if (ruleType == DiscourseNode.Nucleus + " " + DiscourseNode.Nucleus) { 182 | if (addMultinuclearLinks) { 183 | require(leftExposed.size == 1 && rightExposed.size == 1, "Bad structure!") 184 | depTree.parents(rightExposed(0)) = leftExposed.head 185 | // All labels of multinuclear things start with = 186 | depTree.labels(rightExposed(0)) = "=" + node.children(1).rel2par 187 | leftExposed 188 | } else { 189 | if (node.children(0).rel2par == "Same-Unit" && node.children(1).rel2par == "Same-Unit") { 190 | // There can be multiple if one Same-Unit contains some coordination 191 | for (leftIdx <- leftExposed) { 192 | for (rightIdx <- rightExposed) { 193 | depTree.sameUnitPairs += leftIdx -> rightIdx 194 | } 195 | } 196 | } 197 | leftExposed ++ rightExposed 198 | } 199 | } else if (ruleType == DiscourseNode.Nucleus + " " + DiscourseNode.Satellite) { 200 | // Mononuclear, left-headed 201 | val head = leftExposed.head 202 | // val head = leftExposed.last // This works a bit worse 203 | for (rightIdx <- rightExposed) { 204 | depTree.parents(rightIdx) = head 205 | depTree.labels(rightIdx) = node.children(1).rel2par 206 | } 207 | leftExposed 208 | } else { 209 | // Mononuclear, right-headed 210 | require(ruleType == DiscourseNode.Satellite + " " + DiscourseNode.Nucleus) 211 | val head = rightExposed.head 212 | for (leftIdx <- leftExposed) { 213 | depTree.parents(leftIdx) = head 214 | depTree.labels(leftIdx) = node.children(0).rel2par 215 | } 216 | rightExposed 217 | } 218 | } else { 219 | ////////////////// 220 | // HIGHER ARITY // 221 | ////////////////// 222 | val allChildrenAreNuclei = !node.children.map(_.label == DiscourseNode.Satellite).reduce(_ || _) 223 | val oneChildIsNucleus = node.children.map(_.label).filter(_ == DiscourseNode.Nucleus).size == 1 224 | require(allChildrenAreNuclei || oneChildIsNucleus, "Bad higher-arity: " + node.children.map(_.label).toSeq) 225 | // Higher-arity, all nuclei. Can be Same-Unit, mostly List 226 | if (allChildrenAreNuclei) { 227 | val allChildrenExposedIndices = node.children.map(child => setAdvancedParentsHelper(child, depTree, addMultinuclearLinks)) 228 | // Link up all pairs of exposed indices across the children 229 | val allExposed = new ArrayBuffer[Int] 230 | if (addMultinuclearLinks) { 231 | // Add links in sequence a <- b <- c ... (child points to parent here) 232 | // There should only be one exposed index in this case 233 | for (childIdx <- 0 until allChildrenExposedIndices.size) { 234 | require(allChildrenExposedIndices(childIdx).size == 1) 235 | if (childIdx > 0) { 236 | depTree.parents(allChildrenExposedIndices(childIdx).head) = allChildrenExposedIndices(childIdx - 1).head 237 | // All labels of multinuclear things start with = 238 | depTree.labels(allChildrenExposedIndices(childIdx).head) = "=" + node.children(childIdx).rel2par 239 | } 240 | } 241 | allExposed += allChildrenExposedIndices(0).head 242 | } else { 243 | // Pass all children up 244 | for (exposedIndices <- allChildrenExposedIndices) { 245 | allExposed ++= exposedIndices 246 | } 247 | } 248 | allExposed 249 | } else { 250 | // Higher-arity, one nucleus. Typically standard relations that simply have arity > 2 251 | val nucleusIdx = node.children.map(_.label).zipWithIndex.filter(_._1 == DiscourseNode.Nucleus).head._2 252 | val nucleusExposed = setAdvancedParentsHelper(node.children(nucleusIdx), depTree, addMultinuclearLinks) 253 | for (i <- 0 until node.children.size) { 254 | if (i != nucleusIdx) { 255 | val satelliteExposed = setAdvancedParentsHelper(node.children(i), depTree, addMultinuclearLinks) 256 | // val nucleusHead = if (i < nucleusIdx) nucleusExposed.head else nucleusExposed.last // This works a bit worse 257 | val nucleusHead = nucleusExposed.head 258 | for (satelliteIdx <- satelliteExposed) { 259 | depTree.parents(satelliteIdx) = nucleusHead 260 | depTree.labels(satelliteIdx) = node.children(i).rel2par 261 | } 262 | } 263 | } 264 | nucleusExposed 265 | } 266 | } 267 | } 268 | 269 | def makeChildrenMap(parents: Seq[Int]) = { 270 | val childrenMap = new HashMap[Int,ArrayBuffer[Int]] 271 | for (i <- 0 until parents.size) { 272 | childrenMap.put(i, new ArrayBuffer[Int]) 273 | } 274 | for (i <- 0 until parents.size) { 275 | if (parents(i) != -1) { 276 | childrenMap(parents(i)) += i 277 | } 278 | } 279 | childrenMap 280 | } 281 | 282 | def computeDepths(parents: Seq[Int]): Array[Int] = computeDepths(parents, Array.fill(parents.size)(""), false) 283 | 284 | def computeDepths(parents: Seq[Int], labels: Seq[String], flattenMultinuclear: Boolean): Array[Int] = { 285 | val depths = Array.tabulate(parents.size)(i => -1) 286 | var unassignedDepths = true 287 | while (unassignedDepths) { 288 | unassignedDepths = false 289 | for (i <- 0 until parents.size) { 290 | if (depths(i) == -1) { 291 | if (parents(i) == -1) { 292 | depths(i) = 1 293 | } else if (depths(parents(i)) != -1) { 294 | depths(i) = if (flattenMultinuclear && labels(i).startsWith("=")) depths(parents(i)) else depths(parents(i)) + 1 295 | } else { 296 | unassignedDepths = true 297 | } 298 | } 299 | } 300 | } 301 | // for (i <- 0 until depths.size) { 302 | // require(depths(i) == computeDepth(parents, labels, flattenMultinuclear, i)) 303 | // } 304 | depths 305 | } 306 | 307 | def computeDepth(parents: Seq[Int], labels: Seq[String], flattenMultinuclear: Boolean, idx: Int) = { 308 | var node = idx 309 | var depth = 0 310 | // The root of the tree is at depth 1 311 | while (node != -1) { 312 | if (!flattenMultinuclear || !labels(node).startsWith("=")) { 313 | depth += 1 314 | } 315 | node = parents(node) 316 | } 317 | depth 318 | } 319 | 320 | def computeNumDominated(parents: Seq[Int], idx: Int) = { 321 | val childrenMap = makeChildrenMap(parents) 322 | val children = childrenMap(idx) 323 | var totalChildren = 0 324 | var newFrontier = new HashSet[Int] ++ children 325 | var frontier = new HashSet[Int] 326 | while (!newFrontier.isEmpty) { 327 | frontier = newFrontier 328 | newFrontier = new HashSet[Int] 329 | for (child <- frontier) { 330 | totalChildren += 1 331 | newFrontier ++= childrenMap(child) 332 | } 333 | } 334 | totalChildren 335 | } 336 | } 337 | 338 | case class DiscourseDependencyTree(val parents: Array[Int], 339 | val labels: Array[String], 340 | val sameUnitPairs: ArrayBuffer[(Int,Int)]) { 341 | } 342 | 343 | case class DiscourseNode(val label: String, 344 | val rel2par: String, 345 | val span: (Int,Int), 346 | val leafText: String, 347 | val children: ArrayBuffer[DiscourseNode]) extends Serializable { 348 | var head: Int = -1 349 | var hiraoHead: Int = -1 350 | var parent: DiscourseNode = null 351 | 352 | // N.B. If anything changes here, should rerun EDUAligner and make sure things aren't worse 353 | val leafTextPreTok = leafText.replace("

", "") 354 | val leafWordsWhitespace = leafTextPreTok.split("\\s+").filter(_ != "

") 355 | // Adding the period fixes a bug where "buy-outs" is treated differently sentence-internally than it is 356 | // when it ends an utterance; generally this makes the tokenizer more consistent on fragments 357 | val leafWordsPTBLL = if (leafTextPreTok.split("\\s+").last.contains("-")) { 358 | new PTBLineLexer().tokenize(leafTextPreTok + " .").toArray(Array[String]()).dropRight(1) 359 | } else { 360 | new PTBLineLexer().tokenize(leafTextPreTok).toArray(Array[String]()) 361 | } 362 | // There are still some spaces in some tokens; get rid of these 363 | val leafWordsPTB = if (leafTextPreTok != "") { 364 | leafWordsPTBLL.flatMap(_.split("\\s+")).filter(_ != "

") 365 | } else { 366 | Array[String]() 367 | } 368 | 369 | // def leafWords = leafWordsWhitespace 370 | def leafWords = leafWordsPTB 371 | // def leafWords = leafWordsPTBLL 372 | 373 | def isLeaf = span._2 - span._1 == 1 374 | 375 | } 376 | 377 | object DiscourseNode { 378 | val Nucleus = "Nucleus" 379 | val Satellite = "Satellite" 380 | val Root = "Root" 381 | } 382 | -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/data/DiscourseTreeReader.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ.data 2 | 3 | import java.io.File 4 | 5 | import scala.collection.JavaConverters.asScalaBufferConverter 6 | import scala.collection.mutable.ArrayBuffer 7 | 8 | import edu.berkeley.nlp.entity.coref.MentionPropertyComputer 9 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 10 | import edu.berkeley.nlp.futile.util.Logger 11 | 12 | object DiscourseTreeReader { 13 | 14 | def readDisFile(path: String) = { 15 | val file = new File(path) 16 | val reader = IOUtils.openInHard(file) 17 | val iter = IOUtils.lineIterator(reader) 18 | 19 | val currNodeStack = new ArrayBuffer[DiscourseNode] 20 | while (iter.hasNext) { 21 | var line = iter.next 22 | if (line.trim.startsWith(")")) { 23 | val finishedNode = currNodeStack.remove(currNodeStack.size - 1) 24 | if (!currNodeStack.isEmpty) { 25 | currNodeStack.last.children += finishedNode 26 | } else { 27 | currNodeStack += finishedNode 28 | } 29 | } else { // ( 30 | val text = if (line.contains("(text _!")) { 31 | val textStart = line.indexOf("_!") + 2 32 | val textEnd = line.lastIndexOf("_!") 33 | val result = line.substring(textStart, textEnd) 34 | line = line.substring(0, textStart) + line.substring(textEnd) 35 | result 36 | } else { 37 | "" 38 | } 39 | val lineCut = line.trim.drop(1).trim 40 | val label = lineCut.substring(0, lineCut.indexOf(" ")) 41 | require(label == DiscourseNode.Nucleus || label == DiscourseNode.Satellite || label == DiscourseNode.Root, "Bad label: " + label + "; " + line) 42 | val spanText = readField(line, "span") 43 | var isLeaf = false 44 | val span = if (!spanText.isEmpty) { 45 | isLeaf = false 46 | makeSpan(spanText) 47 | } else { 48 | isLeaf = true 49 | val leafText = readField(line, "leaf") 50 | require(!leafText.isEmpty, line) 51 | (leafText.toInt - 1) -> leafText.toInt 52 | } 53 | val rel2par = readField(line, "rel2par") 54 | val newNode = new DiscourseNode(label, rel2par, span, text, ArrayBuffer[DiscourseNode]()) 55 | if (currNodeStack.isEmpty || !isLeaf) { 56 | currNodeStack += newNode 57 | } else { 58 | currNodeStack.last.children += newNode 59 | } 60 | } 61 | } 62 | require(currNodeStack.size == 1, currNodeStack.size) 63 | reader.close() 64 | new DiscourseTree(file.getName, currNodeStack.head) 65 | } 66 | 67 | def writeDisFile(tree: DiscourseTree) { 68 | writeDisFileHelper(tree.rootNode, "") 69 | Logger.logss("PARENTS: " + tree.parents.toSeq) 70 | } 71 | 72 | def writeDisFileHelper(node: DiscourseNode, currStr: String) { 73 | if (node.children.isEmpty) { 74 | val leafIdx = node.span._1 // node.span._2 is what is used in the dataset 75 | Logger.logss(currStr + "( " + node.label + " (leaf " + leafIdx + ") (rel2par " + node.rel2par + ") (text _!" + node.leafText + "_!) ) (HEAD = " + node.head + ")") 76 | } else { 77 | val spanStart = node.span._1 // node.span._1 + 1 is what is used in the dataset 78 | Logger.logss(currStr + "( " + node.label + " (span " + spanStart + " " + node.span._2 + ") (rel2par " + node.rel2par + ") (HEAD = " + node.head + ")") 79 | for (child <- node.children) { 80 | writeDisFileHelper(child, currStr + " ") 81 | } 82 | Logger.logss(currStr + ")") 83 | } 84 | } 85 | 86 | def makeSpan(str: String) = { 87 | val strSplit = str.split("\\s+") 88 | require(strSplit.size == 2, strSplit) 89 | (strSplit(0).toInt - 1) -> strSplit(1).toInt 90 | } 91 | 92 | def readField(line: String, label: String) = { 93 | if (line.contains("(" + label + " ")) { 94 | val start = line.indexOf("(" + label + " ") 95 | val end = line.indexOf(")", start) 96 | // Logger.logss(line + " " + start + " " + end) 97 | line.substring(start + label.size + 2, end) 98 | } else { 99 | "" 100 | } 101 | } 102 | 103 | def readAllAlignAndFilter(preprocDocsPath: String, discourseTreesPath: String, mpc: MentionPropertyComputer): Seq[DiscourseDepEx] = { 104 | val allTreeFiles = new File(discourseTreesPath).listFiles.sortBy(_.getName).filter(_.getName.endsWith(".out.dis")) 105 | val allTrees = allTreeFiles.map(file => DiscourseTreeReader.readDisFile(file.getAbsolutePath)) 106 | val allSummDocFiles = new File(preprocDocsPath).listFiles.sortBy(_.getName) 107 | val allSummDocs = allSummDocFiles.map(file => SummDoc.readSummDocNoAbstract(file.getAbsolutePath, mpc, filterSpuriousDocs = false, filterSpuriousSummSents = false)) 108 | require(allTrees.size == allSummDocs.size) 109 | val badFiles = new ArrayBuffer[String] 110 | val exs = new ArrayBuffer[DiscourseDepEx] 111 | for (i <- 0 until allTrees.size) { 112 | require(allTreeFiles(i).getName.dropRight(4) == allSummDocFiles(i).getName, allTreeFiles(i).getName.dropRight(4) + " " + allSummDocFiles(i).getName) 113 | Logger.logss(allSummDocFiles(i).getName) 114 | try { 115 | val alignments = EDUAligner.align(allTrees(i).leafWords, allSummDocs(i).doc) 116 | exs += new DiscourseDepEx(allSummDocs(i), allTrees(i), alignments) 117 | } catch { 118 | case e: Exception => { 119 | Logger.logss(e) 120 | badFiles += allSummDocFiles(i).getName 121 | } 122 | } 123 | } 124 | Logger.logss("Read in " + exs.size + " out of " + allTrees.size + " possible") 125 | Logger.logss(badFiles.size + " bad files: " + badFiles) 126 | exs; 127 | } 128 | 129 | def readHiraoParents(file: String) = { 130 | val lines = IOUtils.readLines(file).asScala 131 | val parents = Array.fill(lines.size)(-100) 132 | for (line <- lines) { 133 | val childParent = line.replace("leaf", "").replace("->", "").split(" ") 134 | parents(childParent(0).toInt - 1) = childParent(1).toInt - 1 135 | } 136 | require(parents.filter(_ == -100).size == 0, "Null lingered") 137 | parents 138 | } 139 | 140 | def main(args: Array[String]) { 141 | val allFiles = new File("data/RSTDiscourse/data/RSTtrees-WSJ-main-1.0/ALL-FILES/").listFiles.filter(_.getName.endsWith(".out.dis")) 142 | val hiraoDir = new File("data/dependency/") 143 | val hiraoFiles = hiraoDir.listFiles.map(_.getName) 144 | var totalEduLen = 0 145 | var numEdus = 0 146 | for (file <- allFiles) { 147 | println("====================") 148 | println(file.getName) 149 | val tree = readDisFile(file.getAbsolutePath) 150 | if (hiraoFiles.contains(file.getName)) { 151 | val referenceParents = readHiraoParents(hiraoDir.getAbsolutePath + "/" + file.getName) 152 | Logger.logss("REFERENCE: " + referenceParents.toSeq) 153 | Logger.logss(" NORMAL: " + tree.parents.toSeq) 154 | require(tree.parents.size == referenceParents.size, tree.parents.size + " " + referenceParents.size) 155 | def hammingDist(candParents: Seq[Int]) = (0 until referenceParents.size).map(i => if (candParents(i) != referenceParents(i)) 1 else 0).reduce(_ + _) 156 | Logger.logss("Normal parents: Hamming dist on " + file.getName + ": " + hammingDist(tree.parents) + " / " + referenceParents.size) 157 | } 158 | // writeDisFile(tree) 159 | totalEduLen += tree.leaves.foldLeft(0)(_ + _.leafWords.size) 160 | numEdus += tree.leaves.size 161 | } 162 | Logger.logss(totalEduLen + " / " + numEdus + " = " + (totalEduLen.toDouble / numEdus)) 163 | // val root = readDisFile("data/RSTDiscourse/data/RSTtrees-WSJ-main-1.0/TRAINING/wsj_0600.out.dis") 164 | // writeDisFile(root) 165 | } 166 | } -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/data/EDUAligner.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ.data 2 | 3 | import java.io.File 4 | 5 | import scala.collection.mutable.ArrayBuffer 6 | 7 | import edu.berkeley.nlp.entity.coref.MentionPropertyComputer 8 | import edu.berkeley.nlp.entity.coref.NumberGenderComputer 9 | import edu.berkeley.nlp.futile.util.Logger 10 | 11 | object EDUAligner { 12 | 13 | def align(leafWords: Seq[Seq[String]], docSents: Seq[DepParse]) = { 14 | var currSentIdx = 0 15 | var currWordIdx = 0 16 | val leafSpans = new ArrayBuffer[((Int,Int),(Int,Int))] 17 | for (i <- 0 until leafWords.size) { 18 | val start = (currSentIdx, currWordIdx) 19 | val currLen = docSents(currSentIdx).size 20 | require(currWordIdx + leafWords(i).size <= currLen, 21 | currWordIdx + " " + leafWords(i).size + " " + currLen + "\nsent = " + docSents(currSentIdx).getWords.toSeq + ", leaf words = " + leafWords(i).toSeq) 22 | var leafWordIdx = 0 23 | while (leafWordIdx < leafWords(i).size) { 24 | val docWord = docSents(currSentIdx).getWord(currWordIdx) 25 | val leafWord = leafWords(i)(leafWordIdx) 26 | val currWordsEqual = docWord == leafWord 27 | val currWordsEffectivelyEqual = docWord.contains("'") || docWord.contains("`") // Ignore some punc symbols because they're weird 28 | // Spurious period but last thing ended in period, so it was probably added by the tokenizer (like "Ltd. .") 29 | if (!currWordsEqual && docWord == "." && currWordIdx > 0 && docSents(currSentIdx).getWord(currWordIdx - 1).endsWith(".")) { 30 | currWordIdx += 1 31 | if (currWordIdx == docSents(currSentIdx).size) { 32 | currSentIdx += 1 33 | currWordIdx = 0 34 | } 35 | // N.B. don't advance leafWordIdx 36 | } else { 37 | require(currWordsEqual || currWordsEffectivelyEqual, docWord + " :: " + leafWord + "\nsent = " + docSents(currSentIdx).getWords.toSeq + ", leaf words = " + leafWords(i).toSeq) 38 | currWordIdx += 1 39 | if (currWordIdx == docSents(currSentIdx).size) { 40 | currSentIdx += 1 41 | currWordIdx = 0 42 | } 43 | leafWordIdx += 1 44 | } 45 | } 46 | val end = if (currWordIdx == 0) { 47 | (currSentIdx - 1, docSents(currSentIdx - 1).size) 48 | } else { 49 | (currSentIdx, currWordIdx) 50 | } 51 | leafSpans += start -> end 52 | // if (currWordIdx == docSents(currSentIdx).size) { 53 | // currSentIdx += 1 54 | // currWordIdx = 0 55 | // } 56 | } 57 | leafSpans 58 | // } 59 | } 60 | 61 | def main(args: Array[String]) { 62 | val allTreeFiles = new File("data/RSTDiscourse/data/RSTtrees-WSJ-main-1.0/ALL-FILES/").listFiles.sortBy(_.getName).filter(_.getName.endsWith(".out.dis")) 63 | val allTrees = allTreeFiles.map(file => DiscourseTreeReader.readDisFile(file.getAbsolutePath)) 64 | // val allSummDocs = new File("data/RSTDiscourse/data/RSTtrees-WSJ-main-1.0/ALL-FILES-PREPROC/").listFiles.sortBy(_.getName)) 65 | val numberGenderComputer = NumberGenderComputer.readBergsmaLinData("data/gender.data"); 66 | val mpc = new MentionPropertyComputer(Some(numberGenderComputer)) 67 | val allSummDocFiles = new File("data/RSTDiscourse/data/RSTtrees-WSJ-main-1.0/ALL-FILES-PROC2/").listFiles.sortBy(_.getName) 68 | val allSummDocs = allSummDocFiles.map(file => SummDoc.readSummDocNoAbstract(file.getAbsolutePath, mpc, filterSpuriousDocs = false, filterSpuriousSummSents = false)) 69 | val summNames = new File("data/RSTDiscourse/data/RSTtrees-WSJ-main-1.0/SUMM-SUBSET-PROC/").listFiles.map(_.getName) 70 | require(allTrees.size == allSummDocs.size) 71 | val badFiles = new ArrayBuffer[String] 72 | for (i <- 0 until allTrees.size) { 73 | require(allTreeFiles(i).getName.dropRight(4) == allSummDocFiles(i).getName, allTreeFiles(i).getName.dropRight(4) + " " + allSummDocFiles(i).getName) 74 | Logger.logss(allSummDocFiles(i).getName) 75 | try { 76 | align(allTrees(i).leafWords, allSummDocs(i).doc) 77 | } catch { 78 | case e: Exception => { 79 | Logger.logss(e) 80 | badFiles += allSummDocFiles(i).getName 81 | } 82 | } 83 | } 84 | Logger.logss(badFiles.size + " bad files: " + badFiles) 85 | val badSummDocs = (badFiles.toSet & summNames.toSet) 86 | Logger.logss(badSummDocs.size + " bad summarized files: " + badSummDocs.toSeq.sorted) 87 | } 88 | } -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/data/StopwordDict.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ.data 2 | 3 | import scala.collection.mutable.HashSet 4 | 5 | object StopwordDict { 6 | 7 | // N.B. This set was extracted from the RST treebank (train and test) mostly to reproduce 8 | // Hirao's results; it shouldn't really be used for other things 9 | val stopwords = Set("!", "", "#", "$", "%", "&", "'", "''", "'S", "'s", "()", ",", "-", "--", "-owned", ".", "", ":", ";", "<", "?", "", 10 | "A", "A.", "", "AND", "After", "All", "Am", "An", "And", "Any", "As", "At", "BE", "Between", "Both", "But", "By", "Each", 11 | "Few", "For", "From", "Had", "He", "Here", "How", "I", "If", "In", "Is", "It", "Its", "MORE", "More", "Most", "NO", "No", "No.", 12 | "Not", "OF", "Of", "On", "One", "Only", "Or", "Other", "Our", "Over", "She", "So", "Some", "Such", "THE", "Than", "That", "The", 13 | "Their", "Then", "There", "These", "They", "Those", "To", "UPS", "Under", "Until", "WHY", "We", "What", "When", "While", "Why", 14 | "Would", "You", "`It", "``", "a", "about", "above", "after", "again", "again.", "", "against", "all", "am", "an", "and", "any", 15 | "as", "at", "be", "been", "being", "below", "between", "both", "but", "by", "ca", "can", "could", "did", "do", "doing", "down", 16 | "each", "few", "for", "from", "further", "had", "have", "having", "he", "her", "here", "herself", "him", "him.", "", "himself", 17 | "how", "if", "in", "into", "is", "it", "its", "itself", "let", "lets", "me", "more", "most", "must", "my", "n't", "no", "nor", 18 | "not", "of", "off", "on", "one", "ones", "only", "or", "other", "others", "ought", "our", "out", "over", "own", "owned", "owns", 19 | "same", "she", "should", "so", "some", "such", "than", "that", "the", "their", "them", "then", "there", "these", "they", "those", 20 | "through", "to", "too", "under", "until", "up", "very", "we", "were", "what", "when", "where", "which", "while", "who", "whom", 21 | "why", "with", "wo", "would", "you", "your", "yourself", "{", "}") 22 | // Leave $ in there 23 | val stopwordTags = new HashSet[String] ++ Array("CC", "DT", "EX", "IN", "LS", "MD", "PDT", "POS", "PRN", "PRP", "PRP$", "RP", "SYM", 24 | "TO", "WDT", "WP", "WP$", "WRB", ".", ",", "``", "''", ";", ":", "-LRB-", "-RRB-", "-LSB-", "-RSB-", "-LCB-", "-RCB-") 25 | 26 | } -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/data/SummDoc.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ.data 2 | 3 | import java.io.File 4 | 5 | import scala.collection.mutable.ArrayBuffer 6 | 7 | import edu.berkeley.nlp.entity.ConllDocReader 8 | import edu.berkeley.nlp.entity.coref.CorefDoc 9 | import edu.berkeley.nlp.entity.coref.CorefDocAssembler 10 | import edu.berkeley.nlp.entity.coref.MentionPropertyComputer 11 | import edu.berkeley.nlp.entity.lang.EnglishCorefLanguagePack 12 | import edu.berkeley.nlp.entity.lang.Language 13 | import edu.berkeley.nlp.futile.util.Counter 14 | import edu.berkeley.nlp.futile.util.Logger 15 | import edu.berkeley.nlp.summ.CorefUtils 16 | 17 | @SerialVersionUID(2350732155930072470L) 18 | case class SummDoc(val name: String, 19 | val corefDoc: CorefDoc, 20 | val doc: Seq[DepParse], 21 | val summSents: Seq[Seq[String]], 22 | val summary: Seq[DepParse]) extends DepParseDoc { 23 | 24 | def getSentMents(sentIdx: Int) = corefDoc.goldMentions.filter(_.sentIdx == sentIdx) 25 | 26 | val sentMentStartIndices = (0 to corefDoc.rawDoc.numSents).map(sentIdx => { 27 | val mentsFollowing = corefDoc.goldMentions.filter(_.sentIdx >= sentIdx) 28 | if (mentsFollowing.isEmpty) corefDoc.goldMentions.size else mentsFollowing.head.mentIdx; 29 | }) 30 | val entitiesPerSent = (0 until corefDoc.rawDoc.numSents).map(sentIdx => getEntitiesInSpan(sentIdx, 0, doc(sentIdx).size)) 31 | 32 | def getMentionsInSpan(sentIdx: Int, startIdx: Int, endIdx: Int) = { 33 | val mentsInSent = corefDoc.goldMentions.slice(sentMentStartIndices(sentIdx), sentMentStartIndices(sentIdx+1)) 34 | mentsInSent.filter(ment => startIdx <= ment.startIdx && ment.endIdx <= endIdx) 35 | } 36 | 37 | def getEntitiesInSpan(sentIdx: Int, startIdx: Int, endIdx: Int) = { 38 | getMentionsInSpan(sentIdx, startIdx, endIdx).map(ment => corefDoc.goldClustering.getClusterIdx(ment.mentIdx)).distinct.sorted 39 | } 40 | // 41 | val entitiesBySize = corefDoc.goldClustering.clusters.zipWithIndex.map(clusterAndIdx => clusterAndIdx._1.size -> clusterAndIdx._2).sortBy(- _._1).map(_._2) 42 | val entitySemanticTypes = (0 until corefDoc.goldClustering.clusters.size).map(clusterIdx => { 43 | val cluster = corefDoc.goldClustering.clusters(clusterIdx) 44 | val types = new Counter[String] 45 | cluster.foreach(mentIdx => types.incrementCount(corefDoc.goldMentions(mentIdx).nerString, 1.0)) 46 | types.removeKey("O") 47 | types.argMax() 48 | }) 49 | } 50 | 51 | object SummDoc { 52 | 53 | def makeSummDoc(name: String, corefDoc: CorefDoc, summSents: Seq[Seq[String]]): SummDoc = { 54 | val doc = (0 until corefDoc.rawDoc.numSents).map(i => { 55 | // DepParse.fromDepConstTree(corefDoc.rawDoc.trees(i)) 56 | new DepParseConllWrapped(corefDoc.rawDoc, i) 57 | }) 58 | val summary = (0 until summSents.size).map(i => { 59 | new DepParseRaw(summSents(i).toArray, Array.tabulate(summSents(i).size)(i => "-")) 60 | }) 61 | new SummDoc(name, corefDoc, doc, summSents, summary) 62 | } 63 | 64 | def readSummDocNoAbstract(docPath: String, 65 | mentionPropertyComputer: MentionPropertyComputer, 66 | filterSpuriousDocs: Boolean = false, 67 | filterSpuriousSummSents: Boolean = false) = { 68 | val doc = new ConllDocReader(Language.ENGLISH).readConllDocs(docPath)(0) 69 | val assembler = new CorefDocAssembler(new EnglishCorefLanguagePack, true) 70 | val corefDoc = assembler.createCorefDoc(doc, mentionPropertyComputer) 71 | makeSummDoc(new File(docPath).getName, corefDoc, Seq[Seq[String]]()) 72 | } 73 | 74 | def readSummDocs(docsPath: String, 75 | abstractsPath: String, 76 | abstractsAreConll: Boolean, 77 | maxFiles: Int = -1, 78 | mentionPropertyComputer: MentionPropertyComputer, 79 | filterSpuriousDocs: Boolean = false, 80 | filterSpuriousSummSents: Boolean = false, 81 | docFilter: (SummDoc => Boolean)) = { 82 | val docFiles = new File(docsPath).listFiles.sorted 83 | val processedDocs = new ArrayBuffer[SummDoc] 84 | var docIdx = 0 85 | var numSummSentsFiltered = 0 86 | var iter = docFiles.iterator 87 | val assembler = new CorefDocAssembler(new EnglishCorefLanguagePack, true) 88 | var filteredByDocFilter = 0 89 | var filteredBySpurious = 0 90 | // XXX 91 | // val otherMpc = new singledoc.coref.MentionPropertyComputer(Some(HorribleCorefMunger.reverseMungeNumberGenderComputer(mentionPropertyComputer.maybeNumGendComputer.get))) 92 | // val maybeCorefModel = Some(GUtil.load("models/coref-onto.ser.gz").asInstanceOf[PairwiseScorer]) 93 | // XXX 94 | while ((maxFiles == -1 || docIdx < maxFiles) && iter.hasNext) { 95 | val docFile = iter.next 96 | if (docIdx % 500 == 0) { 97 | Logger.logss(" Processing document " + docIdx + "; kept " + processedDocs.size + " so far") 98 | } 99 | docIdx += 1 100 | val fileName = docFile.getName() 101 | val abstractFile = new File(abstractsPath + "/" + fileName) 102 | require(abstractFile.exists(), "Couldn't find abstract file at " + abstractsPath + "/" + fileName) 103 | // val doc = DepParse.readFromFile(docFile.getAbsolutePath()) 104 | val doc = new ConllDocReader(Language.ENGLISH).readConllDocs(docFile.getAbsolutePath)(0) 105 | val summRaw = if (abstractsAreConll) { 106 | DepParse.readFromConllFile(abstractFile.getAbsolutePath()) 107 | } else { 108 | DepParse.readFromConstFile(abstractFile.getAbsolutePath()) 109 | } 110 | val summ = if (filterSpuriousSummSents) { 111 | summRaw.filter(sentParse => !SummaryAligner.identifySpuriousSentence(sentParse.getWords)) 112 | } else { 113 | summRaw 114 | } 115 | numSummSentsFiltered += (summRaw.size - summ.size) 116 | if (summ.size == 0 || (filterSpuriousDocs && SummaryAligner.identifySpuriousSummary(summRaw(0).getWords))) { 117 | // val corefDoc = assembler.createCorefDoc(doc, mentionPropertyComputer) 118 | // val summDoc = makeSummDoc(fileName, corefDoc, summ.map(_.getWords.toSeq), docCompressor) 119 | // if (docFilter(summDoc)) { 120 | // filteredBySpurious += 1 121 | // } else { 122 | // filteredByDocFilter += 1 123 | // } 124 | // Do nothing 125 | } else { 126 | val rawCorefDoc = assembler.createCorefDoc(doc, mentionPropertyComputer) 127 | val rawGoldMents = rawCorefDoc.goldMentions.map(CorefUtils.remapMentionType(_)) 128 | val corefDoc = new CorefDoc(rawCorefDoc.rawDoc, rawGoldMents, rawCorefDoc.goldClustering, rawGoldMents) 129 | val summDoc = makeSummDoc(fileName, corefDoc, summ.map(_.getWords.toSeq)) 130 | if (docFilter(summDoc)) { 131 | processedDocs += summDoc 132 | } else { 133 | filteredByDocFilter += 1 134 | } 135 | } 136 | } 137 | Logger.logss("Read docs from " + docsPath + " and abstracts from " + abstractsPath + "; filtered " + 138 | (docFiles.size - processedDocs.size) + "/" + docFiles.size + " docs (" + filteredByDocFilter + 139 | " from doc filter (len, etc.) and " + filteredBySpurious + " from spurious detection (article, etc.)) and " + numSummSentsFiltered + " sentences") 140 | processedDocs 141 | } 142 | } -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/data/SummaryAligner.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ.data 2 | 3 | import scala.collection.JavaConverters._ 4 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 5 | import edu.berkeley.nlp.futile.util.Counter 6 | import edu.berkeley.nlp.futile.util.Logger 7 | import edu.berkeley.nlp.futile.LightRunner 8 | import edu.berkeley.nlp.futile.EditDistance.EditOp 9 | import edu.berkeley.nlp.summ.RougeComputer 10 | 11 | object SummaryAligner { 12 | 13 | def alignDocAndSummary(depParseDoc: DepParseDoc, verbose: Boolean): Array[Int] = { 14 | alignDocAndSummary(depParseDoc.doc.map(_.getWords.toSeq), depParseDoc.summary.map(_.getWords.toSeq), depParseDoc.name, verbose) 15 | } 16 | 17 | def getEditDistanceWithDeletion(docSent: Seq[String], summSent: Seq[String]) = { 18 | edu.berkeley.nlp.futile.EditDistance.editDistance(docSent.map(_.toLowerCase).asJava, summSent.map(_.toLowerCase).asJava, 1.0, 0.0, 1.0, true) 19 | } 20 | 21 | def getEditDistanceOpsWithDeletion(docSent: Seq[String], summSent: Seq[String]): Array[EditOp] = { 22 | edu.berkeley.nlp.futile.EditDistance.getEditDistanceOperations(docSent.map(_.toLowerCase).asJava, summSent.map(_.toLowerCase).asJava, 1.0, 0.0, 1.0, true) 23 | } 24 | 25 | /** 26 | * Produces a one-to-many alignment between the doc and the summary (i.e. each summary 27 | * sentence is aligned to at most one document sentence). Length is the length of the 28 | * summary (so summary is the target). 29 | */ 30 | def alignDocAndSummary(docSentences: Seq[Seq[String]], summary: Seq[Seq[String]], docName: String = "", verbose: Boolean = false): Array[Int] = { 31 | val alignments = Array.tabulate(summary.size)(i => -1) 32 | var numSentsAligned = 0 33 | for (summSentIdx <- 0 until summary.size) { 34 | var someAlignment = false; 35 | var bestAlignmentEd = Int.MaxValue 36 | var bestAlignmentChoice = -1 37 | for (docSentIdx <- 0 until docSentences.size) { 38 | val ed = edu.berkeley.nlp.futile.EditDistance.editDistance(docSentences(docSentIdx).asJava, summary(summSentIdx).asJava, 1.0, 0.0, 1.0, true) 39 | if (ed < bestAlignmentEd) { 40 | bestAlignmentEd = ed.toInt 41 | bestAlignmentChoice = docSentIdx 42 | } 43 | } 44 | if (verbose) { 45 | Logger.logss(summSentIdx + ": best alignment choice = " + bestAlignmentChoice + ", ed = " + bestAlignmentEd) 46 | } 47 | if (bestAlignmentEd < summary(summSentIdx).size * 0.5) { 48 | someAlignment = true 49 | alignments(summSentIdx) = bestAlignmentChoice 50 | if (verbose) { 51 | Logger.logss("ALIGNED: " + summSentIdx + " -> " + bestAlignmentChoice) 52 | Logger.logss("S1: " + docSentences(bestAlignmentChoice).reduce(_ + " " + _)) 53 | Logger.logss("S2: " + summary(summSentIdx).reduce(_ + " " + _)) 54 | Logger.logss("ED: " + bestAlignmentEd) 55 | } 56 | } 57 | if (!someAlignment) { 58 | // Logger.logss("UNALIGNED: " + summSentIdx + " " + summary(summSentIdx).reduce(_ + " " + _)); 59 | } else { 60 | numSentsAligned += 1 61 | } 62 | } 63 | if (verbose && numSentsAligned > 0) { 64 | Logger.logss(">1 alignment for " + docName) 65 | } 66 | alignments 67 | } 68 | 69 | def alignDocAndSummaryOracleRouge(depParseDoc: DepParseDoc, summSizeCutoff: Int): Array[Int] = { 70 | alignDocAndSummaryOracleRouge(depParseDoc.doc.map(_.getWords.toSeq), depParseDoc.summary.map(_.getWords.toSeq), summSizeCutoff) 71 | } 72 | 73 | def alignDocAndSummaryOracleRouge(docSentences: Seq[Seq[String]], summary: Seq[Seq[String]], summSizeCutoff: Int): Array[Int] = { 74 | val choices = Array.tabulate(summary.size)(i => { 75 | if (summary(i).size >= summSizeCutoff) { 76 | val summSent = summary(i) 77 | var bestRougeSourceIdx = -1 78 | var bestRougeScore = 0 79 | for (j <- 0 until docSentences.size) { 80 | var score = RougeComputer.computeRouge2SuffStats(Seq(docSentences(j)), Seq(summSent))._1 81 | if (score > bestRougeScore) { 82 | bestRougeSourceIdx = j 83 | bestRougeScore = score 84 | } 85 | } 86 | bestRougeSourceIdx 87 | } else { 88 | -1 89 | } 90 | }) 91 | choices 92 | } 93 | 94 | def identifySpuriousSummary(firstSentence: Seq[String]) = { 95 | val firstWords = firstSentence.slice(0, Math.min(10, firstSentence.size)).map(_.toLowerCase) 96 | val firstWordsNoPlurals = firstWords.map(word => if (word.endsWith("s")) word.dropRight(1) else word) 97 | firstWordsNoPlurals.contains("letter") || firstWordsNoPlurals.contains("article") || firstWordsNoPlurals.contains("column") || 98 | firstWordsNoPlurals.contains("review") || firstWordsNoPlurals.contains("interview") || firstWordsNoPlurals.contains("profile") 99 | } 100 | 101 | def identifySpuriousSentence(sentence: Seq[String]) = { 102 | val sentenceLcNoPlurals = sentence.map(_.toLowerCase).map(word => if (word.endsWith("s")) word.dropRight(1) else word) 103 | // sentenceLcNoPlurals.contains("photo") 104 | sentenceLcNoPlurals.contains("photo") || sentenceLcNoPlurals.contains("photo.") 105 | } 106 | } -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/demo/TreeJPanel.java: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ.demo; 2 | 3 | import java.awt.Color; 4 | import java.awt.FontMetrics; 5 | import java.awt.Graphics2D; 6 | import java.awt.geom.Line2D; 7 | import java.awt.geom.Point2D; 8 | import java.util.ArrayList; 9 | import java.util.List; 10 | 11 | import javax.swing.JPanel; 12 | 13 | import edu.berkeley.nlp.futile.fig.basic.Pair; 14 | import edu.berkeley.nlp.futile.syntax.Tree; 15 | 16 | /** 17 | * Class for displaying a Tree. 18 | * 19 | * @author Dan Klein 20 | */ 21 | 22 | public class TreeJPanel extends JPanel { 23 | 24 | private static final long serialVersionUID = 1L; 25 | 26 | static double sisterSkip = 2.5; 27 | static double parentSkip = 0.5; 28 | static double belowLineSkip = 0.1; 29 | static double aboveLineSkip = 0.1; 30 | 31 | static boolean drawTrianglesAtBottom = true; 32 | 33 | public static String nodeToString(Tree t) { 34 | if (t == null) { 35 | return " "; 36 | } 37 | Object l = t.getLabel(); 38 | if (l == null) { 39 | return " "; 40 | } 41 | String str = (String) l; 42 | if (str == null) { 43 | return " "; 44 | } 45 | return str; 46 | } 47 | 48 | static class WidthResult { 49 | double width = 0.0; 50 | double nodeTab = 0.0; 51 | double nodeCenter = 0.0; 52 | double childTab = 0.0; 53 | } 54 | 55 | public static WidthResult widthResult(Tree tree, FontMetrics fM) { 56 | WidthResult wr = new WidthResult(); 57 | if (tree == null) { 58 | wr.width = 0.0; 59 | wr.nodeTab = 0.0; 60 | wr.nodeCenter = 0.0; 61 | wr.childTab = 0.0; 62 | return wr; 63 | } 64 | double local = fM.stringWidth(nodeToString(tree)); 65 | if (tree.isLeaf()) { 66 | wr.width = local; 67 | wr.nodeTab = 0.0; 68 | wr.nodeCenter = local / 2.0; 69 | wr.childTab = 0.0; 70 | return wr; 71 | } 72 | double sub = 0.0; 73 | double nodeCenter = 0.0; 74 | double childTab = 0.0; 75 | for (int i = 0; i < tree.getChildren().size(); i++) { 76 | WidthResult subWR = widthResult(tree.getChildren() 77 | .get(i), fM); 78 | if (i == 0) { 79 | nodeCenter += (sub + subWR.nodeCenter) / 2.0; 80 | } 81 | if (i == tree.getChildren().size() - 1) { 82 | nodeCenter += (sub + subWR.nodeCenter) / 2.0; 83 | } 84 | sub += subWR.width; 85 | if (i < tree.getChildren().size() - 1) { 86 | sub += sisterSkip * fM.stringWidth(" "); 87 | } 88 | } 89 | double localLeft = local / 2.0; 90 | double subLeft = nodeCenter; 91 | double totalLeft = Math.max(localLeft, subLeft); 92 | double localRight = local / 2.0; 93 | double subRight = sub - nodeCenter; 94 | double totalRight = Math.max(localRight, subRight); 95 | wr.width = totalLeft + totalRight; 96 | wr.childTab = totalLeft - subLeft; 97 | wr.nodeTab = totalLeft - localLeft; 98 | wr.nodeCenter = nodeCenter + wr.childTab; 99 | return wr; 100 | } 101 | 102 | /** 103 | * GREG ADDED 104 | */ 105 | public static boolean isContained(int spanStart, int spanEnd, List> spans) { 106 | for (Pair span : spans) { 107 | if (span.getFirst().intValue() <= spanStart && spanEnd <= span.getSecond().intValue()) { 108 | return true; 109 | } 110 | } 111 | return false; 112 | } 113 | 114 | 115 | public static List getContainedIndices(int spanStart, int spanEnd, List> subSpans) { 116 | List containedIndices = new ArrayList(); 117 | for (int i = 0; i < subSpans.size(); i++) { 118 | Pair span = subSpans.get(i); 119 | if (spanStart <= span.getFirst().intValue() && span.getSecond().intValue() <= spanEnd) { 120 | containedIndices.add(i); 121 | } 122 | } 123 | return containedIndices; 124 | } 125 | 126 | /** 127 | * GREG'S VERSION 128 | */ 129 | public static Pair paintTreeModified(Tree t, Point2D start, Graphics2D g2, FontMetrics fM, int charOffset, Color color, 130 | List> compressedSpans, List> pronReplacedSpans, List pronReplacements) { 131 | g2.setColor(color); 132 | String nodeStr = nodeToString(t); 133 | double nodeWidth = fM.stringWidth(nodeStr); 134 | double nodeHeight = fM.getHeight(); 135 | double nodeAscent = fM.getAscent(); 136 | WidthResult wr = widthResult(t, fM); 137 | double treeWidth = wr.width; 138 | double nodeTab = wr.nodeTab; 139 | double childTab = wr.childTab; 140 | double nodeCenter = wr.nodeCenter; 141 | g2.drawString(nodeStr, (float) (nodeTab + start.getX()), (float) (start.getY() + nodeAscent)); 142 | if (t.isLeaf()) { 143 | return Pair.makePair(new Double(nodeWidth), new Double(nodeHeight)); 144 | } 145 | double layerMultiplier = (1.0 + belowLineSkip + aboveLineSkip + parentSkip); 146 | double layerHeight = nodeHeight * layerMultiplier; 147 | double childStartX = start.getX() + childTab; 148 | double childStartY = start.getY() + layerHeight; 149 | double lineStartX = start.getX() + nodeCenter; 150 | double lineStartY = start.getY() + nodeHeight * (1.0 + belowLineSkip); 151 | double lineEndY = lineStartY + nodeHeight * parentSkip; 152 | // recursively draw children 153 | int currCharOffset = charOffset; 154 | double maxChildHeight = 0; 155 | for (int i = 0; i < t.getChildren().size(); i++) { 156 | Tree child = t.getChildren().get(i); 157 | int childSize = 0; 158 | for (String leaf : child.getYield()) { 159 | childSize += leaf.length() + 1; 160 | } 161 | boolean isChildCompressed = isContained(currCharOffset, currCharOffset + childSize - 1, compressedSpans); 162 | Color childColor = (isChildCompressed ? Color.gray : child.isLeaf() ? SummarizerDemo.alignmentColor() : color); 163 | Pair childWidthHeight = paintTreeModified(child, new Point2D.Double(childStartX, childStartY), g2, fM, currCharOffset, childColor, 164 | compressedSpans, pronReplacedSpans, pronReplacements); 165 | double cWidth = childWidthHeight.getFirst().doubleValue(); 166 | maxChildHeight = Math.max(maxChildHeight, childWidthHeight.getSecond().doubleValue()); 167 | // draw connectors 168 | wr = widthResult(child, fM); 169 | g2.setColor((isChildCompressed ? Color.gray : color)); 170 | if (drawTrianglesAtBottom && child.isLeaf()) { 171 | double triangleStartX = childStartX; 172 | double triangleEndX = childStartX + wr.width; 173 | int[] xs = {(int)lineStartX, (int)triangleStartX, (int)triangleEndX, (int)lineStartX}; 174 | int[] ys = {(int)lineStartY, (int)lineEndY, (int)lineEndY, (int)lineStartY}; 175 | g2.drawPolyline(xs, ys, 4); 176 | } else { 177 | double lineEndX = childStartX + wr.nodeCenter; 178 | g2.draw(new Line2D.Double(lineStartX, lineStartY, lineEndX, 179 | lineEndY)); 180 | } 181 | childStartX += cWidth; 182 | if (i < t.getChildren().size() - 1) { 183 | childStartX += sisterSkip * fM.stringWidth(" "); 184 | } 185 | currCharOffset += childSize; 186 | } 187 | return Pair.makePair(new Double(treeWidth), new Double(lineEndY - start.getY() + maxChildHeight)); 188 | } 189 | } -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/preprocess/EDUSegmenterSemiMarkov.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ.preprocess 2 | 3 | import scala.collection.JavaConverters._ 4 | import scala.collection.mutable.ArrayBuffer 5 | import scala.collection.mutable.HashMap 6 | import edu.berkeley.nlp.futile.LightRunner 7 | import edu.berkeley.nlp.futile.classify.ClassifyUtils 8 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 9 | import edu.berkeley.nlp.futile.fig.basic.Indexer 10 | import edu.berkeley.nlp.futile.syntax.Trees.PennTreeRenderer 11 | import edu.berkeley.nlp.futile.util.Logger 12 | import edu.berkeley.nlp.entity.ConllDoc 13 | import edu.berkeley.nlp.summ.data.DiscourseDepEx 14 | import edu.berkeley.nlp.summ.LikelihoodAndGradientComputer 15 | 16 | @SerialVersionUID(1L) 17 | class EDUSegmenterSemiMarkovFeaturizer(val featIdx: Indexer[String], 18 | val wrappedFeaturizer: EDUSegmenterFeaturizer) extends Serializable { 19 | 20 | private def maybeAdd(feats: ArrayBuffer[Int], addToIndexer: Boolean, feat: String) { 21 | if (addToIndexer) { 22 | feats += featIdx.getIndex(feat) 23 | } else { 24 | val idx = featIdx.indexOf(feat) 25 | if (idx != -1) { 26 | feats += idx 27 | } 28 | } 29 | } 30 | 31 | def extractFeaturesCached(ex: DiscourseDepEx, addToIndexer: Boolean): Array[Array[Array[Array[Int]]]] = { 32 | // Only featurize sentence-internal boundaries (sentence boundaries are trivially EDU segments) 33 | if (ex.cachedEduSemiMarkovFeatures == null) { 34 | ex.cachedEduSemiMarkovFeatures = extractFeatures(ex.conllDoc, addToIndexer) 35 | } 36 | ex.cachedEduSemiMarkovFeatures 37 | } 38 | 39 | def extractFeatures(doc: ConllDoc, addToIndexer: Boolean): Array[Array[Array[Array[Int]]]] = { 40 | // Features on boundaries from the binary version 41 | val wrappedFeats = wrappedFeaturizer.extractFeatures(doc, addToIndexer) 42 | Array.tabulate(doc.numSents)(sentIdx => { 43 | Array.tabulate(doc.words(sentIdx).size)(startIdx => { 44 | Array.tabulate(doc.words(sentIdx).size + 1)(endIdx => { 45 | if (endIdx > startIdx) { 46 | extractFeatures(doc, sentIdx, startIdx, endIdx, wrappedFeats, addToIndexer) 47 | } else { 48 | Array[Int]() 49 | } 50 | }) 51 | }) 52 | }) 53 | } 54 | 55 | private def extractFeatures(doc: ConllDoc, sentIdx: Int, startIdx: Int, endIdx: Int, wrappedFeats: Array[Array[Array[Int]]], addToIndexer: Boolean): Array[Int] = { 56 | val feats = new ArrayBuffer[Int] 57 | def add(feat: String) = maybeAdd(feats, addToIndexer, feat) 58 | val bucketedLen = wrappedFeaturizer.bucket(endIdx - startIdx) 59 | if (startIdx > 0) { 60 | // Don't add these features because they'll be the end of some other span by definition 61 | // feats ++= wrappedFeaturizer.extractFeatures(doc, sentIdx, endIdx - 1, addToIndexer) 62 | } else { 63 | add("StartSent,Len=" + bucketedLen) 64 | } 65 | if (endIdx < doc.words(sentIdx).size - 1) { 66 | feats ++= wrappedFeats(sentIdx)(endIdx - 1) 67 | } else { 68 | add("EndSent,Len=" + bucketedLen) 69 | if (startIdx == 0) { 70 | add("WholeSent,Len=" + bucketedLen) 71 | } 72 | } 73 | if (endIdx - startIdx == 1) { 74 | add("SingleWord=" + doc.words(sentIdx)(startIdx)) 75 | } else if (endIdx - startIdx == 2) { 76 | add("TwoWords,First=" + doc.words(sentIdx)(startIdx)) 77 | add("TwoWords,Second=" + doc.words(sentIdx)(startIdx+1)) 78 | } 79 | // Look at first and last, also the context words 80 | val beforePos = if (startIdx == 0) "" else doc.pos(sentIdx)(startIdx-1) 81 | val firstPos = doc.pos(sentIdx)(startIdx) 82 | val lastPos = doc.pos(sentIdx)(endIdx - 1) 83 | val afterPos = if (endIdx == doc.pos(sentIdx).size) "" else doc.pos(sentIdx)(endIdx) 84 | add("FirstLastPOS=" + firstPos + "-" + afterPos) 85 | add("BeforeAfterPOS=" + beforePos + "-" + afterPos) 86 | // add("BFLAPOS=" + beforePos + "-" + firstPos + "-" + lastPos + "-" + afterPos) 87 | var dominatingConstituents = doc.trees(sentIdx).getAllConstituentTypes(startIdx, endIdx) 88 | if (dominatingConstituents.isEmpty) { 89 | dominatingConstituents = Seq("None") 90 | } else { 91 | // None of these dependency features seem to help 92 | // // We have a valid span, fire features on dependencies 93 | // val headIdx = doc.trees(sentIdx).getSpanHead(startIdx, endIdx) 94 | //// add("HeadWord=" + doc.words(sentIdx)(headIdx)) 95 | // add("HeadPos=" + doc.pos(sentIdx)(headIdx)) 96 | // val parentIdx = doc.trees(sentIdx).childParentDepMap(headIdx) 97 | // if (parentIdx == -1) { 98 | // add("Parent=ROOT") 99 | // } else { 100 | //// add("ParentWord=" + doc.words(sentIdx)(parentIdx)) 101 | // add("ParentPos=" + doc.pos(sentIdx)(parentIdx)) 102 | // add("ParentDist=" + Math.signum(parentIdx - headIdx) + ":" + wrappedFeaturizer.bucket(parentIdx - headIdx)) 103 | // } 104 | } 105 | // Fire features on constituent labels (or None if it isn't a constituent) 106 | for (constituent <- dominatingConstituents) { 107 | add("DominatingConstituent=" + constituent) 108 | add("DominatingConstituentLength=" + constituent + "-" + bucketedLen) 109 | add("DominatingConstituentBefore=" + constituent + "-" + beforePos) 110 | add("DominatingConstituentAfter=" + constituent + "-" + afterPos) 111 | // This makes it way slower and doesn't help 112 | // val maybeParent = doc.trees(sentIdx).getParent(startIdx, endIdx) 113 | // if (!maybeParent.isDefined) { 114 | // add("DominatingParent=None") 115 | // } else { 116 | // val (parent, childIdx) = maybeParent.get 117 | // val childrenStr = (0 until parent.getChildren().size).map(i => (if (childIdx == i) ">" else "") + parent.getChildren().get(i).getLabel()).foldLeft("")(_ + " " + _) 118 | //// Logger.logss(parent.getLabel() + " ->" + childrenStr) 119 | //// add("DominatingRule=" + parent.getLabel() + " ->" + childrenStr) 120 | // add("DominatingParent=" + parent.getLabel() + " -> " + constituent) 121 | // } 122 | } 123 | feats.toArray 124 | } 125 | } 126 | 127 | @SerialVersionUID(1L) 128 | class EDUSegmenterSemiMarkovComputer(val featurizer: EDUSegmenterSemiMarkovFeaturizer, 129 | val wholeSpanLossScale: Double = 4.0) extends LikelihoodAndGradientComputer[DiscourseDepEx] with Serializable { 130 | 131 | def getInitialWeights(initialWeightsScale: Double): Array[Double] = Array.tabulate(featurizer.featIdx.size)(i => 0.0) 132 | 133 | def accumulateGradientAndComputeObjective(ex: DiscourseDepEx, weights: Array[Double], gradient: Array[Double]): Double = { 134 | val (predSegs, predScore) = decode(ex, weights, 1.0); 135 | // val recomputedPredScore = scoreParse(ex, weights, predParents, 1.0) 136 | val goldSegs = ex.goldEduSpans 137 | val goldScore = scoreSegmentation(ex, weights, goldSegs, 1.0) 138 | // Logger.logss("Pred score: " + predScore + ", recomputed pred score: " + recomputedPredScore + ", gold score: " + goldScore) 139 | for (sentIdx <- 0 until ex.conllDoc.numSents) { 140 | for (startIdx <- 0 until ex.conllDoc.words(sentIdx).size) { 141 | for (endIdx <- startIdx + 1 to ex.conllDoc.words(sentIdx).size) { 142 | val seg = startIdx -> endIdx 143 | val increment = (if (goldSegs(sentIdx).contains(seg)) 1 else 0) + (if (predSegs(sentIdx).contains(seg)) -1 else 0) 144 | if (increment != 0) { 145 | val feats = ex.cachedEduSemiMarkovFeatures(sentIdx)(startIdx)(endIdx) 146 | for (feat <- feats) { 147 | gradient(feat) += increment 148 | } 149 | } 150 | } 151 | } 152 | } 153 | predScore - goldScore 154 | } 155 | 156 | def computeObjective(ex: DiscourseDepEx, weights: Array[Double]): Double = accumulateGradientAndComputeObjective(ex, weights, Array.tabulate(weights.size)(i => 0.0)) 157 | 158 | def decode(ex: DiscourseDepEx, weights: Array[Double]): Array[Array[Boolean]] = { 159 | EDUSegmenterSemiMarkov.convertSegsToBooleanArray(decode(ex, weights, 0)._1) 160 | } 161 | 162 | def decode(ex: DiscourseDepEx, weights: Array[Double], lossWeight: Double): (Array[Seq[(Int,Int)]], Double) = { 163 | val feats = featurizer.extractFeaturesCached(ex, false) 164 | var cumScore = 0.0 165 | val allPreds = Array.tabulate(ex.conllDoc.numSents)(sentIdx => { 166 | val result = decodeSentence(feats(sentIdx), ex.conllDoc.words(sentIdx).size, weights, lossWeight, Some(ex.goldEduSpans(sentIdx))) 167 | cumScore += result._2 168 | result._1 169 | }) 170 | (allPreds, cumScore) 171 | } 172 | 173 | def decodeSentence(feats: Array[Array[Array[Int]]], sentLen: Int, weights: Array[Double], lossWeight: Double, goldSpans: Option[Seq[(Int,Int)]]): (Seq[(Int,Int)], Double) = { 174 | val chart = Array.tabulate(sentLen + 1)(i => if (i == 0) 0.0 else Double.NegativeInfinity) 175 | val backptrs = Array.tabulate(sentLen + 1)(i => -1) 176 | for (endIdx <- 1 to sentLen) { 177 | for (startIdx <- 0 until endIdx) { 178 | val isGold = if (goldSpans.isDefined) goldSpans.get.contains(startIdx -> endIdx) else false 179 | val lossScore = if (!isGold) { 180 | if (startIdx == 0 && endIdx == sentLen) { 181 | // lossWeight 182 | lossWeight * wholeSpanLossScale 183 | } else { 184 | lossWeight 185 | } 186 | } else { 187 | 0.0 188 | } 189 | val score = ClassifyUtils.scoreIndexedFeats(feats(startIdx)(endIdx), weights) + lossScore 190 | if (chart(startIdx) + score > chart(endIdx)) { 191 | backptrs(endIdx) = startIdx 192 | chart(endIdx) = chart(startIdx) + score 193 | } 194 | } 195 | } 196 | // Recover the gold derivation 197 | val pairs = new ArrayBuffer[(Int,Int)] 198 | var ptr = sentLen 199 | while (ptr > 0) { 200 | pairs.prepend(backptrs(ptr) -> ptr) 201 | ptr = backptrs(ptr) 202 | } 203 | (pairs.toSeq, chart(sentLen)) 204 | } 205 | 206 | private def scoreSegmentation(ex: DiscourseDepEx, weights: Array[Double], segmentation: Seq[Seq[(Int,Int)]], lossWeight: Double) = { 207 | var score = 0.0 208 | val feats = featurizer.extractFeaturesCached(ex, false) 209 | for (sentIdx <- 0 until ex.conllDoc.numSents) { 210 | for (segment <- segmentation(sentIdx)) { 211 | val isGold = ex.goldEduSpans(sentIdx).contains(segment) 212 | score += ClassifyUtils.scoreIndexedFeats(feats(sentIdx)(segment._1)(segment._2), weights) + (if (!isGold) lossWeight else 0.0) 213 | } 214 | } 215 | score 216 | } 217 | } 218 | 219 | @SerialVersionUID(1L) 220 | class EDUSegmenterSemiMarkov(val computer: EDUSegmenterSemiMarkovComputer, 221 | val weights: Array[Double]) extends EDUSegmenter { 222 | 223 | def decode(ex: DiscourseDepEx) = computer.decode(ex, weights) 224 | 225 | def decode(doc: ConllDoc) = { 226 | val feats = computer.featurizer.extractFeatures(doc, false) 227 | val result = Array.tabulate(feats.size)(i => { 228 | computer.decodeSentence(feats(i), doc.words(i).size, weights, 0.0, None)._1 229 | }) 230 | EDUSegmenterSemiMarkov.convertSegsToBooleanArray(result) 231 | } 232 | } 233 | 234 | object EDUSegmenterSemiMarkov { 235 | 236 | def convertSegsToBooleanArray(segments: Seq[Seq[(Int,Int)]]): Array[Array[Boolean]] = { 237 | Array.tabulate(segments.size)(i => { 238 | val seq = segments(i) 239 | val starts = seq.map(_._1) 240 | Array.tabulate(seq.last._2 - 1)(i => starts.contains(i+1)) 241 | }) 242 | } 243 | } 244 | -------------------------------------------------------------------------------- /src/main/scala/edu/berkeley/nlp/summ/preprocess/StandoffAnnotationHandler.scala: -------------------------------------------------------------------------------- 1 | package edu.berkeley.nlp.summ.preprocess 2 | 3 | import edu.berkeley.nlp.entity.ConllDocReader 4 | import edu.berkeley.nlp.entity.lang.Language 5 | import edu.berkeley.nlp.futile.fig.basic.IOUtils 6 | import scala.collection.JavaConverters._ 7 | import edu.berkeley.nlp.futile.util.Logger 8 | import edu.berkeley.nlp.futile.util.Counter 9 | import scala.collection.mutable.ArrayBuffer 10 | import scala.collection.mutable.HashMap 11 | import edu.berkeley.nlp.entity.ConllDoc 12 | import java.io.File 13 | import edu.berkeley.nlp.entity.ConllDocWriter 14 | import edu.berkeley.nlp.futile.LightRunner 15 | 16 | /** 17 | * Handles combining standoff annotations with data files from the New York Times corpus. 18 | */ 19 | object StandoffAnnotationHandler { 20 | 21 | val readAnnotations = true 22 | val inputDir = "" 23 | val outputDir = "" 24 | val rawXMLDir = "" 25 | 26 | val maxNumFiles = Int.MaxValue 27 | 28 | val tagName = "block class=\"full_text\"" 29 | 30 | def main(args: Array[String]) { 31 | LightRunner.initializeOutput(StandoffAnnotationHandler.getClass()) 32 | LightRunner.populateScala(StandoffAnnotationHandler.getClass(), args) 33 | if (inputDir.isEmpty || outputDir.isEmpty || rawXMLDir.isEmpty) { 34 | Logger.logss("Need all three of inputDir, outputDir, and rawXMLDir to be specified!") 35 | } 36 | // Reconstituting documents 37 | if (readAnnotations) { 38 | var numReadFailures = 0 39 | val corefNerFilesToAlign = new File(inputDir).listFiles() 40 | for (corefNerFile <- corefNerFilesToAlign) { 41 | val rawXMLPath = rawXMLDir + "/" + corefNerFile.getName() + ".xml" 42 | Logger.logss("Dealing with " + corefNerFile.getAbsolutePath) 43 | val maybeReconstitutedConllDoc = reverseStandoffAnnotations(corefNerFile.getAbsolutePath, rawXMLPath) 44 | if (!maybeReconstitutedConllDoc.isDefined) { 45 | numReadFailures += 1 46 | } else if (maybeReconstitutedConllDoc.get.docID == "") { 47 | val path = outputDir + "/" + corefNerFile.getName 48 | val outWriter = IOUtils.openOutHard(path) 49 | Logger.logss("Wrote to " + path) 50 | outWriter.close 51 | } else { 52 | val path = outputDir + "/" + maybeReconstitutedConllDoc.get.docID 53 | val outWriter = IOUtils.openOutHard(path) 54 | ConllDocWriter.writeDoc(outWriter, maybeReconstitutedConllDoc.get) 55 | Logger.logss("Wrote to " + path) 56 | outWriter.close 57 | } 58 | } 59 | Logger.logss("Total read failures: " + numReadFailures + " / " + corefNerFilesToAlign.size) 60 | } else { 61 | // Producing the standoff annotations in the first place 62 | val rawCorefNerFilesToAlign = new File(inputDir).listFiles() 63 | val corefNerFilesToAlign = rawCorefNerFilesToAlign.slice(0, Math.min(maxNumFiles, rawCorefNerFilesToAlign.size)) 64 | for (corefNerFile <- corefNerFilesToAlign) { 65 | val rawXMLPath = rawXMLDir + "/" + corefNerFile.getName() + ".xml" 66 | Logger.logss("=========================") 67 | Logger.logss("Aligning " + corefNerFile.getAbsolutePath + " " + rawXMLPath) 68 | val standoffConllDoc = makeStandoffAnnotations(corefNerFile.getAbsolutePath, rawXMLPath, tagName) 69 | // This happens if we have a blank file 70 | if (standoffConllDoc.docID.isEmpty) { 71 | // Write nothing 72 | val path = outputDir + "/" + corefNerFile.getName 73 | val outWriter = IOUtils.openOutHard(path) 74 | outWriter.close 75 | } else { 76 | val path = outputDir + "/" + standoffConllDoc.docID 77 | val outWriter = IOUtils.openOutHard(path) 78 | ConllDocWriter.writeDoc(outWriter, standoffConllDoc) 79 | Logger.logss("Wrote to " + path) 80 | outWriter.close 81 | } 82 | } 83 | } 84 | LightRunner.finalizeOutput() 85 | } 86 | 87 | val tokenizationMapping = new HashMap[String,String] ++ Seq("(" -> "-LRB-", ")" -> "-RRB-", 88 | "[" -> "-LSB-", "]" -> "-RSB-", 89 | "{" -> "-LCB-", "}" -> "-RCB-", 90 | "&" -> "&") 91 | val reverseTokenizationMapping = new HashMap[String,String] 92 | for (entry <- tokenizationMapping) { 93 | reverseTokenizationMapping += entry._2 -> entry._1 94 | } 95 | 96 | // def extractWords(alignments: ArrayBuffer[ArrayBuffer[(Int,Int,Int)]], docLines: Seq[String]) = { 97 | // alignments.map(_.map(alignment => docLines(alignment._1).substring(alignment._2, alignment._3))) 98 | // } 99 | 100 | // Use :: as delimiter 101 | val delimiter = "::" 102 | val alignmentRe = ("[0-9]+" + delimiter + "[0-9]+" + delimiter + "[0-9]+").r 103 | 104 | def extractWord(alignment: String, docLines: Seq[String]) = { 105 | if (alignmentRe.findFirstIn(alignment) != None) { 106 | val alignmentSplit = alignment.split(delimiter).map(_.toInt) 107 | val word = docLines(alignmentSplit(0)).substring(alignmentSplit(1), alignmentSplit(2)) 108 | if (tokenizationMapping.contains(word)) { 109 | tokenizationMapping(word) 110 | } else { 111 | word 112 | } 113 | } else { 114 | alignment 115 | } 116 | } 117 | 118 | def extractWords(alignments: ArrayBuffer[ArrayBuffer[String]], docLines: Seq[String]) = { 119 | alignments.map(_.map(alignment => extractWord(alignment, docLines))) 120 | } 121 | 122 | def fetchCurrWord(corefNerDoc: ConllDoc, sentIdx: Int, wordIdx: Int) = { 123 | val word = corefNerDoc.words(sentIdx)(wordIdx) 124 | if (reverseTokenizationMapping.contains(word)) { 125 | reverseTokenizationMapping(word) 126 | } else { 127 | word 128 | } 129 | } 130 | 131 | def doMatch(lineChar: Char, docChar: Char) = { 132 | lineChar == docChar || Character.toLowerCase(lineChar) == Character.toLowerCase(docChar) || 133 | ((lineChar == ';' || lineChar == '.' || lineChar == '?' || lineChar == '!') && (docChar == ';' || docChar == '.' || docChar == '?' || docChar == '!')) || 134 | ((lineChar == ''' || lineChar == '`') && (docChar == ''' || docChar == '`')) 135 | 136 | } 137 | 138 | def makeStandoffAnnotations(corefNerFile: String, rawXMLFile: String, tagName: String) = { 139 | val corefNerDoc = new ConllDocReader(Language.ENGLISH).readConllDocs(corefNerFile).head 140 | val rawXMLLines = IOUtils.readLinesHard(rawXMLFile).asScala 141 | // Only do the alignment if there are a nonzero number of words 142 | val newConllDoc = if (corefNerDoc.words.size > 0) { 143 | val alignments = align(corefNerDoc, rawXMLLines, tagName: String) 144 | new ConllDoc(corefNerDoc.docID, 145 | corefNerDoc.docPartNo, 146 | alignments, 147 | corefNerDoc.pos, 148 | corefNerDoc.trees, 149 | corefNerDoc.nerChunks, 150 | corefNerDoc.corefChunks, 151 | corefNerDoc.speakers) 152 | } else { 153 | corefNerDoc 154 | } 155 | newConllDoc 156 | } 157 | 158 | def align(corefNerDoc: ConllDoc, rawXMLLines: Seq[String], tagName: String) = { 159 | val closeTagName = tagName.substring(0, if (tagName.contains(" ")) tagName.indexOf(" ") else tagName.size) 160 | var sentPtr = 0 161 | var wordPtr = 0 162 | var charPtr = 0 163 | var currWord = fetchCurrWord(corefNerDoc, sentPtr, wordPtr) 164 | val alignments = new ArrayBuffer[ArrayBuffer[String]] 165 | alignments += new ArrayBuffer[String] 166 | var inText = false 167 | var badCharactersSkipped = 0 168 | // Iterate through the XML file and advance through the CoNLL document simultaneously 169 | for (lineIdx <- 0 until rawXMLLines.size) { 170 | val line = rawXMLLines(lineIdx) 171 | if (line.contains("")) { 172 | inText = false 173 | } 174 | if (inText) { 175 | // Logger.logss(sentPtr) 176 | if (!line.contains("

") || !line.contains("

")) { 177 | Logger.logss("ANOMALOUS LINE") 178 | Logger.logss(line) 179 | inText = false 180 | } 181 | val lineStart = line.indexOf("

") + 3 182 | val relevantLine = if (line.contains("

") && line.contains("

")) line.substring(lineStart, line.indexOf("

")) else line 183 | // Logger.logss("CURRENT LINE: " + relevantLine) 184 | var linePtr = 0 185 | while (inText && linePtr < relevantLine.size) { 186 | // Logger.logss("Checking line index " + linePtr + "; looking for " + corefNerDoc.words(sentPtr)(wordPtr)(charPtr) + " and was " + relevantLine.substring(linePtr, linePtr+1)) 187 | if (doMatch(relevantLine.charAt(linePtr), currWord(charPtr))) { 188 | // Logger.logss("Matching " + relevantLine.substring(linePtr, linePtr+1)) 189 | charPtr += 1 190 | } else if (!Character.isWhitespace(relevantLine.charAt(linePtr))) { 191 | badCharactersSkipped += 1 192 | // Logger.logss("Bad character skipped! " + relevantLine.charAt(linePtr) + " " + currWord(charPtr)) 193 | } 194 | if (charPtr == currWord.size) { 195 | // Store the alignment 196 | val alignment = lineIdx + delimiter + (lineStart + linePtr - currWord.size + 1) + delimiter + (lineStart + linePtr + 1) 197 | if (linePtr - currWord.size + 1 >= 0 && relevantLine.substring(linePtr - currWord.size + 1, linePtr + 1) == currWord) { 198 | alignments.last += alignment 199 | } else { 200 | // Logger.logss("Mismatch: :" + currWord + ": :" + relevantLine.substring(linePtr - currWord.size + 1, linePtr + 1) + ":") 201 | alignments.last += currWord 202 | } 203 | // Logger.logss("Storing alignment: " + currWord + " " + alignment) 204 | wordPtr += 1 205 | charPtr = 0 206 | } 207 | if (wordPtr == corefNerDoc.words(sentPtr).size) { 208 | alignments += new ArrayBuffer[String] 209 | sentPtr += 1 210 | // Logger.logss("NEW SENTENCE: " + corefNerDoc.words(sentPtr).reduce( _ + " " + _)) 211 | wordPtr = 0 212 | } 213 | // If we're all done 214 | if (sentPtr >= corefNerDoc.words.size) { 215 | inText = false 216 | } else { 217 | // Otherwise, possibly update the current word we're targeting 218 | currWord = fetchCurrWord(corefNerDoc, sentPtr, wordPtr) 219 | } 220 | linePtr += 1 221 | } 222 | } else if (line.contains("<" + tagName + ">")) { 223 | inText = true 224 | } 225 | } 226 | // Drop the last entry if it's empty, which it will be if everything is consumed. 227 | if (alignments.last.isEmpty && alignments.size >= 2) { 228 | alignments.remove(alignments.size - 1) 229 | } 230 | // Check if all lines are the right length. If not, we should just dump the whole file. 231 | var catastrophicError = false 232 | if (alignments.size != corefNerDoc.words.size) { 233 | Logger.logss("Wrong number of lines! " + alignments.size + " " + corefNerDoc.words.size) 234 | catastrophicError = true 235 | } 236 | for (lineIdx <- 0 until Math.min(alignments.size, corefNerDoc.words.size)) { 237 | val alignmentLen = alignments(lineIdx).size 238 | val corefNerDocLen = corefNerDoc.words(lineIdx).size 239 | if (alignmentLen != corefNerDocLen) { 240 | Logger.logss("Wrong number of words in line " + lineIdx + "! " + alignments(lineIdx).size + " " + corefNerDoc.words(lineIdx).size) 241 | // Primarily useful for repairing sentence-final punctuation that was added during preprocessing 242 | if (alignmentLen == corefNerDocLen - 1 && (alignmentLen == 0 || extractWord(alignments(lineIdx)(alignmentLen - 1), rawXMLLines) == corefNerDoc.words(lineIdx)(alignmentLen - 1))) { 243 | Logger.logss("Repaired!") 244 | alignments(lineIdx) += corefNerDoc.words(lineIdx).last 245 | } else { 246 | catastrophicError = true 247 | } 248 | } 249 | } 250 | // If we've had a catastrophic error, just use the raw words rather than standoff annotations. 251 | val finalAlignments = if (catastrophicError) { 252 | Logger.logss("XXXXXXXX CATASTROPHIC ERROR XXXXXXXX") 253 | new ArrayBuffer[ArrayBuffer[String]] ++ corefNerDoc.words.map(new ArrayBuffer[String] ++ _) 254 | } else { 255 | alignments 256 | } 257 | // Verify that the words are the same 258 | val reextractedWords = extractWords(finalAlignments, rawXMLLines) 259 | if (reextractedWords.size != corefNerDoc.words.size) { 260 | Logger.logss("UNEXPLAINED mismatch! " + reextractedWords.size + " " + corefNerDoc.words.size) 261 | } 262 | var someMistake = false 263 | var standoffCounter = 0 264 | var missCounter = 0 265 | for (lineIdx <- 0 until finalAlignments.size) { 266 | if (reextractedWords(lineIdx).size != corefNerDoc.words(lineIdx).size) { 267 | Logger.logss("UNEXPLAINED mismatch in line " + lineIdx + ": " + reextractedWords(lineIdx).size + " " + corefNerDoc.words(lineIdx).size) 268 | } 269 | for (wordIdx <- 0 until finalAlignments(lineIdx).size) { 270 | if (alignmentRe.findFirstIn(finalAlignments(lineIdx)(wordIdx)) != None) { 271 | standoffCounter += 1 272 | } else { 273 | missCounter += 1 274 | } 275 | if (reextractedWords(lineIdx)(wordIdx) != corefNerDoc.words(lineIdx)(wordIdx)) { 276 | Logger.logss("Mismatched word! " + reextractedWords(lineIdx)(wordIdx) + " " + corefNerDoc.words(lineIdx)(wordIdx)) 277 | someMistake = true 278 | } 279 | } 280 | } 281 | Logger.logss(badCharactersSkipped + " bad characters skipped") 282 | Logger.logss(standoffCounter + " standoffs, " + missCounter + " raw strings") 283 | if (!someMistake) { 284 | Logger.logss("******** ALIGNED CORRECTLY! ********") 285 | } else { 286 | Logger.logss("XXXXXXXX ALIGNED INCORRECTLY! XXXXXXXX") 287 | } 288 | finalAlignments 289 | } 290 | 291 | def reverseStandoffAnnotations(corefNerFile: String, rawXMLFile: String): Option[ConllDoc] = { 292 | var reconstitutedConllDoc: Option[ConllDoc] = None 293 | try { 294 | val standoffConllDoc = new ConllDocReader(Language.ENGLISH).readConllDocs(corefNerFile).head 295 | val rawXMLLines = IOUtils.readLinesHard(rawXMLFile).asScala 296 | val words = extractWords(new ArrayBuffer[ArrayBuffer[String]] ++ standoffConllDoc.words.map(new ArrayBuffer[String] ++ _), rawXMLLines) 297 | reconstitutedConllDoc = Some(new ConllDoc(standoffConllDoc.docID, 298 | standoffConllDoc.docPartNo, 299 | words, 300 | standoffConllDoc.pos, 301 | standoffConllDoc.trees, 302 | standoffConllDoc.nerChunks, 303 | standoffConllDoc.corefChunks, 304 | standoffConllDoc.speakers)) 305 | } catch { 306 | case e: Exception => Logger.logss("Exception reading standoffs for " + corefNerFile + " " + rawXMLFile + ": " + e.toString) 307 | } 308 | reconstitutedConllDoc 309 | } 310 | 311 | // Only needed once, to convert PTB parses to CoNLL docs 312 | // def convertParsesToDocs 313 | } -------------------------------------------------------------------------------- /src/main/scala/mstparser/Feature.java: -------------------------------------------------------------------------------- 1 | /////////////////////////////////////////////////////////////////////////////// 2 | // Copyright (C) 2007 University of Texas at Austin and (C) 2005 3 | // University of Pennsylvania and Copyright (C) 2002, 2003 University 4 | // of Massachusetts Amherst, Department of Computer Science. 5 | // 6 | // This software is licensed under the terms of the Common Public 7 | // License, Version 1.0 or (at your option) any subsequent version. 8 | // 9 | // The license is approved by the Open Source Initiative, and is 10 | // available from their website at http://www.opensource.org. 11 | /////////////////////////////////////////////////////////////////////////////// 12 | 13 | package mstparser; 14 | 15 | import gnu.trove.TLinkableAdaptor; 16 | 17 | /** 18 | * A simple class holding a feature index and value that can be used in a TLinkedList. 19 | * 20 | *

21 | * Created: Sat Nov 10 15:25:10 2001 22 | *

23 | * 24 | * @author Jason Baldridge 25 | * @version $Id: TLinkedList.java,v 1.5 2005/03/26 17:52:56 ericdf Exp $ 26 | * @see mstparser.FeatureVector 27 | */ 28 | 29 | public final class Feature extends TLinkableAdaptor { 30 | 31 | public int index; 32 | 33 | public double value; 34 | 35 | public Feature(int i, double v) { 36 | index = i; 37 | value = v; 38 | } 39 | 40 | @Override 41 | public final Feature clone() { 42 | return new Feature(index, value); 43 | } 44 | 45 | public final Feature negation() { 46 | return new Feature(index, -value); 47 | } 48 | 49 | @Override 50 | public final String toString() { 51 | return index + "=" + value; 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /src/main/scala/mstparser/FeatureVector.java: -------------------------------------------------------------------------------- 1 | /////////////////////////////////////////////////////////////////////////////// 2 | // Copyright (C) 2007 University of Texas at Austin and (C) 2005 3 | // University of Pennsylvania and Copyright (C) 2002, 2003 University 4 | // of Massachusetts Amherst, Department of Computer Science. 5 | // 6 | // This software is licensed under the terms of the Common Public 7 | // License, Version 1.0 or (at your option) any subsequent version. 8 | // 9 | // The license is approved by the Open Source Initiative, and is 10 | // available from their website at http://www.opensource.org. 11 | /////////////////////////////////////////////////////////////////////////////// 12 | 13 | package mstparser; 14 | 15 | import gnu.trove.TIntArrayList; 16 | import gnu.trove.TIntDoubleHashMap; 17 | import gnu.trove.TLinkedList; 18 | 19 | import java.util.ListIterator; 20 | 21 | /** 22 | * A FeatureVector that can hold up to two FeatureVector instances inside it, 23 | * which allows for a very quick concatenation operation. 24 | * 25 | *

26 | * Also, in order to avoid copies, the second of these internal FeatureVector instances can 27 | * be negated, so that it has the effect of subtracting any values rather than adding them. 28 | * 29 | *

30 | * Created: Sat Nov 10 15:25:10 2001 31 | *

32 | * 33 | * @author Jason Baldridge 34 | * @version $Id: FeatureVector.java 137 2013-09-10 09:33:47Z wyldfire $ 35 | * @see mstparser.Feature 36 | */ 37 | public final class FeatureVector extends TLinkedList { 38 | private FeatureVector subfv1 = null; 39 | 40 | private FeatureVector subfv2 = null; 41 | 42 | private boolean negateSecondSubFV = false; 43 | 44 | public FeatureVector() { 45 | } 46 | 47 | public FeatureVector(FeatureVector fv1) { 48 | subfv1 = fv1; 49 | } 50 | 51 | public FeatureVector(FeatureVector fv1, FeatureVector fv2) { 52 | subfv1 = fv1; 53 | subfv2 = fv2; 54 | } 55 | 56 | public FeatureVector(FeatureVector fv1, FeatureVector fv2, boolean negSecond) { 57 | subfv1 = fv1; 58 | subfv2 = fv2; 59 | negateSecondSubFV = negSecond; 60 | } 61 | 62 | public FeatureVector(int[] keys) { 63 | for (int i = 0; i < keys.length; i++) 64 | add(new Feature(keys[i], 1.0)); 65 | } 66 | 67 | public void add(int index, double value) { 68 | add(new Feature(index, value)); 69 | } 70 | 71 | public int[] keys() { 72 | TIntArrayList keys = new TIntArrayList(); 73 | addKeysToList(keys); 74 | return keys.toNativeArray(); 75 | } 76 | 77 | private void addKeysToList(TIntArrayList keys) { 78 | if (null != subfv1) { 79 | subfv1.addKeysToList(keys); 80 | 81 | if (null != subfv2) 82 | subfv2.addKeysToList(keys); 83 | } 84 | 85 | ListIterator it = listIterator(); 86 | while (it.hasNext()) 87 | keys.add(((Feature) it.next()).index); 88 | 89 | } 90 | 91 | public final FeatureVector cat(FeatureVector fl2) { 92 | return new FeatureVector(this, fl2); 93 | } 94 | 95 | // fv1 - fv2 96 | public FeatureVector getDistVector(FeatureVector fl2) { 97 | return new FeatureVector(this, fl2, true); 98 | } 99 | 100 | public final double getScore(double[] parameters) { 101 | return getScore(parameters, false); 102 | } 103 | 104 | private final double getScore(double[] parameters, boolean negate) { 105 | double score = 0.0; 106 | 107 | if (null != subfv1) { 108 | score += subfv1.getScore(parameters, negate); 109 | 110 | if (null != subfv2) { 111 | if (negate) { 112 | score += subfv2.getScore(parameters, !negateSecondSubFV); 113 | } else { 114 | score += subfv2.getScore(parameters, negateSecondSubFV); 115 | } 116 | } 117 | } 118 | 119 | ListIterator it = listIterator(); 120 | 121 | if (negate) { 122 | while (it.hasNext()) { 123 | Feature f = (Feature) it.next(); 124 | score -= parameters[f.index] * f.value; 125 | } 126 | } else { 127 | while (it.hasNext()) { 128 | Feature f = (Feature) it.next(); 129 | score += parameters[f.index] * f.value; 130 | } 131 | } 132 | 133 | return score; 134 | } 135 | 136 | public void update(double[] parameters, double[] total, double alpha_k, double upd) { 137 | update(parameters, total, alpha_k, upd, false); 138 | } 139 | 140 | private final void update(double[] parameters, double[] total, double alpha_k, double upd, 141 | boolean negate) { 142 | 143 | if (null != subfv1) { 144 | subfv1.update(parameters, total, alpha_k, upd, negate); 145 | 146 | if (null != subfv2) { 147 | if (negate) { 148 | subfv2.update(parameters, total, alpha_k, upd, !negateSecondSubFV); 149 | } else { 150 | subfv2.update(parameters, total, alpha_k, upd, negateSecondSubFV); 151 | } 152 | } 153 | } 154 | 155 | ListIterator it = listIterator(); 156 | 157 | if (negate) { 158 | while (it.hasNext()) { 159 | Feature f = (Feature) it.next(); 160 | parameters[f.index] -= alpha_k * f.value; 161 | total[f.index] -= upd * alpha_k * f.value; 162 | } 163 | } else { 164 | while (it.hasNext()) { 165 | Feature f = (Feature) it.next(); 166 | parameters[f.index] += alpha_k * f.value; 167 | total[f.index] += upd * alpha_k * f.value; 168 | } 169 | } 170 | 171 | } 172 | 173 | public double dotProduct(FeatureVector fl2) { 174 | 175 | TIntDoubleHashMap hm1 = new TIntDoubleHashMap(this.size()); 176 | addFeaturesToMap(hm1, false); 177 | hm1.compact(); 178 | 179 | TIntDoubleHashMap hm2 = new TIntDoubleHashMap(fl2.size()); 180 | fl2.addFeaturesToMap(hm2, false); 181 | hm2.compact(); 182 | 183 | int[] keys = hm1.keys(); 184 | 185 | double result = 0.0; 186 | for (int i = 0; i < keys.length; i++) 187 | result += hm1.get(keys[i]) * hm2.get(keys[i]); 188 | 189 | return result; 190 | 191 | } 192 | 193 | private void addFeaturesToMap(TIntDoubleHashMap map, boolean negate) { 194 | if (null != subfv1) { 195 | subfv1.addFeaturesToMap(map, negate); 196 | 197 | if (null != subfv2) { 198 | if (negate) { 199 | subfv2.addFeaturesToMap(map, !negateSecondSubFV); 200 | } else { 201 | subfv2.addFeaturesToMap(map, negateSecondSubFV); 202 | } 203 | } 204 | } 205 | 206 | ListIterator it = listIterator(); 207 | if (negate) { 208 | while (it.hasNext()) { 209 | Feature f = (Feature) it.next(); 210 | if (!map.adjustValue(f.index, -f.value)) 211 | map.put(f.index, -f.value); 212 | } 213 | } else { 214 | while (it.hasNext()) { 215 | Feature f = (Feature) it.next(); 216 | if (!map.adjustValue(f.index, f.value)) 217 | map.put(f.index, f.value); 218 | } 219 | } 220 | } 221 | 222 | @Override 223 | public final String toString() { 224 | StringBuilder sb = new StringBuilder(); 225 | toString(sb); 226 | return sb.toString(); 227 | } 228 | 229 | private final void toString(StringBuilder sb) { 230 | if (null != subfv1) { 231 | subfv1.toString(sb); 232 | 233 | if (null != subfv2) 234 | subfv2.toString(sb); 235 | } 236 | ListIterator it = listIterator(); 237 | while (it.hasNext()) 238 | sb.append(it.next().toString()).append(' '); 239 | } 240 | 241 | } 242 | -------------------------------------------------------------------------------- /src/main/scala/mstparser/KsummKBestParseForest.java: -------------------------------------------------------------------------------- 1 | package mstparser; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | import edu.berkeley.nlp.futile.fig.basic.Pair; 7 | 8 | public class KsummKBestParseForest { 9 | 10 | public static int rootType; 11 | 12 | public ParseForestItem[][][][][] chart; 13 | 14 | private int start, end; 15 | 16 | private int K; 17 | 18 | public KsummKBestParseForest(int start, int end, int K) { 19 | this.K = K; 20 | chart = new ParseForestItem[end + 1][end + 1][2][2][K]; 21 | this.start = start; 22 | this.end = end; 23 | } 24 | 25 | public boolean add(int s, int type, int dir, double score, FeatureVector fv) { 26 | 27 | boolean added = false; 28 | 29 | if (chart[s][s][dir][0][0] == null) { 30 | for (int i = 0; i < K; i++) 31 | chart[s][s][dir][0][i] = new ParseForestItem(s, type, dir, Double.NEGATIVE_INFINITY, null); 32 | } 33 | 34 | if (chart[s][s][dir][0][K - 1].prob > score) 35 | return false; 36 | 37 | for (int i = 0; i < K; i++) { 38 | if (chart[s][s][dir][0][i].prob < score) { 39 | ParseForestItem tmp = chart[s][s][dir][0][i]; 40 | chart[s][s][dir][0][i] = new ParseForestItem(s, type, dir, score, fv); 41 | for (int j = i + 1; j < K && tmp.prob != Double.NEGATIVE_INFINITY; j++) { 42 | ParseForestItem tmp1 = chart[s][s][dir][0][j]; 43 | chart[s][s][dir][0][j] = tmp; 44 | tmp = tmp1; 45 | } 46 | added = true; 47 | break; 48 | } 49 | } 50 | 51 | return added; 52 | } 53 | 54 | public boolean add(int s, int r, int t, int type, int dir, int comp, double score, 55 | FeatureVector fv, ParseForestItem p1, ParseForestItem p2) { 56 | 57 | boolean added = false; 58 | 59 | if (chart[s][t][dir][comp][0] == null) { 60 | for (int i = 0; i < K; i++) 61 | chart[s][t][dir][comp][i] = new ParseForestItem(s, r, t, type, dir, comp, 62 | Double.NEGATIVE_INFINITY, null, null, null); 63 | } 64 | 65 | if (chart[s][t][dir][comp][K - 1].prob > score) 66 | return false; 67 | 68 | for (int i = 0; i < K; i++) { 69 | if (chart[s][t][dir][comp][i].prob < score) { 70 | ParseForestItem tmp = chart[s][t][dir][comp][i]; 71 | chart[s][t][dir][comp][i] = new ParseForestItem(s, r, t, type, dir, comp, score, fv, p1, p2); 72 | for (int j = i + 1; j < K && tmp.prob != Double.NEGATIVE_INFINITY; j++) { 73 | ParseForestItem tmp1 = chart[s][t][dir][comp][j]; 74 | chart[s][t][dir][comp][j] = tmp; 75 | tmp = tmp1; 76 | } 77 | added = true; 78 | break; 79 | } 80 | 81 | } 82 | 83 | return added; 84 | 85 | } 86 | 87 | public double getProb(int s, int t, int dir, int comp) { 88 | return getProb(s, t, dir, comp, 0); 89 | } 90 | 91 | public double getProb(int s, int t, int dir, int comp, int i) { 92 | if (chart[s][t][dir][comp][i] != null) 93 | return chart[s][t][dir][comp][i].prob; 94 | return Double.NEGATIVE_INFINITY; 95 | } 96 | 97 | public double[] getProbs(int s, int t, int dir, int comp) { 98 | double[] result = new double[K]; 99 | for (int i = 0; i < K; i++) 100 | result[i] = chart[s][t][dir][comp][i] != null ? chart[s][t][dir][comp][i].prob 101 | : Double.NEGATIVE_INFINITY; 102 | return result; 103 | } 104 | 105 | public ParseForestItem getItem(int s, int t, int dir, int comp) { 106 | return getItem(s, t, dir, comp, 0); 107 | } 108 | 109 | public ParseForestItem getItem(int s, int t, int dir, int comp, int k) { 110 | if (chart[s][t][dir][comp][k] != null) 111 | return chart[s][t][dir][comp][k]; 112 | return null; 113 | } 114 | 115 | public ParseForestItem[] getItems(int s, int t, int dir, int comp) { 116 | if (chart[s][t][dir][comp][0] != null) 117 | return chart[s][t][dir][comp]; 118 | return null; 119 | } 120 | 121 | public Object[] getBestParse() { 122 | Object[] d = new Object[2]; 123 | d[0] = getFeatureVector(chart[0][end][0][0][0]); 124 | d[1] = getDepString(chart[0][end][0][0][0]); 125 | return d; 126 | } 127 | 128 | public List> getBestParseDepLinks() { 129 | return getDepLinks(chart[0][end][0][0][0]); 130 | } 131 | 132 | public Object[][] getBestParses() { 133 | Object[][] d = new Object[K][2]; 134 | for (int k = 0; k < K; k++) { 135 | if (chart[0][end][0][0][k].prob != Double.NEGATIVE_INFINITY) { 136 | d[k][0] = getFeatureVector(chart[0][end][0][0][k]); 137 | d[k][1] = getDepString(chart[0][end][0][0][k]); 138 | } else { 139 | d[k][0] = null; 140 | d[k][1] = null; 141 | } 142 | } 143 | return d; 144 | } 145 | 146 | public FeatureVector getFeatureVector(ParseForestItem pfi) { 147 | if (pfi.left == null) 148 | return pfi.fv; 149 | 150 | return cat(pfi.fv, cat(getFeatureVector(pfi.left), getFeatureVector(pfi.right))); 151 | } 152 | 153 | public String getDepString(ParseForestItem pfi) { 154 | if (pfi.left == null) 155 | return ""; 156 | 157 | if (pfi.comp == 0) { 158 | return (getDepString(pfi.left) + " " + getDepString(pfi.right)).trim(); 159 | } else if (pfi.dir == 0) { 160 | return ((getDepString(pfi.left) + " " + getDepString(pfi.right)).trim() + " " + pfi.s + "|" 161 | + pfi.t + ":" + pfi.type).trim(); 162 | } else { 163 | return (pfi.t + "|" + pfi.s + ":" + pfi.type + " " + (getDepString(pfi.left) + " " + getDepString(pfi.right)) 164 | .trim()).trim(); 165 | } 166 | } 167 | 168 | public List> getDepLinks(ParseForestItem pfi) { 169 | List> links = new ArrayList>(); 170 | if (pfi.left == null) 171 | return links; 172 | 173 | links.addAll(getDepLinks(pfi.left)); 174 | links.addAll(getDepLinks(pfi.right)); 175 | if (pfi.comp == 0) { 176 | // Do nothing 177 | } else if (pfi.dir == 0) { 178 | links.add(new Pair(pfi.s, pfi.t)); 179 | } else { 180 | links.add(new Pair(pfi.t, pfi.s)); 181 | } 182 | return links; 183 | } 184 | 185 | public FeatureVector cat(FeatureVector fv1, FeatureVector fv2) { 186 | return fv1.cat(fv2); 187 | } 188 | 189 | // returns pairs of indeces and -1,-1 if < K pairs 190 | public int[][] getKBestPairs(ParseForestItem[] items1, ParseForestItem[] items2) { 191 | // in this case K = items1.length 192 | 193 | boolean[][] beenPushed = new boolean[K][K]; 194 | 195 | int[][] result = new int[K][2]; 196 | for (int i = 0; i < K; i++) { 197 | result[i][0] = -1; 198 | result[i][1] = -1; 199 | } 200 | 201 | if (items1 == null || items2 == null || items1[0] == null || items2[0] == null) 202 | return result; 203 | 204 | BinaryHeap heap = new BinaryHeap(K + 1); 205 | int n = 0; 206 | ValueIndexPair vip = new ValueIndexPair(items1[0].prob + items2[0].prob, 0, 0); 207 | 208 | heap.add(vip); 209 | beenPushed[0][0] = true; 210 | 211 | while (n < K) { 212 | vip = heap.removeMax(); 213 | 214 | if (vip.val == Double.NEGATIVE_INFINITY) 215 | break; 216 | 217 | result[n][0] = vip.i1; 218 | result[n][1] = vip.i2; 219 | 220 | n++; 221 | if (n >= K) 222 | break; 223 | 224 | if (!beenPushed[vip.i1 + 1][vip.i2]) { 225 | heap.add(new ValueIndexPair(items1[vip.i1 + 1].prob + items2[vip.i2].prob, vip.i1 + 1, 226 | vip.i2)); 227 | beenPushed[vip.i1 + 1][vip.i2] = true; 228 | } 229 | if (!beenPushed[vip.i1][vip.i2 + 1]) { 230 | heap.add(new ValueIndexPair(items1[vip.i1].prob + items2[vip.i2 + 1].prob, vip.i1, 231 | vip.i2 + 1)); 232 | beenPushed[vip.i1][vip.i2 + 1] = true; 233 | } 234 | 235 | } 236 | 237 | return result; 238 | } 239 | 240 | private static class ValueIndexPair { 241 | public double val; 242 | 243 | public int i1, i2; 244 | 245 | public ValueIndexPair(double val, int i1, int i2) { 246 | this.val = val; 247 | this.i1 = i1; 248 | this.i2 = i2; 249 | } 250 | 251 | public int compareTo(ValueIndexPair other) { 252 | if (val < other.val) 253 | return -1; 254 | if (val > other.val) 255 | return 1; 256 | return 0; 257 | } 258 | 259 | } 260 | 261 | // Max Heap 262 | // We know that never more than K elements on Heap 263 | private static class BinaryHeap { 264 | private int DEFAULT_CAPACITY; 265 | 266 | private int currentSize; 267 | 268 | private ValueIndexPair[] theArray; 269 | 270 | public BinaryHeap(int def_cap) { 271 | DEFAULT_CAPACITY = def_cap; 272 | theArray = new ValueIndexPair[DEFAULT_CAPACITY + 1]; 273 | // theArray[0] serves as dummy parent for root (who is at 1) 274 | // "largest" is guaranteed to be larger than all keys in heap 275 | theArray[0] = new ValueIndexPair(Double.POSITIVE_INFINITY, -1, -1); 276 | currentSize = 0; 277 | } 278 | 279 | public ValueIndexPair getMax() { 280 | return theArray[1]; 281 | } 282 | 283 | private int parent(int i) { 284 | return i / 2; 285 | } 286 | 287 | private int leftChild(int i) { 288 | return 2 * i; 289 | } 290 | 291 | private int rightChild(int i) { 292 | return 2 * i + 1; 293 | } 294 | 295 | public void add(ValueIndexPair e) { 296 | 297 | // bubble up: 298 | int where = currentSize + 1; // new last place 299 | while (e.compareTo(theArray[parent(where)]) > 0) { 300 | theArray[where] = theArray[parent(where)]; 301 | where = parent(where); 302 | } 303 | theArray[where] = e; 304 | currentSize++; 305 | } 306 | 307 | public ValueIndexPair removeMax() { 308 | ValueIndexPair min = theArray[1]; 309 | theArray[1] = theArray[currentSize]; 310 | currentSize--; 311 | boolean switched = true; 312 | // bubble down 313 | for (int parent = 1; switched && parent < currentSize;) { 314 | switched = false; 315 | int leftChild = leftChild(parent); 316 | int rightChild = rightChild(parent); 317 | 318 | if (leftChild <= currentSize) { 319 | // if there is a right child, see if we should bubble down there 320 | int largerChild = leftChild; 321 | if ((rightChild <= currentSize) 322 | && (theArray[rightChild].compareTo(theArray[leftChild])) > 0) { 323 | largerChild = rightChild; 324 | } 325 | if (theArray[largerChild].compareTo(theArray[parent]) > 0) { 326 | ValueIndexPair temp = theArray[largerChild]; 327 | theArray[largerChild] = theArray[parent]; 328 | theArray[parent] = temp; 329 | parent = largerChild; 330 | switched = true; 331 | } 332 | } 333 | } 334 | return min; 335 | } 336 | 337 | } 338 | 339 | } -------------------------------------------------------------------------------- /src/main/scala/mstparser/ParseForestItem.java: -------------------------------------------------------------------------------- 1 | package mstparser; 2 | 3 | public class ParseForestItem { 4 | 5 | public int s, r, t, dir, comp, length, type; 6 | 7 | public double prob; 8 | 9 | public FeatureVector fv; 10 | 11 | public ParseForestItem left, right; 12 | 13 | // productions 14 | public ParseForestItem(int i, int k, int j, int type, int dir, int comp, double prob, 15 | FeatureVector fv, ParseForestItem left, ParseForestItem right) { 16 | this.s = i; 17 | this.r = k; 18 | this.t = j; 19 | this.dir = dir; 20 | this.comp = comp; 21 | this.type = type; 22 | length = 6; 23 | 24 | this.prob = prob; 25 | this.fv = fv; 26 | 27 | this.left = left; 28 | this.right = right; 29 | 30 | } 31 | 32 | // preproductions 33 | public ParseForestItem(int s, int type, int dir, double prob, FeatureVector fv) { 34 | this.s = s; 35 | this.dir = dir; 36 | this.type = type; 37 | length = 2; 38 | 39 | this.prob = prob; 40 | this.fv = fv; 41 | 42 | left = null; 43 | right = null; 44 | 45 | } 46 | 47 | public ParseForestItem() { 48 | } 49 | 50 | public void copyValues(ParseForestItem p) { 51 | p.s = s; 52 | p.r = r; 53 | p.t = t; 54 | p.dir = dir; 55 | p.comp = comp; 56 | p.prob = prob; 57 | p.fv = fv; 58 | p.length = length; 59 | p.left = left; 60 | p.right = right; 61 | p.type = type; 62 | } 63 | 64 | // way forest works, only have to check rule and indeces 65 | // for equality. 66 | public boolean equals(ParseForestItem p) { 67 | return s == p.s && t == p.t && r == p.r && dir == p.dir && comp == p.comp && type == p.type; 68 | } 69 | 70 | public boolean isPre() { 71 | return length == 2; 72 | } 73 | 74 | } 75 | --------------------------------------------------------------------------------