├── .gitignore ├── README.md ├── lexicon ├── pos.txt ├── rels.txt └── words.txt ├── pathlstm.jar ├── pom.xml ├── scripts ├── parse.sh └── parse_fnet.sh └── src └── main └── java ├── dmonner └── xlbp │ ├── AdadeltaBasicWeightUpdater.java │ ├── AdadeltaBatchWeightUpdater.java │ ├── AdagradBatchWeightUpdater.java │ ├── AdamBasicWeightUpdater.java │ ├── BasicWeightUpdater.java │ ├── BatchWeightUpdater.java │ ├── Component.java │ ├── DecayWeightUpdater.java │ ├── DownstreamComponent.java │ ├── ForceAdjacencyListWeightInitializer.java │ ├── Function.java │ ├── Input.java │ ├── InputComponent.java │ ├── InternalComponent.java │ ├── MomentumWeightUpdater.java │ ├── NadamBatchWeightUpdater.java │ ├── Network.java │ ├── NetworkCopier.java │ ├── NetworkStringBuilder.java │ ├── NormalWeightInitializer.java │ ├── ResilientEligibilitiesWeightUpdater.java │ ├── ResilientWeightUpdater.java │ ├── Responsibilities.java │ ├── SetWeightInitializer.java │ ├── Target.java │ ├── TargetComponent.java │ ├── UniformWeightInitializer.java │ ├── UpstreamComponent.java │ ├── WeightInitializer.java │ ├── WeightUpdater.java │ ├── WeightUpdaterType.java │ ├── compound │ ├── AbstractCompound.java │ ├── AbstractInternalCompound.java │ ├── AbstractWeightedCompound.java │ ├── Compound.java │ ├── ConvolutionCompound.java │ ├── DiagonalWeightBank.java │ ├── DropoutInputCompound.java │ ├── FunctionCompound.java │ ├── IndirectWeightBank.java │ ├── InputCompound.java │ ├── InternalCompound.java │ ├── LinearCompound.java │ ├── LinearTargetCompound.java │ ├── LogisticCompound.java │ ├── MemoryCellCompound.java │ ├── MemoryCompound.java │ ├── MultiBindCompound.java │ ├── PiCompound.java │ ├── RectifiedLinearCompound.java │ ├── SharedDiagonalWeightBank.java │ ├── SimpleCompound.java │ ├── SingletonCompound.java │ ├── SumOfSquaresTargetCompound.java │ ├── TanhCompound.java │ ├── TanhTargetCompound.java │ ├── TargetCompound.java │ ├── WeightBank.java │ ├── WeightedCompound.java │ └── XEntropyTargetCompound.java │ ├── connection │ ├── AdjacencyListConnection.java │ ├── AdjacencyMatrixConnection.java │ ├── BiasConnection.java │ ├── Connection.java │ ├── ConnectionType.java │ ├── DiagonalConnection.java │ ├── ImmutableDiagonalConnection.java │ ├── IndirectConnection.java │ └── LayerConnection.java │ ├── example │ └── SequentialParity.java │ ├── layer │ ├── AbstractDownstreamLayer.java │ ├── AbstractFanInLayer.java │ ├── AbstractFanOutLayer.java │ ├── AbstractFunctionLayer.java │ ├── AbstractInternalLayer.java │ ├── AbstractLayer.java │ ├── AbstractTargetLayer.java │ ├── AbstractUpstreamLayer.java │ ├── BiasLayer.java │ ├── CopyDestinationLayer.java │ ├── CopySourceLayer.java │ ├── DirectResponsibilityLayer.java │ ├── DownstreamLayer.java │ ├── DropoutPiLayer.java │ ├── EndCapLayer.java │ ├── FanOutLayer.java │ ├── FunctionLayer.java │ ├── InputLayer.java │ ├── InternalLayer.java │ ├── Layer.java │ ├── LinearLayer.java │ ├── LogisticLayer.java │ ├── OffsetLayer.java │ ├── PiLayer.java │ ├── RectifiedLinearLayer.java │ ├── RepulsionLayer.java │ ├── ScaleLayer.java │ ├── SigmaLayer.java │ ├── SumOfSquaresTargetLayer.java │ ├── TanhLayer.java │ ├── TargetLayer.java │ ├── UpstreamLayer.java │ ├── WeightReceiverLayer.java │ ├── WeightSenderLayer.java │ ├── WeightedLayer.java │ ├── XEntropyLogisticLayer.java │ └── XEntropyTargetLayer.java │ ├── stat │ ├── AbstractStat.java │ ├── BitDistStat.java │ ├── BitStat.java │ ├── ConnectionStat.java │ ├── ErrorStat.java │ ├── FractionStat.java │ ├── MeanVarStat.java │ ├── NetworkStat.java │ ├── Optimizer.java │ ├── SetStat.java │ ├── Stat.java │ ├── StepStat.java │ ├── TargetSetStat.java │ ├── TargetStat.java │ ├── TestStat.java │ └── TrialStat.java │ ├── trial │ ├── AbstractTrialStream.java │ ├── HashTrialStream.java │ ├── LayerCheck.java │ ├── NeverBreaker.java │ ├── PerfectBreaker.java │ ├── Step.java │ ├── StepRecord.java │ ├── Trainer.java │ ├── TrainingBreaker.java │ ├── Trial.java │ ├── TrialRecord.java │ ├── TrialSet.java │ ├── TrialStream.java │ ├── TrialStreamAdapter.java │ └── ValidationBreaker.java │ └── util │ ├── ArrayQueue.java │ ├── CSVWriter.java │ ├── IndexAwareHeap.java │ ├── IndexAwareHeapNode.java │ ├── MatrixTools.java │ ├── NoiseGenerator.java │ ├── NormalNoiseGenerator.java │ ├── ReflectionTools.java │ ├── SlidingMedian.java │ └── TableWriter.java ├── se └── lth │ └── cs │ └── srl │ ├── CompletePipeline.java │ ├── Parse.java │ ├── SemanticRoleLabeler.java │ ├── corpus │ ├── ArgMap.java │ ├── ConstituentBuilder.java │ ├── CorefChain.java │ ├── Predicate.java │ ├── PredicateReference.java │ ├── Sentence.java │ ├── StringInText.java │ ├── Word.java │ └── Yield.java │ ├── io │ ├── ANNWriter.java │ ├── AbstractCoNLL09Reader.java │ ├── AllCoNLL09Reader.java │ ├── CoNLL09Writer.java │ ├── DepsOnlyCoNLL09Reader.java │ ├── FrameNetXMLWriter.java │ ├── SRLOnlyCoNLL09Reader.java │ ├── SentenceReader.java │ └── SentenceWriter.java │ ├── languages │ ├── AbstractDummyLanguage.java │ ├── Chinese.java │ ├── Czech.java │ ├── English.java │ ├── German.java │ ├── Language.java │ ├── NullLanguage.java │ └── Spanish.java │ ├── options │ ├── CompletePipelineCMDLineOptions.java │ ├── FullPipelineOptions.java │ ├── Options.java │ └── ParseOptions.java │ ├── pipeline │ ├── AbstractStep.java │ ├── ArgumentClassifier.java │ ├── ArgumentIdentifier.java │ ├── ArgumentStep.java │ ├── LBJavaArgumentClassifier.java │ ├── Pipeline.java │ ├── PipelineStep.java │ ├── PredicateDisambiguator.java │ ├── PredicateIdentifier.java │ ├── Reranker.java │ └── Step.java │ ├── preprocessor │ ├── CMDLineTokenizer.java │ ├── IllinoisPreprocessor.java │ ├── PipelinedPreprocessor.java │ ├── Preprocessor.java │ ├── SimpleChineseLemmatizer.java │ ├── StanfordPreprocessor.java │ └── tokenization │ │ ├── OpenNLPToolsTokenizerWrapper.java │ │ ├── StanfordPTBTokenizer.java │ │ ├── Tokenizer.java │ │ ├── WhiteSpaceTokenizer.java │ │ └── exner │ │ ├── SwedishTokenizer.java │ │ └── Tokenizer.java │ └── util │ ├── BohnetHelper.java │ ├── BrownCluster.java │ ├── ChineseDesegmenter.java │ ├── DasFilter.java │ ├── FileExistenceVerifier.java │ ├── Relation.java │ ├── Sentence2RDF.java │ ├── SentenceAnnotation.java │ ├── StandOffAnnotation.java │ ├── TurboParser.java │ ├── Util.java │ └── WordEmbedding.java └── uk └── ac └── ed └── inf └── srl ├── features ├── AnySetFeature.java ├── ArgDependentAttrFeature.java ├── ArgDependentBrown.java ├── ArgDependentEmbedding.java ├── ArgDependentFeatsFeature.java ├── AttrFeature.java ├── BrownPathFeature.java ├── ChildSetFeature.java ├── ContinuousArgDependentAttrFeature.java ├── ContinuousAttrFeature.java ├── ContinuousFeature.java ├── ContinuousSetFeature.java ├── DepSubCatFeature.java ├── DependencyCPathEmbedding.java ├── DependencyIPathEmbedding.java ├── DependencyPathEmbedding.java ├── DistanceFeature.java ├── DumpFeature.java ├── EmbeddingPath.java ├── FeatsFeature.java ├── Feature.java ├── FeatureFile.java ├── FeatureGenerator.java ├── FeatureName.java ├── FeatureSet.java ├── NumFeature.java ├── PBLabelFeature.java ├── PathFeature.java ├── PathItemSetFeature.java ├── PathLengthFeature.java ├── PositionFeature.java ├── PredDependentAttrFeature.java ├── PredDependentBrown.java ├── PredDependentEmbedding.java ├── PredDependentFeatsFeature.java ├── QContinuousSetFeature.java ├── QContinuousSingleFeature.java ├── QDoubleChildSetFeature.java ├── QSetSetFeature.java ├── QSingleSetFeature.java ├── QSingleSingleFeature.java ├── QuadraticFeature.java ├── SameSubTreeFeature.java ├── SetFeature.java ├── SingleFeature.java ├── SpanLengthFeature.java ├── SubCatSizeFeature.java ├── TargetWord.java └── WordExtractor.java ├── lstm ├── DataConverter.java ├── DataReader.java ├── EmbeddingNetwork.java ├── NetworkOptions.java ├── NetworkRunner.java └── layer │ └── EmbeddingLayer.java └── ml ├── LearningProblem.java ├── Model.java └── liblinear ├── Label.java ├── LibLinearLearningProblem.java ├── LibLinearModel.java └── WeightVector.java /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | *.iml 3 | /target/ 4 | /models/ 5 | /lib/ 6 | -------------------------------------------------------------------------------- /lexicon/pos.txt: -------------------------------------------------------------------------------- 1 | 1 PDT 2 | 2 CC 3 | 3 NNP 4 | 4 , 5 | 5 WP$ 6 | 6 VBN 7 | 7 WP 8 | 8 RBR 9 | 9 CD 10 | 10 RP 11 | 11 JJ 12 | 12 PRP 13 | 13 TO 14 | 14 HYPH 15 | 15 EX 16 | 16 WRB 17 | 17 RB 18 | 18 WDT 19 | 19 VBP 20 | 20 JJR 21 | 21 VBZ 22 | 22 PRF 23 | 23 NNPS 24 | 24 ( 25 | 25 UH 26 | 26 POS 27 | 27 $ 28 | 28 `` 29 | 29 : 30 | 30 JJS 31 | 31 LS 32 | 32 VB 33 | 33 . 34 | 34 MD 35 | 35 NN 36 | 36 NNS 37 | 37 DT 38 | 38 VBD 39 | 39 # 40 | 40 '' 41 | 41 RBS 42 | 42 IN 43 | 43 SYM 44 | 44 ) 45 | 45 PRP$ 46 | 46 VBG 47 | 47 -1 48 | -------------------------------------------------------------------------------- /lexicon/rels.txt: -------------------------------------------------------------------------------- 1 | 1 DIR-PRD^ 2 | 2 GAP-PMOD^ 3 | 3 PRD-TMP^ 4 | 4 DIR-GAP^ 5 | 5 DIR-OPRD^ 6 | 6 SBJv 7 | 7 GAP-OPRD^ 8 | 8 GAP-TMP^ 9 | 9 OBJv 10 | 10 GAP-PRDv 11 | 11 GAP-LGSv 12 | 12 PRT^ 13 | 13 PMODv 14 | 14 SUFFIX^ 15 | 15 GAP-LOC-PRD^ 16 | 16 GAP-PRD^ 17 | 17 VOC^ 18 | 18 DEP^ 19 | 19 PRP^ 20 | 20 TMPv 21 | 21 NMOD^ 22 | 22 LOC-OPRDv 23 | 23 LOC-PRD^ 24 | 24 DTVv 25 | 25 GAP-LOCv 26 | 26 P^ 27 | 27 GAP-VC^ 28 | 28 MNR-PRD^ 29 | 29 SUB^ 30 | 30 EXTRv 31 | 31 EXT^ 32 | 32 GAP-PMODv 33 | 33 GAP-LOC^ 34 | 34 PMOD^ 35 | 35 DEP-GAPv 36 | 36 GAP-LGS^ 37 | 37 PRD-PRPv 38 | 38 PRNv 39 | 39 AMODv 40 | 40 OPRDv 41 | 41 HYPHv 42 | 42 LOCv 43 | 43 BNF^ 44 | 44 LGSv 45 | 45 PRD^ 46 | 46 IMv 47 | 47 GAP-OPRDv 48 | 48 MNR-TMP^ 49 | 49 HMOD^ 50 | 50 ADV^ 51 | 51 DEPv 52 | 52 LOC^ 53 | 53 ADV-GAPv 54 | 54 PRDv 55 | 55 GAP-SBJv 56 | 56 Pv 57 | 57 APPOv 58 | 58 LOC-OPRD^ 59 | 59 POSTHONv 60 | 60 DTV^ 61 | 61 PRTv 62 | 62 COORDv 63 | 63 OBJ^ 64 | 64 HYPH^ 65 | 65 PUT^ 66 | 66 GAP-VCv 67 | 67 EXTR^ 68 | 68 GAP-TMPv 69 | 69 GAP-OBJ^ 70 | 70 VOCv 71 | 71 NMODv 72 | 72 APPO^ 73 | 73 DIR^ 74 | 74 ADV-GAP^ 75 | 75 EXT-GAPv 76 | 76 VC^ 77 | 77 PRN^ 78 | 78 VCv 79 | 79 DEP-GAP^ 80 | 80 LOC-TMP^ 81 | 81 TITLEv 82 | 82 NAME^ 83 | 83 GAP-PRP^ 84 | 84 EXT-GAP^ 85 | 85 MNRv 86 | 86 LOC-PRDv 87 | 87 PUTv 88 | 88 SUBv 89 | 89 EXTv 90 | 90 NAMEv 91 | 91 PRPv 92 | 92 DIRv 93 | 93 TMP^ 94 | 94 TITLE^ 95 | 95 PRD-PRP^ 96 | 96 GAP-OBJv 97 | 97 OPRD^ 98 | 98 POSTHON^ 99 | 99 GAP-NMOD^ 100 | 100 IM^ 101 | 101 MNR^ 102 | 102 PRD-TMPv 103 | 103 SUFFIXv 104 | 104 AMOD^ 105 | 105 HMODv 106 | 106 ADVv 107 | 107 BNFv 108 | 108 GAP-SBJ^ 109 | 109 LGS^ 110 | 110 CONJ^ 111 | 111 SBJ^ 112 | 112 ROOTv 113 | 113 DIR-GAPv 114 | 114 COORD^ 115 | 115 ROOT^ 116 | 116 GAP-MNR^ 117 | 117 CONJv 118 | -------------------------------------------------------------------------------- /pathlstm.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microth/PathLSTM/d09533e6f3b74505ee86376bdf96f2851bd8428d/pathlstm.jar -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | uk.ac.ed.inf 8 | PathLSTM 9 | 1.0.0 10 | 11 | 12 | 13 | CogCompSoftware 14 | CogComp software repository 15 | http://cogcomp.cs.illinois.edu/m2repo/ 16 | 17 | 18 | 19 | 20 | 21 | com.googlecode.mate-tools 22 | anna 23 | 3.5 24 | 25 | 26 | edu.stanford.nlp 27 | stanford-corenlp 28 | 3.6.0 29 | 30 | 31 | org.apache.opennlp 32 | opennlp-tools 33 | 1.6.0 34 | 35 | 36 | edu.illinois.cs.cogcomp 37 | LBJava 38 | 1.2.26 39 | 40 | 41 | net.sf.trove4j 42 | trove4j 43 | 3.0.3 44 | 45 | 46 | 47 | 48 | 49 | 50 | org.apache.maven.plugins 51 | maven-compiler-plugin 52 | 3.6.0 53 | 54 | 1.8 55 | 1.8 56 | 57 | 58 | 59 | org.apache.maven.plugins 60 | maven-source-plugin 61 | 2.1.2 62 | 63 | 64 | attach-sources 65 | 66 | jar 67 | 68 | 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /scripts/parse.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # please download these models and adjust their locations accordingly 3 | LEMMA_MODEL=models/lemma-eng.model 4 | POS_MODEL=models/tagger-eng.model 5 | PARSER_MODEL=models/parse-eng.model 6 | SRL_MODEL=models/srl-ACL2016-eng.model 7 | 8 | RERANKER="-reranker -externalNNs" 9 | 10 | # Stanford CoreNLP (WSJTokenizer) needed for tokenization 11 | TOKENIZE="-tokenize" 12 | STANFORD=lib/stanford-corenlp-3.7.0.jar 13 | 14 | # java 1.8+ is required 15 | JAVA=java 16 | 17 | # parse $1 18 | $JAVA -Xmx60g -cp lib/anna-3.3.jar:$STANFORD:target/classes/ se.lth.cs.srl.CompletePipeline eng -lemma $LEMMA_MODEL -parser $PARSER_MODEL -tagger $POS_MODEL -srl $SRL_MODEL $RERANKER $TOKENIZE -test $1 19 | # note: make sure that the compiled class files (run "mvn compile") are located in target/classes/ 20 | -------------------------------------------------------------------------------- /scripts/parse_fnet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # please download these dependencies and adjust their locations accordingly 3 | SRL_MODEL=models/srl-ICCG16-stanford-eng.model 4 | FN_DATA=framenet/fndata-1.7/ 5 | 6 | #RERANKER="-aibeam 7 -acbeam 3 -alfa 0.75 -reranker -externalNNs -globalFeats" 7 | RERANKER="-reranker -externalNNs -globalFeats" 8 | 9 | # java 1.8+ is required 10 | JAVA=java 11 | 12 | $JAVA -Xmx60g -cp target/classes/:lib/anna-3.3.jar:lib/stanford-corenlp-3.8.0.jar:lib/stanford-corenlp-3.8.0-models.jar se.lth.cs.srl.CompletePipeline fnet -test $1 -srl $SRL_MODEL $RERANKER -tokenize -framenet $FN_DATA -stanford -out out.conll 13 | # note: make sure that the compiled class files (run "mvn compile") are located in target/classes/ or adjust class path to include precompiled .jar 14 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/AdadeltaBasicWeightUpdater.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp; 2 | 3 | import dmonner.xlbp.connection.Connection; 4 | 5 | public class AdadeltaBasicWeightUpdater implements WeightUpdater 6 | { 7 | public static final long serialVersionUID = 1L; 8 | 9 | private final float beta; 10 | private final float eps; 11 | private final Connection parent; 12 | private double[][] squared_gradients; 13 | private double[][] squared_deltas; 14 | 15 | public AdadeltaBasicWeightUpdater(final Connection parent, final float beta, final float eps) 16 | { 17 | this.parent = parent; 18 | this.beta = beta; 19 | this.eps = eps; 20 | } 21 | 22 | @Override 23 | public Connection getConnection() 24 | { 25 | return parent; 26 | } 27 | 28 | @Override 29 | public float getUpdate(final int i, final float dw) 30 | { 31 | squared_gradients[i][0] = beta * squared_gradients[i][0] + (1-beta) * (dw*dw); 32 | float retval = (float)(Math.sqrt(squared_deltas[i][0] + eps) / Math.sqrt(squared_gradients[i][0] + eps) * dw); 33 | squared_deltas[i][0] = beta * squared_deltas[i][0] + (1-beta) * (retval*retval); 34 | return retval; 35 | } 36 | 37 | @Override 38 | public float getUpdate(final int j, final int i, final float dw) 39 | { 40 | squared_gradients[i][j] = beta * squared_gradients[i][j] + (1-beta) * (dw*dw); 41 | float retval = (float)(Math.sqrt(squared_deltas[i][j] + eps) / Math.sqrt(squared_gradients[i][j] + eps) * dw); 42 | squared_deltas[i][j] = beta * squared_deltas[i][j] + (1-beta) * (retval*retval); 43 | return retval; 44 | } 45 | 46 | @Override 47 | public void initialize(final int size) 48 | { 49 | // put 1D array in the 2nd dimension, for array access efficiency 50 | initialize(1, size); 51 | } 52 | 53 | @Override 54 | public void initialize(final int to, final int from) 55 | { 56 | squared_gradients = new double[from][to]; 57 | squared_deltas = new double[from][to]; 58 | } 59 | 60 | @Override 61 | public void processBatch() 62 | { 63 | } 64 | 65 | @Override 66 | public void toString(final NetworkStringBuilder sb) 67 | { 68 | if(sb.showLearningRates()) 69 | sb.appendln("TODO"); 70 | } 71 | 72 | @Override 73 | public void updateFromBiases(final float[] d) 74 | { 75 | } 76 | 77 | @Override 78 | public void updateFromEligibilities(final float[][] e, final float[] d) 79 | { 80 | } 81 | 82 | @Override 83 | public void updateFromInputs(final float[] in, final float[] d) 84 | { 85 | } 86 | 87 | @Override 88 | public void updateFromVector(final float[] v, final float[] d) 89 | { 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/AdamBasicWeightUpdater.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp; 2 | 3 | import dmonner.xlbp.connection.Connection; 4 | 5 | public class AdamBasicWeightUpdater implements WeightUpdater 6 | { 7 | public static final long serialVersionUID = 1L; 8 | 9 | private final float beta; 10 | private final float beta2; 11 | private final float alpha; 12 | private final float eps; 13 | private final Connection parent; 14 | private double[][] squared_deltas; 15 | private double[][] m; 16 | private double[][] v; 17 | public static int t; 18 | 19 | private double[] betas; 20 | private double[] betas2; 21 | 22 | public AdamBasicWeightUpdater(final Connection parent, final float beta, final float beta2, final float eps, final float alpha) 23 | { 24 | this.parent = parent; 25 | this.beta = beta; 26 | this.beta2 = beta2; 27 | this.alpha = alpha; 28 | this.eps = eps; 29 | t = 1; 30 | 31 | } 32 | 33 | @Override 34 | public Connection getConnection() 35 | { 36 | return parent; 37 | } 38 | 39 | @Override 40 | public float getUpdate(final int i, final float dw) 41 | { 42 | m[i][0] = beta * m[i][0] + (1-beta) * dw; 43 | v[i][0] = beta2* v[i][0] + (1-beta2) * (dw*dw); 44 | 45 | return (float)(alpha * (m[i][0]/(double)(1-Math.pow(beta,t))) / (Math.sqrt(v[i][0]/(double)(1-Math.pow(beta2,t)))+eps)); 46 | } 47 | 48 | @Override 49 | public float getUpdate(final int j, final int i, final float dw) 50 | { 51 | //if(i==3 && j==3) System.err.println(parent.getName() + "\t" + dw); 52 | m[i][j] = beta * m[i][j] + (1-beta) * dw; 53 | v[i][j] = beta2* v[i][j] + (1-beta2) * (dw*dw); 54 | return (float)(alpha * (m[i][j]/(double)(1-Math.pow(beta,t))) / (Math.sqrt(v[i][j]/(double)(1-Math.pow(beta2,t)))+eps)); 55 | } 56 | 57 | @Override 58 | public void initialize(final int size) 59 | { 60 | // put 1D array in the 2nd dimension, for array access efficiency 61 | initialize(1, size); 62 | } 63 | 64 | @Override 65 | public void initialize(final int to, final int from) 66 | { 67 | m = new double[from][to]; 68 | v = new double[from][to]; 69 | squared_deltas = new double[from][to]; 70 | } 71 | 72 | @Override 73 | public void processBatch() 74 | { 75 | } 76 | 77 | @Override 78 | public void toString(final NetworkStringBuilder sb) 79 | { 80 | if(sb.showLearningRates()) 81 | sb.appendln("TODO"); 82 | } 83 | 84 | @Override 85 | public void updateFromBiases(final float[] d) 86 | { 87 | } 88 | 89 | @Override 90 | public void updateFromEligibilities(final float[][] e, final float[] d) 91 | { 92 | } 93 | 94 | @Override 95 | public void updateFromInputs(final float[] in, final float[] d) 96 | { 97 | } 98 | 99 | @Override 100 | public void updateFromVector(final float[] v, final float[] d) 101 | { 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/BasicWeightUpdater.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp; 2 | 3 | import dmonner.xlbp.connection.Connection; 4 | 5 | public class BasicWeightUpdater implements WeightUpdater 6 | { 7 | public static final long serialVersionUID = 1L; 8 | 9 | public static float HALF = 1.0F; 10 | 11 | private final Connection parent; 12 | private final float a; 13 | 14 | public BasicWeightUpdater(final Connection parent) 15 | { 16 | this(parent, 0.1F); 17 | } 18 | 19 | public BasicWeightUpdater(final Connection parent, final float a) 20 | { 21 | this.parent = parent; 22 | this.a = a; 23 | } 24 | 25 | @Override 26 | public Connection getConnection() 27 | { 28 | return parent; 29 | } 30 | 31 | @Override 32 | public float getUpdate(final int i, final float dw) 33 | { 34 | return (a/HALF) * dw; 35 | } 36 | 37 | @Override 38 | public float getUpdate(final int j, final int i, final float dw) 39 | { 40 | return (a/HALF) * dw; 41 | } 42 | 43 | @Override 44 | public void initialize(final int size) 45 | { 46 | } 47 | 48 | @Override 49 | public void initialize(final int to, final int from) 50 | { 51 | } 52 | 53 | @Override 54 | public void processBatch() 55 | { 56 | } 57 | 58 | @Override 59 | public void toString(final NetworkStringBuilder sb) 60 | { 61 | if(sb.showLearningRates()) 62 | sb.appendln("Learning Rate:" + a); 63 | } 64 | 65 | @Override 66 | public void updateFromBiases(final float[] d) 67 | { 68 | } 69 | 70 | @Override 71 | public void updateFromEligibilities(final float[][] e, final float[] d) 72 | { 73 | } 74 | 75 | @Override 76 | public void updateFromInputs(final float[] in, final float[] d) 77 | { 78 | } 79 | 80 | @Override 81 | public void updateFromVector(final float[] v, final float[] d) 82 | { 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/Component.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp; 2 | 3 | import java.io.Serializable; 4 | 5 | public interface Component extends Serializable, Comparable 6 | { 7 | public void activateTest(); 8 | 9 | public void activateTrain(); 10 | 11 | public void build(); 12 | 13 | public void clear(); 14 | 15 | public void clearActivations(); 16 | 17 | public void clearEligibilities(); 18 | 19 | public void clearResponsibilities(); 20 | 21 | public Component copy(NetworkCopier copier); 22 | 23 | public Component copy(String nameSuffix); 24 | 25 | public void copyConnectivityFrom(Component comp, NetworkCopier copier); 26 | 27 | public String getName(); 28 | 29 | public boolean isBuilt(); 30 | 31 | public int nWeights(); 32 | 33 | /** 34 | * To be called once after network is set up. For individual layers, checks that they have all the 35 | * necessary inputs and outputs. For fan-in or fan-out layers, checks to see if they only have one 36 | * input/output in the direction that should have multiples; if so, bows itself out and connects 37 | * input to output directly. For higher level components, keeps track of this process and reworks 38 | * internal pointers. 39 | * 40 | * @return false iff this component has removed itself from the network; true otherwise. 41 | */ 42 | public boolean optimize(); 43 | 44 | public void processBatch(); 45 | 46 | public void setWeightInitializer(WeightInitializer win); 47 | 48 | public void setWeightUpdaterType(WeightUpdaterType wut); 49 | 50 | /** 51 | * @return The name of this component. 52 | */ 53 | @Override 54 | public String toString(); 55 | 56 | public void toString(NetworkStringBuilder sb); 57 | 58 | public String toString(String show); 59 | 60 | public void unbuild(); 61 | 62 | public void updateEligibilities(); 63 | 64 | public void updateResponsibilities(); 65 | 66 | public void updateWeights(); 67 | } 68 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/DecayWeightUpdater.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp; 2 | 3 | import dmonner.xlbp.connection.Connection; 4 | 5 | public class DecayWeightUpdater implements WeightUpdater 6 | { 7 | public static final long serialVersionUID = 1L; 8 | 9 | private final Connection parent; 10 | private final float a; 11 | private final float b; 12 | 13 | public DecayWeightUpdater(final Connection parent) 14 | { 15 | this(parent, 0.1F, 0.001F); 16 | } 17 | 18 | public DecayWeightUpdater(final Connection parent, final float a, final float b) 19 | { 20 | this.parent = parent; 21 | this.a = a; 22 | this.b = b; 23 | } 24 | 25 | @Override 26 | public Connection getConnection() 27 | { 28 | return parent; 29 | } 30 | 31 | @Override 32 | public float getUpdate(final int j, final float dw) 33 | { 34 | return a * (dw - b * parent.getWeight(j, 0)); 35 | } 36 | 37 | @Override 38 | public float getUpdate(final int j, final int i, final float dw) 39 | { 40 | return a * (dw - b * parent.getWeight(j, i)); 41 | } 42 | 43 | @Override 44 | public void initialize(final int size) 45 | { 46 | } 47 | 48 | @Override 49 | public void initialize(final int to, final int from) 50 | { 51 | } 52 | 53 | @Override 54 | public void processBatch() 55 | { 56 | } 57 | 58 | @Override 59 | public void toString(final NetworkStringBuilder sb) 60 | { 61 | if(sb.showLearningRates()) 62 | { 63 | sb.appendln("Learning Rate:" + a); 64 | sb.appendln("Weight Decay Rate:" + b); 65 | } 66 | } 67 | 68 | @Override 69 | public void updateFromBiases(final float[] d) 70 | { 71 | } 72 | 73 | @Override 74 | public void updateFromEligibilities(final float[][] e, final float[] d) 75 | { 76 | } 77 | 78 | @Override 79 | public void updateFromInputs(final float[] in, final float[] d) 80 | { 81 | } 82 | 83 | @Override 84 | public void updateFromVector(final float[] v, final float[] d) 85 | { 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/DownstreamComponent.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp; 2 | 3 | import dmonner.xlbp.layer.DownstreamLayer; 4 | 5 | public interface DownstreamComponent extends Component 6 | { 7 | public void addUpstream(final UpstreamComponent upstream); 8 | 9 | public DownstreamLayer asDownstreamLayer(); 10 | 11 | public boolean connectedUpstream(final UpstreamComponent upstream); 12 | 13 | @Override 14 | public DownstreamComponent copy(NetworkCopier copier); 15 | 16 | @Override 17 | public DownstreamComponent copy(String nameSuffix); 18 | 19 | public int getIndexInUpstream(); 20 | 21 | public int getIndexInUpstream(final int index); 22 | 23 | public UpstreamComponent getUpstream(); 24 | 25 | public UpstreamComponent getUpstream(final int index); 26 | 27 | public int indexOfUpstream(final UpstreamComponent upstream); 28 | 29 | public int nUpstream(); 30 | 31 | public void removeUpstream(final int index); 32 | 33 | public void removeUpstream(final UpstreamComponent upstream); 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/ForceAdjacencyListWeightInitializer.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp; 2 | 3 | public class ForceAdjacencyListWeightInitializer extends UniformWeightInitializer 4 | { 5 | public static final long serialVersionUID = 1L; 6 | 7 | public ForceAdjacencyListWeightInitializer() 8 | { 9 | super(); 10 | } 11 | 12 | public ForceAdjacencyListWeightInitializer(final float p) 13 | { 14 | super(p); 15 | } 16 | 17 | public ForceAdjacencyListWeightInitializer(final float p, final float min, final float max) 18 | { 19 | super(p, min, max); 20 | } 21 | 22 | @Override 23 | public boolean fullConnectivity() 24 | { 25 | return false; 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/Input.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp; 2 | 3 | import java.util.Map; 4 | 5 | import dmonner.xlbp.layer.InputLayer; 6 | 7 | public class Input 8 | { 9 | private final InputLayer layer; 10 | //private final float[] value; 11 | //private final int[] binValue; 12 | 13 | private final Map inputs; 14 | 15 | public Input(final InputLayer layer, Map values) //final float[] value, final int[] binValue) 16 | { 17 | this.layer = layer; 18 | //this.value = value; 19 | //this.binValue = binValue; 20 | inputs = values; 21 | 22 | //if(value!=null) { 23 | // if(value.length != layer.size()) 24 | // throw new IllegalArgumentException("Incorrect Input Size; expected " + layer.size() + " for " 25 | // + layer.getName() + ", got " + value.length); 26 | //} else { 27 | for(int i : values.keySet()) { 28 | if(i>layer.size()) 29 | throw new IllegalArgumentException("Incorrect Input Size; expected " + layer.size() + " for " 30 | + layer.getName() + ", got " + i); 31 | } 32 | //} 33 | } 34 | 35 | public void apply() 36 | { 37 | layer.setInput(inputs); 38 | } 39 | 40 | @Override 41 | public boolean equals(final Object other) 42 | { 43 | if(other instanceof Input) 44 | { 45 | final Input that = (Input) other; 46 | return that.layer == this.layer; 47 | } 48 | 49 | return false; 50 | } 51 | 52 | public InputLayer getLayer() 53 | { 54 | return layer; 55 | } 56 | 57 | /*public float[] getValue() 58 | { 59 | return value; 60 | } 61 | 62 | public int[] getBinValue() 63 | { 64 | return binValue; 65 | }*/ 66 | 67 | public Map getValue() { 68 | return inputs; 69 | } 70 | 71 | @Override 72 | public int hashCode() 73 | { 74 | return layer.hashCode(); 75 | } 76 | 77 | @Override 78 | public String toString() 79 | { 80 | return layer.getName() + ": " + inputs.toString(); // MatrixTools.toString(value); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/InputComponent.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp; 2 | 3 | import java.util.Map; 4 | 5 | public interface InputComponent extends Component 6 | { 7 | @Override 8 | public InputComponent copy(String nameSuffix); 9 | 10 | @Override 11 | public InputComponent copy(NetworkCopier copier); 12 | 13 | //public void setInput(final float[] activations); 14 | public void setInput(final Map activations); 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/InternalComponent.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp; 2 | 3 | public interface InternalComponent extends UpstreamComponent, DownstreamComponent 4 | { 5 | @Override 6 | public InternalComponent copy(String nameSuffix); 7 | 8 | @Override 9 | public InternalComponent copy(NetworkCopier copier); 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/NormalWeightInitializer.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp; 2 | 3 | import java.util.Random; 4 | 5 | public class NormalWeightInitializer implements WeightInitializer 6 | { 7 | public static final long serialVersionUID = 1L; 8 | 9 | public final Random rand; 10 | public final float p; 11 | public final float min; 12 | public final float max; 13 | 14 | public NormalWeightInitializer() 15 | { 16 | this(new Random(1L), 1F); 17 | } 18 | 19 | public NormalWeightInitializer(final float p) 20 | { 21 | this(new Random(1L), p, -0.1F, +0.1F); 22 | } 23 | 24 | public NormalWeightInitializer(final float p, final float min, final float max) 25 | { 26 | this(new Random(1L), p, min, max); 27 | } 28 | 29 | public NormalWeightInitializer(final Random rand) 30 | { 31 | this(rand, 1F); 32 | } 33 | 34 | public NormalWeightInitializer(final Random rand, final float p) 35 | { 36 | this(rand, p, -0.1F, +0.1F); 37 | } 38 | 39 | public NormalWeightInitializer(final Random rand, final float p, final float min, final float max) 40 | { 41 | if(p < 0F || p > 1F) 42 | throw new IllegalArgumentException("p must be in [0, 1]."); 43 | 44 | this.rand = rand; 45 | this.p = p; 46 | this.min = min; 47 | this.max = max; 48 | } 49 | 50 | @Override 51 | public boolean fullConnectivity() 52 | { 53 | return p == 1F; 54 | } 55 | 56 | @Override 57 | public boolean newWeight(final int j, final int i) 58 | { 59 | return rand.nextFloat() < p; 60 | } 61 | 62 | @Override 63 | public float randomWeight(final int j, final int i) 64 | { 65 | return (float)rand.nextGaussian() * (max - min) + min; 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/SetWeightInitializer.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp; 2 | 3 | import java.util.Random; 4 | 5 | public class SetWeightInitializer implements WeightInitializer 6 | { 7 | private static final long serialVersionUID = 1L; 8 | 9 | private final float[][] w; 10 | private final boolean full; 11 | 12 | public SetWeightInitializer(final float[][] w) 13 | { 14 | this.w = w; 15 | this.full = checkFull(); 16 | } 17 | 18 | public SetWeightInitializer(final float[][] w, final boolean showFull) 19 | { 20 | this.w = w; 21 | this.full = showFull; 22 | } 23 | 24 | public SetWeightInitializer(int x, int y, int copy, float[][] matrix) { 25 | 26 | this.w = new float[x][y]; 27 | //System.err.println(x + "\t" + y); 28 | //System.err.println(matrix.length + "\t" + matrix[0].length); 29 | 30 | for(int i=0; i 1F) 42 | throw new IllegalArgumentException("p must be in [0, 1]."); 43 | 44 | this.rand = rand; 45 | this.p = p; 46 | this.min = min; 47 | this.max = max; 48 | } 49 | 50 | @Override 51 | public boolean fullConnectivity() 52 | { 53 | return p == 1F; 54 | } 55 | 56 | @Override 57 | public boolean newWeight(final int j, final int i) 58 | { 59 | return rand.nextFloat() < p; 60 | } 61 | 62 | @Override 63 | public float randomWeight(final int j, final int i) 64 | { 65 | return rand.nextFloat() * (max - min) + min; 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/UpstreamComponent.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp; 2 | 3 | import dmonner.xlbp.layer.UpstreamLayer; 4 | 5 | public interface UpstreamComponent extends Component 6 | { 7 | public void addDownstream(final DownstreamComponent downstream); 8 | 9 | public UpstreamLayer asUpstreamLayer(); 10 | 11 | public boolean connectedDownstream(final DownstreamComponent downstream); 12 | 13 | @Override 14 | public UpstreamComponent copy(String nameSuffix); 15 | 16 | @Override 17 | public UpstreamComponent copy(NetworkCopier copier); 18 | 19 | public DownstreamComponent getDownstream(); 20 | 21 | public DownstreamComponent getDownstream(final int index); 22 | 23 | public int getIndexInDownstream(); 24 | 25 | public int getIndexInDownstream(final int index); 26 | 27 | public int indexOfDownstream(final DownstreamComponent downstream); 28 | 29 | public int nDownstream(); 30 | 31 | public void removeDownstream(final DownstreamComponent downstream); 32 | 33 | public void removeDownstream(final int index); 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/WeightInitializer.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp; 2 | 3 | import java.io.Serializable; 4 | 5 | public interface WeightInitializer extends Serializable 6 | { 7 | public boolean fullConnectivity(); 8 | 9 | public boolean newWeight(int j, int i); 10 | 11 | public float randomWeight(int j, int i); 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/WeightUpdater.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp; 2 | 3 | import java.io.Serializable; 4 | 5 | import dmonner.xlbp.connection.Connection; 6 | 7 | public interface WeightUpdater extends Serializable 8 | { 9 | public Connection getConnection(); 10 | 11 | public float getUpdate(int i, float dw); 12 | 13 | public float getUpdate(int j, int i, float dw); 14 | 15 | public void initialize(int size); 16 | 17 | public void initialize(int to, int from); 18 | 19 | public void processBatch(); 20 | 21 | public void toString(NetworkStringBuilder sb); 22 | 23 | public void updateFromBiases(float[] d); 24 | 25 | public void updateFromEligibilities(float[][] e, float[] d); 26 | 27 | public void updateFromInputs(float[] in, float[] d); 28 | 29 | public void updateFromVector(float[] v, float[] d); 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/compound/Compound.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.compound; 2 | 3 | import dmonner.xlbp.Component; 4 | import dmonner.xlbp.NetworkCopier; 5 | import dmonner.xlbp.UpstreamComponent; 6 | import dmonner.xlbp.layer.UpstreamLayer; 7 | 8 | public interface Compound extends UpstreamComponent 9 | { 10 | @Override 11 | public Compound copy(NetworkCopier copier); 12 | 13 | @Override 14 | public Compound copy(String nameSuffix); 15 | 16 | public Component[] getComponents(); 17 | 18 | public UpstreamLayer getOutput(); 19 | 20 | public UpstreamLayer getOutput(int index); 21 | 22 | public int nOutputs(); 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/compound/DiagonalWeightBank.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.compound; 2 | 3 | import dmonner.xlbp.NetworkCopier; 4 | import dmonner.xlbp.WeightInitializer; 5 | import dmonner.xlbp.WeightUpdaterType; 6 | import dmonner.xlbp.connection.DiagonalConnection; 7 | import dmonner.xlbp.connection.LayerConnection; 8 | import dmonner.xlbp.layer.DownstreamLayer; 9 | import dmonner.xlbp.layer.UpstreamLayer; 10 | 11 | public class DiagonalWeightBank extends WeightBank 12 | { 13 | private static final long serialVersionUID = 1L; 14 | private boolean fullOnly; 15 | 16 | public DiagonalWeightBank(final DiagonalWeightBank that, final NetworkCopier copier) 17 | { 18 | super(that, copier); 19 | this.fullOnly = true; 20 | } 21 | 22 | public DiagonalWeightBank(final String name, final UpstreamLayer upstream, 23 | final DownstreamLayer downstream, final WeightInitializer win, final WeightUpdaterType wut) 24 | { 25 | super(name, upstream, downstream, win, wut); 26 | this.fullOnly = true; 27 | } 28 | 29 | @Override 30 | public DiagonalWeightBank copy(final NetworkCopier copier) 31 | { 32 | return new DiagonalWeightBank(this, copier); 33 | } 34 | 35 | @Override 36 | public DiagonalWeightBank copy(final String nameSuffix) 37 | { 38 | final NetworkCopier copier = new NetworkCopier(nameSuffix); 39 | final DiagonalWeightBank copy = copy(copier); 40 | copier.build(); 41 | return copy; 42 | } 43 | 44 | @Override 45 | public DiagonalConnection getConnection() 46 | { 47 | return (DiagonalConnection) super.getConnection(); 48 | } 49 | 50 | @Override 51 | protected LayerConnection makeConnection() 52 | { 53 | final DiagonalConnection conn = new DiagonalConnection(getName(), getWeightInput(), 54 | getWeightOutput()); 55 | conn.setFullOnly(fullOnly); 56 | return conn; 57 | } 58 | 59 | public void setFullOnly(final boolean fullOnly) 60 | { 61 | this.fullOnly = fullOnly; 62 | getConnection().setFullOnly(fullOnly); 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/compound/DropoutInputCompound.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.compound; 2 | 3 | import java.util.Random; 4 | 5 | import dmonner.xlbp.NetworkCopier; 6 | 7 | public class DropoutInputCompound extends InputCompound { 8 | 9 | private static final long serialVersionUID = 1574221823303987773L; 10 | final double dropout; 11 | final Random r; 12 | 13 | public DropoutInputCompound(InputCompound that, NetworkCopier copier) { 14 | super(that, copier); 15 | dropout = 0.0; 16 | r = new Random(); 17 | } 18 | 19 | public DropoutInputCompound(String string, double d, int inputlength) { 20 | super(string, inputlength); 21 | dropout = d; 22 | r = new Random(); 23 | } 24 | 25 | public void setInput(final float[] activations) { 26 | for(int i=0; idropoutrate); 36 | } 37 | 38 | @Override 39 | public void activateTest() 40 | { 41 | System.arraycopy(upstream[0].getActivations(), 0, y, 0, size); 42 | 43 | for(int k = 1; k < nUpstream; k++) 44 | MatrixTools.multiply(upstream[k].getActivations(), y, size); 45 | 46 | for(int j=0; j0.0?x[i][j]:0.0); 40 | //if(j==x[0].length-1) System.out.println(); else System.out.print("\t"); 41 | return x[j]>0.0?x[j]:0.0F; 42 | } 43 | 44 | @Override 45 | public float fprime(final int j) 46 | { 47 | return y[j]>0.0?1.0F:0.0F; 48 | } 49 | 50 | @Override 51 | public void activateTest() 52 | { 53 | for(int j = 0; j < size; j++) 54 | y[j] = (1-dropoutrate) * f(j); 55 | } 56 | 57 | @Override 58 | public void activateTrain() 59 | { 60 | for(int j = 0; j < size; j++) { 61 | if(ThreadLocalRandom.current().nextFloat()>dropoutrate) { 62 | y[j] = f(j); 63 | fprime[j] = fprime(j); 64 | } else { 65 | y[j] = 0.0F; 66 | fprime[j] = 0.0F; 67 | } 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/layer/ScaleLayer.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.layer; 2 | 3 | import dmonner.xlbp.NetworkCopier; 4 | 5 | public class ScaleLayer extends AbstractFunctionLayer 6 | { 7 | private static final long serialVersionUID = 1L; 8 | 9 | private final float factor; 10 | 11 | public ScaleLayer(final ScaleLayer that, final NetworkCopier copier) 12 | { 13 | super(that, copier); 14 | this.factor = that.factor; 15 | } 16 | 17 | public ScaleLayer(final String name, final int size, final float factor) 18 | { 19 | super(name, size); 20 | this.factor = factor; 21 | } 22 | 23 | @Override 24 | public ScaleLayer copy(final NetworkCopier copier) 25 | { 26 | return new ScaleLayer(this, copier); 27 | } 28 | 29 | @Override 30 | public ScaleLayer copy(final String nameSuffix) 31 | { 32 | return copy(new NetworkCopier(nameSuffix)); 33 | } 34 | 35 | @Override 36 | public float f(final int j) 37 | { 38 | return x[j] * factor; 39 | } 40 | 41 | @Override 42 | public float fprime(final int j) 43 | { 44 | return factor; 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/layer/SigmaLayer.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.layer; 2 | 3 | import dmonner.xlbp.NetworkCopier; 4 | import dmonner.xlbp.Responsibilities; 5 | 6 | public class SigmaLayer extends AbstractFanInLayer 7 | { 8 | private static final long serialVersionUID = 1L; 9 | 10 | public SigmaLayer(final SigmaLayer that, final NetworkCopier copier) 11 | { 12 | super(that, copier); 13 | } 14 | 15 | public SigmaLayer(final String name, final int size) 16 | { 17 | super(name, size); 18 | } 19 | 20 | @Override 21 | public void activateTest() 22 | { 23 | System.arraycopy(upstream[0].getActivations(), 0, y, 0, size); 24 | 25 | for(int k = 1; k < nUpstream; k++) 26 | { 27 | final float[] yk = upstream[k].getActivations(); 28 | for(int j = 0; j < size; j++) 29 | y[j] += yk[j]; 30 | } 31 | } 32 | 33 | @Override 34 | public void activateTrain() 35 | { 36 | activateTest(); 37 | } 38 | 39 | @Override 40 | public void aliasResponsibilities(final int index, final Responsibilities resp) 41 | { 42 | super.aliasResponsibilities(index, resp); 43 | for(int i = 0; i < nUpstream; i++) 44 | upstream[i].aliasResponsibilities(myIndexInUpstream[i], resp); 45 | } 46 | 47 | @Override 48 | public void build() 49 | { 50 | if(!built) 51 | { 52 | super.build(); 53 | 54 | y = new float[size]; 55 | d = new Responsibilities(size); 56 | 57 | for(int i = 0; i < nUpstream; i++) 58 | { 59 | upstream[i].build(); 60 | upstream[i].aliasResponsibilities(myIndexInUpstream[i], d); 61 | } 62 | 63 | built = true; 64 | } 65 | } 66 | 67 | @Override 68 | public SigmaLayer copy(final NetworkCopier copier) 69 | { 70 | return new SigmaLayer(this, copier); 71 | } 72 | 73 | @Override 74 | public SigmaLayer copy(final String nameSuffix) 75 | { 76 | return copy(new NetworkCopier(nameSuffix)); 77 | } 78 | 79 | @Override 80 | public void updateEligibilities() 81 | { 82 | if(downstreamCopyLayer != null) 83 | downstream.updateUpstreamResponsibilities(myIndexInDownstream); 84 | } 85 | 86 | @Override 87 | public void updateResponsibilities() 88 | { 89 | if(downstreamCopyLayer == null) 90 | downstream.updateUpstreamResponsibilities(myIndexInDownstream); 91 | } 92 | 93 | @Override 94 | public void updateUpstreamResponsibilities(final int index) 95 | { 96 | // Nothing to do -- upstream ds are already aliased to this layer's d. 97 | } 98 | 99 | } 100 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/layer/SumOfSquaresTargetLayer.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.layer; 2 | 3 | import dmonner.xlbp.NetworkCopier; 4 | 5 | public class SumOfSquaresTargetLayer extends AbstractTargetLayer 6 | { 7 | private static final long serialVersionUID = 1L; 8 | 9 | public SumOfSquaresTargetLayer(final String name, final int size) 10 | { 11 | super(name, size); 12 | } 13 | 14 | public SumOfSquaresTargetLayer(final SumOfSquaresTargetLayer that, final NetworkCopier copier) 15 | { 16 | super(that, copier); 17 | } 18 | 19 | @Override 20 | public SumOfSquaresTargetLayer copy(final NetworkCopier copier) 21 | { 22 | return new SumOfSquaresTargetLayer(this, copier); 23 | } 24 | 25 | @Override 26 | public SumOfSquaresTargetLayer copy(final String nameSuffix) 27 | { 28 | return copy(new NetworkCopier(nameSuffix)); 29 | } 30 | 31 | @Override 32 | public void updateResponsibilities() 33 | { 34 | if(t == null) 35 | d.clear(); 36 | else 37 | d.target(t, y, w); 38 | 39 | super.updateResponsibilities(); 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/layer/TanhLayer.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.layer; 2 | 3 | import dmonner.xlbp.NetworkCopier; 4 | 5 | public class TanhLayer extends AbstractFunctionLayer 6 | { 7 | private static final long serialVersionUID = 1L; 8 | 9 | public TanhLayer(final String name, final int size) 10 | { 11 | super(name, size); 12 | } 13 | 14 | public TanhLayer(final TanhLayer that, final NetworkCopier copier) 15 | { 16 | super(that, copier); 17 | } 18 | 19 | @Override 20 | public TanhLayer copy(final NetworkCopier copier) 21 | { 22 | return new TanhLayer(this, copier); 23 | } 24 | 25 | @Override 26 | public TanhLayer copy(final String nameSuffix) 27 | { 28 | return copy(new NetworkCopier(nameSuffix)); 29 | } 30 | 31 | @Override 32 | public float f(final int j) 33 | { 34 | return (float) Math.tanh(x[j]); 35 | } 36 | 37 | @Override 38 | public float fprime(final int j) 39 | { 40 | return 1F - (y[j] * y[j]); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/layer/TargetLayer.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.layer; 2 | 3 | import dmonner.xlbp.NetworkCopier; 4 | import dmonner.xlbp.TargetComponent; 5 | 6 | public interface TargetLayer extends DownstreamLayer, TargetComponent 7 | { 8 | @Override 9 | public TargetLayer copy(NetworkCopier copier); 10 | 11 | @Override 12 | public TargetLayer copy(String nameSuffix); 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/layer/UpstreamLayer.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.layer; 2 | 3 | import dmonner.xlbp.NetworkCopier; 4 | import dmonner.xlbp.UpstreamComponent; 5 | 6 | public interface UpstreamLayer extends Layer, UpstreamComponent 7 | { 8 | public void addDownstreamCopyLayer(final CopySourceLayer copySource); 9 | 10 | @Override 11 | public UpstreamLayer copy(String nameSuffix); 12 | 13 | @Override 14 | public UpstreamLayer copy(NetworkCopier copier); 15 | 16 | public CopySourceLayer getDownstreamCopyLayer(); 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/layer/WeightedLayer.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.layer; 2 | 3 | import dmonner.xlbp.NetworkCopier; 4 | 5 | public interface WeightedLayer extends UpstreamLayer 6 | { 7 | @Override 8 | public void addDownstreamCopyLayer(final CopySourceLayer copySource); 9 | 10 | @Override 11 | public WeightedLayer copy(NetworkCopier copier); 12 | 13 | @Override 14 | public WeightedLayer copy(String nameSuffix); 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/layer/XEntropyLogisticLayer.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.layer; 2 | 3 | import dmonner.xlbp.NetworkCopier; 4 | import dmonner.xlbp.Responsibilities; 5 | 6 | public class XEntropyLogisticLayer extends LogisticLayer 7 | { 8 | private static final long serialVersionUID = 1L; 9 | 10 | public XEntropyLogisticLayer(final String name, final int size) 11 | { 12 | super(name, size); 13 | } 14 | 15 | public XEntropyLogisticLayer(final XEntropyLogisticLayer that, final NetworkCopier copier) 16 | { 17 | super(that, copier); 18 | } 19 | 20 | @Override 21 | public void aliasResponsibilities(final int index, final Responsibilities resp) 22 | { 23 | super.aliasResponsibilities(index, resp); 24 | upstream.aliasResponsibilities(myIndexInUpstream, resp); 25 | } 26 | 27 | @Override 28 | public void build() 29 | { 30 | if(!built) 31 | { 32 | super.build(); 33 | 34 | upstream.build(); 35 | upstream.aliasResponsibilities(myIndexInUpstream, d); 36 | 37 | built = true; 38 | } 39 | } 40 | 41 | @Override 42 | public XEntropyLogisticLayer copy(final NetworkCopier copier) 43 | { 44 | return new XEntropyLogisticLayer(this, copier); 45 | } 46 | 47 | @Override 48 | public XEntropyLogisticLayer copy(final String nameSuffix) 49 | { 50 | return copy(new NetworkCopier(nameSuffix)); 51 | } 52 | 53 | @Override 54 | public void updateUpstreamResponsibilities(final int upstreamIndex) 55 | { 56 | // Nothing to do -- upstream ds are already aliased to this layer's d. 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/layer/XEntropyTargetLayer.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.layer; 2 | 3 | import dmonner.xlbp.NetworkCopier; 4 | 5 | public class XEntropyTargetLayer extends AbstractTargetLayer 6 | { 7 | private static final long serialVersionUID = 1L; 8 | 9 | public XEntropyTargetLayer(final String name, final int size) 10 | { 11 | super(name, size); 12 | } 13 | 14 | public XEntropyTargetLayer(final XEntropyTargetLayer that, final NetworkCopier copier) 15 | { 16 | super(that, copier); 17 | } 18 | 19 | @Override 20 | public XEntropyTargetLayer copy(final NetworkCopier copier) 21 | { 22 | return new XEntropyTargetLayer(this, copier); 23 | } 24 | 25 | @Override 26 | public XEntropyTargetLayer copy(final String nameSuffix) 27 | { 28 | return copy(new NetworkCopier(nameSuffix)); 29 | } 30 | 31 | @Override 32 | public void updateResponsibilities() 33 | { 34 | if(t == null) 35 | d.clear(); 36 | else 37 | d.target(t, y, w); 38 | 39 | super.updateResponsibilities(); 40 | } 41 | 42 | public void weightResult(float[] weight) { 43 | for(int i=0; i map) 12 | { 13 | addTo("", map); 14 | } 15 | 16 | @Override 17 | public void saveHeader(final CSVWriter out) throws IOException 18 | { 19 | saveHeader("", out); 20 | } 21 | 22 | @Override 23 | public String toString() 24 | { 25 | return toString(""); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/stat/ConnectionStat.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.stat; 2 | 3 | import java.io.IOException; 4 | import java.util.Map; 5 | 6 | import dmonner.xlbp.connection.Connection; 7 | import dmonner.xlbp.util.CSVWriter; 8 | 9 | public class ConnectionStat extends AbstractStat 10 | { 11 | private final MeanVarStat weights; 12 | private final FractionStat connections; 13 | 14 | public ConnectionStat() 15 | { 16 | weights = new MeanVarStat("Weights"); 17 | connections = new FractionStat("Connections"); 18 | } 19 | 20 | public ConnectionStat(final Connection conn) 21 | { 22 | this(); 23 | 24 | final float[][] w = conn.toMatrix(); 25 | for(final float[] row : w) 26 | for(final float wt : row) 27 | weights.add(Math.abs(wt)); 28 | 29 | connections.add(conn.nWeights(), conn.nWeightsPossible()); 30 | 31 | analyze(); 32 | } 33 | 34 | public void add(final ConnectionStat that) 35 | { 36 | weights.add(that.weights); 37 | connections.add(that.connections); 38 | } 39 | 40 | @Override 41 | public void add(final Stat that) 42 | { 43 | if(that instanceof ConnectionStat) 44 | add((ConnectionStat) that); 45 | else 46 | throw new IllegalArgumentException("Can only add in ConnectionStats."); 47 | } 48 | 49 | @Override 50 | public void addTo(final String prefix, final Map map) 51 | { 52 | weights.addTo(prefix, map); 53 | connections.addTo(prefix, map); 54 | } 55 | 56 | @Override 57 | public void analyze() 58 | { 59 | weights.analyze(); 60 | connections.analyze(); 61 | } 62 | 63 | @Override 64 | public void clear() 65 | { 66 | weights.clear(); 67 | connections.clear(); 68 | } 69 | 70 | public FractionStat getConnections() 71 | { 72 | return connections; 73 | } 74 | 75 | public MeanVarStat getWeights() 76 | { 77 | return weights; 78 | } 79 | 80 | @Override 81 | public void saveData(final CSVWriter out) throws IOException 82 | { 83 | weights.saveData(out); 84 | connections.saveData(out); 85 | } 86 | 87 | @Override 88 | public void saveHeader(final String prefix, final CSVWriter out) throws IOException 89 | { 90 | weights.saveHeader(prefix, out); 91 | connections.saveHeader(prefix, out); 92 | } 93 | 94 | @Override 95 | public String toString(final String prefix) 96 | { 97 | final StringBuffer sb = new StringBuffer(); 98 | 99 | sb.append(weights.toString(prefix)); 100 | sb.append(connections.toString(prefix)); 101 | 102 | return sb.toString(); 103 | } 104 | 105 | } 106 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/stat/MeanVarStat.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.stat; 2 | 3 | import java.io.IOException; 4 | import java.util.Map; 5 | 6 | import dmonner.xlbp.util.CSVWriter; 7 | 8 | public class MeanVarStat extends AbstractStat 9 | { 10 | private final String name; 11 | private float sum; 12 | private float sumsq; 13 | private float n; 14 | private float mean; 15 | private float var; 16 | 17 | public MeanVarStat(final MeanVarStat that) 18 | { 19 | this.name = that.name; 20 | this.sum = that.sum; 21 | this.sumsq = that.sumsq; 22 | this.n = that.n; 23 | this.mean = that.mean; 24 | this.var = that.var; 25 | } 26 | 27 | public MeanVarStat(final String name) 28 | { 29 | this.name = name; 30 | } 31 | 32 | public void add(final float obs) 33 | { 34 | sum += obs; 35 | sumsq += obs * obs; 36 | n++; 37 | } 38 | 39 | public void add(final MeanVarStat that) 40 | { 41 | sum += that.sum; 42 | sumsq += that.sumsq; 43 | n += that.n; 44 | } 45 | 46 | @Override 47 | public void add(final Stat that) 48 | { 49 | if(that instanceof MeanVarStat) 50 | add((MeanVarStat) that); 51 | else 52 | throw new IllegalArgumentException("Can only add in other MeanVarStats."); 53 | } 54 | 55 | @Override 56 | public void addTo(final String prefix, final Map map) 57 | { 58 | map.put(prefix + name + "Mean", mean); 59 | map.put(prefix + name + "Var", var); 60 | } 61 | 62 | @Override 63 | public void analyze() 64 | { 65 | mean = sum / n; 66 | var = (sumsq - 2 * sum * mean) / n + mean * mean; 67 | } 68 | 69 | @Override 70 | public void clear() 71 | { 72 | sum = 0; 73 | sumsq = 0; 74 | n = 0; 75 | mean = 0; 76 | var = 0; 77 | } 78 | 79 | public float getMean() 80 | { 81 | return mean; 82 | } 83 | 84 | public float getVar() 85 | { 86 | return var; 87 | } 88 | 89 | @Override 90 | public void saveData(final CSVWriter out) throws IOException 91 | { 92 | out.appendField(mean); 93 | out.appendField(var); 94 | } 95 | 96 | @Override 97 | public void saveHeader(final String prefix, final CSVWriter out) throws IOException 98 | { 99 | out.appendHeader(prefix + name + "Mean"); 100 | out.appendHeader(prefix + name + "Var"); 101 | } 102 | 103 | @Override 104 | public String toString(final String prefix) 105 | { 106 | final StringBuilder sb = new StringBuilder(); 107 | 108 | final String prename = prefix + name; 109 | 110 | sb.append(prename); 111 | sb.append("Mean = "); 112 | sb.append(mean); 113 | sb.append("\n"); 114 | 115 | sb.append(prename); 116 | sb.append("Variance = "); 117 | sb.append(var); 118 | sb.append("\n"); 119 | 120 | return sb.toString(); 121 | } 122 | 123 | } 124 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/stat/Optimizer.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.stat; 2 | 3 | public abstract class Optimizer 4 | { 5 | public static enum Type 6 | { 7 | BIT_ACCURACY, TARGET_ACCURACY, TRIAL_ACCURACY, STEP_ACCURACY, SSE 8 | } 9 | 10 | public static Optimizer defaultOptimizer = get(Type.SSE); 11 | 12 | public static Optimizer get(final Type type) 13 | { 14 | if(type == Type.BIT_ACCURACY) 15 | return new Optimizer() 16 | { 17 | @Override 18 | public boolean betterThan(final SetStat newest, final SetStat best) 19 | { 20 | if(newest == null) 21 | return false; 22 | if(best == null) 23 | return true; 24 | return newest.getTargetStats().getBits().getAccuracy() >= best.getTargetStats().getBits() 25 | .getAccuracy(); 26 | } 27 | }; 28 | else if(type == Type.TARGET_ACCURACY) 29 | return new Optimizer() 30 | { 31 | @Override 32 | public boolean betterThan(final SetStat newest, final SetStat best) 33 | { 34 | if(newest == null) 35 | return false; 36 | if(best == null) 37 | return true; 38 | return newest.getTargetStats().getCorrect().getFraction() >= best.getTargetStats() 39 | .getCorrect().getFraction(); 40 | } 41 | }; 42 | else if(type == Type.TRIAL_ACCURACY) 43 | return new Optimizer() 44 | { 45 | @Override 46 | public boolean betterThan(final SetStat newest, final SetStat best) 47 | { 48 | if(newest == null) 49 | return false; 50 | if(best == null) 51 | return true; 52 | return newest.getTrialStats().getFraction() >= best.getTrialStats().getFraction(); 53 | } 54 | }; 55 | else if(type == Type.STEP_ACCURACY) 56 | return new Optimizer() 57 | { 58 | @Override 59 | public boolean betterThan(final SetStat newest, final SetStat best) 60 | { 61 | if(newest == null) 62 | return false; 63 | if(best == null) 64 | return true; 65 | return newest.getStepStats().getFraction() >= best.getStepStats().getFraction(); 66 | } 67 | }; 68 | else if(type == Type.SSE) 69 | return new Optimizer() 70 | { 71 | @Override 72 | public boolean betterThan(final SetStat newest, final SetStat best) 73 | { 74 | if(newest == null) 75 | return false; 76 | if(best == null) 77 | return true; 78 | return newest.getTargetStats().getError().getSSE() <= best.getTargetStats().getError() 79 | .getSSE(); 80 | } 81 | }; 82 | else 83 | throw new IllegalStateException("Unhandled Type: " + type); 84 | } 85 | 86 | public abstract boolean betterThan(SetStat newest, SetStat best); 87 | } 88 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/stat/Stat.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.stat; 2 | 3 | import java.io.IOException; 4 | import java.util.Map; 5 | 6 | import dmonner.xlbp.util.CSVWriter; 7 | 8 | public interface Stat 9 | { 10 | public abstract void add(final Stat that); 11 | 12 | public abstract void addTo(final Map map); 13 | 14 | public abstract void addTo(final String prefix, final Map map); 15 | 16 | public abstract void analyze(); 17 | 18 | public abstract void clear(); 19 | 20 | public abstract void saveData(final CSVWriter out) throws IOException; 21 | 22 | public abstract void saveHeader(final CSVWriter out) throws IOException; 23 | 24 | public abstract void saveHeader(final String prefix, final CSVWriter out) throws IOException; 25 | 26 | public abstract String toString(final String prefix); 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/stat/StepStat.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.stat; 2 | 3 | import java.io.IOException; 4 | import java.util.Map; 5 | 6 | import dmonner.xlbp.Target; 7 | import dmonner.xlbp.trial.Step; 8 | import dmonner.xlbp.util.CSVWriter; 9 | 10 | public class StepStat extends AbstractStat 11 | { 12 | private final Step step; 13 | private final TargetSetStat targets; 14 | private final FractionStat correct; 15 | 16 | public StepStat(final Step step) 17 | { 18 | this.step = step; 19 | this.targets = new TargetSetStat(); 20 | this.correct = new FractionStat("Step"); 21 | 22 | analyze(); 23 | } 24 | 25 | @Override 26 | public void add(final Stat that) 27 | { 28 | throw new IllegalArgumentException("Can only get data directly from a Step."); 29 | } 30 | 31 | @Override 32 | public void addTo(final String prefix, final Map map) 33 | { 34 | targets.addTo(prefix, map); 35 | correct.addTo(prefix, map); 36 | } 37 | 38 | @Override 39 | public void analyze() 40 | { 41 | for(final Target target : step.getTargets()) 42 | { 43 | final TargetStat stat = new TargetStat(target.getLayer()); 44 | stat.compare(target.getValue()); 45 | stat.analyze(); 46 | targets.add(stat); 47 | } 48 | 49 | targets.analyze(); 50 | final int possible = targets.size() > 0 ? 1 : 0; 51 | final int actual = targets.getCorrect().getFraction() == 1F ? possible : 0; 52 | correct.add(actual, possible); 53 | correct.analyze(); 54 | } 55 | 56 | @Override 57 | public void clear() 58 | { 59 | targets.clear(); 60 | correct.clear(); 61 | } 62 | 63 | public FractionStat getCorrect() 64 | { 65 | return correct; 66 | } 67 | 68 | public Step getStep() 69 | { 70 | return step; 71 | } 72 | 73 | public TargetSetStat getTargets() 74 | { 75 | return targets; 76 | } 77 | 78 | @Override 79 | public void saveData(final CSVWriter out) throws IOException 80 | { 81 | correct.saveData(out); 82 | targets.saveData(out); 83 | } 84 | 85 | @Override 86 | public void saveHeader(final String prefix, final CSVWriter out) throws IOException 87 | { 88 | correct.saveHeader(prefix, out); 89 | targets.saveHeader(prefix, out); 90 | } 91 | 92 | @Override 93 | public String toString(final String prefix) 94 | { 95 | final StringBuffer sb = new StringBuffer(); 96 | 97 | sb.append(correct.toString(prefix)); 98 | sb.append(targets.toString(prefix)); 99 | 100 | return sb.toString(); 101 | } 102 | 103 | } 104 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/trial/LayerCheck.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.trial; 2 | 3 | import dmonner.xlbp.layer.Layer; 4 | 5 | public class LayerCheck 6 | { 7 | private final String name; 8 | private final Layer layer; 9 | private final float[] eval; 10 | 11 | public LayerCheck(final Layer layer, final float[] eval) 12 | { 13 | this(layer.getName(), layer, eval); 14 | } 15 | 16 | public LayerCheck(final String name, final Layer layer, final float[] eval) 17 | { 18 | this.name = name; 19 | this.layer = layer; 20 | this.eval = eval; 21 | } 22 | 23 | public float[] getEval() 24 | { 25 | return eval; 26 | } 27 | 28 | public Layer getLayer() 29 | { 30 | return layer; 31 | } 32 | 33 | public String getName() 34 | { 35 | return name; 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/trial/NeverBreaker.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.trial; 2 | 3 | import dmonner.xlbp.stat.TestStat; 4 | 5 | public class NeverBreaker implements TrainingBreaker 6 | { 7 | @Override 8 | public boolean isBreakTime(final TestStat stat) 9 | { 10 | return false; 11 | } 12 | 13 | @Override 14 | public void reset() 15 | { 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/trial/PerfectBreaker.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.trial; 2 | 3 | import dmonner.xlbp.stat.TestStat; 4 | 5 | public class PerfectBreaker implements TrainingBreaker 6 | { 7 | @Override 8 | public boolean isBreakTime(final TestStat stat) 9 | { 10 | return stat.getLastTrain().getTrialStats().getFraction() == 1F; 11 | } 12 | 13 | @Override 14 | public void reset() 15 | { 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/trial/StepRecord.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.trial; 2 | 3 | import java.util.HashMap; 4 | import java.util.Map; 5 | import java.util.Map.Entry; 6 | import java.util.Set; 7 | 8 | import dmonner.xlbp.layer.Layer; 9 | 10 | public class StepRecord 11 | { 12 | private final Step step; 13 | private final Map recordings; 14 | 15 | public StepRecord(final Step step) 16 | { 17 | this.step = step; 18 | this.recordings = new HashMap<>(); 19 | 20 | for(final Layer layer : step.getRecordLayers()) 21 | recordings.put(layer, layer.getActivations().clone()); 22 | } 23 | 24 | public float[] getRecording(final Layer layer) 25 | { 26 | return recordings.get(layer); 27 | } 28 | 29 | public Set> getRecordings() 30 | { 31 | return recordings.entrySet(); 32 | } 33 | 34 | public Step getStep() 35 | { 36 | return step; 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/trial/TrainingBreaker.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.trial; 2 | 3 | import dmonner.xlbp.stat.TestStat; 4 | 5 | public interface TrainingBreaker 6 | { 7 | public boolean isBreakTime(final TestStat stat); 8 | 9 | public void reset(); 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/trial/TrialRecord.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.trial; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | public class TrialRecord 7 | { 8 | private final Trial trial; 9 | private final List recordings; 10 | 11 | public TrialRecord(final Trial trial) 12 | { 13 | this.trial = trial; 14 | this.recordings = new ArrayList<>(trial.size()); 15 | 16 | for(final Step step : trial.getSteps()) 17 | recordings.add(step.getLastRecording()); 18 | } 19 | 20 | public List getRecordings() 21 | { 22 | return recordings; 23 | } 24 | 25 | public Trial getTrial() 26 | { 27 | return trial; 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/trial/TrialStream.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.trial; 2 | 3 | import dmonner.xlbp.Network; 4 | 5 | public interface TrialStream 6 | { 7 | public Network getMetaNetwork(); 8 | 9 | public String getName(); 10 | 11 | public Trial nextTestTrial(); 12 | 13 | public Trial nextTrainTrial(); 14 | 15 | public Trial nextValidationTrial(); 16 | 17 | public int nFolds(); 18 | 19 | public int nTestFolds(); 20 | 21 | public int nTestTrials(); 22 | 23 | public int nTrainFolds(); 24 | 25 | public int nTrainTrials(); 26 | 27 | public int nValidationFolds(); 28 | 29 | public int nValidationTrials(); 30 | 31 | public void setFold(int fold); 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/trial/TrialStreamAdapter.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.trial; 2 | 3 | import dmonner.xlbp.Network; 4 | 5 | public class TrialStreamAdapter extends AbstractTrialStream 6 | { 7 | public TrialStreamAdapter(final String name, final Network net) 8 | { 9 | super(name, net); 10 | } 11 | 12 | public TrialStreamAdapter(final String name, final Network net, final int train, final int test, 13 | final int valid) 14 | { 15 | super(name, net, train, test, valid); 16 | } 17 | 18 | public TrialStreamAdapter(final String name, final Network net, final String split) 19 | { 20 | super(name, net, split); 21 | } 22 | 23 | @Override 24 | public Trial nextTestTrial() 25 | { 26 | return null; 27 | } 28 | 29 | @Override 30 | public Trial nextTrainTrial() 31 | { 32 | return null; 33 | } 34 | 35 | @Override 36 | public Trial nextValidationTrial() 37 | { 38 | return null; 39 | } 40 | 41 | @Override 42 | public int nTestTrials() 43 | { 44 | return 0; 45 | } 46 | 47 | @Override 48 | public int nTrainTrials() 49 | { 50 | return 0; 51 | } 52 | 53 | @Override 54 | public int nValidationTrials() 55 | { 56 | return 0; 57 | } 58 | 59 | @Override 60 | public void setFold(final int fold) 61 | { 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/trial/ValidationBreaker.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.trial; 2 | 3 | import dmonner.xlbp.stat.Optimizer; 4 | import dmonner.xlbp.stat.SetStat; 5 | import dmonner.xlbp.stat.TestStat; 6 | 7 | public class ValidationBreaker implements TrainingBreaker 8 | { 9 | private final int maxEpochsWithoutImprovement; 10 | private final Optimizer optimizer; 11 | 12 | private int epochsWithoutImprovement; 13 | private SetStat bestValid; 14 | 15 | public ValidationBreaker() 16 | { 17 | this(10); 18 | } 19 | 20 | public ValidationBreaker(final int maxEpochsWithoutImprovement) 21 | { 22 | this(maxEpochsWithoutImprovement, Optimizer.defaultOptimizer); 23 | } 24 | 25 | public ValidationBreaker(final int maxEpochsWithoutImprovement, final Optimizer optimizer) 26 | { 27 | this.maxEpochsWithoutImprovement = maxEpochsWithoutImprovement; 28 | this.optimizer = optimizer; 29 | } 30 | 31 | @Override 32 | public boolean isBreakTime(final TestStat stat) 33 | { 34 | final SetStat newest = stat.getBestValid(); 35 | 36 | if(optimizer.betterThan(newest, bestValid)) 37 | { 38 | bestValid = newest; 39 | epochsWithoutImprovement = 0; 40 | } 41 | else 42 | { 43 | epochsWithoutImprovement++; 44 | } 45 | 46 | return epochsWithoutImprovement >= maxEpochsWithoutImprovement; 47 | } 48 | 49 | @Override 50 | public void reset() 51 | { 52 | bestValid = null; 53 | epochsWithoutImprovement = 0; 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/util/ArrayQueue.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.util; 2 | 3 | public class ArrayQueue 4 | { 5 | private T[] q; 6 | private int head, size; 7 | private int capacity; 8 | 9 | public ArrayQueue() 10 | { 11 | this(0); 12 | } 13 | 14 | @SuppressWarnings("unchecked") 15 | public ArrayQueue(final int capacity) 16 | { 17 | this.q = (T[]) new Object[capacity]; 18 | this.head = 0; 19 | this.size = 0; 20 | this.capacity = capacity; 21 | } 22 | 23 | public ArrayQueue(final T[] full) 24 | { 25 | fill(full); 26 | } 27 | 28 | public int capacity() 29 | { 30 | return capacity; 31 | } 32 | 33 | public void clear() 34 | { 35 | head = 0; 36 | size = 0; 37 | } 38 | 39 | public void fill(final T[] full) 40 | { 41 | q = full; 42 | head = 0; 43 | size = q.length; 44 | capacity = q.length; 45 | } 46 | 47 | public boolean isEmpty() 48 | { 49 | return size == 0; 50 | } 51 | 52 | public boolean isFull() 53 | { 54 | return size >= capacity; 55 | } 56 | 57 | public T peek() 58 | { 59 | if(isEmpty()) 60 | throw new IllegalStateException("Cannot peek -- queue is empty."); 61 | 62 | return q[head]; 63 | } 64 | 65 | public T peek(final int idx) 66 | { 67 | if(idx < 0 || idx >= size) 68 | throw new IllegalStateException("Cannot peek at " + idx + " -- not enough elements."); 69 | 70 | return q[(head + idx) % capacity]; 71 | } 72 | 73 | public T pop() 74 | { 75 | if(isEmpty()) 76 | throw new IllegalStateException("Cannot pop -- queue is empty."); 77 | 78 | final T rv = q[head]; 79 | head = (head + 1) % capacity; 80 | size--; 81 | return rv; 82 | } 83 | 84 | public void popN(final int n) 85 | { 86 | if(size < n) 87 | throw new IllegalStateException("Cannot pop " + n + " -- not enough elements."); 88 | 89 | head = (head + n) % capacity; 90 | size -= n; 91 | } 92 | 93 | public void push(final T elem) 94 | { 95 | if(isFull()) 96 | throw new IllegalStateException("Cannot push -- queue is full."); 97 | 98 | final int next = (head + size) % capacity; 99 | q[next] = elem; 100 | size++; 101 | } 102 | 103 | public int size() 104 | { 105 | return size; 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/util/IndexAwareHeapNode.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.util; 2 | 3 | import java.util.Comparator; 4 | 5 | public class IndexAwareHeapNode> implements 6 | Comparable> 7 | { 8 | private IndexAwareHeap heap; 9 | private int index; 10 | public final E element; 11 | 12 | public IndexAwareHeapNode(final E elem) 13 | { 14 | this.heap = null; 15 | this.element = elem; 16 | this.index = -1; 17 | } 18 | 19 | @Override 20 | public int compareTo(final IndexAwareHeapNode that) 21 | { 22 | final Comparator comp = heap.getComparator(); 23 | // if not comparator was provided, use the default 24 | if(comp == null) 25 | return this.element.compareTo(that.element); 26 | // otherwise use the provided comparator 27 | else 28 | return comp.compare(this.element, that.element); 29 | } 30 | 31 | public IndexAwareHeap getHeap() 32 | { 33 | return heap; 34 | } 35 | 36 | public int getIndex() 37 | { 38 | return index; 39 | } 40 | 41 | public void remove() 42 | { 43 | heap.remove(index); 44 | } 45 | 46 | public void set(final IndexAwareHeap heap, final int index) 47 | { 48 | this.heap = heap; 49 | this.index = index; 50 | } 51 | 52 | @Override 53 | public String toString() 54 | { 55 | return "[" + index + ": " + element.toString() + "]"; 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/util/NoiseGenerator.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.util; 2 | 3 | public interface NoiseGenerator 4 | { 5 | public float next(); 6 | } 7 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/util/NormalNoiseGenerator.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.util; 2 | 3 | import java.util.Random; 4 | 5 | public class NormalNoiseGenerator implements NoiseGenerator 6 | { 7 | private final Random random; 8 | private final float mean; 9 | private final float sd; 10 | private float cached; 11 | 12 | public NormalNoiseGenerator(final Random random, final float mean, final float sd) 13 | { 14 | this.random = random; 15 | this.mean = mean; 16 | this.sd = sd; 17 | } 18 | 19 | @Override 20 | public float next() 21 | { 22 | // Return the cached variable if there is one. 23 | if(!Float.isNaN(cached)) 24 | { 25 | final float rv = mean + cached * sd; 26 | cached = Float.NaN; 27 | return rv; 28 | } 29 | 30 | // Otherwise, generate two new normal variates. 31 | float v1, v2, s; 32 | do 33 | { 34 | v1 = 2 * random.nextFloat() - 1; 35 | v2 = 2 * random.nextFloat() - 1; 36 | s = v1 * v1 + v2 * v2; 37 | } 38 | while(s >= 1); 39 | 40 | // This method generates two normal random variates. Saved one for next 41 | // time, and return the other. 42 | final float factor = (float) Math.sqrt(-2 * Math.log(s) / s); 43 | cached = factor * v1; 44 | return mean + factor * v2 * sd; 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/dmonner/xlbp/util/TableWriter.java: -------------------------------------------------------------------------------- 1 | package dmonner.xlbp.util; 2 | 3 | import java.io.IOException; 4 | 5 | public class TableWriter extends CSVWriter 6 | { 7 | public TableWriter(final String filename) throws IOException 8 | { 9 | this(filename, false); 10 | } 11 | 12 | public TableWriter(final String filename, final boolean append) throws IOException 13 | { 14 | super(filename, append, "\t", "\n", ""); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/Parse.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl; 2 | 3 | import java.util.zip.ZipFile; 4 | 5 | import se.lth.cs.srl.corpus.Sentence; 6 | import se.lth.cs.srl.io.AllCoNLL09Reader; 7 | import se.lth.cs.srl.io.CoNLL09Writer; 8 | import se.lth.cs.srl.io.DepsOnlyCoNLL09Reader; 9 | import se.lth.cs.srl.io.FrameNetXMLWriter; 10 | import se.lth.cs.srl.io.SRLOnlyCoNLL09Reader; 11 | import se.lth.cs.srl.io.SentenceReader; 12 | import se.lth.cs.srl.io.SentenceWriter; 13 | import se.lth.cs.srl.options.ParseOptions; 14 | import se.lth.cs.srl.pipeline.Pipeline; 15 | import se.lth.cs.srl.pipeline.Reranker; 16 | import se.lth.cs.srl.pipeline.Step; 17 | import se.lth.cs.srl.util.Util; 18 | 19 | public class Parse { 20 | public static ParseOptions parseOptions; 21 | 22 | public static void main(String[] args) throws Exception { 23 | long startTime = System.currentTimeMillis(); 24 | parseOptions = new ParseOptions(args); 25 | 26 | SemanticRoleLabeler srl; 27 | 28 | if (parseOptions.useReranker) { 29 | srl = new Reranker(parseOptions); 30 | // srl = 31 | // Reranker.fromZipFile(zipFile,parseOptions.skipPI,parseOptions.global_alfa,parseOptions.global_aiBeam,parseOptions.global_acBeam); 32 | } else { 33 | ZipFile zipFile = new ZipFile(parseOptions.modelFile); 34 | 35 | srl = parseOptions.skipAI ? Pipeline.fromZipFile(zipFile, 36 | new Step[] { Step.ac }) 37 | : parseOptions.skipPD ? Pipeline.fromZipFile(zipFile, 38 | new Step[] { Step.ai, Step.ac }) 39 | : parseOptions.skipPI ? Pipeline.fromZipFile(zipFile, 40 | new Step[] { Step.pd, Step.ai, Step.ac}) 41 | : Pipeline.fromZipFile(zipFile); 42 | zipFile.close(); 43 | } 44 | 45 | SentenceWriter writer = null; 46 | if (parseOptions.printXML) 47 | writer = new FrameNetXMLWriter(parseOptions.output); 48 | else 49 | writer = new CoNLL09Writer(parseOptions.output); 50 | 51 | SentenceReader reader = parseOptions.skipAI ? new AllCoNLL09Reader( 52 | parseOptions.inputCorpus) : parseOptions.skipPI ? new SRLOnlyCoNLL09Reader( 53 | parseOptions.inputCorpus) : new DepsOnlyCoNLL09Reader( 54 | parseOptions.inputCorpus); 55 | int senCount = 0; 56 | for (Sentence s : reader) { 57 | senCount++; 58 | if (senCount % 100 == 0) 59 | System.out.println("Parsing sentence " + senCount); 60 | srl.parseSentence(s); 61 | if (parseOptions.writeCoref) 62 | writer.specialwrite(s); 63 | else 64 | writer.write(s); 65 | } 66 | writer.close(); 67 | reader.close(); 68 | long totalTime = System.currentTimeMillis() - startTime; 69 | System.out.println("Done."); 70 | System.out.println(srl.getStatus()); 71 | System.out.println(); 72 | System.out.println("Total execution time: " 73 | + Util.insertCommas(totalTime) + "ms"); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/SemanticRoleLabeler.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl; 2 | 3 | import java.util.Date; 4 | 5 | import se.lth.cs.srl.corpus.Sentence; 6 | import se.lth.cs.srl.util.Util; 7 | 8 | public abstract class SemanticRoleLabeler { 9 | 10 | public void parseSentence(Sentence s) { 11 | long startTime = System.currentTimeMillis(); 12 | parse(s); 13 | parsingTime += System.currentTimeMillis() - startTime; 14 | senCount++; 15 | predCount += s.getPredicates().size(); 16 | } 17 | 18 | protected abstract void parse(Sentence s); 19 | 20 | public String getStatus() { 21 | StringBuilder ret = new StringBuilder( 22 | "Semantic role labeler started at " + startDate + "\n"); 23 | ret.append("Time spent loading SRL models (ms)\t\t" 24 | + Util.insertCommas(loadingTime) + "\n"); 25 | ret.append("Time spent parsing semantic roles (ms)\t\t" 26 | + Util.insertCommas(parsingTime) + "\n"); 27 | ret.append("\n"); 28 | ret.append("Number of sentences\t" + Util.insertCommas(senCount) + "\n"); 29 | ret.append("Number of predicates\t" + Util.insertCommas(predCount) 30 | + "\n"); 31 | ret.append("SRL speed (ms/sen)\t" + ((double) parsingTime / senCount) 32 | + "\n"); 33 | ret.append(getSubStatus()); 34 | return ret.toString(); 35 | } 36 | 37 | protected abstract String getSubStatus(); 38 | 39 | public long loadingTime = 0; 40 | public long parsingTime = 0; 41 | public int senCount = 0; 42 | public int predCount = 0; 43 | public final Date startDate = new Date(); 44 | 45 | } 46 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/corpus/ConstituentBuilder.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.corpus; 2 | 3 | import java.util.TreeSet; 4 | 5 | public class ConstituentBuilder { 6 | 7 | Sentence sen; 8 | Word head; 9 | 10 | public ConstituentBuilder(Sentence s, Word w) { 11 | sen = s; 12 | head = w; 13 | } 14 | 15 | public String toString() { 16 | TreeSet children = new TreeSet<>(); 17 | children.add(head.getIdx()); 18 | TreeSet processed = new TreeSet<>(); 19 | while (!children.isEmpty()) { 20 | Word c = sen.get(children.pollFirst()); 21 | if (!processed.contains(c.getIdx())) { 22 | processed.add(c.getIdx()); 23 | for (Word cc : c.getChildren()) { 24 | children.add(cc.getIdx()); 25 | } 26 | } 27 | } 28 | 29 | StringBuilder sb = new StringBuilder(); 30 | String[] wordformArray = sen.getFormArray(); 31 | for (int i = processed.first(); i <= processed.last(); i++) { 32 | sb.append(wordformArray[i]); 33 | sb.append(" "); 34 | } 35 | return sb.toString().trim(); 36 | } 37 | } -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/corpus/CorefChain.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.corpus; 2 | 3 | import java.util.LinkedList; 4 | 5 | public class CorefChain extends LinkedList { 6 | 7 | int id; 8 | 9 | public CorefChain(int id) { 10 | this.id = id; 11 | } 12 | 13 | public String getId() { 14 | return new Integer(id).toString(); 15 | } 16 | 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/corpus/StringInText.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.corpus; 2 | 3 | public class StringInText { 4 | protected String s; 5 | protected int beginPos; 6 | protected int endPos; 7 | 8 | public StringInText(String s, int beginPos, int endPos) { 9 | this.s = s; 10 | this.beginPos = beginPos; 11 | this.endPos = endPos; 12 | } 13 | 14 | public String word() { 15 | return s; 16 | } 17 | 18 | public int begin() { 19 | return beginPos; 20 | } 21 | 22 | public int end() { 23 | return endPos; 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/io/AbstractCoNLL09Reader.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.io; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.File; 5 | import java.io.FileInputStream; 6 | import java.io.IOException; 7 | import java.io.InputStreamReader; 8 | import java.nio.charset.Charset; 9 | import java.util.ArrayList; 10 | import java.util.Iterator; 11 | import java.util.List; 12 | import java.util.regex.Pattern; 13 | 14 | //import se.lth.cs.srl.corpus.Corpus; 15 | import se.lth.cs.srl.corpus.Sentence; 16 | 17 | public abstract class AbstractCoNLL09Reader implements SentenceReader { 18 | 19 | protected static final Pattern NEWLINE_PATTERN = Pattern.compile("\n"); 20 | 21 | protected BufferedReader in; 22 | protected Sentence nextSen; 23 | // protected Corpus c; 24 | private File file; 25 | 26 | public AbstractCoNLL09Reader(File file) { 27 | this.file = file; 28 | open(); 29 | } 30 | 31 | private void restart() { 32 | try { 33 | in.close(); 34 | open(); 35 | } catch (IOException e) { 36 | // TODO Auto-generated catch block 37 | e.printStackTrace(); 38 | } 39 | } 40 | 41 | private void open() { 42 | System.err.println("Opening reader for " + file + "..."); 43 | try { 44 | in = new BufferedReader(new InputStreamReader(new FileInputStream( 45 | file), Charset.forName("UTF-8"))); 46 | // in = new BufferedReader(new FileReader(file)); 47 | readNextSentence(); 48 | } catch (IOException e) { 49 | System.out.println("Failed: " + e.toString()); 50 | System.exit(1); 51 | } 52 | } 53 | 54 | protected abstract void readNextSentence() throws IOException; 55 | 56 | private Sentence getSentence() { 57 | Sentence ret = nextSen; 58 | try { 59 | readNextSentence(); 60 | } catch (IOException e) { 61 | System.out.println("Failed to read from corpus file... exiting."); 62 | System.exit(1); 63 | } 64 | return ret; 65 | } 66 | 67 | @Override 68 | public List readAll() { 69 | ArrayList ret = new ArrayList<>(); 70 | for (Sentence s : this) 71 | ret.add(s); 72 | ret.trimToSize(); 73 | return ret; 74 | } 75 | 76 | @Override 77 | public Iterator iterator() { 78 | if (nextSen == null) 79 | restart(); 80 | return new SentenceIterator(); 81 | } 82 | 83 | @Override 84 | public void close() { 85 | try { 86 | in.close(); 87 | } catch (IOException e) { 88 | // TODO Auto-generated catch block 89 | e.printStackTrace(); 90 | } 91 | } 92 | 93 | private class SentenceIterator implements Iterator { 94 | @Override 95 | public boolean hasNext() { 96 | return nextSen != null; 97 | } 98 | 99 | @Override 100 | public Sentence next() { 101 | return getSentence(); 102 | } 103 | 104 | @Override 105 | public void remove() { 106 | throw new UnsupportedOperationException("Not implemented"); 107 | } 108 | 109 | } 110 | 111 | } -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/io/AllCoNLL09Reader.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.io; 2 | 3 | import java.io.File; 4 | import java.io.IOException; 5 | 6 | import se.lth.cs.srl.corpus.Sentence; 7 | 8 | public class AllCoNLL09Reader extends AbstractCoNLL09Reader { 9 | 10 | public AllCoNLL09Reader(File file) { 11 | super(file); 12 | } 13 | 14 | protected void readNextSentence() throws IOException { 15 | String str; 16 | Sentence sen = null; 17 | StringBuilder senBuffer = new StringBuilder(); 18 | while ((str = in.readLine()) != null) { 19 | if (!str.trim().equals("")) { 20 | senBuffer.append(str).append("\n"); 21 | } else { 22 | sen=Sentence.newSentence((NEWLINE_PATTERN.split(senBuffer.toString()))); 23 | break; 24 | } 25 | } 26 | if (sen == null) { 27 | nextSen = null; 28 | in.close(); 29 | } else { 30 | nextSen = sen; 31 | } 32 | } 33 | 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/io/CoNLL09Writer.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.io; 2 | 3 | import java.io.BufferedWriter; 4 | import java.io.File; 5 | import java.io.FileOutputStream; 6 | import java.io.IOException; 7 | import java.io.OutputStreamWriter; 8 | import java.nio.charset.Charset; 9 | 10 | import se.lth.cs.srl.corpus.Sentence; 11 | 12 | public class CoNLL09Writer implements SentenceWriter { 13 | 14 | private BufferedWriter out; 15 | 16 | public CoNLL09Writer(File filename) { 17 | System.out.println("Writing corpus to " + filename + "..."); 18 | try { 19 | out = new BufferedWriter(new OutputStreamWriter( 20 | new FileOutputStream(filename), Charset.forName("UTF-8"))); 21 | // out = new BufferedWriter(new FileWriter(filename)); 22 | } catch (IOException e) { 23 | System.out.println("Failed while opening writer...\n" 24 | + e.toString()); 25 | System.exit(1); 26 | } 27 | } 28 | 29 | public void write(Sentence s) { 30 | try { 31 | out.write(s.toString() + "\n\n"); 32 | } catch (IOException e) { 33 | e.printStackTrace(); 34 | System.out.println("Failed to write sentance."); 35 | System.exit(1); 36 | } 37 | } 38 | 39 | @Override 40 | public void specialwrite(Sentence s) { 41 | try { 42 | out.write(s.toSpecialString() + "\n\n"); 43 | } catch (IOException e) { 44 | e.printStackTrace(); 45 | System.out.println("Failed to write sentance."); 46 | System.exit(1); 47 | } 48 | } 49 | 50 | public void close() { 51 | try { 52 | out.close(); 53 | } catch (IOException e) { 54 | e.printStackTrace(); 55 | System.out.println("Failed to close writer."); 56 | System.exit(1); 57 | } 58 | } 59 | 60 | } 61 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/io/DepsOnlyCoNLL09Reader.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.io; 2 | 3 | import java.io.File; 4 | import java.io.IOException; 5 | 6 | import se.lth.cs.srl.corpus.Sentence; 7 | 8 | public class DepsOnlyCoNLL09Reader extends AbstractCoNLL09Reader { 9 | 10 | public DepsOnlyCoNLL09Reader(File file) { 11 | super(file); 12 | } 13 | 14 | @Override 15 | protected void readNextSentence() throws IOException { 16 | String str; 17 | Sentence sen = null; 18 | StringBuilder senBuffer = new StringBuilder(); 19 | while ((str = in.readLine()) != null) { 20 | if (!str.trim().equals("")) { 21 | senBuffer.append(str).append("\n"); 22 | } else { 23 | sen = Sentence.newDepsOnlySentence(NEWLINE_PATTERN 24 | .split(senBuffer.toString())); 25 | // System.err.println("Processing sentence ..."); 26 | // System.err.println(sen); 27 | break; 28 | } 29 | } 30 | if (sen == null) { 31 | nextSen = null; 32 | in.close(); 33 | } else { 34 | nextSen = sen; 35 | } 36 | } 37 | 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/io/SRLOnlyCoNLL09Reader.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.io; 2 | 3 | import java.io.File; 4 | import java.io.IOException; 5 | 6 | import se.lth.cs.srl.corpus.Sentence; 7 | 8 | public class SRLOnlyCoNLL09Reader extends AbstractCoNLL09Reader { 9 | 10 | public SRLOnlyCoNLL09Reader(File file) { 11 | super(file); 12 | } 13 | 14 | @Override 15 | protected void readNextSentence() throws IOException { 16 | String str; 17 | Sentence sen = null; 18 | StringBuilder senBuffer = new StringBuilder(); 19 | while ((str = in.readLine()) != null) { 20 | if (!str.trim().equals("")) { 21 | senBuffer.append(str).append("\n"); 22 | } else { 23 | if (!senBuffer.toString().startsWith("_")) 24 | sen = Sentence.newSRLOnlySentence((NEWLINE_PATTERN 25 | .split(senBuffer.toString()))); 26 | break; 27 | } 28 | } 29 | if (sen == null) { 30 | nextSen = null; 31 | in.close(); 32 | } else { 33 | nextSen = sen; 34 | } 35 | 36 | } 37 | 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/io/SentenceReader.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.io; 2 | 3 | import java.util.List; 4 | 5 | import se.lth.cs.srl.corpus.Sentence; 6 | 7 | public interface SentenceReader extends Iterable { 8 | 9 | List readAll(); 10 | 11 | public void close(); 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/io/SentenceWriter.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.io; 2 | 3 | //import se.lth.cs.srl.corpus.Corpus; 4 | import se.lth.cs.srl.corpus.Sentence; 5 | 6 | public interface SentenceWriter { 7 | 8 | public void write(Sentence s); 9 | 10 | public void close(); 11 | 12 | public void specialwrite(Sentence s); 13 | 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/languages/AbstractDummyLanguage.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.languages; 2 | 3 | import java.util.Map; 4 | 5 | import se.lth.cs.srl.corpus.Predicate; 6 | import se.lth.cs.srl.corpus.Word; 7 | import se.lth.cs.srl.options.FullPipelineOptions; 8 | 9 | /** 10 | * this class is just so we can run the pipeline without the srl stuff i.e., to 11 | * run the http interface of the anna pipeline without srl. 12 | * 13 | * @author anders 14 | * 15 | */ 16 | public abstract class AbstractDummyLanguage extends Language { 17 | 18 | @Override 19 | public String getDefaultSense(Predicate pred) { 20 | throw new Error("!"); 21 | } 22 | 23 | @Override 24 | public String getCoreArgumentLabelSequence(Predicate pred, 25 | Map proposition) { 26 | throw new Error("!"); 27 | } 28 | 29 | @Override 30 | public String getLexiconURL(Predicate pred) { 31 | throw new Error("!"); 32 | } 33 | 34 | @Override 35 | public String verifyLanguageSpecificModelFiles(FullPipelineOptions options) { 36 | return null; 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/languages/Chinese.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.languages; 2 | 3 | import is2.lemmatizer.Lemmatizer; 4 | 5 | import java.io.File; 6 | import java.io.IOException; 7 | import java.util.Map; 8 | import java.util.regex.Pattern; 9 | 10 | import se.lth.cs.srl.corpus.Predicate; 11 | import se.lth.cs.srl.corpus.Sentence; 12 | import se.lth.cs.srl.corpus.Word; 13 | import se.lth.cs.srl.options.FullPipelineOptions; 14 | import se.lth.cs.srl.preprocessor.SimpleChineseLemmatizer; 15 | //import se.lth.cs.srl.preprocessor.tokenization.StanfordChineseSegmenterWrapper; 16 | import se.lth.cs.srl.preprocessor.tokenization.Tokenizer; 17 | import se.lth.cs.srl.util.FileExistenceVerifier; 18 | 19 | public class Chinese extends Language { 20 | 21 | private static Pattern CALSPattern=Pattern.compile("^A0|A1|A2|A3|A4$"); 22 | @Override 23 | public String getCoreArgumentLabelSequence(Predicate pred,Map proposition) { 24 | Sentence sen=pred.getMySentence(); 25 | StringBuilder ret=new StringBuilder(); 26 | for(int i=1,size=sen.size();i the input corpus. assumed to be tokenized like CoNLL 09 data\n" 24 | + "-out the file to write output to (default out.txt)\n" 25 | + "-nopi skips the predicate identification\n" 26 | + "-tokenize implies the input is unsegmented, with one sentence per line, i.e. _not_ CoNLL09 format"; 27 | } 28 | 29 | @Override 30 | protected int trySubParseArg(String[] args, int ai) { 31 | if (args[ai].equals("-out")) { 32 | ai++; 33 | output = new File(args[ai]); 34 | ai++; 35 | } else if (args[ai].equals("-test")) { 36 | ai++; 37 | input = new File(args[ai]); 38 | ai++; 39 | } else if (args[ai].equals("-nopi")) { 40 | ai++; 41 | skipPI = true; 42 | } else if (args[ai].equals("-noai")) { 43 | ai++; 44 | skipAI = true; 45 | } else if (args[ai].equals("-desegment")) { // Not printed out in the 46 | // help 47 | // (getSubUsageOptions()), 48 | // don't think it needs to. 49 | // This is only 50 | // experimental. 51 | ai++; 52 | desegment = true; 53 | skipPI = false; // This won't be regarded anyway. It's not 54 | // applicable when the initial segmentation is lost. 55 | } else if (args[ai].equals("-tokenize")) { 56 | ai++; 57 | super.loadPreprocessorWithTokenizer = true; 58 | skipPI = false; // Same as above 59 | desegment = false; 60 | } else if (args[ai].equals("-hybrid")) { 61 | hybrid = true; 62 | ai++; 63 | } else if (args[ai].equals("-external")) { 64 | external = true; 65 | ai++; 66 | } 67 | return ai; 68 | } 69 | 70 | @Override 71 | protected Class getIntendedEntryClass() { 72 | return CompletePipeline.class; 73 | } 74 | 75 | } 76 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/pipeline/AbstractStep.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.pipeline; 2 | 3 | import java.io.IOException; 4 | import java.io.ObjectInputStream; 5 | import java.util.HashMap; 6 | import java.util.Map; 7 | import java.util.zip.ZipFile; 8 | 9 | import se.lth.cs.srl.corpus.Sentence; 10 | import uk.ac.ed.inf.srl.features.FeatureSet; 11 | import uk.ac.ed.inf.srl.ml.Model; 12 | 13 | public abstract class AbstractStep implements PipelineStep { 14 | 15 | public static final Integer POSITIVE = 1; 16 | public static final Integer NEGATIVE = 0; 17 | 18 | protected FeatureSet featureSet; 19 | protected Map models; 20 | 21 | public AbstractStep(FeatureSet fs) { 22 | this.featureSet = fs; 23 | } 24 | 25 | public abstract void extractInstances(Sentence s); 26 | 27 | public abstract void parse(Sentence s); 28 | 29 | protected abstract String getModelFileName(); 30 | 31 | @Override 32 | public void readModels(ZipFile zipFile) throws IOException, 33 | ClassNotFoundException { 34 | models = new HashMap<>(); 35 | readModels(zipFile, models, getModelFileName()); 36 | } 37 | 38 | 39 | static void readModels(ZipFile zipFile, Map models, 40 | String filename) throws IOException, ClassNotFoundException { 41 | ObjectInputStream ois = new ObjectInputStream( 42 | zipFile.getInputStream(zipFile.getEntry(filename))); 43 | int numberOfModels = ois.readInt(); 44 | for (int i = 0; i < numberOfModels; ++i) { 45 | String POSPrefix = (String) ois.readObject(); 46 | Model m = (Model) ois.readObject(); 47 | models.put(POSPrefix, m); 48 | } 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/pipeline/PipelineStep.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.pipeline; 2 | 3 | import java.io.IOException; 4 | import java.util.zip.ZipFile; 5 | import java.util.zip.ZipOutputStream; 6 | 7 | import se.lth.cs.srl.corpus.Sentence; 8 | 9 | public interface PipelineStep { 10 | 11 | public void prepareLearning(); 12 | 13 | public void extractInstances(Sentence s); 14 | 15 | public void done(); 16 | 17 | public void train(); 18 | 19 | public void writeModels(ZipOutputStream zos) throws IOException; 20 | 21 | public void readModels(ZipFile zipFile) throws IOException, 22 | ClassNotFoundException; 23 | 24 | public void parse(Sentence s); 25 | 26 | public void prepareLearning(int i); 27 | 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/pipeline/Step.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.pipeline; 2 | 3 | public enum Step { 4 | pi, pd, ai, ac, /* po, ao */ 5 | } -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/preprocessor/CMDLineTokenizer.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.preprocessor; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.File; 5 | import java.io.IOException; 6 | import java.io.InputStreamReader; 7 | 8 | import se.lth.cs.srl.languages.Language; 9 | import se.lth.cs.srl.languages.Language.L; 10 | import se.lth.cs.srl.preprocessor.tokenization.Tokenizer; 11 | 12 | public class CMDLineTokenizer { 13 | 14 | public static void main(String[] args) throws IOException { 15 | String l = args[0]; 16 | File modelFile = args.length > 1 ? new File(args[1]) : null; 17 | L lang = null; 18 | try { 19 | lang = L.valueOf(l); 20 | } catch (Exception e) { 21 | System.err.println("Unknown language " + l + ", aborting."); 22 | System.exit(1); 23 | } 24 | Language.setLanguage(lang); 25 | Tokenizer tokenizer = Language.getLanguage().getTokenizer(modelFile); 26 | BufferedReader reader = new BufferedReader(new InputStreamReader( 27 | System.in, "UTF8")); 28 | String line; 29 | int senCount = 0; 30 | while ((line = reader.readLine()) != null) { 31 | senCount++; 32 | String[] tokens = tokenizer.tokenize(line); 33 | for (int i = 1; i < tokens.length; ++i) { 34 | StringBuilder sb = new StringBuilder(); 35 | sb.append(senCount).append('_').append(i).append('\t') 36 | .append(tokens[i]).append(COLUMNS); 37 | System.out.println(sb.toString()); 38 | } 39 | System.out.println(); 40 | } 41 | } 42 | 43 | static final String COLUMNS = "\t_\t_\t_\t_\t_\t_\t_\t_\t_\t_\t_\t_"; 44 | 45 | public static void usage() { 46 | System.err 47 | .println("Reads untokenized text on STDIN (one sentence per line), and writes it out in CoNLL09 format to STDOUT"); 48 | System.err.println("Usage: java -cp ... " 49 | + CMDLineTokenizer.class.getCanonicalName() 50 | + " [model-file]"); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/preprocessor/Preprocessor.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.preprocessor; 2 | 3 | import se.lth.cs.srl.corpus.StringInText; 4 | import se.lth.cs.srl.preprocessor.tokenization.Tokenizer; 5 | import is2.data.SentenceData09; 6 | 7 | public abstract class Preprocessor { 8 | protected Tokenizer tokenizer; 9 | public long tokenizeTime = 0; 10 | public long lemmatizeTime = 0; 11 | public long dpTime = 0; 12 | 13 | public abstract boolean hasParser(); 14 | 15 | public abstract StringBuilder getStatus(); 16 | 17 | protected abstract SentenceData09 preprocess(SentenceData09 sentence); 18 | 19 | public SentenceData09 preprocess(String[] forms) { 20 | SentenceData09 instance = new SentenceData09(); 21 | instance.init(forms); 22 | return preprocess(instance); 23 | } 24 | 25 | public String[] tokenize(String sentence) { 26 | synchronized (tokenizer) { 27 | long start = System.currentTimeMillis(); 28 | String[] words = tokenizer.tokenize(sentence); 29 | tokenizeTime += (System.currentTimeMillis() - start); 30 | return words; 31 | } 32 | } 33 | 34 | public StringInText[] tokenizeplus(String sentence) { 35 | synchronized (tokenizer) { 36 | long start = System.currentTimeMillis(); 37 | StringInText[] words = tokenizer.tokenizeplus(sentence); 38 | tokenizeTime += (System.currentTimeMillis() - start); 39 | return words; 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/preprocessor/SimpleChineseLemmatizer.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.preprocessor; 2 | 3 | import is2.data.SentenceData09; 4 | import is2.lemmatizer.Lemmatizer; 5 | 6 | public class SimpleChineseLemmatizer extends Lemmatizer { 7 | 8 | // @Override 9 | // public String[] getLemmas(String[] forms) { 10 | // if(true) 11 | // throw new 12 | // Error("This method should not be trusted. Fix the root token in accordance with is2.lemmatizer.Lemmatizer before using it."); 13 | // String[] ret=new String[forms.length]; //TODO, make sure to deal with the 14 | // root token properly. 15 | // ret[0]=""; 16 | // for(int i=1;i l = new ArrayList<>(); 50 | for (String s : tokenize(sentence)) { 51 | Word w = new Word(s); 52 | l.add(new StringInText(w.word(), w.beginPosition() + startpos, w 53 | .endPosition() + startpos)); 54 | } 55 | StringInText[] tok = new StringInText[l.size()]; 56 | // tok[0]=new StringInText(is2.io.CONLLReader09.ROOT,0,0); 57 | int i = 0; 58 | for (StringInText s : l) 59 | tok[i++] = s; 60 | 61 | startpos += (1 + sentence.length()); 62 | 63 | return tok; 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/preprocessor/tokenization/Tokenizer.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.preprocessor.tokenization; 2 | 3 | import se.lth.cs.srl.corpus.StringInText; 4 | 5 | public interface Tokenizer { 6 | 7 | /** 8 | * Tokenize a sentence. The returned array contains a root-token 9 | * 10 | * @param sentence 11 | * The sentence to tokenize 12 | * @return a root token, followed by the forms 13 | */ 14 | public abstract String[] tokenize(String sentence); 15 | 16 | public abstract StringInText[] tokenizeplus(String sentence); 17 | 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/preprocessor/tokenization/WhiteSpaceTokenizer.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.preprocessor.tokenization; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | import java.util.StringTokenizer; 6 | 7 | import se.lth.cs.srl.corpus.StringInText; 8 | 9 | public class WhiteSpaceTokenizer implements Tokenizer { 10 | 11 | @Override 12 | public String[] tokenize(String sentence) { 13 | StringTokenizer tokenizer = new StringTokenizer(sentence); 14 | String[] tokens = new String[tokenizer.countTokens() + 1]; 15 | int r = 0; 16 | tokens[r++] = is2.io.CONLLReader09.ROOT; 17 | while (tokenizer.hasMoreTokens()) 18 | tokens[r++] = tokenizer.nextToken(); 19 | return tokens; 20 | } 21 | 22 | public static void main(String[] args) { 23 | String t1 = "En gul bil körde hundratusen mil."; 24 | String t2 = "Leonardos fullständiga namn var Leonardo di ser Piero da Vinci."; 25 | String t3 = "Genom skattereformen införs individuell beskattning (särbeskattning) av arbetsinkomster."; 26 | String t4 = "\"Oh, no,\" she's saying, \"our $400 blender can't handle something this hard!\""; 27 | String[] tests = { t1, t2, t3, t4 }; 28 | for (String test : tests) { 29 | WhiteSpaceTokenizer tokenizer = new WhiteSpaceTokenizer(); 30 | String[] tokens = tokenizer.tokenize(test); 31 | for (String token : tokens) 32 | System.out.println(token); 33 | System.out.println(); 34 | } 35 | } 36 | 37 | @Override 38 | public StringInText[] tokenizeplus(String sentence) { 39 | int offset = 0; 40 | StringTokenizer tokenizer = new StringTokenizer(sentence); 41 | List l = new ArrayList<>(); 42 | 43 | while (tokenizer.hasMoreTokens()) { 44 | String s = tokenizer.nextToken(); 45 | l.add(new StringInText(s, offset, offset + s.length())); 46 | offset += (1 + s.length()); 47 | } 48 | StringInText[] tok = new StringInText[l.size() + 1]; 49 | tok[0] = new StringInText(is2.io.CONLLReader09.ROOT, 0, 0); 50 | int i = 1; 51 | for (StringInText s : l) 52 | tok[i++] = s; 53 | 54 | return tok; 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/preprocessor/tokenization/exner/SwedishTokenizer.java: -------------------------------------------------------------------------------- 1 | /** 2 | * SweNLP is a framework for performing parallel processing of text. 3 | * Copyright � 2011 Peter Exner 4 | * 5 | * This file is part of SweNLP. 6 | * 7 | * SweNLP is free software: you can redistribute it and/or modify 8 | * it under the terms of the GNU General Public License as published by 9 | * the Free Software Foundation, either version 3 of the License, or 10 | * (at your option) any later version. 11 | * 12 | * SweNLP is distributed in the hope that it will be useful, 13 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | * GNU General Public License for more details. 16 | * 17 | * You should have received a copy of the GNU General Public License 18 | * along with SweNLP. If not, see . 19 | */ 20 | 21 | package se.lth.cs.srl.preprocessor.tokenization.exner; 22 | 23 | import java.io.IOException; 24 | import java.io.StringReader; 25 | import java.io.UnsupportedEncodingException; 26 | import java.nio.charset.Charset; 27 | import java.util.ArrayList; 28 | 29 | public class SwedishTokenizer { 30 | public ArrayList tokenize(String sentence, Charset charset) { 31 | ArrayList tokens = new ArrayList<>(); 32 | 33 | try { 34 | sentence = preProcessSentence(sentence); 35 | 36 | Tokenizer swedishTokenizer = new Tokenizer(new StringReader( 37 | new String(sentence.getBytes(charset), 38 | Charset.forName("UTF-8")))); 39 | 40 | while (swedishTokenizer.getNextToken() >= 0) { 41 | tokens.add(swedishTokenizer.yytext()); 42 | } 43 | 44 | } catch (UnsupportedEncodingException e) { 45 | // TODO Auto-generated catch block 46 | e.printStackTrace(); 47 | } catch (IOException e) { 48 | // TODO Auto-generated catch block 49 | e.printStackTrace(); 50 | } 51 | 52 | return tokens; 53 | } 54 | 55 | private static String preProcessSentence(String sentence) { 56 | // Done to make the flex rules match 57 | 58 | sentence = sentence.replaceAll(" \\.", "\\."); 59 | sentence = sentence.replaceAll(" \\)", "\\)"); 60 | 61 | return sentence; 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/util/BohnetHelper.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.util; 2 | 3 | import is2.lemmatizer.Lemmatizer; 4 | import is2.parser.Parser; 5 | import is2.tag.Tagger; 6 | 7 | import java.io.File; 8 | import java.io.FileNotFoundException; 9 | import java.io.IOException; 10 | 11 | import se.lth.cs.srl.options.FullPipelineOptions; 12 | import se.lth.cs.srl.options.Options; 13 | 14 | public class BohnetHelper { 15 | 16 | public static Lemmatizer getLemmatizer(File modelFile) 17 | throws FileNotFoundException, IOException { 18 | String[] argsL = { "-model", modelFile.toString() }; 19 | return new Lemmatizer(modelFile.toString(), false); 20 | 21 | // new is2.lemmatizer.Options(argsL)); 22 | } 23 | 24 | public static Tagger getTagger(File modelFile) { 25 | String[] argsT = { "-model", modelFile.toString() }; 26 | return new Tagger(modelFile.toString()); 27 | // new is2.tag.Options(argsT)); 28 | } 29 | 30 | public static is2.mtag.Tagger getMTagger(File modelFile) throws IOException { 31 | String[] argsMT = { "-model", modelFile.toString() }; 32 | return new is2.mtag.Tagger(modelFile.toString()); 33 | // new is2.mtag.Options(argsMT)); 34 | } 35 | 36 | public static Parser getParser(File modelFile) { 37 | String[] argsDP = { 38 | "-model", 39 | modelFile.toString(), 40 | "-cores", 41 | Integer.toString(Math.min(Options.cores, 42 | FullPipelineOptions.cores)) }; 43 | return new Parser(modelFile.toString()); 44 | // new is2.parser.Options(argsDP)); 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/util/ChineseDesegmenter.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.util; 2 | 3 | import java.io.File; 4 | import java.io.FileNotFoundException; 5 | import java.io.FileOutputStream; 6 | import java.io.PrintStream; 7 | 8 | import se.lth.cs.srl.corpus.Sentence; 9 | import se.lth.cs.srl.io.DepsOnlyCoNLL09Reader; 10 | import se.lth.cs.srl.io.SentenceReader; 11 | 12 | public class ChineseDesegmenter { 13 | 14 | public static String desegment(String[] forms) { 15 | StringBuilder ret = new StringBuilder(); 16 | for (int i = 1; i < forms.length; ++i) 17 | ret.append(forms[i]); 18 | return ret.toString(); 19 | } 20 | 21 | public static void main(String[] args) throws FileNotFoundException { 22 | 23 | String inputFile = 24 | // "/home/anders/corpora/conll09/CoNLL2009-ST-Chinese-train.txt"; 25 | "/home/anders/corpora/conll09/chi/CoNLL2009-ST-evaluation-Chinese.txt"; 26 | String outputFile = "chi-desegmented.out"; 27 | boolean separateLines = true; 28 | if (args.length > 0) 29 | inputFile = args[0]; 30 | if (args.length > 1) 31 | outputFile = args[1]; 32 | if (args.length > 2) 33 | separateLines = Boolean.parseBoolean(args[2]); // Whether to print 34 | // newlines between 35 | // sentences. 36 | File input = new File(inputFile); 37 | File output = new File(outputFile); 38 | SentenceReader reader = new DepsOnlyCoNLL09Reader(input); 39 | PrintStream out = new PrintStream(new FileOutputStream(output)); 40 | for (Sentence s : reader) { 41 | String desegmented = desegment(s.getFormArray()); 42 | if (separateLines) 43 | out.println(desegmented); 44 | else 45 | out.print(desegmented); 46 | } 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/util/FileExistenceVerifier.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.util; 2 | 3 | import java.io.File; 4 | 5 | import se.lth.cs.srl.languages.Language; 6 | import se.lth.cs.srl.options.FullPipelineOptions; 7 | 8 | public class FileExistenceVerifier { 9 | 10 | /** 11 | * Checks if a file can be exists, and if it can be read. 12 | * 13 | * @param files 14 | * vararg number of files to check 15 | * @return null if all is good, otherwise a String containing the error 16 | * message. 17 | */ 18 | public static String verifyFiles(File... files) { 19 | StringBuilder sb = new StringBuilder(); 20 | for (File f : files) { 21 | if (f == null || !f.exists()) { 22 | sb.append("File " + f + " does not exist.\n"); 23 | } 24 | if (f == null || !f.canRead()) { 25 | sb.append("File " + f + " can not be read.\n"); 26 | } 27 | } 28 | if (sb.length() == 0) 29 | return null; 30 | else 31 | return sb.toString(); 32 | } 33 | 34 | public static String verifyCompletePipelineAlwaysNecessaryFiles( 35 | FullPipelineOptions options) { 36 | return verifyFiles(/* options.tagger, options.parser,*/ options.srl); 37 | } 38 | 39 | public static String verifyCompletePipelineAllNecessaryModelFiles( 40 | FullPipelineOptions options) { 41 | String error1 = verifyCompletePipelineAlwaysNecessaryFiles(options); 42 | String error2 = Language.getLanguage() 43 | .verifyLanguageSpecificModelFiles(options); 44 | if (error1 != null) { 45 | if (error2 != null) { 46 | return error1 + error2; 47 | } else { 48 | return error1; 49 | } 50 | } else if (error2 != null) { 51 | return error2; 52 | } else { 53 | return null; 54 | } 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/util/Relation.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.util; 2 | 3 | public class Relation { 4 | public int head; 5 | public int dependent; 6 | public String label; 7 | 8 | public Relation(int head, int dependent, String label) { 9 | this.head = head; 10 | this.dependent = dependent; 11 | this.label = label; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/util/TurboParser.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.util; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.FileReader; 5 | import java.io.IOException; 6 | 7 | import se.lth.cs.srl.corpus.Sentence; 8 | import se.lth.cs.srl.corpus.Word; 9 | 10 | public class TurboParser { 11 | 12 | BufferedReader br; 13 | 14 | public TurboParser(String file) { 15 | try { 16 | br = new BufferedReader(new FileReader(file)); 17 | 18 | } catch (IOException e) { 19 | e.printStackTrace(); 20 | System.exit(1); 21 | } 22 | } 23 | 24 | public void overwriteParse(Sentence s) { 25 | try { 26 | // skip ROOT (i==0); 27 | for (int i = 1; i < s.size(); i++) { 28 | Word w = s.get(i); 29 | String line = br.readLine(); 30 | // if current line is blank (end of last sentence), read next 31 | // line 32 | if (line.equals("")) 33 | line = br.readLine(); 34 | 35 | String[] parts = line.split("\t"); 36 | // sanity check 37 | if (!parts[1].toLowerCase().equals(w.getForm().toLowerCase())) { 38 | System.err 39 | .println("WARNING: different normalization applied? (" 40 | + parts[1] + " vs. " + w.getForm() + ")"); 41 | w.setLemma(w.getForm().replaceAll("[0-9]", "D")); 42 | } 43 | 44 | // CoNLL-X 45 | /**/w.setPOS(parts[3]); 46 | w.setHeadId(Integer.parseInt(parts[6])); 47 | w.setDeprel(parts[7]);/**/ 48 | 49 | // CoNLL-09 50 | /** 51 | * w.setPOS(parts[4]); w.setHeadId(Integer.parseInt(parts[8])); 52 | * w.setDeprel(parts[10]);/ 53 | **/ 54 | 55 | } 56 | s.buildDependencyTree(); 57 | 58 | } catch (IOException e) { 59 | e.printStackTrace(); 60 | System.exit(1); 61 | } 62 | } 63 | 64 | } 65 | -------------------------------------------------------------------------------- /src/main/java/se/lth/cs/srl/util/Util.java: -------------------------------------------------------------------------------- 1 | package se.lth.cs.srl.util; 2 | 3 | public class Util { 4 | 5 | public static String insertCommas(long l) { 6 | StringBuilder ret = new StringBuilder(Long.toString(l)); 7 | ret.reverse(); 8 | for (int i = 3; i < ret.length(); i += 4) { 9 | if (i + 1 <= ret.length()) 10 | ret.insert(i, ","); 11 | } 12 | return ret.reverse().toString(); 13 | } 14 | 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/AnySetFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import se.lth.cs.srl.corpus.Predicate; 4 | import se.lth.cs.srl.corpus.Sentence; 5 | import se.lth.cs.srl.corpus.Word; 6 | import se.lth.cs.srl.corpus.Word.WordData; 7 | import uk.ac.ed.inf.srl.features.FeatureName; 8 | import uk.ac.ed.inf.srl.features.SetFeature; 9 | 10 | public class AnySetFeature extends SetFeature { 11 | private static final long serialVersionUID = 1L; 12 | 13 | WordData attr; 14 | 15 | protected AnySetFeature(FeatureName name, WordData attr, 16 | boolean usedForPredicateIdentification, String POSPrefix) { 17 | super(name, true, false, POSPrefix); 18 | this.attr = attr; 19 | } 20 | 21 | @Override 22 | public String[] getFeatureStrings(Sentence s, int predIndex, int argIndex) { 23 | return makeFeatureStrings(s.get(argIndex)); 24 | } 25 | 26 | @Override 27 | public String[] getFeatureStrings(Predicate pred, Word arg) { 28 | return makeFeatureStrings(arg); 29 | } 30 | 31 | private String[] makeFeatureStrings(Word w) { 32 | String[] ret = new String[w.getSpan().size()]; 33 | int i = 0; 34 | for (Word child : w.getChildren()) 35 | ret[i++] = child.getAttr(attr); 36 | return ret; 37 | } 38 | 39 | @Override 40 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 41 | for (Word child : s) 42 | addMap(child.getAttr(attr)); 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/ArgDependentAttrFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import java.util.Collection; 4 | import java.util.Map; 5 | 6 | import se.lth.cs.srl.corpus.Predicate; 7 | import se.lth.cs.srl.corpus.Sentence; 8 | import se.lth.cs.srl.corpus.Word; 9 | import se.lth.cs.srl.corpus.Word.WordData; 10 | import uk.ac.ed.inf.srl.features.AttrFeature; 11 | import uk.ac.ed.inf.srl.features.FeatureName; 12 | import uk.ac.ed.inf.srl.features.TargetWord; 13 | 14 | public class ArgDependentAttrFeature extends AttrFeature { 15 | private static final long serialVersionUID = 1L; 16 | 17 | protected ArgDependentAttrFeature(FeatureName name, WordData attr, 18 | TargetWord tw, String POSPrefix) { 19 | super(name, attr, tw, true, false, POSPrefix); 20 | } 21 | 22 | @Override 23 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 24 | for (Predicate p : s.getPredicates()) { 25 | if (doExtractFeatures(p)) 26 | for (Word arg : p.getArgMap().keySet()) { 27 | for (Word w : arg.getSpan()) 28 | addMap(w.getAttr(attr)); 29 | } 30 | } 31 | } 32 | 33 | @Override 34 | public String getFeatureString(Sentence s, int predIndex, int argIndex) { 35 | Word w = wordExtractor.getWord(s, predIndex, argIndex); 36 | if (w == null) 37 | return null; 38 | else 39 | return w.getAttr(attr); 40 | } 41 | 42 | @Override 43 | public String getFeatureString(Predicate pred, Word arg) { 44 | Word w = wordExtractor.getWord(pred, arg); 45 | if (w == null) 46 | return null; 47 | else 48 | return w.getAttr(attr); 49 | } 50 | 51 | @Override 52 | public void addFeatures(Sentence s, Collection indices, 53 | Map nonbinFeats, int predIndex, int argIndex, 54 | Integer offset, boolean allWords) { 55 | addFeatures(indices, getFeatureString(s, predIndex, argIndex), offset, 56 | allWords); 57 | 58 | } 59 | 60 | @Override 61 | public void addFeatures(Collection indices, 62 | Map nonbinFeats, Predicate pred, Word arg, 63 | Integer offset, boolean allWords) { 64 | addFeatures(indices, getFeatureString(pred, arg), offset, allWords); 65 | } 66 | 67 | private void addFeatures(Collection indices, String featureString, 68 | Integer offset, boolean allWords) { 69 | if (featureString == null) 70 | return; 71 | Integer i = indexOf(featureString); 72 | if (i != -1 && (allWords || i < predMaxIndex)) 73 | indices.add(i + offset); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/ArgDependentBrown.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import se.lth.cs.srl.corpus.Predicate; 4 | import se.lth.cs.srl.corpus.Sentence; 5 | import se.lth.cs.srl.corpus.Word; 6 | import uk.ac.ed.inf.srl.features.ArgDependentAttrFeature; 7 | import uk.ac.ed.inf.srl.features.FeatureName; 8 | import uk.ac.ed.inf.srl.features.TargetWord; 9 | import se.lth.cs.srl.util.BrownCluster; 10 | import se.lth.cs.srl.util.BrownCluster.ClusterVal; 11 | 12 | public class ArgDependentBrown extends ArgDependentAttrFeature { 13 | private static final long serialVersionUID = 1L; 14 | 15 | private BrownCluster bc; 16 | private ClusterVal cv; 17 | 18 | protected ArgDependentBrown(FeatureName name, TargetWord tw, 19 | String POSPrefix, BrownCluster bc, ClusterVal cv) { 20 | super(name, null, tw, POSPrefix); 21 | this.bc = bc; 22 | this.cv = cv; 23 | } 24 | 25 | @Override 26 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 27 | for (Predicate p : s.getPredicates()) { 28 | if (doExtractFeatures(p)) 29 | for (Word arg : p.getArgMap().keySet()) { 30 | Word w = wordExtractor.getWord(null, arg); 31 | if (w != null) 32 | addMap(bc.getValue(w.getForm(), cv)); 33 | } 34 | } 35 | } 36 | 37 | @Override 38 | public String getFeatureString(Sentence s, int predIndex, int argIndex) { 39 | Word w = wordExtractor.getWord(s, predIndex, argIndex); 40 | if (w == null) 41 | return null; 42 | else 43 | return bc.getValue(w.getForm(), cv); 44 | 45 | } 46 | 47 | @Override 48 | public String getFeatureString(Predicate pred, Word arg) { 49 | Word w = wordExtractor.getWord(pred, arg); 50 | if (w == null) 51 | return null; 52 | else 53 | return bc.getValue(w.getForm(), cv); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/ArgDependentFeatsFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import java.util.Collection; 4 | import java.util.Map; 5 | 6 | import se.lth.cs.srl.corpus.Predicate; 7 | import se.lth.cs.srl.corpus.Sentence; 8 | import se.lth.cs.srl.corpus.Word; 9 | import uk.ac.ed.inf.srl.features.FeatsFeature; 10 | import uk.ac.ed.inf.srl.features.FeatureName; 11 | import uk.ac.ed.inf.srl.features.TargetWord; 12 | import se.lth.cs.srl.languages.Language; 13 | 14 | public class ArgDependentFeatsFeature extends FeatsFeature { 15 | private static final long serialVersionUID = 1L; 16 | 17 | protected ArgDependentFeatsFeature(FeatureName name, TargetWord tw, 18 | String POSPrefix) { 19 | super(name, tw, true, false, POSPrefix); 20 | } 21 | 22 | @Override 23 | public String[] getFeatureStrings(Sentence s, int predIndex, int argIndex) { 24 | Word w = wordExtractor.getWord(s, predIndex, argIndex); 25 | if (w == null) 26 | return null; 27 | else 28 | return Language.getLanguage().getFeatSplitPattern() 29 | .split(w.getFeats()); 30 | } 31 | 32 | @Override 33 | public String[] getFeatureStrings(Predicate pred, Word arg) { 34 | Word w = wordExtractor.getWord(pred, arg); 35 | if (w == null) 36 | return null; 37 | else 38 | return Language.getLanguage().getFeatSplitPattern() 39 | .split(w.getFeats()); 40 | } 41 | 42 | @Override 43 | public void addFeatures(Sentence s, Collection indices, 44 | Map nonbinFeats, int predIndex, int argIndex, 45 | Integer offset, boolean allWords) { 46 | addFeatures(indices, getFeatureStrings(s, predIndex, argIndex), offset, 47 | allWords); 48 | 49 | } 50 | 51 | @Override 52 | public void addFeatures(Collection indices, 53 | Map nonbinFeats, Predicate pred, Word arg, 54 | Integer offset, boolean allWords) { 55 | addFeatures(indices, getFeatureStrings(pred, arg), offset, allWords); 56 | } 57 | 58 | private void addFeatures(Collection indices, String[] values, 59 | Integer offset, boolean allWords) { 60 | if (values == null) 61 | return; 62 | for (String v : values) { 63 | Integer i = indexOf(v); 64 | if (i != -1 && (allWords || i < predMaxIndex)) 65 | indices.add(i + offset); 66 | } 67 | } 68 | 69 | @Override 70 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 71 | for (Predicate p : s.getPredicates()) { 72 | if (doExtractFeatures(p)) 73 | for (Word arg : p.getArgMap().keySet()) { 74 | Word w = wordExtractor.getWord(null, arg); 75 | if (w == null) 76 | continue; 77 | for (String v : Language.getLanguage() 78 | .getFeatSplitPattern().split(w.getFeats())) 79 | addMap(v); 80 | } 81 | } 82 | 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/AttrFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import se.lth.cs.srl.corpus.Word.WordData; 4 | import uk.ac.ed.inf.srl.features.FeatureName; 5 | import uk.ac.ed.inf.srl.features.SingleFeature; 6 | import uk.ac.ed.inf.srl.features.TargetWord; 7 | import uk.ac.ed.inf.srl.features.WordExtractor; 8 | 9 | public abstract class AttrFeature extends SingleFeature { 10 | private static final long serialVersionUID = 1; 11 | 12 | protected WordData attr; 13 | protected WordExtractor wordExtractor; 14 | 15 | protected AttrFeature(FeatureName name, WordData attr, TargetWord tw, 16 | boolean includeArgs, boolean usedForPredicateIdentification, 17 | String POSPrefix) { 18 | super(name, includeArgs, usedForPredicateIdentification, POSPrefix); 19 | this.attr = attr; 20 | this.wordExtractor = WordExtractor.getExtractor(tw); 21 | } 22 | 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/ChildSetFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import se.lth.cs.srl.corpus.Predicate; 4 | import se.lth.cs.srl.corpus.Sentence; 5 | import se.lth.cs.srl.corpus.Word; 6 | import se.lth.cs.srl.corpus.Word.WordData; 7 | import uk.ac.ed.inf.srl.features.FeatureName; 8 | import uk.ac.ed.inf.srl.features.SetFeature; 9 | 10 | public class ChildSetFeature extends SetFeature { 11 | private static final long serialVersionUID = 1L; 12 | 13 | WordData attr; 14 | 15 | protected ChildSetFeature(FeatureName name, WordData attr, 16 | boolean usedForPredicateIdentification, String POSPrefix) { 17 | super(name, false, usedForPredicateIdentification, POSPrefix); 18 | this.attr = attr; 19 | } 20 | 21 | @Override 22 | public String[] getFeatureStrings(Sentence s, int predIndex, int argIndex) { 23 | return makeFeatureStrings(s.get(predIndex)); 24 | } 25 | 26 | @Override 27 | public String[] getFeatureStrings(Predicate pred, Word arg) { 28 | return makeFeatureStrings(pred); 29 | } 30 | 31 | private String[] makeFeatureStrings(Word pred) { 32 | String[] ret = new String[pred.getChildren().size()]; 33 | int i = 0; 34 | for (Word child : pred.getChildren()) 35 | ret[i++] = child.getAttr(attr); 36 | return ret; 37 | } 38 | 39 | @Override 40 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 41 | if (allWords) { 42 | for (int i = 1, size = s.size(); i < size; ++i) { 43 | if (doExtractFeatures(s.get(i))) 44 | for (Word child : s.get(i).getChildren()) { 45 | addMap(child.getAttr(attr)); 46 | } 47 | } 48 | } else { 49 | for (Predicate pred : s.getPredicates()) { 50 | if (doExtractFeatures(pred)) 51 | for (Word child : pred.getChildren()) 52 | addMap(child.getAttr(attr)); 53 | } 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/ContinuousArgDependentAttrFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import java.util.Collection; 4 | import java.util.Map; 5 | 6 | import se.lth.cs.srl.corpus.Predicate; 7 | import se.lth.cs.srl.corpus.Sentence; 8 | import se.lth.cs.srl.corpus.Word; 9 | import se.lth.cs.srl.corpus.Word.WordData; 10 | import uk.ac.ed.inf.srl.features.ContinuousAttrFeature; 11 | import uk.ac.ed.inf.srl.features.FeatureName; 12 | import uk.ac.ed.inf.srl.features.TargetWord; 13 | 14 | public class ContinuousArgDependentAttrFeature extends ContinuousAttrFeature { 15 | private static final long serialVersionUID = 1L; 16 | 17 | protected ContinuousArgDependentAttrFeature(FeatureName name, 18 | WordData attr, TargetWord tw, String POSPrefix) { 19 | super(name, attr, tw, true, false, POSPrefix); 20 | } 21 | 22 | @Override 23 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 24 | // for(Predicate p:s.getPredicates()){ 25 | // if(doExtractFeatures(p)) 26 | // for(Word arg:p.getArgMap().keySet()){ 27 | // Word w=wordExtractor.getWord(null, arg); 28 | // if(w!=null) 29 | // addMap(w.getAttr(attr)); 30 | // } 31 | // } 32 | } 33 | 34 | @Override 35 | public Double getFeatureValue(Sentence s, int predIndex, int argIndex) { 36 | return 0.0; 37 | } 38 | 39 | @Override 40 | public String getFeatureString(Sentence s, int predIndex, int argIndex) { 41 | Word w = wordExtractor.getWord(s, predIndex, argIndex); 42 | if (w == null) 43 | return null; 44 | else 45 | return w.getAttr(attr); 46 | } 47 | 48 | @Override 49 | public Double getFeatureValue(Predicate pred, Word arg) { 50 | return 0.0; 51 | } 52 | 53 | @Override 54 | public String getFeatureString(Predicate pred, Word arg) { 55 | Word w = wordExtractor.getWord(pred, arg); 56 | if (w == null) 57 | return null; 58 | else 59 | return w.getAttr(attr); 60 | } 61 | 62 | @Override 63 | public void addFeatures(Sentence s, Collection indices, 64 | Map nonbinFeats, int predIndex, int argIndex, 65 | Integer offset, boolean allWords) { 66 | addFeatures(nonbinFeats, getFeatureString(s, predIndex, argIndex), 67 | getFeatureValue(s, predIndex, argIndex), offset, allWords); 68 | 69 | } 70 | 71 | @Override 72 | public void addFeatures(Collection indices, 73 | Map nonbinFeats, Predicate pred, Word arg, 74 | Integer offset, boolean allWords) { 75 | addFeatures(nonbinFeats, getFeatureString(pred, arg), 76 | getFeatureValue(pred, arg), offset, allWords); 77 | } 78 | 79 | private void addFeatures(Map nonbinFeats, 80 | String featureString, Double val, Integer offset, boolean allWords) { 81 | if (featureString == null) 82 | return; 83 | Integer i = indexOf(featureString); 84 | if (i != -1 && (allWords || i < predMaxIndex)) 85 | nonbinFeats.put(i + offset, val); 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/ContinuousAttrFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import se.lth.cs.srl.corpus.Word.WordData; 4 | import uk.ac.ed.inf.srl.features.ContinuousFeature; 5 | import uk.ac.ed.inf.srl.features.FeatureName; 6 | import uk.ac.ed.inf.srl.features.TargetWord; 7 | import uk.ac.ed.inf.srl.features.WordExtractor; 8 | 9 | public abstract class ContinuousAttrFeature extends ContinuousFeature { 10 | private static final long serialVersionUID = 1; 11 | 12 | protected WordData attr; 13 | protected WordExtractor wordExtractor; 14 | 15 | protected ContinuousAttrFeature(FeatureName name, WordData attr, 16 | TargetWord tw, boolean includeArgs, 17 | boolean usedForPredicateIdentification, String POSPrefix) { 18 | super(name, includeArgs, usedForPredicateIdentification, POSPrefix); 19 | this.attr = attr; 20 | this.wordExtractor = WordExtractor.getExtractor(tw); 21 | } 22 | 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/ContinuousFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import java.util.Collection; 4 | import java.util.Map; 5 | 6 | import se.lth.cs.srl.corpus.Predicate; 7 | import se.lth.cs.srl.corpus.Sentence; 8 | import se.lth.cs.srl.corpus.Word; 9 | import uk.ac.ed.inf.srl.features.Feature; 10 | import uk.ac.ed.inf.srl.features.FeatureName; 11 | 12 | public abstract class ContinuousFeature extends Feature { 13 | private static final long serialVersionUID = 1L; 14 | 15 | protected ContinuousFeature(FeatureName name, boolean includeArgs, 16 | boolean usedForPredicateIdentification, String POSPrefix) { 17 | super(name, includeArgs, usedForPredicateIdentification, POSPrefix); 18 | indexcounter = 2; 19 | } 20 | 21 | @Override 22 | public void addFeatures(Sentence s, Collection indices, 23 | Map nonbinFeats, int predIndex, int argIndex, 24 | Integer offset, boolean allWords) { 25 | // Integer i=indexOf(getFeatureString(s,predIndex,argIndex)); 26 | Double d = getFeatureValue(s, predIndex, argIndex); 27 | 28 | // if(i!=-1 && (allWords || i indices, 35 | Map nonbinFeats, Predicate pred, Word arg, 36 | Integer offset, boolean allWords) { 37 | // Integer i=indexOf(getFeatureString(pred,arg)); 38 | Double d = getFeatureValue(pred, arg); 39 | 40 | // if(i!=-1 && (allWords || i children) { 33 | switch (children.size()) { 34 | case 0: 35 | return " "; // This is the String corresponding to 0 children. Yet 36 | // this should not be ignored as a feature which it 37 | // would by the addMap() method if it would return the 38 | // empty string or null. Not really sure if this is 39 | // optimal, or this feature should be ignored. 40 | case 1: 41 | return children.iterator().next().getDeprel(); 42 | default: 43 | Word[] sortedChildren = children.toArray(new Word[0]); 44 | Arrays.sort(sortedChildren, s.wordComparator); 45 | StringBuilder ret = new StringBuilder(sortedChildren[0].getDeprel()); 46 | for (int i = 1, size = sortedChildren.length; i < size; ++i) 47 | ret.append(SEPARATOR).append(sortedChildren[i].getDeprel()); 48 | return ret.toString(); 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/DependencyCPathEmbedding.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import se.lth.cs.srl.corpus.Predicate; 4 | import se.lth.cs.srl.corpus.Sentence; 5 | import se.lth.cs.srl.corpus.Word; 6 | import uk.ac.ed.inf.srl.features.DependencyPathEmbedding; 7 | import uk.ac.ed.inf.srl.features.FeatureName; 8 | import uk.ac.ed.inf.srl.features.TargetWord; 9 | import uk.ac.ed.inf.srl.lstm.DataConverter; 10 | import uk.ac.ed.inf.srl.lstm.EmbeddingNetwork; 11 | 12 | public class DependencyCPathEmbedding extends DependencyPathEmbedding { 13 | private static final long serialVersionUID = 1L; 14 | 15 | protected DependencyCPathEmbedding(FeatureName name, TargetWord tw, 16 | String POSPrefix, boolean comp, EmbeddingNetwork net, DataConverter dc, 17 | int dim) { 18 | super(name, tw, POSPrefix, comp, net, dc, dim); 19 | 20 | for(int i=0; i values = new ArrayList<>(); 24 | values.addAll(f.indices.keySet()); 25 | Collections.sort(values); 26 | for (String value : values) { 27 | System.out.println(value + " - " + f.indexOf(value)); 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/FeatsFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import uk.ac.ed.inf.srl.features.FeatureName; 4 | import uk.ac.ed.inf.srl.features.SetFeature; 5 | import uk.ac.ed.inf.srl.features.TargetWord; 6 | import uk.ac.ed.inf.srl.features.WordExtractor; 7 | 8 | public abstract class FeatsFeature extends SetFeature { 9 | private static final long serialVersionUID = 1L; 10 | 11 | protected WordExtractor wordExtractor; 12 | 13 | protected FeatsFeature(FeatureName name, TargetWord tw, 14 | boolean includeArgs, boolean usedForPredicateIdentification, 15 | String POSPrefix) { 16 | super(name, includeArgs, usedForPredicateIdentification, POSPrefix); 17 | wordExtractor = WordExtractor.getExtractor(tw); 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/FeatureSet.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import java.util.Arrays; 4 | import java.util.HashMap; 5 | import java.util.List; 6 | import java.util.Map; 7 | 8 | import uk.ac.ed.inf.srl.features.Feature; 9 | 10 | /** 11 | * A feature set is basically a Map>, where String is a 12 | * prefix of a POS-tag, and List is the features used by the classifier 13 | * for that prefix. 14 | * 15 | * @author anders bjorkelund 16 | * 17 | */ 18 | 19 | public class FeatureSet extends HashMap> { 20 | private static final long serialVersionUID = 1L; 21 | /** 22 | * The prefixes of this Map, sorted in reverse order, i.e. longer prefixes 23 | * go before shorter. And the empty string comes last. 24 | */ 25 | public final String[] POSPrefixes; 26 | 27 | /** 28 | * Constructs a featureset based on the argument. The POSPrefixes are 29 | * extracted are sorted in reverse order on construction. 30 | * 31 | * @param featureSet 32 | * the featureset 33 | */ 34 | public FeatureSet(Map> featureSet) { 35 | super(featureSet); 36 | POSPrefixes = this.keySet().toArray(new String[0]); 37 | // Arrays.sort(POSPrefixes,Collections.reverseOrder()); 38 | Arrays.sort(POSPrefixes); 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/NumFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | public class NumFeature { 4 | static public String bin(int i) { 5 | if (i <= -20) 6 | return "-20"; 7 | if (i <= -10) 8 | return "-10"; 9 | if (i <= -5) 10 | return "-5"; 11 | 12 | if (i >= 20) 13 | return "20"; 14 | if (i >= 10) 15 | return "10"; 16 | if (i >= 5) 17 | return "5"; 18 | 19 | return Integer.toString(i); 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/PBLabelFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import java.util.Collection; 4 | import java.util.Map; 5 | 6 | import se.lth.cs.srl.corpus.Predicate; 7 | import se.lth.cs.srl.corpus.Sentence; 8 | import se.lth.cs.srl.corpus.Word; 9 | import uk.ac.ed.inf.srl.features.FeatureName; 10 | import uk.ac.ed.inf.srl.features.SingleFeature; 11 | import uk.ac.ed.inf.srl.features.TargetWord; 12 | 13 | public class PBLabelFeature extends SingleFeature { 14 | private static final long serialVersionUID = 1L; 15 | 16 | public PBLabelFeature(FeatureName name, TargetWord tw, 17 | boolean usedForPredicateIdentification, String POSPrefix) { 18 | super(name, false, usedForPredicateIdentification, POSPrefix); 19 | } 20 | 21 | @Override 22 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 23 | for (Predicate p : s.getPredicates()) { 24 | if (doExtractFeatures(p)) 25 | for (Word arg : p.getArgMap().keySet()) { 26 | addMap(p.getArgMap().get(arg)); 27 | } 28 | } 29 | } 30 | 31 | @Override 32 | public String getFeatureString(Sentence s, int predIndex, int argIndex) { 33 | return ((Predicate) s.get(predIndex)).getArgMap().get(/* 34 | * wordExtractor.getWord 35 | * (s, predIndex, 36 | * argIndex) 37 | */s.get(argIndex)); 38 | } 39 | 40 | @Override 41 | public String getFeatureString(Predicate pred, Word arg) { 42 | return pred.getArgMap().get(/* wordExtractor.getWord(pred, */arg); 43 | } 44 | 45 | @Override 46 | public void addFeatures(Sentence s, Collection indices, 47 | Map nonbinFeats, int predIndex, int argIndex, 48 | Integer offset, boolean allWords) { 49 | addFeatures(indices, getFeatureString(s, predIndex, argIndex), offset, 50 | allWords); 51 | } 52 | 53 | @Override 54 | public void addFeatures(Collection indices, 55 | Map nonbinFeats, Predicate pred, Word arg, 56 | Integer offset, boolean allWords) { 57 | addFeatures(indices, getFeatureString(pred, arg), offset, allWords); 58 | } 59 | 60 | private void addFeatures(Collection indices, String featureString, 61 | Integer offset, boolean allWords) { 62 | if (featureString == null) { 63 | return; 64 | } 65 | Integer i = indexOf(featureString); 66 | if (i != -1 && (allWords || i < predMaxIndex)) 67 | indices.add(i + offset); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/PathFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import java.util.List; 4 | 5 | import se.lth.cs.srl.corpus.Predicate; 6 | import se.lth.cs.srl.corpus.Sentence; 7 | import se.lth.cs.srl.corpus.Word; 8 | import se.lth.cs.srl.corpus.Word.WordData; 9 | import uk.ac.ed.inf.srl.features.FeatureName; 10 | import uk.ac.ed.inf.srl.features.SingleFeature; 11 | 12 | public class PathFeature extends SingleFeature { 13 | private static final long serialVersionUID = 1L; 14 | 15 | private static final String UP = "0"; 16 | private static final String DOWN = "1"; 17 | 18 | private boolean consider_deptree; 19 | private WordData attr; 20 | 21 | protected PathFeature(FeatureName name, WordData attr, 22 | boolean consider_deptree, String POSPrefix) { 23 | super(name, true, false, POSPrefix); 24 | this.attr = attr; 25 | this.consider_deptree = consider_deptree; 26 | } 27 | 28 | @Override 29 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 30 | for (Predicate pred : s.getPredicates()) { 31 | if (doExtractFeatures(pred)) 32 | for (Word arg : pred.getArgMap().keySet()) { 33 | addMap(getFeatureString(pred, arg)); 34 | } 35 | } 36 | } 37 | 38 | @Override 39 | public String getFeatureString(Sentence s, int predIndex, int argIndex) { 40 | return makeFeatureString(s.get(predIndex), s.get(argIndex)); 41 | } 42 | 43 | @Override 44 | public String getFeatureString(Predicate pred, Word arg) { 45 | return makeFeatureString(pred, arg); 46 | } 47 | 48 | public String makeFeatureString(Word pred, Word arg) { 49 | StringBuilder ret = new StringBuilder(); 50 | if (consider_deptree) { 51 | return ret.append(makeDepBasedFeatureString(pred, arg)).toString(); 52 | } 53 | 54 | ret.append("NODEP"); 55 | boolean up = true; 56 | if (pred.getIdx() < arg.getIdx()) 57 | up = false; 58 | 59 | if (Math.abs(pred.getIdx() - arg.getIdx()) == 0) 60 | return " "; 61 | 62 | Sentence s = pred.getMySentence(); 63 | for (int i = up ? arg.getIdx() : pred.getIdx(); i < (up ? pred.getIdx() 64 | : arg.getIdx()); i++) { 65 | ret.append(s.get(i).getAttr(attr)); 66 | ret.append(up ? UP : DOWN); 67 | } 68 | return ret.toString(); 69 | } 70 | 71 | public String makeDepBasedFeatureString(Word pred, Word arg) { 72 | boolean up = true; 73 | List path = Word.findPath(pred, arg); 74 | if (path.size() == 0) 75 | return " "; 76 | StringBuilder ret = new StringBuilder(); 77 | for (int i = 0; i < (path.size() - 1); ++i) { 78 | Word w = path.get(i); 79 | ret.append(w.getAttr(attr)); 80 | if (up) { 81 | if (w.getHead() == path.get(i + 1)) { // Arrow up 82 | ret.append(UP); 83 | } else { // Arrow down 84 | ret.append(DOWN); 85 | up = false; 86 | } 87 | } else { 88 | ret.append(DOWN); 89 | } 90 | } 91 | return ret.toString(); 92 | 93 | } 94 | 95 | } 96 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/PathLengthFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import java.util.List; 4 | 5 | import se.lth.cs.srl.corpus.Predicate; 6 | import se.lth.cs.srl.corpus.Sentence; 7 | import se.lth.cs.srl.corpus.Word; 8 | import se.lth.cs.srl.corpus.Word.WordData; 9 | import uk.ac.ed.inf.srl.features.FeatureName; 10 | import uk.ac.ed.inf.srl.features.NumFeature; 11 | import uk.ac.ed.inf.srl.features.SingleFeature; 12 | 13 | public class PathLengthFeature extends SingleFeature { 14 | private static final long serialVersionUID = 1L; 15 | 16 | protected PathLengthFeature(FeatureName name, WordData attr, 17 | boolean consider_deptree, String POSPrefix) { 18 | super(name, true, false, POSPrefix); 19 | } 20 | 21 | @Override 22 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 23 | for (Predicate pred : s.getPredicates()) { 24 | if (doExtractFeatures(pred)) 25 | for (Word arg : pred.getArgMap().keySet()) { 26 | addMap(getFeatureString(pred, arg)); 27 | } 28 | } 29 | } 30 | 31 | @Override 32 | public String getFeatureString(Sentence s, int predIndex, int argIndex) { 33 | return makeFeatureString(s.get(predIndex), s.get(argIndex)); 34 | } 35 | 36 | @Override 37 | public String getFeatureString(Predicate pred, Word arg) { 38 | return makeFeatureString(pred, arg); 39 | } 40 | 41 | public String makeFeatureString(Word pred, Word arg) { 42 | List path = Word.findPath(pred, arg); 43 | return "PathLength" + NumFeature.bin(path.size()); 44 | } 45 | 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/PositionFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import java.util.Collection; 4 | import java.util.Map; 5 | 6 | import se.lth.cs.srl.corpus.Predicate; 7 | import se.lth.cs.srl.corpus.Sentence; 8 | import se.lth.cs.srl.corpus.Word; 9 | import uk.ac.ed.inf.srl.features.FeatureName; 10 | import uk.ac.ed.inf.srl.features.SingleFeature; 11 | 12 | public class PositionFeature extends SingleFeature { 13 | private static final long serialVersionUID = 1L; 14 | /* 15 | * Position is the position of the argument wrt the predicate. I.e. if the 16 | * predicate is at position 2, and the argument at position 4, their 17 | * relation is AFTER 18 | */ 19 | public static final String BEFORE = "B"; 20 | public static final String ON = "O"; 21 | public static final String AFTER = "A"; 22 | 23 | protected PositionFeature(String POSPrefix) { 24 | super(FeatureName.Position, true, false, POSPrefix); 25 | indices.put(BEFORE, Integer.valueOf(1)); 26 | indices.put(ON, Integer.valueOf(2)); 27 | indices.put(AFTER, Integer.valueOf(3)); 28 | indexcounter = 4; 29 | } 30 | 31 | @Override 32 | public void addFeatures(Sentence s, Collection indices, 33 | Map nonbinFeats, int predIndex, int argIndex, 34 | Integer offset, boolean allWords) { 35 | indices.add(indexOf(getFeatureString(s, predIndex, argIndex)) + offset); 36 | } 37 | 38 | @Override 39 | public void addFeatures(Collection indices, 40 | Map nonbinFeats, Predicate pred, Word arg, 41 | Integer offset, boolean allWords) { 42 | indices.add(indexOf(getFeatureString(pred, arg)) + offset); 43 | } 44 | 45 | @Override 46 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 47 | // Do nothing, the map is constructed in the constructor. 48 | } 49 | 50 | @Override 51 | public String getFeatureString(Sentence s, int predIndex, int argIndex) { 52 | if (predIndex == argIndex) 53 | return ON; 54 | else if (predIndex < argIndex) 55 | return AFTER; 56 | else 57 | return BEFORE; 58 | } 59 | 60 | @Override 61 | public String getFeatureString(Predicate pred, Word arg) { 62 | // int cmp=pred.compareTo(arg); 63 | int cmp = pred.getMySentence().wordComparator.compare(pred, arg); 64 | if (cmp < 0) { 65 | return AFTER; 66 | } else if (cmp == 0) { 67 | return ON; 68 | } else { 69 | return BEFORE; 70 | } 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/PredDependentAttrFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import se.lth.cs.srl.corpus.Predicate; 4 | import se.lth.cs.srl.corpus.Sentence; 5 | import se.lth.cs.srl.corpus.Word; 6 | import se.lth.cs.srl.corpus.Word.WordData; 7 | import uk.ac.ed.inf.srl.features.AttrFeature; 8 | import uk.ac.ed.inf.srl.features.FeatureName; 9 | import uk.ac.ed.inf.srl.features.TargetWord; 10 | 11 | public class PredDependentAttrFeature extends AttrFeature { 12 | private static final long serialVersionUID = 1L; 13 | 14 | protected PredDependentAttrFeature(FeatureName name, WordData attr, 15 | TargetWord tw, boolean usedForPredicateIdentification, 16 | String POSPrefix) { 17 | super(name, attr, tw, false, usedForPredicateIdentification, POSPrefix); 18 | } 19 | 20 | @Override 21 | public String getFeatureString(Sentence s, int predIndex, int argIndex) { 22 | if (wordExtractor.getWord(s, predIndex, argIndex) == null) 23 | return null; 24 | return wordExtractor.getWord(s, predIndex, argIndex).getAttr(attr); 25 | } 26 | 27 | @Override 28 | public String getFeatureString(Predicate pred, Word arg) { 29 | if (wordExtractor.getWord(pred, arg) == null) 30 | return null; 31 | return wordExtractor.getWord(pred, arg).getAttr(attr); 32 | } 33 | 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/PredDependentBrown.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import se.lth.cs.srl.corpus.Predicate; 4 | import se.lth.cs.srl.corpus.Sentence; 5 | import se.lth.cs.srl.corpus.Word; 6 | import uk.ac.ed.inf.srl.features.FeatureName; 7 | import uk.ac.ed.inf.srl.features.PredDependentAttrFeature; 8 | import uk.ac.ed.inf.srl.features.TargetWord; 9 | import se.lth.cs.srl.util.BrownCluster; 10 | import se.lth.cs.srl.util.BrownCluster.ClusterVal; 11 | 12 | public class PredDependentBrown extends PredDependentAttrFeature { 13 | private static final long serialVersionUID = 1L; 14 | 15 | private BrownCluster bc; 16 | private ClusterVal cv; 17 | 18 | protected PredDependentBrown(FeatureName name, TargetWord tw, 19 | boolean includeAllWords, String POSPrefix, BrownCluster bc, 20 | ClusterVal cv) { 21 | super(name, null, tw, includeAllWords, POSPrefix); 22 | this.bc = bc; 23 | this.cv = cv; 24 | } 25 | 26 | @Override 27 | public String getFeatureString(Sentence s, int predIndex, int argIndex) { 28 | Word w = wordExtractor.getWord(s, predIndex, argIndex); 29 | if (w.isBOS()) 30 | return "ROOT"; 31 | return bc.getValue(w.getForm(), cv); 32 | } 33 | 34 | @Override 35 | public String getFeatureString(Predicate pred, Word arg) { 36 | Word w = wordExtractor.getWord(pred, arg); 37 | if (w.isBOS()) 38 | return "ROOT"; 39 | return bc.getValue(w.getForm(), cv); 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/PredDependentEmbedding.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import se.lth.cs.srl.corpus.Predicate; 4 | import se.lth.cs.srl.corpus.Sentence; 5 | import se.lth.cs.srl.corpus.Word; 6 | import uk.ac.ed.inf.srl.features.ContinuousFeature; 7 | import uk.ac.ed.inf.srl.features.FeatureName; 8 | import uk.ac.ed.inf.srl.features.TargetWord; 9 | import uk.ac.ed.inf.srl.features.WordExtractor; 10 | import se.lth.cs.srl.util.WordEmbedding; 11 | 12 | public class PredDependentEmbedding extends ContinuousFeature { 13 | private static final long serialVersionUID = 1L; 14 | 15 | private WordExtractor wordExtractor; 16 | private WordEmbedding bc; 17 | private int dim; 18 | private boolean tokenembedding; 19 | 20 | protected PredDependentEmbedding(FeatureName name, TargetWord tw, 21 | boolean includeAllWords, String POSPrefix, WordEmbedding bc, 22 | int dim, boolean token) { 23 | // super(name, null, tw, includeAllWords, POSPrefix); 24 | super(name, false, true, POSPrefix); 25 | indices.put("WEPRED" + dim, 1); 26 | 27 | this.wordExtractor = WordExtractor.getExtractor(tw); 28 | this.bc = bc; 29 | this.dim = dim; 30 | this.tokenembedding = token; 31 | } 32 | 33 | @Override 34 | public Double getFeatureValue(Sentence s, int predIndex, int argIndex) { 35 | Word w = wordExtractor.getWord(s, predIndex, argIndex); 36 | if (w == null) 37 | return 0.0; 38 | else 39 | return this.getValue(w, dim); 40 | // return w.getRep(dim); 41 | } 42 | 43 | @Override 44 | public String getFeatureString(Sentence s, int predIndex, int argIndex) { 45 | Word w = wordExtractor.getWord(s, predIndex, argIndex); 46 | return "WEPRED" + dim; 47 | } 48 | 49 | @Override 50 | public Double getFeatureValue(Predicate pred, Word arg) { 51 | Word w = wordExtractor.getWord(pred, arg); 52 | if (w == null) 53 | return 0.0; 54 | else { 55 | // return w.getRep(dim); 56 | return this.getValue(w, dim); 57 | } 58 | } 59 | 60 | @Override 61 | public String getFeatureString(Predicate pred, Word arg) { 62 | return "WEPRED" + dim; 63 | } 64 | 65 | private Double getValue(Word w, int dim) { 66 | if (tokenembedding) 67 | return w.getRep(dim); 68 | else 69 | return bc.getValue(w.getForm(), dim); 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/PredDependentFeatsFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import se.lth.cs.srl.corpus.Predicate; 4 | import se.lth.cs.srl.corpus.Sentence; 5 | import se.lth.cs.srl.corpus.Word; 6 | import uk.ac.ed.inf.srl.features.FeatsFeature; 7 | import uk.ac.ed.inf.srl.features.FeatureName; 8 | import uk.ac.ed.inf.srl.features.TargetWord; 9 | import se.lth.cs.srl.languages.Language; 10 | 11 | public class PredDependentFeatsFeature extends FeatsFeature { 12 | private static final long serialVersionUID = 1L; 13 | 14 | protected PredDependentFeatsFeature(FeatureName name, TargetWord tw, 15 | boolean usedForPredicateIdentification, String POSPrefix) { 16 | super(name, tw, false, usedForPredicateIdentification, POSPrefix); 17 | } 18 | 19 | @Override 20 | public String[] getFeatureStrings(Sentence s, int predIndex, int argIndex) { 21 | Word w = wordExtractor.getWord(s, predIndex, argIndex); 22 | return Language.getLanguage().getFeatSplitPattern().split(w.getFeats()); 23 | } 24 | 25 | @Override 26 | public String[] getFeatureStrings(Predicate pred, Word arg) { 27 | Word w = wordExtractor.getWord(pred, arg); 28 | return Language.getLanguage().getFeatSplitPattern().split(w.getFeats()); 29 | } 30 | 31 | @Override 32 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 33 | if (allWords) { 34 | for (int i = 1, size = s.size(); i < size; ++i) { 35 | if (doExtractFeatures(s.get(i))) 36 | for (String v : getFeatureStrings(s, i, -1)) 37 | addMap(v); 38 | } 39 | } else { 40 | for (Predicate pred : s.getPredicates()) { 41 | if (doExtractFeatures(pred)) 42 | for (String v : getFeatureStrings(pred, null)) 43 | addMap(v); 44 | } 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/QContinuousSetFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import java.util.Collection; 4 | import java.util.Map; 5 | 6 | import se.lth.cs.srl.corpus.Predicate; 7 | import se.lth.cs.srl.corpus.Sentence; 8 | import se.lth.cs.srl.corpus.Word; 9 | import uk.ac.ed.inf.srl.features.ContinuousFeature; 10 | import uk.ac.ed.inf.srl.features.FeatureGenerator; 11 | import uk.ac.ed.inf.srl.features.QuadraticFeature; 12 | import uk.ac.ed.inf.srl.features.SetFeature; 13 | 14 | public class QContinuousSetFeature extends SetFeature implements 15 | QuadraticFeature { 16 | private static final long serialVersionUID = 1L; 17 | 18 | private ContinuousFeature f1; 19 | private SetFeature f2; 20 | 21 | protected QContinuousSetFeature(ContinuousFeature f1, SetFeature f2, 22 | boolean usedForPredicateIdentification, String POSPrefix) { 23 | super(f1.name, f1.includeArgs || f2.includeArgs, 24 | usedForPredicateIdentification, POSPrefix); 25 | this.f1 = f1; 26 | this.f2 = f2; 27 | } 28 | 29 | @Override 30 | public void addFeatures(Sentence s, Collection indices, 31 | Map nonbinFeats, int predIndex, int argIndex, 32 | Integer offset, boolean allWords) { 33 | for (String v : getFeatureStrings(s, predIndex, argIndex)) { 34 | Integer i = indexOf(v); 35 | Double d = f1.getFeatureValue(s, predIndex, argIndex); 36 | if (i != 1 && (allWords || i < predMaxIndex)) 37 | nonbinFeats.put(i + offset, d); 38 | } 39 | } 40 | 41 | @Override 42 | public void addFeatures(Collection indices, 43 | Map nonbinFeats, Predicate pred, Word arg, 44 | Integer offset, boolean allWords) { 45 | for (String v : getFeatureStrings(pred, arg)) { 46 | Integer i = indexOf(v); 47 | Double d = f1.getFeatureValue(pred, arg); 48 | if (i != -1 && (allWords || i < predMaxIndex)) 49 | nonbinFeats.put(i + offset, d); 50 | } 51 | } 52 | 53 | @Override 54 | public String[] getFeatureStrings(Sentence s, int predIndex, int argIndex) { 55 | String f1val = f1.getFeatureString(s, predIndex, argIndex); 56 | String[] f2vals = f2.getFeatureStrings(s, predIndex, argIndex); 57 | makeFeatureStrings(f1val, f2vals); 58 | return f2vals; 59 | } 60 | 61 | @Override 62 | public String[] getFeatureStrings(Predicate pred, Word arg) { 63 | String f1val = f1.getFeatureString(pred, arg); 64 | String[] f2vals = f2.getFeatureStrings(pred, arg); 65 | if (f2vals != null) { 66 | makeFeatureStrings(f1val, f2vals); 67 | return f2vals; 68 | } else { 69 | return new String[] { "" }; 70 | // return new String[0]; 71 | } 72 | } 73 | 74 | private void makeFeatureStrings(String f1val, String[] f2vals) { 75 | for (int i = 0, length = f2vals.length; i < length; ++i) 76 | f2vals[i] += VALUE_SEPARATOR + f1val; 77 | } 78 | 79 | public String getName() { 80 | return FeatureGenerator.getCanonicalName(f1.name, f2.name); 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/QDoubleChildSetFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import java.util.Set; 4 | 5 | import se.lth.cs.srl.corpus.Predicate; 6 | import se.lth.cs.srl.corpus.Sentence; 7 | import se.lth.cs.srl.corpus.Word; 8 | import uk.ac.ed.inf.srl.features.ChildSetFeature; 9 | import uk.ac.ed.inf.srl.features.FeatureGenerator; 10 | import uk.ac.ed.inf.srl.features.QuadraticFeature; 11 | import uk.ac.ed.inf.srl.features.SetFeature; 12 | 13 | public class QDoubleChildSetFeature extends SetFeature implements 14 | QuadraticFeature { 15 | private static final long serialVersionUID = 1L; 16 | 17 | private ChildSetFeature f1, f2; 18 | 19 | protected QDoubleChildSetFeature(ChildSetFeature f1, ChildSetFeature f2, 20 | boolean usedForPredicateIdentification, String POSPrefix) { 21 | super(f1.name, f1.includeArgs && f2.includeArgs, 22 | usedForPredicateIdentification, POSPrefix); // The boolean 23 | // should always 24 | // evaluate to 25 | // false, seeing as 26 | // ChildSetFeatures 27 | // are always 28 | // focused on the 29 | // pred 30 | this.f1 = f1; 31 | this.f2 = f2; 32 | } 33 | 34 | @Override 35 | public String[] getFeatureStrings(Sentence s, int predIndex, int argIndex) { 36 | return makeFeatureStrings(s.get(predIndex).getChildren()); 37 | } 38 | 39 | @Override 40 | public String[] getFeatureStrings(Predicate pred, Word arg) { 41 | return makeFeatureStrings(pred.getChildren()); 42 | } 43 | 44 | private String[] makeFeatureStrings(Set children) { 45 | String[] ret = new String[children.size()]; 46 | int i = 0; 47 | for (Word child : children) 48 | ret[i++] = child.getAttr(f1.attr) + VALUE_SEPARATOR 49 | + child.getAttr(f2.attr); 50 | return ret; 51 | } 52 | 53 | @Override 54 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 55 | if (includeArgs) { 56 | throw new Error("You are wrong here."); 57 | } else { 58 | if (allWords) { 59 | for (int i = 1, size = s.size(); i < size; ++i) { 60 | if (doExtractFeatures(s.get(i))) 61 | for (String v : getFeatureStrings(s, i, -1)) 62 | addMap(v); 63 | } 64 | } else { 65 | for (Predicate pred : s.getPredicates()) { 66 | if (doExtractFeatures(pred)) 67 | for (String v : getFeatureStrings(pred, null)) 68 | addMap(v); 69 | } 70 | } 71 | } 72 | } 73 | 74 | public String getName() { 75 | return FeatureGenerator.getCanonicalName(f1.name, f2.name); 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/QSetSetFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import se.lth.cs.srl.corpus.Predicate; 4 | import se.lth.cs.srl.corpus.Sentence; 5 | import se.lth.cs.srl.corpus.Word; 6 | import uk.ac.ed.inf.srl.features.FeatureGenerator; 7 | import uk.ac.ed.inf.srl.features.QuadraticFeature; 8 | import uk.ac.ed.inf.srl.features.SetFeature; 9 | 10 | public class QSetSetFeature extends SetFeature implements QuadraticFeature { 11 | private static final long serialVersionUID = 1L; 12 | 13 | private SetFeature f1; 14 | private SetFeature f2; 15 | 16 | protected QSetSetFeature(SetFeature f1, SetFeature f2, 17 | boolean usedForPredicateIdentification, String POSPrefix) { 18 | super(f1.name, f1.includeArgs || f2.includeArgs, 19 | usedForPredicateIdentification, POSPrefix); 20 | this.f1 = f1; 21 | this.f2 = f2; 22 | } 23 | 24 | @Override 25 | public String[] getFeatureStrings(Sentence s, int predIndex, int argIndex) { 26 | String[] f1vals = f1.getFeatureStrings(s, predIndex, argIndex); 27 | String[] f2vals = f2.getFeatureStrings(s, predIndex, argIndex); 28 | makeFeatureStrings(f1vals, f2vals); 29 | return f2vals; 30 | } 31 | 32 | @Override 33 | public String[] getFeatureStrings(Predicate pred, Word arg) { 34 | String[] f1vals = f1.getFeatureStrings(pred, arg); 35 | String[] f2vals = f2.getFeatureStrings(pred, arg); 36 | if (f2vals != null) { 37 | makeFeatureStrings(f1vals, f2vals); 38 | return f2vals; 39 | } else { 40 | return new String[] { "" }; 41 | // return new String[0]; 42 | } 43 | } 44 | 45 | private void makeFeatureStrings(String[] f1vals, String[] f2vals) { 46 | 47 | if (f1vals.length != f2vals.length) { 48 | System.err 49 | .println("CHECK YOUR IMPLEMENTATION! Trying to combine two set features of different lengths"); 50 | System.exit(1); 51 | } 52 | 53 | for (int i = 0, length = f2vals.length; i < length; ++i) 54 | f2vals[i] += VALUE_SEPARATOR + f1vals[i]; 55 | } 56 | 57 | public String getName() { 58 | return FeatureGenerator.getCanonicalName(f1.name, f2.name); 59 | } 60 | 61 | } 62 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/QSingleSetFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import se.lth.cs.srl.corpus.Predicate; 4 | import se.lth.cs.srl.corpus.Sentence; 5 | import se.lth.cs.srl.corpus.Word; 6 | import uk.ac.ed.inf.srl.features.FeatureGenerator; 7 | import uk.ac.ed.inf.srl.features.QuadraticFeature; 8 | import uk.ac.ed.inf.srl.features.SetFeature; 9 | import uk.ac.ed.inf.srl.features.SingleFeature; 10 | 11 | public class QSingleSetFeature extends SetFeature implements QuadraticFeature { 12 | private static final long serialVersionUID = 1L; 13 | 14 | private SingleFeature f1; 15 | private SetFeature f2; 16 | 17 | protected QSingleSetFeature(SingleFeature f1, SetFeature f2, 18 | boolean usedForPredicateIdentification, String POSPrefix) { 19 | super(f1.name, f1.includeArgs || f2.includeArgs, 20 | usedForPredicateIdentification, POSPrefix); 21 | this.f1 = f1; 22 | this.f2 = f2; 23 | } 24 | 25 | @Override 26 | public String[] getFeatureStrings(Sentence s, int predIndex, int argIndex) { 27 | String f1val = f1.getFeatureString(s, predIndex, argIndex); 28 | String[] f2vals = f2.getFeatureStrings(s, predIndex, argIndex); 29 | makeFeatureStrings(f1val, f2vals); 30 | return f2vals; 31 | } 32 | 33 | @Override 34 | public String[] getFeatureStrings(Predicate pred, Word arg) { 35 | String f1val = f1.getFeatureString(pred, arg); 36 | String[] f2vals = f2.getFeatureStrings(pred, arg); 37 | if (f2vals != null) { 38 | makeFeatureStrings(f1val, f2vals); 39 | return f2vals; 40 | } else { 41 | return new String[] { "" }; 42 | // return new String[0]; 43 | } 44 | } 45 | 46 | private void makeFeatureStrings(String f1val, String[] f2vals) { 47 | for (int i = 0, length = f2vals.length; i < length; ++i) 48 | f2vals[i] += VALUE_SEPARATOR + f1val; 49 | } 50 | 51 | public String getName() { 52 | return FeatureGenerator.getCanonicalName(f1.name, f2.name); 53 | } 54 | 55 | } 56 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/QSingleSingleFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import se.lth.cs.srl.corpus.Predicate; 4 | import se.lth.cs.srl.corpus.Sentence; 5 | import se.lth.cs.srl.corpus.Word; 6 | import uk.ac.ed.inf.srl.features.FeatureGenerator; 7 | import uk.ac.ed.inf.srl.features.QuadraticFeature; 8 | import uk.ac.ed.inf.srl.features.SingleFeature; 9 | 10 | public class QSingleSingleFeature extends SingleFeature implements 11 | QuadraticFeature { 12 | private static final long serialVersionUID = 1L; 13 | 14 | private SingleFeature f1, f2; 15 | 16 | protected QSingleSingleFeature(SingleFeature f1, SingleFeature f2, 17 | boolean usedForPredicateIdentification, String POSPrefix) { 18 | super(f1.name, f1.includeArgs || f2.includeArgs, 19 | usedForPredicateIdentification, POSPrefix); 20 | this.f1 = f1; 21 | this.f2 = f2; 22 | } 23 | 24 | @Override 25 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 26 | if (includeArgs) { 27 | for (Predicate pred : s.getPredicates()) { 28 | if (doExtractFeatures(pred)) 29 | for (Word arg : pred.getArgMap().keySet()) { 30 | addMap(getFeatureString(pred, arg)); 31 | } 32 | } 33 | } else { 34 | super.performFeatureExtraction(s, allWords); 35 | } 36 | } 37 | 38 | @Override 39 | public String getFeatureString(Sentence s, int predIndex, int argIndex) { 40 | return f1.getFeatureString(s, predIndex, argIndex) + VALUE_SEPARATOR 41 | + f2.getFeatureString(s, predIndex, argIndex); 42 | } 43 | 44 | @Override 45 | public String getFeatureString(Predicate pred, Word arg) { 46 | return f1.getFeatureString(pred, arg) + VALUE_SEPARATOR 47 | + f2.getFeatureString(pred, arg); 48 | } 49 | 50 | public String getName() { 51 | return FeatureGenerator.getCanonicalName(f1.name, f2.name); 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/QuadraticFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | public interface QuadraticFeature { 4 | 5 | public static final String VALUE_SEPARATOR = " + "; 6 | 7 | } 8 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/SameSubTreeFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import java.util.List; 4 | 5 | import se.lth.cs.srl.corpus.Predicate; 6 | import se.lth.cs.srl.corpus.Sentence; 7 | import se.lth.cs.srl.corpus.Word; 8 | import se.lth.cs.srl.corpus.Word.WordData; 9 | import uk.ac.ed.inf.srl.features.FeatureName; 10 | import uk.ac.ed.inf.srl.features.SingleFeature; 11 | 12 | public class SameSubTreeFeature extends SingleFeature { 13 | private static final long serialVersionUID = 1L; 14 | 15 | boolean parent; 16 | 17 | protected SameSubTreeFeature(FeatureName name, WordData attr, 18 | boolean consider_parent, String POSPrefix) { 19 | super(name, true, false, POSPrefix); 20 | this.parent = consider_parent; 21 | } 22 | 23 | @Override 24 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 25 | for (Predicate pred : s.getPredicates()) { 26 | if (doExtractFeatures(pred)) 27 | for (Word arg : pred.getArgMap().keySet()) { 28 | addMap(getFeatureString(pred, arg)); 29 | } 30 | } 31 | } 32 | 33 | @Override 34 | public String getFeatureString(Sentence s, int predIndex, int argIndex) { 35 | return makeFeatureString(s.get(predIndex), s.get(argIndex)); 36 | } 37 | 38 | @Override 39 | public String getFeatureString(Predicate pred, Word arg) { 40 | return makeFeatureString(pred, arg); 41 | } 42 | 43 | public String makeFeatureString(Word pred, Word arg) { 44 | List path = Word.findPath(parent ? pred.getHead() : pred, arg); 45 | boolean inTree = true; 46 | for (int i = 0; i < (path.size() - 1); ++i) { 47 | Word w = path.get(i); 48 | // path goes up instead of down 49 | if (w.getHead() == path.get(i + 1)) { 50 | inTree = false; 51 | break; 52 | } 53 | } 54 | 55 | return (parent ? "Parent" : "") + "SubTree" + (inTree ? 1 : 0); 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/SetFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import java.util.Collection; 4 | import java.util.Map; 5 | 6 | import se.lth.cs.srl.corpus.Predicate; 7 | import se.lth.cs.srl.corpus.Sentence; 8 | import se.lth.cs.srl.corpus.Word; 9 | import uk.ac.ed.inf.srl.features.Feature; 10 | import uk.ac.ed.inf.srl.features.FeatureName; 11 | 12 | public abstract class SetFeature extends Feature { 13 | private static final long serialVersionUID = 1L; 14 | 15 | protected SetFeature(FeatureName name, boolean includeArgs, 16 | boolean usedForPredicateIdentification, String POSPrefix) { 17 | super(name, includeArgs, usedForPredicateIdentification, POSPrefix); 18 | } 19 | 20 | public abstract String[] getFeatureStrings(Sentence s, int predIndex, 21 | int argIndex); 22 | 23 | public abstract String[] getFeatureStrings(Predicate pred, Word arg); 24 | 25 | @Override 26 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 27 | if (includeArgs) { 28 | for (Predicate pred : s.getPredicates()) { 29 | if (doExtractFeatures(pred)) 30 | for (Word arg : pred.getArgMap().keySet()) { 31 | for (String v : getFeatureStrings(pred, arg)) 32 | addMap(v); 33 | } 34 | } 35 | } else { 36 | if (allWords) { 37 | for (int i = 1, size = s.size(); i < size; ++i) { 38 | if (doExtractFeatures(s.get(i))) 39 | for (String v : getFeatureStrings(s, i, -1)) 40 | addMap(v); 41 | } 42 | } else { 43 | for (Predicate pred : s.getPredicates()) { 44 | if (doExtractFeatures(pred)) 45 | for (String v : getFeatureStrings(pred, null)) 46 | addMap(v); 47 | } 48 | } 49 | } 50 | } 51 | 52 | @Override 53 | public void addFeatures(Sentence s, Collection indices, 54 | Map nonbinFeats, int predIndex, int argIndex, 55 | Integer offset, boolean allWords) { 56 | for (String v : getFeatureStrings(s, predIndex, argIndex)) { 57 | Integer i = indexOf(v); 58 | if (i != -1 && (allWords || i < predMaxIndex)) 59 | indices.add(i + offset); 60 | } 61 | } 62 | 63 | @Override 64 | public void addFeatures(Collection indices, 65 | Map nonbinFeats, Predicate pred, Word arg, 66 | Integer offset, boolean allWords) { 67 | for (String v : getFeatureStrings(pred, arg)) { 68 | Integer i = indexOf(v); 69 | if (i != -1 && (allWords || i < predMaxIndex)) 70 | indices.add(i + offset); 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/SingleFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import java.util.Collection; 4 | import java.util.Map; 5 | 6 | import se.lth.cs.srl.corpus.Predicate; 7 | import se.lth.cs.srl.corpus.Sentence; 8 | import se.lth.cs.srl.corpus.Word; 9 | import uk.ac.ed.inf.srl.features.Feature; 10 | import uk.ac.ed.inf.srl.features.FeatureName; 11 | 12 | public abstract class SingleFeature extends Feature { 13 | private static final long serialVersionUID = 1L; 14 | 15 | protected SingleFeature(FeatureName name, boolean includeArgs, 16 | boolean usedForPredicateIdentification, String POSPrefix) { 17 | super(name, includeArgs, usedForPredicateIdentification, POSPrefix); 18 | } 19 | 20 | @Override 21 | public void addFeatures(Sentence s, Collection indices, 22 | Map nonbinFeats, int predIndex, int argIndex, 23 | Integer offset, boolean allWords) { 24 | Integer i = indexOf(getFeatureString(s, predIndex, argIndex)); 25 | if (i != -1 && (allWords || i < predMaxIndex)) 26 | indices.add(i + offset); 27 | } 28 | 29 | @Override 30 | public void addFeatures(Collection indices, 31 | Map nonbinFeats, Predicate pred, Word arg, 32 | Integer offset, boolean allWords) { 33 | Integer i = indexOf(getFeatureString(pred, arg)); 34 | if (i != -1 && (allWords || i < predMaxIndex)) 35 | indices.add(i + offset); 36 | } 37 | 38 | /** 39 | * This method works with features that have includeArgs==false. It extracts 40 | * either from all words (if boolean allWords is true), or from the 41 | * predicates only (if false) 42 | */ 43 | @Override 44 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 45 | if (includeArgs) { 46 | throw new Error("You are wrong here, check your implementation."); 47 | } else { 48 | if (allWords) { 49 | for (int i = 1, size = s.size(); i < size; ++i) 50 | if (doExtractFeatures(s.get(i))) 51 | addMap(getFeatureString(s, i, -1)); 52 | } else { 53 | for (Predicate pred : s.getPredicates()) 54 | if (doExtractFeatures(pred)) 55 | addMap(getFeatureString(pred, null)); 56 | } 57 | } 58 | } 59 | 60 | public abstract String getFeatureString(Sentence s, int predIndex, 61 | int argIndex); 62 | 63 | public abstract String getFeatureString(Predicate pred, Word arg); 64 | } 65 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/SpanLengthFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import se.lth.cs.srl.corpus.Predicate; 4 | import se.lth.cs.srl.corpus.Sentence; 5 | import se.lth.cs.srl.corpus.Word; 6 | import se.lth.cs.srl.corpus.Word.WordData; 7 | import uk.ac.ed.inf.srl.features.FeatureName; 8 | import uk.ac.ed.inf.srl.features.NumFeature; 9 | import uk.ac.ed.inf.srl.features.SingleFeature; 10 | 11 | public class SpanLengthFeature extends SingleFeature { 12 | private static final long serialVersionUID = 1L; 13 | 14 | protected SpanLengthFeature(FeatureName name, WordData attr, 15 | boolean consider_deptree, String POSPrefix) { 16 | super(name, true, false, POSPrefix); 17 | } 18 | 19 | @Override 20 | protected void performFeatureExtraction(Sentence s, boolean allWords) { 21 | for (Predicate pred : s.getPredicates()) { 22 | if (doExtractFeatures(pred)) 23 | for (Word arg : pred.getArgMap().keySet()) { 24 | addMap(getFeatureString(pred, arg)); 25 | } 26 | } 27 | } 28 | 29 | @Override 30 | public String getFeatureString(Sentence s, int predIndex, int argIndex) { 31 | return makeFeatureString(s.get(predIndex), s.get(argIndex)); 32 | } 33 | 34 | @Override 35 | public String getFeatureString(Predicate pred, Word arg) { 36 | return makeFeatureString(pred, arg); 37 | } 38 | 39 | public String makeFeatureString(Word pred, Word arg) { 40 | // if(arg.getSpan().size()>1) { 41 | // for(Word w : arg.getSpan()) 42 | // System.err.print(w.getForm()+ " "); 43 | // System.err.println("-> " +NumFeature.bin(arg.getSpan().size())); 44 | // } 45 | return "SpanLength" + NumFeature.bin(arg.getSpan().size()); 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/SubCatSizeFeature.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | import java.util.Set; 4 | 5 | import se.lth.cs.srl.corpus.Predicate; 6 | import se.lth.cs.srl.corpus.Sentence; 7 | import se.lth.cs.srl.corpus.Word; 8 | import uk.ac.ed.inf.srl.features.FeatureName; 9 | import uk.ac.ed.inf.srl.features.NumFeature; 10 | import uk.ac.ed.inf.srl.features.SingleFeature; 11 | 12 | public class SubCatSizeFeature extends SingleFeature { 13 | private static final long serialVersionUID = 1L; 14 | 15 | SubCatSizeFeature(boolean usedForPredicateIdentification, String POSPrefix) { 16 | super(FeatureName.SubCatSize, false, usedForPredicateIdentification, 17 | POSPrefix); 18 | } 19 | 20 | @Override 21 | public String getFeatureString(Sentence s, int predIndex, int argIndex) { 22 | return makeFeatureString(s, s.get(predIndex).getChildren()); 23 | } 24 | 25 | @Override 26 | public String getFeatureString(Predicate pred, Word arg) { 27 | return makeFeatureString(pred.getMySentence(), pred.getChildren()); 28 | } 29 | 30 | private String makeFeatureString(Sentence s, Set children) { 31 | return "SUBCAT" + NumFeature.bin(children.size()); 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/features/TargetWord.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.features; 2 | 3 | public enum TargetWord { 4 | Pred, // This is the (potential) predicate 5 | PredParent, // This is the parent of the (potential) predicate 6 | Arg, // This is the (potential) argument 7 | LeftDep, // This is the leftmost dependent of the (potential) argument 8 | RightDep, // This is the rightmost dependent of the (potential) argument 9 | LeftSibling, // This is the left sibling of the (potential) argument 10 | RightSibling, // This is the right sibling of the (potential) argument 11 | PredSubj, 12 | 13 | FirstWord, LastWord, SecondWord, 14 | 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/ml/LearningProblem.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.ml; 2 | 3 | import java.util.Collection; 4 | import java.util.Map; 5 | 6 | public interface LearningProblem { 7 | 8 | public void addInstance(int label, Collection indices, 9 | Map nonbinFeats); 10 | 11 | public void done(); 12 | 13 | public Model train(); 14 | 15 | } -------------------------------------------------------------------------------- /src/main/java/uk/ac/ed/inf/srl/ml/Model.java: -------------------------------------------------------------------------------- 1 | package uk.ac.ed.inf.srl.ml; 2 | 3 | import java.io.Serializable; 4 | import java.util.Collection; 5 | import java.util.List; 6 | import java.util.Map; 7 | 8 | import uk.ac.ed.inf.srl.ml.liblinear.Label; 9 | 10 | public interface Model extends Serializable { 11 | 12 | public List