├── .gitignore ├── README.md ├── build.sbt ├── config ├── log4j.xml └── paramtoy.in ├── param.f1 ├── param.in ├── project ├── Dependencies.scala ├── build.properties └── plugins.sbt └── src ├── main └── java │ └── edu │ └── shanghaitech │ └── ai │ └── nlp │ ├── data │ ├── LVeGCorpus.java │ ├── ObjectFileManager.java │ └── StateTreeList.java │ ├── eval │ └── EnglishPennTreebankParseEvaluator.java │ ├── lveg │ ├── LVeGPCFG.java │ ├── LVeGTester.java │ ├── LVeGTesterImp.java │ ├── LVeGTesterSim.java │ ├── LVeGToy.java │ ├── LVeGTrainer.java │ ├── LVeGTrainerImp.java │ ├── LearnerConfig.java │ ├── PCFGMaxRule.java │ ├── impl │ │ ├── BinaryGrammarRule.java │ │ ├── DiagonalGaussianDistribution.java │ │ ├── DiagonalGaussianMixture.java │ │ ├── GaussFactory.java │ │ ├── GeneralGaussianDistribution.java │ │ ├── GeneralGaussianMixture.java │ │ ├── LVeGInferencer.java │ │ ├── LVeGParser.java │ │ ├── MaxRuleInferencer.java │ │ ├── MaxRuleParser.java │ │ ├── MoGFactory.java │ │ ├── PCFGInferencer.java │ │ ├── PCFGMaxRuleInferencer.java │ │ ├── PCFGMaxRuleParser.java │ │ ├── PCFGParser.java │ │ ├── RuleTable.java │ │ ├── SimpleLVeGGrammar.java │ │ ├── SimpleLVeGLexicon.java │ │ ├── UnaryGrammarRule.java │ │ └── Valuator.java │ └── model │ │ ├── ChartCell.java │ │ ├── GaussianDistribution.java │ │ ├── GaussianMixture.java │ │ ├── GrammarRule.java │ │ ├── Inferencer.java │ │ ├── LVeGGrammar.java │ │ ├── LVeGLexicon.java │ │ └── Parser.java │ ├── optimization │ ├── Batch.java │ ├── Gradient.java │ ├── Optimizer.java │ ├── ParallelOptimizer.java │ ├── SGDMinimizer.java │ ├── SimpleMinimizer.java │ └── SimpleOptimizer.java │ ├── syntax │ ├── State.java │ └── Tree.java │ └── util │ ├── Debugger.java │ ├── ErrorUtil.java │ ├── Executor.java │ ├── FunUtil.java │ ├── GradientChecker.java │ ├── Indexer.java │ ├── LogUtil.java │ ├── MutableInteger.java │ ├── Numberer.java │ ├── ObjectPool.java │ ├── Option.java │ ├── OptionParser.java │ ├── Recorder.java │ └── ThreadPool.java └── test └── java └── edu └── shanghaitech └── ai └── nlp ├── data ├── ConstraintTester.java └── F1er.java ├── lveg ├── Java.java ├── LVeGLearnerTest.java ├── LVeGPCFGTest.java ├── LVeGTesterTest.java ├── LVeGToyTest.java ├── impl │ ├── BinaryGrammarTest.java │ ├── DiagonalGaussianDistributionTest.java │ ├── LVeGGrammarTest.java │ ├── RuleTableGeneric.java │ ├── RuleTableTest.java │ └── UnaryGrammarRuleTest.java └── model │ ├── GaussianDistributionTest.java │ ├── GaussianMixtureTest.java │ └── InferencerTest.java ├── optimization ├── ParallelOptimizerTest.java └── SGDForMoGTest.java ├── syntax └── StateTest.java └── util ├── FunUtilTest.java ├── LogUtilTest.java └── ObjectToolTest.java /.gitignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | log/ 3 | libs/ 4 | pool/ 5 | target/ 6 | script/ 7 | classes/ 8 | berkeley/ 9 | .settings/ 10 | implementation_notes.txt 11 | .classpath 12 | .project 13 | .log 14 | 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Latent Vector Grammars (LVeGs) 2 | 3 | Code for [Gaussian Mixture Latent Vector Grammars](https://arxiv.org/abs/1805.04688). 4 | 5 | ## How to Run 6 | 7 | Specify parameter values for learning (param.in) and for inference (param.f1). 8 | 9 | ### Learning 10 | 11 | ``` 12 | $ sbt "run-main edu.shanghaitech.ai.nlp.lveg.LVeGTrainerImp param.in" 13 | ``` 14 | 15 | ### Inference 16 | 17 | ``` 18 | $ sbt "run-main edu.shanghaitech.ai.nlp.lveg.LVeGTesterImp param.f1" 19 | ``` 20 | 21 | ## Data 22 | 23 | Parsing data available at [Google Drive](https://drive.google.com/open?id=1sSwTaVgKJe-oA7jsoM6Mbij_j-xDUdH7). 24 | 25 | ## Models 26 | 27 | Parsing models available at [Google Drive](https://drive.google.com/open?id=1CqWOMn7xWfax5Sj5lP_ypKYieasjviKd). 28 | 29 | ## Dependencies 30 | 31 | sbt-0.13.10, java-1.8.0, and scala-2.12.0. -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | import Version._ 2 | 3 | lazy val commonSettings = Seq( 4 | organization := "edu.shanghaitech.ai.nlp", 5 | version := "0.0.1", 6 | 7 | scalaVersion := Version.scala, 8 | crossScalaVersions := Seq("2.12.0", "2.10.4"), 9 | 10 | libraryDependencies ++= Seq( 11 | Library.log4j, 12 | Library.junit, 13 | Library.pool2, 14 | Library.BerkeleyParser 15 | ), 16 | javacOptions ++= Seq("-source", "1.8"), 17 | javaOptions += "-Xmx10g", 18 | fork := true 19 | ) 20 | 21 | lazy val root = project 22 | .in(file(".")) 23 | .settings(commonSettings: _*) 24 | .settings( 25 | name := "lveg" 26 | ) -------------------------------------------------------------------------------- /config/log4j.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /config/paramtoy.in: -------------------------------------------------------------------------------- 1 | -datadir, E:/SourceCode/ParsersData/wsj/, 2 | -train, wsj_s2-21_tree, 3 | -test, wsj_s23_tree, 4 | -dev, wsj_s22_tree, 5 | -inCorpus, wsj_all.tree, 6 | -outCorpus, wsj_all.tree, 7 | -saveCorpus, false, 8 | -loadCorpus, false, 9 | 10 | -outGrammar, lveg, 11 | -nbatchSave, 10, 12 | -saveGrammar, true, 13 | 14 | -lr, 1, 15 | -reg, false, 16 | -clip, false, 17 | -absmax, 5.0, 18 | -wdecay, 1e-4, 19 | -l1, true, 20 | -minmw, 1e-9, 21 | -epsilon, 1e-8, 22 | -choice, SGD, 23 | -lambda, 0.9, 24 | -lambda1, 0.9, 25 | -lambda2, 0.999, 26 | 27 | -ntcyker, 6, 28 | -ntbatch, 6, 29 | -ntgrad, 6, 30 | -nteval, 6, 31 | -nttest, 6, 32 | -pclose, true, 33 | -pcyker, false, 34 | -pbatch, true, 35 | -peval, true, 36 | -pgrad, true, 37 | -pmode, THREAD_POOL, 38 | -pverbose, false, 39 | 40 | -iosprune, false, 41 | -sampling, false, 42 | -riserate, 2, 43 | -maxnbig, -30, 44 | -rtratio, -0.2, 45 | -hardcut, false, 46 | -expzero, 1e-6, 47 | -bsize, 1, 48 | -nepoch, 3, 49 | -maxsample, 1, 50 | -maxslen, 50, 51 | -nAllowedDrop, 6, 52 | -maxramdom, 1, 53 | -maxmw, 4, 54 | -nwratio, 0.8, 55 | -maxmu, 1, 56 | -nmratio, 0.5, 57 | -maxvar, 5, 58 | -nvratio, 0.8, 59 | -ncomponent, 2, 60 | -dim, 1, 61 | -resetw, false, 62 | -resetc, false, 63 | -mwfactor, 1.0, 64 | -usemasks, false, 65 | -tgbase, 3, 66 | -tgratio, 0.2, 67 | -tgprob, 1e-8, 68 | -iomask, true, 69 | 70 | -eratio, -0.1, 71 | -efraction, 1.0, 72 | -eonlylen, -1, 73 | -efirstk, -1, 74 | -eontrain, true, 75 | -eondev, false, 76 | -eonextradev, false, 77 | -ellprune, false, 78 | -ellimwrite, false, 79 | -epochskipk, 1, 80 | -enbatchdev, 1, 81 | 82 | -pf1, false, 83 | -ef1tag, , 84 | -inGrammar, lveg_best.gr, 85 | -loadGrammar, false, 86 | -ef1prune, false, 87 | -ef1ondev, false, 88 | -ef1ontrain, false, 89 | -ef1imwrite, false, 90 | 91 | -runtag, toy_0827_complex_test1_2_1_bs1, 92 | -logroot, log/, 93 | -logtype, 0, 94 | -verbose, false, 95 | -nbatch, 1, 96 | -precision, 3, 97 | -rndomseed, 0, 98 | 99 | -dgradnbatch, 1, 100 | 101 | -imgprefix, lveg -------------------------------------------------------------------------------- /param.f1: -------------------------------------------------------------------------------- 1 | -datadir, F:/SourceCode/ParsersData/wsj/, 2 | -train, wsj_s2-21_tree, 3 | -test, wsj_s23_tree, 4 | -dev, wsj_s22_tree, 5 | -inCorpus, wsj_all_r0.tree, 6 | -outCorpus, wsj_all.tree, 7 | -saveCorpus, false, 8 | -loadCorpus, true, 9 | 10 | -ntcyker, 6, 11 | -nttest, 6, 12 | -pclose, false, 13 | -pcyker, false, 14 | 15 | -riserate, 2, 16 | -rtratio, -0.2, 17 | -hardcut, false, 18 | -expzero, 1e-3, 19 | -maxramdom, 1, 20 | -maxmw, 10, 21 | -nwratio, 0.8, 22 | -maxmu, 0.1, 23 | -nmratio, 0.5, 24 | -maxvar, 5, 25 | -nvratio, 0.8, 26 | -ncomponent, 1, 27 | -dim, 3, 28 | -maxnbig, 10, 29 | -maxslen, 70, 30 | -usemasks, true, 31 | -tgbase, 3, 32 | -tgratio, 1.0, 33 | -tgprob, 1e-4, 34 | -iomask, false, 35 | -sexp, 0.35, 36 | 37 | -eratio, -0.1, 38 | -eonlylen, 200, 39 | -efirstk, 3000, 40 | -eonextradev, false, 41 | 42 | -pf1, false, 43 | -ef1tag, debug, 44 | -consfile, gr.sophis.200.23.cons, 45 | -inGrammar, lveg_final.gr, 46 | -eusestag, true, 47 | -ef1prune, true, 48 | -ef1ondev, false, 49 | -ef1ontrain, false, 50 | -ef1imwrite, false, 51 | 52 | -runtag, gr_0827_1437_50_4_3_180_nb40_p0_mt55_mf8_l1_r1, 53 | -logroot, log/, 54 | -logtype, 0, 55 | -verbose, false, 56 | -precision, 3, 57 | -rndomseed, 0, 58 | 59 | -imgprefix, lveg -------------------------------------------------------------------------------- /param.in: -------------------------------------------------------------------------------- 1 | -datadir, E:/SourceCode/ParsersData/wsj/, 2 | -train, wsj_s2-21_tree, 3 | -test, wsj_s23_tree, 4 | -dev, wsj_s22_tree, 5 | -inCorpus, wsj_all_r1.tree, 6 | -outCorpus, wsj_all_r0.tree, 7 | -saveCorpus, false, 8 | -loadCorpus, true, 9 | 10 | -outGrammar, lveg, 11 | -nbatchSave, 20, 12 | -saveGrammar, true, 13 | 14 | -lr, 0.01, 15 | -reg, false, 16 | -clip, false, 17 | -absmax, 5.0, 18 | -wdecay, 0.1, 19 | -l1, true, 20 | -minmw, 1e-10, 21 | -epsilon, 1e-8, 22 | -choice, ADAM, 23 | -lambda, 0.9, 24 | -lambda1, 0.9, 25 | -lambda2, 0.999, 26 | 27 | -ntcyker, 6, 28 | -ntbatch, 6, 29 | -ntgrad, 6, 30 | -nteval, 6, 31 | -nttest, 6, 32 | -pclose, false, 33 | -pcyker, false, 34 | -pbatch, true, 35 | -peval, true, 36 | -pgrad, true, 37 | -pmode, THREAD_POOL, 38 | -pverbose, false, 39 | 40 | -iosprune, true, 41 | -sampling, false, 42 | -riserate, 2, 43 | -maxnbig, 40, 44 | -rtratio, -0.2, 45 | -hardcut, false, 46 | -expzero, 1e-3, 47 | -bsize, 180, 48 | -nepoch, 30, 49 | -maxsample, 1, 50 | -maxslen, 60, 51 | -nAllowedDrop, 0, 52 | -maxramdom, 1, 53 | -maxmw, 10, 54 | -nwratio, 0.8, 55 | -maxmu, 0.1, 56 | -nmratio, 0.5, 57 | -maxvar, 5, 58 | -nvratio, 0.8, 59 | -ncomponent, 1, 60 | -dim, 3, 61 | -resetw, true, 62 | -resetc, true, 63 | -mwfactor, 8.0, 64 | -usemasks, true, 65 | -tgbase, 3, 66 | -tgratio, 1.0, 67 | -tgprob, 5e-5, 68 | -iomask, false, 69 | -sexp, 0.35, 70 | -pivota, 0, 71 | -pivotb, 4000, 72 | -resetl, true, 73 | -resetp, false, 74 | 75 | -eratio, -0.1, 76 | -efraction, -0.25, 77 | -eonlylen, 3, 78 | -efirstk, 200, 79 | -eontrain, false, 80 | -eondev, true, 81 | -eonextradev, true, 82 | -ellprune, true, 83 | -ellimwrite, false, 84 | -epochskipk, 1, 85 | -enbatchdev, 1, 86 | 87 | -pf1, true, 88 | -ef1tag, final, 89 | -inGrammar, lveg_final.gr, 90 | -loadGrammar, false, 91 | -ef1prune, true, 92 | -ef1ondev, true, 93 | -ef1ontrain, false, 94 | -ef1imwrite, false, 95 | 96 | -consfile, , 97 | -runtag, 2018_0212_4_3_3_200_test2, 98 | -logroot, log/, 99 | -logtype, 0, 100 | -verbose, false, 101 | -nbatch, 0, 102 | -precision, 3, 103 | -rndomseed, 0, 104 | 105 | -dgradnbatch, -1, 106 | 107 | -imgprefix, lveg -------------------------------------------------------------------------------- /project/Dependencies.scala: -------------------------------------------------------------------------------- 1 | import sbt._ 2 | 3 | object Version { 4 | val scala = "2.12.0" 5 | val junit = "4.12" 6 | val log4j = "1.2.16" 7 | val pool2 = "2.4.2" 8 | val BerkeleyParser = "1.7" 9 | } 10 | 11 | object Library { 12 | val junit = "junit" % "junit" % Version.junit 13 | val log4j = "log4j" % "log4j" % Version.log4j 14 | val pool2 = "org.apache.commons" % "commons-pool2" % Version.pool2 15 | val BerkeleyParser = "BerkeleyParser" % "BerkeleyParser" % Version.BerkeleyParser from "https://raw.githubusercontent.com/slavpetrov/berkeleyparser/master/BerkeleyParser-1.7.jar" 16 | } -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=0.13.10 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaoyanpeng/lveg/d2765585dfb095d5a8c862896d944576aa4ac9a8/project/plugins.sbt -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/data/LVeGCorpus.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.data; 2 | 3 | import java.io.Serializable; 4 | import java.util.List; 5 | 6 | import edu.berkeley.nlp.syntax.Tree; 7 | import edu.berkeley.nlp.util.Counter; 8 | import edu.shanghaitech.ai.nlp.lveg.impl.SimpleLVeGLexicon; 9 | import edu.shanghaitech.ai.nlp.syntax.State; 10 | 11 | /** 12 | * @author Yanpeng Zhao 13 | * 14 | */ 15 | public class LVeGCorpus implements Serializable { 16 | /** 17 | * 18 | */ 19 | private static final long serialVersionUID = 1278753896112455981L; 20 | 21 | /** 22 | * @param trees the parse tree 23 | * @param lexicon temporary data holder (wordIndexer) 24 | * @param rareThreshold words with frequencies lower than this value will be replaced with its signature 25 | */ 26 | public static void replaceRareWords( 27 | StateTreeList trees, SimpleLVeGLexicon lexicon, int rareThreshold) { 28 | Counter wordCounter = new Counter<>(); 29 | for (Tree tree : trees) { 30 | List words = tree.getYield(); 31 | for (State word : words) { 32 | String name = word.getName(); 33 | wordCounter.incrementCount(name, 1.0); 34 | lexicon.wordIndexer.add(name); 35 | } 36 | } 37 | 38 | for (Tree tree : trees) { 39 | List words = tree.getYield(); 40 | int pos = 0; 41 | for (State word : words) { 42 | String name = word.getName(); 43 | if (wordCounter.getCount(name) <= rareThreshold) { 44 | name = lexicon.getCachedSignature(name, pos); 45 | word.setName(name); 46 | } 47 | pos++; 48 | } 49 | } 50 | } 51 | 52 | } 53 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/data/ObjectFileManager.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.data; 2 | 3 | import java.io.File; 4 | import java.io.FileInputStream; 5 | import java.io.FileOutputStream; 6 | import java.io.IOException; 7 | import java.io.ObjectInputStream; 8 | import java.io.ObjectOutputStream; 9 | import java.io.Serializable; 10 | import java.text.SimpleDateFormat; 11 | import java.util.Date; 12 | import java.util.Set; 13 | import java.util.zip.GZIPInputStream; 14 | import java.util.zip.GZIPOutputStream; 15 | 16 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGGrammar; 17 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGLexicon; 18 | import edu.shanghaitech.ai.nlp.util.Numberer; 19 | 20 | public class ObjectFileManager { 21 | 22 | public static class ObjectFile implements Serializable { 23 | /** 24 | * 25 | */ 26 | private static final long serialVersionUID = 1852249590891181238L; 27 | 28 | public boolean save(String filename) { 29 | /* 30 | if (new File(filename).exists()) { 31 | filename += new SimpleDateFormat(".yyyyMMddHHmmss").format(new Date()); 32 | } 33 | */ 34 | try { 35 | FileOutputStream fos = new FileOutputStream(filename); 36 | GZIPOutputStream gos = new GZIPOutputStream(fos); 37 | ObjectOutputStream oos = new ObjectOutputStream(gos); 38 | oos.writeObject(this); 39 | oos.flush(); 40 | oos.close(); 41 | gos.close(); 42 | fos.close(); 43 | } catch (IOException e) { 44 | e.printStackTrace(); 45 | return false; 46 | } 47 | return true; 48 | } 49 | 50 | public static Object load(String filename) { 51 | Object o = null; 52 | try { 53 | FileInputStream fis = new FileInputStream(filename); 54 | GZIPInputStream gis = new GZIPInputStream(fis); 55 | ObjectInputStream ois = new ObjectInputStream(gis); 56 | o = ois.readObject(); 57 | ois.close(); 58 | gis.close(); 59 | fis.close(); 60 | } catch (IOException | ClassNotFoundException e) { 61 | e.printStackTrace(); 62 | return null; 63 | } 64 | return o; 65 | } 66 | } 67 | 68 | 69 | public static class CorpusFile extends ObjectFile { 70 | /** 71 | * 72 | */ 73 | private static final long serialVersionUID = -5871763111246836457L; 74 | private StateTreeList train; 75 | private StateTreeList test; 76 | private StateTreeList dev; 77 | private Numberer numberer; 78 | 79 | public CorpusFile(StateTreeList train, StateTreeList test, StateTreeList dev, Numberer numberer) { 80 | this.numberer = numberer; 81 | this.train = train; 82 | this.test = test; 83 | this.dev = dev; 84 | } 85 | 86 | public StateTreeList getTrain() { 87 | return train; 88 | } 89 | 90 | public StateTreeList getTest() { 91 | return test; 92 | } 93 | 94 | public StateTreeList getDev() { 95 | return dev; 96 | } 97 | 98 | public Numberer getNumberer() { 99 | return numberer; 100 | } 101 | } 102 | 103 | 104 | public static class GrammarFile extends ObjectFile { 105 | /** 106 | * 107 | */ 108 | private static final long serialVersionUID = -8119623200836693006L; 109 | private LVeGGrammar grammar; 110 | private LVeGLexicon lexicon; 111 | 112 | public GrammarFile(LVeGGrammar grammar, LVeGLexicon lexicon) { 113 | this.grammar = grammar; 114 | this.lexicon = lexicon; 115 | } 116 | 117 | public LVeGGrammar getGrammar() { 118 | return grammar; 119 | } 120 | 121 | public LVeGLexicon getLexicon() { 122 | return lexicon; 123 | } 124 | } 125 | 126 | 127 | public static class Constraint extends ObjectFile { 128 | /** 129 | * 130 | */ 131 | private static final long serialVersionUID = -235658934972194254L; 132 | private Set[][][] constraints; 133 | 134 | public Constraint(Set[][][] constraints) { 135 | this.constraints = constraints; 136 | } 137 | 138 | public Set[][][] getConstraints() { 139 | return constraints; 140 | } 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/data/StateTreeList.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.data; 2 | 3 | import java.io.Serializable; 4 | import java.util.AbstractCollection; 5 | import java.util.ArrayList; 6 | import java.util.Collections; 7 | import java.util.Iterator; 8 | import java.util.List; 9 | import java.util.Random; 10 | 11 | import edu.berkeley.nlp.syntax.Tree; 12 | import edu.shanghaitech.ai.nlp.syntax.State; 13 | import edu.shanghaitech.ai.nlp.util.Numberer; 14 | 15 | /** 16 | * @author Yanpeng Zhao 17 | * 18 | */ 19 | public class StateTreeList extends AbstractCollection> implements Serializable { 20 | /** 21 | * 22 | */ 23 | private static final long serialVersionUID = 6424710107698526446L; 24 | private final static short ID_WORD = -1; 25 | private List> trees; 26 | 27 | 28 | public class StateTreeListIterator implements Iterator> { 29 | 30 | Tree currentTree; 31 | Iterator> treeListIter; 32 | 33 | public StateTreeListIterator() { 34 | treeListIter = trees.iterator(); 35 | currentTree = null; 36 | } 37 | 38 | @Override 39 | public boolean hasNext() { 40 | // TODO Auto-generated method stub 41 | if (currentTree != null) { 42 | // TODO 43 | } 44 | return treeListIter.hasNext(); 45 | } 46 | 47 | @Override 48 | public Tree next() { 49 | // TODO Auto-generated method stub 50 | currentTree = treeListIter.next(); 51 | return currentTree; 52 | } 53 | 54 | @Override 55 | public void remove() { 56 | treeListIter.remove(); 57 | } 58 | 59 | } 60 | 61 | 62 | @Override 63 | public Iterator> iterator() { 64 | // TODO Auto-generated method stub 65 | return new StateTreeListIterator(); 66 | } 67 | 68 | @Override 69 | public int size() { 70 | // TODO Auto-generated method stub 71 | return trees.size(); 72 | } 73 | 74 | @Override 75 | public boolean isEmpty() { 76 | return trees.isEmpty(); 77 | } 78 | 79 | @Override 80 | public boolean add(Tree tree) { 81 | return trees.add(tree); 82 | } 83 | 84 | public Tree get(int i) { 85 | return trees.get(i); 86 | } 87 | 88 | 89 | public StateTreeList() { 90 | this.trees = new ArrayList<>(); 91 | } 92 | 93 | 94 | public StateTreeList(StateTreeList stateTreeList) { 95 | this.trees = new ArrayList<>(); 96 | for (Tree tree: stateTreeList.trees) { 97 | trees.add(copyTreeButLeaf(tree)); 98 | } 99 | } 100 | 101 | 102 | /** 103 | * The leaf is copied by reference. It's the same as 104 | * {@code edu.berkeley.nlp.PCFGLA.StateSetTreeList.resizeStateSetTree} 105 | * 106 | * @param tree a parse tree 107 | * @return 108 | * 109 | */ 110 | private Tree copyTreeButLeaf(Tree tree) { 111 | if (tree.isLeaf()) { return tree; } 112 | State state = new State(tree.getLabel(), false); 113 | List> children = new ArrayList<>(); 114 | for (Tree child : tree.getChildren()) { 115 | children.add(copyTreeButLeaf(child)); 116 | } 117 | return new Tree(state, children); 118 | } 119 | 120 | 121 | public StateTreeList copy() { 122 | StateTreeList stateTreeList = new StateTreeList(); 123 | for (Tree tree : trees) { 124 | stateTreeList.add(copyTree(tree)); 125 | } 126 | return stateTreeList; 127 | } 128 | 129 | 130 | private Tree copyTree(Tree tree) { 131 | List> children = new ArrayList<>(tree.getChildren().size()); 132 | for (Tree child : tree.getChildren()) { 133 | children.add(copyTree(child)); 134 | } 135 | return new Tree(tree.getLabel().copy(), children); 136 | } 137 | 138 | 139 | public StateTreeList(List> trees, Numberer numberer) { 140 | this.trees = new ArrayList<>(); 141 | for (Tree tree : trees) { 142 | this.trees.add(stringTreeToStateTree(tree, numberer)); 143 | tree = null; // clean the memory 144 | } 145 | } 146 | 147 | 148 | /** 149 | * @param trees parse trees 150 | * @param numberer recording the ids of tags 151 | * 152 | */ 153 | public static void initializeNumbererTag( 154 | List> trees, Numberer numberer) { 155 | for (Tree tree : trees) { 156 | stringTreeToStateTree(tree, numberer); 157 | } 158 | } 159 | 160 | 161 | public static void stringTreeToStateTree( 162 | List> trees, Numberer numberer) { 163 | for (Tree tree : trees) { 164 | stringTreeToStateTree(tree, numberer); 165 | } 166 | } 167 | 168 | 169 | /** 170 | * @param tree a parse tree 171 | * @param numberer record the ids of tags 172 | * @return parse tree represented by the state list 173 | * 174 | */ 175 | public static Tree stringTreeToStateTree(Tree tree, Numberer numberer) { 176 | Tree result = stringTreeToStateTree(tree, numberer, 0, tree.getYield().size()); 177 | List words = result.getYield(); 178 | for (short pos = 0; pos < words.size(); pos++) { 179 | words.get(pos).from = pos; 180 | words.get(pos).to = (short) (pos + 1); 181 | } 182 | return result; 183 | } 184 | 185 | 186 | public void shuffle(Random rnd) { 187 | Collections.shuffle(trees, rnd); 188 | } 189 | 190 | 191 | public void reset() { 192 | for (Tree tree : trees) { 193 | reset(tree); 194 | } 195 | } 196 | 197 | 198 | public void reset(Tree tree) { 199 | if (tree.isLeaf()) { return; } 200 | if (tree.getLabel() != null) { 201 | tree.getLabel().clear(false); 202 | } 203 | for (Tree child : tree.getChildren()) { 204 | reset(child); 205 | } 206 | } 207 | 208 | 209 | /** 210 | * Convert a state tree to a string tree. 211 | * 212 | * @param tree a state tree 213 | * @param numberer which records the ids of tags 214 | * @return 215 | */ 216 | public static Tree stateTreeToStringTree(Tree tree, Numberer numberer) { 217 | if (tree.isLeaf()) { 218 | String name = tree.getLabel().getName(); 219 | return new Tree(name); 220 | } 221 | 222 | String name = (String) numberer.object(tree.getLabel().getId()); 223 | Tree newTree = new Tree(name); 224 | List> children = new ArrayList<>(); 225 | 226 | for (Tree child : tree.getChildren()) { 227 | Tree newChild = stateTreeToStringTree(child, numberer); 228 | children.add(newChild); 229 | } 230 | newTree.setChildren(children); 231 | return newTree; 232 | } 233 | 234 | 235 | /** 236 | * Convert a string tree to a state tree. 237 | * 238 | * @param tree a parse tree 239 | * @param numberer which records the ids of tags 240 | * @param from starting point of the span 241 | * @param to ending point of the span 242 | * @return parse tree represented by the state list 243 | * 244 | */ 245 | private static Tree stringTreeToStateTree(Tree tree, Numberer numberer, int from, int to) { 246 | if (tree.isLeaf()) { 247 | State state = new State(tree.getLabel().intern(), (short) ID_WORD, (short) from, (short) to); 248 | return new Tree(state); 249 | } 250 | /* numberer is initialized here */ 251 | short id = (short) numberer.number(tree.getLabel()); 252 | 253 | // System.out.println(tree.getLabel().intern()); // tag name 254 | State state = new State(null, id, (short) from, (short) to); 255 | Tree newTree = new Tree(state); 256 | List> children = new ArrayList<>(); 257 | for (Tree child : tree.getChildren()) { 258 | short length = (short) child.getYield().size(); 259 | Tree newChild = stringTreeToStateTree(child, numberer, from, from + length); 260 | from += length; 261 | children.add(newChild); 262 | } 263 | newTree.setChildren(children); 264 | return newTree; 265 | } 266 | 267 | } 268 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/LVeGPCFG.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg; 2 | 3 | import java.io.IOException; 4 | import java.nio.charset.StandardCharsets; 5 | import java.util.ArrayList; 6 | import java.util.Arrays; 7 | import java.util.HashSet; 8 | import java.util.List; 9 | import java.util.Map; 10 | import java.util.PriorityQueue; 11 | 12 | import edu.berkeley.nlp.PCFGLA.TreeAnnotations; 13 | import edu.berkeley.nlp.syntax.Tree; 14 | import edu.shanghaitech.ai.nlp.data.StateTreeList; 15 | import edu.shanghaitech.ai.nlp.data.ObjectFileManager.GrammarFile; 16 | import edu.shanghaitech.ai.nlp.eval.EnglishPennTreebankParseEvaluator; 17 | import edu.shanghaitech.ai.nlp.lveg.impl.PCFGParser; 18 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianDistribution; 19 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 20 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGGrammar; 21 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGLexicon; 22 | import edu.shanghaitech.ai.nlp.optimization.Optimizer; 23 | import edu.shanghaitech.ai.nlp.syntax.State; 24 | import edu.shanghaitech.ai.nlp.util.FunUtil; 25 | import edu.shanghaitech.ai.nlp.util.Numberer; 26 | import edu.shanghaitech.ai.nlp.util.OptionParser; 27 | import edu.shanghaitech.ai.nlp.util.ThreadPool; 28 | 29 | public class LVeGPCFG extends LearnerConfig { 30 | /** 31 | * 32 | */ 33 | private static final long serialVersionUID = 8232031232691463175L; 34 | 35 | protected static EnglishPennTreebankParseEvaluator.LabeledConstituentEval scorer; 36 | protected static StateTreeList trainTrees; 37 | protected static StateTreeList testTrees; 38 | protected static StateTreeList devTrees; 39 | 40 | protected static LVeGGrammar grammar; 41 | protected static LVeGLexicon lexicon; 42 | 43 | protected static PCFGParser pcfgParser; 44 | protected static ThreadPool mparser; 45 | 46 | protected static String treeFile; 47 | protected static Options opts; 48 | 49 | public static void main(String[] args) throws Exception { 50 | String fparams = args[0]; 51 | try { 52 | args = readFile(fparams, StandardCharsets.UTF_8).split(","); 53 | } catch (IOException e) { 54 | e.printStackTrace(); 55 | } 56 | OptionParser optionParser = new OptionParser(Options.class); 57 | opts = (Options) optionParser.parse(args, true); 58 | // configurations 59 | initialize(opts, true); // logger can only be used after the initialization 60 | logger.info("Calling with " + optionParser.getParsedOptions() + "\n"); 61 | 62 | // loading data 63 | Numberer wrapper = new Numberer(); 64 | Map trees = loadData(wrapper, opts); 65 | // training 66 | long startTime = System.currentTimeMillis(); 67 | train(trees, wrapper); 68 | long endTime = System.currentTimeMillis(); 69 | logger.trace("[total time consumed by LVeG tester] " + (endTime - startTime) / 1000.0 + "\n"); 70 | } 71 | 72 | 73 | private static void train(Map trees, Numberer wrapper) throws Exception { 74 | trainTrees = trees.get(ID_TRAIN); 75 | testTrees = trees.get(ID_TEST); 76 | devTrees = trees.get(ID_DEV); 77 | 78 | treeFile = sublogroot + opts.imgprefix; 79 | 80 | Numberer numberer = wrapper.getGlobalNumberer(KEY_TAG_SET); 81 | 82 | /* to ease the parameters tuning */ 83 | GaussianMixture.config(opts.maxnbig, opts.expzero, opts.maxmw, opts.ncomponent, 84 | opts.nwratio, opts.riserate, opts.rtratio, opts.hardcut, random, mogPool); 85 | GaussianDistribution.config(opts.maxmu, opts.maxvar, opts.dim, opts.nmratio, opts.nvratio, random, gaussPool); 86 | Optimizer.config(opts.choice, random, opts.maxsample, opts.bsize, opts.minmw, opts.sampling); // FIXME no errors, just alert you... 87 | 88 | // load grammar 89 | logger.trace("--->Loading grammars from \'" + subdatadir + opts.inGrammar + "\'...\n"); 90 | GrammarFile gfile = (GrammarFile) GrammarFile.load(subdatadir + opts.inGrammar); 91 | grammar = gfile.getGrammar(); 92 | lexicon = gfile.getLexicon(); 93 | 94 | /* 95 | logger.trace(grammar); 96 | logger.trace(lexicon); 97 | System.exit(0); 98 | */ 99 | 100 | lexicon.labelTrees(trainTrees); // FIXME no errors, just alert you to pay attention to it 101 | lexicon.labelTrees(testTrees); // save the search time cost by finding a specific tag-word 102 | lexicon.labelTrees(devTrees); // pair in in Lexicon.score(...) 103 | 104 | scorer = new EnglishPennTreebankParseEvaluator.LabeledConstituentEval( 105 | new HashSet(Arrays.asList(new String[] { "ROOT", "PSEUDO" })), 106 | new HashSet(Arrays.asList(new String[] { "''", "``", ".", ":", "," }))); 107 | pcfgParser = new PCFGParser, Tree>(grammar, lexicon, opts.maxslen, 108 | opts.ntcyker, opts.pcyker, opts.ef1prune, false); 109 | mparser = new ThreadPool(pcfgParser, opts.nttest); 110 | 111 | logger.info("\n---F1 CONFIG---\n[parallel: batch-" + opts.pbatch + ", grad-" + 112 | opts.pgrad + ", eval-" + opts.peval + ", test-" + opts.pf1 + "]\n\n"); 113 | 114 | sorter = new PriorityQueue<>(opts.bsize + 5, wcomparator); 115 | 116 | StringBuffer sb = new StringBuffer(); 117 | sb.append("[test ]" + f1entry(testTrees, numberer, false) + "\n"); 118 | if (opts.ef1ontrain) { 119 | scorer.reset(); 120 | sb.append("[train]" + f1entry(trainTrees, numberer, true) + "\n"); 121 | } 122 | if (opts.ef1ondev) { 123 | scorer.reset(); 124 | sb.append("[dev ]" + f1entry(devTrees, numberer, false) + "\n"); 125 | } 126 | logger.info("[summary]\n" + sb.toString() + "\n"); 127 | // kill threads 128 | grammar.shutdown(); 129 | lexicon.shutdown(); 130 | mparser.shutdown(); 131 | } 132 | 133 | 134 | public static String f1entry(StateTreeList trees, Numberer numberer, boolean istrain) { 135 | if (opts.pf1) { 136 | return parallelFscore(opts, mparser, trees, numberer, istrain); 137 | } else { 138 | return serialFscore(opts, pcfgParser, trees, numberer, istrain); 139 | } 140 | } 141 | 142 | 143 | public static String parallelFscore(Options opts, ThreadPool mparser, StateTreeList stateTreeList, Numberer numberer, boolean istrain) { 144 | Tree goldTree = null; 145 | Tree parsedTree = null; 146 | int nUnparsable = 0, cnt = 0, idx = 0; 147 | List> trees = new ArrayList<>(stateTreeList.size()); 148 | filterTrees(opts, stateTreeList, trees, numberer, istrain); 149 | 150 | for (Tree tree : trees) { 151 | mparser.execute(tree); 152 | while (mparser.hasNext()) { 153 | goldTree = trees.get(idx); 154 | parsedTree = (Tree) mparser.getNext(); 155 | if (!saveTree(goldTree, parsedTree, numberer, idx)) { 156 | nUnparsable++; 157 | } 158 | idx++; 159 | } 160 | } 161 | while (!mparser.isDone()) { 162 | while (mparser.hasNext()) { 163 | goldTree = trees.get(idx); 164 | parsedTree = (Tree) mparser.getNext(); 165 | if (!saveTree(goldTree, parsedTree, numberer, idx)) { 166 | nUnparsable++; 167 | } 168 | idx++; 169 | } 170 | } 171 | mparser.reset(); 172 | String summary = scorer.display(); 173 | logger.trace("\n[max rule parser: " + nUnparsable + " unparsable sample(s) of " + stateTreeList.size() + "(" + trees.size() + ") samples]\n"); 174 | logger.trace(summary + "\n\n"); 175 | return summary; 176 | } 177 | 178 | 179 | public static String serialFscore(Options opts, PCFGParser mrParser, StateTreeList stateTreeList, Numberer numberer, boolean istrain) { 180 | 181 | int nUnparsable = 0, idx = 0; 182 | 183 | List> trees = new ArrayList<>(stateTreeList.size()); 184 | filterTrees(opts, stateTreeList, trees, numberer, istrain); 185 | 186 | for (Tree tree : trees) { 187 | Tree parsedTree = mrParser.parse(tree); 188 | if (!saveTree(tree, parsedTree, numberer, idx)) { 189 | nUnparsable++; 190 | } 191 | idx++; // index the State tree 192 | } 193 | String summary = scorer.display(); 194 | logger.trace("\n[max rule parser: " + nUnparsable + " unparsable sample(s) of " + stateTreeList.size() + "(" + trees.size() + ") samples]\n"); 195 | logger.trace(summary + "\n\n"); 196 | return summary; 197 | } 198 | 199 | 200 | public static boolean saveTree(Tree tree, Tree parsedTree, Numberer numberer, int idx) { 201 | try { 202 | Tree goldTree = null; 203 | if (opts.ef1imwrite) { 204 | String treename = treeFile + "_gd_" + idx; 205 | goldTree = StateTreeList.stateTreeToStringTree(tree, numberer); 206 | FunUtil.saveTree2image(null, treename, goldTree, numberer); 207 | goldTree = TreeAnnotations.unAnnotateTree(goldTree, false); 208 | FunUtil.saveTree2image(null, treename + "_ua", goldTree, numberer); 209 | 210 | treename = treeFile + "_te_" + idx; 211 | FunUtil.saveTree2image(null, treename, parsedTree, numberer); 212 | parsedTree = TreeAnnotations.unAnnotateTree(parsedTree, false); 213 | FunUtil.saveTree2image(null, treename + "_ua", parsedTree, numberer); 214 | } else { 215 | goldTree = StateTreeList.stateTreeToStringTree(tree, numberer); 216 | goldTree = TreeAnnotations.unAnnotateTree(goldTree, false); 217 | parsedTree = TreeAnnotations.unAnnotateTree(parsedTree, false); 218 | } 219 | scorer.evaluate(parsedTree, goldTree); 220 | logger.trace(idx + "\tgold : " + goldTree + "\n"); 221 | logger.trace(idx + "\tparsed: " + parsedTree + "\n"); 222 | return true; 223 | } catch (Exception e) { 224 | e.printStackTrace(); 225 | return false; 226 | } 227 | } 228 | } 229 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/PCFGMaxRule.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg; 2 | 3 | import java.io.IOException; 4 | import java.nio.charset.StandardCharsets; 5 | import java.util.ArrayList; 6 | import java.util.Arrays; 7 | import java.util.HashSet; 8 | import java.util.List; 9 | import java.util.Map; 10 | import java.util.PriorityQueue; 11 | 12 | import edu.berkeley.nlp.PCFGLA.TreeAnnotations; 13 | import edu.berkeley.nlp.syntax.Tree; 14 | import edu.shanghaitech.ai.nlp.data.StateTreeList; 15 | import edu.shanghaitech.ai.nlp.data.ObjectFileManager.GrammarFile; 16 | import edu.shanghaitech.ai.nlp.eval.EnglishPennTreebankParseEvaluator; 17 | import edu.shanghaitech.ai.nlp.lveg.impl.PCFGMaxRuleParser; 18 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianDistribution; 19 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 20 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGGrammar; 21 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGLexicon; 22 | import edu.shanghaitech.ai.nlp.optimization.Optimizer; 23 | import edu.shanghaitech.ai.nlp.syntax.State; 24 | import edu.shanghaitech.ai.nlp.util.FunUtil; 25 | import edu.shanghaitech.ai.nlp.util.Numberer; 26 | import edu.shanghaitech.ai.nlp.util.OptionParser; 27 | import edu.shanghaitech.ai.nlp.util.ThreadPool; 28 | 29 | public class PCFGMaxRule extends LearnerConfig { 30 | /** 31 | * 32 | */ 33 | private static final long serialVersionUID = 8232031232691463175L; 34 | 35 | protected static EnglishPennTreebankParseEvaluator.LabeledConstituentEval scorer; 36 | protected static StateTreeList trainTrees; 37 | protected static StateTreeList testTrees; 38 | protected static StateTreeList devTrees; 39 | 40 | protected static LVeGGrammar grammar; 41 | protected static LVeGLexicon lexicon; 42 | 43 | protected static PCFGMaxRuleParser pcfgMrParser; 44 | protected static ThreadPool mparser; 45 | 46 | protected static String treeFile; 47 | protected static Options opts; 48 | 49 | public static void main(String[] args) throws Exception { 50 | String fparams = args[0]; 51 | try { 52 | args = readFile(fparams, StandardCharsets.UTF_8).split(","); 53 | } catch (IOException e) { 54 | e.printStackTrace(); 55 | } 56 | OptionParser optionParser = new OptionParser(Options.class); 57 | opts = (Options) optionParser.parse(args, true); 58 | // configurations 59 | initialize(opts, true); // logger can only be used after the initialization 60 | logger.info("Calling with " + optionParser.getParsedOptions() + "\n"); 61 | 62 | // loading data 63 | Numberer wrapper = new Numberer(); 64 | Map trees = loadData(wrapper, opts); 65 | // training 66 | long startTime = System.currentTimeMillis(); 67 | train(trees, wrapper); 68 | long endTime = System.currentTimeMillis(); 69 | logger.trace("[total time consumed by LVeG tester] " + (endTime - startTime) / 1000.0 + "\n"); 70 | } 71 | 72 | 73 | private static void train(Map trees, Numberer wrapper) throws Exception { 74 | trainTrees = trees.get(ID_TRAIN); 75 | testTrees = trees.get(ID_TEST); 76 | devTrees = trees.get(ID_DEV); 77 | 78 | treeFile = sublogroot + opts.imgprefix; 79 | 80 | Numberer numberer = wrapper.getGlobalNumberer(KEY_TAG_SET); 81 | 82 | /* to ease the parameters tuning */ 83 | GaussianMixture.config(opts.maxnbig, opts.expzero, opts.maxmw, opts.ncomponent, 84 | opts.nwratio, opts.riserate, opts.rtratio, opts.hardcut, random, mogPool); 85 | GaussianDistribution.config(opts.maxmu, opts.maxvar, opts.dim, opts.nmratio, opts.nvratio, random, gaussPool); 86 | Optimizer.config(opts.choice, random, opts.maxsample, opts.bsize, opts.minmw, opts.sampling); // FIXME no errors, just alert you... 87 | 88 | // load grammar 89 | logger.trace("--->Loading grammars from \'" + subdatadir + opts.inGrammar + "\'...\n"); 90 | GrammarFile gfile = (GrammarFile) GrammarFile.load(subdatadir + opts.inGrammar); 91 | grammar = gfile.getGrammar(); 92 | lexicon = gfile.getLexicon(); 93 | 94 | /* 95 | logger.trace(grammar); 96 | logger.trace(lexicon); 97 | System.exit(0); 98 | */ 99 | 100 | lexicon.labelTrees(trainTrees); // FIXME no errors, just alert you to pay attention to it 101 | lexicon.labelTrees(testTrees); // save the search time cost by finding a specific tag-word 102 | lexicon.labelTrees(devTrees); // pair in in Lexicon.score(...) 103 | 104 | scorer = new EnglishPennTreebankParseEvaluator.LabeledConstituentEval( 105 | new HashSet(Arrays.asList(new String[] { "ROOT", "PSEUDO" })), 106 | new HashSet(Arrays.asList(new String[] { "''", "``", ".", ":", "," }))); 107 | pcfgMrParser = new PCFGMaxRuleParser, Tree>(grammar, lexicon, opts.maxslen, 108 | opts.ntcyker, opts.pcyker, opts.ef1prune, false); 109 | mparser = new ThreadPool(pcfgMrParser, opts.nttest); 110 | 111 | 112 | logger.info("\n---F1 CONFIG---\n[parallel: batch-" + opts.pbatch + ", grad-" + 113 | opts.pgrad + ", eval-" + opts.peval + ", test-" + opts.pf1 + "]\n\n"); 114 | 115 | sorter = new PriorityQueue<>(opts.bsize + 5, wcomparator); 116 | 117 | StringBuffer sb = new StringBuffer(); 118 | sb.append("[test ]" + f1entry(testTrees, numberer, false) + "\n"); 119 | if (opts.ef1ontrain) { 120 | scorer.reset(); 121 | sb.append("[train]" + f1entry(trainTrees, numberer, true) + "\n"); 122 | } 123 | if (opts.ef1ondev) { 124 | scorer.reset(); 125 | sb.append("[dev ]" + f1entry(devTrees, numberer, false) + "\n"); 126 | } 127 | logger.info("[summary]\n" + sb.toString() + "\n"); 128 | // kill threads 129 | grammar.shutdown(); 130 | lexicon.shutdown(); 131 | mparser.shutdown(); 132 | } 133 | 134 | 135 | public static String f1entry(StateTreeList trees, Numberer numberer, boolean istrain) { 136 | if (opts.pf1) { 137 | return parallelFscore(opts, mparser, trees, numberer, istrain); 138 | } else { 139 | return serialFscore(opts, pcfgMrParser, trees, numberer, istrain); 140 | } 141 | } 142 | 143 | 144 | public static String parallelFscore(Options opts, ThreadPool mparser, StateTreeList stateTreeList, Numberer numberer, boolean istrain) { 145 | Tree goldTree = null; 146 | Tree parsedTree = null; 147 | int nUnparsable = 0, cnt = 0, idx = 0; 148 | List> trees = new ArrayList<>(stateTreeList.size()); 149 | filterTrees(opts, stateTreeList, trees, numberer, istrain); 150 | 151 | for (Tree tree : trees) { 152 | mparser.execute(tree); 153 | while (mparser.hasNext()) { 154 | goldTree = trees.get(idx); 155 | parsedTree = (Tree) mparser.getNext(); 156 | if (!saveTree(goldTree, parsedTree, numberer, idx)) { 157 | nUnparsable++; 158 | } 159 | idx++; 160 | } 161 | } 162 | while (!mparser.isDone()) { 163 | while (mparser.hasNext()) { 164 | goldTree = trees.get(idx); 165 | parsedTree = (Tree) mparser.getNext(); 166 | if (!saveTree(goldTree, parsedTree, numberer, idx)) { 167 | nUnparsable++; 168 | } 169 | idx++; 170 | } 171 | } 172 | mparser.reset(); 173 | String summary = scorer.display(); 174 | logger.trace("\n[max rule parser: " + nUnparsable + " unparsable sample(s) of " + stateTreeList.size() + "(" + trees.size() + ") samples]\n"); 175 | logger.trace(summary + "\n\n"); 176 | return summary; 177 | } 178 | 179 | 180 | public static String serialFscore(Options opts, PCFGMaxRuleParser mrParser, StateTreeList stateTreeList, Numberer numberer, boolean istrain) { 181 | 182 | int nUnparsable = 0, idx = 0; 183 | 184 | List> trees = new ArrayList<>(stateTreeList.size()); 185 | filterTrees(opts, stateTreeList, trees, numberer, istrain); 186 | 187 | for (Tree tree : trees) { 188 | Tree parsedTree = mrParser.parse(tree); 189 | if (!saveTree(tree, parsedTree, numberer, idx)) { 190 | nUnparsable++; 191 | } 192 | idx++; // index the State tree 193 | } 194 | String summary = scorer.display(); 195 | logger.trace("\n[max rule parser: " + nUnparsable + " unparsable sample(s) of " + stateTreeList.size() + "(" + trees.size() + ") samples]\n"); 196 | logger.trace(summary + "\n\n"); 197 | return summary; 198 | } 199 | 200 | 201 | public static boolean saveTree(Tree tree, Tree parsedTree, Numberer numberer, int idx) { 202 | try { 203 | Tree goldTree = null; 204 | if (opts.ef1imwrite) { 205 | String treename = treeFile + "_gd_" + idx; 206 | goldTree = StateTreeList.stateTreeToStringTree(tree, numberer); 207 | FunUtil.saveTree2image(null, treename, goldTree, numberer); 208 | goldTree = TreeAnnotations.unAnnotateTree(goldTree, false); 209 | FunUtil.saveTree2image(null, treename + "_ua", goldTree, numberer); 210 | 211 | treename = treeFile + "_te_" + idx; 212 | FunUtil.saveTree2image(null, treename, parsedTree, numberer); 213 | parsedTree = TreeAnnotations.unAnnotateTree(parsedTree, false); 214 | FunUtil.saveTree2image(null, treename + "_ua", parsedTree, numberer); 215 | } else { 216 | goldTree = StateTreeList.stateTreeToStringTree(tree, numberer); 217 | goldTree = TreeAnnotations.unAnnotateTree(goldTree, false); 218 | parsedTree = TreeAnnotations.unAnnotateTree(parsedTree, false); 219 | } 220 | scorer.evaluate(parsedTree, goldTree); 221 | logger.trace(idx + "\tgold : " + goldTree + "\n"); 222 | logger.trace(idx + "\tparsed: " + parsedTree + "\n"); 223 | return true; 224 | } catch (Exception e) { 225 | e.printStackTrace(); 226 | return false; 227 | } 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/impl/BinaryGrammarRule.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 4 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule; 5 | 6 | /** 7 | * @author Yanpeng Zhao 8 | * 9 | */ 10 | public class BinaryGrammarRule extends GrammarRule implements Comparable { 11 | /** 12 | * 13 | */ 14 | private static final long serialVersionUID = 7092883784519182728L; 15 | /** 16 | * the IDs of the two right-hand side nonterminals 17 | */ 18 | public short lchild; 19 | public short rchild; 20 | 21 | 22 | public BinaryGrammarRule(short lhs, short lchild, short rchild) { 23 | this.lhs = lhs; 24 | this.lchild = lchild; 25 | this.rchild = rchild; 26 | this.type = RuleType.LRBRULE; 27 | } 28 | 29 | 30 | public BinaryGrammarRule(short lhs, short lchild, short rchild, boolean init) { 31 | this(lhs, lchild, rchild); 32 | if (init) { initializeWeight(RuleType.LRBRULE, (short) -1, (short) -1); } 33 | } 34 | 35 | 36 | public BinaryGrammarRule(short lhs, short lchild, short rchild, GaussianMixture weight) { 37 | this(lhs, lchild, rchild); 38 | this.weight = weight; 39 | } 40 | 41 | 42 | @Override 43 | public void initializeWeight(RuleType type, short ncomponent, short ndim) { 44 | weight = rndRuleWeight(RuleType.LRBRULE, ncomponent, ndim); 45 | } 46 | 47 | 48 | @Override 49 | public GrammarRule copy() { 50 | BinaryGrammarRule rule = new BinaryGrammarRule(lhs, lchild, rchild); 51 | rule.weight = weight.copy(true); 52 | return rule; 53 | } 54 | 55 | 56 | @Override 57 | public boolean isUnary() { 58 | return false; 59 | } 60 | 61 | 62 | @Override 63 | public int hashCode() { 64 | return (lhs << 16) ^ (lchild << 8) ^ (rchild); 65 | } 66 | 67 | 68 | @Override 69 | public boolean equals(Object o) { 70 | if (this == o) { return true; } 71 | 72 | if (o instanceof BinaryGrammarRule) { 73 | BinaryGrammarRule rule = (BinaryGrammarRule) o; 74 | if (lhs == rule.lhs && lchild == rule.lchild && rchild == rule.rchild && type == rule.type) { 75 | return true; 76 | } 77 | } 78 | return false; 79 | } 80 | 81 | 82 | @Override 83 | public int compareTo(Object o) { 84 | BinaryGrammarRule rule = (BinaryGrammarRule) o; 85 | if (lhs < rule.lhs) { return -1; } 86 | if (lhs > rule.lhs) { return 1; } 87 | if (lchild < rule.lchild) { return -1; } 88 | if (lchild > rule.lchild) { return 1; } 89 | if (rchild < rule.rchild) { return -1; } 90 | if (rchild > rule.rchild) { return 1; } 91 | return 0; 92 | } 93 | 94 | 95 | @Override 96 | public String toString() { 97 | return "B-Rule [P: " + lhs +", LC: " + lchild + ", RC: " + rchild + "]"; 98 | } 99 | 100 | } 101 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/impl/DiagonalGaussianMixture.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import java.util.EnumMap; 4 | import java.util.List; 5 | import java.util.Map.Entry; 6 | import java.util.Set; 7 | 8 | import edu.shanghaitech.ai.nlp.lveg.LVeGTrainer; 9 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianDistribution; 10 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 11 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule.RuleUnit; 12 | import edu.shanghaitech.ai.nlp.util.FunUtil; 13 | 14 | /** 15 | * This one only differs from the GaussianMixture when in the need of creating the new specific instance. 16 | * 17 | * @author Yanpeng Zhao 18 | * 19 | */ 20 | public class DiagonalGaussianMixture extends GaussianMixture { 21 | /** 22 | * 23 | */ 24 | private static final long serialVersionUID = 1083077972374093199L; 25 | 26 | public DiagonalGaussianMixture() { 27 | super((short) 0); 28 | } 29 | 30 | public DiagonalGaussianMixture(short ncomponent) { 31 | super(ncomponent); 32 | initialize(); 33 | } 34 | 35 | public DiagonalGaussianMixture(short ncomponent, boolean init) { 36 | super(ncomponent); 37 | if (init) { initialize(); } 38 | } 39 | 40 | @Override 41 | public GaussianMixture instance(short ncomponent, boolean init) { 42 | return new DiagonalGaussianMixture(ncomponent, init); 43 | } 44 | 45 | public DiagonalGaussianMixture( 46 | short ncomponent, List weights, List>> mixture) { 47 | this(); 48 | this.ncomponent = ncomponent; 49 | for (int i = 0; i < weights.size(); i++) { 50 | this.components.add(new Component((short) i, weights.get(i), mixture.get(i))); 51 | } 52 | } 53 | 54 | 55 | public static DiagonalGaussianMixture borrowObject(short ncomponent) { 56 | GaussianMixture obj = null; 57 | try { 58 | obj = defObjectPool.borrowObject(ncomponent); 59 | } catch (Exception e) { 60 | logger.error("---------Borrow GM " + e + "\n"); 61 | try { 62 | LVeGTrainer.mogPool.invalidateObject(ncomponent, obj); 63 | } catch (Exception e1) { 64 | logger.error("---------Borrow GM(invalidate) " + e + "\n"); 65 | } 66 | ncomponent = ncomponent == -1 ? defNcomponent : ncomponent; 67 | obj = new DiagonalGaussianMixture(ncomponent); 68 | } 69 | return (DiagonalGaussianMixture) obj; 70 | } 71 | 72 | @Override 73 | protected void initialize() { 74 | for (int i = 0; i < ncomponent; i++) { 75 | double weight = (defRnd.nextDouble() - defNegWRatio) * defMaxmw; 76 | EnumMap> multivnd = new EnumMap<>(RuleUnit.class); 77 | weight = /*-0.69314718056*/ 0; // mixing weight 0.5, 1, 2 78 | components.add(new Component((short) i, weight, multivnd)); 79 | } 80 | } 81 | 82 | @Override 83 | public DiagonalGaussianMixture copy(boolean deep) { 84 | DiagonalGaussianMixture gm = new DiagonalGaussianMixture(); 85 | copy(gm, deep); 86 | return gm; 87 | } 88 | 89 | 90 | @Override 91 | public DiagonalGaussianMixture replaceKeys(EnumMap keys) { 92 | DiagonalGaussianMixture gm = new DiagonalGaussianMixture(); 93 | replaceKeys(gm, keys); 94 | return gm; 95 | } 96 | 97 | 98 | @Override 99 | public DiagonalGaussianMixture replaceAllKeys(RuleUnit newkey) { 100 | DiagonalGaussianMixture gm = new DiagonalGaussianMixture(); 101 | replaceAllKeys(gm, newkey); 102 | return gm; 103 | } 104 | 105 | 106 | @Override 107 | public DiagonalGaussianMixture multiply(GaussianMixture multiplier) { 108 | DiagonalGaussianMixture gm = new DiagonalGaussianMixture(); 109 | multiply(gm, multiplier); 110 | return gm; 111 | } 112 | 113 | @Override 114 | public GaussianMixture mulAndMarginalize(GaussianMixture gm, GaussianMixture des, RuleUnit key, boolean deep) { 115 | // 'des' is exactly the same as 'this' when deep is false 116 | if (des != null) { // placeholder 117 | if (deep) { 118 | des.clear(false); 119 | copy(des, true); 120 | } 121 | } else { // new memo space 122 | des = deep ? copy(true) : this; 123 | } 124 | // the following is the general case 125 | for (Component comp : des.components()) { 126 | double logsum = Double.NEGATIVE_INFINITY; 127 | GaussianDistribution gd = comp.squeeze(key); 128 | // w(ROOT->X) has no P portion in computing outside score 129 | if (gd == null) { continue; } 130 | for (Component comp1 : gm.components()) { 131 | GaussianDistribution gd1 = comp1.squeeze(null); 132 | double logcomp = comp1.getWeight() + gd.mulAndMarginalize(gd1); 133 | logsum = FunUtil.logAdd(logsum, logcomp); 134 | } 135 | comp.setWeight(comp.getWeight() + logsum); 136 | comp.getMultivnd().remove(key); 137 | } 138 | return des; 139 | } 140 | 141 | @Override 142 | public double mulAndMarginalize(EnumMap counts) { 143 | if (counts == null) { return Double.NEGATIVE_INFINITY; } 144 | double values = Double.NEGATIVE_INFINITY; 145 | for (Component comp : components) { 146 | double value = 0.0, vtmp = 0.0; 147 | for (Entry> node : comp.getMultivnd().entrySet()) { 148 | vtmp = 0.0; 149 | GaussianMixture gm = counts.get(node.getKey()); 150 | for (GaussianDistribution gd : node.getValue()) { 151 | vtmp = mulAndMarginalize(gm, gd); // head (tail) variable & outside (inside) score 152 | break; 153 | } 154 | value += vtmp; 155 | } 156 | value += comp.getWeight(); 157 | values = FunUtil.logAdd(values, value); 158 | } 159 | return values; 160 | } 161 | 162 | 163 | public static double mulAndMarginalize(GaussianMixture gm, GaussianDistribution gd) { 164 | double value = Double.NEGATIVE_INFINITY, vtmp; 165 | for (Component comp : gm.components()) { 166 | GaussianDistribution ios = comp.squeeze(null); 167 | vtmp = gd.mulAndMarginalize(ios); 168 | vtmp += comp.getWeight(); 169 | value = FunUtil.logAdd(value, vtmp); 170 | } 171 | return value; 172 | } 173 | 174 | } 175 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/impl/GaussFactory.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import java.util.List; 4 | import java.util.Random; 5 | 6 | import org.apache.commons.pool2.KeyedPooledObjectFactory; 7 | import org.apache.commons.pool2.PooledObject; 8 | import org.apache.commons.pool2.impl.DefaultPooledObject; 9 | 10 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianDistribution; 11 | 12 | public class GaussFactory implements KeyedPooledObjectFactory { 13 | 14 | protected short ndimension; 15 | protected double nmratio; 16 | protected double nvratio; 17 | protected double maxvar; 18 | protected double maxmu; 19 | protected Random rnd; 20 | 21 | public GaussFactory(short ndimension, double maxmu , double maxvar, double nmratio, double nvratio, Random rnd) { 22 | this.ndimension = ndimension; 23 | this.nmratio = nmratio; 24 | this.nvratio = nvratio; 25 | this.maxvar = maxvar; 26 | this.maxmu = maxmu; 27 | this.rnd = rnd; 28 | } 29 | 30 | @Override 31 | public void activateObject(Short key, PooledObject po) throws Exception { 32 | short ndim = key == -1 ? ndimension : key; 33 | List mus = po.getObject().getMus(); 34 | for (int i = 0; i < ndim; i++) { 35 | double rndn = (rnd.nextDouble() - nmratio) * maxmu; 36 | // rndn = 0.5; 37 | mus.add(rndn); 38 | } 39 | List vars = po.getObject().getVars(); 40 | for (int i = 0; i < ndim; i++) { 41 | double rndn = (rnd.nextDouble() - nvratio) * maxvar; 42 | // rndn = 0.5; 43 | vars.add(rndn); 44 | } 45 | } 46 | 47 | @Override 48 | public void destroyObject(Short key, PooledObject po) throws Exception { 49 | po.getObject().destroy(key); 50 | } 51 | 52 | @Override 53 | public PooledObject makeObject(Short key) throws Exception { 54 | short ndim = key == -1 ? ndimension : key; 55 | GaussianDistribution gauss = new DiagonalGaussianDistribution(ndim, false); 56 | gauss.setKey(key); 57 | return new DefaultPooledObject(gauss); 58 | } 59 | 60 | @Override 61 | public void passivateObject(Short key, PooledObject po) throws Exception { 62 | po.getObject().clear(key); 63 | } 64 | 65 | @Override 66 | public boolean validateObject(Short key, PooledObject po) { 67 | GaussianDistribution obj = po.getObject(); 68 | return obj != null && obj.isValid(key); 69 | } 70 | 71 | } 72 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/impl/GeneralGaussianDistribution.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import java.util.List; 4 | 5 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianDistribution; 6 | 7 | public class GeneralGaussianDistribution extends GaussianDistribution { 8 | /** 9 | * 10 | */ 11 | private static final long serialVersionUID = 3855732607147998162L; 12 | 13 | public GeneralGaussianDistribution() { 14 | super((short) 0); 15 | } 16 | 17 | public GeneralGaussianDistribution(short ndimension) { 18 | super(ndimension); 19 | initialize(); 20 | } 21 | 22 | public GeneralGaussianDistribution(short ndimension, boolean init) { 23 | super(ndimension); 24 | if (init) { initialize(); } 25 | } 26 | 27 | @Override 28 | public GaussianDistribution instance(short ndimension, boolean init) { 29 | return new GeneralGaussianDistribution(ndimension, init); 30 | } 31 | 32 | protected double eval(List sample, boolean normal) { 33 | return 0.0; 34 | } 35 | 36 | @Override 37 | public double mulAndMarginalize(GaussianDistribution gd) { 38 | // TODO Auto-generated method stub 39 | return 0; 40 | } 41 | 42 | @Override 43 | protected void initialize() { 44 | // TODO Auto-generated method stub 45 | 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/impl/GeneralGaussianMixture.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import java.util.EnumMap; 4 | 5 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 6 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule.RuleUnit; 7 | 8 | public class GeneralGaussianMixture extends GaussianMixture { 9 | /** 10 | * 11 | */ 12 | private static final long serialVersionUID = -6052750581733945498L; 13 | 14 | public GeneralGaussianMixture() { 15 | super((short) 0); 16 | } 17 | 18 | public GeneralGaussianMixture(short ncomponent) { 19 | super(ncomponent); 20 | initialize(); 21 | } 22 | 23 | public GeneralGaussianMixture(short ncomponent, boolean init) { 24 | super(ncomponent); 25 | if (init) { initialize(); } 26 | } 27 | 28 | @Override 29 | public GaussianMixture instance(short ncomponent, boolean init) { 30 | return new GeneralGaussianMixture(ncomponent, init); 31 | } 32 | 33 | @Override 34 | public double mulAndMarginalize(EnumMap counts) { 35 | return 0; 36 | } 37 | 38 | @Override 39 | protected void initialize() { 40 | } 41 | 42 | @Override 43 | public GaussianMixture mulAndMarginalize(GaussianMixture gm, GaussianMixture des, RuleUnit key, boolean deep) { 44 | return null; 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/impl/MaxRuleParser.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import java.util.List; 4 | import java.util.Set; 5 | 6 | import edu.berkeley.nlp.syntax.Tree; 7 | import edu.shanghaitech.ai.nlp.data.StateTreeList; 8 | import edu.shanghaitech.ai.nlp.lveg.model.ChartCell.Chart; 9 | import edu.shanghaitech.ai.nlp.lveg.LVeGTrainer; 10 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 11 | import edu.shanghaitech.ai.nlp.lveg.model.Inferencer; 12 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGLexicon; 13 | import edu.shanghaitech.ai.nlp.lveg.model.Parser; 14 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGGrammar; 15 | import edu.shanghaitech.ai.nlp.syntax.State; 16 | 17 | public class MaxRuleParser extends Parser { 18 | /** 19 | * 20 | */ 21 | private static final long serialVersionUID = 514004588461969299L; 22 | private MaxRuleInferencer inferencer; 23 | 24 | 25 | private MaxRuleParser(MaxRuleParser parser) { 26 | super(parser.maxslen, parser.nthread, parser.parallel, parser.iosprune, parser.usemask); 27 | this.inferencer = parser.inferencer; 28 | this.chart = new Chart(parser.maxslen, true, true, parser.usemask); 29 | this.usestag = parser.usestag; 30 | this.masks = parser.masks; 31 | } 32 | 33 | 34 | public MaxRuleParser(LVeGGrammar grammar, LVeGLexicon lexicon, short maxLenParsing, short nthread, 35 | boolean parallel, boolean iosprune, boolean usemasks, boolean usestag, Set[][][] masks) { 36 | super(maxLenParsing, nthread, parallel, iosprune, usemasks); 37 | this.inferencer = new MaxRuleInferencer(grammar, lexicon); 38 | this.chart = new Chart(maxLenParsing, true, true, usemasks); 39 | this.usestag = usestag; 40 | this.masks = masks; 41 | } 42 | 43 | 44 | @Override 45 | public MaxRuleParser newInstance() { 46 | return new MaxRuleParser(this); 47 | } 48 | 49 | 50 | @Override 51 | public synchronized Object call() throws Exception { 52 | Tree sample = (Tree) task; 53 | Tree parsed = parse(sample, itask); 54 | Meta cache = new Meta(itask, parsed); 55 | synchronized (caches) { 56 | caches.add(cache); 57 | caches.notify(); 58 | } 59 | task = null; 60 | return itask; 61 | } 62 | 63 | 64 | /** 65 | * Dedicated to error handling while recovering the recorded best parse path. 66 | * 67 | * @param tree the golden parse tree 68 | * @return parse tree given the sentence 69 | */ 70 | public Tree parse(Tree tree, int itree) { 71 | Tree parsed = null; 72 | try { // do NOT expect it to crash 73 | boolean valid = evalMaxRuleCount(tree, itree); 74 | if (valid) { 75 | parsed = StateTreeList.stateTreeToStringTree(tree, Inferencer.grammar.numberer); 76 | parsed = Inferencer.extractBestMaxRuleParse(chart, parsed.getYield()); 77 | } else { 78 | parsed = new Tree(Inferencer.DUMMY_TAG); 79 | } 80 | } catch (Exception e) { 81 | parsed = new Tree(Inferencer.DUMMY_TAG); 82 | e.printStackTrace(); 83 | } 84 | return parsed; 85 | } 86 | 87 | 88 | /** 89 | * Compute pseudo counts of grammar rules, and find a best parse path. 90 | * 91 | * @param tree the golden parse tree 92 | * @return whether the sentence can be parsed (true) of not (false) 93 | */ 94 | private boolean evalMaxRuleCount(Tree tree, int itree) { 95 | List sentence = tree.getYield(); 96 | int nword = sentence.size(); 97 | double scoreS = doInsideOutside(tree, sentence, itree, nword); 98 | // logger.trace("\nInside scores with the sentence...\n\n"); // DEBUG 99 | // FunUtil.debugChart(chart.getChart(true), (short) -1, tree.getYield().size()); // DEBUG 100 | // logger.trace("\nOutside scores with the sentence...\n\n"); // DEBUG 101 | // FunUtil.debugChart(chart.getChart(false), (short) -1, tree.getYield().size()); // DEBUG 102 | 103 | if (Double.isFinite(scoreS)) { 104 | // logger.trace("\nEval rule count with the sentence...\n"); // DEBUG 105 | inferencer.evalMaxRuleCount(chart, sentence, nword, scoreS); 106 | return true; 107 | } 108 | return false; 109 | } 110 | 111 | 112 | /** 113 | * Inside/outside score calculation is required by MaxRule parser. 114 | * 115 | * @param tree the golden parse tree 116 | * @param sentence the sentence need to be parsed 117 | * @param nword length of the sentence 118 | * @return 119 | */ 120 | private double doInsideOutside(Tree tree, List sentence, int itree, int nword) { 121 | List goldentag = usestag ? tree.getPreTerminalYield() : null; 122 | if (chart != null) { 123 | chart.clear(nword); 124 | } else { 125 | chart = new Chart(nword, true, true, usemask); 126 | } 127 | if (usemask) { 128 | boolean status = usemask; 129 | if (masks != null) { // obtained from kbest PCFG parsing 130 | status = createKbestMask(nword, chart, masks[itree]); 131 | } 132 | if (!status || masks == null) { // posterior probability 133 | createPCFGMask(nword, chart, sentence); 134 | } 135 | } 136 | if (parallel) { 137 | cpool.reset(); 138 | Inferencer.insideScore(chart, sentence, nword, iosprune, cpool); 139 | Inferencer.setRootOutsideScore(chart); 140 | cpool.reset(); 141 | Inferencer.outsideScore(chart, sentence, nword, iosprune, cpool); 142 | } else { 143 | // logger.trace("\nInside score...\n"); // DEBUG 144 | Inferencer.insideScore(chart, sentence, goldentag, nword, iosprune, usemask, LVeGTrainer.iomask); 145 | // FunUtil.debugChart(chart.getChart(true), (short) -1, tree.getYield().size()); // DEBUG 146 | 147 | Inferencer.setRootOutsideScore(chart); 148 | // logger.trace("\nOutside score...\n"); // DEBUG 149 | Inferencer.outsideScore(chart, sentence, nword, iosprune, usemask, LVeGTrainer.iomask); 150 | // FunUtil.debugChart(chart.getChart(false), (short) -1, tree.getYield().size()); // DEBUG 151 | } 152 | double scoreS = Double.NEGATIVE_INFINITY; 153 | GaussianMixture score = chart.getInsideScore((short) 0, Chart.idx(0, 1)); 154 | if (score != null) { 155 | scoreS = score.eval(null, true); 156 | } else if (usemask && masks != null) { // re-parse using pcfg pruning, assuming k-best pruning fails 157 | chart.clear(nword); 158 | 159 | createPCFGMask(nword, chart, sentence); 160 | 161 | if (parallel) { 162 | cpool.reset(); 163 | Inferencer.insideScore(chart, sentence, nword, iosprune, cpool); 164 | Inferencer.setRootOutsideScore(chart); 165 | cpool.reset(); 166 | Inferencer.outsideScore(chart, sentence, nword, iosprune, cpool); 167 | } else { 168 | Inferencer.insideScore(chart, sentence, goldentag, nword, iosprune, usemask, LVeGTrainer.iomask); 169 | Inferencer.setRootOutsideScore(chart); 170 | Inferencer.outsideScore(chart, sentence, nword, iosprune, usemask, LVeGTrainer.iomask); 171 | } 172 | score = chart.getInsideScore((short) 0, Chart.idx(0, 1)); 173 | if (score != null) { 174 | scoreS = score.eval(null, true); 175 | } 176 | } 177 | return scoreS; 178 | } 179 | 180 | 181 | private static boolean createKbestMask(int nword, Chart chart, Set[][] mask) { 182 | int len = mask.length, idx, layer; 183 | if (nword != len) { return false; } 184 | for (int i = 0; i < len; i++) { 185 | for (int j = i; j < len; j++) { 186 | layer = nword - j + i; // nword - (j - i) 187 | idx = layer * (layer - 1) / 2 + i; // (nword - 1 + 1)(nword - 1) / 2 188 | for (String label : mask[i][j]) { 189 | short ikey = (short) Inferencer.grammar.numberer.number(label); 190 | chart.addPosteriorMask(ikey, idx); 191 | } 192 | } 193 | } 194 | return true; 195 | } 196 | 197 | 198 | private static void createPCFGMask(int nword, Chart chart, List sentence) { 199 | PCFGInferencer.insideScore(chart, sentence, nword, LVeGTrainer.iomask, LVeGTrainer.tgBase, LVeGTrainer.tgRatio); 200 | PCFGInferencer.setRootOutsideScore(chart); 201 | PCFGInferencer.outsideScore(chart, sentence, nword, LVeGTrainer.iomask, LVeGTrainer.tgBase, LVeGTrainer.tgRatio); 202 | if (!LVeGTrainer.iomask) { // not use inside/outside score masks 203 | double score = chart.getInsideScoreMask((short) 0, Chart.idx(0, 1)); 204 | PCFGInferencer.createPosteriorMask(nword, chart, score, LVeGTrainer.tgProb); 205 | } 206 | } 207 | 208 | } 209 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/impl/MoGFactory.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import java.util.EnumMap; 4 | import java.util.List; 5 | import java.util.Random; 6 | import java.util.Set; 7 | 8 | import org.apache.commons.pool2.KeyedPooledObjectFactory; 9 | import org.apache.commons.pool2.PooledObject; 10 | import org.apache.commons.pool2.impl.DefaultPooledObject; 11 | 12 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianDistribution; 13 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 14 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture.Component; 15 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule.RuleUnit; 16 | 17 | public class MoGFactory implements KeyedPooledObjectFactory { 18 | 19 | protected short ncomponent; 20 | protected double nwratio; 21 | protected double maxmw; 22 | protected Random rnd; 23 | 24 | public MoGFactory(short ncomponent, double maxmw, double nwratio, Random rnd) { 25 | this.ncomponent = ncomponent; 26 | this.nwratio = nwratio; 27 | this.maxmw = maxmw; 28 | this.rnd = rnd; 29 | } 30 | 31 | /** 32 | * https://commons.apache.org/proper/commons-pool/apidocs/index.html?org/apache/commons/pool2/class-use/PooledObject.html 33 | * The above link tells us that activateObject(K, org.apache.commons.pool2.PooledObject) is invoked on every instance 34 | * that has been passivated before it is borrowed from the pool, which is DEFINITELY NOT consist with the codes. In line 35 | * 395 of GenericKeyedObjectPool, every non-null empty object, including the newly-created one, will be passed 36 | * into {@code activateObject(K, T)}. TODO modify and rebuild our own pool2 library. 37 | */ 38 | @Override 39 | public void activateObject(Short key, PooledObject po) throws Exception { 40 | short ncomp = key == -1 ? ncomponent : key; 41 | List components = po.getObject().components(); 42 | for (int i = 0; i < ncomp; i++) { 43 | double weight = (rnd.nextDouble() - nwratio) * maxmw; 44 | EnumMap> multivnd = new EnumMap<>(RuleUnit.class); 45 | // weight = /*-0.69314718056*/ 0; // mixing weight 0.5, 1, 2 46 | components.add(new Component((short) i, weight, multivnd)); 47 | } 48 | } 49 | 50 | @Override 51 | public void destroyObject(Short key, PooledObject po) throws Exception { 52 | po.getObject().destroy(key); 53 | } 54 | 55 | @Override 56 | public PooledObject makeObject(Short key) throws Exception { 57 | short ncomp = key == -1 ? ncomponent : key; 58 | GaussianMixture mog = new DiagonalGaussianMixture(ncomp, false); 59 | mog.setKey(key); 60 | return new DefaultPooledObject(mog); 61 | } 62 | 63 | @Override 64 | public void passivateObject(Short key, PooledObject po) throws Exception { 65 | po.getObject().clear(key); 66 | } 67 | 68 | @Override 69 | public boolean validateObject(Short key, PooledObject po) { 70 | GaussianMixture obj = po.getObject(); 71 | return obj != null && obj.isValid(key); 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/impl/PCFGMaxRuleParser.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import java.util.List; 4 | 5 | import edu.berkeley.nlp.syntax.Tree; 6 | import edu.shanghaitech.ai.nlp.data.StateTreeList; 7 | import edu.shanghaitech.ai.nlp.lveg.model.ChartCell.Chart; 8 | import edu.shanghaitech.ai.nlp.lveg.model.Inferencer; 9 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGGrammar; 10 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGLexicon; 11 | import edu.shanghaitech.ai.nlp.lveg.model.Parser; 12 | import edu.shanghaitech.ai.nlp.syntax.State; 13 | import edu.shanghaitech.ai.nlp.util.Executor; 14 | 15 | public class PCFGMaxRuleParser extends Parser { 16 | /** 17 | * 18 | */ 19 | private static final long serialVersionUID = -2668797817745391783L; 20 | private PCFGMaxRuleInferencer inferencer; 21 | 22 | 23 | private PCFGMaxRuleParser(PCFGMaxRuleParser parser) { 24 | super(parser.maxslen, parser.nthread, parser.parallel, parser.iosprune, parser.usemask); 25 | this.inferencer = parser.inferencer; // shared by multiple threads 26 | this.chart = new Chart(parser.maxslen, false, true, true); 27 | } 28 | 29 | 30 | public PCFGMaxRuleParser(LVeGGrammar grammar, LVeGLexicon lexicon, short maxslen, short nthread, 31 | boolean parallel, boolean iosprune, boolean usemasks) { 32 | super(maxslen, nthread, parallel, iosprune, usemasks); 33 | this.inferencer = new PCFGMaxRuleInferencer(grammar, lexicon); 34 | this.chart = new Chart(maxslen, false, true, true); 35 | } 36 | 37 | 38 | @Override 39 | public Executor newInstance() { 40 | return new PCFGMaxRuleParser(this); 41 | } 42 | 43 | 44 | @Override 45 | public synchronized Object call() throws Exception { 46 | Tree sample = (Tree) task; 47 | Tree parsed = parse(sample); 48 | Meta cache = new Meta(itask, parsed); 49 | synchronized (caches) { 50 | caches.add(cache); 51 | caches.notify(); 52 | } 53 | task = null; 54 | return itask; 55 | } 56 | 57 | 58 | /** 59 | * Dedicated to error handling while recovering the recorded best parse path. 60 | * 61 | * @param tree the golden parse tree 62 | * @return parse tree given the sentence 63 | */ 64 | public Tree parse(Tree tree) { 65 | Tree parsed = null; 66 | try { // do NOT expect it to crash 67 | boolean valid = evalMaxRuleCount(tree); 68 | if (valid) { 69 | parsed = StateTreeList.stateTreeToStringTree(tree, Inferencer.grammar.numberer); 70 | parsed = Inferencer.extractBestMaxRuleParse(chart, parsed.getYield()); 71 | } else { 72 | parsed = new Tree(Inferencer.DUMMY_TAG); 73 | } 74 | } catch (Exception e) { 75 | parsed = new Tree(Inferencer.DUMMY_TAG); 76 | e.printStackTrace(); 77 | } 78 | return parsed; 79 | } 80 | 81 | 82 | private boolean evalMaxRuleCount(Tree tree) { 83 | List sentence = tree.getYield(); 84 | int nword = sentence.size(); 85 | double scoreS = doInsideOutside(tree, sentence, nword); 86 | // logger.trace("\nInside scores with the sentence...\n\n"); // DEBUG 87 | // FunUtil.debugChart(chart.getChart(true), (short) -1, tree.getYield().size()); // DEBUG 88 | // logger.trace("\nOutside scores with the sentence...\n\n"); // DEBUG 89 | // FunUtil.debugChart(chart.getChart(false), (short) -1, tree.getYield().size()); // DEBUG 90 | 91 | if (Double.isFinite(scoreS)) { 92 | // logger.trace("\nEval rule count with the sentence...\n"); // DEBUG 93 | inferencer.evalMaxRuleProdCount(chart, sentence, nword, scoreS); 94 | // inferencer.evalMaxRuleSumCount(chart, sentence, nword, scoreS); 95 | return true; 96 | } 97 | return false; 98 | } 99 | 100 | 101 | private double doInsideOutside(Tree tree, List sentence, int nword) { 102 | if (chart != null) { 103 | chart.clear(nword); 104 | } else { 105 | chart = new Chart(nword, false, true, true); 106 | } 107 | 108 | // logger.trace("\nInside score...\n"); // DEBUG 109 | PCFGInferencer.insideScore(chart, sentence, nword, false, -1, -1); 110 | // FunUtil.debugChart(chart.getChart(true), (short) -1, tree.getYield().size()); // DEBUG 111 | 112 | PCFGInferencer.setRootOutsideScore(chart); 113 | // logger.trace("\nOutside score...\n"); // DEBUG 114 | PCFGInferencer.outsideScore(chart, sentence, nword, false, -1, -1); 115 | // FunUtil.debugChart(chart.getChart(false), (short) -1, tree.getYield().size()); // DEBUG 116 | 117 | double scoreS = Double.NEGATIVE_INFINITY; 118 | if (chart.containsKeyMask((short) 0, Chart.idx(0, 1), true)) { 119 | scoreS = chart.getInsideScoreMask((short) 0, Chart.idx(0, 1)); 120 | } 121 | return scoreS; 122 | } 123 | 124 | } 125 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/impl/PCFGParser.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import java.util.List; 4 | 5 | import edu.berkeley.nlp.syntax.Tree; 6 | import edu.shanghaitech.ai.nlp.data.StateTreeList; 7 | import edu.shanghaitech.ai.nlp.lveg.model.ChartCell.Chart; 8 | import edu.shanghaitech.ai.nlp.lveg.model.Inferencer; 9 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGGrammar; 10 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGLexicon; 11 | import edu.shanghaitech.ai.nlp.lveg.model.Parser; 12 | import edu.shanghaitech.ai.nlp.syntax.State; 13 | import edu.shanghaitech.ai.nlp.util.Executor; 14 | 15 | public class PCFGParser extends Parser { 16 | /** 17 | * 18 | */ 19 | private static final long serialVersionUID = -2668797817745391783L; 20 | private PCFGInferencer inferencer; 21 | 22 | 23 | private PCFGParser(PCFGParser parser) { 24 | super(parser.maxslen, parser.nthread, parser.parallel, parser.iosprune, parser.usemask); 25 | this.inferencer = parser.inferencer; // shared by multiple threads 26 | this.chart = new Chart(parser.maxslen, false, true, false); 27 | } 28 | 29 | 30 | public PCFGParser(LVeGGrammar grammar, LVeGLexicon lexicon, short maxslen, short nthread, 31 | boolean parallel, boolean iosprune, boolean usemasks) { 32 | super(maxslen, nthread, parallel, iosprune, usemasks); 33 | this.inferencer = new PCFGInferencer(grammar, lexicon); 34 | this.chart = new Chart(maxslen, false, true, false); 35 | } 36 | 37 | 38 | @Override 39 | public Executor newInstance() { 40 | return new PCFGParser(this); 41 | } 42 | 43 | 44 | @Override 45 | public synchronized Object call() throws Exception { 46 | Tree sample = (Tree) task; 47 | Tree parsed = parse(sample); 48 | Meta cache = new Meta(itask, parsed); 49 | synchronized (caches) { 50 | caches.add(cache); 51 | caches.notify(); 52 | } 53 | task = null; 54 | return itask; 55 | } 56 | 57 | 58 | /** 59 | * Dedicated to error handling while recovering the recorded best parse path. 60 | * 61 | * @param tree the golden parse tree 62 | * @return parse tree given the sentence 63 | */ 64 | public Tree parse(Tree tree) { 65 | Tree parsed = null; 66 | try { // do NOT expect it to crash 67 | viterbiParse(tree); 68 | parsed = StateTreeList.stateTreeToStringTree(tree, Inferencer.grammar.numberer); 69 | parsed = Inferencer.extractBestMaxRuleParse(chart, parsed.getYield()); 70 | } catch (Exception e) { 71 | parsed = new Tree(Inferencer.DUMMY_TAG); 72 | e.printStackTrace(); 73 | } 74 | return parsed; 75 | } 76 | 77 | 78 | /** 79 | * Compute and record a viterbi parse path. 80 | * 81 | * @param tree the golden parse tree 82 | */ 83 | private void viterbiParse(Tree tree) { 84 | List sentence = tree.getYield(); 85 | int nword = sentence.size(); 86 | if (chart != null) { 87 | chart.clear(nword); 88 | } else { 89 | chart = new Chart(nword, false, true, false); 90 | } 91 | inferencer.viterbiParse(chart, sentence, nword); 92 | } 93 | 94 | } 95 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/impl/RuleTable.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import java.io.Serializable; 4 | import java.util.HashMap; 5 | import java.util.Map; 6 | import java.util.Set; 7 | 8 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 9 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule; 10 | 11 | /** 12 | * Map grammar rules to their counts. Implementation in the generic way could be better. 13 | * 14 | * @author Yanpeng Zhao 15 | * 16 | */ 17 | public class RuleTable implements Serializable { 18 | /** 19 | * 20 | */ 21 | private static final long serialVersionUID = 7379371632425330796L; 22 | Class type; 23 | Map table; 24 | 25 | 26 | public RuleTable(Class type) { 27 | this.table = new HashMap<>(); 28 | this.type = type; 29 | } 30 | 31 | 32 | public int size() { 33 | return table.size(); 34 | } 35 | 36 | 37 | public void clear() { 38 | table.clear(); 39 | } 40 | 41 | 42 | public boolean isEmpty() { 43 | return size() == 0; 44 | } 45 | 46 | 47 | public Set keySet() { 48 | return table.keySet(); 49 | } 50 | 51 | 52 | public boolean containsKey(GrammarRule key) { 53 | return table.containsKey(key); 54 | } 55 | 56 | 57 | /** 58 | * Type-specific instance. 59 | * 60 | * @param key search keyword 61 | * @return 62 | * 63 | */ 64 | public boolean isCompatible(GrammarRule key) { 65 | return type.isInstance(key); 66 | } 67 | 68 | 69 | public GaussianMixture getCount(GrammarRule key) { 70 | return table.get(key); 71 | } 72 | 73 | 74 | public void setCount(GrammarRule key, GaussianMixture value) { 75 | if (isCompatible(key)) { 76 | table.put(key, value); 77 | } 78 | } 79 | 80 | 81 | public void addCount(GrammarRule key, double increment) { 82 | GaussianMixture count = getCount(key); 83 | if (count == null) { 84 | GaussianMixture gm = new DiagonalGaussianMixture(); 85 | gm.add(increment); 86 | setCount(key, gm); 87 | return; 88 | } 89 | count.add(increment); 90 | } 91 | 92 | 93 | public void addCount(GrammarRule key, GaussianMixture increment, boolean prune) { 94 | GaussianMixture count = getCount(key); 95 | if (count == null) { 96 | GaussianMixture gm = new DiagonalGaussianMixture(); 97 | gm.add(increment, prune); 98 | setCount(key, gm); 99 | return; 100 | } 101 | count.add(increment, prune); 102 | } 103 | 104 | 105 | /** 106 | * @param deep deep copy or shallow copy 107 | * @return 108 | */ 109 | @SuppressWarnings({ "rawtypes", "unchecked" }) 110 | public RuleTable copy(boolean deep) { 111 | RuleTable ruleTable = new RuleTable(type); 112 | for (GrammarRule rule : table.keySet()) { 113 | // copy key by reference, when only the count (value) varies 114 | if (!deep) { 115 | ruleTable.addCount(rule, null, false); 116 | } else { 117 | ruleTable.addCount(rule.copy(), table.get(rule).copy(true), false); 118 | } 119 | } 120 | return ruleTable; 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/impl/UnaryGrammarRule.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 4 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule; 5 | 6 | 7 | /** 8 | * @author Yanpeng Zhao 9 | * 10 | */ 11 | public class UnaryGrammarRule extends GrammarRule implements Comparable { 12 | /** 13 | * 14 | */ 15 | private static final long serialVersionUID = 7796142278063439713L; 16 | /** 17 | * the ID of the right-hand side nonterminal 18 | */ 19 | public int rhs; 20 | 21 | 22 | public UnaryGrammarRule(short lhs, int rhs) { 23 | this.lhs = lhs; 24 | this.rhs = rhs; 25 | this.type = RuleType.LRURULE; 26 | } 27 | 28 | 29 | public UnaryGrammarRule(short lhs, int rhs, RuleType type) { 30 | this.lhs = lhs; 31 | this.rhs = rhs; 32 | this.type = type; 33 | } 34 | 35 | 36 | public UnaryGrammarRule(short lhs, int rhs, RuleType type, boolean init) { 37 | this(lhs, rhs, type); 38 | if (init) { initializeWeight(type, (short) -1, (short) -1); } 39 | } 40 | 41 | 42 | public UnaryGrammarRule(short lhs, int rhs, RuleType type, GaussianMixture weight) { 43 | this(lhs, rhs, type); 44 | this.weight = weight; 45 | } 46 | 47 | 48 | @Override 49 | public void initializeWeight(RuleType type, short ncomponent, short ndim) { 50 | weight = rndRuleWeight(type, ncomponent, ndim); 51 | } 52 | 53 | 54 | @Override 55 | public GrammarRule copy() { 56 | UnaryGrammarRule rule = new UnaryGrammarRule(lhs, rhs); 57 | rule.weight = weight.copy(true); 58 | rule.type = type; 59 | return rule; 60 | } 61 | 62 | 63 | @Override 64 | public boolean isUnary() { 65 | return true; 66 | } 67 | 68 | 69 | @Override 70 | public int hashCode() { 71 | return (lhs << 18) ^ (rhs); 72 | } 73 | 74 | 75 | @Override 76 | public boolean equals(Object o) { 77 | if (this == o) { return true; } 78 | 79 | if (o instanceof UnaryGrammarRule) { 80 | UnaryGrammarRule rule = (UnaryGrammarRule) o; 81 | if (lhs == rule.lhs && rhs == rule.rhs && type == rule.type) { 82 | return true; 83 | } 84 | } 85 | return false; 86 | } 87 | 88 | 89 | @Override 90 | public int compareTo(Object o) { 91 | // TODO Auto-generated method stub 92 | UnaryGrammarRule rule = (UnaryGrammarRule) o; 93 | if (lhs < rule.lhs) { return -1; } 94 | if (lhs > rule.lhs) { return 1; } 95 | if (rhs < rule.rhs) { return -1; } 96 | if (rhs > rule.rhs) { return 1; } 97 | return 0; 98 | } 99 | 100 | 101 | @Override 102 | public String toString() { 103 | return "U-Rule [P: " + lhs +", UC: " + rhs + ", T: " + type + "]"; 104 | } 105 | 106 | } 107 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/impl/Valuator.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import java.util.List; 4 | 5 | import edu.berkeley.nlp.syntax.Tree; 6 | import edu.shanghaitech.ai.nlp.lveg.model.ChartCell.Chart; 7 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 8 | import edu.shanghaitech.ai.nlp.lveg.model.Inferencer; 9 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGLexicon; 10 | import edu.shanghaitech.ai.nlp.lveg.model.Parser; 11 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGGrammar; 12 | import edu.shanghaitech.ai.nlp.syntax.State; 13 | 14 | public class Valuator extends Parser { 15 | /** 16 | * 17 | */ 18 | private static final long serialVersionUID = -1069775066639086250L; 19 | private LVeGInferencer inferencer; 20 | 21 | 22 | private Valuator(Valuator valuator) { 23 | super(valuator.maxslen, valuator.nthread, valuator.parallel, valuator.iosprune, false); 24 | this.inferencer = valuator.inferencer; 25 | this.chart = new Chart(valuator.maxslen, true, false, false); 26 | } 27 | 28 | 29 | public Valuator(LVeGGrammar grammar, LVeGLexicon lexicon, short maxLenParsing, short nthread, 30 | boolean parallel, boolean iosprune, boolean usemasks) { 31 | super(maxLenParsing, nthread, parallel, iosprune, false); 32 | this.inferencer = new LVeGInferencer(grammar, lexicon); 33 | this.chart = new Chart(maxLenParsing, true, false, false); 34 | } 35 | 36 | 37 | @Override 38 | public Valuator newInstance() { 39 | return new Valuator(this); 40 | } 41 | 42 | 43 | @Override 44 | public synchronized Object call() { 45 | Tree sample = (Tree) task; 46 | double ll = Double.NEGATIVE_INFINITY; 47 | synchronized (sample) { 48 | ll = probability(sample); 49 | } 50 | Meta cache = new Meta(itask, ll); 51 | synchronized (caches) { 52 | caches.add(cache); 53 | caches.notify(); 54 | } 55 | task = null; 56 | return itask; 57 | } 58 | 59 | 60 | /** 61 | * Compute \log p(t | s) = \log {p(t, s) / p(s)}, where s denotes the 62 | * sentence, t is the parse tree. 63 | * 64 | * @param tree the parse tree 65 | * @return logarithmic conditional probability of the parse tree given the sentence 66 | */ 67 | public double probability(Tree tree) { 68 | double ll = Double.NEGATIVE_INFINITY; 69 | try { // do NOT except it to crash 70 | double jointdist = scoreTree(tree); 71 | double partition = scoreSentence(tree); 72 | ll = jointdist - partition; 73 | } catch (Exception e) { 74 | e.printStackTrace(); 75 | } 76 | return ll; 77 | } 78 | 79 | 80 | /** 81 | * Compute p(t, s), where s denotes the sentence, t is a parse tree. 82 | * 83 | * @param tree the parse tree 84 | * @return score of the parse tree 85 | */ 86 | protected double scoreTree(Tree tree) { 87 | LVeGInferencer.insideScoreWithTree(tree); 88 | double scoreT = Double.NEGATIVE_INFINITY; 89 | GaussianMixture score = tree.getLabel().getInsideScore(); 90 | if (score != null) { 91 | scoreT = score.eval(null, true); 92 | } 93 | return scoreT; 94 | } 95 | 96 | 97 | /** 98 | * Compute \sum_{t \in T} p(t, s), where T is the space of the parse tree. 99 | * 100 | * @param tree in which only the sentence is used 101 | * @return the sentence score 102 | */ 103 | protected double scoreSentence(Tree tree) { 104 | List sentence = tree.getYield(); 105 | int nword = sentence.size(); 106 | if (chart != null) { 107 | chart.clear(nword); 108 | } else { 109 | chart = new Chart(nword, true, false, false); 110 | } 111 | if (parallel) { 112 | cpool.reset(); 113 | Inferencer.insideScore(chart, sentence, nword, iosprune, cpool); 114 | } else { 115 | Inferencer.insideScore(chart, sentence, null, nword, iosprune, false, false); 116 | } 117 | double scoreS = Double.NEGATIVE_INFINITY; 118 | GaussianMixture score = chart.getInsideScore((short) 0, Chart.idx(0, 1)); 119 | if (score != null) { 120 | scoreS = score.eval(null, true); 121 | } 122 | return scoreS; 123 | } 124 | 125 | } 126 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/model/GrammarRule.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.model; 2 | 3 | import java.io.Serializable; 4 | import java.util.EnumMap; 5 | import java.util.HashSet; 6 | import java.util.Set; 7 | 8 | import edu.shanghaitech.ai.nlp.lveg.impl.DiagonalGaussianDistribution; 9 | import edu.shanghaitech.ai.nlp.lveg.impl.DiagonalGaussianMixture; 10 | 11 | /** 12 | * @author Yanpeng Zhao 13 | * 14 | */ 15 | public abstract class GrammarRule implements Serializable { 16 | /** 17 | * 18 | */ 19 | private static final long serialVersionUID = -1464935410648668068L; 20 | public enum RuleType { 21 | LRURULE(0), LHSPACE(1), RHSPACE(2), LRBRULE(3); 22 | private final int id; 23 | 24 | private RuleType(int id) { 25 | this.id = id; 26 | } 27 | 28 | public String id() { 29 | return String.valueOf(id); 30 | } 31 | 32 | @Override 33 | public String toString() { 34 | return String.valueOf(id); 35 | } 36 | } 37 | 38 | public enum RuleUnit { 39 | P, C, LC, RC, UC, RM 40 | } 41 | 42 | /** 43 | * the ID of the left-hand side nonterminal 44 | */ 45 | public short lhs; 46 | public RuleType type; 47 | public GaussianMixture weight; 48 | 49 | /* 50 | public final static byte LRBRULE = 3; // left and right hand sides, binary rule 51 | public final static byte RHSPACE = 2; 52 | public final static byte LHSPACE = 1; 53 | public final static byte LRURULE = 0; // left and right hand sides, unary rule 54 | */ 55 | 56 | /* 57 | public static class Unit { 58 | public final static String P = "p"; 59 | public final static String C = "c"; 60 | public final static String LC = "lc"; 61 | public final static String RC = "rc"; 62 | public final static String UC = "uc"; 63 | public final static String RM = "rm"; 64 | } 65 | */ 66 | 67 | public GrammarRule() { 68 | // TODO 69 | } 70 | 71 | public abstract boolean isUnary(); 72 | public abstract GrammarRule copy(); 73 | public abstract void initializeWeight(RuleType type, short ncomponent, short ndim); 74 | 75 | 76 | public void addWeightComponent(RuleType type, short increment, short ndim) { 77 | short defNcomp = increment > 0 ? increment : GaussianMixture.defNcomponent; 78 | short defNdim = ndim > 0 ? ndim : GaussianDistribution.defNdimension; 79 | if (weight == null) { 80 | weight = rndRuleWeight(type, defNcomp, defNdim); 81 | } else { 82 | GaussianMixture aweight = new DiagonalGaussianMixture(defNcomp); 83 | rndRuleWeight(type, defNcomp, defNdim, aweight); 84 | weight.add(aweight, false); 85 | weight.rectifyId(); // required 86 | } 87 | } 88 | 89 | 90 | public static GaussianMixture rndRuleWeight(RuleType type, short ncomponent, short ndim) { 91 | short defNcomp = ncomponent > 0 ? ncomponent : GaussianMixture.defNcomponent; 92 | short defNdim = ndim > 0 ? ndim : GaussianDistribution.defNdimension; 93 | GaussianMixture aweight = new DiagonalGaussianMixture(defNcomp); 94 | rndRuleWeight(type, defNcomp, defNdim, aweight); 95 | return aweight; 96 | } 97 | 98 | 99 | private static void rndRuleWeight(RuleType type, short ncomponent, short dim, GaussianMixture weight) { 100 | switch (type) { 101 | case RHSPACE: // rules for the root since it does not have subtypes 102 | for (int i = 0; i < ncomponent; i++) { 103 | Set set = new HashSet<>(1, 1); 104 | set.add(new DiagonalGaussianDistribution(dim)); 105 | weight.add(i, RuleUnit.C, set); 106 | } 107 | break; 108 | case LHSPACE: // rules in the preterminal layer (discarded) 109 | for (int i = 0; i < ncomponent; i++) { 110 | Set set = new HashSet<>(1, 1); 111 | set.add(new DiagonalGaussianDistribution(dim)); 112 | weight.add(i, RuleUnit.P, set); 113 | } 114 | break; 115 | case LRURULE: // general unary rules 116 | for (int i = 0; i < ncomponent; i++) { 117 | EnumMap> map = new EnumMap<>(RuleUnit.class); 118 | Set set0 = new HashSet<>(1, 1); 119 | Set set1 = new HashSet<>(1, 1); 120 | set0.add(new DiagonalGaussianDistribution(dim)); 121 | set1.add(new DiagonalGaussianDistribution(dim)); 122 | map.put(RuleUnit.P, set0); 123 | map.put(RuleUnit.UC, set1); 124 | weight.add(i, map); 125 | } 126 | break; 127 | case LRBRULE: // general binary rules 128 | for (int i = 0; i < ncomponent; i++) { 129 | EnumMap> map = new EnumMap<>(RuleUnit.class); 130 | Set set0 = new HashSet<>(1, 1); 131 | Set set1 = new HashSet<>(1, 1); 132 | Set set2 = new HashSet<>(1, 1); 133 | set0.add(new DiagonalGaussianDistribution(dim)); 134 | set1.add(new DiagonalGaussianDistribution(dim)); 135 | set2.add(new DiagonalGaussianDistribution(dim)); 136 | map.put(RuleUnit.P, set0); 137 | map.put(RuleUnit.LC, set1); 138 | map.put(RuleUnit.RC, set2); 139 | weight.add(i, map); 140 | } 141 | break; 142 | default: 143 | throw new RuntimeException("Not consistent with any grammar rule type. Type: " + type); 144 | } 145 | } 146 | 147 | 148 | public RuleType getType() { 149 | return type; 150 | } 151 | 152 | 153 | public void setType(RuleType type) { 154 | this.type = type; 155 | } 156 | 157 | 158 | public short getLhs() { 159 | return lhs; 160 | } 161 | 162 | 163 | public void setLhs(short lhs) { 164 | this.lhs = lhs; 165 | } 166 | 167 | 168 | public GaussianMixture getWeight() { 169 | return weight; 170 | } 171 | 172 | 173 | public void setWeight(GaussianMixture weight) { 174 | this.weight = weight; 175 | } 176 | 177 | } 178 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/model/LVeGGrammar.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.model; 2 | 3 | import java.io.Serializable; 4 | import java.util.EnumMap; 5 | import java.util.List; 6 | import java.util.Map; 7 | import java.util.Set; 8 | 9 | import edu.berkeley.nlp.syntax.Tree; 10 | import edu.shanghaitech.ai.nlp.lveg.LVeGTrainer; 11 | import edu.shanghaitech.ai.nlp.lveg.impl.BinaryGrammarRule; 12 | import edu.shanghaitech.ai.nlp.lveg.impl.RuleTable; 13 | import edu.shanghaitech.ai.nlp.lveg.impl.UnaryGrammarRule; 14 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule.RuleType; 15 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule.RuleUnit; 16 | import edu.shanghaitech.ai.nlp.optimization.Optimizer; 17 | import edu.shanghaitech.ai.nlp.syntax.State; 18 | import edu.shanghaitech.ai.nlp.util.Numberer; 19 | import edu.shanghaitech.ai.nlp.util.Recorder; 20 | 21 | public abstract class LVeGGrammar extends Recorder implements Serializable { 22 | /** 23 | * 24 | */ 25 | private static final long serialVersionUID = 5874243553526905936L; 26 | protected RuleTable uRuleTable; 27 | protected RuleTable bRuleTable; 28 | 29 | protected List[] uRulesWithP; 30 | protected List[] uRulesWithC; 31 | 32 | protected List[] bRulesWithP; 33 | protected List[] bRulesWithLC; 34 | protected List[] bRulesWithRC; 35 | 36 | protected Optimizer optimizer; 37 | public Numberer numberer; 38 | public int ntag; 39 | 40 | /** 41 | * Needed when we want to find a rule and access its statistics. 42 | * we first construct a rule, which is used as the key, and use 43 | * the key to find the real rule that contains more information. 44 | */ 45 | protected Map uRuleMap; 46 | protected Map bRuleMap; 47 | 48 | /** 49 | * For any nonterminals A \neq B \neq C, p(A->B) is computed as 50 | * p(A->B) + \sum_{C} p(A->C) \times p(C->B), in which p(A->B) 51 | * is zero if A->B does not exist, and the resulting new rules 52 | * are added to the unary rule set. Fields containing 'sum' are 53 | * dedicated to the general CYK algorithm, and are dedicated to 54 | * to the Viterbi algorithm if they contain 'Max'. However, the 55 | * point is how to define the maximum between two MoGs. 56 | */ 57 | protected Set chainSumUnaryRules; 58 | protected List[] chainSumUnaryRulesWithP; 59 | protected List[] chainSumUnaryRulesWithC; 60 | 61 | protected abstract void initialize(); 62 | 63 | public abstract void postInitialize(); 64 | 65 | public abstract void initializeOptimizer(); 66 | 67 | public abstract void tallyStateTree(Tree tree); 68 | 69 | public abstract void addURule(UnaryGrammarRule rule); 70 | 71 | public void addBRule(BinaryGrammarRule rule) {} 72 | 73 | /** 74 | * @param idParent id of the left hand side of the binary rule 75 | * @param idlChild id of the left child in the binary rule 76 | * @param idrChild id of the right child in the binary rule 77 | * @param context null can be returned (true) or not (false), which can be used to check if the given query rule is valid or not 78 | * @return 79 | */ 80 | public GaussianMixture getBRuleWeight(short idParent, short idlChild, short idrChild, boolean context) { 81 | GrammarRule rule = getBRule(idParent, idlChild, idrChild); 82 | if (rule != null) { 83 | return rule.getWeight(); 84 | } 85 | if (!context) { 86 | // when calculating inside and outside scores, we do not want the rule weight to be null, so just set it to zero 87 | // if the given query rule is not valid (never appears in the training set). 88 | logger.warn("\nBinary Rule NOT Found: [P: " + idParent + ", LC: " + idlChild + ", RC: " + idrChild + "]\n"); 89 | GaussianMixture weight = GrammarRule.rndRuleWeight(RuleType.LRBRULE, (short) -1, (short) -1); 90 | /*weight.setWeights(Double.NEGATIVE_INFINITY);*/ 91 | weight.setWeights(LVeGTrainer.minmw); 92 | return weight; 93 | } else { // 94 | return null; 95 | } 96 | } 97 | 98 | public GaussianMixture getURuleWeight(short idParent, short idChild, RuleType type, boolean context) { 99 | GrammarRule rule = getURule(idParent, idChild, type); 100 | if (rule != null) { 101 | return rule.getWeight(); 102 | } 103 | if (!context) { 104 | logger.warn("\nUnary Rule NOT Found: [P: " + idParent + ", UC: " + idChild + ", TYPE: " + type + "]\n"); 105 | GaussianMixture weight = GrammarRule.rndRuleWeight(type, (short) -1, (short) -1); 106 | /*weight.setWeights(Double.NEGATIVE_INFINITY);*/ 107 | weight.setWeights(LVeGTrainer.minmw); 108 | return weight; 109 | } else { 110 | return null; 111 | } 112 | } 113 | 114 | public GrammarRule getBRule(short idParent, short idlChild, short idrChild) { 115 | GrammarRule rule = new BinaryGrammarRule(idParent, idlChild, idrChild); 116 | return bRuleMap.get(rule); 117 | } 118 | 119 | public GrammarRule getURule(short idParent, int idChild, RuleType type) { 120 | GrammarRule rule = new UnaryGrammarRule(idParent, idChild, type); 121 | return uRuleMap.get(rule); 122 | } 123 | 124 | public Map getBRuleMap() { 125 | return bRuleMap; 126 | } 127 | 128 | public Map getURuleMap() { 129 | return uRuleMap; 130 | } 131 | 132 | public List getChainSumUnaryRulesWithC(int itag) { 133 | return chainSumUnaryRulesWithC[itag]; 134 | } 135 | 136 | public List getChainSumUnaryRulesWithP(int itag) { 137 | return chainSumUnaryRulesWithP[itag]; 138 | } 139 | 140 | public List getBRuleWithRC(int itag) { 141 | return bRulesWithRC[itag]; 142 | } 143 | 144 | public List getBRuleWithLC(int itag) { 145 | return bRulesWithLC[itag]; 146 | } 147 | 148 | public List getBRuleWithP(int itag) { 149 | return bRulesWithP[itag]; 150 | } 151 | 152 | public List getURuleWithP(int itag) { 153 | return uRulesWithP[itag]; 154 | } 155 | 156 | public List getURuleWithC(int itag) { 157 | return uRulesWithC[itag]; 158 | } 159 | 160 | public boolean containsBRule(short idParent, short idlChild, short idrChild) { 161 | GrammarRule rule = new BinaryGrammarRule(idParent, idlChild, idrChild); 162 | return bRuleTable.containsKey(rule); 163 | } 164 | 165 | public boolean containsURule(short idParent, int idChild, RuleType type) { 166 | GrammarRule rule = new UnaryGrammarRule(idParent, idChild, type); 167 | return uRuleTable.containsKey(rule); 168 | } 169 | 170 | public void addCount(short idParent, short idlChild, short idrChild, EnumMap count, short isample, boolean withTree) { 171 | GrammarRule rule = getBRule(idParent, idlChild, idrChild); 172 | addCount(rule, count, isample, withTree); 173 | } 174 | 175 | public Map>> getCount(short idParent, short idlChild, short idrChild, boolean withTree) { 176 | GrammarRule rule = getBRule(idParent, idlChild, idrChild); 177 | return getCount(rule, withTree); 178 | } 179 | 180 | public void addCount(short idParent, int idChild, EnumMap count, RuleType type, short isample, boolean withTree) { 181 | GrammarRule rule = new UnaryGrammarRule(idParent, idChild, type); 182 | addCount(rule, count, isample, withTree); 183 | } 184 | 185 | public Map>> getCount(short idParent, int idChild, boolean withTree, RuleType type) { 186 | GrammarRule rule = new UnaryGrammarRule(idParent, idChild, type); 187 | return getCount(rule, withTree); 188 | } 189 | 190 | public void addCount(GrammarRule rule, EnumMap count, short isample, boolean withTree) { 191 | optimizer.addCount(rule, count, isample, withTree); 192 | } 193 | 194 | public Map>> getCount(GrammarRule rule, boolean withTree) { 195 | return optimizer.getCount(rule, withTree); 196 | } 197 | 198 | public void setOptimizer(Optimizer optimizer) { 199 | this.optimizer = optimizer; 200 | } 201 | 202 | public Optimizer getOptimizer() { 203 | return optimizer; 204 | } 205 | 206 | public void evalGradients(List scoreOfST) { 207 | optimizer.evalGradients(scoreOfST); 208 | } 209 | 210 | /** 211 | * Apply stochastic gradient descent. 212 | */ 213 | public void applyGradientDescent(List scoreOfST) { 214 | optimizer.applyGradientDescent(scoreOfST); 215 | } 216 | 217 | /** 218 | * Get the set of the rules. 219 | */ 220 | public Set getRuleSet() { 221 | return optimizer.getRuleSet(); 222 | } 223 | 224 | public void shutdown() { 225 | if (optimizer != null) { 226 | optimizer.shutdown(); 227 | } 228 | } 229 | 230 | } 231 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/lveg/model/Parser.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.model; 2 | 3 | import java.util.PriorityQueue; 4 | import java.util.Set; 5 | 6 | import edu.shanghaitech.ai.nlp.lveg.model.ChartCell.Chart; 7 | import edu.shanghaitech.ai.nlp.lveg.model.Inferencer.InputToSubCYKer; 8 | import edu.shanghaitech.ai.nlp.lveg.model.Inferencer.SubCYKer; 9 | import edu.shanghaitech.ai.nlp.util.Executor; 10 | import edu.shanghaitech.ai.nlp.util.Recorder; 11 | import edu.shanghaitech.ai.nlp.util.ThreadPool; 12 | 13 | public abstract class Parser extends Recorder implements Executor { 14 | /** 15 | * 16 | */ 17 | private static final long serialVersionUID = -7112164011234304607L; 18 | protected short maxslen = 120; 19 | 20 | protected int idx; 21 | protected short nthread; 22 | protected boolean parallel; 23 | protected boolean iosprune; 24 | protected boolean usemask; 25 | protected boolean usestag; // use golden tags in the sentence 26 | 27 | protected I task; 28 | protected int itask; 29 | protected Chart chart; 30 | protected PriorityQueue> caches; 31 | protected Set[][][] masks; 32 | 33 | protected transient ThreadPool cpool; 34 | 35 | protected Parser(short maxslen, short nthread, boolean parallel, boolean iosprune, boolean usemask) { 36 | this.maxslen = maxslen; 37 | this.iosprune = iosprune; 38 | this.usemask = usemask; 39 | this.usestag = false; 40 | this.parallel = parallel; 41 | this.nthread = nthread < 0 ? 1 : nthread; 42 | if (parallel) { 43 | SubCYKer subCYKer = new SubCYKer(); 44 | this.cpool = new ThreadPool(subCYKer, nthread); 45 | } 46 | } 47 | 48 | @Override 49 | public void setIdx(int idx, PriorityQueue> caches) { 50 | this.idx = idx; 51 | this.caches = caches; 52 | } 53 | 54 | @Override 55 | public void setNextTask(int itask, I task) { 56 | this.task = task; 57 | this.itask = itask; 58 | } 59 | 60 | } 61 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/optimization/Batch.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.optimization; 2 | 3 | import java.io.Serializable; 4 | import java.util.ArrayList; 5 | import java.util.EnumMap; 6 | import java.util.HashMap; 7 | import java.util.List; 8 | import java.util.Map; 9 | import java.util.Set; 10 | 11 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 12 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule.RuleUnit; 13 | 14 | public class Batch implements Serializable { 15 | /** 16 | * 17 | */ 18 | private static final long serialVersionUID = -5018573183821755031L; 19 | protected Map>> batch; 20 | 21 | /** 22 | * @param maxsize initialized by default if lower than 0, otherwise initialized with the specified capacity 23 | */ 24 | public Batch(int maxsize) { 25 | if (maxsize > 0) { 26 | batch = new HashMap<>(maxsize, 1); 27 | } else { 28 | batch = new HashMap<>(); 29 | } 30 | } 31 | 32 | protected void add(short idx, EnumMap cnt) { 33 | List> cnts = null; 34 | if ((cnts = batch.get(idx)) != null) { 35 | cnts.add(cnt); 36 | } else { 37 | cnts = new ArrayList<>(); 38 | cnts.add(cnt); 39 | batch.put(idx, cnts); 40 | } 41 | } 42 | 43 | protected List> get(short i) { 44 | return batch.get(i); 45 | } 46 | 47 | protected boolean containsKey(short i) { 48 | return batch.containsKey(i); 49 | } 50 | 51 | protected Set keySet() { 52 | return batch.keySet(); 53 | } 54 | 55 | protected void clear() { 56 | batch.clear(); 57 | } 58 | 59 | protected int size() { 60 | return batch.size(); 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/optimization/Optimizer.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.optimization; 2 | 3 | import java.io.Serializable; 4 | import java.util.EnumMap; 5 | import java.util.List; 6 | import java.util.Map; 7 | import java.util.Random; 8 | import java.util.Set; 9 | 10 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 11 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule; 12 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule.RuleUnit; 13 | import edu.shanghaitech.ai.nlp.util.Recorder; 14 | 15 | /** 16 | * @author Yanpeng Zhao 17 | * 18 | */ 19 | public abstract class Optimizer extends Recorder implements Serializable { 20 | /** 21 | * 22 | */ 23 | private static final long serialVersionUID = -6185772433718644490L; 24 | public enum OptChoice { 25 | NORMALIZED, SGD, MOMENTUM, ADAGRAD, RMSPROP, ADADELTA, ADAM 26 | } 27 | /** 28 | * Pseudo counts of grammar rules given the parse tree (countWithT) or the sentence (countWithS). 29 | */ 30 | protected Map cntsWithT; 31 | protected Map cntsWithS; 32 | 33 | /** 34 | * ruleset contains all the rules that are need to be optimized, it is 35 | * used to quickly index the rules. 36 | */ 37 | protected Set ruleSet; 38 | protected static Random rnd; 39 | protected static int maxsample = 1; 40 | protected static short batchsize = 1; 41 | protected static double minexp = Math.log(1e-6); 42 | protected static boolean sampling = false; 43 | protected static OptChoice choice = OptChoice.ADAM; 44 | 45 | /** 46 | * @param scoreSandT the parse tree score (odd index) and the sentence score (even index) 47 | */ 48 | public abstract void evalGradients(List scoresST); 49 | 50 | /** 51 | * Stochastic gradient descent. 52 | * 53 | * @param scoresOfST the parse tree score (odd index) and the sentence score (even index) 54 | */ 55 | public abstract void applyGradientDescent(List scoresST); 56 | 57 | /** 58 | * @param rule the rule that needs optimizing. 59 | */ 60 | public abstract void addRule(GrammarRule rule); 61 | 62 | protected abstract void reset(); 63 | 64 | public Object debug(GrammarRule rule, boolean debug) { return null; } 65 | public void shutdown() { /* NULL */ } 66 | 67 | /** 68 | * @param rule the grammar rule 69 | * @param cnt which contains 1) key GrammarRule.Unit.P maps to the outside score of the parent node 70 | * 2) key GrammarRule.Unit.UC/C (LC) maps to the inside score (of the left node) if the rule is unary (binary) 71 | * 3) key GrammarRule.Unit.RC maps to the inside score of the right node if the rule is binary, otherwise null 72 | * @param idx index of the sample in this batch 73 | * @param withT type of the expected pseudo count 74 | */ 75 | public void addCount(GrammarRule rule, EnumMap cnt, short idx, boolean withT) { 76 | Batch batch = null; 77 | Map cnts = withT ? cntsWithT : cntsWithS; 78 | if (rule != null && (batch = cnts.get(rule)) != null) { 79 | batch.add(idx, cnt); 80 | } else { 81 | logger.info("Not a valid grammar rule or the rule was not found. Rule: " + rule + "\n"); 82 | } 83 | } 84 | 85 | 86 | /** 87 | * The method for debugging. 88 | * 89 | * @param rule the grammar rule 90 | * @param withT type of the expected count 91 | * @return 92 | */ 93 | public Map>> getCount(GrammarRule rule, boolean withT) { 94 | Batch batch = null; 95 | Map cnts = withT ? cntsWithT : cntsWithS; 96 | if (rule != null && (batch = cnts.get(rule)) != null) { 97 | return batch.batch; 98 | } else { 99 | logger.info("Not a valid grammar rule or the rule was not found. Rule: " + rule + "\n"); 100 | return null; 101 | } 102 | } 103 | 104 | 105 | /** 106 | * Terrible data-structure design. Object saving leaves out the static members of the object. 107 | * FIXME no errors, just alert you to pay attention to it and improve it in future. 108 | * 109 | * @param random 110 | * @param msample 111 | * @param bsize 112 | */ 113 | public static void config(OptChoice achoice, Random random, int msample, short bsize, double minweight, boolean spling) { 114 | choice = achoice; 115 | rnd = random; 116 | batchsize = bsize; 117 | maxsample = msample; 118 | minexp = Math.log(minweight); 119 | sampling = spling; 120 | } 121 | 122 | 123 | /** 124 | * Get set of the rules. 125 | * 126 | * @return 127 | */ 128 | public Set getRuleSet() { 129 | return ruleSet; 130 | } 131 | 132 | } 133 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/optimization/SGDMinimizer.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.optimization; 2 | 3 | import java.io.Serializable; 4 | 5 | /** 6 | * TODO to figure out how the optimizers in standford parser work. 7 | * 8 | * @author Yanpeng Zhao 9 | * 10 | */ 11 | public class SGDMinimizer implements Serializable { 12 | /** 13 | * 14 | */ 15 | private static final long serialVersionUID = 2195142486292205901L; 16 | 17 | public SGDMinimizer() { 18 | 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/optimization/SimpleMinimizer.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.optimization; 2 | 3 | import java.io.Serializable; 4 | import java.util.ArrayList; 5 | import java.util.EnumMap; 6 | import java.util.List; 7 | import java.util.Map; 8 | import java.util.Map.Entry; 9 | import java.util.Random; 10 | 11 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 12 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule; 13 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule.RuleType; 14 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule.RuleUnit; 15 | import edu.shanghaitech.ai.nlp.util.Recorder; 16 | 17 | /** 18 | * Naive. The dedicated implementation is used to enable the program to run asap. 19 | * 20 | * @author Yanpeng Zhao 21 | * 22 | */ 23 | public class SimpleMinimizer extends Recorder implements Serializable { 24 | /** 25 | * 26 | */ 27 | private static final long serialVersionUID = -4030933989131874669L; 28 | protected EnumMap> truths; 29 | protected EnumMap> sample; 30 | protected EnumMap> ggrads; 31 | protected List wgrads; 32 | protected double wgrad; 33 | 34 | 35 | /** 36 | * To avoid the excessive 'new' operations. 37 | */ 38 | public SimpleMinimizer() { 39 | this.sample = new EnumMap>(RuleUnit.class); 40 | sample.put(RuleUnit.P, new ArrayList<>()); 41 | sample.put(RuleUnit.C, new ArrayList<>()); 42 | sample.put(RuleUnit.UC, new ArrayList<>()); 43 | sample.put(RuleUnit.LC, new ArrayList<>()); 44 | sample.put(RuleUnit.RC, new ArrayList<>()); 45 | this.truths = new EnumMap<>(RuleUnit.class); 46 | truths.put(RuleUnit.P, new ArrayList<>()); 47 | truths.put(RuleUnit.C, new ArrayList<>()); 48 | truths.put(RuleUnit.UC, new ArrayList<>()); 49 | truths.put(RuleUnit.LC, new ArrayList<>()); 50 | truths.put(RuleUnit.RC, new ArrayList<>()); 51 | this.wgrads = new ArrayList(); 52 | this.ggrads = new EnumMap>(RuleUnit.class); 53 | ggrads.put(RuleUnit.P, new ArrayList<>()); 54 | ggrads.put(RuleUnit.C, new ArrayList<>()); 55 | ggrads.put(RuleUnit.UC, new ArrayList<>()); 56 | ggrads.put(RuleUnit.LC, new ArrayList<>()); 57 | ggrads.put(RuleUnit.RC, new ArrayList<>()); 58 | } 59 | 60 | 61 | public SimpleMinimizer(Random random, int msample, short bsize) { 62 | this(); 63 | } 64 | 65 | 66 | /** 67 | * TODO revise the computation in logarithm. 68 | * 69 | * @param scoreT in logarithmic form 70 | * @param scoreS in logarithmic form 71 | * @param ioScoreWithT 72 | * @param ioScoreWithS 73 | * @return 74 | */ 75 | private double derivateRuleWeight( 76 | double scoreT, 77 | double scoreS, 78 | List> ioScoreWithT, 79 | List> ioScoreWithS) { 80 | double countWithT = 0.0, countWithS = 0.0, cnt, part, dRuleW; 81 | if (ioScoreWithT != null) { 82 | for (Map iosWithT : ioScoreWithT) { 83 | cnt = 1.0; 84 | boolean found = false; 85 | for (Entry ios : iosWithT.entrySet()) { 86 | part = ios.getValue().evalInsideOutside(truths.get(ios.getKey()), false); 87 | cnt *= part; 88 | found = true; 89 | } 90 | if (found) { countWithT += cnt; } 91 | } 92 | } 93 | if (ioScoreWithS != null) { 94 | for (Map iosWithS : ioScoreWithS) { 95 | cnt = 1.0; 96 | boolean found = false; 97 | for (Entry ios : iosWithS.entrySet()) { 98 | part = ios.getValue().evalInsideOutside(truths.get(ios.getKey()), false); 99 | cnt *= part; 100 | found = true; 101 | } 102 | if (found) { countWithS += cnt; } 103 | } 104 | } 105 | dRuleW = Math.exp(Math.log(countWithS) - scoreS) - Math.exp(Math.log(countWithT) - scoreT); 106 | return dRuleW; 107 | } 108 | 109 | 110 | /** 111 | * @param rule 112 | * @param ioScoreWithT 113 | * @param ioScoreWithS 114 | * @param scoresOfSAndT 115 | */ 116 | public void optimize( 117 | GrammarRule rule, 118 | Batch ioScoreWithT, 119 | Batch ioScoreWithS, 120 | List scoresSandT) { 121 | int batchsize = scoresSandT.size() / 2; 122 | GaussianMixture ruleW = rule.getWeight(); 123 | List> iosWithT, iosWithS; 124 | boolean removed = false, cumulative, updated; 125 | double scoreT, scoreS, dRuleW; 126 | RuleType uRuleType = null; 127 | 128 | for (int icomponent = 0; icomponent < ruleW.ncomponent(); icomponent++) { 129 | updated = false; // 130 | for (short isample = 0; isample < Optimizer.maxsample; isample++) { 131 | switch (rule.getType()) { 132 | case LRBRULE: { 133 | sample(sample.get(RuleUnit.P), ruleW.dim(icomponent, RuleUnit.P)); 134 | sample(sample.get(RuleUnit.LC), ruleW.dim(icomponent, RuleUnit.LC)); 135 | sample(sample.get(RuleUnit.RC), ruleW.dim(icomponent, RuleUnit.RC)); 136 | break; 137 | } 138 | case LRURULE: { 139 | sample(sample.get(RuleUnit.P), ruleW.dim(icomponent, RuleUnit.P)); 140 | sample(sample.get(RuleUnit.UC), ruleW.dim(icomponent, RuleUnit.UC)); 141 | break; 142 | } 143 | case LHSPACE: { 144 | sample(sample.get(RuleUnit.P), ruleW.dim(icomponent, RuleUnit.P)); 145 | uRuleType = RuleType.LHSPACE; 146 | break; 147 | } 148 | case RHSPACE: { 149 | sample(sample.get(RuleUnit.C), ruleW.dim(icomponent, RuleUnit.C)); 150 | break; 151 | } 152 | default: { 153 | logger.error("Not a valid unary grammar rule.\n"); 154 | } 155 | } 156 | ruleW.restoreSample(icomponent, sample, truths); 157 | for (short i = 0; i < batchsize; i++) { 158 | iosWithT = ioScoreWithT.get(i); 159 | iosWithS = ioScoreWithS.get(i); 160 | if (iosWithT == null && iosWithS == null) { continue; } // zero counts 161 | scoreT = scoresSandT.get(i * 2); 162 | scoreS = scoresSandT.get(i * 2 + 1); 163 | /** 164 | * For the rule A->w, count(A->w) = o(A->w) * i(A->w) = o(A->w) * w(A->w). 165 | * For the sub-type rule r of A->w, count(a->w) = o(a->w) * w(a->w). The derivative 166 | * of the objective function w.r.t w(r) is (count(r | T_S) - count(r | S)) / w(r), 167 | * which contains the term 1 / w(r), thus we could eliminate w(r) when computing it. 168 | */ 169 | if (!removed && uRuleType == RuleType.LHSPACE) { 170 | if (iosWithT != null) { 171 | for (Map ios : iosWithT) { ios.remove(RuleUnit.C); } 172 | } 173 | if (iosWithS != null) { 174 | for (Map ios : iosWithS) { ios.remove(RuleUnit.C); } 175 | } 176 | } 177 | // cumulative = (isample + i) > 0; // incorrect version 178 | // CHECK when to clear old gradients and accumulate new gradients 179 | cumulative = isample > 0 ? true : (updated ? true : false); 180 | dRuleW = derivateRuleWeight(scoreT, scoreS, iosWithT, iosWithS); 181 | ruleW.derivative(cumulative, icomponent, dRuleW, sample, ggrads, wgrads, true); 182 | updated = true; // CHECK do we need to update in the case where derivative() was not invoked. 183 | } 184 | removed = true; // CHECK avoid impossible remove 185 | } 186 | if (updated) { 187 | ruleW.update(icomponent, ggrads, wgrads, Optimizer.minexp); 188 | } 189 | } 190 | } 191 | 192 | 193 | protected void sample(List slice, int dim) { 194 | slice.clear(); 195 | for (int i = 0; i < dim; i++) { 196 | slice.add(Optimizer.rnd.nextGaussian()); 197 | } 198 | } 199 | 200 | } 201 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/optimization/SimpleOptimizer.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.optimization; 2 | 3 | import java.util.HashMap; 4 | import java.util.HashSet; 5 | import java.util.List; 6 | import java.util.Random; 7 | 8 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule; 9 | 10 | /** 11 | * @author Yanpeng Zhao 12 | * 13 | */ 14 | public class SimpleOptimizer extends Optimizer { 15 | /** 16 | * 17 | */ 18 | private static final long serialVersionUID = -1474006225613054835L; 19 | private SimpleMinimizer minimizer; 20 | 21 | private SimpleOptimizer() { 22 | this.cntsWithS = new HashMap<>(); 23 | this.cntsWithT = new HashMap<>(); 24 | this.ruleSet = new HashSet<>(); 25 | } 26 | 27 | 28 | public SimpleOptimizer(Random random) { 29 | this(); 30 | rnd = random; 31 | this.minimizer = new SimpleMinimizer(random, maxsample, batchsize); 32 | } 33 | 34 | 35 | public SimpleOptimizer(Random random, short msample, short bsize) { 36 | this(); 37 | rnd = random; 38 | batchsize = bsize; 39 | maxsample = msample; 40 | this.minimizer = new SimpleMinimizer(random, maxsample, batchsize); 41 | } 42 | 43 | 44 | @Override 45 | public void applyGradientDescent(List scoresST) { 46 | if (scoresST.size() == 0) { return; } 47 | Batch cntWithT, cntWithS; 48 | for (GrammarRule rule : ruleSet) { 49 | cntWithT = cntsWithT.get(rule); 50 | cntWithS = cntsWithS.get(rule); 51 | if (cntWithT.size() == 0 && cntWithS.size() == 0) { continue; } 52 | minimizer.optimize(rule, cntWithT, cntWithS, scoresST); 53 | } 54 | reset(); 55 | } 56 | 57 | 58 | @Override 59 | public void evalGradients(List scoreSandT) { 60 | return; 61 | } 62 | 63 | 64 | @Override 65 | public void addRule(GrammarRule rule) { 66 | ruleSet.add(rule); 67 | Batch batchWithT = new Batch(-1); 68 | Batch batchWithS = new Batch(-1); 69 | cntsWithT.put(rule, batchWithT); 70 | cntsWithS.put(rule, batchWithS); 71 | } 72 | 73 | 74 | @Override 75 | public void reset() { 76 | Batch cntWithT, cntWithS; 77 | for (GrammarRule rule : ruleSet) { 78 | if ((cntWithT = cntsWithT.get(rule)) != null) { cntWithT.clear(); } 79 | if ((cntWithS = cntsWithS.get(rule)) != null) { cntWithS.clear(); } 80 | } 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/syntax/State.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.syntax; 2 | 3 | import java.io.Serializable; 4 | 5 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 6 | import edu.shanghaitech.ai.nlp.util.Numberer; 7 | 8 | /** 9 | * Represent the nodes, non-terminals or terminals (words), of parse tree. 10 | * Each State encapsulates a tag (word), recording the id and the name of 11 | * the tag (word). 12 | * 13 | * @author Yanpeng Zhao 14 | * 15 | */ 16 | public class State implements Serializable { 17 | /** 18 | * 19 | */ 20 | private static final long serialVersionUID = -7306026179545408514L; 21 | private String name; 22 | private short id; 23 | 24 | public short from; 25 | public short to; 26 | 27 | /* if the node is a terminal */ 28 | public int wordIdx; // index of the word 29 | public int signIdx; // signature index of the word 30 | 31 | protected GaussianMixture insideScore; 32 | protected GaussianMixture outsideScore; 33 | 34 | 35 | /** 36 | * @param name name of the tag, null for non-terminals, and word itself for terminals 37 | * @param id id of the tag 38 | * @param from starting point of the span of the current state 39 | * @param to ending point of the span of the current state 40 | * 41 | */ 42 | public State(String name, short id, short from, short to) { 43 | this.name = name; 44 | this.from = from; 45 | this.to = to; 46 | this.id = id; 47 | this.wordIdx = -1; 48 | this.signIdx = -1; 49 | } 50 | 51 | 52 | public State(State state, boolean copyScore) { 53 | this.name = state.name; 54 | this.from = state.from; 55 | this.to = state.to; 56 | this.id = state.id; 57 | this.wordIdx = state.wordIdx; 58 | this.signIdx = state.signIdx; 59 | 60 | if (copyScore) { 61 | this.insideScore = state.insideScore; 62 | this.outsideScore = state.outsideScore; 63 | } 64 | } 65 | 66 | 67 | public State copy() { 68 | return new State(this, false); 69 | } 70 | 71 | 72 | public State copy(boolean copyScore) { 73 | return new State(this, copyScore); 74 | } 75 | 76 | 77 | public String getName() { 78 | return name; 79 | } 80 | 81 | 82 | public void setName(String name) { 83 | this.name = name; 84 | } 85 | 86 | 87 | public short getId() { 88 | return id; 89 | } 90 | 91 | 92 | public void setId(short id) { 93 | this.id = id; 94 | } 95 | 96 | 97 | public GaussianMixture getInsideScore() { 98 | return insideScore; 99 | } 100 | 101 | 102 | public void setInsideScore(GaussianMixture insideScore) { 103 | this.insideScore = insideScore; 104 | } 105 | 106 | 107 | public GaussianMixture getOutsideScore() { 108 | return outsideScore; 109 | } 110 | 111 | 112 | public void setOutsideScore(GaussianMixture outsideScore) { 113 | this.outsideScore = outsideScore; 114 | } 115 | 116 | 117 | private void resetScore() { 118 | if (insideScore != null) { 119 | insideScore.clear(true); 120 | } 121 | if (outsideScore != null) { 122 | outsideScore.clear(true); 123 | } 124 | this.insideScore = null; 125 | this.outsideScore = null; 126 | } 127 | 128 | 129 | public void clear(boolean deep) { 130 | if (deep) { 131 | this.name = null; 132 | this.from = -1; 133 | this.to = -1; 134 | this.id = -1; 135 | this.wordIdx = -1; 136 | this.signIdx = -1; 137 | } 138 | resetScore(); 139 | } 140 | 141 | 142 | public void clear() { 143 | clear(true); 144 | } 145 | 146 | 147 | public String toString(boolean simple, short nfirst, Numberer numberer) { 148 | if (simple) { 149 | return toString(); 150 | } else { 151 | StringBuffer sb = new StringBuffer(); 152 | name = name != null ? name : (String) numberer.object(id); 153 | sb.append("State [name=" + name + ", id=" + id + ", from=" + from + ", to=" + to + "]"); 154 | if (insideScore != null) { 155 | sb.append("->[iscore="); 156 | sb.append(insideScore.toString(!simple, nfirst)); 157 | sb.append("]"); 158 | } else { 159 | sb.append("->[iscore=null]"); 160 | } 161 | if (outsideScore != null) { 162 | sb.append("->[oscore="); 163 | sb.append(outsideScore.toString(!simple, nfirst)); 164 | sb.append("]"); 165 | } else { 166 | sb.append("->[oscore=null]"); 167 | } 168 | return sb.toString(); 169 | } 170 | } 171 | 172 | 173 | public String toString(Numberer numberer) { 174 | name = name != null ? name : (String) (String) numberer.object(id); 175 | return "State [name=" + name + ", id=" + id + ", from=" + from + ", to=" + to + "]"; 176 | } 177 | 178 | } 179 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/syntax/Tree.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.syntax; 2 | 3 | 4 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/util/ErrorUtil.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.util; 2 | 3 | public class ErrorUtil extends Error { 4 | 5 | /** 6 | * 7 | */ 8 | private static final long serialVersionUID = 1L; 9 | 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/util/Executor.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.util; 2 | 3 | import java.io.Serializable; 4 | import java.util.PriorityQueue; 5 | import java.util.concurrent.Callable; 6 | 7 | public interface Executor extends Callable, Serializable { 8 | public static class Meta { 9 | public int id; 10 | public O value; 11 | public Meta(int id, O value) { 12 | this.id = id; 13 | this.value = value; 14 | } 15 | public O value() { return value; } 16 | @Override 17 | public String toString() { 18 | return "Meta [id=" + id + ", value=" + value + "]"; 19 | } 20 | } 21 | 22 | public abstract Executor newInstance(); 23 | 24 | public abstract void setNextTask(int itask, I task); 25 | 26 | public abstract void setIdx(int idx, PriorityQueue> caches); 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/util/FunUtil.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.util; 2 | 3 | import java.awt.AlphaComposite; 4 | import java.awt.Graphics2D; 5 | import java.awt.geom.Rectangle2D; 6 | import java.awt.image.BufferedImage; 7 | import java.io.File; 8 | import java.text.DecimalFormat; 9 | import java.text.NumberFormat; 10 | import java.util.ArrayList; 11 | import java.util.Comparator; 12 | import java.util.List; 13 | import java.util.Map; 14 | import java.util.Random; 15 | import java.util.Map.Entry; 16 | 17 | import javax.imageio.ImageIO; 18 | 19 | import edu.berkeley.nlp.syntax.Tree; 20 | import edu.berkeley.nlp.ui.TreeJPanel; 21 | import edu.shanghaitech.ai.nlp.data.StateTreeList; 22 | import edu.shanghaitech.ai.nlp.lveg.LVeGTrainer; 23 | import edu.shanghaitech.ai.nlp.util.Numberer; 24 | import edu.shanghaitech.ai.nlp.syntax.State; 25 | 26 | /** 27 | * Useful methods for debugging or ... 28 | * 29 | * @author Yanpeng Zhao 30 | * 31 | */ 32 | public class FunUtil extends Recorder { 33 | /** 34 | * 35 | */ 36 | private static final long serialVersionUID = -9216654024276471124L; 37 | public final static double LOG_ZERO = -1.0e20; 38 | public final static double LOG_TINY = -0.5e20; 39 | public final static double EXP_ZERO = -Math.log(-LOG_ZERO); 40 | public final static NumberFormat formatter = new DecimalFormat("0.###E0"); 41 | 42 | private static Random random = new Random(LVeGTrainer.randomseed); 43 | 44 | public static Comparator> keycomparator = new Comparator>() { 45 | @Override 46 | public int compare(Entry o1, Entry o2) { 47 | return o1.getKey() - o2.getKey(); 48 | } 49 | }; 50 | 51 | public static class KeyComparator implements Comparator { 52 | Map map; 53 | public KeyComparator(Map map) { 54 | this.map = map; 55 | } 56 | @Override 57 | public int compare(Integer o1, Integer o2) { 58 | return o1 - o2; 59 | } 60 | } 61 | 62 | 63 | /** 64 | * @param dir the dir that needs to be created 65 | * @return true if created successfully 66 | */ 67 | public static boolean mkdir(String dir) { 68 | File file = new File(dir); 69 | if (!file.exists()) { 70 | if (file.mkdir() || file.mkdirs()) { 71 | return true; 72 | } 73 | } 74 | return false; 75 | } 76 | 77 | 78 | /** 79 | * Return log(a + b) given log(a) and log(b). 80 | * 81 | * @param x in logarithm 82 | * @param y in logarithm 83 | * @return 84 | */ 85 | public static double logAdd(double x, double y) { 86 | double tmp, diff; 87 | if (x < y) { 88 | tmp = x; 89 | x = y; 90 | y = tmp; 91 | } 92 | diff = y - x; // <= 0 93 | if (diff < EXP_ZERO) { 94 | // if y is far smaller than x 95 | return x < LOG_TINY ? LOG_ZERO : x; 96 | } else { 97 | return x + Math.log(1.0 + Math.exp(diff)); 98 | } 99 | } 100 | 101 | 102 | /** 103 | * Match a number with optional '-' and decimal. 104 | * 105 | * @param str the string 106 | * @return 107 | */ 108 | public static boolean isNumeric(String str){ 109 | return str.matches("[-+]?\\d*\\.?\\d+"); 110 | } 111 | 112 | 113 | /** 114 | * @param stateTree the state parse tree 115 | * @param filename image name 116 | * @param stringTree the string parse tree 117 | * @throws Exception oops 118 | */ 119 | public static void saveTree2image(Tree stateTree, String filename, Tree stringTree, Numberer numberer) throws Exception { 120 | TreeJPanel tjp = new TreeJPanel(); 121 | if (stringTree == null) { 122 | stringTree = StateTreeList.stateTreeToStringTree(stateTree, numberer); 123 | logger.trace("\nSTRING PARSE TREE: " + stringTree + "\n"); 124 | } 125 | 126 | tjp.setTree(stringTree); 127 | BufferedImage bi = new BufferedImage(tjp.width(), tjp.height(), BufferedImage.TYPE_INT_ARGB); 128 | 129 | Graphics2D g2 = bi.createGraphics(); 130 | g2.setComposite(AlphaComposite.getInstance(AlphaComposite.CLEAR, 1.0f)); 131 | Rectangle2D.Double rect = new Rectangle2D.Double(0, 0, tjp.width(), tjp.height()); 132 | g2.fill(rect); 133 | g2.setComposite(AlphaComposite.getInstance(AlphaComposite.SRC_OVER, 1.0f)); 134 | tjp.paintComponent(g2); 135 | g2.dispose(); 136 | 137 | ImageIO.write(bi, "png", new File(filename + ".png")); 138 | } 139 | 140 | 141 | /** 142 | * @param list item container 143 | * @param type Doubel.class or Integer.class 144 | * @param length number of items in the list 145 | * @param maxint maximum for integer, and 1 for double 146 | * @param nonzero zero inclusive (false) or exclusive (true) 147 | * @param negative allow the negative (true) or not allow (false) 148 | */ 149 | public static void randomInitList(Random rnd, List list, Class type, int length, 150 | int maxint, double ratio, boolean nonzero, boolean negative) { 151 | Double obj = new Double(0); 152 | for (int i = 0; i < length; i++) { 153 | double tmp = rnd.nextDouble() * maxint; 154 | if (nonzero) { while (tmp == 0.0) { tmp = rnd.nextDouble() * maxint; } } 155 | if (negative && tmp < ratio) { tmp = 0 - tmp; } 156 | list.add(type.isInstance(obj) ? type.cast(tmp) : type.cast((int) tmp)); 157 | } 158 | } 159 | 160 | 161 | /** 162 | * @param array item container 163 | * @param type Double.class or Integer.class 164 | * @param maxint maximum for integer, and 1 for double 165 | * 166 | */ 167 | public static void randomInitArray(T[] array, Class type, int maxint) { 168 | Double obj = new Double(0); 169 | for (int i = 0; i < array.length; i++) { 170 | double tmp = random.nextDouble() * maxint; 171 | array[i] = type.isInstance(obj) ? type.cast(tmp) : type.cast((int) tmp); 172 | } 173 | 174 | } 175 | 176 | 177 | public static void randomInitArrayInt(int[] array, int maxint) { 178 | for (int i = 0; i < array.length; i++) { 179 | array[i] = (int) (random.nextDouble() * maxint); 180 | } 181 | } 182 | 183 | 184 | public static void randomInitArrayDouble(double[] array) { 185 | for (int i = 0; i < array.length; i++) { 186 | array[i] = random.nextDouble(); 187 | } 188 | } 189 | 190 | 191 | /** 192 | * @param list a list of doubles 193 | * @param precision double precision 194 | * @param nfirst print first # of items 195 | * @param exponential whether the list should be read in the exponential form or not 196 | * @return 197 | */ 198 | public static List double2str(List list, int precision, int nfirst, boolean exponential, boolean scientific) { 199 | List strs = new ArrayList(); 200 | String format = "%." + precision + "f", str; 201 | if (nfirst < 0 || nfirst > list.size()) { nfirst = list.size(); } 202 | for (int i = 0; i < nfirst; i++) { 203 | double value = exponential ? Math.exp(list.get(i)) : list.get(i); 204 | str = scientific ? formatter.format(value) : String.format(format, value); 205 | strs.add(str); 206 | } 207 | return strs; 208 | } 209 | 210 | 211 | /** 212 | * @param list a list of doubles 213 | * @param exponential whether the list should be read in the exponential form or not 214 | * @return 215 | */ 216 | public static double sum(List list, boolean exponential) { 217 | double sum = 0.0; 218 | for (Double d : list) { 219 | sum += exponential ? Math.exp(d) : d; 220 | } 221 | return sum; 222 | } 223 | 224 | 225 | public static void printArrayInt(int[] array) { 226 | String string = "["; 227 | for (int i = 0; i < array.length - 1; i++) { 228 | string += array[i] + ", "; 229 | } 230 | string += array[array.length - 1] + "]"; 231 | logger.trace(string + "\n"); 232 | } 233 | 234 | 235 | public static void printArrayDouble(double[] array) { 236 | String string = "["; 237 | for (int i = 0; i < array.length - 1; i++) { 238 | string += array[i] + ", "; 239 | } 240 | string += array[array.length - 1] + "]"; 241 | logger.trace(string + "\n"); 242 | } 243 | 244 | 245 | public static void printArray(T[] array) { 246 | if (isEmpty(array)) { return; } 247 | String string = "["; 248 | for (int i = 0; i < array.length - 1; i++) { 249 | string += array[i] + ", "; 250 | } 251 | string += array[array.length - 1] + "]"; 252 | logger.trace(string + "\n"); 253 | } 254 | 255 | 256 | public static void printList(List list) { 257 | if (isEmpty(list)) { return; } 258 | String string = "["; 259 | for (int i = 0; i < list.size() - 1; i++) { 260 | string += list.get(i) + ", "; 261 | } 262 | string += list.get(list.size() - 1) + "]"; 263 | logger.trace(string + "\n"); 264 | } 265 | 266 | 267 | public static boolean isEmpty(T[] array) { 268 | if (array == null || array.length == 0) { 269 | logger.error("[null or empty]\n"); 270 | return true; 271 | } 272 | return false; 273 | } 274 | 275 | 276 | public static boolean isEmpty(List list) { 277 | if (list == null || list.isEmpty()) { 278 | logger.error("[null or empty]\n"); 279 | return true; 280 | } 281 | return false; 282 | } 283 | 284 | } 285 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/util/GradientChecker.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.util; 2 | 3 | import java.util.ArrayList; 4 | import java.util.EnumMap; 5 | import java.util.List; 6 | import java.util.Map; 7 | import java.util.Map.Entry; 8 | 9 | import edu.berkeley.nlp.syntax.Tree; 10 | import edu.shanghaitech.ai.nlp.lveg.impl.LVeGParser; 11 | import edu.shanghaitech.ai.nlp.lveg.impl.Valuator; 12 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianDistribution; 13 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 14 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule; 15 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule.RuleType; 16 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule.RuleUnit; 17 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGGrammar; 18 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGLexicon; 19 | import edu.shanghaitech.ai.nlp.optimization.Gradient.Grads; 20 | import edu.shanghaitech.ai.nlp.syntax.State; 21 | 22 | public class GradientChecker extends Recorder { 23 | /** 24 | * 25 | */ 26 | private static final long serialVersionUID = -1234982154206455805L; 27 | 28 | 29 | public static void gradcheck(LVeGGrammar grammar, LVeGLexicon lexicon, LVeGParser lvegParser, 30 | Valuator valuator, Tree tree, double maxsample) { 31 | double delta = 1e-5; 32 | Map uRuleMap = grammar.getURuleMap(); 33 | for (Map.Entry entry : uRuleMap.entrySet()) { 34 | gradcheck(grammar, lexicon, entry, lvegParser, valuator, tree, delta, maxsample); 35 | } 36 | uRuleMap = lexicon.getURuleMap(); 37 | for (Map.Entry entry : uRuleMap.entrySet()) { 38 | gradcheck(grammar, lexicon, entry, lvegParser, valuator, tree, delta, maxsample); 39 | } 40 | } 41 | 42 | 43 | public static void gradcheck(LVeGGrammar grammar, LVeGLexicon lexicon, Map.Entry entry, 44 | LVeGParser lvegParser, Valuator valuator, Tree tree, double delta, double maxsample) { 45 | GaussianMixture gm = entry.getValue().getWeight(); 46 | double src = gm.getWeight(0); 47 | 48 | double ltInit = lvegParser.doInsideOutsideWithTree(tree); 49 | double lsInit = lvegParser.doInsideOutside(tree, -1); 50 | double llInit = ltInit - lsInit; 51 | 52 | // w.r.t. mixing weight 53 | gm.setWeight(0, gm.getWeight(0) + delta); 54 | 55 | 56 | // double llBefore = valuator.probability(tree); 57 | double ltBefore = lvegParser.doInsideOutsideWithTree(tree); 58 | double lsBefore = lvegParser.doInsideOutside(tree, -1); 59 | double llBefore = ltBefore - lsBefore; 60 | 61 | 62 | double t1 = gm.getWeight(0); 63 | 64 | gm.setWeight(0, gm.getWeight(0) - 2 * delta); 65 | 66 | 67 | // double llAfter = valuator.probability(tree); 68 | double ltAfter = lvegParser.doInsideOutsideWithTree(tree); 69 | double lsAfter = lvegParser.doInsideOutside(tree, -1); 70 | double llAfter = ltAfter - lsAfter; 71 | /* 72 | logger.trace( 73 | "\nltI: " + ltInit + "\tlsI: " + lsInit + "\tllI: " + llInit + "\n" + 74 | "ltB: " + ltBefore + "\tlsB: " + lsBefore + "\tllB: " + llBefore + "\n" + 75 | "ltA: " + ltAfter + "\tlsA: " + lsAfter + "\tllA: " + llAfter); 76 | */ 77 | double t2 = gm.getWeight(0); 78 | 79 | // restore 80 | gm.setWeight(0, gm.getWeight(0) + delta); 81 | double des = gm.getWeight(0); 82 | double numericalGrad = -(llBefore - llAfter) / ((t1 - t2)); 83 | 84 | logger.trace("\n-----\nRule: " + entry.getKey() + "\nGrad Weight: " + 85 | numericalGrad + "=(" + llBefore + " - " + llAfter + ")/(" + (t1 - t2) + ")\n" + 86 | "B : " + src + "\tA : " + des + "\t(B - A) =" + (des - src) + "\n" + 87 | "t1: " + t1 + "\tt2: " + t2 + "\t(t1 - t2)=" + (t1 - t2) + "\n-----\n"); 88 | 89 | 90 | gradcheckmu(entry, lvegParser, valuator, tree, delta); 91 | gradcheckvar(entry, lvegParser, valuator, tree, delta); 92 | 93 | 94 | Object gradients = null; 95 | if (entry.getKey().type != RuleType.LHSPACE) { 96 | gradients = grammar.getOptimizer().debug(entry.getKey(), false); 97 | } else { 98 | gradients = lexicon.getOptimizer().debug(entry.getKey(), false); 99 | } 100 | // divide it by # of samplings 101 | StringBuffer sb = new StringBuffer(); 102 | if (gradients != null) { 103 | Grads grads = (Grads) gradients; 104 | sb.append("\n---\nWgrads: "); 105 | List wgrads = new ArrayList<>(grads.wgrads.size()); 106 | for (Double dw : grads.wgrads) { 107 | wgrads.add(dw / maxsample); 108 | } 109 | List>> ggrads = new ArrayList<>(grads.ggrads.size()); 110 | for (Map> comp : grads.ggrads) { 111 | EnumMap> gauss = new EnumMap<>(RuleUnit.class); 112 | for (Entry> gaussian : comp.entrySet()) { 113 | List params = new ArrayList<>(gaussian.getValue().size()); 114 | for (Double dg : gaussian.getValue()) { 115 | params.add(dg / maxsample); 116 | } 117 | gauss.put(gaussian.getKey(), params); 118 | } 119 | ggrads.add(gauss); 120 | } 121 | logger.trace("\n---\nWgrads: " + wgrads + "\nGgrads: " + ggrads + "\n---\n"); 122 | } 123 | } 124 | 125 | 126 | public static void gradcheckmu(Map.Entry entry, 127 | LVeGParser lvegParser, Valuator valuator, Tree tree, double delta) { 128 | GaussianMixture gm = entry.getValue().getWeight(); 129 | 130 | double ltInit = lvegParser.doInsideOutsideWithTree(tree); 131 | double lsInit = lvegParser.doInsideOutside(tree, -1); 132 | double llInit = ltInit - lsInit; 133 | 134 | // w.r.t. mixing weight 135 | GaussianDistribution gd = gm.getComponent((short) 0).squeeze(null); 136 | List mus = gd.getMus(); 137 | 138 | double src = mus.get(0); 139 | mus.set(0, mus.get(0) + delta); 140 | 141 | 142 | // double llBefore = valuator.probability(tree); 143 | double ltBefore = lvegParser.doInsideOutsideWithTree(tree); 144 | double lsBefore = lvegParser.doInsideOutside(tree, -1); 145 | double llBefore = ltBefore - lsBefore; 146 | 147 | 148 | double t1 = mus.get(0); 149 | 150 | mus.set(0, mus.get(0) - 2 * delta); 151 | 152 | 153 | // double llAfter = valuator.probability(tree); 154 | double ltAfter = lvegParser.doInsideOutsideWithTree(tree); 155 | double lsAfter = lvegParser.doInsideOutside(tree, -1); 156 | double llAfter = ltAfter - lsAfter; 157 | /* 158 | logger.trace( 159 | "\nltI: " + ltInit + "\tlsI: " + lsInit + "\tllI: " + llInit + "\n" + 160 | "ltB: " + ltBefore + "\tlsB: " + lsBefore + "\tllB: " + llBefore + "\n" + 161 | "ltA: " + ltAfter + "\tlsA: " + lsAfter + "\tllA: " + llAfter); 162 | */ 163 | double t2 = mus.get(0); 164 | 165 | // restore 166 | mus.set(0, mus.get(0) + delta); 167 | double des = mus.get(0); 168 | double numericalGrad = -(llBefore - llAfter) / ((t1 - t2)); 169 | 170 | logger.trace("\n-----\nRule: " + entry.getKey() + "\nGrad MU : " + 171 | numericalGrad + "=(" + llBefore + " - " + llAfter + ")/(" + (t1 - t2) + ")\n" + 172 | "B : " + src + "\tA : " + des + "\t(B - A) =" + (des - src) + "\n" + 173 | "t1: " + t1 + "\tt2: " + t2 + "\t(t1 - t2)=" + (t1 - t2) + "\n-----\n"); 174 | } 175 | 176 | 177 | public static void gradcheckvar(Map.Entry entry, 178 | LVeGParser lvegParser, Valuator valuator, Tree tree, double delta) { 179 | GaussianMixture gm = entry.getValue().getWeight(); 180 | 181 | double ltInit = lvegParser.doInsideOutsideWithTree(tree); 182 | double lsInit = lvegParser.doInsideOutside(tree, -1); 183 | double llInit = ltInit - lsInit; 184 | 185 | // w.r.t. mixing weight 186 | GaussianDistribution gd = gm.getComponent((short) 0).squeeze(null); 187 | List vars = gd.getVars(); 188 | 189 | double src = vars.get(0); 190 | vars.set(0, vars.get(0) + delta); 191 | 192 | 193 | // double llBefore = valuator.probability(tree); 194 | double ltBefore = lvegParser.doInsideOutsideWithTree(tree); 195 | double lsBefore = lvegParser.doInsideOutside(tree, -1); 196 | double llBefore = ltBefore - lsBefore; 197 | 198 | 199 | double t1 = vars.get(0); 200 | 201 | vars.set(0, vars.get(0) - 2 * delta); 202 | 203 | 204 | // double llAfter = valuator.probability(tree); 205 | double ltAfter = lvegParser.doInsideOutsideWithTree(tree); 206 | double lsAfter = lvegParser.doInsideOutside(tree, -1); 207 | double llAfter = ltAfter - lsAfter; 208 | /* 209 | logger.trace( 210 | "\nltI: " + ltInit + "\tlsI: " + lsInit + "\tllI: " + llInit + "\n" + 211 | "ltB: " + ltBefore + "\tlsB: " + lsBefore + "\tllB: " + llBefore + "\n" + 212 | "ltA: " + ltAfter + "\tlsA: " + lsAfter + "\tllA: " + llAfter); 213 | */ 214 | double t2 = vars.get(0); 215 | 216 | // restore 217 | vars.set(0, vars.get(0) + delta); 218 | double des = vars.get(0); 219 | double numericalGrad = -(llBefore - llAfter) / ((t1 - t2) /*2 * delta*/); 220 | 221 | logger.trace("\n-----\nRule: " + entry.getKey() + "\nGrad VAR : " + 222 | numericalGrad + "=(" + llBefore + " - " + llAfter + ")/(" + (t1 - t2) + ")\n" + 223 | "B : " + src + "\tA : " + des + "\t(B - A) =" + (des - src) + "\n" + 224 | "t1: " + t1 + "\tt2: " + t2 + "\t(t1 - t2)=" + (t1 - t2) + "\n-----\n"); 225 | } 226 | } 227 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/util/Indexer.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.util; 2 | 3 | import java.io.Serializable; 4 | import java.util.AbstractList; 5 | import java.util.ArrayList; 6 | import java.util.Collection; 7 | import java.util.HashMap; 8 | import java.util.List; 9 | import java.util.Map; 10 | 11 | public class Indexer extends AbstractList implements Serializable { 12 | /** 13 | * 14 | */ 15 | private static final long serialVersionUID = 1L; 16 | private boolean locked = false; 17 | private List objects; 18 | private Map indexes; 19 | 20 | public Indexer() { 21 | objects = new ArrayList(); 22 | indexes = new HashMap(); 23 | } 24 | 25 | public Indexer(Collection c) { 26 | this(); 27 | for (E e : c) { index(e); } 28 | } 29 | 30 | /** 31 | * Look up the index of the given element, would be added 32 | * if the element does not exist. 33 | * 34 | * @param e the element to be looked up 35 | * @return the index of the given element 36 | */ 37 | public int index(E e) { 38 | if (e == null) { return -1; } 39 | Integer idx = indexes.get(e); 40 | if (idx == null) { 41 | if (locked) { return -1; } 42 | idx = size(); 43 | objects.add(e); 44 | indexes.put(e, idx); 45 | } 46 | return idx; 47 | } 48 | 49 | public int indexof(Object o) { 50 | Integer idx = indexes.get(o); 51 | return idx == null ? -1 : idx; 52 | } 53 | 54 | public boolean add(E e) { 55 | if (locked) { 56 | throw new IllegalStateException("Tried to add to locked indexer"); 57 | } 58 | if (contains(e)) { return false; } 59 | indexes.put(e, size()); 60 | objects.add(e); 61 | return true; 62 | } 63 | 64 | public boolean contains(Object o) { 65 | return indexes.containsKey(o); 66 | } 67 | 68 | @Override 69 | public E get(int index) { 70 | return objects.get(index); 71 | } 72 | 73 | @Override 74 | public int size() { 75 | return objects.size(); 76 | } 77 | 78 | @Override 79 | public void clear() { 80 | objects.clear(); 81 | indexes.clear(); 82 | } 83 | 84 | public List getObjects() { 85 | return objects; 86 | } 87 | 88 | public void lock() { 89 | this.locked = true; 90 | } 91 | 92 | public void unlock() { 93 | this.locked = false; 94 | } 95 | 96 | @Override 97 | public String toString() { 98 | return "Indexer [locked=" + locked + ", objects=" + objects + ", indexes=" + indexes + "]"; 99 | } 100 | 101 | } 102 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/util/LogUtil.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.util; 2 | 3 | import java.io.Serializable; 4 | 5 | import org.apache.log4j.Logger; 6 | import org.apache.log4j.xml.DOMConfigurator; 7 | 8 | /** 9 | * It contains two kinds of loggers, console logger and file logger, which 10 | * are supposed not to be used at the same time. If you must use both of 11 | * them together, please invoke getConsoleLogger before 12 | * getFileLogger. That is the best way I can think of. 13 | * 14 | * @author Yanpeng Zhao 15 | * 16 | */ 17 | public class LogUtil implements Serializable { 18 | /** 19 | * 20 | */ 21 | private static final long serialVersionUID = -113853579415943859L; 22 | 23 | private static LogUtil instance; 24 | 25 | private static Logger logCons = null; 26 | private static Logger logFile = null; 27 | private static Logger logBoth = null; 28 | 29 | private final static String KEY = "log.name"; 30 | private final static String LOGGER_XML = "config/log4j.xml"; 31 | 32 | 33 | private LogUtil() { 34 | logFile = Logger.getLogger("FILE"); 35 | logCons = Logger.getLogger("CONSOLE"); 36 | logBoth = Logger.getRootLogger(); 37 | } 38 | 39 | 40 | public static LogUtil getLogger() { 41 | if (instance == null) { 42 | instance = new LogUtil(); 43 | } 44 | return instance; 45 | } 46 | 47 | 48 | public Logger getBothLogger(String log) { 49 | System.setProperty(KEY, log); 50 | DOMConfigurator.configure(LOGGER_XML); 51 | return logBoth; 52 | } 53 | 54 | 55 | public Logger getConsoleLogger() { 56 | DOMConfigurator.configure(LOGGER_XML); 57 | return logCons; 58 | } 59 | 60 | 61 | public Logger getFileLogger(String log) { 62 | System.setProperty(KEY, log); 63 | DOMConfigurator.configure(LOGGER_XML); 64 | return logFile; 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/util/MutableInteger.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.util; 2 | 3 | /** 4 | * @author Dan Klein 5 | * 6 | */ 7 | public class MutableInteger extends Number implements Comparable { 8 | /** 9 | * 10 | */ 11 | private static final long serialVersionUID = 6612815442080956970L; 12 | private int i; 13 | 14 | public MutableInteger() { 15 | this(0); 16 | } 17 | 18 | public MutableInteger(int i) { 19 | this.i = i; 20 | } 21 | 22 | public void set(int i) { 23 | this.i = i; 24 | } 25 | 26 | public int compareTo(MutableInteger mi) { 27 | return (i < mi.i ? -1 : (i == mi.i ? 0 : 1)); 28 | } 29 | 30 | @Override 31 | public int compareTo(Object o) { 32 | return compareTo((MutableInteger) o); 33 | } 34 | 35 | @Override 36 | public int hashCode() { 37 | return i; 38 | } 39 | 40 | @Override 41 | public boolean equals(Object obj) { 42 | if (obj instanceof MutableInteger) { 43 | return i == ((MutableInteger) obj).i; 44 | } 45 | return false; 46 | } 47 | 48 | @Override 49 | public int intValue() { 50 | return i; 51 | } 52 | 53 | @Override 54 | public byte byteValue() { 55 | return (byte) i; 56 | } 57 | 58 | @Override 59 | public long longValue() { 60 | return i; 61 | } 62 | 63 | @Override 64 | public short shortValue() { 65 | return (short) i; 66 | } 67 | 68 | @Override 69 | public float floatValue() { 70 | return i; 71 | } 72 | 73 | @Override 74 | public double doubleValue() { 75 | return i; 76 | } 77 | 78 | @Override 79 | public String toString() { 80 | return Integer.toString(i); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/util/Numberer.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.util; 2 | 3 | import java.io.Serializable; 4 | import java.util.HashMap; 5 | import java.util.Map; 6 | import java.util.NoSuchElementException; 7 | import java.util.Set; 8 | 9 | /** 10 | * Rewrite, changing static fields to non-static ones, to be able to save to the object. 11 | * Note that we have to ensure {@link #idx} is thread-safe by synchronizing it in {@link #object(int)}. 12 | * 13 | * @author Dan Klein 14 | * @author Yanpeng Zhao 15 | * 16 | */ 17 | public class Numberer implements Serializable { 18 | /** 19 | * 20 | */ 21 | private static final long serialVersionUID = 2309277618797328501L; 22 | private Map numbererMap = new HashMap(); 23 | private boolean locked = false; 24 | private MutableInteger idx; 25 | private Map idx2obj; 26 | private Map obj2idx; 27 | private int count; 28 | 29 | public Numberer() {} 30 | 31 | public Numberer(boolean placeholder) { 32 | this.count = 0; 33 | this.idx2obj = new HashMap(); 34 | this.obj2idx = new HashMap(); 35 | this.idx = new MutableInteger(); 36 | } 37 | 38 | public void put(String key, Object o) { 39 | numbererMap.put(key, o); 40 | } 41 | 42 | public Numberer getGlobalNumberer(String key) { 43 | Numberer value = (Numberer) numbererMap.get(key); 44 | if (value == null) { 45 | value = new Numberer(true); 46 | numbererMap.put(key, value); 47 | } 48 | return value; 49 | } 50 | 51 | public int translate(String src, String des, int i) { 52 | return getGlobalNumberer(des).number( 53 | getGlobalNumberer(src).object(i)); 54 | } 55 | 56 | public boolean containsIdx(int i) { 57 | return idx2obj.containsKey(i); 58 | } 59 | 60 | public boolean containsObj(Object o) { 61 | return obj2idx.containsKey(o); // CHECK 62 | } 63 | 64 | public int number(String key, Object o) { 65 | return getGlobalNumberer(key).number(o); 66 | } 67 | 68 | 69 | public Object object(String key, int i) { 70 | return getGlobalNumberer(key).object(i); 71 | } 72 | 73 | 74 | public void setNumbererMap(Map numbererMap) { 75 | this.numbererMap = numbererMap; 76 | } 77 | 78 | 79 | public int number(Object o) { 80 | // CHECK do we really need it as the proxy of integer? 81 | MutableInteger anidx = (MutableInteger) obj2idx.get(o); 82 | if (anidx == null) { 83 | if (locked) { 84 | throw new NoSuchElementException("no object: " + o); 85 | } 86 | anidx = new MutableInteger(count); 87 | count++; 88 | obj2idx.put(o, anidx); 89 | idx2obj.put(anidx, o); 90 | } 91 | return anidx.intValue(); 92 | } 93 | 94 | /** 95 | * We do need to make "idx" thread-safe. 96 | * 97 | * @param i index of the object 98 | * @return object corresponding to the index 99 | */ 100 | public Object object(int i) { 101 | synchronized (idx) { 102 | idx.set(i); 103 | return idx2obj.get(idx); 104 | } 105 | } 106 | 107 | public Map getNumbererMap() { 108 | return numbererMap; 109 | } 110 | 111 | public Set objects() { 112 | return obj2idx.keySet(); 113 | } 114 | 115 | public void lock() { 116 | this.locked = true; 117 | } 118 | 119 | public void unlock() { 120 | this.locked = false; 121 | } 122 | 123 | public int size() { 124 | return count; 125 | } 126 | 127 | @Override 128 | public String toString() { 129 | StringBuffer sb = new StringBuffer(); 130 | sb.append("["); 131 | for (int i = 0; i < count; i++) { 132 | sb.append(i); 133 | sb.append("->"); 134 | sb.append(object(i)); 135 | if (i < count - 1) { 136 | sb.append(", "); 137 | } 138 | } 139 | sb.append("]"); 140 | return sb.toString(); 141 | } 142 | 143 | } 144 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/util/ObjectPool.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.util; 2 | 3 | import java.io.Serializable; 4 | 5 | import org.apache.commons.pool2.KeyedPooledObjectFactory; 6 | import org.apache.commons.pool2.impl.GenericKeyedObjectPool; 7 | import org.apache.commons.pool2.impl.GenericKeyedObjectPoolConfig; 8 | 9 | public class ObjectPool extends GenericKeyedObjectPool implements Serializable { 10 | public ObjectPool(KeyedPooledObjectFactory factory) { 11 | super(factory); 12 | } 13 | 14 | public ObjectPool(KeyedPooledObjectFactory factory, GenericKeyedObjectPoolConfig config) { 15 | super(factory, config); 16 | } 17 | 18 | /** 19 | * 20 | */ 21 | private static final long serialVersionUID = -1542223532794120768L; 22 | 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/util/Option.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.util; 2 | 3 | import java.lang.annotation.ElementType; 4 | import java.lang.annotation.Retention; 5 | import java.lang.annotation.RetentionPolicy; 6 | import java.lang.annotation.Target; 7 | 8 | @Retention(RetentionPolicy.RUNTIME) 9 | @Target(ElementType.FIELD) 10 | public @interface Option { 11 | String name(); 12 | String usage() default ""; 13 | String defaultValue() default ""; 14 | boolean required() default false; 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/util/OptionParser.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.util; 2 | 3 | import java.lang.reflect.Constructor; 4 | import java.lang.reflect.Field; 5 | import java.util.Arrays; 6 | import java.util.HashMap; 7 | import java.util.HashSet; 8 | import java.util.Map; 9 | import java.util.Set; 10 | 11 | public class OptionParser extends Recorder { 12 | /** 13 | * 14 | */ 15 | private static final long serialVersionUID = -3167531289091868339L; 16 | private final Map options = new HashMap<>(); 17 | private final Map fields = new HashMap<>(); 18 | private final Map, Type> types = new HashMap, Type>(); 19 | private final Set required = new HashSet<>(); 20 | private final Class clsOption; 21 | private enum Type { 22 | CHAR, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, BOOLEAN, STRING 23 | } 24 | 25 | private StringBuilder parsedOpts; 26 | 27 | public OptionParser(Class clsOption) { 28 | this.clsOption = clsOption; 29 | for (Field field : clsOption.getDeclaredFields()) { 30 | Option option = field.getAnnotation(Option.class); 31 | if (option == null) { continue; } 32 | options.put(option.name(), option); 33 | fields.put(option.name(), field); 34 | if (option.required()) { 35 | required.add(option.name()); 36 | } 37 | } 38 | types.put(char.class, Type.CHAR); 39 | types.put(byte.class, Type.BYTE); 40 | types.put(short.class, Type.SHORT); 41 | types.put(int.class, Type.INT); 42 | types.put(long.class, Type.LONG); 43 | types.put(float.class, Type.FLOAT); 44 | types.put(double.class, Type.DOUBLE); 45 | types.put(boolean.class, Type.BOOLEAN); 46 | types.put(String.class, Type.STRING); 47 | } 48 | 49 | public String getParsedOptions() { 50 | return parsedOpts.toString(); 51 | } 52 | 53 | public Object parse(String[] args) { 54 | return parse(args, false, false); 55 | } 56 | 57 | public Object parse(String[] args, boolean exitIfFailed) { 58 | return parse(args, exitIfFailed, false); 59 | } 60 | 61 | private void usage() { 62 | System.err.println(); // cannot use logger since it has not been initialized 63 | for (Option option : options.values()) { 64 | System.err.printf("%-30s%s", option.name(), option.usage()); 65 | if (option.required()) { System.err.printf(" [required]"); } 66 | System.err.println(); 67 | } 68 | System.err.printf("%-30shelp message\n", "-h"); 69 | System.err.println(); 70 | System.exit(2); 71 | } 72 | 73 | private void error(boolean exitIfFailed, String value) { 74 | if (exitIfFailed) { 75 | throw new RuntimeException("Cannot recognize option |" + value + "|"); 76 | } else { 77 | logger.warn("Cannot recognize option |" + value + "|"); 78 | } 79 | } 80 | 81 | public Object parse(String[] args, boolean exitIfFailed, boolean parrot) { 82 | if (parrot) { logger.info("Calling with " + Arrays.deepToString(args)); } 83 | try { 84 | Set seenOpts = new HashSet(); 85 | parsedOpts = new StringBuilder("{"); 86 | Object option = clsOption.newInstance(); 87 | for (int i = 0; i < args.length; i++) { 88 | String key = args[i], value = args[i + 1]; 89 | if (key != null) { key = key.trim(); } 90 | if (value != null) { value = value.trim(); } 91 | if ("-h".equals(key)) { usage(); } 92 | seenOpts.add(key); 93 | Option opt = options.get(key); 94 | if (opt == null) { 95 | error(exitIfFailed, key); 96 | continue; 97 | } 98 | 99 | Field field = fields.get(key); 100 | Class ftype = field.getType(); 101 | if (!ftype.isEnum()) { 102 | switch (types.get(ftype)) { 103 | case CHAR: { 104 | field.setChar(option, value.charAt(0)); 105 | break; 106 | } 107 | case BYTE: { 108 | field.setByte(option, Byte.parseByte(value)); 109 | break; 110 | } 111 | case SHORT: { 112 | field.setShort(option, Short.parseShort(value)); 113 | break; 114 | } 115 | case INT: { 116 | field.setInt(option, Integer.parseInt(value)); 117 | break; 118 | } 119 | case LONG: { 120 | field.setLong(option, Long.parseLong(value)); 121 | break; 122 | } 123 | case FLOAT: { 124 | field.setFloat(option, Float.parseFloat(value)); 125 | break; 126 | } 127 | case DOUBLE: { 128 | field.setDouble(option, Double.parseDouble(value)); 129 | break; 130 | } 131 | case STRING: { 132 | field.set(option, value); 133 | break; 134 | } 135 | case BOOLEAN: { 136 | field.setBoolean(option, Boolean.parseBoolean(value)); 137 | break; 138 | } 139 | default: { 140 | try { 141 | Constructor cst = ftype.getConstructor(new Class[] {String.class}); 142 | field.set(option, cst.newInstance((Object[]) (new String[] { value }))); 143 | } catch (NoSuchMethodException e) { 144 | logger.error("Cannot construct object of type " + ftype.getCanonicalName() + " from the string.\n"); 145 | } 146 | } 147 | } 148 | } else { 149 | Object[] values = ftype.getEnumConstants(); 150 | boolean found = false; 151 | for (Object val : values) { 152 | String name = ((Enum) val).name(); 153 | if (value.equals(name)) { 154 | field.set(option, val); 155 | found = true; 156 | break; 157 | } 158 | } 159 | if (!found) { error(exitIfFailed, value); } 160 | } 161 | // parsedOpts.append(String.format(" %s => %s", opt.name(), value)); 162 | parsedOpts.append(String.format("\n%s, %s,", opt.name(), value)); 163 | i++; 164 | } 165 | // parsedOpts.append(" }"); 166 | parsedOpts.append("\n}"); 167 | Set leftOpts = new HashSet(required); 168 | leftOpts.removeAll(seenOpts); 169 | if (!leftOpts.isEmpty()) { 170 | logger.error("Failed to specify: " + leftOpts + "\n"); 171 | usage(); 172 | } 173 | return option; 174 | } catch (Exception e) { 175 | e.printStackTrace(); 176 | } 177 | return null; 178 | } 179 | 180 | } 181 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/util/Recorder.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.util; 2 | 3 | import java.io.Serializable; 4 | 5 | import org.apache.log4j.Logger; 6 | 7 | public class Recorder implements Serializable { 8 | /** 9 | * 10 | */ 11 | private static final long serialVersionUID = 7851876682201361771L; 12 | protected static Logger logger = null; 13 | protected static LogUtil logUtil = LogUtil.getLogger(); 14 | 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/edu/shanghaitech/ai/nlp/util/ThreadPool.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.util; 2 | 3 | import java.io.Serializable; 4 | import java.util.Comparator; 5 | import java.util.PriorityQueue; 6 | import java.util.concurrent.ExecutionException; 7 | import java.util.concurrent.ExecutorService; 8 | import java.util.concurrent.Executors; 9 | import java.util.concurrent.Future; 10 | import java.util.concurrent.TimeUnit; 11 | 12 | import edu.shanghaitech.ai.nlp.util.Executor.Meta; 13 | 14 | public class ThreadPool extends Recorder implements Serializable { 15 | /** 16 | * 17 | */ 18 | private static final long serialVersionUID = -4870873867836614814L; 19 | public final static String MAIN_THREAD = "MAIN_THREAD"; 20 | protected static Comparator> idcomparator = new Comparator>() { 21 | @Override 22 | public int compare(Meta o1, Meta o2) { 23 | return o1.id - o2.id; 24 | } 25 | }; 26 | protected int lastReturn; 27 | protected int lastSubmission; 28 | 29 | protected int nthread; 30 | protected Future[] submits; 31 | protected Executor[] executors; 32 | protected ExecutorService pool; 33 | protected PriorityQueue> scores; 34 | 35 | 36 | public ThreadPool(Executor executor, int nthread) { 37 | this.nthread = nthread; 38 | this.submits = new Future[nthread]; 39 | this.executors = new Executor[nthread]; 40 | this.pool = Executors.newFixedThreadPool(nthread); 41 | this.scores = new PriorityQueue>(idcomparator); 42 | this.lastReturn = -1; 43 | this.lastSubmission = 0; 44 | for (int i = 0; i < nthread; i++) { 45 | executors[i] = executor.newInstance(); 46 | executors[i].setIdx(i, scores); 47 | } 48 | } 49 | 50 | 51 | public void execute(Object task) { 52 | synchronized (scores) { 53 | while (true) { 54 | for (int i = 0; i < nthread; i++) { 55 | if (submits[i] == null || submits[i].isDone()) { 56 | 57 | // int iret = 1 << 30; 58 | // boolean isnull = submits[i] == null; 59 | // if (!isnull) { 60 | // try { // get the index of the finished task, should be larger than or equal to 0 61 | // iret = (int) submits[i].get(); 62 | // } catch (InterruptedException | ExecutionException e) { 63 | // e.printStackTrace(); 64 | // iret = 1 << 30; 65 | // logger.error("OOPS_BUG: get()\n"); 66 | // throw new RuntimeException("OOPS_BUG: excute\n"); 67 | // } 68 | // } 69 | 70 | // if (MAIN_THREAD.equals(Thread.currentThread().getName())) { 71 | // logger.trace("\n---3------last ret: " + lastReturn + ", last submission: " + lastSubmission + 72 | // ", size: " + scores.size() + ", active: " + Thread.activeCount() + ", isnull: " + isnull + ", iret: " + iret); 73 | // } 74 | 75 | executors[i].setNextTask(lastSubmission, task); 76 | submits[i] = pool.submit(executors[i]); 77 | lastSubmission++; 78 | return; 79 | } 80 | } 81 | try { 82 | scores.wait(); 83 | } catch (InterruptedException e) { 84 | e.printStackTrace(); 85 | } 86 | } 87 | } 88 | } 89 | 90 | 91 | 92 | /** 93 | * A safe implementation of task submission, but is less efficient. 94 | * 95 | * @param task the task to be executed 96 | */ 97 | public void executeSafe(Object task) { 98 | synchronized (scores) { 99 | while (true) { 100 | int iworker = 0, iret = -1; 101 | boolean isnull = false, isdone = false, wait = false; 102 | for (; iworker < nthread; iworker++) { 103 | isnull = submits[iworker] == null; 104 | if (!isnull) { 105 | isdone = submits[iworker].isDone(); 106 | } 107 | if (isnull || isdone) { 108 | break; 109 | } 110 | } // find the free executor 111 | 112 | if (isnull || isdone) { 113 | if (isdone) { 114 | try { // get the index of the finished task, should be larger than or equal to 0 115 | iret = (int) submits[iworker].get(); 116 | } catch (InterruptedException | ExecutionException e) { 117 | iret = -1; 118 | e.printStackTrace(); 119 | throw new IllegalStateException("OOPS_BUG: no return value."); 120 | } 121 | if (iret < 0) { 122 | wait = true; 123 | } 124 | } 125 | // lastSubmission: # of total submitted tasks 126 | // lastReturn + 1: # of total returned tasks 127 | // scores.size() : # of total tasks that are finished and waiting to be retrieved 128 | if (lastSubmission - lastReturn - 1 - scores.size() >= nthread) { 129 | wait = true; // this error should be handled in Callable.call() 130 | throw new IllegalStateException("OOPS_BUG: Number of submissions is larger than that of available threads."); 131 | } 132 | // if (MAIN_THREAD.equals(Thread.currentThread().getName())) { 133 | // logger.trace("\n---3------last ret: " + lastReturn + ", last submission: " + 134 | // lastSubmission + ", size: " + scores.size() + ", active: " + Thread.activeCount() + 135 | // ", isnull: " + isnull + ", iret: " + iret); 136 | // } 137 | } else { // no free worker 138 | wait = true; 139 | } 140 | 141 | if (wait) { 142 | try { 143 | scores.wait(); 144 | } catch (InterruptedException e) { 145 | e.printStackTrace(); 146 | } 147 | } else { 148 | executors[iworker].setNextTask(lastSubmission, task); 149 | submits[iworker] = pool.submit(executors[iworker]); 150 | lastSubmission++; 151 | return; 152 | } 153 | } 154 | } 155 | } 156 | 157 | 158 | public Object getNext() { 159 | if (!hasNext()) { 160 | throw new IllegalStateException("OOPS_BUG: Can only be invoked when there are available results to retrieve."); 161 | } 162 | synchronized (scores) { 163 | // if (MAIN_THREAD.equals(Thread.currentThread().getName())) { 164 | // logger.trace("\n---2------last ret: " + lastReturn + ", last submission: " + lastSubmission + ", size: " + scores.size()); 165 | // } 166 | Meta score = scores.poll(); 167 | lastReturn++; 168 | return score.value(); 169 | } 170 | } 171 | 172 | 173 | public boolean hasNext() { 174 | synchronized (scores) { 175 | // if (MAIN_THREAD.equals(Thread.currentThread().getName())) { 176 | // logger.trace("\n---1------last ret: " + lastReturn + ", last submission: " + lastSubmission + ", size: " + scores.size()); 177 | // } 178 | if (scores.isEmpty()) { return false; } 179 | Meta score = scores.peek(); 180 | return score.id == (lastReturn + 1); 181 | } 182 | } 183 | 184 | 185 | public boolean isDone() { 186 | return (lastSubmission - 1) == lastReturn; 187 | } 188 | 189 | 190 | public void sleep() { 191 | synchronized (scores) { 192 | try { 193 | scores.wait(); 194 | } catch (InterruptedException e) { 195 | e.printStackTrace(); 196 | } 197 | } 198 | } 199 | 200 | 201 | public void shutdown() { 202 | reset(); 203 | if (pool != null) { 204 | try { 205 | pool.shutdown(); 206 | pool.awaitTermination(10, TimeUnit.MILLISECONDS); 207 | } catch (InterruptedException e) { 208 | e.printStackTrace(); 209 | } 210 | } 211 | } 212 | 213 | 214 | public void reset() { 215 | lastReturn = -1; 216 | lastSubmission = 0; 217 | scores.clear(); 218 | } 219 | 220 | } 221 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/data/F1er.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.data; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.File; 5 | import java.io.FileReader; 6 | import java.io.FileWriter; 7 | import java.util.regex.Matcher; 8 | import java.util.regex.Pattern; 9 | 10 | import org.junit.Test; 11 | 12 | public class F1er { 13 | 14 | private final static String PARSED = "^(\\d+).*?(parsed).*?:(.*)"; 15 | private final static String GOLDEN = "^(\\d+).*?(gold).*?:(.*)"; 16 | 17 | 18 | @Test 19 | public void testF1er() throws Exception { 20 | String root = "E:/SourceCode/ParsersData/"; 21 | String infile = root + "/gr_0615_1225_50_3_4_180_nb30_p5_mt4_f1_ep10_40_mt4_n25.log"; 22 | String outfile = root + "/bit"; 23 | splitData(infile, outfile); 24 | } 25 | 26 | private void splitData(String infile, String outfile) throws Exception { 27 | String line = null; 28 | FileWriter writer = null; 29 | FileWriter[] writers = new FileWriter[4]; 30 | writers[0] = new FileWriter(outfile + ".tst.parsed"); 31 | writers[1] = new FileWriter(outfile + ".tst.golded"); 32 | writers[2] = new FileWriter(outfile + ".dev.parsed"); 33 | writers[3] = new FileWriter(outfile + ".dev.golded"); 34 | 35 | Matcher matr = null; 36 | Pattern pat1 = Pattern.compile(PARSED); 37 | Pattern pat2 = Pattern.compile(GOLDEN); 38 | 39 | BufferedReader reader = new BufferedReader(new FileReader(new File(infile))); 40 | 41 | int a0 = 0, a1 = 0, cnt0 = -1, cnt1 = -1; 42 | while ((line = reader.readLine()) != null) { 43 | matr = pat1.matcher(line); 44 | if (matr.find()) { 45 | if (Integer.valueOf(matr.group(1)) == 0) { 46 | cnt0++; 47 | System.out.println("---P->cnt0: " + cnt0 + "\ta0: " + a0); 48 | a0 = 0; 49 | } 50 | writer = writers[cnt0 * 2]; 51 | writer.write(matr.group(3).trim()); 52 | writer.write("\n"); 53 | a0++; 54 | } else { 55 | matr = pat2.matcher(line); 56 | if (matr.find()) { 57 | if (Integer.valueOf(matr.group(1)) == 0) { 58 | cnt1++; 59 | System.out.println("---G->cnt1: " + cnt1 + "\ta1: " + a1); 60 | a1 = 0; 61 | } 62 | writer = writers[cnt1 * 2 + 1]; 63 | writer.write(matr.group(3).trim()); 64 | writer.write("\n"); 65 | a1++; 66 | } 67 | } 68 | } 69 | 70 | reader.close(); 71 | for (FileWriter w : writers) { 72 | if (w != null) { w.close(); } 73 | } 74 | 75 | System.out.println("---G->cnt1: " + cnt1 + "\ta1: " + a1); 76 | System.out.println("---P->cnt0: " + cnt0 + "\ta0: " + a0); 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/lveg/Java.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg; 2 | 3 | import java.text.DecimalFormat; 4 | import java.text.NumberFormat; 5 | import java.text.SimpleDateFormat; 6 | import java.util.ArrayList; 7 | import java.util.Collection; 8 | import java.util.Collections; 9 | import java.util.Date; 10 | import java.util.HashMap; 11 | import java.util.HashSet; 12 | import java.util.Iterator; 13 | import java.util.List; 14 | import java.util.Map; 15 | import java.util.PriorityQueue; 16 | import java.util.Set; 17 | import java.util.regex.Matcher; 18 | import java.util.regex.Pattern; 19 | 20 | import org.junit.Test; 21 | 22 | import edu.shanghaitech.ai.nlp.lveg.impl.BinaryGrammarRule; 23 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture.Component; 24 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule; 25 | import edu.shanghaitech.ai.nlp.util.Debugger; 26 | import edu.shanghaitech.ai.nlp.util.FunUtil; 27 | 28 | public class Java { 29 | 30 | // @Test 31 | public void testJava() { 32 | Ghost ghost0 = new Ghost((short) 0); 33 | System.out.println(ghost0); 34 | Ghost ghost1 = new Ghost((short) 1); 35 | System.out.println(ghost1); 36 | Ghost ghost2 = new Ghost((short) 2); 37 | System.out.println(ghost2); 38 | 39 | System.out.println(ghost2.soul == ghost1.soul); 40 | ghost2.increaseSoul(); 41 | System.out.println(ghost0); 42 | System.out.println(ghost1.soul == ghost0.soul); 43 | } 44 | 45 | 46 | protected static class Ghost { 47 | static short id; 48 | static Soul soul = new Soul(); 49 | 50 | public Ghost(short id) { 51 | this.id = id; 52 | } 53 | 54 | public void increaseSoul() { 55 | soul.increase(); 56 | } 57 | 58 | public String toString() { 59 | return "-" + id + soul; 60 | } 61 | } 62 | 63 | protected static class Soul { 64 | static short id = 0; 65 | String soul = "-xiasini-"; 66 | 67 | public Soul() { 68 | id++; 69 | System.out.println("Do you have the soul?"); 70 | } 71 | 72 | public void increase() { 73 | id++; 74 | } 75 | 76 | public String toString() { 77 | return soul + "-" + id; 78 | } 79 | } 80 | 81 | enum XXX { 82 | a(1), 83 | b(2), 84 | c(3); 85 | 86 | int x; 87 | 88 | XXX(int x) { 89 | this.x = x; 90 | } 91 | 92 | int value() { 93 | return x; 94 | } 95 | 96 | @Override 97 | public String toString() { 98 | return String.valueOf(x); 99 | } 100 | } 101 | 102 | enum YYY { 103 | a, b, c 104 | } 105 | 106 | @Test 107 | public void c() { 108 | // Debugger.debugTreebank(); 109 | 110 | String format = "%.3f"; 111 | NumberFormat formatter = new DecimalFormat("0.###E0"); 112 | double xx = Double.NEGATIVE_INFINITY; 113 | System.out.println(String.format(format, xx)); 114 | System.out.println(formatter.format(xx)); 115 | 116 | double yy = Double.NaN; 117 | System.out.println(String.format(format, yy)); 118 | System.out.println(formatter.format(yy)); 119 | 120 | GrammarRule rule0 = new BinaryGrammarRule((short) 6, (short) 8, (short)8); 121 | GrammarRule rule1 = new BinaryGrammarRule((short) 6, (short) 8, (short)8); 122 | 123 | System.out.println(rule0 = rule1); 124 | System.out.println(rule0.equals(rule1)); 125 | 126 | System.out.println(XXX.b); 127 | System.out.println(YYY.a.ordinal()); 128 | 129 | XXX x1 = XXX.a; 130 | XXX x2 = XXX.b; 131 | System.out.println(x1 == x2); 132 | System.out.println(x1.equals(x2)); 133 | 134 | } 135 | 136 | // @Test 137 | public void a() { 138 | int x = 1; 139 | b(x++); 140 | System.out.println("in a: " + x); 141 | } 142 | 143 | public void b(int a) { 144 | System.out.println("in b: " + a); 145 | } 146 | 147 | // @Test 148 | public void testConcurrentModifier() { 149 | List mylist = new ArrayList(); 150 | mylist.add("1"); 151 | mylist.add("2"); 152 | mylist.add("3"); 153 | mylist.add("4"); 154 | mylist.add("5"); 155 | Iterator it = mylist.iterator(); 156 | while (it.hasNext()) { 157 | String val = it.next(); 158 | System.out.println("list val: " + val); 159 | if ("3".equals(val)) { 160 | // mylist.remove(val); 161 | it.remove(); 162 | } 163 | } 164 | for (String val : mylist) { 165 | System.out.println(val); 166 | } 167 | } 168 | 169 | // @Test 170 | public void testRegx() { 171 | String regx = ".*?(\\d+)"; 172 | String line = "lveg_4.gr"; 173 | 174 | Pattern pat = Pattern.compile(regx); 175 | Matcher matcher; 176 | matcher = pat.matcher(line); 177 | if (matcher.find()) { 178 | System.out.println(matcher.group(1) + "\t" + matcher.groupCount()); 179 | } 180 | 181 | 182 | Set set0 = new HashSet(); 183 | Set set1 = new HashSet(); 184 | 185 | set0.add(1); 186 | set0.add(2); 187 | set0.add(3); 188 | 189 | set1.add(2); 190 | set1.add(3); 191 | set1.add(4); 192 | 193 | set0.retainAll(set1); 194 | System.out.println(set0); 195 | 196 | double m = 1.0 - Double.NEGATIVE_INFINITY; 197 | System.out.println(m); 198 | List x = new ArrayList(); 199 | x.add(m); 200 | System.out.println("scores: " + FunUtil.double2str(x, 3, -1, false, true)); 201 | } 202 | 203 | // @Test 204 | public void testQueue() { 205 | List list = new ArrayList(5); 206 | PriorityQueue sorted = new PriorityQueue(5); 207 | for (int i = 0; i < 5; i++) { 208 | sorted.add(i); 209 | list.add(i); 210 | } 211 | System.out.println(sorted); 212 | for (Integer it : sorted) { 213 | if (it < 3) { 214 | continue; 215 | } 216 | sorted.remove(it); // this is a bad coding example 217 | } 218 | System.out.println(sorted); 219 | 220 | // for (int i = 1; i < 5; i++) { 221 | // list.remove(i); 222 | // } 223 | list.subList(0, list.size()).clear(); 224 | list.add(4); 225 | list.subList(1, list.size()).clear(); 226 | System.out.println(list); 227 | } 228 | 229 | 230 | // @Test 231 | public void testTest() { 232 | int a = 3, b = 3; 233 | assert(a == b); 234 | 235 | // double d0 = -15.332610488842343; 236 | // double d1 = -8.00196935333965; 237 | // double d2 = -14.834674151593733; 238 | // 239 | // double x = FunUtil.logAdd(d0, d1); 240 | // x = FunUtil.logAdd(x, d2); 241 | // System.out.println(x); 242 | // 243 | // double y = Math.log(Math.exp(d0) + Math.exp(d1) + Math.exp(d2)); 244 | // System.out.println(y); 245 | 246 | Map map = new HashMap(); 247 | map.put(0, 3.0); 248 | map.put(4, 1.0); 249 | map.put(2, 2.0); 250 | map.put(3, 9.0); 251 | map.put(7, 5.0); 252 | map.put(8, 6.0); 253 | int k = 4; 254 | Collection values = map.values(); 255 | PriorityQueue queue = new PriorityQueue(); 256 | for (Double d : values) { 257 | queue.offer(d); 258 | if (queue.size() > k) { queue.poll(); } 259 | } 260 | double val = queue.peek(); 261 | System.out.println(val); 262 | while (!queue.isEmpty()) { 263 | System.out.println(queue.poll()); 264 | } 265 | System.out.println(); 266 | queue.clear(); 267 | queue.addAll(values); 268 | while (!queue.isEmpty()) { 269 | System.out.println(queue.poll()); 270 | } 271 | // Collections.sort(values); 272 | } 273 | 274 | // @Test 275 | public void testDate() { 276 | String timeStamp = new SimpleDateFormat(".yyyyMMddHHmmss").format(new Date()); 277 | System.out.println(timeStamp); 278 | 279 | String xx = " "; 280 | System.out.println(xx.trim().equals("")); 281 | 282 | // double l1 = (-5.0620584989385815 + 4.368911318378636); 283 | // double l2 = (-5.0620384989385805 + 4.368891318378635); 284 | 285 | // double l3 = (-5.062048498938581 + 4.368901318378636); 286 | // double l1 = (-5.061048498938581 + 4.367901318378635); 287 | // double l2 = (-5.063048498938581 + 4.369901318378636); 288 | 289 | double l3 = ((-5.063048498938581 + 4.369401193378641) - (-5.061048498938581 + 4.368401193378641)) / (0.002); 290 | // System.out.println(String.format("%.10f", l1)); 291 | // System.out.println(String.format("%.10f", l2)); 292 | System.out.println(String.format("%.10f", l3)); 293 | 294 | System.out.println(0.0 / Double.NEGATIVE_INFINITY); 295 | 296 | } 297 | 298 | 299 | // @Test 300 | public void testMain() { 301 | class LL { 302 | int x = 0; 303 | public LL(int x) { 304 | this.x = x; 305 | } 306 | public String toString() { 307 | return String.valueOf(x); 308 | } 309 | } 310 | Set container = new HashSet(); 311 | LL l0 = new LL(0); 312 | LL l1 = new LL(1); 313 | container.add(l0); 314 | container.add(l1); 315 | 316 | LL[] array = container.toArray(new LL[0]); 317 | for (int i = 0; i < array.length; i++) { 318 | System.out.println(array[i]); 319 | } 320 | array[1].x = 50; 321 | System.out.println(container); 322 | 323 | } 324 | } 325 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/lveg/LVeGLearnerTest.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg; 2 | 3 | import org.junit.Test; 4 | 5 | public class LVeGLearnerTest { 6 | 7 | @Test 8 | public void testLVeGLearner() { 9 | String[] args = {"param.in"}; 10 | try { 11 | // LVeGTrainer.main(args); 12 | LVeGTrainerImp.main(args); 13 | } catch (Exception e) { 14 | e.printStackTrace(); 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/lveg/LVeGPCFGTest.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg; 2 | 3 | import org.junit.Test; 4 | 5 | public class LVeGPCFGTest { 6 | 7 | @Test 8 | public void testLVeGPCFG() { 9 | String[] args = {"param.in"}; 10 | try { 11 | LVeGPCFG.main(args); 12 | } catch (Exception e) { 13 | e.printStackTrace(); 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/lveg/LVeGTesterTest.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg; 2 | 3 | import org.junit.Test; 4 | 5 | public class LVeGTesterTest { 6 | 7 | @Test 8 | public void testLVeGLearner() { 9 | String[] args = {"param.f1"}; 10 | try { 11 | // LVeGTester.main(args); 12 | LVeGTesterImp.main(args); 13 | // LVeGTesterSim.main(args); 14 | } catch (Exception e) { 15 | e.printStackTrace(); 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/lveg/LVeGToyTest.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg; 2 | 3 | import org.junit.Test; 4 | 5 | public class LVeGToyTest { 6 | 7 | @Test 8 | public void testLVeGLearner() { 9 | String[] args = {"config/paramtoy.in"}; 10 | try { 11 | LVeGToy.main(args); 12 | } catch (Exception e) { 13 | e.printStackTrace(); 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/lveg/impl/BinaryGrammarTest.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import java.util.EnumMap; 4 | import java.util.HashSet; 5 | import java.util.Set; 6 | 7 | import org.junit.Test; 8 | 9 | import edu.shanghaitech.ai.nlp.lveg.LVeGTrainer; 10 | import edu.shanghaitech.ai.nlp.lveg.impl.DiagonalGaussianDistribution; 11 | import edu.shanghaitech.ai.nlp.lveg.impl.DiagonalGaussianMixture; 12 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianDistribution; 13 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 14 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule.RuleUnit; 15 | 16 | public class BinaryGrammarTest { 17 | @Test 18 | public void testBinaryGrammarTest() { 19 | GaussianMixture gm = new DiagonalGaussianMixture(LVeGTrainer.ncomponent); 20 | for (int i = 0; i < LVeGTrainer.ncomponent; i++) { 21 | EnumMap> map = new EnumMap<>(RuleUnit.class); 22 | Set list0 = new HashSet(); 23 | Set list1 = new HashSet(); 24 | Set list2 = new HashSet(); 25 | list0.add(new DiagonalGaussianDistribution(LVeGTrainer.dim)); 26 | list1.add(new DiagonalGaussianDistribution(LVeGTrainer.dim)); 27 | list2.add(new DiagonalGaussianDistribution(LVeGTrainer.dim)); 28 | map.put(RuleUnit.P, list0); 29 | map.put(RuleUnit.LC, list1); 30 | map.put(RuleUnit.RC, list2); 31 | gm.add(i, map); 32 | } 33 | System.out.println(gm); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/lveg/impl/LVeGGrammarTest.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import java.util.List; 4 | 5 | import org.junit.Test; 6 | 7 | import edu.shanghaitech.ai.nlp.lveg.impl.BinaryGrammarRule; 8 | import edu.shanghaitech.ai.nlp.lveg.impl.SimpleLVeGGrammar; 9 | import edu.shanghaitech.ai.nlp.lveg.impl.UnaryGrammarRule; 10 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule; 11 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule.RuleType; 12 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGGrammar; 13 | import edu.shanghaitech.ai.nlp.util.Debugger; 14 | 15 | public class LVeGGrammarTest { 16 | 17 | // static { 18 | // circle -3-5-7- 19 | GrammarRule r0 = new UnaryGrammarRule((short) 1, (short) 0, RuleType.LRURULE); 20 | GrammarRule r1 = new UnaryGrammarRule((short) 3, (short) 0, RuleType.LRURULE); 21 | GrammarRule r2 = new UnaryGrammarRule((short) 2, (short) 1, RuleType.LRURULE); 22 | GrammarRule r3 = new UnaryGrammarRule((short) 4, (short) 1, RuleType.LRURULE); 23 | GrammarRule r4 = new UnaryGrammarRule((short) 8, (short) 2, RuleType.LRURULE); 24 | UnaryGrammarRule r5 = new UnaryGrammarRule((short) 9, (short) 2, RuleType.LRURULE); 25 | UnaryGrammarRule r6 = new UnaryGrammarRule((short) 5, (short) 3, RuleType.LRURULE); 26 | UnaryGrammarRule r7 = new UnaryGrammarRule((short) 8, (short) 3, RuleType.LRURULE); 27 | UnaryGrammarRule r8 = new UnaryGrammarRule((short) 9, (short) 3, RuleType.LRURULE); 28 | UnaryGrammarRule r9 = new UnaryGrammarRule((short) 6, (short) 4, RuleType.LRURULE); 29 | UnaryGrammarRule r10 = new UnaryGrammarRule((short) 9, (short) 4, RuleType.LRURULE); 30 | UnaryGrammarRule r11 = new UnaryGrammarRule((short) 7, (short) 5, RuleType.LRURULE); 31 | UnaryGrammarRule r12 = new UnaryGrammarRule((short) 8, (short) 6, RuleType.LRURULE); 32 | UnaryGrammarRule r13 = new UnaryGrammarRule((short) 3, (short) 7, RuleType.LRURULE); 33 | UnaryGrammarRule r14 = new UnaryGrammarRule((short) 8, (short) 7, RuleType.LRURULE); 34 | UnaryGrammarRule r15 = new UnaryGrammarRule((short) 9, (short) 8, RuleType.LRURULE); 35 | UnaryGrammarRule r25 = new UnaryGrammarRule((short) 9, (short) 8, RuleType.LRURULE); 36 | 37 | BinaryGrammarRule rule16= new BinaryGrammarRule((short) 1, (short) 2, (short) 3); 38 | BinaryGrammarRule rule17= new BinaryGrammarRule((short) 2, (short) 4, (short) 6); 39 | BinaryGrammarRule rule18= new BinaryGrammarRule((short) 3, (short) 5, (short) 8); 40 | BinaryGrammarRule rule19= new BinaryGrammarRule((short) 2, (short) 8, (short) 9); 41 | BinaryGrammarRule rule20= new BinaryGrammarRule((short) 5, (short) 1, (short) 2); 42 | BinaryGrammarRule rule21= new BinaryGrammarRule((short) 8, (short) 4, (short) 7); 43 | BinaryGrammarRule rule23= new BinaryGrammarRule((short) 3, (short) 6, (short) 7); 44 | BinaryGrammarRule rule24= new BinaryGrammarRule((short) 7, (short) 8, (short) 9); 45 | 46 | // } 47 | 48 | int nTag = 10; 49 | LVeGGrammar grammar = new SimpleLVeGGrammar(null, nTag); 50 | 51 | @Test 52 | public void testLVeGGrammar() { 53 | grammar.addURule((UnaryGrammarRule) r0); 54 | grammar.addURule((UnaryGrammarRule) r1); 55 | grammar.addURule((UnaryGrammarRule) r2); 56 | grammar.addURule((UnaryGrammarRule) r3); 57 | grammar.addURule((UnaryGrammarRule) r4); 58 | grammar.addURule(r5); 59 | grammar.addURule(r6); 60 | grammar.addURule(r7); 61 | grammar.addURule(r8); 62 | grammar.addURule(r9); 63 | grammar.addURule(r10); 64 | grammar.addURule(r11); 65 | grammar.addURule(r12); 66 | grammar.addURule(r13); 67 | grammar.addURule(r14); 68 | grammar.addURule(r15); 69 | // addURule() test 70 | grammar.addURule(r15); 71 | grammar.addURule(r25); 72 | 73 | /* 74 | for (int i = 0; i < nTag; i++) { 75 | List rules = grammar.getUnaryRuleWithC(i); 76 | System.out.println("Rules with child: " + i); 77 | System.out.println(rules); 78 | } 79 | */ 80 | 81 | Debugger.checkUnaryRuleCircle(grammar, null, true); 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/lveg/impl/RuleTableGeneric.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import java.io.Serializable; 4 | import java.util.HashMap; 5 | import java.util.Map; 6 | import java.util.Set; 7 | 8 | import edu.shanghaitech.ai.nlp.lveg.impl.DiagonalGaussianMixture; 9 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 10 | 11 | /** 12 | * Map grammar rules to their counts. {@link #isCompatible(Object)} must 13 | * work well since it is the key. This kind of implementation does not 14 | * have the practical significance. See the test case in RuleTableTest. 15 | * 16 | * @author Yanpeng Zhao 17 | * 18 | */ 19 | public class RuleTableGeneric implements Serializable { 20 | 21 | /** 22 | * 23 | */ 24 | private static final long serialVersionUID = 1L; 25 | 26 | Class type; 27 | Map table; 28 | 29 | 30 | public RuleTableGeneric(Class type) { 31 | this.table = new HashMap(); 32 | this.type = type; 33 | } 34 | 35 | 36 | public int size() { 37 | return table.size(); 38 | } 39 | 40 | 41 | public void clear() { 42 | table.clear(); 43 | } 44 | 45 | 46 | public boolean isEmpty() { 47 | return size() == 0; 48 | } 49 | 50 | 51 | public Set keySet() { 52 | return table.keySet(); 53 | } 54 | 55 | 56 | public boolean containsKey(T key) { 57 | return table.containsKey(key); 58 | } 59 | 60 | 61 | /** 62 | * Type-specific instance. 63 | * 64 | * @param key search keyword 65 | * @return 66 | * 67 | */ 68 | public boolean isCompatible(T key) { 69 | return type.isInstance(key); 70 | } 71 | 72 | 73 | public GaussianMixture getCount(T key) { 74 | return table.get(key); 75 | } 76 | 77 | 78 | public void setCount(T key, GaussianMixture value) { 79 | if (isCompatible(key)) { 80 | table.put(key, value); 81 | } 82 | } 83 | 84 | 85 | public void increaseCount(T key, double increment) { 86 | GaussianMixture count = getCount(key); 87 | if (count == null) { 88 | GaussianMixture mog = new DiagonalGaussianMixture((short) 0); 89 | mog.add(increment); 90 | setCount(key, mog); 91 | return; 92 | } 93 | count.add(increment); 94 | } 95 | 96 | 97 | public void increaseCount(T key, GaussianMixture increment, boolean prune) { 98 | GaussianMixture count = getCount(key); 99 | if (count == null) { 100 | GaussianMixture mog = new DiagonalGaussianMixture((short) 0); 101 | mog.add(increment, prune); 102 | setCount(key, mog); 103 | return; 104 | } 105 | count.add(increment, prune); 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/lveg/impl/RuleTableTest.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import static org.junit.Assert.*; 4 | 5 | import java.util.HashMap; 6 | import java.util.HashSet; 7 | import java.util.Map; 8 | import java.util.Set; 9 | 10 | import org.junit.Test; 11 | 12 | import edu.shanghaitech.ai.nlp.lveg.impl.BinaryGrammarRule; 13 | import edu.shanghaitech.ai.nlp.lveg.impl.RuleTable; 14 | import edu.shanghaitech.ai.nlp.lveg.impl.UnaryGrammarRule; 15 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule; 16 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule.RuleType; 17 | 18 | public class RuleTableTest { 19 | 20 | @Test 21 | public void testRuleCounter() { 22 | 23 | Map rules = new HashMap(); 24 | 25 | Set ruleSet = new HashSet(); 26 | RuleTable unaryRuleTable = new RuleTable(UnaryGrammarRule.class); 27 | RuleTable binaryRuleTable = new RuleTable(BinaryGrammarRule.class); 28 | 29 | // RuleTable unaryRuleTable = new RuleTable(UnaryGrammarRule.class); 30 | // RuleTable binaryRuleTable = new RuleTable(BinaryGrammarRule.class); 31 | 32 | GrammarRule rule0 = new UnaryGrammarRule((short) 1, (short) 2, RuleType.LRURULE, true); 33 | GrammarRule rule1 = new UnaryGrammarRule((short) 1, (short) 2, RuleType.LRURULE, false); 34 | GrammarRule rule5 = new UnaryGrammarRule((short) 1, (short) 2, RuleType.RHSPACE, true); 35 | GrammarRule rule8 = new UnaryGrammarRule((short) 1, (short) 2); 36 | 37 | GrammarRule rule3 = new BinaryGrammarRule((short) 1, (short) 2, (short) 3, true); 38 | GrammarRule rule4 = new BinaryGrammarRule((short) 1, (short) 2, (short) 3); 39 | GrammarRule rule10= new BinaryGrammarRule((short) 1, (short) 2, (short) 3); 40 | 41 | GrammarRule rule6 = new UnaryGrammarRule((short) 4, (short) 2); 42 | GrammarRule rule7 = new UnaryGrammarRule((short) 4, (short) 2); 43 | 44 | GrammarRule rule9 = new UnaryGrammarRule((short) 7, (short) 2); 45 | GrammarRule rule2 = new UnaryGrammarRule((short) 7, (short) 2); 46 | 47 | 48 | ruleSet.add(rule0); 49 | ruleSet.add(rule5); 50 | ruleSet.add(rule3); 51 | 52 | GrammarRule[] ruleArray = ruleSet.toArray(new GrammarRule[0]); 53 | for (int i = 0; i < ruleArray.length; i++) { 54 | System.out.println(ruleArray[i]); 55 | } 56 | // ruleArray[1].type = 3; 57 | // System.out.println(ruleSet); 58 | 59 | 60 | assertFalse(rule0 instanceof BinaryGrammarRule); 61 | assertTrue(rule0 instanceof UnaryGrammarRule); 62 | assertTrue(rule0 instanceof GrammarRule); 63 | 64 | 65 | rules.put(rule0, -1.0); 66 | rules.put(rule3, 20.0); 67 | assertFalse(rules.containsKey(rule5)); 68 | assertTrue(rules.containsKey(rule1)); 69 | assertTrue(rules.containsKey(rule4)); 70 | for (GrammarRule rule : rules.keySet()) { 71 | System.out.print("Is Unary: " + rule.isUnary() + "\t"); 72 | System.out.println(rules.get(rule)); 73 | } 74 | System.out.println("Return Value: " + rules.get(rule5)); 75 | System.out.println(rules); 76 | 77 | 78 | assertTrue(unaryRuleTable.isCompatible(rule0)); 79 | assertFalse(unaryRuleTable.isCompatible(rule3)); 80 | 81 | assertTrue(binaryRuleTable.isCompatible(rule4)); 82 | assertFalse(binaryRuleTable.isCompatible(rule1)); 83 | 84 | 85 | unaryRuleTable.addCount(rule0, 1); 86 | if (unaryRuleTable.containsKey(rule1)) { 87 | System.out.println("UnaryGrammarRule: It works."); 88 | } else { 89 | System.err.println("UnaryGrammarRule: Oops."); 90 | } 91 | assertTrue(unaryRuleTable.containsKey(rule1)); 92 | assertFalse(unaryRuleTable.containsKey(rule5)); 93 | assertTrue(unaryRuleTable.containsKey(rule8)); 94 | 95 | unaryRuleTable.addCount(rule6, 1); 96 | assertTrue(unaryRuleTable.containsKey(rule7)); 97 | 98 | unaryRuleTable.addCount(rule3, 1); 99 | unaryRuleTable.containsKey(rule3); 100 | assertFalse(unaryRuleTable.containsKey(rule4)); 101 | 102 | 103 | binaryRuleTable.addCount(rule3, 1); 104 | if (binaryRuleTable.containsKey(rule4)) { 105 | System.out.println("BinaryGrammarRule: It works."); 106 | } else { 107 | System.err.println("BinaryGrammarRule:Oops."); 108 | } 109 | assertTrue(binaryRuleTable.containsKey(rule4)); 110 | assertTrue(binaryRuleTable.containsKey(rule10)); 111 | 112 | // test the reference mechanism of add() method of Set 113 | Set set = new HashSet(); 114 | set.add((UnaryGrammarRule) rule0); 115 | System.out.println(set); 116 | rule0.setLhs((short) 100); 117 | System.out.println(set); 118 | 119 | 120 | RuleTableGeneric uTable = new RuleTableGeneric(UnaryGrammarRule.class); 121 | RuleTableGeneric bTable = new RuleTableGeneric(BinaryGrammarRule.class); 122 | 123 | 124 | GrammarRule rule12 = new UnaryGrammarRule((short) 1, (short) 2); 125 | UnaryGrammarRule rule13= new UnaryGrammarRule((short) 1, (short) 2); 126 | 127 | GrammarRule rule14 = new BinaryGrammarRule((short) 1, (short) 2, (short) 3, false); 128 | BinaryGrammarRule rule15 = new BinaryGrammarRule((short) 1, (short) 2, (short) 3, false); 129 | System.out.println("Hello. " + rule14); 130 | 131 | // assertTrue(uTable.isCompatible(rule12)); 132 | // assertTrue(uTable.isCompatible(rule14)); 133 | 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/lveg/impl/UnaryGrammarRuleTest.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.impl; 2 | 3 | import java.util.EnumMap; 4 | import java.util.HashSet; 5 | import java.util.Set; 6 | 7 | import org.junit.Test; 8 | 9 | import edu.shanghaitech.ai.nlp.lveg.LVeGTrainer; 10 | import edu.shanghaitech.ai.nlp.lveg.impl.DiagonalGaussianDistribution; 11 | import edu.shanghaitech.ai.nlp.lveg.impl.DiagonalGaussianMixture; 12 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianDistribution; 13 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 14 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule.RuleUnit; 15 | 16 | public class UnaryGrammarRuleTest { 17 | 18 | @Test 19 | public void testGMadd0() { 20 | GaussianMixture gm = new DiagonalGaussianMixture(LVeGTrainer.ncomponent); 21 | for (int i = 0; i < LVeGTrainer.ncomponent; i++) { 22 | Set set = new HashSet(); 23 | set.add(new DiagonalGaussianDistribution(LVeGTrainer.dim)); 24 | gm.add(i, RuleUnit.C, set); 25 | } 26 | System.out.println(gm); 27 | } 28 | 29 | 30 | @Test 31 | public void testGMadd1() { 32 | GaussianMixture gm = new DiagonalGaussianMixture(LVeGTrainer.ncomponent); 33 | for (int i = 0; i < LVeGTrainer.ncomponent; i++) { 34 | Set set = new HashSet(); 35 | set.add(new DiagonalGaussianDistribution(LVeGTrainer.dim)); 36 | gm.add(i, RuleUnit.P, set); 37 | } 38 | System.out.println(gm); 39 | } 40 | 41 | @Test 42 | public void testGMadd2() { 43 | GaussianMixture gm = new DiagonalGaussianMixture(LVeGTrainer.ncomponent); 44 | for (int i = 0; i < LVeGTrainer.ncomponent; i++) { 45 | EnumMap> map = new EnumMap<>(RuleUnit.class); 46 | Set set0 = new HashSet(); 47 | Set set1 = new HashSet(); 48 | set0.add(new DiagonalGaussianDistribution(LVeGTrainer.dim)); 49 | set1.add(new DiagonalGaussianDistribution(LVeGTrainer.dim)); 50 | map.put(RuleUnit.P, set0); 51 | map.put(RuleUnit.UC, set1); 52 | gm.add(i, map); 53 | } 54 | System.out.println(gm); 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/lveg/model/GaussianDistributionTest.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.model; 2 | 3 | import static org.junit.Assert.*; 4 | 5 | import java.util.ArrayList; 6 | import java.util.List; 7 | 8 | import org.junit.Test; 9 | 10 | import edu.shanghaitech.ai.nlp.lveg.impl.DiagonalGaussianDistribution; 11 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianDistribution; 12 | 13 | public class GaussianDistributionTest { 14 | 15 | @Test 16 | public void testGaussianDistribution() { 17 | GaussianDistribution gd = new DiagonalGaussianDistribution(); 18 | System.out.println(gd); 19 | } 20 | 21 | 22 | @Test 23 | public void testInstanceEqual() { 24 | GaussianDistribution gd0 = new DiagonalGaussianDistribution(); 25 | GaussianDistribution gd1 = new DiagonalGaussianDistribution(); 26 | GaussianDistribution gd2 = new DiagonalGaussianDistribution((short) 5); 27 | 28 | assertFalse(gd0 == gd1); 29 | assertTrue(gd0.equals(gd1)); 30 | 31 | gd0.getMus().add(2.0); 32 | gd0.getMus().add(3.0); 33 | gd0.getVars().add(4.0); 34 | 35 | gd1.getMus().add(2.0); 36 | gd1.getMus().add(3.0); 37 | gd1.getVars().add(4.0); 38 | 39 | assertTrue(gd0.equals(gd1)); 40 | 41 | gd1.getMus().add(2.0); 42 | assertFalse(gd0.equals(gd1)); 43 | } 44 | 45 | 46 | @Test 47 | public void testDoubleEqual() { 48 | List xx = new ArrayList(); 49 | List yy = new ArrayList(); 50 | 51 | xx.add(1.0); 52 | xx.add(2.0); 53 | 54 | yy.addAll(xx); 55 | 56 | assertTrue(xx.equals(yy)); 57 | 58 | xx.set(0, 3.0); 59 | 60 | System.out.println(xx); 61 | System.out.println(yy); 62 | xx.clear(); 63 | System.out.println(yy); 64 | } 65 | 66 | 67 | @Test 68 | public void testStringEqual() { 69 | String str0 = new String("nihaoa"); 70 | String str1 = "nihaoa"; 71 | String str2 = str1; 72 | String str3 = new String("nihaoa"); 73 | 74 | assertFalse(str0 == str3); 75 | assertTrue(str0.equals(str3)); 76 | 77 | assertTrue(str1 == str2); 78 | assertTrue(str0.equals(str1)); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/lveg/model/InferencerTest.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.lveg.model; 2 | 3 | import org.junit.Test; 4 | 5 | import edu.shanghaitech.ai.nlp.lveg.impl.LVeGInferencer; 6 | import edu.shanghaitech.ai.nlp.lveg.model.ChartCell.Chart; 7 | 8 | public class InferencerTest { 9 | 10 | private static LVeGInferencer inferencer = new LVeGInferencer(null, null); 11 | 12 | // @Test 13 | public void testInferencer() { 14 | // inferencer.insideScore(null, false); 15 | // inferencer.outsideScore(null); 16 | 17 | Chart chart = new Chart(5, true, false, false); 18 | // insideScore(chart, 5, 0, 4); 19 | for (int i = 0; i < 5; i++) { 20 | outsideScore(chart, 5, i, i); 21 | System.out.println(); 22 | } 23 | } 24 | 25 | 26 | @Test 27 | public void testInsideLoop() { 28 | // inside score 29 | int nword = 5; 30 | for (int ilayer = 1; ilayer < nword; ilayer++) { 31 | for (int left = 0; left < nword - ilayer; left++) { 32 | for (int right = left; right < left + ilayer; right++) { 33 | System.out.println(left + "," + (left + ilayer) + "(" + left + "," + right + "), (" + 34 | (right + 1) + "," + (left + ilayer) + ")"); 35 | } 36 | } 37 | System.out.println(); 38 | } 39 | } 40 | 41 | 42 | // @Test 43 | public void insideScore(Chart chart, int nword, int begin, int end) { 44 | 45 | if (begin == end) { 46 | System.out.println("(" + begin + ", " + end + ")"); 47 | } 48 | 49 | for (int split = begin; split < end; split++) { 50 | 51 | int x0 = begin, y0 = split; 52 | int x1 = split + 1, y1 = end; 53 | int l0 = nword - (y0 - x0), l1 = nword - (y1 - x1); 54 | 55 | if (!chart.getStatus(Chart.idx(x0, l0), true)) { 56 | insideScore(chart, nword, begin, split); 57 | } 58 | if (!chart.getStatus(Chart.idx(x1, l1), true)) { 59 | insideScore(chart, nword, split + 1, end); 60 | } 61 | 62 | System.out.println("-(" + begin + ", " + end + ")--" + 63 | "-(" + x0 + ", " + y0 + ")" + "-(" + x1 + ", " + y1 + ")"); 64 | } 65 | 66 | chart.setStatus(Chart.idx(begin, nword - (end - begin)), true, true); 67 | System.out.println(); 68 | } 69 | 70 | 71 | // @Test 72 | public void outsideScore(Chart chart, int nword, int begin, int end) { 73 | 74 | if (begin == 0 && end == nword - 1) { 75 | if (!chart.getStatus(Chart.idx(begin, nword - (end - begin)), true)) { 76 | System.out.println("(" + begin + ", " + end + ")"); 77 | chart.setStatus(Chart.idx(begin, nword - (end - begin)), true, true); 78 | } 79 | } 80 | 81 | int x0, y0, x1, y1, l0, l1; 82 | 83 | for (int right = end + 1; right < nword; right++) { 84 | x0 = begin; 85 | y0 = right; 86 | x1 = end + 1; 87 | y1 = right; 88 | l0 = y0 - x0; 89 | l1 = y1 - x1; 90 | if (!chart.getStatus(Chart.idx(x0, l0), true)) { 91 | outsideScore(chart, nword, x0, y0); 92 | } 93 | 94 | System.out.println("-(" + begin + ", " + end + ")--" + 95 | "-(" + x0 + ", " + y0 + ")" + "-(" + x1 + ", " + y1 + ")"); 96 | } 97 | 98 | for (int left = 0; left < begin; left++) { 99 | x0 = left; 100 | y0 = end; 101 | x1 = left; 102 | y1 = begin - 1; 103 | l0 = y0 - x0; 104 | l1 = y1 - x1; 105 | if (!chart.getStatus(Chart.idx(x0, l0), true)) { 106 | outsideScore(chart, nword, x0, y0); 107 | } 108 | 109 | System.out.println("=(" + begin + ", " + end + ")--" + 110 | "-(" + x0 + ", " + y0 + ")" + "-(" + x1 + ", " + y1 + ")"); 111 | } 112 | 113 | chart.setStatus(Chart.idx(begin, nword - (end - begin)), true, true); 114 | System.out.println(); 115 | } 116 | 117 | 118 | // @Test 119 | public void testOutsideLoop() { 120 | int nword = 5; 121 | for (int ilayer = nword - 1; ilayer >= 0; ilayer--) { 122 | for (int left = 0; left < nword - ilayer; left++) { 123 | 124 | for (int right = left + ilayer + 1; right < nword; right++) { 125 | System.out.println(left + "," + (left + ilayer) + "(" + left + "," + right + "), (" + 126 | (left + ilayer + 1) + "," + (right) + ")"); 127 | } 128 | 129 | 130 | for (int right = 0; right < left; right++) { 131 | System.out.println(left + "," + (left + ilayer) + "(" + right + "," + (left + ilayer) + "), (" + 132 | (right) + "," + (left - 1) + ")---"); 133 | } 134 | } 135 | System.out.println(); 136 | } 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/optimization/ParallelOptimizerTest.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.optimization; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | import java.util.concurrent.Callable; 6 | import java.util.concurrent.ExecutorService; 7 | import java.util.concurrent.Executors; 8 | import java.util.concurrent.TimeUnit; 9 | 10 | import org.junit.Test; 11 | 12 | import edu.shanghaitech.ai.nlp.lveg.model.Parser; 13 | import edu.shanghaitech.ai.nlp.util.ThreadPool; 14 | 15 | public class ParallelOptimizerTest { 16 | 17 | protected static int THREADS_NUM = 2; 18 | 19 | 20 | protected static class Muppet { 21 | protected static int maxiter = 5; 22 | 23 | public static void staticPrint(int i, int isample) { 24 | for (; i < maxiter; i++) { 25 | // for (int i = 0; i < 5; i++) { 26 | System.out.println("Muppet.static:\tisample_" + isample + "\t" + Thread.currentThread().getId() + ": " + i); 27 | } 28 | } 29 | 30 | public void nonStaticPrint(int i, int isample) { 31 | for (; i < maxiter; i++) { 32 | // for (int i = 0; i < 5; i++) { 33 | System.out.println("Muppet.non-static:\tisample_" + isample + "\t" + Thread.currentThread().getId() + ": " + i); 34 | } 35 | } 36 | } 37 | 38 | protected static class Puppet extends Parser { 39 | /** 40 | * 41 | */ 42 | private static final long serialVersionUID = 3714346641593051149L; 43 | protected static int maxiter = 3; 44 | protected Muppet muppet; 45 | protected String name; 46 | 47 | private Puppet(Puppet puppet) { 48 | super((short) 0, (short) 1, false, false, false); 49 | this.muppet = puppet.muppet; 50 | this.name = puppet.name; 51 | } 52 | 53 | public Puppet(String name) { 54 | super((short) 0, (short) 1, false, false, false); 55 | this.name = name; 56 | this.muppet = new Muppet(); 57 | } 58 | 59 | public void staticPrint(int i, int isample) { 60 | Muppet.staticPrint(i, isample); 61 | } 62 | 63 | public void nonStaticPrint(int i, int isample) { 64 | muppet.nonStaticPrint(i, isample); 65 | } 66 | 67 | protected void setContext(boolean test) {} 68 | 69 | @Override 70 | public synchronized Object call() throws Exception { 71 | String ll = getName(); 72 | 73 | // staticPrint(idx, isample); // 0: uncomment to test static method accessing 74 | 75 | Meta cache = new Meta(itask, ll); 76 | 77 | synchronized (muppet) { 78 | 79 | // staticPrint(idx, isample); // 1: uncomment to test static method accessing 80 | 81 | muppet.nonStaticPrint(idx, itask); 82 | muppet.notifyAll(); 83 | } 84 | 85 | synchronized (caches) { 86 | 87 | // nonStaticPrint(idx, isample); 88 | 89 | caches.add(cache); 90 | caches.notifyAll(); 91 | } 92 | task = null; 93 | return null; 94 | } 95 | 96 | public String getName() { 97 | return name + "_" + idx; 98 | } 99 | 100 | @Override 101 | public Parser newInstance() { 102 | return new Puppet(this); 103 | } 104 | } 105 | 106 | 107 | // @Test 108 | public void testMultiThreadedPool() { 109 | // static or non-static methods accessing test 110 | String ll = null; 111 | int nthread = 2, nfailed = 0; 112 | Puppet puppet = new Puppet("puppet"); 113 | ThreadPool mpuppet = new ThreadPool(puppet, nthread); 114 | for (int i = 0; i < 4; i++) { 115 | mpuppet.execute(null); 116 | while (mpuppet.hasNext()) { 117 | ll = (String) mpuppet.getNext(); 118 | if (ll == null) { 119 | nfailed++; 120 | } else { 121 | System.out.println("~~~>name: " + ll); 122 | } 123 | } 124 | } 125 | while (!mpuppet.isDone()) { 126 | while (mpuppet.hasNext()) { 127 | ll = (String) mpuppet.getNext(); 128 | if (ll == null) { 129 | nfailed++; 130 | } else { 131 | System.out.println("~~~>name: " + ll); 132 | } 133 | } 134 | } 135 | System.out.println("---summary: nfailed=" + nfailed); 136 | mpuppet.shutdown(); 137 | } 138 | 139 | 140 | @Test 141 | public void testParallelOptimizer() { 142 | ExecutorService pool = Executors.newFixedThreadPool(THREADS_NUM); 143 | List> tasks = new ArrayList>(2); 144 | 145 | Puppet p0 = new Puppet("0"); 146 | Puppet p1 = p0; 147 | 148 | tasks.add(new Callable() { 149 | 150 | @Override 151 | public Boolean call() throws Exception { 152 | p0.nonStaticPrint(1, 0); 153 | return true; 154 | } 155 | 156 | }); 157 | 158 | tasks.add(new Callable() { 159 | 160 | @Override 161 | public Boolean call() throws Exception { 162 | p1.nonStaticPrint(2, 1); 163 | return true; 164 | } 165 | 166 | }); 167 | 168 | // see the comments in ParallelOptimizer.useCustomizedBlock(); 169 | for (Callable task : tasks) { 170 | pool.submit(task); 171 | } 172 | boolean exit = true; 173 | try { 174 | pool.shutdown(); 175 | exit = pool.awaitTermination(0, TimeUnit.MILLISECONDS); 176 | } catch (InterruptedException e) { 177 | e.printStackTrace(); 178 | } 179 | System.out.println("exit: " + exit + "... " + pool.isTerminated()); 180 | 181 | /* 182 | try { 183 | pool.invokeAll(tasks); 184 | } catch (InterruptedException e) { 185 | e.printStackTrace(); 186 | } 187 | */ 188 | } 189 | 190 | } 191 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/optimization/SGDForMoGTest.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.optimization; 2 | 3 | import java.util.HashSet; 4 | import java.util.Set; 5 | 6 | import org.junit.Test; 7 | 8 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule; 9 | 10 | public class SGDForMoGTest { 11 | 12 | @Test 13 | public void testSGDForMoG() { 14 | // String key0 = GrammarRule.Unit.P; 15 | // String key1 = "2"; 16 | // 17 | // int x = Integer.valueOf(key0); 18 | int a = 24; 19 | int b = 24; 20 | assert(a == b); 21 | 22 | char x = (char) -1; 23 | // x = 1; 24 | System.out.println(Integer.toBinaryString(x)); 25 | System.out.println(Integer.toBinaryString((short) x)); 26 | System.out.println((short) x); 27 | System.out.println(x == 1); 28 | System.out.println((int) x == -1); 29 | 30 | char z = 1; 31 | System.out.println(Integer.toBinaryString(z)); 32 | System.out.println(Integer.toBinaryString((short) z)); 33 | System.out.println(z == 1); 34 | 35 | byte y = -1; 36 | System.out.println(Byte.MAX_VALUE + ", " + Byte.MIN_VALUE + ", " + Integer.toBinaryString(y)); 37 | 38 | Byte f = -1; 39 | System.out.println(f); 40 | 41 | int xx = 98; 42 | int mm = 100; 43 | System.out.println(Integer.toBinaryString(xx)); 44 | xx = (1 << 31) + (xx << 16); 45 | System.out.println(Integer.toBinaryString(xx)); 46 | xx += mm; 47 | System.out.println(Integer.toBinaryString(mm)); 48 | 49 | System.out.println(Integer.toBinaryString(xx) + "\txx < 0: " + (xx < 0)); 50 | 51 | xx = ((xx << 1) >>> 1); 52 | System.out.println(Integer.toBinaryString(xx)); 53 | int xxx = (xx >>> 16); 54 | System.out.println(Integer.toBinaryString(xxx)); 55 | int mmm = ((xx << 16) >>> 16); 56 | System.out.println(Integer.toBinaryString(mmm)); 57 | 58 | System.out.println(Math.exp(Math.log(0.0) - 3)); 59 | 60 | Set tmp = new HashSet(); 61 | tmp.add((short) 0); 62 | System.out.println(tmp.contains((short) 0)); 63 | 64 | } 65 | 66 | } 67 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/syntax/StateTest.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.syntax; 2 | 3 | import static org.junit.Assert.assertTrue; 4 | 5 | import org.junit.Test; 6 | 7 | import edu.shanghaitech.ai.nlp.lveg.impl.DiagonalGaussianMixture; 8 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 9 | 10 | public class StateTest { 11 | 12 | @Test 13 | public void testState() { 14 | State state = new State("hello", (short) 0, (short) 1, (short) 2); 15 | GaussianMixture iscore = new DiagonalGaussianMixture((short) 2); 16 | GaussianMixture oscore = new DiagonalGaussianMixture((short) 2); 17 | state.setInsideScore(iscore); 18 | state.setOutsideScore(oscore); 19 | 20 | // System.out.println(state.toString(false, (short) 2)); 21 | assertTrue(state.getInsideScore() != null); 22 | state.clear(false); 23 | // System.out.println(state.toString(false, (short) 2)); 24 | assertTrue(state.getInsideScore() == null); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/util/FunUtilTest.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.util; 2 | 3 | import java.io.IOException; 4 | import java.nio.charset.StandardCharsets; 5 | import java.util.ArrayList; 6 | import java.util.Collections; 7 | import java.util.List; 8 | import java.util.Random; 9 | 10 | import org.junit.Test; 11 | 12 | import edu.shanghaitech.ai.nlp.lveg.LVeGTrainer; 13 | import edu.shanghaitech.ai.nlp.lveg.LearnerConfig; 14 | 15 | public class FunUtilTest { 16 | 17 | @Test 18 | public void testFunUtil() { 19 | int maxint = 10; 20 | Integer[] arrayInt = new Integer[5]; 21 | FunUtil.randomInitArray(arrayInt, Integer.class, maxint); 22 | FunUtil.printArray(arrayInt); 23 | 24 | Double[] arrayDouble = new Double[5]; 25 | FunUtil.randomInitArray(arrayDouble, Double.class, 1); 26 | FunUtil.printArray(arrayDouble); 27 | 28 | List listInt = new ArrayList(); 29 | FunUtil.randomInitList(LVeGTrainer.random, listInt, Integer.class, 5, maxint, 0.5, false, true); 30 | FunUtil.printList(listInt); 31 | 32 | List listDouble = new ArrayList(); 33 | FunUtil.randomInitList(LVeGTrainer.random, listDouble, Double.class, 5, 1, 0.5, false, true); 34 | FunUtil.printList(listDouble); 35 | 36 | int[] arrayint = new int[5]; 37 | FunUtil.randomInitArrayInt(arrayint, maxint); 38 | FunUtil.printArrayInt(arrayint); 39 | 40 | double[] arraydouble = new double[5]; 41 | FunUtil.randomInitArrayDouble(arraydouble); 42 | FunUtil.printArrayDouble(arraydouble); 43 | 44 | Double[] xarray = new Double[0]; 45 | FunUtil.printArray(xarray); 46 | 47 | double x = -30; 48 | double y = -6; 49 | double z = FunUtil.logAdd(x, y); 50 | double m = Math.log((Math.exp(x) + Math.exp(y))); 51 | System.out.println("Precision of the logAdd method is: " + (m - z) + ", [m = " + m + ", z =" + z + "]"); 52 | } 53 | 54 | // @Test 55 | public void testShuffle() { 56 | List listInt = new ArrayList(); 57 | FunUtil.randomInitList(LVeGTrainer.random, listInt, Integer.class, 5, 2, 0.5, false, true); 58 | System.out.println("---shuffle test---"); 59 | System.out.println(listInt); 60 | Collections.shuffle(listInt, new Random(0)); 61 | System.out.println(listInt); 62 | } 63 | 64 | @Test 65 | public void stringFormat() { 66 | double x = Double.parseDouble("1e-3"); 67 | System.out.println(x); 68 | 69 | 70 | try { 71 | String str = LearnerConfig.readFile("param.ini", StandardCharsets.UTF_8); 72 | System.out.println(str); 73 | String[] arr = str.split(","); 74 | for (int i = 0; i < arr.length; i++) { 75 | System.out.println(i + " | " + arr[i].trim() + ")"); 76 | } 77 | } catch (IOException e) { 78 | e.printStackTrace(); 79 | } 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/util/LogUtilTest.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.util; 2 | 3 | import org.apache.log4j.Logger; 4 | import org.junit.Test; 5 | 6 | public class LogUtilTest extends Recorder { 7 | 8 | @Test 9 | public void testLogUtils() { 10 | String logFile = "log/0_logger_test"; 11 | /* get console logger before the file logger */ 12 | 13 | // Logger logger0 = logUtil.getConsoleLogger(); 14 | Logger logger1 = logUtil.getFileLogger(logFile); 15 | Logger logger0 = logUtil.getConsoleLogger(); 16 | 17 | logger1.fatal("File Fatal 1\n"); 18 | 19 | logger0.trace("Trace\n"); 20 | logger0.debug("Debug\n"); 21 | logger0.info("Info\n"); 22 | logger0.warn("Warn\n"); 23 | logger0.error("Error\n"); 24 | logger0.fatal("Fatal\n"); 25 | 26 | logger1.trace("Trace\n"); 27 | logger1.debug("Debug\n"); 28 | logger1.info("Info\n"); 29 | logger1.warn("Warn\n"); 30 | logger1.error("Error\n"); 31 | logger1.fatal("Fatal\n"); 32 | 33 | logger0.fatal("File Fatal 1\n"); 34 | 35 | } 36 | 37 | // @Test 38 | public void testBothLogger() { 39 | String logfile = "log/0_logger_test"; 40 | Logger logger0 = logUtil.getBothLogger(logfile); 41 | 42 | logger0.trace("Trace\n"); 43 | logger0.debug("Debug\n"); 44 | logger0.info("Info\n"); 45 | logger0.warn("Warn\n"); 46 | logger0.error("Error\n"); 47 | logger0.fatal("Fatal\n"); 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/test/java/edu/shanghaitech/ai/nlp/util/ObjectToolTest.java: -------------------------------------------------------------------------------- 1 | package edu.shanghaitech.ai.nlp.util; 2 | 3 | import java.util.Random; 4 | 5 | import org.apache.commons.pool2.impl.GenericKeyedObjectPoolConfig; 6 | import org.junit.Before; 7 | import org.junit.Test; 8 | 9 | import edu.shanghaitech.ai.nlp.lveg.impl.GaussFactory; 10 | import edu.shanghaitech.ai.nlp.lveg.impl.MoGFactory; 11 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianDistribution; 12 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture; 13 | 14 | public class ObjectToolTest { 15 | 16 | private ObjectPool mogPool; 17 | private ObjectPool gaussPool; 18 | 19 | @Before 20 | public void setup() throws Exception { 21 | GenericKeyedObjectPoolConfig config = new GenericKeyedObjectPoolConfig(); 22 | // config.setMaxTotalPerKey(Integer.MAX_VALUE); 23 | // config.setMaxTotal(Integer.MAX_VALUE); 24 | 25 | config.setMaxTotalPerKey(8); 26 | config.setMaxTotal(2000); 27 | 28 | 29 | config.setBlockWhenExhausted(true); 30 | config.setMaxWaitMillis(1500); 31 | config.setTestOnBorrow(true); 32 | // config.setTestOnCreate(false); 33 | // config.setTestOnReturn(false); 34 | 35 | StringBuffer sb = new StringBuffer(); 36 | sb.append("---max idle per key: " + config.getMaxIdlePerKey() + "\n" 37 | + "---max total: " + config.getMaxTotal() + "\n" 38 | + "---max total per key: " + config.getMaxTotalPerKey() + "\n" 39 | + "---min idle per key: " + config.getMinIdlePerKey() + "\n"); 40 | 41 | short defaultval = 2; 42 | Random rnd = new Random(0); 43 | MoGFactory mfactory = new MoGFactory(defaultval, defaultval, 0.5, rnd); 44 | GaussFactory gfactory = new GaussFactory(defaultval, defaultval, defaultval, 0.5, 0.5, rnd); 45 | 46 | mogPool = new ObjectPool(mfactory, config); 47 | gaussPool = new ObjectPool(gfactory, config); 48 | 49 | sb.append("---max active: " + mogPool.getNumActive() + "\n" 50 | + "---block: " + mogPool.getBlockWhenExhausted() + "\n" 51 | + "---wait: " + mogPool.getMaxWaitMillis() + "\n" 52 | + "---waitb: " + mogPool.getMaxBorrowWaitTimeMillis() + "\n"); 53 | System.out.println(sb.toString()); 54 | } 55 | 56 | 57 | // @Test 58 | public void testStress() throws Exception { 59 | int total = 2044612, factor = 4; 60 | long beginTime = System.currentTimeMillis(); 61 | for (int i= 0; i < total; i++) { 62 | mogPool.borrowObject((short) -1); 63 | for (int j = 0; j < factor; j++) { 64 | gaussPool.borrowObject((short) -1); 65 | } 66 | if (i % 100 == 0) { 67 | System.out.println("-" + (i / 100)); 68 | } 69 | } 70 | long endTime = System.currentTimeMillis(); 71 | System.out.println("------->it consumed " + (endTime - beginTime) / 1000.0 + "s\n"); 72 | } 73 | 74 | 75 | @Test 76 | public void testObjectTool() { 77 | GaussianMixture lastone = null; 78 | int total = 20, cnt = 0, mid = 8; 79 | for (int i = 0; cnt < total; i++, cnt++) { 80 | if (cnt < mid) { 81 | i = 0; 82 | } else { 83 | i = cnt - mid + 1; 84 | } 85 | System.out.println("-----------------" + cnt + "\t" + i); 86 | try { 87 | GaussianMixture mog = mogPool.borrowObject((short) i); 88 | lastone = mog; 89 | // gaussPool.borrowObject((short) i); 90 | System.out.println(mog); 91 | } catch (Exception e) { 92 | e.printStackTrace(); 93 | } 94 | } 95 | 96 | mogPool.returnObject(lastone.getKey(), lastone); 97 | 98 | } 99 | 100 | } 101 | --------------------------------------------------------------------------------