├── .gitignore ├── README.org ├── pom.xml └── src ├── main └── java │ └── edu │ └── umass │ └── nlp │ ├── dimred │ ├── RandomProjection.java │ └── SmallSparseVector.java │ ├── examples │ └── HMMTest.java │ ├── exec │ ├── Execution.java │ ├── Opt.java │ ├── OptionManager.java │ ├── StandardOptionHandlers.java │ └── package.html │ ├── functional │ ├── CallbackFn.java │ ├── Double2DoubleFn.java │ ├── DoubleFn.java │ ├── FactoryFn.java │ ├── Fn.java │ ├── Functional.java │ ├── PredFn.java │ ├── PredFns.java │ └── Vector2DoubleFn.java │ ├── io │ ├── IOUtils.java │ └── ZipUtils.java │ ├── ml │ ├── F1Stats.java │ ├── LossFn.java │ ├── LossFns.java │ ├── Regularizers.java │ ├── classification │ │ ├── BasicClassifierDatum.java │ │ ├── BasicLabeledClassifierDatum.java │ │ ├── ClassifierDatum.java │ │ ├── IClassifier.java │ │ ├── IProbabilisticClassifier.java │ │ ├── LabeledClassifierDatum.java │ │ ├── MaxEntropyClassifier.java │ │ └── Reranker.java │ ├── feats │ │ ├── IPredExtractor.java │ │ ├── Predicate.java │ │ ├── PredicateManager.java │ │ └── WeightsManager.java │ ├── prob │ │ ├── AbstractDistribution.java │ │ ├── BasicConditionalDistribution.java │ │ ├── DirichletMultinomial.java │ │ ├── Distributions.java │ │ ├── IConditionalDistribution.java │ │ ├── IDistribution.java │ │ └── ISuffStats.java │ ├── regression │ │ └── LinearRegressionModel.java │ └── sequence │ │ ├── BasicLabelSeqDatum.java │ │ ├── CRF.java │ │ ├── ForwardBackwards.java │ │ ├── ILabeledSeqDatum.java │ │ ├── ISeqDatum.java │ │ ├── ProbabilisticSequenceModel.java │ │ ├── SequenceModel.java │ │ ├── State.java │ │ ├── StateSpace.java │ │ ├── StateSpaces.java │ │ ├── TokenF1Eval.java │ │ └── Transition.java │ ├── optimize │ ├── BacktrackingLineMinimizer.java │ ├── CachingDifferentiableFn.java │ ├── GradientDescent.java │ ├── IDifferentiableFn.java │ ├── ILineMinimizer.java │ ├── IOptimizer.java │ ├── LBFGSMinimizer.java │ └── OptimizeUtils.java │ ├── parallel │ └── ParallelUtils.java │ ├── process │ ├── Document.java │ └── Token.java │ ├── text │ └── HTMLUtils.java │ ├── trees │ ├── BasicTree.java │ ├── ITree.java │ └── Trees.java │ └── utils │ ├── AtomicDouble.java │ ├── BasicPair.java │ ├── BasicValued.java │ ├── Collections.java │ ├── CounterMap.java │ ├── Counters.java │ ├── DoubleArrays.java │ ├── ICounter.java │ ├── IHasProperties.java │ ├── IIndexed.java │ ├── ILockable.java │ ├── IMergable.java │ ├── IPair.java │ ├── ISpannable.java │ ├── IValuable.java │ ├── IValued.java │ ├── IWrapper.java │ ├── Indexer.java │ ├── LogAdder.java │ ├── MapCounter.java │ ├── Maxer.java │ ├── MergableUtils.java │ ├── MutableDouble.java │ ├── SloppyMath.java │ ├── Span.java │ └── StringUtils.java └── test └── java └── edu └── umass └── nlp └── exec └── ExecutionTest.java /.gitignore: -------------------------------------------------------------------------------- 1 | .hg* 2 | target/** 3 | *~ -------------------------------------------------------------------------------- /README.org: -------------------------------------------------------------------------------- 1 | * NLP Utils 2 | 3 | NLP Utilities in Java and Clojure from Aria Haghighi's grad school and post-doc days. 4 | 5 | ** What's Implemented? 6 | - Generic unconstrained numerical function optimization, including L-BFGS 7 | - Basic machine learning abstractions: sparse feature vectors and large-scale indexing 8 | - MaxEnt classification 9 | - Sequence CRF Model 10 | - Common dimensinality reduction techniques 11 | 12 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | edu.umass.nlp 7 | umass-nlp 8 | 1.0-SNAPSHOT 9 | 10 | 11 | commons-primitives 12 | commons-primitives 13 | 1.0 14 | 15 | 16 | org.yaml 17 | snakeyaml 18 | 1.6 19 | 20 | 21 | log4j 22 | log4j 23 | 1.2.14 24 | 25 | 26 | net.htmlparser.jericho 27 | jericho-html 28 | 3.1 29 | compile 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | src/main/clojure 48 | 49 | 50 | 51 | 52 | maven-compiler-plugin 53 | 54 | 1.6 55 | 1.6 56 | UTF-8 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/dimred/RandomProjection.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.dimred; 2 | 3 | import edu.umass.nlp.utils.DoubleArrays; 4 | 5 | import java.util.Random; 6 | 7 | 8 | /** 9 | * To do dimension reduction put your sparse matrix 10 | * in a SmallSparseVector[], call setInput and then 11 | * getProjectedMatrix where each row corresponds 12 | * to your reduced vectors. 13 | */ 14 | public class RandomProjection { 15 | 16 | private double[][] projected ; 17 | 18 | public void setInput(SmallSparseVector[] sparseX, int kPrincipalComponents) { 19 | int maxDimension = 0; 20 | for (SmallSparseVector vec: sparseX) { 21 | for (int i=0; i < vec.size(); ++i) { 22 | int dim = vec.getActiveDimension(i); 23 | maxDimension = Math.max(maxDimension,dim); 24 | } 25 | } 26 | double[][] randomMatrix = getRandomMatrix(maxDimension+1, kPrincipalComponents); 27 | projected = new double[sparseX.length][kPrincipalComponents]; 28 | for (int i=0; i < sparseX.length; ++i) { 29 | SmallSparseVector vec = sparseX[i]; 30 | for (int j=0; j < vec.size(); ++j) { 31 | int dim = vec.getActiveDimension(j); 32 | double count = vec.getActiveDimensionCount(j); 33 | if (count == 0.0) continue; 34 | for (int k=0; k < kPrincipalComponents; ++k) { 35 | double v = count * randomMatrix[dim][k]; 36 | if (v != 0.0) projected[i][k] += v; 37 | } 38 | } 39 | double vecLen = DoubleArrays.vectorLength(projected[i]); 40 | if (vecLen > 0.0) DoubleArrays.scaleInPlace(projected[i], 1.0/vecLen); 41 | } 42 | } 43 | 44 | private double[][] getRandomMatrix(int biggerDim, int smallerDim) { 45 | double[][] randomMatrix = new double[biggerDim][smallerDim]; 46 | Random rand = new Random(0); 47 | double[] probs = {1.0/6.0,2.0/3.0,1.0/6.0}; 48 | double[] vals = {Math.sqrt(3) * 1, 0, Math.sqrt(3) * -1}; 49 | for (int i=0; i < biggerDim; ++i) { 50 | for (int j=0; j < smallerDim; ++j) { 51 | int rIndex = DoubleArrays.sample(probs,rand); 52 | randomMatrix[i][j] = vals[rIndex]; 53 | } 54 | } 55 | return randomMatrix; 56 | } 57 | 58 | public double[][] getProjectedMatrix() { 59 | return projected; 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/dimred/SmallSparseVector.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.dimred; 2 | 3 | import java.io.Serializable; 4 | import java.util.Arrays; 5 | 6 | 7 | public class SmallSparseVector implements Serializable { 8 | 9 | private static final long serialVersionUID = 42L; 10 | 11 | float[] data = new float[0]; 12 | int[] indices = new int[0]; 13 | int length = 0; 14 | 15 | private void grow() { 16 | int curSize = data.length; 17 | int newSize = curSize + 10; 18 | 19 | float[] newData = new float[newSize]; 20 | System.arraycopy(data, 0, newData, 0, curSize); 21 | data = newData; 22 | int[] newIndices = new int[newSize]; 23 | System.arraycopy(indices, 0, newIndices, 0, curSize); 24 | for (int i=curSize; i < newIndices.length; ++i) { 25 | newIndices[i] = Integer.MAX_VALUE; 26 | newData[i] = Float.POSITIVE_INFINITY; 27 | } 28 | indices = newIndices; 29 | } 30 | 31 | public double getCount(int index) { 32 | int res = Arrays.binarySearch(indices, index); 33 | if (res >= 0 && res < length) { 34 | return data[res]; 35 | } 36 | return 0.0; 37 | } 38 | 39 | public void incrementCount(int index0, double x0) { 40 | double curCount = getCount(index0); 41 | setCount(index0, curCount + x0); 42 | } 43 | 44 | public int size() { 45 | return length; 46 | } 47 | 48 | public void setCount(int index0, double x0) { 49 | float x = (float) x0; 50 | // short index = (short) index0; 51 | int res = Arrays.binarySearch(indices, index0); 52 | // Greater than everything 53 | if (res >= 0 && res < length) { 54 | data[res] = x; 55 | return; 56 | } 57 | if (length+1 >= data.length) { 58 | grow(); 59 | } 60 | // In the middle 61 | int insertionPoint = -(res+1); 62 | assert insertionPoint >= 0 && insertionPoint <= length : String.format("length: %d insertion: %d",length,insertionPoint); 63 | // Shift The Stuff After 64 | System.arraycopy(data, insertionPoint, data, insertionPoint+1, length-insertionPoint); 65 | System.arraycopy(indices, insertionPoint, indices, insertionPoint+1, length-insertionPoint); 66 | indices[insertionPoint] = index0; 67 | data[insertionPoint] = x; 68 | length++; 69 | } 70 | 71 | public int getActiveDimension(int i) { 72 | assert i < indices.length; 73 | return indices[i]; 74 | } 75 | 76 | public double getActiveDimensionCount(int i) { 77 | assert i < data.length; 78 | return data[i]; 79 | } 80 | 81 | public double l2Norm() { 82 | double sum = 0.0; 83 | for (int i=0; i < length; ++i) { 84 | sum += data[i] * data[i]; 85 | } 86 | return Math.sqrt(sum); 87 | } 88 | 89 | public void scale(double c) { 90 | for (int i=0; i < length; ++i) { 91 | data[i] *= c; 92 | } 93 | } 94 | 95 | public String toString() { 96 | StringBuilder builder = new StringBuilder(); 97 | builder.append("{ "); 98 | for (int i=0; i < length; ++i) { 99 | builder.append(String.format("%d : %.5f",indices[i],data[i])); 100 | builder.append(" "); 101 | } 102 | builder.append(" }"); 103 | return builder.toString(); 104 | } 105 | 106 | public double dotProduct(SmallSparseVector other) { 107 | double sum = 0.0; 108 | for (int i=0; i < this.size(); ++i) { 109 | int dim = getActiveDimension(i); 110 | sum += this.getCount(dim) * other.getCount(dim); 111 | } 112 | return sum; 113 | } 114 | 115 | public static void main(String[] args) { 116 | SmallSparseVector sv = new SmallSparseVector(); 117 | sv.setCount(0, 1.0); 118 | sv.setCount(1, 2.0); 119 | sv.incrementCount(1, 1.0); 120 | sv.incrementCount(-1, 10.0); 121 | System.out.println(sv); 122 | } 123 | 124 | 125 | } 126 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/examples/HMMTest.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.examples; 2 | 3 | import edu.umass.nlp.functional.Fn; 4 | import edu.umass.nlp.functional.Functional; 5 | import edu.umass.nlp.io.IOUtils; 6 | import edu.umass.nlp.ml.prob.BasicConditionalDistribution; 7 | import edu.umass.nlp.ml.prob.DirichletMultinomial; 8 | import edu.umass.nlp.ml.prob.IConditionalDistribution; 9 | import edu.umass.nlp.ml.prob.IDistribution; 10 | import edu.umass.nlp.ml.sequence.*; 11 | import edu.umass.nlp.trees.ITree; 12 | import edu.umass.nlp.trees.Trees; 13 | import edu.umass.nlp.utils.DoubleArrays; 14 | import edu.umass.nlp.utils.IPair; 15 | 16 | import java.io.File; 17 | import java.util.*; 18 | 19 | public class HMMTest { 20 | 21 | O startWord, stopWord; 22 | double obsLambda = 1.0; 23 | double transLambda = 1.0e-4; 24 | StateSpace stateSpace; 25 | int numIters = 100; 26 | Random rand = new Random(0); 27 | 28 | HMMTest(List states,O startWord, O stopWord) { 29 | this.stateSpace = StateSpaces.makeFullStateSpace(states); 30 | this.startWord = startWord; 31 | this.stopWord = stopWord; 32 | } 33 | 34 | class HMM { 35 | 36 | IConditionalDistribution obsDistr; 37 | IConditionalDistribution transDistr; 38 | 39 | public HMM() { 40 | obsDistr = new BasicConditionalDistribution( 41 | new Fn>() { 42 | public IDistribution apply(State input) { 43 | return DirichletMultinomial.make(obsLambda); 44 | } 45 | }); 46 | transDistr = new BasicConditionalDistribution( 47 | new Fn>() { 48 | public IDistribution apply(State input) { 49 | return DirichletMultinomial.make(transLambda); 50 | } 51 | }); 52 | } 53 | 54 | public ForwardBackwards.Result random(List seq) { 55 | ForwardBackwards fb = new ForwardBackwards(stateSpace); 56 | // uniform potentials will give uniform marginals, we add some noise 57 | double[][] pots = new double[seq.size()-1][stateSpace.getTransitions().size()]; 58 | ForwardBackwards.Result fbRes = fb.compute(pots); 59 | for (double[] row : fbRes.stateMarginals) { 60 | DoubleArrays.addNoiseInPlace(row, rand, 1.0); 61 | } 62 | for (double[] row : fbRes.transMarginals) { 63 | DoubleArrays.addNoiseInPlace(row, rand, 1.0); 64 | } 65 | 66 | return fbRes; 67 | } 68 | 69 | public ForwardBackwards.Result doInference(List seq) { 70 | double[][] pots = new double[seq.size()-1][stateSpace.getTransitions().size()]; 71 | // fill potentials 72 | for (int i=0; i+1 < seq.size(); ++i) { 73 | for (Transition trans : stateSpace.getTransitions()) { 74 | // only first transition can be from start state 75 | if (i > 0 && trans.from.equals(stateSpace.startState)) continue; 76 | // only last transitiion can go to stop state 77 | if (i+2 < seq.size() && trans.to.equals(stateSpace.stopState)) continue; 78 | double logTransProb = transDistr.getDistribution(trans.from).getLogProb(trans.to); 79 | double logObsProb = i > 0 ? 80 | obsDistr.getDistribution(trans.from).getLogProb(seq.get(i)) : 81 | 0.0; 82 | pots[i][trans.index] = logObsProb + logTransProb; 83 | } 84 | } 85 | return (new ForwardBackwards(stateSpace)).compute(pots); 86 | } 87 | 88 | public void observe(ForwardBackwards.Result fbRes, List seq) { 89 | for (int i=1; i+1 < seq.size(); ++i) { 90 | O obs = seq.get(i); 91 | for (int j = 0; j < fbRes.stateMarginals[i].length; j++) { 92 | State state = stateSpace.getStates().get(j); 93 | double post = fbRes.stateMarginals[i][j]; 94 | obsDistr.observe(state,obs,post); 95 | } 96 | } 97 | for (int i=0; i+1 < seq.size(); ++i) { 98 | for (int t=0; t < fbRes.transMarginals[i].length; ++t) { 99 | Transition trans = stateSpace.getTransitions().get(t); 100 | double post = fbRes.transMarginals[i][t]; 101 | transDistr.observe(trans.from,trans.to, post); 102 | } 103 | } 104 | } 105 | 106 | public void mStep() { 107 | // should be a no-op, but the distributions 108 | // may want to do something (i.e LogLinear distributions) 109 | for (IPair> pair : obsDistr) { 110 | pair.getSecond().lock(); 111 | } 112 | for (IPair> pair : transDistr) { 113 | pair.getSecond().lock(); 114 | } 115 | } 116 | } 117 | 118 | public void learn(Iterable> data) { 119 | HMM hmm = new HMM(); 120 | for (int iter=0; iter < numIters; ++iter) { 121 | HMM newHMM = new HMM(); 122 | double logLike = 0.0; 123 | for (List datum : data) { 124 | // on first iteration, we randomize 125 | // E-Step since params are not initialized 126 | ForwardBackwards.Result res = iter == 0 ? 127 | hmm.random(datum) : 128 | hmm.doInference(datum); 129 | newHMM.observe(res, datum); 130 | logLike += res.logZ; 131 | } 132 | newHMM.mStep(); 133 | if (iter > 0) { 134 | System.out.println("negLogLike: " + (-logLike)); 135 | for (IPair> pair : newHMM.obsDistr) { 136 | System.out.println("for state: " + pair.getFirst() + " probs: " + pair.getSecond()); 137 | } 138 | } 139 | hmm = newHMM; 140 | } 141 | } 142 | 143 | public static void main(String[] args) { 144 | List states = new ArrayList(); 145 | for (int i=0; i < 5; ++i) { 146 | states.add(String.format("H%d",i)); 147 | } 148 | HMMTest test = new HMMTest(states,"",""); 149 | List> data = Functional.map( 150 | Trees.readTrees(IOUtils.text(IOUtils.readerFromResource("samples/trees.mrg"))), 151 | new Fn, List>() { 152 | public List apply(ITree input) { 153 | List tags = Trees.getLeafYield(input); 154 | tags.add(0, ""); 155 | tags.add(""); 156 | return tags; 157 | }}); 158 | test.learn(data); 159 | } 160 | 161 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/exec/Execution.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.exec; 2 | 3 | import edu.umass.nlp.functional.Fn; 4 | import edu.umass.nlp.io.IOUtils; 5 | import org.apache.log4j.*; 6 | import org.apache.log4j.spi.ErrorHandler; 7 | import org.apache.log4j.spi.Filter; 8 | import org.apache.log4j.spi.LoggingEvent; 9 | import org.apache.log4j.spi.RootLogger; 10 | 11 | import java.io.File; 12 | import java.text.SimpleDateFormat; 13 | import java.util.Calendar; 14 | import java.util.regex.Matcher; 15 | import java.util.regex.Pattern; 16 | 17 | 18 | public class Execution { 19 | 20 | private static OptionManager globalOptManager; 21 | 22 | public static class Opts { 23 | @Opt 24 | public String execPoolDir = "execs/"; 25 | 26 | @Opt 27 | public String execDir = null; 28 | 29 | @Opt 30 | public boolean appendDate = false; 31 | 32 | @Opt 33 | public String loggerPattern = "%-5p [%c]: %m%n"; 34 | 35 | @Opt 36 | public Level logLevel = Level.INFO; 37 | 38 | @Opt 39 | public String tag; 40 | 41 | private void createExecDir() { 42 | File rootDir = new File(execPoolDir); 43 | 44 | if (appendDate) { 45 | SimpleDateFormat sdf = new SimpleDateFormat("MM-dd"); 46 | Calendar c = Calendar.getInstance(); 47 | String dateStr = sdf.format(c.getTime()); 48 | rootDir = new File(rootDir, dateStr); 49 | } 50 | 51 | if (!rootDir.exists()) { 52 | rootDir.mkdirs(); 53 | } 54 | 55 | File[] files = rootDir.listFiles(); 56 | int lastNum = 0; 57 | for (File file : files) { 58 | String fname = file.getName(); 59 | Matcher matcher = Pattern.compile("(\\d+).exec").matcher(fname); 60 | if (matcher.matches()) { 61 | int num = Integer.parseInt(matcher.group(1)); 62 | if (num >= lastNum) lastNum = num + 1; 63 | } 64 | } 65 | File toCreate = new File(rootDir, "" + lastNum + ".exec"); 66 | toCreate.mkdir(); 67 | execDir = toCreate.getPath(); 68 | } 69 | 70 | 71 | public void init() { 72 | if (execDir == null) { 73 | createExecDir(); 74 | } 75 | } 76 | 77 | } 78 | 79 | public static Opts opts; 80 | 81 | public static String getExecutionDirectory() { 82 | return (new File(opts.execDir)).getAbsolutePath(); 83 | } 84 | 85 | public static T fillOptions(String group, T o) { 86 | return (T) globalOptManager.fillOptions(group, o); 87 | } 88 | 89 | public static T fillOptions(String group, Class type) { 90 | try { 91 | return fillOptions(group,type.newInstance()); 92 | } catch (Exception e) { 93 | e.printStackTrace(); 94 | System.exit(0); 95 | } 96 | return null; 97 | } 98 | 99 | public static void addOptionHandler(Class type, Fn handler) { 100 | globalOptManager.addOptionHandler(type,handler); 101 | } 102 | 103 | public static void init(String configFile) { 104 | if (configFile != null) { 105 | globalOptManager = new OptionManager(configFile); 106 | globalOptManager.addOptionHandler(Level.class, StandardOptionHandlers.logLevelHandler); 107 | globalOptManager.addOptionHandler(File.class, StandardOptionHandlers.fileHandler); 108 | opts = (Opts) globalOptManager.fillOptions("exec", new Opts()); 109 | 110 | } else { 111 | opts = new Opts(); 112 | } 113 | opts.init(); 114 | if (configFile != null) { 115 | IOUtils.copy(configFile, getExecutionDirectory() + "/config.yaml"); 116 | } 117 | initRootLogger(); 118 | Logger logger = Logger.getLogger("Execution") ; 119 | logger.info("ExecutionDirectory: " + getExecutionDirectory()); 120 | boolean created = (new File(getExecutionDirectory())).mkdirs(); 121 | if (created) { 122 | logger.info("Created " + getExecutionDirectory()); 123 | } 124 | if (opts.tag != null) Logger.getLogger("Execution").info("tag: " + opts.tag); 125 | 126 | } 127 | 128 | private static boolean rootLoggerInited = false; 129 | 130 | private static void initRootLogger() { 131 | if (rootLoggerInited) return; 132 | rootLoggerInited = true; 133 | try { 134 | Logger.getRootLogger().addAppender( 135 | new FileAppender( 136 | new PatternLayout(opts.loggerPattern), 137 | (new File(getExecutionDirectory(),"out.log")).getAbsolutePath())); 138 | Logger.getRootLogger().addAppender(new ConsoleAppender(new PatternLayout(opts.loggerPattern), "System.out")); 139 | Logger.getRootLogger().setLevel(opts.logLevel); 140 | } catch (Exception e) { 141 | e.printStackTrace(); 142 | System.exit(0); 143 | } 144 | } 145 | 146 | } 147 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/exec/Opt.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.exec; 2 | 3 | import java.lang.annotation.ElementType; 4 | import java.lang.annotation.Retention; 5 | import java.lang.annotation.RetentionPolicy; 6 | import java.lang.annotation.Target; 7 | 8 | @Retention(RetentionPolicy.RUNTIME) 9 | @Target(value = {ElementType.FIELD,ElementType.METHOD}) 10 | public @interface Opt { 11 | public abstract String name() default "[unassigned]"; 12 | public abstract String gloss() default ""; 13 | public abstract boolean required() default false; 14 | public abstract String defaultVal() default ""; 15 | 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/exec/OptionManager.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.exec; 2 | 3 | import edu.umass.nlp.functional.Fn; 4 | import edu.umass.nlp.io.IOUtils; 5 | import org.apache.log4j.Logger; 6 | import org.yaml.snakeyaml.Yaml; 7 | 8 | import java.io.File; 9 | import java.lang.annotation.Annotation; 10 | import java.lang.reflect.Field; 11 | import java.util.HashMap; 12 | import java.util.HashSet; 13 | import java.util.Map; 14 | import java.util.Set; 15 | 16 | public class OptionManager { 17 | 18 | private final Map> globalOpts; 19 | private final Logger logger = Logger.getLogger("OptionManager"); 20 | private final Map> handlers = new HashMap>(); 21 | 22 | public OptionManager(String confFile) { 23 | globalOpts = (Map) (new Yaml()).load(IOUtils.text(confFile)); 24 | } 25 | 26 | public static String getOptName(Opt opt, Field f) { 27 | Set names = new HashSet(); 28 | if (opt != null && !opt.name().equals("[unassigned]")) { 29 | return opt.name(); 30 | } 31 | return f.getName(); 32 | } 33 | 34 | private Object convertToType(Class type, String val) throws Exception { 35 | Fn handler = handlers.get(type); 36 | if (handler != null) { 37 | return handler.apply(val); 38 | } 39 | if (type.equals(int.class) || type.equals(Integer.class)) { 40 | return Integer.parseInt(val); 41 | } 42 | if (type.equals(float.class) || type.equals(Float.class)) { 43 | return Float.parseFloat(val); 44 | } 45 | if (type.equals(double.class) || type.equals(Double.class)) { 46 | return Double.parseDouble(val); 47 | } 48 | if (type.equals(short.class) || type.equals(Short.class)) { 49 | return Short.parseShort(val); 50 | } 51 | if (type.equals(boolean.class) || type.equals(Boolean.class)) { 52 | return !(val != null && val.equalsIgnoreCase("false")); 53 | } 54 | if (type.isEnum()) { 55 | Object[] objs = ((Class) type).getEnumConstants(); 56 | for (int i = 0; i < objs.length; ++i) { 57 | Object enumConst = objs[i]; 58 | if (enumConst.toString().equalsIgnoreCase(val)) { 59 | return enumConst; 60 | } 61 | } 62 | } 63 | if (type.equals(File.class)) { 64 | File f = new File(val); 65 | if (!f.exists()) { 66 | logger.warn(String.format("File %s doesn't exits\n", f.getAbsolutePath())); 67 | } 68 | return f; 69 | } 70 | return val; 71 | } 72 | 73 | public void addOptionHandler(Class type, Fn handler) { 74 | handlers.put(type, handler); 75 | } 76 | 77 | public Object fillOptions(String optGroup, Object o) { 78 | final Map localOpts = (Map) globalOpts.get(optGroup); 79 | if (localOpts == null) { 80 | logger.warn("Couldn't find request optionGroup " + optGroup); 81 | return o; 82 | } 83 | return fillOptions(localOpts, o); 84 | } 85 | 86 | public Object fillOptions(Map localOpts, Object o) { 87 | Class c = (o instanceof Class) ? ((Class) o) : o.getClass(); 88 | for (Field f : c.getFields()) { 89 | Opt opt = f.getAnnotation(Opt.class); 90 | String optName = getOptName(opt, f); 91 | Object optVal = localOpts.get(optName); 92 | if (optVal == null) continue; 93 | if ((optVal instanceof String) || (optVal instanceof Boolean) || (optVal instanceof Double) || (optVal instanceof Integer)) { 94 | try { 95 | f.set(o, convertToType(f.getType(), optVal.toString())); 96 | } catch (Exception e) { 97 | logger.warn("Error setting " + optName + 98 | " with value " + optVal + "for class " + o.getClass().getSimpleName()); 99 | } 100 | } else if (optVal instanceof Map) { 101 | try { 102 | f.set(o, fillOptions((Map) optVal, f.getType().newInstance())); 103 | } catch (Exception e) { 104 | logger.warn("Error setting " + optName + 105 | " with value " + optVal + "for class " + o.getClass().getSimpleName()); 106 | } 107 | } else { 108 | throw new RuntimeException("Bad YAML Entry for " + optName + " with val " + optVal); 109 | } 110 | } 111 | return o; 112 | } 113 | 114 | } 115 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/exec/StandardOptionHandlers.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.exec; 2 | 3 | import edu.umass.nlp.functional.Fn; 4 | import org.apache.log4j.Level; 5 | 6 | import java.io.File; 7 | 8 | public class StandardOptionHandlers 9 | { 10 | 11 | public static Fn fileHandler = new Fn() { 12 | public Object apply(String input) { 13 | return new File(input); 14 | }}; 15 | 16 | public static Fn logLevelHandler = new Fn() { 17 | public Object apply(String input) { 18 | try { 19 | Level logLevel = Level.toLevel(input); 20 | return logLevel; 21 | } catch (Exception e) { 22 | e.printStackTrace(); 23 | System.exit(0); 24 | } 25 | return null; 26 | }}; 27 | 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/exec/package.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | To use Execution framework, in the main of you entry point, the 4 | first line should be 5 |
6 | 7 | edu.umass.nlp.exec.Execution.init(pathToConfigFile); 8 | 9 | The pathToConfigFile should be a string to a 10 | yaml file that stores global options. Global options are used to populate 11 | objects with public mutable fields using reflection. The purpose of this 12 | is to provide easy (but not secure or robust) option management. 13 | 14 | The Execution framework does many thing, here's a summmary 15 | 16 |

Global Option Configuration

17 | 18 | Options are grouped together in a hierarchichal fashion. For instance, the 19 | main Execution option would be given in yaml by 20 | 21 |
22 |    
23 |    exec:
24 |     execPoolDir: execs
25 |     loggerPattern: %-5p [%c]: %m%n
26 |   
27 |   
28 | 29 | So in your code when you call: 30 |
31 |   
32 |     Execution.Opts execOpts = Execution.fillOptions("exec", new Execution.Opts());
33 |                                                                           
34 |   
35 | 36 | We reflectively look up options under the exec part of the configuration 37 | file and fills in fields of the passed in objects. You can fill in different instances 38 | of the same options object by using different group names (e.g. exec1, exec2, ...). 39 | 40 | You can also do hierarchical option filling. 41 | 42 |

Store Execution Log and Options in a Directory

43 | 44 | Another feature of the Execution framework is that the log of every run goes to a directory 45 | specified in the exec.execDir directory. Typically, you shouldn't specify the 46 | directory and instead use the option exec.execPoolDir which will automatically 47 | make a new directory for each execution run by adding 0.exec,1.exec,2.exec,... 48 | as needed. The directory will store everything sent to the logger as well as a copy of configuration 49 | needed to re-run the experiment (modulo code changes obviously). 50 | 51 | If you want to store other output in the execution directory, you have access to it in 52 | Execution.getExecutionDirectry. 53 | 54 | You can add option processing behavior using 55 | Execution.addOptionHandler. See 56 | StandardOptionHandlers for example 57 | option handlers. 58 | 59 |

Apache Logger Configuration

60 | 61 | The Execution framework also configures the log4j logger (the pattern 62 | for the logger prefix is configurable via the exec.loggerPattern option 63 | in your global config file. 64 | 65 | This means in your code you should probably not use System.out.println and 66 | opt instead to use the logger. You can read about the Apache Logger system here. 67 | 68 | Two relevant configurable options are: Execution.Opts.loggerPattern 69 | and Execution.Opts.logLevel. 70 | 71 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/functional/CallbackFn.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.functional; 2 | 3 | public interface CallbackFn { 4 | public void callback(Object... args); 5 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/functional/Double2DoubleFn.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.functional; 2 | 3 | public interface Double2DoubleFn { 4 | public double valAt(double x); 5 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/functional/DoubleFn.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.functional; 2 | 3 | public interface DoubleFn { 4 | public double valAt(T x); 5 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/functional/FactoryFn.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.functional; 2 | 3 | public interface FactoryFn { 4 | 5 | public T make(); 6 | 7 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/functional/Fn.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.functional; 2 | 3 | import java.io.Serializable; 4 | 5 | public interface Fn extends Serializable { 6 | public O apply(I input); 7 | 8 | public static class ConstantFn implements Fn 9 | { 10 | 11 | private O c; 12 | 13 | public ConstantFn(O c) { 14 | this.c = c; 15 | } 16 | 17 | public O apply(I input) { 18 | return c; 19 | } 20 | } 21 | 22 | public static class IdentityFn implements Fn 23 | { 24 | 25 | public I apply(I input) 26 | { 27 | return input; 28 | } 29 | } 30 | 31 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/functional/Functional.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.functional; 2 | 3 | import edu.umass.nlp.ml.sequence.CRF; 4 | import edu.umass.nlp.ml.sequence.ILabeledSeqDatum; 5 | import edu.umass.nlp.utils.*; 6 | 7 | import java.util.*; 8 | 9 | 10 | /** 11 | * Collection of Functional Utilities you'd 12 | * find in any functional programming language. 13 | * Things like map, filter, reduce, etc.. 14 | * 15 | */ 16 | public class Functional { 17 | 18 | 19 | public static List take(Iterator it, int n) { 20 | List result = new ArrayList(); 21 | for (int i=0; i < n && it.hasNext(); ++i) { 22 | result.add(it.next()); 23 | } 24 | return result; 25 | } 26 | 27 | 28 | public static IValued findMax(Iterable xs, DoubleFn fn) { 29 | double max = Double.NEGATIVE_INFINITY; 30 | T argMax = null; 31 | for (T x : xs) { 32 | double val = fn.valAt(x); 33 | if (val > max) { max = val ; argMax = x; } 34 | } 35 | return BasicValued.make(argMax,max); 36 | } 37 | 38 | public static IValued findMin(Iterable xs, Fn fn) { 39 | double min= Double.POSITIVE_INFINITY; 40 | T argMin = null; 41 | for (T x : xs) { 42 | double val = fn.apply(x); 43 | if (val < min) { min= val ; argMin = x; } 44 | } 45 | return BasicValued.make(argMin,min); 46 | } 47 | 48 | public static Map map(Map map, Fn fn, PredFn pred, Map resultMap) { 49 | for (Map.Entry entry: map.entrySet()) { 50 | K key = entry.getKey(); 51 | I inter = entry.getValue(); 52 | if (pred.holdsAt(key)) resultMap.put(key, fn.apply(inter)); 53 | } 54 | return resultMap; 55 | } 56 | 57 | public static Map mapPairs(Iterable lst, Fn fn) 58 | { 59 | return mapPairs(lst,fn,new HashMap()); 60 | } 61 | 62 | public static Map mapPairs(Iterable lst, Fn fn, Map resultMap) 63 | { 64 | for (I input: lst) { 65 | O output = fn.apply(input); 66 | resultMap.put(input,output); 67 | } 68 | return resultMap; 69 | } 70 | 71 | public static List map(Iterable lst, Fn fn) { 72 | return map(lst,fn,(PredFn) PredFns.getTruePredicate()); 73 | } 74 | 75 | public static Iterator map(final Iterator it, final Fn fn) { 76 | return new Iterator() { 77 | public boolean hasNext() { 78 | return it.hasNext(); 79 | } 80 | 81 | public O next() { 82 | return fn.apply(it.next()); 83 | } 84 | 85 | public void remove() { 86 | throw new RuntimeException("remove() not supported"); 87 | } 88 | }; 89 | } 90 | 91 | public static Map makeMap(Iterable elems, Fn fn, Map map) { 92 | for (I elem : elems) { 93 | map.put(elem, fn.apply(elem)); 94 | } 95 | return map; 96 | } 97 | 98 | public static Map makeMap(Iterable elems, Fn fn) { 99 | return makeMap(elems, fn, new HashMap()) ; 100 | } 101 | 102 | public static List flatMap(Iterable lst, 103 | Fn> fn) { 104 | PredFn> p = PredFns.getTruePredicate(); 105 | return flatMap(lst,fn,p); 106 | } 107 | 108 | 109 | public static List flatMap(Iterable lst, 110 | Fn> fn, 111 | PredFn> pred) { 112 | List> lstOfLsts = map(lst,fn,pred); 113 | List init = new ArrayList(); 114 | return reduce(lstOfLsts, init, 115 | new Fn, List>, List>() { 116 | public List apply(IPair, List> input) { 117 | List result = input.getFirst(); 118 | result.addAll(input.getSecond()); 119 | return result; 120 | } 121 | }); 122 | } 123 | 124 | public static O reduce(Iterable inputs, 125 | O initial, 126 | Fn,O> fn) { 127 | O output = initial; 128 | for (I input: inputs) { 129 | output = fn.apply(BasicPair.make(output,input)); 130 | } 131 | return output; 132 | } 133 | 134 | public static List map(Iterable lst, Fn fn, PredFn pred) { 135 | List outputs = new ArrayList(); 136 | for (I input: lst) { 137 | O output = fn.apply(input); 138 | if (pred.holdsAt(output)) { 139 | outputs.add(output); 140 | } 141 | } 142 | return outputs; 143 | } 144 | 145 | public static List filter(final Iterable lst, final PredFn pred) { 146 | List ret = new ArrayList(); 147 | for (I input : lst) { 148 | if (pred.holdsAt(input)) ret.add(input); 149 | } 150 | return ret; 151 | } 152 | 153 | 154 | public static T first(Iterable objs, PredFn pred) { 155 | for (T obj : objs) { 156 | if (pred.holdsAt(obj)) return obj; 157 | } 158 | return null; 159 | } 160 | 161 | 162 | 163 | public static List range(int n) { 164 | List result = new ArrayList(); 165 | for (int i = 0; i < n; i++) { 166 | result.add(i); 167 | } 168 | return result; 169 | } 170 | 171 | /** 172 | * 173 | * @return 174 | */ 175 | public static boolean any(Iterable elems, PredFn p) { 176 | for (T elem : elems) { 177 | if (p.holdsAt(elem)) return true; 178 | } 179 | return false; 180 | } 181 | 182 | public static boolean all(Iterable elems, PredFn p) { 183 | for (T elem : elems) { 184 | if (!p.holdsAt(elem)) return false; 185 | } 186 | return true; 187 | } 188 | 189 | 190 | public static T find(Iterable elems, PredFn pred) { 191 | return first(elems, pred); 192 | } 193 | 194 | public static int findIndex(Iterable elems, PredFn pred) { 195 | int index = 0; 196 | for (T elem : elems) { 197 | if (pred.holdsAt(elem)) return index; 198 | index += 1; 199 | } 200 | return -1; 201 | } 202 | 203 | public static List indicesWhere(Iterable elems, PredFn pred) { 204 | List res = new ArrayList(); 205 | int index = 0; 206 | for (T elem : elems) { 207 | if (pred.holdsAt(elem)) { 208 | res.add(index); 209 | } 210 | index ++; 211 | } 212 | return res; 213 | } 214 | 215 | public static String mkString(Iterable elems, String start, String middle, String stop) { 216 | return mkString(elems, start, middle, stop, null); 217 | } 218 | 219 | public static String mkString(Iterable elems, String start, String middle, String stop,Fn strFn) { 220 | StringBuilder sb = new StringBuilder(); 221 | sb.append(start); 222 | Iterator it = elems.iterator(); 223 | while (it.hasNext()) { 224 | T t = it.next(); 225 | sb.append((strFn != null ? strFn.apply(t) : t.toString())); 226 | if (it.hasNext()) { 227 | sb.append(middle); 228 | } 229 | } 230 | sb.append(stop); 231 | return sb.toString(); 232 | } 233 | 234 | public static String mkString(Iterable elems) { 235 | return mkString(elems,"(",",",")",null); 236 | } 237 | 238 | public static List takeWhile(Iterable elems, PredFn pred) { 239 | Iterator it = elems.iterator(); 240 | return takeWhile(it,pred); 241 | } 242 | 243 | public static List takeWhile(Iterator it, PredFn pred) { 244 | List res = new ArrayList(); 245 | while (it.hasNext()) { 246 | T elem = it.next(); 247 | if (pred.holdsAt(elem)) res.add(elem); 248 | else break; 249 | } 250 | return res; 251 | } 252 | 253 | public static List 254 | rangesWhere(Iterable elems, PredFn pred) { 255 | int index = 0; 256 | int lastStart = -1; 257 | List res = new ArrayList(); 258 | for (T elem: elems) { 259 | boolean matches = pred.holdsAt(elem); 260 | if (matches && lastStart < 0) { 261 | lastStart = index; 262 | } 263 | if (!matches && lastStart >= 0) { 264 | res.add(new Span(lastStart,index)); 265 | lastStart = -1; 266 | } 267 | index += 1; 268 | } 269 | if (lastStart >= 0) { 270 | res.add(new Span(lastStart, index)); 271 | } 272 | return res; 273 | } 274 | 275 | public static List> subseqsWhere(List elems, PredFn pred) { 276 | List ranges = rangesWhere(elems, pred); 277 | List> res = new ArrayList>(); 278 | for (Span span: ranges) { 279 | res.add(new ArrayList(elems.subList(span.getStart(), span.getStop()))); 280 | } 281 | return res; 282 | } 283 | 284 | public static FactoryFn curry(final Fn fn, final A fixed) { 285 | return new FactoryFn() { 286 | public R make() { 287 | return fn.apply(fixed); 288 | } 289 | }; 290 | } 291 | 292 | public static Iterable lazyMap(final Iterable xs, final Fn fn) { 293 | return new Iterable() { 294 | public Iterator iterator() { 295 | return new Iterator() { 296 | 297 | private Iterator it = xs.iterator(); 298 | 299 | public boolean hasNext() { 300 | return it.hasNext(); 301 | } 302 | 303 | public O next() { 304 | return fn.apply(it.next()); 305 | } 306 | 307 | public void remove() { 308 | throw new RuntimeException("Not Implemented"); 309 | } 310 | }; 311 | } 312 | }; 313 | } 314 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/functional/PredFn.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.functional; 2 | 3 | public interface PredFn { 4 | 5 | public boolean holdsAt(T elem); 6 | 7 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/functional/PredFns.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.functional; 2 | 3 | public class PredFns { 4 | 5 | public static PredFn getTruePredicate() { 6 | return new PredFn() { 7 | public boolean holdsAt(T elem) { 8 | return true; 9 | }}; 10 | } 11 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/functional/Vector2DoubleFn.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.functional; 2 | 3 | public interface Vector2DoubleFn { 4 | public double valAt(double[] x); 5 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/io/IOUtils.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.io; 2 | 3 | 4 | import edu.umass.nlp.functional.Functional; 5 | 6 | import java.io.*; 7 | import java.util.ArrayList; 8 | import java.util.Collections; 9 | import java.util.Iterator; 10 | import java.util.List; 11 | import java.util.zip.GZIPInputStream; 12 | import java.util.zip.ZipInputStream; 13 | 14 | public class IOUtils { 15 | 16 | public static Iterable lazyLines(final InputStream is) { 17 | try { 18 | return lazyLines(new InputStreamReader(is)); 19 | } catch (Exception e) { 20 | e.printStackTrace(); 21 | } 22 | throw new IllegalStateException(); 23 | } 24 | 25 | public static Iterable lazyLines(final File path) { 26 | try { 27 | return lazyLines(new FileReader(path)); 28 | } catch (Exception e) { 29 | e.printStackTrace(); 30 | } 31 | throw new IllegalStateException(); 32 | } 33 | 34 | public static Iterable lazyLines(final String path) { 35 | try { 36 | return lazyLines(new FileReader(path)); 37 | } catch (Exception e) { 38 | e.printStackTrace(); 39 | } 40 | throw new IllegalStateException(); 41 | } 42 | 43 | public static Iterable lazyLines(final Reader reader) { 44 | return new Iterable() { 45 | public Iterator iterator() { 46 | final BufferedReader buffered = new BufferedReader(reader); 47 | return new Iterator() { 48 | private String nextLine; 49 | private boolean consumed = true; 50 | 51 | private void queue() { 52 | if (!consumed) return; 53 | try { 54 | nextLine = buffered.readLine(); 55 | consumed = false; 56 | } catch (Exception e) { 57 | e.printStackTrace(); 58 | System.exit(0); 59 | } 60 | } 61 | 62 | public boolean hasNext() { 63 | queue(); 64 | return nextLine != null; 65 | } 66 | 67 | public String next() { 68 | queue(); 69 | String ret = nextLine; 70 | consumed = true; 71 | return ret; 72 | } 73 | 74 | public void remove() { 75 | throw new RuntimeException("Not Implemented"); 76 | } 77 | }; 78 | } 79 | }; 80 | } 81 | 82 | public static List lines(InputStream is) { 83 | return lines(new InputStreamReader(is)); 84 | } 85 | 86 | public static List lines(String f) { return lines(new File(f)); } 87 | 88 | public static List lines(File f) { 89 | try { 90 | Reader r = new FileReader(f); 91 | return lines(r); 92 | } catch (Exception e) { 93 | e.printStackTrace(); 94 | } 95 | return Collections.emptyList(); 96 | } 97 | 98 | public static List lines(Reader r) { 99 | List res = new ArrayList(); 100 | try { 101 | BufferedReader br = new BufferedReader(r); 102 | while (true) { 103 | String line = br.readLine(); 104 | if (line == null) break; 105 | res.add(line); 106 | } 107 | } catch (Exception e) { 108 | e.printStackTrace(); 109 | } 110 | return res; 111 | } 112 | 113 | public static Reader reader(File f) { 114 | try { 115 | return new FileReader(f); 116 | } catch (Exception e) { 117 | e.printStackTrace(); 118 | } 119 | return null; 120 | } 121 | 122 | public static InputStream inputStream(String name) { 123 | try { 124 | InputStream is = new FileInputStream(name); 125 | if (name.endsWith(".gz")) return new GZIPInputStream(is); 126 | if (name.endsWith(".zip")) return new ZipInputStream(is); 127 | return is; 128 | } catch (Exception e) { 129 | e.printStackTrace(); 130 | } 131 | return null; 132 | } 133 | 134 | public static Reader readerFromResource(String resourcePath) { 135 | return new InputStreamReader(ClassLoader.getSystemResourceAsStream(resourcePath)); 136 | } 137 | 138 | public static List linesFromResource(String resourcePath) { 139 | return lines(readerFromResource(resourcePath)); 140 | } 141 | 142 | public static boolean exists(String f) { 143 | return (new File(f)).exists(); 144 | } 145 | 146 | public static boolean exists(File f) { 147 | return f.exists(); 148 | } 149 | 150 | public static String changeExt(String path, String newExt) { 151 | if (!newExt.startsWith(".")) { 152 | newExt = "." + newExt; 153 | } 154 | return path.replaceAll("\\.[^.]+$",newExt); 155 | } 156 | 157 | public static String changeDir(String path, String newDir) { 158 | File f = new File(path); 159 | return (new File(newDir,f.getName())).getPath(); 160 | } 161 | 162 | public static String text(InputStream is) { 163 | return Functional.mkString(lines(is),"","\n",""); 164 | } 165 | 166 | public static String text(String path) { 167 | return text(new File(path)); 168 | } 169 | 170 | public static String text(Reader r) { 171 | return Functional.mkString(lines(r),"","\n",""); 172 | } 173 | 174 | public static String text(File f) { 175 | return Functional.mkString(lines(f),"","\n",""); 176 | } 177 | 178 | public static void writeLines(File f, List lines) { 179 | try { 180 | PrintWriter writer = new PrintWriter(new FileWriter(f)); 181 | for (String line : lines) { 182 | writer.println(line); 183 | } 184 | writer.flush(); 185 | writer.close(); 186 | } catch (Exception e) { 187 | e.printStackTrace(); 188 | } 189 | } 190 | 191 | public static void writeLines(String f, List lines) { 192 | writeLines(new File(f), lines); 193 | } 194 | 195 | public static List readObjects(InputStream is) { 196 | List ret = new ArrayList(); 197 | try { 198 | ObjectInputStream ois = new ObjectInputStream(is); 199 | while (true) { 200 | Object o = ois.readObject(); 201 | if (o == null) break; 202 | ret.add(o); 203 | } 204 | } catch (Exception e) { 205 | e.printStackTrace(); 206 | } 207 | return ret; 208 | } 209 | 210 | public static Object readObject(String path) { 211 | InputStream is = inputStream(path); 212 | return readObject(is); 213 | } 214 | 215 | public static Object readObject(InputStream is) { 216 | try { 217 | ObjectInputStream ois = new ObjectInputStream(is); 218 | return ois.readObject(); 219 | } catch (Exception e) { 220 | e.printStackTrace(); 221 | } 222 | return null; 223 | } 224 | 225 | public static List readObjects(String path) { 226 | InputStream is = inputStream(path); 227 | List ret = readObjects(is); 228 | try { 229 | is.close(); 230 | } catch (Exception e) { 231 | e.printStackTrace(); 232 | } 233 | return ret; 234 | } 235 | 236 | public static void writeObject(Object o, String path) { 237 | try { 238 | OutputStream os = new FileOutputStream(new File(path)); 239 | ObjectOutputStream oos = new ObjectOutputStream(os); 240 | oos.writeObject(o); 241 | oos.close(); 242 | } catch (Exception e) { 243 | e.printStackTrace(); 244 | } 245 | } 246 | 247 | public static PrintWriter getPrintWriter(String path) { 248 | try { 249 | return new PrintWriter(new FileWriter(path)); 250 | } catch (Exception e) { 251 | e.printStackTrace(); 252 | return null; 253 | } 254 | } 255 | 256 | public static void copy(String src, String dest) { 257 | List lines = IOUtils.lines(src); 258 | IOUtils.writeLines(dest,lines); 259 | } 260 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/io/ZipUtils.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.io; 2 | 3 | 4 | import edu.umass.nlp.functional.CallbackFn; 5 | import edu.umass.nlp.functional.Fn; 6 | 7 | import java.io.IOException; 8 | import java.io.InputStream; 9 | import java.util.List; 10 | import java.util.zip.ZipEntry; 11 | import java.util.zip.ZipFile; 12 | import java.util.zip.ZipOutputStream; 13 | 14 | public class ZipUtils { 15 | 16 | public static ZipFile getZipFile(String name) { 17 | try { 18 | return new ZipFile(name); 19 | } catch (Exception e) { 20 | e.printStackTrace(); 21 | } 22 | return null; 23 | } 24 | 25 | public static InputStream getEntryInputStream(ZipFile zf, String entryName) { 26 | try { 27 | return zf.getInputStream(zf.getEntry(entryName)); 28 | } catch (Exception e) { 29 | e.printStackTrace(); 30 | } 31 | return null; 32 | } 33 | 34 | public static List getEntryLines(ZipFile zf, String entryName) { 35 | try { 36 | return IOUtils.lines(getEntryInputStream(zf, entryName)); 37 | } catch (Exception e) { e.printStackTrace(); } 38 | return null; 39 | } 40 | 41 | public static boolean entryExists(ZipFile zipFile, String entryName) { 42 | return zipFile.getEntry(entryName) != null; 43 | } 44 | 45 | public static void main(String[] args) { 46 | ZipFile root = ZipUtils.getZipFile(args[0]); 47 | 48 | } 49 | 50 | public static void doZipEntry(ZipOutputStream zos, String entryName, CallbackFn entryFn) { 51 | try { 52 | ZipEntry ze = new ZipEntry(entryName); 53 | zos.putNextEntry(ze); 54 | entryFn.callback(); 55 | zos.closeEntry(); 56 | } catch (IOException e) { 57 | e.printStackTrace(); 58 | } 59 | } 60 | 61 | public static void print(ZipOutputStream zos, String text) { 62 | try { 63 | zos.write(text.getBytes()); 64 | } catch (IOException e) { 65 | e.printStackTrace(); 66 | } 67 | } 68 | 69 | public static void println(ZipOutputStream zos, String text) { print(zos, text + "\n"); } 70 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/F1Stats.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml; 2 | 3 | import edu.umass.nlp.utils.IMergable; 4 | 5 | public class F1Stats implements IMergable { 6 | 7 | public int tp = 0, fp = 0, fn = 0; 8 | public final String label; 9 | 10 | public F1Stats(String label) { 11 | this.label = label; 12 | } 13 | 14 | public void merge(F1Stats other) { 15 | tp += other.tp; 16 | fp += other.fp; 17 | fn += other.fn; 18 | } 19 | 20 | public double getPrecision() { 21 | if (tp + fp > 0.0) { 22 | return (tp / (tp + fp + 0.0)); 23 | } else { 24 | return 0.0; 25 | } 26 | } 27 | 28 | public double getRecall() { 29 | if (tp + fn > 0.0) { 30 | return (tp / (tp + fn + 0.0)); 31 | } else { 32 | return 0.0; 33 | } 34 | } 35 | 36 | public double getFMeasure(double beta) { 37 | double p = getPrecision(); 38 | double r = getRecall(); 39 | if (p + r > 0.0) { 40 | return ((1+beta*beta)* p * r) / ((beta*beta)*p + r); 41 | } else { 42 | return 0.0; 43 | } 44 | } 45 | 46 | public void observe(String trueLabel, String guessLabel) { 47 | assert (label.equals(trueLabel) || label.equals(guessLabel)); 48 | if (label.equals(trueLabel)) { 49 | if (trueLabel.equals(guessLabel)) { 50 | tp++; 51 | } else { 52 | fn++; 53 | } 54 | } else { 55 | fp++; 56 | } 57 | if (trueLabel.equals(label)) { 58 | tp++; 59 | } else if (label.equals(trueLabel)) { 60 | fn++; 61 | } else if (label.equals(guessLabel)) { 62 | fp++; 63 | } 64 | } 65 | 66 | public String toString() { 67 | return String.format("f1: %.3f f2: %.3f prec: %.3f recall: %.3f (tp: %d, fp: %d, fn: %d)", 68 | getFMeasure(1.0), getFMeasure(2.0), getPrecision(), getRecall(), tp, fp, fn); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/LossFn.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml; 2 | 3 | public interface LossFn { 4 | public double getLoss(L trueLabel, L guessLabel); 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/LossFns.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml; 2 | 3 | import edu.umass.nlp.functional.Fn; 4 | import edu.umass.nlp.utils.Collections; 5 | import edu.umass.nlp.utils.ICounter; 6 | import edu.umass.nlp.utils.IPair; 7 | import edu.umass.nlp.utils.MapCounter; 8 | 9 | import java.util.Collection; 10 | import java.util.HashMap; 11 | import java.util.Map; 12 | 13 | 14 | public class LossFns { 15 | 16 | public static Map> compileLossFn(LossFn lossFn, Collection labels) { 17 | Map> res = new HashMap>(); 18 | for (L label : labels) { 19 | for (L otherLabel : labels) { 20 | Collections.getMut(res, label, new MapCounter()) 21 | .incCount(otherLabel, lossFn.getLoss(label, otherLabel)); 22 | } 23 | } 24 | return res; 25 | } 26 | 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/Regularizers.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml; 2 | 3 | import edu.umass.nlp.functional.Fn; 4 | import edu.umass.nlp.utils.BasicPair; 5 | import edu.umass.nlp.utils.IPair; 6 | 7 | public class Regularizers { 8 | 9 | public static Fn> getL2Regularizer(final double sigmaSq) { 10 | return new Fn>() { 11 | public IPair apply(double[] input) { 12 | double obj = 0.0; 13 | double[] grad = new double[input.length]; 14 | for (int i = 0; i < input.length; ++i) { 15 | double w = input[i]; 16 | obj += w * w / sigmaSq; 17 | grad[i] += 2 * w / sigmaSq; 18 | } 19 | return new BasicPair(obj, grad); 20 | } 21 | }; 22 | } 23 | 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/classification/BasicClassifierDatum.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.classification; 2 | 3 | import edu.umass.nlp.utils.IValued; 4 | 5 | import java.util.List; 6 | 7 | public class BasicClassifierDatum implements ClassifierDatum { 8 | private final List> preds; 9 | 10 | public BasicClassifierDatum(List> preds) { 11 | this.preds = preds; 12 | } 13 | 14 | @Override 15 | public List> getPredicates() { 16 | return preds; 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/classification/BasicLabeledClassifierDatum.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.classification; 2 | 3 | import edu.umass.nlp.utils.BasicValued; 4 | import edu.umass.nlp.utils.IValued; 5 | 6 | import java.util.ArrayList; 7 | import java.util.List; 8 | 9 | public class BasicLabeledClassifierDatum implements LabeledClassifierDatum { 10 | private final List> preds; 11 | private final L label; 12 | 13 | public BasicLabeledClassifierDatum(List> preds, L label) { 14 | this.preds = preds; 15 | this.label = label; 16 | } 17 | 18 | @Override 19 | public L getTrueLabel() { 20 | return label; 21 | } 22 | 23 | @Override 24 | public List> getPredicates() { 25 | return preds; 26 | } 27 | 28 | /** 29 | * 30 | */ 31 | public static LabeledClassifierDatum 32 | getBinaryDatum(L label,String... preds) { 33 | List> predPairs = new ArrayList>(); 34 | for (String pred : preds) { 35 | predPairs.add(BasicValued.make(pred,1.0)); 36 | } 37 | return new BasicLabeledClassifierDatum(predPairs, label); 38 | } 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/classification/ClassifierDatum.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.classification; 2 | 3 | import edu.umass.nlp.utils.IValued; 4 | 5 | import java.util.List; 6 | 7 | public interface ClassifierDatum { 8 | public List> getPredicates(); 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/classification/IClassifier.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.classification; 2 | 3 | import java.io.Serializable; 4 | 5 | 6 | public interface IClassifier extends Serializable { 7 | 8 | public void train(Iterable> data, Object opts); 9 | 10 | public L classify(ClassifierDatum datum); 11 | 12 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/classification/IProbabilisticClassifier.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.classification; 2 | 3 | import edu.umass.nlp.ml.prob.IDistribution; 4 | 5 | public interface IProbabilisticClassifier extends IClassifier { 6 | 7 | public IDistribution getLabelDistribution(ClassifierDatum datum); 8 | 9 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/classification/LabeledClassifierDatum.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.classification; 2 | 3 | public interface LabeledClassifierDatum extends ClassifierDatum { 4 | public L getTrueLabel(); 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/classification/MaxEntropyClassifier.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.classification; 2 | 3 | import edu.umass.nlp.functional.Fn; 4 | import edu.umass.nlp.functional.Functional; 5 | import edu.umass.nlp.ml.Regularizers; 6 | import edu.umass.nlp.ml.prob.DirichletMultinomial; 7 | import edu.umass.nlp.ml.prob.IDistribution; 8 | import edu.umass.nlp.optimize.CachingDifferentiableFn; 9 | import edu.umass.nlp.optimize.IDifferentiableFn; 10 | import edu.umass.nlp.optimize.IOptimizer; 11 | import edu.umass.nlp.optimize.LBFGSMinimizer; 12 | import edu.umass.nlp.utils.*; 13 | 14 | import java.util.ArrayList; 15 | import java.util.List; 16 | 17 | public class MaxEntropyClassifier implements IProbabilisticClassifier { 18 | 19 | private double[] weights; 20 | private Indexer predIndexer; 21 | private Indexer labelIndexer; 22 | 23 | @Override 24 | public IDistribution getLabelDistribution(ClassifierDatum datum) { 25 | InnerDatum innerDatum = toInnerDatum(datum); 26 | double[] probs = getLabelProbs(innerDatum); 27 | ICounter probCounts = new MapCounter(); 28 | for (int labelIndex = 0; labelIndex < probs.length; labelIndex++) { 29 | double prob = probs[labelIndex]; 30 | probCounts.setCount(labelIndexer.get(labelIndex),prob); 31 | } 32 | return DirichletMultinomial.make(probCounts); 33 | } 34 | 35 | private static class PredValPair { 36 | public final int pred; 37 | public final double val; 38 | 39 | private PredValPair(int pred, double val) { 40 | assert pred >= 0; 41 | this.pred = pred; 42 | this.val = val; 43 | } 44 | } 45 | 46 | private static class InnerDatum { 47 | List pvs = new ArrayList(); 48 | int trueLabelIndex; 49 | } 50 | 51 | private InnerDatum toInnerDatum(ClassifierDatum datum) { 52 | InnerDatum res = new InnerDatum(); 53 | for (IValued valued : datum.getPredicates()) { 54 | String pred = valued.getElem(); 55 | int predIndex = predIndexer.indexOf(pred); 56 | if (predIndex >= 0) { 57 | res.pvs.add(new PredValPair(predIndex, valued.getValue())); 58 | } 59 | } 60 | return res; 61 | } 62 | 63 | private InnerDatum toInnerLabeledDatum(LabeledClassifierDatum datum) { 64 | InnerDatum res = toInnerDatum(datum); 65 | res.trueLabelIndex = labelIndexer.indexOf(datum.getTrueLabel()); 66 | return res; 67 | } 68 | 69 | private void indexPredicatesAndLabels(Iterable> data) { 70 | predIndexer = new Indexer(); 71 | labelIndexer = new Indexer(); 72 | for (LabeledClassifierDatum datum : data) { 73 | for (IValued valued : datum.getPredicates()) { 74 | predIndexer.add(valued.getElem()); 75 | } 76 | labelIndexer.add(datum.getTrueLabel()); 77 | } 78 | predIndexer.lock(); 79 | labelIndexer.lock(); 80 | } 81 | 82 | @Override 83 | public void train(Iterable> data, Object opts) { 84 | indexPredicatesAndLabels(data); 85 | Iterable innerData = 86 | Functional.map(data, new Fn, InnerDatum>() { 87 | @Override 88 | public InnerDatum apply(LabeledClassifierDatum input) { 89 | return toInnerLabeledDatum(input); 90 | }}); 91 | IDifferentiableFn objFn = new CachingDifferentiableFn(new ObjFn(innerData)); 92 | IOptimizer.Result optRes = (new LBFGSMinimizer()).minimize( 93 | objFn, 94 | new double[objFn.getDimension()], 95 | new LBFGSMinimizer.Opts()); 96 | weights = DoubleArrays.clone(optRes.minArg); 97 | } 98 | 99 | public double[] getLabelProbs(InnerDatum datum) { 100 | double[] logProbs = new double[labelIndexer.size()]; 101 | for (int l = 0; l < logProbs.length; l++) { 102 | double sum = 0.0; 103 | for (PredValPair pair : datum.pvs) { 104 | int f = getWeightIndex(pair.pred, l); 105 | sum += pair.val * weights[f]; 106 | } 107 | logProbs[l] = sum; 108 | } 109 | return SloppyMath.logScoresToProbs(logProbs); 110 | } 111 | 112 | private int getWeightIndex(int predIndex, int labelIndex) { 113 | return predIndex * labelIndexer.size() + labelIndex; 114 | } 115 | 116 | class ObjFn implements IDifferentiableFn { 117 | 118 | private Iterable data; 119 | 120 | ObjFn(Iterable data) { 121 | this.data = data; 122 | } 123 | 124 | @Override 125 | public IPair computeAt(double[] x) { 126 | weights = DoubleArrays.clone(x); 127 | double logObj = 0.0; 128 | double[] grad = new double[getDimension()]; 129 | 130 | for (InnerDatum datum : data) { 131 | double[] probs = getLabelProbs(datum); 132 | logObj += Math.log(probs[datum.trueLabelIndex]); 133 | for (int l = 0; l < probs.length; l++) { 134 | for (PredValPair pair : datum.pvs) { 135 | int f = getWeightIndex(pair.pred, l); 136 | grad[f] -= pair.val * probs[l]; 137 | if (l == datum.trueLabelIndex) { 138 | grad[f] += pair.val * 1.0; 139 | } 140 | } 141 | } 142 | } 143 | 144 | // Negate 145 | logObj *= -1; 146 | DoubleArrays.scaleInPlace(grad, -1); 147 | 148 | IPair regRes = Regularizers.getL2Regularizer(1.0).apply(x); 149 | logObj += regRes.getFirst(); 150 | DoubleArrays.addInPlace(grad,regRes.getSecond()); 151 | 152 | return BasicPair.make(logObj, grad); 153 | } 154 | 155 | @Override 156 | public int getDimension() { 157 | return predIndexer.size() * labelIndexer.size(); 158 | } 159 | } 160 | 161 | @Override 162 | public L classify(ClassifierDatum datum) { 163 | return getLabelDistribution(datum).getMode(); 164 | } 165 | 166 | } 167 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/classification/Reranker.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.classification; 2 | 3 | 4 | import java.util.ArrayList; 5 | import java.util.List; 6 | import java.util.Set; 7 | 8 | import edu.umass.nlp.functional.DoubleFn; 9 | import edu.umass.nlp.functional.Fn; 10 | import edu.umass.nlp.functional.Functional; 11 | import edu.umass.nlp.ml.Regularizers; 12 | import edu.umass.nlp.optimize.CachingDifferentiableFn; 13 | import edu.umass.nlp.optimize.IDifferentiableFn; 14 | import edu.umass.nlp.optimize.IOptimizer; 15 | import edu.umass.nlp.optimize.LBFGSMinimizer; 16 | import edu.umass.nlp.parallel.ParallelUtils; 17 | import edu.umass.nlp.utils.*; 18 | import org.apache.commons.collections.primitives.ArrayDoubleList; 19 | import org.apache.commons.collections.primitives.ArrayIntList; 20 | import org.apache.commons.collections.primitives.DoubleList; 21 | import org.apache.commons.collections.primitives.IntList; 22 | import org.apache.log4j.Logger; 23 | 24 | public class Reranker { 25 | 26 | public static class Opts { 27 | public LBFGSMinimizer.Opts optimizerOpts = new LBFGSMinimizer.Opts(); 28 | public double sigmaSq = 1.0; 29 | } 30 | 31 | private Indexer featIndexer; 32 | private double[] weights; 33 | private transient Logger logger = Logger.getLogger("Reranker"); 34 | 35 | public double[] getWeights() { 36 | 37 | return weights; 38 | } 39 | 40 | public static interface Datum { 41 | public List> getFeatures(L label); 42 | public Set getAllowedLabels(); 43 | } 44 | 45 | private static class InternalDatum { 46 | IntList[] featIndices; 47 | DoubleList[] featValues; 48 | int trueLabel; 49 | 50 | private InternalDatum(IntList[] featIndices, DoubleList[] featValues, int trueLabel) { 51 | this.featIndices = featIndices; 52 | this.featValues = featValues; 53 | this.trueLabel = trueLabel; 54 | } 55 | } 56 | 57 | private InternalDatum toInternalDatum(LabeledDatum datum) { 58 | List labels = new ArrayList(datum.getAllowedLabels()); 59 | IntList[] featIndices = new IntList[labels.size()]; 60 | DoubleList[] featVals = new DoubleList[labels.size()]; 61 | int trueLabelIndex = labels.indexOf(datum.getTrueLabel()); 62 | for (int i = 0; i < labels.size(); i++) { 63 | L label = labels.get(i); 64 | List> feats = datum.getFeatures(label); 65 | featIndices[i] = new ArrayIntList(); 66 | featVals[i] = new ArrayDoubleList(); 67 | for (int f=0; f < feats.size(); ++f) { 68 | int featIndex = featIndexer.indexOf(feats.get(f).getFirst()); 69 | if (featIndex >= 0) { 70 | featIndices[i].add(featIndex); 71 | featVals[i].add(feats.get(f).getSecond()); 72 | } 73 | } 74 | } 75 | return new InternalDatum(featIndices, featVals, trueLabelIndex); 76 | } 77 | 78 | public static interface LabeledDatum extends Datum { 79 | public L getTrueLabel(); 80 | } 81 | 82 | public ICounter getLabelProbs(Datum datum) { 83 | List labels = new ArrayList(datum.getAllowedLabels()); 84 | double[] logProbs = new double[labels.size()]; 85 | for (int i=0; i < logProbs.length; ++i) { 86 | L label = labels.get(i); 87 | double logProb = 0.0; 88 | for (IValued valued : datum.getFeatures(label)) { 89 | logProb += valued.getValue() * weights[featIndexer.getIndex(valued.getElem())]; 90 | } 91 | logProbs[i] = logProb; 92 | } 93 | final double logSum = SloppyMath.logAdd(logProbs); 94 | ICounter res = new MapCounter(); 95 | for (int i=0; i < labels.size(); ++i) { 96 | L label = labels.get(i); 97 | res.setCount(label, Math.exp(logProbs[i]-logSum)); 98 | } 99 | return res; 100 | } 101 | 102 | private double[] getLabelProbs(InternalDatum datum) { 103 | double[] logProbs = new double[datum.featIndices.length]; 104 | for (int i=0; i < logProbs.length; ++i) { 105 | double logProb = 0.0; 106 | for (int f=0; f < datum.featIndices[i].size(); ++f) { 107 | logProb += datum.featValues[i].get(f) * weights[datum.featIndices[i].get(f)]; 108 | } 109 | logProbs[i] = logProb; 110 | } 111 | final double logSum = SloppyMath.logAdd(logProbs); 112 | double[] res = new double[logProbs.length]; 113 | for (int i=0; i < logProbs.length; ++i) { 114 | res[i] = Math.exp(logProbs[i]-logSum); 115 | } 116 | return res; 117 | } 118 | 119 | private class ObjFn implements IDifferentiableFn { 120 | Iterable data ; 121 | Opts opts; 122 | ObjFn(Iterable data,Opts opts) { 123 | this.data = data; 124 | this.opts = opts; 125 | } 126 | 127 | private IPair computeInternal(Iterable datums) { 128 | double obj = 0.0; 129 | double[] grad = new double[getDimension()]; 130 | for (InternalDatum datum: data) { 131 | double[] labelProbs = getLabelProbs(datum); 132 | //System.out.println(labelProbs); 133 | obj += Math.log(labelProbs[datum.trueLabel]); 134 | for (int l=0; l < datum.featIndices.length ;++l) { 135 | for (int f=0; f < datum.featIndices[l].size(); ++f) { 136 | int featIndex = datum.featIndices[l].get(f); 137 | double val = datum.featValues[l].get(f); 138 | grad[featIndex] -= val * labelProbs[l]; 139 | if (l == datum.trueLabel) { 140 | grad[featIndex] += val * 1.0; 141 | } 142 | } 143 | } 144 | } 145 | return BasicPair.make(obj, grad); 146 | } 147 | 148 | class Worker implements Runnable { 149 | Iterable data; 150 | IPair res ; 151 | 152 | Worker(Iterable data) { 153 | this.data = data; 154 | } 155 | 156 | public void run() { 157 | res = computeInternal(data); 158 | } 159 | } 160 | 161 | public IPair computeAt(double[] x) { 162 | double logObj = 0.0; 163 | double[] grad = new double[getDimension()]; 164 | weights = DoubleArrays.clone(x); 165 | 166 | if (data instanceof List) { 167 | List> parts = 168 | Collections.partition((List) data, Runtime.getRuntime().availableProcessors()); 169 | List workers = Functional.map(parts, new Fn, Worker>() { 170 | public Worker apply(Iterable input) { 171 | return new Worker(input); 172 | }}); 173 | ParallelUtils.doParallelWork(workers, workers.size()); 174 | for (Worker worker : workers) { 175 | logObj += worker.res.getFirst(); 176 | DoubleArrays.addInPlace(grad, worker.res.getSecond()); 177 | } 178 | } else { 179 | Worker singleton = new Worker(data); 180 | singleton.run(); 181 | logObj += singleton.res.getFirst(); 182 | grad = singleton.res.getSecond(); 183 | } 184 | 185 | logObj *= -1; 186 | DoubleArrays.scaleInPlace(grad, -1); 187 | 188 | // Regularizer 189 | IPair regRes = (Regularizers.getL2Regularizer(opts.sigmaSq)).apply(x); 190 | logObj += regRes.getFirst(); 191 | 192 | DoubleArrays.addInPlace(grad, regRes.getSecond()); 193 | return BasicPair.make(logObj, grad); 194 | } 195 | 196 | public int getDimension() { 197 | return featIndexer.size(); 198 | } 199 | } 200 | 201 | public void train(final Iterable> data, final Opts opts) { 202 | featIndexer = new Indexer(); 203 | logger.trace("Start Indexing Features"); 204 | for (LabeledDatum datum : data) { 205 | for (L label : datum.getAllowedLabels()) { 206 | for (IValued valued : datum.getFeatures(label)) { 207 | featIndexer.add(valued.getElem()); 208 | } 209 | } 210 | } 211 | featIndexer.lock(); 212 | logger.trace("Done Indexing Features"); 213 | logger.info("Number of Features: " + featIndexer.size()); 214 | logger.trace("Start Caching Data to Internal Representation"); 215 | final Iterable internalData = Functional.map(data, new Fn, InternalDatum>() { 216 | public InternalDatum apply(LabeledDatum input) { 217 | return toInternalDatum(input); 218 | }}); 219 | logger.trace("Done Caching Data"); 220 | IDifferentiableFn objFn = new CachingDifferentiableFn(new ObjFn(internalData,opts)); 221 | logger.trace("Starting Optimization"); 222 | IOptimizer.Result res = (new LBFGSMinimizer()).minimize(objFn, new double[objFn.getDimension()],opts.optimizerOpts); 223 | this.weights = DoubleArrays.clone(res.minArg); 224 | logger.trace("Done with optimization"); 225 | } 226 | 227 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/feats/IPredExtractor.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.feats; 2 | 3 | 4 | import edu.umass.nlp.utils.IValued; 5 | 6 | import java.util.List; 7 | 8 | public interface IPredExtractor { 9 | 10 | public List> getPredicates(T elem); 11 | 12 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/feats/Predicate.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.feats; 2 | 3 | 4 | 5 | import edu.umass.nlp.utils.IIndexed; 6 | 7 | import java.io.Serializable; 8 | 9 | public class Predicate implements IIndexed, Serializable { 10 | 11 | private final String pred; 12 | private final int index; 13 | 14 | public Predicate(String pred, int index) { 15 | this.pred = pred; 16 | this.index = index; 17 | } 18 | 19 | public Predicate(String pred) { 20 | this(pred,-1); 21 | } 22 | 23 | public boolean isIndexed() { 24 | return index >= 0; 25 | } 26 | 27 | public int getIndex() { 28 | return index; 29 | } 30 | 31 | public Predicate getElem() { 32 | return this; 33 | } 34 | 35 | public Predicate withIndex(int index) { 36 | return new Predicate(pred, index); 37 | } 38 | 39 | @Override 40 | public String toString() { 41 | return "Pred(" + pred + ')'; 42 | } 43 | 44 | @Override 45 | public boolean equals(Object o) { 46 | if (this == o) return true; 47 | if (o == null || getClass() != o.getClass()) return false; 48 | 49 | Predicate predicate = (Predicate) o; 50 | 51 | if (pred != null ? !pred.equals(predicate.pred) : predicate.pred != null) return false; 52 | 53 | return true; 54 | } 55 | 56 | @Override 57 | public int hashCode() { 58 | return pred != null ? pred.hashCode() : 0; 59 | } 60 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/feats/PredicateManager.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.feats; 2 | 3 | 4 | import edu.umass.nlp.functional.Fn; 5 | import edu.umass.nlp.functional.Functional; 6 | import edu.umass.nlp.utils.BasicValued; 7 | import edu.umass.nlp.utils.IValued; 8 | import edu.umass.nlp.utils.Indexer; 9 | 10 | import java.io.Serializable; 11 | import java.util.List; 12 | 13 | public class PredicateManager implements Serializable { 14 | 15 | private IPredExtractor predFn; 16 | private Indexer predIndexer = new Indexer(); 17 | private boolean isLocked = false; 18 | 19 | public PredicateManager(IPredExtractor predFn) { 20 | this.predFn = predFn; 21 | } 22 | 23 | public void lock() { 24 | if (!isLocked) { 25 | predIndexer.lock(); 26 | isLocked = true; 27 | } 28 | } 29 | 30 | public boolean isLocked() { 31 | return isLocked; 32 | } 33 | 34 | public List> getPredicates(T elem) { 35 | return Functional.map(predFn.getPredicates(elem), new Fn,IValued>() { 36 | public IValued apply(IValued input) { 37 | return BasicValued.make(getIndexedPredicate(input.getElem()),input.getValue()); 38 | }}); 39 | } 40 | 41 | private Predicate getIndexedPredicate(Predicate pred) { 42 | if (pred.isIndexed()) return pred; 43 | int index = predIndexer.indexOf(pred); 44 | if (index < 0) { 45 | if (isLocked) return null; 46 | pred = pred.withIndex(predIndexer.size()); 47 | predIndexer.add(pred); 48 | return pred; 49 | } 50 | return predIndexer.get(index); 51 | } 52 | 53 | public Indexer getPredIndexer() { 54 | return predIndexer; 55 | } 56 | 57 | public void indexAll(Iterable elems) { 58 | assert !isLocked(); 59 | for (T elem : elems) { 60 | for (IValued pv : predFn.getPredicates(elem)) { 61 | // side-effect: indexs pred 62 | getIndexedPredicate(pv.getElem()); 63 | } 64 | } 65 | } 66 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/feats/WeightsManager.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.feats; 2 | 3 | import edu.umass.nlp.utils.CounterMap; 4 | import edu.umass.nlp.utils.ICounter; 5 | import edu.umass.nlp.utils.Indexer; 6 | import edu.umass.nlp.utils.Span; 7 | 8 | import java.io.Serializable; 9 | import java.util.List; 10 | import java.util.Map; 11 | 12 | /** 13 | * Strange abstraction, but basically, this 14 | * combines the fact that for each predicate their 15 | * is a different weight for each possible label 16 | * (a predicate plus a label yield a feature). 17 | * 18 | * @param 19 | */ 20 | public class WeightsManager implements Serializable { 21 | 22 | public final Indexer

predIndexer; 23 | private final int numPreds; 24 | private final Indexer labelIndexer; 25 | private final int numFeats; 26 | private double[] weights = null; 27 | 28 | public WeightsManager(Indexer

predIndexer, Indexer labelIndexer) { 29 | predIndexer.lock(); 30 | labelIndexer.lock(); 31 | this.predIndexer = predIndexer; 32 | this.labelIndexer = labelIndexer; 33 | this.numPreds = predIndexer.size(); 34 | this.numFeats = numPreds * labelIndexer.size(); 35 | } 36 | 37 | public void setWeights(double[] weights) { 38 | this.weights = weights; 39 | } 40 | 41 | public void addScores(P pred, double val, double[] scores) { 42 | //assert pred.isIndexed(); 43 | int predIndex = predIndexer.indexOf(pred); 44 | assert predIndex >= 0; 45 | for (int l=0; l < getNumLabels(); ++l) { 46 | int weightIndex = getWeightIndex(predIndex, l); 47 | scores[l] += val * weights[weightIndex]; 48 | } 49 | } 50 | 51 | public void addFeatExpecations(P pred, double[] labelWeights, double[] accumExpectations) { 52 | //assert pred.isIndexed(); 53 | int predIndex = predIndexer.indexOf(pred); 54 | assert predIndex >= 0; 55 | assert labelWeights.length == getNumLabels(); 56 | assert accumExpectations.length == getNumFeats(); 57 | for (int l=0; l < getNumLabels(); ++l) { 58 | int weightIndex = getWeightIndex(predIndex, l); 59 | accumExpectations[weightIndex] += labelWeights[l]; 60 | } 61 | } 62 | 63 | public Span getIndexSpan(Predicate p) { 64 | int start = p.getIndex() * getNumLabels(); 65 | return new Span(start, start + getNumLabels()); 66 | } 67 | 68 | public int getNumFeats() { 69 | return numFeats; 70 | } 71 | 72 | public int getWeightIndex(int predIndex, int labelIndex) { 73 | return predIndex * getNumLabels() + labelIndex; 74 | } 75 | 76 | // public Map> getWeightsByLabel() { 77 | // Map> res = CounterMap.make(); 78 | // for (int l=0; l < getNumLabels(); ++l) { 79 | // L label = labelIndexer.get(l); 80 | // for (P pred : predIndexer) { 81 | // int w = getWeightIndex(pred, l); 82 | // CounterMap.setCount(res,label,pred,weights[w]); 83 | // } 84 | // } 85 | // return res; 86 | // } 87 | 88 | public int getNumLabels() { 89 | return getNumLabels(); 90 | } 91 | 92 | // public void inspect() { 93 | // Logger.startTrack("Weight Inspect"); 94 | // CounterMap cm = getWeightsByLabel(); 95 | // for (L l : labelIndexer) { 96 | // Logger.startTrack("Label: " + l); 97 | // Counter labelWeights = cm.getCounter(l); 98 | // List sortedKeys = Counters.absCounts(labelWeights).getSortedKeys().subList(0,20); 99 | // labelWeights.pruneExcept(new HashSet(sortedKeys)); 100 | // Logger.logs(labelWeights.toString()); 101 | // Logger.endTrack(); 102 | // } 103 | // Logger.endTrack(); 104 | // } 105 | 106 | public int getLabelIndex(L l) { 107 | return labelIndexer.indexOf(l); 108 | } 109 | 110 | public List getLabelIndexer() { 111 | return labelIndexer; 112 | } 113 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/prob/AbstractDistribution.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.prob; 2 | 3 | import edu.umass.nlp.functional.DoubleFn; 4 | import edu.umass.nlp.functional.Fn; 5 | import edu.umass.nlp.functional.Functional; 6 | import edu.umass.nlp.utils.BasicValued; 7 | import edu.umass.nlp.utils.Collections; 8 | import edu.umass.nlp.utils.IMergable; 9 | import edu.umass.nlp.utils.IValued; 10 | import org.apache.log4j.Logger; 11 | 12 | import java.io.Serializable; 13 | import java.util.HashSet; 14 | import java.util.List; 15 | import java.util.Random; 16 | import java.util.Set; 17 | 18 | 19 | public abstract class AbstractDistribution implements IDistribution, Serializable { 20 | 21 | public static final long serialVersionUID = 42L; 22 | 23 | protected boolean locked = false; 24 | protected Logger logger = Logger.getLogger("AbstractDistribution"); 25 | 26 | // 27 | // IDistribution 28 | // 29 | 30 | public abstract double getProb(T elem); 31 | 32 | public double getLogProb(T elem) { 33 | return Math.log(getProb(elem)); 34 | } 35 | 36 | public T getMode() { 37 | assert !getSupport().isEmpty(); 38 | return Functional.findMax(this, new DoubleFn>() { 39 | public double valAt(IValued input) { 40 | return input.getValue(); 41 | }}).getElem().getElem(); 42 | } 43 | 44 | public T getSample(Random r) { 45 | double target = r.nextDouble(); 46 | double sofar = 0.0; 47 | for (IValued valued : this) { 48 | double p = valued.getValue(); 49 | if (target > sofar && target < (sofar+p)) { 50 | return valued.getElem(); 51 | } 52 | sofar += p; 53 | } 54 | throw new RuntimeException("error: Couldn't get sample. Only saw mass " + sofar); 55 | } 56 | 57 | public Set getSupport() { 58 | Set supp = new HashSet(); 59 | for (IValued valued : this) { 60 | supp.add(valued.getElem()); 61 | } 62 | return supp; 63 | } 64 | 65 | 66 | // 67 | // ILockable 68 | // 69 | 70 | 71 | public boolean isLocked() { 72 | return locked; 73 | } 74 | 75 | public void lock() { 76 | if (locked) { 77 | logger.warn("Trying to lock an already locked distribution"); 78 | } 79 | else locked = true; 80 | } 81 | 82 | public String toString() { 83 | return toString(20); 84 | } 85 | 86 | public String toString(int numEntries) { 87 | List> entries = Collections.toList(this); 88 | return Functional.mkString(entries.subList(0, Math.min(entries.size(), numEntries)), 89 | "[", ",", "]", 90 | new Fn, String>() { 91 | public String apply(IValued input) { 92 | return String.format("%s : %.4f", input.getFirst(), input.getSecond()); 93 | } 94 | }).toString(); 95 | } 96 | 97 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/prob/BasicConditionalDistribution.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.prob; 2 | 3 | import edu.umass.nlp.functional.Fn; 4 | import edu.umass.nlp.functional.Functional; 5 | import edu.umass.nlp.utils.BasicPair; 6 | import edu.umass.nlp.utils.Collections; 7 | import edu.umass.nlp.utils.IPair; 8 | import org.apache.log4j.Logger; 9 | 10 | import java.util.HashMap; 11 | import java.util.Iterator; 12 | import java.util.Map; 13 | 14 | public class BasicConditionalDistribution implements IConditionalDistribution { 15 | 16 | private final Logger logger = Logger.getLogger("BasicCondDistr"); 17 | private final Map> distrs ; 18 | private final Fn> distrFact ; 19 | private boolean locked = false; 20 | 21 | public BasicConditionalDistribution(Fn> distrFact) { 22 | this.distrs = new HashMap>(); 23 | this.distrFact = distrFact; 24 | } 25 | 26 | // 27 | // IConditional Distribution 28 | // 29 | 30 | public IDistribution getDistribution(C cond) { 31 | assert isLocked(); 32 | return Collections.getMut(distrs, cond, Functional.curry(distrFact, cond)); 33 | } 34 | 35 | public IDistribution apply(C input) { 36 | return getDistribution(input); 37 | } 38 | 39 | public void observe(C cond, O obs, double weight) { 40 | assert !isLocked(); 41 | Collections.getMut(distrs, cond, Functional.curry(distrFact, cond)).observe(obs, weight); 42 | } 43 | 44 | public Iterator>> iterator() { 45 | return Functional.map(distrs.entrySet().iterator(),new Fn>, IPair>>() { 46 | public IPair> apply(Map.Entry> input) { 47 | return BasicPair.make(input.getKey(), input.getValue()); 48 | }}); 49 | } 50 | 51 | // 52 | // ILockable 53 | // 54 | 55 | public boolean isLocked() { 56 | return locked; 57 | } 58 | 59 | public void lock() { 60 | if (locked) { 61 | logger.warn("lock() called on locked object"); 62 | } 63 | locked = true; 64 | } 65 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/prob/DirichletMultinomial.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.prob; 2 | 3 | import edu.umass.nlp.functional.Fn; 4 | import edu.umass.nlp.functional.Functional; 5 | import edu.umass.nlp.utils.*; 6 | 7 | import java.util.Iterator; 8 | 9 | public class DirichletMultinomial extends AbstractDistribution 10 | implements IMergable> { 11 | 12 | ICounter counts_ = new MapCounter(); 13 | double lambda_ = 0.0; 14 | // int numKeys = -1; 15 | 16 | // 17 | // IDistribution 18 | // 19 | 20 | public double getProb(T elem) { 21 | double numer = counts_.getCount(elem) + lambda_; 22 | double denom = counts_.totalCount() + counts_.size() * lambda_; 23 | assert denom > 0.0 : 24 | String.format("Bad Denom: %.3f %s %d", denom, counts_, counts_.size()); 25 | double prob = numer / denom; 26 | assert prob > 0.0 && prob <= 1.0 : String.format("Bad prob: %.5f for key: %s", prob, elem); 27 | return prob; 28 | } 29 | 30 | public Iterator> iterator() { 31 | return Functional.map(counts_, new Fn, IValued>() { 32 | public IValued apply(IValued input) { 33 | return input.withValue(getProb(input.getElem())); 34 | }}).iterator(); 35 | } 36 | 37 | public void observe(T elem, double count) { 38 | if (count != 0.0) counts_.incCount(elem, count); 39 | } 40 | 41 | // 42 | // IMergeable 43 | // 44 | 45 | public void merge(DirichletMultinomial other) { 46 | Counters.incAll(this.counts_, other.counts_); 47 | } 48 | 49 | // 50 | // DirichletMultinomial 51 | // 52 | 53 | public double getLambda() { 54 | return lambda_; 55 | } 56 | 57 | public void setLambda(double lambda) { 58 | assert lambda > 0.0 : "Bad Lambda: " + lambda; 59 | this.lambda_ = lambda; 60 | } 61 | 62 | 63 | // 64 | // Object 65 | // 66 | 67 | public String toString() { 68 | return Counters.from(this).toString(); 69 | } 70 | 71 | 72 | // 73 | // Factory Methods 74 | // 75 | 76 | public static DirichletMultinomial make(double lambda) { 77 | DirichletMultinomial d = new DirichletMultinomial(); 78 | d.setLambda(lambda); 79 | return d; 80 | } 81 | 82 | public static DirichletMultinomial make(ICounter elems) { 83 | DirichletMultinomial d = new DirichletMultinomial(); 84 | for (IValued entry : elems) { 85 | d.observe(entry.getElem(), entry.getValue()); 86 | } 87 | return d; 88 | } 89 | 90 | 91 | // 92 | // Main 93 | // 94 | 95 | public static void main(String[] args) { 96 | DirichletMultinomial d = DirichletMultinomial.make(1.0); 97 | d.observe("a",2.0); 98 | d.observe("c",1.0); 99 | System.out.println(d); 100 | } 101 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/prob/Distributions.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.prob; 2 | 3 | 4 | public class Distributions { 5 | 6 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/prob/IConditionalDistribution.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.prob; 2 | 3 | import edu.umass.nlp.functional.Fn; 4 | import edu.umass.nlp.utils.ILockable; 5 | import edu.umass.nlp.utils.IPair; 6 | 7 | 8 | public interface IConditionalDistribution extends Iterable>>, 9 | Fn>, 10 | ILockable { 11 | public IDistribution getDistribution(C cond); 12 | public void observe(C cond, O obs, double weight); 13 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/prob/IDistribution.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.prob; 2 | 3 | import edu.umass.nlp.utils.ILockable; 4 | import edu.umass.nlp.utils.IValued; 5 | 6 | import java.io.Serializable; 7 | import java.util.Collection; 8 | import java.util.Random; 9 | import java.util.Set; 10 | 11 | /** 12 | * Distribution abstraction. Most will be fine just sub-classing 13 | * AbstractDistribution, which has default implementations 14 | * of many of these functions. 15 | */ 16 | public interface IDistribution extends Iterable>, 17 | ILockable, 18 | Serializable { 19 | public Set getSupport(); 20 | public double getProb(T elem); 21 | public double getLogProb(T elem); 22 | public T getMode(); 23 | public T getSample(Random r); 24 | public void observe(T elem, double weight); 25 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/prob/ISuffStats.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.prob; 2 | 3 | import edu.umass.nlp.utils.IMergable; 4 | 5 | import java.io.Serializable; 6 | 7 | public interface ISuffStats extends IMergable>, Serializable { 8 | public void observe(T elem, double weight); 9 | public IDistribution toDistribution(); 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/regression/LinearRegressionModel.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.regression; 2 | 3 | import edu.umass.nlp.exec.Execution; 4 | import edu.umass.nlp.functional.Fn; 5 | import edu.umass.nlp.functional.Functional; 6 | import edu.umass.nlp.ml.Regularizers; 7 | import edu.umass.nlp.ml.feats.IPredExtractor; 8 | import edu.umass.nlp.ml.feats.Predicate; 9 | import edu.umass.nlp.ml.feats.PredicateManager; 10 | import edu.umass.nlp.optimize.GradientDescent; 11 | import edu.umass.nlp.optimize.IDifferentiableFn; 12 | import edu.umass.nlp.optimize.IOptimizer; 13 | import edu.umass.nlp.optimize.LBFGSMinimizer; 14 | import edu.umass.nlp.utils.*; 15 | import org.apache.commons.collections.primitives.ArrayDoubleList; 16 | import org.apache.commons.collections.primitives.ArrayIntList; 17 | import org.apache.commons.collections.primitives.DoubleList; 18 | import org.apache.commons.collections.primitives.IntList; 19 | 20 | import java.util.List; 21 | 22 | public class LinearRegressionModel { 23 | 24 | private Indexer featIndexer; 25 | private double[] weights; 26 | 27 | public static interface Datum { 28 | public List> getFeatures(); 29 | public double getWeight(); 30 | } 31 | 32 | public static class BasicLabeledDatum implements LabeledDatum { 33 | private final List> feats; 34 | private final double weight; 35 | private final double target; 36 | 37 | public BasicLabeledDatum(List> feats, double target, double weight) { 38 | this.feats = feats; 39 | this.target = target; 40 | this.weight = weight; 41 | } 42 | 43 | public double getTarget() { 44 | return target; 45 | } 46 | 47 | public List> getFeatures() { 48 | return feats; 49 | } 50 | 51 | public double getWeight() { 52 | return weight; 53 | } 54 | } 55 | 56 | public static interface LabeledDatum extends Datum { 57 | public double getTarget(); 58 | } 59 | 60 | private class InternalDatum { 61 | public IntList featIndices; 62 | public DoubleList featVals; 63 | public double target; 64 | public double weight; 65 | 66 | private InternalDatum(LabeledDatum datum) { 67 | featIndices = new ArrayIntList(); 68 | featVals = new ArrayDoubleList(); 69 | target = datum.getTarget(); 70 | weight = datum.getWeight(); 71 | for (IValued valued : datum.getFeatures()) { 72 | int featIndex = featIndexer.getIndex(valued.getElem()); 73 | if (featIndex < 0) continue; 74 | featIndices.add(featIndex); 75 | featVals.add(valued.getSecond()); 76 | } 77 | } 78 | } 79 | 80 | public static class Opts { 81 | public IOptimizer.Opts optimizerOpts = new LBFGSMinimizer.Opts(); 82 | public Fn> regularizer; 83 | } 84 | 85 | private double getPredictionInternal(List> pvs) { 86 | double sum = 0.0; 87 | for (IValued pv : pvs) { 88 | int featIndex = featIndexer.getIndex(pv.getElem()); 89 | if (featIndex < 0) continue; 90 | sum += weights[featIndex] * pv.getValue(); 91 | } 92 | return sum; 93 | } 94 | 95 | private double getPredictionInternal(InternalDatum datum) { 96 | double sum = 0.0; 97 | for (int i=0; i < datum.featIndices.size(); ++i) { 98 | int featIndex = datum.featIndices.get(i); 99 | double val = datum.featVals.get(i); 100 | if (featIndex < 0) continue; 101 | sum += weights[featIndex] * val; 102 | } 103 | return sum; 104 | } 105 | 106 | public void train(final Iterable data, final Opts opts) { 107 | featIndexer = new Indexer(); 108 | for (LabeledDatum datum : data) { 109 | for (IValued valued : datum.getFeatures()) { 110 | featIndexer.add(valued.getElem()); 111 | } 112 | } 113 | featIndexer.lock(); 114 | 115 | final List internalData = Functional.map(data, new Fn() { 116 | public InternalDatum apply(LabeledDatum input) { 117 | return new InternalDatum(input); 118 | } 119 | }); 120 | 121 | IDifferentiableFn objFn = new IDifferentiableFn() { 122 | public IPair computeAt(double[] x) { 123 | weights = DoubleArrays.clone(x); 124 | double obj = 0.0; 125 | double[] grad = new double[getDimension()]; 126 | 127 | for (InternalDatum datum: internalData) { 128 | double trueY = datum.target; 129 | double predictY = getPredictionInternal(datum); 130 | double diffY = (trueY - predictY); 131 | assert !Double.isNaN(diffY); 132 | obj += datum.weight * 0.5 * diffY * diffY; 133 | for (int i=0; i < datum.featIndices.size(); ++i) { 134 | int featIndex = datum.featIndices.get(i); 135 | if (featIndex < 0) continue; 136 | double featVal = datum.featVals.get(i); 137 | grad[featIndex] += datum.weight * diffY * featVal; 138 | } 139 | } 140 | DoubleArrays.scaleInPlace(grad, -1); 141 | 142 | if (opts.regularizer != null) { 143 | IPair res = opts.regularizer.apply(x); 144 | obj += res.getFirst(); 145 | DoubleArrays.addInPlace(grad, res.getSecond()); 146 | } 147 | 148 | 149 | return BasicPair.make(obj, grad); 150 | } 151 | 152 | public int getDimension() { 153 | return featIndexer.size(); 154 | } 155 | }; 156 | IOptimizer.Result res = 157 | (new LBFGSMinimizer()).minimize(objFn, new double[objFn.getDimension()], opts.optimizerOpts); 158 | weights = res.minArg; 159 | } 160 | 161 | public double getPrediction(Datum datum) { 162 | return getPredictionInternal(datum.getFeatures()); 163 | } 164 | 165 | public static void main(String[] args) { 166 | Execution.init(null); 167 | List doc = Collections.makeList("fuzzy", "wuzzy"); 168 | class MyDatum implements LabeledDatum { 169 | private Iterable elems; 170 | MyDatum(String...elems) { 171 | this.elems = Collections.toList(elems); 172 | } 173 | public double getTarget() { 174 | return 1.0; 175 | } 176 | public double getWeight() { 177 | return 1.0; 178 | } 179 | 180 | public List> getFeatures() { 181 | return Functional.map(elems, new Fn>() { 182 | public IValued apply(String input) { 183 | return new BasicValued(input, 1.0); 184 | }}); 185 | } 186 | } 187 | 188 | // IPredExtractor> predFn = new IPredExtractor>() { 189 | // public List> getPredicates(List elem) { 190 | // return Functional.map(elem, new Fn>() { 191 | // public IValued apply(String input) { 192 | // return BasicValued.make(new Predicate(input), 1.0); 193 | // } 194 | // }); 195 | // } 196 | // }; 197 | (new LinearRegressionModel()).train(Collections.makeList(new MyDatum("fuzzy","wuzzy")), new Opts()); 198 | } 199 | 200 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/sequence/BasicLabelSeqDatum.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.sequence; 2 | 3 | import java.util.List; 4 | 5 | public class BasicLabelSeqDatum implements ILabeledSeqDatum { 6 | 7 | private final List> nodePreds; 8 | private final List labels; 9 | private final double weight; 10 | 11 | public BasicLabelSeqDatum(List> nodePreds, List labels, double weight) { 12 | this.nodePreds = nodePreds; 13 | this.labels = labels; 14 | this.weight = weight; 15 | } 16 | 17 | public List getLabels() { 18 | return labels; 19 | } 20 | 21 | public List> getNodePredicates() { 22 | return nodePreds; 23 | } 24 | 25 | public boolean isLabeled() { 26 | return labels != null; 27 | } 28 | 29 | public double getWeight() { 30 | return weight; 31 | } 32 | 33 | 34 | 35 | 36 | } 37 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/sequence/ForwardBackwards.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.sequence; 2 | 3 | import edu.umass.nlp.utils.LogAdder; 4 | import edu.umass.nlp.utils.Maxer; 5 | 6 | import java.util.ArrayList; 7 | import java.util.Arrays; 8 | import java.util.List; 9 | 10 | public class ForwardBackwards { 11 | 12 | // Result Class 13 | public static class Result { 14 | // Log sum over all paths 15 | public double logZ; 16 | // For each seq pos, marginals 17 | // over states 18 | public double[][] stateMarginals; 19 | // For each transition (seq-len - 1) 20 | // marginal over transitions 21 | public double[][] transMarginals; 22 | } 23 | 24 | 25 | // Final 26 | private final StateSpace stateSpace; 27 | private final int numStates, numTrans; 28 | 29 | // Mutable 30 | private double[][] potentials; 31 | private int seqLen; 32 | private boolean doneAlphas, doneBetas, doneNodeMarginals, doneEdgeMarginals; 33 | private double[][] alphas, betas, nodeMarginals, edgeMarginals; 34 | 35 | public ForwardBackwards(StateSpace stateSpace) { 36 | this.stateSpace = stateSpace; 37 | this.numStates = stateSpace.getStates().size(); 38 | this.numTrans = stateSpace.getTransitions().size(); 39 | } 40 | 41 | public Result compute(double[][] potentials) { 42 | this.potentials = potentials; 43 | this.seqLen = potentials.length+1; 44 | doneAlphas = false; 45 | doneBetas = false; 46 | doneNodeMarginals = false; 47 | doneEdgeMarginals = false; 48 | Result res = new Result(); 49 | res.logZ = getLogZ(); 50 | res.stateMarginals = getNodeMarginals(); 51 | res.transMarginals = getEdgeMarginals(); 52 | return res; 53 | } 54 | 55 | private void computeAlphas() { 56 | alphas = new double[seqLen][numStates]; 57 | for (double[] row: alphas) Arrays.fill(row, Double.NEGATIVE_INFINITY); 58 | // Start 59 | alphas[0][stateSpace.startState.index] = 0.0; 60 | // Subseqent 61 | for (int i=1; i < seqLen; ++i) { 62 | for (int s=0; s < numStates; ++s) { 63 | LogAdder logAdder = new LogAdder(); 64 | List inTrans = stateSpace.getTransitionsTo(s); 65 | for (Transition trans: inTrans) { 66 | logAdder.add( 67 | alphas[i-1][trans.from.index] + 68 | potentials[i-1][trans.index] 69 | ); 70 | } 71 | alphas[i][s] = logAdder.logSum(); 72 | } 73 | } 74 | } 75 | 76 | private void ensureAlphas() { 77 | if (!doneAlphas) { 78 | computeAlphas(); 79 | doneAlphas = true; 80 | } 81 | } 82 | 83 | public double getLogZ() { 84 | ensureAlphas(); 85 | return alphas[seqLen-1][stateSpace.stopState.index]; 86 | } 87 | 88 | public double[][] getAlphas() { 89 | ensureAlphas(); 90 | return alphas; 91 | } 92 | 93 | private void ensureBetas() { 94 | ensureAlphas(); 95 | if (!doneBetas) { 96 | computeBetas(); 97 | doneBetas = true; 98 | } 99 | } 100 | 101 | private void computeBetas() { 102 | betas = new double[seqLen][numStates]; 103 | for (double[] row: betas) Arrays.fill(row, Double.NEGATIVE_INFINITY); 104 | // Start 105 | betas[seqLen-1][stateSpace.stopState.index] = 0.0; 106 | // Subseqent 107 | for (int i=seqLen-2; i >= 0; --i) { 108 | for (int s=0; s < numStates; ++s) { 109 | LogAdder logAdder = new LogAdder(); 110 | List outTrans = stateSpace.getTransitionsFrom(s); 111 | for (Transition trans: outTrans) { 112 | logAdder.add( 113 | betas[i+1][trans.to.index] + 114 | potentials[i][trans.index] 115 | ); 116 | } 117 | betas[i][s] = logAdder.logSum(); 118 | } 119 | } 120 | } 121 | 122 | public double[][] getBetas() { 123 | ensureBetas(); 124 | return betas; 125 | } 126 | 127 | public void setInput(double[][] potentials) { 128 | this.potentials = potentials; 129 | this.seqLen = potentials.length+1; 130 | doneAlphas = false; 131 | doneBetas = false; 132 | doneNodeMarginals = false; 133 | doneEdgeMarginals = false; 134 | } 135 | 136 | public double[][] getEdgeMarginals() { 137 | ensureEdgeMarginals(); 138 | return edgeMarginals; 139 | } 140 | 141 | private void ensureEdgeMarginals() { 142 | ensureAlphas(); 143 | ensureBetas(); 144 | if (!doneEdgeMarginals) { 145 | computeEdgeMarginals(); 146 | doneEdgeMarginals = true; 147 | } 148 | } 149 | 150 | private void computeEdgeMarginals() { 151 | edgeMarginals = new double[seqLen-1][numTrans]; 152 | double logZ = getLogZ(); 153 | for (int i=0; i < seqLen-1; ++i) { 154 | for (int s=0; s < numStates; ++s) { 155 | if (alphas[i][s] == Double.NEGATIVE_INFINITY) continue; 156 | for (Transition trans: stateSpace.getTransitionsFrom(s)) { 157 | double numer = alphas[i][s] + 158 | potentials[i][trans.index] + betas[i+1][trans.to.index]; 159 | edgeMarginals[i][trans.index] = Math.exp(numer-logZ); 160 | } 161 | } 162 | } 163 | } 164 | 165 | public double[][] getNodeMarginals() { 166 | ensureNodeMarginals(); 167 | return nodeMarginals; 168 | } 169 | 170 | private void ensureNodeMarginals() { 171 | ensureEdgeMarginals(); 172 | if (!doneNodeMarginals) { 173 | computeNodeMarginals(); 174 | doneNodeMarginals = true; 175 | } 176 | } 177 | 178 | private void computeNodeMarginals() { 179 | ensureEdgeMarginals(); 180 | nodeMarginals = new double[seqLen][numStates]; 181 | // Fist: Must Have All Mass on Start State 182 | nodeMarginals[0][stateSpace.startState.index] = 1.0; 183 | // Middle States 184 | for (int i=1; i < seqLen-1; ++i) { 185 | for (int s=0; s < numStates; ++s) { 186 | double nodeSum = 0.0; 187 | for (Transition trans : stateSpace.getTransitionsFrom(s)) { 188 | nodeSum += edgeMarginals[i][trans.index]; 189 | } 190 | nodeMarginals[i][s] = nodeSum; 191 | } 192 | } 193 | // Last: Must Have All Mass on Stop State 194 | nodeMarginals[seqLen-1][stateSpace.stopState.index] = 1.0; 195 | } 196 | 197 | public List viterbiDecode() { 198 | double[][] viterbiAlphas = new double[seqLen][numStates]; 199 | for (double[] row: viterbiAlphas) Arrays.fill(row, Double.NEGATIVE_INFINITY); 200 | // Start 201 | viterbiAlphas[0][stateSpace.startState.index] = 0.0; 202 | // Subseqent 203 | for (int i=1; i < seqLen; ++i) { 204 | for (int s=0; s < numStates; ++s) { 205 | List inTrans = stateSpace.getTransitionsTo(s); 206 | double max = Double.NEGATIVE_INFINITY; 207 | for (Transition trans: inTrans) { 208 | double val = 209 | viterbiAlphas[i-1][trans.from.index] + 210 | potentials[i-1][trans.index]; 211 | if (val > max) { 212 | max = val; 213 | } 214 | } 215 | viterbiAlphas[i][s] = max; 216 | } 217 | } 218 | List res = new ArrayList(); 219 | res.add(stateSpace.stopState.label); 220 | double trgVal = viterbiAlphas[seqLen-1][stateSpace.stopState.index]; 221 | int trgState = stateSpace.stopState.index; 222 | for (int pos=seqLen-2; pos >= 0; --pos) { 223 | boolean found = false; 224 | for (Transition trans : stateSpace.getTransitionsTo(trgState)) { 225 | double guess = potentials[pos][trans.index] + viterbiAlphas[pos][trans.from.index]; 226 | if (Math.abs(guess-trgVal) < 1.0e-8) { 227 | trgVal = viterbiAlphas[pos][trans.from.index]; 228 | trgState = trans.from.index; 229 | res.add(stateSpace.getStates().get(trgState).label); 230 | found = true; 231 | break; 232 | } 233 | } 234 | if (!found) throw new RuntimeException("Bad"); 235 | } 236 | java.util.Collections.reverse(res); 237 | return res; 238 | } 239 | 240 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/sequence/ILabeledSeqDatum.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.sequence; 2 | 3 | import java.util.List; 4 | 5 | 6 | public interface ILabeledSeqDatum extends ISeqDatum { 7 | public List getLabels(); 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/sequence/ISeqDatum.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.sequence; 2 | 3 | 4 | import edu.umass.nlp.utils.IValued; 5 | 6 | import java.util.List; 7 | 8 | 9 | public interface ISeqDatum { 10 | public List> getNodePredicates(); 11 | public double getWeight(); 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/sequence/ProbabilisticSequenceModel.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.sequence; 2 | 3 | import edu.umass.nlp.ml.prob.IDistribution; 4 | 5 | import java.util.List; 6 | 7 | public interface ProbabilisticSequenceModel extends SequenceModel { 8 | 9 | public List> getTagMarginals(List input); 10 | 11 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/sequence/SequenceModel.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.sequence; 2 | 3 | 4 | import edu.umass.nlp.utils.IPair; 5 | 6 | import java.util.List; 7 | 8 | public interface SequenceModel { 9 | 10 | public void train(Iterable>> labeledInstances); 11 | 12 | public List tag(List input); 13 | 14 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/sequence/State.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.sequence; 2 | 3 | import java.io.Serializable; 4 | 5 | public class State implements Serializable { 6 | public final String label; 7 | public final int index; 8 | public State(String label, int index) { 9 | this.label = label; 10 | this.index = index; 11 | } 12 | 13 | @Override 14 | public String toString() { 15 | return String.format("State(%s)",label); 16 | } 17 | 18 | @Override 19 | public boolean equals(Object o) { 20 | if (this == o) return true; 21 | if (o == null || getClass() != o.getClass()) return false; 22 | 23 | State state = (State) o; 24 | 25 | if (index != state.index) return false; 26 | if (label != null ? !label.equals(state.label) : state.label != null) return false; 27 | 28 | return true; 29 | } 30 | 31 | @Override 32 | public int hashCode() { 33 | int result = label != null ? label.hashCode() : 0; 34 | result = 31 * result + index; 35 | return result; 36 | } 37 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/sequence/StateSpace.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.sequence; 2 | 3 | import java.io.Serializable; 4 | import java.util.ArrayList; 5 | import java.util.HashMap; 6 | import java.util.List; 7 | import java.util.Map; 8 | 9 | public class StateSpace implements Serializable { 10 | 11 | private final List states = new ArrayList(); 12 | private final Map stateIndexer = new HashMap(); 13 | private boolean lockedStates = false; 14 | 15 | public final State startState ; 16 | public final State stopState ; 17 | 18 | private final List allTrans = new ArrayList(); 19 | private List[] transFrom; 20 | private List[] transTo; 21 | 22 | public final static String startLabel = ""; 23 | public final static String stopLabel = ""; 24 | 25 | public StateSpace() { 26 | startState = addState(startLabel); 27 | stopState = addState(stopLabel); 28 | } 29 | 30 | public synchronized State addState(String label) { 31 | assert (!lockedStates); 32 | State existing = stateIndexer.get(label) ; 33 | if (existing != null) return existing; 34 | State state = new State(label, states.size()); 35 | states.add(state); 36 | stateIndexer.put(label, state); 37 | return state; 38 | } 39 | 40 | public synchronized void lockStates() { 41 | this.lockedStates = true; 42 | final int numStates = getStates().size(); 43 | transFrom = new List[numStates]; 44 | transTo = new List[numStates]; 45 | for (int s = 0; s < getStates().size(); ++s) { 46 | transFrom[s] = new ArrayList(); 47 | transTo[s] = new ArrayList(); 48 | } 49 | } 50 | 51 | 52 | 53 | public Transition findTransition(String start, String stop) { 54 | List transs = getTransitionsFrom(getState(start).index); 55 | for (Transition trans : transs) { 56 | if (trans.to.label.equals(stop)) { 57 | return trans; 58 | } 59 | } 60 | return null; 61 | } 62 | 63 | public synchronized boolean isStateLocked() { 64 | return lockedStates; 65 | } 66 | 67 | public List getStates() { 68 | return states; 69 | } 70 | 71 | public State getState(String label) { 72 | return stateIndexer.get(label); 73 | } 74 | 75 | public synchronized Transition addTransition(String fromLabel, String toLabel) { 76 | if (!lockedStates) { 77 | lockStates(); 78 | } 79 | State fromState = stateIndexer.get(fromLabel); 80 | assert (fromState != null); 81 | State toState = stateIndexer.get(toLabel); 82 | if (toState == startState) { 83 | throw new RuntimeException("Added transition to start-state: " + startState); 84 | } 85 | assert (toState != null); 86 | Transition found = findTransition(fromLabel, toLabel); 87 | if (found != null) return found; 88 | Transition trans = new Transition(fromState,toState,allTrans.size()); 89 | allTrans.add(trans); 90 | transFrom[trans.from.index].add(trans); 91 | transTo[trans.to.index].add(trans); 92 | return trans; 93 | } 94 | 95 | public synchronized List getTransitions() { 96 | return allTrans; 97 | } 98 | 99 | public List getTransitionsFrom(int s) { 100 | return transFrom[s]; 101 | } 102 | 103 | public List getTransitionsTo(int s) { 104 | return transTo[s]; 105 | } 106 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/sequence/StateSpaces.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.sequence; 2 | 3 | 4 | import java.util.ArrayList; 5 | import java.util.List; 6 | 7 | public class StateSpaces { 8 | 9 | public static StateSpace makeFullStateSpace(List labels) { 10 | assert !labels.contains(StateSpace.startLabel); 11 | assert !labels.contains(StateSpace.stopLabel); 12 | StateSpace res = new StateSpace(); 13 | for (String label : labels) { 14 | res.addState(label); 15 | } 16 | for (String label : labels) { 17 | res.addTransition(StateSpace.startLabel,label); 18 | res.addTransition(label,StateSpace.stopLabel); 19 | for (String nextLabel : labels) { 20 | res.addTransition(label, nextLabel); 21 | } 22 | } 23 | return res; 24 | } 25 | 26 | // private static Tree> buildTransitionTree(Tree> root, StateSpace stateSpace, int depth) { 27 | // if (depth == 0) { 28 | // return root; 29 | // } 30 | // List curNGram = root.getLabel(); 31 | // List>> children = new ArrayList>>(); 32 | // if (curNGram.isEmpty()) { 33 | // for (State state : stateSpace.getStates()) { 34 | // Tree> child = 35 | // buildTransitionTree(new Tree(Collections.singletonList(state.label)),stateSpace,depth-1); 36 | // children.sum(child); 37 | // } 38 | // } else { 39 | // L last = curNGram.get(curNGram.size()-1); 40 | // State lastState = stateSpace.getState(last); 41 | // for (Transition trans : stateSpace.getTransitionsFrom(lastState.index)) { 42 | // List newNGram = new ArrayList(curNGram); 43 | // newNGram.sum(trans.to.label); 44 | // Tree> child = 45 | // buildTransitionTree(new Tree(newNGram),stateSpace,depth-1); 46 | // children.sum(child); 47 | // } 48 | // } 49 | // return new Tree>(curNGram, children); 50 | // } 51 | // 52 | // public static StateSpace> makeNGramStateSpace(StateSpace stateSpace, int nGram) { 53 | // Tree> transTree = buildTransitionTree(new Tree>(new ArrayList()),stateSpace, nGram); 54 | // System.out.println(transTree.getTerminalYield()); 55 | // return null; 56 | // } 57 | 58 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/sequence/TokenF1Eval.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.sequence; 2 | 3 | import edu.umass.nlp.ml.F1Stats; 4 | import edu.umass.nlp.utils.Collections; 5 | 6 | import java.util.HashSet; 7 | import java.util.List; 8 | import java.util.Map; 9 | import java.util.Set; 10 | 11 | public class TokenF1Eval { 12 | 13 | public static void updateEval(Map stats, 14 | List trueLabels, 15 | List guessLabels) 16 | { 17 | 18 | for (int i = 0; i < trueLabels.size(); i++) { 19 | String trueLabel = trueLabels.get(i); 20 | String guessLabel = guessLabels.get(i); 21 | 22 | if (trueLabel.equals(guessLabel)) { 23 | Collections.getMut(stats,trueLabel, new F1Stats(trueLabel)).tp++; 24 | } else { 25 | Collections.getMut(stats,trueLabel, new F1Stats(trueLabel)).fn++; 26 | Collections.getMut(stats,guessLabel, new F1Stats(guessLabel)).fp++; 27 | } 28 | } 29 | } 30 | 31 | public static F1Stats getAvgStats(Map stats) { 32 | return getAvgStats(stats,Collections.set(StateSpace.startLabel, StateSpace.stopLabel, "NONE")) ; 33 | } 34 | 35 | public static F1Stats getAvgStats(Map stats,Set toIgnore) { 36 | F1Stats avgStats = new F1Stats("AVG"); 37 | for (Map.Entry entry : stats.entrySet()) { 38 | String label = entry.getKey(); 39 | if (!toIgnore.contains(label)) { 40 | avgStats.merge(entry.getValue()); 41 | } 42 | } 43 | return avgStats; 44 | } 45 | 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/ml/sequence/Transition.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.ml.sequence; 2 | 3 | import java.io.Serializable; 4 | 5 | public class Transition implements Serializable { 6 | 7 | public final State from; 8 | public final State to; 9 | public final int index; 10 | 11 | public Transition(State from, State to, int index) { 12 | this.from = from; 13 | this.to = to; 14 | this.index = index; 15 | } 16 | 17 | @Override 18 | public String toString() { 19 | return String.format("Trans(%s,%s)",from.label,to.label); 20 | } 21 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/optimize/BacktrackingLineMinimizer.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.optimize; 2 | 3 | import edu.umass.nlp.exec.Execution; 4 | import edu.umass.nlp.utils.BasicPair; 5 | import edu.umass.nlp.utils.DoubleArrays; 6 | import edu.umass.nlp.utils.IPair; 7 | import org.apache.log4j.Logger; 8 | 9 | public class BacktrackingLineMinimizer implements ILineMinimizer { 10 | 11 | public Result minimizeAlongDirection(IDifferentiableFn fn, double[] initial, double[] direction, Opts opts) { 12 | final Logger logger = Logger.getLogger(BacktrackingLineMinimizer.class.getSimpleName()); 13 | logger.setLevel(opts.logLevel); 14 | double stepSize = opts.initialStepSize; 15 | final IPair initPair = fn.computeAt(initial); 16 | final double initVal = initPair.getFirst(); 17 | final double[] initDeriv = initPair.getSecond(); 18 | final double directDeriv = DoubleArrays.innerProduct(initDeriv, direction); 19 | logger.trace("DirectionalDeriv: " + directDeriv); 20 | for (int iter=0; iter < opts.maxIterations; ++iter) { 21 | final double[] guess = DoubleArrays.addMultiples(initial, 1.0, direction, stepSize); 22 | final double curVal = fn.computeAt(guess).getFirst(); 23 | final double targetVal = initVal + opts.sufficientDecreaseConstant * directDeriv * stepSize; 24 | final double diff = curVal - targetVal; 25 | logger.trace(String.format("iter=%d stepSize=%.6f curVal=%.4f targetVal=%.4f diff=%.5f",iter,stepSize, curVal,targetVal,diff)); 26 | if (curVal <= targetVal) { 27 | Result res = new Result(); 28 | res.minimized = guess; 29 | res.stepSize = stepSize; 30 | return res; 31 | } 32 | stepSize *= opts.stepSizeMultiplier ; 33 | if (stepSize < ILineMinimizer.STEP_SIZE_TOLERANCE) { 34 | logger.warn("step size underflow"); 35 | break; 36 | } 37 | } 38 | Result deflt = new Result(); 39 | deflt.minimized = initial; 40 | deflt.stepSize = stepSize; 41 | return deflt; 42 | } 43 | 44 | public static void main(String[] args) { 45 | Execution.init(null); 46 | IDifferentiableFn function = new IDifferentiableFn() { 47 | public int getDimension() { 48 | return 1; 49 | } 50 | 51 | public IPair computeAt(double[] x) { 52 | 53 | double val = x[0] * (x[0] - 0.01); 54 | double[] grad = new double[] { 2*x[0] - 0.01 }; 55 | return BasicPair.make(val, grad); 56 | } 57 | }; 58 | 59 | ILineMinimizer.Opts opts = new ILineMinimizer.Opts(); 60 | Result res = (new BacktrackingLineMinimizer()).minimizeAlongDirection(function, 61 | new double[] { 0 }, 62 | new double[] { 1 }, 63 | opts); 64 | System.out.println(res.stepSize); 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/optimize/CachingDifferentiableFn.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.optimize; 2 | 3 | import edu.umass.nlp.utils.BasicPair; 4 | import edu.umass.nlp.utils.DoubleArrays; 5 | import edu.umass.nlp.utils.IPair; 6 | 7 | 8 | public class CachingDifferentiableFn implements IDifferentiableFn { 9 | private final IDifferentiableFn fn ; 10 | private double[] lastX; 11 | private IPair lastVal; 12 | 13 | public CachingDifferentiableFn(IDifferentiableFn fn) { 14 | this.fn = fn; 15 | } 16 | 17 | public IPair computeAt(double[] x) { 18 | if (lastX != null && java.util.Arrays.equals(x,lastX)) { 19 | return BasicPair.make(lastVal.getFirst(),DoubleArrays.clone(lastVal.getSecond())); 20 | } 21 | this.lastX = DoubleArrays.clone(x); 22 | this.lastVal = fn.computeAt(x); 23 | return BasicPair.make(lastVal.getFirst(),DoubleArrays.clone(lastVal.getSecond())); 24 | } 25 | 26 | public int getDimension() { 27 | return fn.getDimension(); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/optimize/GradientDescent.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.optimize; 2 | 3 | import edu.umass.nlp.exec.Execution; 4 | import edu.umass.nlp.utils.BasicPair; 5 | import edu.umass.nlp.utils.DoubleArrays; 6 | import edu.umass.nlp.utils.IPair; 7 | import edu.umass.nlp.utils.SloppyMath; 8 | import org.apache.log4j.Logger; 9 | 10 | public class GradientDescent implements IOptimizer { 11 | 12 | 13 | public Result minimize(IDifferentiableFn fn, double[] initial, Opts opts) { 14 | Logger logger = Logger.getLogger(GradientDescent.class.getSimpleName()); 15 | logger.setLevel(opts.logLevel); 16 | double lastVal = Double.POSITIVE_INFINITY; 17 | double[] curGuess = initial; 18 | for (int iter=0; iter < opts.maxIters; ++iter) { 19 | final IPair fnPair = fn.computeAt(curGuess); 20 | final double curVal = fnPair.getFirst(); 21 | final double[] curGrad = fnPair.getSecond(); 22 | final double relDiff = SloppyMath.relativeDifference(curVal, lastVal); 23 | logger.info("curGrad: " + DoubleArrays.toString(curGrad)); 24 | logger.info(String.format("iter: %d curVal: %.3f lastVal: %.3f relDiff: %.3f",iter,curVal,lastVal,relDiff)); 25 | if (relDiff <= opts.tol) { 26 | Result res = new Result(); 27 | res.minArg = curGuess; 28 | res.minObjVal = curVal; 29 | res.minGrad = curGrad; 30 | res.didConverge = true; 31 | return res; 32 | } 33 | ILineMinimizer.Result lineMinResult = OptimizeUtils.doLineMinimization(fn,curGuess,DoubleArrays.scale(curGrad,-1),opts,iter); 34 | lastVal = curVal; 35 | curGuess = lineMinResult.minimized; 36 | logger.info("curGuess: " + DoubleArrays.toString(curGuess)); 37 | } 38 | Result deflt = new Result(); 39 | deflt.minArg = curGuess; 40 | return deflt; 41 | } 42 | 43 | public static void main(String[] args) { 44 | Execution.init(null); 45 | IDifferentiableFn fn = new IDifferentiableFn() { 46 | public IPair computeAt(double[] x) { 47 | return BasicPair.make( (x[0] -1.0)* (x[0]-1.0), new double[] { 2 * x[0] - 2 } ); 48 | } 49 | 50 | public int getDimension() { 51 | return 1; 52 | } 53 | }; 54 | IOptimizer.Result res = (new LBFGSMinimizer()).minimize(fn, new double[] { 1.0 }, new LBFGSMinimizer.Opts()); 55 | System.out.println("res: " + DoubleArrays.toString(res.minArg)); 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/optimize/IDifferentiableFn.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.optimize; 2 | 3 | import edu.umass.nlp.utils.IPair; 4 | 5 | /** 6 | * A differentiable function is one that can take a vector of 7 | * numbers and return the pair f(x) and gradient of f at x 8 | */ 9 | public interface IDifferentiableFn { 10 | public IPair computeAt(double[] x); 11 | public int getDimension(); 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/optimize/ILineMinimizer.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.optimize; 2 | 3 | import org.apache.log4j.Level; 4 | 5 | 6 | public interface ILineMinimizer { 7 | 8 | public static final double STEP_SIZE_TOLERANCE = 1.0e-10; 9 | 10 | public static class Opts { 11 | public double stepSizeMultiplier = 0.5;//was 0.9; 12 | public double sufficientDecreaseConstant = 1e-4; 13 | public double initialStepSize = 1.0; 14 | public double tol = 1.e0-6; 15 | public int maxIterations = Integer.MAX_VALUE; 16 | public Level logLevel = Level.OFF; 17 | } 18 | 19 | public static class Result { 20 | double stepSize; 21 | double[] minimized = null; 22 | public boolean didStepSizeUnderflow() { 23 | return stepSize < STEP_SIZE_TOLERANCE; 24 | } 25 | } 26 | 27 | /** 28 | * Given a function fn and an initial point x0 and a direction dir 29 | * we want to return scalar alpha that minimizes 30 | * min_alpha f(x0 + alpha * dir) 31 | * where alpha * dir scales the vector dir by scalar alpha 32 | */ 33 | public Result minimizeAlongDirection(IDifferentiableFn fn, double[] initial, double[] direction, Opts opts); 34 | 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/optimize/IOptimizer.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.optimize; 2 | 3 | import edu.umass.nlp.functional.CallbackFn; 4 | import org.apache.log4j.Level; 5 | 6 | /** 7 | * 8 | * 9 | * 10 | * @author aria42 11 | */ 12 | public interface IOptimizer { 13 | 14 | 15 | public static class Opts { 16 | public int minIters = 10; 17 | public int maxIters = 50; 18 | public double tol = 1.0e-4; 19 | public Level logLevel = Level.INFO; 20 | public double initialStepSizeMultiplier = 0.01; 21 | public double stepSizeMultiplier = 0.5; 22 | // iterCallback at the end of each 23 | // iteration with a single Result argument 24 | public CallbackFn iterCallback = null; 25 | } 26 | 27 | public static class Result { 28 | public double minObjVal; 29 | public double[] minGrad; 30 | public double[] minArg; 31 | public boolean didConverge = false; 32 | } 33 | 34 | Result minimize(IDifferentiableFn fn, double[] initial, Opts opts); 35 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/optimize/LBFGSMinimizer.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.optimize; 2 | 3 | import edu.umass.nlp.utils.DoubleArrays; 4 | import edu.umass.nlp.utils.IPair; 5 | import edu.umass.nlp.utils.SloppyMath; 6 | import org.apache.log4j.Logger; 7 | 8 | import java.util.LinkedList; 9 | import java.util.List; 10 | 11 | /** 12 | * 13 | * @author Dan Klein 14 | * @author aria42 (significant mods) 15 | */ 16 | public class LBFGSMinimizer implements IOptimizer { 17 | 18 | /** 19 | * LBFGS specific options, when you call minimize() 20 | * you should pass in an instance of this class 21 | * (or you get the defaults) 22 | */ 23 | public static class Opts extends IOptimizer.Opts { 24 | public int maxHistorySize = 6; 25 | public int maxHistoryResets = 0; 26 | public String checkpointPath = null; 27 | } 28 | 29 | private List inputDiffs = new LinkedList(); 30 | private List derivDiffs = new LinkedList(); 31 | private Logger logger = Logger.getLogger("LBFGSMinimizer"); 32 | private Opts opts; 33 | 34 | private double[] getInitialInverseHessianDiagonal(IDifferentiableFn fn) { 35 | double scale = 1.0; 36 | if (!derivDiffs.isEmpty()) { 37 | double[] lastDerivativeDifference = derivDiffs.get(0); 38 | double[] lastInputDifference = inputDiffs.get(0); 39 | double num = DoubleArrays.innerProduct(lastDerivativeDifference, 40 | lastInputDifference); 41 | double den = DoubleArrays.innerProduct(lastDerivativeDifference, 42 | lastDerivativeDifference); 43 | scale = num / den; 44 | } 45 | return DoubleArrays.constantArray(scale, fn.getDimension()); 46 | } 47 | 48 | private int getHistorySize() { 49 | assert inputDiffs.size() == derivDiffs.size(); 50 | return inputDiffs.size(); 51 | } 52 | 53 | public double[] getSearchDirection(IDifferentiableFn fn, double[] derivative) { 54 | double[] initialInverseHessianDiagonal = getInitialInverseHessianDiagonal(fn); 55 | double[] direction = implicitMultiply(initialInverseHessianDiagonal, derivative); 56 | return direction; 57 | } 58 | 59 | 60 | private double[] implicitMultiply(double[] initialInverseHessianDiagonal, 61 | double[] derivative) { 62 | double[] rho = new double[getHistorySize()]; 63 | double[] alpha = new double[getHistorySize()]; 64 | double[] right = DoubleArrays.clone(derivative); 65 | // loop last backward 66 | for (int i = getHistorySize() - 1; i >= 0; i--) { 67 | double[] inputDifference = inputDiffs.get(i); 68 | double[] derivativeDifference = derivDiffs.get(i); 69 | rho[i] = DoubleArrays.innerProduct(inputDifference, derivativeDifference); 70 | if (rho[i] == 0.0) { 71 | logger.fatal("LBFGSMinimizer.implicitMultiply: Curvature problem."); 72 | } 73 | alpha[i] = DoubleArrays.innerProduct(inputDifference, right) / rho[i]; 74 | right = DoubleArrays.addMultiples(right, 1.0, derivativeDifference, -1.0 * alpha[i]); 75 | } 76 | double[] left = DoubleArrays.pointwiseMultiply(initialInverseHessianDiagonal, right); 77 | for (int i = 0; i < getHistorySize(); i++) { 78 | double[] inputDifference = inputDiffs.get(i); 79 | double[] derivativeDifference = derivDiffs.get(i); 80 | double beta = DoubleArrays.innerProduct(derivativeDifference, left) / rho[i]; 81 | left = DoubleArrays.addMultiples(left, 1.0, inputDifference, alpha[i] - beta); 82 | } 83 | return left; 84 | } 85 | 86 | private void clearHistories() { 87 | inputDiffs.clear(); 88 | derivDiffs.clear(); 89 | } 90 | 91 | protected void updateHistories(double[] cur, double[] next, List diffs) { 92 | double[] diff = DoubleArrays.addMultiples(next, 1.0, cur, -1.0); 93 | diffs.add(0, diff); 94 | if (diffs.size() > opts.maxHistorySize) { 95 | diffs.remove(diffs.size() - 1); 96 | } 97 | } 98 | 99 | private Result getResult(IDifferentiableFn fn, double[] x) { 100 | Result res = new Result(); 101 | res.minArg = DoubleArrays.clone(x); 102 | IPair fnEval = fn.computeAt(res.minArg); 103 | res.minObjVal = fnEval.getFirst(); 104 | res.minGrad = fnEval.getSecond(); 105 | return res; 106 | } 107 | 108 | public Result minimize(IDifferentiableFn fn, double[] initial, IOptimizer.Opts opts) { 109 | int numResets = 0; 110 | if (opts == null) opts = new Opts(); 111 | clearHistories(); 112 | logger.setLevel(opts.logLevel); 113 | if (!(opts instanceof Opts)) { 114 | logger.warn("opts are not LBFGS specific, reverting to all-default opts"); 115 | opts = new Opts(); 116 | } 117 | this.opts = (Opts) opts; 118 | 119 | Result cur = getResult(fn, initial); 120 | for (int iter = 0; iter < opts.maxIters; ++iter) { 121 | double[] dir = getSearchDirection(fn, cur.minGrad); 122 | DoubleArrays.scaleInPlace(dir, -1); 123 | ILineMinimizer.Result lineMinRes = 124 | OptimizeUtils.doLineMinimization(fn, cur.minArg, dir, opts, iter); 125 | Result next = getResult(fn,lineMinRes.minimized); 126 | final double relDiff = SloppyMath.relativeDifference(cur.minObjVal, next.minObjVal); 127 | //logger.info("relDiff: " + relDiff); 128 | if (iter > opts.minIters && relDiff < opts.tol) { 129 | if (numResets < ((Opts) opts).maxHistoryResets) { 130 | logger.info("Dumping Cache"); 131 | iter--; 132 | numResets++; 133 | clearHistories(); 134 | continue; 135 | } 136 | logger.info(String.format("Finished: value: %.5f",cur.minObjVal)); 137 | return cur; 138 | } 139 | // Updates 140 | updateHistories(cur.minArg,next.minArg, inputDiffs); 141 | updateHistories(cur.minGrad,next.minGrad, derivDiffs); 142 | logger.info(String.format("End of iter %d: value: %.5f relDiff: %.3f", 143 | iter+1, 144 | cur.minObjVal, 145 | SloppyMath.relativeDifference(cur.minObjVal, next.minObjVal))); 146 | cur = next; 147 | 148 | } 149 | return cur; 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/optimize/OptimizeUtils.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.optimize; 2 | 3 | 4 | public class OptimizeUtils { 5 | public static ILineMinimizer.Opts getLineMinimizerOpts(IOptimizer.Opts opts, int iter) { 6 | ILineMinimizer.Opts lineMineOpts = new ILineMinimizer.Opts(); 7 | lineMineOpts.logLevel = opts.logLevel; 8 | lineMineOpts.stepSizeMultiplier = iter > 0 ? 9 | opts.stepSizeMultiplier : 10 | opts.initialStepSizeMultiplier; 11 | return lineMineOpts; 12 | } 13 | 14 | public static ILineMinimizer.Result doLineMinimization(IDifferentiableFn fn, 15 | double[] initial, 16 | double[] direction, 17 | IOptimizer.Opts opts, 18 | int iter) { 19 | return (new BacktrackingLineMinimizer()) 20 | .minimizeAlongDirection( 21 | fn, 22 | initial, 23 | direction, 24 | getLineMinimizerOpts(opts, iter)); 25 | } 26 | 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/parallel/ParallelUtils.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.parallel; 2 | 3 | import edu.umass.nlp.functional.Fn; 4 | import edu.umass.nlp.functional.Functional; 5 | import edu.umass.nlp.utils.Collections; 6 | 7 | import java.io.PrintWriter; 8 | import java.util.ArrayList; 9 | import java.util.List; 10 | import java.util.concurrent.ExecutorService; 11 | import java.util.concurrent.Executors; 12 | import java.util.concurrent.Semaphore; 13 | import java.util.concurrent.TimeUnit; 14 | 15 | 16 | public class ParallelUtils { 17 | 18 | public static void doParallelWork(List runnables, 19 | int numThreads) 20 | { 21 | final ExecutorService executor = Executors.newFixedThreadPool(numThreads); 22 | for (final Runnable runnable : runnables) { 23 | executor.execute(runnable); 24 | } 25 | try { 26 | executor.shutdown(); 27 | executor.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS); 28 | } catch (InterruptedException ie) { 29 | ie.printStackTrace(); 30 | } 31 | } 32 | 33 | public static void main(String[] args) { 34 | List ints = new ArrayList(); 35 | for (int i=1; i <= 100; ++i) { 36 | ints.add(i); 37 | } 38 | List> parts = Collections.partition(ints, Runtime.getRuntime().availableProcessors()); 39 | class Worker implements Runnable { 40 | double sum = 0.0; 41 | List part; 42 | Worker(List part) { 43 | this.part = part; 44 | } 45 | public void run() { 46 | for (Integer i : part) { 47 | sum += i; 48 | } 49 | } 50 | } 51 | List workers = Functional.map(parts, new Fn, Worker>() { 52 | public Worker apply(List input) { 53 | return new Worker(input); 54 | }}); 55 | doParallelWork(workers, Runtime.getRuntime().availableProcessors()); 56 | double totalSum = 0.0; 57 | System.out.println("numThreads: " + workers.size()); 58 | for (Worker worker : workers) { 59 | totalSum += worker.sum; 60 | } 61 | System.out.println("total sum: " + totalSum); 62 | } 63 | 64 | } 65 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/process/Document.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.process; 2 | 3 | import edu.umass.nlp.functional.Fn; 4 | import edu.umass.nlp.functional.Functional; 5 | import edu.umass.nlp.utils.BasicPair; 6 | import edu.umass.nlp.utils.IHasProperties; 7 | import edu.umass.nlp.utils.IPair; 8 | 9 | import java.util.ArrayList; 10 | import java.util.HashMap; 11 | import java.util.List; 12 | import java.util.Map; 13 | 14 | public class Document implements IHasProperties { 15 | private final Map properties = new HashMap(); 16 | public final List> sentences = new ArrayList>(); 17 | 18 | public Document(List rawTokens) { 19 | for (Token token : rawTokens) { 20 | while (token.sentIndex <= sentences.size()) { 21 | sentences.add(new ArrayList()); 22 | } 23 | sentences.get(token.sentIndex).add(token); 24 | } 25 | } 26 | 27 | @Override 28 | public Object getProperty(String name) { 29 | return properties.get(name); 30 | } 31 | 32 | @Override 33 | public List> getProperties() { 34 | return Functional.map(properties.entrySet(), new Fn, IPair>() { 35 | @Override 36 | public IPair apply(Map.Entry input) { 37 | return BasicPair.make(input.getKey(), input.getValue()); 38 | }}); 39 | } 40 | 41 | @Override 42 | public void addProperty(String name, Object val) { 43 | properties.put(name, val); 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/process/Token.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.process; 2 | 3 | import edu.umass.nlp.functional.Fn; 4 | import edu.umass.nlp.io.IOUtils; 5 | import edu.umass.nlp.utils.Span; 6 | 7 | import java.io.InputStream; 8 | import java.util.ArrayList; 9 | import java.util.Collections; 10 | import java.util.List; 11 | 12 | /** 13 | * Represents a token grounded in a document. Also 14 | * tracks the character span from the source in which 15 | * the token appears. Also, a field for a label 16 | * in case its needed. 17 | */ 18 | public class Token implements Comparable { 19 | 20 | public final String origWord; 21 | public int tokIndex; 22 | public int sentIndex; 23 | public Span charSpan; 24 | public String word ; 25 | public String label = NO_LABEL; 26 | 27 | public static String NO_LABEL = "O"; 28 | 29 | public Token(String origWord, Span charSpan, int tokIndex, int sentIndex) { 30 | //assert origWord.length() == charSpan.getLength(); 31 | this.origWord = origWord; 32 | this.charSpan = charSpan; 33 | this.tokIndex = tokIndex; 34 | this.sentIndex = sentIndex; 35 | this.word = origWord; 36 | } 37 | 38 | public String toLine() { 39 | return String.format("%s %s %d %d %d %d %s",word,origWord,charSpan.getStart(),charSpan.getStop(),tokIndex,sentIndex,label); 40 | } 41 | 42 | public String getOrigWord() { 43 | return origWord; 44 | } 45 | 46 | public int getTokIndex() { 47 | return tokIndex; 48 | } 49 | 50 | public int getSentIndex() { 51 | return sentIndex; 52 | } 53 | 54 | public int getStartChar() { 55 | return charSpan.getStart(); 56 | } 57 | 58 | public int getStopChar() { 59 | return charSpan.getStop(); 60 | } 61 | 62 | public String getWord() { 63 | return word; 64 | } 65 | 66 | public String getLabel() { 67 | return label; 68 | } 69 | 70 | public void setLabel(String label) { 71 | this.label = label; 72 | } 73 | 74 | public void setWord(String word) { 75 | 76 | this.word = word; 77 | } 78 | 79 | public Span getCharSpan() { 80 | return charSpan; 81 | } 82 | 83 | @Override 84 | public boolean equals(Object o) { 85 | if (this == o) return true; 86 | if (o == null || getClass() != o.getClass()) return false; 87 | 88 | Token token = (Token) o; 89 | 90 | if (sentIndex != token.sentIndex) return false; 91 | if (tokIndex != token.tokIndex) return false; 92 | 93 | return true; 94 | } 95 | 96 | @Override 97 | public int hashCode() { 98 | int result = tokIndex; 99 | result = 31 * result + sentIndex; 100 | return result; 101 | } 102 | 103 | public int compareTo(Token o) { 104 | if (o.sentIndex != this.sentIndex) { 105 | return this.sentIndex - o.sentIndex; 106 | } 107 | return this.tokIndex - o.tokIndex; 108 | } 109 | 110 | @Override 111 | public String toString() { 112 | return String.format("Token(%d,%d,%s,%s)",sentIndex,tokIndex,word,label); 113 | } 114 | 115 | 116 | /* 117 | * 118 | * Static Factory Methods 119 | * 120 | */ 121 | 122 | public static Token fromLine(String line) { 123 | String[] pieces = line.split("\\s+"); 124 | if (pieces.length != 7) throw new RuntimeException("Token, bad line: " + line); 125 | Token tok = new Token(pieces[1],new Span(Integer.parseInt(pieces[2]),Integer.parseInt(pieces[3])), 126 | Integer.parseInt(pieces[4]),Integer.parseInt(pieces[5])); 127 | tok.label = pieces[6]; 128 | tok.word = pieces[0]; 129 | return tok; 130 | } 131 | 132 | public static Fn fromLineFn = new Fn() { 133 | public Token apply(String input) { return fromLine(input); } 134 | }; 135 | 136 | public static List> fromInputStream(InputStream is) { 137 | List lines = IOUtils.lines(is); 138 | List> res = new ArrayList>(); 139 | for (int i = 0; i < lines.size(); ++i) { 140 | Token tok = fromLine(lines.get(i)); 141 | if (tok.getSentIndex() >= res.size()) { 142 | assert tok.getSentIndex() == res.size(); 143 | res.add(new ArrayList()); 144 | } 145 | res.get(tok.getSentIndex()).add(tok); 146 | } 147 | for (List toks : res) { 148 | Collections.sort(toks); 149 | } 150 | return res; 151 | } 152 | 153 | 154 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/text/HTMLUtils.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.text; 2 | 3 | 4 | import edu.umass.nlp.utils.Span; 5 | import net.htmlparser.jericho.*; 6 | 7 | import java.util.List; 8 | 9 | public class HTMLUtils { 10 | 11 | public static Span getSpan(Segment e) { 12 | return new Span(e.getBegin(), e.getEnd()); 13 | } 14 | 15 | public static Span getContentSpan(Element e) { 16 | return getSpan(e.getContent()); 17 | } 18 | 19 | public static int[] getCharOffsets(String html) { 20 | int[] res = new int[html.length()]; 21 | Source source = new Source(html); 22 | List elems = source.getAllElements(); 23 | for (int i=0; i < res.length; ++i) { 24 | res[i] = i; 25 | } 26 | for (Element elem : elems) { 27 | StartTag startTag = elem.getStartTag(); 28 | Span startSpan = getSpan(startTag); 29 | EndTag stopTag = elem.getEndTag(); 30 | Span stopSpan = getSpan(stopTag); 31 | for (int i=startSpan.getStart(); i < startSpan.getStop(); ++i) { 32 | res[i] -= (startSpan.getLength() - (startSpan.getStop()-i)); 33 | } 34 | for (int i=startSpan.getStop(); i < stopSpan.getStart(); ++i) { 35 | res[i] -= startSpan.getLength(); 36 | } 37 | for (int i=stopSpan.getStart(); i < stopSpan.getStop(); ++i) { 38 | res[i] -= (stopSpan.getLength() - (stopSpan.getStop()-i)) ; 39 | } 40 | for (int i=stopSpan.getStop(); i < html.length(); ++i) { 41 | res[i] -= (startSpan.getLength() + stopSpan.getLength()); 42 | } 43 | } 44 | return res; 45 | } 46 | 47 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/trees/BasicTree.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.trees; 2 | 3 | import edu.umass.nlp.utils.Collections; 4 | 5 | import java.util.ArrayList; 6 | import java.util.List; 7 | 8 | public class BasicTree implements ITree { 9 | 10 | private final L label; 11 | private final List> children; 12 | 13 | public BasicTree(L label) { 14 | this(label, java.util.Collections.>emptyList()); 15 | } 16 | 17 | public BasicTree(L label, List> children) { 18 | this.label = label; 19 | this.children = java.util.Collections.unmodifiableList(new ArrayList>(children)); 20 | } 21 | 22 | // 23 | // ITree 24 | // 25 | 26 | public List> getChildren() { 27 | return children; 28 | } 29 | 30 | public L getLabel() { 31 | return label; 32 | } 33 | 34 | // 35 | // Object 36 | // 37 | 38 | public boolean equals(Object o) { 39 | if (this == o) return true; 40 | if (o == null || getClass() != o.getClass()) return false; 41 | 42 | BasicTree basicTree = (BasicTree) o; 43 | 44 | if (children != null ? !children.equals(basicTree.children) : basicTree.children != null) return false; 45 | if (label != null ? !label.equals(basicTree.label) : basicTree.label != null) return false; 46 | 47 | return true; 48 | } 49 | 50 | public int hashCode() { 51 | int result = label != null ? label.hashCode() : 0; 52 | result = 31 * result + (children != null ? children.hashCode() : 0); 53 | return result; 54 | } 55 | 56 | public String toString() { 57 | return Trees.toString(this); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/trees/ITree.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.trees; 2 | 3 | import edu.umass.nlp.utils.Span; 4 | import java.util.List; 5 | 6 | /** 7 | * Abstraction of a Tree. See class Tree for 8 | * methods on a tree (getNodes, 9 | * @param 10 | */ 11 | public interface ITree { 12 | 13 | public L getLabel(); 14 | 15 | public List> getChildren(); 16 | 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/trees/Trees.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.trees; 2 | 3 | import edu.umass.nlp.functional.Fn; 4 | import edu.umass.nlp.functional.Functional; 5 | import edu.umass.nlp.functional.PredFn; 6 | import edu.umass.nlp.io.IOUtils; 7 | import edu.umass.nlp.utils.*; 8 | import edu.umass.nlp.utils.Collections; 9 | 10 | import java.io.File; 11 | import java.util.*; 12 | 13 | /** 14 | * Methods that can be performed 15 | * on an ITree. 16 | */ 17 | public class Trees { 18 | 19 | public static boolean isLeaf(ITree tree) { 20 | return tree.getChildren().isEmpty(); 21 | } 22 | 23 | public static boolean isPreLeaf(ITree tree) { 24 | return tree.getChildren().size() == 1 && 25 | isLeaf(tree.getChildren().get(0)); 26 | } 27 | 28 | public static String toString(ITree tree) { 29 | StringBuilder sb = new StringBuilder(); 30 | if (isLeaf(tree)) { 31 | sb.append(tree.getLabel().toString()); 32 | } else { 33 | sb.append("("); 34 | sb.append(tree.getLabel().toString()); 35 | List> childs = tree.getChildren(); 36 | for (int i = 0; i < childs.size(); i++) { 37 | ITree child = childs.get(i); 38 | sb.append(" "); 39 | sb.append(toString(child)); 40 | } 41 | sb.append(")"); 42 | } 43 | return sb.toString(); 44 | } 45 | 46 | public static List> getNodes(ITree tree) { 47 | List> res = new ArrayList>(); 48 | res.add(tree); 49 | for (ITree c : tree.getChildren()) { 50 | res.addAll(getNodes(c)); 51 | } 52 | return res; 53 | } 54 | 55 | public static List> getLeaves(ITree tree) { 56 | return Functional.filter(getNodes(tree), new PredFn>() { 57 | public boolean holdsAt(ITree elem) { 58 | return isLeaf(elem); 59 | }}); 60 | } 61 | 62 | public static List> getPreLeaves(ITree tree) { 63 | return Functional.filter(getNodes(tree), new PredFn>() { 64 | public boolean holdsAt(ITree elem) { 65 | return isPreLeaf(elem); 66 | }}); 67 | } 68 | 69 | public static List getPreLeafYield(ITree tree) { 70 | return Functional.map(getPreLeaves(tree), new Fn, L>() { 71 | public L apply(ITree input) { 72 | return input.getLabel(); 73 | }}); 74 | } 75 | 76 | public static List getLeafYield(ITree t) { 77 | return Functional.map(getLeaves(t), new Fn, L>() { 78 | public L apply(ITree input) { 79 | return input.getLabel(); 80 | }}); 81 | } 82 | 83 | public static IdentityHashMap, Span> getSpanMap(ITree root) { 84 | return getSpanMap(root,0); 85 | } 86 | 87 | public static IdentityHashMap, Span> getSpanMap(ITree root, int start) { 88 | IdentityHashMap,Span> res = new IdentityHashMap,Span>(); 89 | int newStart = start; 90 | for (ITree child : root.getChildren()) { 91 | res.putAll(getSpanMap(child, newStart)); 92 | newStart += Trees.getLeafYield(child).size(); 93 | } 94 | res.put(root, new Span(start,newStart)); 95 | return res; 96 | } 97 | 98 | private static class TreeReader { 99 | static IPair,List> nodeFromString(List chars) { 100 | if (chars.get(0) != '(') { 101 | throw new RuntimeException("Error"); 102 | } 103 | chars = dropInitWhiteSpace(Collections.subList(chars,1)); 104 | IPair> labelPair = labelFromString(chars); 105 | String label = labelPair.getFirst(); 106 | List rest = labelPair.getSecond(); 107 | final List> children = new ArrayList>(); 108 | while (!rest.isEmpty() && rest.get(0) != ')') { 109 | IPair,List> childPair = treeFromString(rest); 110 | children.add(childPair.getFirst()); 111 | rest = dropInitWhiteSpace(childPair.getSecond()); 112 | } 113 | rest = Collections.subList(rest,1); 114 | return BasicPair.,List>make(new BasicTree(label,children),rest); 115 | } 116 | 117 | final static Set parens = Collections.set('(',')'); 118 | 119 | static boolean isLabel(Character ch) { 120 | return !Character.isWhitespace(ch.charValue()) && 121 | !parens.contains(ch); 122 | } 123 | 124 | static List dropInitWhiteSpace(List chars) { 125 | for (int i=0; i < chars.size(); ++i) { 126 | if (!Character.isWhitespace(chars.get(i))) { 127 | return Collections.subList(chars,i); 128 | } 129 | } 130 | return java.util.Collections.emptyList(); 131 | } 132 | 133 | static IPair> labelFromString(List chars) { 134 | List labelChars = Functional.takeWhile(chars, new PredFn() { 135 | public boolean holdsAt(Character elem) { 136 | return isLabel(elem); 137 | }}); 138 | List rest = Collections.subList(chars,labelChars.size()); 139 | return BasicPair.make(StringUtils.toString(labelChars), dropInitWhiteSpace(rest)); 140 | } 141 | 142 | static IPair,List> treeFromString(List chars) { 143 | chars = dropInitWhiteSpace(chars); 144 | if (chars.isEmpty()) return null; 145 | try { 146 | if (chars.get(0) == '(') return nodeFromString(chars); 147 | else { 148 | IPair> labelPair = labelFromString(chars); 149 | ITree leaf = new BasicTree(labelPair.getFirst()); 150 | return BasicPair.make(leaf, labelPair.getSecond()); 151 | } 152 | } catch (Exception e) { 153 | throw new RuntimeException("Error parsing String: " + StringUtils.toString(chars)); 154 | } 155 | } 156 | } 157 | 158 | public static ITree readTree(String s) { 159 | return TreeReader.treeFromString(StringUtils.getCharacters(s)).getFirst(); 160 | } 161 | 162 | public static Iterable> readTrees(final String s) { 163 | return new Iterable>() { 164 | List chars = TreeReader.dropInitWhiteSpace(StringUtils.getCharacters(s)); 165 | public Iterator> iterator() { 166 | return new Iterator>() { 167 | public boolean hasNext() { 168 | return !chars.isEmpty(); 169 | } 170 | 171 | public ITree next() { 172 | IPair,List> pair = TreeReader.treeFromString(chars); 173 | ITree t = pair.getFirst(); 174 | chars = TreeReader.dropInitWhiteSpace(pair.getSecond()); 175 | return t; 176 | } 177 | 178 | public void remove() { 179 | throw new RuntimeException("remove() not accepted"); 180 | } 181 | };}}; 182 | } 183 | 184 | public static void main(String[] args) { 185 | // ITree c = new BasicTree("c",new ArrayList()); 186 | // ITree p = new BasicTree("p", Collections.makeList(c,c)); 187 | ITree t = readTree("(S (NP (DT the) (NN man)) (VP (VBD ran)))"); 188 | System.out.println("spanMap: " + getSpanMap(t)); 189 | // List> tags = Functional.filter(getNodes(t), new PredFn>() { 190 | // public boolean holdsAt(ITree elem) { 191 | // return isPreLeaf(elem); 192 | // }}); 193 | // Iterable> trees = 194 | // readTrees(IOUtils.text(new File("/Users/aria42/Dropbox/projs/umass-nlp/trees.mrg"))); 195 | // for (ITree tree : trees) { 196 | // System.out.println(tree); 197 | // } 198 | } 199 | 200 | } 201 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/AtomicDouble.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | /** 4 | * Like AtomicInteger, AtomicDouble 5 | * allows multiple threads to mutate a double safely, by 6 | * putting a lock around it. 7 | * 8 | * @author aria42 (Aria Haghighi) 9 | */ 10 | public final class AtomicDouble implements java.io.Serializable { 11 | 12 | double x = 0.0; 13 | 14 | public AtomicDouble(double initialValue) { 15 | this.x = initialValue; 16 | } 17 | 18 | public AtomicDouble() { 19 | this.x = 0.0; 20 | } 21 | 22 | public double get() { 23 | synchronized (this) { 24 | return x; 25 | } 26 | } 27 | 28 | public void set(double newValue) { 29 | synchronized (this) { 30 | x = newValue; 31 | } 32 | } 33 | 34 | public double getAndSet(double newValue) { 35 | set(newValue); 36 | return get(); 37 | } 38 | 39 | 40 | public double increment(double inc) { 41 | synchronized (this) { 42 | x += inc; 43 | } 44 | return get(); 45 | } 46 | 47 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/BasicPair.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | /** 4 | * Basic Implementation of IPair 5 | * @author aria42 (Aria Haghighi) 6 | * @param 7 | * @param 8 | */ 9 | public class BasicPair implements IPair { 10 | 11 | private final S s; 12 | private final T t; 13 | 14 | public BasicPair(S s, T t) { 15 | this.s = s; 16 | this.t = t; 17 | } 18 | 19 | public S getFirst() { 20 | return s; 21 | } 22 | 23 | public T getSecond() { 24 | return t; 25 | } 26 | 27 | public boolean equals(Object o) { 28 | if (this == o) return true; 29 | if (o == null || getClass() != o.getClass()) return false; 30 | 31 | BasicPair basicPair = (BasicPair) o; 32 | 33 | if (s != null ? !s.equals(basicPair.s) : basicPair.s != null) return false; 34 | if (t != null ? !t.equals(basicPair.t) : basicPair.t != null) return false; 35 | 36 | return true; 37 | } 38 | 39 | public int hashCode() { 40 | int result = s != null ? s.hashCode() : 0; 41 | result = 31 * result + (t != null ? t.hashCode() : 0); 42 | return result; 43 | } 44 | 45 | public static IPair make(S s, T t) { 46 | return new BasicPair(s,t); 47 | } 48 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/BasicValued.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | /** 4 | * Immutable IValued 5 | * @param 6 | */ 7 | public class BasicValued implements IValued { 8 | 9 | private final K elem; 10 | private final double val; 11 | 12 | public BasicValued(K elem, double val) { 13 | this.elem = elem; 14 | this.val = val; 15 | } 16 | 17 | public K getElem() { 18 | return elem; 19 | } 20 | 21 | public double getValue() { 22 | return val; 23 | } 24 | 25 | public IValued withValue(double x) { 26 | return new BasicValued(elem, x); 27 | } 28 | 29 | public K getFirst() { 30 | return getElem(); 31 | } 32 | 33 | public Double getSecond() { 34 | return val; 35 | } 36 | 37 | 38 | public int compareTo(IValued o) { 39 | double v1 = this.getValue(); 40 | double v2 = o.getValue(); 41 | if (v1 < v2) return -1; 42 | if (v2 > v1) return 1; 43 | return 0; 44 | } 45 | 46 | public boolean equals(Object o) { 47 | if (this == o) return true; 48 | if (o == null || getClass() != o.getClass()) return false; 49 | 50 | BasicValued that = (BasicValued) o; 51 | 52 | if (Double.compare(that.val, val) != 0) return false; 53 | if (elem != null ? !elem.equals(that.elem) : that.elem != null) return false; 54 | 55 | return true; 56 | } 57 | 58 | public int hashCode() { 59 | int result; 60 | long temp; 61 | result = elem != null ? elem.hashCode() : 0; 62 | temp = val != +0.0d ? Double.doubleToLongBits(val) : 0L; 63 | result = 31 * result + (int) (temp ^ (temp >>> 32)); 64 | return result; 65 | } 66 | 67 | @Override 68 | public String toString() { 69 | return String.format("(%s,%.4f)", elem, val); 70 | } 71 | 72 | public static BasicValued make(K elem, double val) { 73 | return new BasicValued(elem, val); 74 | } 75 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/Collections.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | 4 | import edu.umass.nlp.functional.FactoryFn; 5 | 6 | import java.util.*; 7 | 8 | /** 9 | * Methods on Collections which are generally 10 | * non-destructive and functional in nature. 11 | * 12 | * @author aria42 (Aria Haghighi) 13 | */ 14 | public class Collections { 15 | 16 | public static List makeList(T... arr) { 17 | List res = new ArrayList(); 18 | for (T t : arr) { 19 | res.add(t); 20 | } 21 | return res; 22 | } 23 | 24 | public static List toList(T[] arr) { 25 | List res = new ArrayList(); 26 | for (T t : arr) { 27 | res.add(t); 28 | } 29 | return res; 30 | } 31 | 32 | public static List toList(Iterable arr) { 33 | List res = new ArrayList(); 34 | for (T t : arr) { 35 | res.add(t); 36 | } 37 | return res; 38 | } 39 | 40 | public static List toList(Iterator it) { 41 | List res = new ArrayList(); 42 | while (it.hasNext()) { 43 | T t = it.next(); 44 | res.add(t); 45 | } 46 | return res; 47 | } 48 | 49 | public static T randChoice(List elems, Random rand) { 50 | int index = rand.nextInt(elems.size()); 51 | return elems.get(index); 52 | } 53 | 54 | public static V getMut(Map map, K key, V notFound) { 55 | V val = map.get(key); 56 | if (val == null) { 57 | val = notFound; 58 | map.put(key,val); 59 | } 60 | return val; 61 | } 62 | 63 | public static V getMut(Map map, K key, FactoryFn fn) { 64 | V val = map.get(key); 65 | if (val == null) { 66 | val = fn.make(); 67 | map.put(key,val); 68 | } 69 | return val; 70 | } 71 | 72 | public static V get(Map map, K key, V notFound) { 73 | V val = map.get(key); 74 | if (val == null) { 75 | return notFound; 76 | } 77 | return val; 78 | } 79 | 80 | public static IPair,List> splitAt(Iterable elems, int index) { 81 | 82 | List before = new ArrayList(); 83 | List after = new ArrayList(); 84 | Iterator it = elems.iterator(); 85 | for (int i=0; it.hasNext(); ++i) { 86 | T elem = it.next(); 87 | (i < index ? before : after).add(elem); 88 | } 89 | return BasicPair.make(before, after); 90 | } 91 | 92 | public static T randChoice(List elems) { 93 | return randChoice(elems, new Random()); 94 | } 95 | 96 | public static List subList(List elems, List indices) { 97 | List res = new ArrayList(); 98 | for (Integer index: indices) { 99 | res.add(elems.get(index)); 100 | } 101 | return res; 102 | } 103 | 104 | public static List subList(List elems, Span span) { 105 | return subList(elems, span.getStart(), span.getStop()); 106 | } 107 | 108 | public static List subList(List elems, int start) { 109 | return subList(elems, start, elems.size()); 110 | } 111 | 112 | public static Set intersect(Iterable s1, Set s2) { 113 | Set res = new HashSet(); 114 | for (T elem : s1) { 115 | if (s2.contains(elem)) res.add(elem); 116 | } 117 | return res; 118 | } 119 | 120 | public static Set intersect(Iterable s1, Collection s2) { 121 | return intersect(s1, new HashSet(s2)); 122 | } 123 | 124 | /** 125 | * Partitions elems into numParts 126 | * each of which are the same size (except possibly the last) 127 | * 128 | * Shouldn't copy list just have views 129 | */ 130 | public static List> partition(List elems, int numParts) { 131 | List> res = new ArrayList>(); 132 | int sizeOfPart = (int) Math.ceil(((double) elems.size()) / numParts); 133 | for (int i=0; i < numParts; ++i) { 134 | int start = i*sizeOfPart; 135 | int stop = Math.min((i + 1) * sizeOfPart, elems.size()); 136 | res.add(Collections.subList(elems,start,stop)); 137 | } 138 | return res; 139 | } 140 | 141 | /** 142 | * Make a set from varargs 143 | */ 144 | public static Set set(T...elems) { 145 | Set res = new HashSet(); 146 | for (T elem : elems) { 147 | res.add(elem); 148 | } 149 | return res; 150 | } 151 | 152 | /** 153 | * Much faster than java.util.List.subList, might be a little less safe. 154 | * 155 | * @author aria42 156 | */ 157 | public static List subList(List elems, int start, int stop) { 158 | class SubList extends AbstractList { 159 | List elems; 160 | int start; 161 | int stop; 162 | 163 | SubList(final List elems, final int start, final int stop) { 164 | if (elems instanceof SubList) { 165 | this.elems = ((SubList)elems).elems; 166 | this.start = ((SubList)elems).start + start; 167 | this.stop = ((SubList)elems).start + stop; 168 | } else { 169 | this.elems = elems; 170 | this.start = start; 171 | this.stop = stop; 172 | } 173 | } 174 | 175 | public T get(int index) { 176 | return elems.get(index+start); 177 | } 178 | 179 | public int size() { 180 | return (stop-start); 181 | } 182 | } 183 | return new SubList(elems,start,stop); 184 | } 185 | 186 | public static List concat(List...items) { 187 | List res = new ArrayList(); 188 | for (List item : items) { 189 | res.addAll(item); 190 | } 191 | return res; 192 | } 193 | 194 | public static List take(Iterable elems, int n) { 195 | List res = new ArrayList(); 196 | for (T elem: elems) { 197 | res.add(elem); 198 | if (res.size() == n) break; 199 | } 200 | return res; 201 | } 202 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/CounterMap.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | import java.util.HashMap; 4 | import java.util.Map; 5 | 6 | public class CounterMap { 7 | 8 | private static final ICounter zero = new MapCounter(); 9 | 10 | public static void incCount(Map> counters, K key, V innerKey, double count) { 11 | Collections.getMut(counters, key, new MapCounter()).incCount(innerKey, count); 12 | } 13 | 14 | public static void setCount(Map> counters, K key, V innerKey, double count) { 15 | Collections.getMut(counters, key, new MapCounter()).setCount(innerKey, count); 16 | } 17 | 18 | public static void getCount(Map> counters, K key, V innerKey) { 19 | Collections.get(counters, key, (ICounter)zero).getCount(innerKey); 20 | } 21 | 22 | public static Map> make() { 23 | return new HashMap>(); 24 | } 25 | 26 | public static void main(String[] args) { 27 | Map> counts = CounterMap.make(); 28 | CounterMap.incCount(counts, "a","b",1.0); 29 | System.out.println(counts); 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/Counters.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | 4 | import edu.umass.nlp.functional.*; 5 | import edu.umass.nlp.ml.prob.DirichletMultinomial; 6 | 7 | import java.util.Comparator; 8 | import java.util.HashSet; 9 | import java.util.List; 10 | import java.util.Set; 11 | 12 | public class Counters { 13 | 14 | public static ICounter scale(ICounter counts, final double alpha) { 15 | return counts.map(new DoubleFn>() { 16 | public double valAt(IValued x) { 17 | return x.getValue() * alpha; 18 | } 19 | }); 20 | } 21 | 22 | public static void scaleDestructive(ICounter counts, final double alpha) { 23 | counts.mapDestructive(new DoubleFn>() { 24 | public double valAt(IValued x) { 25 | return x.getValue() * alpha; 26 | } 27 | }); 28 | } 29 | 30 | public static List> getTopK(ICounter counts, int k) { 31 | return getTopK(counts, k, null); 32 | } 33 | 34 | public static List> getTopK(ICounter counts, int k, final DoubleFn> f) { 35 | List> vals = Collections.toList(counts.iterator()); 36 | if (f != null) { 37 | vals = Functional.map(vals, new Fn, IValued>() { 38 | public IValued apply(IValued input) { 39 | return BasicValued.make(input.getElem(), f.valAt(input)); 40 | } 41 | }); 42 | } 43 | java.util.Collections.sort(vals, new Comparator>() { 44 | public int compare(IValued o1, IValued o2) { 45 | if (o2.getValue() > o1.getValue()) return 1; 46 | if (o2.getValue() < o1.getValue()) return -1; 47 | return 0; 48 | } 49 | }); 50 | return k < vals.size() ? vals.subList(0, k) : vals; 51 | } 52 | 53 | public static void incAll(ICounter accum, ICounter counts) { 54 | for (IValued elem : counts) { 55 | accum.incCount(elem.getElem(), elem.getValue()); 56 | } 57 | } 58 | 59 | public static Set getKeySet(ICounter counter, PredFn> prefFn) { 60 | Set res = new HashSet(); 61 | for (IValued valued : counter) { 62 | if (prefFn.holdsAt(valued)) res.add(valued.getElem()); 63 | } 64 | return res; 65 | } 66 | 67 | public static Set getKeySet(ICounter counter) { 68 | return getKeySet(counter, PredFns.>getTruePredicate()); 69 | } 70 | 71 | public static ICounter from(Iterable> values) { 72 | ICounter counts = new MapCounter(); 73 | for (IValued value : values) { 74 | counts.incCount(value.getElem(), value.getValue()); 75 | } 76 | return counts; 77 | } 78 | 79 | public static ICounter from(double[] vals, List list) { 80 | ICounter counts = new MapCounter(); 81 | for (int i = 0; i < vals.length; i++) { 82 | counts.incCount(list.get(i),vals[i]); 83 | } 84 | return counts; 85 | } 86 | 87 | public static void incAll(ICounter counts, Iterable elems) { 88 | for (T elem : elems) { 89 | counts.incCount(elem,1.0); 90 | } 91 | } 92 | 93 | 94 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/DoubleArrays.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | import java.util.Arrays; 4 | import java.util.Random; 5 | 6 | /** 7 | */ 8 | public class DoubleArrays { 9 | 10 | public static double[] clone(double[] x) { 11 | return clone(x, 0, x.length); 12 | } 13 | 14 | public static double[] clone(double[] x, int start, int stop) { 15 | double[] y = new double[stop-start]; 16 | System.arraycopy(x,start,y,0,stop-start); 17 | return y; 18 | } 19 | 20 | public static double innerProduct(double[] x, double[] y) { 21 | if (x.length != y.length) 22 | throw new RuntimeException("diff lengths: " + x.length + " " 23 | + y.length); 24 | double result = 0.0; 25 | for (int i = 0; i < x.length; i++) { 26 | result += x[i] * y[i]; 27 | } 28 | return result; 29 | } 30 | 31 | public static double[] addMultiples(double[] x, double xMultiplier, 32 | double[] y, double yMuliplier) { 33 | if (x.length != y.length) 34 | throw new RuntimeException("diff lengths: " + x.length + " " 35 | + y.length); 36 | double[] z = new double[x.length]; 37 | for (int i = 0; i < z.length; i++) { 38 | z[i] = x[i] * xMultiplier + y[i] * yMuliplier; 39 | } 40 | return z; 41 | } 42 | 43 | public static double[] constantArray(double c, int length) { 44 | double[] x = new double[length]; 45 | Arrays.fill(x, c); 46 | return x; 47 | } 48 | 49 | public static double[] pointwiseMultiply(double[] x, double[] y) { 50 | if (x.length != y.length) 51 | throw new RuntimeException("diff lengths: " + x.length + " " 52 | + y.length); 53 | double[] z = new double[x.length]; 54 | for (int i = 0; i < z.length; i++) { 55 | z[i] = x[i] * y[i]; 56 | } 57 | return z; 58 | } 59 | 60 | public static boolean probNormInPlace(double[] x) { 61 | double sum = sum(x); 62 | if (sum <= 0.0) return false; 63 | scaleInPlace(x, 1.0 / sum); 64 | return true; 65 | } 66 | 67 | public static double[] uniformDraw(int n, Random rand) { 68 | double[] x = new double[n]; 69 | for (int i=0; i < n; ++i) { 70 | x[i] = rand.nextDouble(); 71 | } 72 | DoubleArrays.probNormInPlace(x); 73 | return x; 74 | } 75 | 76 | public static String toString(double[] x) { 77 | return toString(x, x.length); 78 | } 79 | 80 | public static String toString(double[][] x) { 81 | StringBuilder sb = new StringBuilder(); 82 | for (double[] row : x) { 83 | sb.append(DoubleArrays.toString(row)); 84 | sb.append("\n"); 85 | } 86 | return sb.toString(); 87 | } 88 | 89 | public static String toString(double[] x, int length) { 90 | StringBuffer sb = new StringBuffer(); 91 | sb.append("["); 92 | for (int i = 0; i < SloppyMath.min(x.length, length); i++) { 93 | sb.append(String.format("%.5f", x[i])); 94 | if (i + 1 < SloppyMath.min(x.length, length)) sb.append(", "); 95 | } 96 | sb.append("]"); 97 | return sb.toString(); 98 | } 99 | 100 | public static void scaleInPlace(double[] x, double s) { 101 | if (s == 1.0) return; 102 | for (int i = 0; i < x.length; i++) { 103 | x[i] *= s; 104 | } 105 | } 106 | 107 | public static double[] scale(double[] x, double s) { 108 | double[] res = DoubleArrays.clone(x); 109 | scaleInPlace(res,s); 110 | return res; 111 | 112 | } 113 | 114 | public static int argMax(double[] v) { 115 | int maxI = -1; 116 | double maxV = Double.NEGATIVE_INFINITY; 117 | for (int i = 0; i < v.length; i++) { 118 | if (v[i] > maxV) { 119 | maxV = v[i]; 120 | maxI = i; 121 | } 122 | } 123 | return maxI; 124 | } 125 | 126 | public static double max(double[] v) { 127 | double maxV = Double.NEGATIVE_INFINITY; 128 | for (int i = 0; i < v.length; i++) { 129 | if (v[i] > maxV) { 130 | maxV = v[i]; 131 | } 132 | } 133 | return maxV; 134 | } 135 | 136 | public static double max(double[][] m) { 137 | double max = Double.NEGATIVE_INFINITY; 138 | for (double[] row : m) { 139 | max = Math.max(max(row), max); 140 | } 141 | return max; 142 | } 143 | 144 | public static int argMin(double[] v) { 145 | int minI = -1; 146 | double minV = Double.POSITIVE_INFINITY; 147 | for (int i = 0; i < v.length; i++) { 148 | if (v[i] < minV) { 149 | minV = v[i]; 150 | minI = i; 151 | } 152 | } 153 | return minI; 154 | } 155 | 156 | public static double min(double[] v) { 157 | double minV = Double.POSITIVE_INFINITY; 158 | for (int i = 0; i < v.length; i++) { 159 | if (v[i] < minV) { 160 | minV = v[i]; 161 | } 162 | } 163 | return minV; 164 | } 165 | 166 | public static double min(double[][] m) { 167 | double min = Double.POSITIVE_INFINITY; 168 | for (double[] row : m) { 169 | min = Math.min(min(row), min); 170 | } 171 | return min; 172 | } 173 | 174 | public static double maxAbs(double[] v) { 175 | double maxV = 0; 176 | for (int i = 0; i < v.length; i++) { 177 | double abs = (v[i] <= 0.0d) ? 0.0d - v[i] : v[i]; 178 | if (abs > maxV) { 179 | maxV = abs; 180 | } 181 | } 182 | return maxV; 183 | } 184 | 185 | public static double[] add(double[] a, double b) { 186 | double[] result = new double[a.length]; 187 | for (int i = 0; i < a.length; i++) { 188 | double v = a[i]; 189 | result[i] = v + b; 190 | } 191 | return result; 192 | } 193 | 194 | public static double sum(double[] a) { 195 | double sum = 0.0; 196 | for (int i = 0; i < a.length; i++) { 197 | sum += a[i]; 198 | } 199 | return sum; 200 | } 201 | 202 | 203 | public static double add(double[] a, int first, int last) { 204 | if (last >= a.length) 205 | throw new RuntimeException("last beyond end of array"); 206 | if (first < 0) throw new RuntimeException("first must be at least 0"); 207 | double sum = 0.0; 208 | for (int i = first; i <= last; i++) { 209 | sum += a[i]; 210 | } 211 | return sum; 212 | } 213 | 214 | public static double vectorLength(double[] x) { 215 | return Math.sqrt(innerProduct(x, x)); 216 | } 217 | 218 | public static double[] add(double[] x, double[] y) { 219 | if (x.length != y.length) 220 | throw new RuntimeException("diff lengths: " + x.length + " " 221 | + y.length); 222 | double[] result = new double[x.length]; 223 | for (int i = 0; i < x.length; i++) { 224 | result[i] = x[i] + y[i]; 225 | } 226 | return result; 227 | } 228 | 229 | public static void subtractInPlace(double[] x, double[] y) { 230 | // be in cvs 231 | for (int i = 0; i < x.length; ++i) { 232 | x[i] -= y[i]; 233 | } 234 | } 235 | 236 | /** 237 | * If a subtraction results in NaN (i.e -inf - (-inf)) 238 | * does not perform the computation. 239 | * 240 | * @param x 241 | * @param y 242 | */ 243 | public static void subtractInPlaceUnsafe(double[] x, double[] y) { 244 | // be in cvs 245 | for (int i = 0; i < x.length; ++i) { 246 | if (Double.isNaN(x[i] - y[i])) { 247 | continue; 248 | } 249 | x[i] -= y[i]; 250 | } 251 | } 252 | 253 | public static double[] subtract(double[] x, double[] y) { 254 | if (x.length != y.length) 255 | throw new RuntimeException("diff lengths: " + x.length + " " 256 | + y.length); 257 | double[] result = new double[x.length]; 258 | for (int i = 0; i < x.length; i++) { 259 | result[i] = x[i] - y[i]; 260 | } 261 | return result; 262 | } 263 | 264 | public static double[] exponentiate(double[] pUnexponentiated) { 265 | double[] exponentiated = new double[pUnexponentiated.length]; 266 | for (int index = 0; index < pUnexponentiated.length; index++) { 267 | exponentiated[index] = SloppyMath.exp(pUnexponentiated[index]); 268 | } 269 | return exponentiated; 270 | } 271 | 272 | public static double[][] exponentiate(double[][] pUnexponentiated) { 273 | double[][] exponentiated = new double[pUnexponentiated.length][]; 274 | for (int index = 0; index < pUnexponentiated.length; index++) { 275 | exponentiated[index] = exponentiate(pUnexponentiated[index]); 276 | } 277 | return exponentiated; 278 | } 279 | 280 | public static void truncateInPlace(double[] x, double maxVal) { 281 | for (int index = 0; index < x.length; index++) { 282 | if (x[index] > maxVal) x[index] = maxVal; 283 | else if (x[index] < -maxVal) x[index] = -maxVal; 284 | } 285 | } 286 | 287 | public static void addInPlace(double[] x, double c) { 288 | for (int i = 0; i < x.length; i++) { 289 | x[i] += c; 290 | } 291 | } 292 | 293 | public static void addInPlace(double[] x, double[] y) { 294 | assert y.length >= x.length; 295 | for (int i = 0; i < x.length; ++i) { 296 | x[i] += y[i]; 297 | } 298 | } 299 | 300 | public static void addInPlace2D(double[][] x, double[][] y) { 301 | // TODO Auto-generated method stub 302 | assert y.length >= x.length; 303 | for (int i = 0; i < x.length; ++i) { 304 | DoubleArrays.addInPlace(x[i], y[i]); 305 | } 306 | } 307 | 308 | public static void multiplyInPlace(double[] x, double[] y) { 309 | for (int i = 0; i < x.length; i++) { 310 | x[i] *= y[i]; 311 | } 312 | } 313 | 314 | public static double[] average(double[][] x) { 315 | if (x.length == 0) { 316 | return null; 317 | } 318 | double[] sum = x[0]; 319 | for (int i = 1; i < x.length; i++) { 320 | sum = add(sum, x[i]); 321 | } 322 | double[] avg = scale(sum, (1.0 / x.length)); 323 | return avg; 324 | } 325 | 326 | public static void checkNonNegative(double[] x) { 327 | for (double v : x) { 328 | if (v < -1.0e-10) { 329 | throw new RuntimeException("Negative number " + v); 330 | } 331 | } 332 | } 333 | 334 | public static void checkNonNegative(double[][] m) { 335 | for (double[] row : m) { 336 | checkNonNegative(row); 337 | } 338 | } 339 | 340 | public static void checkValid(double[] x) { 341 | for (double v : x) { 342 | if (Double.isNaN(v)) { 343 | throw new RuntimeException("Invalid entry " + v); 344 | } 345 | } 346 | } 347 | 348 | public static void checkValid(double[][] m) { 349 | for (double[] row : m) { 350 | checkValid(row); 351 | } 352 | } 353 | 354 | public static double lInfinityDist(double[] x, double[] y) { 355 | double max = Double.NEGATIVE_INFINITY; 356 | for (int i = 0; i < x.length; i++) { 357 | max = Math.max(max, Math.abs(x[i] - y[i])); 358 | } 359 | return max; 360 | } 361 | 362 | public static void logInPlace(double[] vec) { 363 | for (int i = 0; i < vec.length; i++) { 364 | vec[i] = Math.log(vec[i]); 365 | } 366 | } 367 | 368 | public static void checkNonInfinite(double[] vec) { 369 | for (double v : vec) { 370 | if (Double.isInfinite(v)) { 371 | throw new RuntimeException("Invalid Entry: " + v); 372 | } 373 | } 374 | } 375 | 376 | public static void checkNonInfinite(double[][] m) { 377 | for (double[] row : m) { 378 | checkNonInfinite(row); 379 | } 380 | } 381 | 382 | public static void addInPlace(double[] a, double[] b, double c) { 383 | for (int i = 0; i < a.length; i++) { 384 | a[i] += b[i] * c; 385 | } 386 | } 387 | 388 | public static void addInPlace(double[][] a, double[][] b, double c) { 389 | for (int i = 0; i < a.length; i++) { 390 | addInPlace(a[i], b[i], c); 391 | } 392 | } 393 | 394 | public static double outerProduct(double[][] M, double[] x) { 395 | double sum = 0.0; 396 | for (int i = 0; i < M.length; i++) { 397 | for (int j = 0; j < M[i].length; j++) { 398 | sum += M[i][j] * x[i] * x[j]; 399 | } 400 | } 401 | return sum; 402 | } 403 | 404 | public static int sample(double[] arr, Random r) { 405 | double sum = DoubleArrays.sum(arr); 406 | assert sum > 0.0; 407 | double goal = r.nextDouble(); 408 | double massSoFar = 0.0; 409 | for (int i = 0; i < arr.length; i++) { 410 | double x = arr[i]; 411 | assert x >= 0.0; 412 | double probSoFar = massSoFar / sum; 413 | double probNext = (massSoFar + x) / sum; 414 | if (goal >= probSoFar && goal <= probNext) { 415 | return i; 416 | } 417 | massSoFar += x; 418 | } 419 | throw new RuntimeException(); 420 | } 421 | 422 | public static void exponentiateInPlace(double[] arr) { 423 | for (int i = 0; i < arr.length; i++) { 424 | arr[i] = Math.exp(arr[i]); 425 | } 426 | } 427 | 428 | 429 | public static void addNoiseInPlace(double[] row, Random rand, double noiseLevel) { 430 | for (int i = 0; i < row.length; i++) { 431 | double v = row[i]; 432 | if (v > 0.0) { 433 | row[i] += rand.nextDouble() * noiseLevel; 434 | } 435 | } 436 | DoubleArrays.probNormInPlace(row); 437 | } 438 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/ICounter.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | import edu.umass.nlp.functional.DoubleFn; 4 | 5 | import java.util.Collection; 6 | 7 | 8 | public interface ICounter extends Collection> { 9 | 10 | public void incCount(K key, double incAmt); 11 | public void setCount(K key, double v); 12 | 13 | // Mutable Operations 14 | public double getCount(K key); 15 | public double totalCount(); 16 | 17 | // Generic 18 | public String toString(int maxKeys); 19 | 20 | 21 | // Abstract operations 22 | public ICounter map(DoubleFn> f); 23 | public void mapDestructive(DoubleFn> f); 24 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/IHasProperties.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | import java.util.List; 4 | 5 | public interface IHasProperties { 6 | 7 | public Object getProperty(String name); 8 | public List> getProperties(); 9 | public void addProperty(String name, Object val); 10 | 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/IIndexed.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | public interface IIndexed extends IWrapper { 4 | public int getIndex(); 5 | 6 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/ILockable.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | /** 4 | * Abstraction for an object with a locking state (boo!). 5 | * The behaviour of an object may change when it is locked. 6 | * A cannonical example is a collection you only want to be allowed 7 | * to be added to for a particular time. 8 | * 9 | * All implementations should issue a Logger.warn or crash if you try 10 | * to lock an already locked object. Safety first. 11 | * @author aria42 12 | */ 13 | public interface ILockable { 14 | public boolean isLocked(); 15 | public void lock(); 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/IMergable.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | public interface IMergable { 4 | public void merge(T other); 5 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/IPair.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | public interface IPair { 4 | public S getFirst(); 5 | public T getSecond(); 6 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/ISpannable.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | public interface ISpannable { 4 | public Span getSpan(); 5 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/IValuable.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | public interface IValuable { 4 | public IValued getValued(); 5 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/IValued.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | /** 4 | * Abstraction for an element and a value associated 5 | * with that element. A feature value pair is an IValued 6 | * over the feature where the value is the feature-value. 7 | * 8 | * We don't just use IPair to avoid double auto-boxing 9 | * @param 10 | */ 11 | public interface IValued extends IPair, Comparable>, IWrapper { 12 | 13 | public double getValue(); 14 | public IValued withValue(double x); 15 | 16 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/IWrapper.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | public interface IWrapper { 4 | public T getElem(); 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/Indexer.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | import java.io.Serializable; 4 | import java.util.*; 5 | 6 | public class Indexer extends AbstractList implements Serializable, ILockable { 7 | 8 | private final List elems = new ArrayList(); 9 | private final Map indexMap = new HashMap(); 10 | private boolean locked = false; 11 | 12 | public class IndexedWrapper implements IIndexed { 13 | private final L elem ; 14 | 15 | public L getElem() { 16 | return elem; 17 | } 18 | 19 | public IndexedWrapper(L elem) { 20 | this.elem = elem; 21 | } 22 | 23 | public int getIndex() { 24 | return indexMap.get(elem); 25 | } 26 | } 27 | 28 | public IIndexed getIndexed(L elem) { 29 | return new IndexedWrapper(elem); 30 | } 31 | 32 | public boolean isLocked() { 33 | return locked; 34 | } 35 | 36 | public void lock() { 37 | if (locked) { 38 | throw new RuntimeException("Tryed to lock() a locked Indexer"); 39 | } 40 | locked = true; 41 | } 42 | 43 | public int indexOf(Object elem) { 44 | Integer i = indexMap.get(elem); 45 | return i == null ? -1 : i; 46 | } 47 | 48 | @Override 49 | public boolean add(L elem) { 50 | if (isLocked()) { 51 | throw new RuntimeException("Tryed to sum() to a locked Indexer"); 52 | } 53 | Integer index = indexMap.get(elem); 54 | if (index != null) { 55 | return false; 56 | } 57 | elems.add(elem); 58 | indexMap.put(elem, elems.size() - 1); 59 | return true; 60 | } 61 | 62 | public int getIndex(L elem) { 63 | Integer index = indexMap.get(elem); 64 | return index != null ? index.intValue() : -1; 65 | } 66 | 67 | public L get(int index) { 68 | return elems.get(index); 69 | } 70 | 71 | public int size() { 72 | return elems.size(); 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/LogAdder.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | import org.apache.commons.collections.primitives.ArrayDoubleList; 4 | import org.apache.commons.collections.primitives.DoubleList; 5 | 6 | public class LogAdder { 7 | 8 | DoubleList xs = new ArrayDoubleList(); 9 | 10 | public void add(double x) { 11 | if (x > Double.NEGATIVE_INFINITY) xs.add(x); 12 | } 13 | 14 | public double logSum() { 15 | return SloppyMath.logAdd(xs.toArray()); 16 | } 17 | 18 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/MapCounter.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | 4 | import edu.umass.nlp.functional.DoubleFn; 5 | import edu.umass.nlp.functional.Fn; 6 | import edu.umass.nlp.functional.Functional; 7 | 8 | import java.util.*; 9 | 10 | public class MapCounter extends AbstractCollection> implements ICounter { 11 | 12 | private final static MutableDouble zero_ = new MutableDouble(0.0); 13 | 14 | private final Map counts_; 15 | private final MutableDouble totalCount_; 16 | 17 | private MapCounter(Map counts_, MutableDouble totalCount_) { 18 | this.counts_ = counts_; 19 | this.totalCount_ = totalCount_; 20 | } 21 | 22 | public MapCounter(Map counts_) { 23 | this.counts_ = counts_; 24 | this.totalCount_ = new MutableDouble(computeTotalCount()); 25 | } 26 | 27 | public MapCounter() { 28 | this(new HashMap()); 29 | } 30 | 31 | private double computeTotalCount() { 32 | double sum = 0.0; 33 | for (Map.Entry entry : counts_.entrySet()) { 34 | sum += entry.getValue().doubleValue(); 35 | } 36 | return sum; 37 | } 38 | 39 | public double getCount(K key) { 40 | return Collections.get(counts_, key, zero_).doubleValue(); 41 | } 42 | 43 | public void incCount(K key, double incAmt) { 44 | MutableDouble oldVal = Collections.getMut(counts_, key, new MutableDouble(0.0)); 45 | double newVal = oldVal.doubleValue() + incAmt; 46 | if (newVal == 0.0) { 47 | counts_.remove(key); 48 | } else { 49 | oldVal.inc(incAmt); 50 | } 51 | totalCount_.inc(incAmt); 52 | } 53 | 54 | public void setCount(K key, double v) { 55 | if (v == 0.0) { 56 | MutableDouble oldVal = Collections.get(counts_, key, zero_); 57 | totalCount_.inc(v - oldVal.doubleValue()); 58 | counts_.remove(key); 59 | } else { 60 | MutableDouble oldVal = Collections.getMut(counts_, key, new MutableDouble(0.0)); 61 | totalCount_.inc(v - oldVal.doubleValue()); 62 | oldVal.set(v); 63 | } 64 | } 65 | 66 | public double totalCount() { 67 | return totalCount_.doubleValue(); 68 | } 69 | 70 | public ICounter map(DoubleFn> f) { 71 | Map newCounts = new HashMap(); 72 | double newTotalCount = 0.0; 73 | for (IValued valued : this) { 74 | double newVal = f.valAt(valued); 75 | newCounts.put(valued.getElem(), new MutableDouble(newVal)); 76 | newTotalCount += newVal; 77 | } 78 | return new MapCounter(newCounts, new MutableDouble(newTotalCount)); 79 | } 80 | 81 | public void mapDestructive(DoubleFn> f) { 82 | double newTotalCount = 0.0; 83 | for (Map.Entry entry : counts_.entrySet()) { 84 | double oldVal = entry.getValue().doubleValue(); 85 | double newVal = f.valAt(BasicValued.make(entry.getKey(), oldVal)); 86 | entry.getValue().set(newVal); 87 | newTotalCount += newVal ; 88 | } 89 | totalCount_.set(newTotalCount); 90 | } 91 | 92 | public Iterator> iterator() { 93 | return Functional.map(counts_.entrySet().iterator(), new Fn, IValued>() { 94 | public IValued apply(Map.Entry input) { 95 | return BasicValued.make(input.getKey(), input.getValue().doubleValue()); 96 | } 97 | }); 98 | } 99 | 100 | public int size() { 101 | return counts_.size(); 102 | } 103 | 104 | public boolean add(IValued valued) { 105 | incCount(valued.getElem(), valued.getValue()); 106 | return true; 107 | } 108 | 109 | 110 | @Override 111 | public String toString() { 112 | return toString(25); 113 | } 114 | 115 | public String toString(int maxKeys) { 116 | List> vals = Counters.getTopK(this, maxKeys, null); 117 | return Functional.mkString( 118 | vals, // elems 119 | "[", // start 120 | ",", // middle 121 | "]", // stop 122 | new Fn, String>() { 123 | public String apply(IValued input) { 124 | return String.format("%s: %.4f\n", input.getElem(), input.getValue()); 125 | } 126 | }); 127 | } 128 | 129 | 130 | 131 | 132 | public static void main(String[] args) { 133 | ICounter counts = new MapCounter(); 134 | System.out.println(counts); 135 | counts.incCount("planets", 7); 136 | System.out.println(counts); 137 | counts.incCount("planets", 1); 138 | System.out.println(counts); 139 | counts.setCount("suns", 1); 140 | System.out.println(counts); 141 | counts.incCount("aliens", 0); 142 | counts.add(BasicValued.make("aria",42.0)); 143 | System.out.println(counts.toString(1)); 144 | System.out.println(Counters.getTopK(counts,1,null)); 145 | Counters.scaleDestructive(counts,2.0); 146 | System.out.println(counts.toString()); 147 | System.out.println("Total: " + counts.totalCount()); 148 | 149 | } 150 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/Maxer.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | public class Maxer { 4 | public double max = Double.NEGATIVE_INFINITY; 5 | public L argMax = null; 6 | public void observe(L elem, double val) { 7 | if (val > max) { 8 | argMax = elem; 9 | max = val; 10 | } 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/MergableUtils.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | import java.util.HashMap; 4 | import java.util.Map; 5 | 6 | public class MergableUtils { 7 | public static > 8 | void mergeInto(Map map, 9 | Map omap) 10 | { 11 | Map res = new HashMap(); 12 | for (Map.Entry entry : map.entrySet()) { 13 | K k = entry.getKey(); 14 | M v = entry.getValue(); 15 | M ov = omap.get(k); 16 | if (ov != null) v.merge(ov); 17 | } 18 | for (Map.Entry entry : omap.entrySet()) { 19 | K k = entry.getKey(); 20 | M v = map.get(k); 21 | if (v == null) { 22 | map.put(k, v); 23 | } 24 | } 25 | } 26 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/MutableDouble.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | /** 4 | * A class for Double objects that you can change. 5 | * 6 | * @author Dan Klein 7 | */ 8 | public final class MutableDouble extends Number implements Comparable { 9 | 10 | private double d; 11 | 12 | // Mutable 13 | public void set(double d) { 14 | this.d = d; 15 | } 16 | 17 | public int hashCode() { 18 | long bits = Double.doubleToLongBits(d); 19 | return (int) (bits ^ (bits >>> 32)); 20 | } 21 | 22 | /** 23 | * Compares this object to the specified object. The result is 24 | * true if and only if the argument is not 25 | * null and is an MutableDouble object that 26 | * contains the same double value as this object. 27 | * Note that a MutableDouble isn't and can't be equal to an Double. 28 | * 29 | * @param obj the object to compare with. 30 | * @return true if the objects are the same; 31 | * false otherwise. 32 | */ 33 | public boolean equals(Object obj) { 34 | if (obj instanceof MutableDouble) { 35 | return d == ((MutableDouble) obj).d; 36 | } 37 | return false; 38 | } 39 | 40 | public String toString() { 41 | return Double.toString(d); 42 | } 43 | 44 | // Comparable interface 45 | 46 | /** 47 | * Compares two MutableDouble objects numerically. 48 | * 49 | * @param anotherMutableDouble the MutableDouble to be 50 | * compared. 51 | * @return Tthe value 0 if this MutableDouble is 52 | * equal to the argument MutableDouble; a value less than 53 | * 0 if this MutableDouble is numerically less 54 | * than the argument MutableDouble; and a value greater 55 | * than 0 if this MutableDouble is numerically 56 | * greater than the argument MutableDouble (signed 57 | * comparison). 58 | */ 59 | public int compareTo(MutableDouble anotherMutableDouble) { 60 | double thisVal = this.d; 61 | double anotherVal = anotherMutableDouble.d; 62 | return (thisVal < anotherVal ? -1 : (thisVal == anotherVal ? 0 : 1)); 63 | } 64 | 65 | /** 66 | * Compares this MutableDouble object to another object. 67 | * If the object is an MutableDouble, this function behaves 68 | * like compareTo(MutableDouble). Otherwise, it throws a 69 | * ClassCastException (as MutableDouble 70 | * objects are only comparable to other MutableDouble 71 | * objects). 72 | * 73 | * @param o the Object to be compared. 74 | * @return 0/-1/1 75 | * @throws ClassCastException 76 | * if the argument is not an 77 | * MutableDouble. 78 | * @see Comparable 79 | */ 80 | public int compareTo(Object o) { 81 | return compareTo((MutableDouble) o); 82 | } 83 | 84 | // Number interface 85 | public int intValue() { 86 | return (int) d; 87 | } 88 | 89 | public long longValue() { 90 | return (long) d; 91 | } 92 | 93 | public short shortValue() { 94 | return (short) d; 95 | } 96 | 97 | public byte byteValue() { 98 | return (byte) d; 99 | } 100 | 101 | public float floatValue() { 102 | return (float) d; 103 | } 104 | 105 | public double doubleValue() { 106 | return d; 107 | } 108 | 109 | public MutableDouble() { 110 | this(0); 111 | } 112 | 113 | public MutableDouble(double d) { 114 | this.d = d; 115 | } 116 | 117 | public double inc(double inc) { 118 | this.d += inc; 119 | return d; 120 | } 121 | 122 | private static final long serialVersionUID = 624465615824626762L; 123 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/Span.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | 4 | public class Span { 5 | 6 | private int start, stop; 7 | 8 | public Span(int start, int stop) { 9 | setAndEnsure(start,stop); 10 | } 11 | 12 | public Span shift(int offset) { 13 | return new Span(start+offset,stop+offset); 14 | } 15 | 16 | public boolean contains(Span other) { 17 | return this.start <= other.start && this.stop >= other.stop; 18 | } 19 | 20 | public boolean contains(int i) { 21 | return i >= start && i < stop; 22 | } 23 | 24 | private void setAndEnsure(int start, int stop) { 25 | assert stop >= start; 26 | this.start = start; 27 | this.stop = stop; 28 | } 29 | 30 | public int getLength() { 31 | return stop-start; 32 | } 33 | 34 | public int getStart() { 35 | return start; 36 | } 37 | 38 | public void setStart(int start) { 39 | setAndEnsure(start,stop); 40 | } 41 | 42 | public int getStop() { 43 | return stop; 44 | } 45 | 46 | public void setStop(int stop) { 47 | setAndEnsure(start,stop); 48 | } 49 | 50 | public String toString() { 51 | return String.format("(%d,%d)", start, stop); 52 | } 53 | 54 | public boolean equals(Object o) { 55 | if (this == o) return true; 56 | if (o == null || getClass() != o.getClass()) return false; 57 | 58 | Span span = (Span) o; 59 | 60 | if (start != span.start) return false; 61 | if (stop != span.stop) return false; 62 | 63 | return true; 64 | } 65 | 66 | public int hashCode() { 67 | int result = start; 68 | result = 31 * result + stop; 69 | return result; 70 | } 71 | 72 | 73 | } -------------------------------------------------------------------------------- /src/main/java/edu/umass/nlp/utils/StringUtils.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.utils; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | public class StringUtils { 7 | 8 | public static List getCharacters(String s) { 9 | List chars = new ArrayList(); 10 | for (int i=0; i < s.length(); ++i) { 11 | chars.add(s.charAt(i)); 12 | } 13 | return chars; 14 | } 15 | 16 | public static String toString(List chars) { 17 | char[] charArr = new char[chars.size()]; 18 | for (int i=0; i < chars.size(); ++i) { 19 | charArr[i] = chars.get(i); 20 | } 21 | return new String(charArr); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/test/java/edu/umass/nlp/exec/ExecutionTest.java: -------------------------------------------------------------------------------- 1 | package edu.umass.nlp.exec; 2 | 3 | import org.apache.log4j.Logger; 4 | 5 | public class ExecutionTest { 6 | 7 | public static class Opts { 8 | @Opt 9 | public boolean doStuff = false; 10 | } 11 | 12 | public static void main(String[] args) { 13 | Execution.init("/Users/aria42/Desktop/test.yaml"); 14 | Logger logger = Logger.getLogger("ExecutionTest"); 15 | Opts opts = Execution.fillOptions("opts", Opts.class); 16 | logger.info("Hi"); 17 | logger.debug("Debug"); 18 | logger.trace("ACK"); 19 | logger.info(opts.doStuff); 20 | } 21 | } 22 | --------------------------------------------------------------------------------