├── .gitignore ├── LICENSE ├── README.md ├── pom.xml └── src ├── main └── java │ └── it │ └── uniroma2 │ └── sag │ └── kelp │ ├── data │ └── examplegenerator │ │ ├── SequenceExampleGenerator.java │ │ ├── SequenceExampleGeneratorKernel.java │ │ ├── SequenceExampleGeneratorLinear.java │ │ └── SequenceExampleGeneratorTypeResolver.java │ ├── learningalgorithm │ ├── MultiEpochLearning.java │ ├── PassiveAggressive.java │ ├── budgetedAlgorithm │ │ ├── BudgetedLearningAlgorithm.java │ │ ├── RandomizedBudgetPerceptron.java │ │ └── Stoptron.java │ ├── classification │ │ ├── dcd │ │ │ ├── DCDLearningAlgorithm.java │ │ │ └── DCDLoss.java │ │ ├── hmm │ │ │ ├── SequenceClassificationKernelBasedLearningAlgorithm.java │ │ │ ├── SequenceClassificationLearningAlgorithm.java │ │ │ └── SequenceClassificationLinearLearningAlgorithm.java │ │ ├── liblinear │ │ │ ├── LibLinearLearningAlgorithm.java │ │ │ └── solver │ │ │ │ ├── COPYRIGHT │ │ │ │ ├── L2R_L2_SvcFunction.java │ │ │ │ ├── L2R_L2_SvrFunction.java │ │ │ │ ├── LibLinearFeature.java │ │ │ │ ├── LibLinearFeatureNode.java │ │ │ │ ├── Problem.java │ │ │ │ ├── Tron.java │ │ │ │ └── TronFunction.java │ │ ├── passiveaggressive │ │ │ ├── BudgetedPassiveAggressiveClassification.java │ │ │ ├── KernelizedPassiveAggressiveClassification.java │ │ │ ├── LinearPassiveAggressiveClassification.java │ │ │ └── PassiveAggressiveClassification.java │ │ ├── pegasos │ │ │ └── PegasosLearningAlgorithm.java │ │ ├── perceptron │ │ │ ├── KernelizedPerceptron.java │ │ │ ├── LinearPerceptron.java │ │ │ └── Perceptron.java │ │ ├── probabilityestimator │ │ │ └── platt │ │ │ │ ├── BinaryPlattNormalizer.java │ │ │ │ ├── MulticlassPlattNormalizer.java │ │ │ │ ├── PlattInputElement.java │ │ │ │ ├── PlattInputList.java │ │ │ │ └── PlattMethod.java │ │ └── scw │ │ │ ├── SCWType.java │ │ │ └── SoftConfidenceWeightedClassification.java │ ├── clustering │ │ └── kernelbasedkmeans │ │ │ ├── KernelBasedKMeansEngine.java │ │ │ └── KernelBasedKMeansExample.java │ └── regression │ │ ├── liblinear │ │ └── LibLinearRegression.java │ │ └── passiveaggressive │ │ ├── KernelizedPassiveAggressiveRegression.java │ │ ├── LinearPassiveAggressiveRegression.java │ │ └── PassiveAggressiveRegression.java │ ├── linearization │ ├── LinearizationFunction.java │ └── nystrom │ │ ├── NystromMethod.java │ │ └── NystromMethodEnsemble.java │ ├── predictionfunction │ ├── SequencePrediction.java │ ├── SequencePredictionFunction.java │ └── model │ │ └── SequenceModel.java │ └── utils │ └── evaluation │ ├── ClusteringEvaluator.java │ └── MulticlassSequenceClassificationEvaluator.java └── test ├── java └── it │ └── uniroma2 │ └── sag │ └── kelp │ ├── algorithms │ ├── binary │ │ └── liblinear │ │ │ └── LibLinearDenseVsSparseClassificationEvaluator.java │ └── incrementalTrain │ │ └── IncrementalTrainTest.java │ └── learningalgorithm │ └── classification │ └── hmm │ ├── SequenceLearningKernelTest.java │ └── SequenceLearningLinearTest.java └── resources ├── sequence_learning ├── README.txt ├── declaration_of_independence.klp.gz ├── gettysburg_address.klp.gz ├── prediction_test_kernel.txt └── prediction_test_linear.txt └── svmTest └── binary ├── binary_test.klp ├── binary_train.klp └── liblinear └── polarity_sparse_dense_repr.txt.gz /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | 3 | # Mobile Tools for Java (J2ME) 4 | .mtj.tmp/ 5 | 6 | # Package Files # 7 | *.jar 8 | *.war 9 | *.ear 10 | .settings 11 | .project 12 | .classpath 13 | target 14 | 15 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 16 | hs_err_pid* 17 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 4.0.0 3 | it.uniroma2.sag.kelp 4 | kelp-additional-algorithms 5 | 2.2.4-SNAPSHOT 6 | kelp-additional-algorithms 7 | http://www.kelp-ml.org 8 | 9 | Semantic Analytics Group @ Uniroma2 10 | http://sag.art.uniroma2.it 11 | 12 | 13 | 14 | sfilice 15 | Simone Filice 16 | simone.filice@gmail.com 17 | http://sag.art.uniroma2.it/people/filice/ 18 | SAG group @University of Roma Tor Vergata 19 | http://sag.art.uniroma2.it/ 20 | 21 | 22 | dcroce 23 | Danilo Croce 24 | croce@info.uniroma2.it 25 | http://sag.art.uniroma2.it/people/croce/ 26 | SAG group @University of Roma Tor Vergata 27 | http://sag.art.uniroma2.it/ 28 | 29 | 30 | gcastellucci 31 | Giuseppe Castellucci 32 | castellucci.giuseppe@gmail.com 33 | http://sag.art.uniroma2.it/people/castellucci/ 34 | SAG group @University of Roma Tor Vergata 35 | http://sag.art.uniroma2.it/ 36 | 37 | 38 | 39 | https://github.com/SAG-KeLP/kelp-additional-algorithms 40 | scm:git:https://github.com/SAG-KeLP/kelp-additional-algorithms.git 41 | scm:git:https://github.com/SAG-KeLP/kelp-additional-algorithms.git 42 | HEAD 43 | 44 | 45 | 46 | kelp_repo_release 47 | Sag Libs Repository Stable 48 | http://sag.art.uniroma2.it:8081/artifactory/kelp-release/ 49 | 50 | 51 | kelp_repo_snap 52 | Sag Libs Repository Snapshots 53 | http://sag.art.uniroma2.it:8081/artifactory/kelp-snapshot/ 54 | 55 | 56 | 57 | 58 | kelp_repo_snap 59 | Sag Libs Repository Snapshots 60 | 61 | false 62 | always 63 | warn 64 | 65 | 66 | true 67 | always 68 | fail 69 | 70 | http://sag.art.uniroma2.it:8081/artifactory/kelp-snapshot/ 71 | 72 | 73 | kelp_repo_release 74 | Sag Libs Repository Stable 75 | 76 | true 77 | always 78 | warn 79 | 80 | 81 | false 82 | always 83 | fail 84 | 85 | http://sag.art.uniroma2.it:8081/artifactory/kelp-release/ 86 | 87 | 88 | 89 | 90 | it.uniroma2.sag.kelp 91 | kelp-core 92 | ${project.version} 93 | 94 | 95 | org.apache.commons 96 | commons-math3 97 | 3.2 98 | 99 | 100 | 101 | 102 | 103 | src/main/resources 104 | 105 | **/* 106 | 107 | 108 | 109 | src/test/resources 110 | 111 | **/* 112 | 113 | 114 | 115 | 116 | 117 | org.apache.maven.plugins 118 | maven-compiler-plugin 119 | 3.1 120 | 121 | 1.6 122 | 1.6 123 | 124 | 125 | 126 | maven-source-plugin 127 | 2.1.1 128 | 129 | 130 | bundle-sources 131 | package 132 | 133 | jar-no-fork 134 | 135 | 136 | 137 | 138 | 139 | external.atlassian.jgitflow 140 | jgitflow-maven-plugin 141 | 1.0-m5.1 142 | 143 | https://github.com/SAG-KeLP/kelp-additional-algorithms.git 144 | 145 | master 146 | development 147 | version 148 | 149 | true 150 | true 151 | ${kelp.git.user} 152 | ${kelp.git.password} 153 | ${developmentVersion} 154 | ${releaseVersion} 155 | 156 | 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/data/examplegenerator/SequenceExampleGenerator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.data.examplegenerator; 17 | 18 | import it.uniroma2.sag.kelp.data.example.Example; 19 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 20 | import it.uniroma2.sag.kelp.data.example.SequencePath; 21 | 22 | import java.io.Serializable; 23 | 24 | import com.fasterxml.jackson.annotation.JsonTypeInfo; 25 | import com.fasterxml.jackson.databind.annotation.JsonTypeIdResolver; 26 | 27 | /** 28 | * A SequenceExampleGenerator generates a copy of an input 29 | * Example (reflecting an item in a SequenceExample) 30 | * enriched with information derived from the s assigned to the previous 31 | * n examples. 32 | * This allows the SequenceClassificationLearningAlgorithm to learn 33 | * from the observations that are derived from a targeted example, but also from 34 | * its history, in terms of labels assigned to previous examples. 35 | * 36 | * @author Danilo Croce 37 | * 38 | */ 39 | @JsonTypeInfo(use = JsonTypeInfo.Id.CUSTOM, include = JsonTypeInfo.As.PROPERTY, property = "sequenceExamplesGeneratorType") 40 | @JsonTypeIdResolver(SequenceExampleGeneratorTypeResolver.class) 41 | public interface SequenceExampleGenerator extends Serializable{ 42 | 43 | /** 44 | * At labeling time, this method allows to enrich a specific 45 | * Example with the labels assigned by the classifier to the 46 | * previous Examples 47 | * 48 | * @param sequenceExample 49 | * The targeted sequence 50 | * @param sequencePath 51 | * the sequence of Label assigned from a classifier 52 | * to the SequenceExamlpe 53 | * @param offset 54 | * the offset of the targeted word in the sequence 55 | * @return 56 | */ 57 | public Example generateExampleWithHistory(SequenceExample sequenceExample, SequencePath sequencePath, int offset); 58 | 59 | /** 60 | * This method allows to enrich each Example from an input 61 | * SequenceExample with the labels assigned by the classifier 62 | * to the previous Examples 63 | * 64 | * @param sequenceExample 65 | * The input sequence 66 | * 67 | * @return 68 | */ 69 | public SequenceExample generateSequenceExampleEnrichedWithHistory(SequenceExample sequenceExample); 70 | 71 | /** 72 | * @return the number n of elements (in the sequence) whose 73 | * labels are to be considered to enrich a targeted element 74 | */ 75 | public int getTransitionsOrder(); 76 | } 77 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/data/examplegenerator/SequenceExampleGeneratorKernel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.data.examplegenerator; 17 | 18 | import java.io.IOException; 19 | 20 | import com.fasterxml.jackson.annotation.JsonTypeName; 21 | 22 | import it.uniroma2.sag.kelp.data.example.Example; 23 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 24 | import it.uniroma2.sag.kelp.data.example.SequencePath; 25 | import it.uniroma2.sag.kelp.data.representation.Representation; 26 | import it.uniroma2.sag.kelp.data.representation.vector.SparseVector; 27 | 28 | /** 29 | * A SequenceExampleGeneratorKernelBasedAlg allows to 30 | * implicitly enrich a targeted Example (reflecting an item 31 | * in a SequenceExample) with information derived from the s 32 | * assigned to the previous n examples. 33 | * 34 | * 35 | * 36 | * 37 | * Given a representation used to represent an example, this 38 | * SequenceExampleGenerator generates a new example containing only 39 | * feature reflecting the classes assigned to the previous examples in the 40 | * sequence. 41 | * 42 | * 43 | * 44 | * This class should be used when the kernel based learning algorithm is used. 45 | * While a kernel function operates on the original representation, an 46 | * additional kernel function should operate on the additional representation 47 | * generated by this example generator. 48 | * 49 | * @author Danilo Croce 50 | * 51 | */ 52 | @JsonTypeName("se_gen_kb") 53 | public class SequenceExampleGeneratorKernel implements SequenceExampleGenerator { 54 | 55 | /** 56 | * 57 | */ 58 | private static final long serialVersionUID = -6307705638697955450L; 59 | 60 | private String transitionRepresentationName; 61 | 62 | private int transitionsOrder; 63 | 64 | public SequenceExampleGeneratorKernel() { 65 | } 66 | 67 | public SequenceExampleGeneratorKernel(int transitionsOrder, String transitionRepresentationName) { 68 | this.transitionRepresentationName = transitionRepresentationName; 69 | this.transitionsOrder = transitionsOrder; 70 | } 71 | 72 | public Example generateExampleWithHistory(SequenceExample sequenceExample, SequencePath p, int elementId) { 73 | 74 | Example innerExample = sequenceExample.getExample(elementId); 75 | 76 | Example enrichedObservedExample = innerExample.duplicate(); 77 | 78 | String transitionString = p.getHistoryBefore(elementId, transitionsOrder); 79 | 80 | Representation enrichedObservationRepresentation = generateManipulatedRepresentation(transitionString); 81 | /* 82 | * Enrich the observed representation with the previous transition 83 | */ 84 | enrichedObservedExample.addRepresentation(transitionRepresentationName, enrichedObservationRepresentation); 85 | 86 | return enrichedObservedExample; 87 | } 88 | 89 | private Representation generateManipulatedRepresentation(String transitionString) { 90 | try { 91 | SparseVector newRepresentation = new SparseVector(); 92 | if (transitionString.trim().length() > 0) 93 | newRepresentation.setDataFromText(transitionString + ":1.0"); 94 | return newRepresentation; 95 | } catch (IOException e) { 96 | e.printStackTrace(); 97 | return null; 98 | } 99 | 100 | } 101 | 102 | @Override 103 | public SequenceExample generateSequenceExampleEnrichedWithHistory(SequenceExample sequenceExample) { 104 | SequenceExample res = (SequenceExample) (sequenceExample.duplicate()); 105 | 106 | for (int elementId = 0; elementId < res.getLenght(); elementId++) { 107 | 108 | Example e = res.getExample(elementId); 109 | 110 | String transitionString = new String(); 111 | for (int j = elementId - transitionsOrder; j < elementId; j++) { 112 | if (j < 0) { 113 | transitionString += SequenceExample.SEQDELIM + j + "init"; 114 | } else { 115 | Example ej = sequenceExample.getExample(j); 116 | transitionString += SequenceExample.SEQDELIM + ej.getClassificationLabels().iterator().next(); 117 | } 118 | } 119 | Representation newRepresentation = generateManipulatedRepresentation(transitionString); 120 | e.getRepresentations().put(transitionRepresentationName, newRepresentation); 121 | } 122 | 123 | return res; 124 | } 125 | 126 | public String getTransitionRepresentationName() { 127 | return transitionRepresentationName; 128 | } 129 | 130 | @Override 131 | public int getTransitionsOrder() { 132 | return transitionsOrder; 133 | } 134 | 135 | public void setTransitionRepresentationName(String transitionRepresentationName) { 136 | this.transitionRepresentationName = transitionRepresentationName; 137 | } 138 | 139 | public void setTransitionsOrder(int transitionsOrder) { 140 | this.transitionsOrder = transitionsOrder; 141 | } 142 | 143 | } 144 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/data/examplegenerator/SequenceExampleGeneratorLinear.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.data.examplegenerator; 17 | 18 | import com.fasterxml.jackson.annotation.JsonTypeName; 19 | 20 | import it.uniroma2.sag.kelp.data.example.Example; 21 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 22 | import it.uniroma2.sag.kelp.data.example.SequencePath; 23 | import it.uniroma2.sag.kelp.data.example.SimpleExample; 24 | import it.uniroma2.sag.kelp.data.representation.Representation; 25 | import it.uniroma2.sag.kelp.data.representation.Vector; 26 | import it.uniroma2.sag.kelp.data.representation.vector.SparseVector; 27 | 28 | /** 29 | * A SequenceExampleGeneratorLinearAlg allows to explicitly 30 | * enrich a targeted Example (reflecting an item in a 31 | * SequenceExample) with information derived from the s 32 | * assigned to the previous n examples. 33 | * 34 | * 35 | * 36 | * 37 | * Given a representation used to represent an example, this class generates a 38 | * copy of an input Example enriched with additional features 39 | * reflecting the classes assigned to the previous examples in the sequence. 40 | * 41 | * 42 | * 43 | * This class should be used when the learning algorithm used within the 44 | * SequenceClassificationLearningAlgorithm implements the 45 | * LinearMethod interface. 46 | * 47 | * @author Danilo Croce 48 | * 49 | */ 50 | @JsonTypeName("se_gen_lin") 51 | public class SequenceExampleGeneratorLinear implements SequenceExampleGenerator { 52 | 53 | /** 54 | * 55 | */ 56 | private static final long serialVersionUID = -6889374446991783357L; 57 | 58 | /** 59 | * The identifier of the representation used to represent an example in the 60 | * sequence and which will be enriched 61 | */ 62 | private String originalRepresentationName; 63 | 64 | /** 65 | * The number of examples preceding a target example to be considered during 66 | * the manipulation process 67 | */ 68 | private int transitionsOrder; 69 | 70 | /** 71 | * The weight to assign to each new feature added in the manipulation 72 | * process 73 | */ 74 | private float transitionWeight; 75 | 76 | public SequenceExampleGeneratorLinear() { 77 | 78 | } 79 | 80 | /** 81 | * @param transitionsOrder 82 | * The number of examples preceding a target example to be 83 | * considered during the manipulation process 84 | * @param originalRepresentationName 85 | * The identifier of the representation used to represent an 86 | * example in the sequence and which will be enriched 87 | * @param enrichedWithHistoryRepresentationName 88 | * The identifier of the new representation produced in the 89 | * manipulation process 90 | * @param transitionWeight 91 | * The weight to assign to each new feature added in the 92 | * manipulation process 93 | */ 94 | public SequenceExampleGeneratorLinear(int transitionsOrder, String originalRepresentationName, 95 | float transitionWeight) { 96 | this.originalRepresentationName = originalRepresentationName; 97 | this.transitionsOrder = transitionsOrder; 98 | this.transitionWeight = transitionWeight; 99 | } 100 | 101 | public Example generateExampleWithHistory(SequenceExample sequenceExample, SequencePath p, int elementId) { 102 | 103 | Example innerExample = sequenceExample.getExample(elementId); 104 | Representation observationRepresentation = innerExample.getRepresentation(originalRepresentationName); 105 | 106 | Example enrichedObservedExample = new SimpleExample(); 107 | 108 | if (transitionsOrder > 0) { 109 | String transitionString = p.getHistoryBefore(elementId, transitionsOrder); 110 | 111 | Representation enrichedObservationRepresentation = generateManipulatedRepresentation( 112 | observationRepresentation, transitionString); 113 | /* 114 | * Enrich the observed representation with the previous transition 115 | */ 116 | enrichedObservedExample.addRepresentation(originalRepresentationName, enrichedObservationRepresentation); 117 | } else { 118 | enrichedObservedExample.addRepresentation(originalRepresentationName, observationRepresentation); 119 | } 120 | 121 | return enrichedObservedExample; 122 | } 123 | 124 | /** 125 | * Given the representation of a targeted example and a string containing 126 | * the sequence of labels assigned to the previous examples, this method 127 | * produces a new representation with additional features reflecting the 128 | * sequence of labels 129 | * 130 | * @param representation 131 | * the name of the targeted representation 132 | * @param transitionString 133 | * the string a string containing the sequence of labels assigned 134 | * to the previous examples 135 | * @return the new enriched representation 136 | */ 137 | private Representation generateManipulatedRepresentation(Representation representation, String transitionString) { 138 | Representation newRepresentation = null; 139 | 140 | if (representation instanceof SparseVector) { 141 | try { 142 | newRepresentation = new SparseVector(); 143 | newRepresentation.setDataFromText( 144 | representation.toString().trim() + " " + transitionString + ":" + transitionWeight); 145 | return newRepresentation; 146 | } catch (Exception e1) { 147 | e1.printStackTrace(); 148 | return null; 149 | } 150 | } else { 151 | System.err.println("Warning: SequenceExampleGeneratorLinearAlg only work on SparseVector... now "); 152 | return null; 153 | } 154 | } 155 | 156 | public SequenceExample generateSequenceExampleEnrichedWithHistory(SequenceExample sequenceExample) { 157 | 158 | SequenceExample res = (SequenceExample) sequenceExample.duplicate(); 159 | 160 | for (int elementId = 0; elementId < res.getLenght(); elementId++) { 161 | 162 | Example e = res.getExample(elementId); 163 | 164 | Representation newRepresentation = null; 165 | Vector vector = (Vector) e.getRepresentation(originalRepresentationName); 166 | if (transitionsOrder > 0) { 167 | String transitionString = new String(); 168 | for (int j = elementId - transitionsOrder; j < elementId; j++) { 169 | if (j < 0) { 170 | transitionString += SequenceExample.SEQDELIM + j + "init"; 171 | } else { 172 | Example ej = res.getExample(j); 173 | transitionString += SequenceExample.SEQDELIM + ej.getClassificationLabels().iterator().next(); 174 | } 175 | } 176 | newRepresentation = generateManipulatedRepresentation(vector, transitionString); 177 | } else { 178 | newRepresentation = vector.copyVector(); 179 | } 180 | 181 | e.addRepresentation(originalRepresentationName, newRepresentation); 182 | } 183 | 184 | return res; 185 | } 186 | 187 | /** 188 | * @return The identifier of the representation used to represent an example 189 | * in the sequence and which will be enriched 190 | */ 191 | public String getRepresentationName() { 192 | return originalRepresentationName; 193 | } 194 | 195 | public int getTransitionsOrder() { 196 | return transitionsOrder; 197 | } 198 | 199 | /** 200 | * @return The weight to assign to each new feature added in the 201 | * manipulation process 202 | */ 203 | public float getTransitionWeight() { 204 | return transitionWeight; 205 | } 206 | 207 | /** 208 | * @param representationName 209 | * The identifier of the representation used to represent a 210 | * example in the sequence and which will be enriched 211 | */ 212 | public void setRepresentationName(String representationName) { 213 | this.originalRepresentationName = representationName; 214 | } 215 | 216 | } 217 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/data/examplegenerator/SequenceExampleGeneratorTypeResolver.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.data.examplegenerator; 17 | 18 | import java.lang.reflect.Modifier; 19 | import java.util.HashMap; 20 | import java.util.Map; 21 | import java.util.Set; 22 | 23 | import org.reflections.Reflections; 24 | import org.slf4j.Logger; 25 | import org.slf4j.LoggerFactory; 26 | 27 | import com.fasterxml.jackson.annotation.JsonTypeInfo.Id; 28 | import com.fasterxml.jackson.annotation.JsonTypeName; 29 | import com.fasterxml.jackson.databind.DatabindContext; 30 | import com.fasterxml.jackson.databind.JavaType; 31 | import com.fasterxml.jackson.databind.jsontype.TypeIdResolver; 32 | 33 | /** 34 | * It is a class implementing TypeIdResolver which will be used by 35 | * Jackson library during the serialization in JSON and deserialization of 36 | * SequenceExamplesGenerators 37 | * 38 | * @author Simone Filice 39 | * 40 | */ 41 | public class SequenceExampleGeneratorTypeResolver implements TypeIdResolver { 42 | private static Logger logger = LoggerFactory.getLogger(SequenceExampleGeneratorTypeResolver.class); 43 | 44 | private static Map> idToClassMapping; 45 | private static Map, String> classToIdMapping; 46 | 47 | static { 48 | Reflections reflections = new Reflections("it"); 49 | idToClassMapping = new HashMap>(); 50 | classToIdMapping = new HashMap, String>(); 51 | Set> classes = reflections 52 | .getSubTypesOf(SequenceExampleGenerator.class); 53 | for (Class extends SequenceExampleGenerator> clazz : classes) { 54 | if (Modifier.isAbstract(clazz.getModifiers())) { 55 | continue; 56 | } 57 | String abbreviation; 58 | if (clazz.isAnnotationPresent(JsonTypeName.class)) { 59 | JsonTypeName info = clazz.getAnnotation(JsonTypeName.class); 60 | abbreviation = info.value(); 61 | 62 | } else { 63 | abbreviation = clazz.getSimpleName(); 64 | } 65 | idToClassMapping.put(abbreviation, clazz); 66 | classToIdMapping.put(clazz, abbreviation); 67 | } 68 | logger.debug("Label Implementations: {}", idToClassMapping); 69 | } 70 | 71 | private JavaType mBaseType; 72 | 73 | @Override 74 | public Id getMechanism() { 75 | return Id.CUSTOM; 76 | } 77 | 78 | @Override 79 | public String idFromBaseType() { 80 | return idFromValueAndType(null, mBaseType.getRawClass()); 81 | } 82 | 83 | @Override 84 | public String idFromValue(Object obj) { 85 | return idFromValueAndType(obj, obj.getClass()); 86 | } 87 | 88 | @Override 89 | public String idFromValueAndType(Object arg0, Class> arg1) { 90 | return classToIdMapping.get(arg0.getClass()); 91 | } 92 | 93 | @Override 94 | public void init(JavaType arg0) { 95 | mBaseType = arg0; 96 | } 97 | 98 | @Override 99 | public JavaType typeFromId(DatabindContext context, String arg0) { 100 | 101 | Class extends SequenceExampleGenerator> clazz = idToClassMapping.get(arg0); 102 | if (clazz != null) { 103 | JavaType type = context.constructSpecializedType(mBaseType, clazz); 104 | return type; 105 | } 106 | throw new IllegalStateException("cannot find mapping for '" + arg0 + "'"); 107 | } 108 | 109 | @Override 110 | public String getDescForKnownTypeIds() { 111 | // TODO Auto-generated method stub 112 | return null; 113 | } 114 | 115 | } 116 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/MultiEpochLearning.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.label.Label; 20 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 21 | 22 | import java.util.List; 23 | 24 | import com.fasterxml.jackson.annotation.JsonIgnore; 25 | import com.fasterxml.jackson.annotation.JsonTypeName; 26 | 27 | /** 28 | * It is a meta learning algorithms for online learning methods. It performs 29 | * multiple iterations on the training data 30 | * 31 | * @author Simone Filice 32 | * 33 | */ 34 | @JsonTypeName("multiEpoch") 35 | public class MultiEpochLearning implements MetaLearningAlgorithm{ 36 | 37 | private LearningAlgorithm baseAlgorithm; 38 | private int epochs; 39 | 40 | public MultiEpochLearning(){ 41 | 42 | } 43 | 44 | public MultiEpochLearning(int epochs, LearningAlgorithm baseAlgorithm, List labels){ 45 | this.setEpochs(epochs); 46 | this.setBaseAlgorithm(baseAlgorithm); 47 | this.setLabels(labels); 48 | } 49 | 50 | public MultiEpochLearning(int epochs, LearningAlgorithm baseAlgorithm){ 51 | this.setEpochs(epochs); 52 | this.setBaseAlgorithm(baseAlgorithm); 53 | } 54 | 55 | @Override 56 | public void setBaseAlgorithm(LearningAlgorithm baseAlgorithm) { 57 | this.baseAlgorithm=baseAlgorithm; 58 | } 59 | 60 | @Override 61 | public LearningAlgorithm getBaseAlgorithm() { 62 | return this.baseAlgorithm; 63 | } 64 | 65 | /** 66 | * @return the number of epochs 67 | */ 68 | public int getEpochs() { 69 | return epochs; 70 | } 71 | 72 | /** 73 | * @param epochs the number of epochs to set 74 | */ 75 | public void setEpochs(int epochs) { 76 | this.epochs = epochs; 77 | } 78 | 79 | @Override 80 | public void learn(Dataset dataset) { 81 | 82 | for(int i=0; i labels){ 91 | this.baseAlgorithm.setLabels(labels); 92 | } 93 | 94 | @Override 95 | @JsonIgnore 96 | public List getLabels() { 97 | return this.baseAlgorithm.getLabels(); 98 | } 99 | 100 | @Override 101 | public MultiEpochLearning duplicate() { 102 | MultiEpochLearning copy = new MultiEpochLearning(); 103 | copy.epochs=epochs; 104 | copy.setBaseAlgorithm(baseAlgorithm.duplicate()); 105 | return copy; 106 | } 107 | 108 | @Override 109 | public void reset() { 110 | this.baseAlgorithm.reset(); 111 | } 112 | 113 | @Override 114 | public PredictionFunction getPredictionFunction() { 115 | return this.baseAlgorithm.getPredictionFunction(); 116 | } 117 | 118 | @Override 119 | public void setPredictionFunction(PredictionFunction predictionFunction) { 120 | this.baseAlgorithm.setPredictionFunction(predictionFunction); 121 | } 122 | 123 | 124 | 125 | } 126 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/PassiveAggressive.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm; 17 | 18 | import java.util.Arrays; 19 | import java.util.List; 20 | 21 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 22 | import it.uniroma2.sag.kelp.data.example.Example; 23 | import it.uniroma2.sag.kelp.data.label.Label; 24 | import it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm; 25 | 26 | /** 27 | * It is an online learning algorithms that implements the Passive Aggressive algorithms described in 28 | * 29 | * [Crammer, JMLR2006] K. Crammer, O. Dekel, J. Keshet and S. Shalev-Shwartz. Online passive-aggressive algorithms. 30 | * Journal of Machine Learning Research 7:551–585, 2006. 31 | * 32 | * @author Simone Filice 33 | * 34 | */ 35 | public abstract class PassiveAggressive implements OnlineLearningAlgorithm, BinaryLearningAlgorithm{ 36 | 37 | /** 38 | * It is the updating policy applied by the Passive Aggressive Algorithm when a miss-prediction occurs 39 | * 40 | * @author Simone Filice 41 | */ 42 | public enum Policy{ 43 | /** 44 | * The new prediction hypothesis after a new example \( \mathbf{x}_t\) with label \(y_t\) is observed is: 45 | * 46 | * \(argmin_{\mathbf{w}} \frac{1}{2} \left \| \mathbf{w}-\mathbf{w}_t \right \|^2\) 47 | * such that \( l(\mathbf{w};(\mathbf{x}_t,y_t))=0 \) 48 | */ 49 | HARD_PA, 50 | 51 | /** 52 | * The new prediction hypothesis after a new example \( \mathbf{x}_t\) with label \(y_t\) is observed is: 53 | * 54 | * \(argmin_{\mathbf{w}} \frac{1}{2} \left \| \mathbf{w}-\mathbf{w}_t \right \|^2 + C\xi \) 55 | * such that \( l(\mathbf{w};(\mathbf{x}_t,y_t))\leq \xi \) and \( \xi\geq 0\) 56 | */ 57 | PA_I, 58 | 59 | /** 60 | * The new prediction hypothesis after a new example \( \mathbf{x}_t\) with label \(y_t\) is observed is: 61 | * 62 | * \(argmin_{\mathbf{w}} \frac{1}{2} \left \| \mathbf{w}-\mathbf{w}_t \right \|^2 + C\xi^2 \) 63 | * such that \( l(\mathbf{w};(\mathbf{x}_t,y_t))\leq \xi \) and \( \xi\geq 0\) 64 | */ 65 | PA_II 66 | } 67 | 68 | 69 | protected Label label; 70 | 71 | 72 | 73 | protected Policy policy = Policy.PA_II; 74 | 75 | protected float c = 1;//the aggressiveness parameter 76 | 77 | 78 | 79 | @Override 80 | public void reset() { 81 | this.getPredictionFunction().reset(); 82 | } 83 | 84 | 85 | /** 86 | * @return the updating policy 87 | */ 88 | public Policy getPolicy() { 89 | return policy; 90 | } 91 | 92 | 93 | /** 94 | * @param policy the updating policy to set 95 | */ 96 | public void setPolicy(Policy policy) { 97 | this.policy = policy; 98 | } 99 | 100 | 101 | /** 102 | * @return the aggressiveness parameter 103 | */ 104 | public float getC() { 105 | return c; 106 | } 107 | 108 | 109 | /** 110 | * @param c the aggressiveness to set 111 | */ 112 | public void setC(float c) { 113 | this.c = c; 114 | } 115 | 116 | 117 | protected float computeWeight(Example example, float lossValue, float exampleSquaredNorm, float aggressiveness) { 118 | float weight=1; 119 | 120 | switch(policy){ 121 | case HARD_PA: 122 | weight=lossValue/exampleSquaredNorm; 123 | break; 124 | case PA_I: 125 | weight=lossValue/exampleSquaredNorm; 126 | if(weight>aggressiveness){ 127 | weight=aggressiveness; 128 | } 129 | break; 130 | case PA_II: 131 | weight=lossValue/(exampleSquaredNorm+1/(2*aggressiveness)); 132 | break; 133 | } 134 | 135 | return weight; 136 | } 137 | 138 | 139 | @Override 140 | public void setLabels(List labels){ 141 | if(labels.size()!=1){ 142 | throw new IllegalArgumentException("The Passive Aggressive algorithm is a binary method which can learn a single Label"); 143 | } 144 | else{ 145 | this.label=labels.get(0); 146 | this.getPredictionFunction().setLabels(labels); 147 | } 148 | } 149 | 150 | 151 | @Override 152 | public List getLabels() { 153 | return Arrays.asList(label); 154 | } 155 | 156 | @Override 157 | public void learn(Dataset dataset){ 158 | while(dataset.hasNextExample()){ 159 | this.learn(dataset.getNextExample()); 160 | } 161 | dataset.reset(); 162 | } 163 | 164 | @Override 165 | public Label getLabel(){ 166 | return this.label; 167 | } 168 | 169 | @Override 170 | public void setLabel(Label label){ 171 | this.setLabels(Arrays.asList(label)); 172 | } 173 | 174 | } 175 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/budgetedAlgorithm/BudgetedLearningAlgorithm.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.budgetedAlgorithm; 17 | 18 | 19 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 20 | import it.uniroma2.sag.kelp.data.example.Example; 21 | import it.uniroma2.sag.kelp.data.label.Label; 22 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 23 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 24 | import it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm; 25 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 26 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryKernelMachineModel; 27 | 28 | import java.util.Arrays; 29 | import java.util.List; 30 | 31 | /** 32 | * It is binary kernel-based online learning method that binds the number of support vector to a fix number (i.e. the budget) 33 | * When the budget is full, a particular updating policy (that must be specified by extending classes) is adopted 34 | * 35 | * @author Simone Filice 36 | * 37 | */ 38 | public abstract class BudgetedLearningAlgorithm implements OnlineLearningAlgorithm, BinaryLearningAlgorithm, KernelMethod{ 39 | 40 | protected int budget; 41 | protected Label label; 42 | 43 | /** 44 | * Returns the budget, i.e. the maximum number of support vectors 45 | * 46 | * @return the budget 47 | */ 48 | public int getBudget() { 49 | return budget; 50 | } 51 | 52 | /** 53 | * Sets the budget, i.e. the maximum number of support vectors 54 | * 55 | * @param budget the budget to set 56 | */ 57 | public void setBudget(int budget) { 58 | this.budget = budget; 59 | } 60 | 61 | @Override 62 | public void learn(Dataset dataset){ 63 | while(dataset.hasNextExample()){ 64 | this.learn(dataset.getNextExample()); 65 | } 66 | dataset.reset(); 67 | } 68 | 69 | @Override 70 | public Prediction learn(Example example){ 71 | BinaryKernelMachineModel model = (BinaryKernelMachineModel) this.getPredictionFunction().getModel(); 72 | if(model.getSupportVectors().size() labels){ 90 | if(labels.size()!=1){ 91 | throw new IllegalArgumentException("Any budgeted learning algorithm is a binary method which can learn a single Label"); 92 | } 93 | else{ 94 | this.label=labels.get(0); 95 | this.getPredictionFunction().setLabels(labels); 96 | } 97 | } 98 | 99 | 100 | @Override 101 | public List getLabels() { 102 | return Arrays.asList(label); 103 | } 104 | 105 | @Override 106 | public Label getLabel(){ 107 | return this.label; 108 | } 109 | 110 | @Override 111 | public void setLabel(Label label){ 112 | this.setLabels(Arrays.asList(label)); 113 | } 114 | 115 | 116 | } 117 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/budgetedAlgorithm/RandomizedBudgetPerceptron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.budgetedAlgorithm; 17 | 18 | import it.uniroma2.sag.kelp.data.example.Example; 19 | import it.uniroma2.sag.kelp.data.label.Label; 20 | import it.uniroma2.sag.kelp.kernel.Kernel; 21 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 22 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 23 | import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm; 24 | import it.uniroma2.sag.kelp.learningalgorithm.MetaLearningAlgorithm; 25 | import it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm; 26 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 27 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 28 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryKernelMachineModel; 29 | import it.uniroma2.sag.kelp.predictionfunction.model.SupportVector; 30 | 31 | import java.util.Random; 32 | 33 | import com.fasterxml.jackson.annotation.JsonIgnore; 34 | import com.fasterxml.jackson.annotation.JsonTypeName; 35 | 36 | /** 37 | * It is a variation of the Randomized Budget Perceptron proposed in 38 | * [CavallantiCOLT2006] G. Cavallanti, N. Cesa-Bianchi, C. Gentile. Tracking the best hyperplane with a simple budget Perceptron. In proc. of the 19-th annual conference on Computational Learning Theory. (2006) 39 | * 40 | * Until the budget is not reached the online learning updating policy is the one of the baseAlgorithm that this 41 | * meta-algorithm is exploiting. When the budget is full, a random support vector is deleted and the perceptron updating policy is 42 | * adopted 43 | * 44 | * @author Simone Filice 45 | * 46 | */ 47 | @JsonTypeName("randomizedPerceptron") 48 | public class RandomizedBudgetPerceptron extends BudgetedLearningAlgorithm implements MetaLearningAlgorithm{ 49 | 50 | private static final long DEFAULT_SEED=1; 51 | private long initialSeed = DEFAULT_SEED; 52 | @JsonIgnore 53 | private Random randomGenerator; 54 | 55 | private OnlineLearningAlgorithm baseAlgorithm; 56 | 57 | public RandomizedBudgetPerceptron(){ 58 | randomGenerator = new Random(initialSeed); 59 | } 60 | 61 | public RandomizedBudgetPerceptron(int budget, OnlineLearningAlgorithm baseAlgorithm, long seed, Label label){ 62 | randomGenerator = new Random(initialSeed); 63 | this.setBudget(budget); 64 | this.setBaseAlgorithm(baseAlgorithm); 65 | this.setSeed(seed); 66 | this.setLabel(label); 67 | } 68 | 69 | /** 70 | * Sets the seed for the random generator adopted to select the support vector to delete 71 | * 72 | * @param seed the seed of the randomGenerator 73 | */ 74 | public void setSeed(long seed){ 75 | this.initialSeed = seed; 76 | this.randomGenerator.setSeed(seed); 77 | } 78 | 79 | @Override 80 | public RandomizedBudgetPerceptron duplicate() { 81 | RandomizedBudgetPerceptron copy = new RandomizedBudgetPerceptron(); 82 | copy.setBudget(budget); 83 | copy.setBaseAlgorithm(baseAlgorithm.duplicate()); 84 | copy.setSeed(initialSeed); 85 | return copy; 86 | } 87 | 88 | @Override 89 | public void reset() { 90 | this.baseAlgorithm.reset(); 91 | this.randomGenerator.setSeed(initialSeed); 92 | } 93 | 94 | @Override 95 | protected Prediction predictAndLearnWithFullBudget(Example example) { 96 | Prediction prediction = this.baseAlgorithm.getPredictionFunction().predict(example); 97 | 98 | if((prediction.getScore(getLabel())>0) != example.isExampleOf(getLabel())){ 99 | int svToDelete = this.randomGenerator.nextInt(budget); 100 | float weight = 1; 101 | if(!example.isExampleOf(getLabels().get(0))){ 102 | weight=-1; 103 | } 104 | SupportVector sv = new SupportVector(weight, example); 105 | 106 | ((BinaryKernelMachineModel)this.baseAlgorithm.getPredictionFunction().getModel()).setSupportVector(sv, svToDelete); 107 | } 108 | return prediction; 109 | } 110 | 111 | @Override 112 | public void setBaseAlgorithm(LearningAlgorithm baseAlgorithm) { 113 | if(baseAlgorithm instanceof OnlineLearningAlgorithm && baseAlgorithm instanceof KernelMethod && baseAlgorithm instanceof BinaryLearningAlgorithm){ 114 | this.baseAlgorithm = (OnlineLearningAlgorithm) baseAlgorithm; 115 | }else{ 116 | throw new IllegalArgumentException("a valid baseAlgorithm for the Randomized Budget Perceptron must implement OnlineLearningAlgorithm, BinaryLeaningAlgorithm and KernelMethod"); 117 | } 118 | } 119 | 120 | @Override 121 | public OnlineLearningAlgorithm getBaseAlgorithm() { 122 | return this.baseAlgorithm; 123 | } 124 | 125 | @Override 126 | public PredictionFunction getPredictionFunction() { 127 | return this.baseAlgorithm.getPredictionFunction(); 128 | } 129 | 130 | @Override 131 | public Kernel getKernel() { 132 | return ((KernelMethod)this.baseAlgorithm).getKernel(); 133 | } 134 | 135 | @Override 136 | public void setKernel(Kernel kernel) { 137 | ((KernelMethod)this.baseAlgorithm).setKernel(kernel); 138 | 139 | } 140 | 141 | @Override 142 | protected Prediction predictAndLearnWithAvailableBudget(Example example) { 143 | return this.baseAlgorithm.learn(example); 144 | } 145 | 146 | @Override 147 | public void setPredictionFunction(PredictionFunction predictionFunction) { 148 | this.baseAlgorithm.setPredictionFunction(predictionFunction); 149 | } 150 | 151 | } 152 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/budgetedAlgorithm/Stoptron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.budgetedAlgorithm; 17 | 18 | import it.uniroma2.sag.kelp.data.example.Example; 19 | import it.uniroma2.sag.kelp.data.label.Label; 20 | import it.uniroma2.sag.kelp.kernel.Kernel; 21 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 22 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 23 | import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm; 24 | import it.uniroma2.sag.kelp.learningalgorithm.MetaLearningAlgorithm; 25 | import it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm; 26 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 27 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 28 | 29 | import com.fasterxml.jackson.annotation.JsonTypeName; 30 | 31 | /** 32 | * It is a variation of the Stoptron proposed in 33 | * [OrabonaICML2008] Francesco Orabona, Joseph Keshet, and Barbara Caputo. The projectron: a bounded kernel-based perceptron. In Int. Conf. on Machine Learning (2008) 34 | * 35 | * Until the budget is not reached the online learning updating policy is the one of the baseAlgorithm that this 36 | * meta-algorithm is exploiting. When the budget is full, the learning process ends 37 | * 38 | * @author Simone Filice 39 | * 40 | */ 41 | @JsonTypeName("stoptron") 42 | public class Stoptron extends BudgetedLearningAlgorithm implements MetaLearningAlgorithm{ 43 | 44 | private OnlineLearningAlgorithm baseAlgorithm; 45 | 46 | public Stoptron(){ 47 | 48 | } 49 | 50 | public Stoptron(int budget, OnlineLearningAlgorithm baseAlgorithm, Label label){ 51 | this.setBudget(budget); 52 | this.setBaseAlgorithm(baseAlgorithm); 53 | this.setLabel(label); 54 | } 55 | 56 | @Override 57 | public Stoptron duplicate() { 58 | Stoptron copy = new Stoptron(); 59 | copy.setBudget(budget); 60 | copy.setBaseAlgorithm(baseAlgorithm.duplicate()); 61 | return copy; 62 | } 63 | 64 | @Override 65 | public void reset() { 66 | this.baseAlgorithm.reset(); 67 | } 68 | 69 | @Override 70 | protected Prediction predictAndLearnWithFullBudget(Example example) { 71 | return this.baseAlgorithm.getPredictionFunction().predict(example); 72 | } 73 | 74 | @Override 75 | public void setBaseAlgorithm(LearningAlgorithm baseAlgorithm) { 76 | if(baseAlgorithm instanceof OnlineLearningAlgorithm && baseAlgorithm instanceof KernelMethod && baseAlgorithm instanceof BinaryLearningAlgorithm){ 77 | this.baseAlgorithm = (OnlineLearningAlgorithm) baseAlgorithm; 78 | }else{ 79 | throw new IllegalArgumentException("a valid baseAlgorithm for the Stoptron must implement OnlineLearningAlgorithm, BinaryLeaningAlgorithm and KernelMethod"); 80 | } 81 | } 82 | 83 | @Override 84 | public OnlineLearningAlgorithm getBaseAlgorithm() { 85 | return this.baseAlgorithm; 86 | } 87 | 88 | @Override 89 | public PredictionFunction getPredictionFunction() { 90 | return this.baseAlgorithm.getPredictionFunction(); 91 | } 92 | 93 | @Override 94 | public Kernel getKernel() { 95 | return ((KernelMethod)this.baseAlgorithm).getKernel(); 96 | } 97 | 98 | @Override 99 | public void setKernel(Kernel kernel) { 100 | ((KernelMethod)this.baseAlgorithm).setKernel(kernel); 101 | 102 | } 103 | 104 | @Override 105 | protected Prediction predictAndLearnWithAvailableBudget(Example example) { 106 | return this.baseAlgorithm.learn(example); 107 | } 108 | 109 | @Override 110 | public void setPredictionFunction(PredictionFunction predictionFunction) { 111 | this.baseAlgorithm.setPredictionFunction(predictionFunction); 112 | } 113 | 114 | } 115 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/dcd/DCDLoss.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.dcd; 2 | 3 | public enum DCDLoss { 4 | L1, L2 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/hmm/SequenceClassificationKernelBasedLearningAlgorithm.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.hmm; 17 | 18 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGenerator; 19 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGeneratorKernel; 20 | import it.uniroma2.sag.kelp.kernel.Kernel; 21 | import it.uniroma2.sag.kelp.kernel.cache.KernelCache; 22 | import it.uniroma2.sag.kelp.kernel.standard.LinearKernelCombination; 23 | import it.uniroma2.sag.kelp.kernel.vector.LinearKernel; 24 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 25 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 26 | import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm; 27 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 28 | 29 | /** 30 | * /** This class implements a sequential labeling paradigm. 31 | * Given sequences of items (each implemented as an Example and 32 | * associated to one Label) this class allow to apply a generic 33 | * LearningAlgorithm to use the "history" of each item in the 34 | * sequence in order to improve the classification quality. In other words, the 35 | * classification of each example does not depend only its representation, but 36 | * it also depend on its "history", in terms of the classed assigned to the 37 | * preceding examples. 38 | * This class should be used when a kernel-based learning algorithm is 39 | * used, thus directly operating in the implicit space underlying a kernel 40 | * function. 41 | * 42 | * 43 | * This algorithms was inspired by the work of: 44 | * Y. Altun, I. Tsochantaridis, and T. Hofmann. Hidden Markov support vector 45 | * machines. In Proceedings of the Twentieth International Conference on Machine 46 | * Learning, 2003. 47 | * 48 | * @author Danilo Croce 49 | * 50 | */ 51 | public class SequenceClassificationKernelBasedLearningAlgorithm extends SequenceClassificationLearningAlgorithm 52 | implements KernelMethod { 53 | 54 | private final static String TRANSITION_REPRESENTATION_NAME = "__trans_rep__"; 55 | 56 | private LinearKernelCombination sequenceBasedKernel; 57 | 58 | public SequenceClassificationKernelBasedLearningAlgorithm() { 59 | 60 | } 61 | 62 | /** 63 | * @param baseLearningAlgorithm 64 | * the learning algorithm devoted to the acquisition of a model 65 | * after that each example has been enriched with its "history" 66 | * @param transitionsOrder 67 | * given a targeted item in the sequence, this variable 68 | * determines the number of previous example considered in the 69 | * learning/labeling process. 70 | * @param transitionWeight 71 | * the importance of the transition-based features during the 72 | * learning process. Higher valuers will assign more importance 73 | * to the transitions. 74 | * @throws Exception 75 | * The input baseLearningAlgorithm is not a 76 | * kernel-based method 77 | */ 78 | public SequenceClassificationKernelBasedLearningAlgorithm(BinaryLearningAlgorithm baseLearningAlgorithm, 79 | int transitionsOrder, float transitionWeight) throws Exception { 80 | 81 | if (!(baseLearningAlgorithm instanceof KernelMethod)) { 82 | throw new Exception("ERROR: the input baseLearningAlgorithm is not a kernel-based method!"); 83 | } 84 | 85 | Kernel inputKernel = ((KernelMethod) baseLearningAlgorithm).getKernel(); 86 | 87 | sequenceBasedKernel = new LinearKernelCombination(); 88 | sequenceBasedKernel.addKernel(1, inputKernel); 89 | Kernel transitionBasedKernel = new LinearKernel(TRANSITION_REPRESENTATION_NAME); 90 | sequenceBasedKernel.addKernel(transitionWeight, transitionBasedKernel); 91 | sequenceBasedKernel.normalizeWeights(); 92 | 93 | setKernel(sequenceBasedKernel); 94 | 95 | BinaryLearningAlgorithm binaryLearningAlgorithmCopy = (BinaryLearningAlgorithm) baseLearningAlgorithm 96 | .duplicate(); 97 | 98 | ((KernelMethod) binaryLearningAlgorithmCopy).setKernel(sequenceBasedKernel); 99 | 100 | OneVsAllLearning oneVsAllLearning = new OneVsAllLearning(); 101 | oneVsAllLearning.setBaseAlgorithm(binaryLearningAlgorithmCopy); 102 | 103 | super.setBaseLearningAlgorithm(oneVsAllLearning); 104 | 105 | SequenceExampleGenerator sequenceExamplesGenerator = new SequenceExampleGeneratorKernel( 106 | transitionsOrder, TRANSITION_REPRESENTATION_NAME); 107 | 108 | super.setSequenceExampleGenerator(sequenceExamplesGenerator); 109 | } 110 | 111 | @Override 112 | public LearningAlgorithm duplicate() { 113 | return null; 114 | } 115 | 116 | @Override 117 | public LearningAlgorithm getBaseAlgorithm() { 118 | return super.getBaseLearningAlgorithm(); 119 | } 120 | 121 | @Override 122 | public Kernel getKernel() { 123 | return sequenceBasedKernel; 124 | } 125 | 126 | @Override 127 | public void setBaseAlgorithm(LearningAlgorithm baseAlgorithm) { 128 | super.setBaseLearningAlgorithm(baseAlgorithm); 129 | } 130 | 131 | @Override 132 | public void setKernel(Kernel kernel) { 133 | this.sequenceBasedKernel = (LinearKernelCombination) kernel; 134 | } 135 | 136 | public void setKernelCache(KernelCache cache) { 137 | this.getKernel().setKernelCache(cache); 138 | } 139 | 140 | } 141 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/hmm/SequenceClassificationLinearLearningAlgorithm.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.hmm; 17 | 18 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGenerator; 19 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGeneratorLinear; 20 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 21 | import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm; 22 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 23 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 24 | 25 | /** 26 | * This class implements a sequential labeling paradigm. 27 | * Given sequences of items (each implemented as an Example and 28 | * associated to one Label) this class allow to apply a generic 29 | * LearningAlgorithm to use the "history" of each item in the 30 | * sequence in order to improve the classification quality. In other words, the 31 | * classification of each example does not depend only its representation, but 32 | * it also depend on its "history", in terms of the classed assigned to the 33 | * preceding examples. 34 | * This class should be used when a linear learning algorithm is used, 35 | * thus directly operating in the representation space. 36 | * 37 | * 38 | * This algorithms was inspired by the work of: 39 | * Y. Altun, I. Tsochantaridis, and T. Hofmann. Hidden Markov support vector 40 | * machines. In Proceedings of the Twentieth International Conference on Machine 41 | * Learning, 2003. 42 | * 43 | * @author Danilo Croce 44 | * 45 | */ 46 | public class SequenceClassificationLinearLearningAlgorithm extends SequenceClassificationLearningAlgorithm 47 | implements LinearMethod { 48 | 49 | /** 50 | * @param baseLearningAlgorithm 51 | * the "linear" learning algorithm devoted to the acquisition of 52 | * a model after that each example has been enriched with its 53 | * "history" 54 | * @param transitionsOrder 55 | * given a targeted item in the sequence, this variable 56 | * determines the number of previous example considered in the 57 | * learning/labeling process. 58 | * @param transitionWeight 59 | * the importance of the transition-based features during the 60 | * learning process. Higher valuers will assign more importance 61 | * to the transitions. 62 | * @throws Exception The input baseLearningAlgorithm is not a Linear method 63 | */ 64 | public SequenceClassificationLinearLearningAlgorithm(BinaryLearningAlgorithm baseLearningAlgorithm, 65 | int transitionsOrder, float transitionWeight) throws Exception { 66 | 67 | if (!(baseLearningAlgorithm instanceof LinearMethod)) { 68 | throw new Exception("ERROR: the input baseLearningAlgorithm is not a Linear method!"); 69 | } 70 | 71 | OneVsAllLearning oneVsAllLearning = new OneVsAllLearning(); 72 | oneVsAllLearning.setBaseAlgorithm(baseLearningAlgorithm); 73 | 74 | super.setBaseLearningAlgorithm(oneVsAllLearning); 75 | String representation = ((LinearMethod) baseLearningAlgorithm).getRepresentation(); 76 | 77 | SequenceExampleGenerator sequenceExamplesGenerator = new SequenceExampleGeneratorLinear(transitionsOrder, 78 | representation, transitionWeight); 79 | 80 | super.setSequenceExampleGenerator(sequenceExamplesGenerator); 81 | } 82 | 83 | @Override 84 | public LearningAlgorithm duplicate() { 85 | // TODO Auto-generated method stub 86 | return null; 87 | } 88 | 89 | @Override 90 | public LearningAlgorithm getBaseAlgorithm() { 91 | return super.getBaseLearningAlgorithm(); 92 | } 93 | 94 | @Override 95 | public String getRepresentation() { 96 | return ((SequenceClassificationLinearLearningAlgorithm) getSequenceExampleGenerator()).getRepresentation(); 97 | } 98 | 99 | @Override 100 | public void setBaseAlgorithm(LearningAlgorithm baseAlgorithm) { 101 | super.setBaseLearningAlgorithm(baseAlgorithm); 102 | } 103 | 104 | @Override 105 | public void setRepresentation(String representationName) { 106 | ((SequenceClassificationLinearLearningAlgorithm) getSequenceExampleGenerator()) 107 | .setRepresentation(representationName); 108 | } 109 | 110 | } 111 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/COPYRIGHT: -------------------------------------------------------------------------------- 1 | 2 | Copyright (c) 2007-2013 The LIBLINEAR Project. 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions 7 | are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | 16 | 3. Neither name of copyright holders nor the names of its contributors 17 | may be used to endorse or promote products derived from this software 18 | without specific prior written permission. 19 | 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR 25 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 26 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 27 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 28 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 29 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 30 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 31 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/L2R_L2_SvcFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | /** 19 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 20 | * C++ sources. Original Java sources (v 1.94) are available at 21 | * http://liblinear.bwaldvogel.de 22 | * 23 | * @author Danilo Croce 24 | */ 25 | public class L2R_L2_SvcFunction implements TronFunction { 26 | 27 | protected final Problem prob; 28 | protected final double[] C; 29 | protected final int[] I; 30 | protected final double[] z; 31 | 32 | protected int sizeI; 33 | 34 | public L2R_L2_SvcFunction(Problem prob, double[] C) { 35 | int l = prob.l; 36 | 37 | this.prob = prob; 38 | 39 | z = new double[l]; 40 | I = new int[l]; 41 | this.C = C; 42 | } 43 | 44 | public double fun(double[] w) { 45 | int i; 46 | double f = 0; 47 | double[] y = prob.y; 48 | int l = prob.l; 49 | int w_size = get_nr_variable(); 50 | 51 | Xv(w, z); 52 | 53 | for (i = 0; i < w_size; i++) 54 | f += w[i] * w[i]; 55 | f /= 2.0; 56 | for (i = 0; i < l; i++) { 57 | z[i] = y[i] * z[i]; 58 | double d = 1 - z[i]; 59 | if (d > 0) 60 | f += C[i] * d * d; 61 | } 62 | 63 | return (f); 64 | } 65 | 66 | public int get_nr_variable() { 67 | return prob.n; 68 | } 69 | 70 | public void grad(double[] w, double[] g) { 71 | double[] y = prob.y; 72 | int l = prob.l; 73 | int w_size = get_nr_variable(); 74 | 75 | sizeI = 0; 76 | for (int i = 0; i < l; i++) { 77 | if (z[i] < 1) { 78 | z[sizeI] = C[i] * y[i] * (z[i] - 1); 79 | I[sizeI] = i; 80 | sizeI++; 81 | } 82 | } 83 | subXTv(z, g); 84 | 85 | for (int i = 0; i < w_size; i++) 86 | g[i] = w[i] + 2 * g[i]; 87 | } 88 | 89 | public void Hv(double[] s, double[] Hs) { 90 | int i; 91 | int w_size = get_nr_variable(); 92 | double[] wa = new double[sizeI]; 93 | 94 | subXv(s, wa); 95 | for (i = 0; i < sizeI; i++) 96 | wa[i] = C[I[i]] * wa[i]; 97 | 98 | subXTv(wa, Hs); 99 | for (i = 0; i < w_size; i++) 100 | Hs[i] = s[i] + 2 * Hs[i]; 101 | } 102 | 103 | protected void subXTv(double[] v, double[] XTv) { 104 | int i; 105 | int w_size = get_nr_variable(); 106 | 107 | for (i = 0; i < w_size; i++) 108 | XTv[i] = 0; 109 | 110 | for (i = 0; i < sizeI; i++) { 111 | for (LibLinearFeature s : prob.x[I[i]]) { 112 | XTv[s.getIndex() - 1] += v[i] * s.getValue(); 113 | } 114 | } 115 | } 116 | 117 | private void subXv(double[] v, double[] Xv) { 118 | 119 | for (int i = 0; i < sizeI; i++) { 120 | Xv[i] = 0; 121 | for (LibLinearFeature s : prob.x[I[i]]) { 122 | Xv[i] += v[s.getIndex() - 1] * s.getValue(); 123 | } 124 | } 125 | } 126 | 127 | protected void Xv(double[] v, double[] Xv) { 128 | 129 | for (int i = 0; i < prob.l; i++) { 130 | Xv[i] = 0; 131 | for (LibLinearFeature s : prob.x[i]) { 132 | Xv[i] += v[s.getIndex() - 1] * s.getValue(); 133 | } 134 | } 135 | } 136 | 137 | } 138 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/L2R_L2_SvrFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | /** 19 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 20 | * C++ sources. Original Java sources (v 1.94) are available at 21 | * http://liblinear.bwaldvogel.de 22 | * 23 | * @author Danilo Croce 24 | */ 25 | public class L2R_L2_SvrFunction extends L2R_L2_SvcFunction { 26 | 27 | private double p; 28 | 29 | public L2R_L2_SvrFunction( Problem prob, double[] C, double p ) { 30 | super(prob, C); 31 | this.p = p; 32 | } 33 | 34 | @Override 35 | public double fun(double[] w) { 36 | double f = 0; 37 | double[] y = prob.y; 38 | int l = prob.l; 39 | int w_size = get_nr_variable(); 40 | double d; 41 | 42 | Xv(w, z); 43 | 44 | for (int i = 0; i < w_size; i++) 45 | f += w[i] * w[i]; 46 | f /= 2; 47 | for (int i = 0; i < l; i++) { 48 | d = z[i] - y[i]; 49 | if (d < -p) 50 | f += C[i] * (d + p) * (d + p); 51 | else if (d > p) f += C[i] * (d - p) * (d - p); 52 | } 53 | 54 | return f; 55 | } 56 | 57 | @Override 58 | public void grad(double[] w, double[] g) { 59 | double[] y = prob.y; 60 | int l = prob.l; 61 | int w_size = get_nr_variable(); 62 | 63 | sizeI = 0; 64 | for (int i = 0; i < l; i++) { 65 | double d = z[i] - y[i]; 66 | 67 | // generate index set I 68 | if (d < -p) { 69 | z[sizeI] = C[i] * (d + p); 70 | I[sizeI] = i; 71 | sizeI++; 72 | } else if (d > p) { 73 | z[sizeI] = C[i] * (d - p); 74 | I[sizeI] = i; 75 | sizeI++; 76 | } 77 | 78 | } 79 | subXTv(z, g); 80 | 81 | for (int i = 0; i < w_size; i++) 82 | g[i] = w[i] + 2 * g[i]; 83 | 84 | } 85 | 86 | } -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/LibLinearFeature.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | /** 19 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 20 | * C++ sources. Original Java sources (v 1.94) are available at 21 | * http://liblinear.bwaldvogel.de 22 | * 23 | * @author Danilo Croce 24 | */ 25 | public interface LibLinearFeature { 26 | 27 | int getIndex(); 28 | 29 | double getValue(); 30 | 31 | void setValue(double value); 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/LibLinearFeatureNode.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | /** 19 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 20 | * C++ sources. Original Java sources (v 1.94) are available at 21 | * http://liblinear.bwaldvogel.de 22 | * 23 | * @author Danilo Croce 24 | */ 25 | public class LibLinearFeatureNode implements LibLinearFeature { 26 | 27 | public final int index; 28 | public double value; 29 | 30 | public LibLinearFeatureNode( final int index, final double value ) { 31 | if (index < 0) throw new IllegalArgumentException("index must be >= 0"); 32 | this.index = index; 33 | this.value = value; 34 | } 35 | 36 | /** 37 | * @since 1.9 38 | */ 39 | public int getIndex() { 40 | return index; 41 | } 42 | 43 | /** 44 | * @since 1.9 45 | */ 46 | public double getValue() { 47 | return value; 48 | } 49 | 50 | /** 51 | * @since 1.9 52 | */ 53 | public void setValue(double value) { 54 | this.value = value; 55 | } 56 | 57 | @Override 58 | public int hashCode() { 59 | final int prime = 31; 60 | int result = 1; 61 | result = prime * result + index; 62 | long temp; 63 | temp = Double.doubleToLongBits(value); 64 | result = prime * result + (int)(temp ^ (temp >>> 32)); 65 | return result; 66 | } 67 | 68 | @Override 69 | public boolean equals(Object obj) { 70 | if (this == obj) return true; 71 | if (obj == null) return false; 72 | if (getClass() != obj.getClass()) return false; 73 | LibLinearFeatureNode other = (LibLinearFeatureNode)obj; 74 | if (index != other.index) return false; 75 | if (Double.doubleToLongBits(value) != Double.doubleToLongBits(other.value)) return false; 76 | return true; 77 | } 78 | 79 | @Override 80 | public String toString() { 81 | return "FeatureNode(idx=" + index + ", value=" + value + ")"; 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/Problem.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | /** 19 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 20 | * C++ sources. Original Java sources (v 1.94) are available at 21 | * http://liblinear.bwaldvogel.de 22 | * 23 | * @author Danilo Croce 24 | */ 25 | import gnu.trove.map.hash.TIntObjectHashMap; 26 | import gnu.trove.map.hash.TObjectIntHashMap; 27 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 28 | import it.uniroma2.sag.kelp.data.example.Example; 29 | import it.uniroma2.sag.kelp.data.example.SimpleExample; 30 | import it.uniroma2.sag.kelp.data.label.Label; 31 | import it.uniroma2.sag.kelp.data.representation.Representation; 32 | import it.uniroma2.sag.kelp.data.representation.Vector; 33 | import it.uniroma2.sag.kelp.data.representation.vector.DenseVector; 34 | import it.uniroma2.sag.kelp.data.representation.vector.SparseVector; 35 | 36 | import java.io.IOException; 37 | import java.util.ArrayList; 38 | import java.util.Map; 39 | 40 | /** 41 | * 42 | * Describes the problem 43 | * 44 | * 45 | * For example, if we have the following training data: 46 | * 47 | * 48 | * LABEL ATTR1 ATTR2 ATTR3 ATTR4 ATTR5 49 | * ----- ----- ----- ----- ----- ----- 50 | * 1 0 0.1 0.2 0 0 51 | * 2 0 0.1 0.3 -1.2 0 52 | * 1 0.4 0 0 0 0 53 | * 2 0 0.1 0 1.4 0.5 54 | * 3 -0.1 -0.2 0.1 1.1 0.1 55 | * 56 | * and bias = 1, then the components of problem are: 57 | * 58 | * l = 5 59 | * n = 6 60 | * 61 | * y -> 1 2 1 2 3 62 | * 63 | * x -> [ ] -> (2,0.1) (3,0.2) (6,1) (-1,?) 64 | * [ ] -> (2,0.1) (3,0.3) (4,-1.2) (6,1) (-1,?) 65 | * [ ] -> (1,0.4) (6,1) (-1,?) 66 | * [ ] -> (2,0.1) (4,1.4) (5,0.5) (6,1) (-1,?) 67 | * [ ] -> (1,-0.1) (2,-0.2) (3,0.1) (4,1.1) (5,0.1) (6,1) (-1,?) 68 | * 69 | */ 70 | public class Problem { 71 | 72 | public enum LibLinearSolverType { 73 | CLASSIFICATION, REGRESSION 74 | } 75 | 76 | public TObjectIntHashMap featureDict = new TObjectIntHashMap(); 77 | 78 | public TIntObjectHashMap featureInverseDict = new TIntObjectHashMap(); 79 | 80 | /** the number of training data */ 81 | public int l; 82 | 83 | /** the number of features (including the bias feature if bias >= 0) */ 84 | public int n; 85 | 86 | /** an array containing the target values */ 87 | public double[] y; 88 | /** array of sparse feature nodes */ 89 | public LibLinearFeature[][] x; 90 | 91 | /** 92 | * If bias >= 0, we assume that one additional feature is added to the 93 | * end of each data instance 94 | */ 95 | public double bias; 96 | 97 | private boolean isInputDense; 98 | 99 | public Problem(Dataset dataset, String reprentationName, Label label, 100 | LibLinearSolverType solverType) { 101 | 102 | this.l = dataset.getNumberOfExamples(); 103 | this.y = new double[l]; 104 | this.x = new LibLinearFeature[l][]; 105 | 106 | ArrayList vectorlist = new ArrayList(); 107 | 108 | if (dataset.getExamples().get(0).getRepresentation(reprentationName) instanceof DenseVector) 109 | isInputDense = true; 110 | 111 | int i = 0; 112 | for (Example e : dataset.getExamples()) { 113 | SimpleExample simpleExample = (SimpleExample) e; 114 | Representation r = simpleExample 115 | .getRepresentation(reprentationName); 116 | Vector vector = (Vector) r; 117 | 118 | vectorlist.add(vector); 119 | 120 | if (solverType == LibLinearSolverType.CLASSIFICATION) { 121 | if (e.isExampleOf(label)) 122 | y[i] = 1; 123 | else 124 | y[i] = -1; 125 | } else { 126 | y[i] = e.getRegressionValue(label); 127 | } 128 | 129 | i++; 130 | } 131 | 132 | initializeExamples(vectorlist); 133 | 134 | } 135 | 136 | private DenseVector getDenseW(double[] w) { 137 | double[] tmp = new double[w.length - 1]; 138 | for (int i = 0; i < w.length - 1; i++) { 139 | tmp[i] = w[i]; 140 | } 141 | return new DenseVector(tmp); 142 | } 143 | 144 | private SparseVector getSparseW(double[] w) { 145 | SparseVector res = new SparseVector(); 146 | 147 | StringBuilder sb = new StringBuilder(); 148 | for (int i = 0; i < w.length - 1; i++) { 149 | sb.append(this.featureInverseDict.get(i + 1) + ":" + w[i] + " "); 150 | } 151 | sb.append("__LIB_LINEAR_BIAS_:" + w[w.length - 1]); 152 | 153 | try { 154 | res.setDataFromText(sb.toString().trim()); 155 | } catch (IOException e) { 156 | e.printStackTrace(); 157 | return null; 158 | } 159 | return res; 160 | } 161 | 162 | public Vector getW(double[] w) { 163 | if (isInputDense) { 164 | return getDenseW(w); 165 | } 166 | return getSparseW(w); 167 | } 168 | 169 | public void initializeExamples(ArrayList vectorlist) { 170 | if (isInputDense) { 171 | initializeExamplesDense(vectorlist); 172 | } else { 173 | initializeExamplesSparse(vectorlist); 174 | } 175 | } 176 | 177 | private void initializeExamplesDense(ArrayList vectorlist) { 178 | for (int vectorId = 0; vectorId < vectorlist.size(); vectorId++) { 179 | DenseVector denseVector = (DenseVector) (vectorlist.get(vectorId)); 180 | if (vectorId == 0) { 181 | bias = 0; 182 | n = denseVector.getNumberOfFeatures() + 1; 183 | } 184 | this.x[vectorId] = new LibLinearFeatureNode[denseVector 185 | .getNumberOfFeatures()]; 186 | for (int j = 0; j < denseVector.getNumberOfFeatures(); j++) 187 | this.x[vectorId][j] = new LibLinearFeatureNode(j + 1, 188 | denseVector.getFeatureValue(j)); 189 | } 190 | } 191 | 192 | private void initializeExamplesSparse(ArrayList vectorlist) { 193 | /* 194 | * Building dictionary 195 | */ 196 | int featureIndex = 1; 197 | for (Vector v : vectorlist) { 198 | //for (String dimLabel : v.getActiveFeatures().keySet()) { 199 | for (Object dimLabel : v.getActiveFeatures().keySet()) { 200 | if (!featureDict.containsKey(dimLabel)) { 201 | featureDict.put(dimLabel, featureIndex); 202 | featureInverseDict.put(featureIndex, dimLabel); 203 | featureIndex++; 204 | // System.out.println(featureIndex + " " + dimLabel); 205 | } 206 | } 207 | } 208 | 209 | /* 210 | * Initialize the object 211 | */ 212 | n = featureDict.size() + 1; 213 | bias = 0; 214 | int i = 0; 215 | for (Vector v : vectorlist) { 216 | Map, Number> activeFeatures = v.getActiveFeatures(); 217 | this.x[i] = new LibLinearFeatureNode[activeFeatures.size()]; 218 | int j = 0; 219 | for (Object dimLabel : activeFeatures.keySet()) { 220 | this.x[i][j] = new LibLinearFeatureNode( 221 | featureDict.get(dimLabel), activeFeatures.get(dimLabel).doubleValue()); 222 | j++; 223 | } 224 | i++; 225 | } 226 | } 227 | 228 | } 229 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/Tron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | import org.slf4j.Logger; 19 | import org.slf4j.LoggerFactory; 20 | 21 | 22 | /** 23 | * Trust Region Newton Method optimization 24 | * 25 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 26 | * C++ sources. Original Java sources (v 1.94) are available at 27 | * http://liblinear.bwaldvogel.de 28 | * 29 | * @author Danilo Croce 30 | */ 31 | public class Tron { 32 | private Logger logger = LoggerFactory.getLogger(Tron.class); 33 | 34 | private final TronFunction fun_obj; 35 | private final double eps; 36 | private final int max_iter; 37 | 38 | public Tron(final TronFunction fun_obj) { 39 | this(fun_obj, 0.1); 40 | } 41 | 42 | public Tron(final TronFunction fun_obj, double eps) { 43 | this(fun_obj, eps, 1000); 44 | } 45 | 46 | public Tron(final TronFunction fun_obj, double eps, int max_iter) { 47 | this.fun_obj = fun_obj; 48 | this.eps = eps; 49 | this.max_iter = max_iter; 50 | } 51 | 52 | public void tron(double[] w) { 53 | // Parameters for updating the iterates. 54 | double eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75; 55 | 56 | // Parameters for updating the trust region size delta. 57 | double sigma1 = 0.25, sigma2 = 0.5, sigma3 = 4; 58 | 59 | int n = fun_obj.get_nr_variable(); 60 | int i, cg_iter; 61 | double delta, snorm, one = 1.0; 62 | double alpha, f, fnew, prered, actred, gs; 63 | int search = 1, iter = 1; 64 | double[] s = new double[n]; 65 | double[] r = new double[n]; 66 | double[] w_new = new double[n]; 67 | double[] g = new double[n]; 68 | 69 | for (i = 0; i < n; i++) 70 | w[i] = 0; 71 | 72 | f = fun_obj.fun(w); 73 | fun_obj.grad(w, g); 74 | delta = euclideanNorm(g); 75 | double gnorm1 = delta; 76 | double gnorm = gnorm1; 77 | 78 | if (gnorm <= eps * gnorm1) 79 | search = 0; 80 | 81 | iter = 1; 82 | 83 | while (iter <= max_iter && search != 0) { 84 | cg_iter = trcg(delta, g, s, r); 85 | 86 | System.arraycopy(w, 0, w_new, 0, n); 87 | daxpy(one, s, w_new); 88 | 89 | gs = dot(g, s); 90 | prered = -0.5 * (gs - dot(s, r)); 91 | fnew = fun_obj.fun(w_new); 92 | 93 | // Compute the actual reduction. 94 | actred = f - fnew; 95 | 96 | // On the first iteration, adjust the initial step bound. 97 | snorm = euclideanNorm(s); 98 | if (iter == 1) 99 | delta = Math.min(delta, snorm); 100 | 101 | // Compute prediction alpha*snorm of the step. 102 | if (fnew - f - gs <= 0) 103 | alpha = sigma3; 104 | else 105 | alpha = Math.max(sigma1, -0.5 * (gs / (fnew - f - gs))); 106 | 107 | // Update the trust region bound according to the ratio of actual to 108 | // predicted reduction. 109 | if (actred < eta0 * prered) 110 | delta = Math.min(Math.max(alpha, sigma1) * snorm, sigma2 111 | * delta); 112 | else if (actred < eta1 * prered) 113 | delta = Math.max(sigma1 * delta, 114 | Math.min(alpha * snorm, sigma2 * delta)); 115 | else if (actred < eta2 * prered) 116 | delta = Math.max(sigma1 * delta, 117 | Math.min(alpha * snorm, sigma3 * delta)); 118 | else 119 | delta = Math 120 | .max(delta, Math.min(alpha * snorm, sigma3 * delta)); 121 | 122 | // info("iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d%n", 123 | // iter, actred, prered, delta, f, gnorm, cg_iter); 124 | info("iter {} act {} pre {} delta {} f {} |g| {} CG {}", 125 | iter, actred, prered, delta, f, gnorm, cg_iter); 126 | 127 | if (actred > eta0 * prered) { 128 | iter++; 129 | System.arraycopy(w_new, 0, w, 0, n); 130 | f = fnew; 131 | fun_obj.grad(w, g); 132 | 133 | gnorm = euclideanNorm(g); 134 | if (gnorm <= eps * gnorm1) 135 | break; 136 | } 137 | if (f < -1.0e+32) { 138 | info("WARNING: f < -1.0e+32%n"); 139 | break; 140 | } 141 | if (Math.abs(actred) <= 0 && prered <= 0) { 142 | info("WARNING: actred and prered <= 0%n"); 143 | break; 144 | } 145 | if (Math.abs(actred) <= 1.0e-12 * Math.abs(f) 146 | && Math.abs(prered) <= 1.0e-12 * Math.abs(f)) { 147 | info("WARNING: actred and prered too small%n"); 148 | break; 149 | } 150 | } 151 | } 152 | 153 | private void info(String msg) { 154 | logger.debug(msg); 155 | } 156 | 157 | private void info(String msgFormatted, Object... args) { 158 | // Formatter formatter = new Formatter(); 159 | // Formatter format = formatter.format(msgFormatted, args); 160 | logger.debug(msgFormatted,args); 161 | // formatter.close(); 162 | } 163 | 164 | private int trcg(double delta, double[] g, double[] s, double[] r) { 165 | int n = fun_obj.get_nr_variable(); 166 | double one = 1; 167 | double[] d = new double[n]; 168 | double[] Hd = new double[n]; 169 | double rTr, rnewTrnew, cgtol; 170 | 171 | for (int i = 0; i < n; i++) { 172 | s[i] = 0; 173 | r[i] = -g[i]; 174 | d[i] = r[i]; 175 | } 176 | cgtol = 0.1 * euclideanNorm(g); 177 | 178 | int cg_iter = 0; 179 | rTr = dot(r, r); 180 | 181 | while (true) { 182 | if (euclideanNorm(r) <= cgtol) 183 | break; 184 | cg_iter++; 185 | fun_obj.Hv(d, Hd); 186 | 187 | double alpha = rTr / dot(d, Hd); 188 | daxpy(alpha, d, s); 189 | if (euclideanNorm(s) > delta) { 190 | info("cg reaches trust region boundary%n"); 191 | alpha = -alpha; 192 | daxpy(alpha, d, s); 193 | 194 | double std = dot(s, d); 195 | double sts = dot(s, s); 196 | double dtd = dot(d, d); 197 | double dsq = delta * delta; 198 | double rad = Math.sqrt(std * std + dtd * (dsq - sts)); 199 | if (std >= 0) 200 | alpha = (dsq - sts) / (std + rad); 201 | else 202 | alpha = (rad - std) / dtd; 203 | daxpy(alpha, d, s); 204 | alpha = -alpha; 205 | daxpy(alpha, Hd, r); 206 | break; 207 | } 208 | alpha = -alpha; 209 | daxpy(alpha, Hd, r); 210 | rnewTrnew = dot(r, r); 211 | double beta = rnewTrnew / rTr; 212 | scale(beta, d); 213 | daxpy(one, r, d); 214 | rTr = rnewTrnew; 215 | } 216 | 217 | return (cg_iter); 218 | } 219 | 220 | /** 221 | * constant times a vector plus a vector 222 | * 223 | * 224 | * vector2 += constant * vector1 225 | * 226 | * 227 | * @since 1.8 228 | */ 229 | private static void daxpy(double constant, double vector1[], 230 | double vector2[]) { 231 | if (constant == 0) 232 | return; 233 | 234 | assert vector1.length == vector2.length; 235 | for (int i = 0; i < vector1.length; i++) { 236 | vector2[i] += constant * vector1[i]; 237 | } 238 | } 239 | 240 | /** 241 | * returns the dot product of two vectors 242 | * 243 | * @since 1.8 244 | */ 245 | private static double dot(double vector1[], double vector2[]) { 246 | 247 | double product = 0; 248 | assert vector1.length == vector2.length; 249 | for (int i = 0; i < vector1.length; i++) { 250 | product += vector1[i] * vector2[i]; 251 | } 252 | return product; 253 | 254 | } 255 | 256 | /** 257 | * returns the euclidean norm of a vector 258 | * 259 | * @since 1.8 260 | */ 261 | private static double euclideanNorm(double vector[]) { 262 | 263 | int n = vector.length; 264 | 265 | if (n < 1) { 266 | return 0; 267 | } 268 | 269 | if (n == 1) { 270 | return Math.abs(vector[0]); 271 | } 272 | 273 | // this algorithm is (often) more accurate than just summing up the 274 | // squares and taking the square-root afterwards 275 | 276 | double scale = 0; // scaling factor that is factored out 277 | double sum = 1; // basic sum of squares from which scale has been 278 | // factored out 279 | for (int i = 0; i < n; i++) { 280 | if (vector[i] != 0) { 281 | double abs = Math.abs(vector[i]); 282 | // try to get the best scaling factor 283 | if (scale < abs) { 284 | double t = scale / abs; 285 | sum = 1 + sum * (t * t); 286 | scale = abs; 287 | } else { 288 | double t = abs / scale; 289 | sum += t * t; 290 | } 291 | } 292 | } 293 | 294 | return scale * Math.sqrt(sum); 295 | } 296 | 297 | /** 298 | * scales a vector by a constant 299 | * 300 | * @since 1.8 301 | */ 302 | private static void scale(double constant, double vector[]) { 303 | if (constant == 1.0) 304 | return; 305 | for (int i = 0; i < vector.length; i++) { 306 | vector[i] *= constant; 307 | } 308 | 309 | } 310 | } 311 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/TronFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | 19 | /** 20 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 21 | * C++ sources. Original Java sources (v 1.94) are available at 22 | * http://liblinear.bwaldvogel.de 23 | * 24 | * @author Danilo Croce 25 | */ 26 | interface TronFunction { 27 | 28 | double fun(double[] w); 29 | 30 | void grad(double[] w, double[] g); 31 | 32 | void Hv(double[] s, double[] Hs); 33 | 34 | int get_nr_variable(); 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/passiveaggressive/KernelizedPassiveAggressiveClassification.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.passiveaggressive; 17 | 18 | 19 | import com.fasterxml.jackson.annotation.JsonTypeName; 20 | 21 | import it.uniroma2.sag.kelp.data.label.Label; 22 | import it.uniroma2.sag.kelp.kernel.Kernel; 23 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 24 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 25 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryKernelMachineClassifier; 26 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryKernelMachineModel; 27 | 28 | /** 29 | * Online Passive-Aggressive Learning Algorithm for classification tasks (Kernel Machine version) . 30 | * Every time an example is misclassified it is added as support vector, with the weight that solves the 31 | * passive aggressive minimization problem 32 | * 33 | * reference: 34 | * 35 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 36 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 37 | * 38 | * The standard algorithm is modified, including the fairness extention from 39 | * [FiliceECIR2014] S. Filice, G. Castellucci, D. Croce, and R. Basili. Effective Kernelized Online Learning 40 | * in Language Processing Tasks. In collection of Advances in Information Retrieval, pp. 347–358, Springer International Publishing, 2014. 41 | * 42 | * 43 | * @author Simone Filice 44 | */ 45 | 46 | @JsonTypeName("kernelizedPA") 47 | public class KernelizedPassiveAggressiveClassification extends PassiveAggressiveClassification implements KernelMethod{ 48 | 49 | private Kernel kernel; 50 | 51 | public KernelizedPassiveAggressiveClassification(){ 52 | this.classifier = new BinaryKernelMachineClassifier(); 53 | this.classifier.setModel(new BinaryKernelMachineModel()); 54 | } 55 | 56 | public KernelizedPassiveAggressiveClassification(float cp, float cn, Loss loss, Policy policy, Kernel kernel, Label label){ 57 | this.classifier = new BinaryKernelMachineClassifier(); 58 | this.classifier.setModel(new BinaryKernelMachineModel()); 59 | this.setKernel(kernel); 60 | this.setLoss(loss); 61 | this.setCp(cp); 62 | this.setCn(cn); 63 | this.setLabel(label); 64 | this.setPolicy(policy); 65 | } 66 | 67 | 68 | @Override 69 | public Kernel getKernel() { 70 | return kernel; 71 | } 72 | 73 | @Override 74 | public void setKernel(Kernel kernel) { 75 | this.kernel = kernel; 76 | this.getPredictionFunction().getModel().setKernel(kernel); 77 | } 78 | 79 | 80 | @Override 81 | public KernelizedPassiveAggressiveClassification duplicate(){ 82 | KernelizedPassiveAggressiveClassification copy = new KernelizedPassiveAggressiveClassification(); 83 | copy.setCp(this.cp); 84 | copy.setCn(c); 85 | copy.setFairness(this.fairness); 86 | copy.setKernel(this.kernel); 87 | copy.setLoss(this.loss); 88 | copy.setPolicy(this.policy); 89 | //copy.setLabel(label); 90 | return copy; 91 | } 92 | 93 | @Override 94 | public BinaryKernelMachineClassifier getPredictionFunction(){ 95 | return (BinaryKernelMachineClassifier) this.classifier; 96 | } 97 | 98 | @Override 99 | public void setPredictionFunction(PredictionFunction predictionFunction) { 100 | this.classifier = (BinaryKernelMachineClassifier) predictionFunction; 101 | } 102 | 103 | } 104 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/passiveaggressive/LinearPassiveAggressiveClassification.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.passiveaggressive; 17 | 18 | import it.uniroma2.sag.kelp.data.label.Label; 19 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 20 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 21 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryLinearClassifier; 22 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 23 | 24 | import com.fasterxml.jackson.annotation.JsonTypeName; 25 | 26 | 27 | /** 28 | * Online Passive-Aggressive Learning Algorithm for classification tasks (linear version) . 29 | * Every time an example is misclassified it is added the the current hyperplane, with the weight that solves the 30 | * passive aggressive minimization problem 31 | * 32 | * reference: 33 | * 34 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 35 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 36 | * 37 | * The standard algorithm is modified, including the fairness extention from 38 | * [FiliceECIR2014] S. Filice, G. Castellucci, D. Croce, and R. Basili. Effective Kernelized Online Learning 39 | * in Language Processing Tasks. In collection of Advances in Information Retrieval, pp. 347–358, Springer International Publishing, 2014. 40 | * 41 | * @author Simone Filice 42 | */ 43 | @JsonTypeName("linearPA") 44 | public class LinearPassiveAggressiveClassification extends PassiveAggressiveClassification implements LinearMethod{ 45 | 46 | private String representation; 47 | 48 | public LinearPassiveAggressiveClassification(){ 49 | this.classifier = new BinaryLinearClassifier(); 50 | this.classifier.setModel(new BinaryLinearModel()); 51 | } 52 | 53 | public LinearPassiveAggressiveClassification(float cp, float cn, Loss loss, Policy policy, String representation, Label label){ 54 | this.classifier = new BinaryLinearClassifier(); 55 | this.classifier.setModel(new BinaryLinearModel()); 56 | this.setCp(cp); 57 | this.setCn(cn); 58 | this.setLoss(loss); 59 | this.setPolicy(policy); 60 | this.setRepresentation(representation); 61 | this.setLabel(label); 62 | } 63 | 64 | @Override 65 | public String getRepresentation() { 66 | return representation; 67 | } 68 | 69 | @Override 70 | public void setRepresentation(String representation) { 71 | this.representation = representation; 72 | BinaryLinearModel model = (BinaryLinearModel) this.classifier.getModel(); 73 | model.setRepresentation(representation); 74 | } 75 | 76 | @Override 77 | public LinearPassiveAggressiveClassification duplicate(){ 78 | LinearPassiveAggressiveClassification copy = new LinearPassiveAggressiveClassification(); 79 | copy.setRepresentation(this.representation); 80 | copy.setCp(this.cp); 81 | copy.setCn(this.c); 82 | copy.setFairness(this.fairness); 83 | copy.setLoss(this.loss); 84 | copy.setPolicy(this.policy); 85 | return copy; 86 | } 87 | 88 | @Override 89 | public BinaryLinearClassifier getPredictionFunction(){ 90 | return (BinaryLinearClassifier) this.classifier; 91 | } 92 | 93 | @Override 94 | public void setPredictionFunction(PredictionFunction predictionFunction) { 95 | this.classifier = (BinaryLinearClassifier) predictionFunction; 96 | } 97 | 98 | } 99 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/passiveaggressive/PassiveAggressiveClassification.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.passiveaggressive; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.learningalgorithm.PassiveAggressive; 21 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 22 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryClassifier; 23 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput; 24 | 25 | import com.fasterxml.jackson.annotation.JsonIgnore; 26 | import com.fasterxml.jackson.annotation.JsonProperty; 27 | 28 | /** 29 | * Online Passive-Aggressive Learning Algorithm for classification tasks. 30 | * Every time an example is misclassified it is added the the current hyperplane, with the weight that solves the 31 | * passive aggressive minimization problem 32 | * 33 | * reference: 34 | * 35 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 36 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 37 | * 38 | * The standard algorithm is modified, including the fairness extention from 39 | * [FiliceECIR2014] S. Filice, G. Castellucci, D. Croce, and R. Basili. Effective Kernelized Online Learning 40 | * in Language Processing Tasks. In collection of Advances in Information Retrieval, pp. 347-358, Springer International Publishing, 2014. 41 | * 42 | * @author Simone Filice 43 | */ 44 | public abstract class PassiveAggressiveClassification extends PassiveAggressive implements ClassificationLearningAlgorithm{ 45 | 46 | public enum Loss{ 47 | HINGE, 48 | RAMP 49 | } 50 | 51 | protected Loss loss = Loss.HINGE; 52 | protected float cp = c;//cp is the aggressiveness w.r.t. positive examples. c will be considered the aggressiveness w.r.t. negative examples 53 | protected boolean fairness = false; 54 | 55 | @JsonIgnore 56 | protected BinaryClassifier classifier; 57 | 58 | 59 | /** 60 | * @return the fairness 61 | */ 62 | public boolean isFairness() { 63 | return fairness; 64 | } 65 | 66 | 67 | /** 68 | * @param fairness the fairness to set 69 | */ 70 | public void setFairness(boolean fairness) { 71 | this.fairness = fairness; 72 | } 73 | 74 | /** 75 | * @return the aggressiveness parameter for positive examples 76 | */ 77 | public float getCp() { 78 | return cp; 79 | } 80 | 81 | 82 | /** 83 | * @param cp the aggressiveness parameter for positive examples 84 | */ 85 | public void setCp(float cp) { 86 | this.cp = cp; 87 | } 88 | 89 | /** 90 | * @return the aggressiveness parameter for negative examples 91 | */ 92 | public float getCn() { 93 | return c; 94 | } 95 | 96 | 97 | /** 98 | * @param cn the aggressiveness parameter for negative examples 99 | */ 100 | public void setCn(float cn) { 101 | this.c = cn; 102 | } 103 | 104 | @Override 105 | @JsonIgnore 106 | public float getC(){ 107 | return c; 108 | } 109 | 110 | @Override 111 | @JsonProperty 112 | public void setC(float c){ 113 | super.setC(c); 114 | this.cp=c; 115 | } 116 | 117 | /** 118 | * @return the loss function type 119 | */ 120 | public Loss getLoss() { 121 | return loss; 122 | } 123 | 124 | 125 | /** 126 | * @param loss the loss function type to set 127 | */ 128 | public void setLoss(Loss loss) { 129 | this.loss = loss; 130 | } 131 | 132 | @Override 133 | public BinaryClassifier getPredictionFunction() { 134 | return this.classifier; 135 | } 136 | 137 | @Override 138 | public BinaryMarginClassifierOutput learn(Example example){ 139 | 140 | BinaryMarginClassifierOutput prediction=this.classifier.predict(example); 141 | 142 | float lossValue = 0;//it represents the distance from the correct semi-space 143 | if(prediction.isClassPredicted(label)!=example.isExampleOf(label)){ 144 | lossValue = 1 + Math.abs(prediction.getScore(label)); 145 | }else if(Math.abs(prediction.getScore(label))<1){ 146 | lossValue = 1 - Math.abs(prediction.getScore(label)); 147 | } 148 | 149 | if(lossValue>0 && (lossValue<2 || this.loss!=Loss.RAMP)){ 150 | float exampleAggressiveness=this.c; 151 | if(example.isExampleOf(label)){ 152 | exampleAggressiveness=cp; 153 | } 154 | float exampleSquaredNorm = this.classifier.getModel().getSquaredNorm(example); 155 | float weight = this.computeWeight(example, lossValue, exampleSquaredNorm ,exampleAggressiveness); 156 | if(!example.isExampleOf(label)){ 157 | weight*=-1; 158 | } 159 | this.getPredictionFunction().getModel().addExample(weight, example); 160 | } 161 | return prediction; 162 | 163 | } 164 | 165 | @Override 166 | public void learn(Dataset dataset){ 167 | if(this.fairness){ 168 | float positiveExample = dataset.getNumberOfPositiveExamples(label); 169 | float negativeExample = dataset.getNumberOfNegativeExamples(label); 170 | cp = c * negativeExample / positiveExample; 171 | } 172 | //System.out.println("cn: " + c + " cp: " + cp); 173 | super.learn(dataset); 174 | } 175 | 176 | } 177 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/pegasos/PegasosLearningAlgorithm.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.pegasos; 17 | 18 | import java.util.ArrayList; 19 | import java.util.Arrays; 20 | import java.util.List; 21 | 22 | import com.fasterxml.jackson.annotation.JsonTypeName; 23 | 24 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 25 | import it.uniroma2.sag.kelp.data.example.Example; 26 | import it.uniroma2.sag.kelp.data.label.Label; 27 | import it.uniroma2.sag.kelp.data.representation.Vector; 28 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 29 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 30 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 31 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 32 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryLinearClassifier; 33 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput; 34 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 35 | 36 | /** 37 | * It implements the Primal Estimated sub-GrAdient SOlver (PEGASOS) for SVM. It is a learning 38 | * algorithm for binary linear classification Support Vector Machines. It operates in an explicit 39 | * feature space (i.e. it does not relies on any kernel). Further details can be found in: 40 | * 41 | * [SingerICML2007] Y. Singer and N. Srebro. Pegasos: Primal estimated sub-gradient solver for SVM. 42 | * In Proceeding of ICML 2007. 43 | * 44 | * @author Simone Filice 45 | * 46 | */ 47 | @JsonTypeName("pegasos") 48 | public class PegasosLearningAlgorithm implements LinearMethod, ClassificationLearningAlgorithm, BinaryLearningAlgorithm{ 49 | 50 | private Label label; 51 | 52 | private BinaryLinearClassifier classifier; 53 | 54 | private int k = 1; 55 | private int iterations = 1000; 56 | private float lambda = 0.01f; 57 | 58 | private String representation; 59 | 60 | /** 61 | * Returns the number of examples k that Pegasos exploits in its 62 | * mini-batch learning approach 63 | * 64 | * @return k 65 | */ 66 | public int getK() { 67 | return k; 68 | } 69 | 70 | /** 71 | * Sets the number of examples k that Pegasos exploits in its 72 | * mini-batch learning approach 73 | * 74 | * @param k the k to set 75 | */ 76 | public void setK(int k) { 77 | this.k = k; 78 | } 79 | 80 | /** 81 | * Returns the number of iterations 82 | * 83 | * @return the number of iterations 84 | */ 85 | public int getIterations() { 86 | return iterations; 87 | } 88 | 89 | /** 90 | * Sets the number of iterations 91 | * 92 | * @param T the number of iterations to set 93 | */ 94 | public void setIterations(int T) { 95 | this.iterations = T; 96 | } 97 | 98 | /** 99 | * Returns the regularization coefficient 100 | * 101 | * @return the lambda 102 | */ 103 | public float getLambda() { 104 | return lambda; 105 | } 106 | 107 | /** 108 | * Sets the regularization coefficient 109 | * 110 | * @param lambda the lambda to set 111 | */ 112 | public void setLambda(float lambda) { 113 | this.lambda = lambda; 114 | } 115 | 116 | public PegasosLearningAlgorithm(){ 117 | this.classifier = new BinaryLinearClassifier(); 118 | this.classifier.setModel(new BinaryLinearModel()); 119 | } 120 | 121 | public PegasosLearningAlgorithm(int k, float lambda, int T, String Representation, Label label){ 122 | this.classifier = new BinaryLinearClassifier(); 123 | this.classifier.setModel(new BinaryLinearModel()); 124 | this.setK(k); 125 | this.setLabel(label); 126 | this.setLambda(lambda); 127 | this.setRepresentation(Representation); 128 | this.setIterations(T); 129 | } 130 | 131 | @Override 132 | public String getRepresentation() { 133 | return representation; 134 | } 135 | 136 | @Override 137 | public void setRepresentation(String representation) { 138 | this.representation = representation; 139 | BinaryLinearModel model = this.classifier.getModel(); 140 | model.setRepresentation(representation); 141 | } 142 | 143 | @Override 144 | public void learn(Dataset dataset) { 145 | if(this.getPredictionFunction().getModel().getHyperplane()==null){ 146 | this.getPredictionFunction().getModel().setHyperplane(dataset.getZeroVector(representation)); 147 | } 148 | 149 | for(int t=1;t<=iterations;t++){ 150 | 151 | List A_t = dataset.getRandExamples(k); 152 | List A_tp = new ArrayList(); 153 | List signA_tp = new ArrayList(); 154 | float eta_t = ((float)1)/(lambda*t); 155 | Vector w_t = this.getPredictionFunction().getModel().getHyperplane(); 156 | 157 | //creating A_tp 158 | for(Example example: A_t){ 159 | BinaryMarginClassifierOutput prediction = this.classifier.predict(example); 160 | float y = -1; 161 | if(example.isExampleOf(label)){ 162 | y=1; 163 | } 164 | 165 | if(prediction.getScore(label)*y<1){ 166 | A_tp.add(example); 167 | signA_tp.add(y); 168 | } 169 | } 170 | //creating w_(t+1/2) 171 | w_t.scale(1-eta_t*lambda); 172 | float miscassificationFactor = eta_t/k; 173 | for(int i=0; i labels){ 211 | if(labels.size()!=1){ 212 | throw new IllegalArgumentException("Pegasos algorithm is a binary method which can learn a single Label"); 213 | } 214 | else{ 215 | this.label=labels.get(0); 216 | this.classifier.setLabels(labels); 217 | } 218 | } 219 | 220 | 221 | @Override 222 | public List getLabels() { 223 | return Arrays.asList(label); 224 | } 225 | 226 | @Override 227 | public Label getLabel(){ 228 | return this.label; 229 | } 230 | 231 | @Override 232 | public void setLabel(Label label){ 233 | this.setLabels(Arrays.asList(label)); 234 | } 235 | 236 | @Override 237 | public void setPredictionFunction(PredictionFunction predictionFunction) { 238 | this.classifier = (BinaryLinearClassifier) predictionFunction; 239 | } 240 | } 241 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/perceptron/KernelizedPerceptron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron; 17 | 18 | 19 | import it.uniroma2.sag.kelp.data.label.Label; 20 | import it.uniroma2.sag.kelp.kernel.Kernel; 21 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 22 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 23 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryKernelMachineClassifier; 24 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryKernelMachineModel; 25 | 26 | import com.fasterxml.jackson.annotation.JsonTypeName; 27 | 28 | /** 29 | * The perceptron learning algorithm algorithm for classification tasks (Kernel machine version). Reference: 30 | * [Rosenblatt1957] F. Rosenblatt. The Perceptron – a perceiving and recognizing automaton. Report 85-460-1, Cornell Aeronautical Laboratory (1957) 31 | * 32 | * @author Simone Filice 33 | * 34 | */ 35 | @JsonTypeName("kernelizedPerceptron") 36 | public class KernelizedPerceptron extends Perceptron implements KernelMethod{ 37 | 38 | 39 | private Kernel kernel; 40 | 41 | public KernelizedPerceptron(){ 42 | this.classifier = new BinaryKernelMachineClassifier(); 43 | this.classifier.setModel(new BinaryKernelMachineModel()); 44 | } 45 | 46 | public KernelizedPerceptron(float alpha, float margin, boolean unbiased, Kernel kernel, Label label){ 47 | this.classifier = new BinaryKernelMachineClassifier(); 48 | this.classifier.setModel(new BinaryKernelMachineModel()); 49 | this.setAlpha(alpha); 50 | this.setMargin(margin); 51 | this.setUnbiased(unbiased); 52 | this.setKernel(kernel); 53 | this.setLabel(label); 54 | } 55 | 56 | @Override 57 | public Kernel getKernel() { 58 | return kernel; 59 | } 60 | 61 | @Override 62 | public void setKernel(Kernel kernel) { 63 | this.kernel = kernel; 64 | this.getPredictionFunction().getModel().setKernel(kernel); 65 | } 66 | 67 | @Override 68 | public KernelizedPerceptron duplicate(){ 69 | KernelizedPerceptron copy = new KernelizedPerceptron(); 70 | copy.setKernel(this.kernel); 71 | copy.setAlpha(this.alpha); 72 | copy.setMargin(this.margin); 73 | copy.setUnbiased(this.unbiased); 74 | return copy; 75 | } 76 | 77 | @Override 78 | public BinaryKernelMachineClassifier getPredictionFunction(){ 79 | return (BinaryKernelMachineClassifier) this.classifier; 80 | } 81 | 82 | @Override 83 | public void setPredictionFunction(PredictionFunction predictionFunction) { 84 | this.classifier = (BinaryKernelMachineClassifier) predictionFunction; 85 | } 86 | 87 | } 88 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/perceptron/LinearPerceptron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron; 17 | 18 | 19 | import com.fasterxml.jackson.annotation.JsonTypeName; 20 | 21 | import it.uniroma2.sag.kelp.data.label.Label; 22 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 23 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 24 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryLinearClassifier; 25 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 26 | 27 | /** 28 | * The perceptron learning algorithm algorithm for classification tasks (linear version). Reference: 29 | * [Rosenblatt1957] F. Rosenblatt. The Perceptron – a perceiving and recognizing automaton. Report 85-460-1, Cornell Aeronautical Laboratory (1957) 30 | * 31 | * @author Simone Filice 32 | * 33 | */ 34 | @JsonTypeName("linearPerceptron") 35 | public class LinearPerceptron extends Perceptron implements LinearMethod{ 36 | 37 | 38 | private String representation; 39 | 40 | 41 | public LinearPerceptron(){ 42 | this.classifier = new BinaryLinearClassifier(); 43 | this.classifier.setModel(new BinaryLinearModel()); 44 | } 45 | 46 | public LinearPerceptron(float alpha, float margin, boolean unbiased, String representation, Label label){ 47 | this.classifier = new BinaryLinearClassifier(); 48 | this.classifier.setModel(new BinaryLinearModel()); 49 | this.setAlpha(alpha); 50 | this.setMargin(margin); 51 | this.setUnbiased(unbiased); 52 | this.setRepresentation(representation); 53 | this.setLabel(label); 54 | } 55 | 56 | @Override 57 | public String getRepresentation() { 58 | return representation; 59 | } 60 | 61 | @Override 62 | public void setRepresentation(String representation) { 63 | this.representation = representation; 64 | BinaryLinearModel model = (BinaryLinearModel) this.classifier.getModel(); 65 | model.setRepresentation(representation); 66 | } 67 | 68 | @Override 69 | public LinearPerceptron duplicate(){ 70 | LinearPerceptron copy = new LinearPerceptron(); 71 | copy.setAlpha(this.alpha); 72 | copy.setMargin(this.margin); 73 | copy.setRepresentation(representation); 74 | copy.setUnbiased(this.unbiased); 75 | return copy; 76 | } 77 | 78 | @Override 79 | public BinaryLinearClassifier getPredictionFunction(){ 80 | return (BinaryLinearClassifier) this.classifier; 81 | } 82 | 83 | @Override 84 | public void setPredictionFunction(PredictionFunction predictionFunction) { 85 | this.classifier = (BinaryLinearClassifier) predictionFunction; 86 | } 87 | 88 | } 89 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/perceptron/Perceptron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron; 17 | 18 | import java.util.Arrays; 19 | import java.util.List; 20 | 21 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 22 | import it.uniroma2.sag.kelp.data.example.Example; 23 | import it.uniroma2.sag.kelp.data.label.Label; 24 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 25 | import it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm; 26 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 27 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryClassifier; 28 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput; 29 | 30 | import com.fasterxml.jackson.annotation.JsonIgnore; 31 | 32 | /** 33 | * The perceptron learning algorithm algorithm for classification tasks. Reference: 34 | * [Rosenblatt1957] F. Rosenblatt. The Perceptron – a perceiving and recognizing automaton. Report 85-460-1, Cornell Aeronautical Laboratory (1957) 35 | * 36 | * @author Simone Filice 37 | * 38 | */ 39 | public abstract class Perceptron implements ClassificationLearningAlgorithm, OnlineLearningAlgorithm, BinaryLearningAlgorithm{ 40 | 41 | @JsonIgnore 42 | protected BinaryClassifier classifier; 43 | 44 | protected Label label; 45 | 46 | protected float alpha=1; 47 | protected float margin = 1; 48 | protected boolean unbiased=false; 49 | 50 | /** 51 | * Returns the learning rate, i.e. the weight associated to misclassified examples during the learning process 52 | * 53 | * @return the learning rate 54 | */ 55 | public float getAlpha() { 56 | return alpha; 57 | } 58 | 59 | /** 60 | * Sets the learning rate, i.e. the weight associated to misclassified examples during the learning process 61 | * 62 | * @param alpha the learning rate to set 63 | */ 64 | public void setAlpha(float alpha) { 65 | if(alpha<=0 || alpha>1){ 66 | throw new IllegalArgumentException("Invalid learning rate for the perceptron algorithm: valid alphas in (0,1]"); 67 | } 68 | this.alpha = alpha; 69 | } 70 | 71 | /** 72 | * Returns the desired margin, i.e. the minimum distance from the hyperplane that an example must have 73 | * in order to be not considered misclassified 74 | * 75 | * @return the margin 76 | */ 77 | public float getMargin() { 78 | return margin; 79 | } 80 | 81 | /** 82 | * Sets the desired margin, i.e. the minimum distance from the hyperplane that an example must have 83 | * in order to be not considered misclassified 84 | * 85 | * @param margin the margin to set 86 | */ 87 | public void setMargin(float margin) { 88 | this.margin = margin; 89 | } 90 | 91 | /** 92 | * Returns whether the bias, i.e. the constant term of the hyperplane, is always 0, or can be modified during 93 | * the learning process 94 | * 95 | * @return the unbiased 96 | */ 97 | public boolean isUnbiased() { 98 | return unbiased; 99 | } 100 | 101 | /** 102 | * Sets whether the bias, i.e. the constant term of the hyperplane, is always 0, or can be modified during 103 | * the learning process 104 | * 105 | * @param unbiased the unbiased to set 106 | */ 107 | public void setUnbiased(boolean unbiased) { 108 | this.unbiased = unbiased; 109 | } 110 | 111 | 112 | @Override 113 | public void learn(Dataset dataset) { 114 | 115 | while(dataset.hasNextExample()){ 116 | Example example = dataset.getNextExample(); 117 | this.learn(example); 118 | } 119 | dataset.reset(); 120 | } 121 | 122 | @Override 123 | public BinaryMarginClassifierOutput learn(Example example){ 124 | BinaryMarginClassifierOutput prediction = this.classifier.predict(example); 125 | 126 | float predValue = prediction.getScore(label); 127 | if(Math.abs(predValue) labels){ 154 | if(labels.size()!=1){ 155 | throw new IllegalArgumentException("The Perceptron algorithm is a binary method which can learn a single Label"); 156 | } 157 | else{ 158 | this.label=labels.get(0); 159 | this.classifier.setLabels(labels); 160 | } 161 | } 162 | 163 | 164 | @Override 165 | public List getLabels() { 166 | 167 | return Arrays.asList(label); 168 | } 169 | 170 | @Override 171 | public Label getLabel(){ 172 | return this.label; 173 | } 174 | 175 | @Override 176 | public void setLabel(Label label){ 177 | this.setLabels(Arrays.asList(label)); 178 | } 179 | 180 | } 181 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/BinaryPlattNormalizer.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | import it.uniroma2.sag.kelp.data.label.Label; 4 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput; 5 | 6 | public class BinaryPlattNormalizer { 7 | 8 | private float A; 9 | private float B; 10 | 11 | public BinaryPlattNormalizer() { 12 | 13 | } 14 | 15 | public BinaryPlattNormalizer(float a, float b) { 16 | super(); 17 | A = a; 18 | B = b; 19 | } 20 | 21 | public float normalizeScore(float nonNomalizedScore) { 22 | return (float) (1.0 / (1.0 + Math.exp(A * nonNomalizedScore + B))); 23 | } 24 | 25 | public float getA() { 26 | return A; 27 | } 28 | 29 | public float getB() { 30 | return B; 31 | } 32 | 33 | public void setA(float a) { 34 | A = a; 35 | } 36 | 37 | public void setB(float b) { 38 | B = b; 39 | } 40 | 41 | @Override 42 | public String toString() { 43 | return "PlattSigmoidFunction [A=" + A + ", B=" + B + "]"; 44 | } 45 | 46 | public BinaryMarginClassifierOutput getNormalizedScore(BinaryMarginClassifierOutput binaryMarginClassifierOutput) { 47 | 48 | Label positiveLabel = binaryMarginClassifierOutput.getAllClasses().get(0); 49 | 50 | Float nonNormalizedScore = binaryMarginClassifierOutput.getScore(positiveLabel); 51 | 52 | BinaryMarginClassifierOutput res = new BinaryMarginClassifierOutput(positiveLabel, 53 | normalizeScore(nonNormalizedScore)); 54 | 55 | return res; 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/MulticlassPlattNormalizer.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | import java.util.HashMap; 4 | 5 | import it.uniroma2.sag.kelp.data.label.Label; 6 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassificationOutput; 7 | 8 | public class MulticlassPlattNormalizer { 9 | 10 | private HashMap binaryPlattNormalizers; 11 | 12 | public void addBinaryPlattNormalizer(Label label, BinaryPlattNormalizer binaryPlattNormalizer) { 13 | if (binaryPlattNormalizers == null) { 14 | binaryPlattNormalizers = new HashMap(); 15 | } 16 | binaryPlattNormalizers.put(label, binaryPlattNormalizer); 17 | } 18 | 19 | public OneVsAllClassificationOutput getNormalizedScores(OneVsAllClassificationOutput oneVsAllClassificationOutput) { 20 | OneVsAllClassificationOutput res = new OneVsAllClassificationOutput(); 21 | 22 | for (Label l : oneVsAllClassificationOutput.getAllClasses()) { 23 | float nonNormalizedScore = oneVsAllClassificationOutput.getScore(l); 24 | BinaryPlattNormalizer binaryPlattNormalizer = binaryPlattNormalizers.get(l); 25 | float normalizedScore = binaryPlattNormalizer.normalizeScore(nonNormalizedScore); 26 | 27 | res.addBinaryPrediction(l, normalizedScore); 28 | } 29 | 30 | return res; 31 | } 32 | 33 | public static OneVsAllClassificationOutput softmax(OneVsAllClassificationOutput oneVsAllClassificationOutput) { 34 | OneVsAllClassificationOutput res = new OneVsAllClassificationOutput(); 35 | 36 | float denom = 0; 37 | for (Label l : oneVsAllClassificationOutput.getAllClasses()) { 38 | float score = oneVsAllClassificationOutput.getScore(l); 39 | denom += Math.exp(score); 40 | } 41 | 42 | 43 | for (Label l : oneVsAllClassificationOutput.getAllClasses()) { 44 | float score = oneVsAllClassificationOutput.getScore(l); 45 | float newScore = (float)Math.exp(score)/denom; 46 | 47 | res.addBinaryPrediction(l, newScore); 48 | } 49 | 50 | return res; 51 | } 52 | 53 | public HashMap getBinaryPlattNormalizers() { 54 | return binaryPlattNormalizers; 55 | } 56 | 57 | public void setBinaryPlattNormalizers(HashMap binaryPlattNormalizers) { 58 | this.binaryPlattNormalizers = binaryPlattNormalizers; 59 | } 60 | 61 | } 62 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/PlattInputElement.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | public class PlattInputElement { 4 | 5 | private int label; 6 | private float value; 7 | 8 | public PlattInputElement(int label, float value) { 9 | super(); 10 | this.label = label; 11 | this.value = value; 12 | } 13 | 14 | public int getLabel() { 15 | return label; 16 | } 17 | 18 | public float getValue() { 19 | return value; 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/PlattInputList.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | import java.util.Vector; 4 | 5 | public class PlattInputList { 6 | 7 | private Vector list; 8 | private int positiveElement; 9 | private int negativeElement; 10 | 11 | public PlattInputList() { 12 | list = new Vector(); 13 | } 14 | 15 | public void add(PlattInputElement arg0) { 16 | if (arg0.getLabel() > 0) 17 | positiveElement++; 18 | else 19 | negativeElement++; 20 | 21 | list.add(arg0); 22 | } 23 | 24 | public PlattInputElement get(int index) { 25 | return list.get(index); 26 | } 27 | 28 | public int size() { 29 | return list.size(); 30 | } 31 | 32 | public int getPositiveElement() { 33 | return positiveElement; 34 | } 35 | 36 | public int getNegativeElement() { 37 | return negativeElement; 38 | } 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/PlattMethod.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | import java.util.HashMap; 4 | 5 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 6 | import it.uniroma2.sag.kelp.data.example.Example; 7 | import it.uniroma2.sag.kelp.data.label.Label; 8 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 9 | import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm; 10 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 11 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 12 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 13 | 14 | public class PlattMethod { 15 | 16 | /** 17 | * Input parameters: 18 | * 19 | * deci = array of SVM decision values 20 | * 21 | * label = array of booleans: is the example labeled +1? 22 | * 23 | * prior1 = number of positive examples 24 | * 25 | * prior0 = number of negative examples 26 | * 27 | * Outputs: 28 | * 29 | * A, B = parameters of sigmoid 30 | * 31 | * @return 32 | **/ 33 | private static BinaryPlattNormalizer estimateSigmoid(float[] deci, float[] label, int prior1, int prior0) { 34 | 35 | /** 36 | * Parameter setting 37 | */ 38 | // Maximum number of iterations 39 | int maxiter = 100; 40 | // Minimum step taken in line search 41 | // minstep=1e-10; 42 | double minstep = 1e-10; 43 | double stopping = 1e-5; 44 | // Sigma: Set to any value > 0 45 | double sigma = 1e-12; 46 | // Construct initial values: target support in array t, 47 | // initial function value in fval 48 | double hiTarget = ((double) prior1 + 1.0f) / ((double) prior1 + 2.0f); 49 | double loTarget = 1 / (prior0 + 2.0f); 50 | 51 | int len = prior1 + prior0; // Total number of data 52 | double A; 53 | double B; 54 | 55 | double t[] = new double[len]; 56 | 57 | for (int i = 0; i < len; i++) { 58 | if (label[i] > 0) 59 | t[i] = hiTarget; 60 | else 61 | t[i] = loTarget; 62 | } 63 | 64 | A = 0; 65 | B = Math.log((prior0 + 1.0) / (prior1 + 1.0)); 66 | double fval = 0f; 67 | 68 | for (int i = 0; i < len; i++) { 69 | double fApB = deci[i] * A + B; 70 | if (fApB >= 0) 71 | fval += t[i] * fApB + Math.log(1 + Math.exp(-fApB)); 72 | else 73 | fval += (t[i] - 1) * fApB + Math.log(1 + Math.exp(fApB)); 74 | } 75 | 76 | int it = 1; 77 | for (it = 1; it <= maxiter; it++) { 78 | // Update Gradient and Hessian (use H� = H + sigma I) 79 | double h11 = sigma; 80 | double h22 = sigma; 81 | double h21 = 0; 82 | double g1 = 0; 83 | double g2 = 0; 84 | for (int i = 0; i < len; i++) { 85 | double fApB = deci[i] * A + B; 86 | double p; 87 | double q; 88 | if (fApB >= 0) { 89 | p = (Math.exp(-fApB) / (1.0 + Math.exp(-fApB))); 90 | q = (1.0 / (1.0 + Math.exp(-fApB))); 91 | } else { 92 | p = 1.0 / (1.0 + Math.exp(fApB)); 93 | q = Math.exp(fApB) / (1.0 + Math.exp(fApB)); 94 | } 95 | double d2 = p * q; 96 | h11 += deci[i] * deci[i] * d2; 97 | h22 += d2; 98 | h21 += deci[i] * d2; 99 | double d1 = t[i] - p; 100 | g1 += deci[i] * d1; 101 | g2 += d1; 102 | } 103 | if (Math.abs(g1) < stopping && Math.abs(g2) < stopping) // Stopping 104 | // criteria 105 | break; 106 | 107 | // Compute modified Newton directions 108 | double det = h11 * h22 - h21 * h21; 109 | double dA = -(h22 * g1 - h21 * g2) / det; 110 | double dB = -(-h21 * g1 + h11 * g2) / det; 111 | double gd = g1 * dA + g2 * dB; 112 | double stepsize = 1; 113 | 114 | while (stepsize >= minstep) { // Line search 115 | double newA = A + stepsize * dA; 116 | double newB = B + stepsize * dB; 117 | double newf = 0.0; 118 | for (int i = 0; i < len; i++) { 119 | double fApB = deci[i] * newA + newB; 120 | if (fApB >= 0) 121 | newf += t[i] * fApB + Math.log(1 + Math.exp(-fApB)); 122 | else 123 | newf += (t[i] - 1) * fApB + Math.log(1 + Math.exp(fApB)); 124 | } 125 | 126 | if (newf < fval + 1e-4 * stepsize * gd) { 127 | A = newA; 128 | B = newB; 129 | fval = newf; 130 | break; // Sufficient decrease satisfied 131 | } else 132 | stepsize /= 2.0; 133 | } 134 | if (stepsize < minstep) { 135 | System.out.println("Line search fails"); 136 | break; 137 | } 138 | } 139 | if (it >= maxiter) 140 | System.out.println("Reaching maximum iterations"); 141 | 142 | return new BinaryPlattNormalizer((float) A, (float) B); 143 | 144 | } 145 | 146 | public static BinaryPlattNormalizer esitmateSigmoid(SimpleDataset dataset, 147 | BinaryLearningAlgorithm binaryLearningAlgorithm, int nFolds) { 148 | 149 | PlattInputList plattInputList = new PlattInputList(); 150 | 151 | Label positiveLabel = binaryLearningAlgorithm.getLabel(); 152 | 153 | SimpleDataset[] folds = dataset.getShuffledDataset().nFolding(nFolds); 154 | 155 | for (int f = 0; f < folds.length; f++) { 156 | 157 | SimpleDataset fold = folds[f]; 158 | 159 | SimpleDataset localTrainDataset = new SimpleDataset(); 160 | SimpleDataset localTestDataset = new SimpleDataset(); 161 | for (int i = 0; i < folds.length; i++) { 162 | if (i != f) { 163 | localTrainDataset.addExamples(fold); 164 | } else { 165 | localTestDataset.addExamples(fold); 166 | } 167 | } 168 | 169 | LearningAlgorithm duplicatedLearningAlgorithm = binaryLearningAlgorithm.duplicate(); 170 | 171 | duplicatedLearningAlgorithm.learn(fold); 172 | 173 | PredictionFunction predictionFunction = duplicatedLearningAlgorithm.getPredictionFunction(); 174 | 175 | for (Example example : localTestDataset.getExamples()) { 176 | Prediction predict = predictionFunction.predict(example); 177 | 178 | float value = predict.getScore(positiveLabel); 179 | 180 | int label = 1; 181 | if (!example.isExampleOf(positiveLabel)) 182 | label = -1; 183 | plattInputList.add(new PlattInputElement(label, value)); 184 | } 185 | } 186 | 187 | return estimateSigmoid(plattInputList); 188 | } 189 | 190 | public static MulticlassPlattNormalizer esitmateSigmoid(SimpleDataset dataset, OneVsAllLearning oneVsAllLearning, 191 | int nFolds) { 192 | 193 | HashMap plattInputLists = new HashMap(); 194 | for(Label label: dataset.getClassificationLabels()){ 195 | plattInputLists.put(label, new PlattInputList()); 196 | } 197 | 198 | SimpleDataset[] folds = dataset.getShuffledDataset().nFolding(nFolds); 199 | 200 | MulticlassPlattNormalizer res = new MulticlassPlattNormalizer(); 201 | 202 | for (int f = 0; f < folds.length; f++) { 203 | 204 | SimpleDataset fold = folds[f]; 205 | 206 | SimpleDataset localTrainDataset = new SimpleDataset(); 207 | SimpleDataset localTestDataset = new SimpleDataset(); 208 | for (int i = 0; i < folds.length; i++) { 209 | if (i != f) { 210 | localTrainDataset.addExamples(fold); 211 | } else { 212 | localTestDataset.addExamples(fold); 213 | } 214 | } 215 | 216 | LearningAlgorithm duplicatedLearningAlgorithm = oneVsAllLearning.duplicate(); 217 | 218 | duplicatedLearningAlgorithm.learn(fold); 219 | 220 | PredictionFunction predictionFunction = duplicatedLearningAlgorithm.getPredictionFunction(); 221 | 222 | for (Example example : localTestDataset.getExamples()) { 223 | Prediction predict = predictionFunction.predict(example); 224 | 225 | for (Label label : dataset.getClassificationLabels()) { 226 | 227 | float valueOfLabel = predict.getScore(label); 228 | 229 | int binaryLabel = 1; 230 | if (!example.isExampleOf(label)) 231 | binaryLabel = -1; 232 | plattInputLists.get(label).add(new PlattInputElement(binaryLabel, valueOfLabel)); 233 | } 234 | } 235 | } 236 | 237 | for (Label label : dataset.getClassificationLabels()) { 238 | res.addBinaryPlattNormalizer(label, estimateSigmoid(plattInputLists.get(label))); 239 | } 240 | 241 | return res; 242 | } 243 | 244 | protected static BinaryPlattNormalizer estimateSigmoid(PlattInputList inputList) { 245 | float[] deci = new float[inputList.size()]; 246 | float[] label = new float[inputList.size()]; 247 | int prior1 = inputList.getPositiveElement(); 248 | int prior0 = inputList.getNegativeElement(); 249 | 250 | for (int i = 0; i < inputList.size(); i++) { 251 | deci[i] = inputList.get(i).getValue(); 252 | label[i] = inputList.get(i).getLabel(); 253 | } 254 | 255 | return estimateSigmoid(deci, label, prior1, prior0); 256 | } 257 | 258 | } 259 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/scw/SCWType.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.scw; 17 | 18 | /** 19 | * The two types of Soft Confidence-Weighted implemented variants 20 | * 21 | * @author Danilo Croce 22 | * 23 | */ 24 | public enum SCWType { 25 | 26 | SCW_I, SCW_II 27 | 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/clustering/kernelbasedkmeans/KernelBasedKMeansExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.clustering.kernelbasedkmeans; 17 | 18 | import com.fasterxml.jackson.annotation.JsonTypeName; 19 | 20 | import it.uniroma2.sag.kelp.data.clustering.ClusterExample; 21 | import it.uniroma2.sag.kelp.data.example.Example; 22 | 23 | @JsonTypeName("kernelbasedkmeansexample") 24 | public class KernelBasedKMeansExample extends ClusterExample { 25 | 26 | /** 27 | * 28 | */ 29 | private static final long serialVersionUID = -5368757832244686390L; 30 | 31 | public KernelBasedKMeansExample() { 32 | super(); 33 | } 34 | 35 | public KernelBasedKMeansExample(Example e, float dist) { 36 | super(e, dist); 37 | } 38 | 39 | @Override 40 | public Example getExample() { 41 | return example; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/liblinear/LibLinearRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.liblinear; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.label.Label; 20 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 21 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 22 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.L2R_L2_SvcFunction; 23 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.L2R_L2_SvrFunction; 24 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.Problem; 25 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.Problem.LibLinearSolverType; 26 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.Tron; 27 | import it.uniroma2.sag.kelp.learningalgorithm.regression.RegressionLearningAlgorithm; 28 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 29 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 30 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateLinearRegressionFunction; 31 | 32 | import java.util.Arrays; 33 | import java.util.List; 34 | 35 | import com.fasterxml.jackson.annotation.JsonIgnore; 36 | import com.fasterxml.jackson.annotation.JsonTypeName; 37 | 38 | /** 39 | * This class implements linear SVM regression trained using a coordinate descent 40 | * algorithm [Fan et al, 2008]. It operates in an explicit feature space (i.e. 41 | * it does not relies on any kernel). This code has been adapted from the Java 42 | * port of the original LIBLINEAR C++ sources. 43 | * 44 | * Further details can be found in: 45 | * 46 | * [Fan et al, 2008] R.-E. Fan, K.-W. Chang, C.-J. Hsieh, X.-R. Wang, and C.-J. 47 | * Lin. LIBLINEAR: A Library for Large Linear Classification, Journal of Machine 48 | * Learning Research 9(2008), 1871-1874. Software available at 49 | * 50 | * The original LIBLINEAR code: 51 | * http://www.csie.ntu.edu.tw/~cjlin/liblinear 52 | * 53 | * The original JAVA porting (v 1.94): http://liblinear.bwaldvogel.de 54 | * 55 | * @author Danilo Croce 56 | */ 57 | @JsonTypeName("liblinearregression") 58 | public class LibLinearRegression implements LinearMethod, 59 | RegressionLearningAlgorithm, BinaryLearningAlgorithm { 60 | 61 | /** 62 | * The property corresponding to the variable to be learned 63 | */ 64 | private Label label; 65 | /** 66 | * The regularization parameter 67 | */ 68 | private double c = 1; 69 | 70 | /** 71 | * The regressor to be returned 72 | */ 73 | @JsonIgnore 74 | private UnivariateLinearRegressionFunction regressionFunction; 75 | 76 | /** 77 | * The epsilon in loss function of SVR (default 0.1) 78 | */ 79 | private double p = 0.1f; 80 | 81 | /** 82 | * The identifier of the representation to be considered for the training 83 | * step 84 | */ 85 | private String representation; 86 | 87 | /** 88 | * @param label 89 | * The regression property to be learned 90 | * @param c 91 | * The regularization parameter 92 | * 93 | * @param p 94 | * The The epsilon in loss function of SVR 95 | * 96 | * @param representationName 97 | * The identifier of the representation to be considered for the 98 | * training step 99 | */ 100 | public LibLinearRegression(Label label, double c, double p, 101 | String representationName) { 102 | this(); 103 | 104 | this.setLabel(label); 105 | this.c = c; 106 | this.p = p; 107 | this.setRepresentation(representationName); 108 | } 109 | 110 | /** 111 | * @param c 112 | * The regularization parameter 113 | * 114 | * @param representationName 115 | * The identifier of the representation to be considered for the 116 | * training step 117 | */ 118 | public LibLinearRegression(double c, double p, String representationName) { 119 | this(); 120 | this.c = c; 121 | this.p = p; 122 | this.setRepresentation(representationName); 123 | } 124 | 125 | public LibLinearRegression() { 126 | this.regressionFunction = new UnivariateLinearRegressionFunction(); 127 | this.regressionFunction.setModel(new BinaryLinearModel()); 128 | } 129 | 130 | /** 131 | * @return the regularization parameter 132 | */ 133 | public double getC() { 134 | return c; 135 | } 136 | 137 | /** 138 | * @param c 139 | * the regularization parameter 140 | */ 141 | public void setC(double c) { 142 | this.c = c; 143 | } 144 | 145 | /** 146 | * @return the epsilon in loss function 147 | */ 148 | public double getP() { 149 | return p; 150 | } 151 | 152 | /** 153 | * @param p 154 | * the epsilon in loss function 155 | */ 156 | public void setP(double p) { 157 | this.p = p; 158 | } 159 | 160 | /* 161 | * (non-Javadoc) 162 | * 163 | * @see 164 | * it.uniroma2.sag.kelp.learningalgorithm.LinearMethod#getRepresentation() 165 | */ 166 | @Override 167 | public String getRepresentation() { 168 | return representation; 169 | } 170 | 171 | /* 172 | * (non-Javadoc) 173 | * 174 | * @see 175 | * it.uniroma2.sag.kelp.learningalgorithm.LinearMethod#setRepresentation 176 | * (java.lang.String) 177 | */ 178 | @Override 179 | public void setRepresentation(String representation) { 180 | this.representation = representation; 181 | BinaryLinearModel model = this.regressionFunction.getModel(); 182 | model.setRepresentation(representation); 183 | } 184 | 185 | /* 186 | * (non-Javadoc) 187 | * 188 | * @see 189 | * it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#setLabels(java 190 | * .util.List) 191 | */ 192 | @Override 193 | public void setLabels(List labels) { 194 | if (labels.size() != 1) { 195 | throw new IllegalArgumentException( 196 | "LibLinear algorithm is a binary method which can learn a single Label"); 197 | } else { 198 | this.label = labels.get(0); 199 | this.regressionFunction.setLabels(labels); 200 | } 201 | } 202 | 203 | /* 204 | * (non-Javadoc) 205 | * 206 | * @see it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#getLabels() 207 | */ 208 | @Override 209 | public List getLabels() { 210 | return Arrays.asList(label); 211 | } 212 | 213 | /* 214 | * (non-Javadoc) 215 | * 216 | * @see 217 | * it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm#getLabel() 218 | */ 219 | @Override 220 | public Label getLabel() { 221 | return this.label; 222 | } 223 | 224 | /* 225 | * (non-Javadoc) 226 | * 227 | * @see 228 | * it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm#setLabel 229 | * (it.uniroma2.sag.kelp.data.label.Label) 230 | */ 231 | @Override 232 | public void setLabel(Label label) { 233 | this.setLabels(Arrays.asList(label)); 234 | } 235 | 236 | /* 237 | * (non-Javadoc) 238 | * 239 | * @see 240 | * it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#learn(it.uniroma2 241 | * .sag.kelp.data.dataset.Dataset) 242 | */ 243 | @Override 244 | public void learn(Dataset dataset) { 245 | 246 | double eps = 0.001; 247 | 248 | int l = dataset.getNumberOfExamples(); 249 | 250 | double[] C = new double[l]; 251 | for (int i = 0; i < l; i++) { 252 | C[i] = c; 253 | } 254 | 255 | Problem problem = new Problem(dataset, representation, label, 256 | LibLinearSolverType.REGRESSION); 257 | 258 | L2R_L2_SvcFunction fun_obj = new L2R_L2_SvrFunction(problem, C, p); 259 | 260 | Tron tron = new Tron(fun_obj, eps); 261 | 262 | double[] w = new double[problem.n]; 263 | tron.tron(w); 264 | 265 | this.regressionFunction.getModel().setHyperplane(problem.getW(w)); 266 | this.regressionFunction.getModel().setRepresentation(representation); 267 | this.regressionFunction.getModel().setBias(0); 268 | } 269 | 270 | /* 271 | * (non-Javadoc) 272 | * 273 | * @see it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#duplicate() 274 | */ 275 | @Override 276 | public LibLinearRegression duplicate() { 277 | LibLinearRegression copy = new LibLinearRegression(); 278 | copy.setRepresentation(representation); 279 | copy.setC(c); 280 | copy.setP(p); 281 | return copy; 282 | } 283 | 284 | /* 285 | * (non-Javadoc) 286 | * 287 | * @see it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#reset() 288 | */ 289 | @Override 290 | public void reset() { 291 | this.regressionFunction.reset(); 292 | } 293 | 294 | @Override 295 | public UnivariateLinearRegressionFunction getPredictionFunction() { 296 | return regressionFunction; 297 | } 298 | 299 | @Override 300 | public void setPredictionFunction(PredictionFunction predictionFunction) { 301 | this.regressionFunction = (UnivariateLinearRegressionFunction) predictionFunction; 302 | } 303 | 304 | } 305 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/KernelizedPassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import com.fasterxml.jackson.annotation.JsonTypeName; 19 | 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.kernel.Kernel; 22 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 23 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 24 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateKernelMachineRegressionFunction; 25 | 26 | /** 27 | * Online Passive-Aggressive Learning Algorithm for regression tasks (kernel machine version). 28 | * 29 | * reference: 30 | * 31 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 32 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 33 | * 34 | * @author Simone Filice 35 | */ 36 | @JsonTypeName("kernelizedPA-R") 37 | public class KernelizedPassiveAggressiveRegression extends PassiveAggressiveRegression implements KernelMethod{ 38 | 39 | private Kernel kernel; 40 | 41 | public KernelizedPassiveAggressiveRegression(){ 42 | this.regressor = new UnivariateKernelMachineRegressionFunction(); 43 | } 44 | 45 | public KernelizedPassiveAggressiveRegression(float aggressiveness, float epsilon, Policy policy, Kernel kernel, Label label){ 46 | this.regressor = new UnivariateKernelMachineRegressionFunction(); 47 | this.setC(aggressiveness); 48 | this.setEpsilon(epsilon); 49 | this.setPolicy(policy); 50 | this.setKernel(kernel); 51 | this.setLabel(label); 52 | } 53 | 54 | @Override 55 | public Kernel getKernel(){ 56 | return kernel; 57 | } 58 | 59 | @Override 60 | public void setKernel(Kernel kernel) { 61 | this.kernel = kernel; 62 | this.getPredictionFunction().getModel().setKernel(kernel); 63 | } 64 | 65 | @Override 66 | public KernelizedPassiveAggressiveRegression duplicate() { 67 | KernelizedPassiveAggressiveRegression copy = new KernelizedPassiveAggressiveRegression(); 68 | copy.setC(this.c); 69 | copy.setKernel(this.kernel); 70 | copy.setPolicy(this.policy); 71 | copy.setEpsilon(epsilon); 72 | return copy; 73 | } 74 | 75 | @Override 76 | public UnivariateKernelMachineRegressionFunction getPredictionFunction(){ 77 | return (UnivariateKernelMachineRegressionFunction) this.regressor; 78 | } 79 | 80 | @Override 81 | public void setPredictionFunction(PredictionFunction predictionFunction) { 82 | this.regressor = (UnivariateKernelMachineRegressionFunction) predictionFunction; 83 | } 84 | 85 | } 86 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/LinearPassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import com.fasterxml.jackson.annotation.JsonTypeName; 19 | 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 22 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 23 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 24 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateLinearRegressionFunction; 25 | 26 | /** 27 | * Online Passive-Aggressive Learning Algorithm for regression tasks (linear version). 28 | * 29 | * reference: 30 | * 31 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 32 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 33 | * 34 | * @author Simone Filice 35 | */ 36 | @JsonTypeName("linearPA-R") 37 | public class LinearPassiveAggressiveRegression extends PassiveAggressiveRegression implements LinearMethod{ 38 | 39 | private String representation; 40 | 41 | public LinearPassiveAggressiveRegression(){ 42 | UnivariateLinearRegressionFunction regressor = new UnivariateLinearRegressionFunction(); 43 | regressor.setModel(new BinaryLinearModel()); 44 | this.regressor = regressor; 45 | 46 | } 47 | 48 | public LinearPassiveAggressiveRegression(float aggressiveness, float epsilon, Policy policy, String representation, Label label){ 49 | UnivariateLinearRegressionFunction regressor = new UnivariateLinearRegressionFunction(); 50 | regressor.setModel(new BinaryLinearModel()); 51 | this.regressor = regressor; 52 | this.setC(aggressiveness); 53 | this.setEpsilon(epsilon); 54 | this.setPolicy(policy); 55 | this.setRepresentation(representation); 56 | this.setLabel(label); 57 | } 58 | 59 | @Override 60 | public LinearPassiveAggressiveRegression duplicate() { 61 | LinearPassiveAggressiveRegression copy = new LinearPassiveAggressiveRegression(); 62 | copy.setC(this.c); 63 | copy.setRepresentation(this.representation); 64 | copy.setPolicy(this.policy); 65 | copy.setEpsilon(epsilon); 66 | return copy; 67 | } 68 | 69 | @Override 70 | public String getRepresentation() { 71 | return representation; 72 | } 73 | 74 | @Override 75 | public void setRepresentation(String representation) { 76 | this.representation = representation; 77 | this.getPredictionFunction().getModel().setRepresentation(representation); 78 | } 79 | 80 | @Override 81 | public UnivariateLinearRegressionFunction getPredictionFunction(){ 82 | return (UnivariateLinearRegressionFunction) this.regressor; 83 | } 84 | 85 | @Override 86 | public void setPredictionFunction(PredictionFunction predictionFunction) { 87 | this.regressor = (UnivariateLinearRegressionFunction) predictionFunction; 88 | } 89 | 90 | } 91 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/PassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.learningalgorithm.PassiveAggressive; 21 | import it.uniroma2.sag.kelp.learningalgorithm.regression.RegressionLearningAlgorithm; 22 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateRegressionOutput; 23 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateRegressionFunction; 24 | 25 | import com.fasterxml.jackson.annotation.JsonIgnore; 26 | 27 | /** 28 | * Online Passive-Aggressive Learning Algorithm for regression tasks. 29 | * 30 | * reference: 31 | * 32 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 33 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 34 | * 35 | * @author Simone Filice 36 | */ 37 | public abstract class PassiveAggressiveRegression extends PassiveAggressive implements RegressionLearningAlgorithm{ 38 | 39 | @JsonIgnore 40 | protected UnivariateRegressionFunction regressor; 41 | 42 | protected float epsilon; 43 | 44 | /** 45 | * Returns epsilon, i.e. the accepted distance between the predicted and the real regression values 46 | * 47 | * @return the epsilon 48 | */ 49 | public float getEpsilon() { 50 | return epsilon; 51 | } 52 | 53 | /** 54 | * Sets epsilon, i.e. the accepted distance between the predicted and the real regression values 55 | * 56 | * @param epsilon the epsilon to set 57 | */ 58 | public void setEpsilon(float epsilon) { 59 | this.epsilon = epsilon; 60 | } 61 | 62 | @Override 63 | public UnivariateRegressionFunction getPredictionFunction() { 64 | return this.regressor; 65 | } 66 | 67 | @Override 68 | public void learn(Dataset dataset){ 69 | 70 | while(dataset.hasNextExample()){ 71 | Example example = dataset.getNextExample(); 72 | this.learn(example); 73 | } 74 | dataset.reset(); 75 | } 76 | 77 | @Override 78 | public UnivariateRegressionOutput learn(Example example){ 79 | UnivariateRegressionOutput prediction=this.regressor.predict(example); 80 | float difference = example.getRegressionValue(label) - prediction.getScore(label); 81 | float lossValue = Math.abs(difference) - epsilon;//it represents the distance from the correct semi-space 82 | if(lossValue>0){ 83 | float exampleSquaredNorm = this.regressor.getModel().getSquaredNorm(example); 84 | float weight = this.computeWeight(example, lossValue, exampleSquaredNorm, c); 85 | if(difference<0){ 86 | weight = -weight; 87 | } 88 | this.regressor.getModel().addExample(weight, example); 89 | } 90 | return prediction; 91 | } 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/linearization/LinearizationFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.linearization; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 20 | import it.uniroma2.sag.kelp.data.example.Example; 21 | import it.uniroma2.sag.kelp.data.representation.Vector; 22 | 23 | /** 24 | * This interface allows implementing function to linearized examples through 25 | * linear representations, i.e. vectors 26 | * 27 | * 28 | * @author Danilo Croce 29 | * 30 | */ 31 | public interface LinearizationFunction { 32 | 33 | /** 34 | * Given an input Example, this method generates a linear 35 | * Representation>, i.e. a Vector. 36 | * 37 | * @param example 38 | * The input example. 39 | * @return The linearized representation of the input example. 40 | */ 41 | public Vector getLinearRepresentation(Example example); 42 | 43 | /** 44 | * This method linearizes an input example, providing a new example 45 | * containing only a representation with a specific name, provided as input. 46 | * The produced example inherits the labels of the input example. 47 | * 48 | * @param example 49 | * The input example. 50 | * @param vectorName 51 | * The name of the linear representation inside the new example 52 | * @return 53 | */ 54 | public Example getLinearizedExample(Example example, String representationName); 55 | 56 | /** 57 | * This method linearizes all the examples in the input dataset 58 | * , generating a corresponding linearized dataset. The produced examples 59 | * inherit the labels of the corresponding input examples. 60 | * 61 | * @param dataset 62 | * The input dataset 63 | * @param representationName 64 | * The name of the linear representation inside the new examples 65 | * @return 66 | */ 67 | public SimpleDataset getLinearizedDataset(Dataset dataset, String representationName); 68 | 69 | /** 70 | * @return the size of the resulting embedding, i.e. the number of resulting 71 | * vector dimensions 72 | */ 73 | public int getEmbeddingSize(); 74 | 75 | } 76 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/predictionfunction/SequencePrediction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.predictionfunction; 17 | 18 | import java.util.ArrayList; 19 | import java.util.List; 20 | 21 | import it.uniroma2.sag.kelp.data.example.SequencePath; 22 | import it.uniroma2.sag.kelp.data.label.Label; 23 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 24 | 25 | /** 26 | * It is a output provided by a machine learning systems on a sequence. This 27 | * specific implementation allows to assign multiple labelings to single 28 | * sequence, useful for some labeling strategies, such as Beam Search. Notice 29 | * that each labeling requires a score to select the more promising labeling. 30 | * 31 | * @author Danilo Croce 32 | * 33 | */ 34 | public class SequencePrediction implements Prediction { 35 | 36 | /** 37 | * 38 | */ 39 | private static final long serialVersionUID = -1040539866977906008L; 40 | /** 41 | * This list contains multiple labelings to be assigned to a single sequence 42 | */ 43 | private List paths; 44 | 45 | public SequencePrediction() { 46 | paths = new ArrayList(); 47 | } 48 | 49 | /** 50 | * @return The best path, i.e., the labeling with the highest score in the 51 | * list of labelings provided by a classifier 52 | */ 53 | public SequencePath bestPath() { 54 | return paths.get(0); 55 | } 56 | 57 | /** 58 | * @return a list containing multiple labelings to be assigned to a single 59 | * sequence 60 | */ 61 | public List getPaths() { 62 | return paths; 63 | } 64 | 65 | @Override 66 | public Float getScore(Label label) { 67 | return null; 68 | } 69 | 70 | /** 71 | * @param paths 72 | * a list contains multiple labelings to be assigned to a single 73 | * sequence 74 | */ 75 | public void setPaths(List paths) { 76 | this.paths = paths; 77 | } 78 | 79 | @Override 80 | public String toString() { 81 | StringBuilder sb = new StringBuilder(); 82 | for (int i = 0; i < paths.size(); i++) { 83 | if (i == 0) 84 | sb.append("Best Path\t"); 85 | else 86 | sb.append("Altern. Path\t"); 87 | SequencePath sequencePath = paths.get(i); 88 | sb.append(sequencePath + "\n"); 89 | } 90 | return sb.toString(); 91 | } 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/predictionfunction/model/SequenceModel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.predictionfunction.model; 17 | 18 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGenerator; 19 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 20 | 21 | /** 22 | * This class implements a model produced by a 23 | * SequenceClassificationLearningAlgorithm 24 | * 25 | * @author Danilo Croce 26 | * 27 | */ 28 | public class SequenceModel implements Model { 29 | 30 | /** 31 | * 32 | */ 33 | private static final long serialVersionUID = -2749198158786953940L; 34 | 35 | /** 36 | * The prediction function producing the emission scores to be considered in 37 | * the Viterbi Decoding 38 | */ 39 | private PredictionFunction basePredictionFunction; 40 | 41 | private SequenceExampleGenerator sequenceExampleGenerator; 42 | 43 | public SequenceModel() { 44 | super(); 45 | } 46 | 47 | public SequenceModel(PredictionFunction basePredictionFunction, SequenceExampleGenerator sequenceExampleGenerator) { 48 | super(); 49 | this.basePredictionFunction = basePredictionFunction; 50 | this.sequenceExampleGenerator = sequenceExampleGenerator; 51 | } 52 | 53 | public PredictionFunction getBasePredictionFunction() { 54 | return basePredictionFunction; 55 | } 56 | 57 | public SequenceExampleGenerator getSequenceExampleGenerator() { 58 | return sequenceExampleGenerator; 59 | } 60 | 61 | @Override 62 | public void reset() { 63 | } 64 | 65 | public void setBasePredictionFunction(PredictionFunction basePredictionFunction) { 66 | this.basePredictionFunction = basePredictionFunction; 67 | } 68 | 69 | public void setSequenceExampleGenerator(SequenceExampleGenerator sequenceExampleGenerator) { 70 | this.sequenceExampleGenerator = sequenceExampleGenerator; 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/utils/evaluation/ClusteringEvaluator.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.utils.evaluation; 2 | 3 | import java.util.ArrayList; 4 | import java.util.HashSet; 5 | import java.util.TreeMap; 6 | 7 | import it.uniroma2.sag.kelp.data.clustering.Cluster; 8 | import it.uniroma2.sag.kelp.data.clustering.ClusterExample; 9 | import it.uniroma2.sag.kelp.data.clustering.ClusterList; 10 | import it.uniroma2.sag.kelp.data.example.Example; 11 | import it.uniroma2.sag.kelp.data.example.SimpleExample; 12 | import it.uniroma2.sag.kelp.data.label.Label; 13 | import it.uniroma2.sag.kelp.data.label.StringLabel; 14 | import it.uniroma2.sag.kelp.learningalgorithm.clustering.kernelbasedkmeans.KernelBasedKMeansExample; 15 | 16 | /** 17 | * 18 | * Implements Evaluation methods for clustering algorithms. 19 | * 20 | * More details about Purity and NMI can be found here: 21 | * 22 | * https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1. 23 | * html 24 | * 25 | * @author Danilo Croce 26 | * 27 | */ 28 | public class ClusteringEvaluator { 29 | 30 | public static float getPurity(ClusterList clusters) { 31 | 32 | float res = 0; 33 | int k = clusters.size(); 34 | 35 | for (int clustId = 0; clustId < k; clustId++) { 36 | 37 | TreeMap classSizes = new TreeMap(); 38 | 39 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 40 | HashSet labels = vce.getExample().getClassificationLabels(); 41 | for (Label label : labels) 42 | if (!classSizes.containsKey(label)) 43 | classSizes.put(label, 1); 44 | else 45 | classSizes.put(label, classSizes.get(label) + 1); 46 | } 47 | 48 | int maxSize = 0; 49 | for (int size : classSizes.values()) { 50 | if (size > maxSize) { 51 | maxSize = size; 52 | } 53 | } 54 | res += maxSize; 55 | } 56 | 57 | return res / (float) clusters.getNumberOfExamples(); 58 | } 59 | 60 | public static float getMI(ClusterList clusters) { 61 | 62 | float res = 0; 63 | 64 | float N = clusters.getNumberOfExamples(); 65 | 66 | int k = clusters.size(); 67 | 68 | TreeMap classCardinality = getClassCardinality(clusters); 69 | 70 | for (int clustId = 0; clustId < k; clustId++) { 71 | 72 | TreeMap classSizes = getClassCardinalityWithinCluster(clusters, clustId); 73 | 74 | for (Label className : classSizes.keySet()) { 75 | int wSize = classSizes.get(className); 76 | res += ((float) wSize / N) * myLog(N * (float) wSize 77 | / (clusters.get(clustId).getExamples().size() * (float) classCardinality.get(className))); 78 | } 79 | 80 | } 81 | 82 | return res; 83 | 84 | } 85 | 86 | private static TreeMap getClassCardinalityWithinCluster(ClusterList clusters, int clustId) { 87 | 88 | TreeMap classSizes = new TreeMap(); 89 | 90 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 91 | HashSet labels = vce.getExample().getClassificationLabels(); 92 | for (Label label : labels) 93 | if (!classSizes.containsKey(label)) 94 | classSizes.put(label, 1); 95 | else 96 | classSizes.put(label, classSizes.get(label) + 1); 97 | } 98 | 99 | return classSizes; 100 | } 101 | 102 | private static float getClusterEntropy(ClusterList clusters) { 103 | 104 | float res = 0; 105 | float N = clusters.getNumberOfExamples(); 106 | int k = clusters.size(); 107 | 108 | for (int clustId = 0; clustId < k; clustId++) { 109 | int clusterElementSize = clusters.get(clustId).getExamples().size(); 110 | if (clusterElementSize != 0) 111 | res -= ((float) clusterElementSize / N) * myLog((float) clusterElementSize / N); 112 | } 113 | return res; 114 | 115 | } 116 | 117 | private static float getClassEntropy(ClusterList clusters) { 118 | 119 | float res = 0; 120 | float N = clusters.getNumberOfExamples(); 121 | 122 | TreeMap classCardinality = getClassCardinality(clusters); 123 | 124 | for (int classSize : classCardinality.values()) { 125 | res -= ((float) classSize / N) * myLog((float) classSize / N); 126 | } 127 | return res; 128 | 129 | } 130 | 131 | private static float myLog(float f) { 132 | return (float) (Math.log(f) / Math.log(2f)); 133 | } 134 | 135 | private static TreeMap getClassCardinality(ClusterList clusters) { 136 | TreeMap classSizes = new TreeMap(); 137 | 138 | int k = clusters.size(); 139 | 140 | for (int clustId = 0; clustId < k; clustId++) { 141 | 142 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 143 | HashSet labels = vce.getExample().getClassificationLabels(); 144 | for (Label label : labels) 145 | if (!classSizes.containsKey(label)) 146 | classSizes.put(label, 1); 147 | else 148 | classSizes.put(label, classSizes.get(label) + 1); 149 | } 150 | } 151 | return classSizes; 152 | } 153 | 154 | public static float getNMI(ClusterList clusters) { 155 | return getMI(clusters) / ((getClusterEntropy(clusters) + getClassEntropy(clusters)) / 2f); 156 | } 157 | 158 | public static String getStatistics(ClusterList clusters) { 159 | StringBuilder sb = new StringBuilder(); 160 | 161 | sb.append("Purity:\t" + getPurity(clusters) + "\n"); 162 | sb.append("Mutual Information:\t" + getMI(clusters) + "\n"); 163 | sb.append("Cluster Entropy:\t" + getClusterEntropy(clusters) + "\n"); 164 | sb.append("Class Entropy:\t" + getClassEntropy(clusters) + "\n"); 165 | sb.append("NMI:\t" + getNMI(clusters)); 166 | 167 | return sb.toString(); 168 | } 169 | 170 | public static void main(String[] args) { 171 | ClusterList clusters = new ClusterList(); 172 | 173 | Cluster c1 = new Cluster("C1"); 174 | ArrayList list1 = new ArrayList(); 175 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 176 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 177 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 178 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 179 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 180 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 181 | for (Example e : list1) { 182 | c1.add(new KernelBasedKMeansExample(e, 1f)); 183 | } 184 | 185 | Cluster c2 = new Cluster("C2"); 186 | ArrayList list2 = new ArrayList(); 187 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 188 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 189 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 190 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 191 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 192 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 193 | for (Example e : list2) { 194 | c2.add(new KernelBasedKMeansExample(e, 1f)); 195 | } 196 | 197 | Cluster c3 = new Cluster("C3"); 198 | ArrayList list3 = new ArrayList(); 199 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 200 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 201 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 202 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 203 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 204 | for (Example e : list3) { 205 | c3.add(new KernelBasedKMeansExample(e, 1f)); 206 | } 207 | 208 | clusters.add(c1); 209 | clusters.add(c2); 210 | clusters.add(c3); 211 | 212 | System.out.println(ClusteringEvaluator.getStatistics(clusters)); 213 | 214 | //From https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html 215 | //Purity = 0.71 216 | //NMI = 0.36 217 | 218 | } 219 | 220 | } 221 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/utils/evaluation/MulticlassSequenceClassificationEvaluator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.utils.evaluation; 17 | 18 | import java.util.List; 19 | 20 | import it.uniroma2.sag.kelp.data.example.Example; 21 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 22 | import it.uniroma2.sag.kelp.data.example.SequencePath; 23 | import it.uniroma2.sag.kelp.data.label.Label; 24 | import it.uniroma2.sag.kelp.data.label.SequenceEmission; 25 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 26 | import it.uniroma2.sag.kelp.predictionfunction.SequencePrediction; 27 | 28 | /** 29 | * This is an instance of an Evaluator. It allows to compute the some common 30 | * measure for classification tasks acting over SequenceExamples. It 31 | * computes precision, recall, f1s for each class, and a global accuracy. 32 | * 33 | * @author Danilo Croce 34 | */ 35 | public class MulticlassSequenceClassificationEvaluator extends MulticlassClassificationEvaluator{ 36 | 37 | /** 38 | * Initialize a new F1Evaluator that will work on the specified classes 39 | * 40 | * @param labels 41 | */ 42 | public MulticlassSequenceClassificationEvaluator(List labels) { 43 | super(labels); 44 | } 45 | 46 | public void addCount(Example test, Prediction prediction) { 47 | addCount((SequenceExample) test, (SequencePrediction) prediction); 48 | } 49 | 50 | /** 51 | * This method should be implemented in the subclasses to update counters 52 | * useful to compute the performance measure 53 | * 54 | * @param test 55 | * the test example 56 | * @param predicted 57 | * the prediction of the system 58 | */ 59 | public void addCount(SequenceExample test, SequencePrediction predicted) { 60 | 61 | SequencePath bestPath = predicted.bestPath(); 62 | 63 | for (int seqIdx = 0; seqIdx < test.getLenght(); seqIdx++) { 64 | 65 | Example testItem = test.getExample(seqIdx); 66 | SequenceEmission sequenceLabel = bestPath.getAssignedSequnceLabels().get(seqIdx); 67 | 68 | for (Label l : this.labels) { 69 | ClassStats stats = this.classStats.get(l); 70 | if(testItem.isExampleOf(l)){ 71 | if(sequenceLabel.getLabel().equals(l)){ 72 | stats.tp++; 73 | totalTp++; 74 | }else{ 75 | stats.fn++; 76 | totalFn++; 77 | } 78 | }else{ 79 | if(sequenceLabel.getLabel().equals(l)){ 80 | stats.fp++; 81 | totalFp++; 82 | }else{ 83 | stats.tn++; 84 | totalTn++; 85 | } 86 | } 87 | 88 | } 89 | 90 | //TODO: check (i) e' giusto valutare l'accuracy dei singoli elementi della sequenza e non della sequenza completa 91 | //(ii) va considerato il caso multilabel 92 | total++; 93 | 94 | if (testItem.isExampleOf(sequenceLabel.getLabel())) { 95 | correct++; 96 | } 97 | 98 | this.computed = false; 99 | } 100 | } 101 | 102 | } 103 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/algorithms/binary/liblinear/LibLinearDenseVsSparseClassificationEvaluator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.algorithms.binary.liblinear; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.data.manipulator.NormalizationManipolator; 22 | import it.uniroma2.sag.kelp.data.manipulator.VectorConcatenationManipulator; 23 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.LibLinearLearningAlgorithm; 24 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 25 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassificationOutput; 26 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassifier; 27 | import it.uniroma2.sag.kelp.utils.evaluation.MulticlassClassificationEvaluator; 28 | import it.uniroma2.sag.kelp.utils.exception.NoSuchPerformanceMeasureException; 29 | 30 | import java.io.FileNotFoundException; 31 | import java.io.UnsupportedEncodingException; 32 | import java.util.ArrayList; 33 | import java.util.List; 34 | 35 | import org.junit.Assert; 36 | import org.junit.Test; 37 | 38 | public class LibLinearDenseVsSparseClassificationEvaluator { 39 | 40 | private static List sparseScores = new ArrayList(); 41 | private static List denseScores = new ArrayList(); 42 | 43 | @Test 44 | public void testConsistency() { 45 | try { 46 | String inputFilePath = "src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz"; 47 | 48 | SimpleDataset dataset = new SimpleDataset(); 49 | dataset.populate(inputFilePath); 50 | SimpleDataset[] split = dataset.split(0.5f); 51 | 52 | SimpleDataset trainingSet = split[0]; 53 | SimpleDataset testSet = split[1]; 54 | float c = 1.0f; 55 | float f1Dense = testDense(trainingSet, c, testSet); 56 | float f1Sparse = testSparse(trainingSet, c, testSet); 57 | 58 | Assert.assertEquals(f1Sparse, f1Dense, 0.000001); 59 | 60 | for (int i = 0; i < sparseScores.size(); i++) { 61 | Assert.assertEquals(sparseScores.get(i), denseScores.get(i), 62 | 0.000001); 63 | } 64 | } catch (FileNotFoundException e) { 65 | e.printStackTrace(); 66 | Assert.assertTrue(false); 67 | } catch (UnsupportedEncodingException e) { 68 | e.printStackTrace(); 69 | Assert.assertTrue(false); 70 | } catch (NoSuchPerformanceMeasureException e) { 71 | e.printStackTrace(); 72 | Assert.assertTrue(false); 73 | } catch (Exception e) { 74 | e.printStackTrace(); 75 | Assert.assertTrue(false); 76 | } 77 | } 78 | 79 | private static float testSparse(SimpleDataset trainingSet, float c, 80 | SimpleDataset testSet) throws FileNotFoundException, 81 | UnsupportedEncodingException, NoSuchPerformanceMeasureException { 82 | List classes = trainingSet.getClassificationLabels(); 83 | NormalizationManipolator norma = new NormalizationManipolator(); 84 | trainingSet.manipulate(norma); 85 | testSet.manipulate(norma); 86 | List repr = new ArrayList(); 87 | repr.add("WS"); 88 | List reprW = new ArrayList(); 89 | reprW.add(1.0f); 90 | VectorConcatenationManipulator man = new VectorConcatenationManipulator( 91 | "WS0", repr, reprW); 92 | trainingSet.manipulate(man); 93 | testSet.manipulate(man); 94 | 95 | LibLinearLearningAlgorithm svmSolver = new LibLinearLearningAlgorithm(); 96 | svmSolver.setCn(c); 97 | svmSolver.setCp(c); 98 | svmSolver.setRepresentation("WS0"); 99 | 100 | OneVsAllLearning ovaLearner = new OneVsAllLearning(); 101 | ovaLearner.setBaseAlgorithm(svmSolver); 102 | ovaLearner.setLabels(classes); 103 | ovaLearner.learn(trainingSet); 104 | OneVsAllClassifier f = ovaLearner.getPredictionFunction(); 105 | MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator( 106 | trainingSet.getClassificationLabels()); 107 | for (Example e : testSet.getExamples()) { 108 | OneVsAllClassificationOutput predict = f.predict(e); 109 | Label l = predict.getPredictedClasses().get(0); 110 | evaluator.addCount(e, predict); 111 | sparseScores.add(predict.getScore(l)); 112 | } 113 | 114 | return evaluator.getMacroF1(); 115 | } 116 | 117 | private static float testDense(SimpleDataset trainingSet, float c, 118 | SimpleDataset testSet) throws FileNotFoundException, 119 | UnsupportedEncodingException, NoSuchPerformanceMeasureException { 120 | List classes = trainingSet.getClassificationLabels(); 121 | 122 | LibLinearLearningAlgorithm svmSolver = new LibLinearLearningAlgorithm(); 123 | svmSolver.setCn(c); 124 | svmSolver.setCp(c); 125 | svmSolver.setRepresentation("WS"); 126 | 127 | OneVsAllLearning ovaLearner = new OneVsAllLearning(); 128 | ovaLearner.setBaseAlgorithm(svmSolver); 129 | ovaLearner.setLabels(classes); 130 | ovaLearner.learn(trainingSet); 131 | OneVsAllClassifier f = ovaLearner.getPredictionFunction(); 132 | MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator( 133 | trainingSet.getClassificationLabels()); 134 | for (Example e : testSet.getExamples()) { 135 | OneVsAllClassificationOutput predict = f.predict(e); 136 | Label l = predict.getPredictedClasses().get(0); 137 | evaluator.addCount(e, predict); 138 | denseScores.add(predict.getScore(l)); 139 | } 140 | 141 | return evaluator.getMacroF1(); 142 | } 143 | 144 | } 145 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/algorithms/incrementalTrain/IncrementalTrainTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Simone Filice and Giuseppe Castellucci and Danilo Croce 3 | * and Giovanni Da San Martino and Alessandro Moschitti and Roberto Basili 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package it.uniroma2.sag.kelp.algorithms.incrementalTrain; 18 | 19 | import java.io.IOException; 20 | import java.util.Random; 21 | 22 | import org.junit.Assert; 23 | import org.junit.BeforeClass; 24 | import org.junit.Test; 25 | 26 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 27 | import it.uniroma2.sag.kelp.data.example.Example; 28 | import it.uniroma2.sag.kelp.data.label.Label; 29 | import it.uniroma2.sag.kelp.data.label.StringLabel; 30 | import it.uniroma2.sag.kelp.kernel.Kernel; 31 | import it.uniroma2.sag.kelp.kernel.cache.FixSizeKernelCache; 32 | import it.uniroma2.sag.kelp.kernel.vector.LinearKernel; 33 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 34 | import it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron.KernelizedPerceptron; 35 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryKernelMachineClassifier; 36 | import it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput; 37 | import it.uniroma2.sag.kelp.predictionfunction.classifier.Classifier; 38 | import it.uniroma2.sag.kelp.utils.JacksonSerializerWrapper; 39 | import it.uniroma2.sag.kelp.utils.ObjectSerializer; 40 | 41 | public class IncrementalTrainTest { 42 | private static Classifier f = null; 43 | private static SimpleDataset trainingSet; 44 | private static SimpleDataset testSet; 45 | private static SimpleDataset [] folds; 46 | private static ObjectSerializer serializer = new JacksonSerializerWrapper(); 47 | private static KernelizedPerceptron learner; 48 | 49 | private static Label positiveClass = new StringLabel("+1"); 50 | 51 | @BeforeClass 52 | public static void learnModel() { 53 | trainingSet = new SimpleDataset(); 54 | testSet = new SimpleDataset(); 55 | try { 56 | trainingSet.populate("src/test/resources/svmTest/binary/binary_train.klp"); 57 | trainingSet.shuffleExamples(new Random()); 58 | // Read a dataset into a test variable 59 | testSet.populate("src/test/resources/svmTest/binary/binary_test.klp"); 60 | } catch (Exception e) { 61 | e.printStackTrace(); 62 | Assert.assertTrue(false); 63 | } 64 | 65 | folds = trainingSet.nFolding(2); 66 | 67 | // define the kernel 68 | Kernel kernel = new LinearKernel("0"); 69 | 70 | // add a cache 71 | kernel.setKernelCache(new FixSizeKernelCache(trainingSet 72 | .getNumberOfExamples())); 73 | 74 | // define the learning algorithm 75 | learner = new KernelizedPerceptron(0.2f, 1f, false, kernel, positiveClass); 76 | 77 | // learn and get the prediction function 78 | learner.learn(trainingSet); 79 | f = learner.getPredictionFunction(); 80 | } 81 | 82 | @Test 83 | public void incrementalTrain() throws IOException{ 84 | String jsonSerialization = serializer.writeValueAsString(learner); 85 | System.out.println(jsonSerialization); 86 | ClassificationLearningAlgorithm jsonAlgo = serializer.readValue(jsonSerialization, ClassificationLearningAlgorithm.class); 87 | jsonAlgo.learn(folds[0]); 88 | jsonAlgo.learn(folds[1]); 89 | Classifier jsonClassifier = jsonAlgo.getPredictionFunction(); 90 | 91 | for(Example ex : testSet.getExamples()){ 92 | ClassificationOutput p = f.predict(ex); 93 | Float score = p.getScore(positiveClass); 94 | ClassificationOutput pJson = jsonClassifier.predict(ex); 95 | Float scoreJson = pJson.getScore(positiveClass); 96 | Assert.assertEquals(scoreJson.floatValue(), score.floatValue(), 97 | 0.001f); 98 | } 99 | } 100 | 101 | @Test 102 | public void reloadAndContinueTraining() throws IOException{ 103 | String jsonLearnerSerialization = serializer.writeValueAsString(learner); 104 | System.out.println(jsonLearnerSerialization); 105 | KernelizedPerceptron jsonAlgo = serializer.readValue(jsonLearnerSerialization, KernelizedPerceptron.class); 106 | jsonAlgo.learn(folds[0]); 107 | String jsonClassifierSerialization = serializer.writeValueAsString(jsonAlgo.getPredictionFunction()); 108 | jsonAlgo = serializer.readValue(jsonLearnerSerialization, KernelizedPerceptron.class); //Brand new classifier 109 | BinaryKernelMachineClassifier jsonClassifier = serializer.readValue(jsonClassifierSerialization, BinaryKernelMachineClassifier.class); 110 | jsonAlgo.getPredictionFunction().setModel(jsonClassifier.getModel()); 111 | jsonAlgo.learn(folds[1]); 112 | jsonClassifier = jsonAlgo.getPredictionFunction(); 113 | 114 | for(Example ex : testSet.getExamples()){ 115 | ClassificationOutput p = f.predict(ex); 116 | Float score = p.getScore(positiveClass); 117 | ClassificationOutput pJson = jsonClassifier.predict(ex); 118 | Float scoreJson = pJson.getScore(positiveClass); 119 | Assert.assertEquals(scoreJson.floatValue(), score.floatValue(), 120 | 0.001f); 121 | } 122 | } 123 | 124 | } 125 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/learningalgorithm/classification/hmm/SequenceLearningLinearTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.hmm; 17 | 18 | import java.io.BufferedReader; 19 | import java.io.FileInputStream; 20 | import java.io.FileNotFoundException; 21 | import java.io.IOException; 22 | import java.io.InputStreamReader; 23 | import java.io.UnsupportedEncodingException; 24 | import java.util.ArrayList; 25 | import java.util.zip.GZIPInputStream; 26 | 27 | import org.junit.Assert; 28 | import org.junit.Test; 29 | 30 | import it.uniroma2.sag.kelp.data.dataset.SequenceDataset; 31 | import it.uniroma2.sag.kelp.data.example.Example; 32 | import it.uniroma2.sag.kelp.data.example.ParsingExampleException; 33 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 34 | import it.uniroma2.sag.kelp.data.example.SequencePath; 35 | import it.uniroma2.sag.kelp.data.label.Label; 36 | import it.uniroma2.sag.kelp.data.label.StringLabel; 37 | import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLearningAlgorithm; 38 | import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLoss; 39 | import it.uniroma2.sag.kelp.predictionfunction.SequencePrediction; 40 | import it.uniroma2.sag.kelp.predictionfunction.SequencePredictionFunction; 41 | 42 | public class SequenceLearningLinearTest { 43 | 44 | private static final Float TOLERANCE = 0.001f; 45 | 46 | public static void main(String[] args) throws Exception { 47 | 48 | } 49 | 50 | @Test 51 | public void testLinear() { 52 | 53 | String inputTrainFilePath = "src/test/resources/sequence_learning/declaration_of_independence.klp.gz"; 54 | String inputTestFilePath = "src/test/resources/sequence_learning/gettysburg_address.klp.gz"; 55 | String scoreFilePath = "src/test/resources/sequence_learning/prediction_test_linear.txt"; 56 | 57 | /* 58 | * Given a targeted item in the sequence, this variable determines the 59 | * number of previous example considered in the learning/labeling 60 | * process. 61 | * 62 | * NOTE: if this variable is set to 0, the learning process corresponds 63 | * to a traditional multi-class classification schema 64 | */ 65 | int transitionsOrder = 1; 66 | 67 | /* 68 | * This variable determines the importance of the transition-based 69 | * features during the learning process. Higher valuers will assign more 70 | * importance to the transitions. 71 | */ 72 | float weight = 1f; 73 | 74 | /* 75 | * The size of the beam to be used in the decoding process. This number 76 | * determines the number of possible sequences produced in the labeling 77 | * process. It will also increase the process complexity. 78 | */ 79 | int beamSize = 5; 80 | 81 | /* 82 | * During the labeling process, each item is classified with respect to 83 | * the target classes. To reduce the complexity of the labeling process, 84 | * this variable determines the number of classes that received the 85 | * highest classification scores to be considered after the 86 | * classification step in the Viterbi Decoding. 87 | */ 88 | int maxEmissionCandidates = 3; 89 | 90 | /* 91 | * This representation contains the feature vector representing items in 92 | * the sequence 93 | */ 94 | String originalRepresentationName = "rep"; 95 | 96 | /* 97 | * Loading the training dataset 98 | */ 99 | SequenceDataset sequenceTrainDataset = new SequenceDataset(); 100 | try { 101 | sequenceTrainDataset.populate(inputTrainFilePath); 102 | } catch (IOException e) { 103 | e.printStackTrace(); 104 | Assert.assertTrue(false); 105 | } catch (InstantiationException e) { 106 | e.printStackTrace(); 107 | Assert.assertTrue(false); 108 | } catch (ParsingExampleException e) { 109 | e.printStackTrace(); 110 | Assert.assertTrue(false); 111 | } catch (Exception e) { 112 | e.printStackTrace(); 113 | Assert.assertTrue(false); 114 | } 115 | 116 | /* 117 | * Instance classifier 118 | */ 119 | float cSVM = 1f; 120 | DCDLearningAlgorithm instanceClassifierLearningAlgorithm = new DCDLearningAlgorithm(cSVM, cSVM, DCDLoss.L1, 121 | false, 50, originalRepresentationName); 122 | 123 | /* 124 | * Sequence classifier. 125 | */ 126 | SequenceClassificationLearningAlgorithm sequenceClassificationLearningAlgorithm = null; 127 | try { 128 | sequenceClassificationLearningAlgorithm = new SequenceClassificationLinearLearningAlgorithm( 129 | instanceClassifierLearningAlgorithm, transitionsOrder, weight); 130 | sequenceClassificationLearningAlgorithm.setMaxEmissionCandidates(maxEmissionCandidates); 131 | sequenceClassificationLearningAlgorithm.setBeamSize(beamSize); 132 | 133 | sequenceClassificationLearningAlgorithm.learn(sequenceTrainDataset); 134 | } catch (Exception e1) { 135 | e1.printStackTrace(); 136 | Assert.assertTrue(false); 137 | } 138 | 139 | SequencePredictionFunction predictionFunction = (SequencePredictionFunction) sequenceClassificationLearningAlgorithm 140 | .getPredictionFunction(); 141 | 142 | /* 143 | * Load the test set 144 | */ 145 | SequenceDataset sequenceTestDataset = new SequenceDataset(); 146 | try { 147 | sequenceTestDataset.populate(inputTestFilePath); 148 | } catch (IOException e) { 149 | e.printStackTrace(); 150 | Assert.assertTrue(false); 151 | } catch (InstantiationException e) { 152 | e.printStackTrace(); 153 | Assert.assertTrue(false); 154 | } catch (ParsingExampleException e) { 155 | e.printStackTrace(); 156 | Assert.assertTrue(false); 157 | } 158 | 159 | /* 160 | * Tagging and evaluating 161 | */ 162 | // PrintStream ps = new PrintStream(scoreFilePath); 163 | ArrayList labels = new ArrayList(); 164 | ArrayList scores = new ArrayList(); 165 | for (Example example : sequenceTestDataset.getExamples()) { 166 | 167 | SequenceExample sequenceExample = (SequenceExample) example; 168 | SequencePrediction sequencePrediction = (SequencePrediction) predictionFunction.predict(sequenceExample); 169 | 170 | SequencePath bestPath = sequencePrediction.bestPath(); 171 | for (int i = 0; i < sequenceExample.getLenght(); i++) { 172 | // ps.println(bestPath.getAssignedLabel(i) + "\t" + 173 | // bestPath.getScore()); 174 | labels.add(bestPath.getAssignedLabel(i)); 175 | scores.add(bestPath.getScore()); 176 | } 177 | 178 | } 179 | // ps.close(); 180 | 181 | ArrayList oldScores = loadScores(scoreFilePath); 182 | ArrayList oldLabels = loadLabels(scoreFilePath); 183 | 184 | for (int i = 0; i < oldScores.size(); i++) { 185 | Assert.assertEquals(oldScores.get(i), scores.get(i), TOLERANCE); 186 | Assert.assertEquals(labels.get(i).toString(), oldLabels.get(i).toString()); 187 | } 188 | 189 | } 190 | 191 | public static ArrayList loadScores(String filepath) { 192 | try { 193 | ArrayList scores = new ArrayList(); 194 | BufferedReader in = null; 195 | String encoding = "UTF-8"; 196 | if (filepath.endsWith(".gz")) { 197 | in = new BufferedReader( 198 | new InputStreamReader(new GZIPInputStream(new FileInputStream(filepath)), encoding)); 199 | } else { 200 | in = new BufferedReader(new InputStreamReader(new FileInputStream(filepath), encoding)); 201 | } 202 | 203 | String str = ""; 204 | while ((str = in.readLine()) != null) { 205 | scores.add(Double.parseDouble(str.split("\t")[1])); 206 | } 207 | 208 | in.close(); 209 | 210 | return scores; 211 | 212 | } catch (UnsupportedEncodingException e) { 213 | e.printStackTrace(); 214 | Assert.assertTrue(false); 215 | } catch (FileNotFoundException e) { 216 | e.printStackTrace(); 217 | Assert.assertTrue(false); 218 | } catch (IOException e) { 219 | e.printStackTrace(); 220 | Assert.assertTrue(false); 221 | } 222 | 223 | return null; 224 | } 225 | 226 | public static ArrayList loadLabels(String filepath) { 227 | try { 228 | ArrayList res = new ArrayList(); 229 | BufferedReader in = null; 230 | String encoding = "UTF-8"; 231 | if (filepath.endsWith(".gz")) { 232 | in = new BufferedReader( 233 | new InputStreamReader(new GZIPInputStream(new FileInputStream(filepath)), encoding)); 234 | } else { 235 | in = new BufferedReader(new InputStreamReader(new FileInputStream(filepath), encoding)); 236 | } 237 | 238 | String str = ""; 239 | while ((str = in.readLine()) != null) { 240 | res.add(new StringLabel(str.split("\t")[0])); 241 | } 242 | 243 | in.close(); 244 | 245 | return res; 246 | 247 | } catch (UnsupportedEncodingException e) { 248 | e.printStackTrace(); 249 | Assert.assertTrue(false); 250 | } catch (FileNotFoundException e) { 251 | e.printStackTrace(); 252 | Assert.assertTrue(false); 253 | } catch (IOException e) { 254 | e.printStackTrace(); 255 | Assert.assertTrue(false); 256 | } 257 | 258 | return null; 259 | } 260 | 261 | } 262 | -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/README.txt: -------------------------------------------------------------------------------- 1 | The datasets reported in this folder have been created starting from the dataset produced by Thorsten Joachims as an example problem for his SVM^{hmm} implementation. 2 | 3 | The original dataset can be downloaded at: 4 | http://download.joachims.org/svm_hmm/examples/example7.tar.gz 5 | while its description is reported at: 6 | https://www.cs.cornell.edu/people/tj/svm_light/svm_hmm.html -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/declaration_of_independence.klp.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/sequence_learning/declaration_of_independence.klp.gz -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/gettysburg_address.klp.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/sequence_learning/gettysburg_address.klp.gz -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/prediction_test_kernel.txt: -------------------------------------------------------------------------------- 1 | 1 -58.85170393685728 2 | 22 -58.85170393685728 3 | 3 -58.85170393685728 4 | 8 -58.85170393685728 5 | 15 -58.85170393685728 6 | 30 -58.85170393685728 7 | 20 -58.85170393685728 8 | 15 -58.85170393685728 9 | 11 -58.85170393685728 10 | 12 -58.85170393685728 11 | 8 -58.85170393685728 12 | 7 -58.85170393685728 13 | 12 -58.85170393685728 14 | 7 -58.85170393685728 15 | 9 -58.85170393685728 16 | 12 -58.85170393685728 17 | 6 -58.85170393685728 18 | 27 -58.85170393685728 19 | 8 -58.85170393685728 20 | 13 -58.85170393685728 21 | 3 -58.85170393685728 22 | 27 -58.85170393685728 23 | 25 -58.85170393685728 24 | 7 -58.85170393685728 25 | 12 -58.85170393685728 26 | 7 -58.85170393685728 27 | 7 -58.85170393685728 28 | 15 -58.85170393685728 29 | 30 -58.85170393685728 30 | 29 -58.85170393685728 31 | 9 -58.85170393685728 32 | 17 -58.85170393685728 33 | 21 -49.958351223707055 34 | 19 -49.958351223707055 35 | 30 -49.958351223707055 36 | 29 -49.958351223707055 37 | 8 -49.958351223707055 38 | 7 -49.958351223707055 39 | 13 -49.958351223707055 40 | 9 -49.958351223707055 41 | 12 -49.958351223707055 42 | 6 -49.958351223707055 43 | 28 -49.958351223707055 44 | 21 -49.958351223707055 45 | 7 -49.958351223707055 46 | 12 -49.958351223707055 47 | 8 -49.958351223707055 48 | 7 -49.958351223707055 49 | 12 -49.958351223707055 50 | 8 -49.958351223707055 51 | 29 -49.958351223707055 52 | 3 -49.958351223707055 53 | 21 -49.958351223707055 54 | 27 -49.958351223707055 55 | 9 -49.958351223707055 56 | 9 -49.958351223707055 57 | 12 -49.958351223707055 58 | 17 -49.958351223707055 59 | 19 -19.14935390144825 60 | 30 -19.14935390144825 61 | 29 -19.14935390144825 62 | 8 -19.14935390144825 63 | 7 -19.14935390144825 64 | 13 -19.14935390144825 65 | 13 -19.14935390144825 66 | 8 -19.14935390144825 67 | 7 -19.14935390144825 68 | 12 -19.14935390144825 69 | 17 -19.14935390144825 70 | 19 -51.68865761583535 71 | 30 -51.68865761583535 72 | 9 -51.68865761583535 73 | 25 -51.68865761583535 74 | 26 -51.68865761583535 75 | 7 -51.68865761583535 76 | 12 -51.68865761583535 77 | 8 -51.68865761583535 78 | 7 -51.68865761583535 79 | 12 -51.68865761583535 80 | 8 -51.68865761583535 81 | 7 -51.68865761583535 82 | 9 -51.68865761583535 83 | 12 -51.68865761583535 84 | 8 -51.68865761583535 85 | 7 -51.68865761583535 86 | 9 -51.68865761583535 87 | 12 -51.68865761583535 88 | 30 -51.68865761583535 89 | 20 -51.68865761583535 90 | 15 -51.68865761583535 91 | 7 -51.68865761583535 92 | 7 -51.68865761583535 93 | 12 -51.68865761583535 94 | 12 -51.68865761583535 95 | 30 -51.68865761583535 96 | 17 -51.68865761583535 97 | 19 -20.546289531993914 98 | 31 -20.546289531993914 99 | 21 -20.546289531993914 100 | 28 -20.546289531993914 101 | 3 -20.546289531993914 102 | 8 -20.546289531993914 103 | 7 -20.546289531993914 104 | 19 -20.546289531993914 105 | 11 -20.546289531993914 106 | 26 -20.546289531993914 107 | 7 -20.546289531993914 108 | 17 -20.546289531993914 109 | 3 -39.18519755398995 110 | 8 -39.18519755398995 111 | 7 -39.18519755398995 112 | 9 -39.18519755398995 113 | 12 -39.18519755398995 114 | 6 -39.18519755398995 115 | 19 -39.18519755398995 116 | 11 -39.18519755398995 117 | 26 -39.18519755398995 118 | 6 -39.18519755398995 119 | 19 -39.18519755398995 120 | 11 -39.18519755398995 121 | 26 -39.18519755398995 122 | 6 -39.18519755398995 123 | 19 -39.18519755398995 124 | 11 -39.18519755398995 125 | 26 -39.18519755398995 126 | 7 -39.18519755398995 127 | 12 -39.18519755398995 128 | 17 -39.18519755398995 129 | 7 -46.958838324933005 130 | 9 -46.958838324933005 131 | 15 -46.958838324933005 132 | 6 -46.958838324933005 133 | 28 -46.958838324933005 134 | 3 -46.958838324933005 135 | 12 -46.958838324933005 136 | 21 -46.958838324933005 137 | 29 -46.958838324933005 138 | 21 -46.958838324933005 139 | 30 -46.958838324933005 140 | 29 -46.958838324933005 141 | 19 -46.958838324933005 142 | 20 -46.958838324933005 143 | 12 -46.958838324933005 144 | 20 -46.958838324933005 145 | 9 -46.958838324933005 146 | 12 -46.958838324933005 147 | 25 -46.958838324933005 148 | 26 -46.958838324933005 149 | 3 -46.958838324933005 150 | 9 -46.958838324933005 151 | 17 -46.958838324933005 152 | 7 -49.424689389703104 153 | 12 -49.424689389703104 154 | 11 -49.424689389703104 155 | 26 -49.424689389703104 156 | 9 -49.424689389703104 157 | 3 -49.424689389703104 158 | 9 -49.424689389703104 159 | 12 -49.424689389703104 160 | 8 -49.424689389703104 161 | 19 -49.424689389703104 162 | 11 -49.424689389703104 163 | 21 -49.424689389703104 164 | 6 -49.424689389703104 165 | 3 -49.424689389703104 166 | 19 -49.424689389703104 167 | 8 -49.424689389703104 168 | 12 -49.424689389703104 169 | 26 -49.424689389703104 170 | 7 -49.424689389703104 171 | 19 -49.424689389703104 172 | 30 -49.424689389703104 173 | 21 -49.424689389703104 174 | 17 -49.424689389703104 175 | 19 -52.114114669781316 176 | 31 -52.114114669781316 177 | 8 -52.114114669781316 178 | 19 -52.114114669781316 179 | 7 -52.114114669781316 180 | 28 -52.114114669781316 181 | 21 -52.114114669781316 182 | 25 -52.114114669781316 183 | 26 -52.114114669781316 184 | 27 -52.114114669781316 185 | 21 -52.114114669781316 186 | 25 -52.114114669781316 187 | 7 -52.114114669781316 188 | 9 -52.114114669781316 189 | 12 -52.114114669781316 190 | 32 -52.114114669781316 191 | 19 -52.114114669781316 192 | 30 -52.114114669781316 193 | 11 -52.114114669781316 194 | 21 -52.114114669781316 195 | 30 -52.114114669781316 196 | 21 -52.114114669781316 197 | 29 -52.114114669781316 198 | 21 -52.114114669781316 199 | 21 -52.114114669781316 200 | 29 -52.114114669781316 201 | 17 -52.114114669781316 202 | 19 -157.40459068974272 203 | 31 -157.40459068974272 204 | 21 -157.40459068974272 205 | 8 -157.40459068974272 206 | 19 -157.40459068974272 207 | 25 -157.40459068974272 208 | 26 -157.40459068974272 209 | 21 -157.40459068974272 210 | 27 -157.40459068974272 211 | 25 -157.40459068974272 212 | 7 -157.40459068974272 213 | 13 -157.40459068974272 214 | 13 -157.40459068974272 215 | 28 -157.40459068974272 216 | 21 -157.40459068974272 217 | 19 -157.40459068974272 218 | 5 -157.40459068974272 219 | 7 -157.40459068974272 220 | 8 -157.40459068974272 221 | 7 -157.40459068974272 222 | 9 -157.40459068974272 223 | 12 -157.40459068974272 224 | 19 -157.40459068974272 225 | 30 -157.40459068974272 226 | 9 -157.40459068974272 227 | 12 -157.40459068974272 228 | 25 -157.40459068974272 229 | 7 -157.40459068974272 230 | 12 -157.40459068974272 231 | 8 -157.40459068974272 232 | 32 -157.40459068974272 233 | 19 -157.40459068974272 234 | 30 -157.40459068974272 235 | 7 -157.40459068974272 236 | 23 -157.40459068974272 237 | 9 -157.40459068974272 238 | 12 -157.40459068974272 239 | 8 -157.40459068974272 240 | 12 -157.40459068974272 241 | 8 -157.40459068974272 242 | 7 -157.40459068974272 243 | 19 -157.40459068974272 244 | 21 -157.40459068974272 245 | 21 -157.40459068974272 246 | 26 -157.40459068974272 247 | 7 -157.40459068974272 248 | 7 -157.40459068974272 249 | 12 -157.40459068974272 250 | 11 -157.40459068974272 251 | 21 -157.40459068974272 252 | 30 -157.40459068974272 253 | 29 -157.40459068974272 254 | 8 -157.40459068974272 255 | 9 -157.40459068974272 256 | 6 -157.40459068974272 257 | 7 -157.40459068974272 258 | 7 -157.40459068974272 259 | 12 -157.40459068974272 260 | 8 -157.40459068974272 261 | 13 -157.40459068974272 262 | 11 -157.40459068974272 263 | 30 -157.40459068974272 264 | 7 -157.40459068974272 265 | 9 -157.40459068974272 266 | 12 -157.40459068974272 267 | 8 -157.40459068974272 268 | 12 -157.40459068974272 269 | 6 -157.40459068974272 270 | 3 -157.40459068974272 271 | 7 -157.40459068974272 272 | 12 -157.40459068974272 273 | 8 -157.40459068974272 274 | 7 -157.40459068974272 275 | 15 -157.40459068974272 276 | 6 -157.40459068974272 277 | 8 -157.40459068974272 278 | 7 -157.40459068974272 279 | 15 -157.40459068974272 280 | 6 -157.40459068974272 281 | 8 -157.40459068974272 282 | 7 -157.40459068974272 283 | 15 -157.40459068974272 284 | 11 -157.40459068974272 285 | 21 -157.40459068974272 286 | 26 -157.40459068974272 287 | 8 -157.40459068974272 288 | 7 -157.40459068974272 289 | 12 -157.40459068974272 290 | 17 -157.40459068974272 291 | -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/prediction_test_linear.txt: -------------------------------------------------------------------------------- 1 | 1 -61.552865965064605 2 | 22 -61.552865965064605 3 | 3 -61.552865965064605 4 | 8 -61.552865965064605 5 | 15 -61.552865965064605 6 | 30 -61.552865965064605 7 | 20 -61.552865965064605 8 | 15 -61.552865965064605 9 | 11 -61.552865965064605 10 | 12 -61.552865965064605 11 | 8 -61.552865965064605 12 | 7 -61.552865965064605 13 | 12 -61.552865965064605 14 | 7 -61.552865965064605 15 | 9 -61.552865965064605 16 | 12 -61.552865965064605 17 | 6 -61.552865965064605 18 | 27 -61.552865965064605 19 | 8 -61.552865965064605 20 | 13 -61.552865965064605 21 | 3 -61.552865965064605 22 | 27 -61.552865965064605 23 | 25 -61.552865965064605 24 | 7 -61.552865965064605 25 | 12 -61.552865965064605 26 | 7 -61.552865965064605 27 | 7 -61.552865965064605 28 | 15 -61.552865965064605 29 | 30 -61.552865965064605 30 | 29 -61.552865965064605 31 | 9 -61.552865965064605 32 | 17 -61.552865965064605 33 | 21 -50.586976361817456 34 | 19 -50.586976361817456 35 | 30 -50.586976361817456 36 | 29 -50.586976361817456 37 | 8 -50.586976361817456 38 | 7 -50.586976361817456 39 | 13 -50.586976361817456 40 | 9 -50.586976361817456 41 | 12 -50.586976361817456 42 | 6 -50.586976361817456 43 | 28 -50.586976361817456 44 | 21 -50.586976361817456 45 | 7 -50.586976361817456 46 | 12 -50.586976361817456 47 | 3 -50.586976361817456 48 | 7 -50.586976361817456 49 | 12 -50.586976361817456 50 | 21 -50.586976361817456 51 | 21 -50.586976361817456 52 | 3 -50.586976361817456 53 | 21 -50.586976361817456 54 | 21 -50.586976361817456 55 | 7 -50.586976361817456 56 | 9 -50.586976361817456 57 | 12 -50.586976361817456 58 | 17 -50.586976361817456 59 | 19 -19.745991163812985 60 | 30 -19.745991163812985 61 | 26 -19.745991163812985 62 | 8 -19.745991163812985 63 | 7 -19.745991163812985 64 | 13 -19.745991163812985 65 | 13 -19.745991163812985 66 | 8 -19.745991163812985 67 | 7 -19.745991163812985 68 | 12 -19.745991163812985 69 | 17 -19.745991163812985 70 | 19 -54.528993898737625 71 | 30 -54.528993898737625 72 | 9 -54.528993898737625 73 | 25 -54.528993898737625 74 | 26 -54.528993898737625 75 | 7 -54.528993898737625 76 | 12 -54.528993898737625 77 | 8 -54.528993898737625 78 | 7 -54.528993898737625 79 | 12 -54.528993898737625 80 | 8 -54.528993898737625 81 | 7 -54.528993898737625 82 | 9 -54.528993898737625 83 | 12 -54.528993898737625 84 | 8 -54.528993898737625 85 | 7 -54.528993898737625 86 | 30 -54.528993898737625 87 | 21 -54.528993898737625 88 | 30 -54.528993898737625 89 | 20 -54.528993898737625 90 | 15 -54.528993898737625 91 | 7 -54.528993898737625 92 | 7 -54.528993898737625 93 | 12 -54.528993898737625 94 | 12 -54.528993898737625 95 | 30 -54.528993898737625 96 | 17 -54.528993898737625 97 | 19 -21.642864657110263 98 | 31 -21.642864657110263 99 | 21 -21.642864657110263 100 | 28 -21.642864657110263 101 | 3 -21.642864657110263 102 | 8 -21.642864657110263 103 | 7 -21.642864657110263 104 | 19 -21.642864657110263 105 | 11 -21.642864657110263 106 | 26 -21.642864657110263 107 | 7 -21.642864657110263 108 | 17 -21.642864657110263 109 | 3 -41.163958681094705 110 | 8 -41.163958681094705 111 | 7 -41.163958681094705 112 | 9 -41.163958681094705 113 | 12 -41.163958681094705 114 | 6 -41.163958681094705 115 | 19 -41.163958681094705 116 | 11 -41.163958681094705 117 | 26 -41.163958681094705 118 | 6 -41.163958681094705 119 | 19 -41.163958681094705 120 | 11 -41.163958681094705 121 | 26 -41.163958681094705 122 | 6 -41.163958681094705 123 | 19 -41.163958681094705 124 | 11 -41.163958681094705 125 | 26 -41.163958681094705 126 | 7 -41.163958681094705 127 | 12 -41.163958681094705 128 | 17 -41.163958681094705 129 | 7 -47.69120077996808 130 | 9 -47.69120077996808 131 | 15 -47.69120077996808 132 | 6 -47.69120077996808 133 | 28 -47.69120077996808 134 | 3 -47.69120077996808 135 | 12 -47.69120077996808 136 | 21 -47.69120077996808 137 | 29 -47.69120077996808 138 | 21 -47.69120077996808 139 | 30 -47.69120077996808 140 | 29 -47.69120077996808 141 | 19 -47.69120077996808 142 | 12 -47.69120077996808 143 | 30 -47.69120077996808 144 | 20 -47.69120077996808 145 | 9 -47.69120077996808 146 | 12 -47.69120077996808 147 | 25 -47.69120077996808 148 | 26 -47.69120077996808 149 | 3 -47.69120077996808 150 | 12 -47.69120077996808 151 | 17 -47.69120077996808 152 | 7 -49.193906625295206 153 | 12 -49.193906625295206 154 | 11 -49.193906625295206 155 | 26 -49.193906625295206 156 | 9 -49.193906625295206 157 | 3 -49.193906625295206 158 | 9 -49.193906625295206 159 | 12 -49.193906625295206 160 | 7 -49.193906625295206 161 | 19 -49.193906625295206 162 | 11 -49.193906625295206 163 | 21 -49.193906625295206 164 | 6 -49.193906625295206 165 | 3 -49.193906625295206 166 | 19 -49.193906625295206 167 | 8 -49.193906625295206 168 | 12 -49.193906625295206 169 | 26 -49.193906625295206 170 | 7 -49.193906625295206 171 | 19 -49.193906625295206 172 | 30 -49.193906625295206 173 | 21 -49.193906625295206 174 | 17 -49.193906625295206 175 | 19 -52.77654733531991 176 | 31 -52.77654733531991 177 | 8 -52.77654733531991 178 | 19 -52.77654733531991 179 | 7 -52.77654733531991 180 | 28 -52.77654733531991 181 | 21 -52.77654733531991 182 | 25 -52.77654733531991 183 | 26 -52.77654733531991 184 | 27 -52.77654733531991 185 | 21 -52.77654733531991 186 | 25 -52.77654733531991 187 | 7 -52.77654733531991 188 | 9 -52.77654733531991 189 | 12 -52.77654733531991 190 | 32 -52.77654733531991 191 | 19 -52.77654733531991 192 | 30 -52.77654733531991 193 | 11 -52.77654733531991 194 | 21 -52.77654733531991 195 | 30 -52.77654733531991 196 | 21 -52.77654733531991 197 | 12 -52.77654733531991 198 | 21 -52.77654733531991 199 | 21 -52.77654733531991 200 | 29 -52.77654733531991 201 | 17 -52.77654733531991 202 | 19 -162.86896772139426 203 | 31 -162.86896772139426 204 | 21 -162.86896772139426 205 | 8 -162.86896772139426 206 | 19 -162.86896772139426 207 | 25 -162.86896772139426 208 | 26 -162.86896772139426 209 | 21 -162.86896772139426 210 | 27 -162.86896772139426 211 | 25 -162.86896772139426 212 | 7 -162.86896772139426 213 | 13 -162.86896772139426 214 | 13 -162.86896772139426 215 | 28 -162.86896772139426 216 | 21 -162.86896772139426 217 | 19 -162.86896772139426 218 | 5 -162.86896772139426 219 | 7 -162.86896772139426 220 | 8 -162.86896772139426 221 | 7 -162.86896772139426 222 | 9 -162.86896772139426 223 | 12 -162.86896772139426 224 | 19 -162.86896772139426 225 | 30 -162.86896772139426 226 | 9 -162.86896772139426 227 | 12 -162.86896772139426 228 | 25 -162.86896772139426 229 | 7 -162.86896772139426 230 | 12 -162.86896772139426 231 | 8 -162.86896772139426 232 | 32 -162.86896772139426 233 | 19 -162.86896772139426 234 | 30 -162.86896772139426 235 | 7 -162.86896772139426 236 | 23 -162.86896772139426 237 | 9 -162.86896772139426 238 | 12 -162.86896772139426 239 | 8 -162.86896772139426 240 | 12 -162.86896772139426 241 | 5 -162.86896772139426 242 | 7 -162.86896772139426 243 | 19 -162.86896772139426 244 | 21 -162.86896772139426 245 | 21 -162.86896772139426 246 | 26 -162.86896772139426 247 | 7 -162.86896772139426 248 | 7 -162.86896772139426 249 | 12 -162.86896772139426 250 | 11 -162.86896772139426 251 | 21 -162.86896772139426 252 | 30 -162.86896772139426 253 | 29 -162.86896772139426 254 | 8 -162.86896772139426 255 | 9 -162.86896772139426 256 | 6 -162.86896772139426 257 | 7 -162.86896772139426 258 | 7 -162.86896772139426 259 | 12 -162.86896772139426 260 | 8 -162.86896772139426 261 | 13 -162.86896772139426 262 | 11 -162.86896772139426 263 | 30 -162.86896772139426 264 | 7 -162.86896772139426 265 | 9 -162.86896772139426 266 | 12 -162.86896772139426 267 | 8 -162.86896772139426 268 | 12 -162.86896772139426 269 | 6 -162.86896772139426 270 | 3 -162.86896772139426 271 | 8 -162.86896772139426 272 | 12 -162.86896772139426 273 | 8 -162.86896772139426 274 | 7 -162.86896772139426 275 | 15 -162.86896772139426 276 | 6 -162.86896772139426 277 | 8 -162.86896772139426 278 | 7 -162.86896772139426 279 | 15 -162.86896772139426 280 | 6 -162.86896772139426 281 | 8 -162.86896772139426 282 | 7 -162.86896772139426 283 | 15 -162.86896772139426 284 | 11 -162.86896772139426 285 | 21 -162.86896772139426 286 | 26 -162.86896772139426 287 | 8 -162.86896772139426 288 | 7 -162.86896772139426 289 | 12 -162.86896772139426 290 | 17 -162.86896772139426 291 | -------------------------------------------------------------------------------- /src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz --------------------------------------------------------------------------------
SequenceExampleGenerator
Example
SequenceExample
n
SequenceClassificationLearningAlgorithm
Label
SequenceExamlpe
SequenceExampleGeneratorKernelBasedAlg
SequenceExampleGeneratorLinearAlg
LinearMethod
TypeIdResolver
SequenceExamplesGenerator
46 | * \(argmin_{\mathbf{w}} \frac{1}{2} \left \| \mathbf{w}-\mathbf{w}_t \right \|^2\) 47 | *
such that \( l(\mathbf{w};(\mathbf{x}_t,y_t))=0 \) 48 | */ 49 | HARD_PA, 50 | 51 | /** 52 | * The new prediction hypothesis after a new example \( \mathbf{x}_t\) with label \(y_t\) is observed is: 53 | *
54 | * \(argmin_{\mathbf{w}} \frac{1}{2} \left \| \mathbf{w}-\mathbf{w}_t \right \|^2 + C\xi \) 55 | *
such that \( l(\mathbf{w};(\mathbf{x}_t,y_t))\leq \xi \) and \( \xi\geq 0\) 56 | */ 57 | PA_I, 58 | 59 | /** 60 | * The new prediction hypothesis after a new example \( \mathbf{x}_t\) with label \(y_t\) is observed is: 61 | *
62 | * \(argmin_{\mathbf{w}} \frac{1}{2} \left \| \mathbf{w}-\mathbf{w}_t \right \|^2 + C\xi^2 \) 63 | *
such that \( l(\mathbf{w};(\mathbf{x}_t,y_t))\leq \xi \) and \( \xi\geq 0\) 64 | */ 65 | PA_II 66 | } 67 | 68 | 69 | protected Label label; 70 | 71 | 72 | 73 | protected Policy policy = Policy.PA_II; 74 | 75 | protected float c = 1;//the aggressiveness parameter 76 | 77 | 78 | 79 | @Override 80 | public void reset() { 81 | this.getPredictionFunction().reset(); 82 | } 83 | 84 | 85 | /** 86 | * @return the updating policy 87 | */ 88 | public Policy getPolicy() { 89 | return policy; 90 | } 91 | 92 | 93 | /** 94 | * @param policy the updating policy to set 95 | */ 96 | public void setPolicy(Policy policy) { 97 | this.policy = policy; 98 | } 99 | 100 | 101 | /** 102 | * @return the aggressiveness parameter 103 | */ 104 | public float getC() { 105 | return c; 106 | } 107 | 108 | 109 | /** 110 | * @param c the aggressiveness to set 111 | */ 112 | public void setC(float c) { 113 | this.c = c; 114 | } 115 | 116 | 117 | protected float computeWeight(Example example, float lossValue, float exampleSquaredNorm, float aggressiveness) { 118 | float weight=1; 119 | 120 | switch(policy){ 121 | case HARD_PA: 122 | weight=lossValue/exampleSquaredNorm; 123 | break; 124 | case PA_I: 125 | weight=lossValue/exampleSquaredNorm; 126 | if(weight>aggressiveness){ 127 | weight=aggressiveness; 128 | } 129 | break; 130 | case PA_II: 131 | weight=lossValue/(exampleSquaredNorm+1/(2*aggressiveness)); 132 | break; 133 | } 134 | 135 | return weight; 136 | } 137 | 138 | 139 | @Override 140 | public void setLabels(List labels){ 141 | if(labels.size()!=1){ 142 | throw new IllegalArgumentException("The Passive Aggressive algorithm is a binary method which can learn a single Label"); 143 | } 144 | else{ 145 | this.label=labels.get(0); 146 | this.getPredictionFunction().setLabels(labels); 147 | } 148 | } 149 | 150 | 151 | @Override 152 | public List getLabels() { 153 | return Arrays.asList(label); 154 | } 155 | 156 | @Override 157 | public void learn(Dataset dataset){ 158 | while(dataset.hasNextExample()){ 159 | this.learn(dataset.getNextExample()); 160 | } 161 | dataset.reset(); 162 | } 163 | 164 | @Override 165 | public Label getLabel(){ 166 | return this.label; 167 | } 168 | 169 | @Override 170 | public void setLabel(Label label){ 171 | this.setLabels(Arrays.asList(label)); 172 | } 173 | 174 | } 175 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/budgetedAlgorithm/BudgetedLearningAlgorithm.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.budgetedAlgorithm; 17 | 18 | 19 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 20 | import it.uniroma2.sag.kelp.data.example.Example; 21 | import it.uniroma2.sag.kelp.data.label.Label; 22 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 23 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 24 | import it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm; 25 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 26 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryKernelMachineModel; 27 | 28 | import java.util.Arrays; 29 | import java.util.List; 30 | 31 | /** 32 | * It is binary kernel-based online learning method that binds the number of support vector to a fix number (i.e. the budget) 33 | * When the budget is full, a particular updating policy (that must be specified by extending classes) is adopted 34 | * 35 | * @author Simone Filice 36 | * 37 | */ 38 | public abstract class BudgetedLearningAlgorithm implements OnlineLearningAlgorithm, BinaryLearningAlgorithm, KernelMethod{ 39 | 40 | protected int budget; 41 | protected Label label; 42 | 43 | /** 44 | * Returns the budget, i.e. the maximum number of support vectors 45 | * 46 | * @return the budget 47 | */ 48 | public int getBudget() { 49 | return budget; 50 | } 51 | 52 | /** 53 | * Sets the budget, i.e. the maximum number of support vectors 54 | * 55 | * @param budget the budget to set 56 | */ 57 | public void setBudget(int budget) { 58 | this.budget = budget; 59 | } 60 | 61 | @Override 62 | public void learn(Dataset dataset){ 63 | while(dataset.hasNextExample()){ 64 | this.learn(dataset.getNextExample()); 65 | } 66 | dataset.reset(); 67 | } 68 | 69 | @Override 70 | public Prediction learn(Example example){ 71 | BinaryKernelMachineModel model = (BinaryKernelMachineModel) this.getPredictionFunction().getModel(); 72 | if(model.getSupportVectors().size() labels){ 90 | if(labels.size()!=1){ 91 | throw new IllegalArgumentException("Any budgeted learning algorithm is a binary method which can learn a single Label"); 92 | } 93 | else{ 94 | this.label=labels.get(0); 95 | this.getPredictionFunction().setLabels(labels); 96 | } 97 | } 98 | 99 | 100 | @Override 101 | public List getLabels() { 102 | return Arrays.asList(label); 103 | } 104 | 105 | @Override 106 | public Label getLabel(){ 107 | return this.label; 108 | } 109 | 110 | @Override 111 | public void setLabel(Label label){ 112 | this.setLabels(Arrays.asList(label)); 113 | } 114 | 115 | 116 | } 117 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/budgetedAlgorithm/RandomizedBudgetPerceptron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.budgetedAlgorithm; 17 | 18 | import it.uniroma2.sag.kelp.data.example.Example; 19 | import it.uniroma2.sag.kelp.data.label.Label; 20 | import it.uniroma2.sag.kelp.kernel.Kernel; 21 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 22 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 23 | import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm; 24 | import it.uniroma2.sag.kelp.learningalgorithm.MetaLearningAlgorithm; 25 | import it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm; 26 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 27 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 28 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryKernelMachineModel; 29 | import it.uniroma2.sag.kelp.predictionfunction.model.SupportVector; 30 | 31 | import java.util.Random; 32 | 33 | import com.fasterxml.jackson.annotation.JsonIgnore; 34 | import com.fasterxml.jackson.annotation.JsonTypeName; 35 | 36 | /** 37 | * It is a variation of the Randomized Budget Perceptron proposed in 38 | * [CavallantiCOLT2006] G. Cavallanti, N. Cesa-Bianchi, C. Gentile. Tracking the best hyperplane with a simple budget Perceptron. In proc. of the 19-th annual conference on Computational Learning Theory. (2006) 39 | * 40 | * Until the budget is not reached the online learning updating policy is the one of the baseAlgorithm that this 41 | * meta-algorithm is exploiting. When the budget is full, a random support vector is deleted and the perceptron updating policy is 42 | * adopted 43 | * 44 | * @author Simone Filice 45 | * 46 | */ 47 | @JsonTypeName("randomizedPerceptron") 48 | public class RandomizedBudgetPerceptron extends BudgetedLearningAlgorithm implements MetaLearningAlgorithm{ 49 | 50 | private static final long DEFAULT_SEED=1; 51 | private long initialSeed = DEFAULT_SEED; 52 | @JsonIgnore 53 | private Random randomGenerator; 54 | 55 | private OnlineLearningAlgorithm baseAlgorithm; 56 | 57 | public RandomizedBudgetPerceptron(){ 58 | randomGenerator = new Random(initialSeed); 59 | } 60 | 61 | public RandomizedBudgetPerceptron(int budget, OnlineLearningAlgorithm baseAlgorithm, long seed, Label label){ 62 | randomGenerator = new Random(initialSeed); 63 | this.setBudget(budget); 64 | this.setBaseAlgorithm(baseAlgorithm); 65 | this.setSeed(seed); 66 | this.setLabel(label); 67 | } 68 | 69 | /** 70 | * Sets the seed for the random generator adopted to select the support vector to delete 71 | * 72 | * @param seed the seed of the randomGenerator 73 | */ 74 | public void setSeed(long seed){ 75 | this.initialSeed = seed; 76 | this.randomGenerator.setSeed(seed); 77 | } 78 | 79 | @Override 80 | public RandomizedBudgetPerceptron duplicate() { 81 | RandomizedBudgetPerceptron copy = new RandomizedBudgetPerceptron(); 82 | copy.setBudget(budget); 83 | copy.setBaseAlgorithm(baseAlgorithm.duplicate()); 84 | copy.setSeed(initialSeed); 85 | return copy; 86 | } 87 | 88 | @Override 89 | public void reset() { 90 | this.baseAlgorithm.reset(); 91 | this.randomGenerator.setSeed(initialSeed); 92 | } 93 | 94 | @Override 95 | protected Prediction predictAndLearnWithFullBudget(Example example) { 96 | Prediction prediction = this.baseAlgorithm.getPredictionFunction().predict(example); 97 | 98 | if((prediction.getScore(getLabel())>0) != example.isExampleOf(getLabel())){ 99 | int svToDelete = this.randomGenerator.nextInt(budget); 100 | float weight = 1; 101 | if(!example.isExampleOf(getLabels().get(0))){ 102 | weight=-1; 103 | } 104 | SupportVector sv = new SupportVector(weight, example); 105 | 106 | ((BinaryKernelMachineModel)this.baseAlgorithm.getPredictionFunction().getModel()).setSupportVector(sv, svToDelete); 107 | } 108 | return prediction; 109 | } 110 | 111 | @Override 112 | public void setBaseAlgorithm(LearningAlgorithm baseAlgorithm) { 113 | if(baseAlgorithm instanceof OnlineLearningAlgorithm && baseAlgorithm instanceof KernelMethod && baseAlgorithm instanceof BinaryLearningAlgorithm){ 114 | this.baseAlgorithm = (OnlineLearningAlgorithm) baseAlgorithm; 115 | }else{ 116 | throw new IllegalArgumentException("a valid baseAlgorithm for the Randomized Budget Perceptron must implement OnlineLearningAlgorithm, BinaryLeaningAlgorithm and KernelMethod"); 117 | } 118 | } 119 | 120 | @Override 121 | public OnlineLearningAlgorithm getBaseAlgorithm() { 122 | return this.baseAlgorithm; 123 | } 124 | 125 | @Override 126 | public PredictionFunction getPredictionFunction() { 127 | return this.baseAlgorithm.getPredictionFunction(); 128 | } 129 | 130 | @Override 131 | public Kernel getKernel() { 132 | return ((KernelMethod)this.baseAlgorithm).getKernel(); 133 | } 134 | 135 | @Override 136 | public void setKernel(Kernel kernel) { 137 | ((KernelMethod)this.baseAlgorithm).setKernel(kernel); 138 | 139 | } 140 | 141 | @Override 142 | protected Prediction predictAndLearnWithAvailableBudget(Example example) { 143 | return this.baseAlgorithm.learn(example); 144 | } 145 | 146 | @Override 147 | public void setPredictionFunction(PredictionFunction predictionFunction) { 148 | this.baseAlgorithm.setPredictionFunction(predictionFunction); 149 | } 150 | 151 | } 152 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/budgetedAlgorithm/Stoptron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.budgetedAlgorithm; 17 | 18 | import it.uniroma2.sag.kelp.data.example.Example; 19 | import it.uniroma2.sag.kelp.data.label.Label; 20 | import it.uniroma2.sag.kelp.kernel.Kernel; 21 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 22 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 23 | import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm; 24 | import it.uniroma2.sag.kelp.learningalgorithm.MetaLearningAlgorithm; 25 | import it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm; 26 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 27 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 28 | 29 | import com.fasterxml.jackson.annotation.JsonTypeName; 30 | 31 | /** 32 | * It is a variation of the Stoptron proposed in 33 | * [OrabonaICML2008] Francesco Orabona, Joseph Keshet, and Barbara Caputo. The projectron: a bounded kernel-based perceptron. In Int. Conf. on Machine Learning (2008) 34 | * 35 | * Until the budget is not reached the online learning updating policy is the one of the baseAlgorithm that this 36 | * meta-algorithm is exploiting. When the budget is full, the learning process ends 37 | * 38 | * @author Simone Filice 39 | * 40 | */ 41 | @JsonTypeName("stoptron") 42 | public class Stoptron extends BudgetedLearningAlgorithm implements MetaLearningAlgorithm{ 43 | 44 | private OnlineLearningAlgorithm baseAlgorithm; 45 | 46 | public Stoptron(){ 47 | 48 | } 49 | 50 | public Stoptron(int budget, OnlineLearningAlgorithm baseAlgorithm, Label label){ 51 | this.setBudget(budget); 52 | this.setBaseAlgorithm(baseAlgorithm); 53 | this.setLabel(label); 54 | } 55 | 56 | @Override 57 | public Stoptron duplicate() { 58 | Stoptron copy = new Stoptron(); 59 | copy.setBudget(budget); 60 | copy.setBaseAlgorithm(baseAlgorithm.duplicate()); 61 | return copy; 62 | } 63 | 64 | @Override 65 | public void reset() { 66 | this.baseAlgorithm.reset(); 67 | } 68 | 69 | @Override 70 | protected Prediction predictAndLearnWithFullBudget(Example example) { 71 | return this.baseAlgorithm.getPredictionFunction().predict(example); 72 | } 73 | 74 | @Override 75 | public void setBaseAlgorithm(LearningAlgorithm baseAlgorithm) { 76 | if(baseAlgorithm instanceof OnlineLearningAlgorithm && baseAlgorithm instanceof KernelMethod && baseAlgorithm instanceof BinaryLearningAlgorithm){ 77 | this.baseAlgorithm = (OnlineLearningAlgorithm) baseAlgorithm; 78 | }else{ 79 | throw new IllegalArgumentException("a valid baseAlgorithm for the Stoptron must implement OnlineLearningAlgorithm, BinaryLeaningAlgorithm and KernelMethod"); 80 | } 81 | } 82 | 83 | @Override 84 | public OnlineLearningAlgorithm getBaseAlgorithm() { 85 | return this.baseAlgorithm; 86 | } 87 | 88 | @Override 89 | public PredictionFunction getPredictionFunction() { 90 | return this.baseAlgorithm.getPredictionFunction(); 91 | } 92 | 93 | @Override 94 | public Kernel getKernel() { 95 | return ((KernelMethod)this.baseAlgorithm).getKernel(); 96 | } 97 | 98 | @Override 99 | public void setKernel(Kernel kernel) { 100 | ((KernelMethod)this.baseAlgorithm).setKernel(kernel); 101 | 102 | } 103 | 104 | @Override 105 | protected Prediction predictAndLearnWithAvailableBudget(Example example) { 106 | return this.baseAlgorithm.learn(example); 107 | } 108 | 109 | @Override 110 | public void setPredictionFunction(PredictionFunction predictionFunction) { 111 | this.baseAlgorithm.setPredictionFunction(predictionFunction); 112 | } 113 | 114 | } 115 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/dcd/DCDLoss.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.dcd; 2 | 3 | public enum DCDLoss { 4 | L1, L2 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/hmm/SequenceClassificationKernelBasedLearningAlgorithm.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.hmm; 17 | 18 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGenerator; 19 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGeneratorKernel; 20 | import it.uniroma2.sag.kelp.kernel.Kernel; 21 | import it.uniroma2.sag.kelp.kernel.cache.KernelCache; 22 | import it.uniroma2.sag.kelp.kernel.standard.LinearKernelCombination; 23 | import it.uniroma2.sag.kelp.kernel.vector.LinearKernel; 24 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 25 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 26 | import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm; 27 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 28 | 29 | /** 30 | * /** This class implements a sequential labeling paradigm. 31 | * Given sequences of items (each implemented as an Example and 32 | * associated to one Label) this class allow to apply a generic 33 | * LearningAlgorithm to use the "history" of each item in the 34 | * sequence in order to improve the classification quality. In other words, the 35 | * classification of each example does not depend only its representation, but 36 | * it also depend on its "history", in terms of the classed assigned to the 37 | * preceding examples. 38 | * This class should be used when a kernel-based learning algorithm is 39 | * used, thus directly operating in the implicit space underlying a kernel 40 | * function. 41 | * 42 | * 43 | * This algorithms was inspired by the work of: 44 | * Y. Altun, I. Tsochantaridis, and T. Hofmann. Hidden Markov support vector 45 | * machines. In Proceedings of the Twentieth International Conference on Machine 46 | * Learning, 2003. 47 | * 48 | * @author Danilo Croce 49 | * 50 | */ 51 | public class SequenceClassificationKernelBasedLearningAlgorithm extends SequenceClassificationLearningAlgorithm 52 | implements KernelMethod { 53 | 54 | private final static String TRANSITION_REPRESENTATION_NAME = "__trans_rep__"; 55 | 56 | private LinearKernelCombination sequenceBasedKernel; 57 | 58 | public SequenceClassificationKernelBasedLearningAlgorithm() { 59 | 60 | } 61 | 62 | /** 63 | * @param baseLearningAlgorithm 64 | * the learning algorithm devoted to the acquisition of a model 65 | * after that each example has been enriched with its "history" 66 | * @param transitionsOrder 67 | * given a targeted item in the sequence, this variable 68 | * determines the number of previous example considered in the 69 | * learning/labeling process. 70 | * @param transitionWeight 71 | * the importance of the transition-based features during the 72 | * learning process. Higher valuers will assign more importance 73 | * to the transitions. 74 | * @throws Exception 75 | * The input baseLearningAlgorithm is not a 76 | * kernel-based method 77 | */ 78 | public SequenceClassificationKernelBasedLearningAlgorithm(BinaryLearningAlgorithm baseLearningAlgorithm, 79 | int transitionsOrder, float transitionWeight) throws Exception { 80 | 81 | if (!(baseLearningAlgorithm instanceof KernelMethod)) { 82 | throw new Exception("ERROR: the input baseLearningAlgorithm is not a kernel-based method!"); 83 | } 84 | 85 | Kernel inputKernel = ((KernelMethod) baseLearningAlgorithm).getKernel(); 86 | 87 | sequenceBasedKernel = new LinearKernelCombination(); 88 | sequenceBasedKernel.addKernel(1, inputKernel); 89 | Kernel transitionBasedKernel = new LinearKernel(TRANSITION_REPRESENTATION_NAME); 90 | sequenceBasedKernel.addKernel(transitionWeight, transitionBasedKernel); 91 | sequenceBasedKernel.normalizeWeights(); 92 | 93 | setKernel(sequenceBasedKernel); 94 | 95 | BinaryLearningAlgorithm binaryLearningAlgorithmCopy = (BinaryLearningAlgorithm) baseLearningAlgorithm 96 | .duplicate(); 97 | 98 | ((KernelMethod) binaryLearningAlgorithmCopy).setKernel(sequenceBasedKernel); 99 | 100 | OneVsAllLearning oneVsAllLearning = new OneVsAllLearning(); 101 | oneVsAllLearning.setBaseAlgorithm(binaryLearningAlgorithmCopy); 102 | 103 | super.setBaseLearningAlgorithm(oneVsAllLearning); 104 | 105 | SequenceExampleGenerator sequenceExamplesGenerator = new SequenceExampleGeneratorKernel( 106 | transitionsOrder, TRANSITION_REPRESENTATION_NAME); 107 | 108 | super.setSequenceExampleGenerator(sequenceExamplesGenerator); 109 | } 110 | 111 | @Override 112 | public LearningAlgorithm duplicate() { 113 | return null; 114 | } 115 | 116 | @Override 117 | public LearningAlgorithm getBaseAlgorithm() { 118 | return super.getBaseLearningAlgorithm(); 119 | } 120 | 121 | @Override 122 | public Kernel getKernel() { 123 | return sequenceBasedKernel; 124 | } 125 | 126 | @Override 127 | public void setBaseAlgorithm(LearningAlgorithm baseAlgorithm) { 128 | super.setBaseLearningAlgorithm(baseAlgorithm); 129 | } 130 | 131 | @Override 132 | public void setKernel(Kernel kernel) { 133 | this.sequenceBasedKernel = (LinearKernelCombination) kernel; 134 | } 135 | 136 | public void setKernelCache(KernelCache cache) { 137 | this.getKernel().setKernelCache(cache); 138 | } 139 | 140 | } 141 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/hmm/SequenceClassificationLinearLearningAlgorithm.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.hmm; 17 | 18 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGenerator; 19 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGeneratorLinear; 20 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 21 | import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm; 22 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 23 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 24 | 25 | /** 26 | * This class implements a sequential labeling paradigm. 27 | * Given sequences of items (each implemented as an Example and 28 | * associated to one Label) this class allow to apply a generic 29 | * LearningAlgorithm to use the "history" of each item in the 30 | * sequence in order to improve the classification quality. In other words, the 31 | * classification of each example does not depend only its representation, but 32 | * it also depend on its "history", in terms of the classed assigned to the 33 | * preceding examples. 34 | * This class should be used when a linear learning algorithm is used, 35 | * thus directly operating in the representation space. 36 | * 37 | * 38 | * This algorithms was inspired by the work of: 39 | * Y. Altun, I. Tsochantaridis, and T. Hofmann. Hidden Markov support vector 40 | * machines. In Proceedings of the Twentieth International Conference on Machine 41 | * Learning, 2003. 42 | * 43 | * @author Danilo Croce 44 | * 45 | */ 46 | public class SequenceClassificationLinearLearningAlgorithm extends SequenceClassificationLearningAlgorithm 47 | implements LinearMethod { 48 | 49 | /** 50 | * @param baseLearningAlgorithm 51 | * the "linear" learning algorithm devoted to the acquisition of 52 | * a model after that each example has been enriched with its 53 | * "history" 54 | * @param transitionsOrder 55 | * given a targeted item in the sequence, this variable 56 | * determines the number of previous example considered in the 57 | * learning/labeling process. 58 | * @param transitionWeight 59 | * the importance of the transition-based features during the 60 | * learning process. Higher valuers will assign more importance 61 | * to the transitions. 62 | * @throws Exception The input baseLearningAlgorithm is not a Linear method 63 | */ 64 | public SequenceClassificationLinearLearningAlgorithm(BinaryLearningAlgorithm baseLearningAlgorithm, 65 | int transitionsOrder, float transitionWeight) throws Exception { 66 | 67 | if (!(baseLearningAlgorithm instanceof LinearMethod)) { 68 | throw new Exception("ERROR: the input baseLearningAlgorithm is not a Linear method!"); 69 | } 70 | 71 | OneVsAllLearning oneVsAllLearning = new OneVsAllLearning(); 72 | oneVsAllLearning.setBaseAlgorithm(baseLearningAlgorithm); 73 | 74 | super.setBaseLearningAlgorithm(oneVsAllLearning); 75 | String representation = ((LinearMethod) baseLearningAlgorithm).getRepresentation(); 76 | 77 | SequenceExampleGenerator sequenceExamplesGenerator = new SequenceExampleGeneratorLinear(transitionsOrder, 78 | representation, transitionWeight); 79 | 80 | super.setSequenceExampleGenerator(sequenceExamplesGenerator); 81 | } 82 | 83 | @Override 84 | public LearningAlgorithm duplicate() { 85 | // TODO Auto-generated method stub 86 | return null; 87 | } 88 | 89 | @Override 90 | public LearningAlgorithm getBaseAlgorithm() { 91 | return super.getBaseLearningAlgorithm(); 92 | } 93 | 94 | @Override 95 | public String getRepresentation() { 96 | return ((SequenceClassificationLinearLearningAlgorithm) getSequenceExampleGenerator()).getRepresentation(); 97 | } 98 | 99 | @Override 100 | public void setBaseAlgorithm(LearningAlgorithm baseAlgorithm) { 101 | super.setBaseLearningAlgorithm(baseAlgorithm); 102 | } 103 | 104 | @Override 105 | public void setRepresentation(String representationName) { 106 | ((SequenceClassificationLinearLearningAlgorithm) getSequenceExampleGenerator()) 107 | .setRepresentation(representationName); 108 | } 109 | 110 | } 111 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/COPYRIGHT: -------------------------------------------------------------------------------- 1 | 2 | Copyright (c) 2007-2013 The LIBLINEAR Project. 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions 7 | are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | 16 | 3. Neither name of copyright holders nor the names of its contributors 17 | may be used to endorse or promote products derived from this software 18 | without specific prior written permission. 19 | 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR 25 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 26 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 27 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 28 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 29 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 30 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 31 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/L2R_L2_SvcFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | /** 19 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 20 | * C++ sources. Original Java sources (v 1.94) are available at 21 | * http://liblinear.bwaldvogel.de 22 | * 23 | * @author Danilo Croce 24 | */ 25 | public class L2R_L2_SvcFunction implements TronFunction { 26 | 27 | protected final Problem prob; 28 | protected final double[] C; 29 | protected final int[] I; 30 | protected final double[] z; 31 | 32 | protected int sizeI; 33 | 34 | public L2R_L2_SvcFunction(Problem prob, double[] C) { 35 | int l = prob.l; 36 | 37 | this.prob = prob; 38 | 39 | z = new double[l]; 40 | I = new int[l]; 41 | this.C = C; 42 | } 43 | 44 | public double fun(double[] w) { 45 | int i; 46 | double f = 0; 47 | double[] y = prob.y; 48 | int l = prob.l; 49 | int w_size = get_nr_variable(); 50 | 51 | Xv(w, z); 52 | 53 | for (i = 0; i < w_size; i++) 54 | f += w[i] * w[i]; 55 | f /= 2.0; 56 | for (i = 0; i < l; i++) { 57 | z[i] = y[i] * z[i]; 58 | double d = 1 - z[i]; 59 | if (d > 0) 60 | f += C[i] * d * d; 61 | } 62 | 63 | return (f); 64 | } 65 | 66 | public int get_nr_variable() { 67 | return prob.n; 68 | } 69 | 70 | public void grad(double[] w, double[] g) { 71 | double[] y = prob.y; 72 | int l = prob.l; 73 | int w_size = get_nr_variable(); 74 | 75 | sizeI = 0; 76 | for (int i = 0; i < l; i++) { 77 | if (z[i] < 1) { 78 | z[sizeI] = C[i] * y[i] * (z[i] - 1); 79 | I[sizeI] = i; 80 | sizeI++; 81 | } 82 | } 83 | subXTv(z, g); 84 | 85 | for (int i = 0; i < w_size; i++) 86 | g[i] = w[i] + 2 * g[i]; 87 | } 88 | 89 | public void Hv(double[] s, double[] Hs) { 90 | int i; 91 | int w_size = get_nr_variable(); 92 | double[] wa = new double[sizeI]; 93 | 94 | subXv(s, wa); 95 | for (i = 0; i < sizeI; i++) 96 | wa[i] = C[I[i]] * wa[i]; 97 | 98 | subXTv(wa, Hs); 99 | for (i = 0; i < w_size; i++) 100 | Hs[i] = s[i] + 2 * Hs[i]; 101 | } 102 | 103 | protected void subXTv(double[] v, double[] XTv) { 104 | int i; 105 | int w_size = get_nr_variable(); 106 | 107 | for (i = 0; i < w_size; i++) 108 | XTv[i] = 0; 109 | 110 | for (i = 0; i < sizeI; i++) { 111 | for (LibLinearFeature s : prob.x[I[i]]) { 112 | XTv[s.getIndex() - 1] += v[i] * s.getValue(); 113 | } 114 | } 115 | } 116 | 117 | private void subXv(double[] v, double[] Xv) { 118 | 119 | for (int i = 0; i < sizeI; i++) { 120 | Xv[i] = 0; 121 | for (LibLinearFeature s : prob.x[I[i]]) { 122 | Xv[i] += v[s.getIndex() - 1] * s.getValue(); 123 | } 124 | } 125 | } 126 | 127 | protected void Xv(double[] v, double[] Xv) { 128 | 129 | for (int i = 0; i < prob.l; i++) { 130 | Xv[i] = 0; 131 | for (LibLinearFeature s : prob.x[i]) { 132 | Xv[i] += v[s.getIndex() - 1] * s.getValue(); 133 | } 134 | } 135 | } 136 | 137 | } 138 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/L2R_L2_SvrFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | /** 19 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 20 | * C++ sources. Original Java sources (v 1.94) are available at 21 | * http://liblinear.bwaldvogel.de 22 | * 23 | * @author Danilo Croce 24 | */ 25 | public class L2R_L2_SvrFunction extends L2R_L2_SvcFunction { 26 | 27 | private double p; 28 | 29 | public L2R_L2_SvrFunction( Problem prob, double[] C, double p ) { 30 | super(prob, C); 31 | this.p = p; 32 | } 33 | 34 | @Override 35 | public double fun(double[] w) { 36 | double f = 0; 37 | double[] y = prob.y; 38 | int l = prob.l; 39 | int w_size = get_nr_variable(); 40 | double d; 41 | 42 | Xv(w, z); 43 | 44 | for (int i = 0; i < w_size; i++) 45 | f += w[i] * w[i]; 46 | f /= 2; 47 | for (int i = 0; i < l; i++) { 48 | d = z[i] - y[i]; 49 | if (d < -p) 50 | f += C[i] * (d + p) * (d + p); 51 | else if (d > p) f += C[i] * (d - p) * (d - p); 52 | } 53 | 54 | return f; 55 | } 56 | 57 | @Override 58 | public void grad(double[] w, double[] g) { 59 | double[] y = prob.y; 60 | int l = prob.l; 61 | int w_size = get_nr_variable(); 62 | 63 | sizeI = 0; 64 | for (int i = 0; i < l; i++) { 65 | double d = z[i] - y[i]; 66 | 67 | // generate index set I 68 | if (d < -p) { 69 | z[sizeI] = C[i] * (d + p); 70 | I[sizeI] = i; 71 | sizeI++; 72 | } else if (d > p) { 73 | z[sizeI] = C[i] * (d - p); 74 | I[sizeI] = i; 75 | sizeI++; 76 | } 77 | 78 | } 79 | subXTv(z, g); 80 | 81 | for (int i = 0; i < w_size; i++) 82 | g[i] = w[i] + 2 * g[i]; 83 | 84 | } 85 | 86 | } -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/LibLinearFeature.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | /** 19 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 20 | * C++ sources. Original Java sources (v 1.94) are available at 21 | * http://liblinear.bwaldvogel.de 22 | * 23 | * @author Danilo Croce 24 | */ 25 | public interface LibLinearFeature { 26 | 27 | int getIndex(); 28 | 29 | double getValue(); 30 | 31 | void setValue(double value); 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/LibLinearFeatureNode.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | /** 19 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 20 | * C++ sources. Original Java sources (v 1.94) are available at 21 | * http://liblinear.bwaldvogel.de 22 | * 23 | * @author Danilo Croce 24 | */ 25 | public class LibLinearFeatureNode implements LibLinearFeature { 26 | 27 | public final int index; 28 | public double value; 29 | 30 | public LibLinearFeatureNode( final int index, final double value ) { 31 | if (index < 0) throw new IllegalArgumentException("index must be >= 0"); 32 | this.index = index; 33 | this.value = value; 34 | } 35 | 36 | /** 37 | * @since 1.9 38 | */ 39 | public int getIndex() { 40 | return index; 41 | } 42 | 43 | /** 44 | * @since 1.9 45 | */ 46 | public double getValue() { 47 | return value; 48 | } 49 | 50 | /** 51 | * @since 1.9 52 | */ 53 | public void setValue(double value) { 54 | this.value = value; 55 | } 56 | 57 | @Override 58 | public int hashCode() { 59 | final int prime = 31; 60 | int result = 1; 61 | result = prime * result + index; 62 | long temp; 63 | temp = Double.doubleToLongBits(value); 64 | result = prime * result + (int)(temp ^ (temp >>> 32)); 65 | return result; 66 | } 67 | 68 | @Override 69 | public boolean equals(Object obj) { 70 | if (this == obj) return true; 71 | if (obj == null) return false; 72 | if (getClass() != obj.getClass()) return false; 73 | LibLinearFeatureNode other = (LibLinearFeatureNode)obj; 74 | if (index != other.index) return false; 75 | if (Double.doubleToLongBits(value) != Double.doubleToLongBits(other.value)) return false; 76 | return true; 77 | } 78 | 79 | @Override 80 | public String toString() { 81 | return "FeatureNode(idx=" + index + ", value=" + value + ")"; 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/Problem.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | /** 19 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 20 | * C++ sources. Original Java sources (v 1.94) are available at 21 | * http://liblinear.bwaldvogel.de 22 | * 23 | * @author Danilo Croce 24 | */ 25 | import gnu.trove.map.hash.TIntObjectHashMap; 26 | import gnu.trove.map.hash.TObjectIntHashMap; 27 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 28 | import it.uniroma2.sag.kelp.data.example.Example; 29 | import it.uniroma2.sag.kelp.data.example.SimpleExample; 30 | import it.uniroma2.sag.kelp.data.label.Label; 31 | import it.uniroma2.sag.kelp.data.representation.Representation; 32 | import it.uniroma2.sag.kelp.data.representation.Vector; 33 | import it.uniroma2.sag.kelp.data.representation.vector.DenseVector; 34 | import it.uniroma2.sag.kelp.data.representation.vector.SparseVector; 35 | 36 | import java.io.IOException; 37 | import java.util.ArrayList; 38 | import java.util.Map; 39 | 40 | /** 41 | * 42 | * Describes the problem 43 | * 44 | * 45 | * For example, if we have the following training data: 46 | * 47 | * 48 | * LABEL ATTR1 ATTR2 ATTR3 ATTR4 ATTR5 49 | * ----- ----- ----- ----- ----- ----- 50 | * 1 0 0.1 0.2 0 0 51 | * 2 0 0.1 0.3 -1.2 0 52 | * 1 0.4 0 0 0 0 53 | * 2 0 0.1 0 1.4 0.5 54 | * 3 -0.1 -0.2 0.1 1.1 0.1 55 | * 56 | * and bias = 1, then the components of problem are: 57 | * 58 | * l = 5 59 | * n = 6 60 | * 61 | * y -> 1 2 1 2 3 62 | * 63 | * x -> [ ] -> (2,0.1) (3,0.2) (6,1) (-1,?) 64 | * [ ] -> (2,0.1) (3,0.3) (4,-1.2) (6,1) (-1,?) 65 | * [ ] -> (1,0.4) (6,1) (-1,?) 66 | * [ ] -> (2,0.1) (4,1.4) (5,0.5) (6,1) (-1,?) 67 | * [ ] -> (1,-0.1) (2,-0.2) (3,0.1) (4,1.1) (5,0.1) (6,1) (-1,?) 68 | * 69 | */ 70 | public class Problem { 71 | 72 | public enum LibLinearSolverType { 73 | CLASSIFICATION, REGRESSION 74 | } 75 | 76 | public TObjectIntHashMap featureDict = new TObjectIntHashMap(); 77 | 78 | public TIntObjectHashMap featureInverseDict = new TIntObjectHashMap(); 79 | 80 | /** the number of training data */ 81 | public int l; 82 | 83 | /** the number of features (including the bias feature if bias >= 0) */ 84 | public int n; 85 | 86 | /** an array containing the target values */ 87 | public double[] y; 88 | /** array of sparse feature nodes */ 89 | public LibLinearFeature[][] x; 90 | 91 | /** 92 | * If bias >= 0, we assume that one additional feature is added to the 93 | * end of each data instance 94 | */ 95 | public double bias; 96 | 97 | private boolean isInputDense; 98 | 99 | public Problem(Dataset dataset, String reprentationName, Label label, 100 | LibLinearSolverType solverType) { 101 | 102 | this.l = dataset.getNumberOfExamples(); 103 | this.y = new double[l]; 104 | this.x = new LibLinearFeature[l][]; 105 | 106 | ArrayList vectorlist = new ArrayList(); 107 | 108 | if (dataset.getExamples().get(0).getRepresentation(reprentationName) instanceof DenseVector) 109 | isInputDense = true; 110 | 111 | int i = 0; 112 | for (Example e : dataset.getExamples()) { 113 | SimpleExample simpleExample = (SimpleExample) e; 114 | Representation r = simpleExample 115 | .getRepresentation(reprentationName); 116 | Vector vector = (Vector) r; 117 | 118 | vectorlist.add(vector); 119 | 120 | if (solverType == LibLinearSolverType.CLASSIFICATION) { 121 | if (e.isExampleOf(label)) 122 | y[i] = 1; 123 | else 124 | y[i] = -1; 125 | } else { 126 | y[i] = e.getRegressionValue(label); 127 | } 128 | 129 | i++; 130 | } 131 | 132 | initializeExamples(vectorlist); 133 | 134 | } 135 | 136 | private DenseVector getDenseW(double[] w) { 137 | double[] tmp = new double[w.length - 1]; 138 | for (int i = 0; i < w.length - 1; i++) { 139 | tmp[i] = w[i]; 140 | } 141 | return new DenseVector(tmp); 142 | } 143 | 144 | private SparseVector getSparseW(double[] w) { 145 | SparseVector res = new SparseVector(); 146 | 147 | StringBuilder sb = new StringBuilder(); 148 | for (int i = 0; i < w.length - 1; i++) { 149 | sb.append(this.featureInverseDict.get(i + 1) + ":" + w[i] + " "); 150 | } 151 | sb.append("__LIB_LINEAR_BIAS_:" + w[w.length - 1]); 152 | 153 | try { 154 | res.setDataFromText(sb.toString().trim()); 155 | } catch (IOException e) { 156 | e.printStackTrace(); 157 | return null; 158 | } 159 | return res; 160 | } 161 | 162 | public Vector getW(double[] w) { 163 | if (isInputDense) { 164 | return getDenseW(w); 165 | } 166 | return getSparseW(w); 167 | } 168 | 169 | public void initializeExamples(ArrayList vectorlist) { 170 | if (isInputDense) { 171 | initializeExamplesDense(vectorlist); 172 | } else { 173 | initializeExamplesSparse(vectorlist); 174 | } 175 | } 176 | 177 | private void initializeExamplesDense(ArrayList vectorlist) { 178 | for (int vectorId = 0; vectorId < vectorlist.size(); vectorId++) { 179 | DenseVector denseVector = (DenseVector) (vectorlist.get(vectorId)); 180 | if (vectorId == 0) { 181 | bias = 0; 182 | n = denseVector.getNumberOfFeatures() + 1; 183 | } 184 | this.x[vectorId] = new LibLinearFeatureNode[denseVector 185 | .getNumberOfFeatures()]; 186 | for (int j = 0; j < denseVector.getNumberOfFeatures(); j++) 187 | this.x[vectorId][j] = new LibLinearFeatureNode(j + 1, 188 | denseVector.getFeatureValue(j)); 189 | } 190 | } 191 | 192 | private void initializeExamplesSparse(ArrayList vectorlist) { 193 | /* 194 | * Building dictionary 195 | */ 196 | int featureIndex = 1; 197 | for (Vector v : vectorlist) { 198 | //for (String dimLabel : v.getActiveFeatures().keySet()) { 199 | for (Object dimLabel : v.getActiveFeatures().keySet()) { 200 | if (!featureDict.containsKey(dimLabel)) { 201 | featureDict.put(dimLabel, featureIndex); 202 | featureInverseDict.put(featureIndex, dimLabel); 203 | featureIndex++; 204 | // System.out.println(featureIndex + " " + dimLabel); 205 | } 206 | } 207 | } 208 | 209 | /* 210 | * Initialize the object 211 | */ 212 | n = featureDict.size() + 1; 213 | bias = 0; 214 | int i = 0; 215 | for (Vector v : vectorlist) { 216 | Map, Number> activeFeatures = v.getActiveFeatures(); 217 | this.x[i] = new LibLinearFeatureNode[activeFeatures.size()]; 218 | int j = 0; 219 | for (Object dimLabel : activeFeatures.keySet()) { 220 | this.x[i][j] = new LibLinearFeatureNode( 221 | featureDict.get(dimLabel), activeFeatures.get(dimLabel).doubleValue()); 222 | j++; 223 | } 224 | i++; 225 | } 226 | } 227 | 228 | } 229 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/Tron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | import org.slf4j.Logger; 19 | import org.slf4j.LoggerFactory; 20 | 21 | 22 | /** 23 | * Trust Region Newton Method optimization 24 | * 25 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 26 | * C++ sources. Original Java sources (v 1.94) are available at 27 | * http://liblinear.bwaldvogel.de 28 | * 29 | * @author Danilo Croce 30 | */ 31 | public class Tron { 32 | private Logger logger = LoggerFactory.getLogger(Tron.class); 33 | 34 | private final TronFunction fun_obj; 35 | private final double eps; 36 | private final int max_iter; 37 | 38 | public Tron(final TronFunction fun_obj) { 39 | this(fun_obj, 0.1); 40 | } 41 | 42 | public Tron(final TronFunction fun_obj, double eps) { 43 | this(fun_obj, eps, 1000); 44 | } 45 | 46 | public Tron(final TronFunction fun_obj, double eps, int max_iter) { 47 | this.fun_obj = fun_obj; 48 | this.eps = eps; 49 | this.max_iter = max_iter; 50 | } 51 | 52 | public void tron(double[] w) { 53 | // Parameters for updating the iterates. 54 | double eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75; 55 | 56 | // Parameters for updating the trust region size delta. 57 | double sigma1 = 0.25, sigma2 = 0.5, sigma3 = 4; 58 | 59 | int n = fun_obj.get_nr_variable(); 60 | int i, cg_iter; 61 | double delta, snorm, one = 1.0; 62 | double alpha, f, fnew, prered, actred, gs; 63 | int search = 1, iter = 1; 64 | double[] s = new double[n]; 65 | double[] r = new double[n]; 66 | double[] w_new = new double[n]; 67 | double[] g = new double[n]; 68 | 69 | for (i = 0; i < n; i++) 70 | w[i] = 0; 71 | 72 | f = fun_obj.fun(w); 73 | fun_obj.grad(w, g); 74 | delta = euclideanNorm(g); 75 | double gnorm1 = delta; 76 | double gnorm = gnorm1; 77 | 78 | if (gnorm <= eps * gnorm1) 79 | search = 0; 80 | 81 | iter = 1; 82 | 83 | while (iter <= max_iter && search != 0) { 84 | cg_iter = trcg(delta, g, s, r); 85 | 86 | System.arraycopy(w, 0, w_new, 0, n); 87 | daxpy(one, s, w_new); 88 | 89 | gs = dot(g, s); 90 | prered = -0.5 * (gs - dot(s, r)); 91 | fnew = fun_obj.fun(w_new); 92 | 93 | // Compute the actual reduction. 94 | actred = f - fnew; 95 | 96 | // On the first iteration, adjust the initial step bound. 97 | snorm = euclideanNorm(s); 98 | if (iter == 1) 99 | delta = Math.min(delta, snorm); 100 | 101 | // Compute prediction alpha*snorm of the step. 102 | if (fnew - f - gs <= 0) 103 | alpha = sigma3; 104 | else 105 | alpha = Math.max(sigma1, -0.5 * (gs / (fnew - f - gs))); 106 | 107 | // Update the trust region bound according to the ratio of actual to 108 | // predicted reduction. 109 | if (actred < eta0 * prered) 110 | delta = Math.min(Math.max(alpha, sigma1) * snorm, sigma2 111 | * delta); 112 | else if (actred < eta1 * prered) 113 | delta = Math.max(sigma1 * delta, 114 | Math.min(alpha * snorm, sigma2 * delta)); 115 | else if (actred < eta2 * prered) 116 | delta = Math.max(sigma1 * delta, 117 | Math.min(alpha * snorm, sigma3 * delta)); 118 | else 119 | delta = Math 120 | .max(delta, Math.min(alpha * snorm, sigma3 * delta)); 121 | 122 | // info("iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d%n", 123 | // iter, actred, prered, delta, f, gnorm, cg_iter); 124 | info("iter {} act {} pre {} delta {} f {} |g| {} CG {}", 125 | iter, actred, prered, delta, f, gnorm, cg_iter); 126 | 127 | if (actred > eta0 * prered) { 128 | iter++; 129 | System.arraycopy(w_new, 0, w, 0, n); 130 | f = fnew; 131 | fun_obj.grad(w, g); 132 | 133 | gnorm = euclideanNorm(g); 134 | if (gnorm <= eps * gnorm1) 135 | break; 136 | } 137 | if (f < -1.0e+32) { 138 | info("WARNING: f < -1.0e+32%n"); 139 | break; 140 | } 141 | if (Math.abs(actred) <= 0 && prered <= 0) { 142 | info("WARNING: actred and prered <= 0%n"); 143 | break; 144 | } 145 | if (Math.abs(actred) <= 1.0e-12 * Math.abs(f) 146 | && Math.abs(prered) <= 1.0e-12 * Math.abs(f)) { 147 | info("WARNING: actred and prered too small%n"); 148 | break; 149 | } 150 | } 151 | } 152 | 153 | private void info(String msg) { 154 | logger.debug(msg); 155 | } 156 | 157 | private void info(String msgFormatted, Object... args) { 158 | // Formatter formatter = new Formatter(); 159 | // Formatter format = formatter.format(msgFormatted, args); 160 | logger.debug(msgFormatted,args); 161 | // formatter.close(); 162 | } 163 | 164 | private int trcg(double delta, double[] g, double[] s, double[] r) { 165 | int n = fun_obj.get_nr_variable(); 166 | double one = 1; 167 | double[] d = new double[n]; 168 | double[] Hd = new double[n]; 169 | double rTr, rnewTrnew, cgtol; 170 | 171 | for (int i = 0; i < n; i++) { 172 | s[i] = 0; 173 | r[i] = -g[i]; 174 | d[i] = r[i]; 175 | } 176 | cgtol = 0.1 * euclideanNorm(g); 177 | 178 | int cg_iter = 0; 179 | rTr = dot(r, r); 180 | 181 | while (true) { 182 | if (euclideanNorm(r) <= cgtol) 183 | break; 184 | cg_iter++; 185 | fun_obj.Hv(d, Hd); 186 | 187 | double alpha = rTr / dot(d, Hd); 188 | daxpy(alpha, d, s); 189 | if (euclideanNorm(s) > delta) { 190 | info("cg reaches trust region boundary%n"); 191 | alpha = -alpha; 192 | daxpy(alpha, d, s); 193 | 194 | double std = dot(s, d); 195 | double sts = dot(s, s); 196 | double dtd = dot(d, d); 197 | double dsq = delta * delta; 198 | double rad = Math.sqrt(std * std + dtd * (dsq - sts)); 199 | if (std >= 0) 200 | alpha = (dsq - sts) / (std + rad); 201 | else 202 | alpha = (rad - std) / dtd; 203 | daxpy(alpha, d, s); 204 | alpha = -alpha; 205 | daxpy(alpha, Hd, r); 206 | break; 207 | } 208 | alpha = -alpha; 209 | daxpy(alpha, Hd, r); 210 | rnewTrnew = dot(r, r); 211 | double beta = rnewTrnew / rTr; 212 | scale(beta, d); 213 | daxpy(one, r, d); 214 | rTr = rnewTrnew; 215 | } 216 | 217 | return (cg_iter); 218 | } 219 | 220 | /** 221 | * constant times a vector plus a vector 222 | * 223 | * 224 | * vector2 += constant * vector1 225 | * 226 | * 227 | * @since 1.8 228 | */ 229 | private static void daxpy(double constant, double vector1[], 230 | double vector2[]) { 231 | if (constant == 0) 232 | return; 233 | 234 | assert vector1.length == vector2.length; 235 | for (int i = 0; i < vector1.length; i++) { 236 | vector2[i] += constant * vector1[i]; 237 | } 238 | } 239 | 240 | /** 241 | * returns the dot product of two vectors 242 | * 243 | * @since 1.8 244 | */ 245 | private static double dot(double vector1[], double vector2[]) { 246 | 247 | double product = 0; 248 | assert vector1.length == vector2.length; 249 | for (int i = 0; i < vector1.length; i++) { 250 | product += vector1[i] * vector2[i]; 251 | } 252 | return product; 253 | 254 | } 255 | 256 | /** 257 | * returns the euclidean norm of a vector 258 | * 259 | * @since 1.8 260 | */ 261 | private static double euclideanNorm(double vector[]) { 262 | 263 | int n = vector.length; 264 | 265 | if (n < 1) { 266 | return 0; 267 | } 268 | 269 | if (n == 1) { 270 | return Math.abs(vector[0]); 271 | } 272 | 273 | // this algorithm is (often) more accurate than just summing up the 274 | // squares and taking the square-root afterwards 275 | 276 | double scale = 0; // scaling factor that is factored out 277 | double sum = 1; // basic sum of squares from which scale has been 278 | // factored out 279 | for (int i = 0; i < n; i++) { 280 | if (vector[i] != 0) { 281 | double abs = Math.abs(vector[i]); 282 | // try to get the best scaling factor 283 | if (scale < abs) { 284 | double t = scale / abs; 285 | sum = 1 + sum * (t * t); 286 | scale = abs; 287 | } else { 288 | double t = abs / scale; 289 | sum += t * t; 290 | } 291 | } 292 | } 293 | 294 | return scale * Math.sqrt(sum); 295 | } 296 | 297 | /** 298 | * scales a vector by a constant 299 | * 300 | * @since 1.8 301 | */ 302 | private static void scale(double constant, double vector[]) { 303 | if (constant == 1.0) 304 | return; 305 | for (int i = 0; i < vector.length; i++) { 306 | vector[i] *= constant; 307 | } 308 | 309 | } 310 | } 311 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/TronFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | 19 | /** 20 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 21 | * C++ sources. Original Java sources (v 1.94) are available at 22 | * http://liblinear.bwaldvogel.de 23 | * 24 | * @author Danilo Croce 25 | */ 26 | interface TronFunction { 27 | 28 | double fun(double[] w); 29 | 30 | void grad(double[] w, double[] g); 31 | 32 | void Hv(double[] s, double[] Hs); 33 | 34 | int get_nr_variable(); 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/passiveaggressive/KernelizedPassiveAggressiveClassification.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.passiveaggressive; 17 | 18 | 19 | import com.fasterxml.jackson.annotation.JsonTypeName; 20 | 21 | import it.uniroma2.sag.kelp.data.label.Label; 22 | import it.uniroma2.sag.kelp.kernel.Kernel; 23 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 24 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 25 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryKernelMachineClassifier; 26 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryKernelMachineModel; 27 | 28 | /** 29 | * Online Passive-Aggressive Learning Algorithm for classification tasks (Kernel Machine version) . 30 | * Every time an example is misclassified it is added as support vector, with the weight that solves the 31 | * passive aggressive minimization problem 32 | * 33 | * reference: 34 | * 35 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 36 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 37 | * 38 | * The standard algorithm is modified, including the fairness extention from 39 | * [FiliceECIR2014] S. Filice, G. Castellucci, D. Croce, and R. Basili. Effective Kernelized Online Learning 40 | * in Language Processing Tasks. In collection of Advances in Information Retrieval, pp. 347–358, Springer International Publishing, 2014. 41 | * 42 | * 43 | * @author Simone Filice 44 | */ 45 | 46 | @JsonTypeName("kernelizedPA") 47 | public class KernelizedPassiveAggressiveClassification extends PassiveAggressiveClassification implements KernelMethod{ 48 | 49 | private Kernel kernel; 50 | 51 | public KernelizedPassiveAggressiveClassification(){ 52 | this.classifier = new BinaryKernelMachineClassifier(); 53 | this.classifier.setModel(new BinaryKernelMachineModel()); 54 | } 55 | 56 | public KernelizedPassiveAggressiveClassification(float cp, float cn, Loss loss, Policy policy, Kernel kernel, Label label){ 57 | this.classifier = new BinaryKernelMachineClassifier(); 58 | this.classifier.setModel(new BinaryKernelMachineModel()); 59 | this.setKernel(kernel); 60 | this.setLoss(loss); 61 | this.setCp(cp); 62 | this.setCn(cn); 63 | this.setLabel(label); 64 | this.setPolicy(policy); 65 | } 66 | 67 | 68 | @Override 69 | public Kernel getKernel() { 70 | return kernel; 71 | } 72 | 73 | @Override 74 | public void setKernel(Kernel kernel) { 75 | this.kernel = kernel; 76 | this.getPredictionFunction().getModel().setKernel(kernel); 77 | } 78 | 79 | 80 | @Override 81 | public KernelizedPassiveAggressiveClassification duplicate(){ 82 | KernelizedPassiveAggressiveClassification copy = new KernelizedPassiveAggressiveClassification(); 83 | copy.setCp(this.cp); 84 | copy.setCn(c); 85 | copy.setFairness(this.fairness); 86 | copy.setKernel(this.kernel); 87 | copy.setLoss(this.loss); 88 | copy.setPolicy(this.policy); 89 | //copy.setLabel(label); 90 | return copy; 91 | } 92 | 93 | @Override 94 | public BinaryKernelMachineClassifier getPredictionFunction(){ 95 | return (BinaryKernelMachineClassifier) this.classifier; 96 | } 97 | 98 | @Override 99 | public void setPredictionFunction(PredictionFunction predictionFunction) { 100 | this.classifier = (BinaryKernelMachineClassifier) predictionFunction; 101 | } 102 | 103 | } 104 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/passiveaggressive/LinearPassiveAggressiveClassification.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.passiveaggressive; 17 | 18 | import it.uniroma2.sag.kelp.data.label.Label; 19 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 20 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 21 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryLinearClassifier; 22 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 23 | 24 | import com.fasterxml.jackson.annotation.JsonTypeName; 25 | 26 | 27 | /** 28 | * Online Passive-Aggressive Learning Algorithm for classification tasks (linear version) . 29 | * Every time an example is misclassified it is added the the current hyperplane, with the weight that solves the 30 | * passive aggressive minimization problem 31 | * 32 | * reference: 33 | * 34 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 35 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 36 | * 37 | * The standard algorithm is modified, including the fairness extention from 38 | * [FiliceECIR2014] S. Filice, G. Castellucci, D. Croce, and R. Basili. Effective Kernelized Online Learning 39 | * in Language Processing Tasks. In collection of Advances in Information Retrieval, pp. 347–358, Springer International Publishing, 2014. 40 | * 41 | * @author Simone Filice 42 | */ 43 | @JsonTypeName("linearPA") 44 | public class LinearPassiveAggressiveClassification extends PassiveAggressiveClassification implements LinearMethod{ 45 | 46 | private String representation; 47 | 48 | public LinearPassiveAggressiveClassification(){ 49 | this.classifier = new BinaryLinearClassifier(); 50 | this.classifier.setModel(new BinaryLinearModel()); 51 | } 52 | 53 | public LinearPassiveAggressiveClassification(float cp, float cn, Loss loss, Policy policy, String representation, Label label){ 54 | this.classifier = new BinaryLinearClassifier(); 55 | this.classifier.setModel(new BinaryLinearModel()); 56 | this.setCp(cp); 57 | this.setCn(cn); 58 | this.setLoss(loss); 59 | this.setPolicy(policy); 60 | this.setRepresentation(representation); 61 | this.setLabel(label); 62 | } 63 | 64 | @Override 65 | public String getRepresentation() { 66 | return representation; 67 | } 68 | 69 | @Override 70 | public void setRepresentation(String representation) { 71 | this.representation = representation; 72 | BinaryLinearModel model = (BinaryLinearModel) this.classifier.getModel(); 73 | model.setRepresentation(representation); 74 | } 75 | 76 | @Override 77 | public LinearPassiveAggressiveClassification duplicate(){ 78 | LinearPassiveAggressiveClassification copy = new LinearPassiveAggressiveClassification(); 79 | copy.setRepresentation(this.representation); 80 | copy.setCp(this.cp); 81 | copy.setCn(this.c); 82 | copy.setFairness(this.fairness); 83 | copy.setLoss(this.loss); 84 | copy.setPolicy(this.policy); 85 | return copy; 86 | } 87 | 88 | @Override 89 | public BinaryLinearClassifier getPredictionFunction(){ 90 | return (BinaryLinearClassifier) this.classifier; 91 | } 92 | 93 | @Override 94 | public void setPredictionFunction(PredictionFunction predictionFunction) { 95 | this.classifier = (BinaryLinearClassifier) predictionFunction; 96 | } 97 | 98 | } 99 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/passiveaggressive/PassiveAggressiveClassification.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.passiveaggressive; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.learningalgorithm.PassiveAggressive; 21 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 22 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryClassifier; 23 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput; 24 | 25 | import com.fasterxml.jackson.annotation.JsonIgnore; 26 | import com.fasterxml.jackson.annotation.JsonProperty; 27 | 28 | /** 29 | * Online Passive-Aggressive Learning Algorithm for classification tasks. 30 | * Every time an example is misclassified it is added the the current hyperplane, with the weight that solves the 31 | * passive aggressive minimization problem 32 | * 33 | * reference: 34 | * 35 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 36 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 37 | * 38 | * The standard algorithm is modified, including the fairness extention from 39 | * [FiliceECIR2014] S. Filice, G. Castellucci, D. Croce, and R. Basili. Effective Kernelized Online Learning 40 | * in Language Processing Tasks. In collection of Advances in Information Retrieval, pp. 347-358, Springer International Publishing, 2014. 41 | * 42 | * @author Simone Filice 43 | */ 44 | public abstract class PassiveAggressiveClassification extends PassiveAggressive implements ClassificationLearningAlgorithm{ 45 | 46 | public enum Loss{ 47 | HINGE, 48 | RAMP 49 | } 50 | 51 | protected Loss loss = Loss.HINGE; 52 | protected float cp = c;//cp is the aggressiveness w.r.t. positive examples. c will be considered the aggressiveness w.r.t. negative examples 53 | protected boolean fairness = false; 54 | 55 | @JsonIgnore 56 | protected BinaryClassifier classifier; 57 | 58 | 59 | /** 60 | * @return the fairness 61 | */ 62 | public boolean isFairness() { 63 | return fairness; 64 | } 65 | 66 | 67 | /** 68 | * @param fairness the fairness to set 69 | */ 70 | public void setFairness(boolean fairness) { 71 | this.fairness = fairness; 72 | } 73 | 74 | /** 75 | * @return the aggressiveness parameter for positive examples 76 | */ 77 | public float getCp() { 78 | return cp; 79 | } 80 | 81 | 82 | /** 83 | * @param cp the aggressiveness parameter for positive examples 84 | */ 85 | public void setCp(float cp) { 86 | this.cp = cp; 87 | } 88 | 89 | /** 90 | * @return the aggressiveness parameter for negative examples 91 | */ 92 | public float getCn() { 93 | return c; 94 | } 95 | 96 | 97 | /** 98 | * @param cn the aggressiveness parameter for negative examples 99 | */ 100 | public void setCn(float cn) { 101 | this.c = cn; 102 | } 103 | 104 | @Override 105 | @JsonIgnore 106 | public float getC(){ 107 | return c; 108 | } 109 | 110 | @Override 111 | @JsonProperty 112 | public void setC(float c){ 113 | super.setC(c); 114 | this.cp=c; 115 | } 116 | 117 | /** 118 | * @return the loss function type 119 | */ 120 | public Loss getLoss() { 121 | return loss; 122 | } 123 | 124 | 125 | /** 126 | * @param loss the loss function type to set 127 | */ 128 | public void setLoss(Loss loss) { 129 | this.loss = loss; 130 | } 131 | 132 | @Override 133 | public BinaryClassifier getPredictionFunction() { 134 | return this.classifier; 135 | } 136 | 137 | @Override 138 | public BinaryMarginClassifierOutput learn(Example example){ 139 | 140 | BinaryMarginClassifierOutput prediction=this.classifier.predict(example); 141 | 142 | float lossValue = 0;//it represents the distance from the correct semi-space 143 | if(prediction.isClassPredicted(label)!=example.isExampleOf(label)){ 144 | lossValue = 1 + Math.abs(prediction.getScore(label)); 145 | }else if(Math.abs(prediction.getScore(label))<1){ 146 | lossValue = 1 - Math.abs(prediction.getScore(label)); 147 | } 148 | 149 | if(lossValue>0 && (lossValue<2 || this.loss!=Loss.RAMP)){ 150 | float exampleAggressiveness=this.c; 151 | if(example.isExampleOf(label)){ 152 | exampleAggressiveness=cp; 153 | } 154 | float exampleSquaredNorm = this.classifier.getModel().getSquaredNorm(example); 155 | float weight = this.computeWeight(example, lossValue, exampleSquaredNorm ,exampleAggressiveness); 156 | if(!example.isExampleOf(label)){ 157 | weight*=-1; 158 | } 159 | this.getPredictionFunction().getModel().addExample(weight, example); 160 | } 161 | return prediction; 162 | 163 | } 164 | 165 | @Override 166 | public void learn(Dataset dataset){ 167 | if(this.fairness){ 168 | float positiveExample = dataset.getNumberOfPositiveExamples(label); 169 | float negativeExample = dataset.getNumberOfNegativeExamples(label); 170 | cp = c * negativeExample / positiveExample; 171 | } 172 | //System.out.println("cn: " + c + " cp: " + cp); 173 | super.learn(dataset); 174 | } 175 | 176 | } 177 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/pegasos/PegasosLearningAlgorithm.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.pegasos; 17 | 18 | import java.util.ArrayList; 19 | import java.util.Arrays; 20 | import java.util.List; 21 | 22 | import com.fasterxml.jackson.annotation.JsonTypeName; 23 | 24 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 25 | import it.uniroma2.sag.kelp.data.example.Example; 26 | import it.uniroma2.sag.kelp.data.label.Label; 27 | import it.uniroma2.sag.kelp.data.representation.Vector; 28 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 29 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 30 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 31 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 32 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryLinearClassifier; 33 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput; 34 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 35 | 36 | /** 37 | * It implements the Primal Estimated sub-GrAdient SOlver (PEGASOS) for SVM. It is a learning 38 | * algorithm for binary linear classification Support Vector Machines. It operates in an explicit 39 | * feature space (i.e. it does not relies on any kernel). Further details can be found in: 40 | * 41 | * [SingerICML2007] Y. Singer and N. Srebro. Pegasos: Primal estimated sub-gradient solver for SVM. 42 | * In Proceeding of ICML 2007. 43 | * 44 | * @author Simone Filice 45 | * 46 | */ 47 | @JsonTypeName("pegasos") 48 | public class PegasosLearningAlgorithm implements LinearMethod, ClassificationLearningAlgorithm, BinaryLearningAlgorithm{ 49 | 50 | private Label label; 51 | 52 | private BinaryLinearClassifier classifier; 53 | 54 | private int k = 1; 55 | private int iterations = 1000; 56 | private float lambda = 0.01f; 57 | 58 | private String representation; 59 | 60 | /** 61 | * Returns the number of examples k that Pegasos exploits in its 62 | * mini-batch learning approach 63 | * 64 | * @return k 65 | */ 66 | public int getK() { 67 | return k; 68 | } 69 | 70 | /** 71 | * Sets the number of examples k that Pegasos exploits in its 72 | * mini-batch learning approach 73 | * 74 | * @param k the k to set 75 | */ 76 | public void setK(int k) { 77 | this.k = k; 78 | } 79 | 80 | /** 81 | * Returns the number of iterations 82 | * 83 | * @return the number of iterations 84 | */ 85 | public int getIterations() { 86 | return iterations; 87 | } 88 | 89 | /** 90 | * Sets the number of iterations 91 | * 92 | * @param T the number of iterations to set 93 | */ 94 | public void setIterations(int T) { 95 | this.iterations = T; 96 | } 97 | 98 | /** 99 | * Returns the regularization coefficient 100 | * 101 | * @return the lambda 102 | */ 103 | public float getLambda() { 104 | return lambda; 105 | } 106 | 107 | /** 108 | * Sets the regularization coefficient 109 | * 110 | * @param lambda the lambda to set 111 | */ 112 | public void setLambda(float lambda) { 113 | this.lambda = lambda; 114 | } 115 | 116 | public PegasosLearningAlgorithm(){ 117 | this.classifier = new BinaryLinearClassifier(); 118 | this.classifier.setModel(new BinaryLinearModel()); 119 | } 120 | 121 | public PegasosLearningAlgorithm(int k, float lambda, int T, String Representation, Label label){ 122 | this.classifier = new BinaryLinearClassifier(); 123 | this.classifier.setModel(new BinaryLinearModel()); 124 | this.setK(k); 125 | this.setLabel(label); 126 | this.setLambda(lambda); 127 | this.setRepresentation(Representation); 128 | this.setIterations(T); 129 | } 130 | 131 | @Override 132 | public String getRepresentation() { 133 | return representation; 134 | } 135 | 136 | @Override 137 | public void setRepresentation(String representation) { 138 | this.representation = representation; 139 | BinaryLinearModel model = this.classifier.getModel(); 140 | model.setRepresentation(representation); 141 | } 142 | 143 | @Override 144 | public void learn(Dataset dataset) { 145 | if(this.getPredictionFunction().getModel().getHyperplane()==null){ 146 | this.getPredictionFunction().getModel().setHyperplane(dataset.getZeroVector(representation)); 147 | } 148 | 149 | for(int t=1;t<=iterations;t++){ 150 | 151 | List A_t = dataset.getRandExamples(k); 152 | List A_tp = new ArrayList(); 153 | List signA_tp = new ArrayList(); 154 | float eta_t = ((float)1)/(lambda*t); 155 | Vector w_t = this.getPredictionFunction().getModel().getHyperplane(); 156 | 157 | //creating A_tp 158 | for(Example example: A_t){ 159 | BinaryMarginClassifierOutput prediction = this.classifier.predict(example); 160 | float y = -1; 161 | if(example.isExampleOf(label)){ 162 | y=1; 163 | } 164 | 165 | if(prediction.getScore(label)*y<1){ 166 | A_tp.add(example); 167 | signA_tp.add(y); 168 | } 169 | } 170 | //creating w_(t+1/2) 171 | w_t.scale(1-eta_t*lambda); 172 | float miscassificationFactor = eta_t/k; 173 | for(int i=0; i labels){ 211 | if(labels.size()!=1){ 212 | throw new IllegalArgumentException("Pegasos algorithm is a binary method which can learn a single Label"); 213 | } 214 | else{ 215 | this.label=labels.get(0); 216 | this.classifier.setLabels(labels); 217 | } 218 | } 219 | 220 | 221 | @Override 222 | public List getLabels() { 223 | return Arrays.asList(label); 224 | } 225 | 226 | @Override 227 | public Label getLabel(){ 228 | return this.label; 229 | } 230 | 231 | @Override 232 | public void setLabel(Label label){ 233 | this.setLabels(Arrays.asList(label)); 234 | } 235 | 236 | @Override 237 | public void setPredictionFunction(PredictionFunction predictionFunction) { 238 | this.classifier = (BinaryLinearClassifier) predictionFunction; 239 | } 240 | } 241 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/perceptron/KernelizedPerceptron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron; 17 | 18 | 19 | import it.uniroma2.sag.kelp.data.label.Label; 20 | import it.uniroma2.sag.kelp.kernel.Kernel; 21 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 22 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 23 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryKernelMachineClassifier; 24 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryKernelMachineModel; 25 | 26 | import com.fasterxml.jackson.annotation.JsonTypeName; 27 | 28 | /** 29 | * The perceptron learning algorithm algorithm for classification tasks (Kernel machine version). Reference: 30 | * [Rosenblatt1957] F. Rosenblatt. The Perceptron – a perceiving and recognizing automaton. Report 85-460-1, Cornell Aeronautical Laboratory (1957) 31 | * 32 | * @author Simone Filice 33 | * 34 | */ 35 | @JsonTypeName("kernelizedPerceptron") 36 | public class KernelizedPerceptron extends Perceptron implements KernelMethod{ 37 | 38 | 39 | private Kernel kernel; 40 | 41 | public KernelizedPerceptron(){ 42 | this.classifier = new BinaryKernelMachineClassifier(); 43 | this.classifier.setModel(new BinaryKernelMachineModel()); 44 | } 45 | 46 | public KernelizedPerceptron(float alpha, float margin, boolean unbiased, Kernel kernel, Label label){ 47 | this.classifier = new BinaryKernelMachineClassifier(); 48 | this.classifier.setModel(new BinaryKernelMachineModel()); 49 | this.setAlpha(alpha); 50 | this.setMargin(margin); 51 | this.setUnbiased(unbiased); 52 | this.setKernel(kernel); 53 | this.setLabel(label); 54 | } 55 | 56 | @Override 57 | public Kernel getKernel() { 58 | return kernel; 59 | } 60 | 61 | @Override 62 | public void setKernel(Kernel kernel) { 63 | this.kernel = kernel; 64 | this.getPredictionFunction().getModel().setKernel(kernel); 65 | } 66 | 67 | @Override 68 | public KernelizedPerceptron duplicate(){ 69 | KernelizedPerceptron copy = new KernelizedPerceptron(); 70 | copy.setKernel(this.kernel); 71 | copy.setAlpha(this.alpha); 72 | copy.setMargin(this.margin); 73 | copy.setUnbiased(this.unbiased); 74 | return copy; 75 | } 76 | 77 | @Override 78 | public BinaryKernelMachineClassifier getPredictionFunction(){ 79 | return (BinaryKernelMachineClassifier) this.classifier; 80 | } 81 | 82 | @Override 83 | public void setPredictionFunction(PredictionFunction predictionFunction) { 84 | this.classifier = (BinaryKernelMachineClassifier) predictionFunction; 85 | } 86 | 87 | } 88 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/perceptron/LinearPerceptron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron; 17 | 18 | 19 | import com.fasterxml.jackson.annotation.JsonTypeName; 20 | 21 | import it.uniroma2.sag.kelp.data.label.Label; 22 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 23 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 24 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryLinearClassifier; 25 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 26 | 27 | /** 28 | * The perceptron learning algorithm algorithm for classification tasks (linear version). Reference: 29 | * [Rosenblatt1957] F. Rosenblatt. The Perceptron – a perceiving and recognizing automaton. Report 85-460-1, Cornell Aeronautical Laboratory (1957) 30 | * 31 | * @author Simone Filice 32 | * 33 | */ 34 | @JsonTypeName("linearPerceptron") 35 | public class LinearPerceptron extends Perceptron implements LinearMethod{ 36 | 37 | 38 | private String representation; 39 | 40 | 41 | public LinearPerceptron(){ 42 | this.classifier = new BinaryLinearClassifier(); 43 | this.classifier.setModel(new BinaryLinearModel()); 44 | } 45 | 46 | public LinearPerceptron(float alpha, float margin, boolean unbiased, String representation, Label label){ 47 | this.classifier = new BinaryLinearClassifier(); 48 | this.classifier.setModel(new BinaryLinearModel()); 49 | this.setAlpha(alpha); 50 | this.setMargin(margin); 51 | this.setUnbiased(unbiased); 52 | this.setRepresentation(representation); 53 | this.setLabel(label); 54 | } 55 | 56 | @Override 57 | public String getRepresentation() { 58 | return representation; 59 | } 60 | 61 | @Override 62 | public void setRepresentation(String representation) { 63 | this.representation = representation; 64 | BinaryLinearModel model = (BinaryLinearModel) this.classifier.getModel(); 65 | model.setRepresentation(representation); 66 | } 67 | 68 | @Override 69 | public LinearPerceptron duplicate(){ 70 | LinearPerceptron copy = new LinearPerceptron(); 71 | copy.setAlpha(this.alpha); 72 | copy.setMargin(this.margin); 73 | copy.setRepresentation(representation); 74 | copy.setUnbiased(this.unbiased); 75 | return copy; 76 | } 77 | 78 | @Override 79 | public BinaryLinearClassifier getPredictionFunction(){ 80 | return (BinaryLinearClassifier) this.classifier; 81 | } 82 | 83 | @Override 84 | public void setPredictionFunction(PredictionFunction predictionFunction) { 85 | this.classifier = (BinaryLinearClassifier) predictionFunction; 86 | } 87 | 88 | } 89 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/perceptron/Perceptron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron; 17 | 18 | import java.util.Arrays; 19 | import java.util.List; 20 | 21 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 22 | import it.uniroma2.sag.kelp.data.example.Example; 23 | import it.uniroma2.sag.kelp.data.label.Label; 24 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 25 | import it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm; 26 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 27 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryClassifier; 28 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput; 29 | 30 | import com.fasterxml.jackson.annotation.JsonIgnore; 31 | 32 | /** 33 | * The perceptron learning algorithm algorithm for classification tasks. Reference: 34 | * [Rosenblatt1957] F. Rosenblatt. The Perceptron – a perceiving and recognizing automaton. Report 85-460-1, Cornell Aeronautical Laboratory (1957) 35 | * 36 | * @author Simone Filice 37 | * 38 | */ 39 | public abstract class Perceptron implements ClassificationLearningAlgorithm, OnlineLearningAlgorithm, BinaryLearningAlgorithm{ 40 | 41 | @JsonIgnore 42 | protected BinaryClassifier classifier; 43 | 44 | protected Label label; 45 | 46 | protected float alpha=1; 47 | protected float margin = 1; 48 | protected boolean unbiased=false; 49 | 50 | /** 51 | * Returns the learning rate, i.e. the weight associated to misclassified examples during the learning process 52 | * 53 | * @return the learning rate 54 | */ 55 | public float getAlpha() { 56 | return alpha; 57 | } 58 | 59 | /** 60 | * Sets the learning rate, i.e. the weight associated to misclassified examples during the learning process 61 | * 62 | * @param alpha the learning rate to set 63 | */ 64 | public void setAlpha(float alpha) { 65 | if(alpha<=0 || alpha>1){ 66 | throw new IllegalArgumentException("Invalid learning rate for the perceptron algorithm: valid alphas in (0,1]"); 67 | } 68 | this.alpha = alpha; 69 | } 70 | 71 | /** 72 | * Returns the desired margin, i.e. the minimum distance from the hyperplane that an example must have 73 | * in order to be not considered misclassified 74 | * 75 | * @return the margin 76 | */ 77 | public float getMargin() { 78 | return margin; 79 | } 80 | 81 | /** 82 | * Sets the desired margin, i.e. the minimum distance from the hyperplane that an example must have 83 | * in order to be not considered misclassified 84 | * 85 | * @param margin the margin to set 86 | */ 87 | public void setMargin(float margin) { 88 | this.margin = margin; 89 | } 90 | 91 | /** 92 | * Returns whether the bias, i.e. the constant term of the hyperplane, is always 0, or can be modified during 93 | * the learning process 94 | * 95 | * @return the unbiased 96 | */ 97 | public boolean isUnbiased() { 98 | return unbiased; 99 | } 100 | 101 | /** 102 | * Sets whether the bias, i.e. the constant term of the hyperplane, is always 0, or can be modified during 103 | * the learning process 104 | * 105 | * @param unbiased the unbiased to set 106 | */ 107 | public void setUnbiased(boolean unbiased) { 108 | this.unbiased = unbiased; 109 | } 110 | 111 | 112 | @Override 113 | public void learn(Dataset dataset) { 114 | 115 | while(dataset.hasNextExample()){ 116 | Example example = dataset.getNextExample(); 117 | this.learn(example); 118 | } 119 | dataset.reset(); 120 | } 121 | 122 | @Override 123 | public BinaryMarginClassifierOutput learn(Example example){ 124 | BinaryMarginClassifierOutput prediction = this.classifier.predict(example); 125 | 126 | float predValue = prediction.getScore(label); 127 | if(Math.abs(predValue) labels){ 154 | if(labels.size()!=1){ 155 | throw new IllegalArgumentException("The Perceptron algorithm is a binary method which can learn a single Label"); 156 | } 157 | else{ 158 | this.label=labels.get(0); 159 | this.classifier.setLabels(labels); 160 | } 161 | } 162 | 163 | 164 | @Override 165 | public List getLabels() { 166 | 167 | return Arrays.asList(label); 168 | } 169 | 170 | @Override 171 | public Label getLabel(){ 172 | return this.label; 173 | } 174 | 175 | @Override 176 | public void setLabel(Label label){ 177 | this.setLabels(Arrays.asList(label)); 178 | } 179 | 180 | } 181 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/BinaryPlattNormalizer.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | import it.uniroma2.sag.kelp.data.label.Label; 4 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput; 5 | 6 | public class BinaryPlattNormalizer { 7 | 8 | private float A; 9 | private float B; 10 | 11 | public BinaryPlattNormalizer() { 12 | 13 | } 14 | 15 | public BinaryPlattNormalizer(float a, float b) { 16 | super(); 17 | A = a; 18 | B = b; 19 | } 20 | 21 | public float normalizeScore(float nonNomalizedScore) { 22 | return (float) (1.0 / (1.0 + Math.exp(A * nonNomalizedScore + B))); 23 | } 24 | 25 | public float getA() { 26 | return A; 27 | } 28 | 29 | public float getB() { 30 | return B; 31 | } 32 | 33 | public void setA(float a) { 34 | A = a; 35 | } 36 | 37 | public void setB(float b) { 38 | B = b; 39 | } 40 | 41 | @Override 42 | public String toString() { 43 | return "PlattSigmoidFunction [A=" + A + ", B=" + B + "]"; 44 | } 45 | 46 | public BinaryMarginClassifierOutput getNormalizedScore(BinaryMarginClassifierOutput binaryMarginClassifierOutput) { 47 | 48 | Label positiveLabel = binaryMarginClassifierOutput.getAllClasses().get(0); 49 | 50 | Float nonNormalizedScore = binaryMarginClassifierOutput.getScore(positiveLabel); 51 | 52 | BinaryMarginClassifierOutput res = new BinaryMarginClassifierOutput(positiveLabel, 53 | normalizeScore(nonNormalizedScore)); 54 | 55 | return res; 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/MulticlassPlattNormalizer.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | import java.util.HashMap; 4 | 5 | import it.uniroma2.sag.kelp.data.label.Label; 6 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassificationOutput; 7 | 8 | public class MulticlassPlattNormalizer { 9 | 10 | private HashMap binaryPlattNormalizers; 11 | 12 | public void addBinaryPlattNormalizer(Label label, BinaryPlattNormalizer binaryPlattNormalizer) { 13 | if (binaryPlattNormalizers == null) { 14 | binaryPlattNormalizers = new HashMap(); 15 | } 16 | binaryPlattNormalizers.put(label, binaryPlattNormalizer); 17 | } 18 | 19 | public OneVsAllClassificationOutput getNormalizedScores(OneVsAllClassificationOutput oneVsAllClassificationOutput) { 20 | OneVsAllClassificationOutput res = new OneVsAllClassificationOutput(); 21 | 22 | for (Label l : oneVsAllClassificationOutput.getAllClasses()) { 23 | float nonNormalizedScore = oneVsAllClassificationOutput.getScore(l); 24 | BinaryPlattNormalizer binaryPlattNormalizer = binaryPlattNormalizers.get(l); 25 | float normalizedScore = binaryPlattNormalizer.normalizeScore(nonNormalizedScore); 26 | 27 | res.addBinaryPrediction(l, normalizedScore); 28 | } 29 | 30 | return res; 31 | } 32 | 33 | public static OneVsAllClassificationOutput softmax(OneVsAllClassificationOutput oneVsAllClassificationOutput) { 34 | OneVsAllClassificationOutput res = new OneVsAllClassificationOutput(); 35 | 36 | float denom = 0; 37 | for (Label l : oneVsAllClassificationOutput.getAllClasses()) { 38 | float score = oneVsAllClassificationOutput.getScore(l); 39 | denom += Math.exp(score); 40 | } 41 | 42 | 43 | for (Label l : oneVsAllClassificationOutput.getAllClasses()) { 44 | float score = oneVsAllClassificationOutput.getScore(l); 45 | float newScore = (float)Math.exp(score)/denom; 46 | 47 | res.addBinaryPrediction(l, newScore); 48 | } 49 | 50 | return res; 51 | } 52 | 53 | public HashMap getBinaryPlattNormalizers() { 54 | return binaryPlattNormalizers; 55 | } 56 | 57 | public void setBinaryPlattNormalizers(HashMap binaryPlattNormalizers) { 58 | this.binaryPlattNormalizers = binaryPlattNormalizers; 59 | } 60 | 61 | } 62 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/PlattInputElement.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | public class PlattInputElement { 4 | 5 | private int label; 6 | private float value; 7 | 8 | public PlattInputElement(int label, float value) { 9 | super(); 10 | this.label = label; 11 | this.value = value; 12 | } 13 | 14 | public int getLabel() { 15 | return label; 16 | } 17 | 18 | public float getValue() { 19 | return value; 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/PlattInputList.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | import java.util.Vector; 4 | 5 | public class PlattInputList { 6 | 7 | private Vector list; 8 | private int positiveElement; 9 | private int negativeElement; 10 | 11 | public PlattInputList() { 12 | list = new Vector(); 13 | } 14 | 15 | public void add(PlattInputElement arg0) { 16 | if (arg0.getLabel() > 0) 17 | positiveElement++; 18 | else 19 | negativeElement++; 20 | 21 | list.add(arg0); 22 | } 23 | 24 | public PlattInputElement get(int index) { 25 | return list.get(index); 26 | } 27 | 28 | public int size() { 29 | return list.size(); 30 | } 31 | 32 | public int getPositiveElement() { 33 | return positiveElement; 34 | } 35 | 36 | public int getNegativeElement() { 37 | return negativeElement; 38 | } 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/PlattMethod.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | import java.util.HashMap; 4 | 5 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 6 | import it.uniroma2.sag.kelp.data.example.Example; 7 | import it.uniroma2.sag.kelp.data.label.Label; 8 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 9 | import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm; 10 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 11 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 12 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 13 | 14 | public class PlattMethod { 15 | 16 | /** 17 | * Input parameters: 18 | * 19 | * deci = array of SVM decision values 20 | * 21 | * label = array of booleans: is the example labeled +1? 22 | * 23 | * prior1 = number of positive examples 24 | * 25 | * prior0 = number of negative examples 26 | * 27 | * Outputs: 28 | * 29 | * A, B = parameters of sigmoid 30 | * 31 | * @return 32 | **/ 33 | private static BinaryPlattNormalizer estimateSigmoid(float[] deci, float[] label, int prior1, int prior0) { 34 | 35 | /** 36 | * Parameter setting 37 | */ 38 | // Maximum number of iterations 39 | int maxiter = 100; 40 | // Minimum step taken in line search 41 | // minstep=1e-10; 42 | double minstep = 1e-10; 43 | double stopping = 1e-5; 44 | // Sigma: Set to any value > 0 45 | double sigma = 1e-12; 46 | // Construct initial values: target support in array t, 47 | // initial function value in fval 48 | double hiTarget = ((double) prior1 + 1.0f) / ((double) prior1 + 2.0f); 49 | double loTarget = 1 / (prior0 + 2.0f); 50 | 51 | int len = prior1 + prior0; // Total number of data 52 | double A; 53 | double B; 54 | 55 | double t[] = new double[len]; 56 | 57 | for (int i = 0; i < len; i++) { 58 | if (label[i] > 0) 59 | t[i] = hiTarget; 60 | else 61 | t[i] = loTarget; 62 | } 63 | 64 | A = 0; 65 | B = Math.log((prior0 + 1.0) / (prior1 + 1.0)); 66 | double fval = 0f; 67 | 68 | for (int i = 0; i < len; i++) { 69 | double fApB = deci[i] * A + B; 70 | if (fApB >= 0) 71 | fval += t[i] * fApB + Math.log(1 + Math.exp(-fApB)); 72 | else 73 | fval += (t[i] - 1) * fApB + Math.log(1 + Math.exp(fApB)); 74 | } 75 | 76 | int it = 1; 77 | for (it = 1; it <= maxiter; it++) { 78 | // Update Gradient and Hessian (use H� = H + sigma I) 79 | double h11 = sigma; 80 | double h22 = sigma; 81 | double h21 = 0; 82 | double g1 = 0; 83 | double g2 = 0; 84 | for (int i = 0; i < len; i++) { 85 | double fApB = deci[i] * A + B; 86 | double p; 87 | double q; 88 | if (fApB >= 0) { 89 | p = (Math.exp(-fApB) / (1.0 + Math.exp(-fApB))); 90 | q = (1.0 / (1.0 + Math.exp(-fApB))); 91 | } else { 92 | p = 1.0 / (1.0 + Math.exp(fApB)); 93 | q = Math.exp(fApB) / (1.0 + Math.exp(fApB)); 94 | } 95 | double d2 = p * q; 96 | h11 += deci[i] * deci[i] * d2; 97 | h22 += d2; 98 | h21 += deci[i] * d2; 99 | double d1 = t[i] - p; 100 | g1 += deci[i] * d1; 101 | g2 += d1; 102 | } 103 | if (Math.abs(g1) < stopping && Math.abs(g2) < stopping) // Stopping 104 | // criteria 105 | break; 106 | 107 | // Compute modified Newton directions 108 | double det = h11 * h22 - h21 * h21; 109 | double dA = -(h22 * g1 - h21 * g2) / det; 110 | double dB = -(-h21 * g1 + h11 * g2) / det; 111 | double gd = g1 * dA + g2 * dB; 112 | double stepsize = 1; 113 | 114 | while (stepsize >= minstep) { // Line search 115 | double newA = A + stepsize * dA; 116 | double newB = B + stepsize * dB; 117 | double newf = 0.0; 118 | for (int i = 0; i < len; i++) { 119 | double fApB = deci[i] * newA + newB; 120 | if (fApB >= 0) 121 | newf += t[i] * fApB + Math.log(1 + Math.exp(-fApB)); 122 | else 123 | newf += (t[i] - 1) * fApB + Math.log(1 + Math.exp(fApB)); 124 | } 125 | 126 | if (newf < fval + 1e-4 * stepsize * gd) { 127 | A = newA; 128 | B = newB; 129 | fval = newf; 130 | break; // Sufficient decrease satisfied 131 | } else 132 | stepsize /= 2.0; 133 | } 134 | if (stepsize < minstep) { 135 | System.out.println("Line search fails"); 136 | break; 137 | } 138 | } 139 | if (it >= maxiter) 140 | System.out.println("Reaching maximum iterations"); 141 | 142 | return new BinaryPlattNormalizer((float) A, (float) B); 143 | 144 | } 145 | 146 | public static BinaryPlattNormalizer esitmateSigmoid(SimpleDataset dataset, 147 | BinaryLearningAlgorithm binaryLearningAlgorithm, int nFolds) { 148 | 149 | PlattInputList plattInputList = new PlattInputList(); 150 | 151 | Label positiveLabel = binaryLearningAlgorithm.getLabel(); 152 | 153 | SimpleDataset[] folds = dataset.getShuffledDataset().nFolding(nFolds); 154 | 155 | for (int f = 0; f < folds.length; f++) { 156 | 157 | SimpleDataset fold = folds[f]; 158 | 159 | SimpleDataset localTrainDataset = new SimpleDataset(); 160 | SimpleDataset localTestDataset = new SimpleDataset(); 161 | for (int i = 0; i < folds.length; i++) { 162 | if (i != f) { 163 | localTrainDataset.addExamples(fold); 164 | } else { 165 | localTestDataset.addExamples(fold); 166 | } 167 | } 168 | 169 | LearningAlgorithm duplicatedLearningAlgorithm = binaryLearningAlgorithm.duplicate(); 170 | 171 | duplicatedLearningAlgorithm.learn(fold); 172 | 173 | PredictionFunction predictionFunction = duplicatedLearningAlgorithm.getPredictionFunction(); 174 | 175 | for (Example example : localTestDataset.getExamples()) { 176 | Prediction predict = predictionFunction.predict(example); 177 | 178 | float value = predict.getScore(positiveLabel); 179 | 180 | int label = 1; 181 | if (!example.isExampleOf(positiveLabel)) 182 | label = -1; 183 | plattInputList.add(new PlattInputElement(label, value)); 184 | } 185 | } 186 | 187 | return estimateSigmoid(plattInputList); 188 | } 189 | 190 | public static MulticlassPlattNormalizer esitmateSigmoid(SimpleDataset dataset, OneVsAllLearning oneVsAllLearning, 191 | int nFolds) { 192 | 193 | HashMap plattInputLists = new HashMap(); 194 | for(Label label: dataset.getClassificationLabels()){ 195 | plattInputLists.put(label, new PlattInputList()); 196 | } 197 | 198 | SimpleDataset[] folds = dataset.getShuffledDataset().nFolding(nFolds); 199 | 200 | MulticlassPlattNormalizer res = new MulticlassPlattNormalizer(); 201 | 202 | for (int f = 0; f < folds.length; f++) { 203 | 204 | SimpleDataset fold = folds[f]; 205 | 206 | SimpleDataset localTrainDataset = new SimpleDataset(); 207 | SimpleDataset localTestDataset = new SimpleDataset(); 208 | for (int i = 0; i < folds.length; i++) { 209 | if (i != f) { 210 | localTrainDataset.addExamples(fold); 211 | } else { 212 | localTestDataset.addExamples(fold); 213 | } 214 | } 215 | 216 | LearningAlgorithm duplicatedLearningAlgorithm = oneVsAllLearning.duplicate(); 217 | 218 | duplicatedLearningAlgorithm.learn(fold); 219 | 220 | PredictionFunction predictionFunction = duplicatedLearningAlgorithm.getPredictionFunction(); 221 | 222 | for (Example example : localTestDataset.getExamples()) { 223 | Prediction predict = predictionFunction.predict(example); 224 | 225 | for (Label label : dataset.getClassificationLabels()) { 226 | 227 | float valueOfLabel = predict.getScore(label); 228 | 229 | int binaryLabel = 1; 230 | if (!example.isExampleOf(label)) 231 | binaryLabel = -1; 232 | plattInputLists.get(label).add(new PlattInputElement(binaryLabel, valueOfLabel)); 233 | } 234 | } 235 | } 236 | 237 | for (Label label : dataset.getClassificationLabels()) { 238 | res.addBinaryPlattNormalizer(label, estimateSigmoid(plattInputLists.get(label))); 239 | } 240 | 241 | return res; 242 | } 243 | 244 | protected static BinaryPlattNormalizer estimateSigmoid(PlattInputList inputList) { 245 | float[] deci = new float[inputList.size()]; 246 | float[] label = new float[inputList.size()]; 247 | int prior1 = inputList.getPositiveElement(); 248 | int prior0 = inputList.getNegativeElement(); 249 | 250 | for (int i = 0; i < inputList.size(); i++) { 251 | deci[i] = inputList.get(i).getValue(); 252 | label[i] = inputList.get(i).getLabel(); 253 | } 254 | 255 | return estimateSigmoid(deci, label, prior1, prior0); 256 | } 257 | 258 | } 259 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/scw/SCWType.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.scw; 17 | 18 | /** 19 | * The two types of Soft Confidence-Weighted implemented variants 20 | * 21 | * @author Danilo Croce 22 | * 23 | */ 24 | public enum SCWType { 25 | 26 | SCW_I, SCW_II 27 | 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/clustering/kernelbasedkmeans/KernelBasedKMeansExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.clustering.kernelbasedkmeans; 17 | 18 | import com.fasterxml.jackson.annotation.JsonTypeName; 19 | 20 | import it.uniroma2.sag.kelp.data.clustering.ClusterExample; 21 | import it.uniroma2.sag.kelp.data.example.Example; 22 | 23 | @JsonTypeName("kernelbasedkmeansexample") 24 | public class KernelBasedKMeansExample extends ClusterExample { 25 | 26 | /** 27 | * 28 | */ 29 | private static final long serialVersionUID = -5368757832244686390L; 30 | 31 | public KernelBasedKMeansExample() { 32 | super(); 33 | } 34 | 35 | public KernelBasedKMeansExample(Example e, float dist) { 36 | super(e, dist); 37 | } 38 | 39 | @Override 40 | public Example getExample() { 41 | return example; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/liblinear/LibLinearRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.liblinear; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.label.Label; 20 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 21 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 22 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.L2R_L2_SvcFunction; 23 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.L2R_L2_SvrFunction; 24 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.Problem; 25 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.Problem.LibLinearSolverType; 26 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.Tron; 27 | import it.uniroma2.sag.kelp.learningalgorithm.regression.RegressionLearningAlgorithm; 28 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 29 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 30 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateLinearRegressionFunction; 31 | 32 | import java.util.Arrays; 33 | import java.util.List; 34 | 35 | import com.fasterxml.jackson.annotation.JsonIgnore; 36 | import com.fasterxml.jackson.annotation.JsonTypeName; 37 | 38 | /** 39 | * This class implements linear SVM regression trained using a coordinate descent 40 | * algorithm [Fan et al, 2008]. It operates in an explicit feature space (i.e. 41 | * it does not relies on any kernel). This code has been adapted from the Java 42 | * port of the original LIBLINEAR C++ sources. 43 | * 44 | * Further details can be found in: 45 | * 46 | * [Fan et al, 2008] R.-E. Fan, K.-W. Chang, C.-J. Hsieh, X.-R. Wang, and C.-J. 47 | * Lin. LIBLINEAR: A Library for Large Linear Classification, Journal of Machine 48 | * Learning Research 9(2008), 1871-1874. Software available at 49 | * 50 | * The original LIBLINEAR code: 51 | * http://www.csie.ntu.edu.tw/~cjlin/liblinear 52 | * 53 | * The original JAVA porting (v 1.94): http://liblinear.bwaldvogel.de 54 | * 55 | * @author Danilo Croce 56 | */ 57 | @JsonTypeName("liblinearregression") 58 | public class LibLinearRegression implements LinearMethod, 59 | RegressionLearningAlgorithm, BinaryLearningAlgorithm { 60 | 61 | /** 62 | * The property corresponding to the variable to be learned 63 | */ 64 | private Label label; 65 | /** 66 | * The regularization parameter 67 | */ 68 | private double c = 1; 69 | 70 | /** 71 | * The regressor to be returned 72 | */ 73 | @JsonIgnore 74 | private UnivariateLinearRegressionFunction regressionFunction; 75 | 76 | /** 77 | * The epsilon in loss function of SVR (default 0.1) 78 | */ 79 | private double p = 0.1f; 80 | 81 | /** 82 | * The identifier of the representation to be considered for the training 83 | * step 84 | */ 85 | private String representation; 86 | 87 | /** 88 | * @param label 89 | * The regression property to be learned 90 | * @param c 91 | * The regularization parameter 92 | * 93 | * @param p 94 | * The The epsilon in loss function of SVR 95 | * 96 | * @param representationName 97 | * The identifier of the representation to be considered for the 98 | * training step 99 | */ 100 | public LibLinearRegression(Label label, double c, double p, 101 | String representationName) { 102 | this(); 103 | 104 | this.setLabel(label); 105 | this.c = c; 106 | this.p = p; 107 | this.setRepresentation(representationName); 108 | } 109 | 110 | /** 111 | * @param c 112 | * The regularization parameter 113 | * 114 | * @param representationName 115 | * The identifier of the representation to be considered for the 116 | * training step 117 | */ 118 | public LibLinearRegression(double c, double p, String representationName) { 119 | this(); 120 | this.c = c; 121 | this.p = p; 122 | this.setRepresentation(representationName); 123 | } 124 | 125 | public LibLinearRegression() { 126 | this.regressionFunction = new UnivariateLinearRegressionFunction(); 127 | this.regressionFunction.setModel(new BinaryLinearModel()); 128 | } 129 | 130 | /** 131 | * @return the regularization parameter 132 | */ 133 | public double getC() { 134 | return c; 135 | } 136 | 137 | /** 138 | * @param c 139 | * the regularization parameter 140 | */ 141 | public void setC(double c) { 142 | this.c = c; 143 | } 144 | 145 | /** 146 | * @return the epsilon in loss function 147 | */ 148 | public double getP() { 149 | return p; 150 | } 151 | 152 | /** 153 | * @param p 154 | * the epsilon in loss function 155 | */ 156 | public void setP(double p) { 157 | this.p = p; 158 | } 159 | 160 | /* 161 | * (non-Javadoc) 162 | * 163 | * @see 164 | * it.uniroma2.sag.kelp.learningalgorithm.LinearMethod#getRepresentation() 165 | */ 166 | @Override 167 | public String getRepresentation() { 168 | return representation; 169 | } 170 | 171 | /* 172 | * (non-Javadoc) 173 | * 174 | * @see 175 | * it.uniroma2.sag.kelp.learningalgorithm.LinearMethod#setRepresentation 176 | * (java.lang.String) 177 | */ 178 | @Override 179 | public void setRepresentation(String representation) { 180 | this.representation = representation; 181 | BinaryLinearModel model = this.regressionFunction.getModel(); 182 | model.setRepresentation(representation); 183 | } 184 | 185 | /* 186 | * (non-Javadoc) 187 | * 188 | * @see 189 | * it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#setLabels(java 190 | * .util.List) 191 | */ 192 | @Override 193 | public void setLabels(List labels) { 194 | if (labels.size() != 1) { 195 | throw new IllegalArgumentException( 196 | "LibLinear algorithm is a binary method which can learn a single Label"); 197 | } else { 198 | this.label = labels.get(0); 199 | this.regressionFunction.setLabels(labels); 200 | } 201 | } 202 | 203 | /* 204 | * (non-Javadoc) 205 | * 206 | * @see it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#getLabels() 207 | */ 208 | @Override 209 | public List getLabels() { 210 | return Arrays.asList(label); 211 | } 212 | 213 | /* 214 | * (non-Javadoc) 215 | * 216 | * @see 217 | * it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm#getLabel() 218 | */ 219 | @Override 220 | public Label getLabel() { 221 | return this.label; 222 | } 223 | 224 | /* 225 | * (non-Javadoc) 226 | * 227 | * @see 228 | * it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm#setLabel 229 | * (it.uniroma2.sag.kelp.data.label.Label) 230 | */ 231 | @Override 232 | public void setLabel(Label label) { 233 | this.setLabels(Arrays.asList(label)); 234 | } 235 | 236 | /* 237 | * (non-Javadoc) 238 | * 239 | * @see 240 | * it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#learn(it.uniroma2 241 | * .sag.kelp.data.dataset.Dataset) 242 | */ 243 | @Override 244 | public void learn(Dataset dataset) { 245 | 246 | double eps = 0.001; 247 | 248 | int l = dataset.getNumberOfExamples(); 249 | 250 | double[] C = new double[l]; 251 | for (int i = 0; i < l; i++) { 252 | C[i] = c; 253 | } 254 | 255 | Problem problem = new Problem(dataset, representation, label, 256 | LibLinearSolverType.REGRESSION); 257 | 258 | L2R_L2_SvcFunction fun_obj = new L2R_L2_SvrFunction(problem, C, p); 259 | 260 | Tron tron = new Tron(fun_obj, eps); 261 | 262 | double[] w = new double[problem.n]; 263 | tron.tron(w); 264 | 265 | this.regressionFunction.getModel().setHyperplane(problem.getW(w)); 266 | this.regressionFunction.getModel().setRepresentation(representation); 267 | this.regressionFunction.getModel().setBias(0); 268 | } 269 | 270 | /* 271 | * (non-Javadoc) 272 | * 273 | * @see it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#duplicate() 274 | */ 275 | @Override 276 | public LibLinearRegression duplicate() { 277 | LibLinearRegression copy = new LibLinearRegression(); 278 | copy.setRepresentation(representation); 279 | copy.setC(c); 280 | copy.setP(p); 281 | return copy; 282 | } 283 | 284 | /* 285 | * (non-Javadoc) 286 | * 287 | * @see it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#reset() 288 | */ 289 | @Override 290 | public void reset() { 291 | this.regressionFunction.reset(); 292 | } 293 | 294 | @Override 295 | public UnivariateLinearRegressionFunction getPredictionFunction() { 296 | return regressionFunction; 297 | } 298 | 299 | @Override 300 | public void setPredictionFunction(PredictionFunction predictionFunction) { 301 | this.regressionFunction = (UnivariateLinearRegressionFunction) predictionFunction; 302 | } 303 | 304 | } 305 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/KernelizedPassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import com.fasterxml.jackson.annotation.JsonTypeName; 19 | 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.kernel.Kernel; 22 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 23 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 24 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateKernelMachineRegressionFunction; 25 | 26 | /** 27 | * Online Passive-Aggressive Learning Algorithm for regression tasks (kernel machine version). 28 | * 29 | * reference: 30 | * 31 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 32 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 33 | * 34 | * @author Simone Filice 35 | */ 36 | @JsonTypeName("kernelizedPA-R") 37 | public class KernelizedPassiveAggressiveRegression extends PassiveAggressiveRegression implements KernelMethod{ 38 | 39 | private Kernel kernel; 40 | 41 | public KernelizedPassiveAggressiveRegression(){ 42 | this.regressor = new UnivariateKernelMachineRegressionFunction(); 43 | } 44 | 45 | public KernelizedPassiveAggressiveRegression(float aggressiveness, float epsilon, Policy policy, Kernel kernel, Label label){ 46 | this.regressor = new UnivariateKernelMachineRegressionFunction(); 47 | this.setC(aggressiveness); 48 | this.setEpsilon(epsilon); 49 | this.setPolicy(policy); 50 | this.setKernel(kernel); 51 | this.setLabel(label); 52 | } 53 | 54 | @Override 55 | public Kernel getKernel(){ 56 | return kernel; 57 | } 58 | 59 | @Override 60 | public void setKernel(Kernel kernel) { 61 | this.kernel = kernel; 62 | this.getPredictionFunction().getModel().setKernel(kernel); 63 | } 64 | 65 | @Override 66 | public KernelizedPassiveAggressiveRegression duplicate() { 67 | KernelizedPassiveAggressiveRegression copy = new KernelizedPassiveAggressiveRegression(); 68 | copy.setC(this.c); 69 | copy.setKernel(this.kernel); 70 | copy.setPolicy(this.policy); 71 | copy.setEpsilon(epsilon); 72 | return copy; 73 | } 74 | 75 | @Override 76 | public UnivariateKernelMachineRegressionFunction getPredictionFunction(){ 77 | return (UnivariateKernelMachineRegressionFunction) this.regressor; 78 | } 79 | 80 | @Override 81 | public void setPredictionFunction(PredictionFunction predictionFunction) { 82 | this.regressor = (UnivariateKernelMachineRegressionFunction) predictionFunction; 83 | } 84 | 85 | } 86 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/LinearPassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import com.fasterxml.jackson.annotation.JsonTypeName; 19 | 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 22 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 23 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 24 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateLinearRegressionFunction; 25 | 26 | /** 27 | * Online Passive-Aggressive Learning Algorithm for regression tasks (linear version). 28 | * 29 | * reference: 30 | * 31 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 32 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 33 | * 34 | * @author Simone Filice 35 | */ 36 | @JsonTypeName("linearPA-R") 37 | public class LinearPassiveAggressiveRegression extends PassiveAggressiveRegression implements LinearMethod{ 38 | 39 | private String representation; 40 | 41 | public LinearPassiveAggressiveRegression(){ 42 | UnivariateLinearRegressionFunction regressor = new UnivariateLinearRegressionFunction(); 43 | regressor.setModel(new BinaryLinearModel()); 44 | this.regressor = regressor; 45 | 46 | } 47 | 48 | public LinearPassiveAggressiveRegression(float aggressiveness, float epsilon, Policy policy, String representation, Label label){ 49 | UnivariateLinearRegressionFunction regressor = new UnivariateLinearRegressionFunction(); 50 | regressor.setModel(new BinaryLinearModel()); 51 | this.regressor = regressor; 52 | this.setC(aggressiveness); 53 | this.setEpsilon(epsilon); 54 | this.setPolicy(policy); 55 | this.setRepresentation(representation); 56 | this.setLabel(label); 57 | } 58 | 59 | @Override 60 | public LinearPassiveAggressiveRegression duplicate() { 61 | LinearPassiveAggressiveRegression copy = new LinearPassiveAggressiveRegression(); 62 | copy.setC(this.c); 63 | copy.setRepresentation(this.representation); 64 | copy.setPolicy(this.policy); 65 | copy.setEpsilon(epsilon); 66 | return copy; 67 | } 68 | 69 | @Override 70 | public String getRepresentation() { 71 | return representation; 72 | } 73 | 74 | @Override 75 | public void setRepresentation(String representation) { 76 | this.representation = representation; 77 | this.getPredictionFunction().getModel().setRepresentation(representation); 78 | } 79 | 80 | @Override 81 | public UnivariateLinearRegressionFunction getPredictionFunction(){ 82 | return (UnivariateLinearRegressionFunction) this.regressor; 83 | } 84 | 85 | @Override 86 | public void setPredictionFunction(PredictionFunction predictionFunction) { 87 | this.regressor = (UnivariateLinearRegressionFunction) predictionFunction; 88 | } 89 | 90 | } 91 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/PassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.learningalgorithm.PassiveAggressive; 21 | import it.uniroma2.sag.kelp.learningalgorithm.regression.RegressionLearningAlgorithm; 22 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateRegressionOutput; 23 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateRegressionFunction; 24 | 25 | import com.fasterxml.jackson.annotation.JsonIgnore; 26 | 27 | /** 28 | * Online Passive-Aggressive Learning Algorithm for regression tasks. 29 | * 30 | * reference: 31 | * 32 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 33 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 34 | * 35 | * @author Simone Filice 36 | */ 37 | public abstract class PassiveAggressiveRegression extends PassiveAggressive implements RegressionLearningAlgorithm{ 38 | 39 | @JsonIgnore 40 | protected UnivariateRegressionFunction regressor; 41 | 42 | protected float epsilon; 43 | 44 | /** 45 | * Returns epsilon, i.e. the accepted distance between the predicted and the real regression values 46 | * 47 | * @return the epsilon 48 | */ 49 | public float getEpsilon() { 50 | return epsilon; 51 | } 52 | 53 | /** 54 | * Sets epsilon, i.e. the accepted distance between the predicted and the real regression values 55 | * 56 | * @param epsilon the epsilon to set 57 | */ 58 | public void setEpsilon(float epsilon) { 59 | this.epsilon = epsilon; 60 | } 61 | 62 | @Override 63 | public UnivariateRegressionFunction getPredictionFunction() { 64 | return this.regressor; 65 | } 66 | 67 | @Override 68 | public void learn(Dataset dataset){ 69 | 70 | while(dataset.hasNextExample()){ 71 | Example example = dataset.getNextExample(); 72 | this.learn(example); 73 | } 74 | dataset.reset(); 75 | } 76 | 77 | @Override 78 | public UnivariateRegressionOutput learn(Example example){ 79 | UnivariateRegressionOutput prediction=this.regressor.predict(example); 80 | float difference = example.getRegressionValue(label) - prediction.getScore(label); 81 | float lossValue = Math.abs(difference) - epsilon;//it represents the distance from the correct semi-space 82 | if(lossValue>0){ 83 | float exampleSquaredNorm = this.regressor.getModel().getSquaredNorm(example); 84 | float weight = this.computeWeight(example, lossValue, exampleSquaredNorm, c); 85 | if(difference<0){ 86 | weight = -weight; 87 | } 88 | this.regressor.getModel().addExample(weight, example); 89 | } 90 | return prediction; 91 | } 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/linearization/LinearizationFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.linearization; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 20 | import it.uniroma2.sag.kelp.data.example.Example; 21 | import it.uniroma2.sag.kelp.data.representation.Vector; 22 | 23 | /** 24 | * This interface allows implementing function to linearized examples through 25 | * linear representations, i.e. vectors 26 | * 27 | * 28 | * @author Danilo Croce 29 | * 30 | */ 31 | public interface LinearizationFunction { 32 | 33 | /** 34 | * Given an input Example, this method generates a linear 35 | * Representation>, i.e. a Vector. 36 | * 37 | * @param example 38 | * The input example. 39 | * @return The linearized representation of the input example. 40 | */ 41 | public Vector getLinearRepresentation(Example example); 42 | 43 | /** 44 | * This method linearizes an input example, providing a new example 45 | * containing only a representation with a specific name, provided as input. 46 | * The produced example inherits the labels of the input example. 47 | * 48 | * @param example 49 | * The input example. 50 | * @param vectorName 51 | * The name of the linear representation inside the new example 52 | * @return 53 | */ 54 | public Example getLinearizedExample(Example example, String representationName); 55 | 56 | /** 57 | * This method linearizes all the examples in the input dataset 58 | * , generating a corresponding linearized dataset. The produced examples 59 | * inherit the labels of the corresponding input examples. 60 | * 61 | * @param dataset 62 | * The input dataset 63 | * @param representationName 64 | * The name of the linear representation inside the new examples 65 | * @return 66 | */ 67 | public SimpleDataset getLinearizedDataset(Dataset dataset, String representationName); 68 | 69 | /** 70 | * @return the size of the resulting embedding, i.e. the number of resulting 71 | * vector dimensions 72 | */ 73 | public int getEmbeddingSize(); 74 | 75 | } 76 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/predictionfunction/SequencePrediction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.predictionfunction; 17 | 18 | import java.util.ArrayList; 19 | import java.util.List; 20 | 21 | import it.uniroma2.sag.kelp.data.example.SequencePath; 22 | import it.uniroma2.sag.kelp.data.label.Label; 23 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 24 | 25 | /** 26 | * It is a output provided by a machine learning systems on a sequence. This 27 | * specific implementation allows to assign multiple labelings to single 28 | * sequence, useful for some labeling strategies, such as Beam Search. Notice 29 | * that each labeling requires a score to select the more promising labeling. 30 | * 31 | * @author Danilo Croce 32 | * 33 | */ 34 | public class SequencePrediction implements Prediction { 35 | 36 | /** 37 | * 38 | */ 39 | private static final long serialVersionUID = -1040539866977906008L; 40 | /** 41 | * This list contains multiple labelings to be assigned to a single sequence 42 | */ 43 | private List paths; 44 | 45 | public SequencePrediction() { 46 | paths = new ArrayList(); 47 | } 48 | 49 | /** 50 | * @return The best path, i.e., the labeling with the highest score in the 51 | * list of labelings provided by a classifier 52 | */ 53 | public SequencePath bestPath() { 54 | return paths.get(0); 55 | } 56 | 57 | /** 58 | * @return a list containing multiple labelings to be assigned to a single 59 | * sequence 60 | */ 61 | public List getPaths() { 62 | return paths; 63 | } 64 | 65 | @Override 66 | public Float getScore(Label label) { 67 | return null; 68 | } 69 | 70 | /** 71 | * @param paths 72 | * a list contains multiple labelings to be assigned to a single 73 | * sequence 74 | */ 75 | public void setPaths(List paths) { 76 | this.paths = paths; 77 | } 78 | 79 | @Override 80 | public String toString() { 81 | StringBuilder sb = new StringBuilder(); 82 | for (int i = 0; i < paths.size(); i++) { 83 | if (i == 0) 84 | sb.append("Best Path\t"); 85 | else 86 | sb.append("Altern. Path\t"); 87 | SequencePath sequencePath = paths.get(i); 88 | sb.append(sequencePath + "\n"); 89 | } 90 | return sb.toString(); 91 | } 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/predictionfunction/model/SequenceModel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.predictionfunction.model; 17 | 18 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGenerator; 19 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 20 | 21 | /** 22 | * This class implements a model produced by a 23 | * SequenceClassificationLearningAlgorithm 24 | * 25 | * @author Danilo Croce 26 | * 27 | */ 28 | public class SequenceModel implements Model { 29 | 30 | /** 31 | * 32 | */ 33 | private static final long serialVersionUID = -2749198158786953940L; 34 | 35 | /** 36 | * The prediction function producing the emission scores to be considered in 37 | * the Viterbi Decoding 38 | */ 39 | private PredictionFunction basePredictionFunction; 40 | 41 | private SequenceExampleGenerator sequenceExampleGenerator; 42 | 43 | public SequenceModel() { 44 | super(); 45 | } 46 | 47 | public SequenceModel(PredictionFunction basePredictionFunction, SequenceExampleGenerator sequenceExampleGenerator) { 48 | super(); 49 | this.basePredictionFunction = basePredictionFunction; 50 | this.sequenceExampleGenerator = sequenceExampleGenerator; 51 | } 52 | 53 | public PredictionFunction getBasePredictionFunction() { 54 | return basePredictionFunction; 55 | } 56 | 57 | public SequenceExampleGenerator getSequenceExampleGenerator() { 58 | return sequenceExampleGenerator; 59 | } 60 | 61 | @Override 62 | public void reset() { 63 | } 64 | 65 | public void setBasePredictionFunction(PredictionFunction basePredictionFunction) { 66 | this.basePredictionFunction = basePredictionFunction; 67 | } 68 | 69 | public void setSequenceExampleGenerator(SequenceExampleGenerator sequenceExampleGenerator) { 70 | this.sequenceExampleGenerator = sequenceExampleGenerator; 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/utils/evaluation/ClusteringEvaluator.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.utils.evaluation; 2 | 3 | import java.util.ArrayList; 4 | import java.util.HashSet; 5 | import java.util.TreeMap; 6 | 7 | import it.uniroma2.sag.kelp.data.clustering.Cluster; 8 | import it.uniroma2.sag.kelp.data.clustering.ClusterExample; 9 | import it.uniroma2.sag.kelp.data.clustering.ClusterList; 10 | import it.uniroma2.sag.kelp.data.example.Example; 11 | import it.uniroma2.sag.kelp.data.example.SimpleExample; 12 | import it.uniroma2.sag.kelp.data.label.Label; 13 | import it.uniroma2.sag.kelp.data.label.StringLabel; 14 | import it.uniroma2.sag.kelp.learningalgorithm.clustering.kernelbasedkmeans.KernelBasedKMeansExample; 15 | 16 | /** 17 | * 18 | * Implements Evaluation methods for clustering algorithms. 19 | * 20 | * More details about Purity and NMI can be found here: 21 | * 22 | * https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1. 23 | * html 24 | * 25 | * @author Danilo Croce 26 | * 27 | */ 28 | public class ClusteringEvaluator { 29 | 30 | public static float getPurity(ClusterList clusters) { 31 | 32 | float res = 0; 33 | int k = clusters.size(); 34 | 35 | for (int clustId = 0; clustId < k; clustId++) { 36 | 37 | TreeMap classSizes = new TreeMap(); 38 | 39 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 40 | HashSet labels = vce.getExample().getClassificationLabels(); 41 | for (Label label : labels) 42 | if (!classSizes.containsKey(label)) 43 | classSizes.put(label, 1); 44 | else 45 | classSizes.put(label, classSizes.get(label) + 1); 46 | } 47 | 48 | int maxSize = 0; 49 | for (int size : classSizes.values()) { 50 | if (size > maxSize) { 51 | maxSize = size; 52 | } 53 | } 54 | res += maxSize; 55 | } 56 | 57 | return res / (float) clusters.getNumberOfExamples(); 58 | } 59 | 60 | public static float getMI(ClusterList clusters) { 61 | 62 | float res = 0; 63 | 64 | float N = clusters.getNumberOfExamples(); 65 | 66 | int k = clusters.size(); 67 | 68 | TreeMap classCardinality = getClassCardinality(clusters); 69 | 70 | for (int clustId = 0; clustId < k; clustId++) { 71 | 72 | TreeMap classSizes = getClassCardinalityWithinCluster(clusters, clustId); 73 | 74 | for (Label className : classSizes.keySet()) { 75 | int wSize = classSizes.get(className); 76 | res += ((float) wSize / N) * myLog(N * (float) wSize 77 | / (clusters.get(clustId).getExamples().size() * (float) classCardinality.get(className))); 78 | } 79 | 80 | } 81 | 82 | return res; 83 | 84 | } 85 | 86 | private static TreeMap getClassCardinalityWithinCluster(ClusterList clusters, int clustId) { 87 | 88 | TreeMap classSizes = new TreeMap(); 89 | 90 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 91 | HashSet labels = vce.getExample().getClassificationLabels(); 92 | for (Label label : labels) 93 | if (!classSizes.containsKey(label)) 94 | classSizes.put(label, 1); 95 | else 96 | classSizes.put(label, classSizes.get(label) + 1); 97 | } 98 | 99 | return classSizes; 100 | } 101 | 102 | private static float getClusterEntropy(ClusterList clusters) { 103 | 104 | float res = 0; 105 | float N = clusters.getNumberOfExamples(); 106 | int k = clusters.size(); 107 | 108 | for (int clustId = 0; clustId < k; clustId++) { 109 | int clusterElementSize = clusters.get(clustId).getExamples().size(); 110 | if (clusterElementSize != 0) 111 | res -= ((float) clusterElementSize / N) * myLog((float) clusterElementSize / N); 112 | } 113 | return res; 114 | 115 | } 116 | 117 | private static float getClassEntropy(ClusterList clusters) { 118 | 119 | float res = 0; 120 | float N = clusters.getNumberOfExamples(); 121 | 122 | TreeMap classCardinality = getClassCardinality(clusters); 123 | 124 | for (int classSize : classCardinality.values()) { 125 | res -= ((float) classSize / N) * myLog((float) classSize / N); 126 | } 127 | return res; 128 | 129 | } 130 | 131 | private static float myLog(float f) { 132 | return (float) (Math.log(f) / Math.log(2f)); 133 | } 134 | 135 | private static TreeMap getClassCardinality(ClusterList clusters) { 136 | TreeMap classSizes = new TreeMap(); 137 | 138 | int k = clusters.size(); 139 | 140 | for (int clustId = 0; clustId < k; clustId++) { 141 | 142 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 143 | HashSet labels = vce.getExample().getClassificationLabels(); 144 | for (Label label : labels) 145 | if (!classSizes.containsKey(label)) 146 | classSizes.put(label, 1); 147 | else 148 | classSizes.put(label, classSizes.get(label) + 1); 149 | } 150 | } 151 | return classSizes; 152 | } 153 | 154 | public static float getNMI(ClusterList clusters) { 155 | return getMI(clusters) / ((getClusterEntropy(clusters) + getClassEntropy(clusters)) / 2f); 156 | } 157 | 158 | public static String getStatistics(ClusterList clusters) { 159 | StringBuilder sb = new StringBuilder(); 160 | 161 | sb.append("Purity:\t" + getPurity(clusters) + "\n"); 162 | sb.append("Mutual Information:\t" + getMI(clusters) + "\n"); 163 | sb.append("Cluster Entropy:\t" + getClusterEntropy(clusters) + "\n"); 164 | sb.append("Class Entropy:\t" + getClassEntropy(clusters) + "\n"); 165 | sb.append("NMI:\t" + getNMI(clusters)); 166 | 167 | return sb.toString(); 168 | } 169 | 170 | public static void main(String[] args) { 171 | ClusterList clusters = new ClusterList(); 172 | 173 | Cluster c1 = new Cluster("C1"); 174 | ArrayList list1 = new ArrayList(); 175 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 176 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 177 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 178 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 179 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 180 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 181 | for (Example e : list1) { 182 | c1.add(new KernelBasedKMeansExample(e, 1f)); 183 | } 184 | 185 | Cluster c2 = new Cluster("C2"); 186 | ArrayList list2 = new ArrayList(); 187 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 188 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 189 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 190 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 191 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 192 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 193 | for (Example e : list2) { 194 | c2.add(new KernelBasedKMeansExample(e, 1f)); 195 | } 196 | 197 | Cluster c3 = new Cluster("C3"); 198 | ArrayList list3 = new ArrayList(); 199 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 200 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 201 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 202 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 203 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 204 | for (Example e : list3) { 205 | c3.add(new KernelBasedKMeansExample(e, 1f)); 206 | } 207 | 208 | clusters.add(c1); 209 | clusters.add(c2); 210 | clusters.add(c3); 211 | 212 | System.out.println(ClusteringEvaluator.getStatistics(clusters)); 213 | 214 | //From https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html 215 | //Purity = 0.71 216 | //NMI = 0.36 217 | 218 | } 219 | 220 | } 221 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/utils/evaluation/MulticlassSequenceClassificationEvaluator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.utils.evaluation; 17 | 18 | import java.util.List; 19 | 20 | import it.uniroma2.sag.kelp.data.example.Example; 21 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 22 | import it.uniroma2.sag.kelp.data.example.SequencePath; 23 | import it.uniroma2.sag.kelp.data.label.Label; 24 | import it.uniroma2.sag.kelp.data.label.SequenceEmission; 25 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 26 | import it.uniroma2.sag.kelp.predictionfunction.SequencePrediction; 27 | 28 | /** 29 | * This is an instance of an Evaluator. It allows to compute the some common 30 | * measure for classification tasks acting over SequenceExamples. It 31 | * computes precision, recall, f1s for each class, and a global accuracy. 32 | * 33 | * @author Danilo Croce 34 | */ 35 | public class MulticlassSequenceClassificationEvaluator extends MulticlassClassificationEvaluator{ 36 | 37 | /** 38 | * Initialize a new F1Evaluator that will work on the specified classes 39 | * 40 | * @param labels 41 | */ 42 | public MulticlassSequenceClassificationEvaluator(List labels) { 43 | super(labels); 44 | } 45 | 46 | public void addCount(Example test, Prediction prediction) { 47 | addCount((SequenceExample) test, (SequencePrediction) prediction); 48 | } 49 | 50 | /** 51 | * This method should be implemented in the subclasses to update counters 52 | * useful to compute the performance measure 53 | * 54 | * @param test 55 | * the test example 56 | * @param predicted 57 | * the prediction of the system 58 | */ 59 | public void addCount(SequenceExample test, SequencePrediction predicted) { 60 | 61 | SequencePath bestPath = predicted.bestPath(); 62 | 63 | for (int seqIdx = 0; seqIdx < test.getLenght(); seqIdx++) { 64 | 65 | Example testItem = test.getExample(seqIdx); 66 | SequenceEmission sequenceLabel = bestPath.getAssignedSequnceLabels().get(seqIdx); 67 | 68 | for (Label l : this.labels) { 69 | ClassStats stats = this.classStats.get(l); 70 | if(testItem.isExampleOf(l)){ 71 | if(sequenceLabel.getLabel().equals(l)){ 72 | stats.tp++; 73 | totalTp++; 74 | }else{ 75 | stats.fn++; 76 | totalFn++; 77 | } 78 | }else{ 79 | if(sequenceLabel.getLabel().equals(l)){ 80 | stats.fp++; 81 | totalFp++; 82 | }else{ 83 | stats.tn++; 84 | totalTn++; 85 | } 86 | } 87 | 88 | } 89 | 90 | //TODO: check (i) e' giusto valutare l'accuracy dei singoli elementi della sequenza e non della sequenza completa 91 | //(ii) va considerato il caso multilabel 92 | total++; 93 | 94 | if (testItem.isExampleOf(sequenceLabel.getLabel())) { 95 | correct++; 96 | } 97 | 98 | this.computed = false; 99 | } 100 | } 101 | 102 | } 103 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/algorithms/binary/liblinear/LibLinearDenseVsSparseClassificationEvaluator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.algorithms.binary.liblinear; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.data.manipulator.NormalizationManipolator; 22 | import it.uniroma2.sag.kelp.data.manipulator.VectorConcatenationManipulator; 23 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.LibLinearLearningAlgorithm; 24 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 25 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassificationOutput; 26 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassifier; 27 | import it.uniroma2.sag.kelp.utils.evaluation.MulticlassClassificationEvaluator; 28 | import it.uniroma2.sag.kelp.utils.exception.NoSuchPerformanceMeasureException; 29 | 30 | import java.io.FileNotFoundException; 31 | import java.io.UnsupportedEncodingException; 32 | import java.util.ArrayList; 33 | import java.util.List; 34 | 35 | import org.junit.Assert; 36 | import org.junit.Test; 37 | 38 | public class LibLinearDenseVsSparseClassificationEvaluator { 39 | 40 | private static List sparseScores = new ArrayList(); 41 | private static List denseScores = new ArrayList(); 42 | 43 | @Test 44 | public void testConsistency() { 45 | try { 46 | String inputFilePath = "src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz"; 47 | 48 | SimpleDataset dataset = new SimpleDataset(); 49 | dataset.populate(inputFilePath); 50 | SimpleDataset[] split = dataset.split(0.5f); 51 | 52 | SimpleDataset trainingSet = split[0]; 53 | SimpleDataset testSet = split[1]; 54 | float c = 1.0f; 55 | float f1Dense = testDense(trainingSet, c, testSet); 56 | float f1Sparse = testSparse(trainingSet, c, testSet); 57 | 58 | Assert.assertEquals(f1Sparse, f1Dense, 0.000001); 59 | 60 | for (int i = 0; i < sparseScores.size(); i++) { 61 | Assert.assertEquals(sparseScores.get(i), denseScores.get(i), 62 | 0.000001); 63 | } 64 | } catch (FileNotFoundException e) { 65 | e.printStackTrace(); 66 | Assert.assertTrue(false); 67 | } catch (UnsupportedEncodingException e) { 68 | e.printStackTrace(); 69 | Assert.assertTrue(false); 70 | } catch (NoSuchPerformanceMeasureException e) { 71 | e.printStackTrace(); 72 | Assert.assertTrue(false); 73 | } catch (Exception e) { 74 | e.printStackTrace(); 75 | Assert.assertTrue(false); 76 | } 77 | } 78 | 79 | private static float testSparse(SimpleDataset trainingSet, float c, 80 | SimpleDataset testSet) throws FileNotFoundException, 81 | UnsupportedEncodingException, NoSuchPerformanceMeasureException { 82 | List classes = trainingSet.getClassificationLabels(); 83 | NormalizationManipolator norma = new NormalizationManipolator(); 84 | trainingSet.manipulate(norma); 85 | testSet.manipulate(norma); 86 | List repr = new ArrayList(); 87 | repr.add("WS"); 88 | List reprW = new ArrayList(); 89 | reprW.add(1.0f); 90 | VectorConcatenationManipulator man = new VectorConcatenationManipulator( 91 | "WS0", repr, reprW); 92 | trainingSet.manipulate(man); 93 | testSet.manipulate(man); 94 | 95 | LibLinearLearningAlgorithm svmSolver = new LibLinearLearningAlgorithm(); 96 | svmSolver.setCn(c); 97 | svmSolver.setCp(c); 98 | svmSolver.setRepresentation("WS0"); 99 | 100 | OneVsAllLearning ovaLearner = new OneVsAllLearning(); 101 | ovaLearner.setBaseAlgorithm(svmSolver); 102 | ovaLearner.setLabels(classes); 103 | ovaLearner.learn(trainingSet); 104 | OneVsAllClassifier f = ovaLearner.getPredictionFunction(); 105 | MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator( 106 | trainingSet.getClassificationLabels()); 107 | for (Example e : testSet.getExamples()) { 108 | OneVsAllClassificationOutput predict = f.predict(e); 109 | Label l = predict.getPredictedClasses().get(0); 110 | evaluator.addCount(e, predict); 111 | sparseScores.add(predict.getScore(l)); 112 | } 113 | 114 | return evaluator.getMacroF1(); 115 | } 116 | 117 | private static float testDense(SimpleDataset trainingSet, float c, 118 | SimpleDataset testSet) throws FileNotFoundException, 119 | UnsupportedEncodingException, NoSuchPerformanceMeasureException { 120 | List classes = trainingSet.getClassificationLabels(); 121 | 122 | LibLinearLearningAlgorithm svmSolver = new LibLinearLearningAlgorithm(); 123 | svmSolver.setCn(c); 124 | svmSolver.setCp(c); 125 | svmSolver.setRepresentation("WS"); 126 | 127 | OneVsAllLearning ovaLearner = new OneVsAllLearning(); 128 | ovaLearner.setBaseAlgorithm(svmSolver); 129 | ovaLearner.setLabels(classes); 130 | ovaLearner.learn(trainingSet); 131 | OneVsAllClassifier f = ovaLearner.getPredictionFunction(); 132 | MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator( 133 | trainingSet.getClassificationLabels()); 134 | for (Example e : testSet.getExamples()) { 135 | OneVsAllClassificationOutput predict = f.predict(e); 136 | Label l = predict.getPredictedClasses().get(0); 137 | evaluator.addCount(e, predict); 138 | denseScores.add(predict.getScore(l)); 139 | } 140 | 141 | return evaluator.getMacroF1(); 142 | } 143 | 144 | } 145 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/algorithms/incrementalTrain/IncrementalTrainTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Simone Filice and Giuseppe Castellucci and Danilo Croce 3 | * and Giovanni Da San Martino and Alessandro Moschitti and Roberto Basili 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package it.uniroma2.sag.kelp.algorithms.incrementalTrain; 18 | 19 | import java.io.IOException; 20 | import java.util.Random; 21 | 22 | import org.junit.Assert; 23 | import org.junit.BeforeClass; 24 | import org.junit.Test; 25 | 26 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 27 | import it.uniroma2.sag.kelp.data.example.Example; 28 | import it.uniroma2.sag.kelp.data.label.Label; 29 | import it.uniroma2.sag.kelp.data.label.StringLabel; 30 | import it.uniroma2.sag.kelp.kernel.Kernel; 31 | import it.uniroma2.sag.kelp.kernel.cache.FixSizeKernelCache; 32 | import it.uniroma2.sag.kelp.kernel.vector.LinearKernel; 33 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 34 | import it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron.KernelizedPerceptron; 35 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryKernelMachineClassifier; 36 | import it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput; 37 | import it.uniroma2.sag.kelp.predictionfunction.classifier.Classifier; 38 | import it.uniroma2.sag.kelp.utils.JacksonSerializerWrapper; 39 | import it.uniroma2.sag.kelp.utils.ObjectSerializer; 40 | 41 | public class IncrementalTrainTest { 42 | private static Classifier f = null; 43 | private static SimpleDataset trainingSet; 44 | private static SimpleDataset testSet; 45 | private static SimpleDataset [] folds; 46 | private static ObjectSerializer serializer = new JacksonSerializerWrapper(); 47 | private static KernelizedPerceptron learner; 48 | 49 | private static Label positiveClass = new StringLabel("+1"); 50 | 51 | @BeforeClass 52 | public static void learnModel() { 53 | trainingSet = new SimpleDataset(); 54 | testSet = new SimpleDataset(); 55 | try { 56 | trainingSet.populate("src/test/resources/svmTest/binary/binary_train.klp"); 57 | trainingSet.shuffleExamples(new Random()); 58 | // Read a dataset into a test variable 59 | testSet.populate("src/test/resources/svmTest/binary/binary_test.klp"); 60 | } catch (Exception e) { 61 | e.printStackTrace(); 62 | Assert.assertTrue(false); 63 | } 64 | 65 | folds = trainingSet.nFolding(2); 66 | 67 | // define the kernel 68 | Kernel kernel = new LinearKernel("0"); 69 | 70 | // add a cache 71 | kernel.setKernelCache(new FixSizeKernelCache(trainingSet 72 | .getNumberOfExamples())); 73 | 74 | // define the learning algorithm 75 | learner = new KernelizedPerceptron(0.2f, 1f, false, kernel, positiveClass); 76 | 77 | // learn and get the prediction function 78 | learner.learn(trainingSet); 79 | f = learner.getPredictionFunction(); 80 | } 81 | 82 | @Test 83 | public void incrementalTrain() throws IOException{ 84 | String jsonSerialization = serializer.writeValueAsString(learner); 85 | System.out.println(jsonSerialization); 86 | ClassificationLearningAlgorithm jsonAlgo = serializer.readValue(jsonSerialization, ClassificationLearningAlgorithm.class); 87 | jsonAlgo.learn(folds[0]); 88 | jsonAlgo.learn(folds[1]); 89 | Classifier jsonClassifier = jsonAlgo.getPredictionFunction(); 90 | 91 | for(Example ex : testSet.getExamples()){ 92 | ClassificationOutput p = f.predict(ex); 93 | Float score = p.getScore(positiveClass); 94 | ClassificationOutput pJson = jsonClassifier.predict(ex); 95 | Float scoreJson = pJson.getScore(positiveClass); 96 | Assert.assertEquals(scoreJson.floatValue(), score.floatValue(), 97 | 0.001f); 98 | } 99 | } 100 | 101 | @Test 102 | public void reloadAndContinueTraining() throws IOException{ 103 | String jsonLearnerSerialization = serializer.writeValueAsString(learner); 104 | System.out.println(jsonLearnerSerialization); 105 | KernelizedPerceptron jsonAlgo = serializer.readValue(jsonLearnerSerialization, KernelizedPerceptron.class); 106 | jsonAlgo.learn(folds[0]); 107 | String jsonClassifierSerialization = serializer.writeValueAsString(jsonAlgo.getPredictionFunction()); 108 | jsonAlgo = serializer.readValue(jsonLearnerSerialization, KernelizedPerceptron.class); //Brand new classifier 109 | BinaryKernelMachineClassifier jsonClassifier = serializer.readValue(jsonClassifierSerialization, BinaryKernelMachineClassifier.class); 110 | jsonAlgo.getPredictionFunction().setModel(jsonClassifier.getModel()); 111 | jsonAlgo.learn(folds[1]); 112 | jsonClassifier = jsonAlgo.getPredictionFunction(); 113 | 114 | for(Example ex : testSet.getExamples()){ 115 | ClassificationOutput p = f.predict(ex); 116 | Float score = p.getScore(positiveClass); 117 | ClassificationOutput pJson = jsonClassifier.predict(ex); 118 | Float scoreJson = pJson.getScore(positiveClass); 119 | Assert.assertEquals(scoreJson.floatValue(), score.floatValue(), 120 | 0.001f); 121 | } 122 | } 123 | 124 | } 125 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/learningalgorithm/classification/hmm/SequenceLearningLinearTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.hmm; 17 | 18 | import java.io.BufferedReader; 19 | import java.io.FileInputStream; 20 | import java.io.FileNotFoundException; 21 | import java.io.IOException; 22 | import java.io.InputStreamReader; 23 | import java.io.UnsupportedEncodingException; 24 | import java.util.ArrayList; 25 | import java.util.zip.GZIPInputStream; 26 | 27 | import org.junit.Assert; 28 | import org.junit.Test; 29 | 30 | import it.uniroma2.sag.kelp.data.dataset.SequenceDataset; 31 | import it.uniroma2.sag.kelp.data.example.Example; 32 | import it.uniroma2.sag.kelp.data.example.ParsingExampleException; 33 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 34 | import it.uniroma2.sag.kelp.data.example.SequencePath; 35 | import it.uniroma2.sag.kelp.data.label.Label; 36 | import it.uniroma2.sag.kelp.data.label.StringLabel; 37 | import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLearningAlgorithm; 38 | import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLoss; 39 | import it.uniroma2.sag.kelp.predictionfunction.SequencePrediction; 40 | import it.uniroma2.sag.kelp.predictionfunction.SequencePredictionFunction; 41 | 42 | public class SequenceLearningLinearTest { 43 | 44 | private static final Float TOLERANCE = 0.001f; 45 | 46 | public static void main(String[] args) throws Exception { 47 | 48 | } 49 | 50 | @Test 51 | public void testLinear() { 52 | 53 | String inputTrainFilePath = "src/test/resources/sequence_learning/declaration_of_independence.klp.gz"; 54 | String inputTestFilePath = "src/test/resources/sequence_learning/gettysburg_address.klp.gz"; 55 | String scoreFilePath = "src/test/resources/sequence_learning/prediction_test_linear.txt"; 56 | 57 | /* 58 | * Given a targeted item in the sequence, this variable determines the 59 | * number of previous example considered in the learning/labeling 60 | * process. 61 | * 62 | * NOTE: if this variable is set to 0, the learning process corresponds 63 | * to a traditional multi-class classification schema 64 | */ 65 | int transitionsOrder = 1; 66 | 67 | /* 68 | * This variable determines the importance of the transition-based 69 | * features during the learning process. Higher valuers will assign more 70 | * importance to the transitions. 71 | */ 72 | float weight = 1f; 73 | 74 | /* 75 | * The size of the beam to be used in the decoding process. This number 76 | * determines the number of possible sequences produced in the labeling 77 | * process. It will also increase the process complexity. 78 | */ 79 | int beamSize = 5; 80 | 81 | /* 82 | * During the labeling process, each item is classified with respect to 83 | * the target classes. To reduce the complexity of the labeling process, 84 | * this variable determines the number of classes that received the 85 | * highest classification scores to be considered after the 86 | * classification step in the Viterbi Decoding. 87 | */ 88 | int maxEmissionCandidates = 3; 89 | 90 | /* 91 | * This representation contains the feature vector representing items in 92 | * the sequence 93 | */ 94 | String originalRepresentationName = "rep"; 95 | 96 | /* 97 | * Loading the training dataset 98 | */ 99 | SequenceDataset sequenceTrainDataset = new SequenceDataset(); 100 | try { 101 | sequenceTrainDataset.populate(inputTrainFilePath); 102 | } catch (IOException e) { 103 | e.printStackTrace(); 104 | Assert.assertTrue(false); 105 | } catch (InstantiationException e) { 106 | e.printStackTrace(); 107 | Assert.assertTrue(false); 108 | } catch (ParsingExampleException e) { 109 | e.printStackTrace(); 110 | Assert.assertTrue(false); 111 | } catch (Exception e) { 112 | e.printStackTrace(); 113 | Assert.assertTrue(false); 114 | } 115 | 116 | /* 117 | * Instance classifier 118 | */ 119 | float cSVM = 1f; 120 | DCDLearningAlgorithm instanceClassifierLearningAlgorithm = new DCDLearningAlgorithm(cSVM, cSVM, DCDLoss.L1, 121 | false, 50, originalRepresentationName); 122 | 123 | /* 124 | * Sequence classifier. 125 | */ 126 | SequenceClassificationLearningAlgorithm sequenceClassificationLearningAlgorithm = null; 127 | try { 128 | sequenceClassificationLearningAlgorithm = new SequenceClassificationLinearLearningAlgorithm( 129 | instanceClassifierLearningAlgorithm, transitionsOrder, weight); 130 | sequenceClassificationLearningAlgorithm.setMaxEmissionCandidates(maxEmissionCandidates); 131 | sequenceClassificationLearningAlgorithm.setBeamSize(beamSize); 132 | 133 | sequenceClassificationLearningAlgorithm.learn(sequenceTrainDataset); 134 | } catch (Exception e1) { 135 | e1.printStackTrace(); 136 | Assert.assertTrue(false); 137 | } 138 | 139 | SequencePredictionFunction predictionFunction = (SequencePredictionFunction) sequenceClassificationLearningAlgorithm 140 | .getPredictionFunction(); 141 | 142 | /* 143 | * Load the test set 144 | */ 145 | SequenceDataset sequenceTestDataset = new SequenceDataset(); 146 | try { 147 | sequenceTestDataset.populate(inputTestFilePath); 148 | } catch (IOException e) { 149 | e.printStackTrace(); 150 | Assert.assertTrue(false); 151 | } catch (InstantiationException e) { 152 | e.printStackTrace(); 153 | Assert.assertTrue(false); 154 | } catch (ParsingExampleException e) { 155 | e.printStackTrace(); 156 | Assert.assertTrue(false); 157 | } 158 | 159 | /* 160 | * Tagging and evaluating 161 | */ 162 | // PrintStream ps = new PrintStream(scoreFilePath); 163 | ArrayList labels = new ArrayList(); 164 | ArrayList scores = new ArrayList(); 165 | for (Example example : sequenceTestDataset.getExamples()) { 166 | 167 | SequenceExample sequenceExample = (SequenceExample) example; 168 | SequencePrediction sequencePrediction = (SequencePrediction) predictionFunction.predict(sequenceExample); 169 | 170 | SequencePath bestPath = sequencePrediction.bestPath(); 171 | for (int i = 0; i < sequenceExample.getLenght(); i++) { 172 | // ps.println(bestPath.getAssignedLabel(i) + "\t" + 173 | // bestPath.getScore()); 174 | labels.add(bestPath.getAssignedLabel(i)); 175 | scores.add(bestPath.getScore()); 176 | } 177 | 178 | } 179 | // ps.close(); 180 | 181 | ArrayList oldScores = loadScores(scoreFilePath); 182 | ArrayList oldLabels = loadLabels(scoreFilePath); 183 | 184 | for (int i = 0; i < oldScores.size(); i++) { 185 | Assert.assertEquals(oldScores.get(i), scores.get(i), TOLERANCE); 186 | Assert.assertEquals(labels.get(i).toString(), oldLabels.get(i).toString()); 187 | } 188 | 189 | } 190 | 191 | public static ArrayList loadScores(String filepath) { 192 | try { 193 | ArrayList scores = new ArrayList(); 194 | BufferedReader in = null; 195 | String encoding = "UTF-8"; 196 | if (filepath.endsWith(".gz")) { 197 | in = new BufferedReader( 198 | new InputStreamReader(new GZIPInputStream(new FileInputStream(filepath)), encoding)); 199 | } else { 200 | in = new BufferedReader(new InputStreamReader(new FileInputStream(filepath), encoding)); 201 | } 202 | 203 | String str = ""; 204 | while ((str = in.readLine()) != null) { 205 | scores.add(Double.parseDouble(str.split("\t")[1])); 206 | } 207 | 208 | in.close(); 209 | 210 | return scores; 211 | 212 | } catch (UnsupportedEncodingException e) { 213 | e.printStackTrace(); 214 | Assert.assertTrue(false); 215 | } catch (FileNotFoundException e) { 216 | e.printStackTrace(); 217 | Assert.assertTrue(false); 218 | } catch (IOException e) { 219 | e.printStackTrace(); 220 | Assert.assertTrue(false); 221 | } 222 | 223 | return null; 224 | } 225 | 226 | public static ArrayList loadLabels(String filepath) { 227 | try { 228 | ArrayList res = new ArrayList(); 229 | BufferedReader in = null; 230 | String encoding = "UTF-8"; 231 | if (filepath.endsWith(".gz")) { 232 | in = new BufferedReader( 233 | new InputStreamReader(new GZIPInputStream(new FileInputStream(filepath)), encoding)); 234 | } else { 235 | in = new BufferedReader(new InputStreamReader(new FileInputStream(filepath), encoding)); 236 | } 237 | 238 | String str = ""; 239 | while ((str = in.readLine()) != null) { 240 | res.add(new StringLabel(str.split("\t")[0])); 241 | } 242 | 243 | in.close(); 244 | 245 | return res; 246 | 247 | } catch (UnsupportedEncodingException e) { 248 | e.printStackTrace(); 249 | Assert.assertTrue(false); 250 | } catch (FileNotFoundException e) { 251 | e.printStackTrace(); 252 | Assert.assertTrue(false); 253 | } catch (IOException e) { 254 | e.printStackTrace(); 255 | Assert.assertTrue(false); 256 | } 257 | 258 | return null; 259 | } 260 | 261 | } 262 | -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/README.txt: -------------------------------------------------------------------------------- 1 | The datasets reported in this folder have been created starting from the dataset produced by Thorsten Joachims as an example problem for his SVM^{hmm} implementation. 2 | 3 | The original dataset can be downloaded at: 4 | http://download.joachims.org/svm_hmm/examples/example7.tar.gz 5 | while its description is reported at: 6 | https://www.cs.cornell.edu/people/tj/svm_light/svm_hmm.html -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/declaration_of_independence.klp.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/sequence_learning/declaration_of_independence.klp.gz -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/gettysburg_address.klp.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/sequence_learning/gettysburg_address.klp.gz -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/prediction_test_kernel.txt: -------------------------------------------------------------------------------- 1 | 1 -58.85170393685728 2 | 22 -58.85170393685728 3 | 3 -58.85170393685728 4 | 8 -58.85170393685728 5 | 15 -58.85170393685728 6 | 30 -58.85170393685728 7 | 20 -58.85170393685728 8 | 15 -58.85170393685728 9 | 11 -58.85170393685728 10 | 12 -58.85170393685728 11 | 8 -58.85170393685728 12 | 7 -58.85170393685728 13 | 12 -58.85170393685728 14 | 7 -58.85170393685728 15 | 9 -58.85170393685728 16 | 12 -58.85170393685728 17 | 6 -58.85170393685728 18 | 27 -58.85170393685728 19 | 8 -58.85170393685728 20 | 13 -58.85170393685728 21 | 3 -58.85170393685728 22 | 27 -58.85170393685728 23 | 25 -58.85170393685728 24 | 7 -58.85170393685728 25 | 12 -58.85170393685728 26 | 7 -58.85170393685728 27 | 7 -58.85170393685728 28 | 15 -58.85170393685728 29 | 30 -58.85170393685728 30 | 29 -58.85170393685728 31 | 9 -58.85170393685728 32 | 17 -58.85170393685728 33 | 21 -49.958351223707055 34 | 19 -49.958351223707055 35 | 30 -49.958351223707055 36 | 29 -49.958351223707055 37 | 8 -49.958351223707055 38 | 7 -49.958351223707055 39 | 13 -49.958351223707055 40 | 9 -49.958351223707055 41 | 12 -49.958351223707055 42 | 6 -49.958351223707055 43 | 28 -49.958351223707055 44 | 21 -49.958351223707055 45 | 7 -49.958351223707055 46 | 12 -49.958351223707055 47 | 8 -49.958351223707055 48 | 7 -49.958351223707055 49 | 12 -49.958351223707055 50 | 8 -49.958351223707055 51 | 29 -49.958351223707055 52 | 3 -49.958351223707055 53 | 21 -49.958351223707055 54 | 27 -49.958351223707055 55 | 9 -49.958351223707055 56 | 9 -49.958351223707055 57 | 12 -49.958351223707055 58 | 17 -49.958351223707055 59 | 19 -19.14935390144825 60 | 30 -19.14935390144825 61 | 29 -19.14935390144825 62 | 8 -19.14935390144825 63 | 7 -19.14935390144825 64 | 13 -19.14935390144825 65 | 13 -19.14935390144825 66 | 8 -19.14935390144825 67 | 7 -19.14935390144825 68 | 12 -19.14935390144825 69 | 17 -19.14935390144825 70 | 19 -51.68865761583535 71 | 30 -51.68865761583535 72 | 9 -51.68865761583535 73 | 25 -51.68865761583535 74 | 26 -51.68865761583535 75 | 7 -51.68865761583535 76 | 12 -51.68865761583535 77 | 8 -51.68865761583535 78 | 7 -51.68865761583535 79 | 12 -51.68865761583535 80 | 8 -51.68865761583535 81 | 7 -51.68865761583535 82 | 9 -51.68865761583535 83 | 12 -51.68865761583535 84 | 8 -51.68865761583535 85 | 7 -51.68865761583535 86 | 9 -51.68865761583535 87 | 12 -51.68865761583535 88 | 30 -51.68865761583535 89 | 20 -51.68865761583535 90 | 15 -51.68865761583535 91 | 7 -51.68865761583535 92 | 7 -51.68865761583535 93 | 12 -51.68865761583535 94 | 12 -51.68865761583535 95 | 30 -51.68865761583535 96 | 17 -51.68865761583535 97 | 19 -20.546289531993914 98 | 31 -20.546289531993914 99 | 21 -20.546289531993914 100 | 28 -20.546289531993914 101 | 3 -20.546289531993914 102 | 8 -20.546289531993914 103 | 7 -20.546289531993914 104 | 19 -20.546289531993914 105 | 11 -20.546289531993914 106 | 26 -20.546289531993914 107 | 7 -20.546289531993914 108 | 17 -20.546289531993914 109 | 3 -39.18519755398995 110 | 8 -39.18519755398995 111 | 7 -39.18519755398995 112 | 9 -39.18519755398995 113 | 12 -39.18519755398995 114 | 6 -39.18519755398995 115 | 19 -39.18519755398995 116 | 11 -39.18519755398995 117 | 26 -39.18519755398995 118 | 6 -39.18519755398995 119 | 19 -39.18519755398995 120 | 11 -39.18519755398995 121 | 26 -39.18519755398995 122 | 6 -39.18519755398995 123 | 19 -39.18519755398995 124 | 11 -39.18519755398995 125 | 26 -39.18519755398995 126 | 7 -39.18519755398995 127 | 12 -39.18519755398995 128 | 17 -39.18519755398995 129 | 7 -46.958838324933005 130 | 9 -46.958838324933005 131 | 15 -46.958838324933005 132 | 6 -46.958838324933005 133 | 28 -46.958838324933005 134 | 3 -46.958838324933005 135 | 12 -46.958838324933005 136 | 21 -46.958838324933005 137 | 29 -46.958838324933005 138 | 21 -46.958838324933005 139 | 30 -46.958838324933005 140 | 29 -46.958838324933005 141 | 19 -46.958838324933005 142 | 20 -46.958838324933005 143 | 12 -46.958838324933005 144 | 20 -46.958838324933005 145 | 9 -46.958838324933005 146 | 12 -46.958838324933005 147 | 25 -46.958838324933005 148 | 26 -46.958838324933005 149 | 3 -46.958838324933005 150 | 9 -46.958838324933005 151 | 17 -46.958838324933005 152 | 7 -49.424689389703104 153 | 12 -49.424689389703104 154 | 11 -49.424689389703104 155 | 26 -49.424689389703104 156 | 9 -49.424689389703104 157 | 3 -49.424689389703104 158 | 9 -49.424689389703104 159 | 12 -49.424689389703104 160 | 8 -49.424689389703104 161 | 19 -49.424689389703104 162 | 11 -49.424689389703104 163 | 21 -49.424689389703104 164 | 6 -49.424689389703104 165 | 3 -49.424689389703104 166 | 19 -49.424689389703104 167 | 8 -49.424689389703104 168 | 12 -49.424689389703104 169 | 26 -49.424689389703104 170 | 7 -49.424689389703104 171 | 19 -49.424689389703104 172 | 30 -49.424689389703104 173 | 21 -49.424689389703104 174 | 17 -49.424689389703104 175 | 19 -52.114114669781316 176 | 31 -52.114114669781316 177 | 8 -52.114114669781316 178 | 19 -52.114114669781316 179 | 7 -52.114114669781316 180 | 28 -52.114114669781316 181 | 21 -52.114114669781316 182 | 25 -52.114114669781316 183 | 26 -52.114114669781316 184 | 27 -52.114114669781316 185 | 21 -52.114114669781316 186 | 25 -52.114114669781316 187 | 7 -52.114114669781316 188 | 9 -52.114114669781316 189 | 12 -52.114114669781316 190 | 32 -52.114114669781316 191 | 19 -52.114114669781316 192 | 30 -52.114114669781316 193 | 11 -52.114114669781316 194 | 21 -52.114114669781316 195 | 30 -52.114114669781316 196 | 21 -52.114114669781316 197 | 29 -52.114114669781316 198 | 21 -52.114114669781316 199 | 21 -52.114114669781316 200 | 29 -52.114114669781316 201 | 17 -52.114114669781316 202 | 19 -157.40459068974272 203 | 31 -157.40459068974272 204 | 21 -157.40459068974272 205 | 8 -157.40459068974272 206 | 19 -157.40459068974272 207 | 25 -157.40459068974272 208 | 26 -157.40459068974272 209 | 21 -157.40459068974272 210 | 27 -157.40459068974272 211 | 25 -157.40459068974272 212 | 7 -157.40459068974272 213 | 13 -157.40459068974272 214 | 13 -157.40459068974272 215 | 28 -157.40459068974272 216 | 21 -157.40459068974272 217 | 19 -157.40459068974272 218 | 5 -157.40459068974272 219 | 7 -157.40459068974272 220 | 8 -157.40459068974272 221 | 7 -157.40459068974272 222 | 9 -157.40459068974272 223 | 12 -157.40459068974272 224 | 19 -157.40459068974272 225 | 30 -157.40459068974272 226 | 9 -157.40459068974272 227 | 12 -157.40459068974272 228 | 25 -157.40459068974272 229 | 7 -157.40459068974272 230 | 12 -157.40459068974272 231 | 8 -157.40459068974272 232 | 32 -157.40459068974272 233 | 19 -157.40459068974272 234 | 30 -157.40459068974272 235 | 7 -157.40459068974272 236 | 23 -157.40459068974272 237 | 9 -157.40459068974272 238 | 12 -157.40459068974272 239 | 8 -157.40459068974272 240 | 12 -157.40459068974272 241 | 8 -157.40459068974272 242 | 7 -157.40459068974272 243 | 19 -157.40459068974272 244 | 21 -157.40459068974272 245 | 21 -157.40459068974272 246 | 26 -157.40459068974272 247 | 7 -157.40459068974272 248 | 7 -157.40459068974272 249 | 12 -157.40459068974272 250 | 11 -157.40459068974272 251 | 21 -157.40459068974272 252 | 30 -157.40459068974272 253 | 29 -157.40459068974272 254 | 8 -157.40459068974272 255 | 9 -157.40459068974272 256 | 6 -157.40459068974272 257 | 7 -157.40459068974272 258 | 7 -157.40459068974272 259 | 12 -157.40459068974272 260 | 8 -157.40459068974272 261 | 13 -157.40459068974272 262 | 11 -157.40459068974272 263 | 30 -157.40459068974272 264 | 7 -157.40459068974272 265 | 9 -157.40459068974272 266 | 12 -157.40459068974272 267 | 8 -157.40459068974272 268 | 12 -157.40459068974272 269 | 6 -157.40459068974272 270 | 3 -157.40459068974272 271 | 7 -157.40459068974272 272 | 12 -157.40459068974272 273 | 8 -157.40459068974272 274 | 7 -157.40459068974272 275 | 15 -157.40459068974272 276 | 6 -157.40459068974272 277 | 8 -157.40459068974272 278 | 7 -157.40459068974272 279 | 15 -157.40459068974272 280 | 6 -157.40459068974272 281 | 8 -157.40459068974272 282 | 7 -157.40459068974272 283 | 15 -157.40459068974272 284 | 11 -157.40459068974272 285 | 21 -157.40459068974272 286 | 26 -157.40459068974272 287 | 8 -157.40459068974272 288 | 7 -157.40459068974272 289 | 12 -157.40459068974272 290 | 17 -157.40459068974272 291 | -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/prediction_test_linear.txt: -------------------------------------------------------------------------------- 1 | 1 -61.552865965064605 2 | 22 -61.552865965064605 3 | 3 -61.552865965064605 4 | 8 -61.552865965064605 5 | 15 -61.552865965064605 6 | 30 -61.552865965064605 7 | 20 -61.552865965064605 8 | 15 -61.552865965064605 9 | 11 -61.552865965064605 10 | 12 -61.552865965064605 11 | 8 -61.552865965064605 12 | 7 -61.552865965064605 13 | 12 -61.552865965064605 14 | 7 -61.552865965064605 15 | 9 -61.552865965064605 16 | 12 -61.552865965064605 17 | 6 -61.552865965064605 18 | 27 -61.552865965064605 19 | 8 -61.552865965064605 20 | 13 -61.552865965064605 21 | 3 -61.552865965064605 22 | 27 -61.552865965064605 23 | 25 -61.552865965064605 24 | 7 -61.552865965064605 25 | 12 -61.552865965064605 26 | 7 -61.552865965064605 27 | 7 -61.552865965064605 28 | 15 -61.552865965064605 29 | 30 -61.552865965064605 30 | 29 -61.552865965064605 31 | 9 -61.552865965064605 32 | 17 -61.552865965064605 33 | 21 -50.586976361817456 34 | 19 -50.586976361817456 35 | 30 -50.586976361817456 36 | 29 -50.586976361817456 37 | 8 -50.586976361817456 38 | 7 -50.586976361817456 39 | 13 -50.586976361817456 40 | 9 -50.586976361817456 41 | 12 -50.586976361817456 42 | 6 -50.586976361817456 43 | 28 -50.586976361817456 44 | 21 -50.586976361817456 45 | 7 -50.586976361817456 46 | 12 -50.586976361817456 47 | 3 -50.586976361817456 48 | 7 -50.586976361817456 49 | 12 -50.586976361817456 50 | 21 -50.586976361817456 51 | 21 -50.586976361817456 52 | 3 -50.586976361817456 53 | 21 -50.586976361817456 54 | 21 -50.586976361817456 55 | 7 -50.586976361817456 56 | 9 -50.586976361817456 57 | 12 -50.586976361817456 58 | 17 -50.586976361817456 59 | 19 -19.745991163812985 60 | 30 -19.745991163812985 61 | 26 -19.745991163812985 62 | 8 -19.745991163812985 63 | 7 -19.745991163812985 64 | 13 -19.745991163812985 65 | 13 -19.745991163812985 66 | 8 -19.745991163812985 67 | 7 -19.745991163812985 68 | 12 -19.745991163812985 69 | 17 -19.745991163812985 70 | 19 -54.528993898737625 71 | 30 -54.528993898737625 72 | 9 -54.528993898737625 73 | 25 -54.528993898737625 74 | 26 -54.528993898737625 75 | 7 -54.528993898737625 76 | 12 -54.528993898737625 77 | 8 -54.528993898737625 78 | 7 -54.528993898737625 79 | 12 -54.528993898737625 80 | 8 -54.528993898737625 81 | 7 -54.528993898737625 82 | 9 -54.528993898737625 83 | 12 -54.528993898737625 84 | 8 -54.528993898737625 85 | 7 -54.528993898737625 86 | 30 -54.528993898737625 87 | 21 -54.528993898737625 88 | 30 -54.528993898737625 89 | 20 -54.528993898737625 90 | 15 -54.528993898737625 91 | 7 -54.528993898737625 92 | 7 -54.528993898737625 93 | 12 -54.528993898737625 94 | 12 -54.528993898737625 95 | 30 -54.528993898737625 96 | 17 -54.528993898737625 97 | 19 -21.642864657110263 98 | 31 -21.642864657110263 99 | 21 -21.642864657110263 100 | 28 -21.642864657110263 101 | 3 -21.642864657110263 102 | 8 -21.642864657110263 103 | 7 -21.642864657110263 104 | 19 -21.642864657110263 105 | 11 -21.642864657110263 106 | 26 -21.642864657110263 107 | 7 -21.642864657110263 108 | 17 -21.642864657110263 109 | 3 -41.163958681094705 110 | 8 -41.163958681094705 111 | 7 -41.163958681094705 112 | 9 -41.163958681094705 113 | 12 -41.163958681094705 114 | 6 -41.163958681094705 115 | 19 -41.163958681094705 116 | 11 -41.163958681094705 117 | 26 -41.163958681094705 118 | 6 -41.163958681094705 119 | 19 -41.163958681094705 120 | 11 -41.163958681094705 121 | 26 -41.163958681094705 122 | 6 -41.163958681094705 123 | 19 -41.163958681094705 124 | 11 -41.163958681094705 125 | 26 -41.163958681094705 126 | 7 -41.163958681094705 127 | 12 -41.163958681094705 128 | 17 -41.163958681094705 129 | 7 -47.69120077996808 130 | 9 -47.69120077996808 131 | 15 -47.69120077996808 132 | 6 -47.69120077996808 133 | 28 -47.69120077996808 134 | 3 -47.69120077996808 135 | 12 -47.69120077996808 136 | 21 -47.69120077996808 137 | 29 -47.69120077996808 138 | 21 -47.69120077996808 139 | 30 -47.69120077996808 140 | 29 -47.69120077996808 141 | 19 -47.69120077996808 142 | 12 -47.69120077996808 143 | 30 -47.69120077996808 144 | 20 -47.69120077996808 145 | 9 -47.69120077996808 146 | 12 -47.69120077996808 147 | 25 -47.69120077996808 148 | 26 -47.69120077996808 149 | 3 -47.69120077996808 150 | 12 -47.69120077996808 151 | 17 -47.69120077996808 152 | 7 -49.193906625295206 153 | 12 -49.193906625295206 154 | 11 -49.193906625295206 155 | 26 -49.193906625295206 156 | 9 -49.193906625295206 157 | 3 -49.193906625295206 158 | 9 -49.193906625295206 159 | 12 -49.193906625295206 160 | 7 -49.193906625295206 161 | 19 -49.193906625295206 162 | 11 -49.193906625295206 163 | 21 -49.193906625295206 164 | 6 -49.193906625295206 165 | 3 -49.193906625295206 166 | 19 -49.193906625295206 167 | 8 -49.193906625295206 168 | 12 -49.193906625295206 169 | 26 -49.193906625295206 170 | 7 -49.193906625295206 171 | 19 -49.193906625295206 172 | 30 -49.193906625295206 173 | 21 -49.193906625295206 174 | 17 -49.193906625295206 175 | 19 -52.77654733531991 176 | 31 -52.77654733531991 177 | 8 -52.77654733531991 178 | 19 -52.77654733531991 179 | 7 -52.77654733531991 180 | 28 -52.77654733531991 181 | 21 -52.77654733531991 182 | 25 -52.77654733531991 183 | 26 -52.77654733531991 184 | 27 -52.77654733531991 185 | 21 -52.77654733531991 186 | 25 -52.77654733531991 187 | 7 -52.77654733531991 188 | 9 -52.77654733531991 189 | 12 -52.77654733531991 190 | 32 -52.77654733531991 191 | 19 -52.77654733531991 192 | 30 -52.77654733531991 193 | 11 -52.77654733531991 194 | 21 -52.77654733531991 195 | 30 -52.77654733531991 196 | 21 -52.77654733531991 197 | 12 -52.77654733531991 198 | 21 -52.77654733531991 199 | 21 -52.77654733531991 200 | 29 -52.77654733531991 201 | 17 -52.77654733531991 202 | 19 -162.86896772139426 203 | 31 -162.86896772139426 204 | 21 -162.86896772139426 205 | 8 -162.86896772139426 206 | 19 -162.86896772139426 207 | 25 -162.86896772139426 208 | 26 -162.86896772139426 209 | 21 -162.86896772139426 210 | 27 -162.86896772139426 211 | 25 -162.86896772139426 212 | 7 -162.86896772139426 213 | 13 -162.86896772139426 214 | 13 -162.86896772139426 215 | 28 -162.86896772139426 216 | 21 -162.86896772139426 217 | 19 -162.86896772139426 218 | 5 -162.86896772139426 219 | 7 -162.86896772139426 220 | 8 -162.86896772139426 221 | 7 -162.86896772139426 222 | 9 -162.86896772139426 223 | 12 -162.86896772139426 224 | 19 -162.86896772139426 225 | 30 -162.86896772139426 226 | 9 -162.86896772139426 227 | 12 -162.86896772139426 228 | 25 -162.86896772139426 229 | 7 -162.86896772139426 230 | 12 -162.86896772139426 231 | 8 -162.86896772139426 232 | 32 -162.86896772139426 233 | 19 -162.86896772139426 234 | 30 -162.86896772139426 235 | 7 -162.86896772139426 236 | 23 -162.86896772139426 237 | 9 -162.86896772139426 238 | 12 -162.86896772139426 239 | 8 -162.86896772139426 240 | 12 -162.86896772139426 241 | 5 -162.86896772139426 242 | 7 -162.86896772139426 243 | 19 -162.86896772139426 244 | 21 -162.86896772139426 245 | 21 -162.86896772139426 246 | 26 -162.86896772139426 247 | 7 -162.86896772139426 248 | 7 -162.86896772139426 249 | 12 -162.86896772139426 250 | 11 -162.86896772139426 251 | 21 -162.86896772139426 252 | 30 -162.86896772139426 253 | 29 -162.86896772139426 254 | 8 -162.86896772139426 255 | 9 -162.86896772139426 256 | 6 -162.86896772139426 257 | 7 -162.86896772139426 258 | 7 -162.86896772139426 259 | 12 -162.86896772139426 260 | 8 -162.86896772139426 261 | 13 -162.86896772139426 262 | 11 -162.86896772139426 263 | 30 -162.86896772139426 264 | 7 -162.86896772139426 265 | 9 -162.86896772139426 266 | 12 -162.86896772139426 267 | 8 -162.86896772139426 268 | 12 -162.86896772139426 269 | 6 -162.86896772139426 270 | 3 -162.86896772139426 271 | 8 -162.86896772139426 272 | 12 -162.86896772139426 273 | 8 -162.86896772139426 274 | 7 -162.86896772139426 275 | 15 -162.86896772139426 276 | 6 -162.86896772139426 277 | 8 -162.86896772139426 278 | 7 -162.86896772139426 279 | 15 -162.86896772139426 280 | 6 -162.86896772139426 281 | 8 -162.86896772139426 282 | 7 -162.86896772139426 283 | 15 -162.86896772139426 284 | 11 -162.86896772139426 285 | 21 -162.86896772139426 286 | 26 -162.86896772139426 287 | 8 -162.86896772139426 288 | 7 -162.86896772139426 289 | 12 -162.86896772139426 290 | 17 -162.86896772139426 291 | -------------------------------------------------------------------------------- /src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz --------------------------------------------------------------------------------
38 | * [CavallantiCOLT2006] G. Cavallanti, N. Cesa-Bianchi, C. Gentile. Tracking the best hyperplane with a simple budget Perceptron. In proc. of the 19-th annual conference on Computational Learning Theory. (2006) 39 | *
40 | * Until the budget is not reached the online learning updating policy is the one of the baseAlgorithm that this 41 | * meta-algorithm is exploiting. When the budget is full, a random support vector is deleted and the perceptron updating policy is 42 | * adopted 43 | * 44 | * @author Simone Filice 45 | * 46 | */ 47 | @JsonTypeName("randomizedPerceptron") 48 | public class RandomizedBudgetPerceptron extends BudgetedLearningAlgorithm implements MetaLearningAlgorithm{ 49 | 50 | private static final long DEFAULT_SEED=1; 51 | private long initialSeed = DEFAULT_SEED; 52 | @JsonIgnore 53 | private Random randomGenerator; 54 | 55 | private OnlineLearningAlgorithm baseAlgorithm; 56 | 57 | public RandomizedBudgetPerceptron(){ 58 | randomGenerator = new Random(initialSeed); 59 | } 60 | 61 | public RandomizedBudgetPerceptron(int budget, OnlineLearningAlgorithm baseAlgorithm, long seed, Label label){ 62 | randomGenerator = new Random(initialSeed); 63 | this.setBudget(budget); 64 | this.setBaseAlgorithm(baseAlgorithm); 65 | this.setSeed(seed); 66 | this.setLabel(label); 67 | } 68 | 69 | /** 70 | * Sets the seed for the random generator adopted to select the support vector to delete 71 | * 72 | * @param seed the seed of the randomGenerator 73 | */ 74 | public void setSeed(long seed){ 75 | this.initialSeed = seed; 76 | this.randomGenerator.setSeed(seed); 77 | } 78 | 79 | @Override 80 | public RandomizedBudgetPerceptron duplicate() { 81 | RandomizedBudgetPerceptron copy = new RandomizedBudgetPerceptron(); 82 | copy.setBudget(budget); 83 | copy.setBaseAlgorithm(baseAlgorithm.duplicate()); 84 | copy.setSeed(initialSeed); 85 | return copy; 86 | } 87 | 88 | @Override 89 | public void reset() { 90 | this.baseAlgorithm.reset(); 91 | this.randomGenerator.setSeed(initialSeed); 92 | } 93 | 94 | @Override 95 | protected Prediction predictAndLearnWithFullBudget(Example example) { 96 | Prediction prediction = this.baseAlgorithm.getPredictionFunction().predict(example); 97 | 98 | if((prediction.getScore(getLabel())>0) != example.isExampleOf(getLabel())){ 99 | int svToDelete = this.randomGenerator.nextInt(budget); 100 | float weight = 1; 101 | if(!example.isExampleOf(getLabels().get(0))){ 102 | weight=-1; 103 | } 104 | SupportVector sv = new SupportVector(weight, example); 105 | 106 | ((BinaryKernelMachineModel)this.baseAlgorithm.getPredictionFunction().getModel()).setSupportVector(sv, svToDelete); 107 | } 108 | return prediction; 109 | } 110 | 111 | @Override 112 | public void setBaseAlgorithm(LearningAlgorithm baseAlgorithm) { 113 | if(baseAlgorithm instanceof OnlineLearningAlgorithm && baseAlgorithm instanceof KernelMethod && baseAlgorithm instanceof BinaryLearningAlgorithm){ 114 | this.baseAlgorithm = (OnlineLearningAlgorithm) baseAlgorithm; 115 | }else{ 116 | throw new IllegalArgumentException("a valid baseAlgorithm for the Randomized Budget Perceptron must implement OnlineLearningAlgorithm, BinaryLeaningAlgorithm and KernelMethod"); 117 | } 118 | } 119 | 120 | @Override 121 | public OnlineLearningAlgorithm getBaseAlgorithm() { 122 | return this.baseAlgorithm; 123 | } 124 | 125 | @Override 126 | public PredictionFunction getPredictionFunction() { 127 | return this.baseAlgorithm.getPredictionFunction(); 128 | } 129 | 130 | @Override 131 | public Kernel getKernel() { 132 | return ((KernelMethod)this.baseAlgorithm).getKernel(); 133 | } 134 | 135 | @Override 136 | public void setKernel(Kernel kernel) { 137 | ((KernelMethod)this.baseAlgorithm).setKernel(kernel); 138 | 139 | } 140 | 141 | @Override 142 | protected Prediction predictAndLearnWithAvailableBudget(Example example) { 143 | return this.baseAlgorithm.learn(example); 144 | } 145 | 146 | @Override 147 | public void setPredictionFunction(PredictionFunction predictionFunction) { 148 | this.baseAlgorithm.setPredictionFunction(predictionFunction); 149 | } 150 | 151 | } 152 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/budgetedAlgorithm/Stoptron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.budgetedAlgorithm; 17 | 18 | import it.uniroma2.sag.kelp.data.example.Example; 19 | import it.uniroma2.sag.kelp.data.label.Label; 20 | import it.uniroma2.sag.kelp.kernel.Kernel; 21 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 22 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 23 | import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm; 24 | import it.uniroma2.sag.kelp.learningalgorithm.MetaLearningAlgorithm; 25 | import it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm; 26 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 27 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 28 | 29 | import com.fasterxml.jackson.annotation.JsonTypeName; 30 | 31 | /** 32 | * It is a variation of the Stoptron proposed in
33 | * [OrabonaICML2008] Francesco Orabona, Joseph Keshet, and Barbara Caputo. The projectron: a bounded kernel-based perceptron. In Int. Conf. on Machine Learning (2008) 34 | *
35 | * Until the budget is not reached the online learning updating policy is the one of the baseAlgorithm that this 36 | * meta-algorithm is exploiting. When the budget is full, the learning process ends 37 | * 38 | * @author Simone Filice 39 | * 40 | */ 41 | @JsonTypeName("stoptron") 42 | public class Stoptron extends BudgetedLearningAlgorithm implements MetaLearningAlgorithm{ 43 | 44 | private OnlineLearningAlgorithm baseAlgorithm; 45 | 46 | public Stoptron(){ 47 | 48 | } 49 | 50 | public Stoptron(int budget, OnlineLearningAlgorithm baseAlgorithm, Label label){ 51 | this.setBudget(budget); 52 | this.setBaseAlgorithm(baseAlgorithm); 53 | this.setLabel(label); 54 | } 55 | 56 | @Override 57 | public Stoptron duplicate() { 58 | Stoptron copy = new Stoptron(); 59 | copy.setBudget(budget); 60 | copy.setBaseAlgorithm(baseAlgorithm.duplicate()); 61 | return copy; 62 | } 63 | 64 | @Override 65 | public void reset() { 66 | this.baseAlgorithm.reset(); 67 | } 68 | 69 | @Override 70 | protected Prediction predictAndLearnWithFullBudget(Example example) { 71 | return this.baseAlgorithm.getPredictionFunction().predict(example); 72 | } 73 | 74 | @Override 75 | public void setBaseAlgorithm(LearningAlgorithm baseAlgorithm) { 76 | if(baseAlgorithm instanceof OnlineLearningAlgorithm && baseAlgorithm instanceof KernelMethod && baseAlgorithm instanceof BinaryLearningAlgorithm){ 77 | this.baseAlgorithm = (OnlineLearningAlgorithm) baseAlgorithm; 78 | }else{ 79 | throw new IllegalArgumentException("a valid baseAlgorithm for the Stoptron must implement OnlineLearningAlgorithm, BinaryLeaningAlgorithm and KernelMethod"); 80 | } 81 | } 82 | 83 | @Override 84 | public OnlineLearningAlgorithm getBaseAlgorithm() { 85 | return this.baseAlgorithm; 86 | } 87 | 88 | @Override 89 | public PredictionFunction getPredictionFunction() { 90 | return this.baseAlgorithm.getPredictionFunction(); 91 | } 92 | 93 | @Override 94 | public Kernel getKernel() { 95 | return ((KernelMethod)this.baseAlgorithm).getKernel(); 96 | } 97 | 98 | @Override 99 | public void setKernel(Kernel kernel) { 100 | ((KernelMethod)this.baseAlgorithm).setKernel(kernel); 101 | 102 | } 103 | 104 | @Override 105 | protected Prediction predictAndLearnWithAvailableBudget(Example example) { 106 | return this.baseAlgorithm.learn(example); 107 | } 108 | 109 | @Override 110 | public void setPredictionFunction(PredictionFunction predictionFunction) { 111 | this.baseAlgorithm.setPredictionFunction(predictionFunction); 112 | } 113 | 114 | } 115 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/dcd/DCDLoss.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.dcd; 2 | 3 | public enum DCDLoss { 4 | L1, L2 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/hmm/SequenceClassificationKernelBasedLearningAlgorithm.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.hmm; 17 | 18 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGenerator; 19 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGeneratorKernel; 20 | import it.uniroma2.sag.kelp.kernel.Kernel; 21 | import it.uniroma2.sag.kelp.kernel.cache.KernelCache; 22 | import it.uniroma2.sag.kelp.kernel.standard.LinearKernelCombination; 23 | import it.uniroma2.sag.kelp.kernel.vector.LinearKernel; 24 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 25 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 26 | import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm; 27 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 28 | 29 | /** 30 | * /** This class implements a sequential labeling paradigm. 31 | * Given sequences of items (each implemented as an Example and 32 | * associated to one Label) this class allow to apply a generic 33 | * LearningAlgorithm to use the "history" of each item in the 34 | * sequence in order to improve the classification quality. In other words, the 35 | * classification of each example does not depend only its representation, but 36 | * it also depend on its "history", in terms of the classed assigned to the 37 | * preceding examples. 38 | * This class should be used when a kernel-based learning algorithm is 39 | * used, thus directly operating in the implicit space underlying a kernel 40 | * function. 41 | * 42 | * 43 | * This algorithms was inspired by the work of: 44 | * Y. Altun, I. Tsochantaridis, and T. Hofmann. Hidden Markov support vector 45 | * machines. In Proceedings of the Twentieth International Conference on Machine 46 | * Learning, 2003. 47 | * 48 | * @author Danilo Croce 49 | * 50 | */ 51 | public class SequenceClassificationKernelBasedLearningAlgorithm extends SequenceClassificationLearningAlgorithm 52 | implements KernelMethod { 53 | 54 | private final static String TRANSITION_REPRESENTATION_NAME = "__trans_rep__"; 55 | 56 | private LinearKernelCombination sequenceBasedKernel; 57 | 58 | public SequenceClassificationKernelBasedLearningAlgorithm() { 59 | 60 | } 61 | 62 | /** 63 | * @param baseLearningAlgorithm 64 | * the learning algorithm devoted to the acquisition of a model 65 | * after that each example has been enriched with its "history" 66 | * @param transitionsOrder 67 | * given a targeted item in the sequence, this variable 68 | * determines the number of previous example considered in the 69 | * learning/labeling process. 70 | * @param transitionWeight 71 | * the importance of the transition-based features during the 72 | * learning process. Higher valuers will assign more importance 73 | * to the transitions. 74 | * @throws Exception 75 | * The input baseLearningAlgorithm is not a 76 | * kernel-based method 77 | */ 78 | public SequenceClassificationKernelBasedLearningAlgorithm(BinaryLearningAlgorithm baseLearningAlgorithm, 79 | int transitionsOrder, float transitionWeight) throws Exception { 80 | 81 | if (!(baseLearningAlgorithm instanceof KernelMethod)) { 82 | throw new Exception("ERROR: the input baseLearningAlgorithm is not a kernel-based method!"); 83 | } 84 | 85 | Kernel inputKernel = ((KernelMethod) baseLearningAlgorithm).getKernel(); 86 | 87 | sequenceBasedKernel = new LinearKernelCombination(); 88 | sequenceBasedKernel.addKernel(1, inputKernel); 89 | Kernel transitionBasedKernel = new LinearKernel(TRANSITION_REPRESENTATION_NAME); 90 | sequenceBasedKernel.addKernel(transitionWeight, transitionBasedKernel); 91 | sequenceBasedKernel.normalizeWeights(); 92 | 93 | setKernel(sequenceBasedKernel); 94 | 95 | BinaryLearningAlgorithm binaryLearningAlgorithmCopy = (BinaryLearningAlgorithm) baseLearningAlgorithm 96 | .duplicate(); 97 | 98 | ((KernelMethod) binaryLearningAlgorithmCopy).setKernel(sequenceBasedKernel); 99 | 100 | OneVsAllLearning oneVsAllLearning = new OneVsAllLearning(); 101 | oneVsAllLearning.setBaseAlgorithm(binaryLearningAlgorithmCopy); 102 | 103 | super.setBaseLearningAlgorithm(oneVsAllLearning); 104 | 105 | SequenceExampleGenerator sequenceExamplesGenerator = new SequenceExampleGeneratorKernel( 106 | transitionsOrder, TRANSITION_REPRESENTATION_NAME); 107 | 108 | super.setSequenceExampleGenerator(sequenceExamplesGenerator); 109 | } 110 | 111 | @Override 112 | public LearningAlgorithm duplicate() { 113 | return null; 114 | } 115 | 116 | @Override 117 | public LearningAlgorithm getBaseAlgorithm() { 118 | return super.getBaseLearningAlgorithm(); 119 | } 120 | 121 | @Override 122 | public Kernel getKernel() { 123 | return sequenceBasedKernel; 124 | } 125 | 126 | @Override 127 | public void setBaseAlgorithm(LearningAlgorithm baseAlgorithm) { 128 | super.setBaseLearningAlgorithm(baseAlgorithm); 129 | } 130 | 131 | @Override 132 | public void setKernel(Kernel kernel) { 133 | this.sequenceBasedKernel = (LinearKernelCombination) kernel; 134 | } 135 | 136 | public void setKernelCache(KernelCache cache) { 137 | this.getKernel().setKernelCache(cache); 138 | } 139 | 140 | } 141 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/hmm/SequenceClassificationLinearLearningAlgorithm.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.hmm; 17 | 18 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGenerator; 19 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGeneratorLinear; 20 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 21 | import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm; 22 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 23 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 24 | 25 | /** 26 | * This class implements a sequential labeling paradigm. 27 | * Given sequences of items (each implemented as an Example and 28 | * associated to one Label) this class allow to apply a generic 29 | * LearningAlgorithm to use the "history" of each item in the 30 | * sequence in order to improve the classification quality. In other words, the 31 | * classification of each example does not depend only its representation, but 32 | * it also depend on its "history", in terms of the classed assigned to the 33 | * preceding examples. 34 | * This class should be used when a linear learning algorithm is used, 35 | * thus directly operating in the representation space. 36 | * 37 | * 38 | * This algorithms was inspired by the work of: 39 | * Y. Altun, I. Tsochantaridis, and T. Hofmann. Hidden Markov support vector 40 | * machines. In Proceedings of the Twentieth International Conference on Machine 41 | * Learning, 2003. 42 | * 43 | * @author Danilo Croce 44 | * 45 | */ 46 | public class SequenceClassificationLinearLearningAlgorithm extends SequenceClassificationLearningAlgorithm 47 | implements LinearMethod { 48 | 49 | /** 50 | * @param baseLearningAlgorithm 51 | * the "linear" learning algorithm devoted to the acquisition of 52 | * a model after that each example has been enriched with its 53 | * "history" 54 | * @param transitionsOrder 55 | * given a targeted item in the sequence, this variable 56 | * determines the number of previous example considered in the 57 | * learning/labeling process. 58 | * @param transitionWeight 59 | * the importance of the transition-based features during the 60 | * learning process. Higher valuers will assign more importance 61 | * to the transitions. 62 | * @throws Exception The input baseLearningAlgorithm is not a Linear method 63 | */ 64 | public SequenceClassificationLinearLearningAlgorithm(BinaryLearningAlgorithm baseLearningAlgorithm, 65 | int transitionsOrder, float transitionWeight) throws Exception { 66 | 67 | if (!(baseLearningAlgorithm instanceof LinearMethod)) { 68 | throw new Exception("ERROR: the input baseLearningAlgorithm is not a Linear method!"); 69 | } 70 | 71 | OneVsAllLearning oneVsAllLearning = new OneVsAllLearning(); 72 | oneVsAllLearning.setBaseAlgorithm(baseLearningAlgorithm); 73 | 74 | super.setBaseLearningAlgorithm(oneVsAllLearning); 75 | String representation = ((LinearMethod) baseLearningAlgorithm).getRepresentation(); 76 | 77 | SequenceExampleGenerator sequenceExamplesGenerator = new SequenceExampleGeneratorLinear(transitionsOrder, 78 | representation, transitionWeight); 79 | 80 | super.setSequenceExampleGenerator(sequenceExamplesGenerator); 81 | } 82 | 83 | @Override 84 | public LearningAlgorithm duplicate() { 85 | // TODO Auto-generated method stub 86 | return null; 87 | } 88 | 89 | @Override 90 | public LearningAlgorithm getBaseAlgorithm() { 91 | return super.getBaseLearningAlgorithm(); 92 | } 93 | 94 | @Override 95 | public String getRepresentation() { 96 | return ((SequenceClassificationLinearLearningAlgorithm) getSequenceExampleGenerator()).getRepresentation(); 97 | } 98 | 99 | @Override 100 | public void setBaseAlgorithm(LearningAlgorithm baseAlgorithm) { 101 | super.setBaseLearningAlgorithm(baseAlgorithm); 102 | } 103 | 104 | @Override 105 | public void setRepresentation(String representationName) { 106 | ((SequenceClassificationLinearLearningAlgorithm) getSequenceExampleGenerator()) 107 | .setRepresentation(representationName); 108 | } 109 | 110 | } 111 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/COPYRIGHT: -------------------------------------------------------------------------------- 1 | 2 | Copyright (c) 2007-2013 The LIBLINEAR Project. 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions 7 | are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | 16 | 3. Neither name of copyright holders nor the names of its contributors 17 | may be used to endorse or promote products derived from this software 18 | without specific prior written permission. 19 | 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR 25 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 26 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 27 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 28 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 29 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 30 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 31 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/L2R_L2_SvcFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | /** 19 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 20 | * C++ sources. Original Java sources (v 1.94) are available at 21 | * http://liblinear.bwaldvogel.de 22 | * 23 | * @author Danilo Croce 24 | */ 25 | public class L2R_L2_SvcFunction implements TronFunction { 26 | 27 | protected final Problem prob; 28 | protected final double[] C; 29 | protected final int[] I; 30 | protected final double[] z; 31 | 32 | protected int sizeI; 33 | 34 | public L2R_L2_SvcFunction(Problem prob, double[] C) { 35 | int l = prob.l; 36 | 37 | this.prob = prob; 38 | 39 | z = new double[l]; 40 | I = new int[l]; 41 | this.C = C; 42 | } 43 | 44 | public double fun(double[] w) { 45 | int i; 46 | double f = 0; 47 | double[] y = prob.y; 48 | int l = prob.l; 49 | int w_size = get_nr_variable(); 50 | 51 | Xv(w, z); 52 | 53 | for (i = 0; i < w_size; i++) 54 | f += w[i] * w[i]; 55 | f /= 2.0; 56 | for (i = 0; i < l; i++) { 57 | z[i] = y[i] * z[i]; 58 | double d = 1 - z[i]; 59 | if (d > 0) 60 | f += C[i] * d * d; 61 | } 62 | 63 | return (f); 64 | } 65 | 66 | public int get_nr_variable() { 67 | return prob.n; 68 | } 69 | 70 | public void grad(double[] w, double[] g) { 71 | double[] y = prob.y; 72 | int l = prob.l; 73 | int w_size = get_nr_variable(); 74 | 75 | sizeI = 0; 76 | for (int i = 0; i < l; i++) { 77 | if (z[i] < 1) { 78 | z[sizeI] = C[i] * y[i] * (z[i] - 1); 79 | I[sizeI] = i; 80 | sizeI++; 81 | } 82 | } 83 | subXTv(z, g); 84 | 85 | for (int i = 0; i < w_size; i++) 86 | g[i] = w[i] + 2 * g[i]; 87 | } 88 | 89 | public void Hv(double[] s, double[] Hs) { 90 | int i; 91 | int w_size = get_nr_variable(); 92 | double[] wa = new double[sizeI]; 93 | 94 | subXv(s, wa); 95 | for (i = 0; i < sizeI; i++) 96 | wa[i] = C[I[i]] * wa[i]; 97 | 98 | subXTv(wa, Hs); 99 | for (i = 0; i < w_size; i++) 100 | Hs[i] = s[i] + 2 * Hs[i]; 101 | } 102 | 103 | protected void subXTv(double[] v, double[] XTv) { 104 | int i; 105 | int w_size = get_nr_variable(); 106 | 107 | for (i = 0; i < w_size; i++) 108 | XTv[i] = 0; 109 | 110 | for (i = 0; i < sizeI; i++) { 111 | for (LibLinearFeature s : prob.x[I[i]]) { 112 | XTv[s.getIndex() - 1] += v[i] * s.getValue(); 113 | } 114 | } 115 | } 116 | 117 | private void subXv(double[] v, double[] Xv) { 118 | 119 | for (int i = 0; i < sizeI; i++) { 120 | Xv[i] = 0; 121 | for (LibLinearFeature s : prob.x[I[i]]) { 122 | Xv[i] += v[s.getIndex() - 1] * s.getValue(); 123 | } 124 | } 125 | } 126 | 127 | protected void Xv(double[] v, double[] Xv) { 128 | 129 | for (int i = 0; i < prob.l; i++) { 130 | Xv[i] = 0; 131 | for (LibLinearFeature s : prob.x[i]) { 132 | Xv[i] += v[s.getIndex() - 1] * s.getValue(); 133 | } 134 | } 135 | } 136 | 137 | } 138 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/L2R_L2_SvrFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | /** 19 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 20 | * C++ sources. Original Java sources (v 1.94) are available at 21 | * http://liblinear.bwaldvogel.de 22 | * 23 | * @author Danilo Croce 24 | */ 25 | public class L2R_L2_SvrFunction extends L2R_L2_SvcFunction { 26 | 27 | private double p; 28 | 29 | public L2R_L2_SvrFunction( Problem prob, double[] C, double p ) { 30 | super(prob, C); 31 | this.p = p; 32 | } 33 | 34 | @Override 35 | public double fun(double[] w) { 36 | double f = 0; 37 | double[] y = prob.y; 38 | int l = prob.l; 39 | int w_size = get_nr_variable(); 40 | double d; 41 | 42 | Xv(w, z); 43 | 44 | for (int i = 0; i < w_size; i++) 45 | f += w[i] * w[i]; 46 | f /= 2; 47 | for (int i = 0; i < l; i++) { 48 | d = z[i] - y[i]; 49 | if (d < -p) 50 | f += C[i] * (d + p) * (d + p); 51 | else if (d > p) f += C[i] * (d - p) * (d - p); 52 | } 53 | 54 | return f; 55 | } 56 | 57 | @Override 58 | public void grad(double[] w, double[] g) { 59 | double[] y = prob.y; 60 | int l = prob.l; 61 | int w_size = get_nr_variable(); 62 | 63 | sizeI = 0; 64 | for (int i = 0; i < l; i++) { 65 | double d = z[i] - y[i]; 66 | 67 | // generate index set I 68 | if (d < -p) { 69 | z[sizeI] = C[i] * (d + p); 70 | I[sizeI] = i; 71 | sizeI++; 72 | } else if (d > p) { 73 | z[sizeI] = C[i] * (d - p); 74 | I[sizeI] = i; 75 | sizeI++; 76 | } 77 | 78 | } 79 | subXTv(z, g); 80 | 81 | for (int i = 0; i < w_size; i++) 82 | g[i] = w[i] + 2 * g[i]; 83 | 84 | } 85 | 86 | } -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/LibLinearFeature.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | /** 19 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 20 | * C++ sources. Original Java sources (v 1.94) are available at 21 | * http://liblinear.bwaldvogel.de 22 | * 23 | * @author Danilo Croce 24 | */ 25 | public interface LibLinearFeature { 26 | 27 | int getIndex(); 28 | 29 | double getValue(); 30 | 31 | void setValue(double value); 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/LibLinearFeatureNode.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | /** 19 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 20 | * C++ sources. Original Java sources (v 1.94) are available at 21 | * http://liblinear.bwaldvogel.de 22 | * 23 | * @author Danilo Croce 24 | */ 25 | public class LibLinearFeatureNode implements LibLinearFeature { 26 | 27 | public final int index; 28 | public double value; 29 | 30 | public LibLinearFeatureNode( final int index, final double value ) { 31 | if (index < 0) throw new IllegalArgumentException("index must be >= 0"); 32 | this.index = index; 33 | this.value = value; 34 | } 35 | 36 | /** 37 | * @since 1.9 38 | */ 39 | public int getIndex() { 40 | return index; 41 | } 42 | 43 | /** 44 | * @since 1.9 45 | */ 46 | public double getValue() { 47 | return value; 48 | } 49 | 50 | /** 51 | * @since 1.9 52 | */ 53 | public void setValue(double value) { 54 | this.value = value; 55 | } 56 | 57 | @Override 58 | public int hashCode() { 59 | final int prime = 31; 60 | int result = 1; 61 | result = prime * result + index; 62 | long temp; 63 | temp = Double.doubleToLongBits(value); 64 | result = prime * result + (int)(temp ^ (temp >>> 32)); 65 | return result; 66 | } 67 | 68 | @Override 69 | public boolean equals(Object obj) { 70 | if (this == obj) return true; 71 | if (obj == null) return false; 72 | if (getClass() != obj.getClass()) return false; 73 | LibLinearFeatureNode other = (LibLinearFeatureNode)obj; 74 | if (index != other.index) return false; 75 | if (Double.doubleToLongBits(value) != Double.doubleToLongBits(other.value)) return false; 76 | return true; 77 | } 78 | 79 | @Override 80 | public String toString() { 81 | return "FeatureNode(idx=" + index + ", value=" + value + ")"; 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/liblinear/solver/Problem.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver; 17 | 18 | /** 19 | * NOTE: This code has been adapted from the Java port of the original LIBLINEAR 20 | * C++ sources. Original Java sources (v 1.94) are available at 21 | * http://liblinear.bwaldvogel.de 22 | * 23 | * @author Danilo Croce 24 | */ 25 | import gnu.trove.map.hash.TIntObjectHashMap; 26 | import gnu.trove.map.hash.TObjectIntHashMap; 27 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 28 | import it.uniroma2.sag.kelp.data.example.Example; 29 | import it.uniroma2.sag.kelp.data.example.SimpleExample; 30 | import it.uniroma2.sag.kelp.data.label.Label; 31 | import it.uniroma2.sag.kelp.data.representation.Representation; 32 | import it.uniroma2.sag.kelp.data.representation.Vector; 33 | import it.uniroma2.sag.kelp.data.representation.vector.DenseVector; 34 | import it.uniroma2.sag.kelp.data.representation.vector.SparseVector; 35 | 36 | import java.io.IOException; 37 | import java.util.ArrayList; 38 | import java.util.Map; 39 | 40 | /** 41 | *
LearningAlgorithm
baseLearningAlgorithm
http://liblinear.bwaldvogel.de
42 | * Describes the problem 43 | *
48 | * LABEL ATTR1 ATTR2 ATTR3 ATTR4 ATTR5 49 | * ----- ----- ----- ----- ----- ----- 50 | * 1 0 0.1 0.2 0 0 51 | * 2 0 0.1 0.3 -1.2 0 52 | * 1 0.4 0 0 0 0 53 | * 2 0 0.1 0 1.4 0.5 54 | * 3 -0.1 -0.2 0.1 1.1 0.1 55 | * 56 | * and bias = 1, then the components of problem are: 57 | * 58 | * l = 5 59 | * n = 6 60 | * 61 | * y -> 1 2 1 2 3 62 | * 63 | * x -> [ ] -> (2,0.1) (3,0.2) (6,1) (-1,?) 64 | * [ ] -> (2,0.1) (3,0.3) (4,-1.2) (6,1) (-1,?) 65 | * [ ] -> (1,0.4) (6,1) (-1,?) 66 | * [ ] -> (2,0.1) (4,1.4) (5,0.5) (6,1) (-1,?) 67 | * [ ] -> (1,-0.1) (2,-0.2) (3,0.1) (4,1.1) (5,0.1) (6,1) (-1,?) 68 | *
224 | * vector2 += constant * vector1 225 | *
35 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 36 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 37 | * 38 | *
The standard algorithm is modified, including the fairness extention from
39 | * [FiliceECIR2014] S. Filice, G. Castellucci, D. Croce, and R. Basili. Effective Kernelized Online Learning 40 | * in Language Processing Tasks. In collection of Advances in Information Retrieval, pp. 347–358, Springer International Publishing, 2014. 41 | * 42 | * 43 | * @author Simone Filice 44 | */ 45 | 46 | @JsonTypeName("kernelizedPA") 47 | public class KernelizedPassiveAggressiveClassification extends PassiveAggressiveClassification implements KernelMethod{ 48 | 49 | private Kernel kernel; 50 | 51 | public KernelizedPassiveAggressiveClassification(){ 52 | this.classifier = new BinaryKernelMachineClassifier(); 53 | this.classifier.setModel(new BinaryKernelMachineModel()); 54 | } 55 | 56 | public KernelizedPassiveAggressiveClassification(float cp, float cn, Loss loss, Policy policy, Kernel kernel, Label label){ 57 | this.classifier = new BinaryKernelMachineClassifier(); 58 | this.classifier.setModel(new BinaryKernelMachineModel()); 59 | this.setKernel(kernel); 60 | this.setLoss(loss); 61 | this.setCp(cp); 62 | this.setCn(cn); 63 | this.setLabel(label); 64 | this.setPolicy(policy); 65 | } 66 | 67 | 68 | @Override 69 | public Kernel getKernel() { 70 | return kernel; 71 | } 72 | 73 | @Override 74 | public void setKernel(Kernel kernel) { 75 | this.kernel = kernel; 76 | this.getPredictionFunction().getModel().setKernel(kernel); 77 | } 78 | 79 | 80 | @Override 81 | public KernelizedPassiveAggressiveClassification duplicate(){ 82 | KernelizedPassiveAggressiveClassification copy = new KernelizedPassiveAggressiveClassification(); 83 | copy.setCp(this.cp); 84 | copy.setCn(c); 85 | copy.setFairness(this.fairness); 86 | copy.setKernel(this.kernel); 87 | copy.setLoss(this.loss); 88 | copy.setPolicy(this.policy); 89 | //copy.setLabel(label); 90 | return copy; 91 | } 92 | 93 | @Override 94 | public BinaryKernelMachineClassifier getPredictionFunction(){ 95 | return (BinaryKernelMachineClassifier) this.classifier; 96 | } 97 | 98 | @Override 99 | public void setPredictionFunction(PredictionFunction predictionFunction) { 100 | this.classifier = (BinaryKernelMachineClassifier) predictionFunction; 101 | } 102 | 103 | } 104 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/passiveaggressive/LinearPassiveAggressiveClassification.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.passiveaggressive; 17 | 18 | import it.uniroma2.sag.kelp.data.label.Label; 19 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 20 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 21 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryLinearClassifier; 22 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 23 | 24 | import com.fasterxml.jackson.annotation.JsonTypeName; 25 | 26 | 27 | /** 28 | * Online Passive-Aggressive Learning Algorithm for classification tasks (linear version) . 29 | * Every time an example is misclassified it is added the the current hyperplane, with the weight that solves the 30 | * passive aggressive minimization problem 31 | * 32 | * reference: 33 | *
34 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 35 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 36 | * 37 | *
38 | * [FiliceECIR2014] S. Filice, G. Castellucci, D. Croce, and R. Basili. Effective Kernelized Online Learning 39 | * in Language Processing Tasks. In collection of Advances in Information Retrieval, pp. 347–358, Springer International Publishing, 2014. 40 | * 41 | * @author Simone Filice 42 | */ 43 | @JsonTypeName("linearPA") 44 | public class LinearPassiveAggressiveClassification extends PassiveAggressiveClassification implements LinearMethod{ 45 | 46 | private String representation; 47 | 48 | public LinearPassiveAggressiveClassification(){ 49 | this.classifier = new BinaryLinearClassifier(); 50 | this.classifier.setModel(new BinaryLinearModel()); 51 | } 52 | 53 | public LinearPassiveAggressiveClassification(float cp, float cn, Loss loss, Policy policy, String representation, Label label){ 54 | this.classifier = new BinaryLinearClassifier(); 55 | this.classifier.setModel(new BinaryLinearModel()); 56 | this.setCp(cp); 57 | this.setCn(cn); 58 | this.setLoss(loss); 59 | this.setPolicy(policy); 60 | this.setRepresentation(representation); 61 | this.setLabel(label); 62 | } 63 | 64 | @Override 65 | public String getRepresentation() { 66 | return representation; 67 | } 68 | 69 | @Override 70 | public void setRepresentation(String representation) { 71 | this.representation = representation; 72 | BinaryLinearModel model = (BinaryLinearModel) this.classifier.getModel(); 73 | model.setRepresentation(representation); 74 | } 75 | 76 | @Override 77 | public LinearPassiveAggressiveClassification duplicate(){ 78 | LinearPassiveAggressiveClassification copy = new LinearPassiveAggressiveClassification(); 79 | copy.setRepresentation(this.representation); 80 | copy.setCp(this.cp); 81 | copy.setCn(this.c); 82 | copy.setFairness(this.fairness); 83 | copy.setLoss(this.loss); 84 | copy.setPolicy(this.policy); 85 | return copy; 86 | } 87 | 88 | @Override 89 | public BinaryLinearClassifier getPredictionFunction(){ 90 | return (BinaryLinearClassifier) this.classifier; 91 | } 92 | 93 | @Override 94 | public void setPredictionFunction(PredictionFunction predictionFunction) { 95 | this.classifier = (BinaryLinearClassifier) predictionFunction; 96 | } 97 | 98 | } 99 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/passiveaggressive/PassiveAggressiveClassification.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.passiveaggressive; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.learningalgorithm.PassiveAggressive; 21 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 22 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryClassifier; 23 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput; 24 | 25 | import com.fasterxml.jackson.annotation.JsonIgnore; 26 | import com.fasterxml.jackson.annotation.JsonProperty; 27 | 28 | /** 29 | * Online Passive-Aggressive Learning Algorithm for classification tasks. 30 | * Every time an example is misclassified it is added the the current hyperplane, with the weight that solves the 31 | * passive aggressive minimization problem 32 | * 33 | * reference: 34 | *
39 | * [FiliceECIR2014] S. Filice, G. Castellucci, D. Croce, and R. Basili. Effective Kernelized Online Learning 40 | * in Language Processing Tasks. In collection of Advances in Information Retrieval, pp. 347-358, Springer International Publishing, 2014. 41 | * 42 | * @author Simone Filice 43 | */ 44 | public abstract class PassiveAggressiveClassification extends PassiveAggressive implements ClassificationLearningAlgorithm{ 45 | 46 | public enum Loss{ 47 | HINGE, 48 | RAMP 49 | } 50 | 51 | protected Loss loss = Loss.HINGE; 52 | protected float cp = c;//cp is the aggressiveness w.r.t. positive examples. c will be considered the aggressiveness w.r.t. negative examples 53 | protected boolean fairness = false; 54 | 55 | @JsonIgnore 56 | protected BinaryClassifier classifier; 57 | 58 | 59 | /** 60 | * @return the fairness 61 | */ 62 | public boolean isFairness() { 63 | return fairness; 64 | } 65 | 66 | 67 | /** 68 | * @param fairness the fairness to set 69 | */ 70 | public void setFairness(boolean fairness) { 71 | this.fairness = fairness; 72 | } 73 | 74 | /** 75 | * @return the aggressiveness parameter for positive examples 76 | */ 77 | public float getCp() { 78 | return cp; 79 | } 80 | 81 | 82 | /** 83 | * @param cp the aggressiveness parameter for positive examples 84 | */ 85 | public void setCp(float cp) { 86 | this.cp = cp; 87 | } 88 | 89 | /** 90 | * @return the aggressiveness parameter for negative examples 91 | */ 92 | public float getCn() { 93 | return c; 94 | } 95 | 96 | 97 | /** 98 | * @param cn the aggressiveness parameter for negative examples 99 | */ 100 | public void setCn(float cn) { 101 | this.c = cn; 102 | } 103 | 104 | @Override 105 | @JsonIgnore 106 | public float getC(){ 107 | return c; 108 | } 109 | 110 | @Override 111 | @JsonProperty 112 | public void setC(float c){ 113 | super.setC(c); 114 | this.cp=c; 115 | } 116 | 117 | /** 118 | * @return the loss function type 119 | */ 120 | public Loss getLoss() { 121 | return loss; 122 | } 123 | 124 | 125 | /** 126 | * @param loss the loss function type to set 127 | */ 128 | public void setLoss(Loss loss) { 129 | this.loss = loss; 130 | } 131 | 132 | @Override 133 | public BinaryClassifier getPredictionFunction() { 134 | return this.classifier; 135 | } 136 | 137 | @Override 138 | public BinaryMarginClassifierOutput learn(Example example){ 139 | 140 | BinaryMarginClassifierOutput prediction=this.classifier.predict(example); 141 | 142 | float lossValue = 0;//it represents the distance from the correct semi-space 143 | if(prediction.isClassPredicted(label)!=example.isExampleOf(label)){ 144 | lossValue = 1 + Math.abs(prediction.getScore(label)); 145 | }else if(Math.abs(prediction.getScore(label))<1){ 146 | lossValue = 1 - Math.abs(prediction.getScore(label)); 147 | } 148 | 149 | if(lossValue>0 && (lossValue<2 || this.loss!=Loss.RAMP)){ 150 | float exampleAggressiveness=this.c; 151 | if(example.isExampleOf(label)){ 152 | exampleAggressiveness=cp; 153 | } 154 | float exampleSquaredNorm = this.classifier.getModel().getSquaredNorm(example); 155 | float weight = this.computeWeight(example, lossValue, exampleSquaredNorm ,exampleAggressiveness); 156 | if(!example.isExampleOf(label)){ 157 | weight*=-1; 158 | } 159 | this.getPredictionFunction().getModel().addExample(weight, example); 160 | } 161 | return prediction; 162 | 163 | } 164 | 165 | @Override 166 | public void learn(Dataset dataset){ 167 | if(this.fairness){ 168 | float positiveExample = dataset.getNumberOfPositiveExamples(label); 169 | float negativeExample = dataset.getNumberOfNegativeExamples(label); 170 | cp = c * negativeExample / positiveExample; 171 | } 172 | //System.out.println("cn: " + c + " cp: " + cp); 173 | super.learn(dataset); 174 | } 175 | 176 | } 177 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/pegasos/PegasosLearningAlgorithm.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.pegasos; 17 | 18 | import java.util.ArrayList; 19 | import java.util.Arrays; 20 | import java.util.List; 21 | 22 | import com.fasterxml.jackson.annotation.JsonTypeName; 23 | 24 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 25 | import it.uniroma2.sag.kelp.data.example.Example; 26 | import it.uniroma2.sag.kelp.data.label.Label; 27 | import it.uniroma2.sag.kelp.data.representation.Vector; 28 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 29 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 30 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 31 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 32 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryLinearClassifier; 33 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput; 34 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 35 | 36 | /** 37 | * It implements the Primal Estimated sub-GrAdient SOlver (PEGASOS) for SVM. It is a learning 38 | * algorithm for binary linear classification Support Vector Machines. It operates in an explicit 39 | * feature space (i.e. it does not relies on any kernel). Further details can be found in:
40 | * 41 | * [SingerICML2007] Y. Singer and N. Srebro. Pegasos: Primal estimated sub-gradient solver for SVM. 42 | * In Proceeding of ICML 2007. 43 | * 44 | * @author Simone Filice 45 | * 46 | */ 47 | @JsonTypeName("pegasos") 48 | public class PegasosLearningAlgorithm implements LinearMethod, ClassificationLearningAlgorithm, BinaryLearningAlgorithm{ 49 | 50 | private Label label; 51 | 52 | private BinaryLinearClassifier classifier; 53 | 54 | private int k = 1; 55 | private int iterations = 1000; 56 | private float lambda = 0.01f; 57 | 58 | private String representation; 59 | 60 | /** 61 | * Returns the number of examples k that Pegasos exploits in its 62 | * mini-batch learning approach 63 | * 64 | * @return k 65 | */ 66 | public int getK() { 67 | return k; 68 | } 69 | 70 | /** 71 | * Sets the number of examples k that Pegasos exploits in its 72 | * mini-batch learning approach 73 | * 74 | * @param k the k to set 75 | */ 76 | public void setK(int k) { 77 | this.k = k; 78 | } 79 | 80 | /** 81 | * Returns the number of iterations 82 | * 83 | * @return the number of iterations 84 | */ 85 | public int getIterations() { 86 | return iterations; 87 | } 88 | 89 | /** 90 | * Sets the number of iterations 91 | * 92 | * @param T the number of iterations to set 93 | */ 94 | public void setIterations(int T) { 95 | this.iterations = T; 96 | } 97 | 98 | /** 99 | * Returns the regularization coefficient 100 | * 101 | * @return the lambda 102 | */ 103 | public float getLambda() { 104 | return lambda; 105 | } 106 | 107 | /** 108 | * Sets the regularization coefficient 109 | * 110 | * @param lambda the lambda to set 111 | */ 112 | public void setLambda(float lambda) { 113 | this.lambda = lambda; 114 | } 115 | 116 | public PegasosLearningAlgorithm(){ 117 | this.classifier = new BinaryLinearClassifier(); 118 | this.classifier.setModel(new BinaryLinearModel()); 119 | } 120 | 121 | public PegasosLearningAlgorithm(int k, float lambda, int T, String Representation, Label label){ 122 | this.classifier = new BinaryLinearClassifier(); 123 | this.classifier.setModel(new BinaryLinearModel()); 124 | this.setK(k); 125 | this.setLabel(label); 126 | this.setLambda(lambda); 127 | this.setRepresentation(Representation); 128 | this.setIterations(T); 129 | } 130 | 131 | @Override 132 | public String getRepresentation() { 133 | return representation; 134 | } 135 | 136 | @Override 137 | public void setRepresentation(String representation) { 138 | this.representation = representation; 139 | BinaryLinearModel model = this.classifier.getModel(); 140 | model.setRepresentation(representation); 141 | } 142 | 143 | @Override 144 | public void learn(Dataset dataset) { 145 | if(this.getPredictionFunction().getModel().getHyperplane()==null){ 146 | this.getPredictionFunction().getModel().setHyperplane(dataset.getZeroVector(representation)); 147 | } 148 | 149 | for(int t=1;t<=iterations;t++){ 150 | 151 | List A_t = dataset.getRandExamples(k); 152 | List A_tp = new ArrayList(); 153 | List signA_tp = new ArrayList(); 154 | float eta_t = ((float)1)/(lambda*t); 155 | Vector w_t = this.getPredictionFunction().getModel().getHyperplane(); 156 | 157 | //creating A_tp 158 | for(Example example: A_t){ 159 | BinaryMarginClassifierOutput prediction = this.classifier.predict(example); 160 | float y = -1; 161 | if(example.isExampleOf(label)){ 162 | y=1; 163 | } 164 | 165 | if(prediction.getScore(label)*y<1){ 166 | A_tp.add(example); 167 | signA_tp.add(y); 168 | } 169 | } 170 | //creating w_(t+1/2) 171 | w_t.scale(1-eta_t*lambda); 172 | float miscassificationFactor = eta_t/k; 173 | for(int i=0; i labels){ 211 | if(labels.size()!=1){ 212 | throw new IllegalArgumentException("Pegasos algorithm is a binary method which can learn a single Label"); 213 | } 214 | else{ 215 | this.label=labels.get(0); 216 | this.classifier.setLabels(labels); 217 | } 218 | } 219 | 220 | 221 | @Override 222 | public List getLabels() { 223 | return Arrays.asList(label); 224 | } 225 | 226 | @Override 227 | public Label getLabel(){ 228 | return this.label; 229 | } 230 | 231 | @Override 232 | public void setLabel(Label label){ 233 | this.setLabels(Arrays.asList(label)); 234 | } 235 | 236 | @Override 237 | public void setPredictionFunction(PredictionFunction predictionFunction) { 238 | this.classifier = (BinaryLinearClassifier) predictionFunction; 239 | } 240 | } 241 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/perceptron/KernelizedPerceptron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron; 17 | 18 | 19 | import it.uniroma2.sag.kelp.data.label.Label; 20 | import it.uniroma2.sag.kelp.kernel.Kernel; 21 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 22 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 23 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryKernelMachineClassifier; 24 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryKernelMachineModel; 25 | 26 | import com.fasterxml.jackson.annotation.JsonTypeName; 27 | 28 | /** 29 | * The perceptron learning algorithm algorithm for classification tasks (Kernel machine version). Reference: 30 | * [Rosenblatt1957] F. Rosenblatt. The Perceptron – a perceiving and recognizing automaton. Report 85-460-1, Cornell Aeronautical Laboratory (1957) 31 | * 32 | * @author Simone Filice 33 | * 34 | */ 35 | @JsonTypeName("kernelizedPerceptron") 36 | public class KernelizedPerceptron extends Perceptron implements KernelMethod{ 37 | 38 | 39 | private Kernel kernel; 40 | 41 | public KernelizedPerceptron(){ 42 | this.classifier = new BinaryKernelMachineClassifier(); 43 | this.classifier.setModel(new BinaryKernelMachineModel()); 44 | } 45 | 46 | public KernelizedPerceptron(float alpha, float margin, boolean unbiased, Kernel kernel, Label label){ 47 | this.classifier = new BinaryKernelMachineClassifier(); 48 | this.classifier.setModel(new BinaryKernelMachineModel()); 49 | this.setAlpha(alpha); 50 | this.setMargin(margin); 51 | this.setUnbiased(unbiased); 52 | this.setKernel(kernel); 53 | this.setLabel(label); 54 | } 55 | 56 | @Override 57 | public Kernel getKernel() { 58 | return kernel; 59 | } 60 | 61 | @Override 62 | public void setKernel(Kernel kernel) { 63 | this.kernel = kernel; 64 | this.getPredictionFunction().getModel().setKernel(kernel); 65 | } 66 | 67 | @Override 68 | public KernelizedPerceptron duplicate(){ 69 | KernelizedPerceptron copy = new KernelizedPerceptron(); 70 | copy.setKernel(this.kernel); 71 | copy.setAlpha(this.alpha); 72 | copy.setMargin(this.margin); 73 | copy.setUnbiased(this.unbiased); 74 | return copy; 75 | } 76 | 77 | @Override 78 | public BinaryKernelMachineClassifier getPredictionFunction(){ 79 | return (BinaryKernelMachineClassifier) this.classifier; 80 | } 81 | 82 | @Override 83 | public void setPredictionFunction(PredictionFunction predictionFunction) { 84 | this.classifier = (BinaryKernelMachineClassifier) predictionFunction; 85 | } 86 | 87 | } 88 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/perceptron/LinearPerceptron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron; 17 | 18 | 19 | import com.fasterxml.jackson.annotation.JsonTypeName; 20 | 21 | import it.uniroma2.sag.kelp.data.label.Label; 22 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 23 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 24 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryLinearClassifier; 25 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 26 | 27 | /** 28 | * The perceptron learning algorithm algorithm for classification tasks (linear version). Reference: 29 | * [Rosenblatt1957] F. Rosenblatt. The Perceptron – a perceiving and recognizing automaton. Report 85-460-1, Cornell Aeronautical Laboratory (1957) 30 | * 31 | * @author Simone Filice 32 | * 33 | */ 34 | @JsonTypeName("linearPerceptron") 35 | public class LinearPerceptron extends Perceptron implements LinearMethod{ 36 | 37 | 38 | private String representation; 39 | 40 | 41 | public LinearPerceptron(){ 42 | this.classifier = new BinaryLinearClassifier(); 43 | this.classifier.setModel(new BinaryLinearModel()); 44 | } 45 | 46 | public LinearPerceptron(float alpha, float margin, boolean unbiased, String representation, Label label){ 47 | this.classifier = new BinaryLinearClassifier(); 48 | this.classifier.setModel(new BinaryLinearModel()); 49 | this.setAlpha(alpha); 50 | this.setMargin(margin); 51 | this.setUnbiased(unbiased); 52 | this.setRepresentation(representation); 53 | this.setLabel(label); 54 | } 55 | 56 | @Override 57 | public String getRepresentation() { 58 | return representation; 59 | } 60 | 61 | @Override 62 | public void setRepresentation(String representation) { 63 | this.representation = representation; 64 | BinaryLinearModel model = (BinaryLinearModel) this.classifier.getModel(); 65 | model.setRepresentation(representation); 66 | } 67 | 68 | @Override 69 | public LinearPerceptron duplicate(){ 70 | LinearPerceptron copy = new LinearPerceptron(); 71 | copy.setAlpha(this.alpha); 72 | copy.setMargin(this.margin); 73 | copy.setRepresentation(representation); 74 | copy.setUnbiased(this.unbiased); 75 | return copy; 76 | } 77 | 78 | @Override 79 | public BinaryLinearClassifier getPredictionFunction(){ 80 | return (BinaryLinearClassifier) this.classifier; 81 | } 82 | 83 | @Override 84 | public void setPredictionFunction(PredictionFunction predictionFunction) { 85 | this.classifier = (BinaryLinearClassifier) predictionFunction; 86 | } 87 | 88 | } 89 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/perceptron/Perceptron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron; 17 | 18 | import java.util.Arrays; 19 | import java.util.List; 20 | 21 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 22 | import it.uniroma2.sag.kelp.data.example.Example; 23 | import it.uniroma2.sag.kelp.data.label.Label; 24 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 25 | import it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm; 26 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 27 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryClassifier; 28 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput; 29 | 30 | import com.fasterxml.jackson.annotation.JsonIgnore; 31 | 32 | /** 33 | * The perceptron learning algorithm algorithm for classification tasks. Reference: 34 | * [Rosenblatt1957] F. Rosenblatt. The Perceptron – a perceiving and recognizing automaton. Report 85-460-1, Cornell Aeronautical Laboratory (1957) 35 | * 36 | * @author Simone Filice 37 | * 38 | */ 39 | public abstract class Perceptron implements ClassificationLearningAlgorithm, OnlineLearningAlgorithm, BinaryLearningAlgorithm{ 40 | 41 | @JsonIgnore 42 | protected BinaryClassifier classifier; 43 | 44 | protected Label label; 45 | 46 | protected float alpha=1; 47 | protected float margin = 1; 48 | protected boolean unbiased=false; 49 | 50 | /** 51 | * Returns the learning rate, i.e. the weight associated to misclassified examples during the learning process 52 | * 53 | * @return the learning rate 54 | */ 55 | public float getAlpha() { 56 | return alpha; 57 | } 58 | 59 | /** 60 | * Sets the learning rate, i.e. the weight associated to misclassified examples during the learning process 61 | * 62 | * @param alpha the learning rate to set 63 | */ 64 | public void setAlpha(float alpha) { 65 | if(alpha<=0 || alpha>1){ 66 | throw new IllegalArgumentException("Invalid learning rate for the perceptron algorithm: valid alphas in (0,1]"); 67 | } 68 | this.alpha = alpha; 69 | } 70 | 71 | /** 72 | * Returns the desired margin, i.e. the minimum distance from the hyperplane that an example must have 73 | * in order to be not considered misclassified 74 | * 75 | * @return the margin 76 | */ 77 | public float getMargin() { 78 | return margin; 79 | } 80 | 81 | /** 82 | * Sets the desired margin, i.e. the minimum distance from the hyperplane that an example must have 83 | * in order to be not considered misclassified 84 | * 85 | * @param margin the margin to set 86 | */ 87 | public void setMargin(float margin) { 88 | this.margin = margin; 89 | } 90 | 91 | /** 92 | * Returns whether the bias, i.e. the constant term of the hyperplane, is always 0, or can be modified during 93 | * the learning process 94 | * 95 | * @return the unbiased 96 | */ 97 | public boolean isUnbiased() { 98 | return unbiased; 99 | } 100 | 101 | /** 102 | * Sets whether the bias, i.e. the constant term of the hyperplane, is always 0, or can be modified during 103 | * the learning process 104 | * 105 | * @param unbiased the unbiased to set 106 | */ 107 | public void setUnbiased(boolean unbiased) { 108 | this.unbiased = unbiased; 109 | } 110 | 111 | 112 | @Override 113 | public void learn(Dataset dataset) { 114 | 115 | while(dataset.hasNextExample()){ 116 | Example example = dataset.getNextExample(); 117 | this.learn(example); 118 | } 119 | dataset.reset(); 120 | } 121 | 122 | @Override 123 | public BinaryMarginClassifierOutput learn(Example example){ 124 | BinaryMarginClassifierOutput prediction = this.classifier.predict(example); 125 | 126 | float predValue = prediction.getScore(label); 127 | if(Math.abs(predValue) labels){ 154 | if(labels.size()!=1){ 155 | throw new IllegalArgumentException("The Perceptron algorithm is a binary method which can learn a single Label"); 156 | } 157 | else{ 158 | this.label=labels.get(0); 159 | this.classifier.setLabels(labels); 160 | } 161 | } 162 | 163 | 164 | @Override 165 | public List getLabels() { 166 | 167 | return Arrays.asList(label); 168 | } 169 | 170 | @Override 171 | public Label getLabel(){ 172 | return this.label; 173 | } 174 | 175 | @Override 176 | public void setLabel(Label label){ 177 | this.setLabels(Arrays.asList(label)); 178 | } 179 | 180 | } 181 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/BinaryPlattNormalizer.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | import it.uniroma2.sag.kelp.data.label.Label; 4 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput; 5 | 6 | public class BinaryPlattNormalizer { 7 | 8 | private float A; 9 | private float B; 10 | 11 | public BinaryPlattNormalizer() { 12 | 13 | } 14 | 15 | public BinaryPlattNormalizer(float a, float b) { 16 | super(); 17 | A = a; 18 | B = b; 19 | } 20 | 21 | public float normalizeScore(float nonNomalizedScore) { 22 | return (float) (1.0 / (1.0 + Math.exp(A * nonNomalizedScore + B))); 23 | } 24 | 25 | public float getA() { 26 | return A; 27 | } 28 | 29 | public float getB() { 30 | return B; 31 | } 32 | 33 | public void setA(float a) { 34 | A = a; 35 | } 36 | 37 | public void setB(float b) { 38 | B = b; 39 | } 40 | 41 | @Override 42 | public String toString() { 43 | return "PlattSigmoidFunction [A=" + A + ", B=" + B + "]"; 44 | } 45 | 46 | public BinaryMarginClassifierOutput getNormalizedScore(BinaryMarginClassifierOutput binaryMarginClassifierOutput) { 47 | 48 | Label positiveLabel = binaryMarginClassifierOutput.getAllClasses().get(0); 49 | 50 | Float nonNormalizedScore = binaryMarginClassifierOutput.getScore(positiveLabel); 51 | 52 | BinaryMarginClassifierOutput res = new BinaryMarginClassifierOutput(positiveLabel, 53 | normalizeScore(nonNormalizedScore)); 54 | 55 | return res; 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/MulticlassPlattNormalizer.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | import java.util.HashMap; 4 | 5 | import it.uniroma2.sag.kelp.data.label.Label; 6 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassificationOutput; 7 | 8 | public class MulticlassPlattNormalizer { 9 | 10 | private HashMap binaryPlattNormalizers; 11 | 12 | public void addBinaryPlattNormalizer(Label label, BinaryPlattNormalizer binaryPlattNormalizer) { 13 | if (binaryPlattNormalizers == null) { 14 | binaryPlattNormalizers = new HashMap(); 15 | } 16 | binaryPlattNormalizers.put(label, binaryPlattNormalizer); 17 | } 18 | 19 | public OneVsAllClassificationOutput getNormalizedScores(OneVsAllClassificationOutput oneVsAllClassificationOutput) { 20 | OneVsAllClassificationOutput res = new OneVsAllClassificationOutput(); 21 | 22 | for (Label l : oneVsAllClassificationOutput.getAllClasses()) { 23 | float nonNormalizedScore = oneVsAllClassificationOutput.getScore(l); 24 | BinaryPlattNormalizer binaryPlattNormalizer = binaryPlattNormalizers.get(l); 25 | float normalizedScore = binaryPlattNormalizer.normalizeScore(nonNormalizedScore); 26 | 27 | res.addBinaryPrediction(l, normalizedScore); 28 | } 29 | 30 | return res; 31 | } 32 | 33 | public static OneVsAllClassificationOutput softmax(OneVsAllClassificationOutput oneVsAllClassificationOutput) { 34 | OneVsAllClassificationOutput res = new OneVsAllClassificationOutput(); 35 | 36 | float denom = 0; 37 | for (Label l : oneVsAllClassificationOutput.getAllClasses()) { 38 | float score = oneVsAllClassificationOutput.getScore(l); 39 | denom += Math.exp(score); 40 | } 41 | 42 | 43 | for (Label l : oneVsAllClassificationOutput.getAllClasses()) { 44 | float score = oneVsAllClassificationOutput.getScore(l); 45 | float newScore = (float)Math.exp(score)/denom; 46 | 47 | res.addBinaryPrediction(l, newScore); 48 | } 49 | 50 | return res; 51 | } 52 | 53 | public HashMap getBinaryPlattNormalizers() { 54 | return binaryPlattNormalizers; 55 | } 56 | 57 | public void setBinaryPlattNormalizers(HashMap binaryPlattNormalizers) { 58 | this.binaryPlattNormalizers = binaryPlattNormalizers; 59 | } 60 | 61 | } 62 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/PlattInputElement.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | public class PlattInputElement { 4 | 5 | private int label; 6 | private float value; 7 | 8 | public PlattInputElement(int label, float value) { 9 | super(); 10 | this.label = label; 11 | this.value = value; 12 | } 13 | 14 | public int getLabel() { 15 | return label; 16 | } 17 | 18 | public float getValue() { 19 | return value; 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/PlattInputList.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | import java.util.Vector; 4 | 5 | public class PlattInputList { 6 | 7 | private Vector list; 8 | private int positiveElement; 9 | private int negativeElement; 10 | 11 | public PlattInputList() { 12 | list = new Vector(); 13 | } 14 | 15 | public void add(PlattInputElement arg0) { 16 | if (arg0.getLabel() > 0) 17 | positiveElement++; 18 | else 19 | negativeElement++; 20 | 21 | list.add(arg0); 22 | } 23 | 24 | public PlattInputElement get(int index) { 25 | return list.get(index); 26 | } 27 | 28 | public int size() { 29 | return list.size(); 30 | } 31 | 32 | public int getPositiveElement() { 33 | return positiveElement; 34 | } 35 | 36 | public int getNegativeElement() { 37 | return negativeElement; 38 | } 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/PlattMethod.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | import java.util.HashMap; 4 | 5 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 6 | import it.uniroma2.sag.kelp.data.example.Example; 7 | import it.uniroma2.sag.kelp.data.label.Label; 8 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 9 | import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm; 10 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 11 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 12 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 13 | 14 | public class PlattMethod { 15 | 16 | /** 17 | * Input parameters: 18 | * 19 | * deci = array of SVM decision values 20 | * 21 | * label = array of booleans: is the example labeled +1? 22 | * 23 | * prior1 = number of positive examples 24 | * 25 | * prior0 = number of negative examples 26 | * 27 | * Outputs: 28 | * 29 | * A, B = parameters of sigmoid 30 | * 31 | * @return 32 | **/ 33 | private static BinaryPlattNormalizer estimateSigmoid(float[] deci, float[] label, int prior1, int prior0) { 34 | 35 | /** 36 | * Parameter setting 37 | */ 38 | // Maximum number of iterations 39 | int maxiter = 100; 40 | // Minimum step taken in line search 41 | // minstep=1e-10; 42 | double minstep = 1e-10; 43 | double stopping = 1e-5; 44 | // Sigma: Set to any value > 0 45 | double sigma = 1e-12; 46 | // Construct initial values: target support in array t, 47 | // initial function value in fval 48 | double hiTarget = ((double) prior1 + 1.0f) / ((double) prior1 + 2.0f); 49 | double loTarget = 1 / (prior0 + 2.0f); 50 | 51 | int len = prior1 + prior0; // Total number of data 52 | double A; 53 | double B; 54 | 55 | double t[] = new double[len]; 56 | 57 | for (int i = 0; i < len; i++) { 58 | if (label[i] > 0) 59 | t[i] = hiTarget; 60 | else 61 | t[i] = loTarget; 62 | } 63 | 64 | A = 0; 65 | B = Math.log((prior0 + 1.0) / (prior1 + 1.0)); 66 | double fval = 0f; 67 | 68 | for (int i = 0; i < len; i++) { 69 | double fApB = deci[i] * A + B; 70 | if (fApB >= 0) 71 | fval += t[i] * fApB + Math.log(1 + Math.exp(-fApB)); 72 | else 73 | fval += (t[i] - 1) * fApB + Math.log(1 + Math.exp(fApB)); 74 | } 75 | 76 | int it = 1; 77 | for (it = 1; it <= maxiter; it++) { 78 | // Update Gradient and Hessian (use H� = H + sigma I) 79 | double h11 = sigma; 80 | double h22 = sigma; 81 | double h21 = 0; 82 | double g1 = 0; 83 | double g2 = 0; 84 | for (int i = 0; i < len; i++) { 85 | double fApB = deci[i] * A + B; 86 | double p; 87 | double q; 88 | if (fApB >= 0) { 89 | p = (Math.exp(-fApB) / (1.0 + Math.exp(-fApB))); 90 | q = (1.0 / (1.0 + Math.exp(-fApB))); 91 | } else { 92 | p = 1.0 / (1.0 + Math.exp(fApB)); 93 | q = Math.exp(fApB) / (1.0 + Math.exp(fApB)); 94 | } 95 | double d2 = p * q; 96 | h11 += deci[i] * deci[i] * d2; 97 | h22 += d2; 98 | h21 += deci[i] * d2; 99 | double d1 = t[i] - p; 100 | g1 += deci[i] * d1; 101 | g2 += d1; 102 | } 103 | if (Math.abs(g1) < stopping && Math.abs(g2) < stopping) // Stopping 104 | // criteria 105 | break; 106 | 107 | // Compute modified Newton directions 108 | double det = h11 * h22 - h21 * h21; 109 | double dA = -(h22 * g1 - h21 * g2) / det; 110 | double dB = -(-h21 * g1 + h11 * g2) / det; 111 | double gd = g1 * dA + g2 * dB; 112 | double stepsize = 1; 113 | 114 | while (stepsize >= minstep) { // Line search 115 | double newA = A + stepsize * dA; 116 | double newB = B + stepsize * dB; 117 | double newf = 0.0; 118 | for (int i = 0; i < len; i++) { 119 | double fApB = deci[i] * newA + newB; 120 | if (fApB >= 0) 121 | newf += t[i] * fApB + Math.log(1 + Math.exp(-fApB)); 122 | else 123 | newf += (t[i] - 1) * fApB + Math.log(1 + Math.exp(fApB)); 124 | } 125 | 126 | if (newf < fval + 1e-4 * stepsize * gd) { 127 | A = newA; 128 | B = newB; 129 | fval = newf; 130 | break; // Sufficient decrease satisfied 131 | } else 132 | stepsize /= 2.0; 133 | } 134 | if (stepsize < minstep) { 135 | System.out.println("Line search fails"); 136 | break; 137 | } 138 | } 139 | if (it >= maxiter) 140 | System.out.println("Reaching maximum iterations"); 141 | 142 | return new BinaryPlattNormalizer((float) A, (float) B); 143 | 144 | } 145 | 146 | public static BinaryPlattNormalizer esitmateSigmoid(SimpleDataset dataset, 147 | BinaryLearningAlgorithm binaryLearningAlgorithm, int nFolds) { 148 | 149 | PlattInputList plattInputList = new PlattInputList(); 150 | 151 | Label positiveLabel = binaryLearningAlgorithm.getLabel(); 152 | 153 | SimpleDataset[] folds = dataset.getShuffledDataset().nFolding(nFolds); 154 | 155 | for (int f = 0; f < folds.length; f++) { 156 | 157 | SimpleDataset fold = folds[f]; 158 | 159 | SimpleDataset localTrainDataset = new SimpleDataset(); 160 | SimpleDataset localTestDataset = new SimpleDataset(); 161 | for (int i = 0; i < folds.length; i++) { 162 | if (i != f) { 163 | localTrainDataset.addExamples(fold); 164 | } else { 165 | localTestDataset.addExamples(fold); 166 | } 167 | } 168 | 169 | LearningAlgorithm duplicatedLearningAlgorithm = binaryLearningAlgorithm.duplicate(); 170 | 171 | duplicatedLearningAlgorithm.learn(fold); 172 | 173 | PredictionFunction predictionFunction = duplicatedLearningAlgorithm.getPredictionFunction(); 174 | 175 | for (Example example : localTestDataset.getExamples()) { 176 | Prediction predict = predictionFunction.predict(example); 177 | 178 | float value = predict.getScore(positiveLabel); 179 | 180 | int label = 1; 181 | if (!example.isExampleOf(positiveLabel)) 182 | label = -1; 183 | plattInputList.add(new PlattInputElement(label, value)); 184 | } 185 | } 186 | 187 | return estimateSigmoid(plattInputList); 188 | } 189 | 190 | public static MulticlassPlattNormalizer esitmateSigmoid(SimpleDataset dataset, OneVsAllLearning oneVsAllLearning, 191 | int nFolds) { 192 | 193 | HashMap plattInputLists = new HashMap(); 194 | for(Label label: dataset.getClassificationLabels()){ 195 | plattInputLists.put(label, new PlattInputList()); 196 | } 197 | 198 | SimpleDataset[] folds = dataset.getShuffledDataset().nFolding(nFolds); 199 | 200 | MulticlassPlattNormalizer res = new MulticlassPlattNormalizer(); 201 | 202 | for (int f = 0; f < folds.length; f++) { 203 | 204 | SimpleDataset fold = folds[f]; 205 | 206 | SimpleDataset localTrainDataset = new SimpleDataset(); 207 | SimpleDataset localTestDataset = new SimpleDataset(); 208 | for (int i = 0; i < folds.length; i++) { 209 | if (i != f) { 210 | localTrainDataset.addExamples(fold); 211 | } else { 212 | localTestDataset.addExamples(fold); 213 | } 214 | } 215 | 216 | LearningAlgorithm duplicatedLearningAlgorithm = oneVsAllLearning.duplicate(); 217 | 218 | duplicatedLearningAlgorithm.learn(fold); 219 | 220 | PredictionFunction predictionFunction = duplicatedLearningAlgorithm.getPredictionFunction(); 221 | 222 | for (Example example : localTestDataset.getExamples()) { 223 | Prediction predict = predictionFunction.predict(example); 224 | 225 | for (Label label : dataset.getClassificationLabels()) { 226 | 227 | float valueOfLabel = predict.getScore(label); 228 | 229 | int binaryLabel = 1; 230 | if (!example.isExampleOf(label)) 231 | binaryLabel = -1; 232 | plattInputLists.get(label).add(new PlattInputElement(binaryLabel, valueOfLabel)); 233 | } 234 | } 235 | } 236 | 237 | for (Label label : dataset.getClassificationLabels()) { 238 | res.addBinaryPlattNormalizer(label, estimateSigmoid(plattInputLists.get(label))); 239 | } 240 | 241 | return res; 242 | } 243 | 244 | protected static BinaryPlattNormalizer estimateSigmoid(PlattInputList inputList) { 245 | float[] deci = new float[inputList.size()]; 246 | float[] label = new float[inputList.size()]; 247 | int prior1 = inputList.getPositiveElement(); 248 | int prior0 = inputList.getNegativeElement(); 249 | 250 | for (int i = 0; i < inputList.size(); i++) { 251 | deci[i] = inputList.get(i).getValue(); 252 | label[i] = inputList.get(i).getLabel(); 253 | } 254 | 255 | return estimateSigmoid(deci, label, prior1, prior0); 256 | } 257 | 258 | } 259 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/scw/SCWType.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.scw; 17 | 18 | /** 19 | * The two types of Soft Confidence-Weighted implemented variants 20 | * 21 | * @author Danilo Croce 22 | * 23 | */ 24 | public enum SCWType { 25 | 26 | SCW_I, SCW_II 27 | 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/clustering/kernelbasedkmeans/KernelBasedKMeansExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.clustering.kernelbasedkmeans; 17 | 18 | import com.fasterxml.jackson.annotation.JsonTypeName; 19 | 20 | import it.uniroma2.sag.kelp.data.clustering.ClusterExample; 21 | import it.uniroma2.sag.kelp.data.example.Example; 22 | 23 | @JsonTypeName("kernelbasedkmeansexample") 24 | public class KernelBasedKMeansExample extends ClusterExample { 25 | 26 | /** 27 | * 28 | */ 29 | private static final long serialVersionUID = -5368757832244686390L; 30 | 31 | public KernelBasedKMeansExample() { 32 | super(); 33 | } 34 | 35 | public KernelBasedKMeansExample(Example e, float dist) { 36 | super(e, dist); 37 | } 38 | 39 | @Override 40 | public Example getExample() { 41 | return example; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/liblinear/LibLinearRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.liblinear; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.label.Label; 20 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 21 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 22 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.L2R_L2_SvcFunction; 23 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.L2R_L2_SvrFunction; 24 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.Problem; 25 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.Problem.LibLinearSolverType; 26 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.Tron; 27 | import it.uniroma2.sag.kelp.learningalgorithm.regression.RegressionLearningAlgorithm; 28 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 29 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 30 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateLinearRegressionFunction; 31 | 32 | import java.util.Arrays; 33 | import java.util.List; 34 | 35 | import com.fasterxml.jackson.annotation.JsonIgnore; 36 | import com.fasterxml.jackson.annotation.JsonTypeName; 37 | 38 | /** 39 | * This class implements linear SVM regression trained using a coordinate descent 40 | * algorithm [Fan et al, 2008]. It operates in an explicit feature space (i.e. 41 | * it does not relies on any kernel). This code has been adapted from the Java 42 | * port of the original LIBLINEAR C++ sources. 43 | * 44 | * Further details can be found in: 45 | * 46 | * [Fan et al, 2008] R.-E. Fan, K.-W. Chang, C.-J. Hsieh, X.-R. Wang, and C.-J. 47 | * Lin. LIBLINEAR: A Library for Large Linear Classification, Journal of Machine 48 | * Learning Research 9(2008), 1871-1874. Software available at 49 | * 50 | * The original LIBLINEAR code: 51 | * http://www.csie.ntu.edu.tw/~cjlin/liblinear 52 | * 53 | * The original JAVA porting (v 1.94): http://liblinear.bwaldvogel.de 54 | * 55 | * @author Danilo Croce 56 | */ 57 | @JsonTypeName("liblinearregression") 58 | public class LibLinearRegression implements LinearMethod, 59 | RegressionLearningAlgorithm, BinaryLearningAlgorithm { 60 | 61 | /** 62 | * The property corresponding to the variable to be learned 63 | */ 64 | private Label label; 65 | /** 66 | * The regularization parameter 67 | */ 68 | private double c = 1; 69 | 70 | /** 71 | * The regressor to be returned 72 | */ 73 | @JsonIgnore 74 | private UnivariateLinearRegressionFunction regressionFunction; 75 | 76 | /** 77 | * The epsilon in loss function of SVR (default 0.1) 78 | */ 79 | private double p = 0.1f; 80 | 81 | /** 82 | * The identifier of the representation to be considered for the training 83 | * step 84 | */ 85 | private String representation; 86 | 87 | /** 88 | * @param label 89 | * The regression property to be learned 90 | * @param c 91 | * The regularization parameter 92 | * 93 | * @param p 94 | * The The epsilon in loss function of SVR 95 | * 96 | * @param representationName 97 | * The identifier of the representation to be considered for the 98 | * training step 99 | */ 100 | public LibLinearRegression(Label label, double c, double p, 101 | String representationName) { 102 | this(); 103 | 104 | this.setLabel(label); 105 | this.c = c; 106 | this.p = p; 107 | this.setRepresentation(representationName); 108 | } 109 | 110 | /** 111 | * @param c 112 | * The regularization parameter 113 | * 114 | * @param representationName 115 | * The identifier of the representation to be considered for the 116 | * training step 117 | */ 118 | public LibLinearRegression(double c, double p, String representationName) { 119 | this(); 120 | this.c = c; 121 | this.p = p; 122 | this.setRepresentation(representationName); 123 | } 124 | 125 | public LibLinearRegression() { 126 | this.regressionFunction = new UnivariateLinearRegressionFunction(); 127 | this.regressionFunction.setModel(new BinaryLinearModel()); 128 | } 129 | 130 | /** 131 | * @return the regularization parameter 132 | */ 133 | public double getC() { 134 | return c; 135 | } 136 | 137 | /** 138 | * @param c 139 | * the regularization parameter 140 | */ 141 | public void setC(double c) { 142 | this.c = c; 143 | } 144 | 145 | /** 146 | * @return the epsilon in loss function 147 | */ 148 | public double getP() { 149 | return p; 150 | } 151 | 152 | /** 153 | * @param p 154 | * the epsilon in loss function 155 | */ 156 | public void setP(double p) { 157 | this.p = p; 158 | } 159 | 160 | /* 161 | * (non-Javadoc) 162 | * 163 | * @see 164 | * it.uniroma2.sag.kelp.learningalgorithm.LinearMethod#getRepresentation() 165 | */ 166 | @Override 167 | public String getRepresentation() { 168 | return representation; 169 | } 170 | 171 | /* 172 | * (non-Javadoc) 173 | * 174 | * @see 175 | * it.uniroma2.sag.kelp.learningalgorithm.LinearMethod#setRepresentation 176 | * (java.lang.String) 177 | */ 178 | @Override 179 | public void setRepresentation(String representation) { 180 | this.representation = representation; 181 | BinaryLinearModel model = this.regressionFunction.getModel(); 182 | model.setRepresentation(representation); 183 | } 184 | 185 | /* 186 | * (non-Javadoc) 187 | * 188 | * @see 189 | * it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#setLabels(java 190 | * .util.List) 191 | */ 192 | @Override 193 | public void setLabels(List labels) { 194 | if (labels.size() != 1) { 195 | throw new IllegalArgumentException( 196 | "LibLinear algorithm is a binary method which can learn a single Label"); 197 | } else { 198 | this.label = labels.get(0); 199 | this.regressionFunction.setLabels(labels); 200 | } 201 | } 202 | 203 | /* 204 | * (non-Javadoc) 205 | * 206 | * @see it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#getLabels() 207 | */ 208 | @Override 209 | public List getLabels() { 210 | return Arrays.asList(label); 211 | } 212 | 213 | /* 214 | * (non-Javadoc) 215 | * 216 | * @see 217 | * it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm#getLabel() 218 | */ 219 | @Override 220 | public Label getLabel() { 221 | return this.label; 222 | } 223 | 224 | /* 225 | * (non-Javadoc) 226 | * 227 | * @see 228 | * it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm#setLabel 229 | * (it.uniroma2.sag.kelp.data.label.Label) 230 | */ 231 | @Override 232 | public void setLabel(Label label) { 233 | this.setLabels(Arrays.asList(label)); 234 | } 235 | 236 | /* 237 | * (non-Javadoc) 238 | * 239 | * @see 240 | * it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#learn(it.uniroma2 241 | * .sag.kelp.data.dataset.Dataset) 242 | */ 243 | @Override 244 | public void learn(Dataset dataset) { 245 | 246 | double eps = 0.001; 247 | 248 | int l = dataset.getNumberOfExamples(); 249 | 250 | double[] C = new double[l]; 251 | for (int i = 0; i < l; i++) { 252 | C[i] = c; 253 | } 254 | 255 | Problem problem = new Problem(dataset, representation, label, 256 | LibLinearSolverType.REGRESSION); 257 | 258 | L2R_L2_SvcFunction fun_obj = new L2R_L2_SvrFunction(problem, C, p); 259 | 260 | Tron tron = new Tron(fun_obj, eps); 261 | 262 | double[] w = new double[problem.n]; 263 | tron.tron(w); 264 | 265 | this.regressionFunction.getModel().setHyperplane(problem.getW(w)); 266 | this.regressionFunction.getModel().setRepresentation(representation); 267 | this.regressionFunction.getModel().setBias(0); 268 | } 269 | 270 | /* 271 | * (non-Javadoc) 272 | * 273 | * @see it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#duplicate() 274 | */ 275 | @Override 276 | public LibLinearRegression duplicate() { 277 | LibLinearRegression copy = new LibLinearRegression(); 278 | copy.setRepresentation(representation); 279 | copy.setC(c); 280 | copy.setP(p); 281 | return copy; 282 | } 283 | 284 | /* 285 | * (non-Javadoc) 286 | * 287 | * @see it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#reset() 288 | */ 289 | @Override 290 | public void reset() { 291 | this.regressionFunction.reset(); 292 | } 293 | 294 | @Override 295 | public UnivariateLinearRegressionFunction getPredictionFunction() { 296 | return regressionFunction; 297 | } 298 | 299 | @Override 300 | public void setPredictionFunction(PredictionFunction predictionFunction) { 301 | this.regressionFunction = (UnivariateLinearRegressionFunction) predictionFunction; 302 | } 303 | 304 | } 305 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/KernelizedPassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import com.fasterxml.jackson.annotation.JsonTypeName; 19 | 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.kernel.Kernel; 22 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 23 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 24 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateKernelMachineRegressionFunction; 25 | 26 | /** 27 | * Online Passive-Aggressive Learning Algorithm for regression tasks (kernel machine version). 28 | * 29 | * reference: 30 | * 31 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 32 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 33 | * 34 | * @author Simone Filice 35 | */ 36 | @JsonTypeName("kernelizedPA-R") 37 | public class KernelizedPassiveAggressiveRegression extends PassiveAggressiveRegression implements KernelMethod{ 38 | 39 | private Kernel kernel; 40 | 41 | public KernelizedPassiveAggressiveRegression(){ 42 | this.regressor = new UnivariateKernelMachineRegressionFunction(); 43 | } 44 | 45 | public KernelizedPassiveAggressiveRegression(float aggressiveness, float epsilon, Policy policy, Kernel kernel, Label label){ 46 | this.regressor = new UnivariateKernelMachineRegressionFunction(); 47 | this.setC(aggressiveness); 48 | this.setEpsilon(epsilon); 49 | this.setPolicy(policy); 50 | this.setKernel(kernel); 51 | this.setLabel(label); 52 | } 53 | 54 | @Override 55 | public Kernel getKernel(){ 56 | return kernel; 57 | } 58 | 59 | @Override 60 | public void setKernel(Kernel kernel) { 61 | this.kernel = kernel; 62 | this.getPredictionFunction().getModel().setKernel(kernel); 63 | } 64 | 65 | @Override 66 | public KernelizedPassiveAggressiveRegression duplicate() { 67 | KernelizedPassiveAggressiveRegression copy = new KernelizedPassiveAggressiveRegression(); 68 | copy.setC(this.c); 69 | copy.setKernel(this.kernel); 70 | copy.setPolicy(this.policy); 71 | copy.setEpsilon(epsilon); 72 | return copy; 73 | } 74 | 75 | @Override 76 | public UnivariateKernelMachineRegressionFunction getPredictionFunction(){ 77 | return (UnivariateKernelMachineRegressionFunction) this.regressor; 78 | } 79 | 80 | @Override 81 | public void setPredictionFunction(PredictionFunction predictionFunction) { 82 | this.regressor = (UnivariateKernelMachineRegressionFunction) predictionFunction; 83 | } 84 | 85 | } 86 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/LinearPassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import com.fasterxml.jackson.annotation.JsonTypeName; 19 | 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 22 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 23 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 24 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateLinearRegressionFunction; 25 | 26 | /** 27 | * Online Passive-Aggressive Learning Algorithm for regression tasks (linear version). 28 | * 29 | * reference: 30 | * 31 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 32 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 33 | * 34 | * @author Simone Filice 35 | */ 36 | @JsonTypeName("linearPA-R") 37 | public class LinearPassiveAggressiveRegression extends PassiveAggressiveRegression implements LinearMethod{ 38 | 39 | private String representation; 40 | 41 | public LinearPassiveAggressiveRegression(){ 42 | UnivariateLinearRegressionFunction regressor = new UnivariateLinearRegressionFunction(); 43 | regressor.setModel(new BinaryLinearModel()); 44 | this.regressor = regressor; 45 | 46 | } 47 | 48 | public LinearPassiveAggressiveRegression(float aggressiveness, float epsilon, Policy policy, String representation, Label label){ 49 | UnivariateLinearRegressionFunction regressor = new UnivariateLinearRegressionFunction(); 50 | regressor.setModel(new BinaryLinearModel()); 51 | this.regressor = regressor; 52 | this.setC(aggressiveness); 53 | this.setEpsilon(epsilon); 54 | this.setPolicy(policy); 55 | this.setRepresentation(representation); 56 | this.setLabel(label); 57 | } 58 | 59 | @Override 60 | public LinearPassiveAggressiveRegression duplicate() { 61 | LinearPassiveAggressiveRegression copy = new LinearPassiveAggressiveRegression(); 62 | copy.setC(this.c); 63 | copy.setRepresentation(this.representation); 64 | copy.setPolicy(this.policy); 65 | copy.setEpsilon(epsilon); 66 | return copy; 67 | } 68 | 69 | @Override 70 | public String getRepresentation() { 71 | return representation; 72 | } 73 | 74 | @Override 75 | public void setRepresentation(String representation) { 76 | this.representation = representation; 77 | this.getPredictionFunction().getModel().setRepresentation(representation); 78 | } 79 | 80 | @Override 81 | public UnivariateLinearRegressionFunction getPredictionFunction(){ 82 | return (UnivariateLinearRegressionFunction) this.regressor; 83 | } 84 | 85 | @Override 86 | public void setPredictionFunction(PredictionFunction predictionFunction) { 87 | this.regressor = (UnivariateLinearRegressionFunction) predictionFunction; 88 | } 89 | 90 | } 91 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/PassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.learningalgorithm.PassiveAggressive; 21 | import it.uniroma2.sag.kelp.learningalgorithm.regression.RegressionLearningAlgorithm; 22 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateRegressionOutput; 23 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateRegressionFunction; 24 | 25 | import com.fasterxml.jackson.annotation.JsonIgnore; 26 | 27 | /** 28 | * Online Passive-Aggressive Learning Algorithm for regression tasks. 29 | * 30 | * reference: 31 | * 32 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 33 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 34 | * 35 | * @author Simone Filice 36 | */ 37 | public abstract class PassiveAggressiveRegression extends PassiveAggressive implements RegressionLearningAlgorithm{ 38 | 39 | @JsonIgnore 40 | protected UnivariateRegressionFunction regressor; 41 | 42 | protected float epsilon; 43 | 44 | /** 45 | * Returns epsilon, i.e. the accepted distance between the predicted and the real regression values 46 | * 47 | * @return the epsilon 48 | */ 49 | public float getEpsilon() { 50 | return epsilon; 51 | } 52 | 53 | /** 54 | * Sets epsilon, i.e. the accepted distance between the predicted and the real regression values 55 | * 56 | * @param epsilon the epsilon to set 57 | */ 58 | public void setEpsilon(float epsilon) { 59 | this.epsilon = epsilon; 60 | } 61 | 62 | @Override 63 | public UnivariateRegressionFunction getPredictionFunction() { 64 | return this.regressor; 65 | } 66 | 67 | @Override 68 | public void learn(Dataset dataset){ 69 | 70 | while(dataset.hasNextExample()){ 71 | Example example = dataset.getNextExample(); 72 | this.learn(example); 73 | } 74 | dataset.reset(); 75 | } 76 | 77 | @Override 78 | public UnivariateRegressionOutput learn(Example example){ 79 | UnivariateRegressionOutput prediction=this.regressor.predict(example); 80 | float difference = example.getRegressionValue(label) - prediction.getScore(label); 81 | float lossValue = Math.abs(difference) - epsilon;//it represents the distance from the correct semi-space 82 | if(lossValue>0){ 83 | float exampleSquaredNorm = this.regressor.getModel().getSquaredNorm(example); 84 | float weight = this.computeWeight(example, lossValue, exampleSquaredNorm, c); 85 | if(difference<0){ 86 | weight = -weight; 87 | } 88 | this.regressor.getModel().addExample(weight, example); 89 | } 90 | return prediction; 91 | } 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/linearization/LinearizationFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.linearization; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 20 | import it.uniroma2.sag.kelp.data.example.Example; 21 | import it.uniroma2.sag.kelp.data.representation.Vector; 22 | 23 | /** 24 | * This interface allows implementing function to linearized examples through 25 | * linear representations, i.e. vectors 26 | * 27 | * 28 | * @author Danilo Croce 29 | * 30 | */ 31 | public interface LinearizationFunction { 32 | 33 | /** 34 | * Given an input Example, this method generates a linear 35 | * Representation>, i.e. a Vector. 36 | * 37 | * @param example 38 | * The input example. 39 | * @return The linearized representation of the input example. 40 | */ 41 | public Vector getLinearRepresentation(Example example); 42 | 43 | /** 44 | * This method linearizes an input example, providing a new example 45 | * containing only a representation with a specific name, provided as input. 46 | * The produced example inherits the labels of the input example. 47 | * 48 | * @param example 49 | * The input example. 50 | * @param vectorName 51 | * The name of the linear representation inside the new example 52 | * @return 53 | */ 54 | public Example getLinearizedExample(Example example, String representationName); 55 | 56 | /** 57 | * This method linearizes all the examples in the input dataset 58 | * , generating a corresponding linearized dataset. The produced examples 59 | * inherit the labels of the corresponding input examples. 60 | * 61 | * @param dataset 62 | * The input dataset 63 | * @param representationName 64 | * The name of the linear representation inside the new examples 65 | * @return 66 | */ 67 | public SimpleDataset getLinearizedDataset(Dataset dataset, String representationName); 68 | 69 | /** 70 | * @return the size of the resulting embedding, i.e. the number of resulting 71 | * vector dimensions 72 | */ 73 | public int getEmbeddingSize(); 74 | 75 | } 76 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/predictionfunction/SequencePrediction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.predictionfunction; 17 | 18 | import java.util.ArrayList; 19 | import java.util.List; 20 | 21 | import it.uniroma2.sag.kelp.data.example.SequencePath; 22 | import it.uniroma2.sag.kelp.data.label.Label; 23 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 24 | 25 | /** 26 | * It is a output provided by a machine learning systems on a sequence. This 27 | * specific implementation allows to assign multiple labelings to single 28 | * sequence, useful for some labeling strategies, such as Beam Search. Notice 29 | * that each labeling requires a score to select the more promising labeling. 30 | * 31 | * @author Danilo Croce 32 | * 33 | */ 34 | public class SequencePrediction implements Prediction { 35 | 36 | /** 37 | * 38 | */ 39 | private static final long serialVersionUID = -1040539866977906008L; 40 | /** 41 | * This list contains multiple labelings to be assigned to a single sequence 42 | */ 43 | private List paths; 44 | 45 | public SequencePrediction() { 46 | paths = new ArrayList(); 47 | } 48 | 49 | /** 50 | * @return The best path, i.e., the labeling with the highest score in the 51 | * list of labelings provided by a classifier 52 | */ 53 | public SequencePath bestPath() { 54 | return paths.get(0); 55 | } 56 | 57 | /** 58 | * @return a list containing multiple labelings to be assigned to a single 59 | * sequence 60 | */ 61 | public List getPaths() { 62 | return paths; 63 | } 64 | 65 | @Override 66 | public Float getScore(Label label) { 67 | return null; 68 | } 69 | 70 | /** 71 | * @param paths 72 | * a list contains multiple labelings to be assigned to a single 73 | * sequence 74 | */ 75 | public void setPaths(List paths) { 76 | this.paths = paths; 77 | } 78 | 79 | @Override 80 | public String toString() { 81 | StringBuilder sb = new StringBuilder(); 82 | for (int i = 0; i < paths.size(); i++) { 83 | if (i == 0) 84 | sb.append("Best Path\t"); 85 | else 86 | sb.append("Altern. Path\t"); 87 | SequencePath sequencePath = paths.get(i); 88 | sb.append(sequencePath + "\n"); 89 | } 90 | return sb.toString(); 91 | } 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/predictionfunction/model/SequenceModel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.predictionfunction.model; 17 | 18 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGenerator; 19 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 20 | 21 | /** 22 | * This class implements a model produced by a 23 | * SequenceClassificationLearningAlgorithm 24 | * 25 | * @author Danilo Croce 26 | * 27 | */ 28 | public class SequenceModel implements Model { 29 | 30 | /** 31 | * 32 | */ 33 | private static final long serialVersionUID = -2749198158786953940L; 34 | 35 | /** 36 | * The prediction function producing the emission scores to be considered in 37 | * the Viterbi Decoding 38 | */ 39 | private PredictionFunction basePredictionFunction; 40 | 41 | private SequenceExampleGenerator sequenceExampleGenerator; 42 | 43 | public SequenceModel() { 44 | super(); 45 | } 46 | 47 | public SequenceModel(PredictionFunction basePredictionFunction, SequenceExampleGenerator sequenceExampleGenerator) { 48 | super(); 49 | this.basePredictionFunction = basePredictionFunction; 50 | this.sequenceExampleGenerator = sequenceExampleGenerator; 51 | } 52 | 53 | public PredictionFunction getBasePredictionFunction() { 54 | return basePredictionFunction; 55 | } 56 | 57 | public SequenceExampleGenerator getSequenceExampleGenerator() { 58 | return sequenceExampleGenerator; 59 | } 60 | 61 | @Override 62 | public void reset() { 63 | } 64 | 65 | public void setBasePredictionFunction(PredictionFunction basePredictionFunction) { 66 | this.basePredictionFunction = basePredictionFunction; 67 | } 68 | 69 | public void setSequenceExampleGenerator(SequenceExampleGenerator sequenceExampleGenerator) { 70 | this.sequenceExampleGenerator = sequenceExampleGenerator; 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/utils/evaluation/ClusteringEvaluator.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.utils.evaluation; 2 | 3 | import java.util.ArrayList; 4 | import java.util.HashSet; 5 | import java.util.TreeMap; 6 | 7 | import it.uniroma2.sag.kelp.data.clustering.Cluster; 8 | import it.uniroma2.sag.kelp.data.clustering.ClusterExample; 9 | import it.uniroma2.sag.kelp.data.clustering.ClusterList; 10 | import it.uniroma2.sag.kelp.data.example.Example; 11 | import it.uniroma2.sag.kelp.data.example.SimpleExample; 12 | import it.uniroma2.sag.kelp.data.label.Label; 13 | import it.uniroma2.sag.kelp.data.label.StringLabel; 14 | import it.uniroma2.sag.kelp.learningalgorithm.clustering.kernelbasedkmeans.KernelBasedKMeansExample; 15 | 16 | /** 17 | * 18 | * Implements Evaluation methods for clustering algorithms. 19 | * 20 | * More details about Purity and NMI can be found here: 21 | * 22 | * https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1. 23 | * html 24 | * 25 | * @author Danilo Croce 26 | * 27 | */ 28 | public class ClusteringEvaluator { 29 | 30 | public static float getPurity(ClusterList clusters) { 31 | 32 | float res = 0; 33 | int k = clusters.size(); 34 | 35 | for (int clustId = 0; clustId < k; clustId++) { 36 | 37 | TreeMap classSizes = new TreeMap(); 38 | 39 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 40 | HashSet labels = vce.getExample().getClassificationLabels(); 41 | for (Label label : labels) 42 | if (!classSizes.containsKey(label)) 43 | classSizes.put(label, 1); 44 | else 45 | classSizes.put(label, classSizes.get(label) + 1); 46 | } 47 | 48 | int maxSize = 0; 49 | for (int size : classSizes.values()) { 50 | if (size > maxSize) { 51 | maxSize = size; 52 | } 53 | } 54 | res += maxSize; 55 | } 56 | 57 | return res / (float) clusters.getNumberOfExamples(); 58 | } 59 | 60 | public static float getMI(ClusterList clusters) { 61 | 62 | float res = 0; 63 | 64 | float N = clusters.getNumberOfExamples(); 65 | 66 | int k = clusters.size(); 67 | 68 | TreeMap classCardinality = getClassCardinality(clusters); 69 | 70 | for (int clustId = 0; clustId < k; clustId++) { 71 | 72 | TreeMap classSizes = getClassCardinalityWithinCluster(clusters, clustId); 73 | 74 | for (Label className : classSizes.keySet()) { 75 | int wSize = classSizes.get(className); 76 | res += ((float) wSize / N) * myLog(N * (float) wSize 77 | / (clusters.get(clustId).getExamples().size() * (float) classCardinality.get(className))); 78 | } 79 | 80 | } 81 | 82 | return res; 83 | 84 | } 85 | 86 | private static TreeMap getClassCardinalityWithinCluster(ClusterList clusters, int clustId) { 87 | 88 | TreeMap classSizes = new TreeMap(); 89 | 90 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 91 | HashSet labels = vce.getExample().getClassificationLabels(); 92 | for (Label label : labels) 93 | if (!classSizes.containsKey(label)) 94 | classSizes.put(label, 1); 95 | else 96 | classSizes.put(label, classSizes.get(label) + 1); 97 | } 98 | 99 | return classSizes; 100 | } 101 | 102 | private static float getClusterEntropy(ClusterList clusters) { 103 | 104 | float res = 0; 105 | float N = clusters.getNumberOfExamples(); 106 | int k = clusters.size(); 107 | 108 | for (int clustId = 0; clustId < k; clustId++) { 109 | int clusterElementSize = clusters.get(clustId).getExamples().size(); 110 | if (clusterElementSize != 0) 111 | res -= ((float) clusterElementSize / N) * myLog((float) clusterElementSize / N); 112 | } 113 | return res; 114 | 115 | } 116 | 117 | private static float getClassEntropy(ClusterList clusters) { 118 | 119 | float res = 0; 120 | float N = clusters.getNumberOfExamples(); 121 | 122 | TreeMap classCardinality = getClassCardinality(clusters); 123 | 124 | for (int classSize : classCardinality.values()) { 125 | res -= ((float) classSize / N) * myLog((float) classSize / N); 126 | } 127 | return res; 128 | 129 | } 130 | 131 | private static float myLog(float f) { 132 | return (float) (Math.log(f) / Math.log(2f)); 133 | } 134 | 135 | private static TreeMap getClassCardinality(ClusterList clusters) { 136 | TreeMap classSizes = new TreeMap(); 137 | 138 | int k = clusters.size(); 139 | 140 | for (int clustId = 0; clustId < k; clustId++) { 141 | 142 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 143 | HashSet labels = vce.getExample().getClassificationLabels(); 144 | for (Label label : labels) 145 | if (!classSizes.containsKey(label)) 146 | classSizes.put(label, 1); 147 | else 148 | classSizes.put(label, classSizes.get(label) + 1); 149 | } 150 | } 151 | return classSizes; 152 | } 153 | 154 | public static float getNMI(ClusterList clusters) { 155 | return getMI(clusters) / ((getClusterEntropy(clusters) + getClassEntropy(clusters)) / 2f); 156 | } 157 | 158 | public static String getStatistics(ClusterList clusters) { 159 | StringBuilder sb = new StringBuilder(); 160 | 161 | sb.append("Purity:\t" + getPurity(clusters) + "\n"); 162 | sb.append("Mutual Information:\t" + getMI(clusters) + "\n"); 163 | sb.append("Cluster Entropy:\t" + getClusterEntropy(clusters) + "\n"); 164 | sb.append("Class Entropy:\t" + getClassEntropy(clusters) + "\n"); 165 | sb.append("NMI:\t" + getNMI(clusters)); 166 | 167 | return sb.toString(); 168 | } 169 | 170 | public static void main(String[] args) { 171 | ClusterList clusters = new ClusterList(); 172 | 173 | Cluster c1 = new Cluster("C1"); 174 | ArrayList list1 = new ArrayList(); 175 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 176 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 177 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 178 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 179 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 180 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 181 | for (Example e : list1) { 182 | c1.add(new KernelBasedKMeansExample(e, 1f)); 183 | } 184 | 185 | Cluster c2 = new Cluster("C2"); 186 | ArrayList list2 = new ArrayList(); 187 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 188 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 189 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 190 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 191 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 192 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 193 | for (Example e : list2) { 194 | c2.add(new KernelBasedKMeansExample(e, 1f)); 195 | } 196 | 197 | Cluster c3 = new Cluster("C3"); 198 | ArrayList list3 = new ArrayList(); 199 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 200 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 201 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 202 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 203 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 204 | for (Example e : list3) { 205 | c3.add(new KernelBasedKMeansExample(e, 1f)); 206 | } 207 | 208 | clusters.add(c1); 209 | clusters.add(c2); 210 | clusters.add(c3); 211 | 212 | System.out.println(ClusteringEvaluator.getStatistics(clusters)); 213 | 214 | //From https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html 215 | //Purity = 0.71 216 | //NMI = 0.36 217 | 218 | } 219 | 220 | } 221 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/utils/evaluation/MulticlassSequenceClassificationEvaluator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.utils.evaluation; 17 | 18 | import java.util.List; 19 | 20 | import it.uniroma2.sag.kelp.data.example.Example; 21 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 22 | import it.uniroma2.sag.kelp.data.example.SequencePath; 23 | import it.uniroma2.sag.kelp.data.label.Label; 24 | import it.uniroma2.sag.kelp.data.label.SequenceEmission; 25 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 26 | import it.uniroma2.sag.kelp.predictionfunction.SequencePrediction; 27 | 28 | /** 29 | * This is an instance of an Evaluator. It allows to compute the some common 30 | * measure for classification tasks acting over SequenceExamples. It 31 | * computes precision, recall, f1s for each class, and a global accuracy. 32 | * 33 | * @author Danilo Croce 34 | */ 35 | public class MulticlassSequenceClassificationEvaluator extends MulticlassClassificationEvaluator{ 36 | 37 | /** 38 | * Initialize a new F1Evaluator that will work on the specified classes 39 | * 40 | * @param labels 41 | */ 42 | public MulticlassSequenceClassificationEvaluator(List labels) { 43 | super(labels); 44 | } 45 | 46 | public void addCount(Example test, Prediction prediction) { 47 | addCount((SequenceExample) test, (SequencePrediction) prediction); 48 | } 49 | 50 | /** 51 | * This method should be implemented in the subclasses to update counters 52 | * useful to compute the performance measure 53 | * 54 | * @param test 55 | * the test example 56 | * @param predicted 57 | * the prediction of the system 58 | */ 59 | public void addCount(SequenceExample test, SequencePrediction predicted) { 60 | 61 | SequencePath bestPath = predicted.bestPath(); 62 | 63 | for (int seqIdx = 0; seqIdx < test.getLenght(); seqIdx++) { 64 | 65 | Example testItem = test.getExample(seqIdx); 66 | SequenceEmission sequenceLabel = bestPath.getAssignedSequnceLabels().get(seqIdx); 67 | 68 | for (Label l : this.labels) { 69 | ClassStats stats = this.classStats.get(l); 70 | if(testItem.isExampleOf(l)){ 71 | if(sequenceLabel.getLabel().equals(l)){ 72 | stats.tp++; 73 | totalTp++; 74 | }else{ 75 | stats.fn++; 76 | totalFn++; 77 | } 78 | }else{ 79 | if(sequenceLabel.getLabel().equals(l)){ 80 | stats.fp++; 81 | totalFp++; 82 | }else{ 83 | stats.tn++; 84 | totalTn++; 85 | } 86 | } 87 | 88 | } 89 | 90 | //TODO: check (i) e' giusto valutare l'accuracy dei singoli elementi della sequenza e non della sequenza completa 91 | //(ii) va considerato il caso multilabel 92 | total++; 93 | 94 | if (testItem.isExampleOf(sequenceLabel.getLabel())) { 95 | correct++; 96 | } 97 | 98 | this.computed = false; 99 | } 100 | } 101 | 102 | } 103 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/algorithms/binary/liblinear/LibLinearDenseVsSparseClassificationEvaluator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.algorithms.binary.liblinear; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.data.manipulator.NormalizationManipolator; 22 | import it.uniroma2.sag.kelp.data.manipulator.VectorConcatenationManipulator; 23 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.LibLinearLearningAlgorithm; 24 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 25 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassificationOutput; 26 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassifier; 27 | import it.uniroma2.sag.kelp.utils.evaluation.MulticlassClassificationEvaluator; 28 | import it.uniroma2.sag.kelp.utils.exception.NoSuchPerformanceMeasureException; 29 | 30 | import java.io.FileNotFoundException; 31 | import java.io.UnsupportedEncodingException; 32 | import java.util.ArrayList; 33 | import java.util.List; 34 | 35 | import org.junit.Assert; 36 | import org.junit.Test; 37 | 38 | public class LibLinearDenseVsSparseClassificationEvaluator { 39 | 40 | private static List sparseScores = new ArrayList(); 41 | private static List denseScores = new ArrayList(); 42 | 43 | @Test 44 | public void testConsistency() { 45 | try { 46 | String inputFilePath = "src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz"; 47 | 48 | SimpleDataset dataset = new SimpleDataset(); 49 | dataset.populate(inputFilePath); 50 | SimpleDataset[] split = dataset.split(0.5f); 51 | 52 | SimpleDataset trainingSet = split[0]; 53 | SimpleDataset testSet = split[1]; 54 | float c = 1.0f; 55 | float f1Dense = testDense(trainingSet, c, testSet); 56 | float f1Sparse = testSparse(trainingSet, c, testSet); 57 | 58 | Assert.assertEquals(f1Sparse, f1Dense, 0.000001); 59 | 60 | for (int i = 0; i < sparseScores.size(); i++) { 61 | Assert.assertEquals(sparseScores.get(i), denseScores.get(i), 62 | 0.000001); 63 | } 64 | } catch (FileNotFoundException e) { 65 | e.printStackTrace(); 66 | Assert.assertTrue(false); 67 | } catch (UnsupportedEncodingException e) { 68 | e.printStackTrace(); 69 | Assert.assertTrue(false); 70 | } catch (NoSuchPerformanceMeasureException e) { 71 | e.printStackTrace(); 72 | Assert.assertTrue(false); 73 | } catch (Exception e) { 74 | e.printStackTrace(); 75 | Assert.assertTrue(false); 76 | } 77 | } 78 | 79 | private static float testSparse(SimpleDataset trainingSet, float c, 80 | SimpleDataset testSet) throws FileNotFoundException, 81 | UnsupportedEncodingException, NoSuchPerformanceMeasureException { 82 | List classes = trainingSet.getClassificationLabels(); 83 | NormalizationManipolator norma = new NormalizationManipolator(); 84 | trainingSet.manipulate(norma); 85 | testSet.manipulate(norma); 86 | List repr = new ArrayList(); 87 | repr.add("WS"); 88 | List reprW = new ArrayList(); 89 | reprW.add(1.0f); 90 | VectorConcatenationManipulator man = new VectorConcatenationManipulator( 91 | "WS0", repr, reprW); 92 | trainingSet.manipulate(man); 93 | testSet.manipulate(man); 94 | 95 | LibLinearLearningAlgorithm svmSolver = new LibLinearLearningAlgorithm(); 96 | svmSolver.setCn(c); 97 | svmSolver.setCp(c); 98 | svmSolver.setRepresentation("WS0"); 99 | 100 | OneVsAllLearning ovaLearner = new OneVsAllLearning(); 101 | ovaLearner.setBaseAlgorithm(svmSolver); 102 | ovaLearner.setLabels(classes); 103 | ovaLearner.learn(trainingSet); 104 | OneVsAllClassifier f = ovaLearner.getPredictionFunction(); 105 | MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator( 106 | trainingSet.getClassificationLabels()); 107 | for (Example e : testSet.getExamples()) { 108 | OneVsAllClassificationOutput predict = f.predict(e); 109 | Label l = predict.getPredictedClasses().get(0); 110 | evaluator.addCount(e, predict); 111 | sparseScores.add(predict.getScore(l)); 112 | } 113 | 114 | return evaluator.getMacroF1(); 115 | } 116 | 117 | private static float testDense(SimpleDataset trainingSet, float c, 118 | SimpleDataset testSet) throws FileNotFoundException, 119 | UnsupportedEncodingException, NoSuchPerformanceMeasureException { 120 | List classes = trainingSet.getClassificationLabels(); 121 | 122 | LibLinearLearningAlgorithm svmSolver = new LibLinearLearningAlgorithm(); 123 | svmSolver.setCn(c); 124 | svmSolver.setCp(c); 125 | svmSolver.setRepresentation("WS"); 126 | 127 | OneVsAllLearning ovaLearner = new OneVsAllLearning(); 128 | ovaLearner.setBaseAlgorithm(svmSolver); 129 | ovaLearner.setLabels(classes); 130 | ovaLearner.learn(trainingSet); 131 | OneVsAllClassifier f = ovaLearner.getPredictionFunction(); 132 | MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator( 133 | trainingSet.getClassificationLabels()); 134 | for (Example e : testSet.getExamples()) { 135 | OneVsAllClassificationOutput predict = f.predict(e); 136 | Label l = predict.getPredictedClasses().get(0); 137 | evaluator.addCount(e, predict); 138 | denseScores.add(predict.getScore(l)); 139 | } 140 | 141 | return evaluator.getMacroF1(); 142 | } 143 | 144 | } 145 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/algorithms/incrementalTrain/IncrementalTrainTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Simone Filice and Giuseppe Castellucci and Danilo Croce 3 | * and Giovanni Da San Martino and Alessandro Moschitti and Roberto Basili 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package it.uniroma2.sag.kelp.algorithms.incrementalTrain; 18 | 19 | import java.io.IOException; 20 | import java.util.Random; 21 | 22 | import org.junit.Assert; 23 | import org.junit.BeforeClass; 24 | import org.junit.Test; 25 | 26 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 27 | import it.uniroma2.sag.kelp.data.example.Example; 28 | import it.uniroma2.sag.kelp.data.label.Label; 29 | import it.uniroma2.sag.kelp.data.label.StringLabel; 30 | import it.uniroma2.sag.kelp.kernel.Kernel; 31 | import it.uniroma2.sag.kelp.kernel.cache.FixSizeKernelCache; 32 | import it.uniroma2.sag.kelp.kernel.vector.LinearKernel; 33 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 34 | import it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron.KernelizedPerceptron; 35 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryKernelMachineClassifier; 36 | import it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput; 37 | import it.uniroma2.sag.kelp.predictionfunction.classifier.Classifier; 38 | import it.uniroma2.sag.kelp.utils.JacksonSerializerWrapper; 39 | import it.uniroma2.sag.kelp.utils.ObjectSerializer; 40 | 41 | public class IncrementalTrainTest { 42 | private static Classifier f = null; 43 | private static SimpleDataset trainingSet; 44 | private static SimpleDataset testSet; 45 | private static SimpleDataset [] folds; 46 | private static ObjectSerializer serializer = new JacksonSerializerWrapper(); 47 | private static KernelizedPerceptron learner; 48 | 49 | private static Label positiveClass = new StringLabel("+1"); 50 | 51 | @BeforeClass 52 | public static void learnModel() { 53 | trainingSet = new SimpleDataset(); 54 | testSet = new SimpleDataset(); 55 | try { 56 | trainingSet.populate("src/test/resources/svmTest/binary/binary_train.klp"); 57 | trainingSet.shuffleExamples(new Random()); 58 | // Read a dataset into a test variable 59 | testSet.populate("src/test/resources/svmTest/binary/binary_test.klp"); 60 | } catch (Exception e) { 61 | e.printStackTrace(); 62 | Assert.assertTrue(false); 63 | } 64 | 65 | folds = trainingSet.nFolding(2); 66 | 67 | // define the kernel 68 | Kernel kernel = new LinearKernel("0"); 69 | 70 | // add a cache 71 | kernel.setKernelCache(new FixSizeKernelCache(trainingSet 72 | .getNumberOfExamples())); 73 | 74 | // define the learning algorithm 75 | learner = new KernelizedPerceptron(0.2f, 1f, false, kernel, positiveClass); 76 | 77 | // learn and get the prediction function 78 | learner.learn(trainingSet); 79 | f = learner.getPredictionFunction(); 80 | } 81 | 82 | @Test 83 | public void incrementalTrain() throws IOException{ 84 | String jsonSerialization = serializer.writeValueAsString(learner); 85 | System.out.println(jsonSerialization); 86 | ClassificationLearningAlgorithm jsonAlgo = serializer.readValue(jsonSerialization, ClassificationLearningAlgorithm.class); 87 | jsonAlgo.learn(folds[0]); 88 | jsonAlgo.learn(folds[1]); 89 | Classifier jsonClassifier = jsonAlgo.getPredictionFunction(); 90 | 91 | for(Example ex : testSet.getExamples()){ 92 | ClassificationOutput p = f.predict(ex); 93 | Float score = p.getScore(positiveClass); 94 | ClassificationOutput pJson = jsonClassifier.predict(ex); 95 | Float scoreJson = pJson.getScore(positiveClass); 96 | Assert.assertEquals(scoreJson.floatValue(), score.floatValue(), 97 | 0.001f); 98 | } 99 | } 100 | 101 | @Test 102 | public void reloadAndContinueTraining() throws IOException{ 103 | String jsonLearnerSerialization = serializer.writeValueAsString(learner); 104 | System.out.println(jsonLearnerSerialization); 105 | KernelizedPerceptron jsonAlgo = serializer.readValue(jsonLearnerSerialization, KernelizedPerceptron.class); 106 | jsonAlgo.learn(folds[0]); 107 | String jsonClassifierSerialization = serializer.writeValueAsString(jsonAlgo.getPredictionFunction()); 108 | jsonAlgo = serializer.readValue(jsonLearnerSerialization, KernelizedPerceptron.class); //Brand new classifier 109 | BinaryKernelMachineClassifier jsonClassifier = serializer.readValue(jsonClassifierSerialization, BinaryKernelMachineClassifier.class); 110 | jsonAlgo.getPredictionFunction().setModel(jsonClassifier.getModel()); 111 | jsonAlgo.learn(folds[1]); 112 | jsonClassifier = jsonAlgo.getPredictionFunction(); 113 | 114 | for(Example ex : testSet.getExamples()){ 115 | ClassificationOutput p = f.predict(ex); 116 | Float score = p.getScore(positiveClass); 117 | ClassificationOutput pJson = jsonClassifier.predict(ex); 118 | Float scoreJson = pJson.getScore(positiveClass); 119 | Assert.assertEquals(scoreJson.floatValue(), score.floatValue(), 120 | 0.001f); 121 | } 122 | } 123 | 124 | } 125 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/learningalgorithm/classification/hmm/SequenceLearningLinearTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.hmm; 17 | 18 | import java.io.BufferedReader; 19 | import java.io.FileInputStream; 20 | import java.io.FileNotFoundException; 21 | import java.io.IOException; 22 | import java.io.InputStreamReader; 23 | import java.io.UnsupportedEncodingException; 24 | import java.util.ArrayList; 25 | import java.util.zip.GZIPInputStream; 26 | 27 | import org.junit.Assert; 28 | import org.junit.Test; 29 | 30 | import it.uniroma2.sag.kelp.data.dataset.SequenceDataset; 31 | import it.uniroma2.sag.kelp.data.example.Example; 32 | import it.uniroma2.sag.kelp.data.example.ParsingExampleException; 33 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 34 | import it.uniroma2.sag.kelp.data.example.SequencePath; 35 | import it.uniroma2.sag.kelp.data.label.Label; 36 | import it.uniroma2.sag.kelp.data.label.StringLabel; 37 | import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLearningAlgorithm; 38 | import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLoss; 39 | import it.uniroma2.sag.kelp.predictionfunction.SequencePrediction; 40 | import it.uniroma2.sag.kelp.predictionfunction.SequencePredictionFunction; 41 | 42 | public class SequenceLearningLinearTest { 43 | 44 | private static final Float TOLERANCE = 0.001f; 45 | 46 | public static void main(String[] args) throws Exception { 47 | 48 | } 49 | 50 | @Test 51 | public void testLinear() { 52 | 53 | String inputTrainFilePath = "src/test/resources/sequence_learning/declaration_of_independence.klp.gz"; 54 | String inputTestFilePath = "src/test/resources/sequence_learning/gettysburg_address.klp.gz"; 55 | String scoreFilePath = "src/test/resources/sequence_learning/prediction_test_linear.txt"; 56 | 57 | /* 58 | * Given a targeted item in the sequence, this variable determines the 59 | * number of previous example considered in the learning/labeling 60 | * process. 61 | * 62 | * NOTE: if this variable is set to 0, the learning process corresponds 63 | * to a traditional multi-class classification schema 64 | */ 65 | int transitionsOrder = 1; 66 | 67 | /* 68 | * This variable determines the importance of the transition-based 69 | * features during the learning process. Higher valuers will assign more 70 | * importance to the transitions. 71 | */ 72 | float weight = 1f; 73 | 74 | /* 75 | * The size of the beam to be used in the decoding process. This number 76 | * determines the number of possible sequences produced in the labeling 77 | * process. It will also increase the process complexity. 78 | */ 79 | int beamSize = 5; 80 | 81 | /* 82 | * During the labeling process, each item is classified with respect to 83 | * the target classes. To reduce the complexity of the labeling process, 84 | * this variable determines the number of classes that received the 85 | * highest classification scores to be considered after the 86 | * classification step in the Viterbi Decoding. 87 | */ 88 | int maxEmissionCandidates = 3; 89 | 90 | /* 91 | * This representation contains the feature vector representing items in 92 | * the sequence 93 | */ 94 | String originalRepresentationName = "rep"; 95 | 96 | /* 97 | * Loading the training dataset 98 | */ 99 | SequenceDataset sequenceTrainDataset = new SequenceDataset(); 100 | try { 101 | sequenceTrainDataset.populate(inputTrainFilePath); 102 | } catch (IOException e) { 103 | e.printStackTrace(); 104 | Assert.assertTrue(false); 105 | } catch (InstantiationException e) { 106 | e.printStackTrace(); 107 | Assert.assertTrue(false); 108 | } catch (ParsingExampleException e) { 109 | e.printStackTrace(); 110 | Assert.assertTrue(false); 111 | } catch (Exception e) { 112 | e.printStackTrace(); 113 | Assert.assertTrue(false); 114 | } 115 | 116 | /* 117 | * Instance classifier 118 | */ 119 | float cSVM = 1f; 120 | DCDLearningAlgorithm instanceClassifierLearningAlgorithm = new DCDLearningAlgorithm(cSVM, cSVM, DCDLoss.L1, 121 | false, 50, originalRepresentationName); 122 | 123 | /* 124 | * Sequence classifier. 125 | */ 126 | SequenceClassificationLearningAlgorithm sequenceClassificationLearningAlgorithm = null; 127 | try { 128 | sequenceClassificationLearningAlgorithm = new SequenceClassificationLinearLearningAlgorithm( 129 | instanceClassifierLearningAlgorithm, transitionsOrder, weight); 130 | sequenceClassificationLearningAlgorithm.setMaxEmissionCandidates(maxEmissionCandidates); 131 | sequenceClassificationLearningAlgorithm.setBeamSize(beamSize); 132 | 133 | sequenceClassificationLearningAlgorithm.learn(sequenceTrainDataset); 134 | } catch (Exception e1) { 135 | e1.printStackTrace(); 136 | Assert.assertTrue(false); 137 | } 138 | 139 | SequencePredictionFunction predictionFunction = (SequencePredictionFunction) sequenceClassificationLearningAlgorithm 140 | .getPredictionFunction(); 141 | 142 | /* 143 | * Load the test set 144 | */ 145 | SequenceDataset sequenceTestDataset = new SequenceDataset(); 146 | try { 147 | sequenceTestDataset.populate(inputTestFilePath); 148 | } catch (IOException e) { 149 | e.printStackTrace(); 150 | Assert.assertTrue(false); 151 | } catch (InstantiationException e) { 152 | e.printStackTrace(); 153 | Assert.assertTrue(false); 154 | } catch (ParsingExampleException e) { 155 | e.printStackTrace(); 156 | Assert.assertTrue(false); 157 | } 158 | 159 | /* 160 | * Tagging and evaluating 161 | */ 162 | // PrintStream ps = new PrintStream(scoreFilePath); 163 | ArrayList labels = new ArrayList(); 164 | ArrayList scores = new ArrayList(); 165 | for (Example example : sequenceTestDataset.getExamples()) { 166 | 167 | SequenceExample sequenceExample = (SequenceExample) example; 168 | SequencePrediction sequencePrediction = (SequencePrediction) predictionFunction.predict(sequenceExample); 169 | 170 | SequencePath bestPath = sequencePrediction.bestPath(); 171 | for (int i = 0; i < sequenceExample.getLenght(); i++) { 172 | // ps.println(bestPath.getAssignedLabel(i) + "\t" + 173 | // bestPath.getScore()); 174 | labels.add(bestPath.getAssignedLabel(i)); 175 | scores.add(bestPath.getScore()); 176 | } 177 | 178 | } 179 | // ps.close(); 180 | 181 | ArrayList oldScores = loadScores(scoreFilePath); 182 | ArrayList oldLabels = loadLabels(scoreFilePath); 183 | 184 | for (int i = 0; i < oldScores.size(); i++) { 185 | Assert.assertEquals(oldScores.get(i), scores.get(i), TOLERANCE); 186 | Assert.assertEquals(labels.get(i).toString(), oldLabels.get(i).toString()); 187 | } 188 | 189 | } 190 | 191 | public static ArrayList loadScores(String filepath) { 192 | try { 193 | ArrayList scores = new ArrayList(); 194 | BufferedReader in = null; 195 | String encoding = "UTF-8"; 196 | if (filepath.endsWith(".gz")) { 197 | in = new BufferedReader( 198 | new InputStreamReader(new GZIPInputStream(new FileInputStream(filepath)), encoding)); 199 | } else { 200 | in = new BufferedReader(new InputStreamReader(new FileInputStream(filepath), encoding)); 201 | } 202 | 203 | String str = ""; 204 | while ((str = in.readLine()) != null) { 205 | scores.add(Double.parseDouble(str.split("\t")[1])); 206 | } 207 | 208 | in.close(); 209 | 210 | return scores; 211 | 212 | } catch (UnsupportedEncodingException e) { 213 | e.printStackTrace(); 214 | Assert.assertTrue(false); 215 | } catch (FileNotFoundException e) { 216 | e.printStackTrace(); 217 | Assert.assertTrue(false); 218 | } catch (IOException e) { 219 | e.printStackTrace(); 220 | Assert.assertTrue(false); 221 | } 222 | 223 | return null; 224 | } 225 | 226 | public static ArrayList loadLabels(String filepath) { 227 | try { 228 | ArrayList res = new ArrayList(); 229 | BufferedReader in = null; 230 | String encoding = "UTF-8"; 231 | if (filepath.endsWith(".gz")) { 232 | in = new BufferedReader( 233 | new InputStreamReader(new GZIPInputStream(new FileInputStream(filepath)), encoding)); 234 | } else { 235 | in = new BufferedReader(new InputStreamReader(new FileInputStream(filepath), encoding)); 236 | } 237 | 238 | String str = ""; 239 | while ((str = in.readLine()) != null) { 240 | res.add(new StringLabel(str.split("\t")[0])); 241 | } 242 | 243 | in.close(); 244 | 245 | return res; 246 | 247 | } catch (UnsupportedEncodingException e) { 248 | e.printStackTrace(); 249 | Assert.assertTrue(false); 250 | } catch (FileNotFoundException e) { 251 | e.printStackTrace(); 252 | Assert.assertTrue(false); 253 | } catch (IOException e) { 254 | e.printStackTrace(); 255 | Assert.assertTrue(false); 256 | } 257 | 258 | return null; 259 | } 260 | 261 | } 262 | -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/README.txt: -------------------------------------------------------------------------------- 1 | The datasets reported in this folder have been created starting from the dataset produced by Thorsten Joachims as an example problem for his SVM^{hmm} implementation. 2 | 3 | The original dataset can be downloaded at: 4 | http://download.joachims.org/svm_hmm/examples/example7.tar.gz 5 | while its description is reported at: 6 | https://www.cs.cornell.edu/people/tj/svm_light/svm_hmm.html -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/declaration_of_independence.klp.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/sequence_learning/declaration_of_independence.klp.gz -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/gettysburg_address.klp.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/sequence_learning/gettysburg_address.klp.gz -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/prediction_test_kernel.txt: -------------------------------------------------------------------------------- 1 | 1 -58.85170393685728 2 | 22 -58.85170393685728 3 | 3 -58.85170393685728 4 | 8 -58.85170393685728 5 | 15 -58.85170393685728 6 | 30 -58.85170393685728 7 | 20 -58.85170393685728 8 | 15 -58.85170393685728 9 | 11 -58.85170393685728 10 | 12 -58.85170393685728 11 | 8 -58.85170393685728 12 | 7 -58.85170393685728 13 | 12 -58.85170393685728 14 | 7 -58.85170393685728 15 | 9 -58.85170393685728 16 | 12 -58.85170393685728 17 | 6 -58.85170393685728 18 | 27 -58.85170393685728 19 | 8 -58.85170393685728 20 | 13 -58.85170393685728 21 | 3 -58.85170393685728 22 | 27 -58.85170393685728 23 | 25 -58.85170393685728 24 | 7 -58.85170393685728 25 | 12 -58.85170393685728 26 | 7 -58.85170393685728 27 | 7 -58.85170393685728 28 | 15 -58.85170393685728 29 | 30 -58.85170393685728 30 | 29 -58.85170393685728 31 | 9 -58.85170393685728 32 | 17 -58.85170393685728 33 | 21 -49.958351223707055 34 | 19 -49.958351223707055 35 | 30 -49.958351223707055 36 | 29 -49.958351223707055 37 | 8 -49.958351223707055 38 | 7 -49.958351223707055 39 | 13 -49.958351223707055 40 | 9 -49.958351223707055 41 | 12 -49.958351223707055 42 | 6 -49.958351223707055 43 | 28 -49.958351223707055 44 | 21 -49.958351223707055 45 | 7 -49.958351223707055 46 | 12 -49.958351223707055 47 | 8 -49.958351223707055 48 | 7 -49.958351223707055 49 | 12 -49.958351223707055 50 | 8 -49.958351223707055 51 | 29 -49.958351223707055 52 | 3 -49.958351223707055 53 | 21 -49.958351223707055 54 | 27 -49.958351223707055 55 | 9 -49.958351223707055 56 | 9 -49.958351223707055 57 | 12 -49.958351223707055 58 | 17 -49.958351223707055 59 | 19 -19.14935390144825 60 | 30 -19.14935390144825 61 | 29 -19.14935390144825 62 | 8 -19.14935390144825 63 | 7 -19.14935390144825 64 | 13 -19.14935390144825 65 | 13 -19.14935390144825 66 | 8 -19.14935390144825 67 | 7 -19.14935390144825 68 | 12 -19.14935390144825 69 | 17 -19.14935390144825 70 | 19 -51.68865761583535 71 | 30 -51.68865761583535 72 | 9 -51.68865761583535 73 | 25 -51.68865761583535 74 | 26 -51.68865761583535 75 | 7 -51.68865761583535 76 | 12 -51.68865761583535 77 | 8 -51.68865761583535 78 | 7 -51.68865761583535 79 | 12 -51.68865761583535 80 | 8 -51.68865761583535 81 | 7 -51.68865761583535 82 | 9 -51.68865761583535 83 | 12 -51.68865761583535 84 | 8 -51.68865761583535 85 | 7 -51.68865761583535 86 | 9 -51.68865761583535 87 | 12 -51.68865761583535 88 | 30 -51.68865761583535 89 | 20 -51.68865761583535 90 | 15 -51.68865761583535 91 | 7 -51.68865761583535 92 | 7 -51.68865761583535 93 | 12 -51.68865761583535 94 | 12 -51.68865761583535 95 | 30 -51.68865761583535 96 | 17 -51.68865761583535 97 | 19 -20.546289531993914 98 | 31 -20.546289531993914 99 | 21 -20.546289531993914 100 | 28 -20.546289531993914 101 | 3 -20.546289531993914 102 | 8 -20.546289531993914 103 | 7 -20.546289531993914 104 | 19 -20.546289531993914 105 | 11 -20.546289531993914 106 | 26 -20.546289531993914 107 | 7 -20.546289531993914 108 | 17 -20.546289531993914 109 | 3 -39.18519755398995 110 | 8 -39.18519755398995 111 | 7 -39.18519755398995 112 | 9 -39.18519755398995 113 | 12 -39.18519755398995 114 | 6 -39.18519755398995 115 | 19 -39.18519755398995 116 | 11 -39.18519755398995 117 | 26 -39.18519755398995 118 | 6 -39.18519755398995 119 | 19 -39.18519755398995 120 | 11 -39.18519755398995 121 | 26 -39.18519755398995 122 | 6 -39.18519755398995 123 | 19 -39.18519755398995 124 | 11 -39.18519755398995 125 | 26 -39.18519755398995 126 | 7 -39.18519755398995 127 | 12 -39.18519755398995 128 | 17 -39.18519755398995 129 | 7 -46.958838324933005 130 | 9 -46.958838324933005 131 | 15 -46.958838324933005 132 | 6 -46.958838324933005 133 | 28 -46.958838324933005 134 | 3 -46.958838324933005 135 | 12 -46.958838324933005 136 | 21 -46.958838324933005 137 | 29 -46.958838324933005 138 | 21 -46.958838324933005 139 | 30 -46.958838324933005 140 | 29 -46.958838324933005 141 | 19 -46.958838324933005 142 | 20 -46.958838324933005 143 | 12 -46.958838324933005 144 | 20 -46.958838324933005 145 | 9 -46.958838324933005 146 | 12 -46.958838324933005 147 | 25 -46.958838324933005 148 | 26 -46.958838324933005 149 | 3 -46.958838324933005 150 | 9 -46.958838324933005 151 | 17 -46.958838324933005 152 | 7 -49.424689389703104 153 | 12 -49.424689389703104 154 | 11 -49.424689389703104 155 | 26 -49.424689389703104 156 | 9 -49.424689389703104 157 | 3 -49.424689389703104 158 | 9 -49.424689389703104 159 | 12 -49.424689389703104 160 | 8 -49.424689389703104 161 | 19 -49.424689389703104 162 | 11 -49.424689389703104 163 | 21 -49.424689389703104 164 | 6 -49.424689389703104 165 | 3 -49.424689389703104 166 | 19 -49.424689389703104 167 | 8 -49.424689389703104 168 | 12 -49.424689389703104 169 | 26 -49.424689389703104 170 | 7 -49.424689389703104 171 | 19 -49.424689389703104 172 | 30 -49.424689389703104 173 | 21 -49.424689389703104 174 | 17 -49.424689389703104 175 | 19 -52.114114669781316 176 | 31 -52.114114669781316 177 | 8 -52.114114669781316 178 | 19 -52.114114669781316 179 | 7 -52.114114669781316 180 | 28 -52.114114669781316 181 | 21 -52.114114669781316 182 | 25 -52.114114669781316 183 | 26 -52.114114669781316 184 | 27 -52.114114669781316 185 | 21 -52.114114669781316 186 | 25 -52.114114669781316 187 | 7 -52.114114669781316 188 | 9 -52.114114669781316 189 | 12 -52.114114669781316 190 | 32 -52.114114669781316 191 | 19 -52.114114669781316 192 | 30 -52.114114669781316 193 | 11 -52.114114669781316 194 | 21 -52.114114669781316 195 | 30 -52.114114669781316 196 | 21 -52.114114669781316 197 | 29 -52.114114669781316 198 | 21 -52.114114669781316 199 | 21 -52.114114669781316 200 | 29 -52.114114669781316 201 | 17 -52.114114669781316 202 | 19 -157.40459068974272 203 | 31 -157.40459068974272 204 | 21 -157.40459068974272 205 | 8 -157.40459068974272 206 | 19 -157.40459068974272 207 | 25 -157.40459068974272 208 | 26 -157.40459068974272 209 | 21 -157.40459068974272 210 | 27 -157.40459068974272 211 | 25 -157.40459068974272 212 | 7 -157.40459068974272 213 | 13 -157.40459068974272 214 | 13 -157.40459068974272 215 | 28 -157.40459068974272 216 | 21 -157.40459068974272 217 | 19 -157.40459068974272 218 | 5 -157.40459068974272 219 | 7 -157.40459068974272 220 | 8 -157.40459068974272 221 | 7 -157.40459068974272 222 | 9 -157.40459068974272 223 | 12 -157.40459068974272 224 | 19 -157.40459068974272 225 | 30 -157.40459068974272 226 | 9 -157.40459068974272 227 | 12 -157.40459068974272 228 | 25 -157.40459068974272 229 | 7 -157.40459068974272 230 | 12 -157.40459068974272 231 | 8 -157.40459068974272 232 | 32 -157.40459068974272 233 | 19 -157.40459068974272 234 | 30 -157.40459068974272 235 | 7 -157.40459068974272 236 | 23 -157.40459068974272 237 | 9 -157.40459068974272 238 | 12 -157.40459068974272 239 | 8 -157.40459068974272 240 | 12 -157.40459068974272 241 | 8 -157.40459068974272 242 | 7 -157.40459068974272 243 | 19 -157.40459068974272 244 | 21 -157.40459068974272 245 | 21 -157.40459068974272 246 | 26 -157.40459068974272 247 | 7 -157.40459068974272 248 | 7 -157.40459068974272 249 | 12 -157.40459068974272 250 | 11 -157.40459068974272 251 | 21 -157.40459068974272 252 | 30 -157.40459068974272 253 | 29 -157.40459068974272 254 | 8 -157.40459068974272 255 | 9 -157.40459068974272 256 | 6 -157.40459068974272 257 | 7 -157.40459068974272 258 | 7 -157.40459068974272 259 | 12 -157.40459068974272 260 | 8 -157.40459068974272 261 | 13 -157.40459068974272 262 | 11 -157.40459068974272 263 | 30 -157.40459068974272 264 | 7 -157.40459068974272 265 | 9 -157.40459068974272 266 | 12 -157.40459068974272 267 | 8 -157.40459068974272 268 | 12 -157.40459068974272 269 | 6 -157.40459068974272 270 | 3 -157.40459068974272 271 | 7 -157.40459068974272 272 | 12 -157.40459068974272 273 | 8 -157.40459068974272 274 | 7 -157.40459068974272 275 | 15 -157.40459068974272 276 | 6 -157.40459068974272 277 | 8 -157.40459068974272 278 | 7 -157.40459068974272 279 | 15 -157.40459068974272 280 | 6 -157.40459068974272 281 | 8 -157.40459068974272 282 | 7 -157.40459068974272 283 | 15 -157.40459068974272 284 | 11 -157.40459068974272 285 | 21 -157.40459068974272 286 | 26 -157.40459068974272 287 | 8 -157.40459068974272 288 | 7 -157.40459068974272 289 | 12 -157.40459068974272 290 | 17 -157.40459068974272 291 | -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/prediction_test_linear.txt: -------------------------------------------------------------------------------- 1 | 1 -61.552865965064605 2 | 22 -61.552865965064605 3 | 3 -61.552865965064605 4 | 8 -61.552865965064605 5 | 15 -61.552865965064605 6 | 30 -61.552865965064605 7 | 20 -61.552865965064605 8 | 15 -61.552865965064605 9 | 11 -61.552865965064605 10 | 12 -61.552865965064605 11 | 8 -61.552865965064605 12 | 7 -61.552865965064605 13 | 12 -61.552865965064605 14 | 7 -61.552865965064605 15 | 9 -61.552865965064605 16 | 12 -61.552865965064605 17 | 6 -61.552865965064605 18 | 27 -61.552865965064605 19 | 8 -61.552865965064605 20 | 13 -61.552865965064605 21 | 3 -61.552865965064605 22 | 27 -61.552865965064605 23 | 25 -61.552865965064605 24 | 7 -61.552865965064605 25 | 12 -61.552865965064605 26 | 7 -61.552865965064605 27 | 7 -61.552865965064605 28 | 15 -61.552865965064605 29 | 30 -61.552865965064605 30 | 29 -61.552865965064605 31 | 9 -61.552865965064605 32 | 17 -61.552865965064605 33 | 21 -50.586976361817456 34 | 19 -50.586976361817456 35 | 30 -50.586976361817456 36 | 29 -50.586976361817456 37 | 8 -50.586976361817456 38 | 7 -50.586976361817456 39 | 13 -50.586976361817456 40 | 9 -50.586976361817456 41 | 12 -50.586976361817456 42 | 6 -50.586976361817456 43 | 28 -50.586976361817456 44 | 21 -50.586976361817456 45 | 7 -50.586976361817456 46 | 12 -50.586976361817456 47 | 3 -50.586976361817456 48 | 7 -50.586976361817456 49 | 12 -50.586976361817456 50 | 21 -50.586976361817456 51 | 21 -50.586976361817456 52 | 3 -50.586976361817456 53 | 21 -50.586976361817456 54 | 21 -50.586976361817456 55 | 7 -50.586976361817456 56 | 9 -50.586976361817456 57 | 12 -50.586976361817456 58 | 17 -50.586976361817456 59 | 19 -19.745991163812985 60 | 30 -19.745991163812985 61 | 26 -19.745991163812985 62 | 8 -19.745991163812985 63 | 7 -19.745991163812985 64 | 13 -19.745991163812985 65 | 13 -19.745991163812985 66 | 8 -19.745991163812985 67 | 7 -19.745991163812985 68 | 12 -19.745991163812985 69 | 17 -19.745991163812985 70 | 19 -54.528993898737625 71 | 30 -54.528993898737625 72 | 9 -54.528993898737625 73 | 25 -54.528993898737625 74 | 26 -54.528993898737625 75 | 7 -54.528993898737625 76 | 12 -54.528993898737625 77 | 8 -54.528993898737625 78 | 7 -54.528993898737625 79 | 12 -54.528993898737625 80 | 8 -54.528993898737625 81 | 7 -54.528993898737625 82 | 9 -54.528993898737625 83 | 12 -54.528993898737625 84 | 8 -54.528993898737625 85 | 7 -54.528993898737625 86 | 30 -54.528993898737625 87 | 21 -54.528993898737625 88 | 30 -54.528993898737625 89 | 20 -54.528993898737625 90 | 15 -54.528993898737625 91 | 7 -54.528993898737625 92 | 7 -54.528993898737625 93 | 12 -54.528993898737625 94 | 12 -54.528993898737625 95 | 30 -54.528993898737625 96 | 17 -54.528993898737625 97 | 19 -21.642864657110263 98 | 31 -21.642864657110263 99 | 21 -21.642864657110263 100 | 28 -21.642864657110263 101 | 3 -21.642864657110263 102 | 8 -21.642864657110263 103 | 7 -21.642864657110263 104 | 19 -21.642864657110263 105 | 11 -21.642864657110263 106 | 26 -21.642864657110263 107 | 7 -21.642864657110263 108 | 17 -21.642864657110263 109 | 3 -41.163958681094705 110 | 8 -41.163958681094705 111 | 7 -41.163958681094705 112 | 9 -41.163958681094705 113 | 12 -41.163958681094705 114 | 6 -41.163958681094705 115 | 19 -41.163958681094705 116 | 11 -41.163958681094705 117 | 26 -41.163958681094705 118 | 6 -41.163958681094705 119 | 19 -41.163958681094705 120 | 11 -41.163958681094705 121 | 26 -41.163958681094705 122 | 6 -41.163958681094705 123 | 19 -41.163958681094705 124 | 11 -41.163958681094705 125 | 26 -41.163958681094705 126 | 7 -41.163958681094705 127 | 12 -41.163958681094705 128 | 17 -41.163958681094705 129 | 7 -47.69120077996808 130 | 9 -47.69120077996808 131 | 15 -47.69120077996808 132 | 6 -47.69120077996808 133 | 28 -47.69120077996808 134 | 3 -47.69120077996808 135 | 12 -47.69120077996808 136 | 21 -47.69120077996808 137 | 29 -47.69120077996808 138 | 21 -47.69120077996808 139 | 30 -47.69120077996808 140 | 29 -47.69120077996808 141 | 19 -47.69120077996808 142 | 12 -47.69120077996808 143 | 30 -47.69120077996808 144 | 20 -47.69120077996808 145 | 9 -47.69120077996808 146 | 12 -47.69120077996808 147 | 25 -47.69120077996808 148 | 26 -47.69120077996808 149 | 3 -47.69120077996808 150 | 12 -47.69120077996808 151 | 17 -47.69120077996808 152 | 7 -49.193906625295206 153 | 12 -49.193906625295206 154 | 11 -49.193906625295206 155 | 26 -49.193906625295206 156 | 9 -49.193906625295206 157 | 3 -49.193906625295206 158 | 9 -49.193906625295206 159 | 12 -49.193906625295206 160 | 7 -49.193906625295206 161 | 19 -49.193906625295206 162 | 11 -49.193906625295206 163 | 21 -49.193906625295206 164 | 6 -49.193906625295206 165 | 3 -49.193906625295206 166 | 19 -49.193906625295206 167 | 8 -49.193906625295206 168 | 12 -49.193906625295206 169 | 26 -49.193906625295206 170 | 7 -49.193906625295206 171 | 19 -49.193906625295206 172 | 30 -49.193906625295206 173 | 21 -49.193906625295206 174 | 17 -49.193906625295206 175 | 19 -52.77654733531991 176 | 31 -52.77654733531991 177 | 8 -52.77654733531991 178 | 19 -52.77654733531991 179 | 7 -52.77654733531991 180 | 28 -52.77654733531991 181 | 21 -52.77654733531991 182 | 25 -52.77654733531991 183 | 26 -52.77654733531991 184 | 27 -52.77654733531991 185 | 21 -52.77654733531991 186 | 25 -52.77654733531991 187 | 7 -52.77654733531991 188 | 9 -52.77654733531991 189 | 12 -52.77654733531991 190 | 32 -52.77654733531991 191 | 19 -52.77654733531991 192 | 30 -52.77654733531991 193 | 11 -52.77654733531991 194 | 21 -52.77654733531991 195 | 30 -52.77654733531991 196 | 21 -52.77654733531991 197 | 12 -52.77654733531991 198 | 21 -52.77654733531991 199 | 21 -52.77654733531991 200 | 29 -52.77654733531991 201 | 17 -52.77654733531991 202 | 19 -162.86896772139426 203 | 31 -162.86896772139426 204 | 21 -162.86896772139426 205 | 8 -162.86896772139426 206 | 19 -162.86896772139426 207 | 25 -162.86896772139426 208 | 26 -162.86896772139426 209 | 21 -162.86896772139426 210 | 27 -162.86896772139426 211 | 25 -162.86896772139426 212 | 7 -162.86896772139426 213 | 13 -162.86896772139426 214 | 13 -162.86896772139426 215 | 28 -162.86896772139426 216 | 21 -162.86896772139426 217 | 19 -162.86896772139426 218 | 5 -162.86896772139426 219 | 7 -162.86896772139426 220 | 8 -162.86896772139426 221 | 7 -162.86896772139426 222 | 9 -162.86896772139426 223 | 12 -162.86896772139426 224 | 19 -162.86896772139426 225 | 30 -162.86896772139426 226 | 9 -162.86896772139426 227 | 12 -162.86896772139426 228 | 25 -162.86896772139426 229 | 7 -162.86896772139426 230 | 12 -162.86896772139426 231 | 8 -162.86896772139426 232 | 32 -162.86896772139426 233 | 19 -162.86896772139426 234 | 30 -162.86896772139426 235 | 7 -162.86896772139426 236 | 23 -162.86896772139426 237 | 9 -162.86896772139426 238 | 12 -162.86896772139426 239 | 8 -162.86896772139426 240 | 12 -162.86896772139426 241 | 5 -162.86896772139426 242 | 7 -162.86896772139426 243 | 19 -162.86896772139426 244 | 21 -162.86896772139426 245 | 21 -162.86896772139426 246 | 26 -162.86896772139426 247 | 7 -162.86896772139426 248 | 7 -162.86896772139426 249 | 12 -162.86896772139426 250 | 11 -162.86896772139426 251 | 21 -162.86896772139426 252 | 30 -162.86896772139426 253 | 29 -162.86896772139426 254 | 8 -162.86896772139426 255 | 9 -162.86896772139426 256 | 6 -162.86896772139426 257 | 7 -162.86896772139426 258 | 7 -162.86896772139426 259 | 12 -162.86896772139426 260 | 8 -162.86896772139426 261 | 13 -162.86896772139426 262 | 11 -162.86896772139426 263 | 30 -162.86896772139426 264 | 7 -162.86896772139426 265 | 9 -162.86896772139426 266 | 12 -162.86896772139426 267 | 8 -162.86896772139426 268 | 12 -162.86896772139426 269 | 6 -162.86896772139426 270 | 3 -162.86896772139426 271 | 8 -162.86896772139426 272 | 12 -162.86896772139426 273 | 8 -162.86896772139426 274 | 7 -162.86896772139426 275 | 15 -162.86896772139426 276 | 6 -162.86896772139426 277 | 8 -162.86896772139426 278 | 7 -162.86896772139426 279 | 15 -162.86896772139426 280 | 6 -162.86896772139426 281 | 8 -162.86896772139426 282 | 7 -162.86896772139426 283 | 15 -162.86896772139426 284 | 11 -162.86896772139426 285 | 21 -162.86896772139426 286 | 26 -162.86896772139426 287 | 8 -162.86896772139426 288 | 7 -162.86896772139426 289 | 12 -162.86896772139426 290 | 17 -162.86896772139426 291 | -------------------------------------------------------------------------------- /src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz --------------------------------------------------------------------------------
[Rosenblatt1957] F. Rosenblatt. The Perceptron – a perceiving and recognizing automaton. Report 85-460-1, Cornell Aeronautical Laboratory (1957) 31 | * 32 | * @author Simone Filice 33 | * 34 | */ 35 | @JsonTypeName("kernelizedPerceptron") 36 | public class KernelizedPerceptron extends Perceptron implements KernelMethod{ 37 | 38 | 39 | private Kernel kernel; 40 | 41 | public KernelizedPerceptron(){ 42 | this.classifier = new BinaryKernelMachineClassifier(); 43 | this.classifier.setModel(new BinaryKernelMachineModel()); 44 | } 45 | 46 | public KernelizedPerceptron(float alpha, float margin, boolean unbiased, Kernel kernel, Label label){ 47 | this.classifier = new BinaryKernelMachineClassifier(); 48 | this.classifier.setModel(new BinaryKernelMachineModel()); 49 | this.setAlpha(alpha); 50 | this.setMargin(margin); 51 | this.setUnbiased(unbiased); 52 | this.setKernel(kernel); 53 | this.setLabel(label); 54 | } 55 | 56 | @Override 57 | public Kernel getKernel() { 58 | return kernel; 59 | } 60 | 61 | @Override 62 | public void setKernel(Kernel kernel) { 63 | this.kernel = kernel; 64 | this.getPredictionFunction().getModel().setKernel(kernel); 65 | } 66 | 67 | @Override 68 | public KernelizedPerceptron duplicate(){ 69 | KernelizedPerceptron copy = new KernelizedPerceptron(); 70 | copy.setKernel(this.kernel); 71 | copy.setAlpha(this.alpha); 72 | copy.setMargin(this.margin); 73 | copy.setUnbiased(this.unbiased); 74 | return copy; 75 | } 76 | 77 | @Override 78 | public BinaryKernelMachineClassifier getPredictionFunction(){ 79 | return (BinaryKernelMachineClassifier) this.classifier; 80 | } 81 | 82 | @Override 83 | public void setPredictionFunction(PredictionFunction predictionFunction) { 84 | this.classifier = (BinaryKernelMachineClassifier) predictionFunction; 85 | } 86 | 87 | } 88 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/perceptron/LinearPerceptron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron; 17 | 18 | 19 | import com.fasterxml.jackson.annotation.JsonTypeName; 20 | 21 | import it.uniroma2.sag.kelp.data.label.Label; 22 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 23 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 24 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryLinearClassifier; 25 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 26 | 27 | /** 28 | * The perceptron learning algorithm algorithm for classification tasks (linear version). Reference: 29 | *
[Rosenblatt1957] F. Rosenblatt. The Perceptron – a perceiving and recognizing automaton. Report 85-460-1, Cornell Aeronautical Laboratory (1957) 30 | * 31 | * @author Simone Filice 32 | * 33 | */ 34 | @JsonTypeName("linearPerceptron") 35 | public class LinearPerceptron extends Perceptron implements LinearMethod{ 36 | 37 | 38 | private String representation; 39 | 40 | 41 | public LinearPerceptron(){ 42 | this.classifier = new BinaryLinearClassifier(); 43 | this.classifier.setModel(new BinaryLinearModel()); 44 | } 45 | 46 | public LinearPerceptron(float alpha, float margin, boolean unbiased, String representation, Label label){ 47 | this.classifier = new BinaryLinearClassifier(); 48 | this.classifier.setModel(new BinaryLinearModel()); 49 | this.setAlpha(alpha); 50 | this.setMargin(margin); 51 | this.setUnbiased(unbiased); 52 | this.setRepresentation(representation); 53 | this.setLabel(label); 54 | } 55 | 56 | @Override 57 | public String getRepresentation() { 58 | return representation; 59 | } 60 | 61 | @Override 62 | public void setRepresentation(String representation) { 63 | this.representation = representation; 64 | BinaryLinearModel model = (BinaryLinearModel) this.classifier.getModel(); 65 | model.setRepresentation(representation); 66 | } 67 | 68 | @Override 69 | public LinearPerceptron duplicate(){ 70 | LinearPerceptron copy = new LinearPerceptron(); 71 | copy.setAlpha(this.alpha); 72 | copy.setMargin(this.margin); 73 | copy.setRepresentation(representation); 74 | copy.setUnbiased(this.unbiased); 75 | return copy; 76 | } 77 | 78 | @Override 79 | public BinaryLinearClassifier getPredictionFunction(){ 80 | return (BinaryLinearClassifier) this.classifier; 81 | } 82 | 83 | @Override 84 | public void setPredictionFunction(PredictionFunction predictionFunction) { 85 | this.classifier = (BinaryLinearClassifier) predictionFunction; 86 | } 87 | 88 | } 89 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/perceptron/Perceptron.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron; 17 | 18 | import java.util.Arrays; 19 | import java.util.List; 20 | 21 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 22 | import it.uniroma2.sag.kelp.data.example.Example; 23 | import it.uniroma2.sag.kelp.data.label.Label; 24 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 25 | import it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm; 26 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 27 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryClassifier; 28 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput; 29 | 30 | import com.fasterxml.jackson.annotation.JsonIgnore; 31 | 32 | /** 33 | * The perceptron learning algorithm algorithm for classification tasks. Reference: 34 | *
[Rosenblatt1957] F. Rosenblatt. The Perceptron – a perceiving and recognizing automaton. Report 85-460-1, Cornell Aeronautical Laboratory (1957) 35 | * 36 | * @author Simone Filice 37 | * 38 | */ 39 | public abstract class Perceptron implements ClassificationLearningAlgorithm, OnlineLearningAlgorithm, BinaryLearningAlgorithm{ 40 | 41 | @JsonIgnore 42 | protected BinaryClassifier classifier; 43 | 44 | protected Label label; 45 | 46 | protected float alpha=1; 47 | protected float margin = 1; 48 | protected boolean unbiased=false; 49 | 50 | /** 51 | * Returns the learning rate, i.e. the weight associated to misclassified examples during the learning process 52 | * 53 | * @return the learning rate 54 | */ 55 | public float getAlpha() { 56 | return alpha; 57 | } 58 | 59 | /** 60 | * Sets the learning rate, i.e. the weight associated to misclassified examples during the learning process 61 | * 62 | * @param alpha the learning rate to set 63 | */ 64 | public void setAlpha(float alpha) { 65 | if(alpha<=0 || alpha>1){ 66 | throw new IllegalArgumentException("Invalid learning rate for the perceptron algorithm: valid alphas in (0,1]"); 67 | } 68 | this.alpha = alpha; 69 | } 70 | 71 | /** 72 | * Returns the desired margin, i.e. the minimum distance from the hyperplane that an example must have 73 | * in order to be not considered misclassified 74 | * 75 | * @return the margin 76 | */ 77 | public float getMargin() { 78 | return margin; 79 | } 80 | 81 | /** 82 | * Sets the desired margin, i.e. the minimum distance from the hyperplane that an example must have 83 | * in order to be not considered misclassified 84 | * 85 | * @param margin the margin to set 86 | */ 87 | public void setMargin(float margin) { 88 | this.margin = margin; 89 | } 90 | 91 | /** 92 | * Returns whether the bias, i.e. the constant term of the hyperplane, is always 0, or can be modified during 93 | * the learning process 94 | * 95 | * @return the unbiased 96 | */ 97 | public boolean isUnbiased() { 98 | return unbiased; 99 | } 100 | 101 | /** 102 | * Sets whether the bias, i.e. the constant term of the hyperplane, is always 0, or can be modified during 103 | * the learning process 104 | * 105 | * @param unbiased the unbiased to set 106 | */ 107 | public void setUnbiased(boolean unbiased) { 108 | this.unbiased = unbiased; 109 | } 110 | 111 | 112 | @Override 113 | public void learn(Dataset dataset) { 114 | 115 | while(dataset.hasNextExample()){ 116 | Example example = dataset.getNextExample(); 117 | this.learn(example); 118 | } 119 | dataset.reset(); 120 | } 121 | 122 | @Override 123 | public BinaryMarginClassifierOutput learn(Example example){ 124 | BinaryMarginClassifierOutput prediction = this.classifier.predict(example); 125 | 126 | float predValue = prediction.getScore(label); 127 | if(Math.abs(predValue) labels){ 154 | if(labels.size()!=1){ 155 | throw new IllegalArgumentException("The Perceptron algorithm is a binary method which can learn a single Label"); 156 | } 157 | else{ 158 | this.label=labels.get(0); 159 | this.classifier.setLabels(labels); 160 | } 161 | } 162 | 163 | 164 | @Override 165 | public List getLabels() { 166 | 167 | return Arrays.asList(label); 168 | } 169 | 170 | @Override 171 | public Label getLabel(){ 172 | return this.label; 173 | } 174 | 175 | @Override 176 | public void setLabel(Label label){ 177 | this.setLabels(Arrays.asList(label)); 178 | } 179 | 180 | } 181 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/BinaryPlattNormalizer.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | import it.uniroma2.sag.kelp.data.label.Label; 4 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput; 5 | 6 | public class BinaryPlattNormalizer { 7 | 8 | private float A; 9 | private float B; 10 | 11 | public BinaryPlattNormalizer() { 12 | 13 | } 14 | 15 | public BinaryPlattNormalizer(float a, float b) { 16 | super(); 17 | A = a; 18 | B = b; 19 | } 20 | 21 | public float normalizeScore(float nonNomalizedScore) { 22 | return (float) (1.0 / (1.0 + Math.exp(A * nonNomalizedScore + B))); 23 | } 24 | 25 | public float getA() { 26 | return A; 27 | } 28 | 29 | public float getB() { 30 | return B; 31 | } 32 | 33 | public void setA(float a) { 34 | A = a; 35 | } 36 | 37 | public void setB(float b) { 38 | B = b; 39 | } 40 | 41 | @Override 42 | public String toString() { 43 | return "PlattSigmoidFunction [A=" + A + ", B=" + B + "]"; 44 | } 45 | 46 | public BinaryMarginClassifierOutput getNormalizedScore(BinaryMarginClassifierOutput binaryMarginClassifierOutput) { 47 | 48 | Label positiveLabel = binaryMarginClassifierOutput.getAllClasses().get(0); 49 | 50 | Float nonNormalizedScore = binaryMarginClassifierOutput.getScore(positiveLabel); 51 | 52 | BinaryMarginClassifierOutput res = new BinaryMarginClassifierOutput(positiveLabel, 53 | normalizeScore(nonNormalizedScore)); 54 | 55 | return res; 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/MulticlassPlattNormalizer.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | import java.util.HashMap; 4 | 5 | import it.uniroma2.sag.kelp.data.label.Label; 6 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassificationOutput; 7 | 8 | public class MulticlassPlattNormalizer { 9 | 10 | private HashMap binaryPlattNormalizers; 11 | 12 | public void addBinaryPlattNormalizer(Label label, BinaryPlattNormalizer binaryPlattNormalizer) { 13 | if (binaryPlattNormalizers == null) { 14 | binaryPlattNormalizers = new HashMap(); 15 | } 16 | binaryPlattNormalizers.put(label, binaryPlattNormalizer); 17 | } 18 | 19 | public OneVsAllClassificationOutput getNormalizedScores(OneVsAllClassificationOutput oneVsAllClassificationOutput) { 20 | OneVsAllClassificationOutput res = new OneVsAllClassificationOutput(); 21 | 22 | for (Label l : oneVsAllClassificationOutput.getAllClasses()) { 23 | float nonNormalizedScore = oneVsAllClassificationOutput.getScore(l); 24 | BinaryPlattNormalizer binaryPlattNormalizer = binaryPlattNormalizers.get(l); 25 | float normalizedScore = binaryPlattNormalizer.normalizeScore(nonNormalizedScore); 26 | 27 | res.addBinaryPrediction(l, normalizedScore); 28 | } 29 | 30 | return res; 31 | } 32 | 33 | public static OneVsAllClassificationOutput softmax(OneVsAllClassificationOutput oneVsAllClassificationOutput) { 34 | OneVsAllClassificationOutput res = new OneVsAllClassificationOutput(); 35 | 36 | float denom = 0; 37 | for (Label l : oneVsAllClassificationOutput.getAllClasses()) { 38 | float score = oneVsAllClassificationOutput.getScore(l); 39 | denom += Math.exp(score); 40 | } 41 | 42 | 43 | for (Label l : oneVsAllClassificationOutput.getAllClasses()) { 44 | float score = oneVsAllClassificationOutput.getScore(l); 45 | float newScore = (float)Math.exp(score)/denom; 46 | 47 | res.addBinaryPrediction(l, newScore); 48 | } 49 | 50 | return res; 51 | } 52 | 53 | public HashMap getBinaryPlattNormalizers() { 54 | return binaryPlattNormalizers; 55 | } 56 | 57 | public void setBinaryPlattNormalizers(HashMap binaryPlattNormalizers) { 58 | this.binaryPlattNormalizers = binaryPlattNormalizers; 59 | } 60 | 61 | } 62 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/PlattInputElement.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | public class PlattInputElement { 4 | 5 | private int label; 6 | private float value; 7 | 8 | public PlattInputElement(int label, float value) { 9 | super(); 10 | this.label = label; 11 | this.value = value; 12 | } 13 | 14 | public int getLabel() { 15 | return label; 16 | } 17 | 18 | public float getValue() { 19 | return value; 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/PlattInputList.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | import java.util.Vector; 4 | 5 | public class PlattInputList { 6 | 7 | private Vector list; 8 | private int positiveElement; 9 | private int negativeElement; 10 | 11 | public PlattInputList() { 12 | list = new Vector(); 13 | } 14 | 15 | public void add(PlattInputElement arg0) { 16 | if (arg0.getLabel() > 0) 17 | positiveElement++; 18 | else 19 | negativeElement++; 20 | 21 | list.add(arg0); 22 | } 23 | 24 | public PlattInputElement get(int index) { 25 | return list.get(index); 26 | } 27 | 28 | public int size() { 29 | return list.size(); 30 | } 31 | 32 | public int getPositiveElement() { 33 | return positiveElement; 34 | } 35 | 36 | public int getNegativeElement() { 37 | return negativeElement; 38 | } 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/probabilityestimator/platt/PlattMethod.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.learningalgorithm.classification.probabilityestimator.platt; 2 | 3 | import java.util.HashMap; 4 | 5 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 6 | import it.uniroma2.sag.kelp.data.example.Example; 7 | import it.uniroma2.sag.kelp.data.label.Label; 8 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 9 | import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm; 10 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 11 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 12 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 13 | 14 | public class PlattMethod { 15 | 16 | /** 17 | * Input parameters: 18 | * 19 | * deci = array of SVM decision values 20 | * 21 | * label = array of booleans: is the example labeled +1? 22 | * 23 | * prior1 = number of positive examples 24 | * 25 | * prior0 = number of negative examples 26 | * 27 | * Outputs: 28 | * 29 | * A, B = parameters of sigmoid 30 | * 31 | * @return 32 | **/ 33 | private static BinaryPlattNormalizer estimateSigmoid(float[] deci, float[] label, int prior1, int prior0) { 34 | 35 | /** 36 | * Parameter setting 37 | */ 38 | // Maximum number of iterations 39 | int maxiter = 100; 40 | // Minimum step taken in line search 41 | // minstep=1e-10; 42 | double minstep = 1e-10; 43 | double stopping = 1e-5; 44 | // Sigma: Set to any value > 0 45 | double sigma = 1e-12; 46 | // Construct initial values: target support in array t, 47 | // initial function value in fval 48 | double hiTarget = ((double) prior1 + 1.0f) / ((double) prior1 + 2.0f); 49 | double loTarget = 1 / (prior0 + 2.0f); 50 | 51 | int len = prior1 + prior0; // Total number of data 52 | double A; 53 | double B; 54 | 55 | double t[] = new double[len]; 56 | 57 | for (int i = 0; i < len; i++) { 58 | if (label[i] > 0) 59 | t[i] = hiTarget; 60 | else 61 | t[i] = loTarget; 62 | } 63 | 64 | A = 0; 65 | B = Math.log((prior0 + 1.0) / (prior1 + 1.0)); 66 | double fval = 0f; 67 | 68 | for (int i = 0; i < len; i++) { 69 | double fApB = deci[i] * A + B; 70 | if (fApB >= 0) 71 | fval += t[i] * fApB + Math.log(1 + Math.exp(-fApB)); 72 | else 73 | fval += (t[i] - 1) * fApB + Math.log(1 + Math.exp(fApB)); 74 | } 75 | 76 | int it = 1; 77 | for (it = 1; it <= maxiter; it++) { 78 | // Update Gradient and Hessian (use H� = H + sigma I) 79 | double h11 = sigma; 80 | double h22 = sigma; 81 | double h21 = 0; 82 | double g1 = 0; 83 | double g2 = 0; 84 | for (int i = 0; i < len; i++) { 85 | double fApB = deci[i] * A + B; 86 | double p; 87 | double q; 88 | if (fApB >= 0) { 89 | p = (Math.exp(-fApB) / (1.0 + Math.exp(-fApB))); 90 | q = (1.0 / (1.0 + Math.exp(-fApB))); 91 | } else { 92 | p = 1.0 / (1.0 + Math.exp(fApB)); 93 | q = Math.exp(fApB) / (1.0 + Math.exp(fApB)); 94 | } 95 | double d2 = p * q; 96 | h11 += deci[i] * deci[i] * d2; 97 | h22 += d2; 98 | h21 += deci[i] * d2; 99 | double d1 = t[i] - p; 100 | g1 += deci[i] * d1; 101 | g2 += d1; 102 | } 103 | if (Math.abs(g1) < stopping && Math.abs(g2) < stopping) // Stopping 104 | // criteria 105 | break; 106 | 107 | // Compute modified Newton directions 108 | double det = h11 * h22 - h21 * h21; 109 | double dA = -(h22 * g1 - h21 * g2) / det; 110 | double dB = -(-h21 * g1 + h11 * g2) / det; 111 | double gd = g1 * dA + g2 * dB; 112 | double stepsize = 1; 113 | 114 | while (stepsize >= minstep) { // Line search 115 | double newA = A + stepsize * dA; 116 | double newB = B + stepsize * dB; 117 | double newf = 0.0; 118 | for (int i = 0; i < len; i++) { 119 | double fApB = deci[i] * newA + newB; 120 | if (fApB >= 0) 121 | newf += t[i] * fApB + Math.log(1 + Math.exp(-fApB)); 122 | else 123 | newf += (t[i] - 1) * fApB + Math.log(1 + Math.exp(fApB)); 124 | } 125 | 126 | if (newf < fval + 1e-4 * stepsize * gd) { 127 | A = newA; 128 | B = newB; 129 | fval = newf; 130 | break; // Sufficient decrease satisfied 131 | } else 132 | stepsize /= 2.0; 133 | } 134 | if (stepsize < minstep) { 135 | System.out.println("Line search fails"); 136 | break; 137 | } 138 | } 139 | if (it >= maxiter) 140 | System.out.println("Reaching maximum iterations"); 141 | 142 | return new BinaryPlattNormalizer((float) A, (float) B); 143 | 144 | } 145 | 146 | public static BinaryPlattNormalizer esitmateSigmoid(SimpleDataset dataset, 147 | BinaryLearningAlgorithm binaryLearningAlgorithm, int nFolds) { 148 | 149 | PlattInputList plattInputList = new PlattInputList(); 150 | 151 | Label positiveLabel = binaryLearningAlgorithm.getLabel(); 152 | 153 | SimpleDataset[] folds = dataset.getShuffledDataset().nFolding(nFolds); 154 | 155 | for (int f = 0; f < folds.length; f++) { 156 | 157 | SimpleDataset fold = folds[f]; 158 | 159 | SimpleDataset localTrainDataset = new SimpleDataset(); 160 | SimpleDataset localTestDataset = new SimpleDataset(); 161 | for (int i = 0; i < folds.length; i++) { 162 | if (i != f) { 163 | localTrainDataset.addExamples(fold); 164 | } else { 165 | localTestDataset.addExamples(fold); 166 | } 167 | } 168 | 169 | LearningAlgorithm duplicatedLearningAlgorithm = binaryLearningAlgorithm.duplicate(); 170 | 171 | duplicatedLearningAlgorithm.learn(fold); 172 | 173 | PredictionFunction predictionFunction = duplicatedLearningAlgorithm.getPredictionFunction(); 174 | 175 | for (Example example : localTestDataset.getExamples()) { 176 | Prediction predict = predictionFunction.predict(example); 177 | 178 | float value = predict.getScore(positiveLabel); 179 | 180 | int label = 1; 181 | if (!example.isExampleOf(positiveLabel)) 182 | label = -1; 183 | plattInputList.add(new PlattInputElement(label, value)); 184 | } 185 | } 186 | 187 | return estimateSigmoid(plattInputList); 188 | } 189 | 190 | public static MulticlassPlattNormalizer esitmateSigmoid(SimpleDataset dataset, OneVsAllLearning oneVsAllLearning, 191 | int nFolds) { 192 | 193 | HashMap plattInputLists = new HashMap(); 194 | for(Label label: dataset.getClassificationLabels()){ 195 | plattInputLists.put(label, new PlattInputList()); 196 | } 197 | 198 | SimpleDataset[] folds = dataset.getShuffledDataset().nFolding(nFolds); 199 | 200 | MulticlassPlattNormalizer res = new MulticlassPlattNormalizer(); 201 | 202 | for (int f = 0; f < folds.length; f++) { 203 | 204 | SimpleDataset fold = folds[f]; 205 | 206 | SimpleDataset localTrainDataset = new SimpleDataset(); 207 | SimpleDataset localTestDataset = new SimpleDataset(); 208 | for (int i = 0; i < folds.length; i++) { 209 | if (i != f) { 210 | localTrainDataset.addExamples(fold); 211 | } else { 212 | localTestDataset.addExamples(fold); 213 | } 214 | } 215 | 216 | LearningAlgorithm duplicatedLearningAlgorithm = oneVsAllLearning.duplicate(); 217 | 218 | duplicatedLearningAlgorithm.learn(fold); 219 | 220 | PredictionFunction predictionFunction = duplicatedLearningAlgorithm.getPredictionFunction(); 221 | 222 | for (Example example : localTestDataset.getExamples()) { 223 | Prediction predict = predictionFunction.predict(example); 224 | 225 | for (Label label : dataset.getClassificationLabels()) { 226 | 227 | float valueOfLabel = predict.getScore(label); 228 | 229 | int binaryLabel = 1; 230 | if (!example.isExampleOf(label)) 231 | binaryLabel = -1; 232 | plattInputLists.get(label).add(new PlattInputElement(binaryLabel, valueOfLabel)); 233 | } 234 | } 235 | } 236 | 237 | for (Label label : dataset.getClassificationLabels()) { 238 | res.addBinaryPlattNormalizer(label, estimateSigmoid(plattInputLists.get(label))); 239 | } 240 | 241 | return res; 242 | } 243 | 244 | protected static BinaryPlattNormalizer estimateSigmoid(PlattInputList inputList) { 245 | float[] deci = new float[inputList.size()]; 246 | float[] label = new float[inputList.size()]; 247 | int prior1 = inputList.getPositiveElement(); 248 | int prior0 = inputList.getNegativeElement(); 249 | 250 | for (int i = 0; i < inputList.size(); i++) { 251 | deci[i] = inputList.get(i).getValue(); 252 | label[i] = inputList.get(i).getLabel(); 253 | } 254 | 255 | return estimateSigmoid(deci, label, prior1, prior0); 256 | } 257 | 258 | } 259 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/classification/scw/SCWType.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.scw; 17 | 18 | /** 19 | * The two types of Soft Confidence-Weighted implemented variants 20 | * 21 | * @author Danilo Croce 22 | * 23 | */ 24 | public enum SCWType { 25 | 26 | SCW_I, SCW_II 27 | 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/clustering/kernelbasedkmeans/KernelBasedKMeansExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.clustering.kernelbasedkmeans; 17 | 18 | import com.fasterxml.jackson.annotation.JsonTypeName; 19 | 20 | import it.uniroma2.sag.kelp.data.clustering.ClusterExample; 21 | import it.uniroma2.sag.kelp.data.example.Example; 22 | 23 | @JsonTypeName("kernelbasedkmeansexample") 24 | public class KernelBasedKMeansExample extends ClusterExample { 25 | 26 | /** 27 | * 28 | */ 29 | private static final long serialVersionUID = -5368757832244686390L; 30 | 31 | public KernelBasedKMeansExample() { 32 | super(); 33 | } 34 | 35 | public KernelBasedKMeansExample(Example e, float dist) { 36 | super(e, dist); 37 | } 38 | 39 | @Override 40 | public Example getExample() { 41 | return example; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/liblinear/LibLinearRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2015 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.liblinear; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.label.Label; 20 | import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm; 21 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 22 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.L2R_L2_SvcFunction; 23 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.L2R_L2_SvrFunction; 24 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.Problem; 25 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.Problem.LibLinearSolverType; 26 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.solver.Tron; 27 | import it.uniroma2.sag.kelp.learningalgorithm.regression.RegressionLearningAlgorithm; 28 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 29 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 30 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateLinearRegressionFunction; 31 | 32 | import java.util.Arrays; 33 | import java.util.List; 34 | 35 | import com.fasterxml.jackson.annotation.JsonIgnore; 36 | import com.fasterxml.jackson.annotation.JsonTypeName; 37 | 38 | /** 39 | * This class implements linear SVM regression trained using a coordinate descent 40 | * algorithm [Fan et al, 2008]. It operates in an explicit feature space (i.e. 41 | * it does not relies on any kernel). This code has been adapted from the Java 42 | * port of the original LIBLINEAR C++ sources. 43 | * 44 | * Further details can be found in: 45 | * 46 | * [Fan et al, 2008] R.-E. Fan, K.-W. Chang, C.-J. Hsieh, X.-R. Wang, and C.-J. 47 | * Lin. LIBLINEAR: A Library for Large Linear Classification, Journal of Machine 48 | * Learning Research 9(2008), 1871-1874. Software available at 49 | * 50 | * The original LIBLINEAR code: 51 | * http://www.csie.ntu.edu.tw/~cjlin/liblinear 52 | * 53 | * The original JAVA porting (v 1.94): http://liblinear.bwaldvogel.de 54 | * 55 | * @author Danilo Croce 56 | */ 57 | @JsonTypeName("liblinearregression") 58 | public class LibLinearRegression implements LinearMethod, 59 | RegressionLearningAlgorithm, BinaryLearningAlgorithm { 60 | 61 | /** 62 | * The property corresponding to the variable to be learned 63 | */ 64 | private Label label; 65 | /** 66 | * The regularization parameter 67 | */ 68 | private double c = 1; 69 | 70 | /** 71 | * The regressor to be returned 72 | */ 73 | @JsonIgnore 74 | private UnivariateLinearRegressionFunction regressionFunction; 75 | 76 | /** 77 | * The epsilon in loss function of SVR (default 0.1) 78 | */ 79 | private double p = 0.1f; 80 | 81 | /** 82 | * The identifier of the representation to be considered for the training 83 | * step 84 | */ 85 | private String representation; 86 | 87 | /** 88 | * @param label 89 | * The regression property to be learned 90 | * @param c 91 | * The regularization parameter 92 | * 93 | * @param p 94 | * The The epsilon in loss function of SVR 95 | * 96 | * @param representationName 97 | * The identifier of the representation to be considered for the 98 | * training step 99 | */ 100 | public LibLinearRegression(Label label, double c, double p, 101 | String representationName) { 102 | this(); 103 | 104 | this.setLabel(label); 105 | this.c = c; 106 | this.p = p; 107 | this.setRepresentation(representationName); 108 | } 109 | 110 | /** 111 | * @param c 112 | * The regularization parameter 113 | * 114 | * @param representationName 115 | * The identifier of the representation to be considered for the 116 | * training step 117 | */ 118 | public LibLinearRegression(double c, double p, String representationName) { 119 | this(); 120 | this.c = c; 121 | this.p = p; 122 | this.setRepresentation(representationName); 123 | } 124 | 125 | public LibLinearRegression() { 126 | this.regressionFunction = new UnivariateLinearRegressionFunction(); 127 | this.regressionFunction.setModel(new BinaryLinearModel()); 128 | } 129 | 130 | /** 131 | * @return the regularization parameter 132 | */ 133 | public double getC() { 134 | return c; 135 | } 136 | 137 | /** 138 | * @param c 139 | * the regularization parameter 140 | */ 141 | public void setC(double c) { 142 | this.c = c; 143 | } 144 | 145 | /** 146 | * @return the epsilon in loss function 147 | */ 148 | public double getP() { 149 | return p; 150 | } 151 | 152 | /** 153 | * @param p 154 | * the epsilon in loss function 155 | */ 156 | public void setP(double p) { 157 | this.p = p; 158 | } 159 | 160 | /* 161 | * (non-Javadoc) 162 | * 163 | * @see 164 | * it.uniroma2.sag.kelp.learningalgorithm.LinearMethod#getRepresentation() 165 | */ 166 | @Override 167 | public String getRepresentation() { 168 | return representation; 169 | } 170 | 171 | /* 172 | * (non-Javadoc) 173 | * 174 | * @see 175 | * it.uniroma2.sag.kelp.learningalgorithm.LinearMethod#setRepresentation 176 | * (java.lang.String) 177 | */ 178 | @Override 179 | public void setRepresentation(String representation) { 180 | this.representation = representation; 181 | BinaryLinearModel model = this.regressionFunction.getModel(); 182 | model.setRepresentation(representation); 183 | } 184 | 185 | /* 186 | * (non-Javadoc) 187 | * 188 | * @see 189 | * it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#setLabels(java 190 | * .util.List) 191 | */ 192 | @Override 193 | public void setLabels(List labels) { 194 | if (labels.size() != 1) { 195 | throw new IllegalArgumentException( 196 | "LibLinear algorithm is a binary method which can learn a single Label"); 197 | } else { 198 | this.label = labels.get(0); 199 | this.regressionFunction.setLabels(labels); 200 | } 201 | } 202 | 203 | /* 204 | * (non-Javadoc) 205 | * 206 | * @see it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#getLabels() 207 | */ 208 | @Override 209 | public List getLabels() { 210 | return Arrays.asList(label); 211 | } 212 | 213 | /* 214 | * (non-Javadoc) 215 | * 216 | * @see 217 | * it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm#getLabel() 218 | */ 219 | @Override 220 | public Label getLabel() { 221 | return this.label; 222 | } 223 | 224 | /* 225 | * (non-Javadoc) 226 | * 227 | * @see 228 | * it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm#setLabel 229 | * (it.uniroma2.sag.kelp.data.label.Label) 230 | */ 231 | @Override 232 | public void setLabel(Label label) { 233 | this.setLabels(Arrays.asList(label)); 234 | } 235 | 236 | /* 237 | * (non-Javadoc) 238 | * 239 | * @see 240 | * it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#learn(it.uniroma2 241 | * .sag.kelp.data.dataset.Dataset) 242 | */ 243 | @Override 244 | public void learn(Dataset dataset) { 245 | 246 | double eps = 0.001; 247 | 248 | int l = dataset.getNumberOfExamples(); 249 | 250 | double[] C = new double[l]; 251 | for (int i = 0; i < l; i++) { 252 | C[i] = c; 253 | } 254 | 255 | Problem problem = new Problem(dataset, representation, label, 256 | LibLinearSolverType.REGRESSION); 257 | 258 | L2R_L2_SvcFunction fun_obj = new L2R_L2_SvrFunction(problem, C, p); 259 | 260 | Tron tron = new Tron(fun_obj, eps); 261 | 262 | double[] w = new double[problem.n]; 263 | tron.tron(w); 264 | 265 | this.regressionFunction.getModel().setHyperplane(problem.getW(w)); 266 | this.regressionFunction.getModel().setRepresentation(representation); 267 | this.regressionFunction.getModel().setBias(0); 268 | } 269 | 270 | /* 271 | * (non-Javadoc) 272 | * 273 | * @see it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#duplicate() 274 | */ 275 | @Override 276 | public LibLinearRegression duplicate() { 277 | LibLinearRegression copy = new LibLinearRegression(); 278 | copy.setRepresentation(representation); 279 | copy.setC(c); 280 | copy.setP(p); 281 | return copy; 282 | } 283 | 284 | /* 285 | * (non-Javadoc) 286 | * 287 | * @see it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#reset() 288 | */ 289 | @Override 290 | public void reset() { 291 | this.regressionFunction.reset(); 292 | } 293 | 294 | @Override 295 | public UnivariateLinearRegressionFunction getPredictionFunction() { 296 | return regressionFunction; 297 | } 298 | 299 | @Override 300 | public void setPredictionFunction(PredictionFunction predictionFunction) { 301 | this.regressionFunction = (UnivariateLinearRegressionFunction) predictionFunction; 302 | } 303 | 304 | } 305 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/KernelizedPassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import com.fasterxml.jackson.annotation.JsonTypeName; 19 | 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.kernel.Kernel; 22 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 23 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 24 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateKernelMachineRegressionFunction; 25 | 26 | /** 27 | * Online Passive-Aggressive Learning Algorithm for regression tasks (kernel machine version). 28 | * 29 | * reference: 30 | * 31 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 32 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 33 | * 34 | * @author Simone Filice 35 | */ 36 | @JsonTypeName("kernelizedPA-R") 37 | public class KernelizedPassiveAggressiveRegression extends PassiveAggressiveRegression implements KernelMethod{ 38 | 39 | private Kernel kernel; 40 | 41 | public KernelizedPassiveAggressiveRegression(){ 42 | this.regressor = new UnivariateKernelMachineRegressionFunction(); 43 | } 44 | 45 | public KernelizedPassiveAggressiveRegression(float aggressiveness, float epsilon, Policy policy, Kernel kernel, Label label){ 46 | this.regressor = new UnivariateKernelMachineRegressionFunction(); 47 | this.setC(aggressiveness); 48 | this.setEpsilon(epsilon); 49 | this.setPolicy(policy); 50 | this.setKernel(kernel); 51 | this.setLabel(label); 52 | } 53 | 54 | @Override 55 | public Kernel getKernel(){ 56 | return kernel; 57 | } 58 | 59 | @Override 60 | public void setKernel(Kernel kernel) { 61 | this.kernel = kernel; 62 | this.getPredictionFunction().getModel().setKernel(kernel); 63 | } 64 | 65 | @Override 66 | public KernelizedPassiveAggressiveRegression duplicate() { 67 | KernelizedPassiveAggressiveRegression copy = new KernelizedPassiveAggressiveRegression(); 68 | copy.setC(this.c); 69 | copy.setKernel(this.kernel); 70 | copy.setPolicy(this.policy); 71 | copy.setEpsilon(epsilon); 72 | return copy; 73 | } 74 | 75 | @Override 76 | public UnivariateKernelMachineRegressionFunction getPredictionFunction(){ 77 | return (UnivariateKernelMachineRegressionFunction) this.regressor; 78 | } 79 | 80 | @Override 81 | public void setPredictionFunction(PredictionFunction predictionFunction) { 82 | this.regressor = (UnivariateKernelMachineRegressionFunction) predictionFunction; 83 | } 84 | 85 | } 86 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/LinearPassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import com.fasterxml.jackson.annotation.JsonTypeName; 19 | 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 22 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 23 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 24 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateLinearRegressionFunction; 25 | 26 | /** 27 | * Online Passive-Aggressive Learning Algorithm for regression tasks (linear version). 28 | * 29 | * reference: 30 | * 31 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 32 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 33 | * 34 | * @author Simone Filice 35 | */ 36 | @JsonTypeName("linearPA-R") 37 | public class LinearPassiveAggressiveRegression extends PassiveAggressiveRegression implements LinearMethod{ 38 | 39 | private String representation; 40 | 41 | public LinearPassiveAggressiveRegression(){ 42 | UnivariateLinearRegressionFunction regressor = new UnivariateLinearRegressionFunction(); 43 | regressor.setModel(new BinaryLinearModel()); 44 | this.regressor = regressor; 45 | 46 | } 47 | 48 | public LinearPassiveAggressiveRegression(float aggressiveness, float epsilon, Policy policy, String representation, Label label){ 49 | UnivariateLinearRegressionFunction regressor = new UnivariateLinearRegressionFunction(); 50 | regressor.setModel(new BinaryLinearModel()); 51 | this.regressor = regressor; 52 | this.setC(aggressiveness); 53 | this.setEpsilon(epsilon); 54 | this.setPolicy(policy); 55 | this.setRepresentation(representation); 56 | this.setLabel(label); 57 | } 58 | 59 | @Override 60 | public LinearPassiveAggressiveRegression duplicate() { 61 | LinearPassiveAggressiveRegression copy = new LinearPassiveAggressiveRegression(); 62 | copy.setC(this.c); 63 | copy.setRepresentation(this.representation); 64 | copy.setPolicy(this.policy); 65 | copy.setEpsilon(epsilon); 66 | return copy; 67 | } 68 | 69 | @Override 70 | public String getRepresentation() { 71 | return representation; 72 | } 73 | 74 | @Override 75 | public void setRepresentation(String representation) { 76 | this.representation = representation; 77 | this.getPredictionFunction().getModel().setRepresentation(representation); 78 | } 79 | 80 | @Override 81 | public UnivariateLinearRegressionFunction getPredictionFunction(){ 82 | return (UnivariateLinearRegressionFunction) this.regressor; 83 | } 84 | 85 | @Override 86 | public void setPredictionFunction(PredictionFunction predictionFunction) { 87 | this.regressor = (UnivariateLinearRegressionFunction) predictionFunction; 88 | } 89 | 90 | } 91 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/PassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.learningalgorithm.PassiveAggressive; 21 | import it.uniroma2.sag.kelp.learningalgorithm.regression.RegressionLearningAlgorithm; 22 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateRegressionOutput; 23 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateRegressionFunction; 24 | 25 | import com.fasterxml.jackson.annotation.JsonIgnore; 26 | 27 | /** 28 | * Online Passive-Aggressive Learning Algorithm for regression tasks. 29 | * 30 | * reference: 31 | * 32 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 33 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 34 | * 35 | * @author Simone Filice 36 | */ 37 | public abstract class PassiveAggressiveRegression extends PassiveAggressive implements RegressionLearningAlgorithm{ 38 | 39 | @JsonIgnore 40 | protected UnivariateRegressionFunction regressor; 41 | 42 | protected float epsilon; 43 | 44 | /** 45 | * Returns epsilon, i.e. the accepted distance between the predicted and the real regression values 46 | * 47 | * @return the epsilon 48 | */ 49 | public float getEpsilon() { 50 | return epsilon; 51 | } 52 | 53 | /** 54 | * Sets epsilon, i.e. the accepted distance between the predicted and the real regression values 55 | * 56 | * @param epsilon the epsilon to set 57 | */ 58 | public void setEpsilon(float epsilon) { 59 | this.epsilon = epsilon; 60 | } 61 | 62 | @Override 63 | public UnivariateRegressionFunction getPredictionFunction() { 64 | return this.regressor; 65 | } 66 | 67 | @Override 68 | public void learn(Dataset dataset){ 69 | 70 | while(dataset.hasNextExample()){ 71 | Example example = dataset.getNextExample(); 72 | this.learn(example); 73 | } 74 | dataset.reset(); 75 | } 76 | 77 | @Override 78 | public UnivariateRegressionOutput learn(Example example){ 79 | UnivariateRegressionOutput prediction=this.regressor.predict(example); 80 | float difference = example.getRegressionValue(label) - prediction.getScore(label); 81 | float lossValue = Math.abs(difference) - epsilon;//it represents the distance from the correct semi-space 82 | if(lossValue>0){ 83 | float exampleSquaredNorm = this.regressor.getModel().getSquaredNorm(example); 84 | float weight = this.computeWeight(example, lossValue, exampleSquaredNorm, c); 85 | if(difference<0){ 86 | weight = -weight; 87 | } 88 | this.regressor.getModel().addExample(weight, example); 89 | } 90 | return prediction; 91 | } 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/linearization/LinearizationFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.linearization; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 20 | import it.uniroma2.sag.kelp.data.example.Example; 21 | import it.uniroma2.sag.kelp.data.representation.Vector; 22 | 23 | /** 24 | * This interface allows implementing function to linearized examples through 25 | * linear representations, i.e. vectors 26 | * 27 | * 28 | * @author Danilo Croce 29 | * 30 | */ 31 | public interface LinearizationFunction { 32 | 33 | /** 34 | * Given an input Example, this method generates a linear 35 | * Representation>, i.e. a Vector. 36 | * 37 | * @param example 38 | * The input example. 39 | * @return The linearized representation of the input example. 40 | */ 41 | public Vector getLinearRepresentation(Example example); 42 | 43 | /** 44 | * This method linearizes an input example, providing a new example 45 | * containing only a representation with a specific name, provided as input. 46 | * The produced example inherits the labels of the input example. 47 | * 48 | * @param example 49 | * The input example. 50 | * @param vectorName 51 | * The name of the linear representation inside the new example 52 | * @return 53 | */ 54 | public Example getLinearizedExample(Example example, String representationName); 55 | 56 | /** 57 | * This method linearizes all the examples in the input dataset 58 | * , generating a corresponding linearized dataset. The produced examples 59 | * inherit the labels of the corresponding input examples. 60 | * 61 | * @param dataset 62 | * The input dataset 63 | * @param representationName 64 | * The name of the linear representation inside the new examples 65 | * @return 66 | */ 67 | public SimpleDataset getLinearizedDataset(Dataset dataset, String representationName); 68 | 69 | /** 70 | * @return the size of the resulting embedding, i.e. the number of resulting 71 | * vector dimensions 72 | */ 73 | public int getEmbeddingSize(); 74 | 75 | } 76 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/predictionfunction/SequencePrediction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.predictionfunction; 17 | 18 | import java.util.ArrayList; 19 | import java.util.List; 20 | 21 | import it.uniroma2.sag.kelp.data.example.SequencePath; 22 | import it.uniroma2.sag.kelp.data.label.Label; 23 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 24 | 25 | /** 26 | * It is a output provided by a machine learning systems on a sequence. This 27 | * specific implementation allows to assign multiple labelings to single 28 | * sequence, useful for some labeling strategies, such as Beam Search. Notice 29 | * that each labeling requires a score to select the more promising labeling. 30 | * 31 | * @author Danilo Croce 32 | * 33 | */ 34 | public class SequencePrediction implements Prediction { 35 | 36 | /** 37 | * 38 | */ 39 | private static final long serialVersionUID = -1040539866977906008L; 40 | /** 41 | * This list contains multiple labelings to be assigned to a single sequence 42 | */ 43 | private List paths; 44 | 45 | public SequencePrediction() { 46 | paths = new ArrayList(); 47 | } 48 | 49 | /** 50 | * @return The best path, i.e., the labeling with the highest score in the 51 | * list of labelings provided by a classifier 52 | */ 53 | public SequencePath bestPath() { 54 | return paths.get(0); 55 | } 56 | 57 | /** 58 | * @return a list containing multiple labelings to be assigned to a single 59 | * sequence 60 | */ 61 | public List getPaths() { 62 | return paths; 63 | } 64 | 65 | @Override 66 | public Float getScore(Label label) { 67 | return null; 68 | } 69 | 70 | /** 71 | * @param paths 72 | * a list contains multiple labelings to be assigned to a single 73 | * sequence 74 | */ 75 | public void setPaths(List paths) { 76 | this.paths = paths; 77 | } 78 | 79 | @Override 80 | public String toString() { 81 | StringBuilder sb = new StringBuilder(); 82 | for (int i = 0; i < paths.size(); i++) { 83 | if (i == 0) 84 | sb.append("Best Path\t"); 85 | else 86 | sb.append("Altern. Path\t"); 87 | SequencePath sequencePath = paths.get(i); 88 | sb.append(sequencePath + "\n"); 89 | } 90 | return sb.toString(); 91 | } 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/predictionfunction/model/SequenceModel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.predictionfunction.model; 17 | 18 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGenerator; 19 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 20 | 21 | /** 22 | * This class implements a model produced by a 23 | * SequenceClassificationLearningAlgorithm 24 | * 25 | * @author Danilo Croce 26 | * 27 | */ 28 | public class SequenceModel implements Model { 29 | 30 | /** 31 | * 32 | */ 33 | private static final long serialVersionUID = -2749198158786953940L; 34 | 35 | /** 36 | * The prediction function producing the emission scores to be considered in 37 | * the Viterbi Decoding 38 | */ 39 | private PredictionFunction basePredictionFunction; 40 | 41 | private SequenceExampleGenerator sequenceExampleGenerator; 42 | 43 | public SequenceModel() { 44 | super(); 45 | } 46 | 47 | public SequenceModel(PredictionFunction basePredictionFunction, SequenceExampleGenerator sequenceExampleGenerator) { 48 | super(); 49 | this.basePredictionFunction = basePredictionFunction; 50 | this.sequenceExampleGenerator = sequenceExampleGenerator; 51 | } 52 | 53 | public PredictionFunction getBasePredictionFunction() { 54 | return basePredictionFunction; 55 | } 56 | 57 | public SequenceExampleGenerator getSequenceExampleGenerator() { 58 | return sequenceExampleGenerator; 59 | } 60 | 61 | @Override 62 | public void reset() { 63 | } 64 | 65 | public void setBasePredictionFunction(PredictionFunction basePredictionFunction) { 66 | this.basePredictionFunction = basePredictionFunction; 67 | } 68 | 69 | public void setSequenceExampleGenerator(SequenceExampleGenerator sequenceExampleGenerator) { 70 | this.sequenceExampleGenerator = sequenceExampleGenerator; 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/utils/evaluation/ClusteringEvaluator.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.utils.evaluation; 2 | 3 | import java.util.ArrayList; 4 | import java.util.HashSet; 5 | import java.util.TreeMap; 6 | 7 | import it.uniroma2.sag.kelp.data.clustering.Cluster; 8 | import it.uniroma2.sag.kelp.data.clustering.ClusterExample; 9 | import it.uniroma2.sag.kelp.data.clustering.ClusterList; 10 | import it.uniroma2.sag.kelp.data.example.Example; 11 | import it.uniroma2.sag.kelp.data.example.SimpleExample; 12 | import it.uniroma2.sag.kelp.data.label.Label; 13 | import it.uniroma2.sag.kelp.data.label.StringLabel; 14 | import it.uniroma2.sag.kelp.learningalgorithm.clustering.kernelbasedkmeans.KernelBasedKMeansExample; 15 | 16 | /** 17 | * 18 | * Implements Evaluation methods for clustering algorithms. 19 | * 20 | * More details about Purity and NMI can be found here: 21 | * 22 | * https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1. 23 | * html 24 | * 25 | * @author Danilo Croce 26 | * 27 | */ 28 | public class ClusteringEvaluator { 29 | 30 | public static float getPurity(ClusterList clusters) { 31 | 32 | float res = 0; 33 | int k = clusters.size(); 34 | 35 | for (int clustId = 0; clustId < k; clustId++) { 36 | 37 | TreeMap classSizes = new TreeMap(); 38 | 39 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 40 | HashSet labels = vce.getExample().getClassificationLabels(); 41 | for (Label label : labels) 42 | if (!classSizes.containsKey(label)) 43 | classSizes.put(label, 1); 44 | else 45 | classSizes.put(label, classSizes.get(label) + 1); 46 | } 47 | 48 | int maxSize = 0; 49 | for (int size : classSizes.values()) { 50 | if (size > maxSize) { 51 | maxSize = size; 52 | } 53 | } 54 | res += maxSize; 55 | } 56 | 57 | return res / (float) clusters.getNumberOfExamples(); 58 | } 59 | 60 | public static float getMI(ClusterList clusters) { 61 | 62 | float res = 0; 63 | 64 | float N = clusters.getNumberOfExamples(); 65 | 66 | int k = clusters.size(); 67 | 68 | TreeMap classCardinality = getClassCardinality(clusters); 69 | 70 | for (int clustId = 0; clustId < k; clustId++) { 71 | 72 | TreeMap classSizes = getClassCardinalityWithinCluster(clusters, clustId); 73 | 74 | for (Label className : classSizes.keySet()) { 75 | int wSize = classSizes.get(className); 76 | res += ((float) wSize / N) * myLog(N * (float) wSize 77 | / (clusters.get(clustId).getExamples().size() * (float) classCardinality.get(className))); 78 | } 79 | 80 | } 81 | 82 | return res; 83 | 84 | } 85 | 86 | private static TreeMap getClassCardinalityWithinCluster(ClusterList clusters, int clustId) { 87 | 88 | TreeMap classSizes = new TreeMap(); 89 | 90 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 91 | HashSet labels = vce.getExample().getClassificationLabels(); 92 | for (Label label : labels) 93 | if (!classSizes.containsKey(label)) 94 | classSizes.put(label, 1); 95 | else 96 | classSizes.put(label, classSizes.get(label) + 1); 97 | } 98 | 99 | return classSizes; 100 | } 101 | 102 | private static float getClusterEntropy(ClusterList clusters) { 103 | 104 | float res = 0; 105 | float N = clusters.getNumberOfExamples(); 106 | int k = clusters.size(); 107 | 108 | for (int clustId = 0; clustId < k; clustId++) { 109 | int clusterElementSize = clusters.get(clustId).getExamples().size(); 110 | if (clusterElementSize != 0) 111 | res -= ((float) clusterElementSize / N) * myLog((float) clusterElementSize / N); 112 | } 113 | return res; 114 | 115 | } 116 | 117 | private static float getClassEntropy(ClusterList clusters) { 118 | 119 | float res = 0; 120 | float N = clusters.getNumberOfExamples(); 121 | 122 | TreeMap classCardinality = getClassCardinality(clusters); 123 | 124 | for (int classSize : classCardinality.values()) { 125 | res -= ((float) classSize / N) * myLog((float) classSize / N); 126 | } 127 | return res; 128 | 129 | } 130 | 131 | private static float myLog(float f) { 132 | return (float) (Math.log(f) / Math.log(2f)); 133 | } 134 | 135 | private static TreeMap getClassCardinality(ClusterList clusters) { 136 | TreeMap classSizes = new TreeMap(); 137 | 138 | int k = clusters.size(); 139 | 140 | for (int clustId = 0; clustId < k; clustId++) { 141 | 142 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 143 | HashSet labels = vce.getExample().getClassificationLabels(); 144 | for (Label label : labels) 145 | if (!classSizes.containsKey(label)) 146 | classSizes.put(label, 1); 147 | else 148 | classSizes.put(label, classSizes.get(label) + 1); 149 | } 150 | } 151 | return classSizes; 152 | } 153 | 154 | public static float getNMI(ClusterList clusters) { 155 | return getMI(clusters) / ((getClusterEntropy(clusters) + getClassEntropy(clusters)) / 2f); 156 | } 157 | 158 | public static String getStatistics(ClusterList clusters) { 159 | StringBuilder sb = new StringBuilder(); 160 | 161 | sb.append("Purity:\t" + getPurity(clusters) + "\n"); 162 | sb.append("Mutual Information:\t" + getMI(clusters) + "\n"); 163 | sb.append("Cluster Entropy:\t" + getClusterEntropy(clusters) + "\n"); 164 | sb.append("Class Entropy:\t" + getClassEntropy(clusters) + "\n"); 165 | sb.append("NMI:\t" + getNMI(clusters)); 166 | 167 | return sb.toString(); 168 | } 169 | 170 | public static void main(String[] args) { 171 | ClusterList clusters = new ClusterList(); 172 | 173 | Cluster c1 = new Cluster("C1"); 174 | ArrayList list1 = new ArrayList(); 175 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 176 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 177 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 178 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 179 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 180 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 181 | for (Example e : list1) { 182 | c1.add(new KernelBasedKMeansExample(e, 1f)); 183 | } 184 | 185 | Cluster c2 = new Cluster("C2"); 186 | ArrayList list2 = new ArrayList(); 187 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 188 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 189 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 190 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 191 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 192 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 193 | for (Example e : list2) { 194 | c2.add(new KernelBasedKMeansExample(e, 1f)); 195 | } 196 | 197 | Cluster c3 = new Cluster("C3"); 198 | ArrayList list3 = new ArrayList(); 199 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 200 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 201 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 202 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 203 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 204 | for (Example e : list3) { 205 | c3.add(new KernelBasedKMeansExample(e, 1f)); 206 | } 207 | 208 | clusters.add(c1); 209 | clusters.add(c2); 210 | clusters.add(c3); 211 | 212 | System.out.println(ClusteringEvaluator.getStatistics(clusters)); 213 | 214 | //From https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html 215 | //Purity = 0.71 216 | //NMI = 0.36 217 | 218 | } 219 | 220 | } 221 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/utils/evaluation/MulticlassSequenceClassificationEvaluator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.utils.evaluation; 17 | 18 | import java.util.List; 19 | 20 | import it.uniroma2.sag.kelp.data.example.Example; 21 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 22 | import it.uniroma2.sag.kelp.data.example.SequencePath; 23 | import it.uniroma2.sag.kelp.data.label.Label; 24 | import it.uniroma2.sag.kelp.data.label.SequenceEmission; 25 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 26 | import it.uniroma2.sag.kelp.predictionfunction.SequencePrediction; 27 | 28 | /** 29 | * This is an instance of an Evaluator. It allows to compute the some common 30 | * measure for classification tasks acting over SequenceExamples. It 31 | * computes precision, recall, f1s for each class, and a global accuracy. 32 | * 33 | * @author Danilo Croce 34 | */ 35 | public class MulticlassSequenceClassificationEvaluator extends MulticlassClassificationEvaluator{ 36 | 37 | /** 38 | * Initialize a new F1Evaluator that will work on the specified classes 39 | * 40 | * @param labels 41 | */ 42 | public MulticlassSequenceClassificationEvaluator(List labels) { 43 | super(labels); 44 | } 45 | 46 | public void addCount(Example test, Prediction prediction) { 47 | addCount((SequenceExample) test, (SequencePrediction) prediction); 48 | } 49 | 50 | /** 51 | * This method should be implemented in the subclasses to update counters 52 | * useful to compute the performance measure 53 | * 54 | * @param test 55 | * the test example 56 | * @param predicted 57 | * the prediction of the system 58 | */ 59 | public void addCount(SequenceExample test, SequencePrediction predicted) { 60 | 61 | SequencePath bestPath = predicted.bestPath(); 62 | 63 | for (int seqIdx = 0; seqIdx < test.getLenght(); seqIdx++) { 64 | 65 | Example testItem = test.getExample(seqIdx); 66 | SequenceEmission sequenceLabel = bestPath.getAssignedSequnceLabels().get(seqIdx); 67 | 68 | for (Label l : this.labels) { 69 | ClassStats stats = this.classStats.get(l); 70 | if(testItem.isExampleOf(l)){ 71 | if(sequenceLabel.getLabel().equals(l)){ 72 | stats.tp++; 73 | totalTp++; 74 | }else{ 75 | stats.fn++; 76 | totalFn++; 77 | } 78 | }else{ 79 | if(sequenceLabel.getLabel().equals(l)){ 80 | stats.fp++; 81 | totalFp++; 82 | }else{ 83 | stats.tn++; 84 | totalTn++; 85 | } 86 | } 87 | 88 | } 89 | 90 | //TODO: check (i) e' giusto valutare l'accuracy dei singoli elementi della sequenza e non della sequenza completa 91 | //(ii) va considerato il caso multilabel 92 | total++; 93 | 94 | if (testItem.isExampleOf(sequenceLabel.getLabel())) { 95 | correct++; 96 | } 97 | 98 | this.computed = false; 99 | } 100 | } 101 | 102 | } 103 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/algorithms/binary/liblinear/LibLinearDenseVsSparseClassificationEvaluator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.algorithms.binary.liblinear; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.data.manipulator.NormalizationManipolator; 22 | import it.uniroma2.sag.kelp.data.manipulator.VectorConcatenationManipulator; 23 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.LibLinearLearningAlgorithm; 24 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 25 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassificationOutput; 26 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassifier; 27 | import it.uniroma2.sag.kelp.utils.evaluation.MulticlassClassificationEvaluator; 28 | import it.uniroma2.sag.kelp.utils.exception.NoSuchPerformanceMeasureException; 29 | 30 | import java.io.FileNotFoundException; 31 | import java.io.UnsupportedEncodingException; 32 | import java.util.ArrayList; 33 | import java.util.List; 34 | 35 | import org.junit.Assert; 36 | import org.junit.Test; 37 | 38 | public class LibLinearDenseVsSparseClassificationEvaluator { 39 | 40 | private static List sparseScores = new ArrayList(); 41 | private static List denseScores = new ArrayList(); 42 | 43 | @Test 44 | public void testConsistency() { 45 | try { 46 | String inputFilePath = "src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz"; 47 | 48 | SimpleDataset dataset = new SimpleDataset(); 49 | dataset.populate(inputFilePath); 50 | SimpleDataset[] split = dataset.split(0.5f); 51 | 52 | SimpleDataset trainingSet = split[0]; 53 | SimpleDataset testSet = split[1]; 54 | float c = 1.0f; 55 | float f1Dense = testDense(trainingSet, c, testSet); 56 | float f1Sparse = testSparse(trainingSet, c, testSet); 57 | 58 | Assert.assertEquals(f1Sparse, f1Dense, 0.000001); 59 | 60 | for (int i = 0; i < sparseScores.size(); i++) { 61 | Assert.assertEquals(sparseScores.get(i), denseScores.get(i), 62 | 0.000001); 63 | } 64 | } catch (FileNotFoundException e) { 65 | e.printStackTrace(); 66 | Assert.assertTrue(false); 67 | } catch (UnsupportedEncodingException e) { 68 | e.printStackTrace(); 69 | Assert.assertTrue(false); 70 | } catch (NoSuchPerformanceMeasureException e) { 71 | e.printStackTrace(); 72 | Assert.assertTrue(false); 73 | } catch (Exception e) { 74 | e.printStackTrace(); 75 | Assert.assertTrue(false); 76 | } 77 | } 78 | 79 | private static float testSparse(SimpleDataset trainingSet, float c, 80 | SimpleDataset testSet) throws FileNotFoundException, 81 | UnsupportedEncodingException, NoSuchPerformanceMeasureException { 82 | List classes = trainingSet.getClassificationLabels(); 83 | NormalizationManipolator norma = new NormalizationManipolator(); 84 | trainingSet.manipulate(norma); 85 | testSet.manipulate(norma); 86 | List repr = new ArrayList(); 87 | repr.add("WS"); 88 | List reprW = new ArrayList(); 89 | reprW.add(1.0f); 90 | VectorConcatenationManipulator man = new VectorConcatenationManipulator( 91 | "WS0", repr, reprW); 92 | trainingSet.manipulate(man); 93 | testSet.manipulate(man); 94 | 95 | LibLinearLearningAlgorithm svmSolver = new LibLinearLearningAlgorithm(); 96 | svmSolver.setCn(c); 97 | svmSolver.setCp(c); 98 | svmSolver.setRepresentation("WS0"); 99 | 100 | OneVsAllLearning ovaLearner = new OneVsAllLearning(); 101 | ovaLearner.setBaseAlgorithm(svmSolver); 102 | ovaLearner.setLabels(classes); 103 | ovaLearner.learn(trainingSet); 104 | OneVsAllClassifier f = ovaLearner.getPredictionFunction(); 105 | MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator( 106 | trainingSet.getClassificationLabels()); 107 | for (Example e : testSet.getExamples()) { 108 | OneVsAllClassificationOutput predict = f.predict(e); 109 | Label l = predict.getPredictedClasses().get(0); 110 | evaluator.addCount(e, predict); 111 | sparseScores.add(predict.getScore(l)); 112 | } 113 | 114 | return evaluator.getMacroF1(); 115 | } 116 | 117 | private static float testDense(SimpleDataset trainingSet, float c, 118 | SimpleDataset testSet) throws FileNotFoundException, 119 | UnsupportedEncodingException, NoSuchPerformanceMeasureException { 120 | List classes = trainingSet.getClassificationLabels(); 121 | 122 | LibLinearLearningAlgorithm svmSolver = new LibLinearLearningAlgorithm(); 123 | svmSolver.setCn(c); 124 | svmSolver.setCp(c); 125 | svmSolver.setRepresentation("WS"); 126 | 127 | OneVsAllLearning ovaLearner = new OneVsAllLearning(); 128 | ovaLearner.setBaseAlgorithm(svmSolver); 129 | ovaLearner.setLabels(classes); 130 | ovaLearner.learn(trainingSet); 131 | OneVsAllClassifier f = ovaLearner.getPredictionFunction(); 132 | MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator( 133 | trainingSet.getClassificationLabels()); 134 | for (Example e : testSet.getExamples()) { 135 | OneVsAllClassificationOutput predict = f.predict(e); 136 | Label l = predict.getPredictedClasses().get(0); 137 | evaluator.addCount(e, predict); 138 | denseScores.add(predict.getScore(l)); 139 | } 140 | 141 | return evaluator.getMacroF1(); 142 | } 143 | 144 | } 145 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/algorithms/incrementalTrain/IncrementalTrainTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Simone Filice and Giuseppe Castellucci and Danilo Croce 3 | * and Giovanni Da San Martino and Alessandro Moschitti and Roberto Basili 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package it.uniroma2.sag.kelp.algorithms.incrementalTrain; 18 | 19 | import java.io.IOException; 20 | import java.util.Random; 21 | 22 | import org.junit.Assert; 23 | import org.junit.BeforeClass; 24 | import org.junit.Test; 25 | 26 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 27 | import it.uniroma2.sag.kelp.data.example.Example; 28 | import it.uniroma2.sag.kelp.data.label.Label; 29 | import it.uniroma2.sag.kelp.data.label.StringLabel; 30 | import it.uniroma2.sag.kelp.kernel.Kernel; 31 | import it.uniroma2.sag.kelp.kernel.cache.FixSizeKernelCache; 32 | import it.uniroma2.sag.kelp.kernel.vector.LinearKernel; 33 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 34 | import it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron.KernelizedPerceptron; 35 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryKernelMachineClassifier; 36 | import it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput; 37 | import it.uniroma2.sag.kelp.predictionfunction.classifier.Classifier; 38 | import it.uniroma2.sag.kelp.utils.JacksonSerializerWrapper; 39 | import it.uniroma2.sag.kelp.utils.ObjectSerializer; 40 | 41 | public class IncrementalTrainTest { 42 | private static Classifier f = null; 43 | private static SimpleDataset trainingSet; 44 | private static SimpleDataset testSet; 45 | private static SimpleDataset [] folds; 46 | private static ObjectSerializer serializer = new JacksonSerializerWrapper(); 47 | private static KernelizedPerceptron learner; 48 | 49 | private static Label positiveClass = new StringLabel("+1"); 50 | 51 | @BeforeClass 52 | public static void learnModel() { 53 | trainingSet = new SimpleDataset(); 54 | testSet = new SimpleDataset(); 55 | try { 56 | trainingSet.populate("src/test/resources/svmTest/binary/binary_train.klp"); 57 | trainingSet.shuffleExamples(new Random()); 58 | // Read a dataset into a test variable 59 | testSet.populate("src/test/resources/svmTest/binary/binary_test.klp"); 60 | } catch (Exception e) { 61 | e.printStackTrace(); 62 | Assert.assertTrue(false); 63 | } 64 | 65 | folds = trainingSet.nFolding(2); 66 | 67 | // define the kernel 68 | Kernel kernel = new LinearKernel("0"); 69 | 70 | // add a cache 71 | kernel.setKernelCache(new FixSizeKernelCache(trainingSet 72 | .getNumberOfExamples())); 73 | 74 | // define the learning algorithm 75 | learner = new KernelizedPerceptron(0.2f, 1f, false, kernel, positiveClass); 76 | 77 | // learn and get the prediction function 78 | learner.learn(trainingSet); 79 | f = learner.getPredictionFunction(); 80 | } 81 | 82 | @Test 83 | public void incrementalTrain() throws IOException{ 84 | String jsonSerialization = serializer.writeValueAsString(learner); 85 | System.out.println(jsonSerialization); 86 | ClassificationLearningAlgorithm jsonAlgo = serializer.readValue(jsonSerialization, ClassificationLearningAlgorithm.class); 87 | jsonAlgo.learn(folds[0]); 88 | jsonAlgo.learn(folds[1]); 89 | Classifier jsonClassifier = jsonAlgo.getPredictionFunction(); 90 | 91 | for(Example ex : testSet.getExamples()){ 92 | ClassificationOutput p = f.predict(ex); 93 | Float score = p.getScore(positiveClass); 94 | ClassificationOutput pJson = jsonClassifier.predict(ex); 95 | Float scoreJson = pJson.getScore(positiveClass); 96 | Assert.assertEquals(scoreJson.floatValue(), score.floatValue(), 97 | 0.001f); 98 | } 99 | } 100 | 101 | @Test 102 | public void reloadAndContinueTraining() throws IOException{ 103 | String jsonLearnerSerialization = serializer.writeValueAsString(learner); 104 | System.out.println(jsonLearnerSerialization); 105 | KernelizedPerceptron jsonAlgo = serializer.readValue(jsonLearnerSerialization, KernelizedPerceptron.class); 106 | jsonAlgo.learn(folds[0]); 107 | String jsonClassifierSerialization = serializer.writeValueAsString(jsonAlgo.getPredictionFunction()); 108 | jsonAlgo = serializer.readValue(jsonLearnerSerialization, KernelizedPerceptron.class); //Brand new classifier 109 | BinaryKernelMachineClassifier jsonClassifier = serializer.readValue(jsonClassifierSerialization, BinaryKernelMachineClassifier.class); 110 | jsonAlgo.getPredictionFunction().setModel(jsonClassifier.getModel()); 111 | jsonAlgo.learn(folds[1]); 112 | jsonClassifier = jsonAlgo.getPredictionFunction(); 113 | 114 | for(Example ex : testSet.getExamples()){ 115 | ClassificationOutput p = f.predict(ex); 116 | Float score = p.getScore(positiveClass); 117 | ClassificationOutput pJson = jsonClassifier.predict(ex); 118 | Float scoreJson = pJson.getScore(positiveClass); 119 | Assert.assertEquals(scoreJson.floatValue(), score.floatValue(), 120 | 0.001f); 121 | } 122 | } 123 | 124 | } 125 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/learningalgorithm/classification/hmm/SequenceLearningLinearTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.hmm; 17 | 18 | import java.io.BufferedReader; 19 | import java.io.FileInputStream; 20 | import java.io.FileNotFoundException; 21 | import java.io.IOException; 22 | import java.io.InputStreamReader; 23 | import java.io.UnsupportedEncodingException; 24 | import java.util.ArrayList; 25 | import java.util.zip.GZIPInputStream; 26 | 27 | import org.junit.Assert; 28 | import org.junit.Test; 29 | 30 | import it.uniroma2.sag.kelp.data.dataset.SequenceDataset; 31 | import it.uniroma2.sag.kelp.data.example.Example; 32 | import it.uniroma2.sag.kelp.data.example.ParsingExampleException; 33 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 34 | import it.uniroma2.sag.kelp.data.example.SequencePath; 35 | import it.uniroma2.sag.kelp.data.label.Label; 36 | import it.uniroma2.sag.kelp.data.label.StringLabel; 37 | import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLearningAlgorithm; 38 | import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLoss; 39 | import it.uniroma2.sag.kelp.predictionfunction.SequencePrediction; 40 | import it.uniroma2.sag.kelp.predictionfunction.SequencePredictionFunction; 41 | 42 | public class SequenceLearningLinearTest { 43 | 44 | private static final Float TOLERANCE = 0.001f; 45 | 46 | public static void main(String[] args) throws Exception { 47 | 48 | } 49 | 50 | @Test 51 | public void testLinear() { 52 | 53 | String inputTrainFilePath = "src/test/resources/sequence_learning/declaration_of_independence.klp.gz"; 54 | String inputTestFilePath = "src/test/resources/sequence_learning/gettysburg_address.klp.gz"; 55 | String scoreFilePath = "src/test/resources/sequence_learning/prediction_test_linear.txt"; 56 | 57 | /* 58 | * Given a targeted item in the sequence, this variable determines the 59 | * number of previous example considered in the learning/labeling 60 | * process. 61 | * 62 | * NOTE: if this variable is set to 0, the learning process corresponds 63 | * to a traditional multi-class classification schema 64 | */ 65 | int transitionsOrder = 1; 66 | 67 | /* 68 | * This variable determines the importance of the transition-based 69 | * features during the learning process. Higher valuers will assign more 70 | * importance to the transitions. 71 | */ 72 | float weight = 1f; 73 | 74 | /* 75 | * The size of the beam to be used in the decoding process. This number 76 | * determines the number of possible sequences produced in the labeling 77 | * process. It will also increase the process complexity. 78 | */ 79 | int beamSize = 5; 80 | 81 | /* 82 | * During the labeling process, each item is classified with respect to 83 | * the target classes. To reduce the complexity of the labeling process, 84 | * this variable determines the number of classes that received the 85 | * highest classification scores to be considered after the 86 | * classification step in the Viterbi Decoding. 87 | */ 88 | int maxEmissionCandidates = 3; 89 | 90 | /* 91 | * This representation contains the feature vector representing items in 92 | * the sequence 93 | */ 94 | String originalRepresentationName = "rep"; 95 | 96 | /* 97 | * Loading the training dataset 98 | */ 99 | SequenceDataset sequenceTrainDataset = new SequenceDataset(); 100 | try { 101 | sequenceTrainDataset.populate(inputTrainFilePath); 102 | } catch (IOException e) { 103 | e.printStackTrace(); 104 | Assert.assertTrue(false); 105 | } catch (InstantiationException e) { 106 | e.printStackTrace(); 107 | Assert.assertTrue(false); 108 | } catch (ParsingExampleException e) { 109 | e.printStackTrace(); 110 | Assert.assertTrue(false); 111 | } catch (Exception e) { 112 | e.printStackTrace(); 113 | Assert.assertTrue(false); 114 | } 115 | 116 | /* 117 | * Instance classifier 118 | */ 119 | float cSVM = 1f; 120 | DCDLearningAlgorithm instanceClassifierLearningAlgorithm = new DCDLearningAlgorithm(cSVM, cSVM, DCDLoss.L1, 121 | false, 50, originalRepresentationName); 122 | 123 | /* 124 | * Sequence classifier. 125 | */ 126 | SequenceClassificationLearningAlgorithm sequenceClassificationLearningAlgorithm = null; 127 | try { 128 | sequenceClassificationLearningAlgorithm = new SequenceClassificationLinearLearningAlgorithm( 129 | instanceClassifierLearningAlgorithm, transitionsOrder, weight); 130 | sequenceClassificationLearningAlgorithm.setMaxEmissionCandidates(maxEmissionCandidates); 131 | sequenceClassificationLearningAlgorithm.setBeamSize(beamSize); 132 | 133 | sequenceClassificationLearningAlgorithm.learn(sequenceTrainDataset); 134 | } catch (Exception e1) { 135 | e1.printStackTrace(); 136 | Assert.assertTrue(false); 137 | } 138 | 139 | SequencePredictionFunction predictionFunction = (SequencePredictionFunction) sequenceClassificationLearningAlgorithm 140 | .getPredictionFunction(); 141 | 142 | /* 143 | * Load the test set 144 | */ 145 | SequenceDataset sequenceTestDataset = new SequenceDataset(); 146 | try { 147 | sequenceTestDataset.populate(inputTestFilePath); 148 | } catch (IOException e) { 149 | e.printStackTrace(); 150 | Assert.assertTrue(false); 151 | } catch (InstantiationException e) { 152 | e.printStackTrace(); 153 | Assert.assertTrue(false); 154 | } catch (ParsingExampleException e) { 155 | e.printStackTrace(); 156 | Assert.assertTrue(false); 157 | } 158 | 159 | /* 160 | * Tagging and evaluating 161 | */ 162 | // PrintStream ps = new PrintStream(scoreFilePath); 163 | ArrayList labels = new ArrayList(); 164 | ArrayList scores = new ArrayList(); 165 | for (Example example : sequenceTestDataset.getExamples()) { 166 | 167 | SequenceExample sequenceExample = (SequenceExample) example; 168 | SequencePrediction sequencePrediction = (SequencePrediction) predictionFunction.predict(sequenceExample); 169 | 170 | SequencePath bestPath = sequencePrediction.bestPath(); 171 | for (int i = 0; i < sequenceExample.getLenght(); i++) { 172 | // ps.println(bestPath.getAssignedLabel(i) + "\t" + 173 | // bestPath.getScore()); 174 | labels.add(bestPath.getAssignedLabel(i)); 175 | scores.add(bestPath.getScore()); 176 | } 177 | 178 | } 179 | // ps.close(); 180 | 181 | ArrayList oldScores = loadScores(scoreFilePath); 182 | ArrayList oldLabels = loadLabels(scoreFilePath); 183 | 184 | for (int i = 0; i < oldScores.size(); i++) { 185 | Assert.assertEquals(oldScores.get(i), scores.get(i), TOLERANCE); 186 | Assert.assertEquals(labels.get(i).toString(), oldLabels.get(i).toString()); 187 | } 188 | 189 | } 190 | 191 | public static ArrayList loadScores(String filepath) { 192 | try { 193 | ArrayList scores = new ArrayList(); 194 | BufferedReader in = null; 195 | String encoding = "UTF-8"; 196 | if (filepath.endsWith(".gz")) { 197 | in = new BufferedReader( 198 | new InputStreamReader(new GZIPInputStream(new FileInputStream(filepath)), encoding)); 199 | } else { 200 | in = new BufferedReader(new InputStreamReader(new FileInputStream(filepath), encoding)); 201 | } 202 | 203 | String str = ""; 204 | while ((str = in.readLine()) != null) { 205 | scores.add(Double.parseDouble(str.split("\t")[1])); 206 | } 207 | 208 | in.close(); 209 | 210 | return scores; 211 | 212 | } catch (UnsupportedEncodingException e) { 213 | e.printStackTrace(); 214 | Assert.assertTrue(false); 215 | } catch (FileNotFoundException e) { 216 | e.printStackTrace(); 217 | Assert.assertTrue(false); 218 | } catch (IOException e) { 219 | e.printStackTrace(); 220 | Assert.assertTrue(false); 221 | } 222 | 223 | return null; 224 | } 225 | 226 | public static ArrayList loadLabels(String filepath) { 227 | try { 228 | ArrayList res = new ArrayList(); 229 | BufferedReader in = null; 230 | String encoding = "UTF-8"; 231 | if (filepath.endsWith(".gz")) { 232 | in = new BufferedReader( 233 | new InputStreamReader(new GZIPInputStream(new FileInputStream(filepath)), encoding)); 234 | } else { 235 | in = new BufferedReader(new InputStreamReader(new FileInputStream(filepath), encoding)); 236 | } 237 | 238 | String str = ""; 239 | while ((str = in.readLine()) != null) { 240 | res.add(new StringLabel(str.split("\t")[0])); 241 | } 242 | 243 | in.close(); 244 | 245 | return res; 246 | 247 | } catch (UnsupportedEncodingException e) { 248 | e.printStackTrace(); 249 | Assert.assertTrue(false); 250 | } catch (FileNotFoundException e) { 251 | e.printStackTrace(); 252 | Assert.assertTrue(false); 253 | } catch (IOException e) { 254 | e.printStackTrace(); 255 | Assert.assertTrue(false); 256 | } 257 | 258 | return null; 259 | } 260 | 261 | } 262 | -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/README.txt: -------------------------------------------------------------------------------- 1 | The datasets reported in this folder have been created starting from the dataset produced by Thorsten Joachims as an example problem for his SVM^{hmm} implementation. 2 | 3 | The original dataset can be downloaded at: 4 | http://download.joachims.org/svm_hmm/examples/example7.tar.gz 5 | while its description is reported at: 6 | https://www.cs.cornell.edu/people/tj/svm_light/svm_hmm.html -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/declaration_of_independence.klp.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/sequence_learning/declaration_of_independence.klp.gz -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/gettysburg_address.klp.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/sequence_learning/gettysburg_address.klp.gz -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/prediction_test_kernel.txt: -------------------------------------------------------------------------------- 1 | 1 -58.85170393685728 2 | 22 -58.85170393685728 3 | 3 -58.85170393685728 4 | 8 -58.85170393685728 5 | 15 -58.85170393685728 6 | 30 -58.85170393685728 7 | 20 -58.85170393685728 8 | 15 -58.85170393685728 9 | 11 -58.85170393685728 10 | 12 -58.85170393685728 11 | 8 -58.85170393685728 12 | 7 -58.85170393685728 13 | 12 -58.85170393685728 14 | 7 -58.85170393685728 15 | 9 -58.85170393685728 16 | 12 -58.85170393685728 17 | 6 -58.85170393685728 18 | 27 -58.85170393685728 19 | 8 -58.85170393685728 20 | 13 -58.85170393685728 21 | 3 -58.85170393685728 22 | 27 -58.85170393685728 23 | 25 -58.85170393685728 24 | 7 -58.85170393685728 25 | 12 -58.85170393685728 26 | 7 -58.85170393685728 27 | 7 -58.85170393685728 28 | 15 -58.85170393685728 29 | 30 -58.85170393685728 30 | 29 -58.85170393685728 31 | 9 -58.85170393685728 32 | 17 -58.85170393685728 33 | 21 -49.958351223707055 34 | 19 -49.958351223707055 35 | 30 -49.958351223707055 36 | 29 -49.958351223707055 37 | 8 -49.958351223707055 38 | 7 -49.958351223707055 39 | 13 -49.958351223707055 40 | 9 -49.958351223707055 41 | 12 -49.958351223707055 42 | 6 -49.958351223707055 43 | 28 -49.958351223707055 44 | 21 -49.958351223707055 45 | 7 -49.958351223707055 46 | 12 -49.958351223707055 47 | 8 -49.958351223707055 48 | 7 -49.958351223707055 49 | 12 -49.958351223707055 50 | 8 -49.958351223707055 51 | 29 -49.958351223707055 52 | 3 -49.958351223707055 53 | 21 -49.958351223707055 54 | 27 -49.958351223707055 55 | 9 -49.958351223707055 56 | 9 -49.958351223707055 57 | 12 -49.958351223707055 58 | 17 -49.958351223707055 59 | 19 -19.14935390144825 60 | 30 -19.14935390144825 61 | 29 -19.14935390144825 62 | 8 -19.14935390144825 63 | 7 -19.14935390144825 64 | 13 -19.14935390144825 65 | 13 -19.14935390144825 66 | 8 -19.14935390144825 67 | 7 -19.14935390144825 68 | 12 -19.14935390144825 69 | 17 -19.14935390144825 70 | 19 -51.68865761583535 71 | 30 -51.68865761583535 72 | 9 -51.68865761583535 73 | 25 -51.68865761583535 74 | 26 -51.68865761583535 75 | 7 -51.68865761583535 76 | 12 -51.68865761583535 77 | 8 -51.68865761583535 78 | 7 -51.68865761583535 79 | 12 -51.68865761583535 80 | 8 -51.68865761583535 81 | 7 -51.68865761583535 82 | 9 -51.68865761583535 83 | 12 -51.68865761583535 84 | 8 -51.68865761583535 85 | 7 -51.68865761583535 86 | 9 -51.68865761583535 87 | 12 -51.68865761583535 88 | 30 -51.68865761583535 89 | 20 -51.68865761583535 90 | 15 -51.68865761583535 91 | 7 -51.68865761583535 92 | 7 -51.68865761583535 93 | 12 -51.68865761583535 94 | 12 -51.68865761583535 95 | 30 -51.68865761583535 96 | 17 -51.68865761583535 97 | 19 -20.546289531993914 98 | 31 -20.546289531993914 99 | 21 -20.546289531993914 100 | 28 -20.546289531993914 101 | 3 -20.546289531993914 102 | 8 -20.546289531993914 103 | 7 -20.546289531993914 104 | 19 -20.546289531993914 105 | 11 -20.546289531993914 106 | 26 -20.546289531993914 107 | 7 -20.546289531993914 108 | 17 -20.546289531993914 109 | 3 -39.18519755398995 110 | 8 -39.18519755398995 111 | 7 -39.18519755398995 112 | 9 -39.18519755398995 113 | 12 -39.18519755398995 114 | 6 -39.18519755398995 115 | 19 -39.18519755398995 116 | 11 -39.18519755398995 117 | 26 -39.18519755398995 118 | 6 -39.18519755398995 119 | 19 -39.18519755398995 120 | 11 -39.18519755398995 121 | 26 -39.18519755398995 122 | 6 -39.18519755398995 123 | 19 -39.18519755398995 124 | 11 -39.18519755398995 125 | 26 -39.18519755398995 126 | 7 -39.18519755398995 127 | 12 -39.18519755398995 128 | 17 -39.18519755398995 129 | 7 -46.958838324933005 130 | 9 -46.958838324933005 131 | 15 -46.958838324933005 132 | 6 -46.958838324933005 133 | 28 -46.958838324933005 134 | 3 -46.958838324933005 135 | 12 -46.958838324933005 136 | 21 -46.958838324933005 137 | 29 -46.958838324933005 138 | 21 -46.958838324933005 139 | 30 -46.958838324933005 140 | 29 -46.958838324933005 141 | 19 -46.958838324933005 142 | 20 -46.958838324933005 143 | 12 -46.958838324933005 144 | 20 -46.958838324933005 145 | 9 -46.958838324933005 146 | 12 -46.958838324933005 147 | 25 -46.958838324933005 148 | 26 -46.958838324933005 149 | 3 -46.958838324933005 150 | 9 -46.958838324933005 151 | 17 -46.958838324933005 152 | 7 -49.424689389703104 153 | 12 -49.424689389703104 154 | 11 -49.424689389703104 155 | 26 -49.424689389703104 156 | 9 -49.424689389703104 157 | 3 -49.424689389703104 158 | 9 -49.424689389703104 159 | 12 -49.424689389703104 160 | 8 -49.424689389703104 161 | 19 -49.424689389703104 162 | 11 -49.424689389703104 163 | 21 -49.424689389703104 164 | 6 -49.424689389703104 165 | 3 -49.424689389703104 166 | 19 -49.424689389703104 167 | 8 -49.424689389703104 168 | 12 -49.424689389703104 169 | 26 -49.424689389703104 170 | 7 -49.424689389703104 171 | 19 -49.424689389703104 172 | 30 -49.424689389703104 173 | 21 -49.424689389703104 174 | 17 -49.424689389703104 175 | 19 -52.114114669781316 176 | 31 -52.114114669781316 177 | 8 -52.114114669781316 178 | 19 -52.114114669781316 179 | 7 -52.114114669781316 180 | 28 -52.114114669781316 181 | 21 -52.114114669781316 182 | 25 -52.114114669781316 183 | 26 -52.114114669781316 184 | 27 -52.114114669781316 185 | 21 -52.114114669781316 186 | 25 -52.114114669781316 187 | 7 -52.114114669781316 188 | 9 -52.114114669781316 189 | 12 -52.114114669781316 190 | 32 -52.114114669781316 191 | 19 -52.114114669781316 192 | 30 -52.114114669781316 193 | 11 -52.114114669781316 194 | 21 -52.114114669781316 195 | 30 -52.114114669781316 196 | 21 -52.114114669781316 197 | 29 -52.114114669781316 198 | 21 -52.114114669781316 199 | 21 -52.114114669781316 200 | 29 -52.114114669781316 201 | 17 -52.114114669781316 202 | 19 -157.40459068974272 203 | 31 -157.40459068974272 204 | 21 -157.40459068974272 205 | 8 -157.40459068974272 206 | 19 -157.40459068974272 207 | 25 -157.40459068974272 208 | 26 -157.40459068974272 209 | 21 -157.40459068974272 210 | 27 -157.40459068974272 211 | 25 -157.40459068974272 212 | 7 -157.40459068974272 213 | 13 -157.40459068974272 214 | 13 -157.40459068974272 215 | 28 -157.40459068974272 216 | 21 -157.40459068974272 217 | 19 -157.40459068974272 218 | 5 -157.40459068974272 219 | 7 -157.40459068974272 220 | 8 -157.40459068974272 221 | 7 -157.40459068974272 222 | 9 -157.40459068974272 223 | 12 -157.40459068974272 224 | 19 -157.40459068974272 225 | 30 -157.40459068974272 226 | 9 -157.40459068974272 227 | 12 -157.40459068974272 228 | 25 -157.40459068974272 229 | 7 -157.40459068974272 230 | 12 -157.40459068974272 231 | 8 -157.40459068974272 232 | 32 -157.40459068974272 233 | 19 -157.40459068974272 234 | 30 -157.40459068974272 235 | 7 -157.40459068974272 236 | 23 -157.40459068974272 237 | 9 -157.40459068974272 238 | 12 -157.40459068974272 239 | 8 -157.40459068974272 240 | 12 -157.40459068974272 241 | 8 -157.40459068974272 242 | 7 -157.40459068974272 243 | 19 -157.40459068974272 244 | 21 -157.40459068974272 245 | 21 -157.40459068974272 246 | 26 -157.40459068974272 247 | 7 -157.40459068974272 248 | 7 -157.40459068974272 249 | 12 -157.40459068974272 250 | 11 -157.40459068974272 251 | 21 -157.40459068974272 252 | 30 -157.40459068974272 253 | 29 -157.40459068974272 254 | 8 -157.40459068974272 255 | 9 -157.40459068974272 256 | 6 -157.40459068974272 257 | 7 -157.40459068974272 258 | 7 -157.40459068974272 259 | 12 -157.40459068974272 260 | 8 -157.40459068974272 261 | 13 -157.40459068974272 262 | 11 -157.40459068974272 263 | 30 -157.40459068974272 264 | 7 -157.40459068974272 265 | 9 -157.40459068974272 266 | 12 -157.40459068974272 267 | 8 -157.40459068974272 268 | 12 -157.40459068974272 269 | 6 -157.40459068974272 270 | 3 -157.40459068974272 271 | 7 -157.40459068974272 272 | 12 -157.40459068974272 273 | 8 -157.40459068974272 274 | 7 -157.40459068974272 275 | 15 -157.40459068974272 276 | 6 -157.40459068974272 277 | 8 -157.40459068974272 278 | 7 -157.40459068974272 279 | 15 -157.40459068974272 280 | 6 -157.40459068974272 281 | 8 -157.40459068974272 282 | 7 -157.40459068974272 283 | 15 -157.40459068974272 284 | 11 -157.40459068974272 285 | 21 -157.40459068974272 286 | 26 -157.40459068974272 287 | 8 -157.40459068974272 288 | 7 -157.40459068974272 289 | 12 -157.40459068974272 290 | 17 -157.40459068974272 291 | -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/prediction_test_linear.txt: -------------------------------------------------------------------------------- 1 | 1 -61.552865965064605 2 | 22 -61.552865965064605 3 | 3 -61.552865965064605 4 | 8 -61.552865965064605 5 | 15 -61.552865965064605 6 | 30 -61.552865965064605 7 | 20 -61.552865965064605 8 | 15 -61.552865965064605 9 | 11 -61.552865965064605 10 | 12 -61.552865965064605 11 | 8 -61.552865965064605 12 | 7 -61.552865965064605 13 | 12 -61.552865965064605 14 | 7 -61.552865965064605 15 | 9 -61.552865965064605 16 | 12 -61.552865965064605 17 | 6 -61.552865965064605 18 | 27 -61.552865965064605 19 | 8 -61.552865965064605 20 | 13 -61.552865965064605 21 | 3 -61.552865965064605 22 | 27 -61.552865965064605 23 | 25 -61.552865965064605 24 | 7 -61.552865965064605 25 | 12 -61.552865965064605 26 | 7 -61.552865965064605 27 | 7 -61.552865965064605 28 | 15 -61.552865965064605 29 | 30 -61.552865965064605 30 | 29 -61.552865965064605 31 | 9 -61.552865965064605 32 | 17 -61.552865965064605 33 | 21 -50.586976361817456 34 | 19 -50.586976361817456 35 | 30 -50.586976361817456 36 | 29 -50.586976361817456 37 | 8 -50.586976361817456 38 | 7 -50.586976361817456 39 | 13 -50.586976361817456 40 | 9 -50.586976361817456 41 | 12 -50.586976361817456 42 | 6 -50.586976361817456 43 | 28 -50.586976361817456 44 | 21 -50.586976361817456 45 | 7 -50.586976361817456 46 | 12 -50.586976361817456 47 | 3 -50.586976361817456 48 | 7 -50.586976361817456 49 | 12 -50.586976361817456 50 | 21 -50.586976361817456 51 | 21 -50.586976361817456 52 | 3 -50.586976361817456 53 | 21 -50.586976361817456 54 | 21 -50.586976361817456 55 | 7 -50.586976361817456 56 | 9 -50.586976361817456 57 | 12 -50.586976361817456 58 | 17 -50.586976361817456 59 | 19 -19.745991163812985 60 | 30 -19.745991163812985 61 | 26 -19.745991163812985 62 | 8 -19.745991163812985 63 | 7 -19.745991163812985 64 | 13 -19.745991163812985 65 | 13 -19.745991163812985 66 | 8 -19.745991163812985 67 | 7 -19.745991163812985 68 | 12 -19.745991163812985 69 | 17 -19.745991163812985 70 | 19 -54.528993898737625 71 | 30 -54.528993898737625 72 | 9 -54.528993898737625 73 | 25 -54.528993898737625 74 | 26 -54.528993898737625 75 | 7 -54.528993898737625 76 | 12 -54.528993898737625 77 | 8 -54.528993898737625 78 | 7 -54.528993898737625 79 | 12 -54.528993898737625 80 | 8 -54.528993898737625 81 | 7 -54.528993898737625 82 | 9 -54.528993898737625 83 | 12 -54.528993898737625 84 | 8 -54.528993898737625 85 | 7 -54.528993898737625 86 | 30 -54.528993898737625 87 | 21 -54.528993898737625 88 | 30 -54.528993898737625 89 | 20 -54.528993898737625 90 | 15 -54.528993898737625 91 | 7 -54.528993898737625 92 | 7 -54.528993898737625 93 | 12 -54.528993898737625 94 | 12 -54.528993898737625 95 | 30 -54.528993898737625 96 | 17 -54.528993898737625 97 | 19 -21.642864657110263 98 | 31 -21.642864657110263 99 | 21 -21.642864657110263 100 | 28 -21.642864657110263 101 | 3 -21.642864657110263 102 | 8 -21.642864657110263 103 | 7 -21.642864657110263 104 | 19 -21.642864657110263 105 | 11 -21.642864657110263 106 | 26 -21.642864657110263 107 | 7 -21.642864657110263 108 | 17 -21.642864657110263 109 | 3 -41.163958681094705 110 | 8 -41.163958681094705 111 | 7 -41.163958681094705 112 | 9 -41.163958681094705 113 | 12 -41.163958681094705 114 | 6 -41.163958681094705 115 | 19 -41.163958681094705 116 | 11 -41.163958681094705 117 | 26 -41.163958681094705 118 | 6 -41.163958681094705 119 | 19 -41.163958681094705 120 | 11 -41.163958681094705 121 | 26 -41.163958681094705 122 | 6 -41.163958681094705 123 | 19 -41.163958681094705 124 | 11 -41.163958681094705 125 | 26 -41.163958681094705 126 | 7 -41.163958681094705 127 | 12 -41.163958681094705 128 | 17 -41.163958681094705 129 | 7 -47.69120077996808 130 | 9 -47.69120077996808 131 | 15 -47.69120077996808 132 | 6 -47.69120077996808 133 | 28 -47.69120077996808 134 | 3 -47.69120077996808 135 | 12 -47.69120077996808 136 | 21 -47.69120077996808 137 | 29 -47.69120077996808 138 | 21 -47.69120077996808 139 | 30 -47.69120077996808 140 | 29 -47.69120077996808 141 | 19 -47.69120077996808 142 | 12 -47.69120077996808 143 | 30 -47.69120077996808 144 | 20 -47.69120077996808 145 | 9 -47.69120077996808 146 | 12 -47.69120077996808 147 | 25 -47.69120077996808 148 | 26 -47.69120077996808 149 | 3 -47.69120077996808 150 | 12 -47.69120077996808 151 | 17 -47.69120077996808 152 | 7 -49.193906625295206 153 | 12 -49.193906625295206 154 | 11 -49.193906625295206 155 | 26 -49.193906625295206 156 | 9 -49.193906625295206 157 | 3 -49.193906625295206 158 | 9 -49.193906625295206 159 | 12 -49.193906625295206 160 | 7 -49.193906625295206 161 | 19 -49.193906625295206 162 | 11 -49.193906625295206 163 | 21 -49.193906625295206 164 | 6 -49.193906625295206 165 | 3 -49.193906625295206 166 | 19 -49.193906625295206 167 | 8 -49.193906625295206 168 | 12 -49.193906625295206 169 | 26 -49.193906625295206 170 | 7 -49.193906625295206 171 | 19 -49.193906625295206 172 | 30 -49.193906625295206 173 | 21 -49.193906625295206 174 | 17 -49.193906625295206 175 | 19 -52.77654733531991 176 | 31 -52.77654733531991 177 | 8 -52.77654733531991 178 | 19 -52.77654733531991 179 | 7 -52.77654733531991 180 | 28 -52.77654733531991 181 | 21 -52.77654733531991 182 | 25 -52.77654733531991 183 | 26 -52.77654733531991 184 | 27 -52.77654733531991 185 | 21 -52.77654733531991 186 | 25 -52.77654733531991 187 | 7 -52.77654733531991 188 | 9 -52.77654733531991 189 | 12 -52.77654733531991 190 | 32 -52.77654733531991 191 | 19 -52.77654733531991 192 | 30 -52.77654733531991 193 | 11 -52.77654733531991 194 | 21 -52.77654733531991 195 | 30 -52.77654733531991 196 | 21 -52.77654733531991 197 | 12 -52.77654733531991 198 | 21 -52.77654733531991 199 | 21 -52.77654733531991 200 | 29 -52.77654733531991 201 | 17 -52.77654733531991 202 | 19 -162.86896772139426 203 | 31 -162.86896772139426 204 | 21 -162.86896772139426 205 | 8 -162.86896772139426 206 | 19 -162.86896772139426 207 | 25 -162.86896772139426 208 | 26 -162.86896772139426 209 | 21 -162.86896772139426 210 | 27 -162.86896772139426 211 | 25 -162.86896772139426 212 | 7 -162.86896772139426 213 | 13 -162.86896772139426 214 | 13 -162.86896772139426 215 | 28 -162.86896772139426 216 | 21 -162.86896772139426 217 | 19 -162.86896772139426 218 | 5 -162.86896772139426 219 | 7 -162.86896772139426 220 | 8 -162.86896772139426 221 | 7 -162.86896772139426 222 | 9 -162.86896772139426 223 | 12 -162.86896772139426 224 | 19 -162.86896772139426 225 | 30 -162.86896772139426 226 | 9 -162.86896772139426 227 | 12 -162.86896772139426 228 | 25 -162.86896772139426 229 | 7 -162.86896772139426 230 | 12 -162.86896772139426 231 | 8 -162.86896772139426 232 | 32 -162.86896772139426 233 | 19 -162.86896772139426 234 | 30 -162.86896772139426 235 | 7 -162.86896772139426 236 | 23 -162.86896772139426 237 | 9 -162.86896772139426 238 | 12 -162.86896772139426 239 | 8 -162.86896772139426 240 | 12 -162.86896772139426 241 | 5 -162.86896772139426 242 | 7 -162.86896772139426 243 | 19 -162.86896772139426 244 | 21 -162.86896772139426 245 | 21 -162.86896772139426 246 | 26 -162.86896772139426 247 | 7 -162.86896772139426 248 | 7 -162.86896772139426 249 | 12 -162.86896772139426 250 | 11 -162.86896772139426 251 | 21 -162.86896772139426 252 | 30 -162.86896772139426 253 | 29 -162.86896772139426 254 | 8 -162.86896772139426 255 | 9 -162.86896772139426 256 | 6 -162.86896772139426 257 | 7 -162.86896772139426 258 | 7 -162.86896772139426 259 | 12 -162.86896772139426 260 | 8 -162.86896772139426 261 | 13 -162.86896772139426 262 | 11 -162.86896772139426 263 | 30 -162.86896772139426 264 | 7 -162.86896772139426 265 | 9 -162.86896772139426 266 | 12 -162.86896772139426 267 | 8 -162.86896772139426 268 | 12 -162.86896772139426 269 | 6 -162.86896772139426 270 | 3 -162.86896772139426 271 | 8 -162.86896772139426 272 | 12 -162.86896772139426 273 | 8 -162.86896772139426 274 | 7 -162.86896772139426 275 | 15 -162.86896772139426 276 | 6 -162.86896772139426 277 | 8 -162.86896772139426 278 | 7 -162.86896772139426 279 | 15 -162.86896772139426 280 | 6 -162.86896772139426 281 | 8 -162.86896772139426 282 | 7 -162.86896772139426 283 | 15 -162.86896772139426 284 | 11 -162.86896772139426 285 | 21 -162.86896772139426 286 | 26 -162.86896772139426 287 | 8 -162.86896772139426 288 | 7 -162.86896772139426 289 | 12 -162.86896772139426 290 | 17 -162.86896772139426 291 | -------------------------------------------------------------------------------- /src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz --------------------------------------------------------------------------------
44 | * Further details can be found in: 45 | *
46 | * [Fan et al, 2008] R.-E. Fan, K.-W. Chang, C.-J. Hsieh, X.-R. Wang, and C.-J. 47 | * Lin. LIBLINEAR: A Library for Large Linear Classification, Journal of Machine 48 | * Learning Research 9(2008), 1871-1874. Software available at 49 | *
50 | * The original LIBLINEAR code: 51 | * http://www.csie.ntu.edu.tw/~cjlin/liblinear 52 | *
http://www.csie.ntu.edu.tw/~cjlin/liblinear
53 | * The original JAVA porting (v 1.94): http://liblinear.bwaldvogel.de 54 | * 55 | * @author Danilo Croce 56 | */ 57 | @JsonTypeName("liblinearregression") 58 | public class LibLinearRegression implements LinearMethod, 59 | RegressionLearningAlgorithm, BinaryLearningAlgorithm { 60 | 61 | /** 62 | * The property corresponding to the variable to be learned 63 | */ 64 | private Label label; 65 | /** 66 | * The regularization parameter 67 | */ 68 | private double c = 1; 69 | 70 | /** 71 | * The regressor to be returned 72 | */ 73 | @JsonIgnore 74 | private UnivariateLinearRegressionFunction regressionFunction; 75 | 76 | /** 77 | * The epsilon in loss function of SVR (default 0.1) 78 | */ 79 | private double p = 0.1f; 80 | 81 | /** 82 | * The identifier of the representation to be considered for the training 83 | * step 84 | */ 85 | private String representation; 86 | 87 | /** 88 | * @param label 89 | * The regression property to be learned 90 | * @param c 91 | * The regularization parameter 92 | * 93 | * @param p 94 | * The The epsilon in loss function of SVR 95 | * 96 | * @param representationName 97 | * The identifier of the representation to be considered for the 98 | * training step 99 | */ 100 | public LibLinearRegression(Label label, double c, double p, 101 | String representationName) { 102 | this(); 103 | 104 | this.setLabel(label); 105 | this.c = c; 106 | this.p = p; 107 | this.setRepresentation(representationName); 108 | } 109 | 110 | /** 111 | * @param c 112 | * The regularization parameter 113 | * 114 | * @param representationName 115 | * The identifier of the representation to be considered for the 116 | * training step 117 | */ 118 | public LibLinearRegression(double c, double p, String representationName) { 119 | this(); 120 | this.c = c; 121 | this.p = p; 122 | this.setRepresentation(representationName); 123 | } 124 | 125 | public LibLinearRegression() { 126 | this.regressionFunction = new UnivariateLinearRegressionFunction(); 127 | this.regressionFunction.setModel(new BinaryLinearModel()); 128 | } 129 | 130 | /** 131 | * @return the regularization parameter 132 | */ 133 | public double getC() { 134 | return c; 135 | } 136 | 137 | /** 138 | * @param c 139 | * the regularization parameter 140 | */ 141 | public void setC(double c) { 142 | this.c = c; 143 | } 144 | 145 | /** 146 | * @return the epsilon in loss function 147 | */ 148 | public double getP() { 149 | return p; 150 | } 151 | 152 | /** 153 | * @param p 154 | * the epsilon in loss function 155 | */ 156 | public void setP(double p) { 157 | this.p = p; 158 | } 159 | 160 | /* 161 | * (non-Javadoc) 162 | * 163 | * @see 164 | * it.uniroma2.sag.kelp.learningalgorithm.LinearMethod#getRepresentation() 165 | */ 166 | @Override 167 | public String getRepresentation() { 168 | return representation; 169 | } 170 | 171 | /* 172 | * (non-Javadoc) 173 | * 174 | * @see 175 | * it.uniroma2.sag.kelp.learningalgorithm.LinearMethod#setRepresentation 176 | * (java.lang.String) 177 | */ 178 | @Override 179 | public void setRepresentation(String representation) { 180 | this.representation = representation; 181 | BinaryLinearModel model = this.regressionFunction.getModel(); 182 | model.setRepresentation(representation); 183 | } 184 | 185 | /* 186 | * (non-Javadoc) 187 | * 188 | * @see 189 | * it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#setLabels(java 190 | * .util.List) 191 | */ 192 | @Override 193 | public void setLabels(List labels) { 194 | if (labels.size() != 1) { 195 | throw new IllegalArgumentException( 196 | "LibLinear algorithm is a binary method which can learn a single Label"); 197 | } else { 198 | this.label = labels.get(0); 199 | this.regressionFunction.setLabels(labels); 200 | } 201 | } 202 | 203 | /* 204 | * (non-Javadoc) 205 | * 206 | * @see it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#getLabels() 207 | */ 208 | @Override 209 | public List getLabels() { 210 | return Arrays.asList(label); 211 | } 212 | 213 | /* 214 | * (non-Javadoc) 215 | * 216 | * @see 217 | * it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm#getLabel() 218 | */ 219 | @Override 220 | public Label getLabel() { 221 | return this.label; 222 | } 223 | 224 | /* 225 | * (non-Javadoc) 226 | * 227 | * @see 228 | * it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm#setLabel 229 | * (it.uniroma2.sag.kelp.data.label.Label) 230 | */ 231 | @Override 232 | public void setLabel(Label label) { 233 | this.setLabels(Arrays.asList(label)); 234 | } 235 | 236 | /* 237 | * (non-Javadoc) 238 | * 239 | * @see 240 | * it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#learn(it.uniroma2 241 | * .sag.kelp.data.dataset.Dataset) 242 | */ 243 | @Override 244 | public void learn(Dataset dataset) { 245 | 246 | double eps = 0.001; 247 | 248 | int l = dataset.getNumberOfExamples(); 249 | 250 | double[] C = new double[l]; 251 | for (int i = 0; i < l; i++) { 252 | C[i] = c; 253 | } 254 | 255 | Problem problem = new Problem(dataset, representation, label, 256 | LibLinearSolverType.REGRESSION); 257 | 258 | L2R_L2_SvcFunction fun_obj = new L2R_L2_SvrFunction(problem, C, p); 259 | 260 | Tron tron = new Tron(fun_obj, eps); 261 | 262 | double[] w = new double[problem.n]; 263 | tron.tron(w); 264 | 265 | this.regressionFunction.getModel().setHyperplane(problem.getW(w)); 266 | this.regressionFunction.getModel().setRepresentation(representation); 267 | this.regressionFunction.getModel().setBias(0); 268 | } 269 | 270 | /* 271 | * (non-Javadoc) 272 | * 273 | * @see it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#duplicate() 274 | */ 275 | @Override 276 | public LibLinearRegression duplicate() { 277 | LibLinearRegression copy = new LibLinearRegression(); 278 | copy.setRepresentation(representation); 279 | copy.setC(c); 280 | copy.setP(p); 281 | return copy; 282 | } 283 | 284 | /* 285 | * (non-Javadoc) 286 | * 287 | * @see it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm#reset() 288 | */ 289 | @Override 290 | public void reset() { 291 | this.regressionFunction.reset(); 292 | } 293 | 294 | @Override 295 | public UnivariateLinearRegressionFunction getPredictionFunction() { 296 | return regressionFunction; 297 | } 298 | 299 | @Override 300 | public void setPredictionFunction(PredictionFunction predictionFunction) { 301 | this.regressionFunction = (UnivariateLinearRegressionFunction) predictionFunction; 302 | } 303 | 304 | } 305 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/KernelizedPassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import com.fasterxml.jackson.annotation.JsonTypeName; 19 | 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.kernel.Kernel; 22 | import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod; 23 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 24 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateKernelMachineRegressionFunction; 25 | 26 | /** 27 | * Online Passive-Aggressive Learning Algorithm for regression tasks (kernel machine version). 28 | * 29 | * reference: 30 | * 31 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 32 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 33 | * 34 | * @author Simone Filice 35 | */ 36 | @JsonTypeName("kernelizedPA-R") 37 | public class KernelizedPassiveAggressiveRegression extends PassiveAggressiveRegression implements KernelMethod{ 38 | 39 | private Kernel kernel; 40 | 41 | public KernelizedPassiveAggressiveRegression(){ 42 | this.regressor = new UnivariateKernelMachineRegressionFunction(); 43 | } 44 | 45 | public KernelizedPassiveAggressiveRegression(float aggressiveness, float epsilon, Policy policy, Kernel kernel, Label label){ 46 | this.regressor = new UnivariateKernelMachineRegressionFunction(); 47 | this.setC(aggressiveness); 48 | this.setEpsilon(epsilon); 49 | this.setPolicy(policy); 50 | this.setKernel(kernel); 51 | this.setLabel(label); 52 | } 53 | 54 | @Override 55 | public Kernel getKernel(){ 56 | return kernel; 57 | } 58 | 59 | @Override 60 | public void setKernel(Kernel kernel) { 61 | this.kernel = kernel; 62 | this.getPredictionFunction().getModel().setKernel(kernel); 63 | } 64 | 65 | @Override 66 | public KernelizedPassiveAggressiveRegression duplicate() { 67 | KernelizedPassiveAggressiveRegression copy = new KernelizedPassiveAggressiveRegression(); 68 | copy.setC(this.c); 69 | copy.setKernel(this.kernel); 70 | copy.setPolicy(this.policy); 71 | copy.setEpsilon(epsilon); 72 | return copy; 73 | } 74 | 75 | @Override 76 | public UnivariateKernelMachineRegressionFunction getPredictionFunction(){ 77 | return (UnivariateKernelMachineRegressionFunction) this.regressor; 78 | } 79 | 80 | @Override 81 | public void setPredictionFunction(PredictionFunction predictionFunction) { 82 | this.regressor = (UnivariateKernelMachineRegressionFunction) predictionFunction; 83 | } 84 | 85 | } 86 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/LinearPassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import com.fasterxml.jackson.annotation.JsonTypeName; 19 | 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 22 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 23 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 24 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateLinearRegressionFunction; 25 | 26 | /** 27 | * Online Passive-Aggressive Learning Algorithm for regression tasks (linear version). 28 | * 29 | * reference: 30 | * 31 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 32 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 33 | * 34 | * @author Simone Filice 35 | */ 36 | @JsonTypeName("linearPA-R") 37 | public class LinearPassiveAggressiveRegression extends PassiveAggressiveRegression implements LinearMethod{ 38 | 39 | private String representation; 40 | 41 | public LinearPassiveAggressiveRegression(){ 42 | UnivariateLinearRegressionFunction regressor = new UnivariateLinearRegressionFunction(); 43 | regressor.setModel(new BinaryLinearModel()); 44 | this.regressor = regressor; 45 | 46 | } 47 | 48 | public LinearPassiveAggressiveRegression(float aggressiveness, float epsilon, Policy policy, String representation, Label label){ 49 | UnivariateLinearRegressionFunction regressor = new UnivariateLinearRegressionFunction(); 50 | regressor.setModel(new BinaryLinearModel()); 51 | this.regressor = regressor; 52 | this.setC(aggressiveness); 53 | this.setEpsilon(epsilon); 54 | this.setPolicy(policy); 55 | this.setRepresentation(representation); 56 | this.setLabel(label); 57 | } 58 | 59 | @Override 60 | public LinearPassiveAggressiveRegression duplicate() { 61 | LinearPassiveAggressiveRegression copy = new LinearPassiveAggressiveRegression(); 62 | copy.setC(this.c); 63 | copy.setRepresentation(this.representation); 64 | copy.setPolicy(this.policy); 65 | copy.setEpsilon(epsilon); 66 | return copy; 67 | } 68 | 69 | @Override 70 | public String getRepresentation() { 71 | return representation; 72 | } 73 | 74 | @Override 75 | public void setRepresentation(String representation) { 76 | this.representation = representation; 77 | this.getPredictionFunction().getModel().setRepresentation(representation); 78 | } 79 | 80 | @Override 81 | public UnivariateLinearRegressionFunction getPredictionFunction(){ 82 | return (UnivariateLinearRegressionFunction) this.regressor; 83 | } 84 | 85 | @Override 86 | public void setPredictionFunction(PredictionFunction predictionFunction) { 87 | this.regressor = (UnivariateLinearRegressionFunction) predictionFunction; 88 | } 89 | 90 | } 91 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/PassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.learningalgorithm.PassiveAggressive; 21 | import it.uniroma2.sag.kelp.learningalgorithm.regression.RegressionLearningAlgorithm; 22 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateRegressionOutput; 23 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateRegressionFunction; 24 | 25 | import com.fasterxml.jackson.annotation.JsonIgnore; 26 | 27 | /** 28 | * Online Passive-Aggressive Learning Algorithm for regression tasks. 29 | * 30 | * reference: 31 | * 32 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 33 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 34 | * 35 | * @author Simone Filice 36 | */ 37 | public abstract class PassiveAggressiveRegression extends PassiveAggressive implements RegressionLearningAlgorithm{ 38 | 39 | @JsonIgnore 40 | protected UnivariateRegressionFunction regressor; 41 | 42 | protected float epsilon; 43 | 44 | /** 45 | * Returns epsilon, i.e. the accepted distance between the predicted and the real regression values 46 | * 47 | * @return the epsilon 48 | */ 49 | public float getEpsilon() { 50 | return epsilon; 51 | } 52 | 53 | /** 54 | * Sets epsilon, i.e. the accepted distance between the predicted and the real regression values 55 | * 56 | * @param epsilon the epsilon to set 57 | */ 58 | public void setEpsilon(float epsilon) { 59 | this.epsilon = epsilon; 60 | } 61 | 62 | @Override 63 | public UnivariateRegressionFunction getPredictionFunction() { 64 | return this.regressor; 65 | } 66 | 67 | @Override 68 | public void learn(Dataset dataset){ 69 | 70 | while(dataset.hasNextExample()){ 71 | Example example = dataset.getNextExample(); 72 | this.learn(example); 73 | } 74 | dataset.reset(); 75 | } 76 | 77 | @Override 78 | public UnivariateRegressionOutput learn(Example example){ 79 | UnivariateRegressionOutput prediction=this.regressor.predict(example); 80 | float difference = example.getRegressionValue(label) - prediction.getScore(label); 81 | float lossValue = Math.abs(difference) - epsilon;//it represents the distance from the correct semi-space 82 | if(lossValue>0){ 83 | float exampleSquaredNorm = this.regressor.getModel().getSquaredNorm(example); 84 | float weight = this.computeWeight(example, lossValue, exampleSquaredNorm, c); 85 | if(difference<0){ 86 | weight = -weight; 87 | } 88 | this.regressor.getModel().addExample(weight, example); 89 | } 90 | return prediction; 91 | } 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/linearization/LinearizationFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.linearization; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 20 | import it.uniroma2.sag.kelp.data.example.Example; 21 | import it.uniroma2.sag.kelp.data.representation.Vector; 22 | 23 | /** 24 | * This interface allows implementing function to linearized examples through 25 | * linear representations, i.e. vectors 26 | * 27 | * 28 | * @author Danilo Croce 29 | * 30 | */ 31 | public interface LinearizationFunction { 32 | 33 | /** 34 | * Given an input Example, this method generates a linear 35 | * Representation>, i.e. a Vector. 36 | * 37 | * @param example 38 | * The input example. 39 | * @return The linearized representation of the input example. 40 | */ 41 | public Vector getLinearRepresentation(Example example); 42 | 43 | /** 44 | * This method linearizes an input example, providing a new example 45 | * containing only a representation with a specific name, provided as input. 46 | * The produced example inherits the labels of the input example. 47 | * 48 | * @param example 49 | * The input example. 50 | * @param vectorName 51 | * The name of the linear representation inside the new example 52 | * @return 53 | */ 54 | public Example getLinearizedExample(Example example, String representationName); 55 | 56 | /** 57 | * This method linearizes all the examples in the input dataset 58 | * , generating a corresponding linearized dataset. The produced examples 59 | * inherit the labels of the corresponding input examples. 60 | * 61 | * @param dataset 62 | * The input dataset 63 | * @param representationName 64 | * The name of the linear representation inside the new examples 65 | * @return 66 | */ 67 | public SimpleDataset getLinearizedDataset(Dataset dataset, String representationName); 68 | 69 | /** 70 | * @return the size of the resulting embedding, i.e. the number of resulting 71 | * vector dimensions 72 | */ 73 | public int getEmbeddingSize(); 74 | 75 | } 76 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/predictionfunction/SequencePrediction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.predictionfunction; 17 | 18 | import java.util.ArrayList; 19 | import java.util.List; 20 | 21 | import it.uniroma2.sag.kelp.data.example.SequencePath; 22 | import it.uniroma2.sag.kelp.data.label.Label; 23 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 24 | 25 | /** 26 | * It is a output provided by a machine learning systems on a sequence. This 27 | * specific implementation allows to assign multiple labelings to single 28 | * sequence, useful for some labeling strategies, such as Beam Search. Notice 29 | * that each labeling requires a score to select the more promising labeling. 30 | * 31 | * @author Danilo Croce 32 | * 33 | */ 34 | public class SequencePrediction implements Prediction { 35 | 36 | /** 37 | * 38 | */ 39 | private static final long serialVersionUID = -1040539866977906008L; 40 | /** 41 | * This list contains multiple labelings to be assigned to a single sequence 42 | */ 43 | private List paths; 44 | 45 | public SequencePrediction() { 46 | paths = new ArrayList(); 47 | } 48 | 49 | /** 50 | * @return The best path, i.e., the labeling with the highest score in the 51 | * list of labelings provided by a classifier 52 | */ 53 | public SequencePath bestPath() { 54 | return paths.get(0); 55 | } 56 | 57 | /** 58 | * @return a list containing multiple labelings to be assigned to a single 59 | * sequence 60 | */ 61 | public List getPaths() { 62 | return paths; 63 | } 64 | 65 | @Override 66 | public Float getScore(Label label) { 67 | return null; 68 | } 69 | 70 | /** 71 | * @param paths 72 | * a list contains multiple labelings to be assigned to a single 73 | * sequence 74 | */ 75 | public void setPaths(List paths) { 76 | this.paths = paths; 77 | } 78 | 79 | @Override 80 | public String toString() { 81 | StringBuilder sb = new StringBuilder(); 82 | for (int i = 0; i < paths.size(); i++) { 83 | if (i == 0) 84 | sb.append("Best Path\t"); 85 | else 86 | sb.append("Altern. Path\t"); 87 | SequencePath sequencePath = paths.get(i); 88 | sb.append(sequencePath + "\n"); 89 | } 90 | return sb.toString(); 91 | } 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/predictionfunction/model/SequenceModel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.predictionfunction.model; 17 | 18 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGenerator; 19 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 20 | 21 | /** 22 | * This class implements a model produced by a 23 | * SequenceClassificationLearningAlgorithm 24 | * 25 | * @author Danilo Croce 26 | * 27 | */ 28 | public class SequenceModel implements Model { 29 | 30 | /** 31 | * 32 | */ 33 | private static final long serialVersionUID = -2749198158786953940L; 34 | 35 | /** 36 | * The prediction function producing the emission scores to be considered in 37 | * the Viterbi Decoding 38 | */ 39 | private PredictionFunction basePredictionFunction; 40 | 41 | private SequenceExampleGenerator sequenceExampleGenerator; 42 | 43 | public SequenceModel() { 44 | super(); 45 | } 46 | 47 | public SequenceModel(PredictionFunction basePredictionFunction, SequenceExampleGenerator sequenceExampleGenerator) { 48 | super(); 49 | this.basePredictionFunction = basePredictionFunction; 50 | this.sequenceExampleGenerator = sequenceExampleGenerator; 51 | } 52 | 53 | public PredictionFunction getBasePredictionFunction() { 54 | return basePredictionFunction; 55 | } 56 | 57 | public SequenceExampleGenerator getSequenceExampleGenerator() { 58 | return sequenceExampleGenerator; 59 | } 60 | 61 | @Override 62 | public void reset() { 63 | } 64 | 65 | public void setBasePredictionFunction(PredictionFunction basePredictionFunction) { 66 | this.basePredictionFunction = basePredictionFunction; 67 | } 68 | 69 | public void setSequenceExampleGenerator(SequenceExampleGenerator sequenceExampleGenerator) { 70 | this.sequenceExampleGenerator = sequenceExampleGenerator; 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/utils/evaluation/ClusteringEvaluator.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.utils.evaluation; 2 | 3 | import java.util.ArrayList; 4 | import java.util.HashSet; 5 | import java.util.TreeMap; 6 | 7 | import it.uniroma2.sag.kelp.data.clustering.Cluster; 8 | import it.uniroma2.sag.kelp.data.clustering.ClusterExample; 9 | import it.uniroma2.sag.kelp.data.clustering.ClusterList; 10 | import it.uniroma2.sag.kelp.data.example.Example; 11 | import it.uniroma2.sag.kelp.data.example.SimpleExample; 12 | import it.uniroma2.sag.kelp.data.label.Label; 13 | import it.uniroma2.sag.kelp.data.label.StringLabel; 14 | import it.uniroma2.sag.kelp.learningalgorithm.clustering.kernelbasedkmeans.KernelBasedKMeansExample; 15 | 16 | /** 17 | * 18 | * Implements Evaluation methods for clustering algorithms. 19 | * 20 | * More details about Purity and NMI can be found here: 21 | * 22 | * https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1. 23 | * html 24 | * 25 | * @author Danilo Croce 26 | * 27 | */ 28 | public class ClusteringEvaluator { 29 | 30 | public static float getPurity(ClusterList clusters) { 31 | 32 | float res = 0; 33 | int k = clusters.size(); 34 | 35 | for (int clustId = 0; clustId < k; clustId++) { 36 | 37 | TreeMap classSizes = new TreeMap(); 38 | 39 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 40 | HashSet labels = vce.getExample().getClassificationLabels(); 41 | for (Label label : labels) 42 | if (!classSizes.containsKey(label)) 43 | classSizes.put(label, 1); 44 | else 45 | classSizes.put(label, classSizes.get(label) + 1); 46 | } 47 | 48 | int maxSize = 0; 49 | for (int size : classSizes.values()) { 50 | if (size > maxSize) { 51 | maxSize = size; 52 | } 53 | } 54 | res += maxSize; 55 | } 56 | 57 | return res / (float) clusters.getNumberOfExamples(); 58 | } 59 | 60 | public static float getMI(ClusterList clusters) { 61 | 62 | float res = 0; 63 | 64 | float N = clusters.getNumberOfExamples(); 65 | 66 | int k = clusters.size(); 67 | 68 | TreeMap classCardinality = getClassCardinality(clusters); 69 | 70 | for (int clustId = 0; clustId < k; clustId++) { 71 | 72 | TreeMap classSizes = getClassCardinalityWithinCluster(clusters, clustId); 73 | 74 | for (Label className : classSizes.keySet()) { 75 | int wSize = classSizes.get(className); 76 | res += ((float) wSize / N) * myLog(N * (float) wSize 77 | / (clusters.get(clustId).getExamples().size() * (float) classCardinality.get(className))); 78 | } 79 | 80 | } 81 | 82 | return res; 83 | 84 | } 85 | 86 | private static TreeMap getClassCardinalityWithinCluster(ClusterList clusters, int clustId) { 87 | 88 | TreeMap classSizes = new TreeMap(); 89 | 90 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 91 | HashSet labels = vce.getExample().getClassificationLabels(); 92 | for (Label label : labels) 93 | if (!classSizes.containsKey(label)) 94 | classSizes.put(label, 1); 95 | else 96 | classSizes.put(label, classSizes.get(label) + 1); 97 | } 98 | 99 | return classSizes; 100 | } 101 | 102 | private static float getClusterEntropy(ClusterList clusters) { 103 | 104 | float res = 0; 105 | float N = clusters.getNumberOfExamples(); 106 | int k = clusters.size(); 107 | 108 | for (int clustId = 0; clustId < k; clustId++) { 109 | int clusterElementSize = clusters.get(clustId).getExamples().size(); 110 | if (clusterElementSize != 0) 111 | res -= ((float) clusterElementSize / N) * myLog((float) clusterElementSize / N); 112 | } 113 | return res; 114 | 115 | } 116 | 117 | private static float getClassEntropy(ClusterList clusters) { 118 | 119 | float res = 0; 120 | float N = clusters.getNumberOfExamples(); 121 | 122 | TreeMap classCardinality = getClassCardinality(clusters); 123 | 124 | for (int classSize : classCardinality.values()) { 125 | res -= ((float) classSize / N) * myLog((float) classSize / N); 126 | } 127 | return res; 128 | 129 | } 130 | 131 | private static float myLog(float f) { 132 | return (float) (Math.log(f) / Math.log(2f)); 133 | } 134 | 135 | private static TreeMap getClassCardinality(ClusterList clusters) { 136 | TreeMap classSizes = new TreeMap(); 137 | 138 | int k = clusters.size(); 139 | 140 | for (int clustId = 0; clustId < k; clustId++) { 141 | 142 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 143 | HashSet labels = vce.getExample().getClassificationLabels(); 144 | for (Label label : labels) 145 | if (!classSizes.containsKey(label)) 146 | classSizes.put(label, 1); 147 | else 148 | classSizes.put(label, classSizes.get(label) + 1); 149 | } 150 | } 151 | return classSizes; 152 | } 153 | 154 | public static float getNMI(ClusterList clusters) { 155 | return getMI(clusters) / ((getClusterEntropy(clusters) + getClassEntropy(clusters)) / 2f); 156 | } 157 | 158 | public static String getStatistics(ClusterList clusters) { 159 | StringBuilder sb = new StringBuilder(); 160 | 161 | sb.append("Purity:\t" + getPurity(clusters) + "\n"); 162 | sb.append("Mutual Information:\t" + getMI(clusters) + "\n"); 163 | sb.append("Cluster Entropy:\t" + getClusterEntropy(clusters) + "\n"); 164 | sb.append("Class Entropy:\t" + getClassEntropy(clusters) + "\n"); 165 | sb.append("NMI:\t" + getNMI(clusters)); 166 | 167 | return sb.toString(); 168 | } 169 | 170 | public static void main(String[] args) { 171 | ClusterList clusters = new ClusterList(); 172 | 173 | Cluster c1 = new Cluster("C1"); 174 | ArrayList list1 = new ArrayList(); 175 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 176 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 177 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 178 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 179 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 180 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 181 | for (Example e : list1) { 182 | c1.add(new KernelBasedKMeansExample(e, 1f)); 183 | } 184 | 185 | Cluster c2 = new Cluster("C2"); 186 | ArrayList list2 = new ArrayList(); 187 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 188 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 189 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 190 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 191 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 192 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 193 | for (Example e : list2) { 194 | c2.add(new KernelBasedKMeansExample(e, 1f)); 195 | } 196 | 197 | Cluster c3 = new Cluster("C3"); 198 | ArrayList list3 = new ArrayList(); 199 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 200 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 201 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 202 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 203 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 204 | for (Example e : list3) { 205 | c3.add(new KernelBasedKMeansExample(e, 1f)); 206 | } 207 | 208 | clusters.add(c1); 209 | clusters.add(c2); 210 | clusters.add(c3); 211 | 212 | System.out.println(ClusteringEvaluator.getStatistics(clusters)); 213 | 214 | //From https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html 215 | //Purity = 0.71 216 | //NMI = 0.36 217 | 218 | } 219 | 220 | } 221 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/utils/evaluation/MulticlassSequenceClassificationEvaluator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.utils.evaluation; 17 | 18 | import java.util.List; 19 | 20 | import it.uniroma2.sag.kelp.data.example.Example; 21 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 22 | import it.uniroma2.sag.kelp.data.example.SequencePath; 23 | import it.uniroma2.sag.kelp.data.label.Label; 24 | import it.uniroma2.sag.kelp.data.label.SequenceEmission; 25 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 26 | import it.uniroma2.sag.kelp.predictionfunction.SequencePrediction; 27 | 28 | /** 29 | * This is an instance of an Evaluator. It allows to compute the some common 30 | * measure for classification tasks acting over SequenceExamples. It 31 | * computes precision, recall, f1s for each class, and a global accuracy. 32 | * 33 | * @author Danilo Croce 34 | */ 35 | public class MulticlassSequenceClassificationEvaluator extends MulticlassClassificationEvaluator{ 36 | 37 | /** 38 | * Initialize a new F1Evaluator that will work on the specified classes 39 | * 40 | * @param labels 41 | */ 42 | public MulticlassSequenceClassificationEvaluator(List labels) { 43 | super(labels); 44 | } 45 | 46 | public void addCount(Example test, Prediction prediction) { 47 | addCount((SequenceExample) test, (SequencePrediction) prediction); 48 | } 49 | 50 | /** 51 | * This method should be implemented in the subclasses to update counters 52 | * useful to compute the performance measure 53 | * 54 | * @param test 55 | * the test example 56 | * @param predicted 57 | * the prediction of the system 58 | */ 59 | public void addCount(SequenceExample test, SequencePrediction predicted) { 60 | 61 | SequencePath bestPath = predicted.bestPath(); 62 | 63 | for (int seqIdx = 0; seqIdx < test.getLenght(); seqIdx++) { 64 | 65 | Example testItem = test.getExample(seqIdx); 66 | SequenceEmission sequenceLabel = bestPath.getAssignedSequnceLabels().get(seqIdx); 67 | 68 | for (Label l : this.labels) { 69 | ClassStats stats = this.classStats.get(l); 70 | if(testItem.isExampleOf(l)){ 71 | if(sequenceLabel.getLabel().equals(l)){ 72 | stats.tp++; 73 | totalTp++; 74 | }else{ 75 | stats.fn++; 76 | totalFn++; 77 | } 78 | }else{ 79 | if(sequenceLabel.getLabel().equals(l)){ 80 | stats.fp++; 81 | totalFp++; 82 | }else{ 83 | stats.tn++; 84 | totalTn++; 85 | } 86 | } 87 | 88 | } 89 | 90 | //TODO: check (i) e' giusto valutare l'accuracy dei singoli elementi della sequenza e non della sequenza completa 91 | //(ii) va considerato il caso multilabel 92 | total++; 93 | 94 | if (testItem.isExampleOf(sequenceLabel.getLabel())) { 95 | correct++; 96 | } 97 | 98 | this.computed = false; 99 | } 100 | } 101 | 102 | } 103 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/algorithms/binary/liblinear/LibLinearDenseVsSparseClassificationEvaluator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.algorithms.binary.liblinear; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.data.manipulator.NormalizationManipolator; 22 | import it.uniroma2.sag.kelp.data.manipulator.VectorConcatenationManipulator; 23 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.LibLinearLearningAlgorithm; 24 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 25 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassificationOutput; 26 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassifier; 27 | import it.uniroma2.sag.kelp.utils.evaluation.MulticlassClassificationEvaluator; 28 | import it.uniroma2.sag.kelp.utils.exception.NoSuchPerformanceMeasureException; 29 | 30 | import java.io.FileNotFoundException; 31 | import java.io.UnsupportedEncodingException; 32 | import java.util.ArrayList; 33 | import java.util.List; 34 | 35 | import org.junit.Assert; 36 | import org.junit.Test; 37 | 38 | public class LibLinearDenseVsSparseClassificationEvaluator { 39 | 40 | private static List sparseScores = new ArrayList(); 41 | private static List denseScores = new ArrayList(); 42 | 43 | @Test 44 | public void testConsistency() { 45 | try { 46 | String inputFilePath = "src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz"; 47 | 48 | SimpleDataset dataset = new SimpleDataset(); 49 | dataset.populate(inputFilePath); 50 | SimpleDataset[] split = dataset.split(0.5f); 51 | 52 | SimpleDataset trainingSet = split[0]; 53 | SimpleDataset testSet = split[1]; 54 | float c = 1.0f; 55 | float f1Dense = testDense(trainingSet, c, testSet); 56 | float f1Sparse = testSparse(trainingSet, c, testSet); 57 | 58 | Assert.assertEquals(f1Sparse, f1Dense, 0.000001); 59 | 60 | for (int i = 0; i < sparseScores.size(); i++) { 61 | Assert.assertEquals(sparseScores.get(i), denseScores.get(i), 62 | 0.000001); 63 | } 64 | } catch (FileNotFoundException e) { 65 | e.printStackTrace(); 66 | Assert.assertTrue(false); 67 | } catch (UnsupportedEncodingException e) { 68 | e.printStackTrace(); 69 | Assert.assertTrue(false); 70 | } catch (NoSuchPerformanceMeasureException e) { 71 | e.printStackTrace(); 72 | Assert.assertTrue(false); 73 | } catch (Exception e) { 74 | e.printStackTrace(); 75 | Assert.assertTrue(false); 76 | } 77 | } 78 | 79 | private static float testSparse(SimpleDataset trainingSet, float c, 80 | SimpleDataset testSet) throws FileNotFoundException, 81 | UnsupportedEncodingException, NoSuchPerformanceMeasureException { 82 | List classes = trainingSet.getClassificationLabels(); 83 | NormalizationManipolator norma = new NormalizationManipolator(); 84 | trainingSet.manipulate(norma); 85 | testSet.manipulate(norma); 86 | List repr = new ArrayList(); 87 | repr.add("WS"); 88 | List reprW = new ArrayList(); 89 | reprW.add(1.0f); 90 | VectorConcatenationManipulator man = new VectorConcatenationManipulator( 91 | "WS0", repr, reprW); 92 | trainingSet.manipulate(man); 93 | testSet.manipulate(man); 94 | 95 | LibLinearLearningAlgorithm svmSolver = new LibLinearLearningAlgorithm(); 96 | svmSolver.setCn(c); 97 | svmSolver.setCp(c); 98 | svmSolver.setRepresentation("WS0"); 99 | 100 | OneVsAllLearning ovaLearner = new OneVsAllLearning(); 101 | ovaLearner.setBaseAlgorithm(svmSolver); 102 | ovaLearner.setLabels(classes); 103 | ovaLearner.learn(trainingSet); 104 | OneVsAllClassifier f = ovaLearner.getPredictionFunction(); 105 | MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator( 106 | trainingSet.getClassificationLabels()); 107 | for (Example e : testSet.getExamples()) { 108 | OneVsAllClassificationOutput predict = f.predict(e); 109 | Label l = predict.getPredictedClasses().get(0); 110 | evaluator.addCount(e, predict); 111 | sparseScores.add(predict.getScore(l)); 112 | } 113 | 114 | return evaluator.getMacroF1(); 115 | } 116 | 117 | private static float testDense(SimpleDataset trainingSet, float c, 118 | SimpleDataset testSet) throws FileNotFoundException, 119 | UnsupportedEncodingException, NoSuchPerformanceMeasureException { 120 | List classes = trainingSet.getClassificationLabels(); 121 | 122 | LibLinearLearningAlgorithm svmSolver = new LibLinearLearningAlgorithm(); 123 | svmSolver.setCn(c); 124 | svmSolver.setCp(c); 125 | svmSolver.setRepresentation("WS"); 126 | 127 | OneVsAllLearning ovaLearner = new OneVsAllLearning(); 128 | ovaLearner.setBaseAlgorithm(svmSolver); 129 | ovaLearner.setLabels(classes); 130 | ovaLearner.learn(trainingSet); 131 | OneVsAllClassifier f = ovaLearner.getPredictionFunction(); 132 | MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator( 133 | trainingSet.getClassificationLabels()); 134 | for (Example e : testSet.getExamples()) { 135 | OneVsAllClassificationOutput predict = f.predict(e); 136 | Label l = predict.getPredictedClasses().get(0); 137 | evaluator.addCount(e, predict); 138 | denseScores.add(predict.getScore(l)); 139 | } 140 | 141 | return evaluator.getMacroF1(); 142 | } 143 | 144 | } 145 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/algorithms/incrementalTrain/IncrementalTrainTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Simone Filice and Giuseppe Castellucci and Danilo Croce 3 | * and Giovanni Da San Martino and Alessandro Moschitti and Roberto Basili 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package it.uniroma2.sag.kelp.algorithms.incrementalTrain; 18 | 19 | import java.io.IOException; 20 | import java.util.Random; 21 | 22 | import org.junit.Assert; 23 | import org.junit.BeforeClass; 24 | import org.junit.Test; 25 | 26 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 27 | import it.uniroma2.sag.kelp.data.example.Example; 28 | import it.uniroma2.sag.kelp.data.label.Label; 29 | import it.uniroma2.sag.kelp.data.label.StringLabel; 30 | import it.uniroma2.sag.kelp.kernel.Kernel; 31 | import it.uniroma2.sag.kelp.kernel.cache.FixSizeKernelCache; 32 | import it.uniroma2.sag.kelp.kernel.vector.LinearKernel; 33 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 34 | import it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron.KernelizedPerceptron; 35 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryKernelMachineClassifier; 36 | import it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput; 37 | import it.uniroma2.sag.kelp.predictionfunction.classifier.Classifier; 38 | import it.uniroma2.sag.kelp.utils.JacksonSerializerWrapper; 39 | import it.uniroma2.sag.kelp.utils.ObjectSerializer; 40 | 41 | public class IncrementalTrainTest { 42 | private static Classifier f = null; 43 | private static SimpleDataset trainingSet; 44 | private static SimpleDataset testSet; 45 | private static SimpleDataset [] folds; 46 | private static ObjectSerializer serializer = new JacksonSerializerWrapper(); 47 | private static KernelizedPerceptron learner; 48 | 49 | private static Label positiveClass = new StringLabel("+1"); 50 | 51 | @BeforeClass 52 | public static void learnModel() { 53 | trainingSet = new SimpleDataset(); 54 | testSet = new SimpleDataset(); 55 | try { 56 | trainingSet.populate("src/test/resources/svmTest/binary/binary_train.klp"); 57 | trainingSet.shuffleExamples(new Random()); 58 | // Read a dataset into a test variable 59 | testSet.populate("src/test/resources/svmTest/binary/binary_test.klp"); 60 | } catch (Exception e) { 61 | e.printStackTrace(); 62 | Assert.assertTrue(false); 63 | } 64 | 65 | folds = trainingSet.nFolding(2); 66 | 67 | // define the kernel 68 | Kernel kernel = new LinearKernel("0"); 69 | 70 | // add a cache 71 | kernel.setKernelCache(new FixSizeKernelCache(trainingSet 72 | .getNumberOfExamples())); 73 | 74 | // define the learning algorithm 75 | learner = new KernelizedPerceptron(0.2f, 1f, false, kernel, positiveClass); 76 | 77 | // learn and get the prediction function 78 | learner.learn(trainingSet); 79 | f = learner.getPredictionFunction(); 80 | } 81 | 82 | @Test 83 | public void incrementalTrain() throws IOException{ 84 | String jsonSerialization = serializer.writeValueAsString(learner); 85 | System.out.println(jsonSerialization); 86 | ClassificationLearningAlgorithm jsonAlgo = serializer.readValue(jsonSerialization, ClassificationLearningAlgorithm.class); 87 | jsonAlgo.learn(folds[0]); 88 | jsonAlgo.learn(folds[1]); 89 | Classifier jsonClassifier = jsonAlgo.getPredictionFunction(); 90 | 91 | for(Example ex : testSet.getExamples()){ 92 | ClassificationOutput p = f.predict(ex); 93 | Float score = p.getScore(positiveClass); 94 | ClassificationOutput pJson = jsonClassifier.predict(ex); 95 | Float scoreJson = pJson.getScore(positiveClass); 96 | Assert.assertEquals(scoreJson.floatValue(), score.floatValue(), 97 | 0.001f); 98 | } 99 | } 100 | 101 | @Test 102 | public void reloadAndContinueTraining() throws IOException{ 103 | String jsonLearnerSerialization = serializer.writeValueAsString(learner); 104 | System.out.println(jsonLearnerSerialization); 105 | KernelizedPerceptron jsonAlgo = serializer.readValue(jsonLearnerSerialization, KernelizedPerceptron.class); 106 | jsonAlgo.learn(folds[0]); 107 | String jsonClassifierSerialization = serializer.writeValueAsString(jsonAlgo.getPredictionFunction()); 108 | jsonAlgo = serializer.readValue(jsonLearnerSerialization, KernelizedPerceptron.class); //Brand new classifier 109 | BinaryKernelMachineClassifier jsonClassifier = serializer.readValue(jsonClassifierSerialization, BinaryKernelMachineClassifier.class); 110 | jsonAlgo.getPredictionFunction().setModel(jsonClassifier.getModel()); 111 | jsonAlgo.learn(folds[1]); 112 | jsonClassifier = jsonAlgo.getPredictionFunction(); 113 | 114 | for(Example ex : testSet.getExamples()){ 115 | ClassificationOutput p = f.predict(ex); 116 | Float score = p.getScore(positiveClass); 117 | ClassificationOutput pJson = jsonClassifier.predict(ex); 118 | Float scoreJson = pJson.getScore(positiveClass); 119 | Assert.assertEquals(scoreJson.floatValue(), score.floatValue(), 120 | 0.001f); 121 | } 122 | } 123 | 124 | } 125 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/learningalgorithm/classification/hmm/SequenceLearningLinearTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.hmm; 17 | 18 | import java.io.BufferedReader; 19 | import java.io.FileInputStream; 20 | import java.io.FileNotFoundException; 21 | import java.io.IOException; 22 | import java.io.InputStreamReader; 23 | import java.io.UnsupportedEncodingException; 24 | import java.util.ArrayList; 25 | import java.util.zip.GZIPInputStream; 26 | 27 | import org.junit.Assert; 28 | import org.junit.Test; 29 | 30 | import it.uniroma2.sag.kelp.data.dataset.SequenceDataset; 31 | import it.uniroma2.sag.kelp.data.example.Example; 32 | import it.uniroma2.sag.kelp.data.example.ParsingExampleException; 33 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 34 | import it.uniroma2.sag.kelp.data.example.SequencePath; 35 | import it.uniroma2.sag.kelp.data.label.Label; 36 | import it.uniroma2.sag.kelp.data.label.StringLabel; 37 | import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLearningAlgorithm; 38 | import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLoss; 39 | import it.uniroma2.sag.kelp.predictionfunction.SequencePrediction; 40 | import it.uniroma2.sag.kelp.predictionfunction.SequencePredictionFunction; 41 | 42 | public class SequenceLearningLinearTest { 43 | 44 | private static final Float TOLERANCE = 0.001f; 45 | 46 | public static void main(String[] args) throws Exception { 47 | 48 | } 49 | 50 | @Test 51 | public void testLinear() { 52 | 53 | String inputTrainFilePath = "src/test/resources/sequence_learning/declaration_of_independence.klp.gz"; 54 | String inputTestFilePath = "src/test/resources/sequence_learning/gettysburg_address.klp.gz"; 55 | String scoreFilePath = "src/test/resources/sequence_learning/prediction_test_linear.txt"; 56 | 57 | /* 58 | * Given a targeted item in the sequence, this variable determines the 59 | * number of previous example considered in the learning/labeling 60 | * process. 61 | * 62 | * NOTE: if this variable is set to 0, the learning process corresponds 63 | * to a traditional multi-class classification schema 64 | */ 65 | int transitionsOrder = 1; 66 | 67 | /* 68 | * This variable determines the importance of the transition-based 69 | * features during the learning process. Higher valuers will assign more 70 | * importance to the transitions. 71 | */ 72 | float weight = 1f; 73 | 74 | /* 75 | * The size of the beam to be used in the decoding process. This number 76 | * determines the number of possible sequences produced in the labeling 77 | * process. It will also increase the process complexity. 78 | */ 79 | int beamSize = 5; 80 | 81 | /* 82 | * During the labeling process, each item is classified with respect to 83 | * the target classes. To reduce the complexity of the labeling process, 84 | * this variable determines the number of classes that received the 85 | * highest classification scores to be considered after the 86 | * classification step in the Viterbi Decoding. 87 | */ 88 | int maxEmissionCandidates = 3; 89 | 90 | /* 91 | * This representation contains the feature vector representing items in 92 | * the sequence 93 | */ 94 | String originalRepresentationName = "rep"; 95 | 96 | /* 97 | * Loading the training dataset 98 | */ 99 | SequenceDataset sequenceTrainDataset = new SequenceDataset(); 100 | try { 101 | sequenceTrainDataset.populate(inputTrainFilePath); 102 | } catch (IOException e) { 103 | e.printStackTrace(); 104 | Assert.assertTrue(false); 105 | } catch (InstantiationException e) { 106 | e.printStackTrace(); 107 | Assert.assertTrue(false); 108 | } catch (ParsingExampleException e) { 109 | e.printStackTrace(); 110 | Assert.assertTrue(false); 111 | } catch (Exception e) { 112 | e.printStackTrace(); 113 | Assert.assertTrue(false); 114 | } 115 | 116 | /* 117 | * Instance classifier 118 | */ 119 | float cSVM = 1f; 120 | DCDLearningAlgorithm instanceClassifierLearningAlgorithm = new DCDLearningAlgorithm(cSVM, cSVM, DCDLoss.L1, 121 | false, 50, originalRepresentationName); 122 | 123 | /* 124 | * Sequence classifier. 125 | */ 126 | SequenceClassificationLearningAlgorithm sequenceClassificationLearningAlgorithm = null; 127 | try { 128 | sequenceClassificationLearningAlgorithm = new SequenceClassificationLinearLearningAlgorithm( 129 | instanceClassifierLearningAlgorithm, transitionsOrder, weight); 130 | sequenceClassificationLearningAlgorithm.setMaxEmissionCandidates(maxEmissionCandidates); 131 | sequenceClassificationLearningAlgorithm.setBeamSize(beamSize); 132 | 133 | sequenceClassificationLearningAlgorithm.learn(sequenceTrainDataset); 134 | } catch (Exception e1) { 135 | e1.printStackTrace(); 136 | Assert.assertTrue(false); 137 | } 138 | 139 | SequencePredictionFunction predictionFunction = (SequencePredictionFunction) sequenceClassificationLearningAlgorithm 140 | .getPredictionFunction(); 141 | 142 | /* 143 | * Load the test set 144 | */ 145 | SequenceDataset sequenceTestDataset = new SequenceDataset(); 146 | try { 147 | sequenceTestDataset.populate(inputTestFilePath); 148 | } catch (IOException e) { 149 | e.printStackTrace(); 150 | Assert.assertTrue(false); 151 | } catch (InstantiationException e) { 152 | e.printStackTrace(); 153 | Assert.assertTrue(false); 154 | } catch (ParsingExampleException e) { 155 | e.printStackTrace(); 156 | Assert.assertTrue(false); 157 | } 158 | 159 | /* 160 | * Tagging and evaluating 161 | */ 162 | // PrintStream ps = new PrintStream(scoreFilePath); 163 | ArrayList labels = new ArrayList(); 164 | ArrayList scores = new ArrayList(); 165 | for (Example example : sequenceTestDataset.getExamples()) { 166 | 167 | SequenceExample sequenceExample = (SequenceExample) example; 168 | SequencePrediction sequencePrediction = (SequencePrediction) predictionFunction.predict(sequenceExample); 169 | 170 | SequencePath bestPath = sequencePrediction.bestPath(); 171 | for (int i = 0; i < sequenceExample.getLenght(); i++) { 172 | // ps.println(bestPath.getAssignedLabel(i) + "\t" + 173 | // bestPath.getScore()); 174 | labels.add(bestPath.getAssignedLabel(i)); 175 | scores.add(bestPath.getScore()); 176 | } 177 | 178 | } 179 | // ps.close(); 180 | 181 | ArrayList oldScores = loadScores(scoreFilePath); 182 | ArrayList oldLabels = loadLabels(scoreFilePath); 183 | 184 | for (int i = 0; i < oldScores.size(); i++) { 185 | Assert.assertEquals(oldScores.get(i), scores.get(i), TOLERANCE); 186 | Assert.assertEquals(labels.get(i).toString(), oldLabels.get(i).toString()); 187 | } 188 | 189 | } 190 | 191 | public static ArrayList loadScores(String filepath) { 192 | try { 193 | ArrayList scores = new ArrayList(); 194 | BufferedReader in = null; 195 | String encoding = "UTF-8"; 196 | if (filepath.endsWith(".gz")) { 197 | in = new BufferedReader( 198 | new InputStreamReader(new GZIPInputStream(new FileInputStream(filepath)), encoding)); 199 | } else { 200 | in = new BufferedReader(new InputStreamReader(new FileInputStream(filepath), encoding)); 201 | } 202 | 203 | String str = ""; 204 | while ((str = in.readLine()) != null) { 205 | scores.add(Double.parseDouble(str.split("\t")[1])); 206 | } 207 | 208 | in.close(); 209 | 210 | return scores; 211 | 212 | } catch (UnsupportedEncodingException e) { 213 | e.printStackTrace(); 214 | Assert.assertTrue(false); 215 | } catch (FileNotFoundException e) { 216 | e.printStackTrace(); 217 | Assert.assertTrue(false); 218 | } catch (IOException e) { 219 | e.printStackTrace(); 220 | Assert.assertTrue(false); 221 | } 222 | 223 | return null; 224 | } 225 | 226 | public static ArrayList loadLabels(String filepath) { 227 | try { 228 | ArrayList res = new ArrayList(); 229 | BufferedReader in = null; 230 | String encoding = "UTF-8"; 231 | if (filepath.endsWith(".gz")) { 232 | in = new BufferedReader( 233 | new InputStreamReader(new GZIPInputStream(new FileInputStream(filepath)), encoding)); 234 | } else { 235 | in = new BufferedReader(new InputStreamReader(new FileInputStream(filepath), encoding)); 236 | } 237 | 238 | String str = ""; 239 | while ((str = in.readLine()) != null) { 240 | res.add(new StringLabel(str.split("\t")[0])); 241 | } 242 | 243 | in.close(); 244 | 245 | return res; 246 | 247 | } catch (UnsupportedEncodingException e) { 248 | e.printStackTrace(); 249 | Assert.assertTrue(false); 250 | } catch (FileNotFoundException e) { 251 | e.printStackTrace(); 252 | Assert.assertTrue(false); 253 | } catch (IOException e) { 254 | e.printStackTrace(); 255 | Assert.assertTrue(false); 256 | } 257 | 258 | return null; 259 | } 260 | 261 | } 262 | -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/README.txt: -------------------------------------------------------------------------------- 1 | The datasets reported in this folder have been created starting from the dataset produced by Thorsten Joachims as an example problem for his SVM^{hmm} implementation. 2 | 3 | The original dataset can be downloaded at: 4 | http://download.joachims.org/svm_hmm/examples/example7.tar.gz 5 | while its description is reported at: 6 | https://www.cs.cornell.edu/people/tj/svm_light/svm_hmm.html -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/declaration_of_independence.klp.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/sequence_learning/declaration_of_independence.klp.gz -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/gettysburg_address.klp.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/sequence_learning/gettysburg_address.klp.gz -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/prediction_test_kernel.txt: -------------------------------------------------------------------------------- 1 | 1 -58.85170393685728 2 | 22 -58.85170393685728 3 | 3 -58.85170393685728 4 | 8 -58.85170393685728 5 | 15 -58.85170393685728 6 | 30 -58.85170393685728 7 | 20 -58.85170393685728 8 | 15 -58.85170393685728 9 | 11 -58.85170393685728 10 | 12 -58.85170393685728 11 | 8 -58.85170393685728 12 | 7 -58.85170393685728 13 | 12 -58.85170393685728 14 | 7 -58.85170393685728 15 | 9 -58.85170393685728 16 | 12 -58.85170393685728 17 | 6 -58.85170393685728 18 | 27 -58.85170393685728 19 | 8 -58.85170393685728 20 | 13 -58.85170393685728 21 | 3 -58.85170393685728 22 | 27 -58.85170393685728 23 | 25 -58.85170393685728 24 | 7 -58.85170393685728 25 | 12 -58.85170393685728 26 | 7 -58.85170393685728 27 | 7 -58.85170393685728 28 | 15 -58.85170393685728 29 | 30 -58.85170393685728 30 | 29 -58.85170393685728 31 | 9 -58.85170393685728 32 | 17 -58.85170393685728 33 | 21 -49.958351223707055 34 | 19 -49.958351223707055 35 | 30 -49.958351223707055 36 | 29 -49.958351223707055 37 | 8 -49.958351223707055 38 | 7 -49.958351223707055 39 | 13 -49.958351223707055 40 | 9 -49.958351223707055 41 | 12 -49.958351223707055 42 | 6 -49.958351223707055 43 | 28 -49.958351223707055 44 | 21 -49.958351223707055 45 | 7 -49.958351223707055 46 | 12 -49.958351223707055 47 | 8 -49.958351223707055 48 | 7 -49.958351223707055 49 | 12 -49.958351223707055 50 | 8 -49.958351223707055 51 | 29 -49.958351223707055 52 | 3 -49.958351223707055 53 | 21 -49.958351223707055 54 | 27 -49.958351223707055 55 | 9 -49.958351223707055 56 | 9 -49.958351223707055 57 | 12 -49.958351223707055 58 | 17 -49.958351223707055 59 | 19 -19.14935390144825 60 | 30 -19.14935390144825 61 | 29 -19.14935390144825 62 | 8 -19.14935390144825 63 | 7 -19.14935390144825 64 | 13 -19.14935390144825 65 | 13 -19.14935390144825 66 | 8 -19.14935390144825 67 | 7 -19.14935390144825 68 | 12 -19.14935390144825 69 | 17 -19.14935390144825 70 | 19 -51.68865761583535 71 | 30 -51.68865761583535 72 | 9 -51.68865761583535 73 | 25 -51.68865761583535 74 | 26 -51.68865761583535 75 | 7 -51.68865761583535 76 | 12 -51.68865761583535 77 | 8 -51.68865761583535 78 | 7 -51.68865761583535 79 | 12 -51.68865761583535 80 | 8 -51.68865761583535 81 | 7 -51.68865761583535 82 | 9 -51.68865761583535 83 | 12 -51.68865761583535 84 | 8 -51.68865761583535 85 | 7 -51.68865761583535 86 | 9 -51.68865761583535 87 | 12 -51.68865761583535 88 | 30 -51.68865761583535 89 | 20 -51.68865761583535 90 | 15 -51.68865761583535 91 | 7 -51.68865761583535 92 | 7 -51.68865761583535 93 | 12 -51.68865761583535 94 | 12 -51.68865761583535 95 | 30 -51.68865761583535 96 | 17 -51.68865761583535 97 | 19 -20.546289531993914 98 | 31 -20.546289531993914 99 | 21 -20.546289531993914 100 | 28 -20.546289531993914 101 | 3 -20.546289531993914 102 | 8 -20.546289531993914 103 | 7 -20.546289531993914 104 | 19 -20.546289531993914 105 | 11 -20.546289531993914 106 | 26 -20.546289531993914 107 | 7 -20.546289531993914 108 | 17 -20.546289531993914 109 | 3 -39.18519755398995 110 | 8 -39.18519755398995 111 | 7 -39.18519755398995 112 | 9 -39.18519755398995 113 | 12 -39.18519755398995 114 | 6 -39.18519755398995 115 | 19 -39.18519755398995 116 | 11 -39.18519755398995 117 | 26 -39.18519755398995 118 | 6 -39.18519755398995 119 | 19 -39.18519755398995 120 | 11 -39.18519755398995 121 | 26 -39.18519755398995 122 | 6 -39.18519755398995 123 | 19 -39.18519755398995 124 | 11 -39.18519755398995 125 | 26 -39.18519755398995 126 | 7 -39.18519755398995 127 | 12 -39.18519755398995 128 | 17 -39.18519755398995 129 | 7 -46.958838324933005 130 | 9 -46.958838324933005 131 | 15 -46.958838324933005 132 | 6 -46.958838324933005 133 | 28 -46.958838324933005 134 | 3 -46.958838324933005 135 | 12 -46.958838324933005 136 | 21 -46.958838324933005 137 | 29 -46.958838324933005 138 | 21 -46.958838324933005 139 | 30 -46.958838324933005 140 | 29 -46.958838324933005 141 | 19 -46.958838324933005 142 | 20 -46.958838324933005 143 | 12 -46.958838324933005 144 | 20 -46.958838324933005 145 | 9 -46.958838324933005 146 | 12 -46.958838324933005 147 | 25 -46.958838324933005 148 | 26 -46.958838324933005 149 | 3 -46.958838324933005 150 | 9 -46.958838324933005 151 | 17 -46.958838324933005 152 | 7 -49.424689389703104 153 | 12 -49.424689389703104 154 | 11 -49.424689389703104 155 | 26 -49.424689389703104 156 | 9 -49.424689389703104 157 | 3 -49.424689389703104 158 | 9 -49.424689389703104 159 | 12 -49.424689389703104 160 | 8 -49.424689389703104 161 | 19 -49.424689389703104 162 | 11 -49.424689389703104 163 | 21 -49.424689389703104 164 | 6 -49.424689389703104 165 | 3 -49.424689389703104 166 | 19 -49.424689389703104 167 | 8 -49.424689389703104 168 | 12 -49.424689389703104 169 | 26 -49.424689389703104 170 | 7 -49.424689389703104 171 | 19 -49.424689389703104 172 | 30 -49.424689389703104 173 | 21 -49.424689389703104 174 | 17 -49.424689389703104 175 | 19 -52.114114669781316 176 | 31 -52.114114669781316 177 | 8 -52.114114669781316 178 | 19 -52.114114669781316 179 | 7 -52.114114669781316 180 | 28 -52.114114669781316 181 | 21 -52.114114669781316 182 | 25 -52.114114669781316 183 | 26 -52.114114669781316 184 | 27 -52.114114669781316 185 | 21 -52.114114669781316 186 | 25 -52.114114669781316 187 | 7 -52.114114669781316 188 | 9 -52.114114669781316 189 | 12 -52.114114669781316 190 | 32 -52.114114669781316 191 | 19 -52.114114669781316 192 | 30 -52.114114669781316 193 | 11 -52.114114669781316 194 | 21 -52.114114669781316 195 | 30 -52.114114669781316 196 | 21 -52.114114669781316 197 | 29 -52.114114669781316 198 | 21 -52.114114669781316 199 | 21 -52.114114669781316 200 | 29 -52.114114669781316 201 | 17 -52.114114669781316 202 | 19 -157.40459068974272 203 | 31 -157.40459068974272 204 | 21 -157.40459068974272 205 | 8 -157.40459068974272 206 | 19 -157.40459068974272 207 | 25 -157.40459068974272 208 | 26 -157.40459068974272 209 | 21 -157.40459068974272 210 | 27 -157.40459068974272 211 | 25 -157.40459068974272 212 | 7 -157.40459068974272 213 | 13 -157.40459068974272 214 | 13 -157.40459068974272 215 | 28 -157.40459068974272 216 | 21 -157.40459068974272 217 | 19 -157.40459068974272 218 | 5 -157.40459068974272 219 | 7 -157.40459068974272 220 | 8 -157.40459068974272 221 | 7 -157.40459068974272 222 | 9 -157.40459068974272 223 | 12 -157.40459068974272 224 | 19 -157.40459068974272 225 | 30 -157.40459068974272 226 | 9 -157.40459068974272 227 | 12 -157.40459068974272 228 | 25 -157.40459068974272 229 | 7 -157.40459068974272 230 | 12 -157.40459068974272 231 | 8 -157.40459068974272 232 | 32 -157.40459068974272 233 | 19 -157.40459068974272 234 | 30 -157.40459068974272 235 | 7 -157.40459068974272 236 | 23 -157.40459068974272 237 | 9 -157.40459068974272 238 | 12 -157.40459068974272 239 | 8 -157.40459068974272 240 | 12 -157.40459068974272 241 | 8 -157.40459068974272 242 | 7 -157.40459068974272 243 | 19 -157.40459068974272 244 | 21 -157.40459068974272 245 | 21 -157.40459068974272 246 | 26 -157.40459068974272 247 | 7 -157.40459068974272 248 | 7 -157.40459068974272 249 | 12 -157.40459068974272 250 | 11 -157.40459068974272 251 | 21 -157.40459068974272 252 | 30 -157.40459068974272 253 | 29 -157.40459068974272 254 | 8 -157.40459068974272 255 | 9 -157.40459068974272 256 | 6 -157.40459068974272 257 | 7 -157.40459068974272 258 | 7 -157.40459068974272 259 | 12 -157.40459068974272 260 | 8 -157.40459068974272 261 | 13 -157.40459068974272 262 | 11 -157.40459068974272 263 | 30 -157.40459068974272 264 | 7 -157.40459068974272 265 | 9 -157.40459068974272 266 | 12 -157.40459068974272 267 | 8 -157.40459068974272 268 | 12 -157.40459068974272 269 | 6 -157.40459068974272 270 | 3 -157.40459068974272 271 | 7 -157.40459068974272 272 | 12 -157.40459068974272 273 | 8 -157.40459068974272 274 | 7 -157.40459068974272 275 | 15 -157.40459068974272 276 | 6 -157.40459068974272 277 | 8 -157.40459068974272 278 | 7 -157.40459068974272 279 | 15 -157.40459068974272 280 | 6 -157.40459068974272 281 | 8 -157.40459068974272 282 | 7 -157.40459068974272 283 | 15 -157.40459068974272 284 | 11 -157.40459068974272 285 | 21 -157.40459068974272 286 | 26 -157.40459068974272 287 | 8 -157.40459068974272 288 | 7 -157.40459068974272 289 | 12 -157.40459068974272 290 | 17 -157.40459068974272 291 | -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/prediction_test_linear.txt: -------------------------------------------------------------------------------- 1 | 1 -61.552865965064605 2 | 22 -61.552865965064605 3 | 3 -61.552865965064605 4 | 8 -61.552865965064605 5 | 15 -61.552865965064605 6 | 30 -61.552865965064605 7 | 20 -61.552865965064605 8 | 15 -61.552865965064605 9 | 11 -61.552865965064605 10 | 12 -61.552865965064605 11 | 8 -61.552865965064605 12 | 7 -61.552865965064605 13 | 12 -61.552865965064605 14 | 7 -61.552865965064605 15 | 9 -61.552865965064605 16 | 12 -61.552865965064605 17 | 6 -61.552865965064605 18 | 27 -61.552865965064605 19 | 8 -61.552865965064605 20 | 13 -61.552865965064605 21 | 3 -61.552865965064605 22 | 27 -61.552865965064605 23 | 25 -61.552865965064605 24 | 7 -61.552865965064605 25 | 12 -61.552865965064605 26 | 7 -61.552865965064605 27 | 7 -61.552865965064605 28 | 15 -61.552865965064605 29 | 30 -61.552865965064605 30 | 29 -61.552865965064605 31 | 9 -61.552865965064605 32 | 17 -61.552865965064605 33 | 21 -50.586976361817456 34 | 19 -50.586976361817456 35 | 30 -50.586976361817456 36 | 29 -50.586976361817456 37 | 8 -50.586976361817456 38 | 7 -50.586976361817456 39 | 13 -50.586976361817456 40 | 9 -50.586976361817456 41 | 12 -50.586976361817456 42 | 6 -50.586976361817456 43 | 28 -50.586976361817456 44 | 21 -50.586976361817456 45 | 7 -50.586976361817456 46 | 12 -50.586976361817456 47 | 3 -50.586976361817456 48 | 7 -50.586976361817456 49 | 12 -50.586976361817456 50 | 21 -50.586976361817456 51 | 21 -50.586976361817456 52 | 3 -50.586976361817456 53 | 21 -50.586976361817456 54 | 21 -50.586976361817456 55 | 7 -50.586976361817456 56 | 9 -50.586976361817456 57 | 12 -50.586976361817456 58 | 17 -50.586976361817456 59 | 19 -19.745991163812985 60 | 30 -19.745991163812985 61 | 26 -19.745991163812985 62 | 8 -19.745991163812985 63 | 7 -19.745991163812985 64 | 13 -19.745991163812985 65 | 13 -19.745991163812985 66 | 8 -19.745991163812985 67 | 7 -19.745991163812985 68 | 12 -19.745991163812985 69 | 17 -19.745991163812985 70 | 19 -54.528993898737625 71 | 30 -54.528993898737625 72 | 9 -54.528993898737625 73 | 25 -54.528993898737625 74 | 26 -54.528993898737625 75 | 7 -54.528993898737625 76 | 12 -54.528993898737625 77 | 8 -54.528993898737625 78 | 7 -54.528993898737625 79 | 12 -54.528993898737625 80 | 8 -54.528993898737625 81 | 7 -54.528993898737625 82 | 9 -54.528993898737625 83 | 12 -54.528993898737625 84 | 8 -54.528993898737625 85 | 7 -54.528993898737625 86 | 30 -54.528993898737625 87 | 21 -54.528993898737625 88 | 30 -54.528993898737625 89 | 20 -54.528993898737625 90 | 15 -54.528993898737625 91 | 7 -54.528993898737625 92 | 7 -54.528993898737625 93 | 12 -54.528993898737625 94 | 12 -54.528993898737625 95 | 30 -54.528993898737625 96 | 17 -54.528993898737625 97 | 19 -21.642864657110263 98 | 31 -21.642864657110263 99 | 21 -21.642864657110263 100 | 28 -21.642864657110263 101 | 3 -21.642864657110263 102 | 8 -21.642864657110263 103 | 7 -21.642864657110263 104 | 19 -21.642864657110263 105 | 11 -21.642864657110263 106 | 26 -21.642864657110263 107 | 7 -21.642864657110263 108 | 17 -21.642864657110263 109 | 3 -41.163958681094705 110 | 8 -41.163958681094705 111 | 7 -41.163958681094705 112 | 9 -41.163958681094705 113 | 12 -41.163958681094705 114 | 6 -41.163958681094705 115 | 19 -41.163958681094705 116 | 11 -41.163958681094705 117 | 26 -41.163958681094705 118 | 6 -41.163958681094705 119 | 19 -41.163958681094705 120 | 11 -41.163958681094705 121 | 26 -41.163958681094705 122 | 6 -41.163958681094705 123 | 19 -41.163958681094705 124 | 11 -41.163958681094705 125 | 26 -41.163958681094705 126 | 7 -41.163958681094705 127 | 12 -41.163958681094705 128 | 17 -41.163958681094705 129 | 7 -47.69120077996808 130 | 9 -47.69120077996808 131 | 15 -47.69120077996808 132 | 6 -47.69120077996808 133 | 28 -47.69120077996808 134 | 3 -47.69120077996808 135 | 12 -47.69120077996808 136 | 21 -47.69120077996808 137 | 29 -47.69120077996808 138 | 21 -47.69120077996808 139 | 30 -47.69120077996808 140 | 29 -47.69120077996808 141 | 19 -47.69120077996808 142 | 12 -47.69120077996808 143 | 30 -47.69120077996808 144 | 20 -47.69120077996808 145 | 9 -47.69120077996808 146 | 12 -47.69120077996808 147 | 25 -47.69120077996808 148 | 26 -47.69120077996808 149 | 3 -47.69120077996808 150 | 12 -47.69120077996808 151 | 17 -47.69120077996808 152 | 7 -49.193906625295206 153 | 12 -49.193906625295206 154 | 11 -49.193906625295206 155 | 26 -49.193906625295206 156 | 9 -49.193906625295206 157 | 3 -49.193906625295206 158 | 9 -49.193906625295206 159 | 12 -49.193906625295206 160 | 7 -49.193906625295206 161 | 19 -49.193906625295206 162 | 11 -49.193906625295206 163 | 21 -49.193906625295206 164 | 6 -49.193906625295206 165 | 3 -49.193906625295206 166 | 19 -49.193906625295206 167 | 8 -49.193906625295206 168 | 12 -49.193906625295206 169 | 26 -49.193906625295206 170 | 7 -49.193906625295206 171 | 19 -49.193906625295206 172 | 30 -49.193906625295206 173 | 21 -49.193906625295206 174 | 17 -49.193906625295206 175 | 19 -52.77654733531991 176 | 31 -52.77654733531991 177 | 8 -52.77654733531991 178 | 19 -52.77654733531991 179 | 7 -52.77654733531991 180 | 28 -52.77654733531991 181 | 21 -52.77654733531991 182 | 25 -52.77654733531991 183 | 26 -52.77654733531991 184 | 27 -52.77654733531991 185 | 21 -52.77654733531991 186 | 25 -52.77654733531991 187 | 7 -52.77654733531991 188 | 9 -52.77654733531991 189 | 12 -52.77654733531991 190 | 32 -52.77654733531991 191 | 19 -52.77654733531991 192 | 30 -52.77654733531991 193 | 11 -52.77654733531991 194 | 21 -52.77654733531991 195 | 30 -52.77654733531991 196 | 21 -52.77654733531991 197 | 12 -52.77654733531991 198 | 21 -52.77654733531991 199 | 21 -52.77654733531991 200 | 29 -52.77654733531991 201 | 17 -52.77654733531991 202 | 19 -162.86896772139426 203 | 31 -162.86896772139426 204 | 21 -162.86896772139426 205 | 8 -162.86896772139426 206 | 19 -162.86896772139426 207 | 25 -162.86896772139426 208 | 26 -162.86896772139426 209 | 21 -162.86896772139426 210 | 27 -162.86896772139426 211 | 25 -162.86896772139426 212 | 7 -162.86896772139426 213 | 13 -162.86896772139426 214 | 13 -162.86896772139426 215 | 28 -162.86896772139426 216 | 21 -162.86896772139426 217 | 19 -162.86896772139426 218 | 5 -162.86896772139426 219 | 7 -162.86896772139426 220 | 8 -162.86896772139426 221 | 7 -162.86896772139426 222 | 9 -162.86896772139426 223 | 12 -162.86896772139426 224 | 19 -162.86896772139426 225 | 30 -162.86896772139426 226 | 9 -162.86896772139426 227 | 12 -162.86896772139426 228 | 25 -162.86896772139426 229 | 7 -162.86896772139426 230 | 12 -162.86896772139426 231 | 8 -162.86896772139426 232 | 32 -162.86896772139426 233 | 19 -162.86896772139426 234 | 30 -162.86896772139426 235 | 7 -162.86896772139426 236 | 23 -162.86896772139426 237 | 9 -162.86896772139426 238 | 12 -162.86896772139426 239 | 8 -162.86896772139426 240 | 12 -162.86896772139426 241 | 5 -162.86896772139426 242 | 7 -162.86896772139426 243 | 19 -162.86896772139426 244 | 21 -162.86896772139426 245 | 21 -162.86896772139426 246 | 26 -162.86896772139426 247 | 7 -162.86896772139426 248 | 7 -162.86896772139426 249 | 12 -162.86896772139426 250 | 11 -162.86896772139426 251 | 21 -162.86896772139426 252 | 30 -162.86896772139426 253 | 29 -162.86896772139426 254 | 8 -162.86896772139426 255 | 9 -162.86896772139426 256 | 6 -162.86896772139426 257 | 7 -162.86896772139426 258 | 7 -162.86896772139426 259 | 12 -162.86896772139426 260 | 8 -162.86896772139426 261 | 13 -162.86896772139426 262 | 11 -162.86896772139426 263 | 30 -162.86896772139426 264 | 7 -162.86896772139426 265 | 9 -162.86896772139426 266 | 12 -162.86896772139426 267 | 8 -162.86896772139426 268 | 12 -162.86896772139426 269 | 6 -162.86896772139426 270 | 3 -162.86896772139426 271 | 8 -162.86896772139426 272 | 12 -162.86896772139426 273 | 8 -162.86896772139426 274 | 7 -162.86896772139426 275 | 15 -162.86896772139426 276 | 6 -162.86896772139426 277 | 8 -162.86896772139426 278 | 7 -162.86896772139426 279 | 15 -162.86896772139426 280 | 6 -162.86896772139426 281 | 8 -162.86896772139426 282 | 7 -162.86896772139426 283 | 15 -162.86896772139426 284 | 11 -162.86896772139426 285 | 21 -162.86896772139426 286 | 26 -162.86896772139426 287 | 8 -162.86896772139426 288 | 7 -162.86896772139426 289 | 12 -162.86896772139426 290 | 17 -162.86896772139426 291 | -------------------------------------------------------------------------------- /src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz --------------------------------------------------------------------------------
31 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 32 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 33 | * 34 | * @author Simone Filice 35 | */ 36 | @JsonTypeName("kernelizedPA-R") 37 | public class KernelizedPassiveAggressiveRegression extends PassiveAggressiveRegression implements KernelMethod{ 38 | 39 | private Kernel kernel; 40 | 41 | public KernelizedPassiveAggressiveRegression(){ 42 | this.regressor = new UnivariateKernelMachineRegressionFunction(); 43 | } 44 | 45 | public KernelizedPassiveAggressiveRegression(float aggressiveness, float epsilon, Policy policy, Kernel kernel, Label label){ 46 | this.regressor = new UnivariateKernelMachineRegressionFunction(); 47 | this.setC(aggressiveness); 48 | this.setEpsilon(epsilon); 49 | this.setPolicy(policy); 50 | this.setKernel(kernel); 51 | this.setLabel(label); 52 | } 53 | 54 | @Override 55 | public Kernel getKernel(){ 56 | return kernel; 57 | } 58 | 59 | @Override 60 | public void setKernel(Kernel kernel) { 61 | this.kernel = kernel; 62 | this.getPredictionFunction().getModel().setKernel(kernel); 63 | } 64 | 65 | @Override 66 | public KernelizedPassiveAggressiveRegression duplicate() { 67 | KernelizedPassiveAggressiveRegression copy = new KernelizedPassiveAggressiveRegression(); 68 | copy.setC(this.c); 69 | copy.setKernel(this.kernel); 70 | copy.setPolicy(this.policy); 71 | copy.setEpsilon(epsilon); 72 | return copy; 73 | } 74 | 75 | @Override 76 | public UnivariateKernelMachineRegressionFunction getPredictionFunction(){ 77 | return (UnivariateKernelMachineRegressionFunction) this.regressor; 78 | } 79 | 80 | @Override 81 | public void setPredictionFunction(PredictionFunction predictionFunction) { 82 | this.regressor = (UnivariateKernelMachineRegressionFunction) predictionFunction; 83 | } 84 | 85 | } 86 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/LinearPassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import com.fasterxml.jackson.annotation.JsonTypeName; 19 | 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod; 22 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 23 | import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel; 24 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateLinearRegressionFunction; 25 | 26 | /** 27 | * Online Passive-Aggressive Learning Algorithm for regression tasks (linear version). 28 | * 29 | * reference: 30 | *
31 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 32 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 33 | * 34 | * @author Simone Filice 35 | */ 36 | @JsonTypeName("linearPA-R") 37 | public class LinearPassiveAggressiveRegression extends PassiveAggressiveRegression implements LinearMethod{ 38 | 39 | private String representation; 40 | 41 | public LinearPassiveAggressiveRegression(){ 42 | UnivariateLinearRegressionFunction regressor = new UnivariateLinearRegressionFunction(); 43 | regressor.setModel(new BinaryLinearModel()); 44 | this.regressor = regressor; 45 | 46 | } 47 | 48 | public LinearPassiveAggressiveRegression(float aggressiveness, float epsilon, Policy policy, String representation, Label label){ 49 | UnivariateLinearRegressionFunction regressor = new UnivariateLinearRegressionFunction(); 50 | regressor.setModel(new BinaryLinearModel()); 51 | this.regressor = regressor; 52 | this.setC(aggressiveness); 53 | this.setEpsilon(epsilon); 54 | this.setPolicy(policy); 55 | this.setRepresentation(representation); 56 | this.setLabel(label); 57 | } 58 | 59 | @Override 60 | public LinearPassiveAggressiveRegression duplicate() { 61 | LinearPassiveAggressiveRegression copy = new LinearPassiveAggressiveRegression(); 62 | copy.setC(this.c); 63 | copy.setRepresentation(this.representation); 64 | copy.setPolicy(this.policy); 65 | copy.setEpsilon(epsilon); 66 | return copy; 67 | } 68 | 69 | @Override 70 | public String getRepresentation() { 71 | return representation; 72 | } 73 | 74 | @Override 75 | public void setRepresentation(String representation) { 76 | this.representation = representation; 77 | this.getPredictionFunction().getModel().setRepresentation(representation); 78 | } 79 | 80 | @Override 81 | public UnivariateLinearRegressionFunction getPredictionFunction(){ 82 | return (UnivariateLinearRegressionFunction) this.regressor; 83 | } 84 | 85 | @Override 86 | public void setPredictionFunction(PredictionFunction predictionFunction) { 87 | this.regressor = (UnivariateLinearRegressionFunction) predictionFunction; 88 | } 89 | 90 | } 91 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/learningalgorithm/regression/passiveaggressive/PassiveAggressiveRegression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.regression.passiveaggressive; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.learningalgorithm.PassiveAggressive; 21 | import it.uniroma2.sag.kelp.learningalgorithm.regression.RegressionLearningAlgorithm; 22 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateRegressionOutput; 23 | import it.uniroma2.sag.kelp.predictionfunction.regressionfunction.UnivariateRegressionFunction; 24 | 25 | import com.fasterxml.jackson.annotation.JsonIgnore; 26 | 27 | /** 28 | * Online Passive-Aggressive Learning Algorithm for regression tasks. 29 | * 30 | * reference: 31 | *
32 | * [CrammerJLMR2006] Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz and Yoram Singer 33 | * Online Passive-Aggressive Algorithms. Journal of Machine Learning Research (2006) 34 | * 35 | * @author Simone Filice 36 | */ 37 | public abstract class PassiveAggressiveRegression extends PassiveAggressive implements RegressionLearningAlgorithm{ 38 | 39 | @JsonIgnore 40 | protected UnivariateRegressionFunction regressor; 41 | 42 | protected float epsilon; 43 | 44 | /** 45 | * Returns epsilon, i.e. the accepted distance between the predicted and the real regression values 46 | * 47 | * @return the epsilon 48 | */ 49 | public float getEpsilon() { 50 | return epsilon; 51 | } 52 | 53 | /** 54 | * Sets epsilon, i.e. the accepted distance between the predicted and the real regression values 55 | * 56 | * @param epsilon the epsilon to set 57 | */ 58 | public void setEpsilon(float epsilon) { 59 | this.epsilon = epsilon; 60 | } 61 | 62 | @Override 63 | public UnivariateRegressionFunction getPredictionFunction() { 64 | return this.regressor; 65 | } 66 | 67 | @Override 68 | public void learn(Dataset dataset){ 69 | 70 | while(dataset.hasNextExample()){ 71 | Example example = dataset.getNextExample(); 72 | this.learn(example); 73 | } 74 | dataset.reset(); 75 | } 76 | 77 | @Override 78 | public UnivariateRegressionOutput learn(Example example){ 79 | UnivariateRegressionOutput prediction=this.regressor.predict(example); 80 | float difference = example.getRegressionValue(label) - prediction.getScore(label); 81 | float lossValue = Math.abs(difference) - epsilon;//it represents the distance from the correct semi-space 82 | if(lossValue>0){ 83 | float exampleSquaredNorm = this.regressor.getModel().getSquaredNorm(example); 84 | float weight = this.computeWeight(example, lossValue, exampleSquaredNorm, c); 85 | if(difference<0){ 86 | weight = -weight; 87 | } 88 | this.regressor.getModel().addExample(weight, example); 89 | } 90 | return prediction; 91 | } 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/linearization/LinearizationFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.linearization; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.Dataset; 19 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 20 | import it.uniroma2.sag.kelp.data.example.Example; 21 | import it.uniroma2.sag.kelp.data.representation.Vector; 22 | 23 | /** 24 | * This interface allows implementing function to linearized examples through 25 | * linear representations, i.e. vectors 26 | * 27 | * 28 | * @author Danilo Croce 29 | * 30 | */ 31 | public interface LinearizationFunction { 32 | 33 | /** 34 | * Given an input Example, this method generates a linear 35 | * Representation>, i.e. a Vector. 36 | * 37 | * @param example 38 | * The input example. 39 | * @return The linearized representation of the input example. 40 | */ 41 | public Vector getLinearRepresentation(Example example); 42 | 43 | /** 44 | * This method linearizes an input example, providing a new example 45 | * containing only a representation with a specific name, provided as input. 46 | * The produced example inherits the labels of the input example. 47 | * 48 | * @param example 49 | * The input example. 50 | * @param vectorName 51 | * The name of the linear representation inside the new example 52 | * @return 53 | */ 54 | public Example getLinearizedExample(Example example, String representationName); 55 | 56 | /** 57 | * This method linearizes all the examples in the input dataset 58 | * , generating a corresponding linearized dataset. The produced examples 59 | * inherit the labels of the corresponding input examples. 60 | * 61 | * @param dataset 62 | * The input dataset 63 | * @param representationName 64 | * The name of the linear representation inside the new examples 65 | * @return 66 | */ 67 | public SimpleDataset getLinearizedDataset(Dataset dataset, String representationName); 68 | 69 | /** 70 | * @return the size of the resulting embedding, i.e. the number of resulting 71 | * vector dimensions 72 | */ 73 | public int getEmbeddingSize(); 74 | 75 | } 76 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/predictionfunction/SequencePrediction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.predictionfunction; 17 | 18 | import java.util.ArrayList; 19 | import java.util.List; 20 | 21 | import it.uniroma2.sag.kelp.data.example.SequencePath; 22 | import it.uniroma2.sag.kelp.data.label.Label; 23 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 24 | 25 | /** 26 | * It is a output provided by a machine learning systems on a sequence. This 27 | * specific implementation allows to assign multiple labelings to single 28 | * sequence, useful for some labeling strategies, such as Beam Search. Notice 29 | * that each labeling requires a score to select the more promising labeling. 30 | * 31 | * @author Danilo Croce 32 | * 33 | */ 34 | public class SequencePrediction implements Prediction { 35 | 36 | /** 37 | * 38 | */ 39 | private static final long serialVersionUID = -1040539866977906008L; 40 | /** 41 | * This list contains multiple labelings to be assigned to a single sequence 42 | */ 43 | private List paths; 44 | 45 | public SequencePrediction() { 46 | paths = new ArrayList(); 47 | } 48 | 49 | /** 50 | * @return The best path, i.e., the labeling with the highest score in the 51 | * list of labelings provided by a classifier 52 | */ 53 | public SequencePath bestPath() { 54 | return paths.get(0); 55 | } 56 | 57 | /** 58 | * @return a list containing multiple labelings to be assigned to a single 59 | * sequence 60 | */ 61 | public List getPaths() { 62 | return paths; 63 | } 64 | 65 | @Override 66 | public Float getScore(Label label) { 67 | return null; 68 | } 69 | 70 | /** 71 | * @param paths 72 | * a list contains multiple labelings to be assigned to a single 73 | * sequence 74 | */ 75 | public void setPaths(List paths) { 76 | this.paths = paths; 77 | } 78 | 79 | @Override 80 | public String toString() { 81 | StringBuilder sb = new StringBuilder(); 82 | for (int i = 0; i < paths.size(); i++) { 83 | if (i == 0) 84 | sb.append("Best Path\t"); 85 | else 86 | sb.append("Altern. Path\t"); 87 | SequencePath sequencePath = paths.get(i); 88 | sb.append(sequencePath + "\n"); 89 | } 90 | return sb.toString(); 91 | } 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/predictionfunction/model/SequenceModel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.predictionfunction.model; 17 | 18 | import it.uniroma2.sag.kelp.data.examplegenerator.SequenceExampleGenerator; 19 | import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction; 20 | 21 | /** 22 | * This class implements a model produced by a 23 | * SequenceClassificationLearningAlgorithm 24 | * 25 | * @author Danilo Croce 26 | * 27 | */ 28 | public class SequenceModel implements Model { 29 | 30 | /** 31 | * 32 | */ 33 | private static final long serialVersionUID = -2749198158786953940L; 34 | 35 | /** 36 | * The prediction function producing the emission scores to be considered in 37 | * the Viterbi Decoding 38 | */ 39 | private PredictionFunction basePredictionFunction; 40 | 41 | private SequenceExampleGenerator sequenceExampleGenerator; 42 | 43 | public SequenceModel() { 44 | super(); 45 | } 46 | 47 | public SequenceModel(PredictionFunction basePredictionFunction, SequenceExampleGenerator sequenceExampleGenerator) { 48 | super(); 49 | this.basePredictionFunction = basePredictionFunction; 50 | this.sequenceExampleGenerator = sequenceExampleGenerator; 51 | } 52 | 53 | public PredictionFunction getBasePredictionFunction() { 54 | return basePredictionFunction; 55 | } 56 | 57 | public SequenceExampleGenerator getSequenceExampleGenerator() { 58 | return sequenceExampleGenerator; 59 | } 60 | 61 | @Override 62 | public void reset() { 63 | } 64 | 65 | public void setBasePredictionFunction(PredictionFunction basePredictionFunction) { 66 | this.basePredictionFunction = basePredictionFunction; 67 | } 68 | 69 | public void setSequenceExampleGenerator(SequenceExampleGenerator sequenceExampleGenerator) { 70 | this.sequenceExampleGenerator = sequenceExampleGenerator; 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/utils/evaluation/ClusteringEvaluator.java: -------------------------------------------------------------------------------- 1 | package it.uniroma2.sag.kelp.utils.evaluation; 2 | 3 | import java.util.ArrayList; 4 | import java.util.HashSet; 5 | import java.util.TreeMap; 6 | 7 | import it.uniroma2.sag.kelp.data.clustering.Cluster; 8 | import it.uniroma2.sag.kelp.data.clustering.ClusterExample; 9 | import it.uniroma2.sag.kelp.data.clustering.ClusterList; 10 | import it.uniroma2.sag.kelp.data.example.Example; 11 | import it.uniroma2.sag.kelp.data.example.SimpleExample; 12 | import it.uniroma2.sag.kelp.data.label.Label; 13 | import it.uniroma2.sag.kelp.data.label.StringLabel; 14 | import it.uniroma2.sag.kelp.learningalgorithm.clustering.kernelbasedkmeans.KernelBasedKMeansExample; 15 | 16 | /** 17 | * 18 | * Implements Evaluation methods for clustering algorithms. 19 | * 20 | * More details about Purity and NMI can be found here: 21 | * 22 | * https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1. 23 | * html 24 | * 25 | * @author Danilo Croce 26 | * 27 | */ 28 | public class ClusteringEvaluator { 29 | 30 | public static float getPurity(ClusterList clusters) { 31 | 32 | float res = 0; 33 | int k = clusters.size(); 34 | 35 | for (int clustId = 0; clustId < k; clustId++) { 36 | 37 | TreeMap classSizes = new TreeMap(); 38 | 39 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 40 | HashSet labels = vce.getExample().getClassificationLabels(); 41 | for (Label label : labels) 42 | if (!classSizes.containsKey(label)) 43 | classSizes.put(label, 1); 44 | else 45 | classSizes.put(label, classSizes.get(label) + 1); 46 | } 47 | 48 | int maxSize = 0; 49 | for (int size : classSizes.values()) { 50 | if (size > maxSize) { 51 | maxSize = size; 52 | } 53 | } 54 | res += maxSize; 55 | } 56 | 57 | return res / (float) clusters.getNumberOfExamples(); 58 | } 59 | 60 | public static float getMI(ClusterList clusters) { 61 | 62 | float res = 0; 63 | 64 | float N = clusters.getNumberOfExamples(); 65 | 66 | int k = clusters.size(); 67 | 68 | TreeMap classCardinality = getClassCardinality(clusters); 69 | 70 | for (int clustId = 0; clustId < k; clustId++) { 71 | 72 | TreeMap classSizes = getClassCardinalityWithinCluster(clusters, clustId); 73 | 74 | for (Label className : classSizes.keySet()) { 75 | int wSize = classSizes.get(className); 76 | res += ((float) wSize / N) * myLog(N * (float) wSize 77 | / (clusters.get(clustId).getExamples().size() * (float) classCardinality.get(className))); 78 | } 79 | 80 | } 81 | 82 | return res; 83 | 84 | } 85 | 86 | private static TreeMap getClassCardinalityWithinCluster(ClusterList clusters, int clustId) { 87 | 88 | TreeMap classSizes = new TreeMap(); 89 | 90 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 91 | HashSet labels = vce.getExample().getClassificationLabels(); 92 | for (Label label : labels) 93 | if (!classSizes.containsKey(label)) 94 | classSizes.put(label, 1); 95 | else 96 | classSizes.put(label, classSizes.get(label) + 1); 97 | } 98 | 99 | return classSizes; 100 | } 101 | 102 | private static float getClusterEntropy(ClusterList clusters) { 103 | 104 | float res = 0; 105 | float N = clusters.getNumberOfExamples(); 106 | int k = clusters.size(); 107 | 108 | for (int clustId = 0; clustId < k; clustId++) { 109 | int clusterElementSize = clusters.get(clustId).getExamples().size(); 110 | if (clusterElementSize != 0) 111 | res -= ((float) clusterElementSize / N) * myLog((float) clusterElementSize / N); 112 | } 113 | return res; 114 | 115 | } 116 | 117 | private static float getClassEntropy(ClusterList clusters) { 118 | 119 | float res = 0; 120 | float N = clusters.getNumberOfExamples(); 121 | 122 | TreeMap classCardinality = getClassCardinality(clusters); 123 | 124 | for (int classSize : classCardinality.values()) { 125 | res -= ((float) classSize / N) * myLog((float) classSize / N); 126 | } 127 | return res; 128 | 129 | } 130 | 131 | private static float myLog(float f) { 132 | return (float) (Math.log(f) / Math.log(2f)); 133 | } 134 | 135 | private static TreeMap getClassCardinality(ClusterList clusters) { 136 | TreeMap classSizes = new TreeMap(); 137 | 138 | int k = clusters.size(); 139 | 140 | for (int clustId = 0; clustId < k; clustId++) { 141 | 142 | for (ClusterExample vce : clusters.get(clustId).getExamples()) { 143 | HashSet labels = vce.getExample().getClassificationLabels(); 144 | for (Label label : labels) 145 | if (!classSizes.containsKey(label)) 146 | classSizes.put(label, 1); 147 | else 148 | classSizes.put(label, classSizes.get(label) + 1); 149 | } 150 | } 151 | return classSizes; 152 | } 153 | 154 | public static float getNMI(ClusterList clusters) { 155 | return getMI(clusters) / ((getClusterEntropy(clusters) + getClassEntropy(clusters)) / 2f); 156 | } 157 | 158 | public static String getStatistics(ClusterList clusters) { 159 | StringBuilder sb = new StringBuilder(); 160 | 161 | sb.append("Purity:\t" + getPurity(clusters) + "\n"); 162 | sb.append("Mutual Information:\t" + getMI(clusters) + "\n"); 163 | sb.append("Cluster Entropy:\t" + getClusterEntropy(clusters) + "\n"); 164 | sb.append("Class Entropy:\t" + getClassEntropy(clusters) + "\n"); 165 | sb.append("NMI:\t" + getNMI(clusters)); 166 | 167 | return sb.toString(); 168 | } 169 | 170 | public static void main(String[] args) { 171 | ClusterList clusters = new ClusterList(); 172 | 173 | Cluster c1 = new Cluster("C1"); 174 | ArrayList list1 = new ArrayList(); 175 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 176 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 177 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 178 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 179 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 180 | list1.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 181 | for (Example e : list1) { 182 | c1.add(new KernelBasedKMeansExample(e, 1f)); 183 | } 184 | 185 | Cluster c2 = new Cluster("C2"); 186 | ArrayList list2 = new ArrayList(); 187 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 188 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 189 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 190 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 191 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("o") }, null)); 192 | list2.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 193 | for (Example e : list2) { 194 | c2.add(new KernelBasedKMeansExample(e, 1f)); 195 | } 196 | 197 | Cluster c3 = new Cluster("C3"); 198 | ArrayList list3 = new ArrayList(); 199 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 200 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 201 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("q") }, null)); 202 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 203 | list3.add(new SimpleExample(new StringLabel[] { new StringLabel("x") }, null)); 204 | for (Example e : list3) { 205 | c3.add(new KernelBasedKMeansExample(e, 1f)); 206 | } 207 | 208 | clusters.add(c1); 209 | clusters.add(c2); 210 | clusters.add(c3); 211 | 212 | System.out.println(ClusteringEvaluator.getStatistics(clusters)); 213 | 214 | //From https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html 215 | //Purity = 0.71 216 | //NMI = 0.36 217 | 218 | } 219 | 220 | } 221 | -------------------------------------------------------------------------------- /src/main/java/it/uniroma2/sag/kelp/utils/evaluation/MulticlassSequenceClassificationEvaluator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.utils.evaluation; 17 | 18 | import java.util.List; 19 | 20 | import it.uniroma2.sag.kelp.data.example.Example; 21 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 22 | import it.uniroma2.sag.kelp.data.example.SequencePath; 23 | import it.uniroma2.sag.kelp.data.label.Label; 24 | import it.uniroma2.sag.kelp.data.label.SequenceEmission; 25 | import it.uniroma2.sag.kelp.predictionfunction.Prediction; 26 | import it.uniroma2.sag.kelp.predictionfunction.SequencePrediction; 27 | 28 | /** 29 | * This is an instance of an Evaluator. It allows to compute the some common 30 | * measure for classification tasks acting over SequenceExamples. It 31 | * computes precision, recall, f1s for each class, and a global accuracy. 32 | * 33 | * @author Danilo Croce 34 | */ 35 | public class MulticlassSequenceClassificationEvaluator extends MulticlassClassificationEvaluator{ 36 | 37 | /** 38 | * Initialize a new F1Evaluator that will work on the specified classes 39 | * 40 | * @param labels 41 | */ 42 | public MulticlassSequenceClassificationEvaluator(List labels) { 43 | super(labels); 44 | } 45 | 46 | public void addCount(Example test, Prediction prediction) { 47 | addCount((SequenceExample) test, (SequencePrediction) prediction); 48 | } 49 | 50 | /** 51 | * This method should be implemented in the subclasses to update counters 52 | * useful to compute the performance measure 53 | * 54 | * @param test 55 | * the test example 56 | * @param predicted 57 | * the prediction of the system 58 | */ 59 | public void addCount(SequenceExample test, SequencePrediction predicted) { 60 | 61 | SequencePath bestPath = predicted.bestPath(); 62 | 63 | for (int seqIdx = 0; seqIdx < test.getLenght(); seqIdx++) { 64 | 65 | Example testItem = test.getExample(seqIdx); 66 | SequenceEmission sequenceLabel = bestPath.getAssignedSequnceLabels().get(seqIdx); 67 | 68 | for (Label l : this.labels) { 69 | ClassStats stats = this.classStats.get(l); 70 | if(testItem.isExampleOf(l)){ 71 | if(sequenceLabel.getLabel().equals(l)){ 72 | stats.tp++; 73 | totalTp++; 74 | }else{ 75 | stats.fn++; 76 | totalFn++; 77 | } 78 | }else{ 79 | if(sequenceLabel.getLabel().equals(l)){ 80 | stats.fp++; 81 | totalFp++; 82 | }else{ 83 | stats.tn++; 84 | totalTn++; 85 | } 86 | } 87 | 88 | } 89 | 90 | //TODO: check (i) e' giusto valutare l'accuracy dei singoli elementi della sequenza e non della sequenza completa 91 | //(ii) va considerato il caso multilabel 92 | total++; 93 | 94 | if (testItem.isExampleOf(sequenceLabel.getLabel())) { 95 | correct++; 96 | } 97 | 98 | this.computed = false; 99 | } 100 | } 101 | 102 | } 103 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/algorithms/binary/liblinear/LibLinearDenseVsSparseClassificationEvaluator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.algorithms.binary.liblinear; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.data.manipulator.NormalizationManipolator; 22 | import it.uniroma2.sag.kelp.data.manipulator.VectorConcatenationManipulator; 23 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.LibLinearLearningAlgorithm; 24 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 25 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassificationOutput; 26 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassifier; 27 | import it.uniroma2.sag.kelp.utils.evaluation.MulticlassClassificationEvaluator; 28 | import it.uniroma2.sag.kelp.utils.exception.NoSuchPerformanceMeasureException; 29 | 30 | import java.io.FileNotFoundException; 31 | import java.io.UnsupportedEncodingException; 32 | import java.util.ArrayList; 33 | import java.util.List; 34 | 35 | import org.junit.Assert; 36 | import org.junit.Test; 37 | 38 | public class LibLinearDenseVsSparseClassificationEvaluator { 39 | 40 | private static List sparseScores = new ArrayList(); 41 | private static List denseScores = new ArrayList(); 42 | 43 | @Test 44 | public void testConsistency() { 45 | try { 46 | String inputFilePath = "src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz"; 47 | 48 | SimpleDataset dataset = new SimpleDataset(); 49 | dataset.populate(inputFilePath); 50 | SimpleDataset[] split = dataset.split(0.5f); 51 | 52 | SimpleDataset trainingSet = split[0]; 53 | SimpleDataset testSet = split[1]; 54 | float c = 1.0f; 55 | float f1Dense = testDense(trainingSet, c, testSet); 56 | float f1Sparse = testSparse(trainingSet, c, testSet); 57 | 58 | Assert.assertEquals(f1Sparse, f1Dense, 0.000001); 59 | 60 | for (int i = 0; i < sparseScores.size(); i++) { 61 | Assert.assertEquals(sparseScores.get(i), denseScores.get(i), 62 | 0.000001); 63 | } 64 | } catch (FileNotFoundException e) { 65 | e.printStackTrace(); 66 | Assert.assertTrue(false); 67 | } catch (UnsupportedEncodingException e) { 68 | e.printStackTrace(); 69 | Assert.assertTrue(false); 70 | } catch (NoSuchPerformanceMeasureException e) { 71 | e.printStackTrace(); 72 | Assert.assertTrue(false); 73 | } catch (Exception e) { 74 | e.printStackTrace(); 75 | Assert.assertTrue(false); 76 | } 77 | } 78 | 79 | private static float testSparse(SimpleDataset trainingSet, float c, 80 | SimpleDataset testSet) throws FileNotFoundException, 81 | UnsupportedEncodingException, NoSuchPerformanceMeasureException { 82 | List classes = trainingSet.getClassificationLabels(); 83 | NormalizationManipolator norma = new NormalizationManipolator(); 84 | trainingSet.manipulate(norma); 85 | testSet.manipulate(norma); 86 | List repr = new ArrayList(); 87 | repr.add("WS"); 88 | List reprW = new ArrayList(); 89 | reprW.add(1.0f); 90 | VectorConcatenationManipulator man = new VectorConcatenationManipulator( 91 | "WS0", repr, reprW); 92 | trainingSet.manipulate(man); 93 | testSet.manipulate(man); 94 | 95 | LibLinearLearningAlgorithm svmSolver = new LibLinearLearningAlgorithm(); 96 | svmSolver.setCn(c); 97 | svmSolver.setCp(c); 98 | svmSolver.setRepresentation("WS0"); 99 | 100 | OneVsAllLearning ovaLearner = new OneVsAllLearning(); 101 | ovaLearner.setBaseAlgorithm(svmSolver); 102 | ovaLearner.setLabels(classes); 103 | ovaLearner.learn(trainingSet); 104 | OneVsAllClassifier f = ovaLearner.getPredictionFunction(); 105 | MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator( 106 | trainingSet.getClassificationLabels()); 107 | for (Example e : testSet.getExamples()) { 108 | OneVsAllClassificationOutput predict = f.predict(e); 109 | Label l = predict.getPredictedClasses().get(0); 110 | evaluator.addCount(e, predict); 111 | sparseScores.add(predict.getScore(l)); 112 | } 113 | 114 | return evaluator.getMacroF1(); 115 | } 116 | 117 | private static float testDense(SimpleDataset trainingSet, float c, 118 | SimpleDataset testSet) throws FileNotFoundException, 119 | UnsupportedEncodingException, NoSuchPerformanceMeasureException { 120 | List classes = trainingSet.getClassificationLabels(); 121 | 122 | LibLinearLearningAlgorithm svmSolver = new LibLinearLearningAlgorithm(); 123 | svmSolver.setCn(c); 124 | svmSolver.setCp(c); 125 | svmSolver.setRepresentation("WS"); 126 | 127 | OneVsAllLearning ovaLearner = new OneVsAllLearning(); 128 | ovaLearner.setBaseAlgorithm(svmSolver); 129 | ovaLearner.setLabels(classes); 130 | ovaLearner.learn(trainingSet); 131 | OneVsAllClassifier f = ovaLearner.getPredictionFunction(); 132 | MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator( 133 | trainingSet.getClassificationLabels()); 134 | for (Example e : testSet.getExamples()) { 135 | OneVsAllClassificationOutput predict = f.predict(e); 136 | Label l = predict.getPredictedClasses().get(0); 137 | evaluator.addCount(e, predict); 138 | denseScores.add(predict.getScore(l)); 139 | } 140 | 141 | return evaluator.getMacroF1(); 142 | } 143 | 144 | } 145 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/algorithms/incrementalTrain/IncrementalTrainTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Simone Filice and Giuseppe Castellucci and Danilo Croce 3 | * and Giovanni Da San Martino and Alessandro Moschitti and Roberto Basili 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package it.uniroma2.sag.kelp.algorithms.incrementalTrain; 18 | 19 | import java.io.IOException; 20 | import java.util.Random; 21 | 22 | import org.junit.Assert; 23 | import org.junit.BeforeClass; 24 | import org.junit.Test; 25 | 26 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 27 | import it.uniroma2.sag.kelp.data.example.Example; 28 | import it.uniroma2.sag.kelp.data.label.Label; 29 | import it.uniroma2.sag.kelp.data.label.StringLabel; 30 | import it.uniroma2.sag.kelp.kernel.Kernel; 31 | import it.uniroma2.sag.kelp.kernel.cache.FixSizeKernelCache; 32 | import it.uniroma2.sag.kelp.kernel.vector.LinearKernel; 33 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 34 | import it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron.KernelizedPerceptron; 35 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryKernelMachineClassifier; 36 | import it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput; 37 | import it.uniroma2.sag.kelp.predictionfunction.classifier.Classifier; 38 | import it.uniroma2.sag.kelp.utils.JacksonSerializerWrapper; 39 | import it.uniroma2.sag.kelp.utils.ObjectSerializer; 40 | 41 | public class IncrementalTrainTest { 42 | private static Classifier f = null; 43 | private static SimpleDataset trainingSet; 44 | private static SimpleDataset testSet; 45 | private static SimpleDataset [] folds; 46 | private static ObjectSerializer serializer = new JacksonSerializerWrapper(); 47 | private static KernelizedPerceptron learner; 48 | 49 | private static Label positiveClass = new StringLabel("+1"); 50 | 51 | @BeforeClass 52 | public static void learnModel() { 53 | trainingSet = new SimpleDataset(); 54 | testSet = new SimpleDataset(); 55 | try { 56 | trainingSet.populate("src/test/resources/svmTest/binary/binary_train.klp"); 57 | trainingSet.shuffleExamples(new Random()); 58 | // Read a dataset into a test variable 59 | testSet.populate("src/test/resources/svmTest/binary/binary_test.klp"); 60 | } catch (Exception e) { 61 | e.printStackTrace(); 62 | Assert.assertTrue(false); 63 | } 64 | 65 | folds = trainingSet.nFolding(2); 66 | 67 | // define the kernel 68 | Kernel kernel = new LinearKernel("0"); 69 | 70 | // add a cache 71 | kernel.setKernelCache(new FixSizeKernelCache(trainingSet 72 | .getNumberOfExamples())); 73 | 74 | // define the learning algorithm 75 | learner = new KernelizedPerceptron(0.2f, 1f, false, kernel, positiveClass); 76 | 77 | // learn and get the prediction function 78 | learner.learn(trainingSet); 79 | f = learner.getPredictionFunction(); 80 | } 81 | 82 | @Test 83 | public void incrementalTrain() throws IOException{ 84 | String jsonSerialization = serializer.writeValueAsString(learner); 85 | System.out.println(jsonSerialization); 86 | ClassificationLearningAlgorithm jsonAlgo = serializer.readValue(jsonSerialization, ClassificationLearningAlgorithm.class); 87 | jsonAlgo.learn(folds[0]); 88 | jsonAlgo.learn(folds[1]); 89 | Classifier jsonClassifier = jsonAlgo.getPredictionFunction(); 90 | 91 | for(Example ex : testSet.getExamples()){ 92 | ClassificationOutput p = f.predict(ex); 93 | Float score = p.getScore(positiveClass); 94 | ClassificationOutput pJson = jsonClassifier.predict(ex); 95 | Float scoreJson = pJson.getScore(positiveClass); 96 | Assert.assertEquals(scoreJson.floatValue(), score.floatValue(), 97 | 0.001f); 98 | } 99 | } 100 | 101 | @Test 102 | public void reloadAndContinueTraining() throws IOException{ 103 | String jsonLearnerSerialization = serializer.writeValueAsString(learner); 104 | System.out.println(jsonLearnerSerialization); 105 | KernelizedPerceptron jsonAlgo = serializer.readValue(jsonLearnerSerialization, KernelizedPerceptron.class); 106 | jsonAlgo.learn(folds[0]); 107 | String jsonClassifierSerialization = serializer.writeValueAsString(jsonAlgo.getPredictionFunction()); 108 | jsonAlgo = serializer.readValue(jsonLearnerSerialization, KernelizedPerceptron.class); //Brand new classifier 109 | BinaryKernelMachineClassifier jsonClassifier = serializer.readValue(jsonClassifierSerialization, BinaryKernelMachineClassifier.class); 110 | jsonAlgo.getPredictionFunction().setModel(jsonClassifier.getModel()); 111 | jsonAlgo.learn(folds[1]); 112 | jsonClassifier = jsonAlgo.getPredictionFunction(); 113 | 114 | for(Example ex : testSet.getExamples()){ 115 | ClassificationOutput p = f.predict(ex); 116 | Float score = p.getScore(positiveClass); 117 | ClassificationOutput pJson = jsonClassifier.predict(ex); 118 | Float scoreJson = pJson.getScore(positiveClass); 119 | Assert.assertEquals(scoreJson.floatValue(), score.floatValue(), 120 | 0.001f); 121 | } 122 | } 123 | 124 | } 125 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/learningalgorithm/classification/hmm/SequenceLearningLinearTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.hmm; 17 | 18 | import java.io.BufferedReader; 19 | import java.io.FileInputStream; 20 | import java.io.FileNotFoundException; 21 | import java.io.IOException; 22 | import java.io.InputStreamReader; 23 | import java.io.UnsupportedEncodingException; 24 | import java.util.ArrayList; 25 | import java.util.zip.GZIPInputStream; 26 | 27 | import org.junit.Assert; 28 | import org.junit.Test; 29 | 30 | import it.uniroma2.sag.kelp.data.dataset.SequenceDataset; 31 | import it.uniroma2.sag.kelp.data.example.Example; 32 | import it.uniroma2.sag.kelp.data.example.ParsingExampleException; 33 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 34 | import it.uniroma2.sag.kelp.data.example.SequencePath; 35 | import it.uniroma2.sag.kelp.data.label.Label; 36 | import it.uniroma2.sag.kelp.data.label.StringLabel; 37 | import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLearningAlgorithm; 38 | import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLoss; 39 | import it.uniroma2.sag.kelp.predictionfunction.SequencePrediction; 40 | import it.uniroma2.sag.kelp.predictionfunction.SequencePredictionFunction; 41 | 42 | public class SequenceLearningLinearTest { 43 | 44 | private static final Float TOLERANCE = 0.001f; 45 | 46 | public static void main(String[] args) throws Exception { 47 | 48 | } 49 | 50 | @Test 51 | public void testLinear() { 52 | 53 | String inputTrainFilePath = "src/test/resources/sequence_learning/declaration_of_independence.klp.gz"; 54 | String inputTestFilePath = "src/test/resources/sequence_learning/gettysburg_address.klp.gz"; 55 | String scoreFilePath = "src/test/resources/sequence_learning/prediction_test_linear.txt"; 56 | 57 | /* 58 | * Given a targeted item in the sequence, this variable determines the 59 | * number of previous example considered in the learning/labeling 60 | * process. 61 | * 62 | * NOTE: if this variable is set to 0, the learning process corresponds 63 | * to a traditional multi-class classification schema 64 | */ 65 | int transitionsOrder = 1; 66 | 67 | /* 68 | * This variable determines the importance of the transition-based 69 | * features during the learning process. Higher valuers will assign more 70 | * importance to the transitions. 71 | */ 72 | float weight = 1f; 73 | 74 | /* 75 | * The size of the beam to be used in the decoding process. This number 76 | * determines the number of possible sequences produced in the labeling 77 | * process. It will also increase the process complexity. 78 | */ 79 | int beamSize = 5; 80 | 81 | /* 82 | * During the labeling process, each item is classified with respect to 83 | * the target classes. To reduce the complexity of the labeling process, 84 | * this variable determines the number of classes that received the 85 | * highest classification scores to be considered after the 86 | * classification step in the Viterbi Decoding. 87 | */ 88 | int maxEmissionCandidates = 3; 89 | 90 | /* 91 | * This representation contains the feature vector representing items in 92 | * the sequence 93 | */ 94 | String originalRepresentationName = "rep"; 95 | 96 | /* 97 | * Loading the training dataset 98 | */ 99 | SequenceDataset sequenceTrainDataset = new SequenceDataset(); 100 | try { 101 | sequenceTrainDataset.populate(inputTrainFilePath); 102 | } catch (IOException e) { 103 | e.printStackTrace(); 104 | Assert.assertTrue(false); 105 | } catch (InstantiationException e) { 106 | e.printStackTrace(); 107 | Assert.assertTrue(false); 108 | } catch (ParsingExampleException e) { 109 | e.printStackTrace(); 110 | Assert.assertTrue(false); 111 | } catch (Exception e) { 112 | e.printStackTrace(); 113 | Assert.assertTrue(false); 114 | } 115 | 116 | /* 117 | * Instance classifier 118 | */ 119 | float cSVM = 1f; 120 | DCDLearningAlgorithm instanceClassifierLearningAlgorithm = new DCDLearningAlgorithm(cSVM, cSVM, DCDLoss.L1, 121 | false, 50, originalRepresentationName); 122 | 123 | /* 124 | * Sequence classifier. 125 | */ 126 | SequenceClassificationLearningAlgorithm sequenceClassificationLearningAlgorithm = null; 127 | try { 128 | sequenceClassificationLearningAlgorithm = new SequenceClassificationLinearLearningAlgorithm( 129 | instanceClassifierLearningAlgorithm, transitionsOrder, weight); 130 | sequenceClassificationLearningAlgorithm.setMaxEmissionCandidates(maxEmissionCandidates); 131 | sequenceClassificationLearningAlgorithm.setBeamSize(beamSize); 132 | 133 | sequenceClassificationLearningAlgorithm.learn(sequenceTrainDataset); 134 | } catch (Exception e1) { 135 | e1.printStackTrace(); 136 | Assert.assertTrue(false); 137 | } 138 | 139 | SequencePredictionFunction predictionFunction = (SequencePredictionFunction) sequenceClassificationLearningAlgorithm 140 | .getPredictionFunction(); 141 | 142 | /* 143 | * Load the test set 144 | */ 145 | SequenceDataset sequenceTestDataset = new SequenceDataset(); 146 | try { 147 | sequenceTestDataset.populate(inputTestFilePath); 148 | } catch (IOException e) { 149 | e.printStackTrace(); 150 | Assert.assertTrue(false); 151 | } catch (InstantiationException e) { 152 | e.printStackTrace(); 153 | Assert.assertTrue(false); 154 | } catch (ParsingExampleException e) { 155 | e.printStackTrace(); 156 | Assert.assertTrue(false); 157 | } 158 | 159 | /* 160 | * Tagging and evaluating 161 | */ 162 | // PrintStream ps = new PrintStream(scoreFilePath); 163 | ArrayList labels = new ArrayList(); 164 | ArrayList scores = new ArrayList(); 165 | for (Example example : sequenceTestDataset.getExamples()) { 166 | 167 | SequenceExample sequenceExample = (SequenceExample) example; 168 | SequencePrediction sequencePrediction = (SequencePrediction) predictionFunction.predict(sequenceExample); 169 | 170 | SequencePath bestPath = sequencePrediction.bestPath(); 171 | for (int i = 0; i < sequenceExample.getLenght(); i++) { 172 | // ps.println(bestPath.getAssignedLabel(i) + "\t" + 173 | // bestPath.getScore()); 174 | labels.add(bestPath.getAssignedLabel(i)); 175 | scores.add(bestPath.getScore()); 176 | } 177 | 178 | } 179 | // ps.close(); 180 | 181 | ArrayList oldScores = loadScores(scoreFilePath); 182 | ArrayList oldLabels = loadLabels(scoreFilePath); 183 | 184 | for (int i = 0; i < oldScores.size(); i++) { 185 | Assert.assertEquals(oldScores.get(i), scores.get(i), TOLERANCE); 186 | Assert.assertEquals(labels.get(i).toString(), oldLabels.get(i).toString()); 187 | } 188 | 189 | } 190 | 191 | public static ArrayList loadScores(String filepath) { 192 | try { 193 | ArrayList scores = new ArrayList(); 194 | BufferedReader in = null; 195 | String encoding = "UTF-8"; 196 | if (filepath.endsWith(".gz")) { 197 | in = new BufferedReader( 198 | new InputStreamReader(new GZIPInputStream(new FileInputStream(filepath)), encoding)); 199 | } else { 200 | in = new BufferedReader(new InputStreamReader(new FileInputStream(filepath), encoding)); 201 | } 202 | 203 | String str = ""; 204 | while ((str = in.readLine()) != null) { 205 | scores.add(Double.parseDouble(str.split("\t")[1])); 206 | } 207 | 208 | in.close(); 209 | 210 | return scores; 211 | 212 | } catch (UnsupportedEncodingException e) { 213 | e.printStackTrace(); 214 | Assert.assertTrue(false); 215 | } catch (FileNotFoundException e) { 216 | e.printStackTrace(); 217 | Assert.assertTrue(false); 218 | } catch (IOException e) { 219 | e.printStackTrace(); 220 | Assert.assertTrue(false); 221 | } 222 | 223 | return null; 224 | } 225 | 226 | public static ArrayList loadLabels(String filepath) { 227 | try { 228 | ArrayList res = new ArrayList(); 229 | BufferedReader in = null; 230 | String encoding = "UTF-8"; 231 | if (filepath.endsWith(".gz")) { 232 | in = new BufferedReader( 233 | new InputStreamReader(new GZIPInputStream(new FileInputStream(filepath)), encoding)); 234 | } else { 235 | in = new BufferedReader(new InputStreamReader(new FileInputStream(filepath), encoding)); 236 | } 237 | 238 | String str = ""; 239 | while ((str = in.readLine()) != null) { 240 | res.add(new StringLabel(str.split("\t")[0])); 241 | } 242 | 243 | in.close(); 244 | 245 | return res; 246 | 247 | } catch (UnsupportedEncodingException e) { 248 | e.printStackTrace(); 249 | Assert.assertTrue(false); 250 | } catch (FileNotFoundException e) { 251 | e.printStackTrace(); 252 | Assert.assertTrue(false); 253 | } catch (IOException e) { 254 | e.printStackTrace(); 255 | Assert.assertTrue(false); 256 | } 257 | 258 | return null; 259 | } 260 | 261 | } 262 | -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/README.txt: -------------------------------------------------------------------------------- 1 | The datasets reported in this folder have been created starting from the dataset produced by Thorsten Joachims as an example problem for his SVM^{hmm} implementation. 2 | 3 | The original dataset can be downloaded at: 4 | http://download.joachims.org/svm_hmm/examples/example7.tar.gz 5 | while its description is reported at: 6 | https://www.cs.cornell.edu/people/tj/svm_light/svm_hmm.html -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/declaration_of_independence.klp.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/sequence_learning/declaration_of_independence.klp.gz -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/gettysburg_address.klp.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/sequence_learning/gettysburg_address.klp.gz -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/prediction_test_kernel.txt: -------------------------------------------------------------------------------- 1 | 1 -58.85170393685728 2 | 22 -58.85170393685728 3 | 3 -58.85170393685728 4 | 8 -58.85170393685728 5 | 15 -58.85170393685728 6 | 30 -58.85170393685728 7 | 20 -58.85170393685728 8 | 15 -58.85170393685728 9 | 11 -58.85170393685728 10 | 12 -58.85170393685728 11 | 8 -58.85170393685728 12 | 7 -58.85170393685728 13 | 12 -58.85170393685728 14 | 7 -58.85170393685728 15 | 9 -58.85170393685728 16 | 12 -58.85170393685728 17 | 6 -58.85170393685728 18 | 27 -58.85170393685728 19 | 8 -58.85170393685728 20 | 13 -58.85170393685728 21 | 3 -58.85170393685728 22 | 27 -58.85170393685728 23 | 25 -58.85170393685728 24 | 7 -58.85170393685728 25 | 12 -58.85170393685728 26 | 7 -58.85170393685728 27 | 7 -58.85170393685728 28 | 15 -58.85170393685728 29 | 30 -58.85170393685728 30 | 29 -58.85170393685728 31 | 9 -58.85170393685728 32 | 17 -58.85170393685728 33 | 21 -49.958351223707055 34 | 19 -49.958351223707055 35 | 30 -49.958351223707055 36 | 29 -49.958351223707055 37 | 8 -49.958351223707055 38 | 7 -49.958351223707055 39 | 13 -49.958351223707055 40 | 9 -49.958351223707055 41 | 12 -49.958351223707055 42 | 6 -49.958351223707055 43 | 28 -49.958351223707055 44 | 21 -49.958351223707055 45 | 7 -49.958351223707055 46 | 12 -49.958351223707055 47 | 8 -49.958351223707055 48 | 7 -49.958351223707055 49 | 12 -49.958351223707055 50 | 8 -49.958351223707055 51 | 29 -49.958351223707055 52 | 3 -49.958351223707055 53 | 21 -49.958351223707055 54 | 27 -49.958351223707055 55 | 9 -49.958351223707055 56 | 9 -49.958351223707055 57 | 12 -49.958351223707055 58 | 17 -49.958351223707055 59 | 19 -19.14935390144825 60 | 30 -19.14935390144825 61 | 29 -19.14935390144825 62 | 8 -19.14935390144825 63 | 7 -19.14935390144825 64 | 13 -19.14935390144825 65 | 13 -19.14935390144825 66 | 8 -19.14935390144825 67 | 7 -19.14935390144825 68 | 12 -19.14935390144825 69 | 17 -19.14935390144825 70 | 19 -51.68865761583535 71 | 30 -51.68865761583535 72 | 9 -51.68865761583535 73 | 25 -51.68865761583535 74 | 26 -51.68865761583535 75 | 7 -51.68865761583535 76 | 12 -51.68865761583535 77 | 8 -51.68865761583535 78 | 7 -51.68865761583535 79 | 12 -51.68865761583535 80 | 8 -51.68865761583535 81 | 7 -51.68865761583535 82 | 9 -51.68865761583535 83 | 12 -51.68865761583535 84 | 8 -51.68865761583535 85 | 7 -51.68865761583535 86 | 9 -51.68865761583535 87 | 12 -51.68865761583535 88 | 30 -51.68865761583535 89 | 20 -51.68865761583535 90 | 15 -51.68865761583535 91 | 7 -51.68865761583535 92 | 7 -51.68865761583535 93 | 12 -51.68865761583535 94 | 12 -51.68865761583535 95 | 30 -51.68865761583535 96 | 17 -51.68865761583535 97 | 19 -20.546289531993914 98 | 31 -20.546289531993914 99 | 21 -20.546289531993914 100 | 28 -20.546289531993914 101 | 3 -20.546289531993914 102 | 8 -20.546289531993914 103 | 7 -20.546289531993914 104 | 19 -20.546289531993914 105 | 11 -20.546289531993914 106 | 26 -20.546289531993914 107 | 7 -20.546289531993914 108 | 17 -20.546289531993914 109 | 3 -39.18519755398995 110 | 8 -39.18519755398995 111 | 7 -39.18519755398995 112 | 9 -39.18519755398995 113 | 12 -39.18519755398995 114 | 6 -39.18519755398995 115 | 19 -39.18519755398995 116 | 11 -39.18519755398995 117 | 26 -39.18519755398995 118 | 6 -39.18519755398995 119 | 19 -39.18519755398995 120 | 11 -39.18519755398995 121 | 26 -39.18519755398995 122 | 6 -39.18519755398995 123 | 19 -39.18519755398995 124 | 11 -39.18519755398995 125 | 26 -39.18519755398995 126 | 7 -39.18519755398995 127 | 12 -39.18519755398995 128 | 17 -39.18519755398995 129 | 7 -46.958838324933005 130 | 9 -46.958838324933005 131 | 15 -46.958838324933005 132 | 6 -46.958838324933005 133 | 28 -46.958838324933005 134 | 3 -46.958838324933005 135 | 12 -46.958838324933005 136 | 21 -46.958838324933005 137 | 29 -46.958838324933005 138 | 21 -46.958838324933005 139 | 30 -46.958838324933005 140 | 29 -46.958838324933005 141 | 19 -46.958838324933005 142 | 20 -46.958838324933005 143 | 12 -46.958838324933005 144 | 20 -46.958838324933005 145 | 9 -46.958838324933005 146 | 12 -46.958838324933005 147 | 25 -46.958838324933005 148 | 26 -46.958838324933005 149 | 3 -46.958838324933005 150 | 9 -46.958838324933005 151 | 17 -46.958838324933005 152 | 7 -49.424689389703104 153 | 12 -49.424689389703104 154 | 11 -49.424689389703104 155 | 26 -49.424689389703104 156 | 9 -49.424689389703104 157 | 3 -49.424689389703104 158 | 9 -49.424689389703104 159 | 12 -49.424689389703104 160 | 8 -49.424689389703104 161 | 19 -49.424689389703104 162 | 11 -49.424689389703104 163 | 21 -49.424689389703104 164 | 6 -49.424689389703104 165 | 3 -49.424689389703104 166 | 19 -49.424689389703104 167 | 8 -49.424689389703104 168 | 12 -49.424689389703104 169 | 26 -49.424689389703104 170 | 7 -49.424689389703104 171 | 19 -49.424689389703104 172 | 30 -49.424689389703104 173 | 21 -49.424689389703104 174 | 17 -49.424689389703104 175 | 19 -52.114114669781316 176 | 31 -52.114114669781316 177 | 8 -52.114114669781316 178 | 19 -52.114114669781316 179 | 7 -52.114114669781316 180 | 28 -52.114114669781316 181 | 21 -52.114114669781316 182 | 25 -52.114114669781316 183 | 26 -52.114114669781316 184 | 27 -52.114114669781316 185 | 21 -52.114114669781316 186 | 25 -52.114114669781316 187 | 7 -52.114114669781316 188 | 9 -52.114114669781316 189 | 12 -52.114114669781316 190 | 32 -52.114114669781316 191 | 19 -52.114114669781316 192 | 30 -52.114114669781316 193 | 11 -52.114114669781316 194 | 21 -52.114114669781316 195 | 30 -52.114114669781316 196 | 21 -52.114114669781316 197 | 29 -52.114114669781316 198 | 21 -52.114114669781316 199 | 21 -52.114114669781316 200 | 29 -52.114114669781316 201 | 17 -52.114114669781316 202 | 19 -157.40459068974272 203 | 31 -157.40459068974272 204 | 21 -157.40459068974272 205 | 8 -157.40459068974272 206 | 19 -157.40459068974272 207 | 25 -157.40459068974272 208 | 26 -157.40459068974272 209 | 21 -157.40459068974272 210 | 27 -157.40459068974272 211 | 25 -157.40459068974272 212 | 7 -157.40459068974272 213 | 13 -157.40459068974272 214 | 13 -157.40459068974272 215 | 28 -157.40459068974272 216 | 21 -157.40459068974272 217 | 19 -157.40459068974272 218 | 5 -157.40459068974272 219 | 7 -157.40459068974272 220 | 8 -157.40459068974272 221 | 7 -157.40459068974272 222 | 9 -157.40459068974272 223 | 12 -157.40459068974272 224 | 19 -157.40459068974272 225 | 30 -157.40459068974272 226 | 9 -157.40459068974272 227 | 12 -157.40459068974272 228 | 25 -157.40459068974272 229 | 7 -157.40459068974272 230 | 12 -157.40459068974272 231 | 8 -157.40459068974272 232 | 32 -157.40459068974272 233 | 19 -157.40459068974272 234 | 30 -157.40459068974272 235 | 7 -157.40459068974272 236 | 23 -157.40459068974272 237 | 9 -157.40459068974272 238 | 12 -157.40459068974272 239 | 8 -157.40459068974272 240 | 12 -157.40459068974272 241 | 8 -157.40459068974272 242 | 7 -157.40459068974272 243 | 19 -157.40459068974272 244 | 21 -157.40459068974272 245 | 21 -157.40459068974272 246 | 26 -157.40459068974272 247 | 7 -157.40459068974272 248 | 7 -157.40459068974272 249 | 12 -157.40459068974272 250 | 11 -157.40459068974272 251 | 21 -157.40459068974272 252 | 30 -157.40459068974272 253 | 29 -157.40459068974272 254 | 8 -157.40459068974272 255 | 9 -157.40459068974272 256 | 6 -157.40459068974272 257 | 7 -157.40459068974272 258 | 7 -157.40459068974272 259 | 12 -157.40459068974272 260 | 8 -157.40459068974272 261 | 13 -157.40459068974272 262 | 11 -157.40459068974272 263 | 30 -157.40459068974272 264 | 7 -157.40459068974272 265 | 9 -157.40459068974272 266 | 12 -157.40459068974272 267 | 8 -157.40459068974272 268 | 12 -157.40459068974272 269 | 6 -157.40459068974272 270 | 3 -157.40459068974272 271 | 7 -157.40459068974272 272 | 12 -157.40459068974272 273 | 8 -157.40459068974272 274 | 7 -157.40459068974272 275 | 15 -157.40459068974272 276 | 6 -157.40459068974272 277 | 8 -157.40459068974272 278 | 7 -157.40459068974272 279 | 15 -157.40459068974272 280 | 6 -157.40459068974272 281 | 8 -157.40459068974272 282 | 7 -157.40459068974272 283 | 15 -157.40459068974272 284 | 11 -157.40459068974272 285 | 21 -157.40459068974272 286 | 26 -157.40459068974272 287 | 8 -157.40459068974272 288 | 7 -157.40459068974272 289 | 12 -157.40459068974272 290 | 17 -157.40459068974272 291 | -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/prediction_test_linear.txt: -------------------------------------------------------------------------------- 1 | 1 -61.552865965064605 2 | 22 -61.552865965064605 3 | 3 -61.552865965064605 4 | 8 -61.552865965064605 5 | 15 -61.552865965064605 6 | 30 -61.552865965064605 7 | 20 -61.552865965064605 8 | 15 -61.552865965064605 9 | 11 -61.552865965064605 10 | 12 -61.552865965064605 11 | 8 -61.552865965064605 12 | 7 -61.552865965064605 13 | 12 -61.552865965064605 14 | 7 -61.552865965064605 15 | 9 -61.552865965064605 16 | 12 -61.552865965064605 17 | 6 -61.552865965064605 18 | 27 -61.552865965064605 19 | 8 -61.552865965064605 20 | 13 -61.552865965064605 21 | 3 -61.552865965064605 22 | 27 -61.552865965064605 23 | 25 -61.552865965064605 24 | 7 -61.552865965064605 25 | 12 -61.552865965064605 26 | 7 -61.552865965064605 27 | 7 -61.552865965064605 28 | 15 -61.552865965064605 29 | 30 -61.552865965064605 30 | 29 -61.552865965064605 31 | 9 -61.552865965064605 32 | 17 -61.552865965064605 33 | 21 -50.586976361817456 34 | 19 -50.586976361817456 35 | 30 -50.586976361817456 36 | 29 -50.586976361817456 37 | 8 -50.586976361817456 38 | 7 -50.586976361817456 39 | 13 -50.586976361817456 40 | 9 -50.586976361817456 41 | 12 -50.586976361817456 42 | 6 -50.586976361817456 43 | 28 -50.586976361817456 44 | 21 -50.586976361817456 45 | 7 -50.586976361817456 46 | 12 -50.586976361817456 47 | 3 -50.586976361817456 48 | 7 -50.586976361817456 49 | 12 -50.586976361817456 50 | 21 -50.586976361817456 51 | 21 -50.586976361817456 52 | 3 -50.586976361817456 53 | 21 -50.586976361817456 54 | 21 -50.586976361817456 55 | 7 -50.586976361817456 56 | 9 -50.586976361817456 57 | 12 -50.586976361817456 58 | 17 -50.586976361817456 59 | 19 -19.745991163812985 60 | 30 -19.745991163812985 61 | 26 -19.745991163812985 62 | 8 -19.745991163812985 63 | 7 -19.745991163812985 64 | 13 -19.745991163812985 65 | 13 -19.745991163812985 66 | 8 -19.745991163812985 67 | 7 -19.745991163812985 68 | 12 -19.745991163812985 69 | 17 -19.745991163812985 70 | 19 -54.528993898737625 71 | 30 -54.528993898737625 72 | 9 -54.528993898737625 73 | 25 -54.528993898737625 74 | 26 -54.528993898737625 75 | 7 -54.528993898737625 76 | 12 -54.528993898737625 77 | 8 -54.528993898737625 78 | 7 -54.528993898737625 79 | 12 -54.528993898737625 80 | 8 -54.528993898737625 81 | 7 -54.528993898737625 82 | 9 -54.528993898737625 83 | 12 -54.528993898737625 84 | 8 -54.528993898737625 85 | 7 -54.528993898737625 86 | 30 -54.528993898737625 87 | 21 -54.528993898737625 88 | 30 -54.528993898737625 89 | 20 -54.528993898737625 90 | 15 -54.528993898737625 91 | 7 -54.528993898737625 92 | 7 -54.528993898737625 93 | 12 -54.528993898737625 94 | 12 -54.528993898737625 95 | 30 -54.528993898737625 96 | 17 -54.528993898737625 97 | 19 -21.642864657110263 98 | 31 -21.642864657110263 99 | 21 -21.642864657110263 100 | 28 -21.642864657110263 101 | 3 -21.642864657110263 102 | 8 -21.642864657110263 103 | 7 -21.642864657110263 104 | 19 -21.642864657110263 105 | 11 -21.642864657110263 106 | 26 -21.642864657110263 107 | 7 -21.642864657110263 108 | 17 -21.642864657110263 109 | 3 -41.163958681094705 110 | 8 -41.163958681094705 111 | 7 -41.163958681094705 112 | 9 -41.163958681094705 113 | 12 -41.163958681094705 114 | 6 -41.163958681094705 115 | 19 -41.163958681094705 116 | 11 -41.163958681094705 117 | 26 -41.163958681094705 118 | 6 -41.163958681094705 119 | 19 -41.163958681094705 120 | 11 -41.163958681094705 121 | 26 -41.163958681094705 122 | 6 -41.163958681094705 123 | 19 -41.163958681094705 124 | 11 -41.163958681094705 125 | 26 -41.163958681094705 126 | 7 -41.163958681094705 127 | 12 -41.163958681094705 128 | 17 -41.163958681094705 129 | 7 -47.69120077996808 130 | 9 -47.69120077996808 131 | 15 -47.69120077996808 132 | 6 -47.69120077996808 133 | 28 -47.69120077996808 134 | 3 -47.69120077996808 135 | 12 -47.69120077996808 136 | 21 -47.69120077996808 137 | 29 -47.69120077996808 138 | 21 -47.69120077996808 139 | 30 -47.69120077996808 140 | 29 -47.69120077996808 141 | 19 -47.69120077996808 142 | 12 -47.69120077996808 143 | 30 -47.69120077996808 144 | 20 -47.69120077996808 145 | 9 -47.69120077996808 146 | 12 -47.69120077996808 147 | 25 -47.69120077996808 148 | 26 -47.69120077996808 149 | 3 -47.69120077996808 150 | 12 -47.69120077996808 151 | 17 -47.69120077996808 152 | 7 -49.193906625295206 153 | 12 -49.193906625295206 154 | 11 -49.193906625295206 155 | 26 -49.193906625295206 156 | 9 -49.193906625295206 157 | 3 -49.193906625295206 158 | 9 -49.193906625295206 159 | 12 -49.193906625295206 160 | 7 -49.193906625295206 161 | 19 -49.193906625295206 162 | 11 -49.193906625295206 163 | 21 -49.193906625295206 164 | 6 -49.193906625295206 165 | 3 -49.193906625295206 166 | 19 -49.193906625295206 167 | 8 -49.193906625295206 168 | 12 -49.193906625295206 169 | 26 -49.193906625295206 170 | 7 -49.193906625295206 171 | 19 -49.193906625295206 172 | 30 -49.193906625295206 173 | 21 -49.193906625295206 174 | 17 -49.193906625295206 175 | 19 -52.77654733531991 176 | 31 -52.77654733531991 177 | 8 -52.77654733531991 178 | 19 -52.77654733531991 179 | 7 -52.77654733531991 180 | 28 -52.77654733531991 181 | 21 -52.77654733531991 182 | 25 -52.77654733531991 183 | 26 -52.77654733531991 184 | 27 -52.77654733531991 185 | 21 -52.77654733531991 186 | 25 -52.77654733531991 187 | 7 -52.77654733531991 188 | 9 -52.77654733531991 189 | 12 -52.77654733531991 190 | 32 -52.77654733531991 191 | 19 -52.77654733531991 192 | 30 -52.77654733531991 193 | 11 -52.77654733531991 194 | 21 -52.77654733531991 195 | 30 -52.77654733531991 196 | 21 -52.77654733531991 197 | 12 -52.77654733531991 198 | 21 -52.77654733531991 199 | 21 -52.77654733531991 200 | 29 -52.77654733531991 201 | 17 -52.77654733531991 202 | 19 -162.86896772139426 203 | 31 -162.86896772139426 204 | 21 -162.86896772139426 205 | 8 -162.86896772139426 206 | 19 -162.86896772139426 207 | 25 -162.86896772139426 208 | 26 -162.86896772139426 209 | 21 -162.86896772139426 210 | 27 -162.86896772139426 211 | 25 -162.86896772139426 212 | 7 -162.86896772139426 213 | 13 -162.86896772139426 214 | 13 -162.86896772139426 215 | 28 -162.86896772139426 216 | 21 -162.86896772139426 217 | 19 -162.86896772139426 218 | 5 -162.86896772139426 219 | 7 -162.86896772139426 220 | 8 -162.86896772139426 221 | 7 -162.86896772139426 222 | 9 -162.86896772139426 223 | 12 -162.86896772139426 224 | 19 -162.86896772139426 225 | 30 -162.86896772139426 226 | 9 -162.86896772139426 227 | 12 -162.86896772139426 228 | 25 -162.86896772139426 229 | 7 -162.86896772139426 230 | 12 -162.86896772139426 231 | 8 -162.86896772139426 232 | 32 -162.86896772139426 233 | 19 -162.86896772139426 234 | 30 -162.86896772139426 235 | 7 -162.86896772139426 236 | 23 -162.86896772139426 237 | 9 -162.86896772139426 238 | 12 -162.86896772139426 239 | 8 -162.86896772139426 240 | 12 -162.86896772139426 241 | 5 -162.86896772139426 242 | 7 -162.86896772139426 243 | 19 -162.86896772139426 244 | 21 -162.86896772139426 245 | 21 -162.86896772139426 246 | 26 -162.86896772139426 247 | 7 -162.86896772139426 248 | 7 -162.86896772139426 249 | 12 -162.86896772139426 250 | 11 -162.86896772139426 251 | 21 -162.86896772139426 252 | 30 -162.86896772139426 253 | 29 -162.86896772139426 254 | 8 -162.86896772139426 255 | 9 -162.86896772139426 256 | 6 -162.86896772139426 257 | 7 -162.86896772139426 258 | 7 -162.86896772139426 259 | 12 -162.86896772139426 260 | 8 -162.86896772139426 261 | 13 -162.86896772139426 262 | 11 -162.86896772139426 263 | 30 -162.86896772139426 264 | 7 -162.86896772139426 265 | 9 -162.86896772139426 266 | 12 -162.86896772139426 267 | 8 -162.86896772139426 268 | 12 -162.86896772139426 269 | 6 -162.86896772139426 270 | 3 -162.86896772139426 271 | 8 -162.86896772139426 272 | 12 -162.86896772139426 273 | 8 -162.86896772139426 274 | 7 -162.86896772139426 275 | 15 -162.86896772139426 276 | 6 -162.86896772139426 277 | 8 -162.86896772139426 278 | 7 -162.86896772139426 279 | 15 -162.86896772139426 280 | 6 -162.86896772139426 281 | 8 -162.86896772139426 282 | 7 -162.86896772139426 283 | 15 -162.86896772139426 284 | 11 -162.86896772139426 285 | 21 -162.86896772139426 286 | 26 -162.86896772139426 287 | 8 -162.86896772139426 288 | 7 -162.86896772139426 289 | 12 -162.86896772139426 290 | 17 -162.86896772139426 291 | -------------------------------------------------------------------------------- /src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz --------------------------------------------------------------------------------
Representation>
Vector
dataset
SequenceExamples. It 31 | * computes precision, recall, f1s for each class, and a global accuracy. 32 | * 33 | * @author Danilo Croce 34 | */ 35 | public class MulticlassSequenceClassificationEvaluator extends MulticlassClassificationEvaluator{ 36 | 37 | /** 38 | * Initialize a new F1Evaluator that will work on the specified classes 39 | * 40 | * @param labels 41 | */ 42 | public MulticlassSequenceClassificationEvaluator(List labels) { 43 | super(labels); 44 | } 45 | 46 | public void addCount(Example test, Prediction prediction) { 47 | addCount((SequenceExample) test, (SequencePrediction) prediction); 48 | } 49 | 50 | /** 51 | * This method should be implemented in the subclasses to update counters 52 | * useful to compute the performance measure 53 | * 54 | * @param test 55 | * the test example 56 | * @param predicted 57 | * the prediction of the system 58 | */ 59 | public void addCount(SequenceExample test, SequencePrediction predicted) { 60 | 61 | SequencePath bestPath = predicted.bestPath(); 62 | 63 | for (int seqIdx = 0; seqIdx < test.getLenght(); seqIdx++) { 64 | 65 | Example testItem = test.getExample(seqIdx); 66 | SequenceEmission sequenceLabel = bestPath.getAssignedSequnceLabels().get(seqIdx); 67 | 68 | for (Label l : this.labels) { 69 | ClassStats stats = this.classStats.get(l); 70 | if(testItem.isExampleOf(l)){ 71 | if(sequenceLabel.getLabel().equals(l)){ 72 | stats.tp++; 73 | totalTp++; 74 | }else{ 75 | stats.fn++; 76 | totalFn++; 77 | } 78 | }else{ 79 | if(sequenceLabel.getLabel().equals(l)){ 80 | stats.fp++; 81 | totalFp++; 82 | }else{ 83 | stats.tn++; 84 | totalTn++; 85 | } 86 | } 87 | 88 | } 89 | 90 | //TODO: check (i) e' giusto valutare l'accuracy dei singoli elementi della sequenza e non della sequenza completa 91 | //(ii) va considerato il caso multilabel 92 | total++; 93 | 94 | if (testItem.isExampleOf(sequenceLabel.getLabel())) { 95 | correct++; 96 | } 97 | 98 | this.computed = false; 99 | } 100 | } 101 | 102 | } 103 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/algorithms/binary/liblinear/LibLinearDenseVsSparseClassificationEvaluator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.algorithms.binary.liblinear; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.data.manipulator.NormalizationManipolator; 22 | import it.uniroma2.sag.kelp.data.manipulator.VectorConcatenationManipulator; 23 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.LibLinearLearningAlgorithm; 24 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 25 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassificationOutput; 26 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassifier; 27 | import it.uniroma2.sag.kelp.utils.evaluation.MulticlassClassificationEvaluator; 28 | import it.uniroma2.sag.kelp.utils.exception.NoSuchPerformanceMeasureException; 29 | 30 | import java.io.FileNotFoundException; 31 | import java.io.UnsupportedEncodingException; 32 | import java.util.ArrayList; 33 | import java.util.List; 34 | 35 | import org.junit.Assert; 36 | import org.junit.Test; 37 | 38 | public class LibLinearDenseVsSparseClassificationEvaluator { 39 | 40 | private static List sparseScores = new ArrayList(); 41 | private static List denseScores = new ArrayList(); 42 | 43 | @Test 44 | public void testConsistency() { 45 | try { 46 | String inputFilePath = "src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz"; 47 | 48 | SimpleDataset dataset = new SimpleDataset(); 49 | dataset.populate(inputFilePath); 50 | SimpleDataset[] split = dataset.split(0.5f); 51 | 52 | SimpleDataset trainingSet = split[0]; 53 | SimpleDataset testSet = split[1]; 54 | float c = 1.0f; 55 | float f1Dense = testDense(trainingSet, c, testSet); 56 | float f1Sparse = testSparse(trainingSet, c, testSet); 57 | 58 | Assert.assertEquals(f1Sparse, f1Dense, 0.000001); 59 | 60 | for (int i = 0; i < sparseScores.size(); i++) { 61 | Assert.assertEquals(sparseScores.get(i), denseScores.get(i), 62 | 0.000001); 63 | } 64 | } catch (FileNotFoundException e) { 65 | e.printStackTrace(); 66 | Assert.assertTrue(false); 67 | } catch (UnsupportedEncodingException e) { 68 | e.printStackTrace(); 69 | Assert.assertTrue(false); 70 | } catch (NoSuchPerformanceMeasureException e) { 71 | e.printStackTrace(); 72 | Assert.assertTrue(false); 73 | } catch (Exception e) { 74 | e.printStackTrace(); 75 | Assert.assertTrue(false); 76 | } 77 | } 78 | 79 | private static float testSparse(SimpleDataset trainingSet, float c, 80 | SimpleDataset testSet) throws FileNotFoundException, 81 | UnsupportedEncodingException, NoSuchPerformanceMeasureException { 82 | List classes = trainingSet.getClassificationLabels(); 83 | NormalizationManipolator norma = new NormalizationManipolator(); 84 | trainingSet.manipulate(norma); 85 | testSet.manipulate(norma); 86 | List repr = new ArrayList(); 87 | repr.add("WS"); 88 | List reprW = new ArrayList(); 89 | reprW.add(1.0f); 90 | VectorConcatenationManipulator man = new VectorConcatenationManipulator( 91 | "WS0", repr, reprW); 92 | trainingSet.manipulate(man); 93 | testSet.manipulate(man); 94 | 95 | LibLinearLearningAlgorithm svmSolver = new LibLinearLearningAlgorithm(); 96 | svmSolver.setCn(c); 97 | svmSolver.setCp(c); 98 | svmSolver.setRepresentation("WS0"); 99 | 100 | OneVsAllLearning ovaLearner = new OneVsAllLearning(); 101 | ovaLearner.setBaseAlgorithm(svmSolver); 102 | ovaLearner.setLabels(classes); 103 | ovaLearner.learn(trainingSet); 104 | OneVsAllClassifier f = ovaLearner.getPredictionFunction(); 105 | MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator( 106 | trainingSet.getClassificationLabels()); 107 | for (Example e : testSet.getExamples()) { 108 | OneVsAllClassificationOutput predict = f.predict(e); 109 | Label l = predict.getPredictedClasses().get(0); 110 | evaluator.addCount(e, predict); 111 | sparseScores.add(predict.getScore(l)); 112 | } 113 | 114 | return evaluator.getMacroF1(); 115 | } 116 | 117 | private static float testDense(SimpleDataset trainingSet, float c, 118 | SimpleDataset testSet) throws FileNotFoundException, 119 | UnsupportedEncodingException, NoSuchPerformanceMeasureException { 120 | List classes = trainingSet.getClassificationLabels(); 121 | 122 | LibLinearLearningAlgorithm svmSolver = new LibLinearLearningAlgorithm(); 123 | svmSolver.setCn(c); 124 | svmSolver.setCp(c); 125 | svmSolver.setRepresentation("WS"); 126 | 127 | OneVsAllLearning ovaLearner = new OneVsAllLearning(); 128 | ovaLearner.setBaseAlgorithm(svmSolver); 129 | ovaLearner.setLabels(classes); 130 | ovaLearner.learn(trainingSet); 131 | OneVsAllClassifier f = ovaLearner.getPredictionFunction(); 132 | MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator( 133 | trainingSet.getClassificationLabels()); 134 | for (Example e : testSet.getExamples()) { 135 | OneVsAllClassificationOutput predict = f.predict(e); 136 | Label l = predict.getPredictedClasses().get(0); 137 | evaluator.addCount(e, predict); 138 | denseScores.add(predict.getScore(l)); 139 | } 140 | 141 | return evaluator.getMacroF1(); 142 | } 143 | 144 | } 145 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/algorithms/incrementalTrain/IncrementalTrainTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Simone Filice and Giuseppe Castellucci and Danilo Croce 3 | * and Giovanni Da San Martino and Alessandro Moschitti and Roberto Basili 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package it.uniroma2.sag.kelp.algorithms.incrementalTrain; 18 | 19 | import java.io.IOException; 20 | import java.util.Random; 21 | 22 | import org.junit.Assert; 23 | import org.junit.BeforeClass; 24 | import org.junit.Test; 25 | 26 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 27 | import it.uniroma2.sag.kelp.data.example.Example; 28 | import it.uniroma2.sag.kelp.data.label.Label; 29 | import it.uniroma2.sag.kelp.data.label.StringLabel; 30 | import it.uniroma2.sag.kelp.kernel.Kernel; 31 | import it.uniroma2.sag.kelp.kernel.cache.FixSizeKernelCache; 32 | import it.uniroma2.sag.kelp.kernel.vector.LinearKernel; 33 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 34 | import it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron.KernelizedPerceptron; 35 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryKernelMachineClassifier; 36 | import it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput; 37 | import it.uniroma2.sag.kelp.predictionfunction.classifier.Classifier; 38 | import it.uniroma2.sag.kelp.utils.JacksonSerializerWrapper; 39 | import it.uniroma2.sag.kelp.utils.ObjectSerializer; 40 | 41 | public class IncrementalTrainTest { 42 | private static Classifier f = null; 43 | private static SimpleDataset trainingSet; 44 | private static SimpleDataset testSet; 45 | private static SimpleDataset [] folds; 46 | private static ObjectSerializer serializer = new JacksonSerializerWrapper(); 47 | private static KernelizedPerceptron learner; 48 | 49 | private static Label positiveClass = new StringLabel("+1"); 50 | 51 | @BeforeClass 52 | public static void learnModel() { 53 | trainingSet = new SimpleDataset(); 54 | testSet = new SimpleDataset(); 55 | try { 56 | trainingSet.populate("src/test/resources/svmTest/binary/binary_train.klp"); 57 | trainingSet.shuffleExamples(new Random()); 58 | // Read a dataset into a test variable 59 | testSet.populate("src/test/resources/svmTest/binary/binary_test.klp"); 60 | } catch (Exception e) { 61 | e.printStackTrace(); 62 | Assert.assertTrue(false); 63 | } 64 | 65 | folds = trainingSet.nFolding(2); 66 | 67 | // define the kernel 68 | Kernel kernel = new LinearKernel("0"); 69 | 70 | // add a cache 71 | kernel.setKernelCache(new FixSizeKernelCache(trainingSet 72 | .getNumberOfExamples())); 73 | 74 | // define the learning algorithm 75 | learner = new KernelizedPerceptron(0.2f, 1f, false, kernel, positiveClass); 76 | 77 | // learn and get the prediction function 78 | learner.learn(trainingSet); 79 | f = learner.getPredictionFunction(); 80 | } 81 | 82 | @Test 83 | public void incrementalTrain() throws IOException{ 84 | String jsonSerialization = serializer.writeValueAsString(learner); 85 | System.out.println(jsonSerialization); 86 | ClassificationLearningAlgorithm jsonAlgo = serializer.readValue(jsonSerialization, ClassificationLearningAlgorithm.class); 87 | jsonAlgo.learn(folds[0]); 88 | jsonAlgo.learn(folds[1]); 89 | Classifier jsonClassifier = jsonAlgo.getPredictionFunction(); 90 | 91 | for(Example ex : testSet.getExamples()){ 92 | ClassificationOutput p = f.predict(ex); 93 | Float score = p.getScore(positiveClass); 94 | ClassificationOutput pJson = jsonClassifier.predict(ex); 95 | Float scoreJson = pJson.getScore(positiveClass); 96 | Assert.assertEquals(scoreJson.floatValue(), score.floatValue(), 97 | 0.001f); 98 | } 99 | } 100 | 101 | @Test 102 | public void reloadAndContinueTraining() throws IOException{ 103 | String jsonLearnerSerialization = serializer.writeValueAsString(learner); 104 | System.out.println(jsonLearnerSerialization); 105 | KernelizedPerceptron jsonAlgo = serializer.readValue(jsonLearnerSerialization, KernelizedPerceptron.class); 106 | jsonAlgo.learn(folds[0]); 107 | String jsonClassifierSerialization = serializer.writeValueAsString(jsonAlgo.getPredictionFunction()); 108 | jsonAlgo = serializer.readValue(jsonLearnerSerialization, KernelizedPerceptron.class); //Brand new classifier 109 | BinaryKernelMachineClassifier jsonClassifier = serializer.readValue(jsonClassifierSerialization, BinaryKernelMachineClassifier.class); 110 | jsonAlgo.getPredictionFunction().setModel(jsonClassifier.getModel()); 111 | jsonAlgo.learn(folds[1]); 112 | jsonClassifier = jsonAlgo.getPredictionFunction(); 113 | 114 | for(Example ex : testSet.getExamples()){ 115 | ClassificationOutput p = f.predict(ex); 116 | Float score = p.getScore(positiveClass); 117 | ClassificationOutput pJson = jsonClassifier.predict(ex); 118 | Float scoreJson = pJson.getScore(positiveClass); 119 | Assert.assertEquals(scoreJson.floatValue(), score.floatValue(), 120 | 0.001f); 121 | } 122 | } 123 | 124 | } 125 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/learningalgorithm/classification/hmm/SequenceLearningLinearTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.hmm; 17 | 18 | import java.io.BufferedReader; 19 | import java.io.FileInputStream; 20 | import java.io.FileNotFoundException; 21 | import java.io.IOException; 22 | import java.io.InputStreamReader; 23 | import java.io.UnsupportedEncodingException; 24 | import java.util.ArrayList; 25 | import java.util.zip.GZIPInputStream; 26 | 27 | import org.junit.Assert; 28 | import org.junit.Test; 29 | 30 | import it.uniroma2.sag.kelp.data.dataset.SequenceDataset; 31 | import it.uniroma2.sag.kelp.data.example.Example; 32 | import it.uniroma2.sag.kelp.data.example.ParsingExampleException; 33 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 34 | import it.uniroma2.sag.kelp.data.example.SequencePath; 35 | import it.uniroma2.sag.kelp.data.label.Label; 36 | import it.uniroma2.sag.kelp.data.label.StringLabel; 37 | import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLearningAlgorithm; 38 | import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLoss; 39 | import it.uniroma2.sag.kelp.predictionfunction.SequencePrediction; 40 | import it.uniroma2.sag.kelp.predictionfunction.SequencePredictionFunction; 41 | 42 | public class SequenceLearningLinearTest { 43 | 44 | private static final Float TOLERANCE = 0.001f; 45 | 46 | public static void main(String[] args) throws Exception { 47 | 48 | } 49 | 50 | @Test 51 | public void testLinear() { 52 | 53 | String inputTrainFilePath = "src/test/resources/sequence_learning/declaration_of_independence.klp.gz"; 54 | String inputTestFilePath = "src/test/resources/sequence_learning/gettysburg_address.klp.gz"; 55 | String scoreFilePath = "src/test/resources/sequence_learning/prediction_test_linear.txt"; 56 | 57 | /* 58 | * Given a targeted item in the sequence, this variable determines the 59 | * number of previous example considered in the learning/labeling 60 | * process. 61 | * 62 | * NOTE: if this variable is set to 0, the learning process corresponds 63 | * to a traditional multi-class classification schema 64 | */ 65 | int transitionsOrder = 1; 66 | 67 | /* 68 | * This variable determines the importance of the transition-based 69 | * features during the learning process. Higher valuers will assign more 70 | * importance to the transitions. 71 | */ 72 | float weight = 1f; 73 | 74 | /* 75 | * The size of the beam to be used in the decoding process. This number 76 | * determines the number of possible sequences produced in the labeling 77 | * process. It will also increase the process complexity. 78 | */ 79 | int beamSize = 5; 80 | 81 | /* 82 | * During the labeling process, each item is classified with respect to 83 | * the target classes. To reduce the complexity of the labeling process, 84 | * this variable determines the number of classes that received the 85 | * highest classification scores to be considered after the 86 | * classification step in the Viterbi Decoding. 87 | */ 88 | int maxEmissionCandidates = 3; 89 | 90 | /* 91 | * This representation contains the feature vector representing items in 92 | * the sequence 93 | */ 94 | String originalRepresentationName = "rep"; 95 | 96 | /* 97 | * Loading the training dataset 98 | */ 99 | SequenceDataset sequenceTrainDataset = new SequenceDataset(); 100 | try { 101 | sequenceTrainDataset.populate(inputTrainFilePath); 102 | } catch (IOException e) { 103 | e.printStackTrace(); 104 | Assert.assertTrue(false); 105 | } catch (InstantiationException e) { 106 | e.printStackTrace(); 107 | Assert.assertTrue(false); 108 | } catch (ParsingExampleException e) { 109 | e.printStackTrace(); 110 | Assert.assertTrue(false); 111 | } catch (Exception e) { 112 | e.printStackTrace(); 113 | Assert.assertTrue(false); 114 | } 115 | 116 | /* 117 | * Instance classifier 118 | */ 119 | float cSVM = 1f; 120 | DCDLearningAlgorithm instanceClassifierLearningAlgorithm = new DCDLearningAlgorithm(cSVM, cSVM, DCDLoss.L1, 121 | false, 50, originalRepresentationName); 122 | 123 | /* 124 | * Sequence classifier. 125 | */ 126 | SequenceClassificationLearningAlgorithm sequenceClassificationLearningAlgorithm = null; 127 | try { 128 | sequenceClassificationLearningAlgorithm = new SequenceClassificationLinearLearningAlgorithm( 129 | instanceClassifierLearningAlgorithm, transitionsOrder, weight); 130 | sequenceClassificationLearningAlgorithm.setMaxEmissionCandidates(maxEmissionCandidates); 131 | sequenceClassificationLearningAlgorithm.setBeamSize(beamSize); 132 | 133 | sequenceClassificationLearningAlgorithm.learn(sequenceTrainDataset); 134 | } catch (Exception e1) { 135 | e1.printStackTrace(); 136 | Assert.assertTrue(false); 137 | } 138 | 139 | SequencePredictionFunction predictionFunction = (SequencePredictionFunction) sequenceClassificationLearningAlgorithm 140 | .getPredictionFunction(); 141 | 142 | /* 143 | * Load the test set 144 | */ 145 | SequenceDataset sequenceTestDataset = new SequenceDataset(); 146 | try { 147 | sequenceTestDataset.populate(inputTestFilePath); 148 | } catch (IOException e) { 149 | e.printStackTrace(); 150 | Assert.assertTrue(false); 151 | } catch (InstantiationException e) { 152 | e.printStackTrace(); 153 | Assert.assertTrue(false); 154 | } catch (ParsingExampleException e) { 155 | e.printStackTrace(); 156 | Assert.assertTrue(false); 157 | } 158 | 159 | /* 160 | * Tagging and evaluating 161 | */ 162 | // PrintStream ps = new PrintStream(scoreFilePath); 163 | ArrayList labels = new ArrayList(); 164 | ArrayList scores = new ArrayList(); 165 | for (Example example : sequenceTestDataset.getExamples()) { 166 | 167 | SequenceExample sequenceExample = (SequenceExample) example; 168 | SequencePrediction sequencePrediction = (SequencePrediction) predictionFunction.predict(sequenceExample); 169 | 170 | SequencePath bestPath = sequencePrediction.bestPath(); 171 | for (int i = 0; i < sequenceExample.getLenght(); i++) { 172 | // ps.println(bestPath.getAssignedLabel(i) + "\t" + 173 | // bestPath.getScore()); 174 | labels.add(bestPath.getAssignedLabel(i)); 175 | scores.add(bestPath.getScore()); 176 | } 177 | 178 | } 179 | // ps.close(); 180 | 181 | ArrayList oldScores = loadScores(scoreFilePath); 182 | ArrayList oldLabels = loadLabels(scoreFilePath); 183 | 184 | for (int i = 0; i < oldScores.size(); i++) { 185 | Assert.assertEquals(oldScores.get(i), scores.get(i), TOLERANCE); 186 | Assert.assertEquals(labels.get(i).toString(), oldLabels.get(i).toString()); 187 | } 188 | 189 | } 190 | 191 | public static ArrayList loadScores(String filepath) { 192 | try { 193 | ArrayList scores = new ArrayList(); 194 | BufferedReader in = null; 195 | String encoding = "UTF-8"; 196 | if (filepath.endsWith(".gz")) { 197 | in = new BufferedReader( 198 | new InputStreamReader(new GZIPInputStream(new FileInputStream(filepath)), encoding)); 199 | } else { 200 | in = new BufferedReader(new InputStreamReader(new FileInputStream(filepath), encoding)); 201 | } 202 | 203 | String str = ""; 204 | while ((str = in.readLine()) != null) { 205 | scores.add(Double.parseDouble(str.split("\t")[1])); 206 | } 207 | 208 | in.close(); 209 | 210 | return scores; 211 | 212 | } catch (UnsupportedEncodingException e) { 213 | e.printStackTrace(); 214 | Assert.assertTrue(false); 215 | } catch (FileNotFoundException e) { 216 | e.printStackTrace(); 217 | Assert.assertTrue(false); 218 | } catch (IOException e) { 219 | e.printStackTrace(); 220 | Assert.assertTrue(false); 221 | } 222 | 223 | return null; 224 | } 225 | 226 | public static ArrayList loadLabels(String filepath) { 227 | try { 228 | ArrayList res = new ArrayList(); 229 | BufferedReader in = null; 230 | String encoding = "UTF-8"; 231 | if (filepath.endsWith(".gz")) { 232 | in = new BufferedReader( 233 | new InputStreamReader(new GZIPInputStream(new FileInputStream(filepath)), encoding)); 234 | } else { 235 | in = new BufferedReader(new InputStreamReader(new FileInputStream(filepath), encoding)); 236 | } 237 | 238 | String str = ""; 239 | while ((str = in.readLine()) != null) { 240 | res.add(new StringLabel(str.split("\t")[0])); 241 | } 242 | 243 | in.close(); 244 | 245 | return res; 246 | 247 | } catch (UnsupportedEncodingException e) { 248 | e.printStackTrace(); 249 | Assert.assertTrue(false); 250 | } catch (FileNotFoundException e) { 251 | e.printStackTrace(); 252 | Assert.assertTrue(false); 253 | } catch (IOException e) { 254 | e.printStackTrace(); 255 | Assert.assertTrue(false); 256 | } 257 | 258 | return null; 259 | } 260 | 261 | } 262 | -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/README.txt: -------------------------------------------------------------------------------- 1 | The datasets reported in this folder have been created starting from the dataset produced by Thorsten Joachims as an example problem for his SVM^{hmm} implementation. 2 | 3 | The original dataset can be downloaded at: 4 | http://download.joachims.org/svm_hmm/examples/example7.tar.gz 5 | while its description is reported at: 6 | https://www.cs.cornell.edu/people/tj/svm_light/svm_hmm.html -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/declaration_of_independence.klp.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/sequence_learning/declaration_of_independence.klp.gz -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/gettysburg_address.klp.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/sequence_learning/gettysburg_address.klp.gz -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/prediction_test_kernel.txt: -------------------------------------------------------------------------------- 1 | 1 -58.85170393685728 2 | 22 -58.85170393685728 3 | 3 -58.85170393685728 4 | 8 -58.85170393685728 5 | 15 -58.85170393685728 6 | 30 -58.85170393685728 7 | 20 -58.85170393685728 8 | 15 -58.85170393685728 9 | 11 -58.85170393685728 10 | 12 -58.85170393685728 11 | 8 -58.85170393685728 12 | 7 -58.85170393685728 13 | 12 -58.85170393685728 14 | 7 -58.85170393685728 15 | 9 -58.85170393685728 16 | 12 -58.85170393685728 17 | 6 -58.85170393685728 18 | 27 -58.85170393685728 19 | 8 -58.85170393685728 20 | 13 -58.85170393685728 21 | 3 -58.85170393685728 22 | 27 -58.85170393685728 23 | 25 -58.85170393685728 24 | 7 -58.85170393685728 25 | 12 -58.85170393685728 26 | 7 -58.85170393685728 27 | 7 -58.85170393685728 28 | 15 -58.85170393685728 29 | 30 -58.85170393685728 30 | 29 -58.85170393685728 31 | 9 -58.85170393685728 32 | 17 -58.85170393685728 33 | 21 -49.958351223707055 34 | 19 -49.958351223707055 35 | 30 -49.958351223707055 36 | 29 -49.958351223707055 37 | 8 -49.958351223707055 38 | 7 -49.958351223707055 39 | 13 -49.958351223707055 40 | 9 -49.958351223707055 41 | 12 -49.958351223707055 42 | 6 -49.958351223707055 43 | 28 -49.958351223707055 44 | 21 -49.958351223707055 45 | 7 -49.958351223707055 46 | 12 -49.958351223707055 47 | 8 -49.958351223707055 48 | 7 -49.958351223707055 49 | 12 -49.958351223707055 50 | 8 -49.958351223707055 51 | 29 -49.958351223707055 52 | 3 -49.958351223707055 53 | 21 -49.958351223707055 54 | 27 -49.958351223707055 55 | 9 -49.958351223707055 56 | 9 -49.958351223707055 57 | 12 -49.958351223707055 58 | 17 -49.958351223707055 59 | 19 -19.14935390144825 60 | 30 -19.14935390144825 61 | 29 -19.14935390144825 62 | 8 -19.14935390144825 63 | 7 -19.14935390144825 64 | 13 -19.14935390144825 65 | 13 -19.14935390144825 66 | 8 -19.14935390144825 67 | 7 -19.14935390144825 68 | 12 -19.14935390144825 69 | 17 -19.14935390144825 70 | 19 -51.68865761583535 71 | 30 -51.68865761583535 72 | 9 -51.68865761583535 73 | 25 -51.68865761583535 74 | 26 -51.68865761583535 75 | 7 -51.68865761583535 76 | 12 -51.68865761583535 77 | 8 -51.68865761583535 78 | 7 -51.68865761583535 79 | 12 -51.68865761583535 80 | 8 -51.68865761583535 81 | 7 -51.68865761583535 82 | 9 -51.68865761583535 83 | 12 -51.68865761583535 84 | 8 -51.68865761583535 85 | 7 -51.68865761583535 86 | 9 -51.68865761583535 87 | 12 -51.68865761583535 88 | 30 -51.68865761583535 89 | 20 -51.68865761583535 90 | 15 -51.68865761583535 91 | 7 -51.68865761583535 92 | 7 -51.68865761583535 93 | 12 -51.68865761583535 94 | 12 -51.68865761583535 95 | 30 -51.68865761583535 96 | 17 -51.68865761583535 97 | 19 -20.546289531993914 98 | 31 -20.546289531993914 99 | 21 -20.546289531993914 100 | 28 -20.546289531993914 101 | 3 -20.546289531993914 102 | 8 -20.546289531993914 103 | 7 -20.546289531993914 104 | 19 -20.546289531993914 105 | 11 -20.546289531993914 106 | 26 -20.546289531993914 107 | 7 -20.546289531993914 108 | 17 -20.546289531993914 109 | 3 -39.18519755398995 110 | 8 -39.18519755398995 111 | 7 -39.18519755398995 112 | 9 -39.18519755398995 113 | 12 -39.18519755398995 114 | 6 -39.18519755398995 115 | 19 -39.18519755398995 116 | 11 -39.18519755398995 117 | 26 -39.18519755398995 118 | 6 -39.18519755398995 119 | 19 -39.18519755398995 120 | 11 -39.18519755398995 121 | 26 -39.18519755398995 122 | 6 -39.18519755398995 123 | 19 -39.18519755398995 124 | 11 -39.18519755398995 125 | 26 -39.18519755398995 126 | 7 -39.18519755398995 127 | 12 -39.18519755398995 128 | 17 -39.18519755398995 129 | 7 -46.958838324933005 130 | 9 -46.958838324933005 131 | 15 -46.958838324933005 132 | 6 -46.958838324933005 133 | 28 -46.958838324933005 134 | 3 -46.958838324933005 135 | 12 -46.958838324933005 136 | 21 -46.958838324933005 137 | 29 -46.958838324933005 138 | 21 -46.958838324933005 139 | 30 -46.958838324933005 140 | 29 -46.958838324933005 141 | 19 -46.958838324933005 142 | 20 -46.958838324933005 143 | 12 -46.958838324933005 144 | 20 -46.958838324933005 145 | 9 -46.958838324933005 146 | 12 -46.958838324933005 147 | 25 -46.958838324933005 148 | 26 -46.958838324933005 149 | 3 -46.958838324933005 150 | 9 -46.958838324933005 151 | 17 -46.958838324933005 152 | 7 -49.424689389703104 153 | 12 -49.424689389703104 154 | 11 -49.424689389703104 155 | 26 -49.424689389703104 156 | 9 -49.424689389703104 157 | 3 -49.424689389703104 158 | 9 -49.424689389703104 159 | 12 -49.424689389703104 160 | 8 -49.424689389703104 161 | 19 -49.424689389703104 162 | 11 -49.424689389703104 163 | 21 -49.424689389703104 164 | 6 -49.424689389703104 165 | 3 -49.424689389703104 166 | 19 -49.424689389703104 167 | 8 -49.424689389703104 168 | 12 -49.424689389703104 169 | 26 -49.424689389703104 170 | 7 -49.424689389703104 171 | 19 -49.424689389703104 172 | 30 -49.424689389703104 173 | 21 -49.424689389703104 174 | 17 -49.424689389703104 175 | 19 -52.114114669781316 176 | 31 -52.114114669781316 177 | 8 -52.114114669781316 178 | 19 -52.114114669781316 179 | 7 -52.114114669781316 180 | 28 -52.114114669781316 181 | 21 -52.114114669781316 182 | 25 -52.114114669781316 183 | 26 -52.114114669781316 184 | 27 -52.114114669781316 185 | 21 -52.114114669781316 186 | 25 -52.114114669781316 187 | 7 -52.114114669781316 188 | 9 -52.114114669781316 189 | 12 -52.114114669781316 190 | 32 -52.114114669781316 191 | 19 -52.114114669781316 192 | 30 -52.114114669781316 193 | 11 -52.114114669781316 194 | 21 -52.114114669781316 195 | 30 -52.114114669781316 196 | 21 -52.114114669781316 197 | 29 -52.114114669781316 198 | 21 -52.114114669781316 199 | 21 -52.114114669781316 200 | 29 -52.114114669781316 201 | 17 -52.114114669781316 202 | 19 -157.40459068974272 203 | 31 -157.40459068974272 204 | 21 -157.40459068974272 205 | 8 -157.40459068974272 206 | 19 -157.40459068974272 207 | 25 -157.40459068974272 208 | 26 -157.40459068974272 209 | 21 -157.40459068974272 210 | 27 -157.40459068974272 211 | 25 -157.40459068974272 212 | 7 -157.40459068974272 213 | 13 -157.40459068974272 214 | 13 -157.40459068974272 215 | 28 -157.40459068974272 216 | 21 -157.40459068974272 217 | 19 -157.40459068974272 218 | 5 -157.40459068974272 219 | 7 -157.40459068974272 220 | 8 -157.40459068974272 221 | 7 -157.40459068974272 222 | 9 -157.40459068974272 223 | 12 -157.40459068974272 224 | 19 -157.40459068974272 225 | 30 -157.40459068974272 226 | 9 -157.40459068974272 227 | 12 -157.40459068974272 228 | 25 -157.40459068974272 229 | 7 -157.40459068974272 230 | 12 -157.40459068974272 231 | 8 -157.40459068974272 232 | 32 -157.40459068974272 233 | 19 -157.40459068974272 234 | 30 -157.40459068974272 235 | 7 -157.40459068974272 236 | 23 -157.40459068974272 237 | 9 -157.40459068974272 238 | 12 -157.40459068974272 239 | 8 -157.40459068974272 240 | 12 -157.40459068974272 241 | 8 -157.40459068974272 242 | 7 -157.40459068974272 243 | 19 -157.40459068974272 244 | 21 -157.40459068974272 245 | 21 -157.40459068974272 246 | 26 -157.40459068974272 247 | 7 -157.40459068974272 248 | 7 -157.40459068974272 249 | 12 -157.40459068974272 250 | 11 -157.40459068974272 251 | 21 -157.40459068974272 252 | 30 -157.40459068974272 253 | 29 -157.40459068974272 254 | 8 -157.40459068974272 255 | 9 -157.40459068974272 256 | 6 -157.40459068974272 257 | 7 -157.40459068974272 258 | 7 -157.40459068974272 259 | 12 -157.40459068974272 260 | 8 -157.40459068974272 261 | 13 -157.40459068974272 262 | 11 -157.40459068974272 263 | 30 -157.40459068974272 264 | 7 -157.40459068974272 265 | 9 -157.40459068974272 266 | 12 -157.40459068974272 267 | 8 -157.40459068974272 268 | 12 -157.40459068974272 269 | 6 -157.40459068974272 270 | 3 -157.40459068974272 271 | 7 -157.40459068974272 272 | 12 -157.40459068974272 273 | 8 -157.40459068974272 274 | 7 -157.40459068974272 275 | 15 -157.40459068974272 276 | 6 -157.40459068974272 277 | 8 -157.40459068974272 278 | 7 -157.40459068974272 279 | 15 -157.40459068974272 280 | 6 -157.40459068974272 281 | 8 -157.40459068974272 282 | 7 -157.40459068974272 283 | 15 -157.40459068974272 284 | 11 -157.40459068974272 285 | 21 -157.40459068974272 286 | 26 -157.40459068974272 287 | 8 -157.40459068974272 288 | 7 -157.40459068974272 289 | 12 -157.40459068974272 290 | 17 -157.40459068974272 291 | -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/prediction_test_linear.txt: -------------------------------------------------------------------------------- 1 | 1 -61.552865965064605 2 | 22 -61.552865965064605 3 | 3 -61.552865965064605 4 | 8 -61.552865965064605 5 | 15 -61.552865965064605 6 | 30 -61.552865965064605 7 | 20 -61.552865965064605 8 | 15 -61.552865965064605 9 | 11 -61.552865965064605 10 | 12 -61.552865965064605 11 | 8 -61.552865965064605 12 | 7 -61.552865965064605 13 | 12 -61.552865965064605 14 | 7 -61.552865965064605 15 | 9 -61.552865965064605 16 | 12 -61.552865965064605 17 | 6 -61.552865965064605 18 | 27 -61.552865965064605 19 | 8 -61.552865965064605 20 | 13 -61.552865965064605 21 | 3 -61.552865965064605 22 | 27 -61.552865965064605 23 | 25 -61.552865965064605 24 | 7 -61.552865965064605 25 | 12 -61.552865965064605 26 | 7 -61.552865965064605 27 | 7 -61.552865965064605 28 | 15 -61.552865965064605 29 | 30 -61.552865965064605 30 | 29 -61.552865965064605 31 | 9 -61.552865965064605 32 | 17 -61.552865965064605 33 | 21 -50.586976361817456 34 | 19 -50.586976361817456 35 | 30 -50.586976361817456 36 | 29 -50.586976361817456 37 | 8 -50.586976361817456 38 | 7 -50.586976361817456 39 | 13 -50.586976361817456 40 | 9 -50.586976361817456 41 | 12 -50.586976361817456 42 | 6 -50.586976361817456 43 | 28 -50.586976361817456 44 | 21 -50.586976361817456 45 | 7 -50.586976361817456 46 | 12 -50.586976361817456 47 | 3 -50.586976361817456 48 | 7 -50.586976361817456 49 | 12 -50.586976361817456 50 | 21 -50.586976361817456 51 | 21 -50.586976361817456 52 | 3 -50.586976361817456 53 | 21 -50.586976361817456 54 | 21 -50.586976361817456 55 | 7 -50.586976361817456 56 | 9 -50.586976361817456 57 | 12 -50.586976361817456 58 | 17 -50.586976361817456 59 | 19 -19.745991163812985 60 | 30 -19.745991163812985 61 | 26 -19.745991163812985 62 | 8 -19.745991163812985 63 | 7 -19.745991163812985 64 | 13 -19.745991163812985 65 | 13 -19.745991163812985 66 | 8 -19.745991163812985 67 | 7 -19.745991163812985 68 | 12 -19.745991163812985 69 | 17 -19.745991163812985 70 | 19 -54.528993898737625 71 | 30 -54.528993898737625 72 | 9 -54.528993898737625 73 | 25 -54.528993898737625 74 | 26 -54.528993898737625 75 | 7 -54.528993898737625 76 | 12 -54.528993898737625 77 | 8 -54.528993898737625 78 | 7 -54.528993898737625 79 | 12 -54.528993898737625 80 | 8 -54.528993898737625 81 | 7 -54.528993898737625 82 | 9 -54.528993898737625 83 | 12 -54.528993898737625 84 | 8 -54.528993898737625 85 | 7 -54.528993898737625 86 | 30 -54.528993898737625 87 | 21 -54.528993898737625 88 | 30 -54.528993898737625 89 | 20 -54.528993898737625 90 | 15 -54.528993898737625 91 | 7 -54.528993898737625 92 | 7 -54.528993898737625 93 | 12 -54.528993898737625 94 | 12 -54.528993898737625 95 | 30 -54.528993898737625 96 | 17 -54.528993898737625 97 | 19 -21.642864657110263 98 | 31 -21.642864657110263 99 | 21 -21.642864657110263 100 | 28 -21.642864657110263 101 | 3 -21.642864657110263 102 | 8 -21.642864657110263 103 | 7 -21.642864657110263 104 | 19 -21.642864657110263 105 | 11 -21.642864657110263 106 | 26 -21.642864657110263 107 | 7 -21.642864657110263 108 | 17 -21.642864657110263 109 | 3 -41.163958681094705 110 | 8 -41.163958681094705 111 | 7 -41.163958681094705 112 | 9 -41.163958681094705 113 | 12 -41.163958681094705 114 | 6 -41.163958681094705 115 | 19 -41.163958681094705 116 | 11 -41.163958681094705 117 | 26 -41.163958681094705 118 | 6 -41.163958681094705 119 | 19 -41.163958681094705 120 | 11 -41.163958681094705 121 | 26 -41.163958681094705 122 | 6 -41.163958681094705 123 | 19 -41.163958681094705 124 | 11 -41.163958681094705 125 | 26 -41.163958681094705 126 | 7 -41.163958681094705 127 | 12 -41.163958681094705 128 | 17 -41.163958681094705 129 | 7 -47.69120077996808 130 | 9 -47.69120077996808 131 | 15 -47.69120077996808 132 | 6 -47.69120077996808 133 | 28 -47.69120077996808 134 | 3 -47.69120077996808 135 | 12 -47.69120077996808 136 | 21 -47.69120077996808 137 | 29 -47.69120077996808 138 | 21 -47.69120077996808 139 | 30 -47.69120077996808 140 | 29 -47.69120077996808 141 | 19 -47.69120077996808 142 | 12 -47.69120077996808 143 | 30 -47.69120077996808 144 | 20 -47.69120077996808 145 | 9 -47.69120077996808 146 | 12 -47.69120077996808 147 | 25 -47.69120077996808 148 | 26 -47.69120077996808 149 | 3 -47.69120077996808 150 | 12 -47.69120077996808 151 | 17 -47.69120077996808 152 | 7 -49.193906625295206 153 | 12 -49.193906625295206 154 | 11 -49.193906625295206 155 | 26 -49.193906625295206 156 | 9 -49.193906625295206 157 | 3 -49.193906625295206 158 | 9 -49.193906625295206 159 | 12 -49.193906625295206 160 | 7 -49.193906625295206 161 | 19 -49.193906625295206 162 | 11 -49.193906625295206 163 | 21 -49.193906625295206 164 | 6 -49.193906625295206 165 | 3 -49.193906625295206 166 | 19 -49.193906625295206 167 | 8 -49.193906625295206 168 | 12 -49.193906625295206 169 | 26 -49.193906625295206 170 | 7 -49.193906625295206 171 | 19 -49.193906625295206 172 | 30 -49.193906625295206 173 | 21 -49.193906625295206 174 | 17 -49.193906625295206 175 | 19 -52.77654733531991 176 | 31 -52.77654733531991 177 | 8 -52.77654733531991 178 | 19 -52.77654733531991 179 | 7 -52.77654733531991 180 | 28 -52.77654733531991 181 | 21 -52.77654733531991 182 | 25 -52.77654733531991 183 | 26 -52.77654733531991 184 | 27 -52.77654733531991 185 | 21 -52.77654733531991 186 | 25 -52.77654733531991 187 | 7 -52.77654733531991 188 | 9 -52.77654733531991 189 | 12 -52.77654733531991 190 | 32 -52.77654733531991 191 | 19 -52.77654733531991 192 | 30 -52.77654733531991 193 | 11 -52.77654733531991 194 | 21 -52.77654733531991 195 | 30 -52.77654733531991 196 | 21 -52.77654733531991 197 | 12 -52.77654733531991 198 | 21 -52.77654733531991 199 | 21 -52.77654733531991 200 | 29 -52.77654733531991 201 | 17 -52.77654733531991 202 | 19 -162.86896772139426 203 | 31 -162.86896772139426 204 | 21 -162.86896772139426 205 | 8 -162.86896772139426 206 | 19 -162.86896772139426 207 | 25 -162.86896772139426 208 | 26 -162.86896772139426 209 | 21 -162.86896772139426 210 | 27 -162.86896772139426 211 | 25 -162.86896772139426 212 | 7 -162.86896772139426 213 | 13 -162.86896772139426 214 | 13 -162.86896772139426 215 | 28 -162.86896772139426 216 | 21 -162.86896772139426 217 | 19 -162.86896772139426 218 | 5 -162.86896772139426 219 | 7 -162.86896772139426 220 | 8 -162.86896772139426 221 | 7 -162.86896772139426 222 | 9 -162.86896772139426 223 | 12 -162.86896772139426 224 | 19 -162.86896772139426 225 | 30 -162.86896772139426 226 | 9 -162.86896772139426 227 | 12 -162.86896772139426 228 | 25 -162.86896772139426 229 | 7 -162.86896772139426 230 | 12 -162.86896772139426 231 | 8 -162.86896772139426 232 | 32 -162.86896772139426 233 | 19 -162.86896772139426 234 | 30 -162.86896772139426 235 | 7 -162.86896772139426 236 | 23 -162.86896772139426 237 | 9 -162.86896772139426 238 | 12 -162.86896772139426 239 | 8 -162.86896772139426 240 | 12 -162.86896772139426 241 | 5 -162.86896772139426 242 | 7 -162.86896772139426 243 | 19 -162.86896772139426 244 | 21 -162.86896772139426 245 | 21 -162.86896772139426 246 | 26 -162.86896772139426 247 | 7 -162.86896772139426 248 | 7 -162.86896772139426 249 | 12 -162.86896772139426 250 | 11 -162.86896772139426 251 | 21 -162.86896772139426 252 | 30 -162.86896772139426 253 | 29 -162.86896772139426 254 | 8 -162.86896772139426 255 | 9 -162.86896772139426 256 | 6 -162.86896772139426 257 | 7 -162.86896772139426 258 | 7 -162.86896772139426 259 | 12 -162.86896772139426 260 | 8 -162.86896772139426 261 | 13 -162.86896772139426 262 | 11 -162.86896772139426 263 | 30 -162.86896772139426 264 | 7 -162.86896772139426 265 | 9 -162.86896772139426 266 | 12 -162.86896772139426 267 | 8 -162.86896772139426 268 | 12 -162.86896772139426 269 | 6 -162.86896772139426 270 | 3 -162.86896772139426 271 | 8 -162.86896772139426 272 | 12 -162.86896772139426 273 | 8 -162.86896772139426 274 | 7 -162.86896772139426 275 | 15 -162.86896772139426 276 | 6 -162.86896772139426 277 | 8 -162.86896772139426 278 | 7 -162.86896772139426 279 | 15 -162.86896772139426 280 | 6 -162.86896772139426 281 | 8 -162.86896772139426 282 | 7 -162.86896772139426 283 | 15 -162.86896772139426 284 | 11 -162.86896772139426 285 | 21 -162.86896772139426 286 | 26 -162.86896772139426 287 | 8 -162.86896772139426 288 | 7 -162.86896772139426 289 | 12 -162.86896772139426 290 | 17 -162.86896772139426 291 | -------------------------------------------------------------------------------- /src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz --------------------------------------------------------------------------------
s. It 31 | * computes precision, recall, f1s for each class, and a global accuracy. 32 | * 33 | * @author Danilo Croce 34 | */ 35 | public class MulticlassSequenceClassificationEvaluator extends MulticlassClassificationEvaluator{ 36 | 37 | /** 38 | * Initialize a new F1Evaluator that will work on the specified classes 39 | * 40 | * @param labels 41 | */ 42 | public MulticlassSequenceClassificationEvaluator(List labels) { 43 | super(labels); 44 | } 45 | 46 | public void addCount(Example test, Prediction prediction) { 47 | addCount((SequenceExample) test, (SequencePrediction) prediction); 48 | } 49 | 50 | /** 51 | * This method should be implemented in the subclasses to update counters 52 | * useful to compute the performance measure 53 | * 54 | * @param test 55 | * the test example 56 | * @param predicted 57 | * the prediction of the system 58 | */ 59 | public void addCount(SequenceExample test, SequencePrediction predicted) { 60 | 61 | SequencePath bestPath = predicted.bestPath(); 62 | 63 | for (int seqIdx = 0; seqIdx < test.getLenght(); seqIdx++) { 64 | 65 | Example testItem = test.getExample(seqIdx); 66 | SequenceEmission sequenceLabel = bestPath.getAssignedSequnceLabels().get(seqIdx); 67 | 68 | for (Label l : this.labels) { 69 | ClassStats stats = this.classStats.get(l); 70 | if(testItem.isExampleOf(l)){ 71 | if(sequenceLabel.getLabel().equals(l)){ 72 | stats.tp++; 73 | totalTp++; 74 | }else{ 75 | stats.fn++; 76 | totalFn++; 77 | } 78 | }else{ 79 | if(sequenceLabel.getLabel().equals(l)){ 80 | stats.fp++; 81 | totalFp++; 82 | }else{ 83 | stats.tn++; 84 | totalTn++; 85 | } 86 | } 87 | 88 | } 89 | 90 | //TODO: check (i) e' giusto valutare l'accuracy dei singoli elementi della sequenza e non della sequenza completa 91 | //(ii) va considerato il caso multilabel 92 | total++; 93 | 94 | if (testItem.isExampleOf(sequenceLabel.getLabel())) { 95 | correct++; 96 | } 97 | 98 | this.computed = false; 99 | } 100 | } 101 | 102 | } 103 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/algorithms/binary/liblinear/LibLinearDenseVsSparseClassificationEvaluator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2014 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.algorithms.binary.liblinear; 17 | 18 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 19 | import it.uniroma2.sag.kelp.data.example.Example; 20 | import it.uniroma2.sag.kelp.data.label.Label; 21 | import it.uniroma2.sag.kelp.data.manipulator.NormalizationManipolator; 22 | import it.uniroma2.sag.kelp.data.manipulator.VectorConcatenationManipulator; 23 | import it.uniroma2.sag.kelp.learningalgorithm.classification.liblinear.LibLinearLearningAlgorithm; 24 | import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning; 25 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassificationOutput; 26 | import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassifier; 27 | import it.uniroma2.sag.kelp.utils.evaluation.MulticlassClassificationEvaluator; 28 | import it.uniroma2.sag.kelp.utils.exception.NoSuchPerformanceMeasureException; 29 | 30 | import java.io.FileNotFoundException; 31 | import java.io.UnsupportedEncodingException; 32 | import java.util.ArrayList; 33 | import java.util.List; 34 | 35 | import org.junit.Assert; 36 | import org.junit.Test; 37 | 38 | public class LibLinearDenseVsSparseClassificationEvaluator { 39 | 40 | private static List sparseScores = new ArrayList(); 41 | private static List denseScores = new ArrayList(); 42 | 43 | @Test 44 | public void testConsistency() { 45 | try { 46 | String inputFilePath = "src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz"; 47 | 48 | SimpleDataset dataset = new SimpleDataset(); 49 | dataset.populate(inputFilePath); 50 | SimpleDataset[] split = dataset.split(0.5f); 51 | 52 | SimpleDataset trainingSet = split[0]; 53 | SimpleDataset testSet = split[1]; 54 | float c = 1.0f; 55 | float f1Dense = testDense(trainingSet, c, testSet); 56 | float f1Sparse = testSparse(trainingSet, c, testSet); 57 | 58 | Assert.assertEquals(f1Sparse, f1Dense, 0.000001); 59 | 60 | for (int i = 0; i < sparseScores.size(); i++) { 61 | Assert.assertEquals(sparseScores.get(i), denseScores.get(i), 62 | 0.000001); 63 | } 64 | } catch (FileNotFoundException e) { 65 | e.printStackTrace(); 66 | Assert.assertTrue(false); 67 | } catch (UnsupportedEncodingException e) { 68 | e.printStackTrace(); 69 | Assert.assertTrue(false); 70 | } catch (NoSuchPerformanceMeasureException e) { 71 | e.printStackTrace(); 72 | Assert.assertTrue(false); 73 | } catch (Exception e) { 74 | e.printStackTrace(); 75 | Assert.assertTrue(false); 76 | } 77 | } 78 | 79 | private static float testSparse(SimpleDataset trainingSet, float c, 80 | SimpleDataset testSet) throws FileNotFoundException, 81 | UnsupportedEncodingException, NoSuchPerformanceMeasureException { 82 | List classes = trainingSet.getClassificationLabels(); 83 | NormalizationManipolator norma = new NormalizationManipolator(); 84 | trainingSet.manipulate(norma); 85 | testSet.manipulate(norma); 86 | List repr = new ArrayList(); 87 | repr.add("WS"); 88 | List reprW = new ArrayList(); 89 | reprW.add(1.0f); 90 | VectorConcatenationManipulator man = new VectorConcatenationManipulator( 91 | "WS0", repr, reprW); 92 | trainingSet.manipulate(man); 93 | testSet.manipulate(man); 94 | 95 | LibLinearLearningAlgorithm svmSolver = new LibLinearLearningAlgorithm(); 96 | svmSolver.setCn(c); 97 | svmSolver.setCp(c); 98 | svmSolver.setRepresentation("WS0"); 99 | 100 | OneVsAllLearning ovaLearner = new OneVsAllLearning(); 101 | ovaLearner.setBaseAlgorithm(svmSolver); 102 | ovaLearner.setLabels(classes); 103 | ovaLearner.learn(trainingSet); 104 | OneVsAllClassifier f = ovaLearner.getPredictionFunction(); 105 | MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator( 106 | trainingSet.getClassificationLabels()); 107 | for (Example e : testSet.getExamples()) { 108 | OneVsAllClassificationOutput predict = f.predict(e); 109 | Label l = predict.getPredictedClasses().get(0); 110 | evaluator.addCount(e, predict); 111 | sparseScores.add(predict.getScore(l)); 112 | } 113 | 114 | return evaluator.getMacroF1(); 115 | } 116 | 117 | private static float testDense(SimpleDataset trainingSet, float c, 118 | SimpleDataset testSet) throws FileNotFoundException, 119 | UnsupportedEncodingException, NoSuchPerformanceMeasureException { 120 | List classes = trainingSet.getClassificationLabels(); 121 | 122 | LibLinearLearningAlgorithm svmSolver = new LibLinearLearningAlgorithm(); 123 | svmSolver.setCn(c); 124 | svmSolver.setCp(c); 125 | svmSolver.setRepresentation("WS"); 126 | 127 | OneVsAllLearning ovaLearner = new OneVsAllLearning(); 128 | ovaLearner.setBaseAlgorithm(svmSolver); 129 | ovaLearner.setLabels(classes); 130 | ovaLearner.learn(trainingSet); 131 | OneVsAllClassifier f = ovaLearner.getPredictionFunction(); 132 | MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator( 133 | trainingSet.getClassificationLabels()); 134 | for (Example e : testSet.getExamples()) { 135 | OneVsAllClassificationOutput predict = f.predict(e); 136 | Label l = predict.getPredictedClasses().get(0); 137 | evaluator.addCount(e, predict); 138 | denseScores.add(predict.getScore(l)); 139 | } 140 | 141 | return evaluator.getMacroF1(); 142 | } 143 | 144 | } 145 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/algorithms/incrementalTrain/IncrementalTrainTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Simone Filice and Giuseppe Castellucci and Danilo Croce 3 | * and Giovanni Da San Martino and Alessandro Moschitti and Roberto Basili 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package it.uniroma2.sag.kelp.algorithms.incrementalTrain; 18 | 19 | import java.io.IOException; 20 | import java.util.Random; 21 | 22 | import org.junit.Assert; 23 | import org.junit.BeforeClass; 24 | import org.junit.Test; 25 | 26 | import it.uniroma2.sag.kelp.data.dataset.SimpleDataset; 27 | import it.uniroma2.sag.kelp.data.example.Example; 28 | import it.uniroma2.sag.kelp.data.label.Label; 29 | import it.uniroma2.sag.kelp.data.label.StringLabel; 30 | import it.uniroma2.sag.kelp.kernel.Kernel; 31 | import it.uniroma2.sag.kelp.kernel.cache.FixSizeKernelCache; 32 | import it.uniroma2.sag.kelp.kernel.vector.LinearKernel; 33 | import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm; 34 | import it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron.KernelizedPerceptron; 35 | import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryKernelMachineClassifier; 36 | import it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput; 37 | import it.uniroma2.sag.kelp.predictionfunction.classifier.Classifier; 38 | import it.uniroma2.sag.kelp.utils.JacksonSerializerWrapper; 39 | import it.uniroma2.sag.kelp.utils.ObjectSerializer; 40 | 41 | public class IncrementalTrainTest { 42 | private static Classifier f = null; 43 | private static SimpleDataset trainingSet; 44 | private static SimpleDataset testSet; 45 | private static SimpleDataset [] folds; 46 | private static ObjectSerializer serializer = new JacksonSerializerWrapper(); 47 | private static KernelizedPerceptron learner; 48 | 49 | private static Label positiveClass = new StringLabel("+1"); 50 | 51 | @BeforeClass 52 | public static void learnModel() { 53 | trainingSet = new SimpleDataset(); 54 | testSet = new SimpleDataset(); 55 | try { 56 | trainingSet.populate("src/test/resources/svmTest/binary/binary_train.klp"); 57 | trainingSet.shuffleExamples(new Random()); 58 | // Read a dataset into a test variable 59 | testSet.populate("src/test/resources/svmTest/binary/binary_test.klp"); 60 | } catch (Exception e) { 61 | e.printStackTrace(); 62 | Assert.assertTrue(false); 63 | } 64 | 65 | folds = trainingSet.nFolding(2); 66 | 67 | // define the kernel 68 | Kernel kernel = new LinearKernel("0"); 69 | 70 | // add a cache 71 | kernel.setKernelCache(new FixSizeKernelCache(trainingSet 72 | .getNumberOfExamples())); 73 | 74 | // define the learning algorithm 75 | learner = new KernelizedPerceptron(0.2f, 1f, false, kernel, positiveClass); 76 | 77 | // learn and get the prediction function 78 | learner.learn(trainingSet); 79 | f = learner.getPredictionFunction(); 80 | } 81 | 82 | @Test 83 | public void incrementalTrain() throws IOException{ 84 | String jsonSerialization = serializer.writeValueAsString(learner); 85 | System.out.println(jsonSerialization); 86 | ClassificationLearningAlgorithm jsonAlgo = serializer.readValue(jsonSerialization, ClassificationLearningAlgorithm.class); 87 | jsonAlgo.learn(folds[0]); 88 | jsonAlgo.learn(folds[1]); 89 | Classifier jsonClassifier = jsonAlgo.getPredictionFunction(); 90 | 91 | for(Example ex : testSet.getExamples()){ 92 | ClassificationOutput p = f.predict(ex); 93 | Float score = p.getScore(positiveClass); 94 | ClassificationOutput pJson = jsonClassifier.predict(ex); 95 | Float scoreJson = pJson.getScore(positiveClass); 96 | Assert.assertEquals(scoreJson.floatValue(), score.floatValue(), 97 | 0.001f); 98 | } 99 | } 100 | 101 | @Test 102 | public void reloadAndContinueTraining() throws IOException{ 103 | String jsonLearnerSerialization = serializer.writeValueAsString(learner); 104 | System.out.println(jsonLearnerSerialization); 105 | KernelizedPerceptron jsonAlgo = serializer.readValue(jsonLearnerSerialization, KernelizedPerceptron.class); 106 | jsonAlgo.learn(folds[0]); 107 | String jsonClassifierSerialization = serializer.writeValueAsString(jsonAlgo.getPredictionFunction()); 108 | jsonAlgo = serializer.readValue(jsonLearnerSerialization, KernelizedPerceptron.class); //Brand new classifier 109 | BinaryKernelMachineClassifier jsonClassifier = serializer.readValue(jsonClassifierSerialization, BinaryKernelMachineClassifier.class); 110 | jsonAlgo.getPredictionFunction().setModel(jsonClassifier.getModel()); 111 | jsonAlgo.learn(folds[1]); 112 | jsonClassifier = jsonAlgo.getPredictionFunction(); 113 | 114 | for(Example ex : testSet.getExamples()){ 115 | ClassificationOutput p = f.predict(ex); 116 | Float score = p.getScore(positiveClass); 117 | ClassificationOutput pJson = jsonClassifier.predict(ex); 118 | Float scoreJson = pJson.getScore(positiveClass); 119 | Assert.assertEquals(scoreJson.floatValue(), score.floatValue(), 120 | 0.001f); 121 | } 122 | } 123 | 124 | } 125 | -------------------------------------------------------------------------------- /src/test/java/it/uniroma2/sag/kelp/learningalgorithm/classification/hmm/SequenceLearningLinearTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Simone Filice and Giuseppe Castellucci and Danilo Croce and Roberto Basili 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | package it.uniroma2.sag.kelp.learningalgorithm.classification.hmm; 17 | 18 | import java.io.BufferedReader; 19 | import java.io.FileInputStream; 20 | import java.io.FileNotFoundException; 21 | import java.io.IOException; 22 | import java.io.InputStreamReader; 23 | import java.io.UnsupportedEncodingException; 24 | import java.util.ArrayList; 25 | import java.util.zip.GZIPInputStream; 26 | 27 | import org.junit.Assert; 28 | import org.junit.Test; 29 | 30 | import it.uniroma2.sag.kelp.data.dataset.SequenceDataset; 31 | import it.uniroma2.sag.kelp.data.example.Example; 32 | import it.uniroma2.sag.kelp.data.example.ParsingExampleException; 33 | import it.uniroma2.sag.kelp.data.example.SequenceExample; 34 | import it.uniroma2.sag.kelp.data.example.SequencePath; 35 | import it.uniroma2.sag.kelp.data.label.Label; 36 | import it.uniroma2.sag.kelp.data.label.StringLabel; 37 | import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLearningAlgorithm; 38 | import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLoss; 39 | import it.uniroma2.sag.kelp.predictionfunction.SequencePrediction; 40 | import it.uniroma2.sag.kelp.predictionfunction.SequencePredictionFunction; 41 | 42 | public class SequenceLearningLinearTest { 43 | 44 | private static final Float TOLERANCE = 0.001f; 45 | 46 | public static void main(String[] args) throws Exception { 47 | 48 | } 49 | 50 | @Test 51 | public void testLinear() { 52 | 53 | String inputTrainFilePath = "src/test/resources/sequence_learning/declaration_of_independence.klp.gz"; 54 | String inputTestFilePath = "src/test/resources/sequence_learning/gettysburg_address.klp.gz"; 55 | String scoreFilePath = "src/test/resources/sequence_learning/prediction_test_linear.txt"; 56 | 57 | /* 58 | * Given a targeted item in the sequence, this variable determines the 59 | * number of previous example considered in the learning/labeling 60 | * process. 61 | * 62 | * NOTE: if this variable is set to 0, the learning process corresponds 63 | * to a traditional multi-class classification schema 64 | */ 65 | int transitionsOrder = 1; 66 | 67 | /* 68 | * This variable determines the importance of the transition-based 69 | * features during the learning process. Higher valuers will assign more 70 | * importance to the transitions. 71 | */ 72 | float weight = 1f; 73 | 74 | /* 75 | * The size of the beam to be used in the decoding process. This number 76 | * determines the number of possible sequences produced in the labeling 77 | * process. It will also increase the process complexity. 78 | */ 79 | int beamSize = 5; 80 | 81 | /* 82 | * During the labeling process, each item is classified with respect to 83 | * the target classes. To reduce the complexity of the labeling process, 84 | * this variable determines the number of classes that received the 85 | * highest classification scores to be considered after the 86 | * classification step in the Viterbi Decoding. 87 | */ 88 | int maxEmissionCandidates = 3; 89 | 90 | /* 91 | * This representation contains the feature vector representing items in 92 | * the sequence 93 | */ 94 | String originalRepresentationName = "rep"; 95 | 96 | /* 97 | * Loading the training dataset 98 | */ 99 | SequenceDataset sequenceTrainDataset = new SequenceDataset(); 100 | try { 101 | sequenceTrainDataset.populate(inputTrainFilePath); 102 | } catch (IOException e) { 103 | e.printStackTrace(); 104 | Assert.assertTrue(false); 105 | } catch (InstantiationException e) { 106 | e.printStackTrace(); 107 | Assert.assertTrue(false); 108 | } catch (ParsingExampleException e) { 109 | e.printStackTrace(); 110 | Assert.assertTrue(false); 111 | } catch (Exception e) { 112 | e.printStackTrace(); 113 | Assert.assertTrue(false); 114 | } 115 | 116 | /* 117 | * Instance classifier 118 | */ 119 | float cSVM = 1f; 120 | DCDLearningAlgorithm instanceClassifierLearningAlgorithm = new DCDLearningAlgorithm(cSVM, cSVM, DCDLoss.L1, 121 | false, 50, originalRepresentationName); 122 | 123 | /* 124 | * Sequence classifier. 125 | */ 126 | SequenceClassificationLearningAlgorithm sequenceClassificationLearningAlgorithm = null; 127 | try { 128 | sequenceClassificationLearningAlgorithm = new SequenceClassificationLinearLearningAlgorithm( 129 | instanceClassifierLearningAlgorithm, transitionsOrder, weight); 130 | sequenceClassificationLearningAlgorithm.setMaxEmissionCandidates(maxEmissionCandidates); 131 | sequenceClassificationLearningAlgorithm.setBeamSize(beamSize); 132 | 133 | sequenceClassificationLearningAlgorithm.learn(sequenceTrainDataset); 134 | } catch (Exception e1) { 135 | e1.printStackTrace(); 136 | Assert.assertTrue(false); 137 | } 138 | 139 | SequencePredictionFunction predictionFunction = (SequencePredictionFunction) sequenceClassificationLearningAlgorithm 140 | .getPredictionFunction(); 141 | 142 | /* 143 | * Load the test set 144 | */ 145 | SequenceDataset sequenceTestDataset = new SequenceDataset(); 146 | try { 147 | sequenceTestDataset.populate(inputTestFilePath); 148 | } catch (IOException e) { 149 | e.printStackTrace(); 150 | Assert.assertTrue(false); 151 | } catch (InstantiationException e) { 152 | e.printStackTrace(); 153 | Assert.assertTrue(false); 154 | } catch (ParsingExampleException e) { 155 | e.printStackTrace(); 156 | Assert.assertTrue(false); 157 | } 158 | 159 | /* 160 | * Tagging and evaluating 161 | */ 162 | // PrintStream ps = new PrintStream(scoreFilePath); 163 | ArrayList labels = new ArrayList(); 164 | ArrayList scores = new ArrayList(); 165 | for (Example example : sequenceTestDataset.getExamples()) { 166 | 167 | SequenceExample sequenceExample = (SequenceExample) example; 168 | SequencePrediction sequencePrediction = (SequencePrediction) predictionFunction.predict(sequenceExample); 169 | 170 | SequencePath bestPath = sequencePrediction.bestPath(); 171 | for (int i = 0; i < sequenceExample.getLenght(); i++) { 172 | // ps.println(bestPath.getAssignedLabel(i) + "\t" + 173 | // bestPath.getScore()); 174 | labels.add(bestPath.getAssignedLabel(i)); 175 | scores.add(bestPath.getScore()); 176 | } 177 | 178 | } 179 | // ps.close(); 180 | 181 | ArrayList oldScores = loadScores(scoreFilePath); 182 | ArrayList oldLabels = loadLabels(scoreFilePath); 183 | 184 | for (int i = 0; i < oldScores.size(); i++) { 185 | Assert.assertEquals(oldScores.get(i), scores.get(i), TOLERANCE); 186 | Assert.assertEquals(labels.get(i).toString(), oldLabels.get(i).toString()); 187 | } 188 | 189 | } 190 | 191 | public static ArrayList loadScores(String filepath) { 192 | try { 193 | ArrayList scores = new ArrayList(); 194 | BufferedReader in = null; 195 | String encoding = "UTF-8"; 196 | if (filepath.endsWith(".gz")) { 197 | in = new BufferedReader( 198 | new InputStreamReader(new GZIPInputStream(new FileInputStream(filepath)), encoding)); 199 | } else { 200 | in = new BufferedReader(new InputStreamReader(new FileInputStream(filepath), encoding)); 201 | } 202 | 203 | String str = ""; 204 | while ((str = in.readLine()) != null) { 205 | scores.add(Double.parseDouble(str.split("\t")[1])); 206 | } 207 | 208 | in.close(); 209 | 210 | return scores; 211 | 212 | } catch (UnsupportedEncodingException e) { 213 | e.printStackTrace(); 214 | Assert.assertTrue(false); 215 | } catch (FileNotFoundException e) { 216 | e.printStackTrace(); 217 | Assert.assertTrue(false); 218 | } catch (IOException e) { 219 | e.printStackTrace(); 220 | Assert.assertTrue(false); 221 | } 222 | 223 | return null; 224 | } 225 | 226 | public static ArrayList loadLabels(String filepath) { 227 | try { 228 | ArrayList res = new ArrayList(); 229 | BufferedReader in = null; 230 | String encoding = "UTF-8"; 231 | if (filepath.endsWith(".gz")) { 232 | in = new BufferedReader( 233 | new InputStreamReader(new GZIPInputStream(new FileInputStream(filepath)), encoding)); 234 | } else { 235 | in = new BufferedReader(new InputStreamReader(new FileInputStream(filepath), encoding)); 236 | } 237 | 238 | String str = ""; 239 | while ((str = in.readLine()) != null) { 240 | res.add(new StringLabel(str.split("\t")[0])); 241 | } 242 | 243 | in.close(); 244 | 245 | return res; 246 | 247 | } catch (UnsupportedEncodingException e) { 248 | e.printStackTrace(); 249 | Assert.assertTrue(false); 250 | } catch (FileNotFoundException e) { 251 | e.printStackTrace(); 252 | Assert.assertTrue(false); 253 | } catch (IOException e) { 254 | e.printStackTrace(); 255 | Assert.assertTrue(false); 256 | } 257 | 258 | return null; 259 | } 260 | 261 | } 262 | -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/README.txt: -------------------------------------------------------------------------------- 1 | The datasets reported in this folder have been created starting from the dataset produced by Thorsten Joachims as an example problem for his SVM^{hmm} implementation. 2 | 3 | The original dataset can be downloaded at: 4 | http://download.joachims.org/svm_hmm/examples/example7.tar.gz 5 | while its description is reported at: 6 | https://www.cs.cornell.edu/people/tj/svm_light/svm_hmm.html -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/declaration_of_independence.klp.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/sequence_learning/declaration_of_independence.klp.gz -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/gettysburg_address.klp.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/sequence_learning/gettysburg_address.klp.gz -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/prediction_test_kernel.txt: -------------------------------------------------------------------------------- 1 | 1 -58.85170393685728 2 | 22 -58.85170393685728 3 | 3 -58.85170393685728 4 | 8 -58.85170393685728 5 | 15 -58.85170393685728 6 | 30 -58.85170393685728 7 | 20 -58.85170393685728 8 | 15 -58.85170393685728 9 | 11 -58.85170393685728 10 | 12 -58.85170393685728 11 | 8 -58.85170393685728 12 | 7 -58.85170393685728 13 | 12 -58.85170393685728 14 | 7 -58.85170393685728 15 | 9 -58.85170393685728 16 | 12 -58.85170393685728 17 | 6 -58.85170393685728 18 | 27 -58.85170393685728 19 | 8 -58.85170393685728 20 | 13 -58.85170393685728 21 | 3 -58.85170393685728 22 | 27 -58.85170393685728 23 | 25 -58.85170393685728 24 | 7 -58.85170393685728 25 | 12 -58.85170393685728 26 | 7 -58.85170393685728 27 | 7 -58.85170393685728 28 | 15 -58.85170393685728 29 | 30 -58.85170393685728 30 | 29 -58.85170393685728 31 | 9 -58.85170393685728 32 | 17 -58.85170393685728 33 | 21 -49.958351223707055 34 | 19 -49.958351223707055 35 | 30 -49.958351223707055 36 | 29 -49.958351223707055 37 | 8 -49.958351223707055 38 | 7 -49.958351223707055 39 | 13 -49.958351223707055 40 | 9 -49.958351223707055 41 | 12 -49.958351223707055 42 | 6 -49.958351223707055 43 | 28 -49.958351223707055 44 | 21 -49.958351223707055 45 | 7 -49.958351223707055 46 | 12 -49.958351223707055 47 | 8 -49.958351223707055 48 | 7 -49.958351223707055 49 | 12 -49.958351223707055 50 | 8 -49.958351223707055 51 | 29 -49.958351223707055 52 | 3 -49.958351223707055 53 | 21 -49.958351223707055 54 | 27 -49.958351223707055 55 | 9 -49.958351223707055 56 | 9 -49.958351223707055 57 | 12 -49.958351223707055 58 | 17 -49.958351223707055 59 | 19 -19.14935390144825 60 | 30 -19.14935390144825 61 | 29 -19.14935390144825 62 | 8 -19.14935390144825 63 | 7 -19.14935390144825 64 | 13 -19.14935390144825 65 | 13 -19.14935390144825 66 | 8 -19.14935390144825 67 | 7 -19.14935390144825 68 | 12 -19.14935390144825 69 | 17 -19.14935390144825 70 | 19 -51.68865761583535 71 | 30 -51.68865761583535 72 | 9 -51.68865761583535 73 | 25 -51.68865761583535 74 | 26 -51.68865761583535 75 | 7 -51.68865761583535 76 | 12 -51.68865761583535 77 | 8 -51.68865761583535 78 | 7 -51.68865761583535 79 | 12 -51.68865761583535 80 | 8 -51.68865761583535 81 | 7 -51.68865761583535 82 | 9 -51.68865761583535 83 | 12 -51.68865761583535 84 | 8 -51.68865761583535 85 | 7 -51.68865761583535 86 | 9 -51.68865761583535 87 | 12 -51.68865761583535 88 | 30 -51.68865761583535 89 | 20 -51.68865761583535 90 | 15 -51.68865761583535 91 | 7 -51.68865761583535 92 | 7 -51.68865761583535 93 | 12 -51.68865761583535 94 | 12 -51.68865761583535 95 | 30 -51.68865761583535 96 | 17 -51.68865761583535 97 | 19 -20.546289531993914 98 | 31 -20.546289531993914 99 | 21 -20.546289531993914 100 | 28 -20.546289531993914 101 | 3 -20.546289531993914 102 | 8 -20.546289531993914 103 | 7 -20.546289531993914 104 | 19 -20.546289531993914 105 | 11 -20.546289531993914 106 | 26 -20.546289531993914 107 | 7 -20.546289531993914 108 | 17 -20.546289531993914 109 | 3 -39.18519755398995 110 | 8 -39.18519755398995 111 | 7 -39.18519755398995 112 | 9 -39.18519755398995 113 | 12 -39.18519755398995 114 | 6 -39.18519755398995 115 | 19 -39.18519755398995 116 | 11 -39.18519755398995 117 | 26 -39.18519755398995 118 | 6 -39.18519755398995 119 | 19 -39.18519755398995 120 | 11 -39.18519755398995 121 | 26 -39.18519755398995 122 | 6 -39.18519755398995 123 | 19 -39.18519755398995 124 | 11 -39.18519755398995 125 | 26 -39.18519755398995 126 | 7 -39.18519755398995 127 | 12 -39.18519755398995 128 | 17 -39.18519755398995 129 | 7 -46.958838324933005 130 | 9 -46.958838324933005 131 | 15 -46.958838324933005 132 | 6 -46.958838324933005 133 | 28 -46.958838324933005 134 | 3 -46.958838324933005 135 | 12 -46.958838324933005 136 | 21 -46.958838324933005 137 | 29 -46.958838324933005 138 | 21 -46.958838324933005 139 | 30 -46.958838324933005 140 | 29 -46.958838324933005 141 | 19 -46.958838324933005 142 | 20 -46.958838324933005 143 | 12 -46.958838324933005 144 | 20 -46.958838324933005 145 | 9 -46.958838324933005 146 | 12 -46.958838324933005 147 | 25 -46.958838324933005 148 | 26 -46.958838324933005 149 | 3 -46.958838324933005 150 | 9 -46.958838324933005 151 | 17 -46.958838324933005 152 | 7 -49.424689389703104 153 | 12 -49.424689389703104 154 | 11 -49.424689389703104 155 | 26 -49.424689389703104 156 | 9 -49.424689389703104 157 | 3 -49.424689389703104 158 | 9 -49.424689389703104 159 | 12 -49.424689389703104 160 | 8 -49.424689389703104 161 | 19 -49.424689389703104 162 | 11 -49.424689389703104 163 | 21 -49.424689389703104 164 | 6 -49.424689389703104 165 | 3 -49.424689389703104 166 | 19 -49.424689389703104 167 | 8 -49.424689389703104 168 | 12 -49.424689389703104 169 | 26 -49.424689389703104 170 | 7 -49.424689389703104 171 | 19 -49.424689389703104 172 | 30 -49.424689389703104 173 | 21 -49.424689389703104 174 | 17 -49.424689389703104 175 | 19 -52.114114669781316 176 | 31 -52.114114669781316 177 | 8 -52.114114669781316 178 | 19 -52.114114669781316 179 | 7 -52.114114669781316 180 | 28 -52.114114669781316 181 | 21 -52.114114669781316 182 | 25 -52.114114669781316 183 | 26 -52.114114669781316 184 | 27 -52.114114669781316 185 | 21 -52.114114669781316 186 | 25 -52.114114669781316 187 | 7 -52.114114669781316 188 | 9 -52.114114669781316 189 | 12 -52.114114669781316 190 | 32 -52.114114669781316 191 | 19 -52.114114669781316 192 | 30 -52.114114669781316 193 | 11 -52.114114669781316 194 | 21 -52.114114669781316 195 | 30 -52.114114669781316 196 | 21 -52.114114669781316 197 | 29 -52.114114669781316 198 | 21 -52.114114669781316 199 | 21 -52.114114669781316 200 | 29 -52.114114669781316 201 | 17 -52.114114669781316 202 | 19 -157.40459068974272 203 | 31 -157.40459068974272 204 | 21 -157.40459068974272 205 | 8 -157.40459068974272 206 | 19 -157.40459068974272 207 | 25 -157.40459068974272 208 | 26 -157.40459068974272 209 | 21 -157.40459068974272 210 | 27 -157.40459068974272 211 | 25 -157.40459068974272 212 | 7 -157.40459068974272 213 | 13 -157.40459068974272 214 | 13 -157.40459068974272 215 | 28 -157.40459068974272 216 | 21 -157.40459068974272 217 | 19 -157.40459068974272 218 | 5 -157.40459068974272 219 | 7 -157.40459068974272 220 | 8 -157.40459068974272 221 | 7 -157.40459068974272 222 | 9 -157.40459068974272 223 | 12 -157.40459068974272 224 | 19 -157.40459068974272 225 | 30 -157.40459068974272 226 | 9 -157.40459068974272 227 | 12 -157.40459068974272 228 | 25 -157.40459068974272 229 | 7 -157.40459068974272 230 | 12 -157.40459068974272 231 | 8 -157.40459068974272 232 | 32 -157.40459068974272 233 | 19 -157.40459068974272 234 | 30 -157.40459068974272 235 | 7 -157.40459068974272 236 | 23 -157.40459068974272 237 | 9 -157.40459068974272 238 | 12 -157.40459068974272 239 | 8 -157.40459068974272 240 | 12 -157.40459068974272 241 | 8 -157.40459068974272 242 | 7 -157.40459068974272 243 | 19 -157.40459068974272 244 | 21 -157.40459068974272 245 | 21 -157.40459068974272 246 | 26 -157.40459068974272 247 | 7 -157.40459068974272 248 | 7 -157.40459068974272 249 | 12 -157.40459068974272 250 | 11 -157.40459068974272 251 | 21 -157.40459068974272 252 | 30 -157.40459068974272 253 | 29 -157.40459068974272 254 | 8 -157.40459068974272 255 | 9 -157.40459068974272 256 | 6 -157.40459068974272 257 | 7 -157.40459068974272 258 | 7 -157.40459068974272 259 | 12 -157.40459068974272 260 | 8 -157.40459068974272 261 | 13 -157.40459068974272 262 | 11 -157.40459068974272 263 | 30 -157.40459068974272 264 | 7 -157.40459068974272 265 | 9 -157.40459068974272 266 | 12 -157.40459068974272 267 | 8 -157.40459068974272 268 | 12 -157.40459068974272 269 | 6 -157.40459068974272 270 | 3 -157.40459068974272 271 | 7 -157.40459068974272 272 | 12 -157.40459068974272 273 | 8 -157.40459068974272 274 | 7 -157.40459068974272 275 | 15 -157.40459068974272 276 | 6 -157.40459068974272 277 | 8 -157.40459068974272 278 | 7 -157.40459068974272 279 | 15 -157.40459068974272 280 | 6 -157.40459068974272 281 | 8 -157.40459068974272 282 | 7 -157.40459068974272 283 | 15 -157.40459068974272 284 | 11 -157.40459068974272 285 | 21 -157.40459068974272 286 | 26 -157.40459068974272 287 | 8 -157.40459068974272 288 | 7 -157.40459068974272 289 | 12 -157.40459068974272 290 | 17 -157.40459068974272 291 | -------------------------------------------------------------------------------- /src/test/resources/sequence_learning/prediction_test_linear.txt: -------------------------------------------------------------------------------- 1 | 1 -61.552865965064605 2 | 22 -61.552865965064605 3 | 3 -61.552865965064605 4 | 8 -61.552865965064605 5 | 15 -61.552865965064605 6 | 30 -61.552865965064605 7 | 20 -61.552865965064605 8 | 15 -61.552865965064605 9 | 11 -61.552865965064605 10 | 12 -61.552865965064605 11 | 8 -61.552865965064605 12 | 7 -61.552865965064605 13 | 12 -61.552865965064605 14 | 7 -61.552865965064605 15 | 9 -61.552865965064605 16 | 12 -61.552865965064605 17 | 6 -61.552865965064605 18 | 27 -61.552865965064605 19 | 8 -61.552865965064605 20 | 13 -61.552865965064605 21 | 3 -61.552865965064605 22 | 27 -61.552865965064605 23 | 25 -61.552865965064605 24 | 7 -61.552865965064605 25 | 12 -61.552865965064605 26 | 7 -61.552865965064605 27 | 7 -61.552865965064605 28 | 15 -61.552865965064605 29 | 30 -61.552865965064605 30 | 29 -61.552865965064605 31 | 9 -61.552865965064605 32 | 17 -61.552865965064605 33 | 21 -50.586976361817456 34 | 19 -50.586976361817456 35 | 30 -50.586976361817456 36 | 29 -50.586976361817456 37 | 8 -50.586976361817456 38 | 7 -50.586976361817456 39 | 13 -50.586976361817456 40 | 9 -50.586976361817456 41 | 12 -50.586976361817456 42 | 6 -50.586976361817456 43 | 28 -50.586976361817456 44 | 21 -50.586976361817456 45 | 7 -50.586976361817456 46 | 12 -50.586976361817456 47 | 3 -50.586976361817456 48 | 7 -50.586976361817456 49 | 12 -50.586976361817456 50 | 21 -50.586976361817456 51 | 21 -50.586976361817456 52 | 3 -50.586976361817456 53 | 21 -50.586976361817456 54 | 21 -50.586976361817456 55 | 7 -50.586976361817456 56 | 9 -50.586976361817456 57 | 12 -50.586976361817456 58 | 17 -50.586976361817456 59 | 19 -19.745991163812985 60 | 30 -19.745991163812985 61 | 26 -19.745991163812985 62 | 8 -19.745991163812985 63 | 7 -19.745991163812985 64 | 13 -19.745991163812985 65 | 13 -19.745991163812985 66 | 8 -19.745991163812985 67 | 7 -19.745991163812985 68 | 12 -19.745991163812985 69 | 17 -19.745991163812985 70 | 19 -54.528993898737625 71 | 30 -54.528993898737625 72 | 9 -54.528993898737625 73 | 25 -54.528993898737625 74 | 26 -54.528993898737625 75 | 7 -54.528993898737625 76 | 12 -54.528993898737625 77 | 8 -54.528993898737625 78 | 7 -54.528993898737625 79 | 12 -54.528993898737625 80 | 8 -54.528993898737625 81 | 7 -54.528993898737625 82 | 9 -54.528993898737625 83 | 12 -54.528993898737625 84 | 8 -54.528993898737625 85 | 7 -54.528993898737625 86 | 30 -54.528993898737625 87 | 21 -54.528993898737625 88 | 30 -54.528993898737625 89 | 20 -54.528993898737625 90 | 15 -54.528993898737625 91 | 7 -54.528993898737625 92 | 7 -54.528993898737625 93 | 12 -54.528993898737625 94 | 12 -54.528993898737625 95 | 30 -54.528993898737625 96 | 17 -54.528993898737625 97 | 19 -21.642864657110263 98 | 31 -21.642864657110263 99 | 21 -21.642864657110263 100 | 28 -21.642864657110263 101 | 3 -21.642864657110263 102 | 8 -21.642864657110263 103 | 7 -21.642864657110263 104 | 19 -21.642864657110263 105 | 11 -21.642864657110263 106 | 26 -21.642864657110263 107 | 7 -21.642864657110263 108 | 17 -21.642864657110263 109 | 3 -41.163958681094705 110 | 8 -41.163958681094705 111 | 7 -41.163958681094705 112 | 9 -41.163958681094705 113 | 12 -41.163958681094705 114 | 6 -41.163958681094705 115 | 19 -41.163958681094705 116 | 11 -41.163958681094705 117 | 26 -41.163958681094705 118 | 6 -41.163958681094705 119 | 19 -41.163958681094705 120 | 11 -41.163958681094705 121 | 26 -41.163958681094705 122 | 6 -41.163958681094705 123 | 19 -41.163958681094705 124 | 11 -41.163958681094705 125 | 26 -41.163958681094705 126 | 7 -41.163958681094705 127 | 12 -41.163958681094705 128 | 17 -41.163958681094705 129 | 7 -47.69120077996808 130 | 9 -47.69120077996808 131 | 15 -47.69120077996808 132 | 6 -47.69120077996808 133 | 28 -47.69120077996808 134 | 3 -47.69120077996808 135 | 12 -47.69120077996808 136 | 21 -47.69120077996808 137 | 29 -47.69120077996808 138 | 21 -47.69120077996808 139 | 30 -47.69120077996808 140 | 29 -47.69120077996808 141 | 19 -47.69120077996808 142 | 12 -47.69120077996808 143 | 30 -47.69120077996808 144 | 20 -47.69120077996808 145 | 9 -47.69120077996808 146 | 12 -47.69120077996808 147 | 25 -47.69120077996808 148 | 26 -47.69120077996808 149 | 3 -47.69120077996808 150 | 12 -47.69120077996808 151 | 17 -47.69120077996808 152 | 7 -49.193906625295206 153 | 12 -49.193906625295206 154 | 11 -49.193906625295206 155 | 26 -49.193906625295206 156 | 9 -49.193906625295206 157 | 3 -49.193906625295206 158 | 9 -49.193906625295206 159 | 12 -49.193906625295206 160 | 7 -49.193906625295206 161 | 19 -49.193906625295206 162 | 11 -49.193906625295206 163 | 21 -49.193906625295206 164 | 6 -49.193906625295206 165 | 3 -49.193906625295206 166 | 19 -49.193906625295206 167 | 8 -49.193906625295206 168 | 12 -49.193906625295206 169 | 26 -49.193906625295206 170 | 7 -49.193906625295206 171 | 19 -49.193906625295206 172 | 30 -49.193906625295206 173 | 21 -49.193906625295206 174 | 17 -49.193906625295206 175 | 19 -52.77654733531991 176 | 31 -52.77654733531991 177 | 8 -52.77654733531991 178 | 19 -52.77654733531991 179 | 7 -52.77654733531991 180 | 28 -52.77654733531991 181 | 21 -52.77654733531991 182 | 25 -52.77654733531991 183 | 26 -52.77654733531991 184 | 27 -52.77654733531991 185 | 21 -52.77654733531991 186 | 25 -52.77654733531991 187 | 7 -52.77654733531991 188 | 9 -52.77654733531991 189 | 12 -52.77654733531991 190 | 32 -52.77654733531991 191 | 19 -52.77654733531991 192 | 30 -52.77654733531991 193 | 11 -52.77654733531991 194 | 21 -52.77654733531991 195 | 30 -52.77654733531991 196 | 21 -52.77654733531991 197 | 12 -52.77654733531991 198 | 21 -52.77654733531991 199 | 21 -52.77654733531991 200 | 29 -52.77654733531991 201 | 17 -52.77654733531991 202 | 19 -162.86896772139426 203 | 31 -162.86896772139426 204 | 21 -162.86896772139426 205 | 8 -162.86896772139426 206 | 19 -162.86896772139426 207 | 25 -162.86896772139426 208 | 26 -162.86896772139426 209 | 21 -162.86896772139426 210 | 27 -162.86896772139426 211 | 25 -162.86896772139426 212 | 7 -162.86896772139426 213 | 13 -162.86896772139426 214 | 13 -162.86896772139426 215 | 28 -162.86896772139426 216 | 21 -162.86896772139426 217 | 19 -162.86896772139426 218 | 5 -162.86896772139426 219 | 7 -162.86896772139426 220 | 8 -162.86896772139426 221 | 7 -162.86896772139426 222 | 9 -162.86896772139426 223 | 12 -162.86896772139426 224 | 19 -162.86896772139426 225 | 30 -162.86896772139426 226 | 9 -162.86896772139426 227 | 12 -162.86896772139426 228 | 25 -162.86896772139426 229 | 7 -162.86896772139426 230 | 12 -162.86896772139426 231 | 8 -162.86896772139426 232 | 32 -162.86896772139426 233 | 19 -162.86896772139426 234 | 30 -162.86896772139426 235 | 7 -162.86896772139426 236 | 23 -162.86896772139426 237 | 9 -162.86896772139426 238 | 12 -162.86896772139426 239 | 8 -162.86896772139426 240 | 12 -162.86896772139426 241 | 5 -162.86896772139426 242 | 7 -162.86896772139426 243 | 19 -162.86896772139426 244 | 21 -162.86896772139426 245 | 21 -162.86896772139426 246 | 26 -162.86896772139426 247 | 7 -162.86896772139426 248 | 7 -162.86896772139426 249 | 12 -162.86896772139426 250 | 11 -162.86896772139426 251 | 21 -162.86896772139426 252 | 30 -162.86896772139426 253 | 29 -162.86896772139426 254 | 8 -162.86896772139426 255 | 9 -162.86896772139426 256 | 6 -162.86896772139426 257 | 7 -162.86896772139426 258 | 7 -162.86896772139426 259 | 12 -162.86896772139426 260 | 8 -162.86896772139426 261 | 13 -162.86896772139426 262 | 11 -162.86896772139426 263 | 30 -162.86896772139426 264 | 7 -162.86896772139426 265 | 9 -162.86896772139426 266 | 12 -162.86896772139426 267 | 8 -162.86896772139426 268 | 12 -162.86896772139426 269 | 6 -162.86896772139426 270 | 3 -162.86896772139426 271 | 8 -162.86896772139426 272 | 12 -162.86896772139426 273 | 8 -162.86896772139426 274 | 7 -162.86896772139426 275 | 15 -162.86896772139426 276 | 6 -162.86896772139426 277 | 8 -162.86896772139426 278 | 7 -162.86896772139426 279 | 15 -162.86896772139426 280 | 6 -162.86896772139426 281 | 8 -162.86896772139426 282 | 7 -162.86896772139426 283 | 15 -162.86896772139426 284 | 11 -162.86896772139426 285 | 21 -162.86896772139426 286 | 26 -162.86896772139426 287 | 8 -162.86896772139426 288 | 7 -162.86896772139426 289 | 12 -162.86896772139426 290 | 17 -162.86896772139426 291 | -------------------------------------------------------------------------------- /src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAG-KeLP/kelp-additional-algorithms/63b396f358e54c2f5e87652d8209a017dce21791/src/test/resources/svmTest/binary/liblinear/polarity_sparse_dense_repr.txt.gz --------------------------------------------------------------------------------