├── .gitignore ├── src ├── main │ ├── java │ │ ├── fig │ │ │ ├── basic │ │ │ │ ├── DeepCloneable.java │ │ │ │ ├── OptionSet.java │ │ │ │ ├── AbstractT2Map.java │ │ │ │ ├── genCode │ │ │ │ ├── Option.java │ │ │ │ ├── IdentityHashSet.java │ │ │ │ ├── StatFig.java │ │ │ │ ├── BigStatFig.java │ │ │ │ ├── CharEncUtils.java │ │ │ │ ├── Exceptions.java │ │ │ │ ├── AbstractTMap.java │ │ │ │ ├── MapFactory.java │ │ │ │ ├── StopWatchSet.java │ │ │ │ ├── FullStatFig.java │ │ │ │ ├── OrderedStringMap.java │ │ │ │ ├── Fmt.java │ │ │ │ ├── StopWatch.java │ │ │ │ ├── Interner.java │ │ │ │ ├── Pair.java │ │ │ │ ├── OrderedMap.java │ │ │ │ ├── SysInfoUtils.java │ │ │ │ ├── T2VMap.java │ │ │ │ ├── T2DoubleMap.java │ │ │ │ ├── MapUtils.java │ │ │ │ └── PriorityQueue.java │ │ │ ├── record │ │ │ │ ├── Recordable.java │ │ │ │ └── Record.java │ │ │ └── exec │ │ │ │ └── MonitorThread.java │ │ └── fast │ │ │ ├── evaluation │ │ │ ├── Mastery.java │ │ │ ├── Sample.java │ │ │ ├── Degeneracy.java │ │ │ ├── Metrics.java │ │ │ ├── TestSummary.java │ │ │ ├── TrainSummary.java │ │ │ ├── AUC.java │ │ │ └── PredictivePerformance.java │ │ │ ├── common │ │ │ ├── Functions.java │ │ │ ├── Bijection.java │ │ │ ├── Utility.java │ │ │ ├── Matrix.java │ │ │ └── Stats.java │ │ │ ├── data │ │ │ ├── CVStudent.java │ │ │ └── DataPoint.java │ │ │ └── featurehmm │ │ │ ├── PdfFeatureAware.java │ │ │ ├── LogisticRegression.java │ │ │ ├── BaumWelchScaledLearner.java │ │ │ ├── LBFGS.java │ │ │ ├── FeatureHMM.java │ │ │ └── ForwardBackwardScaledCalculator.java │ ├── assembly │ │ ├── jar.xml │ │ └── zip.xml │ └── python │ │ └── split_dataset.py └── log4j.properties ├── data ├── IRT_exp │ ├── README_dataset.txt │ ├── KT1.conf │ ├── KT2.conf │ ├── FAST+IRT1.conf │ ├── FAST+IRT2.conf │ ├── test0.csv │ └── train0.csv ├── others │ ├── KT.conf │ ├── FAST+IRT.conf │ ├── FAST+item.conf │ ├── FAST+subskill.conf │ ├── FAST+item_SplitBySeq.conf │ ├── KT_test0.txt │ ├── KT_train0.txt │ ├── FAST+item_SplitBySeq_test0.txt │ ├── FAST+item_SplitBySeq_train0.txt │ ├── FAST+item_train0.txt │ ├── FAST+item_test0.txt │ ├── FAST+subskill_test0.txt │ └── FAST+subskill_train0.txt └── item_exp │ ├── KT1.conf │ ├── KT2.conf │ ├── FAST+item3.conf │ ├── FAST+item1.conf │ ├── FAST+item2.conf │ └── README_dataset.txt ├── .project ├── .classpath ├── pom.xml └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | /target/ 2 | sample-data/output 3 | datasets 4 | src/main/python/.idea 5 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/DeepCloneable.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | public interface DeepCloneable { 4 | public T deepClone(); 5 | } 6 | 7 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/OptionSet.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.lang.annotation.*; 4 | 5 | @Retention(RetentionPolicy.RUNTIME) 6 | @Target(ElementType.FIELD) 7 | public @interface OptionSet { 8 | String name(); 9 | } 10 | 11 | -------------------------------------------------------------------------------- /src/main/java/fig/record/Recordable.java: -------------------------------------------------------------------------------- 1 | package fig.record; 2 | 3 | /** 4 | * If an object is recordable, we can add it the static record. 5 | * record() should make calls to Record.add() and so on. 6 | */ 7 | public interface Recordable { 8 | public void record(Object arg); 9 | } 10 | -------------------------------------------------------------------------------- /src/log4j.properties: -------------------------------------------------------------------------------- 1 | # Root logger option 2 | log4j.rootLogger=DEBUG, stdout 3 | 4 | # Direct log messages to stdout 5 | log4j.appender.stdout=org.apache.log4j.ConsoleAppender 6 | log4j.appender.stdout.Target=System.out 7 | log4j.appender.stdout.layout=org.apache.log4j.PatternLayout 8 | log4j.appender.stdout.layout.ConversionPattern=%d{ABSOLUTE} %5p %c{1}:%L - %m%n -------------------------------------------------------------------------------- /data/IRT_exp/README_dataset.txt: -------------------------------------------------------------------------------- 1 | This folder provides a subset of a private dataset from the Quizjet Java tutoring system. This subset contains and only contains observations of one skill, “Nested Loops”. This subset is provided just for FAST testing purpose. Please don’t distribute or use this subset without permission. For getting permission to use the dataset, please contact Dr. Peter Brusilovsky (peterb@pitt.edu). -------------------------------------------------------------------------------- /data/others/KT.conf: -------------------------------------------------------------------------------- 1 | modelName KT 2 | parameterizing false 3 | parameterizingInit false 4 | parameterizingTran false 5 | parameterizingEmit false 6 | forceUsingAllInputFeatures false 7 | inDir ./input/ 8 | outDir ./output/ 9 | nbFiles 1 10 | nbRandomRestart 20 11 | trainInFilePrefix KT_train 12 | testInFilePrefix KT_test 13 | inFileSuffix .txt 14 | EMMaxIters 500 15 | LBFGSMaxIters 50 16 | EMTolerance 1.0E-6 17 | LBFGSTolerance 1.0E-6 18 | -------------------------------------------------------------------------------- /data/IRT_exp/KT1.conf: -------------------------------------------------------------------------------- 1 | modelName KT1 2 | parameterizing false 3 | parameterizingInit false 4 | parameterizingTran false 5 | parameterizingEmit false 6 | forceUsingAllInputFeatures false 7 | nbRandomRestart 1 8 | nbFiles 1 9 | inDir ./data/IRT_exp/ 10 | outDir ./data/IRT_exp/ 11 | trainInFilePrefix train 12 | testInFilePrefix test 13 | inFileSuffix .csv 14 | EMMaxIters 500 15 | LBFGSMaxIters 50 16 | EMTolerance 1.0E-6 17 | LBFGSTolerance 1.0E-6 18 | -------------------------------------------------------------------------------- /data/IRT_exp/KT2.conf: -------------------------------------------------------------------------------- 1 | modelName KT2 2 | parameterizing false 3 | parameterizingInit false 4 | parameterizingTran false 5 | parameterizingEmit false 6 | forceUsingAllInputFeatures false 7 | nbRandomRestart 20 8 | nbFiles 1 9 | inDir ./data/IRT_exp/ 10 | outDir ./data/IRT_exp/ 11 | trainInFilePrefix train 12 | testInFilePrefix test 13 | inFileSuffix .csv 14 | EMMaxIters 500 15 | LBFGSMaxIters 50 16 | EMTolerance 1.0E-6 17 | LBFGSTolerance 1.0E-6 18 | -------------------------------------------------------------------------------- /data/item_exp/KT1.conf: -------------------------------------------------------------------------------- 1 | modelName KT1 2 | parameterizing false 3 | parameterizingInit false 4 | parameterizingTran false 5 | parameterizingEmit false 6 | forceUsingAllInputFeatures false 7 | nbRandomRestart 1 8 | nbFiles 1 9 | inDir ./data/item_exp/ 10 | outDir ./data/item_exp/ 11 | trainInFilePrefix train 12 | testInFilePrefix test 13 | inFileSuffix .csv 14 | EMMaxIters 500 15 | LBFGSMaxIters 50 16 | EMTolerance 1.0E-6 17 | LBFGSTolerance 1.0E-6 18 | -------------------------------------------------------------------------------- /data/item_exp/KT2.conf: -------------------------------------------------------------------------------- 1 | modelName KT2 2 | parameterizing false 3 | parameterizingInit false 4 | parameterizingTran false 5 | parameterizingEmit false 6 | forceUsingAllInputFeatures false 7 | nbRandomRestart 20 8 | nbFiles 1 9 | inDir ./data/item_exp/ 10 | outDir ./data/item_exp/ 11 | trainInFilePrefix train 12 | testInFilePrefix test 13 | inFileSuffix .csv 14 | EMMaxIters 500 15 | LBFGSMaxIters 50 16 | EMTolerance 1.0E-6 17 | LBFGSTolerance 1.0E-6 18 | -------------------------------------------------------------------------------- /data/item_exp/FAST+item3.conf: -------------------------------------------------------------------------------- 1 | modelName FAST+item3 2 | parameterizing true 3 | parameterizingInit true 4 | parameterizingTran true 5 | parameterizingEmit true 6 | forceUsingAllInputFeatures true 7 | nbRandomRestart 1 8 | nbFiles 1 9 | inDir ./data/item_exp/ 10 | outDir ./data/item_exp/ 11 | trainInFilePrefix train 12 | testInFilePrefix test 13 | inFileSuffix .csv 14 | EMMaxIters 500 15 | LBFGSMaxIters 50 16 | EMTolerance 1.0E-6 17 | LBFGSTolerance 1.0E-6 18 | -------------------------------------------------------------------------------- /data/others/FAST+IRT.conf: -------------------------------------------------------------------------------- 1 | modelName FAST+IRT 2 | parameterizing true 3 | parameterizingInit false 4 | parameterizingTran false 5 | parameterizingEmit true 6 | forceUsingAllInputFeatures true 7 | inDir ./input/ 8 | outDir ./output/ 9 | nbFiles 1 10 | nbRandomRestart 20 11 | trainInFilePrefix FAST+IRT_train 12 | testInFilePrefix FAST+IRT_test 13 | inFileSuffix .txt 14 | EMMaxIters 500 15 | LBFGSMaxIters 50 16 | EMTolerance 1.0E-6 17 | LBFGSTolerance 1.0E-6 18 | -------------------------------------------------------------------------------- /data/item_exp/FAST+item1.conf: -------------------------------------------------------------------------------- 1 | modelName FAST+item1 2 | parameterizing true 3 | parameterizingInit false 4 | parameterizingTran false 5 | parameterizingEmit true 6 | forceUsingAllInputFeatures true 7 | nbRandomRestart 1 8 | nbFiles 1 9 | inDir ./data/item_exp/ 10 | outDir ./data/item_exp/ 11 | trainInFilePrefix train 12 | testInFilePrefix test 13 | inFileSuffix .csv 14 | EMMaxIters 500 15 | LBFGSMaxIters 50 16 | EMTolerance 1.0E-6 17 | LBFGSTolerance 1.0E-6 18 | -------------------------------------------------------------------------------- /data/item_exp/FAST+item2.conf: -------------------------------------------------------------------------------- 1 | modelName FAST+item2 2 | parameterizing true 3 | parameterizingInit false 4 | parameterizingTran false 5 | parameterizingEmit true 6 | forceUsingAllInputFeatures true 7 | nbRandomRestart 20 8 | nbFiles 1 9 | inDir ./data/item_exp/ 10 | outDir ./data/item_exp/ 11 | trainInFilePrefix train 12 | testInFilePrefix test 13 | inFileSuffix .csv 14 | EMMaxIters 500 15 | LBFGSMaxIters 50 16 | EMTolerance 1.0E-6 17 | LBFGSTolerance 1.0E-6 18 | -------------------------------------------------------------------------------- /data/others/FAST+item.conf: -------------------------------------------------------------------------------- 1 | modelName FAST+item 2 | parameterizing true 3 | parameterizingInit false 4 | parameterizingTran false 5 | parameterizingEmit true 6 | forceUsingAllInputFeatures true 7 | inDir ./input/ 8 | outDir ./output/ 9 | nbFiles 1 10 | nbRandomRestart 20 11 | trainInFilePrefix FAST+item_train 12 | testInFilePrefix FAST+item_test 13 | inFileSuffix .txt 14 | EMMaxIters 500 15 | LBFGSMaxIters 50 16 | EMTolerance 1.0E-6 17 | LBFGSTolerance 1.0E-6 18 | -------------------------------------------------------------------------------- /data/others/FAST+subskill.conf: -------------------------------------------------------------------------------- 1 | modelName FAST+subskill 2 | parameterizing true 3 | parameterizingInit false 4 | parameterizingTran false 5 | parameterizingEmit true 6 | forceUsingAllInputFeatures true 7 | inDir ./input/ 8 | outDir ./output/ 9 | nbFiles 1 10 | nbRandomRestart 20 11 | trainInFilePrefix FAST+subskill_train 12 | testInFilePrefix FAST+subskill_test 13 | inFileSuffix .txt 14 | EMMaxIters 500 15 | LBFGSMaxIters 50 16 | EMTolerance 1.0E-6 17 | LBFGSTolerance 1.0E-6 18 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/AbstractT2Map.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | 4 | /** 5 | * Just a dummy template right now. TODO: move functionality in here. 6 | */ 7 | public abstract class AbstractT2Map, T extends Comparable> { 8 | public abstract void switchToSortedList(); 9 | 10 | public abstract void lock(); 11 | 12 | public abstract int size(); 13 | 14 | protected boolean locked; 15 | protected AbstractTMap.Functionality keyFunc; 16 | } 17 | -------------------------------------------------------------------------------- /data/item_exp/README_dataset.txt: -------------------------------------------------------------------------------- 1 | This folder provides the 'Geometry Area (1996-97)' dataset accessed via DataShop (Koedinger et al., 2010), available at http://pslcdatashop.org (Koedinger et al., 2010). 2 | 3 | Reference: 4 | Koedinger, K.R., Baker, R.S.J.d., Cunningham, K., Skogsholm, A., Leber, B., Stamper, J. (2010) A Data Repository for the EDM community: The PSLC DataShop. In Romero, C., Ventura, S., Pechenizkiy, M., Baker, R.S.J.d. (Eds.) Handbook of Educational Data Mining. Boca Raton, FL: CRC Press. -------------------------------------------------------------------------------- /.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | fast 4 | NO_M2ECLIPSE_SUPPORT: Project files created with the maven-eclipse-plugin are not supported in M2Eclipse. 5 | 6 | 7 | 8 | org.eclipse.jdt.core.javabuilder 9 | 10 | 11 | 12 | org.eclipse.jdt.core.javanature 13 | 14 | -------------------------------------------------------------------------------- /data/others/FAST+item_SplitBySeq.conf: -------------------------------------------------------------------------------- 1 | modelName FAST+item_SplitBySeq 2 | parameterizing true 3 | parameterizingInit false 4 | parameterizingTran false 5 | parameterizingEmit true 6 | forceUsingAllInputFeatures true 7 | inDir ./input/ 8 | outDir ./output/ 9 | nbFiles 1 10 | nbRandomRestart 20 11 | trainInFilePrefix FAST+item_SplitBySeq_train 12 | testInFilePrefix FAST+item_SplitBySeq_test 13 | inFileSuffix .txt 14 | EMMaxIters 500 15 | LBFGSMaxIters 50 16 | EMTolerance 1.0E-6 17 | LBFGSTolerance 1.0E-6 18 | -------------------------------------------------------------------------------- /data/IRT_exp/FAST+IRT1.conf: -------------------------------------------------------------------------------- 1 | modelName FAST+IRT1 2 | parameterizing true 3 | parameterizingInit false 4 | parameterizingTran false 5 | parameterizingEmit true 6 | forceUsingAllInputFeatures true 7 | generateStudentDummy true 8 | generateItemDummy true 9 | nbRandomRestart 1 10 | nbFiles 1 11 | inDir ./data/IRT_exp/ 12 | outDir ./data/IRT_exp/ 13 | trainInFilePrefix train 14 | testInFilePrefix test 15 | inFileSuffix .csv 16 | EMMaxIters 500 17 | LBFGSMaxIters 50 18 | EMTolerance 1.0E-6 19 | LBFGSTolerance 1.0E-6 20 | -------------------------------------------------------------------------------- /data/IRT_exp/FAST+IRT2.conf: -------------------------------------------------------------------------------- 1 | modelName FAST+IRT2 2 | parameterizing true 3 | parameterizingInit false 4 | parameterizingTran false 5 | parameterizingEmit true 6 | forceUsingAllInputFeatures true 7 | generateStudentDummy true 8 | generateItemDummy true 9 | nbRandomRestart 20 10 | nbFiles 1 11 | inDir ./data/IRT_exp/ 12 | outDir ./data/IRT_exp/ 13 | trainInFilePrefix train 14 | testInFilePrefix test 15 | inFileSuffix .csv 16 | EMMaxIters 500 17 | LBFGSMaxIters 50 18 | EMTolerance 1.0E-6 19 | LBFGSTolerance 1.0E-6 20 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/genCode: -------------------------------------------------------------------------------- 1 | #!/usr/bin/ruby 2 | 3 | def gen(srcFile, destFile, map) 4 | puts "generate #{srcFile} -> #{destFile}" 5 | out = open(destFile, "w") 6 | IO.foreach(srcFile) { |line| 7 | map.each_pair { |k,v| 8 | line = line.gsub(/\w+\/\*#{k}\*\//, v) 9 | } 10 | out.puts line 11 | } 12 | out.close 13 | end 14 | 15 | gen("IntVec.java", "DoubleVec.java", 16 | { "type" => "double", "Type" => "Double", "TypeVec" => "DoubleVec" }) 17 | gen("IntVec.java", "FloatVec.java", 18 | { "type" => "float", "Type" => "Float", "TypeVec" => "FloatVec" }) 19 | -------------------------------------------------------------------------------- /src/main/java/fast/evaluation/Mastery.java: -------------------------------------------------------------------------------- 1 | package fast.evaluation; 2 | 3 | import java.util.ArrayList; 4 | import fast.common.Bijection; 5 | 6 | public class Mastery { //per skill on test set 7 | 8 | public final double MASTERY_THRESHOLD = 0.95; 9 | public Double nbTotalStudents = 0.0; 10 | public Bijection studentsReachedMastery = new Bijection(); // to avoid repeatedly add student into nbPracToReachMastery 11 | public ArrayList nbPracToReachMastery = new ArrayList(); 12 | 13 | // public Mastery(double nbTotalStudents){ 14 | // this.nbTotalStudents = nbTotalStudents; 15 | // } 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/Option.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.lang.annotation.*; 4 | 5 | @Retention(RetentionPolicy.RUNTIME) 6 | @Target(ElementType.FIELD) 7 | public @interface Option { 8 | String name() default ""; 9 | String gloss() default ""; 10 | boolean required() default false; 11 | 12 | // Conditionally required option, e.g. 13 | // - "main.operation": required only when main.operation specified 14 | // - "main.operation=op1": required only when main.operation takes on value op1 15 | // - "operation=op1": the group of the option is used 16 | String condReq() default ""; 17 | } 18 | 19 | -------------------------------------------------------------------------------- /.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/IdentityHashSet.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.util.*; 4 | 5 | public class IdentityHashSet extends AbstractSet { 6 | private final Map map = new IdentityHashMap(); 7 | public IdentityHashSet() { } 8 | public IdentityHashSet(Collection c) { for(E o : c) add(o); } 9 | public int size() { return map.size(); } 10 | public boolean contains(Object o) { return map.containsKey(o); } 11 | public Iterator iterator() { return map.keySet().iterator(); } 12 | public boolean add(E o) { return map.put(o, o) == null; } 13 | public boolean remove(Object o) { return map.remove(o) != null; } 14 | public void clear() { map.clear(); } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/StatFig.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.util.*; 4 | 5 | /** 6 | * For keeping track of statistics. 7 | * Just keeps average and sum. 8 | */ 9 | public class StatFig { 10 | public StatFig() { sum = 0; n = 0; } 11 | public void add(boolean x) { if(x) sum++; n++; } 12 | public void add(double x) { sum += x; n++; } 13 | public void add(double x, int d) { sum += x; n += d; } 14 | public void add(StatFig fig) { sum += fig.sum; n += fig.n; } 15 | public double mean() { return sum/n; } 16 | public int size() { return n; } 17 | public double total() { return sum; } 18 | public String toString() { 19 | return Fmt.D(mean()) + " (" + n + ")"; 20 | } 21 | 22 | double sum; 23 | int n; 24 | } 25 | -------------------------------------------------------------------------------- /data/others/KT_test0.txt: -------------------------------------------------------------------------------- 1 | student outcome KC 2644 correct Variables 2644 correct Variables 2644 correct Variables 2644 correct Variables 2644 correct Variables 2644 correct Variables 2644 incorrect Variables 2644 incorrect Variables 2644 incorrect Variables 2644 correct Variables 2644 correct Variables 2644 correct Variables 2644 correct Variables 2644 incorrect Objects 2644 correct Objects 2644 correct Objects 2644 correct Objects 2644 correct Objects 2644 correct Objects 2644 correct Objects 2644 correct Objects 2644 correct Objects 2761 incorrect Interfaces 2761 incorrect Interfaces 2761 incorrect Interfaces 2761 incorrect Interfaces 2761 correct Interfaces 2761 correct Interfaces 2761 correct Interfaces 2761 correct Interfaces 2761 incorrect Interfaces 2761 incorrect Interfaces 2761 incorrect Interfaces 2761 correct Interfaces -------------------------------------------------------------------------------- /src/main/java/fig/basic/BigStatFig.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.util.*; 4 | 5 | /** 6 | * For keeping track of statistics. 7 | * Just keeps average, sum, min, max. 8 | */ 9 | public class BigStatFig extends StatFig { 10 | public BigStatFig() { min = Double.POSITIVE_INFINITY; max = Double.NEGATIVE_INFINITY; } 11 | public void add(double x) { 12 | super.add(x); 13 | min = Math.min(min, x); 14 | max = Math.max(max, x); 15 | } 16 | public String toString() { 17 | if(min == Double.POSITIVE_INFINITY) return "NaN (0)"; 18 | return Fmt.D(min) + "/ << " + Fmt.D(mean()) + " >> /" + Fmt.D(max) + " (" + n + ")"; 19 | } 20 | 21 | public double getMin() { return min; } 22 | public double getMax() { return max; } 23 | public double range() { return max-min; } 24 | 25 | double min, max; 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/CharEncUtils.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.io.*; 4 | 5 | public class CharEncUtils { 6 | //private static String charEncoding = "ISO-8859-1"; 7 | private static String charEncoding = "UTF-8"; 8 | 9 | public static String getCharEncoding() { return charEncoding; } 10 | 11 | public static void setCharEncoding(String charEncoding) { 12 | if(StrUtils.isEmpty(charEncoding)) return; 13 | CharEncUtils.charEncoding = charEncoding; 14 | LogInfo.updateStdStreams(); 15 | } 16 | 17 | public static BufferedReader getReader(InputStream in) throws IOException { 18 | return new BufferedReader(new InputStreamReader(in, getCharEncoding())); 19 | } 20 | public static PrintWriter getWriter(OutputStream out) throws IOException { 21 | return new PrintWriter(new OutputStreamWriter(out, getCharEncoding()), true); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/assembly/jar.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | final 6 | 7 | jar 8 | 9 | false 10 | 11 | 12 | true 13 | runtime 14 | false 15 | 16 | 17 | 18 | 19 | ${project.build.outputDirectory} 20 | / 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /src/main/java/fast/common/Functions.java: -------------------------------------------------------------------------------- 1 | package fast.common; 2 | 3 | import java.util.ArrayList; 4 | 5 | public class Functions { 6 | 7 | public static double logistic(double value){ 8 | return (1.0 / (1.0 + Math.exp(-1.0 * value))); 9 | } 10 | 11 | public static Double logistic(ArrayList coefficients, ArrayList values){ 12 | if (coefficients == null || values == null) 13 | return null; 14 | if (coefficients.size() != values.size()) 15 | return null; 16 | double logit = 0; 17 | for (int i = 0; i < coefficients.size(); i++) 18 | logit += coefficients.get(i) * values.get(i); 19 | return (1.0 / (1.0 + Math.exp(-1.0 * logit)));//1.0 / (1.0 + Math.exp((-1.0) * logit) 20 | } 21 | 22 | public static Double euclidean_distance(ArrayList vector1, ArrayList vector2){ 23 | if (vector1 == null || vector2 == null) 24 | return null; 25 | if (vector1.size() != vector2.size()) 26 | return null; 27 | double distance = 0; 28 | for (int i = 0; i < vector1.size(); i++) 29 | distance += Math.pow((vector1.get(i) - vector2.get(i)), 2); 30 | return (Math.sqrt(distance)); 31 | } 32 | 33 | } 34 | 35 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/Exceptions.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | public class Exceptions { 4 | public static RuntimeException bad = new RuntimeException("BAD"); 5 | public static RuntimeException bad(Object o) { 6 | return new RuntimeException(""+o); 7 | } 8 | public static RuntimeException bad(String fmt, Object... args) { 9 | return new RuntimeException(String.format(fmt, args)); 10 | } 11 | public static RuntimeException unknownCase(Object o) { 12 | return new RuntimeException("Unknown case: " + o); 13 | } 14 | public static RuntimeException unsupported(Object o) { 15 | return new RuntimeException("Function is unsupported:" + o); 16 | } 17 | 18 | public static RuntimeException unsupported = 19 | new RuntimeException("Function is unsupported"); 20 | public static RuntimeException unimplemented = 21 | new RuntimeException("Function has not been implemented"); 22 | public static RuntimeException unknownCase = 23 | new RuntimeException("Unknown case"); 24 | 25 | // Replacement for assert 26 | public static void enforce(boolean b, Object... o) { 27 | if(!b) throw bad(StrUtils.join(o)); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /data/others/KT_train0.txt: -------------------------------------------------------------------------------- 1 | student outcome KC 2647 correct Variables 2647 correct Variables 2647 correct Variables 2647 incorrect Variables 2647 correct Variables 2647 correct Variables 2647 correct Variables 2647 correct Variables 2647 correct Variables 2647 incorrect Variables 2647 correct Variables 2647 correct Variables 2647 incorrect Objects 2647 correct Objects 2647 correct Objects 2647 incorrect Objects 2647 correct Objects 2647 correct Objects 2647 correct Objects 2646 incorrect Objects 2646 incorrect Objects 2646 correct Objects 2646 incorrect Objects 2646 correct Objects 2646 correct Objects 2646 incorrect Objects 2646 correct Objects 2646 correct Objects 2534 incorrect Interfaces 2534 incorrect Interfaces 2534 incorrect Interfaces 2534 incorrect Interfaces 2534 incorrect Interfaces 2534 incorrect Interfaces 2534 incorrect Interfaces 2534 incorrect Interfaces 2534 incorrect Interfaces 2534 incorrect Interfaces 2534 incorrect Interfaces 2534 incorrect Interfaces 2534 incorrect Interfaces 2534 incorrect Interfaces 2534 incorrect Interfaces 2534 correct Interfaces 2534 incorrect Interfaces 2534 correct Interfaces 2534 correct Interfaces 2534 incorrect Interfaces 2534 incorrect Interfaces 2534 correct Interfaces 2534 incorrect Interfaces 2534 incorrect Interfaces 2534 correct Interfaces 2534 incorrect Interfaces 2534 correct Interfaces -------------------------------------------------------------------------------- /src/main/java/fig/basic/AbstractTMap.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.io.Serializable; 4 | 5 | /** 6 | * Just a dummy class. TODO: move common functionality here. 7 | */ 8 | public abstract class AbstractTMap implements Serializable { 9 | protected static final long serialVersionUID = 42; 10 | 11 | public static class Functionality implements Serializable { 12 | public T[] createArray(int n) { 13 | return (T[]) (new Object[n]); 14 | } 15 | 16 | public T intern(T x) { 17 | return x; 18 | } // Override to get desired behavior, e.g., interning 19 | } 20 | 21 | public static class ObjectFunctionality extends Functionality { 22 | public Object[] createArray(int n) { 23 | return new Object[n]; 24 | } 25 | } 26 | public static Functionality defaultFunctionality = new Functionality(); 27 | 28 | protected static final int growFactor = 2; // How much extra space (times 29 | // size) to give for the capacity 30 | protected static final int defaultExpectedSize = 2; 31 | protected static final double loadFactor = 0.75; // For hash table 32 | 33 | protected enum MapType { 34 | SORTED_LIST, HASH_TABLE 35 | } 36 | 37 | protected MapType mapType; 38 | protected boolean locked; // Are the keys locked 39 | protected int num; 40 | protected T[] keys; 41 | protected Functionality keyFunc; 42 | protected int numCollisions; // For debugging 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/fast/data/CVStudent.java: -------------------------------------------------------------------------------- 1 | /** 2 | * FAST v1.0 08/12/2014 3 | * 4 | * This code is only for research purpose not commercial purpose. 5 | * It is originally developed for research purpose and is still under improvement. 6 | * Please email to us if you want to keep in touch with the latest release. 7 | We sincerely welcome you to contact Yun Huang (huangyun.ai@gmail.com), or Jose P.Gonzalez-Brenes (josepablog@gmail.com) for problems in the code or cooperation. 8 | * We thank Taylor Berg-Kirkpatrick (tberg@cs.berkeley.edu) and Jean-Marc Francois (jahmm) for part of their codes that FAST is developed based on. 9 | * 10 | */ 11 | 12 | package fast.data; 13 | 14 | import java.util.Vector; 15 | 16 | public class CVStudent extends Vector { 17 | private static final long serialVersionUID = -7017179401151027439L; 18 | final int fold; 19 | 20 | public int getFold() { 21 | return fold; 22 | } 23 | 24 | public CVStudent(int fold) { 25 | super(); 26 | this.fold = fold; 27 | } 28 | 29 | @Override 30 | public boolean add(DataPoint s) { 31 | // if (s.getFold() != fold) { 32 | // if (!this.opts.preDpCurDpFromDifferentSet) 33 | // System.out 34 | // .println("Warn: Previous datapoint and current Datapoint are from different sets!"); 35 | // this.opts.preDpCurDpFromDifferentSet = true; 36 | // } 37 | // throw new 38 | // IllegalArgumentException("Multiple occurrences of a student should be in the same fold"); 39 | 40 | return super.add(s); 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/fast/evaluation/Sample.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | //package edu.uci.jforests.sample; 19 | package fast.evaluation; 20 | 21 | /** 22 | * @author Yasser Ganjisaffar 23 | */ 24 | 25 | public class Sample { 26 | 27 | public int[] indicesInDataset; 28 | public double[] weights; 29 | public double[] targets; 30 | public int size; 31 | // Only used in sub samples 32 | public int[] indicesInParentSample; 33 | 34 | /** 35 | * @author Yun Huang changed the code. For getting AUC. 36 | * @param targets 37 | * : the labels, one per instance 38 | */ 39 | public Sample(double[] targets) { 40 | this.targets = targets; 41 | this.size = targets.length; 42 | } 43 | 44 | } -------------------------------------------------------------------------------- /src/main/assembly/zip.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | release 6 | 7 | zip 8 | tar.gz 9 | 10 | false 11 | 12 | 13 | ${project.basedir}/data 14 | data 15 | 16 | **/* 17 | 18 | 19 | 20 | ${project.basedir}/src 21 | src 22 | 23 | **/* 24 | 25 | 26 | 27 | ${project.basedir} 28 | 29 | 30 | README.md 31 | LICENSE 32 | 33 | 34 | 35 | ${project.build.directory} 36 | 37 | 38 | *-final.jar 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/MapFactory.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.util.*; 4 | import java.io.Serializable; 5 | 6 | /** 7 | * The MapFactory is a mechanism for specifying what kind of map is to be used 8 | * by some object. For example, if you want a Counter which is backed by an 9 | * IdentityHashMap instead of the defaul HashMap, you can pass in an 10 | * IdentityHashMapFactory. 11 | * 12 | * @author Dan Klein 13 | */ 14 | 15 | public abstract class MapFactory { 16 | private static final long serialVersionUID = 5724671156522771657L; 17 | public static class HashMapFactory extends MapFactory { 18 | private static final long serialVersionUID = 5724671156522771657L; 19 | public Map buildMap() { 20 | return new HashMap(); 21 | } 22 | } 23 | 24 | public static class IdentityHashMapFactory extends MapFactory { 25 | private static final long serialVersionUID = 5724671156522771657L; 26 | public Map buildMap() { 27 | return new IdentityHashMap(); 28 | } 29 | } 30 | 31 | public static class TreeMapFactory extends MapFactory { 32 | private static final long serialVersionUID = 5724671156522771657L; 33 | public Map buildMap() { 34 | return new TreeMap(); 35 | } 36 | } 37 | 38 | public static class WeakHashMapFactory extends MapFactory { 39 | private static final long serialVersionUID = 5724671156522771657L; 40 | public Map buildMap() { 41 | return new WeakHashMap(); 42 | } 43 | } 44 | 45 | public abstract Map buildMap(); 46 | } 47 | 48 | -------------------------------------------------------------------------------- /data/others/FAST+item_SplitBySeq_test0.txt: -------------------------------------------------------------------------------- 1 | student outcome fold KC features_jArrayList1 features_jArrayList2 features_jArrayList3 features_jArrayList4 features_jArrayList5 features_jIncrement features_jMathFuc1 features_jMathFuc2 features_jOperator1 features_jVariable1 features_jVariables1 features_jVariables2 features_jVariables3 features_jVariables4 features_jVariables5 2812 correct -1 Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 0 0 0 0 2812 correct -1 Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 1 0 0 0 2812 correct -1 Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 1 0 0 2812 correct 1 Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 1 0 2812 correct 1 Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 0 1 2252 correct -1 Arithmetic_Operations NULL NULL NULL NULL NULL 1 0 0 0 0 NULL NULL NULL NULL NULL 2252 correct -1 Arithmetic_Operations NULL NULL NULL NULL NULL 0 0 0 1 0 NULL NULL NULL NULL NULL 2252 correct -1 Arithmetic_Operations NULL NULL NULL NULL NULL 0 0 1 0 0 NULL NULL NULL NULL NULL 2252 incorrect 1 Arithmetic_Operations NULL NULL NULL NULL NULL 0 0 0 0 1 NULL NULL NULL NULL NULL 2252 incorrect 1 Arithmetic_Operations NULL NULL NULL NULL NULL 0 1 0 0 0 NULL NULL NULL NULL NULL 2528 correct -1 ArrayList 1 0 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2528 correct -1 ArrayList 0 1 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2528 incorrect -1 ArrayList 0 0 1 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2528 correct 1 ArrayList 0 0 0 1 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2528 correct 1 ArrayList 0 0 0 0 1 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL -------------------------------------------------------------------------------- /src/main/java/fig/basic/StopWatchSet.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.util.*; 4 | import java.lang.ThreadLocal; 5 | 6 | /** 7 | * 4/2/09: StopWatchSet should be re-entrant (can call begin("foo") twice) and thread-safe. 8 | */ 9 | public class StopWatchSet { 10 | // For measuring time of certain types of events. 11 | // Shared across all threads. 12 | private static Map stopWatches = new LinkedHashMap(); 13 | 14 | // A stack of stop-watches (one per thread) 15 | private static ThreadLocal>> lastStopWatches = new ThreadLocal() { 16 | protected LinkedList> initialValue() { return new LinkedList(); } 17 | }; 18 | 19 | public synchronized static StopWatch getWatch(String s) { 20 | return MapUtils.getMut(stopWatches, s, new StopWatch()); 21 | } 22 | 23 | public static void begin(String s) { 24 | // Create a new stop watch for reentrance and thread safety 25 | lastStopWatches.get().addLast(new Pair(s, new StopWatch().start())); 26 | } 27 | public static void end() { 28 | Pair pair = lastStopWatches.get().removeLast(); 29 | pair.getSecond().stop(); 30 | // Add it 31 | synchronized(stopWatches) { 32 | getWatch(pair.getFirst()).add(pair.getSecond()); 33 | } 34 | } 35 | 36 | public synchronized static OrderedStringMap getStats() { 37 | OrderedStringMap map = new OrderedStringMap(); 38 | for(String key : stopWatches.keySet()) { 39 | StopWatch watch = getWatch(key); 40 | map.put(key, watch + " (" + new StopWatch(watch.n == 0 ? 0 : watch.ms/watch.n) + " x " + watch.n + ")"); 41 | } 42 | return map; 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/FullStatFig.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.util.*; 4 | 5 | /** 6 | * For keeping track of statistics. 7 | * Keeps all the data around (can be memory expensive). 8 | */ 9 | public class FullStatFig extends BigStatFig { 10 | public FullStatFig() { } 11 | public FullStatFig(Iterable c) { 12 | for(double x : c) add(x); 13 | } 14 | public void add(double x) { 15 | super.add(x); 16 | data.add(x); 17 | } 18 | 19 | public double entropy() { 20 | double e = 0; 21 | for(double x : data) { 22 | x /= sum; 23 | if(x > 0) 24 | e += -x * Math.log(x); 25 | } 26 | return e; 27 | } 28 | 29 | public double variance() { 30 | double v = 0; 31 | double m = mean(); 32 | for(double x : data) 33 | v += (x-m)*(x-m); 34 | return v/n; 35 | } 36 | public double stddev() { return Math.sqrt(variance()); } 37 | 38 | public List getData() { return data; } 39 | 40 | // Return for each lag, the correlation 41 | public double[] computeAutocorrelation(int maxLag) { 42 | double mean = mean(); 43 | double stddev = stddev(); 44 | double[] normData = new double[n]; 45 | for(int i = 0; i < n; i++) 46 | normData[i] = (data.get(i) - mean) / stddev; 47 | double[] autocorrelations = new double[maxLag+1]; 48 | for(int lag = 0; lag <= maxLag; lag++) { 49 | double sum = 0; 50 | int count = 0; 51 | for(int i = 0; i+lag < n; i++) { 52 | sum += normData[i] * normData[i+lag]; 53 | count++; 54 | } 55 | autocorrelations[lag] = sum / count; 56 | } 57 | return autocorrelations; 58 | } 59 | 60 | public String toString() { 61 | return Fmt.D(min) + "/ << " + Fmt.D(mean()) + "~" + Fmt.D(stddev()) + " >> /" + Fmt.D(max) + " (" + n + ")"; 62 | } 63 | 64 | private ArrayList data = new ArrayList(); 65 | } 66 | -------------------------------------------------------------------------------- /data/others/FAST+item_SplitBySeq_train0.txt: -------------------------------------------------------------------------------- 1 | student outcome fold KC features_jArrayList1 features_jArrayList2 features_jArrayList3 features_jArrayList4 features_jArrayList5 features_jIncrement features_jMathFuc1 features_jMathFuc2 features_jOperator1 features_jVariable1 features_jVariables1 features_jVariables2 features_jVariables3 features_jVariables4 features_jVariables5 2236 correct 1 Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 0 0 0 0 2236 correct 1 Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 1 0 0 2236 correct 1 Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 1 0 0 0 2236 correct 1 Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 1 0 2236 correct 1 Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 0 1 2812 correct 1 Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 0 0 0 0 2812 correct 1 Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 1 0 0 0 2812 correct 1 Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 1 0 0 2252 correct 1 Arithmetic_Operations NULL NULL NULL NULL NULL 1 0 0 0 0 NULL NULL NULL NULL NULL 2252 correct 1 Arithmetic_Operations NULL NULL NULL NULL NULL 0 0 0 1 0 NULL NULL NULL NULL NULL 2252 correct 1 Arithmetic_Operations NULL NULL NULL NULL NULL 0 0 1 0 0 NULL NULL NULL NULL NULL 2528 correct 1 ArrayList 1 0 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2528 correct 1 ArrayList 0 1 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2528 incorrect 1 ArrayList 0 0 1 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2522 correct 1 ArrayList 1 0 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2522 correct 1 ArrayList 0 1 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2522 incorrect 1 ArrayList 0 0 1 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2522 correct 1 ArrayList 0 0 0 1 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2522 correct 1 ArrayList 0 0 0 0 1 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL -------------------------------------------------------------------------------- /src/main/java/fig/basic/OrderedStringMap.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | 6 | /** 7 | * An OrderedMap for mapping strings to strings. 8 | */ 9 | public class OrderedStringMap extends OrderedMap { 10 | public OrderedStringMap() { } 11 | public OrderedStringMap(OrderedStringMap map) { 12 | for(String key : map.keys()) 13 | put(key, map.get(key)); 14 | } 15 | 16 | public static OrderedStringMap fromFile(String path) throws IOException { 17 | return fromFile(new File(path)); 18 | } 19 | public static OrderedStringMap fromFile(File path) throws IOException { 20 | OrderedStringMap map = new OrderedStringMap(); 21 | map.read(path); 22 | return map; 23 | } 24 | 25 | public void put(String key, Object val) { 26 | super.put(key, StrUtils.toString(val)); 27 | } 28 | 29 | public void read(String path) throws IOException { read(new File(path)); } 30 | public void read(File path) throws IOException { 31 | BufferedReader in = IOUtils.openIn(path); 32 | read(in); 33 | in.close(); 34 | } 35 | public void read(BufferedReader r) throws IOException { 36 | clear(); 37 | String line; 38 | while((line = r.readLine()) != null) { 39 | StringTokenizer st = new StringTokenizer(line, "\t"); 40 | if(!st.hasMoreTokens()) continue; // Skip blank lines 41 | String key = st.nextToken(); 42 | String val = st.hasMoreTokens() ? st.nextToken() : null; 43 | put(key, val); 44 | } 45 | } 46 | 47 | public boolean readEasy(String path) { 48 | if(StrUtils.isEmpty(path)) return false; 49 | return readEasy(new File(path)); 50 | } 51 | public boolean readEasy(File path) { 52 | if(path == null) return false; 53 | try { read(path); return true; } 54 | catch(IOException e) { return false; } 55 | } 56 | 57 | public void readHard(String path) { readHard(new File(path)); } 58 | public void readHard(File path) { 59 | try { read(path); } 60 | catch(IOException e) { throw new RuntimeException(e); } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/main/java/fast/evaluation/Degeneracy.java: -------------------------------------------------------------------------------- 1 | package fast.evaluation; 2 | 3 | public class Degeneracy { //cur Hmm per process 4 | 5 | //main 6 | public Double guessPlusSlipFeatureOff = -1.0; 7 | public Double nbDegKcsBasedOnGuessPlusSlipFeatureOff = -1.0;// either 0 or 1; g+s feature off whether >= 1(1) or not (0) 8 | public String degeneracyJudgementInequality = "be"; 9 | 10 | //secondary 11 | public Double guessPlusSlipFeatureOn = -1.0; //if use null for KT, this is underfined. just consider turn on the first feature (and bias) 12 | public Double guessPlusSlipAvgPerDP = -1.0; //(g+s on all dp from train and test) 13 | public Double pctDegDps = -1.0; //%dp that have g+s>(=)1 on both train and test set 14 | public Double pctDecProbKnown = -1.0;//on test 15 | public Double pctDecProbCorrect = -1.0; //on test 16 | public Double minGuessPlusSlipPerDpOnTrain = -1.0; 17 | public Double minGuessPlusSlipPerDpOnTest = -1.0; 18 | /* 19 | * 0 # degenerated cases by datapoints in train; 20 | * 1 # datapoints in train; 21 | * 2 # degenerated cases by datapoints in test; 22 | * 3 # datapoints in test; 23 | * 4 # sum of guess+slip across datapoints in train 24 | * 5 # sum of guess+slip across datapoints in test 25 | * 6 % decrease pKnow per dp on test 26 | * 7 % decrease pCorrect per dp on test 27 | * 8 minimum g+s per dp on train 28 | * 9 minimum g+s per dp on test 29 | */ 30 | public double[] degeneracyJudgementsAcrossDataPoints = new double[10]; 31 | 32 | //public double avgPerKcFeatureOffGuessPlusSlip = -1.0; //avg per kc 33 | //public double avgPerKcFeatureOnGuessPlusSlip = -1.0; //avg per kc 34 | ////public double avgPerDpGuessPlusSlip = -1.0; //g+s on all dp from train and test on all kcs 35 | //public double avgPerKcGuessPlusSlipAvgPerDP = -1.0; //avg (g+s on all dp from train and test) per kc 36 | //public double overallNbDegKcs = -1.0; //overall #kcs that have g+s>1 for either train or test set 37 | ////public double overallPctDegDps = -1.0; //overall %datapoints that have g+s>1 on both train and test set 38 | //public double avgPerKcPctDegDps = -1.0; 39 | 40 | } 41 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | 5 | 6 | fast 7 | fast 8 | 2.1.1 9 | jar 10 | Feature Aware Student knowledge Tracing 11 | http://maven.apache.org 12 | 13 | 14 | 15 | Jahmm 16 | Jahmm HMM library repository 17 | http://jahmm.googlecode.com/svn/repo 18 | 19 | 20 | 21 | 22 | 23 | 24 | edu.berkeley.nlp 25 | berkeleyparser 26 | r32 27 | 28 | 29 | 30 | be.ac.ulg.montefiore.run.jahmm 31 | jahmm 32 | 0.6.2 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | data 41 | ../data/ 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | maven-assembly-plugin 53 | 54 | 55 | 56 | fast.experimenter.Runner 57 | 58 | 59 | 60 | jar-with-dependencies 61 | 62 | 63 | src/main/assembly/jar.xml 64 | src/main/assembly/zip.xml 65 | 66 | 67 | 68 | 69 | make-assembly 70 | package 71 | 72 | single 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | FAST: Feature-Aware Student knowledge Tracing 3 | ============================================= 4 | 5 | This is the repository of FAST, an efficient toolkit for modeling time-changing student performance ([González-Brenes, Huang, Brusilovsky et al, 2014] (http://educationaldatamining.org/EDM2014/uploads/procs2014/long%20papers/84_EDM-2014-Full.pdf)). FAST is alterantive to the [BNT-SM toolkit] (http://www.cs.cmu.edu/~listen/BNT-SM/), a toolkit that requires the researcher to design a different different Bayes Net for each feature set they want to prototype. 6 | The FAST toolkit is up to 300x faster than BNT-SM, and much simpler to use. 7 | 8 | We presented the model in the 7th International Conference on Educational Data Mining (2014) (see [slides] (http://www.cs.cmu.edu/~joseg/files/fast_presentation.pdf) ), where it was selected as one the top 5 paper submissions. 9 | 10 | 11 | 12 | Technical Details 13 | ----------------- 14 | FAST learns per parameters for each skill using an HMM with Features ([Berg-Kirpatrick et al, 2010] (http://www.cs.berkeley.edu/~tberg/papers/naaclhlt2010.pdf)). 15 | 16 | 17 | 18 | 19 | Running FAST 20 | ============ 21 | 22 | Quick Start 23 | ------------ 24 | 25 | 1. Download the latest release [here] (https://github.com/ml-smores/fast/releases). 26 | 2. Decompress the file. It includes sample data for getting you started quickly. 27 | 3. Open a terminal and type (you need to be in the same directory as the fast-2.1.1-final.jar file in your console, which can be achieved by the cd command): 28 | ``` java -jar fast-2.1.1-final.jar ++data/IRT_exp/FAST+IRT1.conf ```` 29 | 30 | Congratulations! You just trained a student model (with IRT features) using state of the art technology. 31 | 32 | 33 | Please see the [Wiki](https://github.com/ml-smores/fast/wiki/) for more information. 34 | 35 | Please cite our work (and provide the link https://github.com/ml-smores/fast) if you use our tool in your published papers: González-Brenes, J. P., Huang, Y., & Brusilovsky, P. (2014). General features in knowledge tracing: applications to multiple subskills, temporal item response theory, and expert knowledge. In Proc. 7th Int. Conf. on Educational Data Mining (pp. 84-91). 36 | 37 | Contact us 38 | ========== 39 | We would love to hear your feedback. Please [email us] (mailto:ml-smores@googlegroups.com)! 40 | 41 | Thanks, 42 | Yun, 43 | Jose, 44 | and Peter 45 | -------------------------------------------------------------------------------- /src/main/java/fast/evaluation/Metrics.java: -------------------------------------------------------------------------------- 1 | package fast.evaluation; 2 | 3 | //import java.text.DateFormat; 4 | import java.text.DecimalFormat; 5 | //import java.text.SimpleDateFormat; 6 | import java.util.HashMap; 7 | import java.util.Locale; 8 | import fast.common.Bijection; 9 | //import fast.evaluation.EvaluationGeneral.Metrics; 10 | 11 | public class Metrics{ 12 | private DecimalFormat formatter; 13 | { 14 | formatter = (DecimalFormat) DecimalFormat.getInstance(Locale.US); 15 | formatter.applyPattern("#.###"); 16 | } 17 | //public static DateFormat dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss"); 18 | private String modelName = ""; 19 | private HashMap metricNameToValue = new HashMap(); 20 | private Bijection metricNames = new Bijection(); //Need to keep the ordering 21 | 22 | public Metrics(String name){ 23 | modelName = name; 24 | } 25 | 26 | public Metrics(){ 27 | } 28 | 29 | public Bijection getMetricNames(){ 30 | return metricNames; 31 | } 32 | 33 | public double getMetricValue(String name){ 34 | return metricNameToValue.get(name); 35 | } 36 | 37 | 38 | public void setMetricValue(String metricName, Double metricValue){ 39 | metricNames.put(metricName); 40 | metricNameToValue.put(metricName, metricValue); 41 | } 42 | 43 | public void copyMetrics(Metrics eval){ 44 | for (int i = 0; i < eval.metricNames.getSize(); i ++){ 45 | String metricName = eval.metricNames.get(i); 46 | double metricValue = eval.metricNameToValue.get(metricName); 47 | metricNames.put(metricName); 48 | metricNameToValue.put(metricName, metricValue); 49 | } 50 | } 51 | 52 | public String getHeader(String delim){ 53 | String header = "Name" + delim; 54 | for (int i = 0; i < metricNames.getSize(); i ++){ 55 | header += metricNames.get(i) + delim; 56 | } 57 | return header; 58 | } 59 | 60 | public String getEvaluationStr(Metrics eval, String delim){ 61 | String evaluationStr = eval.modelName + delim; 62 | for (int i = 0; i < metricNames.getSize(); i ++){ 63 | Double value = metricNameToValue.get(metricNames.get(i)); 64 | String value_str = Double.isNaN(value)? "NaN" : formatter.format(value); 65 | evaluationStr += value_str + delim; 66 | } 67 | return evaluationStr; 68 | } 69 | 70 | public String getEvaluationStr(String delimiter){ 71 | String evaluationStr = getEvaluationStr(this, delimiter); 72 | return evaluationStr; 73 | } 74 | 75 | } 76 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/Fmt.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.text.*; 4 | 5 | /** 6 | * Formatting class. I'm really lazy. 7 | * D() is a family of default functions for formatting various types of objects. 8 | */ 9 | public class Fmt { 10 | public static String D(double x) { 11 | if(Math.abs(x - (int)x) < 1e-40) // An integer (probably) 12 | return ""+(int)x; 13 | if(Math.abs(x) < 1e-3) // Scientific notation (close to 0) 14 | return String.format("%.2e", x); 15 | return String.format("%.3f", x); 16 | } 17 | public static String D(boolean[] x) { return StrUtils.join(x); } 18 | public static String D(int[] x) { return StrUtils.join(x); } 19 | public static String D(double[] x) { return D(x, " "); } 20 | public static String D(double[] xs, String delim) { 21 | StringBuilder sb = new StringBuilder(); 22 | for(double x : xs) { 23 | if(sb.length() > 0) sb.append(delim); 24 | sb.append(Fmt.D(x)); 25 | } 26 | return sb.toString(); 27 | } 28 | // Print out only first N 29 | public static String D(double[] x, int firstN) { 30 | if(firstN >= x.length) return D(x); 31 | return D(ListUtils.subArray(x, 0, firstN)) + " ...("+(x.length-firstN) + " more)"; 32 | } 33 | public static String D(double[][] x) { return D(x, " "); } 34 | public static String D(double[][] xs, String delim) { 35 | StringBuilder sb = new StringBuilder(); 36 | for(double[] x : xs) { 37 | if(sb.length() > 0) sb.append(delim); 38 | sb.append(Fmt.D(x)); 39 | } 40 | return sb.toString(); 41 | } 42 | 43 | public static String D(TDoubleMap map) { return D(map, 20); } 44 | public static String D(TDoubleMap map, int numTop) { 45 | return MapUtils.topNToString(map, numTop); 46 | } 47 | 48 | public static String D(Object o) { 49 | if(o instanceof double[]) return Fmt.D((double[])o); 50 | if(o instanceof double[][]) return Fmt.D((double[][])o); 51 | if(o instanceof double[][][]) return Fmt.D((double[][][])o); 52 | throw Exceptions.unknownCase; 53 | } 54 | 55 | public static String bytesToString(long b) { 56 | double gb = (double)b / (1024*1024*1024); 57 | if(gb >= 1) return gb >= 10 ? (int)gb+"G" : NumUtils.round(gb, 1)+"G"; 58 | double mb = (double)b / (1024*1024); 59 | if(mb >= 1) return mb >= 10 ? (int)mb+"M" : NumUtils.round(mb, 1)+"M"; 60 | double kb = (double)b / (1024); 61 | if(kb >= 1) return kb >= 10 ? (int)kb+"K" : NumUtils.round(kb, 1)+"K"; 62 | return b+""; 63 | } 64 | public static String formatEasyDateTime(long t) { 65 | return new SimpleDateFormat("MM/dd HH:mm").format(t); 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/main/java/fast/common/Bijection.java: -------------------------------------------------------------------------------- 1 | package fast.common; 2 | 3 | import java.io.Serializable; 4 | import java.util.Collection; 5 | import java.util.Collections; 6 | import java.util.HashMap; 7 | import java.util.Iterator; 8 | import java.util.Vector; 9 | 10 | public class Bijection implements Serializable { 11 | private static final long serialVersionUID = 6526310928919842268L; 12 | // hy: name, posterior index 13 | final private HashMap keys; 14 | // hy: posterior index, name 15 | final private Vector values; 16 | private int size; 17 | 18 | public int getSize() { 19 | return size; 20 | } 21 | 22 | public boolean contains(String key) { 23 | return keys.containsKey(key); 24 | } 25 | 26 | @Override 27 | public String toString() { 28 | StringBuilder sb = new StringBuilder("{"); 29 | Iterator v = values.iterator(); 30 | int i = 0; 31 | while (v.hasNext()) { 32 | String s = v.next(); 33 | sb.append(s); 34 | sb.append("="); 35 | sb.append(i++); 36 | if (v.hasNext()) 37 | sb.append(", "); 38 | } 39 | sb.append("}"); 40 | return sb.toString(); 41 | } 42 | 43 | // FIXME: should return an inmutable collection 44 | public Collection values() { 45 | return keys.values(); // return posterior index 46 | } 47 | 48 | // FIXME: should return an inmutable collection 49 | public Collection keys() { 50 | return keys.keySet(); 51 | } 52 | 53 | public Bijection() { 54 | keys = new HashMap(); 55 | values = new Vector(); 56 | size = 0; 57 | } 58 | 59 | // hy: 60 | public Bijection(Bijection b) { 61 | size = 0; 62 | keys = new HashMap(); 63 | values = new Vector(); 64 | for (int i = 0; i < b.getSize(); i++) { 65 | this.put(b.get(i)); 66 | } 67 | } 68 | 69 | public Bijection(String[] keys) { 70 | this(); 71 | for (String k : keys) 72 | this.put(k); 73 | 74 | } 75 | 76 | public Integer get(String key) { 77 | return keys.get(key); 78 | } 79 | 80 | public String get(Integer value) { 81 | return values.get(value); 82 | } 83 | 84 | public Integer put(String key) { 85 | Integer value = keys.get(key); 86 | if (value == null) { 87 | value = size++; 88 | keys.put(key, value); 89 | values.add(value, key); 90 | } 91 | return value; 92 | } 93 | 94 | static final Integer max(Collection coll) { 95 | if (coll.isEmpty()) 96 | return new Integer(-1); 97 | else 98 | return Collections.max(coll); 99 | } 100 | 101 | } 102 | -------------------------------------------------------------------------------- /src/main/java/fast/evaluation/TestSummary.java: -------------------------------------------------------------------------------- 1 | package fast.evaluation; 2 | 3 | import java.util.ArrayList; 4 | 5 | public class TestSummary{ 6 | public ArrayList actualLabels = new ArrayList(); 7 | public ArrayList predLabels = new ArrayList(); 8 | public ArrayList predProbs = new ArrayList(); 9 | public ArrayList priorProbKnowns = new ArrayList(); 10 | public ArrayList posteriorProbKnowns = new ArrayList(); 11 | public Degeneracy degeneracy = new Degeneracy(); //also contain train info 12 | public Mastery mastery = new Mastery(); 13 | public Metrics eval = new Metrics(); 14 | 15 | 16 | public void update(ArrayList actualLabels, ArrayList predLabels, ArrayList predProbs){ 17 | this.actualLabels = actualLabels; 18 | this.predLabels = predLabels; 19 | this.predProbs = predProbs; 20 | } 21 | 22 | public void update(ArrayList actualLabels, ArrayList predLabels, ArrayList predProbs, 23 | ArrayList priorProbKnowns, ArrayList posteriorProbKnowns){ 24 | this.actualLabels = actualLabels; 25 | this.predLabels = predLabels; 26 | this.predProbs = predProbs; 27 | this.priorProbKnowns = priorProbKnowns; 28 | this.posteriorProbKnowns = posteriorProbKnowns; 29 | } 30 | 31 | public void update(ArrayList actualLabels, ArrayList predLabels, ArrayList predProbs, Degeneracy degeneracy, Mastery mastery){ 32 | this.actualLabels.addAll(actualLabels); 33 | this.predLabels.addAll(predLabels); 34 | this.predProbs.addAll(predProbs); 35 | this.degeneracy = degeneracy; 36 | this.mastery = mastery; 37 | } 38 | 39 | // public void update(TestSummary testSummary){ 40 | // this.actualLabels.addAll(testSummary.actualLabels); 41 | // this.predLabels.addAll(testSummary.predLabels); 42 | // this.predProbs.addAll(testSummary.predProbs); 43 | // this.priorProbKnowns = testSummary.priorProbKnowns; 44 | // this.posteriorProbKnowns = testSummary.posteriorProbKnowns; 45 | // this.degeneracy = testSummary.degeneracy; //also contain train info 46 | // this.mastery = testSummary.mastery; 47 | // this.eval = testSummary.eval; 48 | // } 49 | 50 | // public Metrics getEval(){ 51 | // return eval; 52 | // } 53 | // 54 | // public Mastery getMastery(){ 55 | // return mastery; 56 | // } 57 | // 58 | // public Degeneracy getDegeneracy(){ 59 | // return degeneracy; 60 | // } 61 | // 62 | // public void setEval(Metrics eval){ 63 | // this.eval = eval; 64 | // } 65 | // 66 | // public ArrayList getActualLabels(){ 67 | // return actualLabels; 68 | // } 69 | // 70 | // public ArrayList getPredLabels(){ 71 | // return predLabels; 72 | // } 73 | // 74 | // public ArrayList getPredProbs(){ 75 | // return predProbs; 76 | // } 77 | 78 | } 79 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/StopWatch.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | /** 4 | * Simple class for measuring elapsed time. 5 | */ 6 | public class StopWatch 7 | { 8 | public StopWatch() 9 | { 10 | } 11 | 12 | public StopWatch(long ms) 13 | { 14 | startTime = 0; 15 | endTime = ms; 16 | this.ms = ms; 17 | } 18 | 19 | public void reset() 20 | { 21 | ms = 0; 22 | isRunning = false; 23 | } 24 | 25 | public StopWatch start() 26 | { 27 | assert !isRunning; 28 | isRunning = true; 29 | startTime = System.currentTimeMillis(); 30 | 31 | return this; 32 | } 33 | 34 | public StopWatch stop() 35 | { 36 | assert isRunning; 37 | endTime = System.currentTimeMillis(); 38 | isRunning = false; 39 | ms = endTime - startTime; 40 | n = 1; 41 | return this; 42 | } 43 | 44 | public StopWatch accumStop() 45 | { 46 | // Stop and accumulate time 47 | assert isRunning; 48 | endTime = System.currentTimeMillis(); 49 | isRunning = false; 50 | ms += endTime - startTime; 51 | n++; 52 | return this; 53 | } 54 | 55 | public void add(StopWatch w) { 56 | assert !isRunning && !w.isRunning; 57 | ms += w.ms; 58 | n += w.n; 59 | } 60 | 61 | public long getCurrTimeLong() 62 | { 63 | return ms + (isRunning() ? System.currentTimeMillis() - startTime : 0); 64 | } 65 | 66 | @Override 67 | public String toString() 68 | { 69 | long msCopy = ms; 70 | long m = msCopy / 60000; 71 | msCopy %= 60000; 72 | long h = m / 60; 73 | m %= 60; 74 | long d = h / 24; 75 | h %= 24; 76 | long y = d / 365; 77 | d %= 365; 78 | long s = msCopy / 1000; 79 | 80 | StringBuilder sb = new StringBuilder(); 81 | 82 | if (y > 0) 83 | { 84 | sb.append(y); 85 | sb.append('y'); 86 | sb.append(d); 87 | sb.append('d'); 88 | } 89 | if (d > 0) 90 | { 91 | sb.append(d); 92 | sb.append('d'); 93 | sb.append(h); 94 | sb.append('h'); 95 | } 96 | else if (h > 0) 97 | { 98 | sb.append(h); 99 | sb.append('h'); 100 | sb.append(m); 101 | sb.append('m'); 102 | } 103 | else if (m > 0) 104 | { 105 | sb.append(m); 106 | sb.append('m'); 107 | sb.append(s); 108 | sb.append('s'); 109 | } 110 | else if (s > 9) 111 | { 112 | sb.append(s); 113 | sb.append('s'); 114 | } 115 | else if (s > 0) 116 | { 117 | sb.append((ms / 100) / 10.0); 118 | sb.append('s'); 119 | } 120 | else 121 | { 122 | sb.append(ms / 1000.0); 123 | sb.append('s'); 124 | } 125 | return sb.toString(); 126 | } 127 | 128 | public long startTime, endTime, ms; 129 | 130 | public int n; 131 | 132 | private boolean isRunning = false; 133 | 134 | public boolean isRunning() 135 | { 136 | return isRunning; 137 | } 138 | 139 | // Use StopWatchSet instead 140 | @Deprecated 141 | public static void start(String s) 142 | { 143 | } 144 | 145 | @Deprecated 146 | public static void accumStop(String s) 147 | { 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/Interner.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.util.Map; 4 | import java.util.Set; 5 | 6 | /** 7 | * Canonicalizes objects. Given an object, the intern() method returns a 8 | * canonical representation of that object, that is, an object which equals() 9 | * the input. Furthermore, given two objects x and y, it is guaranteed that if 10 | * x.equals(y), then intern(x) == intern(y). The default behavior is that the 11 | * interner is backed by a HashMap and the canonical version of an object x is 12 | * simply the first object that equals(x) which is passed to the interner. In 13 | * this case, it can be true that intern(x) == x. The backing map can be 14 | * specified by passing a MapFactory on construction (though the only standard 15 | * option which makes much sense is the WeakHashMap, which is slower than a 16 | * HashMap, but which allows unneeded keys to be reclaimed by the garbage 17 | * collector). The source of canonical elements can be changed by specifying an 18 | * Interner.Factory on construction. 19 | * 20 | * @author Dan Klein 21 | */ 22 | public class Interner { 23 | /** 24 | * The source of canonical objects when a non-interned object is presented to 25 | * the interner. The default implementation is an identity map. 26 | */ 27 | public static interface CanonicalFactory { 28 | T build(T object); 29 | } 30 | 31 | static class IdentityCanonicalFactory implements CanonicalFactory { 32 | public T build(T object) { 33 | return object; 34 | } 35 | } 36 | 37 | Map canonicalMap; 38 | CanonicalFactory cf; 39 | 40 | public Set getCanonicalElements() { 41 | return canonicalMap.keySet(); 42 | } 43 | public boolean isCanonical(T x) { 44 | return canonicalMap.containsKey(x); 45 | } 46 | 47 | /** 48 | * Returns a canonical representation of the given object. If the object has 49 | * no canonical representation, one is built using the interner's 50 | * CanonicalFactory. The default is that new objects will be their own 51 | * canonical instances. 52 | * 53 | * @param object 54 | * @return a canonical representation of that object 55 | */ 56 | public T intern(T object) { 57 | T canonical = canonicalMap.get(object); 58 | if (canonical == null) { 59 | canonical = cf.build(object); 60 | canonicalMap.put(canonical, canonical); 61 | } 62 | return canonical; 63 | } 64 | 65 | // Like intern, but don't save object if it's not interned already. 66 | public T getCanonical(T object) { 67 | T canonical = canonicalMap.get(object); 68 | return canonical == null ? object : canonical; 69 | } 70 | 71 | public int size( ) { 72 | return canonicalMap.size(); 73 | } 74 | 75 | public Interner() { 76 | this(new MapFactory.HashMapFactory(), new IdentityCanonicalFactory()); 77 | } 78 | 79 | public Interner(MapFactory mf) { 80 | this(mf, new IdentityCanonicalFactory()); 81 | } 82 | 83 | public Interner(CanonicalFactory f) { 84 | this(new MapFactory.HashMapFactory(), f); 85 | } 86 | 87 | public Interner(MapFactory mf, CanonicalFactory cf) { 88 | canonicalMap = mf.buildMap(); 89 | this.cf = cf; 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /data/IRT_exp/test0.csv: -------------------------------------------------------------------------------- 1 | KCs,student,problem,step,outcome,fold Nested_Loops,2528,jNested1,1481,incorrect,-1 Nested_Loops,2528,jNested2,1484,incorrect,-1 Nested_Loops,2528,jNested3,1489,incorrect,1 Nested_Loops,2647,jNested1,3709,correct,-1 Nested_Loops,2647,jNested2,3710,incorrect,1 Nested_Loops,2649,jNested1,4180,correct,-1 Nested_Loops,2649,jNested2,4181,correct,1 Nested_Loops,2650,jNested1,4371,correct,-1 Nested_Loops,2650,jNested2,4374,incorrect,1 Nested_Loops,2666,jNested1,5821,correct,-1 Nested_Loops,2666,jNested2,5822,incorrect,-1 Nested_Loops,2666,jNested3,5823,incorrect,1 Nested_Loops,2667,jNested1,6045,correct,-1 Nested_Loops,2667,jNested2,6048,incorrect,1 Nested_Loops,2669,jNested1,6736,correct,-1 Nested_Loops,2669,jNested2,6737,correct,-1 Nested_Loops,2669,jNested3,6738,correct,1 Nested_Loops,2671,jNested1,6949,correct,-1 Nested_Loops,2671,jNested3,6952,incorrect,-1 Nested_Loops,2671,jNested2,6950,correct,1 Nested_Loops,2675,jNested1,7635,incorrect,-1 Nested_Loops,2675,jNested2,7637,incorrect,1 Nested_Loops,2761,jNested1,8213,correct,-1 Nested_Loops,2761,jNested2,8214,incorrect,-1 Nested_Loops,2761,jNested3,8222,correct,1 Nested_Loops,2763,jNested1,8621,correct,-1 Nested_Loops,2763,jNested2,8622,incorrect,1 Nested_Loops,2766,jNested1,9266,correct,-1 Nested_Loops,2766,jNested2,9267,correct,-1 Nested_Loops,2766,jNested3,9268,correct,1 Nested_Loops,2787,jNested1,12071,incorrect,-1 Nested_Loops,2787,jNested2,12073,correct,-1 Nested_Loops,2787,jNested3,12074,incorrect,1 Nested_Loops,2792,jNested1,12832,incorrect,-1 Nested_Loops,2792,jNested2,12835,incorrect,-1 Nested_Loops,2792,jNested3,12839,incorrect,1 Nested_Loops,2812,jNested1,15488,correct,-1 Nested_Loops,2812,jNested2,15489,incorrect,-1 Nested_Loops,2812,jNested3,15494,correct,1 Nested_Loops,2986,jNested1,15756,correct,-1 Nested_Loops,2986,jNested2,15757,incorrect,-1 Nested_Loops,2986,jNested3,15770,correct,1 Nested_Loops,2992,jNested1,17033,correct,-1 Nested_Loops,2992,jNested2,17034,incorrect,-1 Nested_Loops,2992,jNested3,17035,incorrect,1 Nested_Loops,3000,jNested1,18071,correct,-1 Nested_Loops,3000,jNested2,18072,incorrect,-1 Nested_Loops,3000,jNested3,18074,incorrect,1 Nested_Loops,3003,jNested1,18374,correct,-1 Nested_Loops,3003,jNested2,18375,incorrect,-1 Nested_Loops,3003,jNested3,18378,incorrect,1 Nested_Loops,3006,jNested1,18476,correct,-1 Nested_Loops,3006,jNested2,18477,incorrect,-1 Nested_Loops,3006,jNested3,18479,incorrect,1 Nested_Loops,3007,jNested1,18924,incorrect,-1 Nested_Loops,3007,jNested2,18930,incorrect,-1 Nested_Loops,3007,jNested3,18938,incorrect,1 Nested_Loops,3011,jNested1,19372,correct,-1 Nested_Loops,3011,jNested2,19375,incorrect,-1 Nested_Loops,3011,jNested3,19383,incorrect,1 Nested_Loops,3018,jNested1,20011,correct,-1 Nested_Loops,3018,jNested2,20014,correct,-1 Nested_Loops,3018,jNested3,20019,incorrect,1 Nested_Loops,3019,jNested1,20298,correct,-1 Nested_Loops,3019,jNested2,20300,incorrect,-1 Nested_Loops,3019,jNested3,20306,incorrect,1 Nested_Loops,3020,jNested1,20577,incorrect,-1 Nested_Loops,3020,jNested2,20580,incorrect,-1 Nested_Loops,3020,jNested3,20583,incorrect,1 Nested_Loops,3021,jNested1,20765,correct,-1 Nested_Loops,3021,jNested2,20766,incorrect,1 Nested_Loops,3022,jNested1,20858,correct,-1 Nested_Loops,3022,jNested2,20859,correct,1 Nested_Loops,3086,jNested1,21132,correct,-1 Nested_Loops,3086,jNested2,21133,incorrect,-1 Nested_Loops,3086,jNested3,21138,incorrect,1 -------------------------------------------------------------------------------- /src/main/java/fast/evaluation/TrainSummary.java: -------------------------------------------------------------------------------- 1 | package fast.evaluation; 2 | 3 | import java.text.DecimalFormat; 4 | import java.util.LinkedHashMap; 5 | //import edu.berkeley.nlp.classify.Feature; 6 | import fast.common.Bijection; 7 | 8 | public class TrainSummary { 9 | 10 | //TODO: Change to final iteration number 11 | public int nbLLError; 12 | public double maxLLDecrease, maxLLDecreaseRatio, trainLL; 13 | public int nbParameterizingFailed; 14 | public int nbStopByEMIteration; 15 | /* 16 | * Bijection allFeatures includes init, tran, emit; 17 | * Each Bijection contains original feature name, and feature name with "_hidden1" surfix; 18 | * Later when output to parameters file, the code will transfer the name using "init", "guess", "slip", "learn" (etc.). 19 | */ 20 | public Bijection allFeatures = new Bijection(); 21 | public Bijection initFeatures = new Bijection(); 22 | public Bijection tranFeatures = new Bijection(); 23 | public Bijection emitFeatures = new Bijection(); 24 | public LinkedHashMap parameters = new LinkedHashMap(); 25 | 26 | 27 | public void update(double trainLL, int nbStopByEMIteration, int nbLLError, double maxLLDecrease, double maxLLDecreaseRatio, 28 | int nbParameterizingFailed){ 29 | this.trainLL = trainLL; 30 | this.nbLLError = nbLLError; 31 | this.maxLLDecrease = maxLLDecrease; 32 | this.maxLLDecreaseRatio = maxLLDecreaseRatio; 33 | this.nbParameterizingFailed = nbParameterizingFailed; 34 | this.nbStopByEMIteration = nbStopByEMIteration; 35 | } 36 | 37 | public String getHeader(String delimiter){ 38 | return ("trainLL" + delimiter + "nbStopByEMIteration" + delimiter + "nbLLError" + delimiter + "maxLLDecrease" + delimiter + "maxLLDecreaseRatio" 39 | + delimiter + "nbParameterizingFailed"); 40 | } 41 | 42 | public String getEvaluationStr(String delimiter, DecimalFormat formatter){ 43 | return (formatter.format(trainLL) + delimiter + formatter.format(nbStopByEMIteration) + delimiter + formatter.format(nbLLError) 44 | + delimiter + formatter.format(maxLLDecrease) + delimiter + formatter.format(maxLLDecreaseRatio) 45 | + delimiter + nbParameterizingFailed); 46 | } 47 | 48 | // public void update(TrainSummary trainSummary){ 49 | // this.nbLLError = trainSummary.nbLLError; 50 | // this.maxLLDecrease = trainSummary.maxLLDecrease; 51 | // this.maxLLDecreaseRatio = trainSummary.maxLLDecreaseRatio; 52 | // this.parameterizingSucceeded = trainSummary.parameterizingSucceeded; 53 | // this.stopByEMIteration = trainSummary.stopByEMIteration; 54 | // this.trainLL = trainSummary.trainLL; 55 | // this.allFeatures = trainSummary.allFeatures; 56 | // this.initFeatures = trainSummary.initFeatures; 57 | // this.tranFeatures = trainSummary.tranFeatures; 58 | // this.emitFeatures = trainSummary.emitFeatures; 59 | // this.parameters = trainSummary.parameters; 60 | // } 61 | 62 | // public double getTrainLL(){ 63 | // return trainLL; 64 | // } 65 | // 66 | // public int getNbLLError(){ 67 | // return nbLLError; 68 | // } 69 | // 70 | // public double getMaxLLDecrease(){ 71 | // return maxLLDecrease; 72 | // } 73 | // 74 | // public double getMaxLLDecreaseRatio(){ 75 | // return maxLLDecreaseRatio; 76 | // } 77 | // 78 | // public boolean getParameterizingSucceeded(){ 79 | // return parameterizingSucceeded; 80 | // } 81 | // 82 | // public int getStopByEMIteration(){ 83 | // return stopByEMIteration; 84 | // } 85 | 86 | } 87 | -------------------------------------------------------------------------------- /src/main/java/fig/exec/MonitorThread.java: -------------------------------------------------------------------------------- 1 | package fig.exec; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | import java.lang.Thread; 6 | import fig.basic.*; 7 | import static fig.basic.LogInfo.*; 8 | 9 | /** 10 | * A separate thread that's responsible for outputting the status 11 | * of this execution and reading in commands. 12 | * The thread is actually contained inside. 13 | */ 14 | class MonitorThread implements Runnable { 15 | private static final int timeInterval = 300; // Number of milliseconds between monitoring 16 | private boolean stop; 17 | private Thread thread; 18 | 19 | public MonitorThread() { 20 | this.stop = false; 21 | this.thread = new Thread(this); 22 | } 23 | 24 | void processCommand(String cmd) { 25 | cmd = cmd.trim(); 26 | if(cmd.equals("")) { 27 | // Print status 28 | Execution.getInfo().print(stderr); 29 | Execution.printOutputMapToStderr(); 30 | StopWatchSet.getStats().print(stderr); 31 | stderr.println(Execution.getVirtualExecDir()); 32 | } 33 | else if(cmd.equals("kill")) { 34 | stderr.println("MonitorThread: KILLING"); 35 | Execution.setExecStatus("killed", true); 36 | Execution.printOutputMap(Execution.getFile("output.map")); 37 | throw new RuntimeException("Killed by input command"); 38 | } 39 | else if(cmd.equals("bail")) { 40 | // Up to program to look at this flag and actually gracefully stop 41 | stderr.println("MonitorThread: BAILING OUT"); 42 | Execution.shouldBail = true; 43 | } 44 | else 45 | stderr.println("Invalid command: '" + cmd + "'"); 46 | } 47 | 48 | void readAndProcessCommand() { 49 | try { 50 | int nBytes = System.in.available(); 51 | if(nBytes > 0) { 52 | byte[] bytes = new byte[nBytes]; 53 | System.in.read(bytes); 54 | String line = new String(bytes); 55 | processCommand(line); 56 | } 57 | } catch(IOException e) { 58 | // Ignore 59 | } 60 | } 61 | 62 | public void run() { 63 | try { 64 | while(!stop) { 65 | if(LogInfo.writeToStdout) 66 | readAndProcessCommand(); 67 | 68 | // Input commands 69 | Execution.inputMap.readEasy(Execution.getFile("input.map")); 70 | 71 | boolean killed = Execution.create && new File(Execution.getFile("kill")).exists(); 72 | if(killed) Execution.setExecStatus("killed", true); 73 | 74 | // Output status 75 | Execution.putOutput("log.note", LogInfo.note); 76 | Execution.putOutput("exec.memory", SysInfoUtils.getUsedMemoryStr()); 77 | Execution.putOutput("exec.time", new StopWatch(LogInfo.getWatch().getCurrTimeLong()).toString()); 78 | Execution.putOutput("exec.errors", "" + LogInfo.getNumErrors()); 79 | Execution.putOutput("exec.warnings", "" + LogInfo.getNumWarnings()); 80 | Execution.setExecStatus("running", false); 81 | Execution.printOutputMap(Execution.getFile("output.map")); 82 | 83 | if(killed) 84 | throw new RuntimeException("Killed by 'kill' file"); 85 | 86 | Utils.sleep(timeInterval); 87 | } 88 | } catch(Exception e) { 89 | e.printStackTrace(); 90 | System.exit(1); // Die completely 91 | } 92 | } 93 | 94 | public void start() { 95 | thread.start(); 96 | } 97 | 98 | public void finish() { 99 | stop = true; 100 | thread.interrupt(); 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /src/main/java/fast/common/Utility.java: -------------------------------------------------------------------------------- 1 | package fast.common; 2 | 3 | import java.text.DecimalFormat; 4 | import java.util.ArrayList; 5 | import java.util.LinkedHashMap; 6 | import java.util.Map; 7 | import java.util.Random; 8 | 9 | public class Utility { 10 | 11 | public static void swap(ArrayList dataStrs, int i, int j) { 12 | String temp = dataStrs.get(i); 13 | dataStrs.set(i, dataStrs.get(j)); 14 | dataStrs.set(j, temp); 15 | } 16 | 17 | 18 | public static void printArray(double[] array, String info) { 19 | System.out.println(info); 20 | String outStr = ""; 21 | for (int i = 0; i < array.length; i++) 22 | outStr += array[i] + "\t"; 23 | System.out.println(outStr); 24 | } 25 | 26 | 27 | public static double[] normalizedBySum(double[] values) { 28 | double[] normedValues = new double[values.length]; 29 | double sum = 0.0; 30 | for (int k = 0; k < values.length; k++) { 31 | sum += values[k]; 32 | } 33 | for (int k = 0; k < values.length; k++) { 34 | normedValues[k] = values[k] / sum; 35 | } 36 | return normedValues; 37 | } 38 | 39 | public static double[] uniformRandomArray(int dim, double lower, 40 | double upper, Random rand) { 41 | double range = upper - lower; 42 | double[] weights = new double[dim]; 43 | for (int i = 0; i < dim; ++i) { 44 | double randVal = rand.nextDouble(); 45 | weights[i] = lower + (range * randVal); 46 | } 47 | return weights; 48 | } 49 | 50 | public static double[] uniformRandomArraySumToOne(int dim, double lower, 51 | double upper, Random rand) { 52 | double range = upper - lower; 53 | double[] weights = new double[dim]; 54 | double sum = 0.0; 55 | int i = 0; 56 | for (; i < dim-1; ++i) { 57 | double randVal = rand.nextDouble(); 58 | weights[i] = lower + (range * randVal); 59 | sum += weights[i]; 60 | } 61 | weights[i] = 1 - sum; 62 | return weights; 63 | } 64 | 65 | public static double[] intToDoubleArray(int[] labels) { 66 | double[] targets = new double[labels.length]; 67 | for (int i = 0; i < labels.length; i++) { 68 | targets[i] = labels[i]; 69 | } 70 | return targets; 71 | } 72 | 73 | public void printArray(double[] oneArray) { 74 | for (int i = 0; i < oneArray.length; i++) 75 | System.out.print(oneArray[i] + "\t"); 76 | System.out.println(); 77 | } 78 | 79 | public static String doubleArrayListToString(ArrayList oneList, DecimalFormat formatter, String delimiter){ 80 | String str = ""; 81 | for (double value : oneList) 82 | str += getValidString(value, formatter) + delimiter; 83 | return str; 84 | } 85 | 86 | public static void arrayListToArray(ArrayList aList, double[] a) { 87 | for (int i = 0; i < aList.size(); i++) 88 | a[i] = aList.get(i); 89 | } 90 | 91 | public static String[] linkedHashMapToStrings(LinkedHashMap aMap, String delimiter){ 92 | String[] strs = {"", ""}; 93 | for (Map.Entry entry : aMap.entrySet()) { 94 | String key = entry.getKey(); 95 | Double value = entry.getValue(); 96 | strs[0] += key + delimiter; 97 | strs[1] += value + delimiter; 98 | } 99 | return strs; 100 | } 101 | 102 | public static String getValidString(Double value, DecimalFormat formatter){ 103 | String str = ""; 104 | if (value == null || Double.isNaN(value)) 105 | str = "NaN"; 106 | else 107 | str = formatter.format(value) + ""; 108 | return str; 109 | } 110 | 111 | 112 | } 113 | 114 | 115 | -------------------------------------------------------------------------------- /src/main/python/split_dataset.py: -------------------------------------------------------------------------------- 1 | __author__ = 'ugonzjo' 2 | import pandas as pd 3 | import random 4 | import os 5 | 6 | pd.set_option('display.width', 1000) 7 | pd.set_option('display.max_rows', 3000) 8 | 9 | ''' 10 | This scripts samples observations from a CSV so that knowledge components appear in training and testing. It relies on pandas. 11 | To install pandas easily, use pip: 12 | 13 | pip install pandas 14 | ''' 15 | 16 | 17 | def split( df, student_column, train_pct= 0.6,): 18 | 19 | if train_pct > 1: 20 | raise RuntimeError("Training and development set have to be (strictly) less than 100% of the students") 21 | 22 | students = df[student_column].unique() 23 | random.shuffle(students) 24 | 25 | train_number = int(train_pct * len(students)) 26 | 27 | 28 | train_students = students[0:train_number] 29 | test_students = students[train_number: ] 30 | 31 | 32 | df_train = df[student_column].isin(train_students) 33 | df_test = df[student_column].isin(test_students) 34 | 35 | 36 | return df_train, df_test,len(train_students), len(test_students) 37 | 38 | 39 | 40 | def lag(df, student_column, amount): 41 | column_order = df.columns 42 | 43 | fixed_columns = [ "fold", "outcome", "problem", "step", "KCs", student_column] 44 | feature_columns = [] 45 | for c in df.columns: 46 | if c not in fixed_columns: 47 | feature_columns.append(c) 48 | 49 | print "Feature columns: ", feature_columns 50 | 51 | 52 | df_lagged_features = df.groupby(student_column)[feature_columns].shift(amount) 53 | l = len(df_lagged_features) 54 | 55 | df_new = df[fixed_columns].join(df_lagged_features) 56 | 57 | assert( l == len(df)) 58 | assert (l == len(df_lagged_features)) 59 | assert (l == len(df_new)) 60 | 61 | df_new = df_new.fillna(0) 62 | return df_new[column_order] 63 | 64 | 65 | 66 | def main(filename="../../../datasets/sweet.csv", student_column="student", lag_features=1, min_students=2, min_observations=100, sep=",", train=0.8, seed=0): 67 | random.seed(seed) 68 | df = pd.read_csv(filename, sep=sep) 69 | 70 | 71 | # Lag features: 72 | df = lag(df, student_column, 1) 73 | 74 | kcs = df["KCs"].unique() 75 | 76 | trains = [] 77 | tests = [] 78 | for kc in kcs: 79 | df_kc = df[ df["KCs"] == kc] 80 | df_train, df_test, train_students, test_students = split(df_kc, train_pct=train, student_column=student_column) 81 | 82 | if len(df_kc[df_train]) > min_observations and train_students > min_students: 83 | trains.append(df_kc[df_train]) 84 | tests.append(df_kc[df_test]) 85 | 86 | print kc, len(df_kc), len(df_kc[df_train]), len(df_kc[df_test]), train_students, test_students 87 | assert len(df_kc) == len(df_kc[df_train]) + len(df_kc[df_test]) 88 | 89 | df_trains = pd.concat(trains, axis=0) 90 | df_tests = pd.concat(tests, axis=0) 91 | 92 | path = os.path.dirname(filename) 93 | name_ext = os.path.basename(filename) 94 | name, ext = os.path.splitext(name_ext) 95 | 96 | 97 | train_filename = "{}/filtered_{}_train0.csv".format(path, name) 98 | test_filename = "{}/filtered_{}_test0.csv".format(path, name) 99 | 100 | with open(train_filename, "w") as train: 101 | df_trains.to_csv(train, index=False) 102 | with open(test_filename, "w") as test: 103 | df_tests.to_csv(test, index=False) 104 | 105 | if __name__ == "__main__": 106 | import sys 107 | 108 | args = sys.argv 109 | print args 110 | cl = {} 111 | for i in range(1, len(args)): # index 0 is the filename 112 | pair = args[i].split('=') 113 | if pair[1].isdigit(): 114 | cl[pair[0]] = int(pair[1]) 115 | elif pair[1].lower() in ("true", "false"): 116 | cl[pair[0]] = (pair[1].lower() == 'true') 117 | else: 118 | cl[pair[0]] = pair[1] 119 | 120 | main(**cl) 121 | -------------------------------------------------------------------------------- /src/main/java/fast/featurehmm/PdfFeatureAware.java: -------------------------------------------------------------------------------- 1 | /** 2 | * FAST v1.0 08/12/2014 3 | * 4 | * This code is only for research purpose not commercial purpose. 5 | * It is originally developed for research purpose and is still under improvement. 6 | * Please email to us if you want to keep in touch with the latest release. 7 | We sincerely welcome you to contact Yun Huang (huangyun.ai@gmail.com), or Jose P.Gonzalez-Brenes (josepablog@gmail.com) for problems in the code or cooperation. 8 | * We thank Taylor Berg-Kirkpatrick (tberg@cs.berkeley.edu) and Jean-Marc Francois (jahmm) for part of their code that FAST is developed based on. 9 | * 10 | */ 11 | 12 | /* 13 | * This is built based on: 14 | * 15 | * jaHMM package - v0.6.1 16 | * Copyright (c) 2004-2006, Jean-Marc Francois. 17 | */ 18 | 19 | package fast.featurehmm; 20 | 21 | import java.io.Serializable; 22 | //import java.text.NumberFormat; 23 | import be.ac.ulg.montefiore.run.jahmm.Observation; 24 | 25 | /** 26 | * Objects implementing this interface represent a probability (distribution) 27 | * function which can be used to paramterize init, transition, or emission 28 | * probabilities. 29 | *

