├── run ├── ocr │ └── ocr.conf ├── ref │ └── ref.conf └── punc │ └── punc.conf ├── lib └── stanford-ner.jar ├── src ├── HOCRF │ ├── Loglikelihood.java │ ├── Feature.java │ ├── FeatureIndex.java │ ├── DataSet.java │ ├── FeatureType.java │ ├── Params.java │ ├── HighOrderCRF.java │ ├── DataSequence.java │ ├── SentenceFeatGenerator.java │ ├── Utility.java │ ├── LabelMap.java │ ├── Function.java │ ├── Viterbi.java │ ├── LogliComputer.java │ ├── Scorer.java │ └── FeatureGenerator.java ├── HOSemiCRF │ ├── Loglikelihood.java │ ├── Feature.java │ ├── FeatureIndex.java │ ├── DataSet.java │ ├── FeatureType.java │ ├── Params.java │ ├── HighOrderSemiCRF.java │ ├── Utility.java │ ├── SentenceObsGenerator.java │ ├── LabelMap.java │ ├── Function.java │ ├── Viterbi.java │ ├── DataSequence.java │ ├── LogliComputer.java │ └── Scorer.java ├── OCR │ ├── CharDetails.java │ ├── Features │ │ ├── FifthOrderTransition.java │ │ ├── ThirdOrderTransition.java │ │ ├── FirstOrderTransition.java │ │ ├── FourthOrderTransition.java │ │ ├── SecondOrderTransition.java │ │ └── Pixel.java │ └── OCR.java ├── Applications │ ├── RefFeatures │ │ ├── Edge.java │ │ ├── EdgeBag.java │ │ ├── FirstOrderTransition.java │ │ ├── ThirdOrderTransition.java │ │ ├── SecondOrderTransition.java │ │ ├── WordBag.java │ │ ├── EdgeWord.java │ │ ├── NextWordBag.java │ │ ├── EdgeWordBag.java │ │ ├── PreviousWordBag.java │ │ ├── EdgePreviousWord.java │ │ ├── EdgePreviousWordBag.java │ │ ├── WordKPositionAfterBag.java │ │ ├── WordKPositionBeforeBag.java │ │ └── LetterNGramsBag.java │ ├── PuncFeatures │ │ ├── Edge.java │ │ ├── EdgeBag.java │ │ ├── FirstOrderTransition.java │ │ ├── SecondOrderTransition.java │ │ ├── ThirdOrderTransition.java │ │ ├── EdgeWord.java │ │ ├── EdgeWordBag.java │ │ ├── EdgePreviousWord.java │ │ ├── EdgeTwoWord.java │ │ ├── EdgePreviousWordBag.java │ │ ├── EdgeTwoWordBag.java │ │ ├── FirstOrderTransitionWord.java │ │ ├── SecondOrderTransitionWord.java │ │ ├── ThirdOrderTransitionWord.java │ │ ├── WordPositionBag.java │ │ └── TwoWordPositionBag.java │ ├── PuncConverter.java │ ├── ReferenceTagger.java │ └── PunctuationPredictor.java └── Parallel │ ├── Schedulable.java │ ├── Timer.java │ ├── TaskThread.java │ └── Scheduler.java ├── LICENSE └── README.md /run/ocr/ocr.conf: -------------------------------------------------------------------------------- 1 | maxIters=100 2 | numthreads=4 -------------------------------------------------------------------------------- /run/ref/ref.conf: -------------------------------------------------------------------------------- 1 | maxIters=100 2 | numthreads=4 -------------------------------------------------------------------------------- /run/punc/punc.conf: -------------------------------------------------------------------------------- 1 | maxIters=100 2 | numthreads=4 3 | maxSegment=1 -------------------------------------------------------------------------------- /lib/stanford-ner.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nvcuong/HOSemiCRF/HEAD/lib/stanford-ner.jar -------------------------------------------------------------------------------- /src/HOCRF/Loglikelihood.java: -------------------------------------------------------------------------------- 1 | package HOCRF; 2 | 3 | /** 4 | * Loglikelihood class 5 | * @author Nguyen Viet Cuong 6 | */ 7 | public class Loglikelihood { 8 | 9 | double logli; // Loglikelihood value 10 | double derivatives[]; // Loglikelihood derivatives 11 | 12 | /** 13 | * Construct a loglikelihood with a given number of features. 14 | * @param n Number of features 15 | */ 16 | public Loglikelihood(int n) { 17 | logli = 0; 18 | derivatives = new double[n]; 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/HOSemiCRF/Loglikelihood.java: -------------------------------------------------------------------------------- 1 | package HOSemiCRF; 2 | 3 | /** 4 | * Loglikelihood class 5 | * @author Nguyen Viet Cuong 6 | */ 7 | public class Loglikelihood { 8 | 9 | double logli; // Loglikelihood value 10 | double derivatives[]; // Loglikelihood derivatives 11 | 12 | /** 13 | * Construct a loglikelihood with a given number of features. 14 | * @param n Number of features 15 | */ 16 | public Loglikelihood(int n) { 17 | logli = 0; 18 | derivatives = new double[n]; 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/OCR/CharDetails.java: -------------------------------------------------------------------------------- 1 | package OCR; 2 | 3 | /** 4 | * Details of a character 5 | * @author Nguyen Viet Cuong 6 | */ 7 | public class CharDetails { 8 | public static final int ROWS = 16; 9 | public static final int COLS = 8; 10 | 11 | int[][] pixels = new int[ROWS][COLS]; 12 | 13 | public CharDetails(int[][] p) { 14 | pixels = p; 15 | } 16 | 17 | public int getPixels(int r, int c) { 18 | return pixels[r][c]; 19 | } 20 | 21 | @Override 22 | public String toString() { 23 | return ""; 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/OCR/Features/FifthOrderTransition.java: -------------------------------------------------------------------------------- 1 | package OCR.Features; 2 | 3 | import java.util.*; 4 | import HOCRF.*; 5 | 6 | /** 7 | * Fifth order transition features 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class FifthOrderTransition extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int pos) { 13 | ArrayList obs = new ArrayList(); 14 | if (pos >= 5) { 15 | obs.add("5E."); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 5; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/OCR/Features/ThirdOrderTransition.java: -------------------------------------------------------------------------------- 1 | package OCR.Features; 2 | 3 | import java.util.*; 4 | import HOCRF.*; 5 | 6 | /** 7 | * Third order transition features 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class ThirdOrderTransition extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int pos) { 13 | ArrayList obs = new ArrayList(); 14 | if (pos >= 3) { 15 | obs.add("3E."); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 3; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/OCR/Features/FirstOrderTransition.java: -------------------------------------------------------------------------------- 1 | package OCR.Features; 2 | 3 | import java.util.*; 4 | import HOCRF.*; 5 | 6 | /** 7 | * First order transition features 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class FirstOrderTransition extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int pos) { 13 | ArrayList obs = new ArrayList(); 14 | if (pos >= 1) { 15 | obs.add("1E."); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 1; 22 | } 23 | } 24 | 25 | -------------------------------------------------------------------------------- /src/OCR/Features/FourthOrderTransition.java: -------------------------------------------------------------------------------- 1 | package OCR.Features; 2 | 3 | import java.util.*; 4 | import HOCRF.*; 5 | 6 | /** 7 | * Fourth order transition features 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class FourthOrderTransition extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int pos) { 13 | ArrayList obs = new ArrayList(); 14 | if (pos >= 4) { 15 | obs.add("4E."); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 4; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/OCR/Features/SecondOrderTransition.java: -------------------------------------------------------------------------------- 1 | package OCR.Features; 2 | 3 | import java.util.*; 4 | import HOCRF.*; 5 | 6 | /** 7 | * Second order transition features 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class SecondOrderTransition extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int pos) { 13 | ArrayList obs = new ArrayList(); 14 | if (pos >= 2) { 15 | obs.add("2E."); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 2; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/RefFeatures/Edge.java: -------------------------------------------------------------------------------- 1 | package Applications.RefFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Edge features between two consecutive segments 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class Edge extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | if (segStart >= 1) { 15 | obs.add("E."); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 1; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/PuncFeatures/Edge.java: -------------------------------------------------------------------------------- 1 | package Applications.PuncFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Edge features between two consecutive segments 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class Edge extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | if (segStart > 0) { 15 | obs.add("E."); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 1; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/PuncFeatures/EdgeBag.java: -------------------------------------------------------------------------------- 1 | package Applications.PuncFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Edge features within a segment 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class EdgeBag extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | for (int i = segStart + 1; i <= segEnd; i++) { 15 | obs.add("EB."); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 0; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/RefFeatures/EdgeBag.java: -------------------------------------------------------------------------------- 1 | package Applications.RefFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Edge features within a segment 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class EdgeBag extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | for (int i = segStart + 1; i <= segEnd; i++) { 15 | obs.add("EB."); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 0; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/RefFeatures/FirstOrderTransition.java: -------------------------------------------------------------------------------- 1 | package Applications.RefFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Semi-Markov first order transition features 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class FirstOrderTransition extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | if (segStart >= 1) { 15 | obs.add("1E."); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 1; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/RefFeatures/ThirdOrderTransition.java: -------------------------------------------------------------------------------- 1 | package Applications.RefFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Semi-Markov third order transition features 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class ThirdOrderTransition extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | if (segStart >= 3) { 15 | obs.add("3E."); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 3; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/PuncFeatures/FirstOrderTransition.java: -------------------------------------------------------------------------------- 1 | package Applications.PuncFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Semi-Markov first order transition features 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class FirstOrderTransition extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | if (segStart >= 1) { 15 | obs.add("1E."); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 1; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/PuncFeatures/SecondOrderTransition.java: -------------------------------------------------------------------------------- 1 | package Applications.PuncFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Semi-Markov second order transition features 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class SecondOrderTransition extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | if (segStart >= 2) { 15 | obs.add("2E."); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 2; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/PuncFeatures/ThirdOrderTransition.java: -------------------------------------------------------------------------------- 1 | package Applications.PuncFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Semi-Markov third order transition features 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class ThirdOrderTransition extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | if (segStart >= 3) { 15 | obs.add("3E."); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 3; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/RefFeatures/SecondOrderTransition.java: -------------------------------------------------------------------------------- 1 | package Applications.RefFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Semi-Markov second order transition features 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class SecondOrderTransition extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | if (segStart >= 2) { 15 | obs.add("2E."); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 2; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/RefFeatures/WordBag.java: -------------------------------------------------------------------------------- 1 | package Applications.RefFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Current word features within a segment 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class WordBag extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | for (int i = segStart; i <= segEnd; i++) { 15 | obs.add("WB." + seq.x(i)); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 0; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/PuncFeatures/EdgeWord.java: -------------------------------------------------------------------------------- 1 | package Applications.PuncFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Edge and word features between two consecutive segments 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class EdgeWord extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | if (segStart > 0) { 15 | obs.add("EW." + seq.x(segStart)); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 1; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/RefFeatures/EdgeWord.java: -------------------------------------------------------------------------------- 1 | package Applications.RefFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Edge features between consecutive segments and current words 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class EdgeWord extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | if (segStart > 0) { 15 | obs.add("EW." + seq.x(segStart)); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 1; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/RefFeatures/NextWordBag.java: -------------------------------------------------------------------------------- 1 | package Applications.RefFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Next word features within a segment 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class NextWordBag extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | for (int i = segStart; i <= segEnd; i++) { 15 | obs.add("NWB." + seq.x(i + 1)); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 0; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/PuncFeatures/EdgeWordBag.java: -------------------------------------------------------------------------------- 1 | package Applications.PuncFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Edge and word features within a segment 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class EdgeWordBag extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | for (int i = segStart + 1; i <= segEnd; i++) { 15 | obs.add("EWB." + seq.x(i)); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 0; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/RefFeatures/EdgeWordBag.java: -------------------------------------------------------------------------------- 1 | package Applications.RefFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Edge features within a segment and current words 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class EdgeWordBag extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | for (int i = segStart + 1; i <= segEnd; i++) { 15 | obs.add("EWB." + seq.x(i)); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 0; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/RefFeatures/PreviousWordBag.java: -------------------------------------------------------------------------------- 1 | package Applications.RefFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Previous word features within a segment 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class PreviousWordBag extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | for (int i = segStart; i <= segEnd; i++) { 15 | obs.add("PWB." + seq.x(i - 1)); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 0; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/PuncFeatures/EdgePreviousWord.java: -------------------------------------------------------------------------------- 1 | package Applications.PuncFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Edge and previous features between two consecutive segments 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class EdgePreviousWord extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | if (segStart > 0) { 15 | obs.add("EPW." + seq.x(segStart - 1)); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 1; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/RefFeatures/EdgePreviousWord.java: -------------------------------------------------------------------------------- 1 | package Applications.RefFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Edge features between consecutive segments and previous words 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class EdgePreviousWord extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | if (segStart > 0) { 15 | obs.add("EPW." + seq.x(segStart - 1)); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 1; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/PuncFeatures/EdgeTwoWord.java: -------------------------------------------------------------------------------- 1 | package Applications.PuncFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Edge and two consecutive words between two segments 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class EdgeTwoWord extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | if (segStart > 0) { 15 | obs.add("ETW." + seq.x(segStart - 1) + "." + seq.x(segStart)); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 1; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/PuncFeatures/EdgePreviousWordBag.java: -------------------------------------------------------------------------------- 1 | package Applications.PuncFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Edge and previous word features within a segment 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class EdgePreviousWordBag extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | for (int i = segStart + 1; i <= segEnd; i++) { 15 | obs.add("EPWB." + seq.x(i - 1)); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 0; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/RefFeatures/EdgePreviousWordBag.java: -------------------------------------------------------------------------------- 1 | package Applications.RefFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Edge features within a segment and previous words 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class EdgePreviousWordBag extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | for (int i = segStart + 1; i <= segEnd; i++) { 15 | obs.add("EPWB." + seq.x(i - 1)); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 0; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/Applications/PuncFeatures/EdgeTwoWordBag.java: -------------------------------------------------------------------------------- 1 | package Applications.PuncFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Edge and two consecutive words features within a segment 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class EdgeTwoWordBag extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | for (int i = segStart + 1; i <= segEnd; i++) { 15 | obs.add("ETWB." + seq.x(i - 1) + "." + seq.x(i)); 16 | } 17 | return obs; 18 | } 19 | 20 | public int order() { 21 | return 0; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/HOCRF/Feature.java: -------------------------------------------------------------------------------- 1 | package HOCRF; 2 | 3 | /** 4 | * Feature class 5 | * @author Nguyen Viet Cuong 6 | */ 7 | public class Feature { 8 | 9 | String obs; // The observation part of the feature 10 | String pat; // The pattern of the feature 11 | double value; // Value of the feature 12 | 13 | /** 14 | * Construct a new feature from observation, pattern, and value. 15 | * @param obs Observation of the feature 16 | * @param pat Label pattern of the feature 17 | * @param value Value of the feature 18 | */ 19 | public Feature(String obs, String pat, double value) { 20 | this.obs = obs; 21 | this.pat = pat; 22 | this.value = value; 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /src/HOSemiCRF/Feature.java: -------------------------------------------------------------------------------- 1 | package HOSemiCRF; 2 | 3 | /** 4 | * Feature class 5 | * @author Nguyen Viet Cuong 6 | */ 7 | public class Feature { 8 | 9 | String obs; // The observation part of the feature 10 | String pat; // The pattern of the feature 11 | double value; // Value of the feature 12 | 13 | /** 14 | * Construct a new feature from observation, pattern, and value. 15 | * @param obs Observation of the feature 16 | * @param pat Label pattern of the feature 17 | * @param value Value of the feature 18 | */ 19 | public Feature(String obs, String pat, double value) { 20 | this.obs = obs; 21 | this.pat = pat; 22 | this.value = value; 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /src/Applications/PuncFeatures/FirstOrderTransitionWord.java: -------------------------------------------------------------------------------- 1 | package Applications.PuncFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Semi-Markov first order transition with word features 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class FirstOrderTransitionWord extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | if (segStart >= 1) { 15 | for (int i = segStart; i <= segEnd; i++) { 16 | obs.add("E1W." + seq.x(i)); 17 | } 18 | } 19 | return obs; 20 | } 21 | 22 | public int order() { 23 | return 1; 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/Applications/PuncFeatures/SecondOrderTransitionWord.java: -------------------------------------------------------------------------------- 1 | package Applications.PuncFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Semi-Markov second order transition with word features 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class SecondOrderTransitionWord extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | if (segStart >= 2) { 15 | for (int i = segStart; i <= segEnd; i++) { 16 | obs.add("E2W." + seq.x(i)); 17 | } 18 | } 19 | return obs; 20 | } 21 | 22 | public int order() { 23 | return 2; 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/Applications/PuncFeatures/ThirdOrderTransitionWord.java: -------------------------------------------------------------------------------- 1 | package Applications.PuncFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Semi-Markov third order transition with word features 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class ThirdOrderTransitionWord extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | if (segStart >= 3) { 15 | for (int i = segStart; i <= segEnd; i++) { 16 | obs.add("E3W." + seq.x(i)); 17 | } 18 | } 19 | return obs; 20 | } 21 | 22 | public int order() { 23 | return 3; 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/OCR/Features/Pixel.java: -------------------------------------------------------------------------------- 1 | package OCR.Features; 2 | 3 | import java.util.*; 4 | import HOCRF.*; 5 | import OCR.*; 6 | 7 | /** 8 | * Pixel features 9 | * @author Nguyen Viet Cuong 10 | */ 11 | public class Pixel extends FeatureType { 12 | 13 | public ArrayList generateObsAt(DataSequence seq, int pos) { 14 | ArrayList obs = new ArrayList(); 15 | CharDetails cd = (CharDetails) seq.x(pos); 16 | for (int r = 0; r < CharDetails.ROWS; r++) { 17 | for (int c = 0; c < CharDetails.COLS; c++) { 18 | if (cd.getPixels(r, c) != 0) { 19 | obs.add(r + "." + c); 20 | } 21 | } 22 | } 23 | return obs; 24 | } 25 | 26 | public int order() { 27 | return 0; 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/Applications/RefFeatures/WordKPositionAfterBag.java: -------------------------------------------------------------------------------- 1 | package Applications.RefFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Words from 2nd to 4th position after a position within a segment 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class WordKPositionAfterBag extends FeatureType { 11 | 12 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 13 | ArrayList obs = new ArrayList(); 14 | for (int i = segStart; i <= segEnd; i++) { 15 | for (int j = i + 1; j < i + WordKPositionBeforeBag.K && j <= seq.length(); j++) { 16 | obs.add("WKAB." + seq.x(j)); 17 | } 18 | } 19 | return obs; 20 | } 21 | 22 | public int order() { 23 | return 0; 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/Applications/RefFeatures/WordKPositionBeforeBag.java: -------------------------------------------------------------------------------- 1 | package Applications.RefFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Words from 2nd to 4th position before a position within a segment 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class WordKPositionBeforeBag extends FeatureType { 11 | 12 | static final int K = 5; 13 | 14 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 15 | ArrayList obs = new ArrayList(); 16 | for (int i = segStart; i <= segEnd; i++) { 17 | for (int j = i - 1; j > i - K && j >= -1; j--) { 18 | obs.add("WKBB." + seq.x(j)); 19 | } 20 | } 21 | return obs; 22 | } 23 | 24 | public int order() { 25 | return 0; 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/Parallel/Schedulable.java: -------------------------------------------------------------------------------- 1 | package Parallel; 2 | 3 | /** 4 | * Interface for one schedulable task 5 | * @author Ye Nan 6 | */ 7 | public interface Schedulable { 8 | 9 | /** 10 | * Process one task. 11 | * @param taskID ID of the task 12 | * @return Partial result 13 | */ 14 | public Object compute(int taskID); //can do synchronized update here if partial result takes too much memory 15 | 16 | /** 17 | * Return the total number of tasks. 18 | * @return Total number of tasks 19 | */ 20 | public int getNumTasks(); 21 | 22 | /** 23 | * Return the next task ID. 24 | * @return Next task ID 25 | */ 26 | public int fetchCurrTaskID(); // synchronize 27 | 28 | /** 29 | * Update using the partial result. 30 | * @param partialResult Partial result 31 | */ 32 | public void update(Object partialResult); // synchronize 33 | } 34 | -------------------------------------------------------------------------------- /src/Applications/PuncFeatures/WordPositionBag.java: -------------------------------------------------------------------------------- 1 | package Applications.PuncFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Current word and position features 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class WordPositionBag extends FeatureType { 11 | 12 | static final int K = 5; 13 | 14 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 15 | ArrayList obs = new ArrayList(); 16 | for (int i = segStart; i <= segEnd; i++) { 17 | for (int j = i; j < i + K && j <= seq.length(); j++) { 18 | obs.add("WPB." + seq.x(j) + "." + (j - i)); 19 | } 20 | for (int j = i - 1; j > i - K && j >= -1; j--) { 21 | obs.add("WPB." + seq.x(j) + "." + (j - i)); 22 | } 23 | } 24 | return obs; 25 | } 26 | 27 | public int order() { 28 | return 0; 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/Applications/PuncFeatures/TwoWordPositionBag.java: -------------------------------------------------------------------------------- 1 | package Applications.PuncFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Two words around the current word and their positions feature 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class TwoWordPositionBag extends FeatureType { 11 | 12 | static final int K = 4; 13 | 14 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 15 | ArrayList obs = new ArrayList(); 16 | for (int i = segStart; i <= segEnd; i++) { 17 | for (int j = -K; j < K; j++) { 18 | for (int k = j + 1; k <= K; k++) { 19 | if (i + j >= 0 && i + j <= seq.length() && i + k >= 0 && i + k <= seq.length()) { 20 | obs.add("2WPB." + seq.x(i + j) + "." + seq.x(i + k) + "." + j + "." + k); 21 | } 22 | } 23 | } 24 | } 25 | return obs; 26 | } 27 | 28 | public int order() { 29 | return 0; 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2012 Cuong V. Nguyen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/HOCRF/FeatureIndex.java: -------------------------------------------------------------------------------- 1 | package HOCRF; 2 | 3 | /** 4 | * Index of a feature 5 | * @author Nguyen Viet Cuong 6 | */ 7 | public class FeatureIndex { 8 | 9 | int obsID; // ID of the observation part 10 | int patID; // ID of the pattern part 11 | 12 | /** 13 | * Construct a feature index from observation and pattern IDs. 14 | * @param obsID Observation ID 15 | * @param patID Pattern ID 16 | */ 17 | public FeatureIndex(int obsID, int patID) { 18 | this.obsID = obsID; 19 | this.patID = patID; 20 | } 21 | 22 | @Override 23 | public boolean equals(Object o) { 24 | if (this == o) return true; 25 | if (o == null || getClass() != o.getClass()) return false; 26 | 27 | FeatureIndex that = (FeatureIndex) o; 28 | if (obsID != that.obsID) return false; 29 | if (patID != that.patID) return false; 30 | 31 | return true; 32 | } 33 | 34 | @Override 35 | public int hashCode() { 36 | int result = 23; 37 | result = result*31 + obsID; 38 | result = result*31 + patID; 39 | return result; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/HOSemiCRF/FeatureIndex.java: -------------------------------------------------------------------------------- 1 | package HOSemiCRF; 2 | 3 | /** 4 | * Index of a feature 5 | * @author Nguyen Viet Cuong 6 | */ 7 | public class FeatureIndex { 8 | 9 | int obsID; // ID of the observation part 10 | int patID; // ID of the pattern part 11 | 12 | /** 13 | * Construct a feature index from observation and pattern IDs. 14 | * @param obsID Observation ID 15 | * @param patID Pattern ID 16 | */ 17 | public FeatureIndex(int obsID, int patID) { 18 | this.obsID = obsID; 19 | this.patID = patID; 20 | } 21 | 22 | @Override 23 | public boolean equals(Object o) { 24 | if (this == o) return true; 25 | if (o == null || getClass() != o.getClass()) return false; 26 | 27 | FeatureIndex that = (FeatureIndex) o; 28 | if (obsID != that.obsID) return false; 29 | if (patID != that.patID) return false; 30 | 31 | return true; 32 | } 33 | 34 | @Override 35 | public int hashCode() { 36 | int result = 23; 37 | result = result*31 + obsID; 38 | result = result*31 + patID; 39 | return result; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/HOCRF/DataSet.java: -------------------------------------------------------------------------------- 1 | package HOCRF; 2 | 3 | import java.util.*; 4 | import java.io.*; 5 | 6 | /** 7 | * Class for a dataset 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class DataSet { 11 | 12 | ArrayList trainSeqs; // List of all data sequences 13 | 14 | /** 15 | * Construct a dataset from a list of data sequences. 16 | * @param trs List of data sequences 17 | */ 18 | public DataSet(ArrayList trs) { 19 | trainSeqs = trs; 20 | } 21 | 22 | /** 23 | * Get the list of data sequences in the dataset. 24 | * @return List of data sequences 25 | */ 26 | public ArrayList getSeqList() { 27 | return trainSeqs; 28 | } 29 | 30 | /** 31 | * Write the dataset to a file. 32 | * @param filename Name of the output file 33 | */ 34 | public void writeToFile(String filename) throws Exception { 35 | BufferedWriter bw = new BufferedWriter(new FileWriter(filename, false)); 36 | for (int i = 0; i < trainSeqs.size(); i++) { 37 | trainSeqs.get(i).writeToBuffer(bw); 38 | if (i < trainSeqs.size() - 1) { 39 | bw.write("\n"); 40 | } 41 | } 42 | bw.close(); 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/HOSemiCRF/DataSet.java: -------------------------------------------------------------------------------- 1 | package HOSemiCRF; 2 | 3 | import java.util.*; 4 | import java.io.*; 5 | 6 | /** 7 | * Class for a dataset 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class DataSet { 11 | 12 | ArrayList trainSeqs; // List of all data sequences 13 | 14 | /** 15 | * Construct a dataset from a list of data sequences. 16 | * @param trs List of data sequences 17 | */ 18 | public DataSet(ArrayList trs) { 19 | trainSeqs = trs; 20 | } 21 | 22 | /** 23 | * Get the list of data sequences in the dataset. 24 | * @return List of data sequences 25 | */ 26 | public ArrayList getSeqList() { 27 | return trainSeqs; 28 | } 29 | 30 | /** 31 | * Write the dataset to a file. 32 | * @param filename Name of the output file 33 | */ 34 | public void writeToFile(String filename) throws Exception { 35 | BufferedWriter bw = new BufferedWriter(new FileWriter(filename, false)); 36 | for (int i = 0; i < trainSeqs.size(); i++) { 37 | trainSeqs.get(i).writeToBuffer(bw); 38 | if (i < trainSeqs.size() - 1) { 39 | bw.write("\n"); 40 | } 41 | } 42 | bw.close(); 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/HOCRF/FeatureType.java: -------------------------------------------------------------------------------- 1 | package HOCRF; 2 | 3 | import java.util.*; 4 | 5 | /** 6 | * Abstract class for feature types 7 | * @author Nguyen Viet Cuong 8 | */ 9 | public abstract class FeatureType { 10 | 11 | /** 12 | * Return the order of the feature type. 13 | */ 14 | public abstract int order(); 15 | 16 | /** 17 | * Return the list of observations at a position. 18 | * @param seq Data sequence 19 | * @param pos Input position 20 | * @return List of observations 21 | */ 22 | public abstract ArrayList generateObsAt(DataSequence seq, int pos); 23 | 24 | /** 25 | * Generate the features activated at a position and a label pattern. 26 | * @param seq Data sequence 27 | * @param pos Input position 28 | * @param labelPat Label pattern of the features 29 | * @return List of activated features 30 | */ 31 | public ArrayList generateFeaturesAt(DataSequence seq, int pos, String labelPat) { 32 | ArrayList features = new ArrayList(); 33 | if (Utility.getOrder(labelPat) == order()) { 34 | ArrayList obs = generateObsAt(seq, pos); 35 | for (String o : obs) { 36 | features.add(new Feature(o, labelPat, 1.0)); 37 | } 38 | } 39 | return features; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/Parallel/Timer.java: -------------------------------------------------------------------------------- 1 | package Parallel; 2 | 3 | import java.util.*; 4 | 5 | /** 6 | * Timer class 7 | * @author Ye Nan 8 | */ 9 | public class Timer { 10 | 11 | public static Hashtable records = new Hashtable(); 12 | public static Vector stimes = new Vector(); 13 | public static double stime, etime; 14 | 15 | public static void report() { 16 | System.out.println("Running times"); 17 | System.out.println(records); 18 | } 19 | 20 | public static void start() { 21 | stimes.add((double) System.currentTimeMillis()); 22 | } 23 | 24 | public static void end(String msg) { 25 | etime = System.currentTimeMillis(); 26 | stime = stimes.lastElement(); 27 | stimes.remove(stimes.size() - 1); 28 | System.out.println(msg + " Time taken: " + (etime - stime) / 1000 + "s"); 29 | } 30 | 31 | public static void record(String task) { 32 | etime = System.currentTimeMillis(); 33 | stime = stimes.lastElement(); 34 | stimes.remove(stimes.size() - 1); 35 | double t = (etime - stime) / 1000; 36 | if (records.containsKey(task)) { 37 | double oldT = records.get(task); 38 | records.put(task, oldT + t); 39 | } else { 40 | records.put(task, t); 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Source code for High-order Semi-Markov Conditional Random Field model 2 | 3 | === WARNING === 4 | 5 | HOSemiCRF requires a lot of memory usage. It is best to run the program 6 | in parallel on a computing cluster with lots of memory. 7 | 8 | === COMPILATION STEPS === 9 | 10 | Requirement: Apache Ant (http://ant.apache.org/) 11 | 12 | 1. Download the HOSemiCRF repository as a zip file: HOSemiCRF-master.zip 13 | 2. Unzip the file: 14 | 15 | unzip HOSemiCRF-master.zip 16 | 17 | 3. Compile the program: 18 | 19 | cd HOSemiCRF-master 20 | 21 | ant 22 | 23 | === RUN THE PUNCTUATION PREDICTION PROGRAM === 24 | 25 | cp dist/lib/HOSemiCRF.jar run/punc/ 26 | cd run/punc 27 | java -cp "HOSemiCRF.jar" Applications.PunctuationPredictor all punc.conf 28 | 29 | === RUN THE REFERENCE PREDICTION PROGRAM === 30 | 31 | cp dist/lib/HOSemiCRF.jar run/ref/ 32 | cd run/ref 33 | java -cp "HOSemiCRF.jar" Applications.ReferenceTagger all ref.conf 34 | 35 | === RUN THE OCR PROGRAM === 36 | 37 | Download data from http://www.seas.upenn.edu/~taskar/ocr/ to the folder run/ocr/ 38 | 39 | cp dist/lib/HOSemiCRF.jar run/ocr/ 40 | cd run/ocr 41 | java -cp "HOSemiCRF.jar" OCR.OCR all ocr.conf 0 42 | 43 | === MORE INFO === 44 | 45 | Please visit: https://github.com/nvcuong/HOSemiCRF/wiki 46 | -------------------------------------------------------------------------------- /src/HOCRF/Params.java: -------------------------------------------------------------------------------- 1 | package HOCRF; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | 6 | /** 7 | * Parameters class 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class Params { 11 | 12 | int numLabels; // Number of labels 13 | int maxIters = 100; // Number of training iterations 14 | int numthreads = 1; // Number of parallel threads 15 | double invSigmaSquare = 1.0; // Inverse of Sigma Squared 16 | double epsForConvergence = 0.001; // Convergence Precision 17 | 18 | /** 19 | * Construct a parameters object. 20 | * @param filename Name of configuration file 21 | * @param nl Number of labels 22 | */ 23 | public Params(String filename, int nl) throws IOException { 24 | Properties options = new Properties(); 25 | options.load(new FileInputStream(filename)); 26 | String value = null; 27 | if ((value = options.getProperty("maxIters")) != null) { 28 | maxIters = Integer.parseInt(value); 29 | } 30 | if ((value = options.getProperty("numthreads")) != null) { 31 | numthreads = Integer.parseInt(value); 32 | } 33 | if ((value = options.getProperty("invSigmaSquare")) != null) { 34 | invSigmaSquare = Double.parseDouble(value); 35 | } 36 | if ((value = options.getProperty("epsForConvergence")) != null) { 37 | epsForConvergence = Double.parseDouble(value); 38 | } 39 | numLabels = nl; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/Applications/RefFeatures/LetterNGramsBag.java: -------------------------------------------------------------------------------- 1 | package Applications.RefFeatures; 2 | 3 | import java.util.*; 4 | import HOSemiCRF.*; 5 | 6 | /** 7 | * Letter n-grams of words within a segment 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class LetterNGramsBag extends FeatureType { 11 | 12 | static final int K = 6; 13 | 14 | public ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd) { 15 | ArrayList obs = new ArrayList(); 16 | for (int i = segStart; i <= segEnd; i++) { 17 | String word = (String) seq.x(i); 18 | for (int N = 2; N <= K; N++) { 19 | ArrayList ngrams = letterNGrams(word, N); 20 | int c = ngrams.size(); 21 | for (int j = 0; j < c; j++) { 22 | obs.add("LNGB." + ngrams.get(j)); 23 | } 24 | } 25 | } 26 | return obs; 27 | } 28 | 29 | public int order() { 30 | return 0; 31 | } 32 | 33 | ArrayList letterNGrams(String word, int N) { 34 | word = "<" + word + ">"; 35 | ArrayList ngrams = new ArrayList(); 36 | int l = word.length() - N + 1; 37 | for (int i = 0; i < l; i++) { 38 | if (i != 0 && i + N != l) { 39 | continue; 40 | } 41 | ngrams.add("#" + word.substring(i, i + N) + "#"); 42 | } 43 | return ngrams; 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/Parallel/TaskThread.java: -------------------------------------------------------------------------------- 1 | package Parallel; 2 | 3 | /** 4 | * Class for one task 5 | * @author Ye Nan 6 | */ 7 | class TaskThread extends Thread { 8 | 9 | int id; 10 | Schedulable task; 11 | Scheduler scheduler; 12 | boolean done = false; // True if task done 13 | int nTasks = 0; 14 | int[] taskIDs = null; 15 | 16 | public TaskThread(Schedulable task, Scheduler scheduler, int id) { 17 | this.scheduler = scheduler; 18 | this.task = task; 19 | this.id = id; 20 | } 21 | 22 | public void setTaskIDs(int[] taskIDs) { 23 | this.taskIDs = taskIDs; 24 | } 25 | 26 | @Override 27 | public void run() { 28 | Timer.start(); 29 | while (true) { 30 | int taskID = -1; 31 | if (taskIDs == null) { 32 | taskID = scheduler.fetchTaskID(id); 33 | if (taskID >= task.getNumTasks()) { 34 | break; 35 | } 36 | } else { 37 | if (nTasks == taskIDs.length) { 38 | break; 39 | } 40 | taskID = taskIDs[nTasks]; 41 | } 42 | 43 | nTasks++; 44 | 45 | Object res = task.compute(taskID); 46 | task.update(res); 47 | } 48 | done = true; 49 | } 50 | 51 | public boolean done() { 52 | return done; 53 | } 54 | 55 | public int getNumProcessedTasks() { 56 | return nTasks; 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/HOSemiCRF/FeatureType.java: -------------------------------------------------------------------------------- 1 | package HOSemiCRF; 2 | 3 | import java.util.*; 4 | 5 | /** 6 | * Abstract class for feature types 7 | * @author Nguyen Viet Cuong 8 | */ 9 | public abstract class FeatureType { 10 | 11 | /** 12 | * Return the order of the feature type. 13 | */ 14 | public abstract int order(); 15 | 16 | /** 17 | * Return the list of observations in a subsequence. 18 | * @param seq Data sequence 19 | * @param segStart Start position of the subsequence 20 | * @param segEnd End position of the subsequence 21 | * @return List of observations 22 | */ 23 | public abstract ArrayList generateObsAt(DataSequence seq, int segStart, int segEnd); 24 | 25 | /** 26 | * Generate the features activated at a segment and a label pattern. 27 | * @param seq Data sequence 28 | * @param segStart Start position of the segment 29 | * @param segEnd End position of the segment 30 | * @param labelPat Label pattern of the features 31 | * @return List of activated features 32 | */ 33 | public ArrayList generateFeaturesAt(DataSequence seq, int segStart, int segEnd, String labelPat) { 34 | ArrayList features = new ArrayList(); 35 | if (Utility.getOrder(labelPat) == order()) { 36 | ArrayList obs = generateObsAt(seq, segStart, segEnd); 37 | for (String o : obs) { 38 | features.add(new Feature(o, labelPat, 1.0)); 39 | } 40 | } 41 | return features; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/HOSemiCRF/Params.java: -------------------------------------------------------------------------------- 1 | package HOSemiCRF; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | 6 | /** 7 | * Parameters class 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class Params { 11 | 12 | int numLabels; // Number of labels 13 | int maxIters = 100; // Number of training iterations 14 | int numthreads = 1; // Number of parallel threads 15 | int maxSegment = -1; // Maximum segment length 16 | double invSigmaSquare = 1.0; // Inverse of Sigma Squared 17 | double epsForConvergence = 0.001; // Convergence Precision 18 | 19 | /** 20 | * Construct a parameters object. 21 | * @param filename Name of configuration file 22 | * @param nl Number of labels 23 | */ 24 | public Params(String filename, int nl) throws IOException { 25 | Properties options = new Properties(); 26 | options.load(new FileInputStream(filename)); 27 | String value = null; 28 | if ((value = options.getProperty("maxIters")) != null) { 29 | maxIters = Integer.parseInt(value); 30 | } 31 | if ((value = options.getProperty("numthreads")) != null) { 32 | numthreads = Integer.parseInt(value); 33 | } 34 | if ((value = options.getProperty("maxSegment")) != null) { 35 | maxSegment = Integer.parseInt(value); 36 | } 37 | if ((value = options.getProperty("invSigmaSquare")) != null) { 38 | invSigmaSquare = Double.parseDouble(value); 39 | } 40 | if ((value = options.getProperty("epsForConvergence")) != null) { 41 | epsForConvergence = Double.parseDouble(value); 42 | } 43 | numLabels = nl; 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/HOCRF/HighOrderCRF.java: -------------------------------------------------------------------------------- 1 | package HOCRF; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | import edu.stanford.nlp.optimization.*; 6 | import Parallel.*; 7 | 8 | /** 9 | * High-order CRF class 10 | * @author Nguyen Viet Cuong 11 | */ 12 | public class HighOrderCRF { 13 | 14 | FeatureGenerator featureGen; // Feature generator 15 | double[] lambda; // Feature weight vector 16 | 17 | /** 18 | * Construct and initialize a high-order CRF from feature generator. 19 | * @param fgen Feature generator 20 | */ 21 | public HighOrderCRF(FeatureGenerator fgen) { 22 | featureGen = fgen; 23 | lambda = new double[featureGen.featureMap.size()]; 24 | Arrays.fill(lambda, 0.0); 25 | } 26 | 27 | /** 28 | * Train a high-order CRF from data. 29 | * @param data Training data 30 | */ 31 | public void train(ArrayList data) { 32 | QNMinimizer qn = new QNMinimizer(); 33 | Function df = new Function(featureGen, data); 34 | lambda = qn.minimize(df, featureGen.params.epsForConvergence, lambda, featureGen.params.maxIters); 35 | } 36 | 37 | /** 38 | * Run Viterbi algorithm on testing data. 39 | * @param data Testing data 40 | */ 41 | public void runViterbi(ArrayList data) throws Exception { 42 | Viterbi tester = new Viterbi(featureGen, lambda, data); 43 | Scheduler sch = new Scheduler(tester, featureGen.params.numthreads, Scheduler.DYNAMIC_NEXT_AVAILABLE); 44 | sch.run(); 45 | } 46 | 47 | /** 48 | * Write the high-order CRF to a file. 49 | * @param filename Name of the output file 50 | */ 51 | public void write(String filename) throws Exception { 52 | PrintWriter out = new PrintWriter(new FileOutputStream(filename)); 53 | out.println(lambda.length); 54 | for (int i = 0; i < lambda.length; i++) { 55 | out.println(lambda[i]); 56 | } 57 | out.close(); 58 | } 59 | 60 | /** 61 | * Read the high-order CRF from a file. 62 | * @param filename Name of the input file 63 | */ 64 | public void read(String filename) throws Exception { 65 | BufferedReader in = new BufferedReader(new FileReader(filename)); 66 | int featureNum = Integer.parseInt(in.readLine()); 67 | lambda = new double[featureNum]; 68 | for (int i = 0; i < featureNum; i++) { 69 | String line = in.readLine(); 70 | lambda[i] = Double.parseDouble(line); 71 | } 72 | in.close(); 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/HOSemiCRF/HighOrderSemiCRF.java: -------------------------------------------------------------------------------- 1 | package HOSemiCRF; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | import edu.stanford.nlp.optimization.*; 6 | import Parallel.*; 7 | 8 | /** 9 | * High-order semi-CRF class 10 | * @author Nguyen Viet Cuong 11 | */ 12 | public class HighOrderSemiCRF { 13 | 14 | FeatureGenerator featureGen; // Feature generator 15 | double[] lambda; // Feature weight vector 16 | 17 | /** 18 | * Construct and initialize a high-order semi-CRF from feature generator. 19 | * @param fgen Feature generator 20 | */ 21 | public HighOrderSemiCRF(FeatureGenerator fgen) { 22 | featureGen = fgen; 23 | lambda = new double[featureGen.featureMap.size()]; 24 | Arrays.fill(lambda, 0.0); 25 | } 26 | 27 | /** 28 | * Train a high-order semi-CRF from data. 29 | * @param data Training data 30 | */ 31 | public void train(ArrayList data) { 32 | QNMinimizer qn = new QNMinimizer(); 33 | Function df = new Function(featureGen, data); 34 | lambda = qn.minimize(df, featureGen.params.epsForConvergence, lambda, featureGen.params.maxIters); 35 | } 36 | 37 | /** 38 | * Run Viterbi algorithm on testing data. 39 | * @param data Testing data 40 | */ 41 | public void runViterbi(ArrayList data) throws Exception { 42 | Viterbi tester = new Viterbi(featureGen, lambda, data); 43 | Scheduler sch = new Scheduler(tester, featureGen.params.numthreads, Scheduler.DYNAMIC_NEXT_AVAILABLE); 44 | sch.run(); 45 | } 46 | 47 | /** 48 | * Write the high-order semi-CRF to a file. 49 | * @param filename Name of the output file 50 | */ 51 | public void write(String filename) throws Exception { 52 | PrintWriter out = new PrintWriter(new FileOutputStream(filename)); 53 | out.println(lambda.length); 54 | for (int i = 0; i < lambda.length; i++) { 55 | out.println(lambda[i]); 56 | } 57 | out.close(); 58 | } 59 | 60 | /** 61 | * Read the high-order semi-CRF from a file. 62 | * @param filename Name of the input file 63 | */ 64 | public void read(String filename) throws Exception { 65 | BufferedReader in = new BufferedReader(new FileReader(filename)); 66 | int featureNum = Integer.parseInt(in.readLine()); 67 | lambda = new double[featureNum]; 68 | for (int i = 0; i < featureNum; i++) { 69 | String line = in.readLine(); 70 | lambda[i] = Double.parseDouble(line); 71 | } 72 | in.close(); 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/HOCRF/DataSequence.java: -------------------------------------------------------------------------------- 1 | package HOCRF; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | 6 | /** 7 | * Class for a data sequence 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class DataSequence { 11 | 12 | Object[] inputs; // Observation array 13 | int[] labels; // Label array 14 | ArrayList[][] features; // Map from [pos,patID] to list of feature IDs 15 | LabelMap labelmap; // Map from label strings to their IDs 16 | 17 | /** 18 | * Construct a data sequence from a label map, labels and observations with default segmentation. 19 | * @param ls Label array 20 | * @param inps Observation array 21 | * @param labelm Label map 22 | */ 23 | public DataSequence(int[] ls, Object[] inps, LabelMap labelm) { 24 | labels = ls; 25 | inputs = inps; 26 | labelmap = labelm; 27 | } 28 | 29 | /** 30 | * Return length of the current data sequence. 31 | * @return Length of the current sequence 32 | */ 33 | public int length() { 34 | return labels.length; 35 | } 36 | 37 | /** 38 | * Return label at a position. 39 | * @param pos Input position 40 | * @return Label at the input position 41 | */ 42 | public int y(int pos) { 43 | return labels[pos]; 44 | } 45 | 46 | /** 47 | * Return observation at a position. 48 | * @param pos Input position 49 | * @return Observation at the input position 50 | */ 51 | public Object x(int pos) { 52 | if (pos < 0 || pos >= inputs.length) { 53 | return ""; 54 | } 55 | return inputs[pos]; 56 | } 57 | 58 | /** 59 | * Set the label at an input position. 60 | * @param pos Input position 61 | * @param newY New label to be set at the input position 62 | */ 63 | public void set_y(int pos, int newY) { 64 | labels[pos] = newY; 65 | } 66 | 67 | /** 68 | * Return the label map of this sequence. 69 | * @return The label map. 70 | */ 71 | public LabelMap getLabelMap() { 72 | return labelmap; 73 | } 74 | 75 | /** 76 | * Return the list of features at a position and a label pattern. 77 | * @param pos Input position 78 | * @param patID Pattern ID 79 | * @return List of features 80 | */ 81 | public ArrayList getFeatures(int pos, int patID) { 82 | return features[pos][patID]; 83 | } 84 | 85 | /** 86 | * Write a data sequence to a buffered writer. 87 | * @param bw Buffered writer 88 | */ 89 | public void writeToBuffer(BufferedWriter bw) throws Exception { 90 | for (int i = 0; i < labels.length; i++) { 91 | bw.write(x(i) + " " + labelmap.revMap(labels[i]) + "\n"); 92 | } 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /src/HOCRF/SentenceFeatGenerator.java: -------------------------------------------------------------------------------- 1 | package HOCRF; 2 | 3 | import java.util.*; 4 | import Parallel.*; 5 | 6 | /** 7 | * Generator class for the features in each sequence 8 | * @author Nguyen Viet Cuong 9 | * @author Sumit Bhagwani 10 | */ 11 | public class SentenceFeatGenerator implements Schedulable { 12 | 13 | int curID; // Current task ID (for parallelization) 14 | ArrayList trainData; // List of training sequences 15 | FeatureGenerator featGen; // Feature generator 16 | 17 | /** 18 | * Construct a generator for the features. 19 | * @param data Training data 20 | * @param fgen Feature generator 21 | */ 22 | public SentenceFeatGenerator(ArrayList data, FeatureGenerator fgen) { 23 | curID = -1; 24 | trainData = data; 25 | featGen = fgen; 26 | } 27 | 28 | /** 29 | * Compute the features for all the positions in a given sequence. 30 | * @param taskID Index of the training sequence 31 | * @return The updated sequence 32 | */ 33 | public Object compute(int taskID) { 34 | DataSequence seq = (DataSequence) trainData.get(taskID); 35 | seq.features = new ArrayList[seq.length()][featGen.patternMap.size()]; 36 | 37 | for (int pos = 0; pos < seq.length(); pos++) { 38 | for (int patID = 0; patID < featGen.patternMap.size(); patID++) { 39 | seq.features[pos][patID] = new ArrayList(); 40 | ArrayList obs = featGen.generateObs(seq, pos); 41 | for (String o : obs) { 42 | Integer oID = featGen.getObsIndex(o); 43 | if (oID != null) { 44 | Integer feat = (Integer) featGen.featureMap.get(new FeatureIndex(oID, patID)); 45 | if (feat != null) { 46 | seq.features[pos][patID].add(feat); 47 | } 48 | } 49 | } 50 | } 51 | } 52 | 53 | return seq; 54 | } 55 | 56 | /** 57 | * Return the number of tasks (for parallelization). 58 | * @return Training data size 59 | */ 60 | public int getNumTasks() { 61 | return trainData.size(); 62 | } 63 | 64 | /** 65 | * Return the next task ID (for parallelization). 66 | * @return Index of the next sequence 67 | */ 68 | public synchronized int fetchCurrTaskID() { 69 | if (curID < getNumTasks()) { 70 | curID++; 71 | } 72 | return curID; 73 | } 74 | 75 | /** 76 | * Update partial result (for parallelization). 77 | * Note that this method does nothing in this case. 78 | * @param partialResult Partial result 79 | */ 80 | public synchronized void update(Object partialResult) { 81 | // Do nothing 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /src/Applications/PuncConverter.java: -------------------------------------------------------------------------------- 1 | package Applications; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | import HOSemiCRF.*; 6 | 7 | /** 8 | * Class for converting the punctuation labels 9 | * @author Nguyen Viet Cuong 10 | */ 11 | public class PuncConverter { 12 | 13 | /** 14 | * Convert punctuations of a dataset. 15 | * @param inFilename Name of the original dataset file 16 | * @param outFilename Name of the new dataset file 17 | */ 18 | public static void convert(String inFilename, String outFilename) throws Exception { 19 | LabelMap labelmap = new LabelMap(); 20 | DataSet data = readInFile(inFilename, labelmap); 21 | data.writeToFile(outFilename); 22 | } 23 | 24 | /** 25 | * Read the original file and convert labels. 26 | * @param filename Name of the input file 27 | * @param labelmap Label map 28 | * @return The training data with new labels 29 | */ 30 | static DataSet readInFile(String filename, LabelMap labelmap) throws Exception { 31 | BufferedReader in = new BufferedReader(new FileReader(filename)); 32 | 33 | ArrayList td = new ArrayList(); 34 | ArrayList inps = new ArrayList(); 35 | ArrayList labels = new ArrayList(); 36 | String line; 37 | 38 | while ((line = in.readLine()) != null) { 39 | if (line.length() > 0) { 40 | StringTokenizer toks = new StringTokenizer(line); 41 | String word = toks.nextToken(); 42 | String tagRel = toks.nextToken(); 43 | inps.add(word); 44 | labels.add(tagRel); 45 | } else if (labels.size() > 0) { 46 | changeLabel(labels); 47 | td.add(new DataSequence(labelmap.mapArrayList(labels), inps.toArray(), labelmap)); 48 | inps = new ArrayList(); 49 | labels = new ArrayList(); 50 | } 51 | } 52 | if (labels.size() > 0) { 53 | changeLabel(labels); 54 | td.add(new DataSequence(labelmap.mapArrayList(labels), inps.toArray(), labelmap)); 55 | } 56 | 57 | in.close(); 58 | return new DataSet(td); 59 | } 60 | 61 | /** 62 | * Change the labels of a sequence. 63 | * @param labels Label sequence 64 | */ 65 | static void changeLabel(ArrayList labels) { 66 | for (int i = 0; i < labels.size(); i++) { 67 | if (labels.get(i).matches("[,.!?]")) { 68 | String last = labels.get(i); 69 | int j = i; 70 | while (j > 0 && labels.get(j-1).equals("O")) { 71 | j--; 72 | } 73 | labels.set(j, labels.get(j) + "-" + last); 74 | } 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/HOCRF/Utility.java: -------------------------------------------------------------------------------- 1 | package HOCRF; 2 | 3 | import java.util.*; 4 | 5 | /** 6 | * Implementations of the utility methods 7 | * @author Nguyen Viet Cuong 8 | */ 9 | public class Utility { 10 | 11 | /** 12 | * Return ln(exp(a) + exp(b)) for two real numbers a and b. 13 | * @param a First real number 14 | * @param b Second real number 15 | * @return ln(exp(a) + exp(b)) 16 | */ 17 | public static double logSumExp(double a, double b) { 18 | if (a == Double.NEGATIVE_INFINITY) { 19 | return b; 20 | } else if (b == Double.NEGATIVE_INFINITY) { 21 | return a; 22 | } else if (a > b) { 23 | return a + Math.log(1 + Math.exp(b - a)); 24 | } else { 25 | return b + Math.log(1 + Math.exp(a - b)); 26 | } 27 | } 28 | 29 | /** 30 | * Generate all proper prefixes of a label pattern. 31 | * @param labelPat Label pattern 32 | * @return List of proper prefixes 33 | */ 34 | public static ArrayList generateProperPrefixes(String labelPat) { 35 | String pats = new String(labelPat); 36 | ArrayList res = new ArrayList(); 37 | while (pats.contains("|")) { 38 | pats = pats.substring(pats.indexOf('|') + 1); 39 | if (pats.contains("|")) { 40 | res.add(pats); 41 | } 42 | } 43 | return res; 44 | } 45 | 46 | /** 47 | * Generate all suffixes of a label pattern. 48 | * @param labelPat Label pattern 49 | * @return List of suffixes 50 | */ 51 | public static ArrayList generateSuffixes(String labelPat) { 52 | String pats = new String(labelPat); 53 | ArrayList res = new ArrayList(); 54 | res.add(pats); 55 | while (pats.contains("|")) { 56 | pats = pats.substring(0, pats.lastIndexOf("|")); 57 | res.add(pats); 58 | } 59 | return res; 60 | } 61 | 62 | /** 63 | * Return the last label in a label pattern. 64 | * @param labelPat Label pattern 65 | * @return The last label in the pattern 66 | */ 67 | public static String getLastLabel(String labelPat) { 68 | String pats = new String(labelPat); 69 | if (pats.contains("|")) { 70 | return pats.substring(0, pats.indexOf("|")); 71 | } else { 72 | return pats; 73 | } 74 | } 75 | 76 | /** 77 | * Return the order of a label pattern. 78 | * @param labelPat Label pattern 79 | * @return The order of the pattern 80 | */ 81 | public static int getOrder(String labelPat) { 82 | int res = 0; 83 | String pats = new String(labelPat); 84 | while (pats.contains("|")) { 85 | pats = pats.substring(pats.indexOf('|') + 1); 86 | res++; 87 | } 88 | return res; 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /src/HOSemiCRF/Utility.java: -------------------------------------------------------------------------------- 1 | package HOSemiCRF; 2 | 3 | import java.util.*; 4 | 5 | /** 6 | * Implementations of the utility methods 7 | * @author Nguyen Viet Cuong 8 | */ 9 | public class Utility { 10 | 11 | /** 12 | * Return ln(exp(a) + exp(b)) for two real numbers a and b. 13 | * @param a First real number 14 | * @param b Second real number 15 | * @return ln(exp(a) + exp(b)) 16 | */ 17 | public static double logSumExp(double a, double b) { 18 | if (a == Double.NEGATIVE_INFINITY) { 19 | return b; 20 | } else if (b == Double.NEGATIVE_INFINITY) { 21 | return a; 22 | } else if (a > b) { 23 | return a + Math.log(1 + Math.exp(b - a)); 24 | } else { 25 | return b + Math.log(1 + Math.exp(a - b)); 26 | } 27 | } 28 | 29 | /** 30 | * Generate all proper prefixes of a label pattern. 31 | * @param labelPat Label pattern 32 | * @return List of proper prefixes 33 | */ 34 | public static ArrayList generateProperPrefixes(String labelPat) { 35 | String pats = new String(labelPat); 36 | ArrayList res = new ArrayList(); 37 | while (pats.contains("|")) { 38 | pats = pats.substring(pats.indexOf('|') + 1); 39 | if (pats.contains("|")) { 40 | res.add(pats); 41 | } 42 | } 43 | return res; 44 | } 45 | 46 | /** 47 | * Generate all suffixes of a label pattern. 48 | * @param labelPat Label pattern 49 | * @return List of suffixes 50 | */ 51 | public static ArrayList generateSuffixes(String labelPat) { 52 | String pats = new String(labelPat); 53 | ArrayList res = new ArrayList(); 54 | res.add(pats); 55 | while (pats.contains("|")) { 56 | pats = pats.substring(0, pats.lastIndexOf("|")); 57 | res.add(pats); 58 | } 59 | return res; 60 | } 61 | 62 | /** 63 | * Return the last label in a label pattern. 64 | * @param labelPat Label pattern 65 | * @return The last label in the pattern 66 | */ 67 | public static String getLastLabel(String labelPat) { 68 | String pats = new String(labelPat); 69 | if (pats.contains("|")) { 70 | return pats.substring(0, pats.indexOf("|")); 71 | } else { 72 | return pats; 73 | } 74 | } 75 | 76 | /** 77 | * Return the order of a label pattern. 78 | * @param labelPat Label pattern 79 | * @return The order of the pattern 80 | */ 81 | public static int getOrder(String labelPat) { 82 | int res = 0; 83 | String pats = new String(labelPat); 84 | while (pats.contains("|")) { 85 | pats = pats.substring(pats.indexOf('|') + 1); 86 | res++; 87 | } 88 | return res; 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /src/HOSemiCRF/SentenceObsGenerator.java: -------------------------------------------------------------------------------- 1 | package HOSemiCRF; 2 | 3 | import java.util.*; 4 | import Parallel.*; 5 | 6 | /** 7 | * Generator class for the observations in each sequence 8 | * @author Nguyen Viet Cuong 9 | * @author Sumit Bhagwani 10 | */ 11 | public class SentenceObsGenerator implements Schedulable { 12 | 13 | int curID; // Current task ID (for parallelization) 14 | ArrayList trainData; // List of training sequences 15 | FeatureGenerator featGen; // Feature generator 16 | 17 | /** 18 | * Construct a generator for the observations. 19 | * @param data Training data 20 | * @param fgen Feature generator 21 | */ 22 | public SentenceObsGenerator(ArrayList data, FeatureGenerator fgen) { 23 | curID = -1; 24 | trainData = data; 25 | featGen = fgen; 26 | } 27 | 28 | /** 29 | * Compute the observations for all the subsequences in a given sequence. 30 | * @param taskID Index of the training sequence 31 | * @return The updated sequence 32 | */ 33 | public Object compute(int taskID) { 34 | DataSequence seq = (DataSequence) trainData.get(taskID); 35 | seq.observationMap = new int[seq.length()][][]; 36 | 37 | for (int segStart = 0; segStart < seq.length(); segStart++) { 38 | int maxLength = Math.min(featGen.params.maxSegment, seq.length() - segStart); 39 | seq.observationMap[segStart] = new int[maxLength][]; 40 | for (int segEnd = segStart; segEnd - segStart < maxLength; segEnd++) { 41 | ArrayList obsIndices = new ArrayList(); 42 | ArrayList obs = featGen.generateObs(seq, segStart, segEnd); 43 | for (String o : obs) { 44 | Integer oID = featGen.getObsIndex(o); 45 | if (oID != null) { 46 | obsIndices.add(oID); 47 | } 48 | } 49 | 50 | int d = segEnd-segStart; 51 | seq.observationMap[segStart][d] = new int[obsIndices.size()]; 52 | for (int i = 0; i < obsIndices.size(); i++) { 53 | seq.observationMap[segStart][d][i] = obsIndices.get(i); 54 | } 55 | } 56 | } 57 | 58 | return seq; 59 | } 60 | 61 | /** 62 | * Return the number of tasks (for parallelization). 63 | * @return Training data size 64 | */ 65 | public int getNumTasks() { 66 | return trainData.size(); 67 | } 68 | 69 | /** 70 | * Return the next task ID (for parallelization). 71 | * @return Index of the next sequence 72 | */ 73 | public synchronized int fetchCurrTaskID() { 74 | if (curID < getNumTasks()) { 75 | curID++; 76 | } 77 | return curID; 78 | } 79 | 80 | /** 81 | * Update partial result (for parallelization). 82 | * Note that this method does nothing in this case. 83 | * @param partialResult Partial result 84 | */ 85 | public synchronized void update(Object partialResult) { 86 | // Do nothing 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/HOCRF/LabelMap.java: -------------------------------------------------------------------------------- 1 | package HOCRF; 2 | 3 | import java.util.*; 4 | import java.io.*; 5 | 6 | /** 7 | * Label map class 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class LabelMap { 11 | 12 | ArrayList mapTable; // List of labels 13 | 14 | /** 15 | * Construct an empty label map. 16 | */ 17 | public LabelMap() { 18 | mapTable = new ArrayList(); 19 | } 20 | 21 | /** 22 | * Return the size of the map. 23 | * @return Number of labels 24 | */ 25 | public int size() { 26 | return mapTable.size(); 27 | } 28 | 29 | /** 30 | * Return the index of a label. 31 | * If there is no such label, add it to the map and return its index. 32 | * @param labelStr Input label 33 | * @return Index of the input label 34 | */ 35 | public int map(String labelStr) { 36 | int index = mapTable.indexOf(labelStr); 37 | if (index == -1) { 38 | mapTable.add(labelStr); 39 | index = mapTable.size()-1; 40 | } 41 | return index; 42 | } 43 | 44 | /** 45 | * Return the label with a given index. 46 | * @param l Index of the label 47 | * @return The label string 48 | */ 49 | public String revMap(int l) { 50 | return mapTable.get(l); 51 | } 52 | 53 | /** 54 | * Map labels into their indices. 55 | * @param labelStrList List of label strings 56 | * @return List of indices of the input labels 57 | */ 58 | public int[] mapArrayList(ArrayList labelStrList) { 59 | int[] result = new int[labelStrList.size()]; 60 | for (int i = 0; i < result.length; i++) { 61 | result[i] = map(labelStrList.get(i)); 62 | } 63 | return result; 64 | } 65 | 66 | /** 67 | * Map label indices into label strings. 68 | * @param labels Label array 69 | * @return Label string array 70 | */ 71 | public String[] revArray(int[] labels) { 72 | String[] result = new String[labels.length]; 73 | for (int i = 0; i < labels.length; i++) { 74 | result[i] = revMap(labels[i]); 75 | } 76 | return result; 77 | } 78 | 79 | /** 80 | * Write the label map into a file. 81 | * @param filename Name of the output file 82 | */ 83 | public void write(String filename) throws IOException { 84 | PrintWriter out = new PrintWriter(new FileOutputStream(filename)); 85 | out.println(mapTable.size()); 86 | for (int i = 0; i < mapTable.size(); i++) { 87 | out.println(mapTable.get(i)); 88 | } 89 | out.close(); 90 | } 91 | 92 | /** 93 | * Read the label map from a file. 94 | * @param filename Name of the input file 95 | */ 96 | public void read(String filename) throws IOException { 97 | BufferedReader in = new BufferedReader(new FileReader(filename)); 98 | int size = Integer.parseInt(in.readLine()); 99 | mapTable = new ArrayList(); 100 | for (int i = 0; i < size; i++) { 101 | mapTable.add(in.readLine()); 102 | } 103 | in.close(); 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /src/HOSemiCRF/LabelMap.java: -------------------------------------------------------------------------------- 1 | package HOSemiCRF; 2 | 3 | import java.util.*; 4 | import java.io.*; 5 | 6 | /** 7 | * Label map class 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class LabelMap { 11 | 12 | ArrayList mapTable; // List of labels 13 | 14 | /** 15 | * Construct an empty label map. 16 | */ 17 | public LabelMap() { 18 | mapTable = new ArrayList(); 19 | } 20 | 21 | /** 22 | * Return the size of the map. 23 | * @return Number of labels 24 | */ 25 | public int size() { 26 | return mapTable.size(); 27 | } 28 | 29 | /** 30 | * Return the index of a label. 31 | * If there is no such label, add it to the map and return its index. 32 | * @param labelStr Input label 33 | * @return Index of the input label 34 | */ 35 | public int map(String labelStr) { 36 | int index = mapTable.indexOf(labelStr); 37 | if (index == -1) { 38 | mapTable.add(labelStr); 39 | index = mapTable.size()-1; 40 | } 41 | return index; 42 | } 43 | 44 | /** 45 | * Return the label with a given index. 46 | * @param l Index of the label 47 | * @return The label string 48 | */ 49 | public String revMap(int l) { 50 | return mapTable.get(l); 51 | } 52 | 53 | /** 54 | * Map labels into their indices. 55 | * @param labelStrList List of label strings 56 | * @return List of indices of the input labels 57 | */ 58 | public int[] mapArrayList(ArrayList labelStrList) { 59 | int[] result = new int[labelStrList.size()]; 60 | for (int i = 0; i < result.length; i++) { 61 | result[i] = map(labelStrList.get(i)); 62 | } 63 | return result; 64 | } 65 | 66 | /** 67 | * Map label indices into label strings. 68 | * @param labels Label array 69 | * @return Label string array 70 | */ 71 | public String[] revArray(int[] labels) { 72 | String[] result = new String[labels.length]; 73 | for (int i = 0; i < labels.length; i++) { 74 | result[i] = revMap(labels[i]); 75 | } 76 | return result; 77 | } 78 | 79 | /** 80 | * Write the label map into a file. 81 | * @param filename Name of the output file 82 | */ 83 | public void write(String filename) throws IOException { 84 | PrintWriter out = new PrintWriter(new FileOutputStream(filename)); 85 | out.println(mapTable.size()); 86 | for (int i = 0; i < mapTable.size(); i++) { 87 | out.println(mapTable.get(i)); 88 | } 89 | out.close(); 90 | } 91 | 92 | /** 93 | * Read the label map from a file. 94 | * @param filename Name of the input file 95 | */ 96 | public void read(String filename) throws IOException { 97 | BufferedReader in = new BufferedReader(new FileReader(filename)); 98 | int size = Integer.parseInt(in.readLine()); 99 | mapTable = new ArrayList(); 100 | for (int i = 0; i < size; i++) { 101 | mapTable.add(in.readLine()); 102 | } 103 | in.close(); 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /src/HOCRF/Function.java: -------------------------------------------------------------------------------- 1 | package HOCRF; 2 | 3 | import java.util.*; 4 | import edu.stanford.nlp.optimization.*; 5 | import Parallel.*; 6 | 7 | /** 8 | * Loglikelihood function class 9 | * @author Nguyen Viet Cuong 10 | */ 11 | public class Function implements DiffFunction { 12 | 13 | FeatureGenerator featureGen; // Feature generator 14 | ArrayList trainData; // List of training sequences 15 | 16 | // Private data structures to compute function value and derivatives 17 | private Loglikelihood logli; // Loglikelihood values 18 | private double lambdaCache[]; // Cache of lambda vector for reuse 19 | 20 | /** 21 | * Construct a function from feature generator and data. 22 | * @param fgen Feature generator 23 | * @param data Training data 24 | */ 25 | public Function(FeatureGenerator fgen, ArrayList data) { 26 | featureGen = fgen; 27 | trainData = data; 28 | lambdaCache = null; 29 | } 30 | 31 | /** 32 | * Return the dimension of the domain. 33 | * @return Domain dimension 34 | */ 35 | public int domainDimension() { 36 | return featureGen.featureMap.size(); 37 | } 38 | 39 | /** 40 | * Return the loglikelihood given a lambda vector. 41 | * @param lambda Lambda vector 42 | * @return Loglikelihood value 43 | */ 44 | public double valueAt(double[] lambda) { 45 | if (Arrays.equals(lambda, lambdaCache)) { 46 | return logli.logli; 47 | } else { 48 | lambdaCache = (double[]) lambda.clone(); 49 | computeValueAndDerivatives(lambda); 50 | return logli.logli; 51 | } 52 | } 53 | 54 | /** 55 | * Return the first derivative of the loglikelihood function given a lambda vector. 56 | * @param lambda Lambda vector 57 | * @return First derivatives 58 | */ 59 | public double[] derivativeAt(double[] lambda) { 60 | if (Arrays.equals(lambda, lambdaCache)) { 61 | return logli.derivatives; 62 | } else { 63 | lambdaCache = (double[]) lambda.clone(); 64 | computeValueAndDerivatives(lambda); 65 | return logli.derivatives; 66 | } 67 | } 68 | 69 | /** 70 | * Compute the values and derivatives of the loglikelihood function. 71 | * @param lambda Lambda vector 72 | */ 73 | public void computeValueAndDerivatives(double[] lambda) { 74 | logli = new Loglikelihood(lambda.length); 75 | for (int i = 0; i < lambda.length; i++) { 76 | logli.logli -= ((lambda[i] * lambda[i]) * featureGen.params.invSigmaSquare) / 2; 77 | logli.derivatives[i] -= (lambda[i] * featureGen.params.invSigmaSquare); 78 | } 79 | 80 | LogliComputer logliComp = new LogliComputer(lambda, featureGen, trainData, logli); 81 | Scheduler sch = new Scheduler(logliComp, featureGen.params.numthreads, Scheduler.DYNAMIC_NEXT_AVAILABLE); 82 | try { 83 | sch.run(); 84 | } catch (Exception e) { 85 | System.out.println("Errors occur when training in parallel! " + e); 86 | } 87 | 88 | // Change sign to maximize and divide the values by size of dataset 89 | int n = trainData.size(); 90 | for (int i = 0; i < logli.derivatives.length; i++) { 91 | logli.derivatives[i] = -(logli.derivatives[i] / n); 92 | } 93 | logli.logli = -(logli.logli / n); 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/HOSemiCRF/Function.java: -------------------------------------------------------------------------------- 1 | package HOSemiCRF; 2 | 3 | import java.util.*; 4 | import edu.stanford.nlp.optimization.*; 5 | import Parallel.*; 6 | 7 | /** 8 | * Loglikelihood function class 9 | * @author Nguyen Viet Cuong 10 | */ 11 | public class Function implements DiffFunction { 12 | 13 | FeatureGenerator featureGen; // Feature generator 14 | ArrayList trainData; // List of training sequences 15 | 16 | // Private data structures to compute function value and derivatives 17 | private Loglikelihood logli; // Loglikelihood values 18 | private double lambdaCache[]; // Cache of lambda vector for reuse 19 | 20 | /** 21 | * Construct a function from feature generator and data. 22 | * @param fgen Feature generator 23 | * @param data Training data 24 | */ 25 | public Function(FeatureGenerator fgen, ArrayList data) { 26 | featureGen = fgen; 27 | trainData = data; 28 | lambdaCache = null; 29 | } 30 | 31 | /** 32 | * Return the dimension of the domain. 33 | * @return Domain dimension 34 | */ 35 | public int domainDimension() { 36 | return featureGen.featureMap.size(); 37 | } 38 | 39 | /** 40 | * Return the loglikelihood given a lambda vector. 41 | * @param lambda Lambda vector 42 | * @return Loglikelihood value 43 | */ 44 | public double valueAt(double[] lambda) { 45 | if (Arrays.equals(lambda, lambdaCache)) { 46 | return logli.logli; 47 | } else { 48 | lambdaCache = (double[]) lambda.clone(); 49 | computeValueAndDerivatives(lambda); 50 | return logli.logli; 51 | } 52 | } 53 | 54 | /** 55 | * Return the first derivative of the loglikelihood function given a lambda vector. 56 | * @param lambda Lambda vector 57 | * @return First derivatives 58 | */ 59 | public double[] derivativeAt(double[] lambda) { 60 | if (Arrays.equals(lambda, lambdaCache)) { 61 | return logli.derivatives; 62 | } else { 63 | lambdaCache = (double[]) lambda.clone(); 64 | computeValueAndDerivatives(lambda); 65 | return logli.derivatives; 66 | } 67 | } 68 | 69 | /** 70 | * Compute the values and derivatives of the loglikelihood function. 71 | * @param lambda Lambda vector 72 | */ 73 | public void computeValueAndDerivatives(double[] lambda) { 74 | logli = new Loglikelihood(lambda.length); 75 | for (int i = 0; i < lambda.length; i++) { 76 | logli.logli -= ((lambda[i] * lambda[i]) * featureGen.params.invSigmaSquare) / 2; 77 | logli.derivatives[i] -= (lambda[i] * featureGen.params.invSigmaSquare); 78 | } 79 | 80 | LogliComputer logliComp = new LogliComputer(lambda, featureGen, trainData, logli); 81 | Scheduler sch = new Scheduler(logliComp, featureGen.params.numthreads, Scheduler.DYNAMIC_NEXT_AVAILABLE); 82 | try { 83 | sch.run(); 84 | } catch (Exception e) { 85 | System.out.println("Errors occur when training in parallel! " + e); 86 | } 87 | 88 | // Change sign to maximize and divide the values by size of dataset 89 | int n = trainData.size(); 90 | for (int i = 0; i < logli.derivatives.length; i++) { 91 | logli.derivatives[i] = -(logli.derivatives[i] / n); 92 | } 93 | logli.logli = -(logli.logli / n); 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/HOCRF/Viterbi.java: -------------------------------------------------------------------------------- 1 | package HOCRF; 2 | 3 | import java.util.*; 4 | import Parallel.*; 5 | 6 | /** 7 | * Implementation of the Viterbi algorithm 8 | * @author Nguyen Viet Cuong 9 | * @author Sumit Bhagwani 10 | */ 11 | public class Viterbi implements Schedulable { 12 | 13 | int curID; // Current task ID (for parallelization) 14 | FeatureGenerator featureGen; // Feature generator 15 | double[] lambda; // Lambda vector 16 | ArrayList data; // List of testing sequences 17 | final int BASE = 1; // Base of the logAlpha array 18 | 19 | /** 20 | * Construct a Viterbi class. 21 | * @param featureGen Feature generator 22 | * @param lambda Lambda vector 23 | * @param data Testing data 24 | */ 25 | public Viterbi(FeatureGenerator featureGen, double[] lambda, ArrayList data) { 26 | curID = -1; 27 | this.featureGen = featureGen; 28 | this.lambda = lambda; 29 | this.data = data; 30 | } 31 | 32 | /** 33 | * Run the Viterbi algorithm for a given sequence. 34 | * @param taskID Index of the sequence 35 | * @return The updated sequence 36 | */ 37 | public Object compute(int taskID) { 38 | DataSequence seq = (DataSequence) data.get(taskID); 39 | double maxScore[][] = new double[seq.length() + 1][featureGen.forwardStateMap.size()]; 40 | String trace[][] = new String[seq.length()][featureGen.forwardStateMap.size()]; 41 | 42 | Arrays.fill(maxScore[0], Double.NEGATIVE_INFINITY); 43 | maxScore[0][0] = 0.0; 44 | for (int j = 0; j < seq.length(); j++) { 45 | Arrays.fill(maxScore[j + BASE], Double.NEGATIVE_INFINITY); 46 | for (int i = 1; i < featureGen.forwardStateMap.size(); i++) { 47 | ArrayList prevState1 = featureGen.forwardTransition1[i]; 48 | ArrayList prevState2 = featureGen.forwardTransition2[i]; 49 | for (int k = 0; k < prevState1.size(); k++) { 50 | int pkID = prevState1.get(k); 51 | int pkyID = prevState2.get(k); 52 | String pky = featureGen.backwardStateList.get(pkyID); 53 | ArrayList features = featureGen.generateFeatures(seq, j, pky); 54 | ArrayList feats = featureGen.getFeatureID(features); 55 | double featuresScore = featureGen.computeFeatureScores(feats, lambda); 56 | if (maxScore[j + BASE][i] < featuresScore + maxScore[j + BASE - 1][pkID]) { 57 | maxScore[j + BASE][i] = featuresScore + maxScore[j + BASE - 1][pkID]; 58 | trace[j][i] = pkID + " " + Utility.getLastLabel(pky); 59 | } 60 | } 61 | } 62 | } 63 | 64 | // Compute max score for last element 65 | double max = Double.NEGATIVE_INFINITY; 66 | String tracemax = ""; 67 | for (int i = 0; i < featureGen.forwardStateMap.size(); i++) { 68 | if (max < maxScore[seq.length() + BASE - 1][i]) { 69 | max = maxScore[seq.length() + BASE - 1][i]; 70 | tracemax = trace[seq.length() - 1][i]; 71 | } 72 | } 73 | 74 | // Trace back 75 | int currPos = seq.length() - 1; 76 | while (currPos >= 0) { 77 | StringTokenizer toks = new StringTokenizer(tracemax); 78 | int prevPat = Integer.parseInt(toks.nextToken()); 79 | int currY = Integer.parseInt(toks.nextToken()); 80 | seq.set_y(currPos, currY); 81 | currPos--; 82 | if (currPos >= 0) { 83 | tracemax = trace[currPos][prevPat]; 84 | } 85 | } 86 | 87 | return seq; 88 | } 89 | 90 | /** 91 | * Return total number of tasks (for parallelization). 92 | * @return Training dataset size 93 | */ 94 | public int getNumTasks() { 95 | return data.size(); 96 | } 97 | 98 | /** 99 | * Return the next task ID (for parallelization). 100 | * @return The next sequence ID 101 | */ 102 | public synchronized int fetchCurrTaskID() { 103 | if (curID < getNumTasks()) { 104 | curID++; 105 | } 106 | return curID; 107 | } 108 | 109 | /** 110 | * Update partial result (for parallelization). 111 | * Note that this method does nothing in this case. 112 | * @param partialResult Partial result 113 | */ 114 | public void update(Object partialResult) { 115 | // Do nothing 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /src/HOSemiCRF/Viterbi.java: -------------------------------------------------------------------------------- 1 | package HOSemiCRF; 2 | 3 | import java.util.*; 4 | import Parallel.*; 5 | 6 | /** 7 | * Implementation of the Viterbi algorithm 8 | * @author Nguyen Viet Cuong 9 | * @author Sumit Bhagwani 10 | */ 11 | public class Viterbi implements Schedulable { 12 | 13 | int curID; // Current task ID (for parallelization) 14 | FeatureGenerator featureGen; // Feature generator 15 | double[] lambda; // Lambda vector 16 | ArrayList data; // List of training sequences 17 | final int BASE = 1; // Base of the logAlpha array 18 | 19 | /** 20 | * Construct a Viterbi class. 21 | * @param featureGen Feature generator 22 | * @param lambda Lambda vector 23 | * @param data Training data 24 | */ 25 | public Viterbi(FeatureGenerator featureGen, double[] lambda, ArrayList data) { 26 | curID = -1; 27 | this.featureGen = featureGen; 28 | this.lambda = lambda; 29 | this.data = data; 30 | } 31 | 32 | /** 33 | * Run the Viterbi algorithm for a given sequence. 34 | * @param taskID Index of the sequence 35 | * @return The updated sequence 36 | */ 37 | public Object compute(int taskID) { 38 | DataSequence seq = (DataSequence) data.get(taskID); 39 | double maxScore[][] = new double[seq.length() + 1][featureGen.forwardStateMap.size()]; 40 | String trace[][] = new String[seq.length()][featureGen.forwardStateMap.size()]; 41 | 42 | Arrays.fill(maxScore[0], Double.NEGATIVE_INFINITY); 43 | maxScore[0][0] = 0.0; 44 | for (int j = 0; j < seq.length(); j++) { 45 | Arrays.fill(maxScore[j + BASE], Double.NEGATIVE_INFINITY); 46 | for (int i = 0; i < featureGen.forwardStateMap.size(); i++) { 47 | int y = featureGen.lastForwardStateLabel[i]; 48 | int maxmem = (y == -1) ? 0 : featureGen.maxMemory[y]; 49 | 50 | ArrayList prevState1 = featureGen.forwardTransition1[i]; 51 | ArrayList prevState2 = featureGen.forwardTransition2[i]; 52 | for (int d = 0; d < maxmem && j - d >= 0; d++) { 53 | for (int k = 0; k < prevState1.size(); k++) { 54 | int pkID = prevState1.get(k); 55 | int pkyID = prevState2.get(k); 56 | String pky = featureGen.backwardStateList.get(pkyID); 57 | ArrayList features = featureGen.generateFeatures(seq, j-d, j, pky); 58 | ArrayList feats = featureGen.getFeatureID(features); 59 | double featuresScore = featureGen.computeFeatureScores(feats, lambda); 60 | if (maxScore[j + BASE][i] < featuresScore + maxScore[j + BASE - d - 1][pkID]) { 61 | maxScore[j + BASE][i] = featuresScore + maxScore[j + BASE - d - 1][pkID]; 62 | trace[j][i] = (j - d - 1) + " " + pkID + " " + y; 63 | } 64 | } 65 | } 66 | } 67 | } 68 | 69 | // Compute max score for last element 70 | double max = Double.NEGATIVE_INFINITY; 71 | String tracemax = ""; 72 | for (int i = 0; i < featureGen.forwardStateMap.size(); i++) { 73 | if (max < maxScore[seq.length() + BASE - 1][i]) { 74 | max = maxScore[seq.length() + BASE - 1][i]; 75 | tracemax = trace[seq.length() - 1][i]; 76 | } 77 | } 78 | 79 | // Trace back 80 | int currPos = seq.length() - 1; 81 | while (currPos >= 0) { 82 | StringTokenizer toks = new StringTokenizer(tracemax); 83 | int prevPos = Integer.parseInt(toks.nextToken()); 84 | int prevPat = Integer.parseInt(toks.nextToken()); 85 | int currY = Integer.parseInt(toks.nextToken()); 86 | seq.setSegment(prevPos + 1, currPos, currY); 87 | currPos = prevPos; 88 | if (currPos >= 0) { 89 | tracemax = trace[prevPos][prevPat]; 90 | } 91 | } 92 | 93 | return seq; 94 | } 95 | 96 | /** 97 | * Return total number of tasks (for parallelization). 98 | * @return Training dataset size 99 | */ 100 | public int getNumTasks() { 101 | return data.size(); 102 | } 103 | 104 | /** 105 | * Return the next task ID (for parallelization). 106 | * @return The next sequence ID 107 | */ 108 | public synchronized int fetchCurrTaskID() { 109 | if (curID < getNumTasks()) { 110 | curID++; 111 | } 112 | return curID; 113 | } 114 | 115 | /** 116 | * Update partial result (for parallelization). 117 | * Note that this method does nothing in this case. 118 | * @param partialResult Partial result 119 | */ 120 | public void update(Object partialResult) { 121 | // Do nothing 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /src/Parallel/Scheduler.java: -------------------------------------------------------------------------------- 1 | package Parallel; 2 | 3 | import java.util.*; 4 | 5 | /** 6 | * Parallelization scheduler 7 | * @author Ye Nan 8 | */ 9 | public class Scheduler { 10 | 11 | public final static int STATIC_UNIFORM_ALLOCATE = 0; 12 | public final static int DYNAMIC_UNIFORM_ALLOCATE = 1; 13 | public final static int DYNAMIC_NEXT_AVAILABLE = 2; 14 | 15 | private int[] curIDs = null; 16 | 17 | Schedulable task; 18 | int nThreads; 19 | int policy = 0; 20 | TaskThread[] threads; 21 | 22 | public Scheduler(Schedulable task, int nThreads, int policy) { 23 | this.policy = policy; 24 | this.task = task; 25 | this.nThreads = nThreads; 26 | threads = new TaskThread[nThreads]; 27 | for (int i = 0; i < nThreads; i++) { 28 | threads[i] = new TaskThread(task, this, i); 29 | } 30 | 31 | int nTasks = task.getNumTasks(); 32 | if (policy == STATIC_UNIFORM_ALLOCATE) { 33 | int[][] taskIDs = new int[nThreads][]; 34 | for (int i = 0; i < nThreads; i++) { 35 | int n = nTasks / nThreads; 36 | if (i < nTasks % nThreads) { 37 | n++; 38 | } 39 | taskIDs[i] = new int[n]; 40 | for (int j = 0; j < n; j++) { 41 | taskIDs[i][j] = i + j * nThreads; 42 | } 43 | threads[i].setTaskIDs(taskIDs[i]); 44 | } 45 | } 46 | } 47 | 48 | public int fetchTaskID(int threadID) { // no need to synchronize 49 | if (policy == DYNAMIC_UNIFORM_ALLOCATE) { 50 | if (curIDs == null) { 51 | curIDs = new int[nThreads]; 52 | for (int i = 0; i < nThreads; i++) { 53 | curIDs[i] = i - nThreads; 54 | } 55 | } 56 | if (curIDs[threadID] < task.getNumTasks() - 1) { 57 | curIDs[threadID] += nThreads; 58 | } else { 59 | return task.getNumTasks(); 60 | } 61 | return curIDs[threadID]; 62 | } else { 63 | return task.fetchCurrTaskID(); 64 | } 65 | } 66 | 67 | public void run() throws Exception { 68 | for (int i = 0; i < nThreads; i++) { 69 | threads[i].start(); 70 | } 71 | for (int i = 0; i < nThreads; i++) { 72 | threads[i].join(); 73 | } 74 | } 75 | 76 | // Simple test program 77 | public static void main(String[] args) throws Exception { 78 | 79 | class SimpleTask implements Schedulable { 80 | 81 | double[] ans; 82 | int curID; 83 | static final int N = 1000000; 84 | int nTasks; 85 | 86 | public SimpleTask(int nTasks) { 87 | curID = -1; 88 | this.nTasks = nTasks; 89 | ans = new double[2]; 90 | } 91 | 92 | public void showResult() { 93 | System.out.println(ans[0] + " " + ans[1]); 94 | } 95 | 96 | public Object compute(int taskID) { 97 | Random r = new Random(taskID); 98 | int a = r.nextInt(); 99 | for (int i = 0; i < 10000; i++) { 100 | for (int j = 0; j < 10000; j++) { 101 | for (int k = 0; k < 2; k++) { 102 | a = taskID + i + j + k + r.nextInt(); 103 | } 104 | } 105 | } 106 | 107 | double[] result = new double[2]; 108 | result[0] = a; 109 | result[1] = 2; 110 | return result; 111 | } 112 | 113 | public int getNumCompletedTasks() { 114 | return curID; 115 | } 116 | 117 | public int getNumTasks() { 118 | return nTasks; 119 | } 120 | 121 | public synchronized int fetchCurrTaskID() { 122 | if (curID < nTasks) { 123 | curID++; 124 | } 125 | return curID; 126 | } 127 | 128 | public synchronized void update(Object partialResult) { 129 | double[] res = (double[]) partialResult; 130 | ans[0] += res[0]; 131 | ans[1] += res[1]; 132 | //showResult(); 133 | } 134 | } 135 | 136 | if (args.length == 0) { 137 | System.out.println("Usage: java Scheduler nTasks nThreads"); 138 | System.out.println(" Expected observation:"); 139 | System.out.println(" TotalRunningTime ~= SingleTaskTime * (nTasks/nThreads)"); 140 | System.exit(0); 141 | } 142 | 143 | Timer.start(); 144 | (new SimpleTask(1)).compute(0); 145 | Timer.record("single"); 146 | 147 | Timer.start(); 148 | int policy = Scheduler.STATIC_UNIFORM_ALLOCATE; 149 | policy = Scheduler.DYNAMIC_NEXT_AVAILABLE; 150 | (new Scheduler(new SimpleTask(Integer.parseInt(args[0])), Integer.parseInt(args[1]), policy)).run(); 151 | Timer.record("thread"); 152 | Timer.report(); 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /src/HOSemiCRF/DataSequence.java: -------------------------------------------------------------------------------- 1 | package HOSemiCRF; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | 6 | /** 7 | * Class for a data sequence 8 | * @author Nguyen Viet Cuong 9 | * @author Sumit Bhagwani 10 | */ 11 | public class DataSequence { 12 | 13 | Object[] inputs; // Observation array 14 | int[] labels; // Label array 15 | int[] startPos; // Start of a segment 16 | int[] endPos; // End of a segment 17 | int[][][] observationMap; // [startPos, segLength] -> List of observation ID using obsMap 18 | LabelMap labelmap; // Map from label strings to their IDs 19 | 20 | /** 21 | * Construct a data sequence from a label map, labels and observations with default segmentation. 22 | * @param ls Label array 23 | * @param inps Observation array 24 | * @param labelm Label map 25 | */ 26 | public DataSequence(int[] ls, Object[] inps, LabelMap labelm) { 27 | labels = ls; 28 | inputs = inps; 29 | labelmap = labelm; 30 | 31 | // Compute start positions 32 | startPos = new int[labels.length]; 33 | startPos[0] = 0; 34 | for (int i = 1; i < labels.length; i++) { 35 | if (labels[i] == labels[i - 1]) { 36 | startPos[i] = startPos[i - 1]; 37 | } else { 38 | startPos[i] = i; 39 | } 40 | } 41 | 42 | // Compute end positions 43 | endPos = new int[labels.length]; 44 | endPos[labels.length - 1] = labels.length - 1; 45 | for (int i = labels.length - 2; i >= 0; i--) { 46 | if (labels[i] == labels[i + 1]) { 47 | endPos[i] = endPos[i + 1]; 48 | } else { 49 | endPos[i] = i; 50 | } 51 | } 52 | } 53 | 54 | /** 55 | * Return length of the current data sequence. 56 | * @return Length of the current sequence 57 | */ 58 | public int length() { 59 | return labels.length; 60 | } 61 | 62 | /** 63 | * Return label at a position. 64 | * @param pos Input position 65 | * @return Label at the input position 66 | */ 67 | public int y(int pos) { 68 | return labels[pos]; 69 | } 70 | 71 | /** 72 | * Return observation at a position. 73 | * @param pos Input position 74 | * @return Observation at the input position 75 | */ 76 | public Object x(int pos) { 77 | if (pos < 0 || pos >= inputs.length) { 78 | return ""; 79 | } 80 | return inputs[pos]; 81 | } 82 | 83 | /** 84 | * Set the label at an input position. 85 | * @param pos Input position 86 | * @param newY New label to be set at the input position 87 | */ 88 | public void set_y(int pos, int newY) { 89 | labels[pos] = newY; 90 | } 91 | 92 | /** 93 | * Return the start position of the segment that includes a given position. 94 | * @param pos Input position 95 | * @return Start position of the segment that includes the input position 96 | */ 97 | public int getSegmentStart(int pos) { 98 | return startPos[pos]; 99 | } 100 | 101 | /** 102 | * Return the end position of the segment that includes a given position. 103 | * @param pos Input position 104 | * @return End position of the segment that includes the input position 105 | */ 106 | public int getSegmentEnd(int pos) { 107 | return endPos[pos]; 108 | } 109 | 110 | /** 111 | * Set the label of a segment. 112 | * @param startPos Start position of the segment 113 | * @param endPos End position of the segment 114 | * @param newY New label to be set for the segment 115 | */ 116 | public void setSegment(int startPos, int endPos, int newY) { 117 | for (int i = startPos; i <= endPos; i++) { 118 | set_y(i, newY); 119 | } 120 | } 121 | 122 | /** 123 | * Check if a subsequence is a segment. 124 | * @param startPos Start position of the subsequence 125 | * @param endPos End position of the subsequence 126 | * @return true if the subsequence is a segment, false otherwise 127 | */ 128 | public boolean isSegment(int startPos, int endPos) { 129 | int y = y(startPos); 130 | for (int i = startPos + 1; i <= endPos; i++) { 131 | if (y(i) != y) { 132 | return false; 133 | } 134 | } 135 | return true; 136 | } 137 | 138 | /** 139 | * Return the maximum length of any segment in the sequence. 140 | * @return Maximum segment length 141 | */ 142 | public int getMaxSegLength() { 143 | int maxSeg = 0, segStart, segEnd; 144 | for (segStart = 0; segStart < length(); segStart = segEnd + 1) { 145 | segEnd = getSegmentEnd(segStart); 146 | if (segEnd - segStart + 1 > maxSeg) { 147 | maxSeg = segEnd - segStart + 1; 148 | } 149 | } 150 | return maxSeg; 151 | } 152 | 153 | /** 154 | * Return the observations for a subsequence. 155 | * @param segStart Start position of the subsequence 156 | * @param segEnd End position of the subsequence 157 | * @return Array of observation IDs 158 | */ 159 | public int[] getObservation(int segStart, int segEnd) { 160 | return observationMap[segStart][segEnd-segStart]; 161 | } 162 | 163 | /** 164 | * Write a data sequence to a buffered writer. 165 | * @param bw Buffered writer 166 | */ 167 | public void writeToBuffer(BufferedWriter bw) throws Exception { 168 | for (int i = 0; i < labels.length; i++) { 169 | bw.write(x(i) + " " + labelmap.revMap(labels[i]) + "\n"); 170 | } 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /src/OCR/OCR.java: -------------------------------------------------------------------------------- 1 | package OCR; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | import HOCRF.*; 6 | import OCR.Features.*; 7 | 8 | /** 9 | * Handwritten Character Recognition 10 | * @author Nguyen Viet Cuong 11 | */ 12 | public class OCR { 13 | 14 | int trainFold = 0; 15 | HighOrderCRF highOrderCrfModel; // High-order CRF model 16 | FeatureGenerator featureGen; // Feature generator 17 | LabelMap labelmap = new LabelMap(); // Label map 18 | String configFile; // Configuration filename 19 | 20 | public OCR(String filename, String fold) { 21 | configFile = filename; 22 | trainFold = Integer.parseInt(fold); 23 | } 24 | 25 | public DataSet readTagged(String filename, int trainFold, boolean isTraining) throws Exception { 26 | BufferedReader in = new BufferedReader(new FileReader(filename)); 27 | 28 | ArrayList td = new ArrayList(); 29 | ArrayList inps = new ArrayList(); 30 | ArrayList labels = new ArrayList(); 31 | String line; 32 | 33 | while ((line = in.readLine()) != null) { 34 | if (line.length() > 0) { 35 | StringTokenizer toks = new StringTokenizer(line); 36 | 37 | toks.nextToken(); 38 | String tagChar = toks.nextToken(); 39 | int nextID = Integer.parseInt(toks.nextToken()); 40 | toks.nextToken(); 41 | toks.nextToken(); 42 | int fold = Integer.parseInt(toks.nextToken()); 43 | 44 | int[][] p = new int[CharDetails.ROWS][CharDetails.COLS]; 45 | for (int r = 0; r < CharDetails.ROWS; r++) { 46 | for (int c = 0; c < CharDetails.COLS; c++) { 47 | p[r][c] = Integer.parseInt(toks.nextToken()); 48 | } 49 | } 50 | 51 | if (isTraining && fold == trainFold) { 52 | inps.add(new CharDetails(p)); 53 | labels.add(tagChar); 54 | } else if (!isTraining && fold != trainFold) { 55 | inps.add(new CharDetails(p)); 56 | labels.add(tagChar); 57 | } 58 | 59 | if (nextID == -1 && labels.size() > 0) { 60 | td.add(new DataSequence(labelmap.mapArrayList(labels), inps.toArray(), labelmap)); 61 | inps = new ArrayList(); 62 | labels = new ArrayList(); 63 | } 64 | } 65 | } 66 | 67 | in.close(); 68 | return new DataSet(td); 69 | } 70 | 71 | public void createFeatureGenerator() throws Exception { 72 | // Add feature types 73 | ArrayList fts = new ArrayList(); 74 | 75 | fts.add(new Pixel()); 76 | fts.add(new FirstOrderTransition()); 77 | fts.add(new SecondOrderTransition()); 78 | //fts.add(new ThirdOrderTransition()); 79 | //fts.add(new FourthOrderTransition()); 80 | //fts.add(new FifthOrderTransition()); 81 | 82 | // Process parameters 83 | Params params = new Params(configFile, labelmap.size()); 84 | 85 | // Initialize feature generator 86 | featureGen = new FeatureGenerator(fts, params); 87 | } 88 | 89 | public void train() throws Exception { 90 | 91 | // Set training file name and create output directory 92 | String trainFilename = "letter.data"; 93 | File dir = new File("learntModels/fold" + trainFold); 94 | dir.mkdirs(); 95 | 96 | // Read training data and save the label map 97 | DataSet trainData = readTagged(trainFilename, trainFold, true); 98 | labelmap.write("learntModels/fold" + trainFold + "/labelmap"); 99 | 100 | // Create and save feature generator 101 | createFeatureGenerator(); 102 | featureGen.initialize(trainData.getSeqList()); 103 | featureGen.write("learntModels/fold" + trainFold + "/features"); 104 | 105 | // Train and save model 106 | highOrderCrfModel = new HighOrderCRF(featureGen); 107 | highOrderCrfModel.train(trainData.getSeqList()); 108 | highOrderCrfModel.write("learntModels/fold" + trainFold + "/crf"); 109 | } 110 | 111 | public void test() throws Exception { 112 | // Read label map, features, and CRF model 113 | labelmap.read("learntModels/fold" + trainFold + "/labelmap"); 114 | createFeatureGenerator(); 115 | featureGen.read("learntModels/fold" + trainFold + "/features"); 116 | highOrderCrfModel = new HighOrderCRF(featureGen); 117 | highOrderCrfModel.read("learntModels/fold" + trainFold + "/crf"); 118 | 119 | // Run Viterbi algorithm 120 | System.out.print("Running Viterbi..."); 121 | String testFilename = "letter.data"; 122 | DataSet testData = readTagged(testFilename, trainFold, false); 123 | long startTime = System.currentTimeMillis(); 124 | highOrderCrfModel.runViterbi(testData.getSeqList()); 125 | System.out.println("done in " + (System.currentTimeMillis() - startTime) + " ms"); 126 | 127 | // Print out the predicted data and score 128 | File dir = new File("out/"); 129 | dir.mkdirs(); 130 | testData.writeToFile("out/letter" + trainFold + ".out"); 131 | 132 | // Score the result 133 | System.out.println("Scoring results..."); 134 | startTime = System.currentTimeMillis(); 135 | DataSet trueTestData = readTagged(testFilename, trainFold, false); 136 | Scorer scr = new Scorer(trueTestData.getSeqList(), testData.getSeqList(), labelmap, false); 137 | scr.tokenScore(); 138 | System.out.println("done in " + (System.currentTimeMillis() - startTime) + " ms"); 139 | } 140 | 141 | public static void main(String argv[]) throws Exception { 142 | OCR ocr = new OCR(argv[1], argv[2]); 143 | if (argv[0].toLowerCase().equals("all")) { 144 | ocr.train(); 145 | ocr.test(); 146 | } else if (argv[0].toLowerCase().equals("train")) { 147 | ocr.train(); 148 | } else if (argv[0].toLowerCase().equals("test")) { 149 | ocr.test(); 150 | } 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /src/Applications/ReferenceTagger.java: -------------------------------------------------------------------------------- 1 | package Applications; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | import HOSemiCRF.*; 6 | import Applications.RefFeatures.*; 7 | 8 | /** 9 | * Reference Tagger class 10 | * @author Nguyen Viet Cuong 11 | */ 12 | public class ReferenceTagger { 13 | 14 | HighOrderSemiCRF highOrderSemiCrfModel; // High-order semi-CRF model 15 | FeatureGenerator featureGen; // Feature generator 16 | LabelMap labelmap = new LabelMap(); // Label map 17 | String configFile; // Configuration filename 18 | 19 | /** 20 | * Construct a tagger from a configuration file. 21 | * @param filename Name of configuration file 22 | */ 23 | public ReferenceTagger(String filename) { 24 | configFile = filename; 25 | } 26 | 27 | /** 28 | * Read the training file. 29 | * @param filename Name of the training file 30 | * @return The training data 31 | */ 32 | public DataSet readTagged(String filename) throws IOException { 33 | BufferedReader in = new BufferedReader(new FileReader(filename)); 34 | 35 | ArrayList td = new ArrayList(); 36 | ArrayList inps = new ArrayList(); 37 | ArrayList labels = new ArrayList(); 38 | String line; 39 | 40 | while ((line = in.readLine()) != null) { 41 | if (line.length() > 0) { 42 | StringTokenizer toks = new StringTokenizer(line); 43 | String word = toks.nextToken(); 44 | String tagRel = toks.nextToken(); 45 | inps.add(word); 46 | labels.add(tagRel); 47 | } else if (labels.size() > 0) { 48 | td.add(new DataSequence(labelmap.mapArrayList(labels), inps.toArray(), labelmap)); 49 | inps = new ArrayList(); 50 | labels = new ArrayList(); 51 | } 52 | } 53 | if (labels.size() > 0) { 54 | td.add(new DataSequence(labelmap.mapArrayList(labels), inps.toArray(), labelmap)); 55 | } 56 | 57 | in.close(); 58 | return new DataSet(td); 59 | } 60 | 61 | /** 62 | * Add the feature types, process the parameters, and initialize the feature generator. 63 | */ 64 | public void createFeatureGenerator() throws Exception { 65 | // Add feature types 66 | ArrayList fts = new ArrayList(); 67 | 68 | // 1st-order CRF features 69 | fts.add(new WordBag()); 70 | fts.add(new PreviousWordBag()); 71 | fts.add(new NextWordBag()); 72 | fts.add(new WordKPositionBeforeBag()); 73 | fts.add(new WordKPositionAfterBag()); 74 | fts.add(new LetterNGramsBag()); 75 | 76 | fts.add(new EdgeBag()); 77 | fts.add(new Edge()); 78 | fts.add(new EdgeWordBag()); 79 | fts.add(new EdgeWord()); 80 | fts.add(new EdgePreviousWordBag()); 81 | fts.add(new EdgePreviousWord()); 82 | 83 | // Add this for 1st-order Semi-CRF 84 | fts.add(new FirstOrderTransition()); 85 | 86 | // Add this for 2nd-order CRF and Semi-CRF 87 | // fts.add(new SecondOrderTransition()); 88 | 89 | // Add this for 3rd-order CRF and Semi-CRF 90 | // fts.add(new ThirdOrderTransition()); 91 | 92 | // Process parameters 93 | Params params = new Params(configFile, labelmap.size()); 94 | 95 | // Initialize feature generator 96 | featureGen = new FeatureGenerator(fts, params); 97 | } 98 | 99 | /** 100 | * Train the high-order semi-CRF. 101 | */ 102 | public void train() throws Exception { 103 | // Set training file name and create output directory 104 | String trainFilename = "ref.train"; 105 | File dir = new File("learntModels/"); 106 | dir.mkdirs(); 107 | 108 | // Read training data and save the label map 109 | DataSet trainData = readTagged(trainFilename); 110 | labelmap.write("learntModels/labelmap"); 111 | 112 | // Create and save feature generator 113 | createFeatureGenerator(); 114 | featureGen.initialize(trainData.getSeqList()); 115 | featureGen.write("learntModels/features"); 116 | 117 | // Train and save model 118 | highOrderSemiCrfModel = new HighOrderSemiCRF(featureGen); 119 | highOrderSemiCrfModel.train(trainData.getSeqList()); 120 | highOrderSemiCrfModel.write("learntModels/crf"); 121 | } 122 | 123 | /** 124 | * Test the high-order semi-CRF. 125 | */ 126 | public void test() throws Exception { 127 | // Read label map, features, and CRF model 128 | labelmap.read("learntModels/labelmap"); 129 | createFeatureGenerator(); 130 | featureGen.read("learntModels/features"); 131 | highOrderSemiCrfModel = new HighOrderSemiCRF(featureGen); 132 | highOrderSemiCrfModel.read("learntModels/crf"); 133 | 134 | // Run Viterbi algorithm 135 | System.out.print("Running Viterbi..."); 136 | String testFilename = "ref.test"; 137 | DataSet testData = readTagged(testFilename); 138 | long startTime = System.currentTimeMillis(); 139 | highOrderSemiCrfModel.runViterbi(testData.getSeqList()); 140 | System.out.println("done in " + (System.currentTimeMillis() - startTime) + " ms"); 141 | 142 | // Print out the predicted data 143 | File dir = new File("out/"); 144 | dir.mkdirs(); 145 | testData.writeToFile("out/ref.test"); 146 | 147 | // Score the results 148 | System.out.println("Scoring results..."); 149 | startTime = System.currentTimeMillis(); 150 | DataSet trueTestData = readTagged(testFilename); 151 | Scorer scr = new Scorer(trueTestData.getSeqList(), testData.getSeqList(), labelmap, false); 152 | scr.phraseScore(); 153 | System.out.println("done in " + (System.currentTimeMillis() - startTime) + " ms"); 154 | } 155 | 156 | /** 157 | * Main class 158 | */ 159 | public static void main(String argv[]) throws Exception { 160 | ReferenceTagger refTagger = new ReferenceTagger(argv[1]); 161 | if (argv[0].toLowerCase().equals("all")) { 162 | refTagger.train(); 163 | refTagger.test(); 164 | } else if (argv[0].toLowerCase().equals("train")) { 165 | refTagger.train(); 166 | } else if (argv[0].toLowerCase().equals("test")) { 167 | refTagger.test(); 168 | } 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /src/Applications/PunctuationPredictor.java: -------------------------------------------------------------------------------- 1 | package Applications; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | import HOSemiCRF.*; 6 | import Applications.PuncFeatures.*; 7 | 8 | /** 9 | * Main class for the Punctuation Prediction task 10 | * @author Nguyen Viet Cuong 11 | */ 12 | public class PunctuationPredictor { 13 | 14 | HighOrderSemiCRF highOrderSemiCrfModel; // High-order semi-CRF model 15 | FeatureGenerator featureGen; // Feature generator 16 | LabelMap labelmap = new LabelMap(); // Label map 17 | String configFile; // Configuration filename 18 | 19 | /** 20 | * Construct a punctuation tagger from a configuration file. 21 | * @param filename Name of configuration file 22 | */ 23 | public PunctuationPredictor(String filename) { 24 | configFile = filename; 25 | } 26 | 27 | /** 28 | * Read the training file. 29 | * @param filename Name of the training file 30 | * @return The training data 31 | */ 32 | public DataSet readTagged(String filename) throws Exception { 33 | BufferedReader in = new BufferedReader(new FileReader(filename)); 34 | 35 | ArrayList td = new ArrayList(); 36 | ArrayList inps = new ArrayList(); 37 | ArrayList labels = new ArrayList(); 38 | String line; 39 | 40 | while ((line = in.readLine()) != null) { 41 | if (line.length() > 0) { 42 | StringTokenizer toks = new StringTokenizer(line); 43 | String word = toks.nextToken(); 44 | String tagRel = toks.nextToken(); 45 | inps.add(word); 46 | labels.add(tagRel); 47 | } else if (labels.size() > 0) { 48 | td.add(new DataSequence(labelmap.mapArrayList(labels), inps.toArray(), labelmap)); 49 | inps = new ArrayList(); 50 | labels = new ArrayList(); 51 | } 52 | } 53 | if (labels.size() > 0) { 54 | td.add(new DataSequence(labelmap.mapArrayList(labels), inps.toArray(), labelmap)); 55 | } 56 | 57 | in.close(); 58 | return new DataSet(td); 59 | } 60 | 61 | /** 62 | * Add the feature types, process the parameters, and initialize the feature generator. 63 | */ 64 | public void createFeatureGenerator() throws Exception { 65 | // Add feature types 66 | ArrayList fts = new ArrayList(); 67 | 68 | // 1st-order CRF featutres 69 | fts.add(new WordPositionBag()); 70 | fts.add(new TwoWordPositionBag()); 71 | 72 | fts.add(new EdgeBag()); 73 | fts.add(new Edge()); 74 | fts.add(new EdgeWordBag()); 75 | fts.add(new EdgeWord()); 76 | fts.add(new EdgePreviousWordBag()); 77 | fts.add(new EdgePreviousWord()); 78 | fts.add(new EdgeTwoWordBag()); 79 | fts.add(new EdgeTwoWord()); 80 | 81 | // Add these for 1st-order Semi-CRF 82 | // fts.add(new FirstOrderTransition()); 83 | // fts.add(new FirstOrderTransitionWord()); 84 | 85 | // Add these for 2nd-order CRF and Semi-CRF 86 | // fts.add(new SecondOrderTransition()); 87 | // fts.add(new SecondOrderTransitionWord()); 88 | 89 | // Add these for 3rd-order CRF and Semi-CRF 90 | // fts.add(new ThirdOrderTransition()); 91 | // fts.add(new ThirdOrderTransitionWord()); 92 | 93 | // Process parameters 94 | Params params = new Params(configFile, labelmap.size()); 95 | 96 | // Initialize feature generator 97 | featureGen = new FeatureGenerator(fts, params); 98 | } 99 | 100 | /** 101 | * Train the high-order semi-CRF. 102 | */ 103 | public void train() throws Exception { 104 | // Set training file name and create output directory 105 | String trainFilename = "punc.train"; 106 | File dir = new File("learntModels/"); 107 | dir.mkdirs(); 108 | 109 | // Read training data and save the label map 110 | PuncConverter.convert("punc.tr", trainFilename); 111 | DataSet trainData = readTagged(trainFilename); 112 | labelmap.write("learntModels/labelmap"); 113 | 114 | // Create and save feature generator 115 | createFeatureGenerator(); 116 | featureGen.initialize(trainData.getSeqList()); 117 | featureGen.write("learntModels/features"); 118 | 119 | // Train and save model 120 | highOrderSemiCrfModel = new HighOrderSemiCRF(featureGen); 121 | highOrderSemiCrfModel.train(trainData.getSeqList()); 122 | highOrderSemiCrfModel.write("learntModels/crf"); 123 | } 124 | 125 | /** 126 | * Test the high-order semi-CRF. 127 | */ 128 | public void test() throws Exception { 129 | // Read label map, features, and CRF model 130 | labelmap.read("learntModels/labelmap"); 131 | createFeatureGenerator(); 132 | featureGen.read("learntModels/features"); 133 | highOrderSemiCrfModel = new HighOrderSemiCRF(featureGen); 134 | highOrderSemiCrfModel.read("learntModels/crf"); 135 | 136 | // Run Viterbi algorithm 137 | System.out.print("Running Viterbi..."); 138 | String testFilename = "punc.test"; 139 | PuncConverter.convert("punc.ts", testFilename); 140 | DataSet testData = readTagged(testFilename); 141 | long startTime = System.currentTimeMillis(); 142 | highOrderSemiCrfModel.runViterbi(testData.getSeqList()); 143 | System.out.println("done in " + (System.currentTimeMillis() - startTime) + " ms"); 144 | 145 | // Print out the predicted data 146 | File dir = new File("out/"); 147 | dir.mkdirs(); 148 | testData.writeToFile("out/punc.test"); 149 | 150 | // Score the results 151 | System.out.println("Scoring results..."); 152 | startTime = System.currentTimeMillis(); 153 | DataSet trueTestData = readTagged(testFilename); 154 | Scorer scr = new Scorer(trueTestData.getSeqList(), testData.getSeqList(), labelmap, true); 155 | scr.tokenScore(); 156 | System.out.println("done in " + (System.currentTimeMillis() - startTime) + " ms"); 157 | } 158 | 159 | /** 160 | * Main class 161 | */ 162 | public static void main(String argv[]) throws Exception { 163 | PunctuationPredictor puncPredictor = new PunctuationPredictor(argv[1]); 164 | if (argv[0].toLowerCase().equals("all")) { 165 | puncPredictor.train(); 166 | puncPredictor.test(); 167 | } else if (argv[0].toLowerCase().equals("train")) { 168 | puncPredictor.train(); 169 | } else if (argv[0].toLowerCase().equals("test")) { 170 | puncPredictor.test(); 171 | } 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /src/HOCRF/LogliComputer.java: -------------------------------------------------------------------------------- 1 | package HOCRF; 2 | 3 | import java.util.*; 4 | import Parallel.*; 5 | 6 | /** 7 | * Algorithm for computing the partition functions and expected feature scores 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class LogliComputer implements Schedulable { 11 | 12 | int curID; // Current task ID (for parallelization) 13 | FeatureGenerator featureGen; // Feature generator 14 | ArrayList trainData; // List of training sequences 15 | double[] lambda; // Lambda vector 16 | Loglikelihood logli; // Loglikelihood value and derivatives 17 | final int BASE = 1; // Base of the logAlpha array 18 | 19 | /** 20 | * Construct a loglikelihood computer. 21 | * @param lambdaValues Lambda vector 22 | * @param fgen Feature generator 23 | * @param td List of training sequences 24 | * @param loglh Initial loglikelihood and its derivatives (partially computed from class Function) 25 | */ 26 | public LogliComputer(double[] lambdaValues, FeatureGenerator fgen, ArrayList td, Loglikelihood loglh) { 27 | curID = -1; 28 | featureGen = fgen; 29 | trainData = td; 30 | lambda = lambdaValues; 31 | logli = loglh; 32 | } 33 | 34 | /** 35 | * Compute the partition function and expected feature score (in log scale) for a given sequence. 36 | * @param taskID Index of the training sequence 37 | * @return Partition function value and expected feature scores 38 | */ 39 | public Object compute(int taskID) { 40 | Loglikelihood res = new Loglikelihood(lambda.length); 41 | DataSequence seq = (DataSequence) trainData.get(taskID); 42 | 43 | addFeatureScores(seq, res); 44 | double[][] logAlpha = computeLogAlpha(seq); 45 | double logZx = computeLogZx(seq, logAlpha); 46 | double[][] logBeta = computeLogBeta(seq); 47 | double[][] marginal = computeMarginal(seq, logAlpha, logBeta, logZx); 48 | double[] expectation = computeExpectation(seq, marginal); 49 | 50 | for (int k = 0; k < lambda.length; k++) { 51 | res.derivatives[k] -= expectation[k]; 52 | } 53 | res.logli -= logZx; 54 | 55 | return res; 56 | } 57 | 58 | /** 59 | * Return total number of tasks (for parallelization). 60 | * @return Training dataset size 61 | */ 62 | public int getNumTasks() { 63 | return trainData.size(); 64 | } 65 | 66 | /** 67 | * Return the next task ID (for parallelization). 68 | * @return The next sequence ID 69 | */ 70 | public synchronized int fetchCurrTaskID() { 71 | if (curID < getNumTasks()) { 72 | curID++; 73 | } 74 | return curID; 75 | } 76 | 77 | /** 78 | * Add the partition function and expected feature scores into the final loglikelihood and its derivatives. 79 | * @param partialResult Partition function and expected feature scores 80 | */ 81 | public synchronized void update(Object partialResult) { 82 | Loglikelihood res = (Loglikelihood) partialResult; 83 | logli.logli += res.logli; 84 | for (int i = 0; i < lambda.length; i++) { 85 | logli.derivatives[i] += res.derivatives[i]; 86 | } 87 | } 88 | 89 | /** 90 | * Add the empirical feature scores and the sequence probability. 91 | * @param seq Training sequence 92 | * @param res Partial loglikelihood to be updated after this method call 93 | */ 94 | public void addFeatureScores(DataSequence seq, Loglikelihood res) { 95 | for (int pos = 0; pos < seq.length(); pos++) { 96 | String labelPat = featureGen.generateLabelPattern(seq, pos); 97 | int sID = featureGen.getBackwardStateIndex(labelPat); 98 | for (int patID : featureGen.allSuffixes[sID]) { 99 | ArrayList feats = featureGen.getFeatures(seq, pos, patID); 100 | for (int index : feats) { 101 | Feature feat = featureGen.featureList.get(index); 102 | res.derivatives[index] += feat.value; 103 | res.logli += lambda[index] * feat.value; 104 | } 105 | } 106 | } 107 | } 108 | 109 | /** 110 | * Run the forward algorithm. 111 | * @param seq Training sequence 112 | * @return Logarithms of the alpha variables 113 | */ 114 | public double[][] computeLogAlpha(DataSequence seq) { 115 | double[][] logAlpha = new double[seq.length() + 1][featureGen.forwardStateMap.size()]; 116 | Arrays.fill(logAlpha[0], Double.NEGATIVE_INFINITY); 117 | logAlpha[0][0] = 0.0; 118 | for (int j = 0; j < seq.length(); j++) { 119 | Arrays.fill(logAlpha[j + BASE], Double.NEGATIVE_INFINITY); 120 | for (int i = 1; i < featureGen.forwardStateMap.size(); i++) { 121 | ArrayList prevState1 = featureGen.forwardTransition1[i]; 122 | ArrayList prevState2 = featureGen.forwardTransition2[i]; 123 | for (int k = 0; k < prevState1.size(); k++) { 124 | int pkID = prevState1.get(k); 125 | int pkyID = prevState2.get(k); 126 | double featuresScore = 0.0; 127 | for (Integer patID : featureGen.allSuffixes[pkyID]) { 128 | ArrayList feats = featureGen.getFeatures(seq, j, patID); 129 | featuresScore += featureGen.computeFeatureScores(feats, lambda); 130 | } 131 | logAlpha[j + BASE][i] = Utility.logSumExp(logAlpha[j + BASE][i], logAlpha[j + BASE - 1][pkID] + featuresScore); 132 | } 133 | } 134 | } 135 | return logAlpha; 136 | } 137 | 138 | /** 139 | * Compute the logarithm of partition function from the alpha variables. 140 | * @param seq Training sequence 141 | * @param logAlpha Logarithms of the alpha variables 142 | * @return Logarithm of the partition function 143 | */ 144 | public double computeLogZx(DataSequence seq, double[][] logAlpha) { 145 | double logZx = Double.NEGATIVE_INFINITY; 146 | int l = seq.length(); 147 | for (int i = 0; i < featureGen.forwardStateMap.size(); i++) { 148 | logZx = Utility.logSumExp(logZx, logAlpha[l][i]); 149 | } 150 | return logZx; 151 | } 152 | 153 | /** 154 | * Run the backward algorithm. 155 | * @param seq Training sequence 156 | * @return Logarithms of the beta variables 157 | */ 158 | public double[][] computeLogBeta(DataSequence seq) { 159 | double[][] logBeta = new double[seq.length() + 1][featureGen.backwardStateMap.size()]; 160 | Arrays.fill(logBeta[seq.length()], 0.0); 161 | for (int j = seq.length() - 1; j > 0; j--) { 162 | Arrays.fill(logBeta[j], Double.NEGATIVE_INFINITY); 163 | for (int i = 0; i < featureGen.backwardStateMap.size(); i++) { 164 | for (int y = 0; y < featureGen.params.numLabels; y++) { 165 | int skID = featureGen.backwardTransition[i][y]; 166 | if (skID != -1) { 167 | double featuresScore = 0.0; 168 | for (Integer patID : featureGen.allSuffixes[skID]) { 169 | ArrayList feats = featureGen.getFeatures(seq, j, patID); 170 | featuresScore += featureGen.computeFeatureScores(feats, lambda); 171 | } 172 | logBeta[j][i] = Utility.logSumExp(logBeta[j][i], logBeta[j + 1][skID] + featuresScore); 173 | } 174 | } 175 | } 176 | } 177 | return logBeta; 178 | } 179 | 180 | /** 181 | * Compute the marginals. 182 | * @param seq Training sequence 183 | * @param logAlpha Logarithms of the alpha variables 184 | * @param logBeta Logarithms of the beta variables 185 | * @param logZx Logarithm of the partition function 186 | * @return The array of marginals 187 | */ 188 | public double[][] computeMarginal(DataSequence seq, double[][] logAlpha, double[][] logBeta, double logZx) { 189 | double[][] marginal = new double[featureGen.patternMap.size()][seq.length()]; 190 | for (int zID = 0; zID < featureGen.patternMap.size(); zID++) { 191 | for (int pos = 0; pos < seq.length(); pos++) { 192 | marginal[zID][pos] = Double.NEGATIVE_INFINITY; 193 | 194 | for (int i = 0; i < featureGen.patternTransition1[zID].size(); i++) { 195 | int piID = featureGen.patternTransition1[zID].get(i); 196 | int piyID = featureGen.patternTransition2[zID].get(i); 197 | 198 | double featuresScore = 0.0; 199 | for (Integer patID : featureGen.allSuffixes[piyID]) { 200 | ArrayList feats = featureGen.getFeatures(seq, pos, patID); 201 | featuresScore += featureGen.computeFeatureScores(feats, lambda); 202 | } 203 | marginal[zID][pos] = Utility.logSumExp(marginal[zID][pos], logAlpha[BASE + pos - 1][piID] + logBeta[pos + 1][piyID] + featuresScore); 204 | } 205 | 206 | marginal[zID][pos] = Math.exp(marginal[zID][pos] - logZx); 207 | } 208 | } 209 | return marginal; 210 | } 211 | 212 | /** 213 | * Compute the feature expectations. 214 | * @param seq Training sequence 215 | * @param marginals The marginals 216 | * @return The feature expectations 217 | */ 218 | public double[] computeExpectation(DataSequence seq, double[][] marginal) { 219 | double[] expectation = new double[lambda.length]; 220 | Arrays.fill(expectation, 0.0); 221 | for (int zID = 0; zID < featureGen.patternMap.size(); zID++) { 222 | for (int pos = 0; pos < seq.length(); pos ++) { 223 | ArrayList feats = featureGen.getFeatures(seq, pos, zID); 224 | for (int index : feats) { 225 | Feature feat = featureGen.featureList.get(index); 226 | expectation[index] += feat.value * marginal[zID][pos]; 227 | } 228 | } 229 | } 230 | return expectation; 231 | } 232 | 233 | /** 234 | * Print a 2D array to stdout for debugging. 235 | * @param arr The 2D array 236 | */ 237 | public void printArray(double[][] arr) { 238 | for (int i = 0; i < arr.length; i++) { 239 | for (int j = 0; j < arr[i].length; j++) { 240 | System.out.print(arr[i][j] + " "); 241 | } 242 | System.out.println(); 243 | } 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /src/HOSemiCRF/LogliComputer.java: -------------------------------------------------------------------------------- 1 | package HOSemiCRF; 2 | 3 | import java.util.*; 4 | import Parallel.*; 5 | 6 | /** 7 | * Algorithm for computing the partition functions and expected feature scores 8 | * @author Nguyen Viet Cuong 9 | */ 10 | public class LogliComputer implements Schedulable { 11 | 12 | int curID; // Current task ID (for parallelization) 13 | FeatureGenerator featureGen; // Feature generator 14 | ArrayList trainData; // List of training sequences 15 | double[] lambda; // Lambda vector 16 | Loglikelihood logli; // Loglikelihood value and derivatives 17 | final int BASE = 1; // Base of the logAlpha array 18 | 19 | /** 20 | * Construct a loglikelihood computer. 21 | * @param lambdaValues Lambda vector 22 | * @param fgen Feature generator 23 | * @param td List of training sequences 24 | * @param loglh Initial loglikelihood and its derivatives (partially computed from class Function) 25 | */ 26 | public LogliComputer(double[] lambdaValues, FeatureGenerator fgen, ArrayList td, Loglikelihood loglh) { 27 | curID = -1; 28 | featureGen = fgen; 29 | trainData = td; 30 | lambda = lambdaValues; 31 | logli = loglh; 32 | } 33 | 34 | /** 35 | * Compute the partition function and expected feature score (in log scale) for a given sequence. 36 | * @param taskID Index of the training sequence 37 | * @return Partition function value and expected feature scores 38 | */ 39 | public Object compute(int taskID) { 40 | Loglikelihood res = new Loglikelihood(lambda.length); 41 | DataSequence seq = (DataSequence) trainData.get(taskID); 42 | 43 | addFeatureScores(seq, res); 44 | double[][] logAlpha = computeLogAlpha(seq); 45 | double logZx = computeLogZx(seq, logAlpha); 46 | double[][] logBeta = computeLogBeta(seq); 47 | double[][][] marginal = computeMarginal(seq, logAlpha, logBeta, logZx); 48 | double[] expectation = computeExpectation(seq, marginal); 49 | 50 | for (int k = 0; k < lambda.length; k++) { 51 | res.derivatives[k] -= expectation[k]; 52 | } 53 | res.logli -= logZx; 54 | 55 | return res; 56 | } 57 | 58 | /** 59 | * Return total number of tasks (for parallelization). 60 | * @return Training dataset size 61 | */ 62 | public int getNumTasks() { 63 | return trainData.size(); 64 | } 65 | 66 | /** 67 | * Return the next task ID (for parallelization). 68 | * @return The next sequence ID 69 | */ 70 | public synchronized int fetchCurrTaskID() { 71 | if (curID < getNumTasks()) { 72 | curID++; 73 | } 74 | return curID; 75 | } 76 | 77 | /** 78 | * Add the partition function and expected feature scores into the final loglikelihood and its derivatives. 79 | * @param partialResult Partition function and expected feature scores 80 | */ 81 | public synchronized void update(Object partialResult) { 82 | Loglikelihood res = (Loglikelihood) partialResult; 83 | logli.logli += res.logli; 84 | for (int i = 0; i < lambda.length; i++) { 85 | logli.derivatives[i] += res.derivatives[i]; 86 | } 87 | } 88 | 89 | /** 90 | * Add the empirical feature scores and the sequence probability. 91 | * @param seq Training sequence 92 | * @param res Partial loglikelihood to be updated after this method call 93 | */ 94 | public void addFeatureScores(DataSequence seq, Loglikelihood res) { 95 | int segStart, segEnd; 96 | for (segStart = 0; segStart < seq.length(); segStart = segEnd + 1) { 97 | segEnd = seq.getSegmentEnd(segStart); 98 | 99 | String labelPat = featureGen.generateLabelPattern(seq, segStart, segEnd); 100 | int sID = featureGen.getBackwardStateIndex(labelPat); 101 | for (int patID : featureGen.allSuffixes[sID]) { 102 | ArrayList feats = featureGen.getFeatures(seq, segStart, segEnd, patID); 103 | for (int index : feats) { 104 | Feature feat = featureGen.featureList.get(index); 105 | res.derivatives[index] += feat.value; 106 | res.logli += lambda[index] * feat.value; 107 | } 108 | } 109 | } 110 | } 111 | 112 | /** 113 | * Run the forward algorithm. 114 | * @param seq Training sequence 115 | * @return Logarithms of the alpha variables 116 | */ 117 | public double[][] computeLogAlpha(DataSequence seq) { 118 | double[][] logAlpha = new double[seq.length() + 1][featureGen.forwardStateMap.size()]; 119 | Arrays.fill(logAlpha[0], Double.NEGATIVE_INFINITY); 120 | logAlpha[0][0] = 0.0; 121 | for (int j = 0; j < seq.length(); j++) { 122 | Arrays.fill(logAlpha[j + BASE], Double.NEGATIVE_INFINITY); 123 | for (int i = 0; i < featureGen.forwardStateMap.size(); i++) { 124 | int y = featureGen.lastForwardStateLabel[i]; 125 | int maxmem = (y == -1) ? 0 : featureGen.maxMemory[y]; 126 | 127 | ArrayList prevState1 = featureGen.forwardTransition1[i]; 128 | ArrayList prevState2 = featureGen.forwardTransition2[i]; 129 | for (int d = 0; d < maxmem && j - d >= 0; d++) { 130 | for (int k = 0; k < prevState1.size(); k++) { 131 | int pkID = prevState1.get(k); 132 | int pkyID = prevState2.get(k); 133 | double featuresScore = 0.0; 134 | for (Integer patID : featureGen.allSuffixes[pkyID]) { 135 | ArrayList feats = featureGen.getFeatures(seq, j - d, j, patID); 136 | featuresScore += featureGen.computeFeatureScores(feats, lambda); 137 | } 138 | logAlpha[j + BASE][i] = Utility.logSumExp(logAlpha[j + BASE][i], logAlpha[j + BASE - d - 1][pkID] + featuresScore); 139 | } 140 | } 141 | } 142 | } 143 | return logAlpha; 144 | } 145 | 146 | /** 147 | * Compute the logarithm of partition function from the alpha variables. 148 | * @param seq Training sequence 149 | * @param logAlpha Logarithms of the alpha variables 150 | * @return Logarithm of the partition function 151 | */ 152 | public double computeLogZx(DataSequence seq, double[][] logAlpha) { 153 | double logZx = Double.NEGATIVE_INFINITY; 154 | int l = seq.length(); 155 | for (int i = 0; i < featureGen.forwardStateMap.size(); i++) { 156 | logZx = Utility.logSumExp(logZx, logAlpha[l][i]); 157 | } 158 | return logZx; 159 | } 160 | 161 | /** 162 | * Run the backward algorithm. 163 | * @param seq Training sequence 164 | * @return Logarithms of the beta variables 165 | */ 166 | public double[][] computeLogBeta(DataSequence seq) { 167 | double[][] logBeta = new double[seq.length() + 1][featureGen.backwardStateMap.size()]; 168 | Arrays.fill(logBeta[seq.length()], 0.0); 169 | for (int j = seq.length() - 1; j > 0; j--) { 170 | Arrays.fill(logBeta[j], Double.NEGATIVE_INFINITY); 171 | for (int i = 0; i < featureGen.backwardStateMap.size(); i++) { 172 | for (int y = 0; y < featureGen.params.numLabels; y++) { 173 | int skID = featureGen.backwardTransition[i][y]; 174 | if (skID != -1) { 175 | for (int d = 0; d < featureGen.maxMemory[y] && j + d < seq.length(); d++) { 176 | double featuresScore = 0.0; 177 | for (Integer patID : featureGen.allSuffixes[skID]) { 178 | ArrayList feats = featureGen.getFeatures(seq, j, j+d, patID); 179 | featuresScore += featureGen.computeFeatureScores(feats, lambda); 180 | } 181 | logBeta[j][i] = Utility.logSumExp(logBeta[j][i], logBeta[j + d + 1][skID] + featuresScore); 182 | } 183 | } 184 | } 185 | } 186 | } 187 | return logBeta; 188 | } 189 | 190 | /** 191 | * Compute the marginals. 192 | * @param seq Training sequence 193 | * @param logAlpha Logarithms of the alpha variables 194 | * @param logBeta Logarithms of the beta variables 195 | * @param logZx Logarithm of the partition function 196 | * @return The array of marginals 197 | */ 198 | public double[][][] computeMarginal(DataSequence seq, double[][] logAlpha, double[][] logBeta, double logZx) { 199 | double[][][] marginal = new double[featureGen.patternMap.size()][seq.length()][]; 200 | for (int zID = 0; zID < featureGen.patternMap.size(); zID++) { 201 | int y = featureGen.lastPatternLabel[zID]; 202 | int maxmem = (y == -1) ? 0 : featureGen.maxMemory[y]; 203 | 204 | for (int segStart = 0; segStart < seq.length(); segStart++) { 205 | int maxLength = Math.min(maxmem, seq.length() - segStart); 206 | marginal[zID][segStart] = new double[maxLength]; 207 | for (int d = 0; d < maxLength; d++) { 208 | marginal[zID][segStart][d] = Double.NEGATIVE_INFINITY; 209 | 210 | for (int i = 0; i < featureGen.patternTransition1[zID].size(); i++) { 211 | int piID = featureGen.patternTransition1[zID].get(i); 212 | int piyID = featureGen.patternTransition2[zID].get(i); 213 | 214 | double featuresScore = 0.0; 215 | for (Integer patID : featureGen.allSuffixes[piyID]) { 216 | ArrayList feats = featureGen.getFeatures(seq, segStart, segStart+d, patID); 217 | featuresScore += featureGen.computeFeatureScores(feats, lambda); 218 | } 219 | marginal[zID][segStart][d] = Utility.logSumExp(marginal[zID][segStart][d], logAlpha[BASE + segStart - 1][piID] + logBeta[segStart + d + 1][piyID] + featuresScore); 220 | } 221 | 222 | marginal[zID][segStart][d] = Math.exp(marginal[zID][segStart][d] - logZx); 223 | } 224 | } 225 | } 226 | return marginal; 227 | } 228 | 229 | /** 230 | * Compute the feature expectations. 231 | * @param seq Training sequence 232 | * @param marginals The marginals 233 | * @return The feature expectations 234 | */ 235 | public double[] computeExpectation(DataSequence seq, double[][][] marginal) { 236 | double[] expectation = new double[lambda.length]; 237 | Arrays.fill(expectation, 0.0); 238 | for (int zID = 0; zID < featureGen.patternMap.size(); zID++) { 239 | int y = featureGen.lastPatternLabel[zID]; 240 | int maxmem = (y == -1) ? 0 : featureGen.maxMemory[y]; 241 | 242 | for (int segStart = 0; segStart < seq.length(); segStart++) { 243 | for (int segEnd = segStart; segEnd < seq.length() && segEnd - segStart < maxmem; segEnd++) { 244 | ArrayList feats = featureGen.getFeatures(seq, segStart, segEnd, zID); 245 | for (int index : feats) { 246 | Feature feat = featureGen.featureList.get(index); 247 | expectation[index] += feat.value * marginal[zID][segStart][segEnd-segStart]; 248 | } 249 | } 250 | } 251 | } 252 | return expectation; 253 | } 254 | 255 | /** 256 | * Print a 2D array to stdout for debugging. 257 | * @param arr The 2D array 258 | */ 259 | public void printArray(double[][] arr) { 260 | for (int i = 0; i < arr.length; i++) { 261 | for (int j = 0; j < arr[i].length; j++) { 262 | System.out.print(arr[i][j] + " "); 263 | } 264 | System.out.println(); 265 | } 266 | } 267 | } 268 | -------------------------------------------------------------------------------- /src/HOSemiCRF/Scorer.java: -------------------------------------------------------------------------------- 1 | package HOSemiCRF; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | import java.text.*; 6 | 7 | /** 8 | * Scorer class 9 | * @author Nguyen Viet Cuong 10 | * @author Ye Nan 11 | */ 12 | public class Scorer { 13 | 14 | String[][] labels; // True labels 15 | String[][] predicted; // Predicted labels 16 | 17 | /** 18 | * Construct a scorer with true data, predicted data, and label map. 19 | * @param trueData List of correct sequences 20 | * @param predictedData List of predicted sequences 21 | * @param labelmap Label map 22 | * @param RM_SUFFIX If set to true, suffixes of labels after '-' will be removed 23 | */ 24 | public Scorer(ArrayList trueData, ArrayList predictedData, LabelMap labelmap, boolean RM_SUFFIX) { 25 | labels = new String[trueData.size()][]; 26 | for (int i = 0; i < trueData.size(); i++) { 27 | DataSequence seq = (DataSequence) trueData.get(i); 28 | labels[i] = labelmap.revArray(seq.labels); 29 | } 30 | 31 | predicted = new String[predictedData.size()][]; 32 | for (int i = 0; i < predictedData.size(); i++) { 33 | DataSequence seq = (DataSequence) predictedData.get(i); 34 | predicted[i] = labelmap.revArray(seq.labels); 35 | } 36 | 37 | if (RM_SUFFIX) { 38 | removeSuffix(labels); 39 | removeSuffix(predicted); 40 | } 41 | } 42 | 43 | /** 44 | * Print the scores based on the correct phrases. 45 | * @return F1 score 46 | */ 47 | public double phraseScore() { 48 | Hashtable labelht = new Hashtable(); 49 | Vector labs = new Vector(); 50 | collectLabels(labels, predicted, labelht, labs); 51 | 52 | double nTokens = 0; 53 | double nMatched = 0; 54 | 55 | double[] nPhrase = new double[labelht.size()]; 56 | double[] nPredicted = new double[labelht.size()]; 57 | double[] nCorrect = new double[labelht.size()]; 58 | for (int s = 0; s < labels.length; s++) { 59 | nTokens += labels[s].length; 60 | String prev = ""; 61 | for (int t = 0; t < labels[s].length; t++) { 62 | if (labels[s][t].equals(predicted[s][t])) { 63 | nMatched++; 64 | } 65 | 66 | if (!labels[s][t].equals(prev)) { 67 | prev = labels[s][t]; 68 | if (!labels[s][t].equals("O")) { 69 | nPhrase[labelht.get(labels[s][t])]++; 70 | } 71 | } 72 | } 73 | } 74 | 75 | for (int s = 0; s < predicted.length; s++) { 76 | String prev = "O"; 77 | int start = -1, end = -1; 78 | for (int t = 0; t < predicted[s].length; t++) { 79 | if (!predicted[s][t].equals(prev)) { //token changed 80 | if (!prev.equals("O")) { //phrase ended 81 | end = t; 82 | if (!(start > 0 && labels[s][start - 1].equals(labels[s][start])) 83 | && !(end < labels[s].length && labels[s][end - 1].equals(labels[s][end]))) { 84 | int i = start; 85 | for (; i < end; i++) { 86 | if (!predicted[s][i].equals(labels[s][i])) { 87 | break; 88 | } 89 | } 90 | if (i == end) { 91 | nCorrect[labelht.get(predicted[s][start])]++; 92 | } 93 | } 94 | } 95 | 96 | prev = predicted[s][t]; 97 | if (!prev.equals("O")) { //mark beginning of phrase 98 | start = t; 99 | nPredicted[labelht.get(prev)]++; 100 | } 101 | } 102 | 103 | if (t == predicted[s].length - 1 && !predicted[s][t].equals("O")) { 104 | int i = start; 105 | for (; i <= t; i++) { 106 | if (!predicted[s][i].equals(labels[s][i])) { 107 | break; 108 | } 109 | } 110 | if (i == t + 1) { 111 | nCorrect[labelht.get(predicted[s][start])]++; 112 | } 113 | } 114 | } 115 | } 116 | 117 | double nTotCorrect = sum(nCorrect); 118 | double nTotPhrase = sum(nPhrase); 119 | double nTotPredicted = sum(nPredicted); 120 | System.out.println((int) nTokens + " tokens with " + (int) nTotPhrase + " phrases."); 121 | System.out.println((int) nTotPredicted + " predicted phrases with " + (int) nTotCorrect + " being correct."); 122 | 123 | DisplayTable table = new DisplayTable(); 124 | Object[] row1 = {"", "Precision", "#Pred", "Recall", "#Phrase", "F1", "#Correct"}; 125 | table.addRow(row1); 126 | 127 | String strP = strRatio(nTotCorrect, nTotPredicted); 128 | String strR = strRatio(nTotCorrect, nTotPhrase); 129 | String strF1 = strRatio(2 * nTotCorrect, nTotPredicted + nTotPhrase); 130 | Object[] row2 = {"Accuracy: " + strRatio(nMatched, nTokens), strP, (int) nTotPredicted, strR, (int) nTotPhrase, strF1, (int) nTotCorrect}; 131 | table.addRow(row2); 132 | for (int c = 0; c < labelht.size(); c++) { 133 | strP = strRatio(nCorrect[c], nPredicted[c]); 134 | strR = strRatio(nCorrect[c], nPhrase[c]); 135 | strF1 = strRatio(2 * nCorrect[c], nPredicted[c] + nPhrase[c]); 136 | Object[] row = {labs.get(c), strP, (int) nPredicted[c], strR, (int) nPhrase[c], strF1, (int) nCorrect[c]}; 137 | table.addRow(row); 138 | } 139 | System.out.println(table); 140 | return ratio(2 * nTotCorrect, nTotPredicted + nTotPhrase); 141 | } 142 | 143 | /** 144 | * Print the scores based on the correct tokens. 145 | * @return F1 score 146 | */ 147 | public double tokenScore() { 148 | Hashtable labelht = new Hashtable(); 149 | Vector labs = new Vector(); 150 | collectLabels(labels, predicted, labelht, labs); 151 | 152 | double nTokens = 0; 153 | double nMatched = 0; 154 | 155 | double[] nToken = new double[labelht.size()]; 156 | double[] nPredicted = new double[labelht.size()]; 157 | double[] nCorrect = new double[labelht.size()]; 158 | for (int s = 0; s < labels.length; s++) { 159 | for (int t = 0; t < labels[s].length; t++) { 160 | nTokens++; 161 | if (labels[s][t].equals(predicted[s][t])) { 162 | nMatched++; 163 | if (!labels[s][t].equals("O")) { 164 | nCorrect[labelht.get(labels[s][t])]++; 165 | } 166 | } 167 | if (!labels[s][t].equals("O")) { 168 | nToken[labelht.get(labels[s][t])]++; 169 | } 170 | if (!predicted[s][t].equals("O")) { 171 | nPredicted[labelht.get(predicted[s][t])]++; 172 | } 173 | } 174 | } 175 | 176 | double nTotCorrect = sum(nCorrect); 177 | double nTotToken = sum(nToken); 178 | double nTotPredicted = sum(nPredicted); 179 | System.out.println((int) nTokens + " tokens with " + (int) nTotToken + " relevant tokens (not equal to O)."); 180 | System.out.println((int) nTotPredicted + " predicted tokens with " + (int) nTotCorrect + " being correct."); 181 | 182 | DisplayTable table = new DisplayTable(); 183 | Object[] row1 = {"", "Precision", "#Pred", "Recall", "#Phrase", "F1", "#Correct"}; 184 | table.addRow(row1); 185 | 186 | String strP = strRatio(nTotCorrect, nTotPredicted); 187 | String strR = strRatio(nTotCorrect, nTotToken); 188 | String strF1 = strRatio(2 * nTotCorrect, nTotPredicted + nTotToken); 189 | Object[] row2 = {"Accuracy: " + strRatio(nMatched, nTokens), strP, (int) nTotPredicted, strR, (int) nTotToken, strF1, (int) nTotCorrect}; 190 | table.addRow(row2); 191 | for (int c = 0; c < labelht.size(); c++) { 192 | strP = strRatio(nCorrect[c], nPredicted[c]); 193 | strR = strRatio(nCorrect[c], nToken[c]); 194 | strF1 = strRatio(2 * nCorrect[c], nPredicted[c] + nToken[c]); 195 | Object[] row = {labs.get(c), strP, (int) nPredicted[c], strR, (int) nToken[c], strF1, (int) nCorrect[c]}; 196 | table.addRow(row); 197 | } 198 | System.out.println(table); 199 | return ratio(2 * nTotCorrect, nTotPredicted + nTotToken); 200 | } 201 | 202 | private void collectLabels(String[][] labels, String[][] predicted, Hashtable labelht, Vector labs) { 203 | for (int s = 0; s < labels.length; s++) { 204 | for (int t = 0; t < labels[s].length; t++) { 205 | if (!labels[s][t].equals("O") && !labelht.containsKey(labels[s][t])) { 206 | labelht.put(labels[s][t], labelht.size()); 207 | labs.add(labels[s][t]); 208 | } 209 | 210 | if (!predicted[s][t].equals("O") && !labelht.containsKey(predicted[s][t])) { 211 | labelht.put(predicted[s][t], labelht.size()); 212 | labs.add(predicted[s][t]); 213 | } 214 | } 215 | } 216 | Collections.sort(labs); 217 | for (int i = 0; i < labs.size(); i++) { 218 | labelht.put(labs.get(i), i); 219 | } 220 | } 221 | 222 | private String strRatio(double a, double b) { 223 | double r = 0; 224 | if (b != 0) { 225 | r = 100 * a / b; 226 | } 227 | DecimalFormat df = new DecimalFormat("#0.00"); 228 | return df.format(r) + "%"; 229 | } 230 | 231 | private double ratio(double a, double b) { 232 | if (b != 0) { 233 | return a / b; 234 | } 235 | return 0; 236 | } 237 | 238 | private double sum(double[] ar) { 239 | double s = 0; 240 | for (double d : ar) { 241 | s += d; 242 | } 243 | return s; 244 | } 245 | 246 | private void removeSuffix(String[][] arr) { 247 | for (int i = 0; i < arr.length; i++) { 248 | for (int j = 0; j < arr[i].length; j++) { 249 | if (arr[i][j].lastIndexOf('-') != -1) { 250 | arr[i][j] = arr[i][j].substring(0, arr[i][j].lastIndexOf('-')); 251 | } 252 | } 253 | } 254 | } 255 | } 256 | 257 | /** 258 | * Class for displaying tables 259 | * @author Ye Nan 260 | */ 261 | class DisplayTable { 262 | 263 | Vector> rows = new Vector>(); // Rows of the table 264 | 265 | DisplayTable() { 266 | } 267 | 268 | void addRow(Object[] entries) { 269 | Vector row = new Vector(); 270 | for (int i = 0; i < entries.length; i++) { 271 | row.add(entries[i].toString()); 272 | } 273 | rows.add(row); 274 | } 275 | 276 | String format(String s, int w, String align) { 277 | int n = w - s.length(); 278 | if (align == "l") {//align to the left 279 | for (int i = 0; i < n; i++) { 280 | s += " "; 281 | } 282 | } else {//align to the right 283 | for (int i = 0; i < n; i++) { 284 | s = " " + s; 285 | } 286 | } 287 | return s; 288 | } 289 | 290 | @Override 291 | public String toString() { 292 | int ncols = 0; 293 | int nrows = rows.size(); 294 | for (int r = 0; r < nrows; r++) { 295 | if (rows.get(r).size() > ncols) { 296 | ncols = rows.get(r).size(); 297 | } 298 | } 299 | 300 | for (int r = 0; r < nrows; r++) { 301 | int s = rows.get(r).size(); 302 | if (s != ncols) { 303 | for (int i = 0; i < ncols - s; i++) { 304 | rows.get(r).add(""); 305 | } 306 | } 307 | } 308 | int[] colWidths = new int[ncols]; 309 | for (int c = 0; c < ncols; c++) { 310 | for (int r = 0; r < nrows; r++) { 311 | int w = rows.get(r).get(c).length(); 312 | if (w > colWidths[c]) { 313 | colWidths[c] = w; 314 | } 315 | } 316 | } 317 | 318 | StringBuffer sb = new StringBuffer(); 319 | for (int r = 0; r < nrows; r++) { 320 | for (int c = 0; c < ncols; c++) { 321 | String align = "r"; 322 | if (c == ncols - 1) { 323 | align = "l"; 324 | } 325 | sb.append(format(rows.get(r).get(c), colWidths[c], align) + "\t "); 326 | } 327 | sb.append("\n"); 328 | } 329 | return sb.toString(); 330 | } 331 | } 332 | -------------------------------------------------------------------------------- /src/HOCRF/Scorer.java: -------------------------------------------------------------------------------- 1 | package HOCRF; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | import java.text.*; 6 | 7 | /** 8 | * Scorer class 9 | * @author Nguyen Viet Cuong 10 | * @author Ye Nan 11 | */ 12 | public class Scorer { 13 | 14 | String[][] labels; // True labels 15 | String[][] predicted; // Predicted labels 16 | 17 | /** 18 | * Construct a scorer with true data, predicted data, and label map. 19 | * @param trueData List of correct sequences 20 | * @param predictedData List of predicted sequences 21 | * @param labelmap Label map 22 | * @param RM_SUFFIX If set to true, suffixes of labels after '-' will be removed 23 | */ 24 | public Scorer(ArrayList trueData, ArrayList predictedData, LabelMap labelmap, boolean RM_SUFFIX) { 25 | labels = new String[trueData.size()][]; 26 | for (int i = 0; i < trueData.size(); i++) { 27 | DataSequence seq = (DataSequence) trueData.get(i); 28 | labels[i] = labelmap.revArray(seq.labels); 29 | } 30 | 31 | predicted = new String[predictedData.size()][]; 32 | for (int i = 0; i < predictedData.size(); i++) { 33 | DataSequence seq = (DataSequence) predictedData.get(i); 34 | predicted[i] = labelmap.revArray(seq.labels); 35 | } 36 | 37 | if (RM_SUFFIX) { 38 | removeSuffix(labels); 39 | removeSuffix(predicted); 40 | } 41 | } 42 | 43 | /** 44 | * Print the scores based on the correct phrases. 45 | * @return F1 score 46 | */ 47 | public double phraseScore() { 48 | HashMap labelht = new HashMap(); 49 | ArrayList labs = new ArrayList(); 50 | collectLabels(labels, predicted, labelht, labs); 51 | 52 | double nTokens = 0; 53 | double nMatched = 0; 54 | 55 | double[] nPhrase = new double[labelht.size()]; 56 | double[] nPredicted = new double[labelht.size()]; 57 | double[] nCorrect = new double[labelht.size()]; 58 | for (int s = 0; s < labels.length; s++) { 59 | nTokens += labels[s].length; 60 | String prev = ""; 61 | for (int t = 0; t < labels[s].length; t++) { 62 | if (labels[s][t].equals(predicted[s][t])) { 63 | nMatched++; 64 | } 65 | 66 | if (!labels[s][t].equals(prev)) { 67 | prev = labels[s][t]; 68 | if (!labels[s][t].equals("O")) { 69 | nPhrase[labelht.get(labels[s][t])]++; 70 | } 71 | } 72 | } 73 | } 74 | 75 | for (int s = 0; s < predicted.length; s++) { 76 | String prev = "O"; 77 | int start = -1, end = -1; 78 | for (int t = 0; t < predicted[s].length; t++) { 79 | if (!predicted[s][t].equals(prev)) { //token changed 80 | if (!prev.equals("O")) { //phrase ended 81 | end = t; 82 | if (!(start > 0 && labels[s][start - 1].equals(labels[s][start])) 83 | && !(end < labels[s].length && labels[s][end - 1].equals(labels[s][end]))) { 84 | int i = start; 85 | for (; i < end; i++) { 86 | if (!predicted[s][i].equals(labels[s][i])) { 87 | break; 88 | } 89 | } 90 | if (i == end) { 91 | nCorrect[labelht.get(predicted[s][start])]++; 92 | } 93 | } 94 | } 95 | 96 | prev = predicted[s][t]; 97 | if (!prev.equals("O")) { //mark beginning of phrase 98 | start = t; 99 | nPredicted[labelht.get(prev)]++; 100 | } 101 | } 102 | 103 | if (t == predicted[s].length - 1 && !predicted[s][t].equals("O")) { 104 | int i = start; 105 | for (; i <= t; i++) { 106 | if (!predicted[s][i].equals(labels[s][i])) { 107 | break; 108 | } 109 | } 110 | if (i == t + 1) { 111 | nCorrect[labelht.get(predicted[s][start])]++; 112 | } 113 | } 114 | } 115 | } 116 | 117 | double nTotCorrect = sum(nCorrect); 118 | double nTotPhrase = sum(nPhrase); 119 | double nTotPredicted = sum(nPredicted); 120 | System.out.println((int) nTokens + " tokens with " + (int) nTotPhrase + " phrases."); 121 | System.out.println((int) nTotPredicted + " predicted phrases with " + (int) nTotCorrect + " being correct."); 122 | 123 | DisplayTable table = new DisplayTable(); 124 | Object[] row1 = {"", "Precision", "#Pred", "Recall", "#Phrase", "F1", "#Correct"}; 125 | table.addRow(row1); 126 | 127 | String strP = strRatio(nTotCorrect, nTotPredicted); 128 | String strR = strRatio(nTotCorrect, nTotPhrase); 129 | String strF1 = strRatio(2 * nTotCorrect, nTotPredicted + nTotPhrase); 130 | Object[] row2 = {"Accuracy: " + strRatio(nMatched, nTokens), strP, (int) nTotPredicted, strR, (int) nTotPhrase, strF1, (int) nTotCorrect}; 131 | table.addRow(row2); 132 | for (int c = 0; c < labelht.size(); c++) { 133 | strP = strRatio(nCorrect[c], nPredicted[c]); 134 | strR = strRatio(nCorrect[c], nPhrase[c]); 135 | strF1 = strRatio(2 * nCorrect[c], nPredicted[c] + nPhrase[c]); 136 | Object[] row = {labs.get(c), strP, (int) nPredicted[c], strR, (int) nPhrase[c], strF1, (int) nCorrect[c]}; 137 | table.addRow(row); 138 | } 139 | System.out.println(table); 140 | return ratio(2 * nTotCorrect, nTotPredicted + nTotPhrase); 141 | } 142 | 143 | /** 144 | * Print the scores based on the correct tokens. 145 | * @return F1 score 146 | */ 147 | public double tokenScore() { 148 | HashMap labelht = new HashMap(); 149 | ArrayList labs = new ArrayList(); 150 | collectLabels(labels, predicted, labelht, labs); 151 | 152 | double nTokens = 0; 153 | double nMatched = 0; 154 | 155 | double[] nToken = new double[labelht.size()]; 156 | double[] nPredicted = new double[labelht.size()]; 157 | double[] nCorrect = new double[labelht.size()]; 158 | for (int s = 0; s < labels.length; s++) { 159 | for (int t = 0; t < labels[s].length; t++) { 160 | nTokens++; 161 | if (labels[s][t].equals(predicted[s][t])) { 162 | nMatched++; 163 | if (!labels[s][t].equals("O")) { 164 | nCorrect[labelht.get(labels[s][t])]++; 165 | } 166 | } 167 | if (!labels[s][t].equals("O")) { 168 | nToken[labelht.get(labels[s][t])]++; 169 | } 170 | if (!predicted[s][t].equals("O")) { 171 | nPredicted[labelht.get(predicted[s][t])]++; 172 | } 173 | } 174 | } 175 | 176 | double nTotCorrect = sum(nCorrect); 177 | double nTotToken = sum(nToken); 178 | double nTotPredicted = sum(nPredicted); 179 | System.out.println((int) nTokens + " tokens with " + (int) nTotToken + " relevant tokens (not equal to O)."); 180 | System.out.println((int) nTotPredicted + " predicted tokens with " + (int) nTotCorrect + " being correct."); 181 | 182 | DisplayTable table = new DisplayTable(); 183 | Object[] row1 = {"", "Precision", "#Pred", "Recall", "#Phrase", "F1", "#Correct"}; 184 | table.addRow(row1); 185 | 186 | String strP = strRatio(nTotCorrect, nTotPredicted); 187 | String strR = strRatio(nTotCorrect, nTotToken); 188 | String strF1 = strRatio(2 * nTotCorrect, nTotPredicted + nTotToken); 189 | Object[] row2 = {"Accuracy: " + strRatio(nMatched, nTokens), strP, (int) nTotPredicted, strR, (int) nTotToken, strF1, (int) nTotCorrect}; 190 | table.addRow(row2); 191 | for (int c = 0; c < labelht.size(); c++) { 192 | strP = strRatio(nCorrect[c], nPredicted[c]); 193 | strR = strRatio(nCorrect[c], nToken[c]); 194 | strF1 = strRatio(2 * nCorrect[c], nPredicted[c] + nToken[c]); 195 | Object[] row = {labs.get(c), strP, (int) nPredicted[c], strR, (int) nToken[c], strF1, (int) nCorrect[c]}; 196 | table.addRow(row); 197 | } 198 | System.out.println(table); 199 | return ratio(2 * nTotCorrect, nTotPredicted + nTotToken); 200 | } 201 | 202 | /** 203 | * Score on sequence level. 204 | * @return Accuracy at sequence level 205 | */ 206 | public double sentenceScore() { 207 | HashMap labelht = new HashMap(); 208 | ArrayList labs = new ArrayList(); 209 | collectLabels(labels, predicted, labelht, labs); 210 | double nMatched = 0; 211 | for (int s = 0; s < labels.length; s++) { 212 | boolean isMatched = true; 213 | for (int t = 0; t < labels[s].length; t++) { 214 | if (!labels[s][t].equals(predicted[s][t])) { 215 | isMatched = false; 216 | break; 217 | } 218 | } 219 | if (isMatched) { 220 | nMatched++; 221 | } 222 | } 223 | System.out.println("Sentence accuracy = " + strRatio(nMatched, labels.length)); 224 | return ratio(nMatched, labels.length); 225 | } 226 | 227 | /** 228 | * Macro-averaged accuracy score. 229 | * @return Macro-averaged accuracy 230 | */ 231 | public double macroAccuracyScore() { 232 | HashMap labelht = new HashMap(); 233 | ArrayList labs = new ArrayList(); 234 | collectLabels(labels, predicted, labelht, labs); 235 | 236 | double[] nToken = new double[labelht.size()]; 237 | double[] nCorrect = new double[labelht.size()]; 238 | for (int s = 0; s < labels.length; s++) { 239 | for (int t = 0; t < labels[s].length; t++) { 240 | if (!labels[s][t].equals("O")) { 241 | nToken[labelht.get(labels[s][t])]++; 242 | if (labels[s][t].equals(predicted[s][t])) { 243 | nCorrect[labelht.get(labels[s][t])]++; 244 | } 245 | } 246 | } 247 | } 248 | 249 | DisplayTable table = new DisplayTable(); 250 | Object[] row1 = {"", "Acc", "#Phrase", "#Correct"}; 251 | table.addRow(row1); 252 | 253 | double totalAcc = 0.0; 254 | for (int c = 0; c < labelht.size(); c++) { 255 | String strAcc = strRatio(nCorrect[c], nToken[c]); 256 | totalAcc += ratio(nCorrect[c], nToken[c]); 257 | Object[] row = {labs.get(c), strAcc, (int) nToken[c], (int) nCorrect[c]}; 258 | table.addRow(row); 259 | } 260 | System.out.println(table); 261 | System.out.println("Averaged accuracy = " + strRatio(totalAcc, labelht.size())); 262 | return ratio(totalAcc, labelht.size()); 263 | } 264 | 265 | private void collectLabels(String[][] labels, String[][] predicted, HashMap labelht, ArrayList labs) { 266 | for (int s = 0; s < labels.length; s++) { 267 | for (int t = 0; t < labels[s].length; t++) { 268 | if (!labels[s][t].equals("O") && !labelht.containsKey(labels[s][t])) { 269 | labelht.put(labels[s][t], labelht.size()); 270 | labs.add(labels[s][t]); 271 | } 272 | 273 | if (!predicted[s][t].equals("O") && !labelht.containsKey(predicted[s][t])) { 274 | labelht.put(predicted[s][t], labelht.size()); 275 | labs.add(predicted[s][t]); 276 | } 277 | } 278 | } 279 | Collections.sort(labs); 280 | for (int i = 0; i < labs.size(); i++) { 281 | labelht.put(labs.get(i), i); 282 | } 283 | } 284 | 285 | private String strRatio(double a, double b) { 286 | double r = 0; 287 | if (b != 0) { 288 | r = 100 * a / b; 289 | } 290 | DecimalFormat df = new DecimalFormat("#0.00"); 291 | return df.format(r) + "%"; 292 | } 293 | 294 | private double ratio(double a, double b) { 295 | if (b != 0) { 296 | return a / b; 297 | } 298 | return 0; 299 | } 300 | 301 | private double sum(double[] ar) { 302 | double s = 0; 303 | for (double d : ar) { 304 | s += d; 305 | } 306 | return s; 307 | } 308 | 309 | private void removeSuffix(String[][] arr) { 310 | for (int i = 0; i < arr.length; i++) { 311 | for (int j = 0; j < arr[i].length; j++) { 312 | if (arr[i][j].lastIndexOf('-') != -1) { 313 | arr[i][j] = arr[i][j].substring(0, arr[i][j].lastIndexOf('-')); 314 | } 315 | } 316 | } 317 | } 318 | } 319 | 320 | /** 321 | * Class for displaying tables 322 | * @author Ye Nan 323 | */ 324 | class DisplayTable { 325 | 326 | ArrayList> rows = new ArrayList>(); // Rows of the table 327 | 328 | DisplayTable() { 329 | } 330 | 331 | void addRow(Object[] entries) { 332 | ArrayList row = new ArrayList(); 333 | for (int i = 0; i < entries.length; i++) { 334 | row.add(entries[i].toString()); 335 | } 336 | rows.add(row); 337 | } 338 | 339 | String format(String s, int w, String align) { 340 | int n = w - s.length(); 341 | if (align == "l") {//align to the left 342 | for (int i = 0; i < n; i++) { 343 | s += " "; 344 | } 345 | } else {//align to the right 346 | for (int i = 0; i < n; i++) { 347 | s = " " + s; 348 | } 349 | } 350 | return s; 351 | } 352 | 353 | @Override 354 | public String toString() { 355 | int ncols = 0; 356 | int nrows = rows.size(); 357 | for (int r = 0; r < nrows; r++) { 358 | if (rows.get(r).size() > ncols) { 359 | ncols = rows.get(r).size(); 360 | } 361 | } 362 | 363 | for (int r = 0; r < nrows; r++) { 364 | int s = rows.get(r).size(); 365 | if (s != ncols) { 366 | for (int i = 0; i < ncols - s; i++) { 367 | rows.get(r).add(""); 368 | } 369 | } 370 | } 371 | int[] colWidths = new int[ncols]; 372 | for (int c = 0; c < ncols; c++) { 373 | for (int r = 0; r < nrows; r++) { 374 | int w = rows.get(r).get(c).length(); 375 | if (w > colWidths[c]) { 376 | colWidths[c] = w; 377 | } 378 | } 379 | } 380 | 381 | StringBuffer sb = new StringBuffer(); 382 | for (int r = 0; r < nrows; r++) { 383 | for (int c = 0; c < ncols; c++) { 384 | String align = "r"; 385 | if (c == ncols - 1) { 386 | align = "l"; 387 | } 388 | sb.append(format(rows.get(r).get(c), colWidths[c], align) + "\t "); 389 | } 390 | sb.append("\n"); 391 | } 392 | return sb.toString(); 393 | } 394 | } 395 | -------------------------------------------------------------------------------- /src/HOCRF/FeatureGenerator.java: -------------------------------------------------------------------------------- 1 | package HOCRF; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | import Parallel.*; 6 | 7 | /** 8 | * Feature generator class 9 | * @author Nguyen Viet Cuong 10 | */ 11 | public class FeatureGenerator { 12 | 13 | ArrayList featureTypes; // Feature types list 14 | int maxOrder; // Maximum order of the CRF 15 | Params params; // Parameters 16 | 17 | HashMap obsMap; // Map from feature observation to its ID 18 | HashMap patternMap; // Map from feature pattern to index 19 | HashMap featureMap; // Map from FeatureIndex to its ID in lambda vector 20 | ArrayList featureList; // Map from feature ID to features 21 | 22 | HashMap forwardStateMap; // Map from forward state to index 23 | ArrayList[] forwardTransition1; // Map from piID to list of pkID (see paper) 24 | ArrayList[] forwardTransition2; // Map from piID to list of pkyID (see paper) 25 | 26 | HashMap backwardStateMap; // Map from backward state to index 27 | int[][] backwardTransition; // Map from [siID,y] to skID (see paper) 28 | ArrayList[] allSuffixes; // Map from sID to its suffixes patID 29 | ArrayList backwardStateList; // List of backward states 30 | 31 | ArrayList[] patternTransition1; // Map from z to piID (see paper) 32 | ArrayList[] patternTransition2; // Map from z to piyID (see paper) 33 | 34 | /** 35 | * Constructor a feature generator. 36 | * @param fts List of feature types 37 | * @param pr Parameters 38 | */ 39 | public FeatureGenerator(ArrayList fts, Params pr) { 40 | featureTypes = fts; 41 | maxOrder = getMaxOrder(); 42 | params = pr; 43 | } 44 | 45 | /** 46 | * Initialize the feature generator with the training data. 47 | * This method needs to be called before the training process. 48 | * @param trainData List of training sequences 49 | */ 50 | public void initialize(ArrayList trainData) throws Exception { 51 | generateFeatureMap(trainData); 52 | generateForwardStatesMap(); 53 | generateBackwardStatesMap(); 54 | generateSentenceFeat(trainData); 55 | buildForwardTransition(); 56 | buildBackwardTransition(); 57 | buildPatternTransition(); 58 | } 59 | 60 | /** 61 | * Write the feature generator to a file. 62 | * @param filename Name of the output file 63 | */ 64 | public void write(String filename) throws Exception { 65 | PrintWriter out = new PrintWriter(new FileOutputStream(filename)); 66 | 67 | // Write observation map 68 | out.println(obsMap.size()); 69 | Iterator iter = obsMap.keySet().iterator(); 70 | while (iter.hasNext()) { 71 | String key = (String) iter.next(); 72 | out.println(key + " " + obsMap.get(key)); 73 | } 74 | 75 | // Write pattern map 76 | out.println(patternMap.size()); 77 | iter = patternMap.keySet().iterator(); 78 | while (iter.hasNext()) { 79 | String key = (String) iter.next(); 80 | out.println(key + " " + patternMap.get(key)); 81 | } 82 | 83 | // Write feature map 84 | out.println(featureMap.size()); 85 | iter = featureMap.keySet().iterator(); 86 | while (iter.hasNext()) { 87 | FeatureIndex fi = (FeatureIndex) iter.next(); 88 | int index = (Integer) featureMap.get(fi); 89 | Feature f = featureList.get(index); 90 | out.println(f.obs + " " + f.pat + " " + f.value + " " + index); 91 | } 92 | 93 | // Write forward state map 94 | out.println(forwardStateMap.size()); 95 | iter = forwardStateMap.keySet().iterator(); 96 | while (iter.hasNext()) { 97 | String key = (String) iter.next(); 98 | if (!key.equals("")) { 99 | out.println(key + " " + forwardStateMap.get(key)); 100 | } 101 | } 102 | 103 | // Write backward state map 104 | out.println(backwardStateMap.size()); 105 | iter = backwardStateMap.keySet().iterator(); 106 | while (iter.hasNext()) { 107 | String key = (String) iter.next(); 108 | out.println(key + " " + backwardStateMap.get(key)); 109 | } 110 | 111 | out.close(); 112 | } 113 | 114 | /** 115 | * Load the feature generator from a file. 116 | * @param filename Name of the file that contains the feature generator information 117 | */ 118 | public void read(String filename) throws Exception { 119 | BufferedReader in = new BufferedReader(new FileReader(filename)); 120 | 121 | // Read observation map 122 | int mapSize = Integer.parseInt(in.readLine()); 123 | obsMap = new HashMap(); 124 | for (int i = 0; i < mapSize; i++) { 125 | String line = in.readLine(); 126 | StringTokenizer toks = new StringTokenizer(line); 127 | String key = toks.nextToken(); 128 | int index = Integer.parseInt(toks.nextToken()); 129 | obsMap.put(key, index); 130 | } 131 | 132 | // Read pattern map 133 | mapSize = Integer.parseInt(in.readLine()); 134 | patternMap = new HashMap(); 135 | for (int i = 0; i < mapSize; i++) { 136 | String line = in.readLine(); 137 | StringTokenizer toks = new StringTokenizer(line); 138 | String key = toks.nextToken(); 139 | int index = Integer.parseInt(toks.nextToken()); 140 | patternMap.put(key, index); 141 | } 142 | 143 | // Read feature map 144 | mapSize = Integer.parseInt(in.readLine()); 145 | featureMap = new HashMap(); 146 | featureList = new ArrayList(mapSize); 147 | for (int i = 0; i < mapSize; i++) featureList.add(null); 148 | for (int i = 0; i < mapSize; i++) { 149 | String line = in.readLine(); 150 | StringTokenizer toks = new StringTokenizer(line); 151 | String obs = toks.nextToken(); 152 | String pat = toks.nextToken(); 153 | double value = Double.parseDouble(toks.nextToken()); 154 | int index = Integer.parseInt(toks.nextToken()); 155 | Feature f = new Feature(obs, pat, value); 156 | featureMap.put(getFeatureIndex(f), index); 157 | featureList.set(index, f); 158 | } 159 | 160 | // Read forward state map 161 | mapSize = Integer.parseInt(in.readLine()); 162 | forwardStateMap = new HashMap(); 163 | forwardStateMap.put("", new Integer(0)); 164 | for (int i = 0; i < mapSize-1; i++) { 165 | String line = in.readLine(); 166 | StringTokenizer toks = new StringTokenizer(line); 167 | String key = toks.nextToken(); 168 | int index = Integer.parseInt(toks.nextToken()); 169 | forwardStateMap.put(key, index); 170 | } 171 | 172 | // Read backward state map 173 | mapSize = Integer.parseInt(in.readLine()); 174 | backwardStateMap = new HashMap(); 175 | backwardStateList = new ArrayList(mapSize); 176 | for (int i = 0; i < mapSize; i++) backwardStateList.add(null); 177 | for (int i = 0; i < mapSize; i++) { 178 | String line = in.readLine(); 179 | StringTokenizer toks = new StringTokenizer(line); 180 | String key = toks.nextToken(); 181 | int index = Integer.parseInt(toks.nextToken()); 182 | backwardStateMap.put(key, index); 183 | backwardStateList.set(index, key); 184 | } 185 | 186 | buildForwardTransition(); 187 | buildBackwardTransition(); 188 | buildPatternTransition(); 189 | 190 | in.close(); 191 | } 192 | 193 | /** 194 | * Get the index of a feature. 195 | * @param f Feature 196 | * @return The feature index 197 | */ 198 | public FeatureIndex getFeatureIndex(Feature f) { 199 | Integer obs = (Integer) getObsIndex(f.obs); 200 | Integer pat = (Integer) getPatternIndex(f.pat); 201 | if (obs == null || pat == null) { 202 | return null; 203 | } else { 204 | return new FeatureIndex(getObsIndex(f.obs), getPatternIndex(f.pat)); 205 | } 206 | } 207 | 208 | /** 209 | * Get the index of an observation string. 210 | * @param obs Observation string 211 | * @return Observation index 212 | */ 213 | public Integer getObsIndex(String obs) { 214 | return (Integer) obsMap.get(obs); 215 | } 216 | 217 | /** 218 | * Get the index of a pattern string. 219 | * @param p Pattern string 220 | * @return Pattern index 221 | */ 222 | public Integer getPatternIndex(String p) { 223 | return (Integer) patternMap.get(p); 224 | } 225 | 226 | /** 227 | * Get the index of a forward state. 228 | * @param p Forward state 229 | * @return Index of the forward state 230 | */ 231 | public Integer getForwardStateIndex(String p) { 232 | return (Integer) forwardStateMap.get(p); 233 | } 234 | 235 | /** 236 | * Get the index of a backward state. 237 | * @param p Backward state 238 | * @return Index of the backward state 239 | */ 240 | public Integer getBackwardStateIndex(String p) { 241 | return (Integer) backwardStateMap.get(p); 242 | } 243 | 244 | /** 245 | * Get the maximum order of the CRF. 246 | * @return Maximum order of the CRF 247 | */ 248 | public int getMaxOrder() { 249 | int res = -1; 250 | for (int i = 0; i < featureTypes.size(); i++) { 251 | if (res < featureTypes.get(i).order()) { 252 | res = featureTypes.get(i).order(); 253 | } 254 | } 255 | return res; 256 | } 257 | 258 | /** 259 | * Generate the features for each training sequence. 260 | * @param trainData List of training sequences 261 | */ 262 | public void generateSentenceFeat(ArrayList trainData) throws Exception { 263 | SentenceFeatGenerator gen = new SentenceFeatGenerator(trainData, this); 264 | Scheduler sch = new Scheduler(gen, params.numthreads, Scheduler.DYNAMIC_NEXT_AVAILABLE); 265 | sch.run(); 266 | } 267 | 268 | /** 269 | * Generate the observation map, pattern map, feature map, and feature list from training data. 270 | * @param trainData List of training sequences 271 | */ 272 | public void generateFeatureMap(ArrayList trainData) { 273 | obsMap = new HashMap(); 274 | patternMap = new HashMap(); 275 | featureMap = new HashMap(); 276 | featureList = new ArrayList(); 277 | for (int t = 0; t < trainData.size(); t++) { 278 | DataSequence seq = (DataSequence) trainData.get(t); 279 | for (int pos = 0; pos < seq.length(); pos++) { 280 | String labelPat = generateLabelPattern(seq, pos); 281 | ArrayList features = generateFeatures(seq, pos, labelPat); 282 | 283 | for (Feature f : features) { 284 | Integer obs_index = getObsIndex(f.obs); 285 | if (obs_index == null) { 286 | obsMap.put(f.obs, obsMap.size()); 287 | } 288 | 289 | Integer pat_index = getPatternIndex(f.pat); 290 | if (pat_index == null) { 291 | patternMap.put(f.pat, patternMap.size()); 292 | } 293 | 294 | FeatureIndex index = getFeatureIndex(f); 295 | if (!featureMap.containsKey(index)) { 296 | featureMap.put(index, featureMap.size()); 297 | featureList.add(f); 298 | } 299 | } 300 | } 301 | } 302 | 303 | //System.out.println("Num pattern = " + patternMap.size()); 304 | } 305 | 306 | /** 307 | * Generate the forward state map. 308 | */ 309 | public void generateForwardStatesMap() { 310 | forwardStateMap = new HashMap(); 311 | forwardStateMap.put("", new Integer(0)); 312 | for (int i = 0; i < params.numLabels; i++) { 313 | forwardStateMap.put("" + i, new Integer(forwardStateMap.size())); 314 | } 315 | Iterator iter = patternMap.keySet().iterator(); 316 | while (iter.hasNext()) { 317 | String labelPat = (String) iter.next(); 318 | ArrayList pats = Utility.generateProperPrefixes(labelPat); 319 | for (String pat : pats) { 320 | if (getForwardStateIndex(pat) == null) { 321 | forwardStateMap.put(pat, forwardStateMap.size()); 322 | } 323 | } 324 | } 325 | } 326 | 327 | /** 328 | * Generate the backward state map and the backward state list. 329 | */ 330 | public void generateBackwardStatesMap() { 331 | backwardStateMap = new HashMap(); 332 | backwardStateList = new ArrayList(); 333 | Iterator iter = forwardStateMap.keySet().iterator(); 334 | while (iter.hasNext()) { 335 | String p = (String) iter.next(); 336 | int lastLabel = p.equals("") ? -1 : Integer.parseInt(Utility.getLastLabel(p)); 337 | for (int y = 0; y < params.numLabels; y++) { 338 | String py = p.equals("") ? y + "" : y + "|" + p; 339 | if (getBackwardStateIndex(py) == null) { 340 | backwardStateMap.put(py, backwardStateMap.size()); 341 | backwardStateList.add(py); 342 | } 343 | } 344 | } 345 | } 346 | 347 | /** 348 | * Generate the maximum posible pattern for a position. 349 | * Note that patterns are in reversed order: y(t)|y(t-1)|y(t-2)|... 350 | * @param seq Data sequence 351 | * @param pos Input position 352 | * @return Pattern string 353 | */ 354 | public String generateLabelPattern(DataSequence seq, int pos) { 355 | String labelPat = ""; 356 | for (int i = 0; i <= maxOrder && pos-i >= 0; i++) { 357 | labelPat = labelPat + "|" + seq.y(pos-i); 358 | } 359 | labelPat = labelPat.substring(1); 360 | return labelPat; 361 | } 362 | 363 | /** 364 | * Generate all features activated at a position with a given label pattern. 365 | * @param seq Data sequence 366 | * @param pos Input position 367 | * @param labelPat Label pattern 368 | * @return List of activated features 369 | */ 370 | public ArrayList generateFeatures(DataSequence seq, int pos, String labelPat) { 371 | ArrayList features = new ArrayList(); 372 | ArrayList suffixes = Utility.generateSuffixes(labelPat); 373 | for (String s : suffixes) { 374 | ArrayList fi = generateFeaturesWithExactPattern(seq, pos, s); 375 | features.addAll(fi); 376 | } 377 | return features; 378 | } 379 | 380 | /** 381 | * Generate all features activated at a position with an exact label pattern. 382 | * @param seq Data sequence 383 | * @param pos Input position 384 | * @param labelPat Exact label pattern of the activated features 385 | * @return List of activated features 386 | */ 387 | public ArrayList generateFeaturesWithExactPattern(DataSequence seq, int pos, String labelPat) { 388 | ArrayList features = new ArrayList(); 389 | for (FeatureType ft : featureTypes) { 390 | ArrayList fi = ft.generateFeaturesAt(seq, pos, labelPat); 391 | features.addAll(fi); 392 | } 393 | return features; 394 | } 395 | 396 | /** 397 | * Generate all observations at a position. 398 | * @param seq Data sequence 399 | * @param pos Input position 400 | * @return List of observations 401 | */ 402 | public ArrayList generateObs(DataSequence seq, int pos) { 403 | ArrayList obs = new ArrayList(); 404 | for (FeatureType ft : featureTypes) { 405 | obs.addAll(ft.generateObsAt(seq, pos)); 406 | } 407 | return obs; 408 | } 409 | 410 | /** 411 | * Return the index of the longest suffix of a string. 412 | * @param p The input string 413 | * @param map Map from strings to indices 414 | * @return Index of the longest suffix of the input string from the input map. 415 | */ 416 | public Integer getLongestSuffixID(String p, HashMap map) { 417 | ArrayList suffixes = Utility.generateSuffixes(p); 418 | for (int i = 0; i < suffixes.size(); i++) { 419 | Integer index = (Integer) map.get(suffixes.get(i)); 420 | if (index != null) { 421 | return index; 422 | } 423 | } 424 | throw new UnsupportedOperationException("No longest suffix index!\n"); 425 | } 426 | 427 | /** 428 | * Return the longest suffix of a string. 429 | * @param p The input string 430 | * @param map Map from strings to indices 431 | * @return The longest suffix of the input string from the input map. 432 | */ 433 | public String getLongestSuffix(String p, HashMap map) { 434 | ArrayList suffixes = Utility.generateSuffixes(p); 435 | for (int i = 0; i < suffixes.size(); i++) { 436 | Integer index = (Integer) map.get(suffixes.get(i)); 437 | if (index != null) { 438 | return suffixes.get(i); 439 | } 440 | } 441 | throw new UnsupportedOperationException("No longest suffix!\n"); 442 | } 443 | 444 | /** 445 | * Build the information for the forward algorithm. 446 | */ 447 | public void buildForwardTransition() { 448 | forwardTransition1 = new ArrayList[forwardStateMap.size()]; 449 | forwardTransition2 = new ArrayList[forwardStateMap.size()]; 450 | 451 | Iterator iter = forwardStateMap.keySet().iterator(); 452 | while (iter.hasNext()) { 453 | String pk = (String) iter.next(); 454 | int pkID = getForwardStateIndex(pk); 455 | 456 | for (int y = 0; y < params.numLabels; y++) { 457 | String pky = pk.equals("") ? y + "" : y + "|" + pk; 458 | Integer index = getLongestSuffixID(pky, forwardStateMap); 459 | if (forwardTransition1[index] == null) { 460 | forwardTransition1[index] = new ArrayList(); 461 | forwardTransition2[index] = new ArrayList(); 462 | } 463 | forwardTransition1[index].add(pkID); 464 | forwardTransition2[index].add(getBackwardStateIndex(pky)); 465 | } 466 | } 467 | } 468 | 469 | /** 470 | * Build the information for the backward algorithm. 471 | */ 472 | public void buildBackwardTransition() { 473 | backwardTransition = new int[backwardStateMap.size()][params.numLabels]; 474 | allSuffixes = new ArrayList[backwardStateMap.size()]; 475 | 476 | Iterator iter = backwardStateMap.keySet().iterator(); 477 | while (iter.hasNext()) { 478 | String si = (String) iter.next(); 479 | int siID = getBackwardStateIndex(si); 480 | int lastLabel = si.equals("") ? -1 : Integer.parseInt(Utility.getLastLabel(si)); 481 | for (int y = 0; y < params.numLabels; y++) { 482 | String siy = y + "|" + si; 483 | String sk = getLongestSuffix(siy, backwardStateMap); 484 | backwardTransition[siID][y] = getBackwardStateIndex(sk); 485 | } 486 | 487 | allSuffixes[siID] = new ArrayList(); 488 | ArrayList suffixes = Utility.generateSuffixes(si); 489 | for (String suffix : suffixes) { 490 | Integer patID = getPatternIndex(suffix); 491 | if (patID != null) { 492 | allSuffixes[siID].add(patID); 493 | } 494 | } 495 | } 496 | } 497 | 498 | /** 499 | * Build the information to compute the marginals and expected feature scores. 500 | */ 501 | public void buildPatternTransition() { 502 | patternTransition1 = new ArrayList[patternMap.size()]; 503 | patternTransition2 = new ArrayList[patternMap.size()]; 504 | 505 | Iterator forwardIter = forwardStateMap.keySet().iterator(); 506 | while (forwardIter.hasNext()) { 507 | String pi = (String) forwardIter.next(); 508 | int lastLabel = pi.equals("") ? -1 : Integer.parseInt(Utility.getLastLabel(pi)); 509 | int piID = getForwardStateIndex(pi); 510 | for (int y = 0; y < params.numLabels; y++) { 511 | String piy = pi.equals("") ? y + "" : y + "|" + pi; 512 | Integer piyID = getBackwardStateIndex(piy); 513 | ArrayList suffixes = Utility.generateSuffixes(piy); 514 | for (String zi : suffixes) { 515 | Integer ziIndex = getPatternIndex(zi); 516 | if (ziIndex != null) { 517 | if (patternTransition1[ziIndex] == null) { 518 | patternTransition1[ziIndex] = new ArrayList(); 519 | patternTransition2[ziIndex] = new ArrayList(); 520 | } 521 | patternTransition1[ziIndex].add(piID); 522 | patternTransition2[ziIndex].add(piyID); 523 | } 524 | } 525 | } 526 | } 527 | } 528 | 529 | /** 530 | * Get the IDs of the features activated at a position. 531 | * @param seq Data sequence 532 | * @param pos Input position 533 | * @param patID Pattern ID 534 | * @return List of feature IDs 535 | */ 536 | public ArrayList getFeatures(DataSequence seq, int pos, int patID) { 537 | return seq.getFeatures(pos, patID); 538 | } 539 | 540 | /** 541 | * Get the IDs of a list of features. 542 | * @param fs List of features 543 | * @return List of feature IDs 544 | */ 545 | public ArrayList getFeatureID(ArrayList fs) { 546 | ArrayList feats = new ArrayList(); 547 | for (Feature f : fs) { 548 | Integer feat = (Integer) featureMap.get(getFeatureIndex(f)); 549 | if (feat != null) { 550 | feats.add(feat); 551 | } 552 | } 553 | return feats; 554 | } 555 | 556 | /** 557 | * Compute the feature scores of a list of features and a weight vector. 558 | * @param feats List of feature IDs 559 | * @param lambda Weights of all the features 560 | * @return The total feature score 561 | */ 562 | public double computeFeatureScores(ArrayList feats, double[] lambda) { 563 | double featuresScore = 0.0; 564 | for (int index : feats) { 565 | Feature feat = featureList.get(index); 566 | featuresScore += lambda[index] * feat.value; 567 | } 568 | return featuresScore; 569 | } 570 | 571 | /** 572 | * Print all statistics for testing. 573 | */ 574 | public void printStatesStatistics() { 575 | System.out.println("Forward Transition:"); 576 | for (int piID = 0; piID < forwardStateMap.size(); piID++) { 577 | System.out.println(piID); 578 | if (forwardTransition1[piID] != null) { 579 | for (int i = 0; i < forwardTransition1[piID].size(); i++) { 580 | System.out.println(forwardTransition1[piID].get(i) + " " + forwardTransition2[piID].get(i)); 581 | } 582 | } 583 | } 584 | 585 | System.out.println("Backward Transition:"); 586 | for (int sID = 0; sID < backwardStateMap.size(); sID++) { 587 | for (int y = 0; y < params.numLabels; y++) { 588 | System.out.println(sID + " " + y + " --> " + backwardTransition[sID][y]); 589 | } 590 | } 591 | 592 | System.out.println("Pattern Transition:"); 593 | for (int pID = 0; pID < patternMap.size(); pID++) { 594 | System.out.println(pID); 595 | if (patternTransition1[pID] != null) { 596 | for (int i = 0; i < patternTransition1[pID].size(); i++) { 597 | System.out.println(patternTransition1[pID].get(i) + " " + patternTransition2[pID].get(i)); 598 | } 599 | } 600 | } 601 | } 602 | } 603 | --------------------------------------------------------------------------------