├── .gitignore
├── README.md
├── lexicon
├── pos.txt
├── rels.txt
└── words.txt
├── pathlstm.jar
├── pom.xml
├── scripts
├── parse.sh
└── parse_fnet.sh
└── src
└── main
└── java
├── dmonner
└── xlbp
│ ├── AdadeltaBasicWeightUpdater.java
│ ├── AdadeltaBatchWeightUpdater.java
│ ├── AdagradBatchWeightUpdater.java
│ ├── AdamBasicWeightUpdater.java
│ ├── BasicWeightUpdater.java
│ ├── BatchWeightUpdater.java
│ ├── Component.java
│ ├── DecayWeightUpdater.java
│ ├── DownstreamComponent.java
│ ├── ForceAdjacencyListWeightInitializer.java
│ ├── Function.java
│ ├── Input.java
│ ├── InputComponent.java
│ ├── InternalComponent.java
│ ├── MomentumWeightUpdater.java
│ ├── NadamBatchWeightUpdater.java
│ ├── Network.java
│ ├── NetworkCopier.java
│ ├── NetworkStringBuilder.java
│ ├── NormalWeightInitializer.java
│ ├── ResilientEligibilitiesWeightUpdater.java
│ ├── ResilientWeightUpdater.java
│ ├── Responsibilities.java
│ ├── SetWeightInitializer.java
│ ├── Target.java
│ ├── TargetComponent.java
│ ├── UniformWeightInitializer.java
│ ├── UpstreamComponent.java
│ ├── WeightInitializer.java
│ ├── WeightUpdater.java
│ ├── WeightUpdaterType.java
│ ├── compound
│ ├── AbstractCompound.java
│ ├── AbstractInternalCompound.java
│ ├── AbstractWeightedCompound.java
│ ├── Compound.java
│ ├── ConvolutionCompound.java
│ ├── DiagonalWeightBank.java
│ ├── DropoutInputCompound.java
│ ├── FunctionCompound.java
│ ├── IndirectWeightBank.java
│ ├── InputCompound.java
│ ├── InternalCompound.java
│ ├── LinearCompound.java
│ ├── LinearTargetCompound.java
│ ├── LogisticCompound.java
│ ├── MemoryCellCompound.java
│ ├── MemoryCompound.java
│ ├── MultiBindCompound.java
│ ├── PiCompound.java
│ ├── RectifiedLinearCompound.java
│ ├── SharedDiagonalWeightBank.java
│ ├── SimpleCompound.java
│ ├── SingletonCompound.java
│ ├── SumOfSquaresTargetCompound.java
│ ├── TanhCompound.java
│ ├── TanhTargetCompound.java
│ ├── TargetCompound.java
│ ├── WeightBank.java
│ ├── WeightedCompound.java
│ └── XEntropyTargetCompound.java
│ ├── connection
│ ├── AdjacencyListConnection.java
│ ├── AdjacencyMatrixConnection.java
│ ├── BiasConnection.java
│ ├── Connection.java
│ ├── ConnectionType.java
│ ├── DiagonalConnection.java
│ ├── ImmutableDiagonalConnection.java
│ ├── IndirectConnection.java
│ └── LayerConnection.java
│ ├── example
│ └── SequentialParity.java
│ ├── layer
│ ├── AbstractDownstreamLayer.java
│ ├── AbstractFanInLayer.java
│ ├── AbstractFanOutLayer.java
│ ├── AbstractFunctionLayer.java
│ ├── AbstractInternalLayer.java
│ ├── AbstractLayer.java
│ ├── AbstractTargetLayer.java
│ ├── AbstractUpstreamLayer.java
│ ├── BiasLayer.java
│ ├── CopyDestinationLayer.java
│ ├── CopySourceLayer.java
│ ├── DirectResponsibilityLayer.java
│ ├── DownstreamLayer.java
│ ├── DropoutPiLayer.java
│ ├── EndCapLayer.java
│ ├── FanOutLayer.java
│ ├── FunctionLayer.java
│ ├── InputLayer.java
│ ├── InternalLayer.java
│ ├── Layer.java
│ ├── LinearLayer.java
│ ├── LogisticLayer.java
│ ├── OffsetLayer.java
│ ├── PiLayer.java
│ ├── RectifiedLinearLayer.java
│ ├── RepulsionLayer.java
│ ├── ScaleLayer.java
│ ├── SigmaLayer.java
│ ├── SumOfSquaresTargetLayer.java
│ ├── TanhLayer.java
│ ├── TargetLayer.java
│ ├── UpstreamLayer.java
│ ├── WeightReceiverLayer.java
│ ├── WeightSenderLayer.java
│ ├── WeightedLayer.java
│ ├── XEntropyLogisticLayer.java
│ └── XEntropyTargetLayer.java
│ ├── stat
│ ├── AbstractStat.java
│ ├── BitDistStat.java
│ ├── BitStat.java
│ ├── ConnectionStat.java
│ ├── ErrorStat.java
│ ├── FractionStat.java
│ ├── MeanVarStat.java
│ ├── NetworkStat.java
│ ├── Optimizer.java
│ ├── SetStat.java
│ ├── Stat.java
│ ├── StepStat.java
│ ├── TargetSetStat.java
│ ├── TargetStat.java
│ ├── TestStat.java
│ └── TrialStat.java
│ ├── trial
│ ├── AbstractTrialStream.java
│ ├── HashTrialStream.java
│ ├── LayerCheck.java
│ ├── NeverBreaker.java
│ ├── PerfectBreaker.java
│ ├── Step.java
│ ├── StepRecord.java
│ ├── Trainer.java
│ ├── TrainingBreaker.java
│ ├── Trial.java
│ ├── TrialRecord.java
│ ├── TrialSet.java
│ ├── TrialStream.java
│ ├── TrialStreamAdapter.java
│ └── ValidationBreaker.java
│ └── util
│ ├── ArrayQueue.java
│ ├── CSVWriter.java
│ ├── IndexAwareHeap.java
│ ├── IndexAwareHeapNode.java
│ ├── MatrixTools.java
│ ├── NoiseGenerator.java
│ ├── NormalNoiseGenerator.java
│ ├── ReflectionTools.java
│ ├── SlidingMedian.java
│ └── TableWriter.java
├── se
└── lth
│ └── cs
│ └── srl
│ ├── CompletePipeline.java
│ ├── Parse.java
│ ├── SemanticRoleLabeler.java
│ ├── corpus
│ ├── ArgMap.java
│ ├── ConstituentBuilder.java
│ ├── CorefChain.java
│ ├── Predicate.java
│ ├── PredicateReference.java
│ ├── Sentence.java
│ ├── StringInText.java
│ ├── Word.java
│ └── Yield.java
│ ├── io
│ ├── ANNWriter.java
│ ├── AbstractCoNLL09Reader.java
│ ├── AllCoNLL09Reader.java
│ ├── CoNLL09Writer.java
│ ├── DepsOnlyCoNLL09Reader.java
│ ├── FrameNetXMLWriter.java
│ ├── SRLOnlyCoNLL09Reader.java
│ ├── SentenceReader.java
│ └── SentenceWriter.java
│ ├── languages
│ ├── AbstractDummyLanguage.java
│ ├── Chinese.java
│ ├── Czech.java
│ ├── English.java
│ ├── German.java
│ ├── Language.java
│ ├── NullLanguage.java
│ └── Spanish.java
│ ├── options
│ ├── CompletePipelineCMDLineOptions.java
│ ├── FullPipelineOptions.java
│ ├── Options.java
│ └── ParseOptions.java
│ ├── pipeline
│ ├── AbstractStep.java
│ ├── ArgumentClassifier.java
│ ├── ArgumentIdentifier.java
│ ├── ArgumentStep.java
│ ├── LBJavaArgumentClassifier.java
│ ├── Pipeline.java
│ ├── PipelineStep.java
│ ├── PredicateDisambiguator.java
│ ├── PredicateIdentifier.java
│ ├── Reranker.java
│ └── Step.java
│ ├── preprocessor
│ ├── CMDLineTokenizer.java
│ ├── IllinoisPreprocessor.java
│ ├── PipelinedPreprocessor.java
│ ├── Preprocessor.java
│ ├── SimpleChineseLemmatizer.java
│ ├── StanfordPreprocessor.java
│ └── tokenization
│ │ ├── OpenNLPToolsTokenizerWrapper.java
│ │ ├── StanfordPTBTokenizer.java
│ │ ├── Tokenizer.java
│ │ ├── WhiteSpaceTokenizer.java
│ │ └── exner
│ │ ├── SwedishTokenizer.java
│ │ └── Tokenizer.java
│ └── util
│ ├── BohnetHelper.java
│ ├── BrownCluster.java
│ ├── ChineseDesegmenter.java
│ ├── DasFilter.java
│ ├── FileExistenceVerifier.java
│ ├── Relation.java
│ ├── Sentence2RDF.java
│ ├── SentenceAnnotation.java
│ ├── StandOffAnnotation.java
│ ├── TurboParser.java
│ ├── Util.java
│ └── WordEmbedding.java
└── uk
└── ac
└── ed
└── inf
└── srl
├── features
├── AnySetFeature.java
├── ArgDependentAttrFeature.java
├── ArgDependentBrown.java
├── ArgDependentEmbedding.java
├── ArgDependentFeatsFeature.java
├── AttrFeature.java
├── BrownPathFeature.java
├── ChildSetFeature.java
├── ContinuousArgDependentAttrFeature.java
├── ContinuousAttrFeature.java
├── ContinuousFeature.java
├── ContinuousSetFeature.java
├── DepSubCatFeature.java
├── DependencyCPathEmbedding.java
├── DependencyIPathEmbedding.java
├── DependencyPathEmbedding.java
├── DistanceFeature.java
├── DumpFeature.java
├── EmbeddingPath.java
├── FeatsFeature.java
├── Feature.java
├── FeatureFile.java
├── FeatureGenerator.java
├── FeatureName.java
├── FeatureSet.java
├── NumFeature.java
├── PBLabelFeature.java
├── PathFeature.java
├── PathItemSetFeature.java
├── PathLengthFeature.java
├── PositionFeature.java
├── PredDependentAttrFeature.java
├── PredDependentBrown.java
├── PredDependentEmbedding.java
├── PredDependentFeatsFeature.java
├── QContinuousSetFeature.java
├── QContinuousSingleFeature.java
├── QDoubleChildSetFeature.java
├── QSetSetFeature.java
├── QSingleSetFeature.java
├── QSingleSingleFeature.java
├── QuadraticFeature.java
├── SameSubTreeFeature.java
├── SetFeature.java
├── SingleFeature.java
├── SpanLengthFeature.java
├── SubCatSizeFeature.java
├── TargetWord.java
└── WordExtractor.java
├── lstm
├── DataConverter.java
├── DataReader.java
├── EmbeddingNetwork.java
├── NetworkOptions.java
├── NetworkRunner.java
└── layer
│ └── EmbeddingLayer.java
└── ml
├── LearningProblem.java
├── Model.java
└── liblinear
├── Label.java
├── LibLinearLearningProblem.java
├── LibLinearModel.java
└── WeightVector.java
/.gitignore:
--------------------------------------------------------------------------------
1 | /.idea/
2 | *.iml
3 | /target/
4 | /models/
5 | /lib/
6 |
--------------------------------------------------------------------------------
/lexicon/pos.txt:
--------------------------------------------------------------------------------
1 | 1 PDT
2 | 2 CC
3 | 3 NNP
4 | 4 ,
5 | 5 WP$
6 | 6 VBN
7 | 7 WP
8 | 8 RBR
9 | 9 CD
10 | 10 RP
11 | 11 JJ
12 | 12 PRP
13 | 13 TO
14 | 14 HYPH
15 | 15 EX
16 | 16 WRB
17 | 17 RB
18 | 18 WDT
19 | 19 VBP
20 | 20 JJR
21 | 21 VBZ
22 | 22 PRF
23 | 23 NNPS
24 | 24 (
25 | 25 UH
26 | 26 POS
27 | 27 $
28 | 28 ``
29 | 29 :
30 | 30 JJS
31 | 31 LS
32 | 32 VB
33 | 33 .
34 | 34 MD
35 | 35 NN
36 | 36 NNS
37 | 37 DT
38 | 38 VBD
39 | 39 #
40 | 40 ''
41 | 41 RBS
42 | 42 IN
43 | 43 SYM
44 | 44 )
45 | 45 PRP$
46 | 46 VBG
47 | 47 -1
48 |
--------------------------------------------------------------------------------
/lexicon/rels.txt:
--------------------------------------------------------------------------------
1 | 1 DIR-PRD^
2 | 2 GAP-PMOD^
3 | 3 PRD-TMP^
4 | 4 DIR-GAP^
5 | 5 DIR-OPRD^
6 | 6 SBJv
7 | 7 GAP-OPRD^
8 | 8 GAP-TMP^
9 | 9 OBJv
10 | 10 GAP-PRDv
11 | 11 GAP-LGSv
12 | 12 PRT^
13 | 13 PMODv
14 | 14 SUFFIX^
15 | 15 GAP-LOC-PRD^
16 | 16 GAP-PRD^
17 | 17 VOC^
18 | 18 DEP^
19 | 19 PRP^
20 | 20 TMPv
21 | 21 NMOD^
22 | 22 LOC-OPRDv
23 | 23 LOC-PRD^
24 | 24 DTVv
25 | 25 GAP-LOCv
26 | 26 P^
27 | 27 GAP-VC^
28 | 28 MNR-PRD^
29 | 29 SUB^
30 | 30 EXTRv
31 | 31 EXT^
32 | 32 GAP-PMODv
33 | 33 GAP-LOC^
34 | 34 PMOD^
35 | 35 DEP-GAPv
36 | 36 GAP-LGS^
37 | 37 PRD-PRPv
38 | 38 PRNv
39 | 39 AMODv
40 | 40 OPRDv
41 | 41 HYPHv
42 | 42 LOCv
43 | 43 BNF^
44 | 44 LGSv
45 | 45 PRD^
46 | 46 IMv
47 | 47 GAP-OPRDv
48 | 48 MNR-TMP^
49 | 49 HMOD^
50 | 50 ADV^
51 | 51 DEPv
52 | 52 LOC^
53 | 53 ADV-GAPv
54 | 54 PRDv
55 | 55 GAP-SBJv
56 | 56 Pv
57 | 57 APPOv
58 | 58 LOC-OPRD^
59 | 59 POSTHONv
60 | 60 DTV^
61 | 61 PRTv
62 | 62 COORDv
63 | 63 OBJ^
64 | 64 HYPH^
65 | 65 PUT^
66 | 66 GAP-VCv
67 | 67 EXTR^
68 | 68 GAP-TMPv
69 | 69 GAP-OBJ^
70 | 70 VOCv
71 | 71 NMODv
72 | 72 APPO^
73 | 73 DIR^
74 | 74 ADV-GAP^
75 | 75 EXT-GAPv
76 | 76 VC^
77 | 77 PRN^
78 | 78 VCv
79 | 79 DEP-GAP^
80 | 80 LOC-TMP^
81 | 81 TITLEv
82 | 82 NAME^
83 | 83 GAP-PRP^
84 | 84 EXT-GAP^
85 | 85 MNRv
86 | 86 LOC-PRDv
87 | 87 PUTv
88 | 88 SUBv
89 | 89 EXTv
90 | 90 NAMEv
91 | 91 PRPv
92 | 92 DIRv
93 | 93 TMP^
94 | 94 TITLE^
95 | 95 PRD-PRP^
96 | 96 GAP-OBJv
97 | 97 OPRD^
98 | 98 POSTHON^
99 | 99 GAP-NMOD^
100 | 100 IM^
101 | 101 MNR^
102 | 102 PRD-TMPv
103 | 103 SUFFIXv
104 | 104 AMOD^
105 | 105 HMODv
106 | 106 ADVv
107 | 107 BNFv
108 | 108 GAP-SBJ^
109 | 109 LGS^
110 | 110 CONJ^
111 | 111 SBJ^
112 | 112 ROOTv
113 | 113 DIR-GAPv
114 | 114 COORD^
115 | 115 ROOT^
116 | 116 GAP-MNR^
117 | 117 CONJv
118 |
--------------------------------------------------------------------------------
/pathlstm.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microth/PathLSTM/d09533e6f3b74505ee86376bdf96f2851bd8428d/pathlstm.jar
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | uk.ac.ed.inf
8 | PathLSTM
9 | 1.0.0
10 |
11 |
12 |
13 | CogCompSoftware
14 | CogComp software repository
15 | http://cogcomp.cs.illinois.edu/m2repo/
16 |
17 |
18 |
19 |
20 |
21 | com.googlecode.mate-tools
22 | anna
23 | 3.5
24 |
25 |
26 | edu.stanford.nlp
27 | stanford-corenlp
28 | 3.6.0
29 |
30 |
31 | org.apache.opennlp
32 | opennlp-tools
33 | 1.6.0
34 |
35 |
36 | edu.illinois.cs.cogcomp
37 | LBJava
38 | 1.2.26
39 |
40 |
41 | net.sf.trove4j
42 | trove4j
43 | 3.0.3
44 |
45 |
46 |
47 |
48 |
49 |
50 | org.apache.maven.plugins
51 | maven-compiler-plugin
52 | 3.6.0
53 |
54 | 1.8
55 | 1.8
56 |
57 |
58 |
59 | org.apache.maven.plugins
60 | maven-source-plugin
61 | 2.1.2
62 |
63 |
64 | attach-sources
65 |
66 | jar
67 |
68 |
69 |
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/scripts/parse.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # please download these models and adjust their locations accordingly
3 | LEMMA_MODEL=models/lemma-eng.model
4 | POS_MODEL=models/tagger-eng.model
5 | PARSER_MODEL=models/parse-eng.model
6 | SRL_MODEL=models/srl-ACL2016-eng.model
7 |
8 | RERANKER="-reranker -externalNNs"
9 |
10 | # Stanford CoreNLP (WSJTokenizer) needed for tokenization
11 | TOKENIZE="-tokenize"
12 | STANFORD=lib/stanford-corenlp-3.7.0.jar
13 |
14 | # java 1.8+ is required
15 | JAVA=java
16 |
17 | # parse $1
18 | $JAVA -Xmx60g -cp lib/anna-3.3.jar:$STANFORD:target/classes/ se.lth.cs.srl.CompletePipeline eng -lemma $LEMMA_MODEL -parser $PARSER_MODEL -tagger $POS_MODEL -srl $SRL_MODEL $RERANKER $TOKENIZE -test $1
19 | # note: make sure that the compiled class files (run "mvn compile") are located in target/classes/
20 |
--------------------------------------------------------------------------------
/scripts/parse_fnet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # please download these dependencies and adjust their locations accordingly
3 | SRL_MODEL=models/srl-ICCG16-stanford-eng.model
4 | FN_DATA=framenet/fndata-1.7/
5 |
6 | #RERANKER="-aibeam 7 -acbeam 3 -alfa 0.75 -reranker -externalNNs -globalFeats"
7 | RERANKER="-reranker -externalNNs -globalFeats"
8 |
9 | # java 1.8+ is required
10 | JAVA=java
11 |
12 | $JAVA -Xmx60g -cp target/classes/:lib/anna-3.3.jar:lib/stanford-corenlp-3.8.0.jar:lib/stanford-corenlp-3.8.0-models.jar se.lth.cs.srl.CompletePipeline fnet -test $1 -srl $SRL_MODEL $RERANKER -tokenize -framenet $FN_DATA -stanford -out out.conll
13 | # note: make sure that the compiled class files (run "mvn compile") are located in target/classes/ or adjust class path to include precompiled .jar
14 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/AdadeltaBasicWeightUpdater.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp;
2 |
3 | import dmonner.xlbp.connection.Connection;
4 |
5 | public class AdadeltaBasicWeightUpdater implements WeightUpdater
6 | {
7 | public static final long serialVersionUID = 1L;
8 |
9 | private final float beta;
10 | private final float eps;
11 | private final Connection parent;
12 | private double[][] squared_gradients;
13 | private double[][] squared_deltas;
14 |
15 | public AdadeltaBasicWeightUpdater(final Connection parent, final float beta, final float eps)
16 | {
17 | this.parent = parent;
18 | this.beta = beta;
19 | this.eps = eps;
20 | }
21 |
22 | @Override
23 | public Connection getConnection()
24 | {
25 | return parent;
26 | }
27 |
28 | @Override
29 | public float getUpdate(final int i, final float dw)
30 | {
31 | squared_gradients[i][0] = beta * squared_gradients[i][0] + (1-beta) * (dw*dw);
32 | float retval = (float)(Math.sqrt(squared_deltas[i][0] + eps) / Math.sqrt(squared_gradients[i][0] + eps) * dw);
33 | squared_deltas[i][0] = beta * squared_deltas[i][0] + (1-beta) * (retval*retval);
34 | return retval;
35 | }
36 |
37 | @Override
38 | public float getUpdate(final int j, final int i, final float dw)
39 | {
40 | squared_gradients[i][j] = beta * squared_gradients[i][j] + (1-beta) * (dw*dw);
41 | float retval = (float)(Math.sqrt(squared_deltas[i][j] + eps) / Math.sqrt(squared_gradients[i][j] + eps) * dw);
42 | squared_deltas[i][j] = beta * squared_deltas[i][j] + (1-beta) * (retval*retval);
43 | return retval;
44 | }
45 |
46 | @Override
47 | public void initialize(final int size)
48 | {
49 | // put 1D array in the 2nd dimension, for array access efficiency
50 | initialize(1, size);
51 | }
52 |
53 | @Override
54 | public void initialize(final int to, final int from)
55 | {
56 | squared_gradients = new double[from][to];
57 | squared_deltas = new double[from][to];
58 | }
59 |
60 | @Override
61 | public void processBatch()
62 | {
63 | }
64 |
65 | @Override
66 | public void toString(final NetworkStringBuilder sb)
67 | {
68 | if(sb.showLearningRates())
69 | sb.appendln("TODO");
70 | }
71 |
72 | @Override
73 | public void updateFromBiases(final float[] d)
74 | {
75 | }
76 |
77 | @Override
78 | public void updateFromEligibilities(final float[][] e, final float[] d)
79 | {
80 | }
81 |
82 | @Override
83 | public void updateFromInputs(final float[] in, final float[] d)
84 | {
85 | }
86 |
87 | @Override
88 | public void updateFromVector(final float[] v, final float[] d)
89 | {
90 | }
91 | }
92 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/AdamBasicWeightUpdater.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp;
2 |
3 | import dmonner.xlbp.connection.Connection;
4 |
5 | public class AdamBasicWeightUpdater implements WeightUpdater
6 | {
7 | public static final long serialVersionUID = 1L;
8 |
9 | private final float beta;
10 | private final float beta2;
11 | private final float alpha;
12 | private final float eps;
13 | private final Connection parent;
14 | private double[][] squared_deltas;
15 | private double[][] m;
16 | private double[][] v;
17 | public static int t;
18 |
19 | private double[] betas;
20 | private double[] betas2;
21 |
22 | public AdamBasicWeightUpdater(final Connection parent, final float beta, final float beta2, final float eps, final float alpha)
23 | {
24 | this.parent = parent;
25 | this.beta = beta;
26 | this.beta2 = beta2;
27 | this.alpha = alpha;
28 | this.eps = eps;
29 | t = 1;
30 |
31 | }
32 |
33 | @Override
34 | public Connection getConnection()
35 | {
36 | return parent;
37 | }
38 |
39 | @Override
40 | public float getUpdate(final int i, final float dw)
41 | {
42 | m[i][0] = beta * m[i][0] + (1-beta) * dw;
43 | v[i][0] = beta2* v[i][0] + (1-beta2) * (dw*dw);
44 |
45 | return (float)(alpha * (m[i][0]/(double)(1-Math.pow(beta,t))) / (Math.sqrt(v[i][0]/(double)(1-Math.pow(beta2,t)))+eps));
46 | }
47 |
48 | @Override
49 | public float getUpdate(final int j, final int i, final float dw)
50 | {
51 | //if(i==3 && j==3) System.err.println(parent.getName() + "\t" + dw);
52 | m[i][j] = beta * m[i][j] + (1-beta) * dw;
53 | v[i][j] = beta2* v[i][j] + (1-beta2) * (dw*dw);
54 | return (float)(alpha * (m[i][j]/(double)(1-Math.pow(beta,t))) / (Math.sqrt(v[i][j]/(double)(1-Math.pow(beta2,t)))+eps));
55 | }
56 |
57 | @Override
58 | public void initialize(final int size)
59 | {
60 | // put 1D array in the 2nd dimension, for array access efficiency
61 | initialize(1, size);
62 | }
63 |
64 | @Override
65 | public void initialize(final int to, final int from)
66 | {
67 | m = new double[from][to];
68 | v = new double[from][to];
69 | squared_deltas = new double[from][to];
70 | }
71 |
72 | @Override
73 | public void processBatch()
74 | {
75 | }
76 |
77 | @Override
78 | public void toString(final NetworkStringBuilder sb)
79 | {
80 | if(sb.showLearningRates())
81 | sb.appendln("TODO");
82 | }
83 |
84 | @Override
85 | public void updateFromBiases(final float[] d)
86 | {
87 | }
88 |
89 | @Override
90 | public void updateFromEligibilities(final float[][] e, final float[] d)
91 | {
92 | }
93 |
94 | @Override
95 | public void updateFromInputs(final float[] in, final float[] d)
96 | {
97 | }
98 |
99 | @Override
100 | public void updateFromVector(final float[] v, final float[] d)
101 | {
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/BasicWeightUpdater.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp;
2 |
3 | import dmonner.xlbp.connection.Connection;
4 |
5 | public class BasicWeightUpdater implements WeightUpdater
6 | {
7 | public static final long serialVersionUID = 1L;
8 |
9 | public static float HALF = 1.0F;
10 |
11 | private final Connection parent;
12 | private final float a;
13 |
14 | public BasicWeightUpdater(final Connection parent)
15 | {
16 | this(parent, 0.1F);
17 | }
18 |
19 | public BasicWeightUpdater(final Connection parent, final float a)
20 | {
21 | this.parent = parent;
22 | this.a = a;
23 | }
24 |
25 | @Override
26 | public Connection getConnection()
27 | {
28 | return parent;
29 | }
30 |
31 | @Override
32 | public float getUpdate(final int i, final float dw)
33 | {
34 | return (a/HALF) * dw;
35 | }
36 |
37 | @Override
38 | public float getUpdate(final int j, final int i, final float dw)
39 | {
40 | return (a/HALF) * dw;
41 | }
42 |
43 | @Override
44 | public void initialize(final int size)
45 | {
46 | }
47 |
48 | @Override
49 | public void initialize(final int to, final int from)
50 | {
51 | }
52 |
53 | @Override
54 | public void processBatch()
55 | {
56 | }
57 |
58 | @Override
59 | public void toString(final NetworkStringBuilder sb)
60 | {
61 | if(sb.showLearningRates())
62 | sb.appendln("Learning Rate:" + a);
63 | }
64 |
65 | @Override
66 | public void updateFromBiases(final float[] d)
67 | {
68 | }
69 |
70 | @Override
71 | public void updateFromEligibilities(final float[][] e, final float[] d)
72 | {
73 | }
74 |
75 | @Override
76 | public void updateFromInputs(final float[] in, final float[] d)
77 | {
78 | }
79 |
80 | @Override
81 | public void updateFromVector(final float[] v, final float[] d)
82 | {
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/Component.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp;
2 |
3 | import java.io.Serializable;
4 |
5 | public interface Component extends Serializable, Comparable
6 | {
7 | public void activateTest();
8 |
9 | public void activateTrain();
10 |
11 | public void build();
12 |
13 | public void clear();
14 |
15 | public void clearActivations();
16 |
17 | public void clearEligibilities();
18 |
19 | public void clearResponsibilities();
20 |
21 | public Component copy(NetworkCopier copier);
22 |
23 | public Component copy(String nameSuffix);
24 |
25 | public void copyConnectivityFrom(Component comp, NetworkCopier copier);
26 |
27 | public String getName();
28 |
29 | public boolean isBuilt();
30 |
31 | public int nWeights();
32 |
33 | /**
34 | * To be called once after network is set up. For individual layers, checks that they have all the
35 | * necessary inputs and outputs. For fan-in or fan-out layers, checks to see if they only have one
36 | * input/output in the direction that should have multiples; if so, bows itself out and connects
37 | * input to output directly. For higher level components, keeps track of this process and reworks
38 | * internal pointers.
39 | *
40 | * @return false iff this component has removed itself from the network; true otherwise.
41 | */
42 | public boolean optimize();
43 |
44 | public void processBatch();
45 |
46 | public void setWeightInitializer(WeightInitializer win);
47 |
48 | public void setWeightUpdaterType(WeightUpdaterType wut);
49 |
50 | /**
51 | * @return The name of this component.
52 | */
53 | @Override
54 | public String toString();
55 |
56 | public void toString(NetworkStringBuilder sb);
57 |
58 | public String toString(String show);
59 |
60 | public void unbuild();
61 |
62 | public void updateEligibilities();
63 |
64 | public void updateResponsibilities();
65 |
66 | public void updateWeights();
67 | }
68 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/DecayWeightUpdater.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp;
2 |
3 | import dmonner.xlbp.connection.Connection;
4 |
5 | public class DecayWeightUpdater implements WeightUpdater
6 | {
7 | public static final long serialVersionUID = 1L;
8 |
9 | private final Connection parent;
10 | private final float a;
11 | private final float b;
12 |
13 | public DecayWeightUpdater(final Connection parent)
14 | {
15 | this(parent, 0.1F, 0.001F);
16 | }
17 |
18 | public DecayWeightUpdater(final Connection parent, final float a, final float b)
19 | {
20 | this.parent = parent;
21 | this.a = a;
22 | this.b = b;
23 | }
24 |
25 | @Override
26 | public Connection getConnection()
27 | {
28 | return parent;
29 | }
30 |
31 | @Override
32 | public float getUpdate(final int j, final float dw)
33 | {
34 | return a * (dw - b * parent.getWeight(j, 0));
35 | }
36 |
37 | @Override
38 | public float getUpdate(final int j, final int i, final float dw)
39 | {
40 | return a * (dw - b * parent.getWeight(j, i));
41 | }
42 |
43 | @Override
44 | public void initialize(final int size)
45 | {
46 | }
47 |
48 | @Override
49 | public void initialize(final int to, final int from)
50 | {
51 | }
52 |
53 | @Override
54 | public void processBatch()
55 | {
56 | }
57 |
58 | @Override
59 | public void toString(final NetworkStringBuilder sb)
60 | {
61 | if(sb.showLearningRates())
62 | {
63 | sb.appendln("Learning Rate:" + a);
64 | sb.appendln("Weight Decay Rate:" + b);
65 | }
66 | }
67 |
68 | @Override
69 | public void updateFromBiases(final float[] d)
70 | {
71 | }
72 |
73 | @Override
74 | public void updateFromEligibilities(final float[][] e, final float[] d)
75 | {
76 | }
77 |
78 | @Override
79 | public void updateFromInputs(final float[] in, final float[] d)
80 | {
81 | }
82 |
83 | @Override
84 | public void updateFromVector(final float[] v, final float[] d)
85 | {
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/DownstreamComponent.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp;
2 |
3 | import dmonner.xlbp.layer.DownstreamLayer;
4 |
5 | public interface DownstreamComponent extends Component
6 | {
7 | public void addUpstream(final UpstreamComponent upstream);
8 |
9 | public DownstreamLayer asDownstreamLayer();
10 |
11 | public boolean connectedUpstream(final UpstreamComponent upstream);
12 |
13 | @Override
14 | public DownstreamComponent copy(NetworkCopier copier);
15 |
16 | @Override
17 | public DownstreamComponent copy(String nameSuffix);
18 |
19 | public int getIndexInUpstream();
20 |
21 | public int getIndexInUpstream(final int index);
22 |
23 | public UpstreamComponent getUpstream();
24 |
25 | public UpstreamComponent getUpstream(final int index);
26 |
27 | public int indexOfUpstream(final UpstreamComponent upstream);
28 |
29 | public int nUpstream();
30 |
31 | public void removeUpstream(final int index);
32 |
33 | public void removeUpstream(final UpstreamComponent upstream);
34 | }
35 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/ForceAdjacencyListWeightInitializer.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp;
2 |
3 | public class ForceAdjacencyListWeightInitializer extends UniformWeightInitializer
4 | {
5 | public static final long serialVersionUID = 1L;
6 |
7 | public ForceAdjacencyListWeightInitializer()
8 | {
9 | super();
10 | }
11 |
12 | public ForceAdjacencyListWeightInitializer(final float p)
13 | {
14 | super(p);
15 | }
16 |
17 | public ForceAdjacencyListWeightInitializer(final float p, final float min, final float max)
18 | {
19 | super(p, min, max);
20 | }
21 |
22 | @Override
23 | public boolean fullConnectivity()
24 | {
25 | return false;
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/Input.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp;
2 |
3 | import java.util.Map;
4 |
5 | import dmonner.xlbp.layer.InputLayer;
6 |
7 | public class Input
8 | {
9 | private final InputLayer layer;
10 | //private final float[] value;
11 | //private final int[] binValue;
12 |
13 | private final Map inputs;
14 |
15 | public Input(final InputLayer layer, Map values) //final float[] value, final int[] binValue)
16 | {
17 | this.layer = layer;
18 | //this.value = value;
19 | //this.binValue = binValue;
20 | inputs = values;
21 |
22 | //if(value!=null) {
23 | // if(value.length != layer.size())
24 | // throw new IllegalArgumentException("Incorrect Input Size; expected " + layer.size() + " for "
25 | // + layer.getName() + ", got " + value.length);
26 | //} else {
27 | for(int i : values.keySet()) {
28 | if(i>layer.size())
29 | throw new IllegalArgumentException("Incorrect Input Size; expected " + layer.size() + " for "
30 | + layer.getName() + ", got " + i);
31 | }
32 | //}
33 | }
34 |
35 | public void apply()
36 | {
37 | layer.setInput(inputs);
38 | }
39 |
40 | @Override
41 | public boolean equals(final Object other)
42 | {
43 | if(other instanceof Input)
44 | {
45 | final Input that = (Input) other;
46 | return that.layer == this.layer;
47 | }
48 |
49 | return false;
50 | }
51 |
52 | public InputLayer getLayer()
53 | {
54 | return layer;
55 | }
56 |
57 | /*public float[] getValue()
58 | {
59 | return value;
60 | }
61 |
62 | public int[] getBinValue()
63 | {
64 | return binValue;
65 | }*/
66 |
67 | public Map getValue() {
68 | return inputs;
69 | }
70 |
71 | @Override
72 | public int hashCode()
73 | {
74 | return layer.hashCode();
75 | }
76 |
77 | @Override
78 | public String toString()
79 | {
80 | return layer.getName() + ": " + inputs.toString(); // MatrixTools.toString(value);
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/InputComponent.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp;
2 |
3 | import java.util.Map;
4 |
5 | public interface InputComponent extends Component
6 | {
7 | @Override
8 | public InputComponent copy(String nameSuffix);
9 |
10 | @Override
11 | public InputComponent copy(NetworkCopier copier);
12 |
13 | //public void setInput(final float[] activations);
14 | public void setInput(final Map activations);
15 | }
16 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/InternalComponent.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp;
2 |
3 | public interface InternalComponent extends UpstreamComponent, DownstreamComponent
4 | {
5 | @Override
6 | public InternalComponent copy(String nameSuffix);
7 |
8 | @Override
9 | public InternalComponent copy(NetworkCopier copier);
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/NormalWeightInitializer.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp;
2 |
3 | import java.util.Random;
4 |
5 | public class NormalWeightInitializer implements WeightInitializer
6 | {
7 | public static final long serialVersionUID = 1L;
8 |
9 | public final Random rand;
10 | public final float p;
11 | public final float min;
12 | public final float max;
13 |
14 | public NormalWeightInitializer()
15 | {
16 | this(new Random(1L), 1F);
17 | }
18 |
19 | public NormalWeightInitializer(final float p)
20 | {
21 | this(new Random(1L), p, -0.1F, +0.1F);
22 | }
23 |
24 | public NormalWeightInitializer(final float p, final float min, final float max)
25 | {
26 | this(new Random(1L), p, min, max);
27 | }
28 |
29 | public NormalWeightInitializer(final Random rand)
30 | {
31 | this(rand, 1F);
32 | }
33 |
34 | public NormalWeightInitializer(final Random rand, final float p)
35 | {
36 | this(rand, p, -0.1F, +0.1F);
37 | }
38 |
39 | public NormalWeightInitializer(final Random rand, final float p, final float min, final float max)
40 | {
41 | if(p < 0F || p > 1F)
42 | throw new IllegalArgumentException("p must be in [0, 1].");
43 |
44 | this.rand = rand;
45 | this.p = p;
46 | this.min = min;
47 | this.max = max;
48 | }
49 |
50 | @Override
51 | public boolean fullConnectivity()
52 | {
53 | return p == 1F;
54 | }
55 |
56 | @Override
57 | public boolean newWeight(final int j, final int i)
58 | {
59 | return rand.nextFloat() < p;
60 | }
61 |
62 | @Override
63 | public float randomWeight(final int j, final int i)
64 | {
65 | return (float)rand.nextGaussian() * (max - min) + min;
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/SetWeightInitializer.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp;
2 |
3 | import java.util.Random;
4 |
5 | public class SetWeightInitializer implements WeightInitializer
6 | {
7 | private static final long serialVersionUID = 1L;
8 |
9 | private final float[][] w;
10 | private final boolean full;
11 |
12 | public SetWeightInitializer(final float[][] w)
13 | {
14 | this.w = w;
15 | this.full = checkFull();
16 | }
17 |
18 | public SetWeightInitializer(final float[][] w, final boolean showFull)
19 | {
20 | this.w = w;
21 | this.full = showFull;
22 | }
23 |
24 | public SetWeightInitializer(int x, int y, int copy, float[][] matrix) {
25 |
26 | this.w = new float[x][y];
27 | //System.err.println(x + "\t" + y);
28 | //System.err.println(matrix.length + "\t" + matrix[0].length);
29 |
30 | for(int i=0; i 1F)
42 | throw new IllegalArgumentException("p must be in [0, 1].");
43 |
44 | this.rand = rand;
45 | this.p = p;
46 | this.min = min;
47 | this.max = max;
48 | }
49 |
50 | @Override
51 | public boolean fullConnectivity()
52 | {
53 | return p == 1F;
54 | }
55 |
56 | @Override
57 | public boolean newWeight(final int j, final int i)
58 | {
59 | return rand.nextFloat() < p;
60 | }
61 |
62 | @Override
63 | public float randomWeight(final int j, final int i)
64 | {
65 | return rand.nextFloat() * (max - min) + min;
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/UpstreamComponent.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp;
2 |
3 | import dmonner.xlbp.layer.UpstreamLayer;
4 |
5 | public interface UpstreamComponent extends Component
6 | {
7 | public void addDownstream(final DownstreamComponent downstream);
8 |
9 | public UpstreamLayer asUpstreamLayer();
10 |
11 | public boolean connectedDownstream(final DownstreamComponent downstream);
12 |
13 | @Override
14 | public UpstreamComponent copy(String nameSuffix);
15 |
16 | @Override
17 | public UpstreamComponent copy(NetworkCopier copier);
18 |
19 | public DownstreamComponent getDownstream();
20 |
21 | public DownstreamComponent getDownstream(final int index);
22 |
23 | public int getIndexInDownstream();
24 |
25 | public int getIndexInDownstream(final int index);
26 |
27 | public int indexOfDownstream(final DownstreamComponent downstream);
28 |
29 | public int nDownstream();
30 |
31 | public void removeDownstream(final DownstreamComponent downstream);
32 |
33 | public void removeDownstream(final int index);
34 | }
35 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/WeightInitializer.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp;
2 |
3 | import java.io.Serializable;
4 |
5 | public interface WeightInitializer extends Serializable
6 | {
7 | public boolean fullConnectivity();
8 |
9 | public boolean newWeight(int j, int i);
10 |
11 | public float randomWeight(int j, int i);
12 | }
13 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/WeightUpdater.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp;
2 |
3 | import java.io.Serializable;
4 |
5 | import dmonner.xlbp.connection.Connection;
6 |
7 | public interface WeightUpdater extends Serializable
8 | {
9 | public Connection getConnection();
10 |
11 | public float getUpdate(int i, float dw);
12 |
13 | public float getUpdate(int j, int i, float dw);
14 |
15 | public void initialize(int size);
16 |
17 | public void initialize(int to, int from);
18 |
19 | public void processBatch();
20 |
21 | public void toString(NetworkStringBuilder sb);
22 |
23 | public void updateFromBiases(float[] d);
24 |
25 | public void updateFromEligibilities(float[][] e, float[] d);
26 |
27 | public void updateFromInputs(float[] in, float[] d);
28 |
29 | public void updateFromVector(float[] v, float[] d);
30 | }
31 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/compound/Compound.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.compound;
2 |
3 | import dmonner.xlbp.Component;
4 | import dmonner.xlbp.NetworkCopier;
5 | import dmonner.xlbp.UpstreamComponent;
6 | import dmonner.xlbp.layer.UpstreamLayer;
7 |
8 | public interface Compound extends UpstreamComponent
9 | {
10 | @Override
11 | public Compound copy(NetworkCopier copier);
12 |
13 | @Override
14 | public Compound copy(String nameSuffix);
15 |
16 | public Component[] getComponents();
17 |
18 | public UpstreamLayer getOutput();
19 |
20 | public UpstreamLayer getOutput(int index);
21 |
22 | public int nOutputs();
23 | }
24 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/compound/DiagonalWeightBank.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.compound;
2 |
3 | import dmonner.xlbp.NetworkCopier;
4 | import dmonner.xlbp.WeightInitializer;
5 | import dmonner.xlbp.WeightUpdaterType;
6 | import dmonner.xlbp.connection.DiagonalConnection;
7 | import dmonner.xlbp.connection.LayerConnection;
8 | import dmonner.xlbp.layer.DownstreamLayer;
9 | import dmonner.xlbp.layer.UpstreamLayer;
10 |
11 | public class DiagonalWeightBank extends WeightBank
12 | {
13 | private static final long serialVersionUID = 1L;
14 | private boolean fullOnly;
15 |
16 | public DiagonalWeightBank(final DiagonalWeightBank that, final NetworkCopier copier)
17 | {
18 | super(that, copier);
19 | this.fullOnly = true;
20 | }
21 |
22 | public DiagonalWeightBank(final String name, final UpstreamLayer upstream,
23 | final DownstreamLayer downstream, final WeightInitializer win, final WeightUpdaterType wut)
24 | {
25 | super(name, upstream, downstream, win, wut);
26 | this.fullOnly = true;
27 | }
28 |
29 | @Override
30 | public DiagonalWeightBank copy(final NetworkCopier copier)
31 | {
32 | return new DiagonalWeightBank(this, copier);
33 | }
34 |
35 | @Override
36 | public DiagonalWeightBank copy(final String nameSuffix)
37 | {
38 | final NetworkCopier copier = new NetworkCopier(nameSuffix);
39 | final DiagonalWeightBank copy = copy(copier);
40 | copier.build();
41 | return copy;
42 | }
43 |
44 | @Override
45 | public DiagonalConnection getConnection()
46 | {
47 | return (DiagonalConnection) super.getConnection();
48 | }
49 |
50 | @Override
51 | protected LayerConnection makeConnection()
52 | {
53 | final DiagonalConnection conn = new DiagonalConnection(getName(), getWeightInput(),
54 | getWeightOutput());
55 | conn.setFullOnly(fullOnly);
56 | return conn;
57 | }
58 |
59 | public void setFullOnly(final boolean fullOnly)
60 | {
61 | this.fullOnly = fullOnly;
62 | getConnection().setFullOnly(fullOnly);
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/compound/DropoutInputCompound.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.compound;
2 |
3 | import java.util.Random;
4 |
5 | import dmonner.xlbp.NetworkCopier;
6 |
7 | public class DropoutInputCompound extends InputCompound {
8 |
9 | private static final long serialVersionUID = 1574221823303987773L;
10 | final double dropout;
11 | final Random r;
12 |
13 | public DropoutInputCompound(InputCompound that, NetworkCopier copier) {
14 | super(that, copier);
15 | dropout = 0.0;
16 | r = new Random();
17 | }
18 |
19 | public DropoutInputCompound(String string, double d, int inputlength) {
20 | super(string, inputlength);
21 | dropout = d;
22 | r = new Random();
23 | }
24 |
25 | public void setInput(final float[] activations) {
26 | for(int i=0; idropoutrate);
36 | }
37 |
38 | @Override
39 | public void activateTest()
40 | {
41 | System.arraycopy(upstream[0].getActivations(), 0, y, 0, size);
42 |
43 | for(int k = 1; k < nUpstream; k++)
44 | MatrixTools.multiply(upstream[k].getActivations(), y, size);
45 |
46 | for(int j=0; j0.0?x[i][j]:0.0);
40 | //if(j==x[0].length-1) System.out.println(); else System.out.print("\t");
41 | return x[j]>0.0?x[j]:0.0F;
42 | }
43 |
44 | @Override
45 | public float fprime(final int j)
46 | {
47 | return y[j]>0.0?1.0F:0.0F;
48 | }
49 |
50 | @Override
51 | public void activateTest()
52 | {
53 | for(int j = 0; j < size; j++)
54 | y[j] = (1-dropoutrate) * f(j);
55 | }
56 |
57 | @Override
58 | public void activateTrain()
59 | {
60 | for(int j = 0; j < size; j++) {
61 | if(ThreadLocalRandom.current().nextFloat()>dropoutrate) {
62 | y[j] = f(j);
63 | fprime[j] = fprime(j);
64 | } else {
65 | y[j] = 0.0F;
66 | fprime[j] = 0.0F;
67 | }
68 | }
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/layer/ScaleLayer.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.layer;
2 |
3 | import dmonner.xlbp.NetworkCopier;
4 |
5 | public class ScaleLayer extends AbstractFunctionLayer
6 | {
7 | private static final long serialVersionUID = 1L;
8 |
9 | private final float factor;
10 |
11 | public ScaleLayer(final ScaleLayer that, final NetworkCopier copier)
12 | {
13 | super(that, copier);
14 | this.factor = that.factor;
15 | }
16 |
17 | public ScaleLayer(final String name, final int size, final float factor)
18 | {
19 | super(name, size);
20 | this.factor = factor;
21 | }
22 |
23 | @Override
24 | public ScaleLayer copy(final NetworkCopier copier)
25 | {
26 | return new ScaleLayer(this, copier);
27 | }
28 |
29 | @Override
30 | public ScaleLayer copy(final String nameSuffix)
31 | {
32 | return copy(new NetworkCopier(nameSuffix));
33 | }
34 |
35 | @Override
36 | public float f(final int j)
37 | {
38 | return x[j] * factor;
39 | }
40 |
41 | @Override
42 | public float fprime(final int j)
43 | {
44 | return factor;
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/layer/SigmaLayer.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.layer;
2 |
3 | import dmonner.xlbp.NetworkCopier;
4 | import dmonner.xlbp.Responsibilities;
5 |
6 | public class SigmaLayer extends AbstractFanInLayer
7 | {
8 | private static final long serialVersionUID = 1L;
9 |
10 | public SigmaLayer(final SigmaLayer that, final NetworkCopier copier)
11 | {
12 | super(that, copier);
13 | }
14 |
15 | public SigmaLayer(final String name, final int size)
16 | {
17 | super(name, size);
18 | }
19 |
20 | @Override
21 | public void activateTest()
22 | {
23 | System.arraycopy(upstream[0].getActivations(), 0, y, 0, size);
24 |
25 | for(int k = 1; k < nUpstream; k++)
26 | {
27 | final float[] yk = upstream[k].getActivations();
28 | for(int j = 0; j < size; j++)
29 | y[j] += yk[j];
30 | }
31 | }
32 |
33 | @Override
34 | public void activateTrain()
35 | {
36 | activateTest();
37 | }
38 |
39 | @Override
40 | public void aliasResponsibilities(final int index, final Responsibilities resp)
41 | {
42 | super.aliasResponsibilities(index, resp);
43 | for(int i = 0; i < nUpstream; i++)
44 | upstream[i].aliasResponsibilities(myIndexInUpstream[i], resp);
45 | }
46 |
47 | @Override
48 | public void build()
49 | {
50 | if(!built)
51 | {
52 | super.build();
53 |
54 | y = new float[size];
55 | d = new Responsibilities(size);
56 |
57 | for(int i = 0; i < nUpstream; i++)
58 | {
59 | upstream[i].build();
60 | upstream[i].aliasResponsibilities(myIndexInUpstream[i], d);
61 | }
62 |
63 | built = true;
64 | }
65 | }
66 |
67 | @Override
68 | public SigmaLayer copy(final NetworkCopier copier)
69 | {
70 | return new SigmaLayer(this, copier);
71 | }
72 |
73 | @Override
74 | public SigmaLayer copy(final String nameSuffix)
75 | {
76 | return copy(new NetworkCopier(nameSuffix));
77 | }
78 |
79 | @Override
80 | public void updateEligibilities()
81 | {
82 | if(downstreamCopyLayer != null)
83 | downstream.updateUpstreamResponsibilities(myIndexInDownstream);
84 | }
85 |
86 | @Override
87 | public void updateResponsibilities()
88 | {
89 | if(downstreamCopyLayer == null)
90 | downstream.updateUpstreamResponsibilities(myIndexInDownstream);
91 | }
92 |
93 | @Override
94 | public void updateUpstreamResponsibilities(final int index)
95 | {
96 | // Nothing to do -- upstream ds are already aliased to this layer's d.
97 | }
98 |
99 | }
100 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/layer/SumOfSquaresTargetLayer.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.layer;
2 |
3 | import dmonner.xlbp.NetworkCopier;
4 |
5 | public class SumOfSquaresTargetLayer extends AbstractTargetLayer
6 | {
7 | private static final long serialVersionUID = 1L;
8 |
9 | public SumOfSquaresTargetLayer(final String name, final int size)
10 | {
11 | super(name, size);
12 | }
13 |
14 | public SumOfSquaresTargetLayer(final SumOfSquaresTargetLayer that, final NetworkCopier copier)
15 | {
16 | super(that, copier);
17 | }
18 |
19 | @Override
20 | public SumOfSquaresTargetLayer copy(final NetworkCopier copier)
21 | {
22 | return new SumOfSquaresTargetLayer(this, copier);
23 | }
24 |
25 | @Override
26 | public SumOfSquaresTargetLayer copy(final String nameSuffix)
27 | {
28 | return copy(new NetworkCopier(nameSuffix));
29 | }
30 |
31 | @Override
32 | public void updateResponsibilities()
33 | {
34 | if(t == null)
35 | d.clear();
36 | else
37 | d.target(t, y, w);
38 |
39 | super.updateResponsibilities();
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/layer/TanhLayer.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.layer;
2 |
3 | import dmonner.xlbp.NetworkCopier;
4 |
5 | public class TanhLayer extends AbstractFunctionLayer
6 | {
7 | private static final long serialVersionUID = 1L;
8 |
9 | public TanhLayer(final String name, final int size)
10 | {
11 | super(name, size);
12 | }
13 |
14 | public TanhLayer(final TanhLayer that, final NetworkCopier copier)
15 | {
16 | super(that, copier);
17 | }
18 |
19 | @Override
20 | public TanhLayer copy(final NetworkCopier copier)
21 | {
22 | return new TanhLayer(this, copier);
23 | }
24 |
25 | @Override
26 | public TanhLayer copy(final String nameSuffix)
27 | {
28 | return copy(new NetworkCopier(nameSuffix));
29 | }
30 |
31 | @Override
32 | public float f(final int j)
33 | {
34 | return (float) Math.tanh(x[j]);
35 | }
36 |
37 | @Override
38 | public float fprime(final int j)
39 | {
40 | return 1F - (y[j] * y[j]);
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/layer/TargetLayer.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.layer;
2 |
3 | import dmonner.xlbp.NetworkCopier;
4 | import dmonner.xlbp.TargetComponent;
5 |
6 | public interface TargetLayer extends DownstreamLayer, TargetComponent
7 | {
8 | @Override
9 | public TargetLayer copy(NetworkCopier copier);
10 |
11 | @Override
12 | public TargetLayer copy(String nameSuffix);
13 | }
14 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/layer/UpstreamLayer.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.layer;
2 |
3 | import dmonner.xlbp.NetworkCopier;
4 | import dmonner.xlbp.UpstreamComponent;
5 |
6 | public interface UpstreamLayer extends Layer, UpstreamComponent
7 | {
8 | public void addDownstreamCopyLayer(final CopySourceLayer copySource);
9 |
10 | @Override
11 | public UpstreamLayer copy(String nameSuffix);
12 |
13 | @Override
14 | public UpstreamLayer copy(NetworkCopier copier);
15 |
16 | public CopySourceLayer getDownstreamCopyLayer();
17 | }
18 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/layer/WeightedLayer.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.layer;
2 |
3 | import dmonner.xlbp.NetworkCopier;
4 |
5 | public interface WeightedLayer extends UpstreamLayer
6 | {
7 | @Override
8 | public void addDownstreamCopyLayer(final CopySourceLayer copySource);
9 |
10 | @Override
11 | public WeightedLayer copy(NetworkCopier copier);
12 |
13 | @Override
14 | public WeightedLayer copy(String nameSuffix);
15 | }
16 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/layer/XEntropyLogisticLayer.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.layer;
2 |
3 | import dmonner.xlbp.NetworkCopier;
4 | import dmonner.xlbp.Responsibilities;
5 |
6 | public class XEntropyLogisticLayer extends LogisticLayer
7 | {
8 | private static final long serialVersionUID = 1L;
9 |
10 | public XEntropyLogisticLayer(final String name, final int size)
11 | {
12 | super(name, size);
13 | }
14 |
15 | public XEntropyLogisticLayer(final XEntropyLogisticLayer that, final NetworkCopier copier)
16 | {
17 | super(that, copier);
18 | }
19 |
20 | @Override
21 | public void aliasResponsibilities(final int index, final Responsibilities resp)
22 | {
23 | super.aliasResponsibilities(index, resp);
24 | upstream.aliasResponsibilities(myIndexInUpstream, resp);
25 | }
26 |
27 | @Override
28 | public void build()
29 | {
30 | if(!built)
31 | {
32 | super.build();
33 |
34 | upstream.build();
35 | upstream.aliasResponsibilities(myIndexInUpstream, d);
36 |
37 | built = true;
38 | }
39 | }
40 |
41 | @Override
42 | public XEntropyLogisticLayer copy(final NetworkCopier copier)
43 | {
44 | return new XEntropyLogisticLayer(this, copier);
45 | }
46 |
47 | @Override
48 | public XEntropyLogisticLayer copy(final String nameSuffix)
49 | {
50 | return copy(new NetworkCopier(nameSuffix));
51 | }
52 |
53 | @Override
54 | public void updateUpstreamResponsibilities(final int upstreamIndex)
55 | {
56 | // Nothing to do -- upstream ds are already aliased to this layer's d.
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/layer/XEntropyTargetLayer.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.layer;
2 |
3 | import dmonner.xlbp.NetworkCopier;
4 |
5 | public class XEntropyTargetLayer extends AbstractTargetLayer
6 | {
7 | private static final long serialVersionUID = 1L;
8 |
9 | public XEntropyTargetLayer(final String name, final int size)
10 | {
11 | super(name, size);
12 | }
13 |
14 | public XEntropyTargetLayer(final XEntropyTargetLayer that, final NetworkCopier copier)
15 | {
16 | super(that, copier);
17 | }
18 |
19 | @Override
20 | public XEntropyTargetLayer copy(final NetworkCopier copier)
21 | {
22 | return new XEntropyTargetLayer(this, copier);
23 | }
24 |
25 | @Override
26 | public XEntropyTargetLayer copy(final String nameSuffix)
27 | {
28 | return copy(new NetworkCopier(nameSuffix));
29 | }
30 |
31 | @Override
32 | public void updateResponsibilities()
33 | {
34 | if(t == null)
35 | d.clear();
36 | else
37 | d.target(t, y, w);
38 |
39 | super.updateResponsibilities();
40 | }
41 |
42 | public void weightResult(float[] weight) {
43 | for(int i=0; i map)
12 | {
13 | addTo("", map);
14 | }
15 |
16 | @Override
17 | public void saveHeader(final CSVWriter out) throws IOException
18 | {
19 | saveHeader("", out);
20 | }
21 |
22 | @Override
23 | public String toString()
24 | {
25 | return toString("");
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/stat/ConnectionStat.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.stat;
2 |
3 | import java.io.IOException;
4 | import java.util.Map;
5 |
6 | import dmonner.xlbp.connection.Connection;
7 | import dmonner.xlbp.util.CSVWriter;
8 |
9 | public class ConnectionStat extends AbstractStat
10 | {
11 | private final MeanVarStat weights;
12 | private final FractionStat connections;
13 |
14 | public ConnectionStat()
15 | {
16 | weights = new MeanVarStat("Weights");
17 | connections = new FractionStat("Connections");
18 | }
19 |
20 | public ConnectionStat(final Connection conn)
21 | {
22 | this();
23 |
24 | final float[][] w = conn.toMatrix();
25 | for(final float[] row : w)
26 | for(final float wt : row)
27 | weights.add(Math.abs(wt));
28 |
29 | connections.add(conn.nWeights(), conn.nWeightsPossible());
30 |
31 | analyze();
32 | }
33 |
34 | public void add(final ConnectionStat that)
35 | {
36 | weights.add(that.weights);
37 | connections.add(that.connections);
38 | }
39 |
40 | @Override
41 | public void add(final Stat that)
42 | {
43 | if(that instanceof ConnectionStat)
44 | add((ConnectionStat) that);
45 | else
46 | throw new IllegalArgumentException("Can only add in ConnectionStats.");
47 | }
48 |
49 | @Override
50 | public void addTo(final String prefix, final Map map)
51 | {
52 | weights.addTo(prefix, map);
53 | connections.addTo(prefix, map);
54 | }
55 |
56 | @Override
57 | public void analyze()
58 | {
59 | weights.analyze();
60 | connections.analyze();
61 | }
62 |
63 | @Override
64 | public void clear()
65 | {
66 | weights.clear();
67 | connections.clear();
68 | }
69 |
70 | public FractionStat getConnections()
71 | {
72 | return connections;
73 | }
74 |
75 | public MeanVarStat getWeights()
76 | {
77 | return weights;
78 | }
79 |
80 | @Override
81 | public void saveData(final CSVWriter out) throws IOException
82 | {
83 | weights.saveData(out);
84 | connections.saveData(out);
85 | }
86 |
87 | @Override
88 | public void saveHeader(final String prefix, final CSVWriter out) throws IOException
89 | {
90 | weights.saveHeader(prefix, out);
91 | connections.saveHeader(prefix, out);
92 | }
93 |
94 | @Override
95 | public String toString(final String prefix)
96 | {
97 | final StringBuffer sb = new StringBuffer();
98 |
99 | sb.append(weights.toString(prefix));
100 | sb.append(connections.toString(prefix));
101 |
102 | return sb.toString();
103 | }
104 |
105 | }
106 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/stat/MeanVarStat.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.stat;
2 |
3 | import java.io.IOException;
4 | import java.util.Map;
5 |
6 | import dmonner.xlbp.util.CSVWriter;
7 |
8 | public class MeanVarStat extends AbstractStat
9 | {
10 | private final String name;
11 | private float sum;
12 | private float sumsq;
13 | private float n;
14 | private float mean;
15 | private float var;
16 |
17 | public MeanVarStat(final MeanVarStat that)
18 | {
19 | this.name = that.name;
20 | this.sum = that.sum;
21 | this.sumsq = that.sumsq;
22 | this.n = that.n;
23 | this.mean = that.mean;
24 | this.var = that.var;
25 | }
26 |
27 | public MeanVarStat(final String name)
28 | {
29 | this.name = name;
30 | }
31 |
32 | public void add(final float obs)
33 | {
34 | sum += obs;
35 | sumsq += obs * obs;
36 | n++;
37 | }
38 |
39 | public void add(final MeanVarStat that)
40 | {
41 | sum += that.sum;
42 | sumsq += that.sumsq;
43 | n += that.n;
44 | }
45 |
46 | @Override
47 | public void add(final Stat that)
48 | {
49 | if(that instanceof MeanVarStat)
50 | add((MeanVarStat) that);
51 | else
52 | throw new IllegalArgumentException("Can only add in other MeanVarStats.");
53 | }
54 |
55 | @Override
56 | public void addTo(final String prefix, final Map map)
57 | {
58 | map.put(prefix + name + "Mean", mean);
59 | map.put(prefix + name + "Var", var);
60 | }
61 |
62 | @Override
63 | public void analyze()
64 | {
65 | mean = sum / n;
66 | var = (sumsq - 2 * sum * mean) / n + mean * mean;
67 | }
68 |
69 | @Override
70 | public void clear()
71 | {
72 | sum = 0;
73 | sumsq = 0;
74 | n = 0;
75 | mean = 0;
76 | var = 0;
77 | }
78 |
79 | public float getMean()
80 | {
81 | return mean;
82 | }
83 |
84 | public float getVar()
85 | {
86 | return var;
87 | }
88 |
89 | @Override
90 | public void saveData(final CSVWriter out) throws IOException
91 | {
92 | out.appendField(mean);
93 | out.appendField(var);
94 | }
95 |
96 | @Override
97 | public void saveHeader(final String prefix, final CSVWriter out) throws IOException
98 | {
99 | out.appendHeader(prefix + name + "Mean");
100 | out.appendHeader(prefix + name + "Var");
101 | }
102 |
103 | @Override
104 | public String toString(final String prefix)
105 | {
106 | final StringBuilder sb = new StringBuilder();
107 |
108 | final String prename = prefix + name;
109 |
110 | sb.append(prename);
111 | sb.append("Mean = ");
112 | sb.append(mean);
113 | sb.append("\n");
114 |
115 | sb.append(prename);
116 | sb.append("Variance = ");
117 | sb.append(var);
118 | sb.append("\n");
119 |
120 | return sb.toString();
121 | }
122 |
123 | }
124 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/stat/Optimizer.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.stat;
2 |
3 | public abstract class Optimizer
4 | {
5 | public static enum Type
6 | {
7 | BIT_ACCURACY, TARGET_ACCURACY, TRIAL_ACCURACY, STEP_ACCURACY, SSE
8 | }
9 |
10 | public static Optimizer defaultOptimizer = get(Type.SSE);
11 |
12 | public static Optimizer get(final Type type)
13 | {
14 | if(type == Type.BIT_ACCURACY)
15 | return new Optimizer()
16 | {
17 | @Override
18 | public boolean betterThan(final SetStat newest, final SetStat best)
19 | {
20 | if(newest == null)
21 | return false;
22 | if(best == null)
23 | return true;
24 | return newest.getTargetStats().getBits().getAccuracy() >= best.getTargetStats().getBits()
25 | .getAccuracy();
26 | }
27 | };
28 | else if(type == Type.TARGET_ACCURACY)
29 | return new Optimizer()
30 | {
31 | @Override
32 | public boolean betterThan(final SetStat newest, final SetStat best)
33 | {
34 | if(newest == null)
35 | return false;
36 | if(best == null)
37 | return true;
38 | return newest.getTargetStats().getCorrect().getFraction() >= best.getTargetStats()
39 | .getCorrect().getFraction();
40 | }
41 | };
42 | else if(type == Type.TRIAL_ACCURACY)
43 | return new Optimizer()
44 | {
45 | @Override
46 | public boolean betterThan(final SetStat newest, final SetStat best)
47 | {
48 | if(newest == null)
49 | return false;
50 | if(best == null)
51 | return true;
52 | return newest.getTrialStats().getFraction() >= best.getTrialStats().getFraction();
53 | }
54 | };
55 | else if(type == Type.STEP_ACCURACY)
56 | return new Optimizer()
57 | {
58 | @Override
59 | public boolean betterThan(final SetStat newest, final SetStat best)
60 | {
61 | if(newest == null)
62 | return false;
63 | if(best == null)
64 | return true;
65 | return newest.getStepStats().getFraction() >= best.getStepStats().getFraction();
66 | }
67 | };
68 | else if(type == Type.SSE)
69 | return new Optimizer()
70 | {
71 | @Override
72 | public boolean betterThan(final SetStat newest, final SetStat best)
73 | {
74 | if(newest == null)
75 | return false;
76 | if(best == null)
77 | return true;
78 | return newest.getTargetStats().getError().getSSE() <= best.getTargetStats().getError()
79 | .getSSE();
80 | }
81 | };
82 | else
83 | throw new IllegalStateException("Unhandled Type: " + type);
84 | }
85 |
86 | public abstract boolean betterThan(SetStat newest, SetStat best);
87 | }
88 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/stat/Stat.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.stat;
2 |
3 | import java.io.IOException;
4 | import java.util.Map;
5 |
6 | import dmonner.xlbp.util.CSVWriter;
7 |
8 | public interface Stat
9 | {
10 | public abstract void add(final Stat that);
11 |
12 | public abstract void addTo(final Map map);
13 |
14 | public abstract void addTo(final String prefix, final Map map);
15 |
16 | public abstract void analyze();
17 |
18 | public abstract void clear();
19 |
20 | public abstract void saveData(final CSVWriter out) throws IOException;
21 |
22 | public abstract void saveHeader(final CSVWriter out) throws IOException;
23 |
24 | public abstract void saveHeader(final String prefix, final CSVWriter out) throws IOException;
25 |
26 | public abstract String toString(final String prefix);
27 | }
28 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/stat/StepStat.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.stat;
2 |
3 | import java.io.IOException;
4 | import java.util.Map;
5 |
6 | import dmonner.xlbp.Target;
7 | import dmonner.xlbp.trial.Step;
8 | import dmonner.xlbp.util.CSVWriter;
9 |
10 | public class StepStat extends AbstractStat
11 | {
12 | private final Step step;
13 | private final TargetSetStat targets;
14 | private final FractionStat correct;
15 |
16 | public StepStat(final Step step)
17 | {
18 | this.step = step;
19 | this.targets = new TargetSetStat();
20 | this.correct = new FractionStat("Step");
21 |
22 | analyze();
23 | }
24 |
25 | @Override
26 | public void add(final Stat that)
27 | {
28 | throw new IllegalArgumentException("Can only get data directly from a Step.");
29 | }
30 |
31 | @Override
32 | public void addTo(final String prefix, final Map map)
33 | {
34 | targets.addTo(prefix, map);
35 | correct.addTo(prefix, map);
36 | }
37 |
38 | @Override
39 | public void analyze()
40 | {
41 | for(final Target target : step.getTargets())
42 | {
43 | final TargetStat stat = new TargetStat(target.getLayer());
44 | stat.compare(target.getValue());
45 | stat.analyze();
46 | targets.add(stat);
47 | }
48 |
49 | targets.analyze();
50 | final int possible = targets.size() > 0 ? 1 : 0;
51 | final int actual = targets.getCorrect().getFraction() == 1F ? possible : 0;
52 | correct.add(actual, possible);
53 | correct.analyze();
54 | }
55 |
56 | @Override
57 | public void clear()
58 | {
59 | targets.clear();
60 | correct.clear();
61 | }
62 |
63 | public FractionStat getCorrect()
64 | {
65 | return correct;
66 | }
67 |
68 | public Step getStep()
69 | {
70 | return step;
71 | }
72 |
73 | public TargetSetStat getTargets()
74 | {
75 | return targets;
76 | }
77 |
78 | @Override
79 | public void saveData(final CSVWriter out) throws IOException
80 | {
81 | correct.saveData(out);
82 | targets.saveData(out);
83 | }
84 |
85 | @Override
86 | public void saveHeader(final String prefix, final CSVWriter out) throws IOException
87 | {
88 | correct.saveHeader(prefix, out);
89 | targets.saveHeader(prefix, out);
90 | }
91 |
92 | @Override
93 | public String toString(final String prefix)
94 | {
95 | final StringBuffer sb = new StringBuffer();
96 |
97 | sb.append(correct.toString(prefix));
98 | sb.append(targets.toString(prefix));
99 |
100 | return sb.toString();
101 | }
102 |
103 | }
104 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/trial/LayerCheck.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.trial;
2 |
3 | import dmonner.xlbp.layer.Layer;
4 |
5 | public class LayerCheck
6 | {
7 | private final String name;
8 | private final Layer layer;
9 | private final float[] eval;
10 |
11 | public LayerCheck(final Layer layer, final float[] eval)
12 | {
13 | this(layer.getName(), layer, eval);
14 | }
15 |
16 | public LayerCheck(final String name, final Layer layer, final float[] eval)
17 | {
18 | this.name = name;
19 | this.layer = layer;
20 | this.eval = eval;
21 | }
22 |
23 | public float[] getEval()
24 | {
25 | return eval;
26 | }
27 |
28 | public Layer getLayer()
29 | {
30 | return layer;
31 | }
32 |
33 | public String getName()
34 | {
35 | return name;
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/trial/NeverBreaker.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.trial;
2 |
3 | import dmonner.xlbp.stat.TestStat;
4 |
5 | public class NeverBreaker implements TrainingBreaker
6 | {
7 | @Override
8 | public boolean isBreakTime(final TestStat stat)
9 | {
10 | return false;
11 | }
12 |
13 | @Override
14 | public void reset()
15 | {
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/trial/PerfectBreaker.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.trial;
2 |
3 | import dmonner.xlbp.stat.TestStat;
4 |
5 | public class PerfectBreaker implements TrainingBreaker
6 | {
7 | @Override
8 | public boolean isBreakTime(final TestStat stat)
9 | {
10 | return stat.getLastTrain().getTrialStats().getFraction() == 1F;
11 | }
12 |
13 | @Override
14 | public void reset()
15 | {
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/trial/StepRecord.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.trial;
2 |
3 | import java.util.HashMap;
4 | import java.util.Map;
5 | import java.util.Map.Entry;
6 | import java.util.Set;
7 |
8 | import dmonner.xlbp.layer.Layer;
9 |
10 | public class StepRecord
11 | {
12 | private final Step step;
13 | private final Map recordings;
14 |
15 | public StepRecord(final Step step)
16 | {
17 | this.step = step;
18 | this.recordings = new HashMap<>();
19 |
20 | for(final Layer layer : step.getRecordLayers())
21 | recordings.put(layer, layer.getActivations().clone());
22 | }
23 |
24 | public float[] getRecording(final Layer layer)
25 | {
26 | return recordings.get(layer);
27 | }
28 |
29 | public Set> getRecordings()
30 | {
31 | return recordings.entrySet();
32 | }
33 |
34 | public Step getStep()
35 | {
36 | return step;
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/trial/TrainingBreaker.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.trial;
2 |
3 | import dmonner.xlbp.stat.TestStat;
4 |
5 | public interface TrainingBreaker
6 | {
7 | public boolean isBreakTime(final TestStat stat);
8 |
9 | public void reset();
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/trial/TrialRecord.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.trial;
2 |
3 | import java.util.ArrayList;
4 | import java.util.List;
5 |
6 | public class TrialRecord
7 | {
8 | private final Trial trial;
9 | private final List recordings;
10 |
11 | public TrialRecord(final Trial trial)
12 | {
13 | this.trial = trial;
14 | this.recordings = new ArrayList<>(trial.size());
15 |
16 | for(final Step step : trial.getSteps())
17 | recordings.add(step.getLastRecording());
18 | }
19 |
20 | public List getRecordings()
21 | {
22 | return recordings;
23 | }
24 |
25 | public Trial getTrial()
26 | {
27 | return trial;
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/trial/TrialStream.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.trial;
2 |
3 | import dmonner.xlbp.Network;
4 |
5 | public interface TrialStream
6 | {
7 | public Network getMetaNetwork();
8 |
9 | public String getName();
10 |
11 | public Trial nextTestTrial();
12 |
13 | public Trial nextTrainTrial();
14 |
15 | public Trial nextValidationTrial();
16 |
17 | public int nFolds();
18 |
19 | public int nTestFolds();
20 |
21 | public int nTestTrials();
22 |
23 | public int nTrainFolds();
24 |
25 | public int nTrainTrials();
26 |
27 | public int nValidationFolds();
28 |
29 | public int nValidationTrials();
30 |
31 | public void setFold(int fold);
32 | }
33 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/trial/TrialStreamAdapter.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.trial;
2 |
3 | import dmonner.xlbp.Network;
4 |
5 | public class TrialStreamAdapter extends AbstractTrialStream
6 | {
7 | public TrialStreamAdapter(final String name, final Network net)
8 | {
9 | super(name, net);
10 | }
11 |
12 | public TrialStreamAdapter(final String name, final Network net, final int train, final int test,
13 | final int valid)
14 | {
15 | super(name, net, train, test, valid);
16 | }
17 |
18 | public TrialStreamAdapter(final String name, final Network net, final String split)
19 | {
20 | super(name, net, split);
21 | }
22 |
23 | @Override
24 | public Trial nextTestTrial()
25 | {
26 | return null;
27 | }
28 |
29 | @Override
30 | public Trial nextTrainTrial()
31 | {
32 | return null;
33 | }
34 |
35 | @Override
36 | public Trial nextValidationTrial()
37 | {
38 | return null;
39 | }
40 |
41 | @Override
42 | public int nTestTrials()
43 | {
44 | return 0;
45 | }
46 |
47 | @Override
48 | public int nTrainTrials()
49 | {
50 | return 0;
51 | }
52 |
53 | @Override
54 | public int nValidationTrials()
55 | {
56 | return 0;
57 | }
58 |
59 | @Override
60 | public void setFold(final int fold)
61 | {
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/trial/ValidationBreaker.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.trial;
2 |
3 | import dmonner.xlbp.stat.Optimizer;
4 | import dmonner.xlbp.stat.SetStat;
5 | import dmonner.xlbp.stat.TestStat;
6 |
7 | public class ValidationBreaker implements TrainingBreaker
8 | {
9 | private final int maxEpochsWithoutImprovement;
10 | private final Optimizer optimizer;
11 |
12 | private int epochsWithoutImprovement;
13 | private SetStat bestValid;
14 |
15 | public ValidationBreaker()
16 | {
17 | this(10);
18 | }
19 |
20 | public ValidationBreaker(final int maxEpochsWithoutImprovement)
21 | {
22 | this(maxEpochsWithoutImprovement, Optimizer.defaultOptimizer);
23 | }
24 |
25 | public ValidationBreaker(final int maxEpochsWithoutImprovement, final Optimizer optimizer)
26 | {
27 | this.maxEpochsWithoutImprovement = maxEpochsWithoutImprovement;
28 | this.optimizer = optimizer;
29 | }
30 |
31 | @Override
32 | public boolean isBreakTime(final TestStat stat)
33 | {
34 | final SetStat newest = stat.getBestValid();
35 |
36 | if(optimizer.betterThan(newest, bestValid))
37 | {
38 | bestValid = newest;
39 | epochsWithoutImprovement = 0;
40 | }
41 | else
42 | {
43 | epochsWithoutImprovement++;
44 | }
45 |
46 | return epochsWithoutImprovement >= maxEpochsWithoutImprovement;
47 | }
48 |
49 | @Override
50 | public void reset()
51 | {
52 | bestValid = null;
53 | epochsWithoutImprovement = 0;
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/util/ArrayQueue.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.util;
2 |
3 | public class ArrayQueue
4 | {
5 | private T[] q;
6 | private int head, size;
7 | private int capacity;
8 |
9 | public ArrayQueue()
10 | {
11 | this(0);
12 | }
13 |
14 | @SuppressWarnings("unchecked")
15 | public ArrayQueue(final int capacity)
16 | {
17 | this.q = (T[]) new Object[capacity];
18 | this.head = 0;
19 | this.size = 0;
20 | this.capacity = capacity;
21 | }
22 |
23 | public ArrayQueue(final T[] full)
24 | {
25 | fill(full);
26 | }
27 |
28 | public int capacity()
29 | {
30 | return capacity;
31 | }
32 |
33 | public void clear()
34 | {
35 | head = 0;
36 | size = 0;
37 | }
38 |
39 | public void fill(final T[] full)
40 | {
41 | q = full;
42 | head = 0;
43 | size = q.length;
44 | capacity = q.length;
45 | }
46 |
47 | public boolean isEmpty()
48 | {
49 | return size == 0;
50 | }
51 |
52 | public boolean isFull()
53 | {
54 | return size >= capacity;
55 | }
56 |
57 | public T peek()
58 | {
59 | if(isEmpty())
60 | throw new IllegalStateException("Cannot peek -- queue is empty.");
61 |
62 | return q[head];
63 | }
64 |
65 | public T peek(final int idx)
66 | {
67 | if(idx < 0 || idx >= size)
68 | throw new IllegalStateException("Cannot peek at " + idx + " -- not enough elements.");
69 |
70 | return q[(head + idx) % capacity];
71 | }
72 |
73 | public T pop()
74 | {
75 | if(isEmpty())
76 | throw new IllegalStateException("Cannot pop -- queue is empty.");
77 |
78 | final T rv = q[head];
79 | head = (head + 1) % capacity;
80 | size--;
81 | return rv;
82 | }
83 |
84 | public void popN(final int n)
85 | {
86 | if(size < n)
87 | throw new IllegalStateException("Cannot pop " + n + " -- not enough elements.");
88 |
89 | head = (head + n) % capacity;
90 | size -= n;
91 | }
92 |
93 | public void push(final T elem)
94 | {
95 | if(isFull())
96 | throw new IllegalStateException("Cannot push -- queue is full.");
97 |
98 | final int next = (head + size) % capacity;
99 | q[next] = elem;
100 | size++;
101 | }
102 |
103 | public int size()
104 | {
105 | return size;
106 | }
107 | }
108 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/util/IndexAwareHeapNode.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.util;
2 |
3 | import java.util.Comparator;
4 |
5 | public class IndexAwareHeapNode> implements
6 | Comparable>
7 | {
8 | private IndexAwareHeap heap;
9 | private int index;
10 | public final E element;
11 |
12 | public IndexAwareHeapNode(final E elem)
13 | {
14 | this.heap = null;
15 | this.element = elem;
16 | this.index = -1;
17 | }
18 |
19 | @Override
20 | public int compareTo(final IndexAwareHeapNode that)
21 | {
22 | final Comparator comp = heap.getComparator();
23 | // if not comparator was provided, use the default
24 | if(comp == null)
25 | return this.element.compareTo(that.element);
26 | // otherwise use the provided comparator
27 | else
28 | return comp.compare(this.element, that.element);
29 | }
30 |
31 | public IndexAwareHeap getHeap()
32 | {
33 | return heap;
34 | }
35 |
36 | public int getIndex()
37 | {
38 | return index;
39 | }
40 |
41 | public void remove()
42 | {
43 | heap.remove(index);
44 | }
45 |
46 | public void set(final IndexAwareHeap heap, final int index)
47 | {
48 | this.heap = heap;
49 | this.index = index;
50 | }
51 |
52 | @Override
53 | public String toString()
54 | {
55 | return "[" + index + ": " + element.toString() + "]";
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/util/NoiseGenerator.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.util;
2 |
3 | public interface NoiseGenerator
4 | {
5 | public float next();
6 | }
7 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/util/NormalNoiseGenerator.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.util;
2 |
3 | import java.util.Random;
4 |
5 | public class NormalNoiseGenerator implements NoiseGenerator
6 | {
7 | private final Random random;
8 | private final float mean;
9 | private final float sd;
10 | private float cached;
11 |
12 | public NormalNoiseGenerator(final Random random, final float mean, final float sd)
13 | {
14 | this.random = random;
15 | this.mean = mean;
16 | this.sd = sd;
17 | }
18 |
19 | @Override
20 | public float next()
21 | {
22 | // Return the cached variable if there is one.
23 | if(!Float.isNaN(cached))
24 | {
25 | final float rv = mean + cached * sd;
26 | cached = Float.NaN;
27 | return rv;
28 | }
29 |
30 | // Otherwise, generate two new normal variates.
31 | float v1, v2, s;
32 | do
33 | {
34 | v1 = 2 * random.nextFloat() - 1;
35 | v2 = 2 * random.nextFloat() - 1;
36 | s = v1 * v1 + v2 * v2;
37 | }
38 | while(s >= 1);
39 |
40 | // This method generates two normal random variates. Saved one for next
41 | // time, and return the other.
42 | final float factor = (float) Math.sqrt(-2 * Math.log(s) / s);
43 | cached = factor * v1;
44 | return mean + factor * v2 * sd;
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/src/main/java/dmonner/xlbp/util/TableWriter.java:
--------------------------------------------------------------------------------
1 | package dmonner.xlbp.util;
2 |
3 | import java.io.IOException;
4 |
5 | public class TableWriter extends CSVWriter
6 | {
7 | public TableWriter(final String filename) throws IOException
8 | {
9 | this(filename, false);
10 | }
11 |
12 | public TableWriter(final String filename, final boolean append) throws IOException
13 | {
14 | super(filename, append, "\t", "\n", "");
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/Parse.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl;
2 |
3 | import java.util.zip.ZipFile;
4 |
5 | import se.lth.cs.srl.corpus.Sentence;
6 | import se.lth.cs.srl.io.AllCoNLL09Reader;
7 | import se.lth.cs.srl.io.CoNLL09Writer;
8 | import se.lth.cs.srl.io.DepsOnlyCoNLL09Reader;
9 | import se.lth.cs.srl.io.FrameNetXMLWriter;
10 | import se.lth.cs.srl.io.SRLOnlyCoNLL09Reader;
11 | import se.lth.cs.srl.io.SentenceReader;
12 | import se.lth.cs.srl.io.SentenceWriter;
13 | import se.lth.cs.srl.options.ParseOptions;
14 | import se.lth.cs.srl.pipeline.Pipeline;
15 | import se.lth.cs.srl.pipeline.Reranker;
16 | import se.lth.cs.srl.pipeline.Step;
17 | import se.lth.cs.srl.util.Util;
18 |
19 | public class Parse {
20 | public static ParseOptions parseOptions;
21 |
22 | public static void main(String[] args) throws Exception {
23 | long startTime = System.currentTimeMillis();
24 | parseOptions = new ParseOptions(args);
25 |
26 | SemanticRoleLabeler srl;
27 |
28 | if (parseOptions.useReranker) {
29 | srl = new Reranker(parseOptions);
30 | // srl =
31 | // Reranker.fromZipFile(zipFile,parseOptions.skipPI,parseOptions.global_alfa,parseOptions.global_aiBeam,parseOptions.global_acBeam);
32 | } else {
33 | ZipFile zipFile = new ZipFile(parseOptions.modelFile);
34 |
35 | srl = parseOptions.skipAI ? Pipeline.fromZipFile(zipFile,
36 | new Step[] { Step.ac })
37 | : parseOptions.skipPD ? Pipeline.fromZipFile(zipFile,
38 | new Step[] { Step.ai, Step.ac })
39 | : parseOptions.skipPI ? Pipeline.fromZipFile(zipFile,
40 | new Step[] { Step.pd, Step.ai, Step.ac})
41 | : Pipeline.fromZipFile(zipFile);
42 | zipFile.close();
43 | }
44 |
45 | SentenceWriter writer = null;
46 | if (parseOptions.printXML)
47 | writer = new FrameNetXMLWriter(parseOptions.output);
48 | else
49 | writer = new CoNLL09Writer(parseOptions.output);
50 |
51 | SentenceReader reader = parseOptions.skipAI ? new AllCoNLL09Reader(
52 | parseOptions.inputCorpus) : parseOptions.skipPI ? new SRLOnlyCoNLL09Reader(
53 | parseOptions.inputCorpus) : new DepsOnlyCoNLL09Reader(
54 | parseOptions.inputCorpus);
55 | int senCount = 0;
56 | for (Sentence s : reader) {
57 | senCount++;
58 | if (senCount % 100 == 0)
59 | System.out.println("Parsing sentence " + senCount);
60 | srl.parseSentence(s);
61 | if (parseOptions.writeCoref)
62 | writer.specialwrite(s);
63 | else
64 | writer.write(s);
65 | }
66 | writer.close();
67 | reader.close();
68 | long totalTime = System.currentTimeMillis() - startTime;
69 | System.out.println("Done.");
70 | System.out.println(srl.getStatus());
71 | System.out.println();
72 | System.out.println("Total execution time: "
73 | + Util.insertCommas(totalTime) + "ms");
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/SemanticRoleLabeler.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl;
2 |
3 | import java.util.Date;
4 |
5 | import se.lth.cs.srl.corpus.Sentence;
6 | import se.lth.cs.srl.util.Util;
7 |
8 | public abstract class SemanticRoleLabeler {
9 |
10 | public void parseSentence(Sentence s) {
11 | long startTime = System.currentTimeMillis();
12 | parse(s);
13 | parsingTime += System.currentTimeMillis() - startTime;
14 | senCount++;
15 | predCount += s.getPredicates().size();
16 | }
17 |
18 | protected abstract void parse(Sentence s);
19 |
20 | public String getStatus() {
21 | StringBuilder ret = new StringBuilder(
22 | "Semantic role labeler started at " + startDate + "\n");
23 | ret.append("Time spent loading SRL models (ms)\t\t"
24 | + Util.insertCommas(loadingTime) + "\n");
25 | ret.append("Time spent parsing semantic roles (ms)\t\t"
26 | + Util.insertCommas(parsingTime) + "\n");
27 | ret.append("\n");
28 | ret.append("Number of sentences\t" + Util.insertCommas(senCount) + "\n");
29 | ret.append("Number of predicates\t" + Util.insertCommas(predCount)
30 | + "\n");
31 | ret.append("SRL speed (ms/sen)\t" + ((double) parsingTime / senCount)
32 | + "\n");
33 | ret.append(getSubStatus());
34 | return ret.toString();
35 | }
36 |
37 | protected abstract String getSubStatus();
38 |
39 | public long loadingTime = 0;
40 | public long parsingTime = 0;
41 | public int senCount = 0;
42 | public int predCount = 0;
43 | public final Date startDate = new Date();
44 |
45 | }
46 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/corpus/ConstituentBuilder.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.corpus;
2 |
3 | import java.util.TreeSet;
4 |
5 | public class ConstituentBuilder {
6 |
7 | Sentence sen;
8 | Word head;
9 |
10 | public ConstituentBuilder(Sentence s, Word w) {
11 | sen = s;
12 | head = w;
13 | }
14 |
15 | public String toString() {
16 | TreeSet children = new TreeSet<>();
17 | children.add(head.getIdx());
18 | TreeSet processed = new TreeSet<>();
19 | while (!children.isEmpty()) {
20 | Word c = sen.get(children.pollFirst());
21 | if (!processed.contains(c.getIdx())) {
22 | processed.add(c.getIdx());
23 | for (Word cc : c.getChildren()) {
24 | children.add(cc.getIdx());
25 | }
26 | }
27 | }
28 |
29 | StringBuilder sb = new StringBuilder();
30 | String[] wordformArray = sen.getFormArray();
31 | for (int i = processed.first(); i <= processed.last(); i++) {
32 | sb.append(wordformArray[i]);
33 | sb.append(" ");
34 | }
35 | return sb.toString().trim();
36 | }
37 | }
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/corpus/CorefChain.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.corpus;
2 |
3 | import java.util.LinkedList;
4 |
5 | public class CorefChain extends LinkedList {
6 |
7 | int id;
8 |
9 | public CorefChain(int id) {
10 | this.id = id;
11 | }
12 |
13 | public String getId() {
14 | return new Integer(id).toString();
15 | }
16 |
17 | }
18 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/corpus/StringInText.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.corpus;
2 |
3 | public class StringInText {
4 | protected String s;
5 | protected int beginPos;
6 | protected int endPos;
7 |
8 | public StringInText(String s, int beginPos, int endPos) {
9 | this.s = s;
10 | this.beginPos = beginPos;
11 | this.endPos = endPos;
12 | }
13 |
14 | public String word() {
15 | return s;
16 | }
17 |
18 | public int begin() {
19 | return beginPos;
20 | }
21 |
22 | public int end() {
23 | return endPos;
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/io/AbstractCoNLL09Reader.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.io;
2 |
3 | import java.io.BufferedReader;
4 | import java.io.File;
5 | import java.io.FileInputStream;
6 | import java.io.IOException;
7 | import java.io.InputStreamReader;
8 | import java.nio.charset.Charset;
9 | import java.util.ArrayList;
10 | import java.util.Iterator;
11 | import java.util.List;
12 | import java.util.regex.Pattern;
13 |
14 | //import se.lth.cs.srl.corpus.Corpus;
15 | import se.lth.cs.srl.corpus.Sentence;
16 |
17 | public abstract class AbstractCoNLL09Reader implements SentenceReader {
18 |
19 | protected static final Pattern NEWLINE_PATTERN = Pattern.compile("\n");
20 |
21 | protected BufferedReader in;
22 | protected Sentence nextSen;
23 | // protected Corpus c;
24 | private File file;
25 |
26 | public AbstractCoNLL09Reader(File file) {
27 | this.file = file;
28 | open();
29 | }
30 |
31 | private void restart() {
32 | try {
33 | in.close();
34 | open();
35 | } catch (IOException e) {
36 | // TODO Auto-generated catch block
37 | e.printStackTrace();
38 | }
39 | }
40 |
41 | private void open() {
42 | System.err.println("Opening reader for " + file + "...");
43 | try {
44 | in = new BufferedReader(new InputStreamReader(new FileInputStream(
45 | file), Charset.forName("UTF-8")));
46 | // in = new BufferedReader(new FileReader(file));
47 | readNextSentence();
48 | } catch (IOException e) {
49 | System.out.println("Failed: " + e.toString());
50 | System.exit(1);
51 | }
52 | }
53 |
54 | protected abstract void readNextSentence() throws IOException;
55 |
56 | private Sentence getSentence() {
57 | Sentence ret = nextSen;
58 | try {
59 | readNextSentence();
60 | } catch (IOException e) {
61 | System.out.println("Failed to read from corpus file... exiting.");
62 | System.exit(1);
63 | }
64 | return ret;
65 | }
66 |
67 | @Override
68 | public List readAll() {
69 | ArrayList ret = new ArrayList<>();
70 | for (Sentence s : this)
71 | ret.add(s);
72 | ret.trimToSize();
73 | return ret;
74 | }
75 |
76 | @Override
77 | public Iterator iterator() {
78 | if (nextSen == null)
79 | restart();
80 | return new SentenceIterator();
81 | }
82 |
83 | @Override
84 | public void close() {
85 | try {
86 | in.close();
87 | } catch (IOException e) {
88 | // TODO Auto-generated catch block
89 | e.printStackTrace();
90 | }
91 | }
92 |
93 | private class SentenceIterator implements Iterator {
94 | @Override
95 | public boolean hasNext() {
96 | return nextSen != null;
97 | }
98 |
99 | @Override
100 | public Sentence next() {
101 | return getSentence();
102 | }
103 |
104 | @Override
105 | public void remove() {
106 | throw new UnsupportedOperationException("Not implemented");
107 | }
108 |
109 | }
110 |
111 | }
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/io/AllCoNLL09Reader.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.io;
2 |
3 | import java.io.File;
4 | import java.io.IOException;
5 |
6 | import se.lth.cs.srl.corpus.Sentence;
7 |
8 | public class AllCoNLL09Reader extends AbstractCoNLL09Reader {
9 |
10 | public AllCoNLL09Reader(File file) {
11 | super(file);
12 | }
13 |
14 | protected void readNextSentence() throws IOException {
15 | String str;
16 | Sentence sen = null;
17 | StringBuilder senBuffer = new StringBuilder();
18 | while ((str = in.readLine()) != null) {
19 | if (!str.trim().equals("")) {
20 | senBuffer.append(str).append("\n");
21 | } else {
22 | sen=Sentence.newSentence((NEWLINE_PATTERN.split(senBuffer.toString())));
23 | break;
24 | }
25 | }
26 | if (sen == null) {
27 | nextSen = null;
28 | in.close();
29 | } else {
30 | nextSen = sen;
31 | }
32 | }
33 |
34 | }
35 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/io/CoNLL09Writer.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.io;
2 |
3 | import java.io.BufferedWriter;
4 | import java.io.File;
5 | import java.io.FileOutputStream;
6 | import java.io.IOException;
7 | import java.io.OutputStreamWriter;
8 | import java.nio.charset.Charset;
9 |
10 | import se.lth.cs.srl.corpus.Sentence;
11 |
12 | public class CoNLL09Writer implements SentenceWriter {
13 |
14 | private BufferedWriter out;
15 |
16 | public CoNLL09Writer(File filename) {
17 | System.out.println("Writing corpus to " + filename + "...");
18 | try {
19 | out = new BufferedWriter(new OutputStreamWriter(
20 | new FileOutputStream(filename), Charset.forName("UTF-8")));
21 | // out = new BufferedWriter(new FileWriter(filename));
22 | } catch (IOException e) {
23 | System.out.println("Failed while opening writer...\n"
24 | + e.toString());
25 | System.exit(1);
26 | }
27 | }
28 |
29 | public void write(Sentence s) {
30 | try {
31 | out.write(s.toString() + "\n\n");
32 | } catch (IOException e) {
33 | e.printStackTrace();
34 | System.out.println("Failed to write sentance.");
35 | System.exit(1);
36 | }
37 | }
38 |
39 | @Override
40 | public void specialwrite(Sentence s) {
41 | try {
42 | out.write(s.toSpecialString() + "\n\n");
43 | } catch (IOException e) {
44 | e.printStackTrace();
45 | System.out.println("Failed to write sentance.");
46 | System.exit(1);
47 | }
48 | }
49 |
50 | public void close() {
51 | try {
52 | out.close();
53 | } catch (IOException e) {
54 | e.printStackTrace();
55 | System.out.println("Failed to close writer.");
56 | System.exit(1);
57 | }
58 | }
59 |
60 | }
61 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/io/DepsOnlyCoNLL09Reader.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.io;
2 |
3 | import java.io.File;
4 | import java.io.IOException;
5 |
6 | import se.lth.cs.srl.corpus.Sentence;
7 |
8 | public class DepsOnlyCoNLL09Reader extends AbstractCoNLL09Reader {
9 |
10 | public DepsOnlyCoNLL09Reader(File file) {
11 | super(file);
12 | }
13 |
14 | @Override
15 | protected void readNextSentence() throws IOException {
16 | String str;
17 | Sentence sen = null;
18 | StringBuilder senBuffer = new StringBuilder();
19 | while ((str = in.readLine()) != null) {
20 | if (!str.trim().equals("")) {
21 | senBuffer.append(str).append("\n");
22 | } else {
23 | sen = Sentence.newDepsOnlySentence(NEWLINE_PATTERN
24 | .split(senBuffer.toString()));
25 | // System.err.println("Processing sentence ...");
26 | // System.err.println(sen);
27 | break;
28 | }
29 | }
30 | if (sen == null) {
31 | nextSen = null;
32 | in.close();
33 | } else {
34 | nextSen = sen;
35 | }
36 | }
37 |
38 | }
39 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/io/SRLOnlyCoNLL09Reader.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.io;
2 |
3 | import java.io.File;
4 | import java.io.IOException;
5 |
6 | import se.lth.cs.srl.corpus.Sentence;
7 |
8 | public class SRLOnlyCoNLL09Reader extends AbstractCoNLL09Reader {
9 |
10 | public SRLOnlyCoNLL09Reader(File file) {
11 | super(file);
12 | }
13 |
14 | @Override
15 | protected void readNextSentence() throws IOException {
16 | String str;
17 | Sentence sen = null;
18 | StringBuilder senBuffer = new StringBuilder();
19 | while ((str = in.readLine()) != null) {
20 | if (!str.trim().equals("")) {
21 | senBuffer.append(str).append("\n");
22 | } else {
23 | if (!senBuffer.toString().startsWith("_"))
24 | sen = Sentence.newSRLOnlySentence((NEWLINE_PATTERN
25 | .split(senBuffer.toString())));
26 | break;
27 | }
28 | }
29 | if (sen == null) {
30 | nextSen = null;
31 | in.close();
32 | } else {
33 | nextSen = sen;
34 | }
35 |
36 | }
37 |
38 | }
39 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/io/SentenceReader.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.io;
2 |
3 | import java.util.List;
4 |
5 | import se.lth.cs.srl.corpus.Sentence;
6 |
7 | public interface SentenceReader extends Iterable {
8 |
9 | List readAll();
10 |
11 | public void close();
12 |
13 | }
14 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/io/SentenceWriter.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.io;
2 |
3 | //import se.lth.cs.srl.corpus.Corpus;
4 | import se.lth.cs.srl.corpus.Sentence;
5 |
6 | public interface SentenceWriter {
7 |
8 | public void write(Sentence s);
9 |
10 | public void close();
11 |
12 | public void specialwrite(Sentence s);
13 |
14 | }
15 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/languages/AbstractDummyLanguage.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.languages;
2 |
3 | import java.util.Map;
4 |
5 | import se.lth.cs.srl.corpus.Predicate;
6 | import se.lth.cs.srl.corpus.Word;
7 | import se.lth.cs.srl.options.FullPipelineOptions;
8 |
9 | /**
10 | * this class is just so we can run the pipeline without the srl stuff i.e., to
11 | * run the http interface of the anna pipeline without srl.
12 | *
13 | * @author anders
14 | *
15 | */
16 | public abstract class AbstractDummyLanguage extends Language {
17 |
18 | @Override
19 | public String getDefaultSense(Predicate pred) {
20 | throw new Error("!");
21 | }
22 |
23 | @Override
24 | public String getCoreArgumentLabelSequence(Predicate pred,
25 | Map proposition) {
26 | throw new Error("!");
27 | }
28 |
29 | @Override
30 | public String getLexiconURL(Predicate pred) {
31 | throw new Error("!");
32 | }
33 |
34 | @Override
35 | public String verifyLanguageSpecificModelFiles(FullPipelineOptions options) {
36 | return null;
37 | }
38 |
39 | }
40 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/languages/Chinese.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.languages;
2 |
3 | import is2.lemmatizer.Lemmatizer;
4 |
5 | import java.io.File;
6 | import java.io.IOException;
7 | import java.util.Map;
8 | import java.util.regex.Pattern;
9 |
10 | import se.lth.cs.srl.corpus.Predicate;
11 | import se.lth.cs.srl.corpus.Sentence;
12 | import se.lth.cs.srl.corpus.Word;
13 | import se.lth.cs.srl.options.FullPipelineOptions;
14 | import se.lth.cs.srl.preprocessor.SimpleChineseLemmatizer;
15 | //import se.lth.cs.srl.preprocessor.tokenization.StanfordChineseSegmenterWrapper;
16 | import se.lth.cs.srl.preprocessor.tokenization.Tokenizer;
17 | import se.lth.cs.srl.util.FileExistenceVerifier;
18 |
19 | public class Chinese extends Language {
20 |
21 | private static Pattern CALSPattern=Pattern.compile("^A0|A1|A2|A3|A4$");
22 | @Override
23 | public String getCoreArgumentLabelSequence(Predicate pred,Map proposition) {
24 | Sentence sen=pred.getMySentence();
25 | StringBuilder ret=new StringBuilder();
26 | for(int i=1,size=sen.size();i the input corpus. assumed to be tokenized like CoNLL 09 data\n"
24 | + "-out the file to write output to (default out.txt)\n"
25 | + "-nopi skips the predicate identification\n"
26 | + "-tokenize implies the input is unsegmented, with one sentence per line, i.e. _not_ CoNLL09 format";
27 | }
28 |
29 | @Override
30 | protected int trySubParseArg(String[] args, int ai) {
31 | if (args[ai].equals("-out")) {
32 | ai++;
33 | output = new File(args[ai]);
34 | ai++;
35 | } else if (args[ai].equals("-test")) {
36 | ai++;
37 | input = new File(args[ai]);
38 | ai++;
39 | } else if (args[ai].equals("-nopi")) {
40 | ai++;
41 | skipPI = true;
42 | } else if (args[ai].equals("-noai")) {
43 | ai++;
44 | skipAI = true;
45 | } else if (args[ai].equals("-desegment")) { // Not printed out in the
46 | // help
47 | // (getSubUsageOptions()),
48 | // don't think it needs to.
49 | // This is only
50 | // experimental.
51 | ai++;
52 | desegment = true;
53 | skipPI = false; // This won't be regarded anyway. It's not
54 | // applicable when the initial segmentation is lost.
55 | } else if (args[ai].equals("-tokenize")) {
56 | ai++;
57 | super.loadPreprocessorWithTokenizer = true;
58 | skipPI = false; // Same as above
59 | desegment = false;
60 | } else if (args[ai].equals("-hybrid")) {
61 | hybrid = true;
62 | ai++;
63 | } else if (args[ai].equals("-external")) {
64 | external = true;
65 | ai++;
66 | }
67 | return ai;
68 | }
69 |
70 | @Override
71 | protected Class> getIntendedEntryClass() {
72 | return CompletePipeline.class;
73 | }
74 |
75 | }
76 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/pipeline/AbstractStep.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.pipeline;
2 |
3 | import java.io.IOException;
4 | import java.io.ObjectInputStream;
5 | import java.util.HashMap;
6 | import java.util.Map;
7 | import java.util.zip.ZipFile;
8 |
9 | import se.lth.cs.srl.corpus.Sentence;
10 | import uk.ac.ed.inf.srl.features.FeatureSet;
11 | import uk.ac.ed.inf.srl.ml.Model;
12 |
13 | public abstract class AbstractStep implements PipelineStep {
14 |
15 | public static final Integer POSITIVE = 1;
16 | public static final Integer NEGATIVE = 0;
17 |
18 | protected FeatureSet featureSet;
19 | protected Map models;
20 |
21 | public AbstractStep(FeatureSet fs) {
22 | this.featureSet = fs;
23 | }
24 |
25 | public abstract void extractInstances(Sentence s);
26 |
27 | public abstract void parse(Sentence s);
28 |
29 | protected abstract String getModelFileName();
30 |
31 | @Override
32 | public void readModels(ZipFile zipFile) throws IOException,
33 | ClassNotFoundException {
34 | models = new HashMap<>();
35 | readModels(zipFile, models, getModelFileName());
36 | }
37 |
38 |
39 | static void readModels(ZipFile zipFile, Map models,
40 | String filename) throws IOException, ClassNotFoundException {
41 | ObjectInputStream ois = new ObjectInputStream(
42 | zipFile.getInputStream(zipFile.getEntry(filename)));
43 | int numberOfModels = ois.readInt();
44 | for (int i = 0; i < numberOfModels; ++i) {
45 | String POSPrefix = (String) ois.readObject();
46 | Model m = (Model) ois.readObject();
47 | models.put(POSPrefix, m);
48 | }
49 | }
50 |
51 | }
52 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/pipeline/PipelineStep.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.pipeline;
2 |
3 | import java.io.IOException;
4 | import java.util.zip.ZipFile;
5 | import java.util.zip.ZipOutputStream;
6 |
7 | import se.lth.cs.srl.corpus.Sentence;
8 |
9 | public interface PipelineStep {
10 |
11 | public void prepareLearning();
12 |
13 | public void extractInstances(Sentence s);
14 |
15 | public void done();
16 |
17 | public void train();
18 |
19 | public void writeModels(ZipOutputStream zos) throws IOException;
20 |
21 | public void readModels(ZipFile zipFile) throws IOException,
22 | ClassNotFoundException;
23 |
24 | public void parse(Sentence s);
25 |
26 | public void prepareLearning(int i);
27 |
28 | }
29 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/pipeline/Step.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.pipeline;
2 |
3 | public enum Step {
4 | pi, pd, ai, ac, /* po, ao */
5 | }
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/preprocessor/CMDLineTokenizer.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.preprocessor;
2 |
3 | import java.io.BufferedReader;
4 | import java.io.File;
5 | import java.io.IOException;
6 | import java.io.InputStreamReader;
7 |
8 | import se.lth.cs.srl.languages.Language;
9 | import se.lth.cs.srl.languages.Language.L;
10 | import se.lth.cs.srl.preprocessor.tokenization.Tokenizer;
11 |
12 | public class CMDLineTokenizer {
13 |
14 | public static void main(String[] args) throws IOException {
15 | String l = args[0];
16 | File modelFile = args.length > 1 ? new File(args[1]) : null;
17 | L lang = null;
18 | try {
19 | lang = L.valueOf(l);
20 | } catch (Exception e) {
21 | System.err.println("Unknown language " + l + ", aborting.");
22 | System.exit(1);
23 | }
24 | Language.setLanguage(lang);
25 | Tokenizer tokenizer = Language.getLanguage().getTokenizer(modelFile);
26 | BufferedReader reader = new BufferedReader(new InputStreamReader(
27 | System.in, "UTF8"));
28 | String line;
29 | int senCount = 0;
30 | while ((line = reader.readLine()) != null) {
31 | senCount++;
32 | String[] tokens = tokenizer.tokenize(line);
33 | for (int i = 1; i < tokens.length; ++i) {
34 | StringBuilder sb = new StringBuilder();
35 | sb.append(senCount).append('_').append(i).append('\t')
36 | .append(tokens[i]).append(COLUMNS);
37 | System.out.println(sb.toString());
38 | }
39 | System.out.println();
40 | }
41 | }
42 |
43 | static final String COLUMNS = "\t_\t_\t_\t_\t_\t_\t_\t_\t_\t_\t_\t_";
44 |
45 | public static void usage() {
46 | System.err
47 | .println("Reads untokenized text on STDIN (one sentence per line), and writes it out in CoNLL09 format to STDOUT");
48 | System.err.println("Usage: java -cp ... "
49 | + CMDLineTokenizer.class.getCanonicalName()
50 | + " [model-file]");
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/preprocessor/Preprocessor.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.preprocessor;
2 |
3 | import se.lth.cs.srl.corpus.StringInText;
4 | import se.lth.cs.srl.preprocessor.tokenization.Tokenizer;
5 | import is2.data.SentenceData09;
6 |
7 | public abstract class Preprocessor {
8 | protected Tokenizer tokenizer;
9 | public long tokenizeTime = 0;
10 | public long lemmatizeTime = 0;
11 | public long dpTime = 0;
12 |
13 | public abstract boolean hasParser();
14 |
15 | public abstract StringBuilder getStatus();
16 |
17 | protected abstract SentenceData09 preprocess(SentenceData09 sentence);
18 |
19 | public SentenceData09 preprocess(String[] forms) {
20 | SentenceData09 instance = new SentenceData09();
21 | instance.init(forms);
22 | return preprocess(instance);
23 | }
24 |
25 | public String[] tokenize(String sentence) {
26 | synchronized (tokenizer) {
27 | long start = System.currentTimeMillis();
28 | String[] words = tokenizer.tokenize(sentence);
29 | tokenizeTime += (System.currentTimeMillis() - start);
30 | return words;
31 | }
32 | }
33 |
34 | public StringInText[] tokenizeplus(String sentence) {
35 | synchronized (tokenizer) {
36 | long start = System.currentTimeMillis();
37 | StringInText[] words = tokenizer.tokenizeplus(sentence);
38 | tokenizeTime += (System.currentTimeMillis() - start);
39 | return words;
40 | }
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/preprocessor/SimpleChineseLemmatizer.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.preprocessor;
2 |
3 | import is2.data.SentenceData09;
4 | import is2.lemmatizer.Lemmatizer;
5 |
6 | public class SimpleChineseLemmatizer extends Lemmatizer {
7 |
8 | // @Override
9 | // public String[] getLemmas(String[] forms) {
10 | // if(true)
11 | // throw new
12 | // Error("This method should not be trusted. Fix the root token in accordance with is2.lemmatizer.Lemmatizer before using it.");
13 | // String[] ret=new String[forms.length]; //TODO, make sure to deal with the
14 | // root token properly.
15 | // ret[0]="";
16 | // for(int i=1;i";
32 | withRoot[0] = is2.io.CONLLReader09.ROOT;
33 | System.arraycopy(tokens, 0, withRoot, 1, tokens.length);
34 | return withRoot;
35 | }
36 |
37 | public static OpenNLPToolsTokenizerWrapper loadOpenNLPTokenizer(
38 | File modelFile) throws IOException {
39 | BufferedInputStream modelIn = new BufferedInputStream(
40 | new FileInputStream(modelFile.toString()));
41 | opennlp.tools.tokenize.Tokenizer tokenizer = new TokenizerME(
42 | new TokenizerModel(modelIn));
43 | return new OpenNLPToolsTokenizerWrapper(tokenizer);
44 | }
45 |
46 | @Override
47 | public StringInText[] tokenizeplus(String sentence) {
48 | Reader r = new StringReader(sentence);
49 | List l = new ArrayList<>();
50 | for (String s : tokenize(sentence)) {
51 | Word w = new Word(s);
52 | l.add(new StringInText(w.word(), w.beginPosition() + startpos, w
53 | .endPosition() + startpos));
54 | }
55 | StringInText[] tok = new StringInText[l.size()];
56 | // tok[0]=new StringInText(is2.io.CONLLReader09.ROOT,0,0);
57 | int i = 0;
58 | for (StringInText s : l)
59 | tok[i++] = s;
60 |
61 | startpos += (1 + sentence.length());
62 |
63 | return tok;
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/preprocessor/tokenization/Tokenizer.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.preprocessor.tokenization;
2 |
3 | import se.lth.cs.srl.corpus.StringInText;
4 |
5 | public interface Tokenizer {
6 |
7 | /**
8 | * Tokenize a sentence. The returned array contains a root-token
9 | *
10 | * @param sentence
11 | * The sentence to tokenize
12 | * @return a root token, followed by the forms
13 | */
14 | public abstract String[] tokenize(String sentence);
15 |
16 | public abstract StringInText[] tokenizeplus(String sentence);
17 |
18 | }
19 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/preprocessor/tokenization/WhiteSpaceTokenizer.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.preprocessor.tokenization;
2 |
3 | import java.util.ArrayList;
4 | import java.util.List;
5 | import java.util.StringTokenizer;
6 |
7 | import se.lth.cs.srl.corpus.StringInText;
8 |
9 | public class WhiteSpaceTokenizer implements Tokenizer {
10 |
11 | @Override
12 | public String[] tokenize(String sentence) {
13 | StringTokenizer tokenizer = new StringTokenizer(sentence);
14 | String[] tokens = new String[tokenizer.countTokens() + 1];
15 | int r = 0;
16 | tokens[r++] = is2.io.CONLLReader09.ROOT;
17 | while (tokenizer.hasMoreTokens())
18 | tokens[r++] = tokenizer.nextToken();
19 | return tokens;
20 | }
21 |
22 | public static void main(String[] args) {
23 | String t1 = "En gul bil körde hundratusen mil.";
24 | String t2 = "Leonardos fullständiga namn var Leonardo di ser Piero da Vinci.";
25 | String t3 = "Genom skattereformen införs individuell beskattning (särbeskattning) av arbetsinkomster.";
26 | String t4 = "\"Oh, no,\" she's saying, \"our $400 blender can't handle something this hard!\"";
27 | String[] tests = { t1, t2, t3, t4 };
28 | for (String test : tests) {
29 | WhiteSpaceTokenizer tokenizer = new WhiteSpaceTokenizer();
30 | String[] tokens = tokenizer.tokenize(test);
31 | for (String token : tokens)
32 | System.out.println(token);
33 | System.out.println();
34 | }
35 | }
36 |
37 | @Override
38 | public StringInText[] tokenizeplus(String sentence) {
39 | int offset = 0;
40 | StringTokenizer tokenizer = new StringTokenizer(sentence);
41 | List l = new ArrayList<>();
42 |
43 | while (tokenizer.hasMoreTokens()) {
44 | String s = tokenizer.nextToken();
45 | l.add(new StringInText(s, offset, offset + s.length()));
46 | offset += (1 + s.length());
47 | }
48 | StringInText[] tok = new StringInText[l.size() + 1];
49 | tok[0] = new StringInText(is2.io.CONLLReader09.ROOT, 0, 0);
50 | int i = 1;
51 | for (StringInText s : l)
52 | tok[i++] = s;
53 |
54 | return tok;
55 | }
56 |
57 | }
58 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/preprocessor/tokenization/exner/SwedishTokenizer.java:
--------------------------------------------------------------------------------
1 | /**
2 | * SweNLP is a framework for performing parallel processing of text.
3 | * Copyright � 2011 Peter Exner
4 | *
5 | * This file is part of SweNLP.
6 | *
7 | * SweNLP is free software: you can redistribute it and/or modify
8 | * it under the terms of the GNU General Public License as published by
9 | * the Free Software Foundation, either version 3 of the License, or
10 | * (at your option) any later version.
11 | *
12 | * SweNLP is distributed in the hope that it will be useful,
13 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | * GNU General Public License for more details.
16 | *
17 | * You should have received a copy of the GNU General Public License
18 | * along with SweNLP. If not, see .
19 | */
20 |
21 | package se.lth.cs.srl.preprocessor.tokenization.exner;
22 |
23 | import java.io.IOException;
24 | import java.io.StringReader;
25 | import java.io.UnsupportedEncodingException;
26 | import java.nio.charset.Charset;
27 | import java.util.ArrayList;
28 |
29 | public class SwedishTokenizer {
30 | public ArrayList tokenize(String sentence, Charset charset) {
31 | ArrayList tokens = new ArrayList<>();
32 |
33 | try {
34 | sentence = preProcessSentence(sentence);
35 |
36 | Tokenizer swedishTokenizer = new Tokenizer(new StringReader(
37 | new String(sentence.getBytes(charset),
38 | Charset.forName("UTF-8"))));
39 |
40 | while (swedishTokenizer.getNextToken() >= 0) {
41 | tokens.add(swedishTokenizer.yytext());
42 | }
43 |
44 | } catch (UnsupportedEncodingException e) {
45 | // TODO Auto-generated catch block
46 | e.printStackTrace();
47 | } catch (IOException e) {
48 | // TODO Auto-generated catch block
49 | e.printStackTrace();
50 | }
51 |
52 | return tokens;
53 | }
54 |
55 | private static String preProcessSentence(String sentence) {
56 | // Done to make the flex rules match
57 |
58 | sentence = sentence.replaceAll(" \\.", "\\.");
59 | sentence = sentence.replaceAll(" \\)", "\\)");
60 |
61 | return sentence;
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/util/BohnetHelper.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.util;
2 |
3 | import is2.lemmatizer.Lemmatizer;
4 | import is2.parser.Parser;
5 | import is2.tag.Tagger;
6 |
7 | import java.io.File;
8 | import java.io.FileNotFoundException;
9 | import java.io.IOException;
10 |
11 | import se.lth.cs.srl.options.FullPipelineOptions;
12 | import se.lth.cs.srl.options.Options;
13 |
14 | public class BohnetHelper {
15 |
16 | public static Lemmatizer getLemmatizer(File modelFile)
17 | throws FileNotFoundException, IOException {
18 | String[] argsL = { "-model", modelFile.toString() };
19 | return new Lemmatizer(modelFile.toString(), false);
20 |
21 | // new is2.lemmatizer.Options(argsL));
22 | }
23 |
24 | public static Tagger getTagger(File modelFile) {
25 | String[] argsT = { "-model", modelFile.toString() };
26 | return new Tagger(modelFile.toString());
27 | // new is2.tag.Options(argsT));
28 | }
29 |
30 | public static is2.mtag.Tagger getMTagger(File modelFile) throws IOException {
31 | String[] argsMT = { "-model", modelFile.toString() };
32 | return new is2.mtag.Tagger(modelFile.toString());
33 | // new is2.mtag.Options(argsMT));
34 | }
35 |
36 | public static Parser getParser(File modelFile) {
37 | String[] argsDP = {
38 | "-model",
39 | modelFile.toString(),
40 | "-cores",
41 | Integer.toString(Math.min(Options.cores,
42 | FullPipelineOptions.cores)) };
43 | return new Parser(modelFile.toString());
44 | // new is2.parser.Options(argsDP));
45 | }
46 |
47 | }
48 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/util/ChineseDesegmenter.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.util;
2 |
3 | import java.io.File;
4 | import java.io.FileNotFoundException;
5 | import java.io.FileOutputStream;
6 | import java.io.PrintStream;
7 |
8 | import se.lth.cs.srl.corpus.Sentence;
9 | import se.lth.cs.srl.io.DepsOnlyCoNLL09Reader;
10 | import se.lth.cs.srl.io.SentenceReader;
11 |
12 | public class ChineseDesegmenter {
13 |
14 | public static String desegment(String[] forms) {
15 | StringBuilder ret = new StringBuilder();
16 | for (int i = 1; i < forms.length; ++i)
17 | ret.append(forms[i]);
18 | return ret.toString();
19 | }
20 |
21 | public static void main(String[] args) throws FileNotFoundException {
22 |
23 | String inputFile =
24 | // "/home/anders/corpora/conll09/CoNLL2009-ST-Chinese-train.txt";
25 | "/home/anders/corpora/conll09/chi/CoNLL2009-ST-evaluation-Chinese.txt";
26 | String outputFile = "chi-desegmented.out";
27 | boolean separateLines = true;
28 | if (args.length > 0)
29 | inputFile = args[0];
30 | if (args.length > 1)
31 | outputFile = args[1];
32 | if (args.length > 2)
33 | separateLines = Boolean.parseBoolean(args[2]); // Whether to print
34 | // newlines between
35 | // sentences.
36 | File input = new File(inputFile);
37 | File output = new File(outputFile);
38 | SentenceReader reader = new DepsOnlyCoNLL09Reader(input);
39 | PrintStream out = new PrintStream(new FileOutputStream(output));
40 | for (Sentence s : reader) {
41 | String desegmented = desegment(s.getFormArray());
42 | if (separateLines)
43 | out.println(desegmented);
44 | else
45 | out.print(desegmented);
46 | }
47 | }
48 |
49 | }
50 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/util/FileExistenceVerifier.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.util;
2 |
3 | import java.io.File;
4 |
5 | import se.lth.cs.srl.languages.Language;
6 | import se.lth.cs.srl.options.FullPipelineOptions;
7 |
8 | public class FileExistenceVerifier {
9 |
10 | /**
11 | * Checks if a file can be exists, and if it can be read.
12 | *
13 | * @param files
14 | * vararg number of files to check
15 | * @return null if all is good, otherwise a String containing the error
16 | * message.
17 | */
18 | public static String verifyFiles(File... files) {
19 | StringBuilder sb = new StringBuilder();
20 | for (File f : files) {
21 | if (f == null || !f.exists()) {
22 | sb.append("File " + f + " does not exist.\n");
23 | }
24 | if (f == null || !f.canRead()) {
25 | sb.append("File " + f + " can not be read.\n");
26 | }
27 | }
28 | if (sb.length() == 0)
29 | return null;
30 | else
31 | return sb.toString();
32 | }
33 |
34 | public static String verifyCompletePipelineAlwaysNecessaryFiles(
35 | FullPipelineOptions options) {
36 | return verifyFiles(/* options.tagger, options.parser,*/ options.srl);
37 | }
38 |
39 | public static String verifyCompletePipelineAllNecessaryModelFiles(
40 | FullPipelineOptions options) {
41 | String error1 = verifyCompletePipelineAlwaysNecessaryFiles(options);
42 | String error2 = Language.getLanguage()
43 | .verifyLanguageSpecificModelFiles(options);
44 | if (error1 != null) {
45 | if (error2 != null) {
46 | return error1 + error2;
47 | } else {
48 | return error1;
49 | }
50 | } else if (error2 != null) {
51 | return error2;
52 | } else {
53 | return null;
54 | }
55 | }
56 |
57 | }
58 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/util/Relation.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.util;
2 |
3 | public class Relation {
4 | public int head;
5 | public int dependent;
6 | public String label;
7 |
8 | public Relation(int head, int dependent, String label) {
9 | this.head = head;
10 | this.dependent = dependent;
11 | this.label = label;
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/util/TurboParser.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.util;
2 |
3 | import java.io.BufferedReader;
4 | import java.io.FileReader;
5 | import java.io.IOException;
6 |
7 | import se.lth.cs.srl.corpus.Sentence;
8 | import se.lth.cs.srl.corpus.Word;
9 |
10 | public class TurboParser {
11 |
12 | BufferedReader br;
13 |
14 | public TurboParser(String file) {
15 | try {
16 | br = new BufferedReader(new FileReader(file));
17 |
18 | } catch (IOException e) {
19 | e.printStackTrace();
20 | System.exit(1);
21 | }
22 | }
23 |
24 | public void overwriteParse(Sentence s) {
25 | try {
26 | // skip ROOT (i==0);
27 | for (int i = 1; i < s.size(); i++) {
28 | Word w = s.get(i);
29 | String line = br.readLine();
30 | // if current line is blank (end of last sentence), read next
31 | // line
32 | if (line.equals(""))
33 | line = br.readLine();
34 |
35 | String[] parts = line.split("\t");
36 | // sanity check
37 | if (!parts[1].toLowerCase().equals(w.getForm().toLowerCase())) {
38 | System.err
39 | .println("WARNING: different normalization applied? ("
40 | + parts[1] + " vs. " + w.getForm() + ")");
41 | w.setLemma(w.getForm().replaceAll("[0-9]", "D"));
42 | }
43 |
44 | // CoNLL-X
45 | /**/w.setPOS(parts[3]);
46 | w.setHeadId(Integer.parseInt(parts[6]));
47 | w.setDeprel(parts[7]);/**/
48 |
49 | // CoNLL-09
50 | /**
51 | * w.setPOS(parts[4]); w.setHeadId(Integer.parseInt(parts[8]));
52 | * w.setDeprel(parts[10]);/
53 | **/
54 |
55 | }
56 | s.buildDependencyTree();
57 |
58 | } catch (IOException e) {
59 | e.printStackTrace();
60 | System.exit(1);
61 | }
62 | }
63 |
64 | }
65 |
--------------------------------------------------------------------------------
/src/main/java/se/lth/cs/srl/util/Util.java:
--------------------------------------------------------------------------------
1 | package se.lth.cs.srl.util;
2 |
3 | public class Util {
4 |
5 | public static String insertCommas(long l) {
6 | StringBuilder ret = new StringBuilder(Long.toString(l));
7 | ret.reverse();
8 | for (int i = 3; i < ret.length(); i += 4) {
9 | if (i + 1 <= ret.length())
10 | ret.insert(i, ",");
11 | }
12 | return ret.reverse().toString();
13 | }
14 |
15 | }
16 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/AnySetFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import se.lth.cs.srl.corpus.Predicate;
4 | import se.lth.cs.srl.corpus.Sentence;
5 | import se.lth.cs.srl.corpus.Word;
6 | import se.lth.cs.srl.corpus.Word.WordData;
7 | import uk.ac.ed.inf.srl.features.FeatureName;
8 | import uk.ac.ed.inf.srl.features.SetFeature;
9 |
10 | public class AnySetFeature extends SetFeature {
11 | private static final long serialVersionUID = 1L;
12 |
13 | WordData attr;
14 |
15 | protected AnySetFeature(FeatureName name, WordData attr,
16 | boolean usedForPredicateIdentification, String POSPrefix) {
17 | super(name, true, false, POSPrefix);
18 | this.attr = attr;
19 | }
20 |
21 | @Override
22 | public String[] getFeatureStrings(Sentence s, int predIndex, int argIndex) {
23 | return makeFeatureStrings(s.get(argIndex));
24 | }
25 |
26 | @Override
27 | public String[] getFeatureStrings(Predicate pred, Word arg) {
28 | return makeFeatureStrings(arg);
29 | }
30 |
31 | private String[] makeFeatureStrings(Word w) {
32 | String[] ret = new String[w.getSpan().size()];
33 | int i = 0;
34 | for (Word child : w.getChildren())
35 | ret[i++] = child.getAttr(attr);
36 | return ret;
37 | }
38 |
39 | @Override
40 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
41 | for (Word child : s)
42 | addMap(child.getAttr(attr));
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/ArgDependentAttrFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import java.util.Collection;
4 | import java.util.Map;
5 |
6 | import se.lth.cs.srl.corpus.Predicate;
7 | import se.lth.cs.srl.corpus.Sentence;
8 | import se.lth.cs.srl.corpus.Word;
9 | import se.lth.cs.srl.corpus.Word.WordData;
10 | import uk.ac.ed.inf.srl.features.AttrFeature;
11 | import uk.ac.ed.inf.srl.features.FeatureName;
12 | import uk.ac.ed.inf.srl.features.TargetWord;
13 |
14 | public class ArgDependentAttrFeature extends AttrFeature {
15 | private static final long serialVersionUID = 1L;
16 |
17 | protected ArgDependentAttrFeature(FeatureName name, WordData attr,
18 | TargetWord tw, String POSPrefix) {
19 | super(name, attr, tw, true, false, POSPrefix);
20 | }
21 |
22 | @Override
23 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
24 | for (Predicate p : s.getPredicates()) {
25 | if (doExtractFeatures(p))
26 | for (Word arg : p.getArgMap().keySet()) {
27 | for (Word w : arg.getSpan())
28 | addMap(w.getAttr(attr));
29 | }
30 | }
31 | }
32 |
33 | @Override
34 | public String getFeatureString(Sentence s, int predIndex, int argIndex) {
35 | Word w = wordExtractor.getWord(s, predIndex, argIndex);
36 | if (w == null)
37 | return null;
38 | else
39 | return w.getAttr(attr);
40 | }
41 |
42 | @Override
43 | public String getFeatureString(Predicate pred, Word arg) {
44 | Word w = wordExtractor.getWord(pred, arg);
45 | if (w == null)
46 | return null;
47 | else
48 | return w.getAttr(attr);
49 | }
50 |
51 | @Override
52 | public void addFeatures(Sentence s, Collection indices,
53 | Map nonbinFeats, int predIndex, int argIndex,
54 | Integer offset, boolean allWords) {
55 | addFeatures(indices, getFeatureString(s, predIndex, argIndex), offset,
56 | allWords);
57 |
58 | }
59 |
60 | @Override
61 | public void addFeatures(Collection indices,
62 | Map nonbinFeats, Predicate pred, Word arg,
63 | Integer offset, boolean allWords) {
64 | addFeatures(indices, getFeatureString(pred, arg), offset, allWords);
65 | }
66 |
67 | private void addFeatures(Collection indices, String featureString,
68 | Integer offset, boolean allWords) {
69 | if (featureString == null)
70 | return;
71 | Integer i = indexOf(featureString);
72 | if (i != -1 && (allWords || i < predMaxIndex))
73 | indices.add(i + offset);
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/ArgDependentBrown.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import se.lth.cs.srl.corpus.Predicate;
4 | import se.lth.cs.srl.corpus.Sentence;
5 | import se.lth.cs.srl.corpus.Word;
6 | import uk.ac.ed.inf.srl.features.ArgDependentAttrFeature;
7 | import uk.ac.ed.inf.srl.features.FeatureName;
8 | import uk.ac.ed.inf.srl.features.TargetWord;
9 | import se.lth.cs.srl.util.BrownCluster;
10 | import se.lth.cs.srl.util.BrownCluster.ClusterVal;
11 |
12 | public class ArgDependentBrown extends ArgDependentAttrFeature {
13 | private static final long serialVersionUID = 1L;
14 |
15 | private BrownCluster bc;
16 | private ClusterVal cv;
17 |
18 | protected ArgDependentBrown(FeatureName name, TargetWord tw,
19 | String POSPrefix, BrownCluster bc, ClusterVal cv) {
20 | super(name, null, tw, POSPrefix);
21 | this.bc = bc;
22 | this.cv = cv;
23 | }
24 |
25 | @Override
26 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
27 | for (Predicate p : s.getPredicates()) {
28 | if (doExtractFeatures(p))
29 | for (Word arg : p.getArgMap().keySet()) {
30 | Word w = wordExtractor.getWord(null, arg);
31 | if (w != null)
32 | addMap(bc.getValue(w.getForm(), cv));
33 | }
34 | }
35 | }
36 |
37 | @Override
38 | public String getFeatureString(Sentence s, int predIndex, int argIndex) {
39 | Word w = wordExtractor.getWord(s, predIndex, argIndex);
40 | if (w == null)
41 | return null;
42 | else
43 | return bc.getValue(w.getForm(), cv);
44 |
45 | }
46 |
47 | @Override
48 | public String getFeatureString(Predicate pred, Word arg) {
49 | Word w = wordExtractor.getWord(pred, arg);
50 | if (w == null)
51 | return null;
52 | else
53 | return bc.getValue(w.getForm(), cv);
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/ArgDependentFeatsFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import java.util.Collection;
4 | import java.util.Map;
5 |
6 | import se.lth.cs.srl.corpus.Predicate;
7 | import se.lth.cs.srl.corpus.Sentence;
8 | import se.lth.cs.srl.corpus.Word;
9 | import uk.ac.ed.inf.srl.features.FeatsFeature;
10 | import uk.ac.ed.inf.srl.features.FeatureName;
11 | import uk.ac.ed.inf.srl.features.TargetWord;
12 | import se.lth.cs.srl.languages.Language;
13 |
14 | public class ArgDependentFeatsFeature extends FeatsFeature {
15 | private static final long serialVersionUID = 1L;
16 |
17 | protected ArgDependentFeatsFeature(FeatureName name, TargetWord tw,
18 | String POSPrefix) {
19 | super(name, tw, true, false, POSPrefix);
20 | }
21 |
22 | @Override
23 | public String[] getFeatureStrings(Sentence s, int predIndex, int argIndex) {
24 | Word w = wordExtractor.getWord(s, predIndex, argIndex);
25 | if (w == null)
26 | return null;
27 | else
28 | return Language.getLanguage().getFeatSplitPattern()
29 | .split(w.getFeats());
30 | }
31 |
32 | @Override
33 | public String[] getFeatureStrings(Predicate pred, Word arg) {
34 | Word w = wordExtractor.getWord(pred, arg);
35 | if (w == null)
36 | return null;
37 | else
38 | return Language.getLanguage().getFeatSplitPattern()
39 | .split(w.getFeats());
40 | }
41 |
42 | @Override
43 | public void addFeatures(Sentence s, Collection indices,
44 | Map nonbinFeats, int predIndex, int argIndex,
45 | Integer offset, boolean allWords) {
46 | addFeatures(indices, getFeatureStrings(s, predIndex, argIndex), offset,
47 | allWords);
48 |
49 | }
50 |
51 | @Override
52 | public void addFeatures(Collection indices,
53 | Map nonbinFeats, Predicate pred, Word arg,
54 | Integer offset, boolean allWords) {
55 | addFeatures(indices, getFeatureStrings(pred, arg), offset, allWords);
56 | }
57 |
58 | private void addFeatures(Collection indices, String[] values,
59 | Integer offset, boolean allWords) {
60 | if (values == null)
61 | return;
62 | for (String v : values) {
63 | Integer i = indexOf(v);
64 | if (i != -1 && (allWords || i < predMaxIndex))
65 | indices.add(i + offset);
66 | }
67 | }
68 |
69 | @Override
70 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
71 | for (Predicate p : s.getPredicates()) {
72 | if (doExtractFeatures(p))
73 | for (Word arg : p.getArgMap().keySet()) {
74 | Word w = wordExtractor.getWord(null, arg);
75 | if (w == null)
76 | continue;
77 | for (String v : Language.getLanguage()
78 | .getFeatSplitPattern().split(w.getFeats()))
79 | addMap(v);
80 | }
81 | }
82 |
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/AttrFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import se.lth.cs.srl.corpus.Word.WordData;
4 | import uk.ac.ed.inf.srl.features.FeatureName;
5 | import uk.ac.ed.inf.srl.features.SingleFeature;
6 | import uk.ac.ed.inf.srl.features.TargetWord;
7 | import uk.ac.ed.inf.srl.features.WordExtractor;
8 |
9 | public abstract class AttrFeature extends SingleFeature {
10 | private static final long serialVersionUID = 1;
11 |
12 | protected WordData attr;
13 | protected WordExtractor wordExtractor;
14 |
15 | protected AttrFeature(FeatureName name, WordData attr, TargetWord tw,
16 | boolean includeArgs, boolean usedForPredicateIdentification,
17 | String POSPrefix) {
18 | super(name, includeArgs, usedForPredicateIdentification, POSPrefix);
19 | this.attr = attr;
20 | this.wordExtractor = WordExtractor.getExtractor(tw);
21 | }
22 |
23 | }
24 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/ChildSetFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import se.lth.cs.srl.corpus.Predicate;
4 | import se.lth.cs.srl.corpus.Sentence;
5 | import se.lth.cs.srl.corpus.Word;
6 | import se.lth.cs.srl.corpus.Word.WordData;
7 | import uk.ac.ed.inf.srl.features.FeatureName;
8 | import uk.ac.ed.inf.srl.features.SetFeature;
9 |
10 | public class ChildSetFeature extends SetFeature {
11 | private static final long serialVersionUID = 1L;
12 |
13 | WordData attr;
14 |
15 | protected ChildSetFeature(FeatureName name, WordData attr,
16 | boolean usedForPredicateIdentification, String POSPrefix) {
17 | super(name, false, usedForPredicateIdentification, POSPrefix);
18 | this.attr = attr;
19 | }
20 |
21 | @Override
22 | public String[] getFeatureStrings(Sentence s, int predIndex, int argIndex) {
23 | return makeFeatureStrings(s.get(predIndex));
24 | }
25 |
26 | @Override
27 | public String[] getFeatureStrings(Predicate pred, Word arg) {
28 | return makeFeatureStrings(pred);
29 | }
30 |
31 | private String[] makeFeatureStrings(Word pred) {
32 | String[] ret = new String[pred.getChildren().size()];
33 | int i = 0;
34 | for (Word child : pred.getChildren())
35 | ret[i++] = child.getAttr(attr);
36 | return ret;
37 | }
38 |
39 | @Override
40 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
41 | if (allWords) {
42 | for (int i = 1, size = s.size(); i < size; ++i) {
43 | if (doExtractFeatures(s.get(i)))
44 | for (Word child : s.get(i).getChildren()) {
45 | addMap(child.getAttr(attr));
46 | }
47 | }
48 | } else {
49 | for (Predicate pred : s.getPredicates()) {
50 | if (doExtractFeatures(pred))
51 | for (Word child : pred.getChildren())
52 | addMap(child.getAttr(attr));
53 | }
54 | }
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/ContinuousArgDependentAttrFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import java.util.Collection;
4 | import java.util.Map;
5 |
6 | import se.lth.cs.srl.corpus.Predicate;
7 | import se.lth.cs.srl.corpus.Sentence;
8 | import se.lth.cs.srl.corpus.Word;
9 | import se.lth.cs.srl.corpus.Word.WordData;
10 | import uk.ac.ed.inf.srl.features.ContinuousAttrFeature;
11 | import uk.ac.ed.inf.srl.features.FeatureName;
12 | import uk.ac.ed.inf.srl.features.TargetWord;
13 |
14 | public class ContinuousArgDependentAttrFeature extends ContinuousAttrFeature {
15 | private static final long serialVersionUID = 1L;
16 |
17 | protected ContinuousArgDependentAttrFeature(FeatureName name,
18 | WordData attr, TargetWord tw, String POSPrefix) {
19 | super(name, attr, tw, true, false, POSPrefix);
20 | }
21 |
22 | @Override
23 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
24 | // for(Predicate p:s.getPredicates()){
25 | // if(doExtractFeatures(p))
26 | // for(Word arg:p.getArgMap().keySet()){
27 | // Word w=wordExtractor.getWord(null, arg);
28 | // if(w!=null)
29 | // addMap(w.getAttr(attr));
30 | // }
31 | // }
32 | }
33 |
34 | @Override
35 | public Double getFeatureValue(Sentence s, int predIndex, int argIndex) {
36 | return 0.0;
37 | }
38 |
39 | @Override
40 | public String getFeatureString(Sentence s, int predIndex, int argIndex) {
41 | Word w = wordExtractor.getWord(s, predIndex, argIndex);
42 | if (w == null)
43 | return null;
44 | else
45 | return w.getAttr(attr);
46 | }
47 |
48 | @Override
49 | public Double getFeatureValue(Predicate pred, Word arg) {
50 | return 0.0;
51 | }
52 |
53 | @Override
54 | public String getFeatureString(Predicate pred, Word arg) {
55 | Word w = wordExtractor.getWord(pred, arg);
56 | if (w == null)
57 | return null;
58 | else
59 | return w.getAttr(attr);
60 | }
61 |
62 | @Override
63 | public void addFeatures(Sentence s, Collection indices,
64 | Map nonbinFeats, int predIndex, int argIndex,
65 | Integer offset, boolean allWords) {
66 | addFeatures(nonbinFeats, getFeatureString(s, predIndex, argIndex),
67 | getFeatureValue(s, predIndex, argIndex), offset, allWords);
68 |
69 | }
70 |
71 | @Override
72 | public void addFeatures(Collection indices,
73 | Map nonbinFeats, Predicate pred, Word arg,
74 | Integer offset, boolean allWords) {
75 | addFeatures(nonbinFeats, getFeatureString(pred, arg),
76 | getFeatureValue(pred, arg), offset, allWords);
77 | }
78 |
79 | private void addFeatures(Map nonbinFeats,
80 | String featureString, Double val, Integer offset, boolean allWords) {
81 | if (featureString == null)
82 | return;
83 | Integer i = indexOf(featureString);
84 | if (i != -1 && (allWords || i < predMaxIndex))
85 | nonbinFeats.put(i + offset, val);
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/ContinuousAttrFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import se.lth.cs.srl.corpus.Word.WordData;
4 | import uk.ac.ed.inf.srl.features.ContinuousFeature;
5 | import uk.ac.ed.inf.srl.features.FeatureName;
6 | import uk.ac.ed.inf.srl.features.TargetWord;
7 | import uk.ac.ed.inf.srl.features.WordExtractor;
8 |
9 | public abstract class ContinuousAttrFeature extends ContinuousFeature {
10 | private static final long serialVersionUID = 1;
11 |
12 | protected WordData attr;
13 | protected WordExtractor wordExtractor;
14 |
15 | protected ContinuousAttrFeature(FeatureName name, WordData attr,
16 | TargetWord tw, boolean includeArgs,
17 | boolean usedForPredicateIdentification, String POSPrefix) {
18 | super(name, includeArgs, usedForPredicateIdentification, POSPrefix);
19 | this.attr = attr;
20 | this.wordExtractor = WordExtractor.getExtractor(tw);
21 | }
22 |
23 | }
24 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/ContinuousFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import java.util.Collection;
4 | import java.util.Map;
5 |
6 | import se.lth.cs.srl.corpus.Predicate;
7 | import se.lth.cs.srl.corpus.Sentence;
8 | import se.lth.cs.srl.corpus.Word;
9 | import uk.ac.ed.inf.srl.features.Feature;
10 | import uk.ac.ed.inf.srl.features.FeatureName;
11 |
12 | public abstract class ContinuousFeature extends Feature {
13 | private static final long serialVersionUID = 1L;
14 |
15 | protected ContinuousFeature(FeatureName name, boolean includeArgs,
16 | boolean usedForPredicateIdentification, String POSPrefix) {
17 | super(name, includeArgs, usedForPredicateIdentification, POSPrefix);
18 | indexcounter = 2;
19 | }
20 |
21 | @Override
22 | public void addFeatures(Sentence s, Collection indices,
23 | Map nonbinFeats, int predIndex, int argIndex,
24 | Integer offset, boolean allWords) {
25 | // Integer i=indexOf(getFeatureString(s,predIndex,argIndex));
26 | Double d = getFeatureValue(s, predIndex, argIndex);
27 |
28 | // if(i!=-1 && (allWords || i indices,
35 | Map nonbinFeats, Predicate pred, Word arg,
36 | Integer offset, boolean allWords) {
37 | // Integer i=indexOf(getFeatureString(pred,arg));
38 | Double d = getFeatureValue(pred, arg);
39 |
40 | // if(i!=-1 && (allWords || i children) {
33 | switch (children.size()) {
34 | case 0:
35 | return " "; // This is the String corresponding to 0 children. Yet
36 | // this should not be ignored as a feature which it
37 | // would by the addMap() method if it would return the
38 | // empty string or null. Not really sure if this is
39 | // optimal, or this feature should be ignored.
40 | case 1:
41 | return children.iterator().next().getDeprel();
42 | default:
43 | Word[] sortedChildren = children.toArray(new Word[0]);
44 | Arrays.sort(sortedChildren, s.wordComparator);
45 | StringBuilder ret = new StringBuilder(sortedChildren[0].getDeprel());
46 | for (int i = 1, size = sortedChildren.length; i < size; ++i)
47 | ret.append(SEPARATOR).append(sortedChildren[i].getDeprel());
48 | return ret.toString();
49 | }
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/DependencyCPathEmbedding.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import se.lth.cs.srl.corpus.Predicate;
4 | import se.lth.cs.srl.corpus.Sentence;
5 | import se.lth.cs.srl.corpus.Word;
6 | import uk.ac.ed.inf.srl.features.DependencyPathEmbedding;
7 | import uk.ac.ed.inf.srl.features.FeatureName;
8 | import uk.ac.ed.inf.srl.features.TargetWord;
9 | import uk.ac.ed.inf.srl.lstm.DataConverter;
10 | import uk.ac.ed.inf.srl.lstm.EmbeddingNetwork;
11 |
12 | public class DependencyCPathEmbedding extends DependencyPathEmbedding {
13 | private static final long serialVersionUID = 1L;
14 |
15 | protected DependencyCPathEmbedding(FeatureName name, TargetWord tw,
16 | String POSPrefix, boolean comp, EmbeddingNetwork net, DataConverter dc,
17 | int dim) {
18 | super(name, tw, POSPrefix, comp, net, dc, dim);
19 |
20 | for(int i=0; i values = new ArrayList<>();
24 | values.addAll(f.indices.keySet());
25 | Collections.sort(values);
26 | for (String value : values) {
27 | System.out.println(value + " - " + f.indexOf(value));
28 | }
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/FeatsFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import uk.ac.ed.inf.srl.features.FeatureName;
4 | import uk.ac.ed.inf.srl.features.SetFeature;
5 | import uk.ac.ed.inf.srl.features.TargetWord;
6 | import uk.ac.ed.inf.srl.features.WordExtractor;
7 |
8 | public abstract class FeatsFeature extends SetFeature {
9 | private static final long serialVersionUID = 1L;
10 |
11 | protected WordExtractor wordExtractor;
12 |
13 | protected FeatsFeature(FeatureName name, TargetWord tw,
14 | boolean includeArgs, boolean usedForPredicateIdentification,
15 | String POSPrefix) {
16 | super(name, includeArgs, usedForPredicateIdentification, POSPrefix);
17 | wordExtractor = WordExtractor.getExtractor(tw);
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/FeatureSet.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import java.util.Arrays;
4 | import java.util.HashMap;
5 | import java.util.List;
6 | import java.util.Map;
7 |
8 | import uk.ac.ed.inf.srl.features.Feature;
9 |
10 | /**
11 | * A feature set is basically a Map>, where String is a
12 | * prefix of a POS-tag, and List is the features used by the classifier
13 | * for that prefix.
14 | *
15 | * @author anders bjorkelund
16 | *
17 | */
18 |
19 | public class FeatureSet extends HashMap> {
20 | private static final long serialVersionUID = 1L;
21 | /**
22 | * The prefixes of this Map, sorted in reverse order, i.e. longer prefixes
23 | * go before shorter. And the empty string comes last.
24 | */
25 | public final String[] POSPrefixes;
26 |
27 | /**
28 | * Constructs a featureset based on the argument. The POSPrefixes are
29 | * extracted are sorted in reverse order on construction.
30 | *
31 | * @param featureSet
32 | * the featureset
33 | */
34 | public FeatureSet(Map> featureSet) {
35 | super(featureSet);
36 | POSPrefixes = this.keySet().toArray(new String[0]);
37 | // Arrays.sort(POSPrefixes,Collections.reverseOrder());
38 | Arrays.sort(POSPrefixes);
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/NumFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | public class NumFeature {
4 | static public String bin(int i) {
5 | if (i <= -20)
6 | return "-20";
7 | if (i <= -10)
8 | return "-10";
9 | if (i <= -5)
10 | return "-5";
11 |
12 | if (i >= 20)
13 | return "20";
14 | if (i >= 10)
15 | return "10";
16 | if (i >= 5)
17 | return "5";
18 |
19 | return Integer.toString(i);
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/PBLabelFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import java.util.Collection;
4 | import java.util.Map;
5 |
6 | import se.lth.cs.srl.corpus.Predicate;
7 | import se.lth.cs.srl.corpus.Sentence;
8 | import se.lth.cs.srl.corpus.Word;
9 | import uk.ac.ed.inf.srl.features.FeatureName;
10 | import uk.ac.ed.inf.srl.features.SingleFeature;
11 | import uk.ac.ed.inf.srl.features.TargetWord;
12 |
13 | public class PBLabelFeature extends SingleFeature {
14 | private static final long serialVersionUID = 1L;
15 |
16 | public PBLabelFeature(FeatureName name, TargetWord tw,
17 | boolean usedForPredicateIdentification, String POSPrefix) {
18 | super(name, false, usedForPredicateIdentification, POSPrefix);
19 | }
20 |
21 | @Override
22 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
23 | for (Predicate p : s.getPredicates()) {
24 | if (doExtractFeatures(p))
25 | for (Word arg : p.getArgMap().keySet()) {
26 | addMap(p.getArgMap().get(arg));
27 | }
28 | }
29 | }
30 |
31 | @Override
32 | public String getFeatureString(Sentence s, int predIndex, int argIndex) {
33 | return ((Predicate) s.get(predIndex)).getArgMap().get(/*
34 | * wordExtractor.getWord
35 | * (s, predIndex,
36 | * argIndex)
37 | */s.get(argIndex));
38 | }
39 |
40 | @Override
41 | public String getFeatureString(Predicate pred, Word arg) {
42 | return pred.getArgMap().get(/* wordExtractor.getWord(pred, */arg);
43 | }
44 |
45 | @Override
46 | public void addFeatures(Sentence s, Collection indices,
47 | Map nonbinFeats, int predIndex, int argIndex,
48 | Integer offset, boolean allWords) {
49 | addFeatures(indices, getFeatureString(s, predIndex, argIndex), offset,
50 | allWords);
51 | }
52 |
53 | @Override
54 | public void addFeatures(Collection indices,
55 | Map nonbinFeats, Predicate pred, Word arg,
56 | Integer offset, boolean allWords) {
57 | addFeatures(indices, getFeatureString(pred, arg), offset, allWords);
58 | }
59 |
60 | private void addFeatures(Collection indices, String featureString,
61 | Integer offset, boolean allWords) {
62 | if (featureString == null) {
63 | return;
64 | }
65 | Integer i = indexOf(featureString);
66 | if (i != -1 && (allWords || i < predMaxIndex))
67 | indices.add(i + offset);
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/PathFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import java.util.List;
4 |
5 | import se.lth.cs.srl.corpus.Predicate;
6 | import se.lth.cs.srl.corpus.Sentence;
7 | import se.lth.cs.srl.corpus.Word;
8 | import se.lth.cs.srl.corpus.Word.WordData;
9 | import uk.ac.ed.inf.srl.features.FeatureName;
10 | import uk.ac.ed.inf.srl.features.SingleFeature;
11 |
12 | public class PathFeature extends SingleFeature {
13 | private static final long serialVersionUID = 1L;
14 |
15 | private static final String UP = "0";
16 | private static final String DOWN = "1";
17 |
18 | private boolean consider_deptree;
19 | private WordData attr;
20 |
21 | protected PathFeature(FeatureName name, WordData attr,
22 | boolean consider_deptree, String POSPrefix) {
23 | super(name, true, false, POSPrefix);
24 | this.attr = attr;
25 | this.consider_deptree = consider_deptree;
26 | }
27 |
28 | @Override
29 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
30 | for (Predicate pred : s.getPredicates()) {
31 | if (doExtractFeatures(pred))
32 | for (Word arg : pred.getArgMap().keySet()) {
33 | addMap(getFeatureString(pred, arg));
34 | }
35 | }
36 | }
37 |
38 | @Override
39 | public String getFeatureString(Sentence s, int predIndex, int argIndex) {
40 | return makeFeatureString(s.get(predIndex), s.get(argIndex));
41 | }
42 |
43 | @Override
44 | public String getFeatureString(Predicate pred, Word arg) {
45 | return makeFeatureString(pred, arg);
46 | }
47 |
48 | public String makeFeatureString(Word pred, Word arg) {
49 | StringBuilder ret = new StringBuilder();
50 | if (consider_deptree) {
51 | return ret.append(makeDepBasedFeatureString(pred, arg)).toString();
52 | }
53 |
54 | ret.append("NODEP");
55 | boolean up = true;
56 | if (pred.getIdx() < arg.getIdx())
57 | up = false;
58 |
59 | if (Math.abs(pred.getIdx() - arg.getIdx()) == 0)
60 | return " ";
61 |
62 | Sentence s = pred.getMySentence();
63 | for (int i = up ? arg.getIdx() : pred.getIdx(); i < (up ? pred.getIdx()
64 | : arg.getIdx()); i++) {
65 | ret.append(s.get(i).getAttr(attr));
66 | ret.append(up ? UP : DOWN);
67 | }
68 | return ret.toString();
69 | }
70 |
71 | public String makeDepBasedFeatureString(Word pred, Word arg) {
72 | boolean up = true;
73 | List path = Word.findPath(pred, arg);
74 | if (path.size() == 0)
75 | return " ";
76 | StringBuilder ret = new StringBuilder();
77 | for (int i = 0; i < (path.size() - 1); ++i) {
78 | Word w = path.get(i);
79 | ret.append(w.getAttr(attr));
80 | if (up) {
81 | if (w.getHead() == path.get(i + 1)) { // Arrow up
82 | ret.append(UP);
83 | } else { // Arrow down
84 | ret.append(DOWN);
85 | up = false;
86 | }
87 | } else {
88 | ret.append(DOWN);
89 | }
90 | }
91 | return ret.toString();
92 |
93 | }
94 |
95 | }
96 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/PathLengthFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import java.util.List;
4 |
5 | import se.lth.cs.srl.corpus.Predicate;
6 | import se.lth.cs.srl.corpus.Sentence;
7 | import se.lth.cs.srl.corpus.Word;
8 | import se.lth.cs.srl.corpus.Word.WordData;
9 | import uk.ac.ed.inf.srl.features.FeatureName;
10 | import uk.ac.ed.inf.srl.features.NumFeature;
11 | import uk.ac.ed.inf.srl.features.SingleFeature;
12 |
13 | public class PathLengthFeature extends SingleFeature {
14 | private static final long serialVersionUID = 1L;
15 |
16 | protected PathLengthFeature(FeatureName name, WordData attr,
17 | boolean consider_deptree, String POSPrefix) {
18 | super(name, true, false, POSPrefix);
19 | }
20 |
21 | @Override
22 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
23 | for (Predicate pred : s.getPredicates()) {
24 | if (doExtractFeatures(pred))
25 | for (Word arg : pred.getArgMap().keySet()) {
26 | addMap(getFeatureString(pred, arg));
27 | }
28 | }
29 | }
30 |
31 | @Override
32 | public String getFeatureString(Sentence s, int predIndex, int argIndex) {
33 | return makeFeatureString(s.get(predIndex), s.get(argIndex));
34 | }
35 |
36 | @Override
37 | public String getFeatureString(Predicate pred, Word arg) {
38 | return makeFeatureString(pred, arg);
39 | }
40 |
41 | public String makeFeatureString(Word pred, Word arg) {
42 | List path = Word.findPath(pred, arg);
43 | return "PathLength" + NumFeature.bin(path.size());
44 | }
45 |
46 | }
47 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/PositionFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import java.util.Collection;
4 | import java.util.Map;
5 |
6 | import se.lth.cs.srl.corpus.Predicate;
7 | import se.lth.cs.srl.corpus.Sentence;
8 | import se.lth.cs.srl.corpus.Word;
9 | import uk.ac.ed.inf.srl.features.FeatureName;
10 | import uk.ac.ed.inf.srl.features.SingleFeature;
11 |
12 | public class PositionFeature extends SingleFeature {
13 | private static final long serialVersionUID = 1L;
14 | /*
15 | * Position is the position of the argument wrt the predicate. I.e. if the
16 | * predicate is at position 2, and the argument at position 4, their
17 | * relation is AFTER
18 | */
19 | public static final String BEFORE = "B";
20 | public static final String ON = "O";
21 | public static final String AFTER = "A";
22 |
23 | protected PositionFeature(String POSPrefix) {
24 | super(FeatureName.Position, true, false, POSPrefix);
25 | indices.put(BEFORE, Integer.valueOf(1));
26 | indices.put(ON, Integer.valueOf(2));
27 | indices.put(AFTER, Integer.valueOf(3));
28 | indexcounter = 4;
29 | }
30 |
31 | @Override
32 | public void addFeatures(Sentence s, Collection indices,
33 | Map nonbinFeats, int predIndex, int argIndex,
34 | Integer offset, boolean allWords) {
35 | indices.add(indexOf(getFeatureString(s, predIndex, argIndex)) + offset);
36 | }
37 |
38 | @Override
39 | public void addFeatures(Collection indices,
40 | Map nonbinFeats, Predicate pred, Word arg,
41 | Integer offset, boolean allWords) {
42 | indices.add(indexOf(getFeatureString(pred, arg)) + offset);
43 | }
44 |
45 | @Override
46 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
47 | // Do nothing, the map is constructed in the constructor.
48 | }
49 |
50 | @Override
51 | public String getFeatureString(Sentence s, int predIndex, int argIndex) {
52 | if (predIndex == argIndex)
53 | return ON;
54 | else if (predIndex < argIndex)
55 | return AFTER;
56 | else
57 | return BEFORE;
58 | }
59 |
60 | @Override
61 | public String getFeatureString(Predicate pred, Word arg) {
62 | // int cmp=pred.compareTo(arg);
63 | int cmp = pred.getMySentence().wordComparator.compare(pred, arg);
64 | if (cmp < 0) {
65 | return AFTER;
66 | } else if (cmp == 0) {
67 | return ON;
68 | } else {
69 | return BEFORE;
70 | }
71 | }
72 |
73 | }
74 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/PredDependentAttrFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import se.lth.cs.srl.corpus.Predicate;
4 | import se.lth.cs.srl.corpus.Sentence;
5 | import se.lth.cs.srl.corpus.Word;
6 | import se.lth.cs.srl.corpus.Word.WordData;
7 | import uk.ac.ed.inf.srl.features.AttrFeature;
8 | import uk.ac.ed.inf.srl.features.FeatureName;
9 | import uk.ac.ed.inf.srl.features.TargetWord;
10 |
11 | public class PredDependentAttrFeature extends AttrFeature {
12 | private static final long serialVersionUID = 1L;
13 |
14 | protected PredDependentAttrFeature(FeatureName name, WordData attr,
15 | TargetWord tw, boolean usedForPredicateIdentification,
16 | String POSPrefix) {
17 | super(name, attr, tw, false, usedForPredicateIdentification, POSPrefix);
18 | }
19 |
20 | @Override
21 | public String getFeatureString(Sentence s, int predIndex, int argIndex) {
22 | if (wordExtractor.getWord(s, predIndex, argIndex) == null)
23 | return null;
24 | return wordExtractor.getWord(s, predIndex, argIndex).getAttr(attr);
25 | }
26 |
27 | @Override
28 | public String getFeatureString(Predicate pred, Word arg) {
29 | if (wordExtractor.getWord(pred, arg) == null)
30 | return null;
31 | return wordExtractor.getWord(pred, arg).getAttr(attr);
32 | }
33 |
34 | }
35 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/PredDependentBrown.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import se.lth.cs.srl.corpus.Predicate;
4 | import se.lth.cs.srl.corpus.Sentence;
5 | import se.lth.cs.srl.corpus.Word;
6 | import uk.ac.ed.inf.srl.features.FeatureName;
7 | import uk.ac.ed.inf.srl.features.PredDependentAttrFeature;
8 | import uk.ac.ed.inf.srl.features.TargetWord;
9 | import se.lth.cs.srl.util.BrownCluster;
10 | import se.lth.cs.srl.util.BrownCluster.ClusterVal;
11 |
12 | public class PredDependentBrown extends PredDependentAttrFeature {
13 | private static final long serialVersionUID = 1L;
14 |
15 | private BrownCluster bc;
16 | private ClusterVal cv;
17 |
18 | protected PredDependentBrown(FeatureName name, TargetWord tw,
19 | boolean includeAllWords, String POSPrefix, BrownCluster bc,
20 | ClusterVal cv) {
21 | super(name, null, tw, includeAllWords, POSPrefix);
22 | this.bc = bc;
23 | this.cv = cv;
24 | }
25 |
26 | @Override
27 | public String getFeatureString(Sentence s, int predIndex, int argIndex) {
28 | Word w = wordExtractor.getWord(s, predIndex, argIndex);
29 | if (w.isBOS())
30 | return "ROOT";
31 | return bc.getValue(w.getForm(), cv);
32 | }
33 |
34 | @Override
35 | public String getFeatureString(Predicate pred, Word arg) {
36 | Word w = wordExtractor.getWord(pred, arg);
37 | if (w.isBOS())
38 | return "ROOT";
39 | return bc.getValue(w.getForm(), cv);
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/PredDependentEmbedding.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import se.lth.cs.srl.corpus.Predicate;
4 | import se.lth.cs.srl.corpus.Sentence;
5 | import se.lth.cs.srl.corpus.Word;
6 | import uk.ac.ed.inf.srl.features.ContinuousFeature;
7 | import uk.ac.ed.inf.srl.features.FeatureName;
8 | import uk.ac.ed.inf.srl.features.TargetWord;
9 | import uk.ac.ed.inf.srl.features.WordExtractor;
10 | import se.lth.cs.srl.util.WordEmbedding;
11 |
12 | public class PredDependentEmbedding extends ContinuousFeature {
13 | private static final long serialVersionUID = 1L;
14 |
15 | private WordExtractor wordExtractor;
16 | private WordEmbedding bc;
17 | private int dim;
18 | private boolean tokenembedding;
19 |
20 | protected PredDependentEmbedding(FeatureName name, TargetWord tw,
21 | boolean includeAllWords, String POSPrefix, WordEmbedding bc,
22 | int dim, boolean token) {
23 | // super(name, null, tw, includeAllWords, POSPrefix);
24 | super(name, false, true, POSPrefix);
25 | indices.put("WEPRED" + dim, 1);
26 |
27 | this.wordExtractor = WordExtractor.getExtractor(tw);
28 | this.bc = bc;
29 | this.dim = dim;
30 | this.tokenembedding = token;
31 | }
32 |
33 | @Override
34 | public Double getFeatureValue(Sentence s, int predIndex, int argIndex) {
35 | Word w = wordExtractor.getWord(s, predIndex, argIndex);
36 | if (w == null)
37 | return 0.0;
38 | else
39 | return this.getValue(w, dim);
40 | // return w.getRep(dim);
41 | }
42 |
43 | @Override
44 | public String getFeatureString(Sentence s, int predIndex, int argIndex) {
45 | Word w = wordExtractor.getWord(s, predIndex, argIndex);
46 | return "WEPRED" + dim;
47 | }
48 |
49 | @Override
50 | public Double getFeatureValue(Predicate pred, Word arg) {
51 | Word w = wordExtractor.getWord(pred, arg);
52 | if (w == null)
53 | return 0.0;
54 | else {
55 | // return w.getRep(dim);
56 | return this.getValue(w, dim);
57 | }
58 | }
59 |
60 | @Override
61 | public String getFeatureString(Predicate pred, Word arg) {
62 | return "WEPRED" + dim;
63 | }
64 |
65 | private Double getValue(Word w, int dim) {
66 | if (tokenembedding)
67 | return w.getRep(dim);
68 | else
69 | return bc.getValue(w.getForm(), dim);
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/PredDependentFeatsFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import se.lth.cs.srl.corpus.Predicate;
4 | import se.lth.cs.srl.corpus.Sentence;
5 | import se.lth.cs.srl.corpus.Word;
6 | import uk.ac.ed.inf.srl.features.FeatsFeature;
7 | import uk.ac.ed.inf.srl.features.FeatureName;
8 | import uk.ac.ed.inf.srl.features.TargetWord;
9 | import se.lth.cs.srl.languages.Language;
10 |
11 | public class PredDependentFeatsFeature extends FeatsFeature {
12 | private static final long serialVersionUID = 1L;
13 |
14 | protected PredDependentFeatsFeature(FeatureName name, TargetWord tw,
15 | boolean usedForPredicateIdentification, String POSPrefix) {
16 | super(name, tw, false, usedForPredicateIdentification, POSPrefix);
17 | }
18 |
19 | @Override
20 | public String[] getFeatureStrings(Sentence s, int predIndex, int argIndex) {
21 | Word w = wordExtractor.getWord(s, predIndex, argIndex);
22 | return Language.getLanguage().getFeatSplitPattern().split(w.getFeats());
23 | }
24 |
25 | @Override
26 | public String[] getFeatureStrings(Predicate pred, Word arg) {
27 | Word w = wordExtractor.getWord(pred, arg);
28 | return Language.getLanguage().getFeatSplitPattern().split(w.getFeats());
29 | }
30 |
31 | @Override
32 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
33 | if (allWords) {
34 | for (int i = 1, size = s.size(); i < size; ++i) {
35 | if (doExtractFeatures(s.get(i)))
36 | for (String v : getFeatureStrings(s, i, -1))
37 | addMap(v);
38 | }
39 | } else {
40 | for (Predicate pred : s.getPredicates()) {
41 | if (doExtractFeatures(pred))
42 | for (String v : getFeatureStrings(pred, null))
43 | addMap(v);
44 | }
45 | }
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/QContinuousSetFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import java.util.Collection;
4 | import java.util.Map;
5 |
6 | import se.lth.cs.srl.corpus.Predicate;
7 | import se.lth.cs.srl.corpus.Sentence;
8 | import se.lth.cs.srl.corpus.Word;
9 | import uk.ac.ed.inf.srl.features.ContinuousFeature;
10 | import uk.ac.ed.inf.srl.features.FeatureGenerator;
11 | import uk.ac.ed.inf.srl.features.QuadraticFeature;
12 | import uk.ac.ed.inf.srl.features.SetFeature;
13 |
14 | public class QContinuousSetFeature extends SetFeature implements
15 | QuadraticFeature {
16 | private static final long serialVersionUID = 1L;
17 |
18 | private ContinuousFeature f1;
19 | private SetFeature f2;
20 |
21 | protected QContinuousSetFeature(ContinuousFeature f1, SetFeature f2,
22 | boolean usedForPredicateIdentification, String POSPrefix) {
23 | super(f1.name, f1.includeArgs || f2.includeArgs,
24 | usedForPredicateIdentification, POSPrefix);
25 | this.f1 = f1;
26 | this.f2 = f2;
27 | }
28 |
29 | @Override
30 | public void addFeatures(Sentence s, Collection indices,
31 | Map nonbinFeats, int predIndex, int argIndex,
32 | Integer offset, boolean allWords) {
33 | for (String v : getFeatureStrings(s, predIndex, argIndex)) {
34 | Integer i = indexOf(v);
35 | Double d = f1.getFeatureValue(s, predIndex, argIndex);
36 | if (i != 1 && (allWords || i < predMaxIndex))
37 | nonbinFeats.put(i + offset, d);
38 | }
39 | }
40 |
41 | @Override
42 | public void addFeatures(Collection indices,
43 | Map nonbinFeats, Predicate pred, Word arg,
44 | Integer offset, boolean allWords) {
45 | for (String v : getFeatureStrings(pred, arg)) {
46 | Integer i = indexOf(v);
47 | Double d = f1.getFeatureValue(pred, arg);
48 | if (i != -1 && (allWords || i < predMaxIndex))
49 | nonbinFeats.put(i + offset, d);
50 | }
51 | }
52 |
53 | @Override
54 | public String[] getFeatureStrings(Sentence s, int predIndex, int argIndex) {
55 | String f1val = f1.getFeatureString(s, predIndex, argIndex);
56 | String[] f2vals = f2.getFeatureStrings(s, predIndex, argIndex);
57 | makeFeatureStrings(f1val, f2vals);
58 | return f2vals;
59 | }
60 |
61 | @Override
62 | public String[] getFeatureStrings(Predicate pred, Word arg) {
63 | String f1val = f1.getFeatureString(pred, arg);
64 | String[] f2vals = f2.getFeatureStrings(pred, arg);
65 | if (f2vals != null) {
66 | makeFeatureStrings(f1val, f2vals);
67 | return f2vals;
68 | } else {
69 | return new String[] { "" };
70 | // return new String[0];
71 | }
72 | }
73 |
74 | private void makeFeatureStrings(String f1val, String[] f2vals) {
75 | for (int i = 0, length = f2vals.length; i < length; ++i)
76 | f2vals[i] += VALUE_SEPARATOR + f1val;
77 | }
78 |
79 | public String getName() {
80 | return FeatureGenerator.getCanonicalName(f1.name, f2.name);
81 | }
82 |
83 | }
84 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/QDoubleChildSetFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import java.util.Set;
4 |
5 | import se.lth.cs.srl.corpus.Predicate;
6 | import se.lth.cs.srl.corpus.Sentence;
7 | import se.lth.cs.srl.corpus.Word;
8 | import uk.ac.ed.inf.srl.features.ChildSetFeature;
9 | import uk.ac.ed.inf.srl.features.FeatureGenerator;
10 | import uk.ac.ed.inf.srl.features.QuadraticFeature;
11 | import uk.ac.ed.inf.srl.features.SetFeature;
12 |
13 | public class QDoubleChildSetFeature extends SetFeature implements
14 | QuadraticFeature {
15 | private static final long serialVersionUID = 1L;
16 |
17 | private ChildSetFeature f1, f2;
18 |
19 | protected QDoubleChildSetFeature(ChildSetFeature f1, ChildSetFeature f2,
20 | boolean usedForPredicateIdentification, String POSPrefix) {
21 | super(f1.name, f1.includeArgs && f2.includeArgs,
22 | usedForPredicateIdentification, POSPrefix); // The boolean
23 | // should always
24 | // evaluate to
25 | // false, seeing as
26 | // ChildSetFeatures
27 | // are always
28 | // focused on the
29 | // pred
30 | this.f1 = f1;
31 | this.f2 = f2;
32 | }
33 |
34 | @Override
35 | public String[] getFeatureStrings(Sentence s, int predIndex, int argIndex) {
36 | return makeFeatureStrings(s.get(predIndex).getChildren());
37 | }
38 |
39 | @Override
40 | public String[] getFeatureStrings(Predicate pred, Word arg) {
41 | return makeFeatureStrings(pred.getChildren());
42 | }
43 |
44 | private String[] makeFeatureStrings(Set children) {
45 | String[] ret = new String[children.size()];
46 | int i = 0;
47 | for (Word child : children)
48 | ret[i++] = child.getAttr(f1.attr) + VALUE_SEPARATOR
49 | + child.getAttr(f2.attr);
50 | return ret;
51 | }
52 |
53 | @Override
54 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
55 | if (includeArgs) {
56 | throw new Error("You are wrong here.");
57 | } else {
58 | if (allWords) {
59 | for (int i = 1, size = s.size(); i < size; ++i) {
60 | if (doExtractFeatures(s.get(i)))
61 | for (String v : getFeatureStrings(s, i, -1))
62 | addMap(v);
63 | }
64 | } else {
65 | for (Predicate pred : s.getPredicates()) {
66 | if (doExtractFeatures(pred))
67 | for (String v : getFeatureStrings(pred, null))
68 | addMap(v);
69 | }
70 | }
71 | }
72 | }
73 |
74 | public String getName() {
75 | return FeatureGenerator.getCanonicalName(f1.name, f2.name);
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/QSetSetFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import se.lth.cs.srl.corpus.Predicate;
4 | import se.lth.cs.srl.corpus.Sentence;
5 | import se.lth.cs.srl.corpus.Word;
6 | import uk.ac.ed.inf.srl.features.FeatureGenerator;
7 | import uk.ac.ed.inf.srl.features.QuadraticFeature;
8 | import uk.ac.ed.inf.srl.features.SetFeature;
9 |
10 | public class QSetSetFeature extends SetFeature implements QuadraticFeature {
11 | private static final long serialVersionUID = 1L;
12 |
13 | private SetFeature f1;
14 | private SetFeature f2;
15 |
16 | protected QSetSetFeature(SetFeature f1, SetFeature f2,
17 | boolean usedForPredicateIdentification, String POSPrefix) {
18 | super(f1.name, f1.includeArgs || f2.includeArgs,
19 | usedForPredicateIdentification, POSPrefix);
20 | this.f1 = f1;
21 | this.f2 = f2;
22 | }
23 |
24 | @Override
25 | public String[] getFeatureStrings(Sentence s, int predIndex, int argIndex) {
26 | String[] f1vals = f1.getFeatureStrings(s, predIndex, argIndex);
27 | String[] f2vals = f2.getFeatureStrings(s, predIndex, argIndex);
28 | makeFeatureStrings(f1vals, f2vals);
29 | return f2vals;
30 | }
31 |
32 | @Override
33 | public String[] getFeatureStrings(Predicate pred, Word arg) {
34 | String[] f1vals = f1.getFeatureStrings(pred, arg);
35 | String[] f2vals = f2.getFeatureStrings(pred, arg);
36 | if (f2vals != null) {
37 | makeFeatureStrings(f1vals, f2vals);
38 | return f2vals;
39 | } else {
40 | return new String[] { "" };
41 | // return new String[0];
42 | }
43 | }
44 |
45 | private void makeFeatureStrings(String[] f1vals, String[] f2vals) {
46 |
47 | if (f1vals.length != f2vals.length) {
48 | System.err
49 | .println("CHECK YOUR IMPLEMENTATION! Trying to combine two set features of different lengths");
50 | System.exit(1);
51 | }
52 |
53 | for (int i = 0, length = f2vals.length; i < length; ++i)
54 | f2vals[i] += VALUE_SEPARATOR + f1vals[i];
55 | }
56 |
57 | public String getName() {
58 | return FeatureGenerator.getCanonicalName(f1.name, f2.name);
59 | }
60 |
61 | }
62 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/QSingleSetFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import se.lth.cs.srl.corpus.Predicate;
4 | import se.lth.cs.srl.corpus.Sentence;
5 | import se.lth.cs.srl.corpus.Word;
6 | import uk.ac.ed.inf.srl.features.FeatureGenerator;
7 | import uk.ac.ed.inf.srl.features.QuadraticFeature;
8 | import uk.ac.ed.inf.srl.features.SetFeature;
9 | import uk.ac.ed.inf.srl.features.SingleFeature;
10 |
11 | public class QSingleSetFeature extends SetFeature implements QuadraticFeature {
12 | private static final long serialVersionUID = 1L;
13 |
14 | private SingleFeature f1;
15 | private SetFeature f2;
16 |
17 | protected QSingleSetFeature(SingleFeature f1, SetFeature f2,
18 | boolean usedForPredicateIdentification, String POSPrefix) {
19 | super(f1.name, f1.includeArgs || f2.includeArgs,
20 | usedForPredicateIdentification, POSPrefix);
21 | this.f1 = f1;
22 | this.f2 = f2;
23 | }
24 |
25 | @Override
26 | public String[] getFeatureStrings(Sentence s, int predIndex, int argIndex) {
27 | String f1val = f1.getFeatureString(s, predIndex, argIndex);
28 | String[] f2vals = f2.getFeatureStrings(s, predIndex, argIndex);
29 | makeFeatureStrings(f1val, f2vals);
30 | return f2vals;
31 | }
32 |
33 | @Override
34 | public String[] getFeatureStrings(Predicate pred, Word arg) {
35 | String f1val = f1.getFeatureString(pred, arg);
36 | String[] f2vals = f2.getFeatureStrings(pred, arg);
37 | if (f2vals != null) {
38 | makeFeatureStrings(f1val, f2vals);
39 | return f2vals;
40 | } else {
41 | return new String[] { "" };
42 | // return new String[0];
43 | }
44 | }
45 |
46 | private void makeFeatureStrings(String f1val, String[] f2vals) {
47 | for (int i = 0, length = f2vals.length; i < length; ++i)
48 | f2vals[i] += VALUE_SEPARATOR + f1val;
49 | }
50 |
51 | public String getName() {
52 | return FeatureGenerator.getCanonicalName(f1.name, f2.name);
53 | }
54 |
55 | }
56 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/QSingleSingleFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import se.lth.cs.srl.corpus.Predicate;
4 | import se.lth.cs.srl.corpus.Sentence;
5 | import se.lth.cs.srl.corpus.Word;
6 | import uk.ac.ed.inf.srl.features.FeatureGenerator;
7 | import uk.ac.ed.inf.srl.features.QuadraticFeature;
8 | import uk.ac.ed.inf.srl.features.SingleFeature;
9 |
10 | public class QSingleSingleFeature extends SingleFeature implements
11 | QuadraticFeature {
12 | private static final long serialVersionUID = 1L;
13 |
14 | private SingleFeature f1, f2;
15 |
16 | protected QSingleSingleFeature(SingleFeature f1, SingleFeature f2,
17 | boolean usedForPredicateIdentification, String POSPrefix) {
18 | super(f1.name, f1.includeArgs || f2.includeArgs,
19 | usedForPredicateIdentification, POSPrefix);
20 | this.f1 = f1;
21 | this.f2 = f2;
22 | }
23 |
24 | @Override
25 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
26 | if (includeArgs) {
27 | for (Predicate pred : s.getPredicates()) {
28 | if (doExtractFeatures(pred))
29 | for (Word arg : pred.getArgMap().keySet()) {
30 | addMap(getFeatureString(pred, arg));
31 | }
32 | }
33 | } else {
34 | super.performFeatureExtraction(s, allWords);
35 | }
36 | }
37 |
38 | @Override
39 | public String getFeatureString(Sentence s, int predIndex, int argIndex) {
40 | return f1.getFeatureString(s, predIndex, argIndex) + VALUE_SEPARATOR
41 | + f2.getFeatureString(s, predIndex, argIndex);
42 | }
43 |
44 | @Override
45 | public String getFeatureString(Predicate pred, Word arg) {
46 | return f1.getFeatureString(pred, arg) + VALUE_SEPARATOR
47 | + f2.getFeatureString(pred, arg);
48 | }
49 |
50 | public String getName() {
51 | return FeatureGenerator.getCanonicalName(f1.name, f2.name);
52 | }
53 |
54 | }
55 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/QuadraticFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | public interface QuadraticFeature {
4 |
5 | public static final String VALUE_SEPARATOR = " + ";
6 |
7 | }
8 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/SameSubTreeFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import java.util.List;
4 |
5 | import se.lth.cs.srl.corpus.Predicate;
6 | import se.lth.cs.srl.corpus.Sentence;
7 | import se.lth.cs.srl.corpus.Word;
8 | import se.lth.cs.srl.corpus.Word.WordData;
9 | import uk.ac.ed.inf.srl.features.FeatureName;
10 | import uk.ac.ed.inf.srl.features.SingleFeature;
11 |
12 | public class SameSubTreeFeature extends SingleFeature {
13 | private static final long serialVersionUID = 1L;
14 |
15 | boolean parent;
16 |
17 | protected SameSubTreeFeature(FeatureName name, WordData attr,
18 | boolean consider_parent, String POSPrefix) {
19 | super(name, true, false, POSPrefix);
20 | this.parent = consider_parent;
21 | }
22 |
23 | @Override
24 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
25 | for (Predicate pred : s.getPredicates()) {
26 | if (doExtractFeatures(pred))
27 | for (Word arg : pred.getArgMap().keySet()) {
28 | addMap(getFeatureString(pred, arg));
29 | }
30 | }
31 | }
32 |
33 | @Override
34 | public String getFeatureString(Sentence s, int predIndex, int argIndex) {
35 | return makeFeatureString(s.get(predIndex), s.get(argIndex));
36 | }
37 |
38 | @Override
39 | public String getFeatureString(Predicate pred, Word arg) {
40 | return makeFeatureString(pred, arg);
41 | }
42 |
43 | public String makeFeatureString(Word pred, Word arg) {
44 | List path = Word.findPath(parent ? pred.getHead() : pred, arg);
45 | boolean inTree = true;
46 | for (int i = 0; i < (path.size() - 1); ++i) {
47 | Word w = path.get(i);
48 | // path goes up instead of down
49 | if (w.getHead() == path.get(i + 1)) {
50 | inTree = false;
51 | break;
52 | }
53 | }
54 |
55 | return (parent ? "Parent" : "") + "SubTree" + (inTree ? 1 : 0);
56 | }
57 |
58 | }
59 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/SetFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import java.util.Collection;
4 | import java.util.Map;
5 |
6 | import se.lth.cs.srl.corpus.Predicate;
7 | import se.lth.cs.srl.corpus.Sentence;
8 | import se.lth.cs.srl.corpus.Word;
9 | import uk.ac.ed.inf.srl.features.Feature;
10 | import uk.ac.ed.inf.srl.features.FeatureName;
11 |
12 | public abstract class SetFeature extends Feature {
13 | private static final long serialVersionUID = 1L;
14 |
15 | protected SetFeature(FeatureName name, boolean includeArgs,
16 | boolean usedForPredicateIdentification, String POSPrefix) {
17 | super(name, includeArgs, usedForPredicateIdentification, POSPrefix);
18 | }
19 |
20 | public abstract String[] getFeatureStrings(Sentence s, int predIndex,
21 | int argIndex);
22 |
23 | public abstract String[] getFeatureStrings(Predicate pred, Word arg);
24 |
25 | @Override
26 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
27 | if (includeArgs) {
28 | for (Predicate pred : s.getPredicates()) {
29 | if (doExtractFeatures(pred))
30 | for (Word arg : pred.getArgMap().keySet()) {
31 | for (String v : getFeatureStrings(pred, arg))
32 | addMap(v);
33 | }
34 | }
35 | } else {
36 | if (allWords) {
37 | for (int i = 1, size = s.size(); i < size; ++i) {
38 | if (doExtractFeatures(s.get(i)))
39 | for (String v : getFeatureStrings(s, i, -1))
40 | addMap(v);
41 | }
42 | } else {
43 | for (Predicate pred : s.getPredicates()) {
44 | if (doExtractFeatures(pred))
45 | for (String v : getFeatureStrings(pred, null))
46 | addMap(v);
47 | }
48 | }
49 | }
50 | }
51 |
52 | @Override
53 | public void addFeatures(Sentence s, Collection indices,
54 | Map nonbinFeats, int predIndex, int argIndex,
55 | Integer offset, boolean allWords) {
56 | for (String v : getFeatureStrings(s, predIndex, argIndex)) {
57 | Integer i = indexOf(v);
58 | if (i != -1 && (allWords || i < predMaxIndex))
59 | indices.add(i + offset);
60 | }
61 | }
62 |
63 | @Override
64 | public void addFeatures(Collection indices,
65 | Map nonbinFeats, Predicate pred, Word arg,
66 | Integer offset, boolean allWords) {
67 | for (String v : getFeatureStrings(pred, arg)) {
68 | Integer i = indexOf(v);
69 | if (i != -1 && (allWords || i < predMaxIndex))
70 | indices.add(i + offset);
71 | }
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/SingleFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import java.util.Collection;
4 | import java.util.Map;
5 |
6 | import se.lth.cs.srl.corpus.Predicate;
7 | import se.lth.cs.srl.corpus.Sentence;
8 | import se.lth.cs.srl.corpus.Word;
9 | import uk.ac.ed.inf.srl.features.Feature;
10 | import uk.ac.ed.inf.srl.features.FeatureName;
11 |
12 | public abstract class SingleFeature extends Feature {
13 | private static final long serialVersionUID = 1L;
14 |
15 | protected SingleFeature(FeatureName name, boolean includeArgs,
16 | boolean usedForPredicateIdentification, String POSPrefix) {
17 | super(name, includeArgs, usedForPredicateIdentification, POSPrefix);
18 | }
19 |
20 | @Override
21 | public void addFeatures(Sentence s, Collection indices,
22 | Map nonbinFeats, int predIndex, int argIndex,
23 | Integer offset, boolean allWords) {
24 | Integer i = indexOf(getFeatureString(s, predIndex, argIndex));
25 | if (i != -1 && (allWords || i < predMaxIndex))
26 | indices.add(i + offset);
27 | }
28 |
29 | @Override
30 | public void addFeatures(Collection indices,
31 | Map nonbinFeats, Predicate pred, Word arg,
32 | Integer offset, boolean allWords) {
33 | Integer i = indexOf(getFeatureString(pred, arg));
34 | if (i != -1 && (allWords || i < predMaxIndex))
35 | indices.add(i + offset);
36 | }
37 |
38 | /**
39 | * This method works with features that have includeArgs==false. It extracts
40 | * either from all words (if boolean allWords is true), or from the
41 | * predicates only (if false)
42 | */
43 | @Override
44 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
45 | if (includeArgs) {
46 | throw new Error("You are wrong here, check your implementation.");
47 | } else {
48 | if (allWords) {
49 | for (int i = 1, size = s.size(); i < size; ++i)
50 | if (doExtractFeatures(s.get(i)))
51 | addMap(getFeatureString(s, i, -1));
52 | } else {
53 | for (Predicate pred : s.getPredicates())
54 | if (doExtractFeatures(pred))
55 | addMap(getFeatureString(pred, null));
56 | }
57 | }
58 | }
59 |
60 | public abstract String getFeatureString(Sentence s, int predIndex,
61 | int argIndex);
62 |
63 | public abstract String getFeatureString(Predicate pred, Word arg);
64 | }
65 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/SpanLengthFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import se.lth.cs.srl.corpus.Predicate;
4 | import se.lth.cs.srl.corpus.Sentence;
5 | import se.lth.cs.srl.corpus.Word;
6 | import se.lth.cs.srl.corpus.Word.WordData;
7 | import uk.ac.ed.inf.srl.features.FeatureName;
8 | import uk.ac.ed.inf.srl.features.NumFeature;
9 | import uk.ac.ed.inf.srl.features.SingleFeature;
10 |
11 | public class SpanLengthFeature extends SingleFeature {
12 | private static final long serialVersionUID = 1L;
13 |
14 | protected SpanLengthFeature(FeatureName name, WordData attr,
15 | boolean consider_deptree, String POSPrefix) {
16 | super(name, true, false, POSPrefix);
17 | }
18 |
19 | @Override
20 | protected void performFeatureExtraction(Sentence s, boolean allWords) {
21 | for (Predicate pred : s.getPredicates()) {
22 | if (doExtractFeatures(pred))
23 | for (Word arg : pred.getArgMap().keySet()) {
24 | addMap(getFeatureString(pred, arg));
25 | }
26 | }
27 | }
28 |
29 | @Override
30 | public String getFeatureString(Sentence s, int predIndex, int argIndex) {
31 | return makeFeatureString(s.get(predIndex), s.get(argIndex));
32 | }
33 |
34 | @Override
35 | public String getFeatureString(Predicate pred, Word arg) {
36 | return makeFeatureString(pred, arg);
37 | }
38 |
39 | public String makeFeatureString(Word pred, Word arg) {
40 | // if(arg.getSpan().size()>1) {
41 | // for(Word w : arg.getSpan())
42 | // System.err.print(w.getForm()+ " ");
43 | // System.err.println("-> " +NumFeature.bin(arg.getSpan().size()));
44 | // }
45 | return "SpanLength" + NumFeature.bin(arg.getSpan().size());
46 | }
47 |
48 | }
49 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/SubCatSizeFeature.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | import java.util.Set;
4 |
5 | import se.lth.cs.srl.corpus.Predicate;
6 | import se.lth.cs.srl.corpus.Sentence;
7 | import se.lth.cs.srl.corpus.Word;
8 | import uk.ac.ed.inf.srl.features.FeatureName;
9 | import uk.ac.ed.inf.srl.features.NumFeature;
10 | import uk.ac.ed.inf.srl.features.SingleFeature;
11 |
12 | public class SubCatSizeFeature extends SingleFeature {
13 | private static final long serialVersionUID = 1L;
14 |
15 | SubCatSizeFeature(boolean usedForPredicateIdentification, String POSPrefix) {
16 | super(FeatureName.SubCatSize, false, usedForPredicateIdentification,
17 | POSPrefix);
18 | }
19 |
20 | @Override
21 | public String getFeatureString(Sentence s, int predIndex, int argIndex) {
22 | return makeFeatureString(s, s.get(predIndex).getChildren());
23 | }
24 |
25 | @Override
26 | public String getFeatureString(Predicate pred, Word arg) {
27 | return makeFeatureString(pred.getMySentence(), pred.getChildren());
28 | }
29 |
30 | private String makeFeatureString(Sentence s, Set children) {
31 | return "SUBCAT" + NumFeature.bin(children.size());
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/features/TargetWord.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.features;
2 |
3 | public enum TargetWord {
4 | Pred, // This is the (potential) predicate
5 | PredParent, // This is the parent of the (potential) predicate
6 | Arg, // This is the (potential) argument
7 | LeftDep, // This is the leftmost dependent of the (potential) argument
8 | RightDep, // This is the rightmost dependent of the (potential) argument
9 | LeftSibling, // This is the left sibling of the (potential) argument
10 | RightSibling, // This is the right sibling of the (potential) argument
11 | PredSubj,
12 |
13 | FirstWord, LastWord, SecondWord,
14 |
15 | }
16 |
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/ml/LearningProblem.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.ml;
2 |
3 | import java.util.Collection;
4 | import java.util.Map;
5 |
6 | public interface LearningProblem {
7 |
8 | public void addInstance(int label, Collection indices,
9 | Map nonbinFeats);
10 |
11 | public void done();
12 |
13 | public Model train();
14 |
15 | }
--------------------------------------------------------------------------------
/src/main/java/uk/ac/ed/inf/srl/ml/Model.java:
--------------------------------------------------------------------------------
1 | package uk.ac.ed.inf.srl.ml;
2 |
3 | import java.io.Serializable;
4 | import java.util.Collection;
5 | import java.util.List;
6 | import java.util.Map;
7 |
8 | import uk.ac.ed.inf.srl.ml.liblinear.Label;
9 |
10 | public interface Model extends Serializable {
11 |
12 | public List