30 | * An PdfContextAware can represent a probability function (if the 31 | * nodes can take discrete values) or a probability distribution (if the nodes 32 | * are continuous). 33 | */ 34 | public interface PdfFeatureAware extends Cloneable, 35 | Serializable { 36 | 37 | /** 38 | * @author hy 39 | * @date 10/06/13 40 | * 41 | * Returns the probability (density) of init/transition/emission given 42 | * the corresponding contextual features, and the index (to determine 43 | * the form of logistic regression), based on the distribution defined 44 | * by a set of weights. (index=observationIndex, or hiddenStateIndex) 45 | * 46 | * @return The probability (density, if o takes continuous 47 | * values) of o for this function. 48 | */ 49 | public double probability(double[] featureValues, int index, String type); 50 | 51 | /** 52 | * Fits this observation probability (distribution) function to a weighted 53 | * (non empty) set of observations. Equations (53) and (54) of Rabiner's A 54 | * Tutorial on Hidden Markov Models and Selected Applications in Speech 55 | * Recognition explain how the weights can be used. 56 | * 57 | * @param o 58 | * An array of observations compatible with this factory. 59 | * @param weights 60 | * The weight associated to each observation (such that 61 | * weight.length == o.length and the sum of all the 62 | * elements equals 1). 63 | */ 64 | // void fit(O[] o, double[] weights);// hy: may be a good place to plug 65 | // logistic 66 | // // regression in! 67 | 68 | /** 69 | * Fits this observation probability (distribution) function to a weighted 70 | * (non empty) set of observations. Equations (53) and (54) of Rabiner's A 71 | * Tutorial on Hidden Markov Models and Selected Applications in Speech 72 | * Recognition explain how the weights can be used. 73 | * 74 | * @param co 75 | * A set of observations compatible with this factory. 76 | * @param weights 77 | * The weight associated to each observation (such that 78 | * weight.length == o.length and the sum of all the 79 | * elements equals 1). 80 | */ 81 | // void fit(Collection co, double[] weights, int 82 | // hiddenStateIndex, 83 | // String type); 84 | 85 | // void fit(Collection co, double[][] weights, String type); 86 | 87 | /** 88 | * Returns a {@link java.lang.String String} describing this distribution. 89 | * 90 | * @param numberFormat 91 | * A formatter used to convert the numbers (e.g. 92 | * probabilities) to strings. 93 | * @return A {@link java.lang.String String} describing this distribution. 94 | */ 95 | 96 | //public String toString(NumberFormat numberFormat); 97 | 98 | public PdfFeatureAware clone(); 99 | } 100 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/Pair.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | 6 | 7 | /** 8 | * A generic-typed pair of objects. 9 | * @author Dan Klein 10 | */ 11 | public class Pair implements Serializable { 12 | static final long serialVersionUID = 42; 13 | 14 | F first; 15 | S second; 16 | 17 | public F getFirst() { 18 | return first; 19 | } 20 | 21 | public S getSecond() { 22 | return second; 23 | } 24 | 25 | public void setFirst(F pFirst) { 26 | first = pFirst; 27 | } 28 | 29 | public void setSecond(S pSecond) { 30 | second = pSecond; 31 | } 32 | 33 | public Pair reverse() { 34 | return new Pair(second, first); 35 | } 36 | 37 | public boolean equals(Object o) { 38 | if (this == o) 39 | return true; 40 | if (!(o instanceof Pair)) 41 | return false; 42 | 43 | final Pair pair = (Pair) o; 44 | 45 | if (first != null ? !first.equals(pair.first) : pair.first != null) 46 | return false; 47 | if (second != null ? !second.equals(pair.second) : pair.second != null) 48 | return false; 49 | 50 | return true; 51 | } 52 | 53 | public int hashCode() { 54 | int result; 55 | result = (first != null ? first.hashCode() : 0); 56 | result = 29 * result + (second != null ? second.hashCode() : 0); 57 | return result; 58 | } 59 | 60 | public String toString() { 61 | return "(" + getFirst() + ", " + getSecond() + ")"; 62 | } 63 | 64 | public Pair(F first, S second) { 65 | this.first = first; 66 | this.second = second; 67 | } 68 | 69 | // Compares only first values 70 | public static class FirstComparator, T> 71 | implements Comparator> { 72 | public int compare(Pair p1, Pair p2) { 73 | return p1.getFirst().compareTo(p2.getFirst()); 74 | } 75 | } 76 | 77 | public static class ReverseFirstComparator, T> 78 | implements Comparator> { 79 | public int compare(Pair p1, Pair p2) { 80 | return p2.getFirst().compareTo(p1.getFirst()); 81 | } 82 | } 83 | 84 | // Compares only second values 85 | public static class SecondComparator> 86 | implements Comparator> { 87 | public int compare(Pair p1, Pair p2) { 88 | return p1.getSecond().compareTo(p2.getSecond()); 89 | } 90 | } 91 | 92 | public static class ReverseSecondComparator> 93 | implements Comparator> { 94 | public int compare(Pair p1, Pair p2) { 95 | return p2.getSecond().compareTo(p1.getSecond()); 96 | } 97 | } 98 | 99 | public static Pair newPair(S first, T second) { 100 | return new Pair(first, second); 101 | } 102 | // Duplicate method to faccilitate backwards compatibility 103 | // - aria42 104 | public static Pair makePair(S first, T second) { 105 | return new Pair(first, second); 106 | } 107 | 108 | public static class LexicographicPairComparator implements Comparator> { 109 | Comparator firstComparator; 110 | Comparator secondComparator; 111 | 112 | public int compare(Pair pair1, Pair pair2) { 113 | int firstCompare = firstComparator.compare(pair1.getFirst(), pair2.getFirst()); 114 | if (firstCompare != 0) 115 | return firstCompare; 116 | return secondComparator.compare(pair1.getSecond(), pair2.getSecond()); 117 | } 118 | 119 | public LexicographicPairComparator(Comparator firstComparator, Comparator secondComparator) { 120 | this.firstComparator = firstComparator; 121 | this.secondComparator = secondComparator; 122 | } 123 | } 124 | 125 | public static class DefaultLexicographicPairComparator,S extends Comparable> 126 | implements Comparator> { 127 | 128 | public int compare(Pair o1, Pair o2) { 129 | int firstCompare = o1.getFirst().compareTo(o2.getFirst()); 130 | if (firstCompare != 0) { 131 | return firstCompare; 132 | } 133 | return o2.getSecond().compareTo(o2.getSecond()); 134 | } 135 | 136 | } 137 | 138 | 139 | } 140 | -------------------------------------------------------------------------------- /data/others/FAST+item_train0.txt: -------------------------------------------------------------------------------- 1 | student outcome KC features_j2D_Arrays1 features_j2D_Arrays2 features_j2D_arrays3 features_j2D_arrays4 features_jArray1 features_jArray2 features_jArray3 features_jArray4 features_jArray5 features_jBa_ques features_jBankAccount features_jClass1 features_jClasses4 features_jClasses_Getter features_jObjects1 features_jObjects2 features_jObjects3 features_jObjects4 features_jObjects5 features_jVariables1 features_jVariables2 features_jVariables3 features_jVariables4 features_jVariables5 2647 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 0 0 0 0 2647 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 1 0 0 0 2647 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 1 0 0 2647 incorrect Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 1 0 2647 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 0 1 2647 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 1 0 2647 incorrect Objects NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 0 0 0 0 NULL NULL NULL NULL NULL 2647 correct Objects NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 0 0 0 0 NULL NULL NULL NULL NULL 2647 correct Objects NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 1 0 0 0 NULL NULL NULL NULL NULL 2647 incorrect Objects NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 1 0 0 NULL NULL NULL NULL NULL 2646 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 0 1 2646 correct Classes NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 1 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2646 incorrect Classes NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 1 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2646 correct Classes NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 1 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2646 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 0 0 0 0 2646 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 1 0 0 0 2646 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 1 0 0 2646 incorrect Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 1 0 2646 incorrect Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 1 0 2646 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 1 0 2646 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 0 1 2665 incorrect Arrays NULL NULL NULL NULL 1 0 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2665 incorrect Arrays NULL NULL NULL NULL 1 0 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2665 incorrect Arrays NULL NULL NULL NULL 1 0 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2665 incorrect Arrays NULL NULL NULL NULL 1 0 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2665 correct Arrays NULL NULL NULL NULL 1 0 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2665 correct Arrays NULL NULL NULL NULL 1 0 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2665 incorrect Arrays NULL NULL NULL NULL 0 1 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2665 correct Arrays NULL NULL NULL NULL 1 0 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2515 incorrect Two-dimensional_Arrays 1 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2515 correct Two-dimensional_Arrays 1 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL -------------------------------------------------------------------------------- /src/main/java/fig/basic/OrderedMap.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | 6 | /** 7 | * A container for mapping objects to objects. 8 | * Keep track of the order of the objects 9 | * (as they were inserted into the data structure). 10 | * No duplicate elements allowed. 11 | */ 12 | public class OrderedMap { 13 | private ArrayList keys = new ArrayList(); 14 | private Map map = new HashMap(); 15 | 16 | public OrderedMap() { } 17 | public OrderedMap(OrderedMap map) { 18 | for(S key : map.keys()) 19 | put(key, map.get(key)); 20 | } 21 | 22 | public void clear() { keys.clear(); map.clear(); } 23 | 24 | public void log(String title) { 25 | LogInfo.track(title, true); 26 | for(S key : keys()) 27 | LogInfo.logs(key + "\t" + get(key)); 28 | LogInfo.end_track(); 29 | } 30 | 31 | public void put(S key) { put(key, null); } 32 | public void putAtEnd(S key) { put(key, get(key)); } 33 | public void removeAt(int i) { 34 | S key = keys.get(i); 35 | keys.remove(i); 36 | map.remove(key); 37 | } 38 | 39 | public void reput(S key, T val) { // Don't affect order (but insert if key doesn't exist) 40 | if(!map.containsKey(key)) put(key, val); 41 | else map.put(key, val); 42 | } 43 | public void put(S key, T val) { 44 | // If key already exists, we replace its value and move it to the end of the list. 45 | if(map.containsKey(key)) 46 | keys.remove(key); // Remove last occurrence of key 47 | keys.add(key); 48 | map.put(key, val); 49 | } 50 | 51 | public int size() { return keys.size(); } 52 | public boolean containsKey(S key) { return map.containsKey(key); } 53 | public T get(S key) { return map.get(key); } 54 | public T get(S key, T defaultVal) { return MapUtils.get(map, key, defaultVal); } 55 | 56 | public Set keySet() { return map.keySet(); } 57 | public List keys() { return keys; } 58 | 59 | // Values { 60 | public ValueCollection values() { return new ValueCollection(); } 61 | public class ValueCollection extends AbstractCollection { 62 | public Iterator iterator() { return new ValueIterator(); } 63 | public int size() { return size(); } 64 | public boolean contains(Object o) { throw new UnsupportedOperationException(); } 65 | public void clear() { throw new UnsupportedOperationException(); } 66 | } 67 | private class ValueIterator implements Iterator { 68 | public ValueIterator() { next = 0; } 69 | public boolean hasNext() { return next < size(); } 70 | public T next() { return map.get(keys.get(next++)); } 71 | public void remove() { throw new UnsupportedOperationException(); } 72 | private int next; 73 | } 74 | // } 75 | 76 | /** 77 | * Output each entry in the HashMap on a line separated by a tab. 78 | */ 79 | public void print(PrintWriter out) { 80 | for(S key : keys) { 81 | print(out, key, map.get(key)); 82 | } 83 | out.flush(); 84 | } 85 | public void print(String path) throws IOException { print(new File(path)); } 86 | public void printHard(String path) { 87 | PrintWriter out = IOUtils.openOutHard(path); 88 | print(out); 89 | out.close(); 90 | } 91 | public void print(File path) throws IOException { 92 | PrintWriter out = IOUtils.openOut(path); 93 | print(out); 94 | out.close(); 95 | } 96 | 97 | public String print() { 98 | StringWriter sw = new StringWriter(); 99 | print(new PrintWriter(sw)); 100 | return sw.toString(); 101 | } 102 | 103 | void print(PrintWriter out, S key, T val) { 104 | out.println(key + (val == null ? "" : "\t" + val)); 105 | } 106 | 107 | public boolean printEasy(String path) { 108 | if(StrUtils.isEmpty(path)) return false; 109 | return printEasy(new File(path)); 110 | } 111 | public boolean printEasy(File path) { 112 | if(path == null) return false; 113 | try { 114 | PrintWriter out = IOUtils.openOut(path); 115 | print(out); 116 | out.close(); 117 | return true; 118 | } catch(Exception e) { 119 | return false; 120 | } 121 | } 122 | 123 | public String toString() { 124 | StringBuilder sb = new StringBuilder(); 125 | for(S key : keys) { 126 | sb.append(key +" "+ map.get(key)+"\n"); 127 | } 128 | return sb.toString(); 129 | } 130 | 131 | } 132 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/SysInfoUtils.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.io.*; 4 | import java.util.*; 5 | import java.net.*; 6 | 7 | public class SysInfoUtils { 8 | public static String getCurrentDateStr() { 9 | return new Date().toString(); 10 | } 11 | public static String getHostName() { 12 | try { 13 | return InetAddress.getLocalHost().getHostName(); 14 | } 15 | catch(UnknownHostException e) { 16 | return "(unknown)"; 17 | } 18 | } 19 | public static String getShortHostName() { 20 | String name = getHostName(); 21 | int i = name.indexOf('.'); 22 | if(i == -1) return name; 23 | return name.substring(0, i); 24 | } 25 | public static String getcwd() { return System.getProperty("user.dir"); } 26 | 27 | private static int numCPUs = -1; // Cache: doesn't change 28 | public static int getNumCPUs() { 29 | // Linux 30 | if(numCPUs != -1) return numCPUs; 31 | try { 32 | int n = 0; 33 | for(String line : IOUtils.readLines("/proc/cpuinfo")) { 34 | if(line.startsWith("processor")) 35 | n++; 36 | } 37 | return numCPUs = n; 38 | } 39 | catch(IOException e) { 40 | } 41 | 42 | // MacOS 43 | try { 44 | // Output format: hw.ncpu: 1 45 | return numCPUs = 46 | Integer.parseInt(StrUtils.split(Utils.systemGetStringOutput("sysctl hw.ncpu").trim(), " ")[1]); 47 | } catch(Exception e) { 48 | } 49 | 50 | return 0; 51 | } 52 | 53 | public static int getNumUsedCPUs() { 54 | // This command should return the percent CPU usages of 55 | // all processes, one on each line 56 | // A bit of a hack: if a process uses more than 50% of the CPU, 57 | // then it is considered used 58 | try { 59 | int n = 0; 60 | for(String line : StrUtils.split(Utils.systemGetStringOutput("ps ax -o pcpu"), "\n")) { 61 | double percentCPU = Utils.parseDoubleEasy(line); 62 | if(percentCPU > 50) n++; 63 | if(percentCPU > 150) n++; 64 | if(percentCPU > 250) n++; 65 | if(percentCPU > 350) n++; 66 | } 67 | return n; 68 | } catch(Exception e) { 69 | return -1; 70 | } 71 | } 72 | 73 | public static int getNumFreeCPUs() { 74 | return getNumCPUs() - getNumUsedCPUs(); 75 | } 76 | 77 | // Return in MHz 78 | private static int cpuSpeed = -1; // Cache it since it doesn't change 79 | public static int getCPUSpeed() { 80 | if(cpuSpeed != -1) return cpuSpeed; 81 | 82 | // Linux: take the average of the CPU speeds of all processors 83 | try { 84 | double sum = 0; 85 | int n = 0; 86 | for(String line : IOUtils.readLines("/proc/cpuinfo")) { 87 | if(line.startsWith("cpu MHz")) { 88 | sum += Double.parseDouble(ListUtils.getLast(StrUtils.split(line))); 89 | n++; 90 | } 91 | } 92 | return cpuSpeed = (int)(sum/n+0.5); 93 | } catch(IOException e) { 94 | } 95 | 96 | // MacOS 97 | try { 98 | // Output format: hw.cpufrequency: 1499999994 99 | return cpuSpeed = 100 | Integer.parseInt(StrUtils.split(Utils.systemGetStringOutput("sysctl hw.cpufrequency").trim(), " ")[1])/1000000; 101 | } catch(Exception e) { 102 | } 103 | 104 | return 0; 105 | } 106 | public static String getCPUSpeedStr() { 107 | return getCPUSpeed() + " MHz"; 108 | } 109 | 110 | // Memory of this java process 111 | public static String getMaxMemoryStr() { 112 | long mem = Runtime.getRuntime().maxMemory(); 113 | return Fmt.bytesToString(mem); 114 | } 115 | public static String getUsedMemoryStr() { 116 | long totalMem = Runtime.getRuntime().totalMemory(); 117 | long freeMem = Runtime.getRuntime().freeMemory(); 118 | return Fmt.bytesToString(totalMem-freeMem); 119 | } 120 | 121 | // Memory 122 | public static long getFreeMemory() { 123 | // Linux 124 | try { 125 | int n = 0; 126 | long memfree = 0, buffers = 0, cached = 0; 127 | for(String line : IOUtils.readLines("/proc/meminfo")) { 128 | if(line.startsWith("MemFree:")) 129 | memfree = Long.parseLong(line.split("\\s+")[1]); 130 | if(line.startsWith("Buffers:")) 131 | buffers = Long.parseLong(line.split("\\s+")[1]); 132 | if(line.startsWith("Cached:")) 133 | cached = Long.parseLong(line.split("\\s+")[1]); 134 | } 135 | return (memfree + buffers + cached) * 1024; 136 | } 137 | catch(Exception e) { 138 | return 0; 139 | } 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /src/main/java/fast/evaluation/AUC.java: -------------------------------------------------------------------------------- 1 | /** 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | // Code from: https://code.google.com/p/jforests/source/browse/trunk/jforests/src/main/java/edu/uci/jforests/eval/AUC.java 18 | //package edu.uci.jforests.eval; 19 | package fast.evaluation; 20 | import java.util.ArrayList; 21 | import java.util.Collections; 22 | import java.util.List; 23 | 24 | /** 25 | * @author Yasser Ganjisaffar 26 | */ 27 | 28 | public class AUC { 29 | 30 | /** 31 | * @author: Yun Huang changed the original code 32 | */ 33 | public double measure(double[] predictions, Sample sample) { 34 | int totalPositive = 0; 35 | int totalNegative = 0; 36 | List sortedProb = new ArrayList(); 37 | for (int i = 0; i < sample.size; i++) { 38 | double label = sample.targets[i]; 39 | sortedProb.add(new DoubleDoublePair(predictions[i], label)); 40 | if (label == 0) { 41 | totalNegative++; 42 | } 43 | else { 44 | totalPositive++; 45 | } 46 | } 47 | Collections.sort(sortedProb); 48 | 49 | double fp = 0; 50 | double tp = 0; 51 | double fpPrev = 0; 52 | double tpPrev = 0; 53 | double area = 0; 54 | double fPrev = Double.MIN_VALUE; 55 | 56 | int i = 0; 57 | while (i < sortedProb.size()) { 58 | DoubleDoublePair pair = sortedProb.get(i); 59 | double curF = pair.key; 60 | if (curF != fPrev) { 61 | area += Math.abs(fp - fpPrev) * ((tp + tpPrev) / 2.0); 62 | fPrev = curF; 63 | fpPrev = fp; 64 | tpPrev = tp; 65 | } 66 | double label = pair.value; 67 | if (label == +1) { 68 | tp++; 69 | } 70 | else { 71 | fp++; 72 | } 73 | i++; 74 | } 75 | area += Math.abs(totalNegative - fpPrev) * ((totalPositive + tpPrev) / 2.0); 76 | area /= ((double) totalPositive * totalNegative); 77 | return area; 78 | } 79 | 80 | // hy commented 81 | // @Override 82 | // public double measure(double[] predictions, Sample sample) { 83 | // int totalPositive = 0; 84 | // int totalNegative = 0; 85 | // List sortedProb = new ArrayList(); 86 | // for (int i = 0; i < sample.size; i++) { 87 | // double label = sample.targets[i]; 88 | // sortedProb.add(new DoubleDoublePair(predictions[i], label)); 89 | // if (label == 0) { 90 | // totalNegative++; 91 | // } 92 | // else { 93 | // totalPositive++; 94 | // } 95 | // } 96 | // Collections.sort(sortedProb); 97 | // 98 | // double fp = 0; 99 | // double tp = 0; 100 | // double fpPrev = 0; 101 | // double tpPrev = 0; 102 | // double area = 0; 103 | // double fPrev = Double.MIN_VALUE; 104 | // 105 | // int i = 0; 106 | // while (i < sortedProb.size()) { 107 | // DoubleDoublePair pair = sortedProb.get(i); 108 | // double curF = pair.key; 109 | // if (curF != fPrev) { 110 | // area += Math.abs(fp - fpPrev) * ((tp + tpPrev) / 2.0); 111 | // fPrev = curF; 112 | // fpPrev = fp; 113 | // tpPrev = tp; 114 | // } 115 | // double label = pair.value; 116 | // if (label == +1) { 117 | // tp++; 118 | // } 119 | // else { 120 | // fp++; 121 | // } 122 | // i++; 123 | // } 124 | // area += Math.abs(totalNegative - fpPrev) * ((totalPositive + tpPrev) / 125 | // 2.0); 126 | // area /= ((double) totalPositive * totalNegative); 127 | // return area; 128 | // } 129 | 130 | private static class DoubleDoublePair implements Comparable { 131 | public double key; 132 | public double value; 133 | 134 | public DoubleDoublePair(double key, double value) { 135 | this.key = key; 136 | this.value = value; 137 | } 138 | 139 | @Override 140 | public int compareTo(DoubleDoublePair o) { 141 | if (this.key > o.key) { 142 | return -1; 143 | } 144 | else if (this.key < o.key) { 145 | return 1; 146 | } 147 | return 0; 148 | } 149 | } 150 | } -------------------------------------------------------------------------------- /src/main/java/fast/common/Matrix.java: -------------------------------------------------------------------------------- 1 | package fast.common; 2 | 3 | import java.util.Arrays; 4 | 5 | public final class Matrix 6 | { 7 | 8 | public static final double EPS = 0.0000001; 9 | 10 | public static double max(double array[]) 11 | { 12 | double max = Double.MIN_VALUE; 13 | for (double a : array) 14 | if (a > max) 15 | max = a; 16 | return max; 17 | } 18 | public static int argmax(double[] array) 19 | { 20 | int index = -1; 21 | double best = Double.MIN_VALUE; 22 | for (int i = 0; i < array.length; i++) 23 | { 24 | //System.out.print( array[i]); 25 | if (array[i] > best) 26 | { 27 | best = array[i]; 28 | index = i; 29 | 30 | //System.out.println("*"); 31 | } 32 | /*else 33 | System.out.println("");*/ 34 | } 35 | return index; 36 | } 37 | public static double[] add(double[] a, double[] b) 38 | { 39 | if( a.length != b.length ) 40 | throw new ArithmeticException("Attempting to add arrays of different lengths (" + a.length + "," + b.length + ")"); 41 | double c[] = new double[a.length]; 42 | for (int i = 0; i < a.length; i++) 43 | { 44 | c[i] = a[i] + b[i]; 45 | } 46 | return c; 47 | } 48 | 49 | public static double[] add(double[] a, int[] b) 50 | { 51 | if( a.length != b.length ) 52 | throw new ArithmeticException("Attempting to add arrays of different lengths (" + a.length + "," + b.length + ")"); 53 | double c[] = new double[a.length]; 54 | for (int i = 0; i < a.length; i++) 55 | { 56 | c[i] = a[i] + b[i]; 57 | } 58 | return c; 59 | } 60 | 61 | public static double[] mult(double[] a, double b) 62 | { 63 | final double ans[] = new double[a.length]; 64 | for (int i = 0; i < a.length; i++) 65 | { 66 | ans[i] = a[i] * b; 67 | } 68 | return ans; 69 | } 70 | 71 | public static double[] dotmult(double[] a, double[] b, double c) 72 | { 73 | if( a.length != b.length ) 74 | throw new ArithmeticException("Attempting to multiply arrays of different lengths (" + a.length + "," + b.length + ")"); 75 | 76 | final double ans[] = new double[a.length]; 77 | for (int i = 0; i < a.length; i++) 78 | { 79 | ans[i] = a[i] * b[i] * c; 80 | } 81 | return ans; 82 | } 83 | 84 | 85 | 86 | public static double[] dotmult(double[] a, double[] b) 87 | { 88 | if( a.length != b.length ) 89 | throw new ArithmeticException("Attempting to multiply arrays of different lengths"); 90 | 91 | final double ans[] = new double[a.length]; 92 | for (int i = 0; i < a.length; i++) 93 | ans[i] = a[i] * b[i]; 94 | return ans; 95 | } 96 | 97 | 98 | public static double[] add(int[] a, double b) 99 | { 100 | double c[] = new double[a.length]; 101 | for (int i = 0; i < a.length; i++) 102 | { 103 | c[i] = a[i] + b; 104 | } 105 | return c; 106 | } 107 | 108 | 109 | 110 | public static double[] add(double[] a, double b) 111 | { 112 | double c[] = new double[a.length]; 113 | for (int i = 0; i < a.length; i++) 114 | { 115 | c[i] = a[i] + b; 116 | } 117 | return c; 118 | } 119 | 120 | public static double sum(double[] vector) 121 | { 122 | double ans = 0; 123 | for (double e: vector) 124 | ans += e; 125 | return ans; 126 | } 127 | 128 | public static double sum(double[][] q, int obs, int dim) 129 | { 130 | double ans = 0; 131 | assert dim == 2: "Parameter value not implemented"; 132 | 133 | for(int i = 0; i < q.length; i++) 134 | ans += q[i][obs]; 135 | 136 | return ans; 137 | } 138 | 139 | 140 | 141 | public static double[] toDouble(String[] split) 142 | { 143 | double[] ans = new double[split.length]; 144 | for(int i = 0; i < ans.length; i++) 145 | ans[i] = Double.valueOf(split[i]); 146 | return ans; 147 | } 148 | 149 | public static void assertProbability(double[] p ) 150 | { 151 | final double sum = Math.abs((Matrix.sum(p) - 1)); 152 | assert sum > (-1*EPS) & sum < EPS : "Not a probability: " + Matrix.sum(p) + " " + Arrays.toString(p); 153 | } 154 | 155 | public static int[] addInt(int[] a, int[] b) 156 | { 157 | if( a.length != b.length ) 158 | throw new ArithmeticException("Attempting to multiply arrays of different lengths"); 159 | 160 | final int ans[] = new int[a.length]; 161 | for (int i = 0; i < a.length; i++) 162 | ans[i] = a[i] + b[i]; 163 | 164 | return ans; 165 | 166 | } 167 | 168 | public static double[] div(double[] p, double pNorm) 169 | { 170 | final double ans[] = new double[p.length]; 171 | for (int i = 0; i < p.length; i++) 172 | { 173 | ans[i] = p[i] / pNorm; 174 | } 175 | return ans; 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /src/main/java/fast/data/DataPoint.java: -------------------------------------------------------------------------------- 1 | /** 2 | * FAST v1.0 08/12/2014 3 | * 4 | * This code is only for research purpose not commercial purpose. 5 | * It is originally developed for research purpose and is still under improvement. 6 | * Please email to us if you want to keep in touch with the latest release. 7 | We sincerely welcome you to contact Yun Huang (huangyun.ai@gmail.com), or Jose P.Gonzalez-Brenes (josepablog@gmail.com) for problems in the code or cooperation. 8 | * We thank Taylor Berg-Kirkpatrick (tberg@cs.berkeley.edu) and Jean-Marc Francois (jahmm) for part of their codes that FAST is developed based on. 9 | * 10 | */ 11 | 12 | package fast.data; 13 | 14 | import java.text.NumberFormat; 15 | import be.ac.ulg.montefiore.run.jahmm.Observation; 16 | 17 | public class DataPoint extends Observation { 18 | 19 | private final int student, problem, step, skill; 20 | private int nbStates; 21 | /** 22 | * expandedFeatures dimensions: 23 | * 1st: hiddenStates; (0 state with _hidden1 features deactivated) 24 | * 2nd: type: 0-init,1-tran,2-emit; 25 | * 3rd: featureValues corresponding to current hiddenState; 26 | * 27 | * expandedFeatures[0] will be null without initialization. 28 | */ 29 | private double[][][] expandedFeatures = null; 30 | // private double[] features = null; 31 | // private boolean oneLogisticRegression = false; 32 | private int fold = -1; 33 | private Double llAprox = -1., llExact = -1.; 34 | private Integer groundTruth, outcome; 35 | 36 | @Override 37 | public String toString(NumberFormat numberFormat) { 38 | return numberFormat.format(outcome); 39 | } 40 | 41 | public void setNbStates(int nbStates) { 42 | this.nbStates = nbStates; 43 | expandedFeatures = new double[this.nbStates][3][]; 44 | } 45 | 46 | public DataPoint(int aOutcome) { 47 | this.student = -1; 48 | this.problem = -1; 49 | this.step = -1; 50 | this.fold = -1; 51 | this.skill = -1; 52 | this.outcome = aOutcome; 53 | } 54 | 55 | public DataPoint(int aStudent, int aSkill, int aProb, int aStep, int aFold, 56 | int aOutcome) { 57 | this.student = aStudent; 58 | this.skill = aSkill; 59 | this.problem = aProb; 60 | this.step = aStep; 61 | this.fold = aFold; 62 | this.outcome = aOutcome; 63 | } 64 | 65 | // For 1LR 66 | public DataPoint(int aStudent, int aSkill, int aProb, int aStep, int aFold, 67 | int aOutcome, double[][][] aFeatures) { 68 | this.student = aStudent; 69 | this.problem = aProb; 70 | this.step = aStep; 71 | this.fold = aFold; 72 | this.skill = aSkill; 73 | this.expandedFeatures = aFeatures; 74 | this.outcome = aOutcome; 75 | // this.oneLogisticRegression = true; 76 | } 77 | 78 | // TODO: Reserved for 1 bias feature, check whether I could simplify 79 | // public DataPoint(int aStudent, int aProb, int aStep, int aFold, int 80 | // aOutcome, 81 | // double[] aFeatures) { 82 | // this.student = aStudent; 83 | // this.problem = aProb; 84 | // this.step = aStep; 85 | // this.fold = aFold; 86 | // this.features = aFeatures; 87 | // this.outcome = aOutcome; 88 | // this.oneLogisticRegression = false; 89 | // } 90 | 91 | public DataPoint(int studentId, int skillId, int problemId, int stepId, 92 | int groundTruth, double llAprox, double llExact) { 93 | this.student = studentId; 94 | this.skill = skillId; 95 | this.problem = problemId; 96 | this.step = stepId; 97 | 98 | this.groundTruth = groundTruth; 99 | this.llAprox = llAprox; 100 | this.llExact = llExact; 101 | } 102 | 103 | public void setExpandedFeatures(double[][][] expandedFeatures) { 104 | this.expandedFeatures = expandedFeatures; 105 | } 106 | 107 | // public void setFeatures(double[] features) { 108 | // this.features = features; 109 | // } 110 | 111 | public int getStudent() { 112 | return student; 113 | } 114 | 115 | public int getSkill() { 116 | return skill; 117 | } 118 | 119 | public int getProblem() { 120 | return problem; 121 | } 122 | 123 | public int getStep() { 124 | return step; 125 | } 126 | 127 | public int getFold() { 128 | return fold; 129 | } 130 | 131 | public int getOutcome() { 132 | return outcome; 133 | } 134 | 135 | public int getGroundTruth() { 136 | return groundTruth; 137 | } 138 | 139 | public double getLLAprox() { 140 | return llAprox; 141 | } 142 | 143 | public double getLLExact() { 144 | return llExact; 145 | } 146 | 147 | /* 148 | * type: 0-init, 1-tran, 2-emit 149 | */ 150 | public double[] getFeatures(int hiddenStateIndex, int type) { 151 | if (expandedFeatures == null) 152 | return null; 153 | // this will return null if 2nd dimension is not initialized 154 | return expandedFeatures[hiddenStateIndex][type]; 155 | // else 156 | // return features; 157 | 158 | } 159 | 160 | // public double[] getFeatures() { 161 | // return features; 162 | // } 163 | 164 | } 165 | -------------------------------------------------------------------------------- /data/others/FAST+item_test0.txt: -------------------------------------------------------------------------------- 1 | student outcome KC features_j2D_Arrays1 features_j2D_Arrays2 features_j2D_arrays3 features_j2D_arrays4 features_jArray1 features_jArray2 features_jArray3 features_jArray4 features_jArray5 features_jBa_ques features_jBankAccount features_jClass1 features_jClasses4 features_jClasses_Getter features_jObjects1 features_jObjects2 features_jObjects3 features_jObjects4 features_jObjects5 features_jVariables1 features_jVariables2 features_jVariables3 features_jVariables4 features_jVariables5 2644 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 0 0 0 0 2644 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 0 0 0 0 2644 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 0 0 0 0 2644 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 1 0 0 0 2644 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 1 0 0 0 2644 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 1 0 0 2644 incorrect Objects NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 0 0 0 0 NULL NULL NULL NULL NULL 2644 correct Objects NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 0 0 0 0 NULL NULL NULL NULL NULL 2644 correct Objects NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 1 0 0 0 NULL NULL NULL NULL NULL 2644 correct Objects NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 1 0 0 0 NULL NULL NULL NULL NULL 2644 correct Objects NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 1 0 0 NULL NULL NULL NULL NULL 2644 correct Objects NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 1 0 0 NULL NULL NULL NULL NULL 2644 correct Objects NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 1 0 NULL NULL NULL NULL NULL 2644 correct Objects NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 1 0 NULL NULL NULL NULL NULL 2644 correct Objects NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 0 1 NULL NULL NULL NULL NULL 2644 incorrect Classes NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 1 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2644 correct Classes NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 1 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2644 correct Classes NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 1 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2644 correct Classes NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 1 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2644 correct Classes NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 0 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2644 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 1 0 2644 incorrect Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 1 0 2644 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 1 0 2644 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 1 0 2644 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 0 1 2644 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 0 1 2644 correct Variables NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 0 0 0 0 1 2644 correct Arrays NULL NULL NULL NULL 1 0 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2644 correct Arrays NULL NULL NULL NULL 1 0 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2644 correct Arrays NULL NULL NULL NULL 0 1 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2644 incorrect Arrays NULL NULL NULL NULL 0 0 1 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2644 incorrect Arrays NULL NULL NULL NULL 0 0 1 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2644 correct Arrays NULL NULL NULL NULL 0 0 1 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2664 incorrect Two-dimensional_Arrays 1 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2664 incorrect Two-dimensional_Arrays 1 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL 2664 correct Two-dimensional_Arrays 1 0 0 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL -------------------------------------------------------------------------------- /src/main/java/fig/record/Record.java: -------------------------------------------------------------------------------- 1 | package fig.record; 2 | 3 | import fig.basic.*; 4 | import java.io.*; 5 | import java.util.*; 6 | 7 | /** 8 | * Record is a static class for instrumenting the state of an execution. 9 | * The calls record the state of variables at certain points of program execution. 10 | * The calls can be grouped hierarchically as well. 11 | * Basic usage: 12 | * init(record file to output to) 13 | * begin(key, value) 14 | * add(key, value) 15 | * end() 16 | * finish() 17 | * Caching feature (not fully supported): 18 | * Recordable objects can be cached, so when they are added a second time, 19 | * only a pointer is written to file. 20 | * 21 | * The record file encodes a tree, where each node of the tree 22 | * consists of a key and optionally a value, and a list of children nodes. 23 | * In the file, each line specifies a node: 24 | * <\t * D, where D is depth of node>\t 25 | */ 26 | public class Record { 27 | public static void init(String path) { 28 | out = IOUtils.openOutEasy(path); 29 | } 30 | public static void finish() { 31 | if(out != null) out.close(); 32 | } 33 | public static void flush() { 34 | if(out != null) out.flush(); 35 | } 36 | 37 | private static void print(Object o) { 38 | if(out == null) return; 39 | for(int i = 0; i < indent; i++) out.print('\t'); 40 | out.println(o+""); 41 | } 42 | 43 | public static void setStruct(Object... keys) { 44 | addTabSepValues(".struct", keys); // Treated specially: enable the structure 45 | } 46 | public static void clearStruct() { 47 | add(".struct"); // Treated specially: disable the structure 48 | } 49 | public static void add(String key, Object... val) { 50 | addTabSepValues(key, val); 51 | } 52 | 53 | public static void addArray(String key, int[] values) { 54 | addArray(key, ListUtils.toObjArray(values)); 55 | } 56 | public static void addArray(String key, double[] values) { 57 | addArray(key, ListUtils.toObjArray(values)); 58 | } 59 | public static void addArray(String key, T[] values) { 60 | StringBuilder buf = new StringBuilder(); 61 | buf.append(".array\t"); 62 | buf.append(key); 63 | for(T value : values) { buf.append('\t'); buf.append(value); } 64 | print(buf.toString()); 65 | } 66 | public static void addArray(String key, List values) { 67 | StringBuilder buf = new StringBuilder(); 68 | buf.append(".array\t"); 69 | buf.append(key); 70 | for(Object value : values) { buf.append('\t'); buf.append(value); } 71 | print(buf.toString()); 72 | } 73 | 74 | public static void addObject(Object o, Object arg) { 75 | // Able to handle many different type of objects 76 | if(o instanceof Recordable) 77 | ((Recordable)o).record(arg); 78 | else if(o instanceof List) { 79 | List l = (List)o; 80 | if(l.size() != 0 && l.get(0) instanceof Recordable) { 81 | // If have list of recordable items, can put explicit structure 82 | for(int i = 0; i < l.size(); i++) 83 | addEmbedArg("index", i, l.get(i), arg); 84 | } 85 | else { 86 | // Example: (index=0 (value=a)) (index=1 (value=b)) 87 | setStruct("index", "value"); 88 | for(int i = 0; i < l.size(); i++) 89 | add(""+i, l.get(i)); 90 | } 91 | } 92 | else if(o instanceof StatFig) { 93 | StatFig f = (StatFig)o; 94 | Record.add("n", f.size()); 95 | Record.add("mean", f.mean()); 96 | if(o instanceof BigStatFig) { 97 | BigStatFig bf = (BigStatFig)o; 98 | Record.add("min", bf.getMin()); 99 | Record.add("max", bf.getMax()); 100 | } 101 | if(o instanceof FullStatFig) { 102 | FullStatFig ff = (FullStatFig)o; 103 | Record.add("stddev", ff.stddev()); 104 | Record.add("entropy", ff.entropy()); 105 | } 106 | } 107 | else { 108 | add("value", o); 109 | } 110 | } 111 | 112 | // Embedding means just creating a new node and adding the nodes 113 | public static void addEmbed(String key, Object o) { 114 | addEmbedArg(key, o, null); 115 | } 116 | public static void addEmbed(String key, Object val, Object o) { 117 | addEmbedArg(key, val, o, null); 118 | } 119 | public static void addEmbedArg(String key, Object o, Object arg) { 120 | begin(key); addObject(o, arg); end(); 121 | } 122 | public static void addEmbedArg(String key, Object val, Object o, Object arg) { 123 | begin(key, val); addObject(o, arg); end(); 124 | } 125 | 126 | public synchronized static void begin(String key) { add(key); indent++; } 127 | public synchronized static void begin(String key, Object val) { add(key, val); indent++; } 128 | public synchronized static void end() { indent--; flush(); } 129 | 130 | // Add a as the key, b as the list of things 131 | private static void addTabSepValues(String a, Object[] b) { 132 | StringBuilder buf = new StringBuilder(); 133 | buf.append(a); 134 | for(Object o : b) { 135 | buf.append('\t'); 136 | buf.append(o); 137 | } 138 | print(buf.toString()); 139 | } 140 | 141 | private static int indent; 142 | private static PrintWriter out; 143 | } 144 | -------------------------------------------------------------------------------- /src/main/java/fast/featurehmm/LogisticRegression.java: -------------------------------------------------------------------------------- 1 | /** 2 | * FAST v1.0 08/12/2014 3 | * 4 | * This code is only for research purpose not commercial purpose. 5 | * It is originally developed for research purpose and is still under improvement. 6 | * Please email to us if you want to keep in touch with the latest release. 7 | We sincerely welcome you to contact Yun Huang (huangyun.ai@gmail.com), or Jose P.Gonzalez-Brenes (josepablog@gmail.com) for problems in the code or cooperation. 8 | * We thank Taylor Berg-Kirkpatrick (tberg@cs.berkeley.edu) and Jean-Marc Francois (jahmm) for part of their codes that FAST is developed based on. 9 | * 10 | */ 11 | 12 | package fast.featurehmm; 13 | 14 | //import fast.common.Bijection; 15 | 16 | public class LogisticRegression { 17 | 18 | private final double[][] featureValues; 19 | private final int[] labels; 20 | private final double[] instanceWeights; 21 | private final String type; 22 | private final double[] regularizationWeights, regularizationBiases; 23 | private final double LBFGS_TOLERANCE; 24 | private final int LBFGS_MAX_ITERS; 25 | private final PdfFeatureAwareLogisticRegression pdf; 26 | private final double[] initalFeatureWeights; 27 | 28 | private int nbParameterizingFailed = 0; 29 | 30 | public LogisticRegression(PdfFeatureAwareLogisticRegression pdf, double[] instanceWeights, 31 | double[][] featureValues, int[] labels, 32 | String type, //Bijection featureMapping, Bijection labelMapping, 33 | double[] regularizationWeights, double[] regularizationBiases, 34 | double LBFGS_TOLERANCE, int LBFGS_MAX_ITERS ) { 35 | 36 | 37 | if (featureValues.length != labels.length) { 38 | System.out.println("featureValues.length != outcomes.length!"); 39 | System.exit(1); 40 | } 41 | if (featureValues[0].length != pdf.getFeatureWeights().length) { 42 | System.out 43 | .println("featureValues[0].length != opdf.featureWeights.length"); 44 | System.exit(1); 45 | } 46 | if (instanceWeights == null || instanceWeights.length != labels.length) { 47 | System.out.println("ERROR: instance/class weights are null || weights.length != outcomes.length"); 48 | System.exit(1); 49 | } 50 | 51 | this.pdf = new PdfFeatureAwareLogisticRegression(pdf); 52 | 53 | this.initalFeatureWeights = new double[pdf.getFeatureWeights().length]; 54 | for (int f = 0; f < initalFeatureWeights.length; f++) 55 | initalFeatureWeights[f] = pdf.getFeatureWeights()[f]; 56 | 57 | this.instanceWeights = new double[instanceWeights.length]; 58 | for (int i = 0; i < instanceWeights.length; i++) 59 | this.instanceWeights[i] = instanceWeights[i]; 60 | 61 | this.featureValues = new double[featureValues.length][featureValues[0].length]; 62 | for (int i = 0; i < featureValues.length; i++) 63 | for (int j = 0; j < featureValues[0].length; j++) 64 | this.featureValues[i][j] = featureValues[i][j]; 65 | 66 | this.labels = new int[labels.length]; 67 | for (int i = 0; i < labels.length; i++) 68 | this.labels[i] = labels[i]; 69 | 70 | //this.featureMapping = new Bijection(featureMapping); 71 | //this.labelMapping = new Bijection(labelMapping); 72 | this.type = type; 73 | this.regularizationWeights = regularizationWeights; 74 | this.regularizationBiases = regularizationBiases; 75 | this.LBFGS_TOLERANCE = LBFGS_TOLERANCE; 76 | this.LBFGS_MAX_ITERS = LBFGS_MAX_ITERS; 77 | } 78 | 79 | public double[] train() { 80 | LBFGS LBFGSTrain = new LBFGS(initalFeatureWeights, pdf, instanceWeights, featureValues, labels, type, regularizationWeights, 81 | regularizationBiases, LBFGS_MAX_ITERS, LBFGS_TOLERANCE); 82 | double[] finalFeatureWeights = LBFGSTrain.run(); 83 | nbParameterizingFailed = LBFGSTrain.getParameterizingResult(); 84 | 85 | if (finalFeatureWeights.length != featureValues[0].length) {//finalFeatureWeights.length != featureMapping.getSize() | 86 | System.out.println("featureWeights.length != features[0].length!");//featureWeights.length != featureMapping.getSize() || 87 | System.exit(1); 88 | } 89 | return finalFeatureWeights; 90 | } 91 | 92 | public int getParameterizingResult(){ 93 | return nbParameterizingFailed; 94 | } 95 | 96 | // public void checkBiasFeature(Opts opts) { 97 | // if (opts.bias >= 0 && featureMapping.get("bias") == null) { 98 | // System.out.println("ERROR: Requiring bias feature inside the datapoint! bias=" 99 | // + opts.bias // + ",onebiasfeature=" + opts.oneBiasFeature 100 | // + ",featureMapping.get(bias)=" + featureMapping.get("bias")); 101 | // System.exit(1); 102 | // } 103 | // if (opts.bias < 0 && featureMapping.get("bias") != null) { 104 | // // already add bias inside the feature vector; 105 | // System.out.println("ERROR: bias=" + opts.bias 106 | // + ",featureMapping.get(bias)=" + featureMapping.get("bias")); 107 | // System.exit(1); 108 | // } 109 | // } 110 | 111 | // public void setFeatureWeights(double[] weights) { 112 | // featureWeights = weights; 113 | // } 114 | 115 | // public double[][] getFeatures() { 116 | // return featureValues; 117 | // } 118 | // 119 | // public int[] getLabels() { 120 | // return labels; 121 | // } 122 | // 123 | // public double[] getInstanceWeights() { 124 | // return instanceWeights; 125 | // } 126 | // 127 | // public double[] getFeatureWeights() { 128 | // return featureWeights; 129 | // } 130 | 131 | } 132 | -------------------------------------------------------------------------------- /data/others/FAST+subskill_test0.txt: -------------------------------------------------------------------------------- 1 | student outcome KCs *features_AbstractMethodDefinition *features_ActualMethodParameter *features_AddExpression *features_ConstructorCall *features_DivideExpression *features_DoubleDataType *features_ImplementsSpecification *features_ImportStatement *features_IntDataType *features_InterfaceDefinition *features_MethodImplementation *features_MultiplyExpression *features_NotEqualExpression *features_NotExpression *features_ObjectCreationStatement *features_ObjectMethodInvocation *features_SimpleAssignmentExpression *features_StringDataType *features_StringInitializationStatement *features_StringLiteral *features_StringVariable *features_SubtractExpression *features_ThisReference *features_java.lang.String.equalsIgnoreCase *features_java.lang.String.length *features_java.lang.String.replace *features_java.lang.System.out.print 2644 correct Variables NULL NULL 1 NULL 0 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 1 0 0 0 0 1 NULL NULL NULL NULL 0 2644 correct Variables NULL NULL 1 NULL 0 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 1 0 0 0 0 1 NULL NULL NULL NULL 0 2644 correct Variables NULL NULL 1 NULL 0 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 1 0 0 0 0 1 NULL NULL NULL NULL 0 2644 correct Variables NULL NULL 1 NULL 0 NULL NULL NULL 1 NULL NULL 1 NULL NULL NULL NULL 0 0 0 0 0 0 NULL NULL NULL NULL 0 2644 correct Variables NULL NULL 1 NULL 0 NULL NULL NULL 1 NULL NULL 1 NULL NULL NULL NULL 0 0 0 0 0 0 NULL NULL NULL NULL 0 2644 correct Variables NULL NULL 0 NULL 0 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 0 0 0 0 0 0 NULL NULL NULL NULL 1 2644 incorrect Variables NULL NULL 0 NULL 1 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 0 0 0 0 0 1 NULL NULL NULL NULL 1 2644 incorrect Variables NULL NULL 0 NULL 1 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 0 0 0 0 0 1 NULL NULL NULL NULL 1 2644 incorrect Variables NULL NULL 0 NULL 1 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 0 0 0 0 0 1 NULL NULL NULL NULL 1 2644 correct Variables NULL NULL 0 NULL 1 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 0 0 0 0 0 1 NULL NULL NULL NULL 1 2644 correct Variables NULL NULL 0 NULL 1 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 0 0 0 0 0 1 NULL NULL NULL NULL 1 2644 correct Variables NULL NULL 0 NULL 0 NULL NULL NULL 0 NULL NULL 0 NULL NULL NULL NULL 0 1 1 1 1 0 NULL NULL NULL NULL 1 2644 correct Variables NULL NULL 0 NULL 0 NULL NULL NULL 0 NULL NULL 0 NULL NULL NULL NULL 0 1 1 1 1 0 NULL NULL NULL NULL 1 2644 incorrect Objects NULL 1 NULL 1 NULL 0 NULL 1 NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 0 0 NULL 2644 correct Objects NULL 1 NULL 1 NULL 0 NULL 1 NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 0 0 NULL 2644 correct Objects NULL 1 NULL 1 NULL 1 NULL 1 NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 0 0 NULL 2644 correct Objects NULL 1 NULL 1 NULL 1 NULL 1 NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 0 0 NULL 2644 correct Objects NULL 1 NULL 1 NULL 0 NULL 1 NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 0 0 NULL 2644 correct Objects NULL 1 NULL 1 NULL 0 NULL 1 NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 0 0 NULL 2644 correct Objects NULL 0 NULL 0 NULL 0 NULL 0 NULL NULL NULL NULL NULL NULL 0 0 NULL NULL NULL NULL NULL NULL NULL NULL 1 0 NULL 2644 correct Objects NULL 0 NULL 0 NULL 0 NULL 0 NULL NULL NULL NULL NULL NULL 0 0 NULL NULL NULL NULL NULL NULL NULL NULL 1 0 NULL 2644 correct Objects NULL 1 NULL 0 NULL 0 NULL 0 NULL NULL NULL NULL NULL NULL 0 1 NULL NULL NULL NULL NULL NULL NULL NULL 0 1 NULL 2991 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 correct Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 correct Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 correct Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 correct Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 correct Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 correct Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2991 correct Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL -------------------------------------------------------------------------------- /src/main/java/fast/featurehmm/BaumWelchScaledLearner.java: -------------------------------------------------------------------------------- 1 | /** 2 | * FAST v1.0 08/12/2014 3 | * 4 | * This code is only for research purpose not commercial purpose. 5 | * It is originally developed for research purpose and is still under improvement. 6 | * Please email to us if you want to keep in touch with the latest release. 7 | We sincerely welcome you to contact Yun Huang (huangyun.ai@gmail.com), or Jose P.Gonzalez-Brenes (josepablog@gmail.com) for problems in the code or cooperation. 8 | * We thank Taylor Berg-Kirkpatrick (tberg@cs.berkeley.edu) and Jean-Marc Francois (jahmm) for part of their codes that FAST is developed based on. 9 | * 10 | */ 11 | 12 | /* 13 | * This is built based on: 14 | * jahmm package - v0.6.1 15 | * Copyright (c) 2004-2006, Jean-Marc Francois. 16 | * 17 | */ 18 | 19 | package fast.featurehmm; 20 | 21 | //import hmm.ForwardBackwardScaledCalculator; 22 | import java.util.Arrays; 23 | import java.util.EnumSet; 24 | import java.util.Iterator; 25 | import java.util.List; 26 | import fast.data.DataPoint; 27 | //import fast.experiment.Opts; 28 | 29 | //import be.ac.ulg.montefiore.run.jahmm.ForwardBackwardCalculator; 30 | 31 | /** 32 | * An implementation of the Baum-Welch learning algorithm. It uses a scaling 33 | * mechanism so as to avoid underflows. 34 | *

35 | * For more information on the scaling procedure, read Rabiner and 36 | * Juang's Fundamentals of speech recognition (Prentice Hall, 37 | * 1993). 38 | */ 39 | public class BaumWelchScaledLearner extends BaumWelchLearner { 40 | /** 41 | * Initializes a Baum-Welch algorithm implementation. 42 | */ 43 | public BaumWelchScaledLearner(boolean parameterizing, boolean parameterizedInit, boolean parameterizedTran, boolean parameterizedEmit, 44 | int nbHiddenStates, 45 | double EM_TOLERANCE, double EPS, double ACCETABLE_LL_DECREASE, int EM_MAX_ITERS) { 46 | super(parameterizing, parameterizedInit, parameterizedTran, parameterizedEmit, nbHiddenStates, EM_TOLERANCE, EPS, ACCETABLE_LL_DECREASE, EM_MAX_ITERS); 47 | } 48 | 49 | protected ForwardBackwardCalculator generateForwardBackwardCalculator( 50 | List sequence, FeatureHMM hmm) { 51 | return new ForwardBackwardScaledCalculator(sequence, hmm, 52 | EnumSet.allOf(ForwardBackwardCalculator.Computation.class)); 53 | } 54 | 55 | /** 56 | * Here, the xi (and, thus, gamma) values are not divided by the probability 57 | * of the sequence because this probability might be too small and induce an 58 | * underflow. xi[t][i][j] still can be interpreted as P[q_t = i and q_(t+1) = 59 | * j | obsSeq, hmm] because we assume that the scaling factors are such that 60 | * their product is equal to the inverse of the probability of the sequence. 61 | * 62 | * hy: the scaling factors are ctFactors, of which each 63 | * ctFactors[t]=P(O1...Ot) 64 | */ 65 | protected double[][][] estimateXi(List sequence, 66 | ForwardBackwardCalculator fbc, FeatureHMM hmm) { 67 | // hy* 68 | if (sequence.size() <= 1) 69 | throw new IllegalArgumentException("Observation sequence too " + "short"); 70 | // *hy 71 | 72 | double xi[][][] = new double[sequence.size() - 1][hmm.getNbHiddenStates()][hmm.getNbHiddenStates()]; 73 | 74 | Iterator seqIterator = sequence.iterator(); 75 | seqIterator.next(); 76 | // DataPoint currentO = seqIterator.next(); 77 | 78 | for (int t = 0; t < sequence.size() - 1; t++) { 79 | DataPoint nextO = seqIterator.next(); 80 | for (int i = 0; i < hmm.getNbHiddenStates(); i++) { 81 | // double[] currentOFeatureValues = currentO.getFeatures(i); 82 | for (int j = 0; j < hmm.getNbHiddenStates(); j++) { 83 | // double[] nextOFeatureValues = nextO.getFeatures(j); 84 | // TODO: hy changed nextO.getFeatures(j) to nextO.getFeatures(i) 85 | xi[t][i][j] = fbc.alphaElement(t, i) 86 | * hmm.getTransitionij(i, j, nextO.getFeatures(i, 1)) 87 | * hmm.getEmissionjk(j, nextO.getOutcome(), 88 | nextO.getFeatures(j, 2)) * fbc.betaElement(t + 1, j); 89 | } 90 | } 91 | // currentO = nextO; 92 | } 93 | 94 | return xi; 95 | } 96 | 97 | // protected double[][][] estimateXi(List sequence, 98 | // ForwardBackwardCalculator fbc, FeatureHMM hmm) { 99 | // // hy* 100 | // if (sequence.size() <= 1) 101 | // throw new IllegalArgumentException("Observation sequence too " + "short"); 102 | // // *hy 103 | // 104 | // double xi[][][] = new double[sequence.size() - 1][hmm.nbStates()][hmm 105 | // .nbStates()]; 106 | // 107 | // Iterator seqIterator = sequence.iterator(); 108 | // seqIterator.next(); 109 | // 110 | // for (int t = 0; t < sequence.size() - 1; t++) { 111 | // DataPoint observation = seqIterator.next(); 112 | // 113 | // for (int i = 0; i < hmm.nbStates(); i++) 114 | // for (int j = 0; j < hmm.nbStates(); j++) 115 | // xi[t][i][j] = fbc.alphaElement(t, i) 116 | // * hmm.getAij(i, j) 117 | // * hmm.getOpdf(j).probability(observation.getFeatures(j), 118 | // observation.getOutcome()) * fbc.betaElement(t + 1, j); 119 | // } 120 | // 121 | // return xi; 122 | // } 123 | 124 | /** 125 | * @author hy 126 | * @date 11/16/13 When current sequence's length is 1, use alpha, beta, P(Dd) 127 | * instead of xi to estimate gamma, to be consistent with the estimateXi 128 | * in BaumWelchScaledLearner, not sure whether divided by p(Dd) or 129 | * not... 130 | * @param fbc 131 | * @return 132 | */ 133 | protected double[][] estimateGamma(List sequence, 134 | ForwardBackwardCalculator fbc, FeatureHMM hmm) { 135 | double[][] gamma = new double[sequence.size()][hmm.getNbHiddenStates()]; 136 | 137 | for (int t = 0; t < sequence.size(); t++) 138 | Arrays.fill(gamma[t], 0.); 139 | 140 | // double probability = fbc.probability(); 141 | 142 | for (int t = 0; t < sequence.size(); t++) { 143 | for (int i = 0; i < hmm.getNbHiddenStates(); i++) { 144 | gamma[t][i] = fbc.alphaElement(t, i) * fbc.betaElement(t, i) 145 | * fbc.probability(); 146 | } 147 | } 148 | 149 | return gamma; 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /data/others/FAST+subskill_train0.txt: -------------------------------------------------------------------------------- 1 | student outcome KCs *features_AbstractMethodDefinition *features_ActualMethodParameter *features_AddExpression *features_ConstructorCall *features_DivideExpression *features_DoubleDataType *features_ImplementsSpecification *features_ImportStatement *features_IntDataType *features_InterfaceDefinition *features_MethodImplementation *features_MultiplyExpression *features_NotEqualExpression *features_NotExpression *features_ObjectCreationStatement *features_ObjectMethodInvocation *features_SimpleAssignmentExpression *features_StringDataType *features_StringInitializationStatement *features_StringLiteral *features_StringVariable *features_SubtractExpression *features_ThisReference *features_java.lang.String.equalsIgnoreCase *features_java.lang.String.length *features_java.lang.String.replace *features_java.lang.System.out.print 2647 correct Variables NULL NULL 1 NULL 0 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 1 0 0 0 0 1 NULL NULL NULL NULL 0 2647 correct Variables NULL NULL 1 NULL 0 NULL NULL NULL 1 NULL NULL 1 NULL NULL NULL NULL 0 0 0 0 0 0 NULL NULL NULL NULL 0 2647 correct Variables NULL NULL 0 NULL 0 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 0 0 0 0 0 0 NULL NULL NULL NULL 1 2647 incorrect Variables NULL NULL 0 NULL 1 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 0 0 0 0 0 1 NULL NULL NULL NULL 1 2647 correct Variables NULL NULL 0 NULL 0 NULL NULL NULL 0 NULL NULL 0 NULL NULL NULL NULL 0 1 1 1 1 0 NULL NULL NULL NULL 1 2647 correct Variables NULL NULL 0 NULL 1 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 0 0 0 0 0 1 NULL NULL NULL NULL 1 2647 incorrect Objects NULL 1 NULL 1 NULL 0 NULL 1 NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 0 0 NULL 2647 correct Objects NULL 1 NULL 1 NULL 0 NULL 1 NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 0 0 NULL 2647 correct Objects NULL 1 NULL 1 NULL 1 NULL 1 NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 0 0 NULL 2647 incorrect Objects NULL 1 NULL 1 NULL 0 NULL 1 NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 0 0 NULL 2647 incorrect Objects NULL 1 NULL 1 NULL 0 NULL 1 NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 0 0 NULL 2647 correct Objects NULL 1 NULL 1 NULL 0 NULL 1 NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 0 0 NULL 2647 incorrect Objects NULL 0 NULL 0 NULL 0 NULL 0 NULL NULL NULL NULL NULL NULL 0 0 NULL NULL NULL NULL NULL NULL NULL NULL 1 0 NULL 2647 correct Objects NULL 0 NULL 0 NULL 0 NULL 0 NULL NULL NULL NULL NULL NULL 0 0 NULL NULL NULL NULL NULL NULL NULL NULL 1 0 NULL 2647 incorrect Objects NULL 1 NULL 0 NULL 0 NULL 0 NULL NULL NULL NULL NULL NULL 0 1 NULL NULL NULL NULL NULL NULL NULL NULL 0 1 NULL 2647 correct Objects NULL 1 NULL 0 NULL 0 NULL 0 NULL NULL NULL NULL NULL NULL 0 1 NULL NULL NULL NULL NULL NULL NULL NULL 0 1 NULL 2808 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2808 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2808 correct Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2808 correct Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2808 correct Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2808 correct Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2808 correct Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2534 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2534 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2534 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2534 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2534 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2534 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2534 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2534 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2534 incorrect Interfaces 1 NULL NULL NULL NULL NULL 1 NULL NULL 1 1 NULL 1 1 NULL NULL NULL NULL NULL NULL NULL NULL 1 1 NULL NULL NULL 2769 correct Variables NULL NULL 1 NULL 0 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 1 0 0 0 0 1 NULL NULL NULL NULL 0 2769 correct Variables NULL NULL 1 NULL 0 NULL NULL NULL 1 NULL NULL 1 NULL NULL NULL NULL 0 0 0 0 0 0 NULL NULL NULL NULL 0 2769 correct Variables NULL NULL 0 NULL 0 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 0 0 0 0 0 0 NULL NULL NULL NULL 1 2769 correct Variables NULL NULL 0 NULL 1 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 0 0 0 0 0 1 NULL NULL NULL NULL 1 2769 correct Variables NULL NULL 0 NULL 0 NULL NULL NULL 0 NULL NULL 0 NULL NULL NULL NULL 0 1 1 1 1 0 NULL NULL NULL NULL 1 2768 correct Variables NULL NULL 1 NULL 0 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 1 0 0 0 0 1 NULL NULL NULL NULL 0 2768 correct Variables NULL NULL 1 NULL 0 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 1 0 0 0 0 1 NULL NULL NULL NULL 0 2768 correct Variables NULL NULL 0 NULL 0 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 0 0 0 0 0 0 NULL NULL NULL NULL 1 2768 correct Variables NULL NULL 0 NULL 0 NULL NULL NULL 0 NULL NULL 0 NULL NULL NULL NULL 0 1 1 1 1 0 NULL NULL NULL NULL 1 2768 correct Variables NULL NULL 1 NULL 0 NULL NULL NULL 1 NULL NULL 0 NULL NULL NULL NULL 1 0 0 0 0 1 NULL NULL NULL NULL 0 -------------------------------------------------------------------------------- /data/IRT_exp/train0.csv: -------------------------------------------------------------------------------- 1 | KCs,student,problem,step,outcome,fold Nested_Loops,2515,jNested1,576,correct,1 Nested_Loops,2515,jNested2,577,correct,1 Nested_Loops,2515,jNested3,578,correct,1 Nested_Loops,2520,jNested1,787,correct,1 Nested_Loops,2520,jNested2,788,incorrect,1 Nested_Loops,2520,jNested3,805,incorrect,1 Nested_Loops,2522,jNested1,1222,correct,1 Nested_Loops,2522,jNested2,1227,incorrect,1 Nested_Loops,2522,jNested3,1236,incorrect,1 Nested_Loops,2528,jNested1,1481,incorrect,1 Nested_Loops,2528,jNested2,1484,incorrect,1 Nested_Loops,2606,jNested1,1990,correct,1 Nested_Loops,2606,jNested2,1991,incorrect,1 Nested_Loops,2641,jNested1,2892,correct,1 Nested_Loops,2641,jNested2,2896,incorrect,1 Nested_Loops,2641,jNested3,2909,incorrect,1 Nested_Loops,2642,jNested1,3150,correct,1 Nested_Loops,2642,jNested2,3151,incorrect,1 Nested_Loops,2642,jNested3,3158,incorrect,1 Nested_Loops,2647,jNested1,3709,correct,1 Nested_Loops,2648,jNested1,3996,correct,1 Nested_Loops,2648,jNested2,3997,incorrect,1 Nested_Loops,2648,jNested3,4002,incorrect,1 Nested_Loops,2649,jNested1,4180,correct,1 Nested_Loops,2650,jNested1,4371,correct,1 Nested_Loops,2652,jNested1,4519,correct,1 Nested_Loops,2652,jNested2,4523,incorrect,1 Nested_Loops,2652,jNested3,4529,incorrect,1 Nested_Loops,2653,jNested1,4791,correct,1 Nested_Loops,2653,jNested2,4792,correct,1 Nested_Loops,2653,jNested3,4797,incorrect,1 Nested_Loops,2655,jNested1,4929,incorrect,1 Nested_Loops,2655,jNested2,4931,incorrect,1 Nested_Loops,2655,jNested3,4937,incorrect,1 Nested_Loops,2664,jNested1,5524,incorrect,1 Nested_Loops,2664,jNested2,5528,incorrect,1 Nested_Loops,2664,jNested3,5536,incorrect,1 Nested_Loops,2665,jNested1,5768,correct,1 Nested_Loops,2665,jNested2,5769,incorrect,1 Nested_Loops,2666,jNested1,5821,correct,1 Nested_Loops,2666,jNested2,5822,incorrect,1 Nested_Loops,2667,jNested1,6045,correct,1 Nested_Loops,2668,jNested1,6333,correct,1 Nested_Loops,2668,jNested2,6339,incorrect,1 Nested_Loops,2668,jNested3,6346,incorrect,1 Nested_Loops,2669,jNested1,6736,correct,1 Nested_Loops,2669,jNested2,6737,correct,1 Nested_Loops,2671,jNested1,6949,correct,1 Nested_Loops,2671,jNested3,6952,incorrect,1 Nested_Loops,2674,jNested1,7321,correct,1 Nested_Loops,2674,jNested2,7324,incorrect,1 Nested_Loops,2674,jNested3,7330,incorrect,1 Nested_Loops,2675,jNested1,7635,incorrect,1 Nested_Loops,2676,jNested1,7733,incorrect,1 Nested_Loops,2676,jNested2,7735,correct,1 Nested_Loops,2676,jNested3,7736,correct,1 Nested_Loops,2759,jNested1,7849,correct,1 Nested_Loops,2759,jNested2,7850,incorrect,1 Nested_Loops,2759,jNested3,7852,incorrect,1 Nested_Loops,2761,jNested1,8213,correct,1 Nested_Loops,2761,jNested2,8214,incorrect,1 Nested_Loops,2763,jNested1,8621,correct,1 Nested_Loops,2765,jNested1,8947,correct,1 Nested_Loops,2765,jNested2,8950,incorrect,1 Nested_Loops,2765,jNested3,8959,incorrect,1 Nested_Loops,2766,jNested1,9266,correct,1 Nested_Loops,2766,jNested2,9267,correct,1 Nested_Loops,2772,jNested1,9691,incorrect,1 Nested_Loops,2772,jNested2,9694,incorrect,1 Nested_Loops,2772,jNested3,9702,incorrect,1 Nested_Loops,2773,jNested1,9980,incorrect,1 Nested_Loops,2773,jNested2,9982,incorrect,1 Nested_Loops,2773,jNested3,9986,incorrect,1 Nested_Loops,2774,jNested1,10198,correct,1 Nested_Loops,2774,jNested2,10199,correct,1 Nested_Loops,2774,jNested3,10200,incorrect,1 Nested_Loops,2777,jNested1,10360,correct,1 Nested_Loops,2777,jNested2,10365,incorrect,1 Nested_Loops,2777,jNested3,10370,incorrect,1 Nested_Loops,2780,jNested1,11056,correct,1 Nested_Loops,2780,jNested2,11058,incorrect,1 Nested_Loops,2780,jNested3,11066,incorrect,1 Nested_Loops,2785,jNested1,11769,correct,1 Nested_Loops,2785,jNested2,11771,incorrect,1 Nested_Loops,2785,jNested3,11775,correct,1 Nested_Loops,2787,jNested1,12071,incorrect,1 Nested_Loops,2787,jNested2,12073,correct,1 Nested_Loops,2791,jNested1,12575,correct,1 Nested_Loops,2791,jNested2,12578,correct,1 Nested_Loops,2791,jNested3,12587,incorrect,1 Nested_Loops,2792,jNested1,12832,incorrect,1 Nested_Loops,2792,jNested2,12835,incorrect,1 Nested_Loops,2805,jNested1,13971,correct,1 Nested_Loops,2805,jNested2,13972,incorrect,1 Nested_Loops,2805,jNested3,13977,correct,1 Nested_Loops,2806,jNested1,14358,correct,1 Nested_Loops,2806,jNested2,14361,incorrect,1 Nested_Loops,2806,jNested3,14370,incorrect,1 Nested_Loops,2807,jNested1,14804,correct,1 Nested_Loops,2807,jNested2,14805,incorrect,1 Nested_Loops,2808,jNested1,14822,correct,1 Nested_Loops,2808,jNested2,14823,incorrect,1 Nested_Loops,2808,jNested3,14826,incorrect,1 Nested_Loops,2812,jNested1,15488,correct,1 Nested_Loops,2812,jNested2,15489,incorrect,1 Nested_Loops,2986,jNested1,15756,correct,1 Nested_Loops,2986,jNested2,15757,incorrect,1 Nested_Loops,2989,jNested1,16159,correct,1 Nested_Loops,2989,jNested2,16167,incorrect,1 Nested_Loops,2989,jNested3,16185,incorrect,1 Nested_Loops,2991,jNested1,16585,correct,1 Nested_Loops,2991,jNested2,16592,incorrect,1 Nested_Loops,2991,jNested3,16600,incorrect,1 Nested_Loops,2992,jNested1,17033,correct,1 Nested_Loops,2992,jNested2,17034,incorrect,1 Nested_Loops,2996,jNested1,17266,correct,1 Nested_Loops,2996,jNested2,17268,incorrect,1 Nested_Loops,2996,jNested3,17270,incorrect,1 Nested_Loops,2997,jNested1,17680,incorrect,1 Nested_Loops,2997,jNested2,17684,incorrect,1 Nested_Loops,2997,jNested3,17698,incorrect,1 Nested_Loops,2999,jNested1,17924,correct,1 Nested_Loops,2999,jNested2,17926,incorrect,1 Nested_Loops,3000,jNested1,18071,correct,1 Nested_Loops,3000,jNested2,18072,incorrect,1 Nested_Loops,3003,jNested1,18374,correct,1 Nested_Loops,3003,jNested2,18375,incorrect,1 Nested_Loops,3006,jNested1,18476,correct,1 Nested_Loops,3006,jNested2,18477,incorrect,1 Nested_Loops,3007,jNested1,18924,incorrect,1 Nested_Loops,3007,jNested2,18930,incorrect,1 Nested_Loops,3011,jNested1,19372,correct,1 Nested_Loops,3011,jNested2,19375,incorrect,1 Nested_Loops,3012,jNested1,19533,correct,1 Nested_Loops,3012,jNested2,19534,incorrect,1 Nested_Loops,3012,jNested3,19537,incorrect,1 Nested_Loops,3018,jNested1,20011,correct,1 Nested_Loops,3018,jNested2,20014,correct,1 Nested_Loops,3019,jNested1,20298,correct,1 Nested_Loops,3019,jNested2,20300,incorrect,1 Nested_Loops,3020,jNested1,20577,incorrect,1 Nested_Loops,3020,jNested2,20580,incorrect,1 Nested_Loops,3021,jNested1,20765,correct,1 Nested_Loops,3022,jNested1,20858,correct,1 Nested_Loops,3086,jNested1,21132,correct,1 Nested_Loops,3086,jNested2,21133,incorrect,1 -------------------------------------------------------------------------------- /src/main/java/fig/basic/T2VMap.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import fig.basic.*; 4 | import static fig.basic.LogInfo.*; 5 | import java.io.*; 6 | import java.util.*; 7 | 8 | /** 9 | * Maps (object, object) pairs to objects. 10 | * Based on T2VMap. 11 | * It's useful when the number of second objects for a fixed first string 12 | * is small. 13 | * Most of the operations in this class parallel that of T2VMap, 14 | * but just applied to two keys. The implementation is essentially dispatching 15 | * down to T2VMap. 16 | * Typical usage: conditional probability table. 17 | */ 18 | public class T2VMap extends AbstractT2Map implements Iterable>>, Serializable { 19 | protected static final long serialVersionUID = 42; 20 | 21 | public T2VMap() { 22 | this.keyFunc = AbstractTMap.defaultFunctionality; 23 | this.valueFunc = AbstractTMap.defaultFunctionality; 24 | } 25 | 26 | public T2VMap(AbstractTMap.Functionality keyFunc, AbstractTMap.Functionality valueFunc) { 27 | this.keyFunc = keyFunc; 28 | this.valueFunc = valueFunc; 29 | } 30 | 31 | public void initKeys(AbstractT2Map map) { 32 | this.locked = map.locked; 33 | 34 | // HACK: BAD dependencies 35 | if(map instanceof T2DoubleMap) { 36 | for(Map.Entry> e : (T2DoubleMap)map) 37 | put(e.getKey(), new TVMap(e.getValue(), valueFunc)); 38 | } 39 | else if(map instanceof T2VMap) { // Not exactly right: need to check type of V 40 | for(Map.Entry> e : ((T2VMap)map)) 41 | put(e.getKey(), new TVMap(e.getValue(), valueFunc)); 42 | } 43 | else 44 | throw new RuntimeException(""); 45 | } 46 | 47 | // Main operations 48 | public boolean containsKey(S key1, T key2) { 49 | TVMap map = getMap(key1, false); 50 | return map != null && map.containsKey(key2); 51 | } 52 | public V get(S key1, T key2, V defaultValue) { 53 | TVMap map = getMap(key1, false); 54 | return map == null ? defaultValue : map.get(key2, defaultValue); 55 | } 56 | public V getWithErrorMsg(S key1, T key2, V defaultValue) { 57 | TVMap map = getMap(key1, false); 58 | if(map == null) errors("(%s, %s) not in map, using %f", key1, key2, defaultValue); 59 | return map == null ? defaultValue : map.get(key2, defaultValue); 60 | } 61 | public V getSure(S key1, T key2) { 62 | // Throw exception if key doesn't exist. 63 | TVMap map = getMap(key1, false); 64 | if(map == null) throw new RuntimeException("Missing key: " + key1); 65 | return map.getSure(key2); 66 | } 67 | public void put(S key1, TVMap map) { // Risky 68 | if(locked) 69 | throw new RuntimeException("Cannot make new entry for " + key1 + ", because map is locked"); 70 | maps.put(key1, map); 71 | } 72 | public void put(S key1, T key2, V value) { 73 | TVMap map = getMap(key1, true); 74 | map.put(key2, value); 75 | } 76 | public int size() { return maps.size(); } 77 | // Return number of entries 78 | public int totalSize() { 79 | int n = 0; 80 | for(TVMap map : maps.values()) 81 | n += map.size(); 82 | return n; 83 | } 84 | public void gut() { 85 | for(TVMap map : maps.values()) 86 | map.gut(); 87 | } 88 | 89 | public Iterator>> iterator() { 90 | return maps.entrySet().iterator(); 91 | } 92 | public Set>> entrySet() { return maps.entrySet(); } 93 | public Set keySet() { return maps.keySet(); } 94 | public Collection> values() { return maps.values(); } 95 | 96 | // If keys are locked, we can share the same keys. 97 | public T2VMap copy() { 98 | return copy(newMap()); 99 | } 100 | public T2VMap copy(T2VMap newMap) { 101 | newMap.locked = locked; 102 | for(Map.Entry> e : maps.entrySet()) 103 | newMap.maps.put(e.getKey(), e.getValue().copy()); 104 | return newMap; 105 | } 106 | public T2VMap restrict(Set set1, Set set2) { 107 | return restrict(newMap(), set1, set2); 108 | } 109 | public T2VMap restrict(T2VMap newMap, Set set1, Set set2) { 110 | newMap.locked = locked; 111 | for(Map.Entry> e : maps.entrySet()) 112 | if(set1.contains(e.getKey())) 113 | newMap.maps.put(e.getKey(), e.getValue().restrict(set2)); 114 | return newMap; 115 | } 116 | public T2VMap reverse(T2VMap newMap) { // Return a map with (key2, key1) pairs 117 | for(Map.Entry> e1 : maps.entrySet()) { 118 | S key1 = e1.getKey(); 119 | TVMap map = e1.getValue(); 120 | for(TVMap.Entry e2 : map) { 121 | T key2 = e2.getKey(); 122 | V value = e2.getValue(); 123 | newMap.put(key2, key1, value); 124 | } 125 | } 126 | return newMap; 127 | } 128 | 129 | public void lock() { 130 | for(TVMap map : maps.values()) 131 | map.lock(); 132 | } 133 | public void switchToSortedList() { 134 | for(TVMap map : maps.values()) 135 | map.switchToSortedList(); 136 | } 137 | public void switchToHashTable() { 138 | for(TVMap map : maps.values()) 139 | map.switchToHashTable(); 140 | } 141 | 142 | protected T2VMap newMap() { return new T2VMap(keyFunc, valueFunc); } 143 | 144 | //////////////////////////////////////////////////////////// 145 | 146 | public TVMap getMap(S key1, boolean modify) { 147 | if(key1 == lastKey) return lastMap; 148 | 149 | TVMap map = maps.get(key1); 150 | if(map != null) return map; 151 | if(modify) { 152 | if(locked) 153 | throw new RuntimeException("Cannot make new entry for " + key1 + ", because map is locked"); 154 | maps.put(key1, map = new TVMap(keyFunc, valueFunc)); 155 | 156 | lastKey = key1; 157 | lastMap = map; 158 | return map; 159 | } 160 | else 161 | return null; 162 | } 163 | 164 | //////////////////////////////////////////////////////////// 165 | 166 | private Map> maps = new HashMap>(); 167 | private S lastKey; // Cache last access 168 | private TVMap lastMap; // Cache last access 169 | protected TVMap.Functionality valueFunc; 170 | } 171 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/T2DoubleMap.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import fig.basic.*; 4 | import static fig.basic.LogInfo.*; 5 | import java.io.*; 6 | import java.util.*; 7 | 8 | /** Maps (object, object) pairs to doubles. 9 | * Based on TDoubleMap. 10 | * It's useful when the number of second objects for a fixed first string 11 | * is small. 12 | * Most of the operations in this class parallel that of TDoubleMap, 13 | * but just applied to two keys. The implementation is essentially dispatching 14 | * down to TDoubleMap. 15 | * Typical usage: conditional probability table. 16 | */ 17 | public class T2DoubleMap extends AbstractT2Map implements Iterable>>, Serializable { 18 | protected static final long serialVersionUID = 42; 19 | 20 | public T2DoubleMap() { 21 | this.keyFunc = AbstractTMap.defaultFunctionality; 22 | } 23 | 24 | public T2DoubleMap(AbstractTMap.Functionality keyFunc) { 25 | this.keyFunc = keyFunc; 26 | } 27 | 28 | public void initKeys(AbstractT2Map map) { 29 | this.locked = map.locked; 30 | 31 | // HACK: BAD dependencies 32 | if(map instanceof T2DoubleMap) { 33 | for(Map.Entry> e : (T2DoubleMap)map) 34 | put(e.getKey(), new TDoubleMap(e.getValue())); 35 | } 36 | else if(map instanceof T2VMap) { // Not exactly right: need to check type of V 37 | for(Map.Entry> e : ((T2VMap)map)) 38 | put(e.getKey(), new TDoubleMap(e.getValue())); 39 | } 40 | else 41 | throw new RuntimeException(""); 42 | } 43 | 44 | // Main operations 45 | public boolean containsKey(S key1, T key2) { 46 | TDoubleMap map = getMap(key1, false); 47 | return map != null && map.containsKey(key2); 48 | } 49 | public double get(S key1, T key2, double defaultValue) { 50 | TDoubleMap map = getMap(key1, false); 51 | return map == null ? defaultValue : map.get(key2, defaultValue); 52 | } 53 | public double getWithErrorMsg(S key1, T key2, double defaultValue) { 54 | TDoubleMap map = getMap(key1, false); 55 | if(map == null) errors("(%s, %s) not in map, using %f", key1, key2, defaultValue); 56 | return map == null ? defaultValue : map.get(key2, defaultValue); 57 | } 58 | public double getSure(S key1, T key2) { 59 | // Throw exception if key doesn't exist. 60 | TDoubleMap map = getMap(key1, false); 61 | if(map == null) throw new RuntimeException("Missing key: " + key1); 62 | return map.getSure(key2); 63 | } 64 | public void put(S key1, TDoubleMap map) { // Risky 65 | if(locked) 66 | throw new RuntimeException("Cannot make new entry for " + key1 + ", because map is locked"); 67 | maps.put(key1, map); 68 | } 69 | public void put(S key1, T key2, double value) { 70 | TDoubleMap map = getMap(key1, true); 71 | map.put(key2, value); 72 | } 73 | public void incr(S key1, T key2, double dValue) { 74 | TDoubleMap map = getMap(key1, true); 75 | map.incr(key2, dValue); 76 | } 77 | public int size() { return maps.size(); } 78 | // Return number of entries 79 | public int totalSize() { 80 | int n = 0; 81 | for(TDoubleMap map : maps.values()) 82 | n += map.size(); 83 | return n; 84 | } 85 | public void gut() { 86 | for(TDoubleMap map : maps.values()) 87 | map.gut(); 88 | } 89 | 90 | public Iterator>> iterator() { 91 | return maps.entrySet().iterator(); 92 | } 93 | public Set>> entrySet() { return maps.entrySet(); } 94 | public Set keySet() { return maps.keySet(); } 95 | public Collection> values() { return maps.values(); } 96 | 97 | // If keys are locked, we can share the same keys. 98 | public T2DoubleMap copy() { 99 | return copy(newMap()); 100 | } 101 | public T2DoubleMap copy(T2DoubleMap newMap) { 102 | newMap.locked = locked; 103 | for(Map.Entry> e : maps.entrySet()) 104 | newMap.maps.put(e.getKey(), e.getValue().copy()); 105 | return newMap; 106 | } 107 | public T2DoubleMap restrict(Set set1, Set set2) { 108 | return restrict(newMap(), set1, set2); 109 | } 110 | public T2DoubleMap restrict(T2DoubleMap newMap, Set set1, Set set2) { 111 | newMap.locked = locked; 112 | for(Map.Entry> e : maps.entrySet()) 113 | if(set1.contains(e.getKey())) 114 | newMap.maps.put(e.getKey(), e.getValue().restrict(set2)); 115 | return newMap; 116 | } 117 | public T2DoubleMap reverse(T2DoubleMap newMap) { // Return a map with (key2, key1) pairs 118 | for(Map.Entry> e1 : maps.entrySet()) { 119 | S key1 = e1.getKey(); 120 | TDoubleMap map = e1.getValue(); 121 | for(TDoubleMap.Entry e2 : map) { 122 | T key2 = e2.getKey(); 123 | double value = e2.getValue(); 124 | newMap.put(key2, key1, value); 125 | } 126 | } 127 | return newMap; 128 | } 129 | 130 | public void lock() { 131 | for(TDoubleMap map : maps.values()) 132 | map.lock(); 133 | } 134 | public void switchToSortedList() { 135 | for(TDoubleMap map : maps.values()) 136 | map.switchToSortedList(); 137 | } 138 | public void switchToHashTable() { 139 | for(TDoubleMap map : maps.values()) 140 | map.switchToHashTable(); 141 | } 142 | 143 | protected T2DoubleMap newMap() { return new T2DoubleMap(keyFunc); } 144 | 145 | //////////////////////////////////////////////////////////// 146 | 147 | public TDoubleMap getMap(S key1, boolean modify) { 148 | if(key1 == lastKey) return lastMap; 149 | 150 | TDoubleMap map = maps.get(key1); 151 | if(map != null) return map; 152 | if(modify) { 153 | if(locked) 154 | throw new RuntimeException("Cannot make new entry for " + key1 + ", because map is locked"); 155 | maps.put(key1, map = new TDoubleMap(keyFunc)); 156 | 157 | lastKey = key1; 158 | lastMap = map; 159 | return map; 160 | } 161 | else 162 | return null; 163 | } 164 | 165 | //////////////////////////////////////////////////////////// 166 | 167 | private Map> maps = new HashMap>(); 168 | private S lastKey; // Cache last access 169 | private TDoubleMap lastMap; // Cache last access 170 | } 171 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/MapUtils.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.util.*; 4 | 5 | public class MapUtils { 6 | // One-level hash maps 7 | public static boolean contains(Map map, S key) { 8 | return map != null && map.containsKey(key); 9 | } 10 | public static T get(Map map, S key, T defaultVal) { 11 | return map == null || !map.containsKey(key) ? defaultVal : map.get(key); 12 | } 13 | public static T getMut(Map map, S key, T defaultVal) { 14 | if(!map.containsKey(key)) { 15 | map.put(key, defaultVal); // Mutate 16 | return defaultVal; 17 | } 18 | return map.get(key); 19 | } 20 | public static boolean putIfAbsent(Map map, S key, T val) { 21 | if (map.containsKey(key)) return false; 22 | map.put(key, val); 23 | return true; 24 | } 25 | public static void set(Map map, S key, T val) { 26 | map.put(key, val); 27 | } 28 | public static void incr(Map map, S key, int dVal) { 29 | if(!map.containsKey(key)) map.put(key, dVal); 30 | else map.put(key, map.get(key) + dVal); 31 | } 32 | public static void incr(Map map, S key) { 33 | incr(map, key, 1); 34 | } 35 | public static void incr(Map map, S key, double dVal) { 36 | if(!map.containsKey(key)) map.put(key, dVal); 37 | else map.put(key, map.get(key) + dVal); 38 | } 39 | 40 | // Two-level hash maps 41 | public static boolean contains(Map> map, S key1, T key2) { 42 | if(map == null) return false; 43 | Map m = map.get(key1); 44 | return m != null && m.containsKey(key2); 45 | } 46 | public static U get(Map> map, S key1, T key2, U defaultVal) { 47 | if(map == null || !map.containsKey(key1)) return defaultVal; 48 | Map m = map.get(key1); 49 | return m == null || !m.containsKey(key2) ? defaultVal : m.get(key2); 50 | } 51 | public static U getMut(Map> map, S key1, T key2, U defaultVal) { 52 | Map m = map.get(key1); 53 | if(m == null) { 54 | map.put(key1, m = new HashMap()); 55 | m.put(key2, defaultVal); 56 | return defaultVal; 57 | } 58 | else if(!m.containsKey(key2)) { 59 | m.put(key2, defaultVal); 60 | return defaultVal; 61 | } 62 | return m.get(key2); 63 | } 64 | public static void add(Map> map, S key1, T key2) { 65 | Set s = map.get(key1); 66 | if(s == null) map.put(key1, s = new HashSet()); 67 | s.add(key2); 68 | } 69 | public static void set(Map> map, S key1, T key2, U val) { 70 | Map m = map.get(key1); 71 | if(m == null) map.put(key1, m = new HashMap()); 72 | m.put(key2, val); 73 | } 74 | public static void incr(Map> map, S key1, T key2, int dVal) { 75 | Map m = map.get(key1); 76 | if(m == null) { 77 | map.put(key1, m = new HashMap()); 78 | m.put(key2, dVal); 79 | } 80 | else if(!m.containsKey(key2)) 81 | m.put(key2, dVal); 82 | else 83 | m.put(key2, m.get(key2) + dVal); 84 | } 85 | public static void incr(Map> map, S key1, T key2) { 86 | incr(map, key1, key2, 1); 87 | } 88 | public static void incr(Map> map, S key1, T key2, double dVal) { 89 | Map m = map.get(key1); 90 | if(m == null) { 91 | map.put(key1, m = new HashMap()); 92 | m.put(key2, dVal); 93 | } 94 | else if(!m.containsKey(key2)) 95 | m.put(key2, dVal); 96 | else 97 | m.put(key2, m.get(key2) + dVal); 98 | } 99 | 100 | // Create a list if it doesn't exist 101 | public static List getListMut(Map> map, S key) { 102 | List list = map.get(key); 103 | if(list == null) 104 | map.put(key, list = new ArrayList()); 105 | return list; 106 | } 107 | 108 | // Hard operations 109 | // Wrapper for operations on maps and sets 110 | public static T getHard(Map map, S key) { 111 | T value = map.get(key); 112 | if(value == null) throw new RuntimeException("Doesn't contain key: " + key); 113 | return value; 114 | } 115 | public static void putHard(Map map, S key, T value) { 116 | if(map.containsKey(key)) throw new RuntimeException("Already contains key; " + key); 117 | map.put(key, value); 118 | } 119 | public static T removeHard(Map map, S key) { 120 | T value = map.remove(key); 121 | if(value == null) throw new RuntimeException("Doesn't contain key"); 122 | return value; 123 | } 124 | public static void addHard(Set set, S key) { 125 | if(set.contains(key)) throw new RuntimeException("Already contains key"); 126 | set.add(key); 127 | } 128 | public static void removeHard(Set set, S key) { 129 | if(!set.remove(key)) throw new RuntimeException("Doesn't contain key"); 130 | } 131 | 132 | // Print out the top k values a hash table sorted by descending value 133 | // Should only take O(k \log n) time, 134 | // but right now the implementation is slow 135 | public static PriorityQueue toPriorityQueue(Map map) { 136 | PriorityQueue pq = new PriorityQueue(); 137 | for(Map.Entry e : map.entrySet()) 138 | pq.add(e.getKey(), e.getValue()); 139 | return pq; 140 | } 141 | public static PriorityQueue toPriorityQueue(TDoubleMap map) { 142 | PriorityQueue pq = new PriorityQueue(); 143 | for(TDoubleMap.Entry e : map) 144 | pq.add(e.getKey(), e.getValue()); 145 | return pq; 146 | } 147 | public static String topNToString(TDoubleMap map, int numTop) { 148 | return topNToString(toPriorityQueue(map), numTop); 149 | } 150 | public static String topNToString(Map map, int numTop) { 151 | return topNToString(toPriorityQueue(map), numTop); 152 | } 153 | public static String topNToString(PriorityQueue pq, int numTop) { 154 | StringBuilder sb = new StringBuilder(); 155 | sb.append('{'); 156 | for(Pair pair : getTopN(pq, numTop)) { 157 | Object key = pair.getFirst(); 158 | double value = pair.getSecond(); 159 | sb.append(' '); 160 | sb.append(key); 161 | sb.append(':'); 162 | sb.append(Fmt.D(value)); 163 | } 164 | if(pq.size() > numTop) 165 | sb.append(" ...("+(pq.size()-numTop)+ " more)"); 166 | sb.append(" }"); 167 | return sb.toString(); 168 | } 169 | // Return a list of the top n elements in the following structures 170 | public static List> getTopN(Map map, int n) { 171 | return getTopN(toPriorityQueue(map), n); 172 | } 173 | public static List> getTopN(TDoubleMap map, int n) { 174 | return getTopN(toPriorityQueue(map), n); 175 | } 176 | public static List> getTopN(PriorityQueue pq, int n) { 177 | List> list = new ArrayList>(); 178 | for(int i = 0; i < n && pq.hasNext(); i++) { 179 | double priority = pq.getPriority(); 180 | T element = pq.next(); 181 | list.add(new Pair(element, priority)); 182 | } 183 | return list; 184 | } 185 | 186 | public static Map compose(Map m1, Map m2, Map mapToFill) { 187 | for (Map.Entry entry: m1.entrySet()) { 188 | V val = m2.get(entry.getValue()); 189 | if (val != null) 190 | mapToFill.put(entry.getKey(), val); 191 | } 192 | return mapToFill; 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /src/main/java/fast/evaluation/PredictivePerformance.java: -------------------------------------------------------------------------------- 1 | /** 2 | * FAST v1.0 08/12/2014 3 | * 4 | * This code is only for research purpose not commercial purpose. 5 | * It is originally developed for research purpose and is still under improvement. 6 | * Please email to us if you want to keep in touch with the latest release. 7 | We sincerely welcome you to contact Yun Huang (huangyun.ai@gmail.com), or Jose P.Gonzalez-Brenes (josepablog@gmail.com) for problems in the code or cooperation. 8 | * We thank Taylor Berg-Kirkpatrick (tberg@cs.berkeley.edu) and Jean-Marc Francois (jahmm) for part of their codes that FAST is developed based on. 9 | * 10 | */ 11 | 12 | package fast.evaluation; 13 | 14 | //import java.text.DecimalFormat; 15 | import java.util.ArrayList; 16 | 17 | public class PredictivePerformance { 18 | 19 | 20 | // private String splitter = "[,\t]"; 21 | // private String delim = ","; 22 | // confusion matrix's "positive" corresponds to majorityName 23 | private String label1Name = "correct"; 24 | private String label0Name = "incorrect"; 25 | private String majorityName = "correct"; 26 | //private DecimalFormat formatter = null; 27 | //private String minorityName = "incorrect"; 28 | 29 | public PredictivePerformance(){ 30 | } 31 | 32 | public PredictivePerformance(String majorityName, String minorityName, String label1Name, String label0Name){ 33 | this.majorityName = majorityName; 34 | //this.minorityName = minorityName; 35 | this.label1Name = label1Name; 36 | this.label0Name = label0Name; 37 | if (!majorityName.equals(label1Name) && !majorityName.equals(label0Name)) 38 | System.out.println("ERROR: majorityName string and actual label name mismatch!"); 39 | if (!minorityName.equals(label1Name) && !minorityName.equals(label0Name)) 40 | System.out.println("ERROR: minorityName string and actual label name mismatch!"); 41 | } 42 | 43 | // public void setFormatter(DecimalFormat formatter){ 44 | // this.formatter = formatter; 45 | // } 46 | 47 | public Metrics evaluateClassifier(ArrayList actualLabels, ArrayList predictLabels, ArrayList predictProbs, 48 | String name){//, BufferedWriter writer, boolean writeHeader){ 49 | Double nbObs = 0.0; 50 | Double majAUC = 0.0; 51 | Double LL = 0.0; 52 | Double accuracy = 0.0; 53 | Double rmse = 0.0; 54 | Double majFmeasure = 0.0, minFmeasure = 0.0; 55 | Double majPrecision = 0.0, minPrecision = 0.0; 56 | Double majRecall = 0.0, minRecall = 0.0; 57 | Double TP = 0.0, TN = 0.0, FP = 0.0, FN = 0.0; 58 | 59 | int nbAccurate = 0; 60 | Double squaredError = 0.0; 61 | //double totalNbInstances = 0.0; 62 | 63 | if ((actualLabels.size() != predictLabels.size()) || (actualLabels.size() != predictLabels.size()) || (predictLabels.size() != predictProbs.size())) { 64 | System.out.println("ERROR: actualLabel, predictLabel and predictProbs size mismatch!"); 65 | System.exit(1); 66 | } 67 | if (actualLabels.size() == 1) { 68 | System.out.println("WARNING: actualLabel size=1!"); 69 | // System.exit(1); 70 | } 71 | 72 | for (int insId = 0; insId < actualLabels.size(); insId++) { 73 | Integer actualLabel = actualLabels.get(insId); 74 | Integer predictLabel = predictLabels.get(insId); 75 | Double predictProb = predictProbs.get(insId); 76 | //System.out.println(actualLabel + delim + predictLabel + delim + predictProb); 77 | 78 | nbObs++; 79 | 80 | if (predictProb > 1.0 || predictProb < 0.0) { 81 | System.out.println("Error:predictProb > 1.0 || predictProb < 0.0! insId=" + insId); 82 | System.exit(1); 83 | } 84 | 85 | majAUC = getAUC(actualLabels, predictProbs); 86 | if (Double.isNaN(majAUC)) 87 | System.out.println("WARNING: AUC=" + majAUC); 88 | // minAUC = -1; 89 | // Not sure about the following method is correct (a little bit different 90 | // from weka...) 91 | // ArrayList inverseActualLabels = new ArrayList(); 92 | // ArrayList inversePredictProbs = new ArrayList(); 93 | // for (int ii = 0; ii < actualLabels.size(); ii++) { 94 | // int label = actualLabels.get(ii); 95 | // double prob = predictProbs.get(ii); 96 | // inverseActualLabels.add(1 - label); 97 | // inversePredictProbs.add(1.0 - prob); 98 | // } 99 | // minAUC[foldRunID] = getAUC(inverseActualLabels, inversePredictProbs); 100 | 101 | Double curLL = Double.NaN; 102 | if (actualLabel == 1){// (predictProb >= 0.5) 103 | if (predictProb > 0.0) 104 | curLL = Math.log10(predictProb); 105 | } 106 | else{ 107 | if (predictProb < 1.0) 108 | curLL = Math.log10(1.0 - predictProb); 109 | } 110 | LL += curLL; 111 | 112 | if (predictLabel == actualLabel) { 113 | nbAccurate += 1; 114 | if ((predictLabel == 1 && label1Name.equals(majorityName)) || (predictLabel == 0 && label0Name.equals(majorityName))) 115 | TP += 1; 116 | else 117 | TN += 1; 118 | } 119 | else { 120 | if ((predictLabel == 1 && label1Name.equals(majorityName)) || (predictLabel == 0 && label0Name.equals(majorityName))) 121 | FP += 1; 122 | else 123 | FN += 1; 124 | } 125 | 126 | squaredError += Math.pow(actualLabel - predictProb, 2); 127 | if (Double.isNaN(squaredError)) { 128 | System.out.println("Error: squaredError is NaN"); 129 | System.exit(1); 130 | } 131 | } 132 | 133 | accuracy = (1.0 * nbAccurate) / nbObs; 134 | rmse = Math.sqrt(squaredError / nbObs); 135 | if (rmse == 0.0) 136 | System.out.println("WARNING: RMSE=0"); 137 | majPrecision = (TP + FP == 0.0) ? 0.0 : (1.0 * TP) / (TP + FP); 138 | minPrecision = (TN + FN == 0.0) ? 0.0 : (1.0 * TN) / (TN + FN); 139 | majRecall = (TP + FN == 0.0) ? 0.0 : (1.0 * TP) / (TP + FN); 140 | minRecall = (TN + FP == 0.0) ? 0.0 : (1.0 * TN) / (TN + FP); 141 | double denominator = majPrecision + majRecall; 142 | majFmeasure = (denominator == 0.0) ? 0.0 : (2 * majPrecision * majRecall) / (majPrecision + majRecall); 143 | denominator = minPrecision + minRecall; 144 | minFmeasure = (denominator == 0.0) ? 0.0 : (2 * minPrecision * minRecall) / (minPrecision + minRecall); 145 | double meanLLPerObs = LL/nbObs; 146 | 147 | Metrics eval = new Metrics(name); 148 | //Metrics.setFormatter(formatter); 149 | eval.setMetricValue("NbObs(test)", nbObs); 150 | eval.setMetricValue("AUC", majAUC); 151 | eval.setMetricValue("LogLikelihood_base10", LL); 152 | eval.setMetricValue("MeanLLPerObs",meanLLPerObs); 153 | eval.setMetricValue("Accuracy", accuracy); 154 | eval.setMetricValue("RMSE", rmse); 155 | eval.setMetricValue("MajFmeasure", majFmeasure); 156 | eval.setMetricValue("MinFmeasure", minFmeasure); 157 | eval.setMetricValue("MajPrecision", majPrecision); 158 | eval.setMetricValue("MinPrecision", minPrecision); 159 | eval.setMetricValue("MajRecall", majRecall); 160 | eval.setMetricValue("MinRecall", minRecall); 161 | eval.setMetricValue("TP", TP); 162 | eval.setMetricValue("TN", TN); 163 | eval.setMetricValue("FP", FP); 164 | eval.setMetricValue("FN", FN); 165 | 166 | return eval; 167 | } 168 | 169 | public static double getAUC(ArrayList actualLabels, ArrayList predictProbs) { 170 | double[] actualLabelsArray = new double[actualLabels.size()]; 171 | double[] predictProbsArray = new double[actualLabels.size()]; 172 | for (int ii = 0; ii < actualLabels.size(); ii++) { 173 | actualLabelsArray[ii] = actualLabels.get(ii) * 1.0; 174 | predictProbsArray[ii] = predictProbs.get(ii); 175 | } 176 | Sample data = new Sample(actualLabelsArray); 177 | AUC aucCalculator = new AUC(); 178 | double auc = aucCalculator.measure(predictProbsArray, data); 179 | return auc; 180 | } 181 | 182 | } 183 | -------------------------------------------------------------------------------- /src/main/java/fast/common/Stats.java: -------------------------------------------------------------------------------- 1 | package fast.common; 2 | //import java.text.NumberFormat; 3 | import java.util.ArrayList; 4 | //import java.util.List; 5 | 6 | 7 | public final class Stats 8 | { 9 | 10 | public static class ValueIndexSummary{ 11 | public ArrayList values; 12 | public ArrayList indexes; 13 | 14 | public ValueIndexSummary(ArrayList values, ArrayList indexes){ 15 | this.values = values; 16 | this.indexes = indexes; 17 | } 18 | 19 | public String toString(){//NumberFormat nf) { 20 | String str = ""; 21 | String delimiter = ","; 22 | if (values != null && indexes != null){ 23 | for (int i = 0; i < values.size(); i++) { 24 | double value = values.get(i); 25 | int index = indexes.get(i); 26 | str += value + "(" + index + ")" + delimiter; 27 | } 28 | str = str.substring(0, str.length() - delimiter.length()); 29 | } 30 | return str; 31 | } 32 | } 33 | 34 | public static Double min (ArrayList a) 35 | { 36 | if (a == null || a.size() == 0) 37 | return null; 38 | Double min = null; 39 | if (a.size() > 0) 40 | for (int i = 0; i < a.size(); i++) { 41 | Double value = a.get(i); 42 | if (value == null) 43 | continue; 44 | if (i == 0 || min == null) 45 | min = value; 46 | else 47 | min = (value < min)? value: min; 48 | } 49 | return min; 50 | } 51 | 52 | public static Double max (ArrayList a) 53 | { 54 | if (a == null || a.size() == 0) 55 | return null; 56 | Double max = null; 57 | if (a.size() > 0) 58 | for (int i = 0; i < a.size(); i++) { 59 | Double value = a.get(i); 60 | if (value == null) 61 | continue; 62 | if (i == 0 || max == null) 63 | max = value; 64 | else{ 65 | max = (value > max)? value: max; 66 | } 67 | } 68 | return max; 69 | } 70 | 71 | public static ValueIndexSummary max_with_index (ArrayList a) 72 | { 73 | if (a == null || a.size() == 0) 74 | return null; 75 | double max = 0.0; 76 | int index = 0; 77 | ArrayList maxes = new ArrayList(); 78 | ArrayList indexes = new ArrayList(); 79 | if (a.size() > 0) 80 | for (int i = 0; i < a.size(); i++) { 81 | Double value = a.get(i); 82 | if (value == null) 83 | continue; 84 | if (i == 0){ 85 | max = value; 86 | index = i; 87 | maxes.add(max); 88 | indexes.add(index); 89 | } 90 | else if (value > max){ 91 | max = value; 92 | index = i; 93 | maxes = new ArrayList(); 94 | indexes = new ArrayList(); 95 | maxes.add(max); 96 | indexes.add(index); 97 | } 98 | else if (value == max){ 99 | maxes.add(value); 100 | indexes.add(i); 101 | } 102 | } 103 | ValueIndexSummary maxObj = new ValueIndexSummary(maxes,indexes); 104 | return maxObj; 105 | } 106 | 107 | public static ValueIndexSummary min_with_index (ArrayList a) 108 | { 109 | if (a == null || a.size() == 0) 110 | return null; 111 | double min = 0.0; 112 | int index = 0; 113 | ArrayList mins = new ArrayList(); 114 | ArrayList indexes = new ArrayList(); 115 | if (a.size() > 0) 116 | for (int i = 0; i < a.size(); i++) { 117 | Double value = a.get(i); 118 | if (value == null) 119 | continue; 120 | if (i == 0){ 121 | min = value; 122 | index = i; 123 | mins.add(min); 124 | indexes.add(index); 125 | } 126 | else if (value < min){ 127 | min = value; 128 | index = i; 129 | mins = new ArrayList(); 130 | indexes = new ArrayList(); 131 | mins.add(min); 132 | indexes.add(index); 133 | } 134 | else if (value == min){ 135 | mins.add(value); 136 | indexes.add(i); 137 | } 138 | } 139 | ValueIndexSummary minObj = new ValueIndexSummary(mins,indexes); 140 | return minObj; 141 | } 142 | 143 | public static Integer countLessThan (ArrayList a, int lessThanValue) 144 | { 145 | if (a == null || a.size() == 0) 146 | return null; 147 | Integer nb = null; 148 | for (int i = 0; i < a.size(); i++) { 149 | Double value = a.get(i); 150 | if (value == null) 151 | continue; 152 | if (value < lessThanValue){ 153 | if (nb == null) 154 | nb = 0; 155 | nb++; 156 | } 157 | } 158 | return nb; 159 | } 160 | 161 | public static Double sum (ArrayList a) 162 | { 163 | if (a == null || a.size() < 1) 164 | return null; 165 | Double sum = null; 166 | for (Double i : a) 167 | if (i != null){ 168 | if (sum == null) 169 | sum = 0.0; 170 | sum += i; 171 | } 172 | return sum; 173 | } 174 | 175 | public static Double mean (ArrayList a) 176 | { 177 | if (a == null || a.size() < 1) 178 | return null; 179 | Double sum = null;//sum(accuracy); 180 | Double mean = null; 181 | int nb = 0; 182 | for (Double i : a) { 183 | if (i != null && !Double.isNaN(i)){ 184 | if (sum == null) 185 | sum = 0.0; 186 | sum += i; 187 | nb++; 188 | } 189 | } 190 | if (sum != null && nb != 0) 191 | mean = sum / (nb * 1.0); 192 | return mean; 193 | } 194 | 195 | /** 196 | * This formula is the corrected sample standard deviation, which is generally known simply as the "sample standard deviation", which is less biased than the uncorrected sample standard deviation (standard formula of variance taking square root) 197 | * sample sd: When only a sample of data from a population is available, the term standard deviation of the sample or sample standard deviation can refer to either the above-mentioned quantity as applied to those data or to a modified quantity that is a better estimate of the population standard deviation (the standard deviation of the entire population). 198 | * not se:The standard error of the mean (SEM) is the standard deviation of the sample-mean's estimate of a population mean. (It can also be viewed as the standard deviation of the error in the sample mean with respect to the true mean, since the sample mean is an unbiased estimator.) 199 | */ 200 | public static Double sd (ArrayList a) 201 | { 202 | if (a == null || a.size() < 2) 203 | return null; 204 | Double sd = null; 205 | Double sum = null; 206 | Double mean = null; 207 | int nb = 0; 208 | for (Double i : a) { 209 | if (i != null && !Double.isNaN(i)){ 210 | if (sum == null) 211 | sum = 0.0; 212 | sum += i; 213 | nb++; 214 | } 215 | } 216 | if (sum != null && nb != 0) 217 | mean = sum / (nb * 1.0); 218 | 219 | if (mean == null) 220 | return null; 221 | 222 | sum = null; 223 | for (Double i : a) 224 | if (i != null && !Double.isNaN(i)){ 225 | if (sum == null) 226 | sum = 0.0; 227 | sum += Math.pow((i - mean), 2); 228 | } 229 | if (sum != null && nb > 1) 230 | sd = Math.sqrt( sum / ( nb - 1.0 ) ); // sample 231 | 232 | return sd; 233 | } 234 | 235 | } 236 | -------------------------------------------------------------------------------- /src/main/java/fig/basic/PriorityQueue.java: -------------------------------------------------------------------------------- 1 | package fig.basic; 2 | 3 | import java.util.Iterator; 4 | import java.util.NoSuchElementException; 5 | import java.util.List; 6 | import java.util.ArrayList; 7 | import java.io.Serializable; 8 | 9 | /** 10 | * A priority queue based on a binary heap. Note that this implementation does 11 | * not efficiently support containment, removal, or element promotion 12 | * (decreaseKey) -- these methods are therefore not yet implemented. 13 | * 14 | * @author Dan Klein 15 | */ 16 | public class PriorityQueue implements Iterator, Serializable, 17 | Cloneable { 18 | private static final long serialVersionUID = 5724671156522771658L; 19 | int size; 20 | int capacity; 21 | List elements; 22 | double[] priorities; 23 | 24 | protected void grow(int newCapacity) { 25 | List newElements = new ArrayList(newCapacity); 26 | double[] newPriorities = new double[newCapacity]; 27 | if (size > 0) { 28 | newElements.addAll(elements); 29 | System.arraycopy(priorities, 0, newPriorities, 0, priorities.length); 30 | } 31 | elements = newElements; 32 | priorities = newPriorities; 33 | capacity = newCapacity; 34 | } 35 | 36 | protected int parent(int loc) { 37 | return (loc - 1) / 2; 38 | } 39 | 40 | protected int leftChild(int loc) { 41 | return 2 * loc + 1; 42 | } 43 | 44 | protected int rightChild(int loc) { 45 | return 2 * loc + 2; 46 | } 47 | 48 | protected void heapifyUp(int loc) { 49 | if (loc == 0) return; 50 | int parent = parent(loc); 51 | if (priorities[loc] > priorities[parent]) { 52 | swap(loc, parent); 53 | heapifyUp(parent); 54 | } 55 | } 56 | 57 | protected void heapifyDown(int loc) { 58 | int max = loc; 59 | int leftChild = leftChild(loc); 60 | if (leftChild < size()) { 61 | double priority = priorities[loc]; 62 | double leftChildPriority = priorities[leftChild]; 63 | if (leftChildPriority > priority) 64 | max = leftChild; 65 | int rightChild = rightChild(loc); 66 | if (rightChild < size()) { 67 | double rightChildPriority = priorities[rightChild(loc)]; 68 | if (rightChildPriority > priority && rightChildPriority > leftChildPriority) 69 | max = rightChild; 70 | } 71 | } 72 | if (max == loc) 73 | return; 74 | swap(loc, max); 75 | heapifyDown(max); 76 | } 77 | 78 | protected void swap(int loc1, int loc2) { 79 | double tempPriority = priorities[loc1]; 80 | E tempElement = elements.get(loc1); 81 | priorities[loc1] = priorities[loc2]; 82 | elements.set(loc1, elements.get(loc2)); 83 | priorities[loc2] = tempPriority; 84 | elements.set(loc2, tempElement); 85 | } 86 | 87 | protected void removeFirst() { 88 | if (size < 1) return; 89 | swap(0, size - 1); 90 | size--; 91 | elements.remove(size); 92 | heapifyDown(0); 93 | } 94 | 95 | /** 96 | * Returns true if the priority queue is non-empty 97 | */ 98 | public boolean hasNext() { 99 | return ! isEmpty(); 100 | } 101 | 102 | /** 103 | * Returns the element in the queue with highest priority, and pops it from 104 | * the queue. 105 | */ 106 | public E next() { 107 | E first = peek(); 108 | removeFirst(); 109 | return first; 110 | } 111 | 112 | /** 113 | * Not supported -- next() already removes the head of the queue. 114 | */ 115 | public void remove() { 116 | throw new UnsupportedOperationException(); 117 | } 118 | 119 | /** 120 | * Returns the highest-priority element in the queue, but does not pop it. 121 | */ 122 | public E peek() { 123 | if (size() > 0) 124 | return elements.get(0); 125 | throw new NoSuchElementException(); 126 | } 127 | 128 | /** 129 | * Gets the priority of the highest-priority element of the queue. 130 | */ 131 | public double getPriority() { 132 | if (size() > 0) 133 | return priorities[0]; 134 | throw new NoSuchElementException(); 135 | } 136 | 137 | /** 138 | * Number of elements in the queue. 139 | */ 140 | public int size() { 141 | return size; 142 | } 143 | 144 | /** 145 | * True if the queue is empty (size == 0). 146 | */ 147 | public boolean isEmpty() { 148 | return size == 0; 149 | } 150 | 151 | /** 152 | * Adds a key to the queue with the given priority. If the key is already in 153 | * the queue, it will be added an additional time, NOT promoted/demoted. 154 | * 155 | * @param key 156 | * @param priority 157 | */ 158 | public boolean add(E key, double priority) { 159 | if (size == capacity) { 160 | grow(2 * capacity + 1); 161 | } 162 | elements.add(key); 163 | priorities[size] = priority; 164 | heapifyUp(size); 165 | size++; 166 | return true; 167 | } 168 | 169 | /** 170 | * Returns a representation of the queue in decreasing priority order. 171 | */ 172 | public String toString() { 173 | return toString(size()); 174 | } 175 | 176 | /** 177 | * Returns a representation of the queue in decreasing priority order, 178 | * displaying at most maxKeysToPring elements. 179 | * 180 | * @param maxKeysToPrint 181 | */ 182 | public String toString(int maxKeysToPrint) { 183 | PriorityQueue pq = clone(); 184 | StringBuilder sb = new StringBuilder("["); 185 | int numKeysPrinted = 0; 186 | while (numKeysPrinted < maxKeysToPrint && pq.hasNext()) { 187 | double priority = pq.getPriority(); 188 | E element = pq.next(); 189 | sb.append(element.toString()); 190 | sb.append(" : "); 191 | sb.append(priority); 192 | if (numKeysPrinted < size() - 1) 193 | // sb.append("\n"); 194 | sb.append(", "); 195 | numKeysPrinted++; 196 | } 197 | if (numKeysPrinted < size()) 198 | sb.append("..."); 199 | sb.append("]"); 200 | return sb.toString(); 201 | } 202 | 203 | /** 204 | * Returns a counter whose keys are the elements in this priority queue, and 205 | * whose counts are the priorities in this queue. In the event there are 206 | * multiple instances of the same element in the queue, the counter's count 207 | * will be the sum of the instances' priorities. 208 | * 209 | * @return 210 | */ 211 | /*public Counter asCounter() { 212 | PriorityQueue pq = clone(); 213 | Counter counter = new Counter(); 214 | while (pq.hasNext()) { 215 | double priority = pq.getPriority(); 216 | E element = pq.next(); 217 | counter.incrementCount(element, priority); 218 | } 219 | return counter; 220 | }*/ 221 | 222 | /** 223 | * Returns a clone of this priority queue. Modifications to one will not 224 | * affect modifications to the other. 225 | */ 226 | public PriorityQueue clone() { 227 | PriorityQueue clonePQ = new PriorityQueue(); 228 | clonePQ.size = size; 229 | clonePQ.capacity = capacity; 230 | clonePQ.elements = new ArrayList(capacity); 231 | clonePQ.priorities = new double[capacity]; 232 | if (size() > 0) { 233 | clonePQ.elements.addAll(elements); 234 | System.arraycopy(priorities, 0, clonePQ.priorities, 0, size()); 235 | } 236 | return clonePQ; 237 | } 238 | 239 | public PriorityQueue() { 240 | this(15); 241 | } 242 | 243 | public PriorityQueue(int capacity) { 244 | int legalCapacity = 0; 245 | while (legalCapacity < capacity) { 246 | legalCapacity = 2 * legalCapacity + 1; 247 | } 248 | grow(legalCapacity); 249 | } 250 | 251 | public static void main(String[] args) { 252 | PriorityQueue pq = new PriorityQueue(); 253 | System.out.println(pq); 254 | pq.add("one",1); 255 | System.out.println(pq); 256 | pq.add("three",3); 257 | System.out.println(pq); 258 | pq.add("one",1.1); 259 | System.out.println(pq); 260 | pq.add("two",2); 261 | System.out.println(pq); 262 | System.out.println(pq.toString(2)); 263 | while (pq.hasNext()) { 264 | System.out.println(pq.next()); 265 | } 266 | } 267 | } 268 | -------------------------------------------------------------------------------- /src/main/java/fast/featurehmm/LBFGS.java: -------------------------------------------------------------------------------- 1 | /** 2 | * FAST v1.0 08/12/2014 3 | * 4 | * This code is only for research purpose not commercial purpose. 5 | * It is originally developed for research purpose and is still under improvement. 6 | * Please email to us if you want to keep in touch with the latest release. 7 | We sincerely welcome you to contact Yun Huang (huangyun.ai@gmail.com), or Jose P.Gonzalez-Brenes (josepablog@gmail.com) for problems in the code or cooperation. 8 | * We thank Taylor Berg-Kirkpatrick (tberg@cs.berkeley.edu) and Jean-Marc Francois (jahmm) for part of their codes that FAST is developed based on. 9 | * 10 | */ 11 | 12 | package fast.featurehmm; 13 | 14 | import edu.berkeley.nlp.math.CachingDifferentiableFunction; 15 | import edu.berkeley.nlp.math.LBFGSMinimizer; 16 | import edu.berkeley.nlp.util.Logger; 17 | import edu.berkeley.nlp.util.Pair; 18 | 19 | public class LBFGS { 20 | 21 | private final double tol; 22 | private final int max_iters; 23 | private final double[] regularizationWeights; 24 | private final double[] regularizationBiases; 25 | private final double[][] featureValues; 26 | private final double[] expectedCounts; 27 | private final int[] classes; 28 | private final int nbFeatures; 29 | private final String type; 30 | 31 | private double[] featureWeights; 32 | private PdfFeatureAwareLogisticRegression pdf; 33 | private int nbParameterizingFailed = 0;//must set 0 34 | private boolean verbose; 35 | 36 | //in Logistic regression, it already creates new memory space 37 | public LBFGS(double[] initialFeatureWeights, PdfFeatureAwareLogisticRegression pdf, 38 | double[] expectedCounts, 39 | double[][] featureValues, int[] classes, 40 | String type, 41 | double[] regularizationWeights, double[] regularizationBiases, 42 | int max_iters, double tolerance) { 43 | this.featureWeights = initialFeatureWeights; 44 | this.pdf = pdf; 45 | 46 | this.expectedCounts = expectedCounts; 47 | this.featureValues = featureValues; 48 | this.classes = classes; 49 | this.nbFeatures = initialFeatureWeights.length; 50 | this.type = type; // trans or emit 51 | this.regularizationWeights = regularizationWeights; 52 | this.regularizationBiases = regularizationBiases; 53 | this.max_iters = max_iters; 54 | this.tol = tolerance; 55 | } 56 | 57 | 58 | public double[] run() { 59 | NegativeRegularizedExpectedLogLikelihood negativeLikelihood = new NegativeRegularizedExpectedLogLikelihood(); 60 | 61 | LBFGSMinimizer minimizer = new LBFGSMinimizer(); 62 | minimizer.setMaxIterations(max_iters); 63 | minimizer.setVerbose(verbose); 64 | 65 | try { 66 | minimizer.minimize(negativeLikelihood, featureWeights, tol); 67 | } 68 | catch (RuntimeException ex) { 69 | nbParameterizingFailed = 1; 70 | Logger.err("RuntimeException probably caused by [LBFGSMinimizer.implicitMultiply]: Curvature problem. parameterizingSucceeded=false."); 71 | } 72 | 73 | return featureWeights; 74 | } 75 | 76 | public void setFeatureWeights(double[] featureWeights) { 77 | this.featureWeights = featureWeights; 78 | } 79 | 80 | public double[] getFeatureWeights(){ 81 | return featureWeights; 82 | } 83 | 84 | // for getting new theta (emitProb and transProb) by new weights and original 85 | // featureValues 86 | private void computePotentials(double[] featureWeights) { 87 | pdf.setFeatureWeights(featureWeights); 88 | } 89 | 90 | // hy:computes one LL consisting of transition and emission 91 | private class NegativeRegularizedExpectedLogLikelihood extends 92 | CachingDifferentiableFunction { 93 | 94 | // Pair.makePair(negativeRegularizedExpectedLogLikelihood, gradient), both 95 | // of them are updated in every iteration 96 | protected Pair calculate(double[] x) { 97 | // print(x, "featureWeights"); 98 | setFeatureWeights(x); 99 | computePotentials(x); 100 | 101 | // hy: just get the small ll as the paper shows 102 | double negativeRegularizedExpectedLogLikelihood = 0.0; 103 | 104 | // JPG removed this: 105 | // if (opts.oneLogisticRegression) 106 | // negativeRegularizedExpectedLogLikelihood = -(pdf 107 | // .calculateExpectedLogLikelihood(expectedCounts, featureValues, 108 | // outcomes) - calculateRegularizer()); 109 | negativeRegularizedExpectedLogLikelihood = -(pdf.calculateExpectedLogLikelihood(expectedCounts, featureValues, 110 | classes, type) - calculateRegularizer()); 111 | 112 | // Calculate gradient 113 | double[] gradient = new double[featureWeights.length]; 114 | // Gradient of emit weights (hy: doesn't have transition part ;-) 115 | int nbDatapoints = expectedCounts.length; 116 | 117 | /* 118 | * JPG COMMENTED THIS: if (opts.forceSetInstanceWeightForLBFGS > 0) { for 119 | * (int e = 0; e < expectedCounts.length; e++) { expectedCounts[e] = 120 | * opts.forceSetInstanceWeightForLBFGS; } } 121 | */ 122 | 123 | // print(expectedCounts, "expected counts:"); 124 | for (int i = 0; i < nbDatapoints; i++) { 125 | // System.out.println("dp id=" + i); 126 | double expectedCount = expectedCounts[i]; 127 | double[] features = featureValues[i]; 128 | if (features.length != featureWeights.length) { 129 | System.out.println("ERROR: features.length !=featureWeights.length"); 130 | System.exit(1); 131 | } 132 | int curClass = classes[i]; 133 | // int hiddenStateIndex = (i >= nbDatapoints / 2) ? 1 : 0; 134 | 135 | // TODO: no bias yet 136 | for (int featureIndex = 0; featureIndex < features.length; featureIndex++) { 137 | if (curClass == 1) { 138 | gradient[featureIndex] -= expectedCount * features[featureIndex] 139 | * (1 - pdf.probability(features, curClass, type)); 140 | } 141 | else { 142 | gradient[featureIndex] -= expectedCount * features[featureIndex] 143 | * (-1.0) * pdf.probability(features, 1, type); 144 | } 145 | 146 | // for (int featureIndex = 0; featureIndex < features.length; 147 | // featureIndex++) { 148 | // if (outcome == 1) { 149 | // gradient[featureIndex] -= expectedCount * features[featureIndex] 150 | // * (1 - pdf.probability(features, outcome)); 151 | // } 152 | // else { 153 | // gradient[featureIndex] -= expectedCount * features[featureIndex] 154 | // * (-1.0) * pdf.probability(features, 1); 155 | // } 156 | 157 | } 158 | } 159 | 160 | /** 161 | * for (int s = 0; s < opts.nbHiddenStates; ++s) { for (int i = 0; i < 162 | * numObservations; ++i) { for (int f = 0; f < 163 | * activeEmitFeatures[s][i].size(); ++f) { Pair feat = 164 | * activeEmitFeatures[s][i].get(f); // sum_dct(e_dct) * f(dct) 165 | * gradient[feat.getFirst()] -= expectedEmitCounts[s][i] feat.getSecond(); 166 | * // sum_dct(e_dct)*sum_d'(thita_d'ct*f(d'ct)) // guess: 167 | * expectedLabelCounts[s] = sum_i(expectedEmitCounts[s][i]) 168 | * gradient[feat.getFirst()] -= -expectedLabelCounts[s] * emitProbs[s][i] 169 | * * feat.getSecond(); } 170 | * 171 | * } } 172 | **/ 173 | 174 | // print(gradient, "Gradient"); 175 | 176 | // Add gradient of regularizer 177 | for (int f = 0; f < nbFeatures; ++f) { 178 | gradient[f] += 2.0 * regularizationWeights[f] 179 | * (featureWeights[f] - regularizationBiases[f]); 180 | } 181 | // print(gradient, "RegGradient"); 182 | // print(x, "featureWeights"); 183 | // System.out.println("negativeRegularizedExpectedLogLikelihood:\t" 184 | // + negativeRegularizedExpectedLogLikelihood); 185 | 186 | return Pair.makePair(negativeRegularizedExpectedLogLikelihood, gradient); 187 | } 188 | 189 | public int dimension() { 190 | return nbFeatures; 191 | } 192 | 193 | } 194 | 195 | // public void print(double[] temp, String info) { 196 | // System.out.print(info + ":\t"); 197 | // for (int i = 0; i < temp.length; i++) 198 | // System.out.print(temp[i] + "\t"); 199 | // System.out.println(); 200 | // 201 | // } 202 | 203 | public int getParameterizingResult(){ 204 | return nbParameterizingFailed; 205 | } 206 | 207 | //TODO: for those features that never get activated, shouldn't be included in regularization 208 | private double calculateRegularizer() { 209 | double result = 0.0; 210 | for (int f = 0; f < nbFeatures; ++f) { 211 | result += regularizationWeights[f] 212 | * (featureWeights[f] - regularizationBiases[f]) 213 | * (featureWeights[f] - regularizationBiases[f]); 214 | } 215 | return result; 216 | } 217 | 218 | } 219 | -------------------------------------------------------------------------------- /src/main/java/fast/featurehmm/FeatureHMM.java: -------------------------------------------------------------------------------- 1 | /** 2 | * FAST v1.0 08/12/2014 3 | * 4 | * This code is only for research purpose not commercial purpose. 5 | * It is originally developed for research purpose and is still under improvement. 6 | * Please email to us if you want to keep in touch with the latest release. 7 | We sincerely welcome you to contact Yun Huang (huangyun.ai@gmail.com), or Jose P.Gonzalez-Brenes (josepablog@gmail.com) for problems in the code or cooperation. 8 | * We thank Taylor Berg-Kirkpatrick (tberg@cs.berkeley.edu) and Jean-Marc Francois (jahmm) for part of their code that FAST is developed based on. 9 | * 10 | */ 11 | 12 | /* 13 | * This is built based on: 14 | * 15 | * jaHMM package - v0.6.1 16 | * Copyright (c) 2004-2006, Jean-Marc Francois. 17 | */ 18 | 19 | package fast.featurehmm; 20 | 21 | import java.io.Serializable; 22 | //import java.text.NumberFormat; 23 | import java.util.ArrayList; 24 | //import java.util.HashSet; 25 | //import java.util.List; 26 | //import be.ac.ulg.montefiore.run.jahmm.Observation; 27 | import fast.data.DataPoint; 28 | 29 | 30 | public class FeatureHMM implements Serializable, Cloneable { 31 | private static final long serialVersionUID = 2L; 32 | 33 | private final boolean parameterizedInit, parameterizedTran, parameterizedEmit; 34 | private final boolean allowForget; 35 | private final int nbHiddenStates; 36 | private final ArrayList initialPdfs;// = new ArrayList(); 37 | private final ArrayList transitionPdfs;// = new ArrayList(); 38 | private final ArrayList emissionPdfs;// = new ArrayList(); 39 | // private final PdfFeatureAwareLogisticRegression initialPdf;// = new ArrayList(); 40 | // private final PdfFeatureAwareLogisticRegression transitionPdf;// = new ArrayList(); 41 | // private final PdfFeatureAwareLogisticRegression emissionPdf;// = new ArrayList(); 42 | 43 | 44 | public FeatureHMM(ArrayList initPdfs, 45 | ArrayList transPdfs, 46 | ArrayList emitPdfs//, 47 | //boolean parameterizing, boolean parameterizedInit, boolean parameterizedTran, boolean parameterizedEmit, 48 | //boolean allowForget, 49 | ){ 50 | //int restartId, String kcName, String modelName) { 51 | if (initPdfs.size() == 0 || initPdfs.size() != transPdfs.size() || emitPdfs.size() != transPdfs.size()) 52 | throw new IllegalArgumentException("ERROR: Wrong initial parameters for HMM (constructor)!"); 53 | this.initialPdfs = initPdfs; 54 | this.transitionPdfs = transPdfs; 55 | this.emissionPdfs = emitPdfs; 56 | this.parameterizedInit = initPdfs.get(0).getParameterizedInit(); 57 | this.parameterizedTran = transPdfs.get(0).getParameterizedTran(); 58 | this.parameterizedEmit = emitPdfs.get(0).getParameterizedEmit(); 59 | // if (this.parameterizedInit || this.parameterizedTran || this.parameterizedEmit) 60 | // this.parameterizing = true; 61 | // else 62 | // this.parameterizing = false; 63 | this.allowForget = transPdfs.get(0).getAllowForget(); 64 | this.nbHiddenStates = initPdfs.size(); 65 | // this.restartId = restartId; 66 | // this.kcName = kcName; 67 | // this.modelName = modelName; 68 | } 69 | 70 | /* The same for i=0 or i=1, because the code expanded features for both hiddenStates to train one logistic regression. */ 71 | public PdfFeatureAwareLogisticRegression getInitialPdf(int i) { 72 | return initialPdfs.get(i); 73 | } 74 | 75 | public PdfFeatureAwareLogisticRegression getTransitionPdf(int i) { 76 | return transitionPdfs.get(i); 77 | } 78 | 79 | public PdfFeatureAwareLogisticRegression getEmissionPdf(int i) { 80 | return emissionPdfs.get(i); 81 | } 82 | 83 | public double getInitiali(int i, double[] featureValues) { 84 | if (!parameterizedInit) 85 | return initialPdfs.get(i).probability(null, i, "init"); 86 | else 87 | return initialPdfs.get(i).probability(featureValues, i, "init"); 88 | 89 | } 90 | 91 | /** The probability associated to the transition going from i to / state j. */ 92 | public double getTransitionij(int i, int j, double[] featureValues) { 93 | if (!parameterizedTran) 94 | return transitionPdfs.get(i).probability(null, j, "trans"); 95 | else{ 96 | if (!allowForget && i == 1){ 97 | return (j == 0 ? 0.0:1.0); 98 | } 99 | else 100 | return transitionPdfs.get(i).probability(featureValues, j, "trans"); 101 | } 102 | } 103 | 104 | public double getEmissionjk(int j, int k, double[] featureValues) { 105 | if (!parameterizedEmit) 106 | return emissionPdfs.get(j).probability(null, k, "emit"); 107 | else{ 108 | double prob = emissionPdfs.get(j).probability(featureValues, k, "emit"); 109 | return prob; 110 | } 111 | } 112 | 113 | public void setInit(PdfFeatureAwareLogisticRegression pdf, int hiddenStateIndex) { 114 | initialPdfs.set(hiddenStateIndex, pdf); 115 | } 116 | 117 | public void setTransition(PdfFeatureAwareLogisticRegression pdf, int hiddenStateIndex) { 118 | transitionPdfs.set(hiddenStateIndex, pdf); 119 | } 120 | 121 | public void setEmission(PdfFeatureAwareLogisticRegression pdf, int hiddenStateIndex) { 122 | emissionPdfs.set(hiddenStateIndex, pdf); 123 | } 124 | 125 | // public String toString(NumberFormat nf) { 126 | // String s = "Hmm with " + nbHiddenStates + " state(s)\n"; 127 | // 128 | // for (int i = 0; i < nbHiddenStates; i++) { 129 | // s += "\nState " + i + "\n"; 130 | // s += " Initial:\n" 131 | // + getHmmString(initialPdfs.get(i), (parameterizedInit ? true 132 | // : false)) + "\n"; 133 | // s += " Transition:\n" 134 | // + getHmmString(transitionPdfs.get(i), (parameterizedTran ? true 135 | // : false)) + "\n"; 136 | // s += " Emission:\n" 137 | // + getHmmString(emissionPdfs.get(i), (parameterizedEmit ? true 138 | // : false)) + "\n"; 139 | // } 140 | // return s; 141 | // } 142 | 143 | // public String getHmmString(PdfFeatureAwareLogisticRegression pdf, 144 | // boolean parameterized) { 145 | // String s = ""; 146 | // //TODO: differentiate between allowing forget vs not 147 | // if (parameterized) { 148 | // s += "\tparameterized:"; 149 | // double[] w = pdf.featureWeights; 150 | // if (w.length == 1) { 151 | // s += "\tprobabilities:\t" + 1 / (1 + Math.exp(-w[0])); 152 | // } 153 | // s += "\tweights:"; 154 | // for (int k = 0; k < w.length; k++) 155 | // s += "\t" + w[k]; 156 | // } 157 | // else { 158 | // double[] p = pdf.probabilities; 159 | // if (p != null) { 160 | // for (int k = 0; k < p.length; k++) 161 | // s += "\t" + pdf.classMapping.get(k) + "\t" + p[k]; 162 | // } 163 | // else 164 | // s += "\t" + pdf.probability; 165 | // 166 | // } 167 | // return s; 168 | // } 169 | // 170 | // public String toString() { 171 | // return toString(NumberFormat.getInstance()); 172 | // } 173 | 174 | public FeatureHMM clone() throws CloneNotSupportedException { 175 | FeatureHMM Hmm = new FeatureHMM(initialPdfs, transitionPdfs, emissionPdfs); 176 | for (int i = 0; i < Hmm.nbHiddenStates; i++) { 177 | Hmm.initialPdfs.set(i, initialPdfs.get(i).clone()); //Can be shallow copy 178 | Hmm.transitionPdfs.set(i, transitionPdfs.get(i).clone()); 179 | Hmm.emissionPdfs.set(i, emissionPdfs.get(i).clone()); 180 | } 181 | return Hmm; 182 | } 183 | 184 | public int getNbHiddenStates() { 185 | // return pi.length; 186 | return nbHiddenStates;//initialPdfs.size(); 187 | } 188 | 189 | public static int getKnownState(FeatureHMM hmm, DataPoint dp, boolean useEmissionToJudgeHiddenStates, boolean allowForget) { 190 | if (!allowForget) 191 | return 1; 192 | int knownState = -1; 193 | if (useEmissionToJudgeHiddenStates){//this requires guess+slip<1; if not, this way of judgement is questionable 194 | double hidden0obs0 = hmm.getEmissionjk(0, 0, dp.getFeatures(0, 2)); 195 | double hidden1obs0 = hmm.getEmissionjk(1, 0, dp.getFeatures(1, 2)); 196 | if (hidden0obs0 > hidden1obs0) { 197 | knownState = 1;//known 198 | } 199 | else { 200 | knownState = 0; 201 | } 202 | } 203 | else{//this requires guess+slip<1; if not, this way of judgement is questionable 204 | double a01 = hmm.getTransitionij(0, 1, dp.getFeatures(0, 1)); 205 | double a10 = hmm.getTransitionij(1, 0, dp.getFeatures(1, 1)); 206 | if (a01 < a10) { 207 | knownState = 0;// know 208 | } 209 | else { 210 | knownState= 1;// know 211 | } 212 | } 213 | return knownState; 214 | } 215 | 216 | public static double checkDegeneracy(FeatureHMM hmm, DataPoint dp, int knownState){ 217 | double guess = hmm.getEmissionjk(1 - knownState, 1, dp.getFeatures(1 - knownState, 2)); 218 | double slip = hmm.getEmissionjk(knownState, 0, dp.getFeatures(knownState, 2)); 219 | return (guess + slip); 220 | } 221 | } 222 | 223 | -------------------------------------------------------------------------------- /src/main/java/fast/featurehmm/ForwardBackwardScaledCalculator.java: -------------------------------------------------------------------------------- 1 | /** 2 | * FAST v1.0 08/12/2014 3 | * 4 | * This code is only for research purpose not commercial purpose. 5 | * It is originally developed for research purpose and is still under improvement. 6 | * Please email to us if you want to keep in touch with the latest release. 7 | We sincerely welcome you to contact Yun Huang (huangyun.ai@gmail.com), or Jose P.Gonzalez-Brenes (josepablog@gmail.com) for problems in the code or cooperation. 8 | * We thank Taylor Berg-Kirkpatrick (tberg@cs.berkeley.edu) and Jean-Marc Francois (jahmm) for part of their codes that FAST is developed based on. 9 | * 10 | */ 11 | 12 | /* 13 | * This is built based on: 14 | * jahmm package - v0.6.1 15 | * Copyright (c) 2004-2006, Jean-Marc Francois. 16 | * 17 | * scaling: this gives P(St+1=qj|O1..Ot+1) for scaled alpha, ctFactors[t] = P(O1..Ot) (for different t, O1~Ot is different) 18 | */ 19 | 20 | package fast.featurehmm; 21 | 22 | import java.util.AbstractList; 23 | import java.util.Arrays; 24 | import java.util.EnumSet; 25 | import java.util.Iterator; 26 | import java.util.List; 27 | import be.ac.ulg.montefiore.run.jahmm.Observation; 28 | import fast.common.Matrix; 29 | import fast.data.DataPoint; 30 | //import be.ac.ulg.montefiore.run.jahmm.ForwardBackwardCalculator; 31 | 32 | //import be.ac.ulg.montefiore.run.jahmm.Hmm; 33 | 34 | public class ForwardBackwardScaledCalculator extends ForwardBackwardCalculator { 35 | /* 36 | * Warning, the semantic of the alpha and beta elements are changed; in this 37 | * class, they have their value scaled. 38 | */ 39 | // Scaling factors 40 | private double[] ctFactors; 41 | private double lnProbability; 42 | private boolean verbose = false; 43 | 44 | /** 45 | * Computes the probability of occurrence of an observation sequence given a 46 | * Hidden Markov Model. The algorithms implemented use scaling to avoid 47 | * underflows. 48 | * 49 | * @param hmm 50 | * A Hidden Markov Model; 51 | * @param oseq 52 | * An observations sequence. 53 | * @param flags 54 | * How the computation should be done. See the 55 | * {@link ForwardBackwardCalculator.Computation}. The alpha array is 56 | * always computed. 57 | */ 58 | public ForwardBackwardScaledCalculator(List oseq, FeatureHMM hmm, 59 | EnumSet flags) { 60 | // System.out.println("ForwardBackwardScaledCalculator..."); 61 | if (oseq.isEmpty()) 62 | throw new IllegalArgumentException(); 63 | 64 | ctFactors = new double[oseq.size()]; 65 | Arrays.fill(ctFactors, 0.); 66 | 67 | computeAlpha(hmm, oseq); 68 | if (verbose) { 69 | System.out.println("alpha:"); 70 | for (int t = 0; t < alpha.length; t++) { 71 | System.out.print("\ttime=" + t); 72 | for (int state = 0; state < alpha[t].length; state++) { 73 | System.out.print("\tstate" + state + "=\t" + alpha[t][state]); 74 | } 75 | System.out.print("\n"); 76 | } 77 | System.out.print("\n"); 78 | } 79 | 80 | if (verbose) { 81 | System.out.println("scaling ctFactors:"); 82 | for (int t = 0; t < ctFactors.length; t++) 83 | System.out.println("\ttime=" + t + "\t" + ctFactors[t]); 84 | System.out.print("\n"); 85 | } 86 | 87 | if (flags.contains(Computation.BETA)) 88 | computeBeta(hmm, oseq); 89 | if (verbose) { 90 | System.out.println("beta:"); 91 | for (int t = 0; t < beta.length; t++) { 92 | System.out.print("\ttime=" + t); 93 | for (int state = 0; state < beta[t].length; state++) { 94 | System.out.print("\tstate" + state + "=\t" + beta[t][state]); 95 | } 96 | System.out.print("\n"); 97 | } 98 | System.out.print("\n"); 99 | } 100 | 101 | computeProbability(oseq); 102 | } 103 | 104 | /** 105 | * Computes the probability of occurence of an observation sequence given a 106 | * Hidden Markov Model. This computation computes the scaled 107 | * alpha array as a side effect. 108 | * 109 | * @see #ForwardBackwardScaledCalculator(List, FeatureHMM, EnumSet) 110 | */ 111 | public ForwardBackwardScaledCalculator(List oseq, FeatureHMM hmm) { 112 | this(oseq, hmm, EnumSet.of(Computation.ALPHA)); 113 | } 114 | 115 | /* Computes the content of the scaled alpha array */ 116 | @Override 117 | protected void computeAlpha(FeatureHMM hmm, List oseq) { 118 | // System.out.println("ForwardBackwardScaledCalculator:computeAlpha..."); 119 | alpha = new double[oseq.size()][hmm.getNbHiddenStates()]; 120 | 121 | for (int i = 0; i < hmm.getNbHiddenStates(); i++) 122 | computeAlphaInit(hmm, oseq.get(0), i); 123 | scale(ctFactors, alpha, 0); 124 | 125 | Iterator seqIterator = oseq.iterator(); 126 | if (seqIterator.hasNext()) 127 | seqIterator.next(); 128 | 129 | for (int t = 1; t < oseq.size(); t++) { 130 | DataPoint observation = seqIterator.next(); 131 | 132 | for (int i = 0; i < hmm.getNbHiddenStates(); i++) 133 | computeAlphaStep(hmm, observation, t, i); 134 | // scale the t-th alpha array by dividing [P(O1...Ot,St=qi) + 135 | // P(O1...Ot,St=qj], new alpha = P(St=qi|O1...Ot) 136 | scale(ctFactors, alpha, t); 137 | } 138 | } 139 | 140 | /* 141 | * Computes the content of the scaled beta array. The scaling factors are 142 | * those computed for alpha. 143 | * 144 | * hy: new beta = P(Ot+1,...ON|St=qi)/P(O1..Ot) 145 | */ 146 | protected void computeBeta(FeatureHMM hmm, List oseq) { 147 | // System.out.println("ForwardBackwardScaledCalculator:computeBeta..."); 148 | beta = new double[oseq.size()][hmm.getNbHiddenStates()]; 149 | 150 | for (int i = 0; i < hmm.getNbHiddenStates(); i++) 151 | beta[oseq.size() - 1][i] = 1. / ctFactors[oseq.size() - 1]; 152 | 153 | for (int t = oseq.size() - 2; t >= 0; t--) 154 | for (int i = 0; i < hmm.getNbHiddenStates(); i++) { 155 | computeBetaStep(hmm, oseq.get(t + 1), t, i); 156 | beta[t][i] /= ctFactors[t]; 157 | } 158 | } 159 | 160 | /* Normalize alpha[t] and put the normalization factor in ctFactors[t] */ 161 | // hy: this gives P(St=qj|O1..Ot) for scaled alpha, ctFactors[t] = 162 | // P(O1..Ot) 163 | // hy: array= new double[oseq.size()][hmm.getNbHiddenStates()]; 164 | private void scale(double[] ctFactors, double[][] array, int t) { 165 | // System.out.println("ForwardBackwardScaledCalculator:scale..."); 166 | double[] table = array[t]; 167 | double sum = 0.; 168 | 169 | // hy:i->nbStates 170 | for (int i = 0; i < table.length; i++) 171 | sum += table[i]; 172 | 173 | ctFactors[t] = sum; 174 | for (int i = 0; i < table.length; i++) 175 | table[i] /= sum; 176 | } 177 | 178 | // TODO 179 | // hy* probability = P(O1)*P(O1,O2)*...*P(O1,O2..On) (not conditional 180 | // P(O2|O1)....?) 181 | private void computeProbability(List oseq) { 182 | // System.out.println("ForwardBackwardScaledCalculator:computeProbability..."); 183 | lnProbability = 0.; 184 | 185 | // System.out.println(Arrays.deepToString(alpha ) + " " + 186 | // Arrays.deepToString(beta) + " " + Arrays.toString(ctFactors)); 187 | for (int t = 0; t < oseq.size(); t++) 188 | lnProbability += Math.log(ctFactors[t]); 189 | 190 | probability = Math.exp(lnProbability); 191 | if (verbose) 192 | System.out.println("probability:\t" + probability); 193 | } 194 | 195 | // added by JPG: 196 | public static double getLL(FeatureHMM hmm, 197 | AbstractList> students) { 198 | // System.out.println("ForwardBackwardScaledCalculator:getLL..."); 199 | double ll = 0; 200 | for (List student : students) { 201 | ForwardBackwardScaledCalculator fwbs = new ForwardBackwardScaledCalculator( 202 | student, hmm, EnumSet.of(Computation.ALPHA, Computation.BETA)); 203 | fwbs.computeProbability(student); 204 | // System.out.println("~" + fwbs.lnProbability); 205 | ll += fwbs.lnProbability; 206 | } 207 | return ll; 208 | } 209 | 210 | /** 211 | * Return the neperian logarithm of the probability of the sequence that 212 | * generated this object. 213 | * 214 | * @return The probability of the sequence of interest's neperian logarithm. 215 | */ 216 | public double lnProbability() { 217 | System.out.println("ForwardBackwardScaledCalculator:lnProbability..."); 218 | return lnProbability; 219 | } 220 | 221 | /* 222 | * Added by JPG. See 223 | * http://xenia.media.mit.edu/~rahimi/rabiner/rabiner-errata/ 224 | * rabiner-errata.html 225 | */ 226 | public double[][] getStateProbabilities() { 227 | // System.out 228 | // .println("ForwardBackwardScaledCalculator:getStateProbabilities..."); 229 | final double[][] p = new double[alpha.length][alpha[0].length]; 230 | 231 | // double ct_1 = 1; 232 | for (int t = 0; t < alpha.length; t++) { 233 | // ct_1 = ctFactors[t] / ct_1; 234 | p[t] = Matrix.dotmult(alpha[t], beta[t], ctFactors[t]); // I thought it 235 | // should be alpha 236 | // * beta *ct_1 237 | 238 | Matrix.assertProbability(p[t]); 239 | } 240 | 241 | return p; 242 | 243 | } 244 | 245 | } 246 | --------------------------------------------------------------------------------