├── .gitignore
├── README.md
├── build.sbt
├── config
├── log4j.xml
└── paramtoy.in
├── param.f1
├── param.in
├── project
├── Dependencies.scala
├── build.properties
└── plugins.sbt
└── src
├── main
└── java
│ └── edu
│ └── shanghaitech
│ └── ai
│ └── nlp
│ ├── data
│ ├── LVeGCorpus.java
│ ├── ObjectFileManager.java
│ └── StateTreeList.java
│ ├── eval
│ └── EnglishPennTreebankParseEvaluator.java
│ ├── lveg
│ ├── LVeGPCFG.java
│ ├── LVeGTester.java
│ ├── LVeGTesterImp.java
│ ├── LVeGTesterSim.java
│ ├── LVeGToy.java
│ ├── LVeGTrainer.java
│ ├── LVeGTrainerImp.java
│ ├── LearnerConfig.java
│ ├── PCFGMaxRule.java
│ ├── impl
│ │ ├── BinaryGrammarRule.java
│ │ ├── DiagonalGaussianDistribution.java
│ │ ├── DiagonalGaussianMixture.java
│ │ ├── GaussFactory.java
│ │ ├── GeneralGaussianDistribution.java
│ │ ├── GeneralGaussianMixture.java
│ │ ├── LVeGInferencer.java
│ │ ├── LVeGParser.java
│ │ ├── MaxRuleInferencer.java
│ │ ├── MaxRuleParser.java
│ │ ├── MoGFactory.java
│ │ ├── PCFGInferencer.java
│ │ ├── PCFGMaxRuleInferencer.java
│ │ ├── PCFGMaxRuleParser.java
│ │ ├── PCFGParser.java
│ │ ├── RuleTable.java
│ │ ├── SimpleLVeGGrammar.java
│ │ ├── SimpleLVeGLexicon.java
│ │ ├── UnaryGrammarRule.java
│ │ └── Valuator.java
│ └── model
│ │ ├── ChartCell.java
│ │ ├── GaussianDistribution.java
│ │ ├── GaussianMixture.java
│ │ ├── GrammarRule.java
│ │ ├── Inferencer.java
│ │ ├── LVeGGrammar.java
│ │ ├── LVeGLexicon.java
│ │ └── Parser.java
│ ├── optimization
│ ├── Batch.java
│ ├── Gradient.java
│ ├── Optimizer.java
│ ├── ParallelOptimizer.java
│ ├── SGDMinimizer.java
│ ├── SimpleMinimizer.java
│ └── SimpleOptimizer.java
│ ├── syntax
│ ├── State.java
│ └── Tree.java
│ └── util
│ ├── Debugger.java
│ ├── ErrorUtil.java
│ ├── Executor.java
│ ├── FunUtil.java
│ ├── GradientChecker.java
│ ├── Indexer.java
│ ├── LogUtil.java
│ ├── MutableInteger.java
│ ├── Numberer.java
│ ├── ObjectPool.java
│ ├── Option.java
│ ├── OptionParser.java
│ ├── Recorder.java
│ └── ThreadPool.java
└── test
└── java
└── edu
└── shanghaitech
└── ai
└── nlp
├── data
├── ConstraintTester.java
└── F1er.java
├── lveg
├── Java.java
├── LVeGLearnerTest.java
├── LVeGPCFGTest.java
├── LVeGTesterTest.java
├── LVeGToyTest.java
├── impl
│ ├── BinaryGrammarTest.java
│ ├── DiagonalGaussianDistributionTest.java
│ ├── LVeGGrammarTest.java
│ ├── RuleTableGeneric.java
│ ├── RuleTableTest.java
│ └── UnaryGrammarRuleTest.java
└── model
│ ├── GaussianDistributionTest.java
│ ├── GaussianMixtureTest.java
│ └── InferencerTest.java
├── optimization
├── ParallelOptimizerTest.java
└── SGDForMoGTest.java
├── syntax
└── StateTest.java
└── util
├── FunUtilTest.java
├── LogUtilTest.java
└── ObjectToolTest.java
/.gitignore:
--------------------------------------------------------------------------------
1 | bin/
2 | log/
3 | libs/
4 | pool/
5 | target/
6 | script/
7 | classes/
8 | berkeley/
9 | .settings/
10 | implementation_notes.txt
11 | .classpath
12 | .project
13 | .log
14 |
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Latent Vector Grammars (LVeGs)
2 |
3 | Code for [Gaussian Mixture Latent Vector Grammars](https://arxiv.org/abs/1805.04688).
4 |
5 | ## How to Run
6 |
7 | Specify parameter values for learning (param.in) and for inference (param.f1).
8 |
9 | ### Learning
10 |
11 | ```
12 | $ sbt "run-main edu.shanghaitech.ai.nlp.lveg.LVeGTrainerImp param.in"
13 | ```
14 |
15 | ### Inference
16 |
17 | ```
18 | $ sbt "run-main edu.shanghaitech.ai.nlp.lveg.LVeGTesterImp param.f1"
19 | ```
20 |
21 | ## Data
22 |
23 | Parsing data available at [Google Drive](https://drive.google.com/open?id=1sSwTaVgKJe-oA7jsoM6Mbij_j-xDUdH7).
24 |
25 | ## Models
26 |
27 | Parsing models available at [Google Drive](https://drive.google.com/open?id=1CqWOMn7xWfax5Sj5lP_ypKYieasjviKd).
28 |
29 | ## Dependencies
30 |
31 | sbt-0.13.10, java-1.8.0, and scala-2.12.0.
--------------------------------------------------------------------------------
/build.sbt:
--------------------------------------------------------------------------------
1 | import Version._
2 |
3 | lazy val commonSettings = Seq(
4 | organization := "edu.shanghaitech.ai.nlp",
5 | version := "0.0.1",
6 |
7 | scalaVersion := Version.scala,
8 | crossScalaVersions := Seq("2.12.0", "2.10.4"),
9 |
10 | libraryDependencies ++= Seq(
11 | Library.log4j,
12 | Library.junit,
13 | Library.pool2,
14 | Library.BerkeleyParser
15 | ),
16 | javacOptions ++= Seq("-source", "1.8"),
17 | javaOptions += "-Xmx10g",
18 | fork := true
19 | )
20 |
21 | lazy val root = project
22 | .in(file("."))
23 | .settings(commonSettings: _*)
24 | .settings(
25 | name := "lveg"
26 | )
--------------------------------------------------------------------------------
/config/log4j.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/config/paramtoy.in:
--------------------------------------------------------------------------------
1 | -datadir, E:/SourceCode/ParsersData/wsj/,
2 | -train, wsj_s2-21_tree,
3 | -test, wsj_s23_tree,
4 | -dev, wsj_s22_tree,
5 | -inCorpus, wsj_all.tree,
6 | -outCorpus, wsj_all.tree,
7 | -saveCorpus, false,
8 | -loadCorpus, false,
9 |
10 | -outGrammar, lveg,
11 | -nbatchSave, 10,
12 | -saveGrammar, true,
13 |
14 | -lr, 1,
15 | -reg, false,
16 | -clip, false,
17 | -absmax, 5.0,
18 | -wdecay, 1e-4,
19 | -l1, true,
20 | -minmw, 1e-9,
21 | -epsilon, 1e-8,
22 | -choice, SGD,
23 | -lambda, 0.9,
24 | -lambda1, 0.9,
25 | -lambda2, 0.999,
26 |
27 | -ntcyker, 6,
28 | -ntbatch, 6,
29 | -ntgrad, 6,
30 | -nteval, 6,
31 | -nttest, 6,
32 | -pclose, true,
33 | -pcyker, false,
34 | -pbatch, true,
35 | -peval, true,
36 | -pgrad, true,
37 | -pmode, THREAD_POOL,
38 | -pverbose, false,
39 |
40 | -iosprune, false,
41 | -sampling, false,
42 | -riserate, 2,
43 | -maxnbig, -30,
44 | -rtratio, -0.2,
45 | -hardcut, false,
46 | -expzero, 1e-6,
47 | -bsize, 1,
48 | -nepoch, 3,
49 | -maxsample, 1,
50 | -maxslen, 50,
51 | -nAllowedDrop, 6,
52 | -maxramdom, 1,
53 | -maxmw, 4,
54 | -nwratio, 0.8,
55 | -maxmu, 1,
56 | -nmratio, 0.5,
57 | -maxvar, 5,
58 | -nvratio, 0.8,
59 | -ncomponent, 2,
60 | -dim, 1,
61 | -resetw, false,
62 | -resetc, false,
63 | -mwfactor, 1.0,
64 | -usemasks, false,
65 | -tgbase, 3,
66 | -tgratio, 0.2,
67 | -tgprob, 1e-8,
68 | -iomask, true,
69 |
70 | -eratio, -0.1,
71 | -efraction, 1.0,
72 | -eonlylen, -1,
73 | -efirstk, -1,
74 | -eontrain, true,
75 | -eondev, false,
76 | -eonextradev, false,
77 | -ellprune, false,
78 | -ellimwrite, false,
79 | -epochskipk, 1,
80 | -enbatchdev, 1,
81 |
82 | -pf1, false,
83 | -ef1tag, ,
84 | -inGrammar, lveg_best.gr,
85 | -loadGrammar, false,
86 | -ef1prune, false,
87 | -ef1ondev, false,
88 | -ef1ontrain, false,
89 | -ef1imwrite, false,
90 |
91 | -runtag, toy_0827_complex_test1_2_1_bs1,
92 | -logroot, log/,
93 | -logtype, 0,
94 | -verbose, false,
95 | -nbatch, 1,
96 | -precision, 3,
97 | -rndomseed, 0,
98 |
99 | -dgradnbatch, 1,
100 |
101 | -imgprefix, lveg
--------------------------------------------------------------------------------
/param.f1:
--------------------------------------------------------------------------------
1 | -datadir, F:/SourceCode/ParsersData/wsj/,
2 | -train, wsj_s2-21_tree,
3 | -test, wsj_s23_tree,
4 | -dev, wsj_s22_tree,
5 | -inCorpus, wsj_all_r0.tree,
6 | -outCorpus, wsj_all.tree,
7 | -saveCorpus, false,
8 | -loadCorpus, true,
9 |
10 | -ntcyker, 6,
11 | -nttest, 6,
12 | -pclose, false,
13 | -pcyker, false,
14 |
15 | -riserate, 2,
16 | -rtratio, -0.2,
17 | -hardcut, false,
18 | -expzero, 1e-3,
19 | -maxramdom, 1,
20 | -maxmw, 10,
21 | -nwratio, 0.8,
22 | -maxmu, 0.1,
23 | -nmratio, 0.5,
24 | -maxvar, 5,
25 | -nvratio, 0.8,
26 | -ncomponent, 1,
27 | -dim, 3,
28 | -maxnbig, 10,
29 | -maxslen, 70,
30 | -usemasks, true,
31 | -tgbase, 3,
32 | -tgratio, 1.0,
33 | -tgprob, 1e-4,
34 | -iomask, false,
35 | -sexp, 0.35,
36 |
37 | -eratio, -0.1,
38 | -eonlylen, 200,
39 | -efirstk, 3000,
40 | -eonextradev, false,
41 |
42 | -pf1, false,
43 | -ef1tag, debug,
44 | -consfile, gr.sophis.200.23.cons,
45 | -inGrammar, lveg_final.gr,
46 | -eusestag, true,
47 | -ef1prune, true,
48 | -ef1ondev, false,
49 | -ef1ontrain, false,
50 | -ef1imwrite, false,
51 |
52 | -runtag, gr_0827_1437_50_4_3_180_nb40_p0_mt55_mf8_l1_r1,
53 | -logroot, log/,
54 | -logtype, 0,
55 | -verbose, false,
56 | -precision, 3,
57 | -rndomseed, 0,
58 |
59 | -imgprefix, lveg
--------------------------------------------------------------------------------
/param.in:
--------------------------------------------------------------------------------
1 | -datadir, E:/SourceCode/ParsersData/wsj/,
2 | -train, wsj_s2-21_tree,
3 | -test, wsj_s23_tree,
4 | -dev, wsj_s22_tree,
5 | -inCorpus, wsj_all_r1.tree,
6 | -outCorpus, wsj_all_r0.tree,
7 | -saveCorpus, false,
8 | -loadCorpus, true,
9 |
10 | -outGrammar, lveg,
11 | -nbatchSave, 20,
12 | -saveGrammar, true,
13 |
14 | -lr, 0.01,
15 | -reg, false,
16 | -clip, false,
17 | -absmax, 5.0,
18 | -wdecay, 0.1,
19 | -l1, true,
20 | -minmw, 1e-10,
21 | -epsilon, 1e-8,
22 | -choice, ADAM,
23 | -lambda, 0.9,
24 | -lambda1, 0.9,
25 | -lambda2, 0.999,
26 |
27 | -ntcyker, 6,
28 | -ntbatch, 6,
29 | -ntgrad, 6,
30 | -nteval, 6,
31 | -nttest, 6,
32 | -pclose, false,
33 | -pcyker, false,
34 | -pbatch, true,
35 | -peval, true,
36 | -pgrad, true,
37 | -pmode, THREAD_POOL,
38 | -pverbose, false,
39 |
40 | -iosprune, true,
41 | -sampling, false,
42 | -riserate, 2,
43 | -maxnbig, 40,
44 | -rtratio, -0.2,
45 | -hardcut, false,
46 | -expzero, 1e-3,
47 | -bsize, 180,
48 | -nepoch, 30,
49 | -maxsample, 1,
50 | -maxslen, 60,
51 | -nAllowedDrop, 0,
52 | -maxramdom, 1,
53 | -maxmw, 10,
54 | -nwratio, 0.8,
55 | -maxmu, 0.1,
56 | -nmratio, 0.5,
57 | -maxvar, 5,
58 | -nvratio, 0.8,
59 | -ncomponent, 1,
60 | -dim, 3,
61 | -resetw, true,
62 | -resetc, true,
63 | -mwfactor, 8.0,
64 | -usemasks, true,
65 | -tgbase, 3,
66 | -tgratio, 1.0,
67 | -tgprob, 5e-5,
68 | -iomask, false,
69 | -sexp, 0.35,
70 | -pivota, 0,
71 | -pivotb, 4000,
72 | -resetl, true,
73 | -resetp, false,
74 |
75 | -eratio, -0.1,
76 | -efraction, -0.25,
77 | -eonlylen, 3,
78 | -efirstk, 200,
79 | -eontrain, false,
80 | -eondev, true,
81 | -eonextradev, true,
82 | -ellprune, true,
83 | -ellimwrite, false,
84 | -epochskipk, 1,
85 | -enbatchdev, 1,
86 |
87 | -pf1, true,
88 | -ef1tag, final,
89 | -inGrammar, lveg_final.gr,
90 | -loadGrammar, false,
91 | -ef1prune, true,
92 | -ef1ondev, true,
93 | -ef1ontrain, false,
94 | -ef1imwrite, false,
95 |
96 | -consfile, ,
97 | -runtag, 2018_0212_4_3_3_200_test2,
98 | -logroot, log/,
99 | -logtype, 0,
100 | -verbose, false,
101 | -nbatch, 0,
102 | -precision, 3,
103 | -rndomseed, 0,
104 |
105 | -dgradnbatch, -1,
106 |
107 | -imgprefix, lveg
--------------------------------------------------------------------------------
/project/Dependencies.scala:
--------------------------------------------------------------------------------
1 | import sbt._
2 |
3 | object Version {
4 | val scala = "2.12.0"
5 | val junit = "4.12"
6 | val log4j = "1.2.16"
7 | val pool2 = "2.4.2"
8 | val BerkeleyParser = "1.7"
9 | }
10 |
11 | object Library {
12 | val junit = "junit" % "junit" % Version.junit
13 | val log4j = "log4j" % "log4j" % Version.log4j
14 | val pool2 = "org.apache.commons" % "commons-pool2" % Version.pool2
15 | val BerkeleyParser = "BerkeleyParser" % "BerkeleyParser" % Version.BerkeleyParser from "https://raw.githubusercontent.com/slavpetrov/berkeleyparser/master/BerkeleyParser-1.7.jar"
16 | }
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version=0.13.10
2 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaoyanpeng/lveg/d2765585dfb095d5a8c862896d944576aa4ac9a8/project/plugins.sbt
--------------------------------------------------------------------------------
/src/main/java/edu/shanghaitech/ai/nlp/data/LVeGCorpus.java:
--------------------------------------------------------------------------------
1 | package edu.shanghaitech.ai.nlp.data;
2 |
3 | import java.io.Serializable;
4 | import java.util.List;
5 |
6 | import edu.berkeley.nlp.syntax.Tree;
7 | import edu.berkeley.nlp.util.Counter;
8 | import edu.shanghaitech.ai.nlp.lveg.impl.SimpleLVeGLexicon;
9 | import edu.shanghaitech.ai.nlp.syntax.State;
10 |
11 | /**
12 | * @author Yanpeng Zhao
13 | *
14 | */
15 | public class LVeGCorpus implements Serializable {
16 | /**
17 | *
18 | */
19 | private static final long serialVersionUID = 1278753896112455981L;
20 |
21 | /**
22 | * @param trees the parse tree
23 | * @param lexicon temporary data holder (wordIndexer)
24 | * @param rareThreshold words with frequencies lower than this value will be replaced with its signature
25 | */
26 | public static void replaceRareWords(
27 | StateTreeList trees, SimpleLVeGLexicon lexicon, int rareThreshold) {
28 | Counter wordCounter = new Counter<>();
29 | for (Tree tree : trees) {
30 | List words = tree.getYield();
31 | for (State word : words) {
32 | String name = word.getName();
33 | wordCounter.incrementCount(name, 1.0);
34 | lexicon.wordIndexer.add(name);
35 | }
36 | }
37 |
38 | for (Tree tree : trees) {
39 | List words = tree.getYield();
40 | int pos = 0;
41 | for (State word : words) {
42 | String name = word.getName();
43 | if (wordCounter.getCount(name) <= rareThreshold) {
44 | name = lexicon.getCachedSignature(name, pos);
45 | word.setName(name);
46 | }
47 | pos++;
48 | }
49 | }
50 | }
51 |
52 | }
53 |
--------------------------------------------------------------------------------
/src/main/java/edu/shanghaitech/ai/nlp/data/ObjectFileManager.java:
--------------------------------------------------------------------------------
1 | package edu.shanghaitech.ai.nlp.data;
2 |
3 | import java.io.File;
4 | import java.io.FileInputStream;
5 | import java.io.FileOutputStream;
6 | import java.io.IOException;
7 | import java.io.ObjectInputStream;
8 | import java.io.ObjectOutputStream;
9 | import java.io.Serializable;
10 | import java.text.SimpleDateFormat;
11 | import java.util.Date;
12 | import java.util.Set;
13 | import java.util.zip.GZIPInputStream;
14 | import java.util.zip.GZIPOutputStream;
15 |
16 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGGrammar;
17 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGLexicon;
18 | import edu.shanghaitech.ai.nlp.util.Numberer;
19 |
20 | public class ObjectFileManager {
21 |
22 | public static class ObjectFile implements Serializable {
23 | /**
24 | *
25 | */
26 | private static final long serialVersionUID = 1852249590891181238L;
27 |
28 | public boolean save(String filename) {
29 | /*
30 | if (new File(filename).exists()) {
31 | filename += new SimpleDateFormat(".yyyyMMddHHmmss").format(new Date());
32 | }
33 | */
34 | try {
35 | FileOutputStream fos = new FileOutputStream(filename);
36 | GZIPOutputStream gos = new GZIPOutputStream(fos);
37 | ObjectOutputStream oos = new ObjectOutputStream(gos);
38 | oos.writeObject(this);
39 | oos.flush();
40 | oos.close();
41 | gos.close();
42 | fos.close();
43 | } catch (IOException e) {
44 | e.printStackTrace();
45 | return false;
46 | }
47 | return true;
48 | }
49 |
50 | public static Object load(String filename) {
51 | Object o = null;
52 | try {
53 | FileInputStream fis = new FileInputStream(filename);
54 | GZIPInputStream gis = new GZIPInputStream(fis);
55 | ObjectInputStream ois = new ObjectInputStream(gis);
56 | o = ois.readObject();
57 | ois.close();
58 | gis.close();
59 | fis.close();
60 | } catch (IOException | ClassNotFoundException e) {
61 | e.printStackTrace();
62 | return null;
63 | }
64 | return o;
65 | }
66 | }
67 |
68 |
69 | public static class CorpusFile extends ObjectFile {
70 | /**
71 | *
72 | */
73 | private static final long serialVersionUID = -5871763111246836457L;
74 | private StateTreeList train;
75 | private StateTreeList test;
76 | private StateTreeList dev;
77 | private Numberer numberer;
78 |
79 | public CorpusFile(StateTreeList train, StateTreeList test, StateTreeList dev, Numberer numberer) {
80 | this.numberer = numberer;
81 | this.train = train;
82 | this.test = test;
83 | this.dev = dev;
84 | }
85 |
86 | public StateTreeList getTrain() {
87 | return train;
88 | }
89 |
90 | public StateTreeList getTest() {
91 | return test;
92 | }
93 |
94 | public StateTreeList getDev() {
95 | return dev;
96 | }
97 |
98 | public Numberer getNumberer() {
99 | return numberer;
100 | }
101 | }
102 |
103 |
104 | public static class GrammarFile extends ObjectFile {
105 | /**
106 | *
107 | */
108 | private static final long serialVersionUID = -8119623200836693006L;
109 | private LVeGGrammar grammar;
110 | private LVeGLexicon lexicon;
111 |
112 | public GrammarFile(LVeGGrammar grammar, LVeGLexicon lexicon) {
113 | this.grammar = grammar;
114 | this.lexicon = lexicon;
115 | }
116 |
117 | public LVeGGrammar getGrammar() {
118 | return grammar;
119 | }
120 |
121 | public LVeGLexicon getLexicon() {
122 | return lexicon;
123 | }
124 | }
125 |
126 |
127 | public static class Constraint extends ObjectFile {
128 | /**
129 | *
130 | */
131 | private static final long serialVersionUID = -235658934972194254L;
132 | private Set[][][] constraints;
133 |
134 | public Constraint(Set[][][] constraints) {
135 | this.constraints = constraints;
136 | }
137 |
138 | public Set[][][] getConstraints() {
139 | return constraints;
140 | }
141 | }
142 | }
143 |
--------------------------------------------------------------------------------
/src/main/java/edu/shanghaitech/ai/nlp/data/StateTreeList.java:
--------------------------------------------------------------------------------
1 | package edu.shanghaitech.ai.nlp.data;
2 |
3 | import java.io.Serializable;
4 | import java.util.AbstractCollection;
5 | import java.util.ArrayList;
6 | import java.util.Collections;
7 | import java.util.Iterator;
8 | import java.util.List;
9 | import java.util.Random;
10 |
11 | import edu.berkeley.nlp.syntax.Tree;
12 | import edu.shanghaitech.ai.nlp.syntax.State;
13 | import edu.shanghaitech.ai.nlp.util.Numberer;
14 |
15 | /**
16 | * @author Yanpeng Zhao
17 | *
18 | */
19 | public class StateTreeList extends AbstractCollection> implements Serializable {
20 | /**
21 | *
22 | */
23 | private static final long serialVersionUID = 6424710107698526446L;
24 | private final static short ID_WORD = -1;
25 | private List> trees;
26 |
27 |
28 | public class StateTreeListIterator implements Iterator> {
29 |
30 | Tree currentTree;
31 | Iterator> treeListIter;
32 |
33 | public StateTreeListIterator() {
34 | treeListIter = trees.iterator();
35 | currentTree = null;
36 | }
37 |
38 | @Override
39 | public boolean hasNext() {
40 | // TODO Auto-generated method stub
41 | if (currentTree != null) {
42 | // TODO
43 | }
44 | return treeListIter.hasNext();
45 | }
46 |
47 | @Override
48 | public Tree next() {
49 | // TODO Auto-generated method stub
50 | currentTree = treeListIter.next();
51 | return currentTree;
52 | }
53 |
54 | @Override
55 | public void remove() {
56 | treeListIter.remove();
57 | }
58 |
59 | }
60 |
61 |
62 | @Override
63 | public Iterator> iterator() {
64 | // TODO Auto-generated method stub
65 | return new StateTreeListIterator();
66 | }
67 |
68 | @Override
69 | public int size() {
70 | // TODO Auto-generated method stub
71 | return trees.size();
72 | }
73 |
74 | @Override
75 | public boolean isEmpty() {
76 | return trees.isEmpty();
77 | }
78 |
79 | @Override
80 | public boolean add(Tree tree) {
81 | return trees.add(tree);
82 | }
83 |
84 | public Tree get(int i) {
85 | return trees.get(i);
86 | }
87 |
88 |
89 | public StateTreeList() {
90 | this.trees = new ArrayList<>();
91 | }
92 |
93 |
94 | public StateTreeList(StateTreeList stateTreeList) {
95 | this.trees = new ArrayList<>();
96 | for (Tree tree: stateTreeList.trees) {
97 | trees.add(copyTreeButLeaf(tree));
98 | }
99 | }
100 |
101 |
102 | /**
103 | * The leaf is copied by reference. It's the same as
104 | * {@code edu.berkeley.nlp.PCFGLA.StateSetTreeList.resizeStateSetTree}
105 | *
106 | * @param tree a parse tree
107 | * @return
108 | *
109 | */
110 | private Tree copyTreeButLeaf(Tree tree) {
111 | if (tree.isLeaf()) { return tree; }
112 | State state = new State(tree.getLabel(), false);
113 | List> children = new ArrayList<>();
114 | for (Tree child : tree.getChildren()) {
115 | children.add(copyTreeButLeaf(child));
116 | }
117 | return new Tree(state, children);
118 | }
119 |
120 |
121 | public StateTreeList copy() {
122 | StateTreeList stateTreeList = new StateTreeList();
123 | for (Tree tree : trees) {
124 | stateTreeList.add(copyTree(tree));
125 | }
126 | return stateTreeList;
127 | }
128 |
129 |
130 | private Tree copyTree(Tree tree) {
131 | List> children = new ArrayList<>(tree.getChildren().size());
132 | for (Tree child : tree.getChildren()) {
133 | children.add(copyTree(child));
134 | }
135 | return new Tree(tree.getLabel().copy(), children);
136 | }
137 |
138 |
139 | public StateTreeList(List> trees, Numberer numberer) {
140 | this.trees = new ArrayList<>();
141 | for (Tree tree : trees) {
142 | this.trees.add(stringTreeToStateTree(tree, numberer));
143 | tree = null; // clean the memory
144 | }
145 | }
146 |
147 |
148 | /**
149 | * @param trees parse trees
150 | * @param numberer recording the ids of tags
151 | *
152 | */
153 | public static void initializeNumbererTag(
154 | List> trees, Numberer numberer) {
155 | for (Tree tree : trees) {
156 | stringTreeToStateTree(tree, numberer);
157 | }
158 | }
159 |
160 |
161 | public static void stringTreeToStateTree(
162 | List> trees, Numberer numberer) {
163 | for (Tree tree : trees) {
164 | stringTreeToStateTree(tree, numberer);
165 | }
166 | }
167 |
168 |
169 | /**
170 | * @param tree a parse tree
171 | * @param numberer record the ids of tags
172 | * @return parse tree represented by the state list
173 | *
174 | */
175 | public static Tree stringTreeToStateTree(Tree tree, Numberer numberer) {
176 | Tree result = stringTreeToStateTree(tree, numberer, 0, tree.getYield().size());
177 | List words = result.getYield();
178 | for (short pos = 0; pos < words.size(); pos++) {
179 | words.get(pos).from = pos;
180 | words.get(pos).to = (short) (pos + 1);
181 | }
182 | return result;
183 | }
184 |
185 |
186 | public void shuffle(Random rnd) {
187 | Collections.shuffle(trees, rnd);
188 | }
189 |
190 |
191 | public void reset() {
192 | for (Tree tree : trees) {
193 | reset(tree);
194 | }
195 | }
196 |
197 |
198 | public void reset(Tree tree) {
199 | if (tree.isLeaf()) { return; }
200 | if (tree.getLabel() != null) {
201 | tree.getLabel().clear(false);
202 | }
203 | for (Tree child : tree.getChildren()) {
204 | reset(child);
205 | }
206 | }
207 |
208 |
209 | /**
210 | * Convert a state tree to a string tree.
211 | *
212 | * @param tree a state tree
213 | * @param numberer which records the ids of tags
214 | * @return
215 | */
216 | public static Tree stateTreeToStringTree(Tree tree, Numberer numberer) {
217 | if (tree.isLeaf()) {
218 | String name = tree.getLabel().getName();
219 | return new Tree(name);
220 | }
221 |
222 | String name = (String) numberer.object(tree.getLabel().getId());
223 | Tree newTree = new Tree(name);
224 | List> children = new ArrayList<>();
225 |
226 | for (Tree child : tree.getChildren()) {
227 | Tree newChild = stateTreeToStringTree(child, numberer);
228 | children.add(newChild);
229 | }
230 | newTree.setChildren(children);
231 | return newTree;
232 | }
233 |
234 |
235 | /**
236 | * Convert a string tree to a state tree.
237 | *
238 | * @param tree a parse tree
239 | * @param numberer which records the ids of tags
240 | * @param from starting point of the span
241 | * @param to ending point of the span
242 | * @return parse tree represented by the state list
243 | *
244 | */
245 | private static Tree stringTreeToStateTree(Tree tree, Numberer numberer, int from, int to) {
246 | if (tree.isLeaf()) {
247 | State state = new State(tree.getLabel().intern(), (short) ID_WORD, (short) from, (short) to);
248 | return new Tree(state);
249 | }
250 | /* numberer is initialized here */
251 | short id = (short) numberer.number(tree.getLabel());
252 |
253 | // System.out.println(tree.getLabel().intern()); // tag name
254 | State state = new State(null, id, (short) from, (short) to);
255 | Tree newTree = new Tree(state);
256 | List> children = new ArrayList<>();
257 | for (Tree child : tree.getChildren()) {
258 | short length = (short) child.getYield().size();
259 | Tree newChild = stringTreeToStateTree(child, numberer, from, from + length);
260 | from += length;
261 | children.add(newChild);
262 | }
263 | newTree.setChildren(children);
264 | return newTree;
265 | }
266 |
267 | }
268 |
--------------------------------------------------------------------------------
/src/main/java/edu/shanghaitech/ai/nlp/lveg/LVeGPCFG.java:
--------------------------------------------------------------------------------
1 | package edu.shanghaitech.ai.nlp.lveg;
2 |
3 | import java.io.IOException;
4 | import java.nio.charset.StandardCharsets;
5 | import java.util.ArrayList;
6 | import java.util.Arrays;
7 | import java.util.HashSet;
8 | import java.util.List;
9 | import java.util.Map;
10 | import java.util.PriorityQueue;
11 |
12 | import edu.berkeley.nlp.PCFGLA.TreeAnnotations;
13 | import edu.berkeley.nlp.syntax.Tree;
14 | import edu.shanghaitech.ai.nlp.data.StateTreeList;
15 | import edu.shanghaitech.ai.nlp.data.ObjectFileManager.GrammarFile;
16 | import edu.shanghaitech.ai.nlp.eval.EnglishPennTreebankParseEvaluator;
17 | import edu.shanghaitech.ai.nlp.lveg.impl.PCFGParser;
18 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianDistribution;
19 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture;
20 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGGrammar;
21 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGLexicon;
22 | import edu.shanghaitech.ai.nlp.optimization.Optimizer;
23 | import edu.shanghaitech.ai.nlp.syntax.State;
24 | import edu.shanghaitech.ai.nlp.util.FunUtil;
25 | import edu.shanghaitech.ai.nlp.util.Numberer;
26 | import edu.shanghaitech.ai.nlp.util.OptionParser;
27 | import edu.shanghaitech.ai.nlp.util.ThreadPool;
28 |
29 | public class LVeGPCFG extends LearnerConfig {
30 | /**
31 | *
32 | */
33 | private static final long serialVersionUID = 8232031232691463175L;
34 |
35 | protected static EnglishPennTreebankParseEvaluator.LabeledConstituentEval scorer;
36 | protected static StateTreeList trainTrees;
37 | protected static StateTreeList testTrees;
38 | protected static StateTreeList devTrees;
39 |
40 | protected static LVeGGrammar grammar;
41 | protected static LVeGLexicon lexicon;
42 |
43 | protected static PCFGParser, ?> pcfgParser;
44 | protected static ThreadPool mparser;
45 |
46 | protected static String treeFile;
47 | protected static Options opts;
48 |
49 | public static void main(String[] args) throws Exception {
50 | String fparams = args[0];
51 | try {
52 | args = readFile(fparams, StandardCharsets.UTF_8).split(",");
53 | } catch (IOException e) {
54 | e.printStackTrace();
55 | }
56 | OptionParser optionParser = new OptionParser(Options.class);
57 | opts = (Options) optionParser.parse(args, true);
58 | // configurations
59 | initialize(opts, true); // logger can only be used after the initialization
60 | logger.info("Calling with " + optionParser.getParsedOptions() + "\n");
61 |
62 | // loading data
63 | Numberer wrapper = new Numberer();
64 | Map trees = loadData(wrapper, opts);
65 | // training
66 | long startTime = System.currentTimeMillis();
67 | train(trees, wrapper);
68 | long endTime = System.currentTimeMillis();
69 | logger.trace("[total time consumed by LVeG tester] " + (endTime - startTime) / 1000.0 + "\n");
70 | }
71 |
72 |
73 | private static void train(Map trees, Numberer wrapper) throws Exception {
74 | trainTrees = trees.get(ID_TRAIN);
75 | testTrees = trees.get(ID_TEST);
76 | devTrees = trees.get(ID_DEV);
77 |
78 | treeFile = sublogroot + opts.imgprefix;
79 |
80 | Numberer numberer = wrapper.getGlobalNumberer(KEY_TAG_SET);
81 |
82 | /* to ease the parameters tuning */
83 | GaussianMixture.config(opts.maxnbig, opts.expzero, opts.maxmw, opts.ncomponent,
84 | opts.nwratio, opts.riserate, opts.rtratio, opts.hardcut, random, mogPool);
85 | GaussianDistribution.config(opts.maxmu, opts.maxvar, opts.dim, opts.nmratio, opts.nvratio, random, gaussPool);
86 | Optimizer.config(opts.choice, random, opts.maxsample, opts.bsize, opts.minmw, opts.sampling); // FIXME no errors, just alert you...
87 |
88 | // load grammar
89 | logger.trace("--->Loading grammars from \'" + subdatadir + opts.inGrammar + "\'...\n");
90 | GrammarFile gfile = (GrammarFile) GrammarFile.load(subdatadir + opts.inGrammar);
91 | grammar = gfile.getGrammar();
92 | lexicon = gfile.getLexicon();
93 |
94 | /*
95 | logger.trace(grammar);
96 | logger.trace(lexicon);
97 | System.exit(0);
98 | */
99 |
100 | lexicon.labelTrees(trainTrees); // FIXME no errors, just alert you to pay attention to it
101 | lexicon.labelTrees(testTrees); // save the search time cost by finding a specific tag-word
102 | lexicon.labelTrees(devTrees); // pair in in Lexicon.score(...)
103 |
104 | scorer = new EnglishPennTreebankParseEvaluator.LabeledConstituentEval(
105 | new HashSet(Arrays.asList(new String[] { "ROOT", "PSEUDO" })),
106 | new HashSet(Arrays.asList(new String[] { "''", "``", ".", ":", "," })));
107 | pcfgParser = new PCFGParser, Tree>(grammar, lexicon, opts.maxslen,
108 | opts.ntcyker, opts.pcyker, opts.ef1prune, false);
109 | mparser = new ThreadPool(pcfgParser, opts.nttest);
110 |
111 | logger.info("\n---F1 CONFIG---\n[parallel: batch-" + opts.pbatch + ", grad-" +
112 | opts.pgrad + ", eval-" + opts.peval + ", test-" + opts.pf1 + "]\n\n");
113 |
114 | sorter = new PriorityQueue<>(opts.bsize + 5, wcomparator);
115 |
116 | StringBuffer sb = new StringBuffer();
117 | sb.append("[test ]" + f1entry(testTrees, numberer, false) + "\n");
118 | if (opts.ef1ontrain) {
119 | scorer.reset();
120 | sb.append("[train]" + f1entry(trainTrees, numberer, true) + "\n");
121 | }
122 | if (opts.ef1ondev) {
123 | scorer.reset();
124 | sb.append("[dev ]" + f1entry(devTrees, numberer, false) + "\n");
125 | }
126 | logger.info("[summary]\n" + sb.toString() + "\n");
127 | // kill threads
128 | grammar.shutdown();
129 | lexicon.shutdown();
130 | mparser.shutdown();
131 | }
132 |
133 |
134 | public static String f1entry(StateTreeList trees, Numberer numberer, boolean istrain) {
135 | if (opts.pf1) {
136 | return parallelFscore(opts, mparser, trees, numberer, istrain);
137 | } else {
138 | return serialFscore(opts, pcfgParser, trees, numberer, istrain);
139 | }
140 | }
141 |
142 |
143 | public static String parallelFscore(Options opts, ThreadPool mparser, StateTreeList stateTreeList, Numberer numberer, boolean istrain) {
144 | Tree goldTree = null;
145 | Tree parsedTree = null;
146 | int nUnparsable = 0, cnt = 0, idx = 0;
147 | List> trees = new ArrayList<>(stateTreeList.size());
148 | filterTrees(opts, stateTreeList, trees, numberer, istrain);
149 |
150 | for (Tree tree : trees) {
151 | mparser.execute(tree);
152 | while (mparser.hasNext()) {
153 | goldTree = trees.get(idx);
154 | parsedTree = (Tree) mparser.getNext();
155 | if (!saveTree(goldTree, parsedTree, numberer, idx)) {
156 | nUnparsable++;
157 | }
158 | idx++;
159 | }
160 | }
161 | while (!mparser.isDone()) {
162 | while (mparser.hasNext()) {
163 | goldTree = trees.get(idx);
164 | parsedTree = (Tree) mparser.getNext();
165 | if (!saveTree(goldTree, parsedTree, numberer, idx)) {
166 | nUnparsable++;
167 | }
168 | idx++;
169 | }
170 | }
171 | mparser.reset();
172 | String summary = scorer.display();
173 | logger.trace("\n[max rule parser: " + nUnparsable + " unparsable sample(s) of " + stateTreeList.size() + "(" + trees.size() + ") samples]\n");
174 | logger.trace(summary + "\n\n");
175 | return summary;
176 | }
177 |
178 |
179 | public static String serialFscore(Options opts, PCFGParser, ?> mrParser, StateTreeList stateTreeList, Numberer numberer, boolean istrain) {
180 |
181 | int nUnparsable = 0, idx = 0;
182 |
183 | List> trees = new ArrayList<>(stateTreeList.size());
184 | filterTrees(opts, stateTreeList, trees, numberer, istrain);
185 |
186 | for (Tree tree : trees) {
187 | Tree parsedTree = mrParser.parse(tree);
188 | if (!saveTree(tree, parsedTree, numberer, idx)) {
189 | nUnparsable++;
190 | }
191 | idx++; // index the State tree
192 | }
193 | String summary = scorer.display();
194 | logger.trace("\n[max rule parser: " + nUnparsable + " unparsable sample(s) of " + stateTreeList.size() + "(" + trees.size() + ") samples]\n");
195 | logger.trace(summary + "\n\n");
196 | return summary;
197 | }
198 |
199 |
200 | public static boolean saveTree(Tree tree, Tree parsedTree, Numberer numberer, int idx) {
201 | try {
202 | Tree goldTree = null;
203 | if (opts.ef1imwrite) {
204 | String treename = treeFile + "_gd_" + idx;
205 | goldTree = StateTreeList.stateTreeToStringTree(tree, numberer);
206 | FunUtil.saveTree2image(null, treename, goldTree, numberer);
207 | goldTree = TreeAnnotations.unAnnotateTree(goldTree, false);
208 | FunUtil.saveTree2image(null, treename + "_ua", goldTree, numberer);
209 |
210 | treename = treeFile + "_te_" + idx;
211 | FunUtil.saveTree2image(null, treename, parsedTree, numberer);
212 | parsedTree = TreeAnnotations.unAnnotateTree(parsedTree, false);
213 | FunUtil.saveTree2image(null, treename + "_ua", parsedTree, numberer);
214 | } else {
215 | goldTree = StateTreeList.stateTreeToStringTree(tree, numberer);
216 | goldTree = TreeAnnotations.unAnnotateTree(goldTree, false);
217 | parsedTree = TreeAnnotations.unAnnotateTree(parsedTree, false);
218 | }
219 | scorer.evaluate(parsedTree, goldTree);
220 | logger.trace(idx + "\tgold : " + goldTree + "\n");
221 | logger.trace(idx + "\tparsed: " + parsedTree + "\n");
222 | return true;
223 | } catch (Exception e) {
224 | e.printStackTrace();
225 | return false;
226 | }
227 | }
228 | }
229 |
--------------------------------------------------------------------------------
/src/main/java/edu/shanghaitech/ai/nlp/lveg/PCFGMaxRule.java:
--------------------------------------------------------------------------------
1 | package edu.shanghaitech.ai.nlp.lveg;
2 |
3 | import java.io.IOException;
4 | import java.nio.charset.StandardCharsets;
5 | import java.util.ArrayList;
6 | import java.util.Arrays;
7 | import java.util.HashSet;
8 | import java.util.List;
9 | import java.util.Map;
10 | import java.util.PriorityQueue;
11 |
12 | import edu.berkeley.nlp.PCFGLA.TreeAnnotations;
13 | import edu.berkeley.nlp.syntax.Tree;
14 | import edu.shanghaitech.ai.nlp.data.StateTreeList;
15 | import edu.shanghaitech.ai.nlp.data.ObjectFileManager.GrammarFile;
16 | import edu.shanghaitech.ai.nlp.eval.EnglishPennTreebankParseEvaluator;
17 | import edu.shanghaitech.ai.nlp.lveg.impl.PCFGMaxRuleParser;
18 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianDistribution;
19 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture;
20 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGGrammar;
21 | import edu.shanghaitech.ai.nlp.lveg.model.LVeGLexicon;
22 | import edu.shanghaitech.ai.nlp.optimization.Optimizer;
23 | import edu.shanghaitech.ai.nlp.syntax.State;
24 | import edu.shanghaitech.ai.nlp.util.FunUtil;
25 | import edu.shanghaitech.ai.nlp.util.Numberer;
26 | import edu.shanghaitech.ai.nlp.util.OptionParser;
27 | import edu.shanghaitech.ai.nlp.util.ThreadPool;
28 |
29 | public class PCFGMaxRule extends LearnerConfig {
30 | /**
31 | *
32 | */
33 | private static final long serialVersionUID = 8232031232691463175L;
34 |
35 | protected static EnglishPennTreebankParseEvaluator.LabeledConstituentEval scorer;
36 | protected static StateTreeList trainTrees;
37 | protected static StateTreeList testTrees;
38 | protected static StateTreeList devTrees;
39 |
40 | protected static LVeGGrammar grammar;
41 | protected static LVeGLexicon lexicon;
42 |
43 | protected static PCFGMaxRuleParser, ?> pcfgMrParser;
44 | protected static ThreadPool mparser;
45 |
46 | protected static String treeFile;
47 | protected static Options opts;
48 |
49 | public static void main(String[] args) throws Exception {
50 | String fparams = args[0];
51 | try {
52 | args = readFile(fparams, StandardCharsets.UTF_8).split(",");
53 | } catch (IOException e) {
54 | e.printStackTrace();
55 | }
56 | OptionParser optionParser = new OptionParser(Options.class);
57 | opts = (Options) optionParser.parse(args, true);
58 | // configurations
59 | initialize(opts, true); // logger can only be used after the initialization
60 | logger.info("Calling with " + optionParser.getParsedOptions() + "\n");
61 |
62 | // loading data
63 | Numberer wrapper = new Numberer();
64 | Map trees = loadData(wrapper, opts);
65 | // training
66 | long startTime = System.currentTimeMillis();
67 | train(trees, wrapper);
68 | long endTime = System.currentTimeMillis();
69 | logger.trace("[total time consumed by LVeG tester] " + (endTime - startTime) / 1000.0 + "\n");
70 | }
71 |
72 |
73 | private static void train(Map trees, Numberer wrapper) throws Exception {
74 | trainTrees = trees.get(ID_TRAIN);
75 | testTrees = trees.get(ID_TEST);
76 | devTrees = trees.get(ID_DEV);
77 |
78 | treeFile = sublogroot + opts.imgprefix;
79 |
80 | Numberer numberer = wrapper.getGlobalNumberer(KEY_TAG_SET);
81 |
82 | /* to ease the parameters tuning */
83 | GaussianMixture.config(opts.maxnbig, opts.expzero, opts.maxmw, opts.ncomponent,
84 | opts.nwratio, opts.riserate, opts.rtratio, opts.hardcut, random, mogPool);
85 | GaussianDistribution.config(opts.maxmu, opts.maxvar, opts.dim, opts.nmratio, opts.nvratio, random, gaussPool);
86 | Optimizer.config(opts.choice, random, opts.maxsample, opts.bsize, opts.minmw, opts.sampling); // FIXME no errors, just alert you...
87 |
88 | // load grammar
89 | logger.trace("--->Loading grammars from \'" + subdatadir + opts.inGrammar + "\'...\n");
90 | GrammarFile gfile = (GrammarFile) GrammarFile.load(subdatadir + opts.inGrammar);
91 | grammar = gfile.getGrammar();
92 | lexicon = gfile.getLexicon();
93 |
94 | /*
95 | logger.trace(grammar);
96 | logger.trace(lexicon);
97 | System.exit(0);
98 | */
99 |
100 | lexicon.labelTrees(trainTrees); // FIXME no errors, just alert you to pay attention to it
101 | lexicon.labelTrees(testTrees); // save the search time cost by finding a specific tag-word
102 | lexicon.labelTrees(devTrees); // pair in in Lexicon.score(...)
103 |
104 | scorer = new EnglishPennTreebankParseEvaluator.LabeledConstituentEval(
105 | new HashSet(Arrays.asList(new String[] { "ROOT", "PSEUDO" })),
106 | new HashSet(Arrays.asList(new String[] { "''", "``", ".", ":", "," })));
107 | pcfgMrParser = new PCFGMaxRuleParser, Tree>(grammar, lexicon, opts.maxslen,
108 | opts.ntcyker, opts.pcyker, opts.ef1prune, false);
109 | mparser = new ThreadPool(pcfgMrParser, opts.nttest);
110 |
111 |
112 | logger.info("\n---F1 CONFIG---\n[parallel: batch-" + opts.pbatch + ", grad-" +
113 | opts.pgrad + ", eval-" + opts.peval + ", test-" + opts.pf1 + "]\n\n");
114 |
115 | sorter = new PriorityQueue<>(opts.bsize + 5, wcomparator);
116 |
117 | StringBuffer sb = new StringBuffer();
118 | sb.append("[test ]" + f1entry(testTrees, numberer, false) + "\n");
119 | if (opts.ef1ontrain) {
120 | scorer.reset();
121 | sb.append("[train]" + f1entry(trainTrees, numberer, true) + "\n");
122 | }
123 | if (opts.ef1ondev) {
124 | scorer.reset();
125 | sb.append("[dev ]" + f1entry(devTrees, numberer, false) + "\n");
126 | }
127 | logger.info("[summary]\n" + sb.toString() + "\n");
128 | // kill threads
129 | grammar.shutdown();
130 | lexicon.shutdown();
131 | mparser.shutdown();
132 | }
133 |
134 |
135 | public static String f1entry(StateTreeList trees, Numberer numberer, boolean istrain) {
136 | if (opts.pf1) {
137 | return parallelFscore(opts, mparser, trees, numberer, istrain);
138 | } else {
139 | return serialFscore(opts, pcfgMrParser, trees, numberer, istrain);
140 | }
141 | }
142 |
143 |
144 | public static String parallelFscore(Options opts, ThreadPool mparser, StateTreeList stateTreeList, Numberer numberer, boolean istrain) {
145 | Tree goldTree = null;
146 | Tree parsedTree = null;
147 | int nUnparsable = 0, cnt = 0, idx = 0;
148 | List> trees = new ArrayList<>(stateTreeList.size());
149 | filterTrees(opts, stateTreeList, trees, numberer, istrain);
150 |
151 | for (Tree tree : trees) {
152 | mparser.execute(tree);
153 | while (mparser.hasNext()) {
154 | goldTree = trees.get(idx);
155 | parsedTree = (Tree) mparser.getNext();
156 | if (!saveTree(goldTree, parsedTree, numberer, idx)) {
157 | nUnparsable++;
158 | }
159 | idx++;
160 | }
161 | }
162 | while (!mparser.isDone()) {
163 | while (mparser.hasNext()) {
164 | goldTree = trees.get(idx);
165 | parsedTree = (Tree) mparser.getNext();
166 | if (!saveTree(goldTree, parsedTree, numberer, idx)) {
167 | nUnparsable++;
168 | }
169 | idx++;
170 | }
171 | }
172 | mparser.reset();
173 | String summary = scorer.display();
174 | logger.trace("\n[max rule parser: " + nUnparsable + " unparsable sample(s) of " + stateTreeList.size() + "(" + trees.size() + ") samples]\n");
175 | logger.trace(summary + "\n\n");
176 | return summary;
177 | }
178 |
179 |
180 | public static String serialFscore(Options opts, PCFGMaxRuleParser, ?> mrParser, StateTreeList stateTreeList, Numberer numberer, boolean istrain) {
181 |
182 | int nUnparsable = 0, idx = 0;
183 |
184 | List> trees = new ArrayList<>(stateTreeList.size());
185 | filterTrees(opts, stateTreeList, trees, numberer, istrain);
186 |
187 | for (Tree tree : trees) {
188 | Tree parsedTree = mrParser.parse(tree);
189 | if (!saveTree(tree, parsedTree, numberer, idx)) {
190 | nUnparsable++;
191 | }
192 | idx++; // index the State tree
193 | }
194 | String summary = scorer.display();
195 | logger.trace("\n[max rule parser: " + nUnparsable + " unparsable sample(s) of " + stateTreeList.size() + "(" + trees.size() + ") samples]\n");
196 | logger.trace(summary + "\n\n");
197 | return summary;
198 | }
199 |
200 |
201 | public static boolean saveTree(Tree tree, Tree parsedTree, Numberer numberer, int idx) {
202 | try {
203 | Tree goldTree = null;
204 | if (opts.ef1imwrite) {
205 | String treename = treeFile + "_gd_" + idx;
206 | goldTree = StateTreeList.stateTreeToStringTree(tree, numberer);
207 | FunUtil.saveTree2image(null, treename, goldTree, numberer);
208 | goldTree = TreeAnnotations.unAnnotateTree(goldTree, false);
209 | FunUtil.saveTree2image(null, treename + "_ua", goldTree, numberer);
210 |
211 | treename = treeFile + "_te_" + idx;
212 | FunUtil.saveTree2image(null, treename, parsedTree, numberer);
213 | parsedTree = TreeAnnotations.unAnnotateTree(parsedTree, false);
214 | FunUtil.saveTree2image(null, treename + "_ua", parsedTree, numberer);
215 | } else {
216 | goldTree = StateTreeList.stateTreeToStringTree(tree, numberer);
217 | goldTree = TreeAnnotations.unAnnotateTree(goldTree, false);
218 | parsedTree = TreeAnnotations.unAnnotateTree(parsedTree, false);
219 | }
220 | scorer.evaluate(parsedTree, goldTree);
221 | logger.trace(idx + "\tgold : " + goldTree + "\n");
222 | logger.trace(idx + "\tparsed: " + parsedTree + "\n");
223 | return true;
224 | } catch (Exception e) {
225 | e.printStackTrace();
226 | return false;
227 | }
228 | }
229 | }
230 |
--------------------------------------------------------------------------------
/src/main/java/edu/shanghaitech/ai/nlp/lveg/impl/BinaryGrammarRule.java:
--------------------------------------------------------------------------------
1 | package edu.shanghaitech.ai.nlp.lveg.impl;
2 |
3 | import edu.shanghaitech.ai.nlp.lveg.model.GaussianMixture;
4 | import edu.shanghaitech.ai.nlp.lveg.model.GrammarRule;
5 |
6 | /**
7 | * @author Yanpeng Zhao
8 | *
9 | */
10 | public class BinaryGrammarRule extends GrammarRule implements Comparable