├── test └── ciir │ └── umass │ └── edu │ ├── SimpleTest.java │ └── eval │ └── EvaluatorTest.java ├── src └── ciir │ └── umass │ └── edu │ ├── stats │ ├── SignificanceTest.java │ ├── BasicStats.java │ └── RandomPermutationTest.java │ ├── utilities │ ├── WorkerThread.java │ ├── TmpFile.java │ ├── RankLibError.java │ ├── KeyValuePair.java │ ├── SimpleMath.java │ ├── MyThreadPool.java │ ├── MergeSorter.java │ ├── ExpressionEvaluator.java │ └── FileUtils.java │ ├── metric │ ├── METRIC.java │ ├── MetricScorer.java │ ├── MetricScorerFactory.java │ ├── PrecisionScorer.java │ ├── ReciprocalRankScorer.java │ ├── ERRScorer.java │ ├── BestAtKScorer.java │ ├── DCGScorer.java │ ├── APScorer.java │ └── NDCGScorer.java │ ├── learning │ ├── RANKER_TYPE.java │ ├── neuralnet │ │ ├── TransferFunction.java │ │ ├── LogiFunction.java │ │ ├── HyperTangentFunction.java │ │ ├── PropParameter.java │ │ ├── ListNeuron.java │ │ ├── Synapse.java │ │ ├── Layer.java │ │ ├── LambdaRank.java │ │ ├── Neuron.java │ │ └── ListNet.java │ ├── boosting │ │ ├── RBWeakRanker.java │ │ ├── WeakRanker.java │ │ └── AdaRank.java │ ├── DenseDataPoint.java │ ├── Combiner.java │ ├── tree │ │ ├── MART.java │ │ ├── RegressionTree.java │ │ ├── Ensemble.java │ │ ├── Split.java │ │ └── RFRanker.java │ ├── RankerTrainer.java │ ├── Sampler.java │ ├── RankList.java │ ├── SparseDataPoint.java │ ├── RankerFactory.java │ ├── Ranker.java │ ├── DataPoint.java │ └── LinearRegRank.java │ ├── features │ ├── Normalizer.java │ ├── SumNormalizor.java │ ├── LinearNormalizer.java │ └── ZScoreNormalizor.java │ └── eval │ └── Analyzer.java ├── README.md ├── LICENSE.txt └── pom.xml /test/ciir/umass/edu/SimpleTest.java: -------------------------------------------------------------------------------- 1 | package ciir.umass.edu; 2 | 3 | import org.junit.Test; 4 | 5 | import static org.junit.Assert.assertEquals; 6 | 7 | /** 8 | * @author jfoley. 9 | */ 10 | public class SimpleTest { 11 | 12 | @Test 13 | public void testSomething() { 14 | assertEquals(4, 4); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/stats/SignificanceTest.java: -------------------------------------------------------------------------------- 1 | package ciir.umass.edu.stats; 2 | 3 | import java.util.HashMap; 4 | 5 | public class SignificanceTest { 6 | 7 | public double test(HashMap target, HashMap baseline) 8 | { 9 | return 0; 10 | } 11 | protected void makeRCall() 12 | { 13 | 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/utilities/WorkerThread.java: -------------------------------------------------------------------------------- 1 | package ciir.umass.edu.utilities; 2 | 3 | public abstract class WorkerThread implements Runnable { 4 | protected int start = -1; 5 | protected int end = -1; 6 | public void set(int start, int end) 7 | { 8 | this.start = start; 9 | this.end = end; 10 | } 11 | public abstract WorkerThread clone(); 12 | } 13 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/stats/BasicStats.java: -------------------------------------------------------------------------------- 1 | package ciir.umass.edu.stats; 2 | 3 | public class BasicStats { 4 | public static double mean(double[] values) 5 | { 6 | double mean = 0.0; 7 | if(values.length == 0) 8 | { 9 | System.out.println("Error in BasicStats::mean(): Empty input array."); 10 | System.exit(1); 11 | } 12 | for(int i=0;i threshold) 31 | return 1; 32 | return 0; 33 | } 34 | public int getFid() 35 | { 36 | return fid; 37 | } 38 | public double getThreshold() 39 | { 40 | return threshold; 41 | } 42 | public String toString() 43 | { 44 | return fid + ":" + threshold; 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/learning/neuralnet/PropParameter.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.learning.neuralnet; 11 | 12 | public class PropParameter { 13 | //RankNet 14 | public int current = -1;//index of current data point in the ranked list 15 | public int[][] pairMap = null; 16 | public PropParameter(int current, int[][] pairMap) 17 | { 18 | this.current = current; 19 | this.pairMap = pairMap; 20 | } 21 | //LambdaRank: RankNet + the following 22 | public float[][] pairWeight = null; 23 | public float[][] targetValue = null; 24 | public PropParameter(int current, int[][] pairMap, float[][] pairWeight, float[][] targetValue) 25 | { 26 | this.current = current; 27 | this.pairMap = pairMap; 28 | this.pairWeight = pairWeight; 29 | this.targetValue = targetValue; 30 | } 31 | //ListNet 32 | public float[] labels = null;//relevance label 33 | public PropParameter(float[] labels) 34 | { 35 | this.labels = labels; 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/learning/DenseDataPoint.java: -------------------------------------------------------------------------------- 1 | package ciir.umass.edu.learning; 2 | 3 | import ciir.umass.edu.utilities.RankLibError; 4 | 5 | public class DenseDataPoint extends DataPoint { 6 | 7 | public DenseDataPoint(String text) { 8 | super(text); 9 | } 10 | 11 | public DenseDataPoint(DenseDataPoint dp) 12 | { 13 | label = dp.label; // relevance label 14 | id = dp.id; // query id 15 | description = dp.description; 16 | cached = dp.cached; 17 | fVals = new float[dp.fVals.length]; 18 | System.arraycopy(dp.fVals, 0, fVals, 0, dp.fVals.length); 19 | } 20 | 21 | @Override 22 | public float getFeatureValue(int fid) 23 | { 24 | if(fid <= 0 || fid >= fVals.length) 25 | { 26 | throw RankLibError.create("Error in DenseDataPoint::getFeatureValue(): requesting unspecified feature, fid=" + fid); 27 | } 28 | if(isUnknown(fVals[fid]))//value for unspecified feature is 0 29 | return 0; 30 | return fVals[fid]; 31 | } 32 | 33 | @Override 34 | public void setFeatureValue(int fid, float fval) 35 | { 36 | if(fid <= 0 || fid >= fVals.length) 37 | { 38 | throw RankLibError.create("Error in DenseDataPoint::setFeatureValue(): feature (id=" + fid + ") not found."); 39 | } 40 | fVals[fid] = fval; 41 | } 42 | 43 | @Override 44 | public void setFeatureVector(float[] dfVals) { 45 | //fVals = new float[dfVals.length]; 46 | //System.arraycopy(dfVals, 0, fVals, 0, dfVals.length); 47 | fVals = dfVals; 48 | } 49 | 50 | @Override 51 | public float[] getFeatureVector() { 52 | return fVals; 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/learning/boosting/WeakRanker.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.learning.boosting; 11 | 12 | import java.util.ArrayList; 13 | import java.util.List; 14 | 15 | import ciir.umass.edu.learning.RankList; 16 | import ciir.umass.edu.utilities.Sorter; 17 | 18 | /** 19 | * @author vdang 20 | * 21 | * Weak rankers for AdaRank. 22 | */ 23 | public class WeakRanker { 24 | private int fid = -1; 25 | 26 | public WeakRanker(int fid) 27 | { 28 | this.fid = fid; 29 | } 30 | public int getFID() 31 | { 32 | return fid; 33 | } 34 | 35 | public RankList rank(RankList l) 36 | { 37 | double[] score = new double[l.size()]; 38 | for(int i=0;i rank(List l) 44 | { 45 | List ll = new ArrayList(); 46 | for(int i=0;i the output of the current neuron on the i-th document 22 | { 23 | sumLabelExp += Math.exp(param.labels[i]); 24 | sumScoreExp += Math.exp(outputs.get(i)); 25 | } 26 | 27 | d1 = new double[outputs.size()]; 28 | d2 = new double[outputs.size()]; 29 | for(int i=0;i samples) 28 | { 29 | for(int i=0;i samples, int[] fids) 37 | { 38 | for(int i=0;i uniqueSet = new HashSet(); 44 | for(int i=0;i keys = new ArrayList();; 20 | protected List values = new ArrayList();; 21 | 22 | public KeyValuePair(String text) 23 | { 24 | try { 25 | int idx = text.lastIndexOf("#"); 26 | if(idx != -1)//remove description at the end of the line (if any) 27 | text = text.substring(0, idx).trim();//remove the comment part at the end of the line 28 | 29 | String[] fs = text.split(" "); 30 | for(int i=0;i keys() 46 | { 47 | return keys; 48 | } 49 | public List values() 50 | { 51 | return values; 52 | } 53 | 54 | private String getKey(String pair) 55 | { 56 | return pair.substring(0, pair.indexOf(":")); 57 | } 58 | private String getValue(String pair) 59 | { 60 | return pair.substring(pair.lastIndexOf(":")+1); 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/metric/MetricScorer.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.metric; 11 | 12 | import ciir.umass.edu.learning.RankList; 13 | 14 | import java.util.List; 15 | 16 | /** 17 | * @author vdang 18 | * A generic retrieval measure computation interface. 19 | */ 20 | public abstract class MetricScorer { 21 | 22 | /** The depth parameter, or how deep of a ranked list to use to score the measure. */ 23 | protected int k = 10; 24 | 25 | public MetricScorer() 26 | { 27 | 28 | } 29 | 30 | /** 31 | * The depth parameter, or how deep of a ranked list to use to score the measure. 32 | * @param k the new depth for this measure. 33 | */ 34 | public void setK(int k) 35 | { 36 | this.k = k; 37 | } 38 | /** The depth parameter, or how deep of a ranked list to use to score the measure. */ 39 | public int getK() 40 | { 41 | return k; 42 | } 43 | public void loadExternalRelevanceJudgment(String qrelFile) 44 | { 45 | 46 | } 47 | public double score(List rl) 48 | { 49 | double score = 0.0; 50 | for(int i=0;ib)?b:a; 36 | } 37 | public static double p(long count, long total) 38 | { 39 | return ((double)count+0.5)/(total+1); 40 | } 41 | public static double round(double val) 42 | { 43 | int precision = 10000; //keep 4 digits 44 | return Math.floor(val * precision +.5)/precision; 45 | } 46 | public static double round(float val) 47 | { 48 | int precision = 10000; //keep 4 digits 49 | return Math.floor(val * precision +.5)/precision; 50 | } 51 | public static double round(double val, int n) 52 | { 53 | int precision = 1; 54 | for(int i=0;i samples, int[] features, MetricScorer scorer) 32 | { 33 | super(samples, features, scorer); 34 | } 35 | 36 | public Ranker createNew() 37 | { 38 | return new MART(); 39 | } 40 | public String name() 41 | { 42 | return "MART"; 43 | } 44 | protected void computePseudoResponses() 45 | { 46 | for(int i=0;i leaves = rt.leaves(); 52 | for(int i=0;i train, int[] features, MetricScorer scorer) 28 | { 29 | Ranker ranker = rf.createRanker(type, train, features, scorer); 30 | long start = System.nanoTime(); 31 | ranker.init(); 32 | ranker.learn(); 33 | trainingTime = System.nanoTime() - start; 34 | //printTrainingTime(); 35 | return ranker; 36 | } 37 | public Ranker train(RANKER_TYPE type, List train, List validation, int[] features, MetricScorer scorer) 38 | { 39 | Ranker ranker = rf.createRanker(type, train, features, scorer); 40 | ranker.setValidationSet(validation); 41 | long start = System.nanoTime(); 42 | ranker.init(); 43 | ranker.learn(); 44 | trainingTime = System.nanoTime() - start; 45 | //printTrainingTime(); 46 | return ranker; 47 | } 48 | public double getTrainingTime() 49 | { 50 | return trainingTime; 51 | } 52 | public void printTrainingTime() 53 | { 54 | System.out.println("Training time: " + SimpleMath.round((trainingTime)/1e9, 2) + " seconds"); 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/metric/MetricScorerFactory.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.metric; 11 | 12 | import java.util.HashMap; 13 | 14 | /** 15 | * @author vdang 16 | */ 17 | public class MetricScorerFactory { 18 | 19 | private static MetricScorer[] mFactory = new MetricScorer[]{new APScorer(), new NDCGScorer(), new DCGScorer(), new PrecisionScorer(), new ReciprocalRankScorer(), new BestAtKScorer(), new ERRScorer()}; 20 | private static HashMap map = new HashMap(); 21 | 22 | public MetricScorerFactory() 23 | { 24 | map.put("MAP", new APScorer()); 25 | map.put("NDCG", new NDCGScorer()); 26 | map.put("DCG", new DCGScorer()); 27 | map.put("P", new PrecisionScorer()); 28 | map.put("RR", new ReciprocalRankScorer()); 29 | map.put("BEST", new BestAtKScorer()); 30 | map.put("ERR", new ERRScorer()); 31 | } 32 | public MetricScorer createScorer(METRIC metric) 33 | { 34 | return mFactory[metric.ordinal() - METRIC.MAP.ordinal()].copy(); 35 | } 36 | public MetricScorer createScorer(METRIC metric, int k) 37 | { 38 | MetricScorer s = mFactory[metric.ordinal() - METRIC.MAP.ordinal()].copy(); 39 | s.setK(k); 40 | return s; 41 | } 42 | public MetricScorer createScorer(String metric)//e.g.: metric = "NDCG@5" 43 | { 44 | int k = -1; 45 | String m = ""; 46 | MetricScorer s = null; 47 | if(metric.indexOf("@") != -1) 48 | { 49 | m = metric.substring(0, metric.indexOf("@")); 50 | k = Integer.parseInt(metric.substring(metric.indexOf("@")+1)); 51 | s = map.get(m.toUpperCase()).copy(); 52 | s.setK(k); 53 | } 54 | else 55 | s = map.get(metric.toUpperCase()).copy(); 56 | return s; 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/learning/Sampler.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.learning; 11 | 12 | import java.util.ArrayList; 13 | import java.util.Arrays; 14 | import java.util.List; 15 | import java.util.Random; 16 | 17 | public class Sampler { 18 | protected List samples = null;//bag data 19 | protected List remains = null;//out-of-bag data 20 | public List doSampling(List samplingPool, float samplingRate, boolean withReplacement) 21 | { 22 | Random r = new Random(); 23 | samples = new ArrayList(); 24 | int size = (int)(samplingRate * samplingPool.size()); 25 | if(withReplacement) 26 | { 27 | int[] used = new int[samplingPool.size()]; 28 | Arrays.fill(used, 0); 29 | for(int i=0;i(); 36 | for(int i=0;i l = new ArrayList(); 43 | for(int i=0;i(); 52 | for(int i=0;i getSamples() 58 | { 59 | return samples; 60 | } 61 | public List getRemains() 62 | { 63 | return remains; 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/metric/PrecisionScorer.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.metric; 11 | 12 | import java.util.Arrays; 13 | 14 | import ciir.umass.edu.learning.RankList; 15 | 16 | /** 17 | * @author vdang 18 | */ 19 | public class PrecisionScorer extends MetricScorer { 20 | 21 | public PrecisionScorer() 22 | { 23 | this.k = 10; 24 | } 25 | public PrecisionScorer(int k) 26 | { 27 | this.k = k; 28 | } 29 | public double score(RankList rl) 30 | { 31 | int count = 0; 32 | 33 | int size = k; 34 | if(k > rl.size() || k <= 0) 35 | size = rl.size(); 36 | 37 | for(int i=0;i 0.0)//relevant 40 | count++; 41 | } 42 | return ((double)count)/size; 43 | } 44 | public MetricScorer copy() 45 | { 46 | return new PrecisionScorer(); 47 | } 48 | public String name() 49 | { 50 | return "P@"+k; 51 | } 52 | public double[][] swapChange(RankList rl) 53 | { 54 | int size = (rl.size() > k) ? k : rl.size(); 55 | /*int relCount = 0; 56 | for(int i=0;i 0.0)//relevant 58 | relCount++;*/ 59 | 60 | double[][] changes = new double[rl.size()][]; 61 | for(int i=0;i 0.0) 80 | return 1; 81 | return 0; 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/stats/RandomPermutationTest.java: -------------------------------------------------------------------------------- 1 | package ciir.umass.edu.stats; 2 | 3 | import java.util.HashMap; 4 | import java.util.Random; 5 | 6 | /** 7 | * Randomized permutation test. Adapted from Michael Bendersky's Python script. 8 | * @author vdang 9 | * 10 | */ 11 | public class RandomPermutationTest extends SignificanceTest { 12 | 13 | public static int nPermutation = 10000; 14 | private static String[] pad = new String[]{"", "0", "00", "000", "0000", "00000", "000000", "0000000", "00000000", "000000000"}; 15 | 16 | /** 17 | * Run the randomization test 18 | * @param baseline 19 | * @param target 20 | * @return 21 | */ 22 | public double test(HashMap target, HashMap baseline) 23 | { 24 | double[] b = new double[baseline.keySet().size()];//baseline 25 | double[] t = new double[target.keySet().size()];//target 26 | int c = 0; 27 | for(String key : baseline.keySet()) 28 | { 29 | b[c] = baseline.get(key).doubleValue(); 30 | t[c] = target.get(key).doubleValue(); 31 | c++; 32 | } 33 | double trueDiff = Math.abs(BasicStats.mean(b) - BasicStats.mean(t)); 34 | double pvalue = 0.0; 35 | double[] pb = new double[baseline.keySet().size()];//permutation of baseline 36 | double[] pt = new double[target.keySet().size()];//permutation of target 37 | for(int i=0;i= trueDiff) 55 | pvalue += 1.0; 56 | } 57 | return pvalue/nPermutation; 58 | } 59 | 60 | /** 61 | * Generate a random bit vector of a certain size 62 | * @param size 63 | * @return 64 | */ 65 | private String randomBitVector(int size) 66 | { 67 | Random r = new Random(); 68 | String output = ""; 69 | for(int i=0;i<(size/10)+1;i++) 70 | { 71 | int x = (int)((1<<10) * r.nextDouble()); 72 | String s = Integer.toBinaryString(x); 73 | if(s.length() == 11) 74 | output += s.substring(1); 75 | else 76 | output += pad[10-s.length()] + s; 77 | } 78 | return output; 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/learning/RankList.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.learning; 11 | 12 | import java.util.List; 13 | 14 | import ciir.umass.edu.utilities.Sorter; 15 | 16 | /** 17 | * @author vdang 18 | * 19 | * This class implement the list of objects (each of which is a DataPoint) to be ranked. 20 | */ 21 | public class RankList { 22 | 23 | protected DataPoint[] rl = null; 24 | 25 | public RankList(List rl) 26 | { 27 | this.rl = new DataPoint[rl.size()]; 28 | for(int i=0;i 3 | 4.0.0 4 | 5 | com.o19s 6 | RankLibPlus 7 | 0.1.0 8 | 9 | jar 10 | RankLib 11 | 12 | UTF-8 13 | ciir.umass.edu.eval.Evaluator 14 | 1.8 15 | 16 | 17 | 18 | 19 | junit 20 | junit 21 | 4.12 22 | test 23 | 24 | 25 | 26 | 27 | 28 | deployment 29 | Internal Releases 30 | http://scm-ciir.cs.umass.edu:8080/nexus/content/repositories/releases/ 31 | 32 | 33 | deployment 34 | Internal Releases 35 | http://scm-ciir.cs.umass.edu:8080/nexus/content/repositories/snapshots/ 36 | 37 | 38 | 39 | 40 | src 41 | test 42 | 43 | 44 | maven-jar-plugin 45 | 2.4 46 | 47 | 48 | true 49 | 50 | ${mainClass} 51 | 52 | 53 | 54 | 55 | 56 | 57 | maven-compiler-plugin 58 | 3.2 59 | 60 | ${javaVersion} 61 | ${javaVersion} 62 | 63 | 64 | 65 | 67 | 68 | maven-javadoc-plugin 69 | 2.10.3 70 | 71 | -Xdoclint:none 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/features/SumNormalizor.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.features; 11 | 12 | import java.util.Arrays; 13 | 14 | import ciir.umass.edu.learning.DataPoint; 15 | import ciir.umass.edu.learning.RankList; 16 | 17 | /** 18 | * @author vdang 19 | */ 20 | public class SumNormalizor extends Normalizer { 21 | @Override 22 | public void normalize(RankList rl) { 23 | if(rl.size() == 0) 24 | { 25 | System.out.println("Error in SumNormalizor::normalize(): The input ranked list is empty"); 26 | System.exit(1); 27 | } 28 | int nFeature = DataPoint.getFeatureCount(); 29 | double[] norm = new double[nFeature]; 30 | Arrays.fill(norm, 0); 31 | for(int i=0;i 0) 43 | dp.setFeatureValue(j, (float)(dp.getFeatureValue(j)/norm[j-1])); 44 | } 45 | } 46 | } 47 | @Override 48 | public void normalize(RankList rl, int[] fids) { 49 | if(rl.size() == 0) 50 | { 51 | System.out.println("Error in SumNormalizor::normalize(): The input ranked list is empty"); 52 | System.exit(1); 53 | } 54 | 55 | //remove duplicate features from the input @fids ==> avoid normalizing the same features multiple times 56 | fids = removeDuplicateFeatures(fids); 57 | 58 | double[] norm = new double[fids.length]; 59 | Arrays.fill(norm, 0); 60 | for(int i=0;i 0) 71 | dp.setFeatureValue(fids[j], (float)(dp.getFeatureValue(fids[j])/norm[j])); 72 | } 73 | } 74 | public String name() 75 | { 76 | return "sum"; 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/learning/neuralnet/Layer.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.learning.neuralnet; 11 | 12 | import java.util.ArrayList; 13 | import java.util.List; 14 | 15 | /** 16 | * @author vdang 17 | * 18 | * This class implements layers of neurons in neural networks. 19 | */ 20 | public class Layer { 21 | protected List neurons = null; 22 | 23 | public Layer(int size) 24 | { 25 | neurons = new ArrayList(); 26 | for(int i=0;i(); 37 | for(int i=0;i avoid normalizing the same features multiple times 44 | fids = removeDuplicateFeatures(fids); 45 | 46 | float[] min = new float[fids.length]; 47 | float[] max = new float[fids.length]; 48 | //Arrays.fill(min, 0); 49 | Arrays.fill (min, Float.MAX_VALUE); 50 | //Arrays.fill(max, 0); 51 | Arrays.fill(max, Float.MIN_VALUE); 52 | 53 | for(int i=0;i min[j]) 68 | { 69 | float value = (dp.getFeatureValue(fids[j]) - min[j]) / (max[j] - min[j]); 70 | dp.setFeatureValue(fids[j], value); 71 | } 72 | else 73 | dp.setFeatureValue(fids[j], 0); 74 | } 75 | } 76 | } 77 | public String name() 78 | { 79 | return "linear"; 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/utilities/MyThreadPool.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.utilities; 11 | 12 | import java.util.concurrent.LinkedBlockingQueue; 13 | import java.util.concurrent.Semaphore; 14 | import java.util.concurrent.ThreadPoolExecutor; 15 | import java.util.concurrent.TimeUnit; 16 | 17 | /** 18 | * 19 | * @author vdang 20 | * 21 | */ 22 | public class MyThreadPool extends ThreadPoolExecutor { 23 | 24 | private final Semaphore semaphore; 25 | private int size = 0; 26 | 27 | private MyThreadPool(int size) 28 | { 29 | super(size, size, 0, TimeUnit.MILLISECONDS, new LinkedBlockingQueue()); 30 | semaphore = new Semaphore(size, true); 31 | this.size = size; 32 | } 33 | 34 | private static MyThreadPool singleton = null; 35 | public static MyThreadPool getInstance() 36 | { 37 | if(singleton == null) 38 | init(Runtime.getRuntime().availableProcessors()); 39 | return singleton; 40 | } 41 | 42 | public static void init(int poolSize) 43 | { 44 | singleton = new MyThreadPool(poolSize); 45 | } 46 | public int size() 47 | { 48 | return size; 49 | } 50 | public WorkerThread[] execute(WorkerThread worker, int nTasks) 51 | { 52 | MyThreadPool p = MyThreadPool.getInstance(); 53 | int[] partition = p.partition(nTasks); 54 | WorkerThread[] workers = new WorkerThread[partition.length-1]; 55 | for(int i=0;i 0) 51 | { 52 | for(int i=0;i avoid normalizing the same features multiple times 70 | fids = removeDuplicateFeatures(fids); 71 | 72 | double[] means = new double[fids.length]; 73 | Arrays.fill(means, 0); 74 | for(int i=0;i 0.0) 94 | { 95 | for(int i=0;i k) ? k : rl.size(); 28 | int firstRank = -1; 29 | for(int i=0;i 0.0)//relevant 32 | firstRank = i+1; 33 | } 34 | return (firstRank==-1)?0:(1.0f/firstRank); 35 | } 36 | public MetricScorer copy() 37 | { 38 | return new ReciprocalRankScorer(); 39 | } 40 | public String name() 41 | { 42 | return "RR@"+k; 43 | } 44 | public double[][] swapChange(RankList rl) 45 | { 46 | int firstRank = -1; 47 | int secondRank = -1; 48 | int size = (rl.size() > k) ? k : rl.size(); 49 | for(int i=0;i 0.0)//relevant 52 | { 53 | if(firstRank==-1) 54 | firstRank = i; 55 | else if(secondRank == -1) 56 | secondRank = i; 57 | } 58 | } 59 | 60 | //compute the change in RR by swapping each pair 61 | double[][] changes = new double[rl.size()][]; 62 | for(int i=0;i 0) 101 | changes[i][j] = changes[j][i] = 1.0/(i+1) - rr; 102 | } 103 | } 104 | return changes; 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/metric/ERRScorer.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.metric; 11 | 12 | import java.util.ArrayList; 13 | import java.util.Arrays; 14 | import java.util.List; 15 | 16 | import ciir.umass.edu.learning.RankList; 17 | 18 | /** 19 | * 20 | * @author Van Dang 21 | * Expected Reciprocal Rank 22 | */ 23 | public class ERRScorer extends MetricScorer { 24 | 25 | public static double MAX = 16;//by default, we assume the relevance scale of {0, 1, 2, 3, 4} => g_max = 4 => 2^g_max = 16 26 | 27 | public ERRScorer() 28 | { 29 | this.k = 10; 30 | } 31 | public ERRScorer(int k) 32 | { 33 | this.k = k; 34 | } 35 | public ERRScorer copy() 36 | { 37 | return new ERRScorer(); 38 | } 39 | /** 40 | * Compute ERR at k. NDCG(k) = DCG(k) / DCG_{perfect}(k). Note that the "perfect ranking" must be computed based on the whole list, 41 | * not just top-k portion of the list. 42 | */ 43 | public double score(RankList rl) 44 | { 45 | int size = k; 46 | if(k > rl.size() || k <= 0) 47 | size = rl.size(); 48 | 49 | List rel = new ArrayList(); 50 | for(int i=0;i k) ? k : rl.size(); 74 | int[] labels = new int[rl.size()]; 75 | double[] R = new double[rl.size()]; 76 | double[] np = new double[rl.size()];//p[i] = (1 - p[0])(1 - p[1])...(1-p[i-1]) 77 | double p = 1.0; 78 | //for(int i=0;i rl.size()-1) 49 | size = rl.size()-1; 50 | 51 | double max = -1.0; 52 | int max_i = 0; 53 | for(int i=0;i<=size;i++) 54 | { 55 | if(max < rl.get(i).getLabel()) 56 | { 57 | max = rl.get(i).getLabel(); 58 | max_i = i; 59 | } 60 | } 61 | return max_i; 62 | } 63 | public String name() 64 | { 65 | return "Best@"+k; 66 | } 67 | public double[][] swapChange(RankList rl) 68 | { 69 | //FIXME: not sure if this implementation is correct! 70 | int[] labels = new int[rl.size()]; 71 | int[] best = new int[rl.size()]; 72 | int max = -1; 73 | int maxVal = -1; 74 | int secondMaxVal = -1;//within top-K 75 | int maxCount = 0;//within top-K 76 | for(int i=0;i= k) 110 | change = 0; 111 | else if(labels[i] == labels[j] || labels[j] == labels[best[k-1]]) 112 | change = 0; 113 | else if(labels[j] > labels[best[k-1]]) 114 | change = labels[j] - labels[best[i]]; 115 | else if(labels[i] < labels[best[k-1]] || maxCount > 1) 116 | change = 0; 117 | else 118 | change = maxVal - Math.max(secondMaxVal, labels[j]); 119 | changes[i][j] = changes[j][i] = change; 120 | } 121 | } 122 | return changes; 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/metric/DCGScorer.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.metric; 11 | 12 | import ciir.umass.edu.learning.RankList; 13 | import ciir.umass.edu.utilities.SimpleMath; 14 | 15 | public class DCGScorer extends MetricScorer { 16 | 17 | protected static double[] discount = null;//cache 18 | protected static double[] gain = null;//cache 19 | 20 | public DCGScorer() 21 | { 22 | this.k = 10; 23 | //init cache if we haven't already done so 24 | if(discount == null) 25 | { 26 | discount = new double[5000]; 27 | for(int i=0;i rl.size() || k <= 0) 62 | size = rl.size(); 63 | 64 | int[] rel = getRelevanceLabels(rl); 65 | return getDCG(rel, size); 66 | } 67 | public double[][] swapChange(RankList rl) 68 | { 69 | int[] rel = getRelevanceLabels(rl); 70 | int size = (rl.size() > k) ? k : rl.size(); 71 | double[][] changes = new double[rl.size()][]; 72 | for(int i=0;i samples, int [] features, MetricScorer scorer) 30 | { 31 | super(samples, features, scorer); 32 | } 33 | protected int[][] batchFeedForward(RankList rl) 34 | { 35 | int[][] pairMap = new int[rl.size()][]; 36 | targetValue = new float[rl.size()][]; 37 | for(int i=0;i rl.get(j).getLabel() || rl.get(i).getLabel() < rl.get(j).getLabel()) 45 | count++; 46 | 47 | pairMap[i] = new int[count]; 48 | targetValue[i] = new float[count]; 49 | 50 | int k=0; 51 | for(int j=0;j rl.get(j).getLabel() || rl.get(i).getLabel() < rl.get(j).getLabel()) 53 | { 54 | pairMap[i][k] = j; 55 | if(rl.get(i).getLabel() > rl.get(j).getLabel()) 56 | targetValue[i][k] = 1; 57 | else 58 | targetValue[i][k] = 0; 59 | k++; 60 | } 61 | } 62 | return pairMap; 63 | } 64 | protected void batchBackPropagate(int[][] pairMap, float[][] pairWeight) 65 | { 66 | for(int i=0;i=1;j--)//back-propagate to the first hidden layer 72 | layers.get(j).updateDelta(p); 73 | 74 | //weight update 75 | outputLayer.updateWeight(p); 76 | for(int j=layers.size()-2;j>=1;j--) 77 | layers.get(j).updateWeight(p); 78 | } 79 | } 80 | protected RankList internalReorder(RankList rl) 81 | { 82 | return rank(rl); 83 | } 84 | protected float[][] computePairWeight(int[][] pairMap, RankList rl) 85 | { 86 | double[][] changes = scorer.swapChange(rl); 87 | float[][] weight = new float[pairMap.length][]; 88 | for(int i=0;i rl.get(pairMap[i][j]).getLabel())?1:-1; 94 | weight[i][j] = (float)Math.abs(changes[i][pairMap[i][j]])*sign; 95 | } 96 | } 97 | return weight; 98 | } 99 | protected void estimateLoss() 100 | { 101 | misorderedPairs = 0; 102 | for(int j=0;j rl.get(l).getLabel()) 111 | { 112 | double o2 = eval(rl.get(l)); 113 | //error += crossEntropy(o1, o2, 1.0f); 114 | if(o1 < o2) 115 | misorderedPairs++; 116 | } 117 | } 118 | } 119 | } 120 | error = 1.0 - scoreOnTrainingData; 121 | if(error > lastError) 122 | { 123 | //Neuron.learningRate *= 0.8; 124 | straightLoss++; 125 | } 126 | else 127 | straightLoss = 0; 128 | lastError = error; 129 | } 130 | 131 | public Ranker createNew() 132 | { 133 | return new LambdaRank(); 134 | } 135 | public String name() 136 | { 137 | return "LambdaRank"; 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/learning/SparseDataPoint.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.learning; 11 | 12 | import java.util.Arrays; 13 | 14 | /** 15 | * Implements a sparse data point using a compressed sparse row data structure 16 | * @author Siddhartha Bagaria 17 | */ 18 | public class SparseDataPoint extends DataPoint { 19 | 20 | // Access pattern of the feature values 21 | private enum accessPattern {SEQUENTIAL, RANDOM}; 22 | private static accessPattern searchPattern = accessPattern.RANDOM; 23 | 24 | // Profiling variables 25 | // private static int numCalls = 0; 26 | // private static float avgTime = 0; 27 | 28 | // The feature ids for known values 29 | int fIds[]; 30 | 31 | // The feature values for corresponding Ids 32 | //float fVals[]; //moved to the parent class 33 | 34 | // Internal search optimizers. Currently unused. 35 | int lastMinId = -1; 36 | int lastMinPos = -1; 37 | 38 | public SparseDataPoint(String text) { 39 | super(text); 40 | } 41 | 42 | public SparseDataPoint(SparseDataPoint dp) 43 | { 44 | label = dp.label; 45 | id = dp.id; 46 | description = dp.description; 47 | cached = dp.cached; 48 | fIds = new int[dp.fIds.length]; 49 | fVals = new float[dp.fVals.length]; 50 | System.arraycopy(dp.fIds, 0, fIds, 0, dp.fIds.length); 51 | System.arraycopy(dp.fVals, 0, fVals, 0, dp.fVals.length); 52 | } 53 | 54 | private int locate(int fid) { 55 | if (searchPattern == accessPattern.SEQUENTIAL) 56 | { 57 | if (lastMinId > fid) 58 | { 59 | lastMinId = -1; 60 | lastMinPos = -1; 61 | } 62 | while (lastMinPos < knownFeatures && lastMinId < fid) 63 | lastMinId = fIds[++lastMinPos]; 64 | if (lastMinId == fid) 65 | return lastMinPos; 66 | } 67 | else if (searchPattern == accessPattern.RANDOM) 68 | { 69 | int pos = Arrays.binarySearch(fIds, fid); 70 | if (pos >= 0) 71 | return pos; 72 | } 73 | else 74 | System.err.println("Invalid search pattern specified for sparse data points."); 75 | 76 | return -1; 77 | } 78 | 79 | public boolean hasFeature(int fid) { 80 | return locate(fid) != -1; 81 | } 82 | 83 | @Override 84 | public float getFeatureValue(int fid) 85 | { 86 | //long time = System.nanoTime(); 87 | if(fid <= 0 || fid > getFeatureCount()) 88 | { 89 | System.out.println("Error in SparseDataPoint::getFeatureValue(): requesting invalid feature, fid=" + fid); 90 | System.exit(1); 91 | } 92 | int pos = locate(fid); 93 | //long completedIn = System.nanoTime() - time; 94 | //avgTime = (avgTime*numCalls + completedIn)/(++numCalls); 95 | //System.out.println("getFeatureValue average time: "+avgTime); 96 | if(pos >= 0) 97 | return fVals[pos]; 98 | 99 | return 0; // Should ideally be returning unknown? 100 | } 101 | 102 | @Override 103 | public void setFeatureValue(int fid, float fval) 104 | { 105 | if(fid <= 0 || fid > getFeatureCount()) 106 | { 107 | System.out.println("Error in SparseDataPoint::setFeatureValue(): feature (id=" + fid + ") out of range."); 108 | System.exit(1); 109 | } 110 | int pos = locate(fid); 111 | if(pos >= 0) 112 | fVals[pos] = fval; 113 | else 114 | { 115 | System.err.println("Error in SparseDataPoint::setFeatureValue(): feature (id=" + fid + ") not found."); 116 | System.exit(1); 117 | } 118 | } 119 | 120 | @Override 121 | public void setFeatureVector(float[] dfVals) 122 | { 123 | fIds = new int[knownFeatures]; 124 | fVals = new float[knownFeatures]; 125 | int pos = 0; 126 | for (int i=1; i leaves = null; 29 | 30 | protected DataPoint[] trainingSamples = null; 31 | protected double[] trainingLabels = null; 32 | protected int[] features = null; 33 | protected float[][] thresholds = null; 34 | protected int[] index = null; 35 | protected FeatureHistogram hist = null; 36 | 37 | public RegressionTree(Split root) 38 | { 39 | this.root = root; 40 | leaves = root.leaves(); 41 | } 42 | public RegressionTree(int nLeaves, DataPoint[] trainingSamples, double[] labels, FeatureHistogram hist, int minLeafSupport) 43 | { 44 | this.nodes = nLeaves; 45 | this.trainingSamples = trainingSamples; 46 | this.trainingLabels = labels; 47 | this.hist = hist; 48 | this.minLeafSupport = minLeafSupport; 49 | index = new int[trainingSamples.length]; 50 | for(int i=0;i queue = new ArrayList(); 60 | root = new Split(index, hist, Float.MAX_VALUE, 0); 61 | root.setRoot(true); 62 | 63 | // Ensure inserts occur only after successful splits 64 | if(root.split(trainingLabels, minLeafSupport)) { 65 | insert(queue, root.getLeft()); 66 | insert(queue, root.getRight()); 67 | } 68 | 69 | int taken = 0; 70 | while( (nodes == -1 || taken + queue.size() < nodes) && queue.size() > 0) 71 | { 72 | Split leaf = queue.get(0); 73 | queue.remove(0); 74 | 75 | if(leaf.getSamples().length < 2 * minLeafSupport) 76 | { 77 | taken++; 78 | continue; 79 | } 80 | 81 | if(!leaf.split(trainingLabels, minLeafSupport))//unsplitable (i.e. variance(s)==0; or after-split variance is higher than before) 82 | taken++; 83 | else 84 | { 85 | insert(queue, leaf.getLeft()); 86 | insert(queue, leaf.getRight()); 87 | } 88 | } 89 | leaves = root.leaves(); 90 | } 91 | 92 | /** 93 | * Get the tree output for the input sample 94 | * @param dp 95 | * @return 96 | */ 97 | public double eval(DataPoint dp) 98 | { 99 | return root.eval(dp); 100 | } 101 | /** 102 | * Retrieve all leave nodes in the tree 103 | * @return 104 | */ 105 | public List leaves() 106 | { 107 | return leaves; 108 | } 109 | /** 110 | * Clear samples associated with each leaves (when they are no longer necessary) in order to save memory 111 | */ 112 | public void clearSamples() 113 | { 114 | trainingSamples = null; 115 | trainingLabels = null; 116 | features = null; 117 | thresholds = null; 118 | index = null; 119 | hist = null; 120 | for(int i=0;i ls, Split s) 149 | { 150 | int i=0; 151 | while(i < ls.size()) 152 | { 153 | if(ls.get(i).getDeviance() > s.getDeviance()) 154 | i++; 155 | else 156 | break; 157 | } 158 | ls.add(i, s); 159 | } 160 | 161 | } 162 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/metric/APScorer.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.metric; 11 | 12 | import ciir.umass.edu.learning.RankList; 13 | import ciir.umass.edu.utilities.RankLibError; 14 | 15 | import java.io.BufferedReader; 16 | import java.io.FileInputStream; 17 | import java.io.IOException; 18 | import java.io.InputStreamReader; 19 | import java.util.Arrays; 20 | import java.util.HashMap; 21 | 22 | /** 23 | * @author vdang 24 | * This class implements MAP (Mean Average Precision) 25 | */ 26 | public class APScorer extends MetricScorer { 27 | //This class computes MAP from the *WHOLE* ranked list. "K" will be completely ignored. 28 | //The reason is, if you want MAP@10, you really should be using NDCG@10 or ERR@10 instead. 29 | 30 | public HashMap relDocCount = null; 31 | 32 | public APScorer() 33 | { 34 | this.k = 0;//consider the whole list 35 | } 36 | public MetricScorer copy() 37 | { 38 | return new APScorer(); 39 | } 40 | public void loadExternalRelevanceJudgment(String qrelFile) 41 | { 42 | relDocCount = new HashMap<>(); 43 | try (BufferedReader in = new BufferedReader(new InputStreamReader(new FileInputStream(qrelFile)))) { 44 | String content = ""; 45 | while((content = in.readLine()) != null) 46 | { 47 | content = content.trim(); 48 | if(content.length() == 0) 49 | continue; 50 | String[] s = content.split(" "); 51 | String qid = s[0].trim(); 52 | //String docid = s[2].trim(); 53 | int label = (int) Math.rint(Double.parseDouble(s[3].trim())); 54 | if(label > 0) { 55 | int prev = relDocCount.getOrDefault(qid, 0); 56 | relDocCount.put(qid, prev+1); 57 | } 58 | } 59 | 60 | System.out.println("Relevance judgment file loaded. [#q=" + relDocCount.size() + "]"); 61 | } 62 | catch(IOException ex) 63 | { 64 | throw RankLibError.create("Error in APScorer::loadExternalRelevanceJudgment(): ", ex); 65 | } 66 | } 67 | /** 68 | * Compute Average Precision (AP) of the list. AP of a list is the average of precision evaluated at ranks where a relevant document 69 | * is observed. 70 | * @return AP of the list. 71 | */ 72 | public double score(RankList rl) 73 | { 74 | double ap = 0.0; 75 | int count = 0; 76 | for(int i=0;i 0.0)//relevant 79 | { 80 | count++; 81 | ap += ((double)count)/(i+1); 82 | } 83 | } 84 | 85 | int rdCount = 0; 86 | if(relDocCount != null) 87 | { 88 | Integer it = relDocCount.get(rl.getID()); 89 | if(it != null) 90 | rdCount = it; 91 | } 92 | else //no qrel-file specified, we can only use the #relevant-docs in the training file 93 | rdCount = count; 94 | 95 | if(rdCount==0) 96 | return 0.0; 97 | return ap / rdCount; 98 | } 99 | public String name() 100 | { 101 | return "MAP"; 102 | } 103 | public double[][] swapChange(RankList rl) 104 | { 105 | //NOTE: Compute swap-change *IGNORING* K (consider the entire ranked list) 106 | int[] relCount = new int[rl.size()]; 107 | int[] labels = new int[rl.size()]; 108 | int count = 0; 109 | for(int i=0;i 0)//relevant 112 | { 113 | labels[i] = 1; 114 | count++; 115 | } 116 | else 117 | labels[i] = 0; 118 | relCount[i] = count; 119 | } 120 | int rdCount = 0;//total number of relevant documents 121 | if(relDocCount != null)//if an external qrels file is specified 122 | { 123 | Integer it = relDocCount.get(rl.getID()); 124 | if(it != null) 125 | rdCount = it; 126 | } 127 | else 128 | rdCount = count; 129 | 130 | double[][] changes = new double[rl.size()][]; 131 | for(int i=0;i 0) 151 | change += ((double)diff) / (k+1); 152 | change += ((double)(-relCount[j]*diff)) / (j+1); 153 | //It is equivalent to: change += ((double)(relCount[j]*labels[i] - relCount[j]*labels[j])) / (j+1); 154 | } 155 | changes[j][i] = changes[i][j] = change/rdCount; 156 | } 157 | } 158 | return changes; 159 | } 160 | } 161 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/learning/RankerFactory.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.learning; 11 | 12 | import ciir.umass.edu.learning.boosting.AdaRank; 13 | import ciir.umass.edu.learning.boosting.RankBoost; 14 | import ciir.umass.edu.learning.neuralnet.LambdaRank; 15 | import ciir.umass.edu.learning.neuralnet.ListNet; 16 | import ciir.umass.edu.learning.neuralnet.RankNet; 17 | import ciir.umass.edu.learning.tree.LambdaMART; 18 | import ciir.umass.edu.learning.tree.MART; 19 | import ciir.umass.edu.learning.tree.RFRanker; 20 | import ciir.umass.edu.metric.MetricScorer; 21 | import ciir.umass.edu.utilities.FileUtils; 22 | import ciir.umass.edu.utilities.RankLibError; 23 | 24 | import java.io.BufferedReader; 25 | import java.io.StringReader; 26 | import java.util.HashMap; 27 | import java.util.List; 28 | 29 | /** 30 | * @author vdang 31 | * 32 | * This class implements the Ranker factory. All ranking algorithms implemented have to be recognized in this class. 33 | */ 34 | public class RankerFactory { 35 | 36 | protected Ranker[] rFactory = new Ranker[]{new MART(), new RankBoost(), new RankNet(), new AdaRank(), new CoorAscent(), new LambdaRank(), new LambdaMART(), new ListNet(), new RFRanker(), new LinearRegRank()}; 37 | protected static HashMap map = new HashMap(); 38 | 39 | public RankerFactory() 40 | { 41 | map.put(createRanker(RANKER_TYPE.MART).name().toUpperCase(), RANKER_TYPE.MART); 42 | map.put(createRanker(RANKER_TYPE.RANKNET).name().toUpperCase(), RANKER_TYPE.RANKNET); 43 | map.put(createRanker(RANKER_TYPE.RANKBOOST).name().toUpperCase(), RANKER_TYPE.RANKBOOST); 44 | map.put(createRanker(RANKER_TYPE.ADARANK).name().toUpperCase(), RANKER_TYPE.ADARANK); 45 | map.put(createRanker(RANKER_TYPE.COOR_ASCENT).name().toUpperCase(), RANKER_TYPE.COOR_ASCENT); 46 | map.put(createRanker(RANKER_TYPE.LAMBDARANK).name().toUpperCase(), RANKER_TYPE.LAMBDARANK); 47 | map.put(createRanker(RANKER_TYPE.LAMBDAMART).name().toUpperCase(), RANKER_TYPE.LAMBDAMART); 48 | map.put(createRanker(RANKER_TYPE.LISTNET).name().toUpperCase(), RANKER_TYPE.LISTNET); 49 | map.put(createRanker(RANKER_TYPE.RANDOM_FOREST).name().toUpperCase(), RANKER_TYPE.RANDOM_FOREST); 50 | map.put(createRanker(RANKER_TYPE.LINEAR_REGRESSION).name().toUpperCase(), RANKER_TYPE.LINEAR_REGRESSION); 51 | } 52 | public Ranker createRanker(RANKER_TYPE type) 53 | { 54 | return rFactory[type.ordinal() - RANKER_TYPE.MART.ordinal()].createNew(); 55 | } 56 | public Ranker createRanker(RANKER_TYPE type, List samples, int[] features, MetricScorer scorer) 57 | { 58 | Ranker r = createRanker(type); 59 | r.setTrainingSet(samples); 60 | r.setFeatures(features); 61 | r.setMetricScorer(scorer); 62 | return r; 63 | } 64 | @SuppressWarnings("unchecked") 65 | public Ranker createRanker(String className) 66 | { 67 | Ranker r = null; 68 | try { 69 | Class c = Class.forName(className); 70 | r = (Ranker) c.newInstance(); 71 | } 72 | catch (ClassNotFoundException e) { 73 | System.out.println("Could find the class \"" + className + "\" you specified. Make sure the jar library is in your classpath."); 74 | e.printStackTrace(); 75 | System.exit(1); 76 | } 77 | catch (InstantiationException e) { 78 | System.out.println("Cannot create objects from the class \"" + className + "\" you specified."); 79 | e.printStackTrace(); 80 | System.exit(1); 81 | } 82 | catch (IllegalAccessException e) { 83 | System.out.println("The class \"" + className + "\" does not implement the Ranker interface."); 84 | e.printStackTrace(); 85 | System.exit(1); 86 | } 87 | return r; 88 | } 89 | public Ranker createRanker(String className, List samples, int[] features, MetricScorer scorer) 90 | { 91 | Ranker r = createRanker(className); 92 | r.setTrainingSet(samples); 93 | r.setFeatures(features); 94 | r.setMetricScorer(scorer); 95 | return r; 96 | } 97 | public Ranker loadRankerFromFile(String modelFile) 98 | { 99 | return loadRankerFromString(FileUtils.read(modelFile, "ASCII")); 100 | } 101 | public Ranker loadRankerFromString(String fullText) 102 | { 103 | try (BufferedReader in = new BufferedReader(new StringReader(fullText))) { 104 | Ranker r; 105 | String content = in.readLine();//read the first line to get the name of the ranking algorithm 106 | content = content.replace("## ", "").trim(); 107 | System.out.println("Model:\t\t" + content); 108 | r = createRanker(map.get(content.toUpperCase())); 109 | r.loadFromString(fullText); 110 | return r; 111 | } 112 | catch(Exception ex) 113 | { 114 | throw RankLibError.create(ex); 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/metric/NDCGScorer.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.metric; 11 | 12 | import ciir.umass.edu.learning.RankList; 13 | import ciir.umass.edu.utilities.RankLibError; 14 | import ciir.umass.edu.utilities.Sorter; 15 | 16 | import java.io.BufferedReader; 17 | import java.io.FileInputStream; 18 | import java.io.IOException; 19 | import java.io.InputStreamReader; 20 | import java.util.ArrayList; 21 | import java.util.Arrays; 22 | import java.util.HashMap; 23 | import java.util.List; 24 | 25 | /** 26 | * @author vdang 27 | */ 28 | public class NDCGScorer extends DCGScorer { 29 | 30 | protected HashMap idealGains = null; 31 | 32 | public NDCGScorer() 33 | { 34 | super(); 35 | idealGains = new HashMap<>(); 36 | } 37 | public NDCGScorer(int k) 38 | { 39 | super(k); 40 | idealGains = new HashMap<>(); 41 | } 42 | public MetricScorer copy() 43 | { 44 | return new NDCGScorer(); 45 | } 46 | public void loadExternalRelevanceJudgment(String qrelFile) 47 | { 48 | //Queries with external relevance judgment will have their cached ideal gain value overridden 49 | try (BufferedReader in = new BufferedReader(new InputStreamReader(new FileInputStream(qrelFile)))) 50 | { 51 | String content = ""; 52 | String lastQID = ""; 53 | List rel = new ArrayList(); 54 | int nQueries = 0; 55 | while((content = in.readLine()) != null) 56 | { 57 | content = content.trim(); 58 | if(content.length() == 0) 59 | continue; 60 | String[] s = content.split(" "); 61 | String qid = s[0].trim(); 62 | //String docid = s[2].trim(); 63 | int label = (int) Math.rint(Double.parseDouble(s[3].trim())); 64 | if(lastQID.compareTo("")!=0 && lastQID.compareTo(qid)!=0) 65 | { 66 | int size = (rel.size() > k) ? k : rel.size(); 67 | int[] r = new int[rel.size()]; 68 | for(int i=0;i 0) 79 | { 80 | int size = (rel.size() > k) ? k : rel.size(); 81 | int[] r = new int[rel.size()]; 82 | for(int i=0;i rl.size() || k <= 0) 106 | size = rl.size(); 107 | 108 | int[] rel = getRelevanceLabels(rl); 109 | 110 | double ideal = 0; 111 | Double d = idealGains.get(rl.getID()); 112 | if(d != null) 113 | ideal = d; 114 | else 115 | { 116 | ideal = getIdealDCG(rel, size); 117 | idealGains.put(rl.getID(), ideal); 118 | } 119 | 120 | if(ideal <= 0.0)//I mean precisely "=" 121 | return 0.0; 122 | 123 | return getDCG(rel, size)/ideal; 124 | } 125 | public double[][] swapChange(RankList rl) 126 | { 127 | int size = (rl.size() > k) ? k : rl.size(); 128 | //compute the ideal ndcg 129 | int[] rel = getRelevanceLabels(rl); 130 | double ideal = 0; 131 | Double d = idealGains.get(rl.getID()); 132 | if(d != null) 133 | ideal = d; 134 | else 135 | { 136 | ideal = getIdealDCG(rel, size); 137 | //idealGains.put(rl.getID(), ideal);//DO *NOT* do caching here. It's not thread-safe. 138 | } 139 | 140 | double[][] changes = new double[rl.size()][]; 141 | for(int i=0;i 0) 150 | changes[j][i] = changes[i][j] = (discount(i) - discount(j)) * (gain(rel[i]) - gain(rel[j])) / ideal; 151 | 152 | return changes; 153 | } 154 | public String name() 155 | { 156 | return "NDCG@"+k; 157 | } 158 | 159 | private double getIdealDCG(int[] rel, int topK) 160 | { 161 | int[] idx = Sorter.sort(rel, false); 162 | double dcg = 0; 163 | for(int i=0;i trees = null; 30 | protected List weights = null; 31 | protected int[] features = null; 32 | 33 | public Ensemble() 34 | { 35 | trees = new ArrayList(); 36 | weights = new ArrayList(); 37 | } 38 | public Ensemble(Ensemble e) 39 | { 40 | trees = new ArrayList(); 41 | weights = new ArrayList(); 42 | trees.addAll(e.trees); 43 | weights.addAll(e.weights); 44 | } 45 | public Ensemble(String xmlRep) 46 | { 47 | try { 48 | trees = new ArrayList(); 49 | weights = new ArrayList(); 50 | DocumentBuilderFactory dbFactory = DocumentBuilderFactory.newInstance(); 51 | DocumentBuilder dBuilder = dbFactory.newDocumentBuilder(); 52 | byte[] xmlDATA = xmlRep.getBytes(); 53 | ByteArrayInputStream in = new ByteArrayInputStream(xmlDATA); 54 | Document doc = dBuilder.parse(in); 55 | NodeList nl = doc.getElementsByTagName("tree"); 56 | HashMap fids = new HashMap(); 57 | for(int i=0;i" + "\n"; 126 | strRep += trees.get(i).toString("\t\t"); 127 | strRep += "\t" + "\n"; 128 | } 129 | strRep += "" + "\n"; 130 | return strRep; 131 | } 132 | public int[] getFeatures() 133 | { 134 | return features; 135 | } 136 | 137 | /** 138 | * Each input node @n corersponds to a tag in the model file. 139 | * @param n 140 | * @return 141 | */ 142 | private Split create(Node n, HashMap fids) 143 | { 144 | Split s = null; 145 | if(n.getFirstChild().getNodeName().compareToIgnoreCase("feature") == 0)//this is a split 146 | { 147 | NodeList nl = n.getChildNodes(); 148 | int fid = Integer.parseInt(nl.item(0).getFirstChild().getNodeValue().trim());// 149 | fids.put(fid, 0); 150 | float threshold = Float.parseFloat(nl.item(1).getFirstChild().getNodeValue().trim());// 151 | s = new Split(fid, threshold, 0); 152 | s.setLeft(create(nl.item(2), fids)); 153 | s.setRight(create(nl.item(3), fids)); 154 | } 155 | else//this is a stump 156 | { 157 | float output = Float.parseFloat(n.getFirstChild().getFirstChild().getNodeValue().trim()); 158 | s = new Split(); 159 | s.setOutput(output); 160 | } 161 | return s; 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/learning/tree/Split.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.learning.tree; 11 | 12 | import java.util.ArrayList; 13 | import java.util.List; 14 | 15 | import ciir.umass.edu.learning.DataPoint; 16 | 17 | /** 18 | * 19 | * @author vdang 20 | * 21 | */ 22 | public class Split { 23 | //Key attributes of a split (tree node) 24 | private int featureID = -1; 25 | private float threshold = 0F; 26 | private double avgLabel = 0.0F; 27 | 28 | //Intermediate variables (ONLY used during learning) 29 | //*DO NOT* attempt to access them once the training is done 30 | private boolean isRoot = false; 31 | private double sumLabel = 0.0; 32 | private double sqSumLabel = 0.0; 33 | private Split left = null; 34 | private Split right = null; 35 | private double deviance = 0F;//mean squared error "S" 36 | private int[][] sortedSampleIDs = null; 37 | public int[] samples = null; 38 | public FeatureHistogram hist = null; 39 | 40 | public Split() 41 | { 42 | 43 | } 44 | public Split(int featureID, float threshold, double deviance) 45 | { 46 | this.featureID = featureID; 47 | this.threshold = threshold; 48 | this.deviance = deviance; 49 | } 50 | public Split(int[][] sortedSampleIDs, double deviance, double sumLabel, double sqSumLabel) 51 | { 52 | this.sortedSampleIDs = sortedSampleIDs; 53 | this.deviance = deviance; 54 | this.sumLabel = sumLabel; 55 | this.sqSumLabel = sqSumLabel; 56 | avgLabel = sumLabel/sortedSampleIDs[0].length; 57 | } 58 | public Split(int[] samples, FeatureHistogram hist, double deviance, double sumLabel) 59 | { 60 | this.samples = samples; 61 | this.hist = hist; 62 | this.deviance = deviance; 63 | this.sumLabel = sumLabel; 64 | avgLabel = sumLabel/samples.length; 65 | } 66 | 67 | public void set(int featureID, float threshold, double deviance) 68 | { 69 | this.featureID = featureID; 70 | this.threshold = threshold; 71 | this.deviance = deviance; 72 | } 73 | public void setLeft(Split s) 74 | { 75 | left = s; 76 | } 77 | public void setRight(Split s) 78 | { 79 | right = s; 80 | } 81 | public void setOutput(float output) 82 | { 83 | avgLabel = output; 84 | } 85 | 86 | public Split getLeft() 87 | { 88 | return left; 89 | } 90 | public Split getRight() 91 | { 92 | return right; 93 | } 94 | public double getDeviance() 95 | { 96 | return deviance; 97 | } 98 | public double getOutput() 99 | { 100 | return avgLabel; 101 | } 102 | 103 | public List leaves() 104 | { 105 | List list = new ArrayList(); 106 | leaves(list); 107 | return list; 108 | } 109 | private void leaves(List leaves) 110 | { 111 | if(featureID == -1) 112 | leaves.add(this); 113 | else 114 | { 115 | left.leaves(leaves); 116 | right.leaves(leaves); 117 | } 118 | } 119 | 120 | public double eval(DataPoint dp) 121 | { 122 | Split n = this; 123 | while(n.featureID != -1) 124 | { 125 | if(dp.getFeatureValue(n.featureID) <= n.threshold) 126 | n = n.left; 127 | else 128 | n = n.right; 129 | } 130 | return n.avgLabel; 131 | } 132 | 133 | public String toString() 134 | { 135 | return toString(""); 136 | } 137 | public String toString(String indent) 138 | { 139 | String strOutput = indent + "" + "\n"; 140 | strOutput += getString(indent + "\t"); 141 | strOutput += indent + "" + "\n"; 142 | return strOutput; 143 | } 144 | public String getString(String indent) 145 | { 146 | String strOutput = ""; 147 | if(featureID == -1) 148 | { 149 | strOutput += indent + " " + avgLabel + " " + "\n"; 150 | } 151 | else 152 | { 153 | strOutput += indent + " " + featureID + " " + "\n"; 154 | strOutput += indent + " " + threshold + " " + "\n"; 155 | strOutput += indent + "" + "\n"; 156 | strOutput += left.getString(indent + "\t"); 157 | strOutput += indent + "" + "\n"; 158 | strOutput += indent + "" + "\n"; 159 | strOutput += right.getString(indent + "\t"); 160 | strOutput += indent + "" + "\n"; 161 | } 162 | return strOutput; 163 | } 164 | 165 | //Internal functions(ONLY used during learning) 166 | //*DO NOT* attempt to call them once the training is done 167 | public boolean split(double[] trainingLabels, int minLeafSupport) 168 | { 169 | return hist.findBestSplit(this, trainingLabels, minLeafSupport); 170 | } 171 | public int[] getSamples() 172 | { 173 | if(sortedSampleIDs != null) 174 | return sortedSampleIDs[0]; 175 | return samples; 176 | } 177 | public int[][] getSampleSortedIndex() 178 | { 179 | return sortedSampleIDs; 180 | } 181 | public double getSumLabel() 182 | { 183 | return sumLabel; 184 | } 185 | public double getSqSumLabel() 186 | { 187 | return sqSumLabel; 188 | } 189 | public void clearSamples() 190 | { 191 | sortedSampleIDs = null; 192 | samples = null; 193 | hist = null; 194 | } 195 | public void setRoot(boolean isRoot) 196 | { 197 | this.isRoot = isRoot; 198 | } 199 | public boolean isRoot() 200 | { 201 | return isRoot; 202 | } 203 | } -------------------------------------------------------------------------------- /src/ciir/umass/edu/utilities/MergeSorter.java: -------------------------------------------------------------------------------- 1 | /*=============================================================================== 2 | * Copyright (c) 2010-2012 University of Massachusetts. All Rights Reserved. 3 | * 4 | * Use of the RankLib package is subject to the terms of the software license set 5 | * forth in the LICENSE file included with this software, and also available at 6 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 7 | *=============================================================================== 8 | */ 9 | 10 | package ciir.umass.edu.utilities; 11 | 12 | import java.util.Random; 13 | 14 | /** 15 | * 16 | * @author vdang 17 | * 18 | */ 19 | public class MergeSorter { 20 | 21 | public static void main(String[] args) 22 | { 23 | float[][] f = new float[1000][]; 24 | for(int r=0;r= list[begin+i-1]) || (!asc && list[begin+i] <= list[begin+i-1]))) i++; 67 | if(i == idx.length) 68 | { 69 | System.arraycopy(idx, start, tmp, k, i-start); 70 | k = i; 71 | } 72 | else 73 | { 74 | j=i+1; 75 | while(j < idx.length && ((asc && list[begin+j] >= list[begin+j-1]) || (!asc && list[begin+j] <= list[begin+j-1]))) j++; 76 | merge(list, idx, start, i-1, i, j-1, tmp, k, asc); 77 | i = j+1; 78 | k=j; 79 | } 80 | ph[p++] = k; 81 | }while(k < idx.length); 82 | System.arraycopy(tmp, 0, idx, 0, idx.length); 83 | 84 | //subsequent iterations 85 | while(p > 2) 86 | { 87 | if(p % 2 == 0) 88 | ph[p++] = idx.length; 89 | k=0; 90 | int np = 1; 91 | for(int w=0;w= list[idx[j]]) 119 | tmp[k++] = idx[i++]; 120 | else 121 | tmp[k++] = idx[j++]; 122 | } 123 | } 124 | while(i <= e1) 125 | tmp[k++] = idx[i++]; 126 | while(j <= e2) 127 | tmp[k++] = idx[j++]; 128 | } 129 | 130 | public static int[] sort(double[] list, boolean asc) 131 | { 132 | return sort(list, 0, list.length-1, asc); 133 | } 134 | public static int[] sort(double[] list, int begin, int end, boolean asc) 135 | { 136 | int len = end - begin + 1; 137 | int[] idx = new int[len]; 138 | int[] tmp = new int[len]; 139 | for(int i=begin;i<=end;i++) 140 | idx[i-begin] = i; 141 | 142 | //identify natural runs and merge them (first iteration) 143 | int i=1; 144 | int j=0; 145 | int k=0; 146 | int start= 0; 147 | int[] ph = new int[len/2+3]; 148 | ph[0] = 0; 149 | int p=1; 150 | do { 151 | start = i-1; 152 | while(i < idx.length && ((asc && list[begin+i] >= list[begin+i-1]) || (!asc && list[begin+i] <= list[begin+i-1]))) i++; 153 | if(i == idx.length) 154 | { 155 | System.arraycopy(idx, start, tmp, k, i-start); 156 | k = i; 157 | } 158 | else 159 | { 160 | j=i+1; 161 | while(j < idx.length && ((asc && list[begin+j] >= list[begin+j-1]) || (!asc && list[begin+j] <= list[begin+j-1]))) j++; 162 | merge(list, idx, start, i-1, i, j-1, tmp, k, asc); 163 | i = j+1; 164 | k=j; 165 | } 166 | ph[p++] = k; 167 | }while(k < idx.length); 168 | System.arraycopy(tmp, 0, idx, 0, idx.length); 169 | 170 | //subsequent iterations 171 | while(p > 2) 172 | { 173 | if(p % 2 == 0) 174 | ph[p++] = idx.length; 175 | k=0; 176 | int np = 1; 177 | for(int w=0;w= list[idx[j]]) 205 | tmp[k++] = idx[i++]; 206 | else 207 | tmp[k++] = idx[j++]; 208 | } 209 | } 210 | while(i <= e1) 211 | tmp[k++] = idx[i++]; 212 | while(j <= e2) 213 | tmp[k++] = idx[j++]; 214 | } 215 | } 216 | -------------------------------------------------------------------------------- /src/ciir/umass/edu/learning/Ranker.java: -------------------------------------------------------------------------------- 1 | 2 | /*=============================================================================== 3 | * Copyright (c) 2010-2015 University of Massachusetts. All Rights Reserved. 4 | * 5 | * Use of the RankLib package is subject to the terms of the software license set 6 | * forth in the LICENSE file included with this software, and also available at 7 | * http://people.cs.umass.edu/~vdang/ranklib_license.html 8 | *=============================================================================== 9 | */ 10 | 11 | package ciir.umass.edu.learning; 12 | 13 | import ciir.umass.edu.metric.MetricScorer; 14 | import ciir.umass.edu.utilities.FileUtils; 15 | import ciir.umass.edu.utilities.MergeSorter; 16 | 17 | import java.text.DateFormat; 18 | import java.text.SimpleDateFormat; 19 | import java.util.ArrayList; 20 | import java.util.Date; 21 | import java.util.List; 22 | 23 | import java.util.Set; 24 | 25 | //- Some Java 7 file utilities for creating directories 26 | import java.nio.file.Files; 27 | import java.nio.file.Path; 28 | import java.nio.file.Paths; 29 | import java.nio.file.attribute.FileAttribute; 30 | import java.nio.file.attribute.PosixFilePermission; 31 | import java.nio.file.attribute.PosixFilePermissions; 32 | 33 | 34 | 35 | /** 36 | * @author vdang 37 | * 38 | * This class implements the generic Ranker interface. Each ranking algorithm implemented has to extend this class. 39 | */ 40 | public abstract class Ranker { 41 | public static boolean verbose = true; 42 | 43 | protected List samples = new ArrayList();//training samples 44 | protected int[] features = null; 45 | protected MetricScorer scorer = null; 46 | protected double scoreOnTrainingData = 0.0; 47 | protected double bestScoreOnValidationData = 0.0; 48 | 49 | protected List validationSamples = null; 50 | 51 | protected Ranker() 52 | { 53 | 54 | } 55 | protected Ranker(List samples, int[] features, MetricScorer scorer) 56 | { 57 | this.samples = samples; 58 | this.features = features; 59 | this.scorer = scorer; 60 | } 61 | 62 | //Utility functions 63 | public void setTrainingSet(List samples) 64 | { 65 | this.samples = samples; 66 | 67 | } 68 | public void setFeatures(int[] features) 69 | { 70 | this.features = features; 71 | } 72 | public void setValidationSet(List samples) 73 | { 74 | this.validationSamples = samples; 75 | } 76 | public void setMetricScorer(MetricScorer scorer) 77 | { 78 | this.scorer = scorer; 79 | } 80 | 81 | public double getScoreOnTrainingData() 82 | { 83 | return scoreOnTrainingData; 84 | } 85 | public double getScoreOnValidationData() 86 | { 87 | return bestScoreOnValidationData; 88 | } 89 | 90 | public int[] getFeatures() 91 | { 92 | return features; 93 | } 94 | 95 | public RankList rank(RankList rl) 96 | { 97 | double[] scores = new double[rl.size()]; 98 | for(int i=0;i rank(List l) 105 | { 106 | List ll = new ArrayList(); 107 | for(int i=0;i perms = PosixFilePermissions.fromString ("rwxr-xr-x"); 122 | FileAttribute> attr = PosixFilePermissions.asFileAttribute (perms); 123 | Path outputDir = Files.createDirectory (parentPath, attr); 124 | } 125 | catch (Exception e) { 126 | System.out.println ("Error creating kcv model file directory " + modelFile); 127 | } 128 | } 129 | 130 | FileUtils.write(modelFile, "ASCII", model()); 131 | } 132 | 133 | protected void PRINT(String msg) 134 | { 135 | if(verbose) 136 | System.out.print(msg); 137 | } 138 | 139 | protected void PRINTLN(String msg) 140 | { 141 | if(verbose) 142 | System.out.println(msg); 143 | } 144 | 145 | protected void PRINT(int[] len, String[] msgs) 146 | { 147 | if(verbose) 148 | { 149 | for(int i=0;i len[i]) 153 | msg = msg.substring(0, len[i]); 154 | else 155 | while(msg.length() < len[i]) 156 | msg += " "; 157 | System.out.print(msg + " | "); 158 | } 159 | } 160 | } 161 | protected void PRINTLN(int[] len, String[] msgs) 162 | { 163 | PRINT(len, msgs); 164 | PRINTLN(""); 165 | } 166 | protected void PRINTTIME() 167 | { 168 | DateFormat dateFormat = new SimpleDateFormat("MM/dd HH:mm:ss"); 169 | Date date = new Date(); 170 | System.out.println(dateFormat.format(date)); 171 | } 172 | protected void PRINT_MEMORY_USAGE() 173 | { 174 | System.out.println("***** " + Runtime.getRuntime().freeMemory() + " / " + Runtime.getRuntime().maxMemory()); 175 | } 176 | 177 | protected void copy(double[] source, double[] target) 178 | { 179 | for(int j=0;j= MAX_FEATURE) 90 | { 91 | while(f >= MAX_FEATURE) 92 | MAX_FEATURE += FEATURE_INCREASE; 93 | float[] tmp = new float [MAX_FEATURE]; 94 | System.arraycopy(fVals, 0, tmp, 0, fVals.length); 95 | Arrays.fill(tmp, fVals.length, MAX_FEATURE, UNKNOWN); 96 | fVals = tmp; 97 | } 98 | fVals[f] = Float.parseFloat(val); 99 | 100 | if(f > featureCount)//#feature will be the max_id observed 101 | featureCount = f; 102 | 103 | if(f > lastFeature)//note that lastFeature is the max_id observed for this current data point, whereas featureCount is the max_id observed on the entire dataset 104 | lastFeature = f; 105 | } 106 | //shrink fVals 107 | float[] tmp = new float[lastFeature+1]; 108 | System.arraycopy(fVals, 0, tmp, 0, lastFeature+1); 109 | fVals = tmp; 110 | } 111 | catch(Exception ex) 112 | { 113 | throw RankLibError.create("Error in DataPoint::parse()", ex); 114 | } 115 | return fVals; 116 | } 117 | 118 | /** 119 | * Get the value of the feature with the given feature ID 120 | * @param fid 121 | * @return 122 | */ 123 | public abstract float getFeatureValue(int fid); 124 | 125 | /** 126 | * Set the value of the feature with the given feature ID 127 | * @param fid 128 | * @param fval 129 | */ 130 | public abstract void setFeatureValue(int fid, float fval); 131 | 132 | /** 133 | * Sets the value of all features with the provided dense array of feature values 134 | */ 135 | public abstract void setFeatureVector(float[] dfVals); 136 | 137 | /** 138 | * Gets the value of all features as a dense array of feature values. 139 | */ 140 | public abstract float[] getFeatureVector(); 141 | 142 | /** 143 | * Default constructor. No-op. 144 | */ 145 | protected DataPoint() {}; 146 | 147 | /** 148 | * The input must have the form: 149 | * @param text 150 | */ 151 | protected DataPoint(String text) 152 | { 153 | float[] fVals = parse(text); 154 | setFeatureVector(fVals); 155 | } 156 | 157 | public String getID() 158 | { 159 | return id; 160 | } 161 | public void setID(String id) 162 | { 163 | this.id = id; 164 | } 165 | public float getLabel() 166 | { 167 | return label; 168 | } 169 | public void setLabel(float label) 170 | { 171 | this.label = label; 172 | } 173 | public String getDescription() 174 | { 175 | return description; 176 | } 177 | public void setDescription(String description) { 178 | assert(description.contains("#")); 179 | this.description = description; 180 | } 181 | public void setCached(double c) 182 | { 183 | cached = c; 184 | } 185 | public double getCached() 186 | { 187 | return cached; 188 | 189 | } 190 | public void resetCached() 191 | { 192 | cached = -100000000.0f;; 193 | } 194 | 195 | public String toString() 196 | { 197 | float[] fVals = getFeatureVector(); 198 | String output = ((int)label) + " " + "qid:" + id + " "; 199 | for(int i=1;i outputs = null; 29 | protected double delta_i = 0.0; 30 | protected double[] deltas_j = null; 31 | 32 | protected List inLinks = null; 33 | protected List outLinks = null; 34 | 35 | public Neuron() 36 | { 37 | output = 0.0; 38 | inLinks = new ArrayList(); 39 | outLinks = new ArrayList(); 40 | 41 | outputs = new ArrayList(); 42 | delta_i = 0.0; 43 | } 44 | public double getOutput() 45 | { 46 | return output; 47 | } 48 | public double getOutput(int k) 49 | { 50 | return outputs.get(k); 51 | } 52 | public List getInLinks() 53 | { 54 | return inLinks; 55 | } 56 | public List getOutLinks() 57 | { 58 | return outLinks; 59 | } 60 | public void setOutput(double output) 61 | { 62 | this.output = output; 63 | } 64 | public void addOutput(double output) 65 | { 66 | outputs.add(output); 67 | } 68 | public void computeOutput() 69 | { 70 | Synapse s = null; 71 | double wsum = 0.0; 72 | for(int j=0;j 1 ==> each bag will contain an ensemble of gradient boosted trees. 32 | public static int nTreeLeaves = 100; 33 | public static float learningRate = 0.1F;//or shrinkage. *ONLY* matters if nTrees > 1. 34 | public static int nThreshold = 256; 35 | public static int minLeafSupport = 1; 36 | 37 | //Variables 38 | protected Ensemble[] ensembles = null;//bag of ensembles, each can be a single tree or an ensemble of gradient boosted trees 39 | 40 | public RFRanker() 41 | { 42 | } 43 | public RFRanker(List samples, int[] features, MetricScorer scorer) 44 | { 45 | super(samples, features, scorer); 46 | } 47 | 48 | public void init() 49 | { 50 | PRINT("Initializing... "); 51 | ensembles = new Ensemble[nBag]; 52 | //initialize parameters for the tree(s) built in each bag 53 | LambdaMART.nTrees = nTrees; 54 | LambdaMART.nTreeLeaves = nTreeLeaves; 55 | LambdaMART.learningRate = learningRate; 56 | LambdaMART.nThreshold = nThreshold; 57 | LambdaMART.minLeafSupport = minLeafSupport; 58 | LambdaMART.nRoundToStopEarly = -1;//no early-stopping since we're doing bagging 59 | //turn on feature sampling 60 | FeatureHistogram.samplingRate = featureSamplingRate; 61 | PRINTLN("[Done]"); 62 | } 63 | public void learn() 64 | { 65 | RankerFactory rf = new RankerFactory(); 66 | PRINTLN("------------------------------------"); 67 | PRINTLN("Training starts..."); 68 | PRINTLN("------------------------------------"); 69 | PRINTLN(new int[]{9, 9, 11}, new String[]{"bag", scorer.name()+"-B", scorer.name()+"-OOB"}); 70 | PRINTLN("------------------------------------"); 71 | //start the bagging process 72 | for(int i=0;i bag = sp.doSampling(samples, subSamplingRate, true); 79 | //"out-of-bag" samples 80 | //List outOfBag = sp.getRemains(); 81 | LambdaMART r = (LambdaMART)rf.createRanker(rType, bag, features, scorer); 82 | //r.setValidationSet(outOfBag); 83 | 84 | boolean tmp = Ranker.verbose; 85 | Ranker.verbose = false;//turn of the progress messages from training this ranker 86 | r.init(); 87 | r.learn(); 88 | Ranker.verbose = tmp; 89 | //PRINTLN(new int[]{9, 9, 11}, new String[]{"b["+(i+1)+"]", SimpleMath.round(r.getScoreOnTrainingData(), 4)+"", SimpleMath.round(r.getScoreOnValidationData(), 4)+""}); 90 | PRINTLN(new int[]{9, 9}, new String[]{"b["+(i+1)+"]", SimpleMath.round(r.getScoreOnTrainingData(), 4)+""}); 91 | ensembles[i] = r.getEnsemble(); 92 | } 93 | //Finishing up 94 | scoreOnTrainingData = scorer.score(rank(samples)); 95 | PRINTLN("------------------------------------"); 96 | PRINTLN("Finished sucessfully."); 97 | PRINTLN(scorer.name() + " on training data: " + SimpleMath.round(scoreOnTrainingData, 4)); 98 | if(validationSamples != null) 99 | { 100 | bestScoreOnValidationData = scorer.score(rank(validationSamples)); 101 | PRINTLN(scorer.name() + " on validation data: " + SimpleMath.round(bestScoreOnValidationData, 4)); 102 | } 103 | PRINTLN("------------------------------------"); 104 | } 105 | public double eval(DataPoint dp) 106 | { 107 | double s = 0; 108 | for(int i=0;i ens = new ArrayList(); 145 | while((content = in.readLine()) != null) 146 | { 147 | content = content.trim(); 148 | if(content.length() == 0) 149 | continue; 150 | if(content.indexOf("##")==0) 151 | continue; 152 | //actual model component 153 | model += content; 154 | if(content.indexOf("") != -1) 155 | { 156 | //load the ensemble 157 | ens.add(new Ensemble(model)); 158 | model = ""; 159 | } 160 | } 161 | in.close(); 162 | HashSet uniqueFeatures = new HashSet(); 163 | ensembles = new Ensemble[ens.size()]; 164 | for(int i=0;i samples, int[] features, MetricScorer scorer) 33 | { 34 | super(samples, features, scorer); 35 | } 36 | public void init() 37 | { 38 | PRINTLN("Initializing... [Done]"); 39 | } 40 | public void learn() 41 | { 42 | PRINTLN("--------------------------------"); 43 | PRINTLN("Training starts..."); 44 | PRINTLN("--------------------------------"); 45 | PRINT("Learning the least square model... "); 46 | 47 | //closed form solution: beta = ((xTx - lambda*I)^(-1)) * (xTy) 48 | //where x is an n-by-f matrix (n=#data-points, f=#features), y is an n-element vector of relevance labels 49 | /*int nSample = 0; 50 | for(int i=0;i keys = kvp.keys(); 150 | List values = kvp.values(); 151 | weight = new double[keys.size()]; 152 | features = new int[keys.size()-1];//weight = 153 | int idx = 0; 154 | for(int i=0;i 0) 158 | { 159 | features[idx] = fid; 160 | weight[idx] = Double.parseDouble(values.get(i)); 161 | idx++; 162 | } 163 | else 164 | weight[weight.length-1] = Double.parseDouble(values.get(i)); 165 | } 166 | } 167 | catch(Exception ex) 168 | { 169 | throw RankLibError.create("Error in LinearRegRank::load(): ", ex); 170 | } 171 | } 172 | public void printParameters() 173 | { 174 | PRINTLN("L2-norm regularization: lambda = " + lambda); 175 | } 176 | public String name() 177 | { 178 | return "Linear Regression"; 179 | } 180 | /** 181 | * Solve a system of linear equations Ax=B, in which A has to be a square matrix with the same length as B 182 | * @param A 183 | * @param B 184 | * @return x 185 | */ 186 | protected double[] solve(double[][] A, double[] B) 187 | { 188 | if(A.length == 0 || B.length == 0) 189 | { 190 | System.out.println("Error: some of the input arrays is empty."); 191 | System.exit(1); 192 | } 193 | if(A[0].length == 0) 194 | { 195 | System.out.println("Error: some of the input arrays is empty."); 196 | System.exit(1); 197 | } 198 | if(A.length != B.length) 199 | { 200 | System.out.println("Error: Solving Ax=B: A and B have different dimension."); 201 | System.exit(1); 202 | } 203 | 204 | //init 205 | double[][] a = new double[A.length][]; 206 | double[] b = new double[B.length]; 207 | System.arraycopy(B, 0, b, 0, B.length); 208 | for(int i=0;i 0) 212 | { 213 | if(a[i].length != a[i-1].length) 214 | { 215 | System.out.println("Error: Solving Ax=B: A is NOT a square matrix."); 216 | System.exit(1); 217 | } 218 | } 219 | System.arraycopy(A[i], 0, a[i], 0, A[i].length); 220 | } 221 | //apply the gaussian elimination process to convert the matrix A to upper triangular form 222 | double pivot = 0.0; 223 | double multiplier = 0.0; 224 | for(int j=0;j=0;i--)//walk back up to the first row -- we only need to care about the right to the diagonal 242 | { 243 | double val = b[i]; 244 | for(int j=i+1;j